agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients map[string]*lsp.Client,
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	toolFn := func() []tools.BaseTool {
 180		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 183		}()
 184
 185		cwd := cfg.WorkingDir()
 186		allTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199
 200		if agentCfg.AllowedTools == nil {
 201			return allTools
 202		}
 203
 204		var filteredTools []tools.BaseTool
 205		for _, tool := range allTools {
 206			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 207				filteredTools = append(filteredTools, tool)
 208			}
 209		}
 210
 211		if agentCfg.ID == "coder" {
 212			mcpToolsOnce.Do(func() {
 213				mcpTools = doGetMCPTools(ctx, permissions, cfg)
 214			})
 215			filteredTools = append(filteredTools, mcpTools...)
 216			if len(lspClients) > 0 {
 217				filteredTools = append(filteredTools, tools.NewDiagnosticsTool(lspClients))
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		agentToolFn:         agentToolFn,
 234		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 235		tools:               csync.NewLazySlice(toolFn),
 236		promptQueue:         csync.NewMap[string, []string](),
 237	}, nil
 238}
 239
 240func (a *agent) Model() catwalk.Model {
 241	return *config.Get().GetModelByType(a.agentCfg.Model)
 242}
 243
 244func (a *agent) Cancel(sessionID string) {
 245	// Cancel regular requests
 246	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 247		slog.Info("Request cancellation initiated", "session_id", sessionID)
 248		cancel()
 249	}
 250
 251	// Also check for summarize requests
 252	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 253		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 254		cancel()
 255	}
 256
 257	if a.QueuedPrompts(sessionID) > 0 {
 258		slog.Info("Clearing queued prompts", "session_id", sessionID)
 259		a.promptQueue.Del(sessionID)
 260	}
 261}
 262
 263func (a *agent) IsBusy() bool {
 264	var busy bool
 265	for cancelFunc := range a.activeRequests.Seq() {
 266		if cancelFunc != nil {
 267			busy = true
 268			break
 269		}
 270	}
 271	return busy
 272}
 273
 274func (a *agent) IsSessionBusy(sessionID string) bool {
 275	_, busy := a.activeRequests.Get(sessionID)
 276	return busy
 277}
 278
 279func (a *agent) QueuedPrompts(sessionID string) int {
 280	l, ok := a.promptQueue.Get(sessionID)
 281	if !ok {
 282		return 0
 283	}
 284	return len(l)
 285}
 286
 287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 288	if content == "" {
 289		return nil
 290	}
 291	if a.titleProvider == nil {
 292		return nil
 293	}
 294	session, err := a.sessions.Get(ctx, sessionID)
 295	if err != nil {
 296		return err
 297	}
 298	parts := []message.ContentPart{message.TextContent{
 299		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 300	}}
 301
 302	// Use streaming approach like summarization
 303	response := a.titleProvider.StreamResponse(
 304		ctx,
 305		[]message.Message{
 306			{
 307				Role:  message.User,
 308				Parts: parts,
 309			},
 310		},
 311		nil,
 312	)
 313
 314	var finalResponse *provider.ProviderResponse
 315	for r := range response {
 316		if r.Error != nil {
 317			return r.Error
 318		}
 319		finalResponse = r.Response
 320	}
 321
 322	if finalResponse == nil {
 323		return fmt.Errorf("no response received from title provider")
 324	}
 325
 326	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 327	if title == "" {
 328		return nil
 329	}
 330
 331	session.Title = title
 332	_, err = a.sessions.Save(ctx, session)
 333	return err
 334}
 335
 336func (a *agent) err(err error) AgentEvent {
 337	return AgentEvent{
 338		Type:  AgentEventTypeError,
 339		Error: err,
 340	}
 341}
 342
 343func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 344	if !a.Model().SupportsImages && attachments != nil {
 345		attachments = nil
 346	}
 347	events := make(chan AgentEvent, 1)
 348	if a.IsSessionBusy(sessionID) {
 349		existing, ok := a.promptQueue.Get(sessionID)
 350		if !ok {
 351			existing = []string{}
 352		}
 353		existing = append(existing, content)
 354		a.promptQueue.Set(sessionID, existing)
 355		return nil, nil
 356	}
 357
 358	genCtx, cancel := context.WithCancel(ctx)
 359
 360	a.activeRequests.Set(sessionID, cancel)
 361	go func() {
 362		slog.Debug("Request started", "sessionID", sessionID)
 363		defer log.RecoverPanic("agent.Run", func() {
 364			events <- a.err(fmt.Errorf("panic while running the agent"))
 365		})
 366		var attachmentParts []message.ContentPart
 367		for _, attachment := range attachments {
 368			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 369		}
 370		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 371		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 372			slog.Error(result.Error.Error())
 373		}
 374		slog.Debug("Request completed", "sessionID", sessionID)
 375		a.activeRequests.Del(sessionID)
 376		cancel()
 377		a.Publish(pubsub.CreatedEvent, result)
 378		events <- result
 379		close(events)
 380	}()
 381	return events, nil
 382}
 383
 384func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 385	cfg := config.Get()
 386	// List existing messages; if none, start title generation asynchronously.
 387	msgs, err := a.messages.List(ctx, sessionID)
 388	if err != nil {
 389		return a.err(fmt.Errorf("failed to list messages: %w", err))
 390	}
 391	if len(msgs) == 0 {
 392		go func() {
 393			defer log.RecoverPanic("agent.Run", func() {
 394				slog.Error("panic while generating title")
 395			})
 396			titleErr := a.generateTitle(ctx, sessionID, content)
 397			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 398				slog.Error("failed to generate title", "error", titleErr)
 399			}
 400		}()
 401	}
 402	session, err := a.sessions.Get(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to get session: %w", err))
 405	}
 406	if session.SummaryMessageID != "" {
 407		summaryMsgInex := -1
 408		for i, msg := range msgs {
 409			if msg.ID == session.SummaryMessageID {
 410				summaryMsgInex = i
 411				break
 412			}
 413		}
 414		if summaryMsgInex != -1 {
 415			msgs = msgs[summaryMsgInex:]
 416			msgs[0].Role = message.User
 417		}
 418	}
 419
 420	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 421	if err != nil {
 422		return a.err(fmt.Errorf("failed to create user message: %w", err))
 423	}
 424	// Append the new user message to the conversation history.
 425	msgHistory := append(msgs, userMsg)
 426
 427	for {
 428		// Check for cancellation before each iteration
 429		select {
 430		case <-ctx.Done():
 431			return a.err(ctx.Err())
 432		default:
 433			// Continue processing
 434		}
 435		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 436		if err != nil {
 437			if errors.Is(err, context.Canceled) {
 438				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 439				a.messages.Update(context.Background(), agentMessage)
 440				return a.err(ErrRequestCancelled)
 441			}
 442			return a.err(fmt.Errorf("failed to process events: %w", err))
 443		}
 444		if cfg.Options.Debug {
 445			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 446		}
 447		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 448			// We are not done, we need to respond with the tool response
 449			msgHistory = append(msgHistory, agentMessage, *toolResults)
 450			// If there are queued prompts, process the next one
 451			nextPrompt, ok := a.promptQueue.Take(sessionID)
 452			if ok {
 453				for _, prompt := range nextPrompt {
 454					// Create a new user message for the queued prompt
 455					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 456					if err != nil {
 457						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 458					}
 459					// Append the new user message to the conversation history
 460					msgHistory = append(msgHistory, userMsg)
 461				}
 462			}
 463
 464			continue
 465		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 466			queuePrompts, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range queuePrompts {
 469					if prompt == "" {
 470						continue
 471					}
 472					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 473					if err != nil {
 474						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 475					}
 476					msgHistory = append(msgHistory, userMsg)
 477				}
 478				continue
 479			}
 480		}
 481		if agentMessage.FinishReason() == "" {
 482			// Kujtim: could not track down where this is happening but this means its cancelled
 483			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 484			_ = a.messages.Update(context.Background(), agentMessage)
 485			return a.err(ErrRequestCancelled)
 486		}
 487		return AgentEvent{
 488			Type:    AgentEventTypeResponse,
 489			Message: agentMessage,
 490			Done:    true,
 491		}
 492	}
 493}
 494
 495func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 496	parts := []message.ContentPart{message.TextContent{Text: content}}
 497	parts = append(parts, attachmentParts...)
 498	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 499		Role:  message.User,
 500		Parts: parts,
 501	})
 502}
 503
 504func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 505	allTools := slices.Collect(a.tools.Seq())
 506	if a.agentToolFn != nil {
 507		agentTool, agentToolErr := a.agentToolFn()
 508		if agentToolErr != nil {
 509			return nil, agentToolErr
 510		}
 511		allTools = append(allTools, agentTool)
 512	}
 513	return allTools, nil
 514}
 515
 516func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 517	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 518
 519	// Create the assistant message first so the spinner shows immediately
 520	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 521		Role:     message.Assistant,
 522		Parts:    []message.ContentPart{},
 523		Model:    a.Model().ID,
 524		Provider: a.providerID,
 525	})
 526	if err != nil {
 527		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 528	}
 529
 530	allTools, toolsErr := a.getAllTools()
 531	if toolsErr != nil {
 532		return assistantMsg, nil, toolsErr
 533	}
 534	// Now collect tools (which may block on MCP initialization)
 535	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 536
 537	// Add the session and message ID into the context if needed by tools.
 538	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 539
 540	// Process each event in the stream.
 541	for event := range eventChan {
 542		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 543			if errors.Is(processErr, context.Canceled) {
 544				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 545			} else {
 546				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 547			}
 548			return assistantMsg, nil, processErr
 549		}
 550		if ctx.Err() != nil {
 551			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 552			return assistantMsg, nil, ctx.Err()
 553		}
 554	}
 555
 556	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 557	toolCalls := assistantMsg.ToolCalls()
 558	for i, toolCall := range toolCalls {
 559		select {
 560		case <-ctx.Done():
 561			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 562			// Make all future tool calls cancelled
 563			for j := i; j < len(toolCalls); j++ {
 564				toolResults[j] = message.ToolResult{
 565					ToolCallID: toolCalls[j].ID,
 566					Content:    "Tool execution canceled by user",
 567					IsError:    true,
 568				}
 569			}
 570			goto out
 571		default:
 572			// Continue processing
 573			var tool tools.BaseTool
 574			allTools, _ := a.getAllTools()
 575			for _, availableTool := range allTools {
 576				if availableTool.Info().Name == toolCall.Name {
 577					tool = availableTool
 578					break
 579				}
 580			}
 581
 582			// Tool not found
 583			if tool == nil {
 584				toolResults[i] = message.ToolResult{
 585					ToolCallID: toolCall.ID,
 586					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 587					IsError:    true,
 588				}
 589				continue
 590			}
 591
 592			// Run tool in goroutine to allow cancellation
 593			type toolExecResult struct {
 594				response tools.ToolResponse
 595				err      error
 596			}
 597			resultChan := make(chan toolExecResult, 1)
 598
 599			go func() {
 600				response, err := tool.Run(ctx, tools.ToolCall{
 601					ID:    toolCall.ID,
 602					Name:  toolCall.Name,
 603					Input: toolCall.Input,
 604				})
 605				resultChan <- toolExecResult{response: response, err: err}
 606			}()
 607
 608			var toolResponse tools.ToolResponse
 609			var toolErr error
 610
 611			select {
 612			case <-ctx.Done():
 613				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 614				// Mark remaining tool calls as cancelled
 615				for j := i; j < len(toolCalls); j++ {
 616					toolResults[j] = message.ToolResult{
 617						ToolCallID: toolCalls[j].ID,
 618						Content:    "Tool execution canceled by user",
 619						IsError:    true,
 620					}
 621				}
 622				goto out
 623			case result := <-resultChan:
 624				toolResponse = result.response
 625				toolErr = result.err
 626			}
 627
 628			if toolErr != nil {
 629				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 630				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 631					toolResults[i] = message.ToolResult{
 632						ToolCallID: toolCall.ID,
 633						Content:    "Permission denied",
 634						IsError:    true,
 635					}
 636					for j := i + 1; j < len(toolCalls); j++ {
 637						toolResults[j] = message.ToolResult{
 638							ToolCallID: toolCalls[j].ID,
 639							Content:    "Tool execution canceled by user",
 640							IsError:    true,
 641						}
 642					}
 643					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 644					break
 645				}
 646			}
 647			toolResults[i] = message.ToolResult{
 648				ToolCallID: toolCall.ID,
 649				Content:    toolResponse.Content,
 650				Metadata:   toolResponse.Metadata,
 651				IsError:    toolResponse.IsError,
 652			}
 653		}
 654	}
 655out:
 656	if len(toolResults) == 0 {
 657		return assistantMsg, nil, nil
 658	}
 659	parts := make([]message.ContentPart, 0)
 660	for _, tr := range toolResults {
 661		parts = append(parts, tr)
 662	}
 663	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 664		Role:     message.Tool,
 665		Parts:    parts,
 666		Provider: a.providerID,
 667	})
 668	if err != nil {
 669		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 670	}
 671
 672	return assistantMsg, &msg, err
 673}
 674
 675func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 676	msg.AddFinish(finishReason, message, details)
 677	_ = a.messages.Update(ctx, *msg)
 678}
 679
 680func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 681	select {
 682	case <-ctx.Done():
 683		return ctx.Err()
 684	default:
 685		// Continue processing.
 686	}
 687
 688	switch event.Type {
 689	case provider.EventThinkingDelta:
 690		assistantMsg.AppendReasoningContent(event.Thinking)
 691		return a.messages.Update(ctx, *assistantMsg)
 692	case provider.EventSignatureDelta:
 693		assistantMsg.AppendReasoningSignature(event.Signature)
 694		return a.messages.Update(ctx, *assistantMsg)
 695	case provider.EventContentDelta:
 696		assistantMsg.FinishThinking()
 697		assistantMsg.AppendContent(event.Content)
 698		return a.messages.Update(ctx, *assistantMsg)
 699	case provider.EventToolUseStart:
 700		assistantMsg.FinishThinking()
 701		slog.Info("Tool call started", "toolCall", event.ToolCall)
 702		assistantMsg.AddToolCall(*event.ToolCall)
 703		return a.messages.Update(ctx, *assistantMsg)
 704	case provider.EventToolUseDelta:
 705		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 706		return a.messages.Update(ctx, *assistantMsg)
 707	case provider.EventToolUseStop:
 708		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 709		assistantMsg.FinishToolCall(event.ToolCall.ID)
 710		return a.messages.Update(ctx, *assistantMsg)
 711	case provider.EventError:
 712		return event.Error
 713	case provider.EventComplete:
 714		assistantMsg.FinishThinking()
 715		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 716		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 717		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 718			return fmt.Errorf("failed to update message: %w", err)
 719		}
 720		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 721	}
 722
 723	return nil
 724}
 725
 726func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 727	sess, err := a.sessions.Get(ctx, sessionID)
 728	if err != nil {
 729		return fmt.Errorf("failed to get session: %w", err)
 730	}
 731
 732	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 733		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 734		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 735		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 736
 737	sess.Cost += cost
 738	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 739	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 740
 741	_, err = a.sessions.Save(ctx, sess)
 742	if err != nil {
 743		return fmt.Errorf("failed to save session: %w", err)
 744	}
 745	return nil
 746}
 747
 748func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 749	if a.summarizeProvider == nil {
 750		return fmt.Errorf("summarize provider not available")
 751	}
 752
 753	// Check if session is busy
 754	if a.IsSessionBusy(sessionID) {
 755		return ErrSessionBusy
 756	}
 757
 758	// Create a new context with cancellation
 759	summarizeCtx, cancel := context.WithCancel(ctx)
 760
 761	// Store the cancel function in activeRequests to allow cancellation
 762	a.activeRequests.Set(sessionID+"-summarize", cancel)
 763
 764	go func() {
 765		defer a.activeRequests.Del(sessionID + "-summarize")
 766		defer cancel()
 767		event := AgentEvent{
 768			Type:     AgentEventTypeSummarize,
 769			Progress: "Starting summarization...",
 770		}
 771
 772		a.Publish(pubsub.CreatedEvent, event)
 773		// Get all messages from the session
 774		msgs, err := a.messages.List(summarizeCtx, sessionID)
 775		if err != nil {
 776			event = AgentEvent{
 777				Type:  AgentEventTypeError,
 778				Error: fmt.Errorf("failed to list messages: %w", err),
 779				Done:  true,
 780			}
 781			a.Publish(pubsub.CreatedEvent, event)
 782			return
 783		}
 784		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 785
 786		if len(msgs) == 0 {
 787			event = AgentEvent{
 788				Type:  AgentEventTypeError,
 789				Error: fmt.Errorf("no messages to summarize"),
 790				Done:  true,
 791			}
 792			a.Publish(pubsub.CreatedEvent, event)
 793			return
 794		}
 795
 796		event = AgentEvent{
 797			Type:     AgentEventTypeSummarize,
 798			Progress: "Analyzing conversation...",
 799		}
 800		a.Publish(pubsub.CreatedEvent, event)
 801
 802		// Add a system message to guide the summarization
 803		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 804
 805		// Create a new message with the summarize prompt
 806		promptMsg := message.Message{
 807			Role:  message.User,
 808			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 809		}
 810
 811		// Append the prompt to the messages
 812		msgsWithPrompt := append(msgs, promptMsg)
 813
 814		event = AgentEvent{
 815			Type:     AgentEventTypeSummarize,
 816			Progress: "Generating summary...",
 817		}
 818
 819		a.Publish(pubsub.CreatedEvent, event)
 820
 821		// Send the messages to the summarize provider
 822		response := a.summarizeProvider.StreamResponse(
 823			summarizeCtx,
 824			msgsWithPrompt,
 825			nil,
 826		)
 827		var finalResponse *provider.ProviderResponse
 828		for r := range response {
 829			if r.Error != nil {
 830				event = AgentEvent{
 831					Type:  AgentEventTypeError,
 832					Error: fmt.Errorf("failed to summarize: %w", err),
 833					Done:  true,
 834				}
 835				a.Publish(pubsub.CreatedEvent, event)
 836				return
 837			}
 838			finalResponse = r.Response
 839		}
 840
 841		summary := strings.TrimSpace(finalResponse.Content)
 842		if summary == "" {
 843			event = AgentEvent{
 844				Type:  AgentEventTypeError,
 845				Error: fmt.Errorf("empty summary returned"),
 846				Done:  true,
 847			}
 848			a.Publish(pubsub.CreatedEvent, event)
 849			return
 850		}
 851		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 852		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 853		event = AgentEvent{
 854			Type:     AgentEventTypeSummarize,
 855			Progress: "Creating new session...",
 856		}
 857
 858		a.Publish(pubsub.CreatedEvent, event)
 859		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 860		if err != nil {
 861			event = AgentEvent{
 862				Type:  AgentEventTypeError,
 863				Error: fmt.Errorf("failed to get session: %w", err),
 864				Done:  true,
 865			}
 866
 867			a.Publish(pubsub.CreatedEvent, event)
 868			return
 869		}
 870		// Create a message in the new session with the summary
 871		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 872			Role: message.Assistant,
 873			Parts: []message.ContentPart{
 874				message.TextContent{Text: summary},
 875				message.Finish{
 876					Reason: message.FinishReasonEndTurn,
 877					Time:   time.Now().Unix(),
 878				},
 879			},
 880			Model:    a.summarizeProvider.Model().ID,
 881			Provider: a.summarizeProviderID,
 882		})
 883		if err != nil {
 884			event = AgentEvent{
 885				Type:  AgentEventTypeError,
 886				Error: fmt.Errorf("failed to create summary message: %w", err),
 887				Done:  true,
 888			}
 889
 890			a.Publish(pubsub.CreatedEvent, event)
 891			return
 892		}
 893		oldSession.SummaryMessageID = msg.ID
 894		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 895		oldSession.PromptTokens = 0
 896		model := a.summarizeProvider.Model()
 897		usage := finalResponse.Usage
 898		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 899			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 900			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 901			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 902		oldSession.Cost += cost
 903		_, err = a.sessions.Save(summarizeCtx, oldSession)
 904		if err != nil {
 905			event = AgentEvent{
 906				Type:  AgentEventTypeError,
 907				Error: fmt.Errorf("failed to save session: %w", err),
 908				Done:  true,
 909			}
 910			a.Publish(pubsub.CreatedEvent, event)
 911		}
 912
 913		event = AgentEvent{
 914			Type:      AgentEventTypeSummarize,
 915			SessionID: oldSession.ID,
 916			Progress:  "Summary complete",
 917			Done:      true,
 918		}
 919		a.Publish(pubsub.CreatedEvent, event)
 920		// Send final success event with the new session ID
 921	}()
 922
 923	return nil
 924}
 925
 926func (a *agent) ClearQueue(sessionID string) {
 927	if a.QueuedPrompts(sessionID) > 0 {
 928		slog.Info("Clearing queued prompts", "session_id", sessionID)
 929		a.promptQueue.Del(sessionID)
 930	}
 931}
 932
 933func (a *agent) CancelAll() {
 934	if !a.IsBusy() {
 935		return
 936	}
 937	for key := range a.activeRequests.Seq2() {
 938		a.Cancel(key) // key is sessionID
 939	}
 940
 941	timeout := time.After(5 * time.Second)
 942	for a.IsBusy() {
 943		select {
 944		case <-timeout:
 945			return
 946		default:
 947			time.Sleep(200 * time.Millisecond)
 948		}
 949	}
 950}
 951
 952func (a *agent) UpdateModel() error {
 953	cfg := config.Get()
 954
 955	// Get current provider configuration
 956	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 957	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 958		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 959	}
 960
 961	// Check if provider has changed
 962	if string(currentProviderCfg.ID) != a.providerID {
 963		// Provider changed, need to recreate the main provider
 964		model := cfg.GetModelByType(a.agentCfg.Model)
 965		if model.ID == "" {
 966			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 967		}
 968
 969		promptID := agentPromptMap[a.agentCfg.ID]
 970		if promptID == "" {
 971			promptID = prompt.PromptDefault
 972		}
 973
 974		opts := []provider.ProviderClientOption{
 975			provider.WithModel(a.agentCfg.Model),
 976			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 977		}
 978
 979		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 980		if err != nil {
 981			return fmt.Errorf("failed to create new provider: %w", err)
 982		}
 983
 984		// Update the provider and provider ID
 985		a.provider = newProvider
 986		a.providerID = string(currentProviderCfg.ID)
 987	}
 988
 989	// Check if providers have changed for title (small) and summarize (large)
 990	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 991	var smallModelProviderCfg config.ProviderConfig
 992	for p := range cfg.Providers.Seq() {
 993		if p.ID == smallModelCfg.Provider {
 994			smallModelProviderCfg = p
 995			break
 996		}
 997	}
 998	if smallModelProviderCfg.ID == "" {
 999		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1000	}
1001
1002	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1003	var largeModelProviderCfg config.ProviderConfig
1004	for p := range cfg.Providers.Seq() {
1005		if p.ID == largeModelCfg.Provider {
1006			largeModelProviderCfg = p
1007			break
1008		}
1009	}
1010	if largeModelProviderCfg.ID == "" {
1011		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1012	}
1013
1014	var maxTitleTokens int64 = 40
1015
1016	// if the max output is too low for the gemini provider it won't return anything
1017	if smallModelCfg.Provider == "gemini" {
1018		maxTitleTokens = 1000
1019	}
1020	// Recreate title provider
1021	titleOpts := []provider.ProviderClientOption{
1022		provider.WithModel(config.SelectedModelTypeSmall),
1023		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1024		provider.WithMaxTokens(maxTitleTokens),
1025	}
1026	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1027	if err != nil {
1028		return fmt.Errorf("failed to create new title provider: %w", err)
1029	}
1030	a.titleProvider = newTitleProvider
1031
1032	// Recreate summarize provider if provider changed (now large model)
1033	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1034		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1035		if largeModel == nil {
1036			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1037		}
1038		summarizeOpts := []provider.ProviderClientOption{
1039			provider.WithModel(config.SelectedModelTypeLarge),
1040			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1041		}
1042		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1043		if err != nil {
1044			return fmt.Errorf("failed to create new summarize provider: %w", err)
1045		}
1046		a.summarizeProvider = newSummarizeProvider
1047		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1048	}
1049
1050	return nil
1051}