agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200		allTools = append(allTools, mcpTools...)
 201
 202		if len(lspClients) > 0 {
 203			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 204		}
 205
 206		if agentTool != nil {
 207			allTools = append(allTools, agentTool)
 208		}
 209
 210		if agentCfg.AllowedTools == nil {
 211			return allTools
 212		}
 213
 214		var filteredTools []tools.BaseTool
 215		for _, tool := range allTools {
 216			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 217				filteredTools = append(filteredTools, tool)
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358
 359	a.activeRequests.Set(sessionID, cancel)
 360	go func() {
 361		slog.Debug("Request started", "sessionID", sessionID)
 362		defer log.RecoverPanic("agent.Run", func() {
 363			events <- a.err(fmt.Errorf("panic while running the agent"))
 364		})
 365		var attachmentParts []message.ContentPart
 366		for _, attachment := range attachments {
 367			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 368		}
 369		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 370		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 371			slog.Error(result.Error.Error())
 372		}
 373		slog.Debug("Request completed", "sessionID", sessionID)
 374		a.activeRequests.Del(sessionID)
 375		cancel()
 376		a.Publish(pubsub.CreatedEvent, result)
 377		select {
 378		case events <- result:
 379		case <-genCtx.Done():
 380		}
 381		close(events)
 382	}()
 383	return events, nil
 384}
 385
 386func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 387	cfg := config.Get()
 388	// List existing messages; if none, start title generation asynchronously.
 389	msgs, err := a.messages.List(ctx, sessionID)
 390	if err != nil {
 391		return a.err(fmt.Errorf("failed to list messages: %w", err))
 392	}
 393	if len(msgs) == 0 {
 394		go func() {
 395			defer log.RecoverPanic("agent.Run", func() {
 396				slog.Error("panic while generating title")
 397			})
 398			titleErr := a.generateTitle(ctx, sessionID, content)
 399			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 400				slog.Error("failed to generate title", "error", titleErr)
 401			}
 402		}()
 403	}
 404	session, err := a.sessions.Get(ctx, sessionID)
 405	if err != nil {
 406		return a.err(fmt.Errorf("failed to get session: %w", err))
 407	}
 408	if session.SummaryMessageID != "" {
 409		summaryMsgInex := -1
 410		for i, msg := range msgs {
 411			if msg.ID == session.SummaryMessageID {
 412				summaryMsgInex = i
 413				break
 414			}
 415		}
 416		if summaryMsgInex != -1 {
 417			msgs = msgs[summaryMsgInex:]
 418			msgs[0].Role = message.User
 419		}
 420	}
 421
 422	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 423	if err != nil {
 424		return a.err(fmt.Errorf("failed to create user message: %w", err))
 425	}
 426	// Append the new user message to the conversation history.
 427	msgHistory := append(msgs, userMsg)
 428
 429	for {
 430		// Check for cancellation before each iteration
 431		select {
 432		case <-ctx.Done():
 433			return a.err(ctx.Err())
 434		default:
 435			// Continue processing
 436		}
 437		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 438		if err != nil {
 439			if errors.Is(err, context.Canceled) {
 440				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 441				a.messages.Update(context.Background(), agentMessage)
 442				return a.err(ErrRequestCancelled)
 443			}
 444			return a.err(fmt.Errorf("failed to process events: %w", err))
 445		}
 446		if cfg.Options.Debug {
 447			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 448		}
 449		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 450			// We are not done, we need to respond with the tool response
 451			msgHistory = append(msgHistory, agentMessage, *toolResults)
 452			// If there are queued prompts, process the next one
 453			nextPrompt, ok := a.promptQueue.Take(sessionID)
 454			if ok {
 455				for _, prompt := range nextPrompt {
 456					// Create a new user message for the queued prompt
 457					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 458					if err != nil {
 459						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 460					}
 461					// Append the new user message to the conversation history
 462					msgHistory = append(msgHistory, userMsg)
 463				}
 464			}
 465
 466			continue
 467		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 468			queuePrompts, ok := a.promptQueue.Take(sessionID)
 469			if ok {
 470				for _, prompt := range queuePrompts {
 471					if prompt == "" {
 472						continue
 473					}
 474					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 475					if err != nil {
 476						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 477					}
 478					msgHistory = append(msgHistory, userMsg)
 479				}
 480				continue
 481			}
 482		}
 483		if agentMessage.FinishReason() == "" {
 484			// Kujtim: could not track down where this is happening but this means its cancelled
 485			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 486			_ = a.messages.Update(context.Background(), agentMessage)
 487			return a.err(ErrRequestCancelled)
 488		}
 489		return AgentEvent{
 490			Type:    AgentEventTypeResponse,
 491			Message: agentMessage,
 492			Done:    true,
 493		}
 494	}
 495}
 496
 497func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 498	parts := []message.ContentPart{message.TextContent{Text: content}}
 499	parts = append(parts, attachmentParts...)
 500	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 501		Role:  message.User,
 502		Parts: parts,
 503	})
 504}
 505
 506func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 507	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 508
 509	// Create the assistant message first so the spinner shows immediately
 510	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 511		Role:     message.Assistant,
 512		Parts:    []message.ContentPart{},
 513		Model:    a.Model().ID,
 514		Provider: a.providerID,
 515	})
 516	if err != nil {
 517		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 518	}
 519
 520	// Now collect tools (which may block on MCP initialization)
 521	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 522
 523	// Add the session and message ID into the context if needed by tools.
 524	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 525
 526	// Process each event in the stream.
 527	for event := range eventChan {
 528		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 529			if errors.Is(processErr, context.Canceled) {
 530				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 531			} else {
 532				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 533			}
 534			return assistantMsg, nil, processErr
 535		}
 536		if ctx.Err() != nil {
 537			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 538			return assistantMsg, nil, ctx.Err()
 539		}
 540	}
 541
 542	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 543	toolCalls := assistantMsg.ToolCalls()
 544	for i, toolCall := range toolCalls {
 545		select {
 546		case <-ctx.Done():
 547			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 548			// Make all future tool calls cancelled
 549			for j := i; j < len(toolCalls); j++ {
 550				toolResults[j] = message.ToolResult{
 551					ToolCallID: toolCalls[j].ID,
 552					Content:    "Tool execution canceled by user",
 553					IsError:    true,
 554				}
 555			}
 556			goto out
 557		default:
 558			// Continue processing
 559			var tool tools.BaseTool
 560			for availableTool := range a.tools.Seq() {
 561				if availableTool.Info().Name == toolCall.Name {
 562					tool = availableTool
 563					break
 564				}
 565			}
 566
 567			// Tool not found
 568			if tool == nil {
 569				toolResults[i] = message.ToolResult{
 570					ToolCallID: toolCall.ID,
 571					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 572					IsError:    true,
 573				}
 574				continue
 575			}
 576
 577			// Run tool in goroutine to allow cancellation
 578			type toolExecResult struct {
 579				response tools.ToolResponse
 580				err      error
 581			}
 582			resultChan := make(chan toolExecResult, 1)
 583
 584			go func() {
 585				response, err := tool.Run(ctx, tools.ToolCall{
 586					ID:    toolCall.ID,
 587					Name:  toolCall.Name,
 588					Input: toolCall.Input,
 589				})
 590				resultChan <- toolExecResult{response: response, err: err}
 591			}()
 592
 593			var toolResponse tools.ToolResponse
 594			var toolErr error
 595
 596			select {
 597			case <-ctx.Done():
 598				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 599				// Mark remaining tool calls as cancelled
 600				for j := i; j < len(toolCalls); j++ {
 601					toolResults[j] = message.ToolResult{
 602						ToolCallID: toolCalls[j].ID,
 603						Content:    "Tool execution canceled by user",
 604						IsError:    true,
 605					}
 606				}
 607				goto out
 608			case result := <-resultChan:
 609				toolResponse = result.response
 610				toolErr = result.err
 611			}
 612
 613			if toolErr != nil {
 614				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 615				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 616					toolResults[i] = message.ToolResult{
 617						ToolCallID: toolCall.ID,
 618						Content:    "Permission denied",
 619						IsError:    true,
 620					}
 621					for j := i + 1; j < len(toolCalls); j++ {
 622						toolResults[j] = message.ToolResult{
 623							ToolCallID: toolCalls[j].ID,
 624							Content:    "Tool execution canceled by user",
 625							IsError:    true,
 626						}
 627					}
 628					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 629					break
 630				}
 631			}
 632			toolResults[i] = message.ToolResult{
 633				ToolCallID: toolCall.ID,
 634				Content:    toolResponse.Content,
 635				Metadata:   toolResponse.Metadata,
 636				IsError:    toolResponse.IsError,
 637			}
 638		}
 639	}
 640out:
 641	if len(toolResults) == 0 {
 642		return assistantMsg, nil, nil
 643	}
 644	parts := make([]message.ContentPart, 0)
 645	for _, tr := range toolResults {
 646		parts = append(parts, tr)
 647	}
 648	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 649		Role:     message.Tool,
 650		Parts:    parts,
 651		Provider: a.providerID,
 652	})
 653	if err != nil {
 654		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 655	}
 656
 657	return assistantMsg, &msg, err
 658}
 659
 660func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 661	msg.AddFinish(finishReason, message, details)
 662	_ = a.messages.Update(ctx, *msg)
 663}
 664
 665func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 666	select {
 667	case <-ctx.Done():
 668		return ctx.Err()
 669	default:
 670		// Continue processing.
 671	}
 672
 673	switch event.Type {
 674	case provider.EventThinkingDelta:
 675		assistantMsg.AppendReasoningContent(event.Thinking)
 676		return a.messages.Update(ctx, *assistantMsg)
 677	case provider.EventSignatureDelta:
 678		assistantMsg.AppendReasoningSignature(event.Signature)
 679		return a.messages.Update(ctx, *assistantMsg)
 680	case provider.EventContentDelta:
 681		assistantMsg.FinishThinking()
 682		assistantMsg.AppendContent(event.Content)
 683		return a.messages.Update(ctx, *assistantMsg)
 684	case provider.EventToolUseStart:
 685		assistantMsg.FinishThinking()
 686		slog.Info("Tool call started", "toolCall", event.ToolCall)
 687		assistantMsg.AddToolCall(*event.ToolCall)
 688		return a.messages.Update(ctx, *assistantMsg)
 689	case provider.EventToolUseDelta:
 690		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 691		return a.messages.Update(ctx, *assistantMsg)
 692	case provider.EventToolUseStop:
 693		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 694		assistantMsg.FinishToolCall(event.ToolCall.ID)
 695		return a.messages.Update(ctx, *assistantMsg)
 696	case provider.EventError:
 697		return event.Error
 698	case provider.EventComplete:
 699		assistantMsg.FinishThinking()
 700		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 701		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 702		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 703			return fmt.Errorf("failed to update message: %w", err)
 704		}
 705		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 706	}
 707
 708	return nil
 709}
 710
 711func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 712	sess, err := a.sessions.Get(ctx, sessionID)
 713	if err != nil {
 714		return fmt.Errorf("failed to get session: %w", err)
 715	}
 716
 717	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 718		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 719		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 720		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 721
 722	sess.Cost += cost
 723	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 724	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 725
 726	_, err = a.sessions.Save(ctx, sess)
 727	if err != nil {
 728		return fmt.Errorf("failed to save session: %w", err)
 729	}
 730	return nil
 731}
 732
 733func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 734	if a.summarizeProvider == nil {
 735		return fmt.Errorf("summarize provider not available")
 736	}
 737
 738	// Check if session is busy
 739	if a.IsSessionBusy(sessionID) {
 740		return ErrSessionBusy
 741	}
 742
 743	// Create a new context with cancellation
 744	summarizeCtx, cancel := context.WithCancel(ctx)
 745
 746	// Store the cancel function in activeRequests to allow cancellation
 747	a.activeRequests.Set(sessionID+"-summarize", cancel)
 748
 749	go func() {
 750		defer a.activeRequests.Del(sessionID + "-summarize")
 751		defer cancel()
 752		event := AgentEvent{
 753			Type:     AgentEventTypeSummarize,
 754			Progress: "Starting summarization...",
 755		}
 756
 757		a.Publish(pubsub.CreatedEvent, event)
 758		// Get all messages from the session
 759		msgs, err := a.messages.List(summarizeCtx, sessionID)
 760		if err != nil {
 761			event = AgentEvent{
 762				Type:  AgentEventTypeError,
 763				Error: fmt.Errorf("failed to list messages: %w", err),
 764				Done:  true,
 765			}
 766			a.Publish(pubsub.CreatedEvent, event)
 767			return
 768		}
 769		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 770
 771		if len(msgs) == 0 {
 772			event = AgentEvent{
 773				Type:  AgentEventTypeError,
 774				Error: fmt.Errorf("no messages to summarize"),
 775				Done:  true,
 776			}
 777			a.Publish(pubsub.CreatedEvent, event)
 778			return
 779		}
 780
 781		event = AgentEvent{
 782			Type:     AgentEventTypeSummarize,
 783			Progress: "Analyzing conversation...",
 784		}
 785		a.Publish(pubsub.CreatedEvent, event)
 786
 787		// Add a system message to guide the summarization
 788		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 789
 790		// Create a new message with the summarize prompt
 791		promptMsg := message.Message{
 792			Role:  message.User,
 793			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 794		}
 795
 796		// Append the prompt to the messages
 797		msgsWithPrompt := append(msgs, promptMsg)
 798
 799		event = AgentEvent{
 800			Type:     AgentEventTypeSummarize,
 801			Progress: "Generating summary...",
 802		}
 803
 804		a.Publish(pubsub.CreatedEvent, event)
 805
 806		// Send the messages to the summarize provider
 807		response := a.summarizeProvider.StreamResponse(
 808			summarizeCtx,
 809			msgsWithPrompt,
 810			nil,
 811		)
 812		var finalResponse *provider.ProviderResponse
 813		for r := range response {
 814			if r.Error != nil {
 815				event = AgentEvent{
 816					Type:  AgentEventTypeError,
 817					Error: fmt.Errorf("failed to summarize: %w", err),
 818					Done:  true,
 819				}
 820				a.Publish(pubsub.CreatedEvent, event)
 821				return
 822			}
 823			finalResponse = r.Response
 824		}
 825
 826		summary := strings.TrimSpace(finalResponse.Content)
 827		if summary == "" {
 828			event = AgentEvent{
 829				Type:  AgentEventTypeError,
 830				Error: fmt.Errorf("empty summary returned"),
 831				Done:  true,
 832			}
 833			a.Publish(pubsub.CreatedEvent, event)
 834			return
 835		}
 836		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 837		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 838		event = AgentEvent{
 839			Type:     AgentEventTypeSummarize,
 840			Progress: "Creating new session...",
 841		}
 842
 843		a.Publish(pubsub.CreatedEvent, event)
 844		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 845		if err != nil {
 846			event = AgentEvent{
 847				Type:  AgentEventTypeError,
 848				Error: fmt.Errorf("failed to get session: %w", err),
 849				Done:  true,
 850			}
 851
 852			a.Publish(pubsub.CreatedEvent, event)
 853			return
 854		}
 855		// Create a message in the new session with the summary
 856		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 857			Role: message.Assistant,
 858			Parts: []message.ContentPart{
 859				message.TextContent{Text: summary},
 860				message.Finish{
 861					Reason: message.FinishReasonEndTurn,
 862					Time:   time.Now().Unix(),
 863				},
 864			},
 865			Model:    a.summarizeProvider.Model().ID,
 866			Provider: a.summarizeProviderID,
 867		})
 868		if err != nil {
 869			event = AgentEvent{
 870				Type:  AgentEventTypeError,
 871				Error: fmt.Errorf("failed to create summary message: %w", err),
 872				Done:  true,
 873			}
 874
 875			a.Publish(pubsub.CreatedEvent, event)
 876			return
 877		}
 878		oldSession.SummaryMessageID = msg.ID
 879		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 880		oldSession.PromptTokens = 0
 881		model := a.summarizeProvider.Model()
 882		usage := finalResponse.Usage
 883		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 884			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 885			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 886			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 887		oldSession.Cost += cost
 888		_, err = a.sessions.Save(summarizeCtx, oldSession)
 889		if err != nil {
 890			event = AgentEvent{
 891				Type:  AgentEventTypeError,
 892				Error: fmt.Errorf("failed to save session: %w", err),
 893				Done:  true,
 894			}
 895			a.Publish(pubsub.CreatedEvent, event)
 896		}
 897
 898		event = AgentEvent{
 899			Type:      AgentEventTypeSummarize,
 900			SessionID: oldSession.ID,
 901			Progress:  "Summary complete",
 902			Done:      true,
 903		}
 904		a.Publish(pubsub.CreatedEvent, event)
 905		// Send final success event with the new session ID
 906	}()
 907
 908	return nil
 909}
 910
 911func (a *agent) ClearQueue(sessionID string) {
 912	if a.QueuedPrompts(sessionID) > 0 {
 913		slog.Info("Clearing queued prompts", "session_id", sessionID)
 914		a.promptQueue.Del(sessionID)
 915	}
 916}
 917
 918func (a *agent) CancelAll() {
 919	if !a.IsBusy() {
 920		return
 921	}
 922	for key := range a.activeRequests.Seq2() {
 923		a.Cancel(key) // key is sessionID
 924	}
 925
 926	timeout := time.After(5 * time.Second)
 927	for a.IsBusy() {
 928		select {
 929		case <-timeout:
 930			return
 931		default:
 932			time.Sleep(200 * time.Millisecond)
 933		}
 934	}
 935}
 936
 937func (a *agent) UpdateModel() error {
 938	cfg := config.Get()
 939
 940	// Get current provider configuration
 941	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 942	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 943		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 944	}
 945
 946	// Check if provider has changed
 947	if string(currentProviderCfg.ID) != a.providerID {
 948		// Provider changed, need to recreate the main provider
 949		model := cfg.GetModelByType(a.agentCfg.Model)
 950		if model.ID == "" {
 951			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 952		}
 953
 954		promptID := agentPromptMap[a.agentCfg.ID]
 955		if promptID == "" {
 956			promptID = prompt.PromptDefault
 957		}
 958
 959		opts := []provider.ProviderClientOption{
 960			provider.WithModel(a.agentCfg.Model),
 961			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 962		}
 963
 964		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 965		if err != nil {
 966			return fmt.Errorf("failed to create new provider: %w", err)
 967		}
 968
 969		// Update the provider and provider ID
 970		a.provider = newProvider
 971		a.providerID = string(currentProviderCfg.ID)
 972	}
 973
 974	// Check if providers have changed for title (small) and summarize (large)
 975	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 976	var smallModelProviderCfg config.ProviderConfig
 977	for p := range cfg.Providers.Seq() {
 978		if p.ID == smallModelCfg.Provider {
 979			smallModelProviderCfg = p
 980			break
 981		}
 982	}
 983	if smallModelProviderCfg.ID == "" {
 984		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 985	}
 986
 987	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 988	var largeModelProviderCfg config.ProviderConfig
 989	for p := range cfg.Providers.Seq() {
 990		if p.ID == largeModelCfg.Provider {
 991			largeModelProviderCfg = p
 992			break
 993		}
 994	}
 995	if largeModelProviderCfg.ID == "" {
 996		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
 997	}
 998
 999	var maxTitleTokens int64 = 40
1000
1001	// if the max output is too low for the gemini provider it won't return anything
1002	if smallModelCfg.Provider == "gemini" {
1003		maxTitleTokens = 1000
1004	}
1005	// Recreate title provider
1006	titleOpts := []provider.ProviderClientOption{
1007		provider.WithModel(config.SelectedModelTypeSmall),
1008		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1009		provider.WithMaxTokens(maxTitleTokens),
1010	}
1011	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1012	if err != nil {
1013		return fmt.Errorf("failed to create new title provider: %w", err)
1014	}
1015	a.titleProvider = newTitleProvider
1016
1017	// Recreate summarize provider if provider changed (now large model)
1018	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1019		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1020		if largeModel == nil {
1021			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1022		}
1023		summarizeOpts := []provider.ProviderClientOption{
1024			provider.WithModel(config.SelectedModelTypeLarge),
1025			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1026		}
1027		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1028		if err != nil {
1029			return fmt.Errorf("failed to create new summarize provider: %w", err)
1030		}
1031		a.summarizeProvider = newSummarizeProvider
1032		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1033	}
1034
1035	return nil
1036}