agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29type AgentEventType string
  30
  31const (
  32	AgentEventTypeError     AgentEventType = "error"
  33	AgentEventTypeResponse  AgentEventType = "response"
  34	AgentEventTypeSummarize AgentEventType = "summarize"
  35)
  36
  37type AgentEvent struct {
  38	Type    AgentEventType
  39	Message message.Message
  40	Error   error
  41
  42	// When summarizing
  43	SessionID string
  44	Progress  string
  45	Done      bool
  46}
  47
  48type Service interface {
  49	pubsub.Suscriber[AgentEvent]
  50	Model() catwalk.Model
  51	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  52	Cancel(sessionID string)
  53	CancelAll()
  54	IsSessionBusy(sessionID string) bool
  55	IsBusy() bool
  56	Summarize(ctx context.Context, sessionID string) error
  57	UpdateModel() error
  58	QueuedPrompts(sessionID string) int
  59	ClearQueue(sessionID string)
  60}
  61
  62type agent struct {
  63	*pubsub.Broker[AgentEvent]
  64	agentCfg    config.Agent
  65	sessions    session.Service
  66	messages    message.Service
  67	permissions permission.Service
  68	mcpTools    []McpTool
  69
  70	tools *csync.LazySlice[tools.BaseTool]
  71	// We need this to be able to update it when model changes
  72	agentToolFn func() (tools.BaseTool, error)
  73
  74	provider   provider.Provider
  75	providerID string
  76
  77	titleProvider       provider.Provider
  78	summarizeProvider   provider.Provider
  79	summarizeProviderID string
  80
  81	activeRequests *csync.Map[string, context.CancelFunc]
  82	promptQueue    *csync.Map[string, []string]
  83}
  84
  85var agentPromptMap = map[string]prompt.PromptID{
  86	"coder": prompt.PromptCoder,
  87	"task":  prompt.PromptTask,
  88}
  89
  90func NewAgent(
  91	ctx context.Context,
  92	agentCfg config.Agent,
  93	// These services are needed in the tools
  94	permissions permission.Service,
  95	sessions session.Service,
  96	messages message.Service,
  97	history history.Service,
  98	lspClients *csync.Map[string, *lsp.Client],
  99) (Service, error) {
 100	cfg := config.Get()
 101
 102	var agentToolFn func() (tools.BaseTool, error)
 103	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 104		agentToolFn = func() (tools.BaseTool, error) {
 105			taskAgentCfg := config.Get().Agents["task"]
 106			if taskAgentCfg.ID == "" {
 107				return nil, fmt.Errorf("task agent not found in config")
 108			}
 109			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 110			if err != nil {
 111				return nil, fmt.Errorf("failed to create task agent: %w", err)
 112			}
 113			return NewAgentTool(taskAgent, sessions, messages), nil
 114		}
 115	}
 116
 117	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 118	if providerCfg == nil {
 119		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 120	}
 121	model := config.Get().GetModelByType(agentCfg.Model)
 122
 123	if model == nil {
 124		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 125	}
 126
 127	promptID := agentPromptMap[agentCfg.ID]
 128	if promptID == "" {
 129		promptID = prompt.PromptDefault
 130	}
 131	opts := []provider.ProviderClientOption{
 132		provider.WithModel(agentCfg.Model),
 133		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 134	}
 135	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 136	if err != nil {
 137		return nil, err
 138	}
 139
 140	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 141	var smallModelProviderCfg *config.ProviderConfig
 142	if smallModelCfg.Provider == providerCfg.ID {
 143		smallModelProviderCfg = providerCfg
 144	} else {
 145		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 146
 147		if smallModelProviderCfg.ID == "" {
 148			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 149		}
 150	}
 151	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 152	if smallModel.ID == "" {
 153		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 154	}
 155
 156	titleOpts := []provider.ProviderClientOption{
 157		provider.WithModel(config.SelectedModelTypeSmall),
 158		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 159	}
 160	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 161	if err != nil {
 162		return nil, err
 163	}
 164
 165	summarizeOpts := []provider.ProviderClientOption{
 166		provider.WithModel(config.SelectedModelTypeLarge),
 167		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 168	}
 169	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 170	if err != nil {
 171		return nil, err
 172	}
 173
 174	toolFn := func() []tools.BaseTool {
 175		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 176		defer func() {
 177			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 178		}()
 179
 180		cwd := cfg.WorkingDir()
 181		allTools := []tools.BaseTool{
 182			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 183			tools.NewDownloadTool(permissions, cwd),
 184			tools.NewEditTool(lspClients, permissions, history, cwd),
 185			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 186			tools.NewFetchTool(permissions, cwd),
 187			tools.NewGlobTool(cwd),
 188			tools.NewGrepTool(cwd),
 189			tools.NewLsTool(permissions, cwd),
 190			tools.NewSourcegraphTool(),
 191			tools.NewViewTool(lspClients, permissions, cwd),
 192			tools.NewWriteTool(lspClients, permissions, history, cwd),
 193		}
 194
 195		mcpToolsOnce.Do(func() {
 196			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 197		})
 198
 199		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 200			if agentCfg.ID == "coder" {
 201				t = append(t, mcpTools...)
 202				if lspClients.Len() > 0 {
 203					t = append(t, tools.NewDiagnosticsTool(lspClients))
 204				}
 205			}
 206			return t
 207		}
 208
 209		if agentCfg.AllowedTools == nil {
 210			return withCoderTools(allTools)
 211		}
 212
 213		var filteredTools []tools.BaseTool
 214		for _, tool := range allTools {
 215			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 216				filteredTools = append(filteredTools, tool)
 217			}
 218		}
 219		return withCoderTools(filteredTools)
 220	}
 221
 222	return &agent{
 223		Broker:              pubsub.NewBroker[AgentEvent](),
 224		agentCfg:            agentCfg,
 225		provider:            agentProvider,
 226		providerID:          string(providerCfg.ID),
 227		messages:            messages,
 228		sessions:            sessions,
 229		titleProvider:       titleProvider,
 230		summarizeProvider:   summarizeProvider,
 231		summarizeProviderID: string(providerCfg.ID),
 232		agentToolFn:         agentToolFn,
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236		permissions:         permissions,
 237	}, nil
 238}
 239
 240func (a *agent) Model() catwalk.Model {
 241	return *config.Get().GetModelByType(a.agentCfg.Model)
 242}
 243
 244func (a *agent) Cancel(sessionID string) {
 245	// Cancel regular requests
 246	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 247		slog.Info("Request cancellation initiated", "session_id", sessionID)
 248		cancel()
 249	}
 250
 251	// Also check for summarize requests
 252	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 253		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 254		cancel()
 255	}
 256
 257	if a.QueuedPrompts(sessionID) > 0 {
 258		slog.Info("Clearing queued prompts", "session_id", sessionID)
 259		a.promptQueue.Del(sessionID)
 260	}
 261}
 262
 263func (a *agent) IsBusy() bool {
 264	var busy bool
 265	for cancelFunc := range a.activeRequests.Seq() {
 266		if cancelFunc != nil {
 267			busy = true
 268			break
 269		}
 270	}
 271	return busy
 272}
 273
 274func (a *agent) IsSessionBusy(sessionID string) bool {
 275	_, busy := a.activeRequests.Get(sessionID)
 276	return busy
 277}
 278
 279func (a *agent) QueuedPrompts(sessionID string) int {
 280	l, ok := a.promptQueue.Get(sessionID)
 281	if !ok {
 282		return 0
 283	}
 284	return len(l)
 285}
 286
 287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 288	if content == "" {
 289		return nil
 290	}
 291	if a.titleProvider == nil {
 292		return nil
 293	}
 294	session, err := a.sessions.Get(ctx, sessionID)
 295	if err != nil {
 296		return err
 297	}
 298	parts := []message.ContentPart{message.TextContent{
 299		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 300	}}
 301
 302	// Use streaming approach like summarization
 303	response := a.titleProvider.StreamResponse(
 304		ctx,
 305		[]message.Message{
 306			{
 307				Role:  message.User,
 308				Parts: parts,
 309			},
 310		},
 311		nil,
 312	)
 313
 314	var finalResponse *provider.ProviderResponse
 315	for r := range response {
 316		if r.Error != nil {
 317			return r.Error
 318		}
 319		finalResponse = r.Response
 320	}
 321
 322	if finalResponse == nil {
 323		return fmt.Errorf("no response received from title provider")
 324	}
 325
 326	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 327
 328	if idx := strings.Index(title, "</think>"); idx > 0 {
 329		title = title[idx+len("</think>"):]
 330	}
 331
 332	title = strings.TrimSpace(title)
 333	if title == "" {
 334		return nil
 335	}
 336
 337	session.Title = title
 338	_, err = a.sessions.Save(ctx, session)
 339	return err
 340}
 341
 342func (a *agent) err(err error) AgentEvent {
 343	return AgentEvent{
 344		Type:  AgentEventTypeError,
 345		Error: err,
 346	}
 347}
 348
 349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 350	if !a.Model().SupportsImages && attachments != nil {
 351		attachments = nil
 352	}
 353	events := make(chan AgentEvent, 1)
 354	if a.IsSessionBusy(sessionID) {
 355		existing, ok := a.promptQueue.Get(sessionID)
 356		if !ok {
 357			existing = []string{}
 358		}
 359		existing = append(existing, content)
 360		a.promptQueue.Set(sessionID, existing)
 361		return nil, nil
 362	}
 363
 364	genCtx, cancel := context.WithCancel(ctx)
 365	a.activeRequests.Set(sessionID, cancel)
 366	startTime := time.Now()
 367
 368	go func() {
 369		slog.Debug("Request started", "sessionID", sessionID)
 370		defer log.RecoverPanic("agent.Run", func() {
 371			events <- a.err(fmt.Errorf("panic while running the agent"))
 372		})
 373		var attachmentParts []message.ContentPart
 374		for _, attachment := range attachments {
 375			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 376		}
 377		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 378		if result.Error != nil {
 379			if isCancelledErr(result.Error) {
 380				slog.Error("Request canceled", "sessionID", sessionID)
 381			} else {
 382				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 383				event.Error(result.Error)
 384			}
 385		} else {
 386			slog.Debug("Request completed", "sessionID", sessionID)
 387		}
 388		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 389		a.activeRequests.Del(sessionID)
 390		cancel()
 391		a.Publish(pubsub.CreatedEvent, result)
 392		events <- result
 393		close(events)
 394	}()
 395	a.eventPromptSent(sessionID)
 396	return events, nil
 397}
 398
 399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 400	cfg := config.Get()
 401	// List existing messages; if none, start title generation asynchronously.
 402	msgs, err := a.messages.List(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to list messages: %w", err))
 405	}
 406	if len(msgs) == 0 {
 407		go func() {
 408			defer log.RecoverPanic("agent.Run", func() {
 409				slog.Error("panic while generating title")
 410			})
 411			titleErr := a.generateTitle(ctx, sessionID, content)
 412			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 413				slog.Error("failed to generate title", "error", titleErr)
 414			}
 415		}()
 416	}
 417	session, err := a.sessions.Get(ctx, sessionID)
 418	if err != nil {
 419		return a.err(fmt.Errorf("failed to get session: %w", err))
 420	}
 421	if session.SummaryMessageID != "" {
 422		summaryMsgInex := -1
 423		for i, msg := range msgs {
 424			if msg.ID == session.SummaryMessageID {
 425				summaryMsgInex = i
 426				break
 427			}
 428		}
 429		if summaryMsgInex != -1 {
 430			msgs = msgs[summaryMsgInex:]
 431			msgs[0].Role = message.User
 432		}
 433	}
 434
 435	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 436	if err != nil {
 437		return a.err(fmt.Errorf("failed to create user message: %w", err))
 438	}
 439	// Append the new user message to the conversation history.
 440	msgHistory := append(msgs, userMsg)
 441
 442	for {
 443		// Check for cancellation before each iteration
 444		select {
 445		case <-ctx.Done():
 446			return a.err(ctx.Err())
 447		default:
 448			// Continue processing
 449		}
 450		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 451		if err != nil {
 452			if errors.Is(err, context.Canceled) {
 453				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 454				a.messages.Update(context.Background(), agentMessage)
 455				return a.err(ErrRequestCancelled)
 456			}
 457			return a.err(fmt.Errorf("failed to process events: %w", err))
 458		}
 459		if cfg.Options.Debug {
 460			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 461		}
 462		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 463			// We are not done, we need to respond with the tool response
 464			msgHistory = append(msgHistory, agentMessage, *toolResults)
 465			// If there are queued prompts, process the next one
 466			nextPrompt, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range nextPrompt {
 469					// Create a new user message for the queued prompt
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					// Append the new user message to the conversation history
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477			}
 478
 479			continue
 480		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 481			queuePrompts, ok := a.promptQueue.Take(sessionID)
 482			if ok {
 483				for _, prompt := range queuePrompts {
 484					if prompt == "" {
 485						continue
 486					}
 487					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 488					if err != nil {
 489						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 490					}
 491					msgHistory = append(msgHistory, userMsg)
 492				}
 493				continue
 494			}
 495		}
 496		if agentMessage.FinishReason() == "" {
 497			// Kujtim: could not track down where this is happening but this means its cancelled
 498			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 499			_ = a.messages.Update(context.Background(), agentMessage)
 500			return a.err(ErrRequestCancelled)
 501		}
 502		return AgentEvent{
 503			Type:    AgentEventTypeResponse,
 504			Message: agentMessage,
 505			Done:    true,
 506		}
 507	}
 508}
 509
 510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 511	parts := []message.ContentPart{message.TextContent{Text: content}}
 512	parts = append(parts, attachmentParts...)
 513	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 514		Role:  message.User,
 515		Parts: parts,
 516	})
 517}
 518
 519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 520	allTools := slices.Collect(a.tools.Seq())
 521	if a.agentToolFn != nil {
 522		agentTool, agentToolErr := a.agentToolFn()
 523		if agentToolErr != nil {
 524			return nil, agentToolErr
 525		}
 526		allTools = append(allTools, agentTool)
 527	}
 528	return allTools, nil
 529}
 530
 531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 532	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 533
 534	// Create the assistant message first so the spinner shows immediately
 535	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 536		Role:     message.Assistant,
 537		Parts:    []message.ContentPart{},
 538		Model:    a.Model().ID,
 539		Provider: a.providerID,
 540	})
 541	if err != nil {
 542		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 543	}
 544
 545	allTools, toolsErr := a.getAllTools()
 546	if toolsErr != nil {
 547		return assistantMsg, nil, toolsErr
 548	}
 549	// Now collect tools (which may block on MCP initialization)
 550	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 551
 552	// Add the session and message ID into the context if needed by tools.
 553	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 554
 555loop:
 556	for {
 557		select {
 558		case event, ok := <-eventChan:
 559			if !ok {
 560				break loop
 561			}
 562			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 563				if errors.Is(processErr, context.Canceled) {
 564					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 565				} else {
 566					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 567				}
 568				return assistantMsg, nil, processErr
 569			}
 570		case <-ctx.Done():
 571			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 572			return assistantMsg, nil, ctx.Err()
 573		}
 574	}
 575
 576	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 577	toolCalls := assistantMsg.ToolCalls()
 578	for i, toolCall := range toolCalls {
 579		select {
 580		case <-ctx.Done():
 581			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 582			// Make all future tool calls cancelled
 583			for j := i; j < len(toolCalls); j++ {
 584				toolResults[j] = message.ToolResult{
 585					ToolCallID: toolCalls[j].ID,
 586					Content:    "Tool execution canceled by user",
 587					IsError:    true,
 588				}
 589			}
 590			goto out
 591		default:
 592			// Continue processing
 593			var tool tools.BaseTool
 594			allTools, _ := a.getAllTools()
 595			for _, availableTool := range allTools {
 596				if availableTool.Info().Name == toolCall.Name {
 597					tool = availableTool
 598					break
 599				}
 600			}
 601
 602			// Tool not found
 603			if tool == nil {
 604				toolResults[i] = message.ToolResult{
 605					ToolCallID: toolCall.ID,
 606					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 607					IsError:    true,
 608				}
 609				continue
 610			}
 611
 612			// Run tool in goroutine to allow cancellation
 613			type toolExecResult struct {
 614				response tools.ToolResponse
 615				err      error
 616			}
 617			resultChan := make(chan toolExecResult, 1)
 618
 619			go func() {
 620				response, err := tool.Run(ctx, tools.ToolCall{
 621					ID:    toolCall.ID,
 622					Name:  toolCall.Name,
 623					Input: toolCall.Input,
 624				})
 625				resultChan <- toolExecResult{response: response, err: err}
 626			}()
 627
 628			var toolResponse tools.ToolResponse
 629			var toolErr error
 630
 631			select {
 632			case <-ctx.Done():
 633				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 634				// Mark remaining tool calls as cancelled
 635				for j := i; j < len(toolCalls); j++ {
 636					toolResults[j] = message.ToolResult{
 637						ToolCallID: toolCalls[j].ID,
 638						Content:    "Tool execution canceled by user",
 639						IsError:    true,
 640					}
 641				}
 642				goto out
 643			case result := <-resultChan:
 644				toolResponse = result.response
 645				toolErr = result.err
 646			}
 647
 648			if toolErr != nil {
 649				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 650				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 651					toolResults[i] = message.ToolResult{
 652						ToolCallID: toolCall.ID,
 653						Content:    "Permission denied",
 654						IsError:    true,
 655					}
 656					for j := i + 1; j < len(toolCalls); j++ {
 657						toolResults[j] = message.ToolResult{
 658							ToolCallID: toolCalls[j].ID,
 659							Content:    "Tool execution canceled by user",
 660							IsError:    true,
 661						}
 662					}
 663					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 664					break
 665				}
 666			}
 667			toolResults[i] = message.ToolResult{
 668				ToolCallID: toolCall.ID,
 669				Content:    toolResponse.Content,
 670				Metadata:   toolResponse.Metadata,
 671				IsError:    toolResponse.IsError,
 672			}
 673		}
 674	}
 675out:
 676	if len(toolResults) == 0 {
 677		return assistantMsg, nil, nil
 678	}
 679	parts := make([]message.ContentPart, 0)
 680	for _, tr := range toolResults {
 681		parts = append(parts, tr)
 682	}
 683	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 684		Role:     message.Tool,
 685		Parts:    parts,
 686		Provider: a.providerID,
 687	})
 688	if err != nil {
 689		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 690	}
 691
 692	return assistantMsg, &msg, err
 693}
 694
 695func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 696	msg.AddFinish(finishReason, message, details)
 697	_ = a.messages.Update(ctx, *msg)
 698}
 699
 700func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 701	select {
 702	case <-ctx.Done():
 703		return ctx.Err()
 704	default:
 705		// Continue processing.
 706	}
 707
 708	switch event.Type {
 709	case provider.EventThinkingDelta:
 710		assistantMsg.AppendReasoningContent(event.Thinking)
 711		return a.messages.Update(ctx, *assistantMsg)
 712	case provider.EventSignatureDelta:
 713		assistantMsg.AppendReasoningSignature(event.Signature)
 714		return a.messages.Update(ctx, *assistantMsg)
 715	case provider.EventContentDelta:
 716		assistantMsg.FinishThinking()
 717		assistantMsg.AppendContent(event.Content)
 718		return a.messages.Update(ctx, *assistantMsg)
 719	case provider.EventToolUseStart:
 720		assistantMsg.FinishThinking()
 721		slog.Info("Tool call started", "toolCall", event.ToolCall)
 722		assistantMsg.AddToolCall(*event.ToolCall)
 723		return a.messages.Update(ctx, *assistantMsg)
 724	case provider.EventToolUseDelta:
 725		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 726		return a.messages.Update(ctx, *assistantMsg)
 727	case provider.EventToolUseStop:
 728		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 729		assistantMsg.FinishToolCall(event.ToolCall.ID)
 730		return a.messages.Update(ctx, *assistantMsg)
 731	case provider.EventError:
 732		return event.Error
 733	case provider.EventComplete:
 734		assistantMsg.FinishThinking()
 735		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 736		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 737		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 738			return fmt.Errorf("failed to update message: %w", err)
 739		}
 740		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 741	}
 742
 743	return nil
 744}
 745
 746func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 747	sess, err := a.sessions.Get(ctx, sessionID)
 748	if err != nil {
 749		return fmt.Errorf("failed to get session: %w", err)
 750	}
 751
 752	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 753		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 754		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 755		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 756
 757	a.eventTokensUsed(sessionID, usage, cost)
 758
 759	sess.Cost += cost
 760	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 761	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 762
 763	_, err = a.sessions.Save(ctx, sess)
 764	if err != nil {
 765		return fmt.Errorf("failed to save session: %w", err)
 766	}
 767	return nil
 768}
 769
 770func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 771	if a.summarizeProvider == nil {
 772		return fmt.Errorf("summarize provider not available")
 773	}
 774
 775	// Check if session is busy
 776	if a.IsSessionBusy(sessionID) {
 777		return ErrSessionBusy
 778	}
 779
 780	// Create a new context with cancellation
 781	summarizeCtx, cancel := context.WithCancel(ctx)
 782
 783	// Store the cancel function in activeRequests to allow cancellation
 784	a.activeRequests.Set(sessionID+"-summarize", cancel)
 785
 786	go func() {
 787		defer a.activeRequests.Del(sessionID + "-summarize")
 788		defer cancel()
 789		event := AgentEvent{
 790			Type:     AgentEventTypeSummarize,
 791			Progress: "Starting summarization...",
 792		}
 793
 794		a.Publish(pubsub.CreatedEvent, event)
 795		// Get all messages from the session
 796		msgs, err := a.messages.List(summarizeCtx, sessionID)
 797		if err != nil {
 798			event = AgentEvent{
 799				Type:  AgentEventTypeError,
 800				Error: fmt.Errorf("failed to list messages: %w", err),
 801				Done:  true,
 802			}
 803			a.Publish(pubsub.CreatedEvent, event)
 804			return
 805		}
 806		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 807
 808		if len(msgs) == 0 {
 809			event = AgentEvent{
 810				Type:  AgentEventTypeError,
 811				Error: fmt.Errorf("no messages to summarize"),
 812				Done:  true,
 813			}
 814			a.Publish(pubsub.CreatedEvent, event)
 815			return
 816		}
 817
 818		event = AgentEvent{
 819			Type:     AgentEventTypeSummarize,
 820			Progress: "Analyzing conversation...",
 821		}
 822		a.Publish(pubsub.CreatedEvent, event)
 823
 824		// Add a system message to guide the summarization
 825		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 826
 827		// Create a new message with the summarize prompt
 828		promptMsg := message.Message{
 829			Role:  message.User,
 830			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 831		}
 832
 833		// Append the prompt to the messages
 834		msgsWithPrompt := append(msgs, promptMsg)
 835
 836		event = AgentEvent{
 837			Type:     AgentEventTypeSummarize,
 838			Progress: "Generating summary...",
 839		}
 840
 841		a.Publish(pubsub.CreatedEvent, event)
 842
 843		// Send the messages to the summarize provider
 844		response := a.summarizeProvider.StreamResponse(
 845			summarizeCtx,
 846			msgsWithPrompt,
 847			nil,
 848		)
 849		var finalResponse *provider.ProviderResponse
 850		for r := range response {
 851			if r.Error != nil {
 852				event = AgentEvent{
 853					Type:  AgentEventTypeError,
 854					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 855					Done:  true,
 856				}
 857				a.Publish(pubsub.CreatedEvent, event)
 858				return
 859			}
 860			finalResponse = r.Response
 861		}
 862
 863		summary := strings.TrimSpace(finalResponse.Content)
 864		if summary == "" {
 865			event = AgentEvent{
 866				Type:  AgentEventTypeError,
 867				Error: fmt.Errorf("empty summary returned"),
 868				Done:  true,
 869			}
 870			a.Publish(pubsub.CreatedEvent, event)
 871			return
 872		}
 873		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 874		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 875		event = AgentEvent{
 876			Type:     AgentEventTypeSummarize,
 877			Progress: "Creating new session...",
 878		}
 879
 880		a.Publish(pubsub.CreatedEvent, event)
 881		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 882		if err != nil {
 883			event = AgentEvent{
 884				Type:  AgentEventTypeError,
 885				Error: fmt.Errorf("failed to get session: %w", err),
 886				Done:  true,
 887			}
 888
 889			a.Publish(pubsub.CreatedEvent, event)
 890			return
 891		}
 892		// Create a message in the new session with the summary
 893		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 894			Role: message.Assistant,
 895			Parts: []message.ContentPart{
 896				message.TextContent{Text: summary},
 897				message.Finish{
 898					Reason: message.FinishReasonEndTurn,
 899					Time:   time.Now().Unix(),
 900				},
 901			},
 902			Model:    a.summarizeProvider.Model().ID,
 903			Provider: a.summarizeProviderID,
 904		})
 905		if err != nil {
 906			event = AgentEvent{
 907				Type:  AgentEventTypeError,
 908				Error: fmt.Errorf("failed to create summary message: %w", err),
 909				Done:  true,
 910			}
 911
 912			a.Publish(pubsub.CreatedEvent, event)
 913			return
 914		}
 915		oldSession.SummaryMessageID = msg.ID
 916		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 917		oldSession.PromptTokens = 0
 918		model := a.summarizeProvider.Model()
 919		usage := finalResponse.Usage
 920		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 921			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 922			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 923			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 924		oldSession.Cost += cost
 925		_, err = a.sessions.Save(summarizeCtx, oldSession)
 926		if err != nil {
 927			event = AgentEvent{
 928				Type:  AgentEventTypeError,
 929				Error: fmt.Errorf("failed to save session: %w", err),
 930				Done:  true,
 931			}
 932			a.Publish(pubsub.CreatedEvent, event)
 933		}
 934
 935		event = AgentEvent{
 936			Type:      AgentEventTypeSummarize,
 937			SessionID: oldSession.ID,
 938			Progress:  "Summary complete",
 939			Done:      true,
 940		}
 941		a.Publish(pubsub.CreatedEvent, event)
 942		// Send final success event with the new session ID
 943	}()
 944
 945	return nil
 946}
 947
 948func (a *agent) ClearQueue(sessionID string) {
 949	if a.QueuedPrompts(sessionID) > 0 {
 950		slog.Info("Clearing queued prompts", "session_id", sessionID)
 951		a.promptQueue.Del(sessionID)
 952	}
 953}
 954
 955func (a *agent) CancelAll() {
 956	if !a.IsBusy() {
 957		return
 958	}
 959	for key := range a.activeRequests.Seq2() {
 960		a.Cancel(key) // key is sessionID
 961	}
 962
 963	timeout := time.After(5 * time.Second)
 964	for a.IsBusy() {
 965		select {
 966		case <-timeout:
 967			return
 968		default:
 969			time.Sleep(200 * time.Millisecond)
 970		}
 971	}
 972}
 973
 974func (a *agent) UpdateModel() error {
 975	cfg := config.Get()
 976
 977	// Get current provider configuration
 978	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 979	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 980		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 981	}
 982
 983	// Check if provider has changed
 984	if string(currentProviderCfg.ID) != a.providerID {
 985		// Provider changed, need to recreate the main provider
 986		model := cfg.GetModelByType(a.agentCfg.Model)
 987		if model.ID == "" {
 988			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 989		}
 990
 991		promptID := agentPromptMap[a.agentCfg.ID]
 992		if promptID == "" {
 993			promptID = prompt.PromptDefault
 994		}
 995
 996		opts := []provider.ProviderClientOption{
 997			provider.WithModel(a.agentCfg.Model),
 998			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 999		}
1000
1001		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1002		if err != nil {
1003			return fmt.Errorf("failed to create new provider: %w", err)
1004		}
1005
1006		// Update the provider and provider ID
1007		a.provider = newProvider
1008		a.providerID = string(currentProviderCfg.ID)
1009	}
1010
1011	// Check if providers have changed for title (small) and summarize (large)
1012	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1013	var smallModelProviderCfg config.ProviderConfig
1014	for p := range cfg.Providers.Seq() {
1015		if p.ID == smallModelCfg.Provider {
1016			smallModelProviderCfg = p
1017			break
1018		}
1019	}
1020	if smallModelProviderCfg.ID == "" {
1021		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1022	}
1023
1024	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1025	var largeModelProviderCfg config.ProviderConfig
1026	for p := range cfg.Providers.Seq() {
1027		if p.ID == largeModelCfg.Provider {
1028			largeModelProviderCfg = p
1029			break
1030		}
1031	}
1032	if largeModelProviderCfg.ID == "" {
1033		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1034	}
1035
1036	var maxTitleTokens int64 = 40
1037
1038	// if the max output is too low for the gemini provider it won't return anything
1039	if smallModelCfg.Provider == "gemini" {
1040		maxTitleTokens = 1000
1041	}
1042	// Recreate title provider
1043	titleOpts := []provider.ProviderClientOption{
1044		provider.WithModel(config.SelectedModelTypeSmall),
1045		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1046		provider.WithMaxTokens(maxTitleTokens),
1047	}
1048	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1049	if err != nil {
1050		return fmt.Errorf("failed to create new title provider: %w", err)
1051	}
1052	a.titleProvider = newTitleProvider
1053
1054	// Recreate summarize provider if provider changed (now large model)
1055	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1056		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1057		if largeModel == nil {
1058			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1059		}
1060		summarizeOpts := []provider.ProviderClientOption{
1061			provider.WithModel(config.SelectedModelTypeLarge),
1062			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1063		}
1064		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1065		if err != nil {
1066			return fmt.Errorf("failed to create new summarize provider: %w", err)
1067		}
1068		a.summarizeProvider = newSummarizeProvider
1069		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1070	}
1071
1072	return nil
1073}