1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183			tools.NewFetchTool(permissions, cwd),
184			tools.NewGlobTool(cwd),
185			tools.NewGrepTool(cwd),
186			tools.NewLsTool(permissions, cwd),
187			tools.NewSourcegraphTool(),
188			tools.NewViewTool(lspClients, permissions, cwd),
189			tools.NewWriteTool(lspClients, permissions, history, cwd),
190		}
191
192		mcpToolsOnce.Do(func() {
193			mcpTools = doGetMCPTools(ctx, permissions, cfg)
194		})
195		allTools = append(allTools, mcpTools...)
196
197		if len(lspClients) > 0 {
198			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199		}
200
201		if agentTool != nil {
202			allTools = append(allTools, agentTool)
203		}
204
205		if agentCfg.AllowedTools == nil {
206			return allTools
207		}
208
209		var filteredTools []tools.BaseTool
210		for _, tool := range allTools {
211			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212				filteredTools = append(filteredTools, tool)
213			}
214		}
215		return filteredTools
216	}
217
218	return &agent{
219		Broker:              pubsub.NewBroker[AgentEvent](),
220		agentCfg:            agentCfg,
221		provider:            agentProvider,
222		providerID:          string(providerCfg.ID),
223		messages:            messages,
224		sessions:            sessions,
225		titleProvider:       titleProvider,
226		summarizeProvider:   summarizeProvider,
227		summarizeProviderID: string(smallModelProviderCfg.ID),
228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
229		tools:               csync.NewLazySlice(toolFn),
230	}, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234	return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240		slog.Info("Request cancellation initiated", "session_id", sessionID)
241		cancel()
242	}
243
244	// Also check for summarize requests
245	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247		cancel()
248	}
249}
250
251func (a *agent) IsBusy() bool {
252	var busy bool
253	for cancelFunc := range a.activeRequests.Seq() {
254		if cancelFunc != nil {
255			busy = true
256			break
257		}
258	}
259	return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263	_, busy := a.activeRequests.Get(sessionID)
264	return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268	if content == "" {
269		return nil
270	}
271	if a.titleProvider == nil {
272		return nil
273	}
274	session, err := a.sessions.Get(ctx, sessionID)
275	if err != nil {
276		return err
277	}
278	parts := []message.ContentPart{message.TextContent{
279		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280	}}
281
282	// Use streaming approach like summarization
283	response := a.titleProvider.StreamResponse(
284		ctx,
285		[]message.Message{
286			{
287				Role:  message.User,
288				Parts: parts,
289			},
290		},
291		nil,
292	)
293
294	var finalResponse *provider.ProviderResponse
295	for r := range response {
296		if r.Error != nil {
297			return r.Error
298		}
299		finalResponse = r.Response
300	}
301
302	if finalResponse == nil {
303		return fmt.Errorf("no response received from title provider")
304	}
305
306	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307	if title == "" {
308		return nil
309	}
310
311	session.Title = title
312	_, err = a.sessions.Save(ctx, session)
313	return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317	return AgentEvent{
318		Type:  AgentEventTypeError,
319		Error: err,
320	}
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324	if !a.Model().SupportsImages && attachments != nil {
325		attachments = nil
326	}
327	events := make(chan AgentEvent)
328	if a.IsSessionBusy(sessionID) {
329		return nil, ErrSessionBusy
330	}
331
332	genCtx, cancel := context.WithCancel(ctx)
333
334	a.activeRequests.Set(sessionID, cancel)
335	go func() {
336		slog.Debug("Request started", "sessionID", sessionID)
337		defer log.RecoverPanic("agent.Run", func() {
338			events <- a.err(fmt.Errorf("panic while running the agent"))
339		})
340		var attachmentParts []message.ContentPart
341		for _, attachment := range attachments {
342			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343		}
344		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346			slog.Error(result.Error.Error())
347		}
348		slog.Debug("Request completed", "sessionID", sessionID)
349		a.activeRequests.Del(sessionID)
350		cancel()
351		a.Publish(pubsub.CreatedEvent, result)
352		events <- result
353		close(events)
354	}()
355	return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359	cfg := config.Get()
360	// List existing messages; if none, start title generation asynchronously.
361	msgs, err := a.messages.List(ctx, sessionID)
362	if err != nil {
363		return a.err(fmt.Errorf("failed to list messages: %w", err))
364	}
365	
366	// Implement sliding window to limit message history
367	maxMessages := cfg.Options.MaxMessagesInContext
368	if maxMessages > 0 && len(msgs) > maxMessages {
369		// Keep the first message (usually system/context) and the last N-1 messages
370		msgs = append(msgs[:1], msgs[len(msgs)-maxMessages+1:]...)
371	}
372	if len(msgs) == 0 {
373		go func() {
374			defer log.RecoverPanic("agent.Run", func() {
375				slog.Error("panic while generating title")
376			})
377			titleErr := a.generateTitle(context.Background(), sessionID, content)
378			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379				slog.Error("failed to generate title", "error", titleErr)
380			}
381		}()
382	}
383	session, err := a.sessions.Get(ctx, sessionID)
384	if err != nil {
385		return a.err(fmt.Errorf("failed to get session: %w", err))
386	}
387	if session.SummaryMessageID != "" {
388		summaryMsgInex := -1
389		for i, msg := range msgs {
390			if msg.ID == session.SummaryMessageID {
391				summaryMsgInex = i
392				break
393			}
394		}
395		if summaryMsgInex != -1 {
396			msgs = msgs[summaryMsgInex:]
397			msgs[0].Role = message.User
398		}
399	}
400
401	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402	if err != nil {
403		return a.err(fmt.Errorf("failed to create user message: %w", err))
404	}
405	// Append the new user message to the conversation history.
406	msgHistory := append(msgs, userMsg)
407
408	for {
409		// Check for cancellation before each iteration
410		select {
411		case <-ctx.Done():
412			return a.err(ctx.Err())
413		default:
414			// Continue processing
415		}
416		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417		if err != nil {
418			if errors.Is(err, context.Canceled) {
419				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
420				a.messages.Update(context.Background(), agentMessage)
421				return a.err(ErrRequestCancelled)
422			}
423			return a.err(fmt.Errorf("failed to process events: %w", err))
424		}
425		if cfg.Options.Debug {
426			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
427		}
428		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
429			// We are not done, we need to respond with the tool response
430			msgHistory = append(msgHistory, agentMessage, *toolResults)
431			continue
432		}
433		if agentMessage.FinishReason() == "" {
434			// Kujtim: could not track down where this is happening but this means its cancelled
435			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436			_ = a.messages.Update(context.Background(), agentMessage)
437			return a.err(ErrRequestCancelled)
438		}
439		return AgentEvent{
440			Type:    AgentEventTypeResponse,
441			Message: agentMessage,
442			Done:    true,
443		}
444	}
445}
446
447func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
448	parts := []message.ContentPart{message.TextContent{Text: content}}
449	parts = append(parts, attachmentParts...)
450	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451		Role:  message.User,
452		Parts: parts,
453	})
454}
455
456func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
457	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
458	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
459
460	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
461		Role:     message.Assistant,
462		Parts:    []message.ContentPart{},
463		Model:    a.Model().ID,
464		Provider: a.providerID,
465	})
466	if err != nil {
467		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
468	}
469
470	// Add the session and message ID into the context if needed by tools.
471	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
472
473	// Process each event in the stream.
474	for event := range eventChan {
475		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
476			if errors.Is(processErr, context.Canceled) {
477				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478			} else {
479				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
480			}
481			return assistantMsg, nil, processErr
482		}
483		if ctx.Err() != nil {
484			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485			return assistantMsg, nil, ctx.Err()
486		}
487	}
488
489	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
490	toolCalls := assistantMsg.ToolCalls()
491	for i, toolCall := range toolCalls {
492		select {
493		case <-ctx.Done():
494			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
495			// Make all future tool calls cancelled
496			for j := i; j < len(toolCalls); j++ {
497				toolResults[j] = message.ToolResult{
498					ToolCallID: toolCalls[j].ID,
499					Content:    "Tool execution canceled by user",
500					IsError:    true,
501				}
502			}
503			goto out
504		default:
505			// Continue processing
506			var tool tools.BaseTool
507			for availableTool := range a.tools.Seq() {
508				if availableTool.Info().Name == toolCall.Name {
509					tool = availableTool
510					break
511				}
512			}
513
514			// Tool not found
515			if tool == nil {
516				toolResults[i] = message.ToolResult{
517					ToolCallID: toolCall.ID,
518					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
519					IsError:    true,
520				}
521				continue
522			}
523
524			// Run tool in goroutine to allow cancellation
525			type toolExecResult struct {
526				response tools.ToolResponse
527				err      error
528			}
529			resultChan := make(chan toolExecResult, 1)
530
531			go func() {
532				response, err := tool.Run(ctx, tools.ToolCall{
533					ID:    toolCall.ID,
534					Name:  toolCall.Name,
535					Input: toolCall.Input,
536				})
537				resultChan <- toolExecResult{response: response, err: err}
538			}()
539
540			var toolResponse tools.ToolResponse
541			var toolErr error
542
543			select {
544			case <-ctx.Done():
545				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
546				// Mark remaining tool calls as cancelled
547				for j := i; j < len(toolCalls); j++ {
548					toolResults[j] = message.ToolResult{
549						ToolCallID: toolCalls[j].ID,
550						Content:    "Tool execution canceled by user",
551						IsError:    true,
552					}
553				}
554				goto out
555			case result := <-resultChan:
556				toolResponse = result.response
557				toolErr = result.err
558			}
559
560			if toolErr != nil {
561				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
562				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
563					toolResults[i] = message.ToolResult{
564						ToolCallID: toolCall.ID,
565						Content:    "Permission denied",
566						IsError:    true,
567					}
568					for j := i + 1; j < len(toolCalls); j++ {
569						toolResults[j] = message.ToolResult{
570							ToolCallID: toolCalls[j].ID,
571							Content:    "Tool execution canceled by user",
572							IsError:    true,
573						}
574					}
575					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
576					break
577				}
578			}
579			toolResults[i] = message.ToolResult{
580				ToolCallID: toolCall.ID,
581				Content:    toolResponse.Content,
582				Metadata:   toolResponse.Metadata,
583				IsError:    toolResponse.IsError,
584			}
585		}
586	}
587out:
588	if len(toolResults) == 0 {
589		return assistantMsg, nil, nil
590	}
591	parts := make([]message.ContentPart, 0)
592	for _, tr := range toolResults {
593		parts = append(parts, tr)
594	}
595	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
596		Role:     message.Tool,
597		Parts:    parts,
598		Provider: a.providerID,
599	})
600	if err != nil {
601		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
602	}
603
604	return assistantMsg, &msg, err
605}
606
607func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
608	msg.AddFinish(finishReason, message, details)
609	_ = a.messages.Update(ctx, *msg)
610}
611
612func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
613	select {
614	case <-ctx.Done():
615		return ctx.Err()
616	default:
617		// Continue processing.
618	}
619
620	switch event.Type {
621	case provider.EventThinkingDelta:
622		assistantMsg.AppendReasoningContent(event.Thinking)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventSignatureDelta:
625		assistantMsg.AppendReasoningSignature(event.Signature)
626		return a.messages.Update(ctx, *assistantMsg)
627	case provider.EventContentDelta:
628		assistantMsg.FinishThinking()
629		assistantMsg.AppendContent(event.Content)
630		return a.messages.Update(ctx, *assistantMsg)
631	case provider.EventToolUseStart:
632		assistantMsg.FinishThinking()
633		slog.Info("Tool call started", "toolCall", event.ToolCall)
634		assistantMsg.AddToolCall(*event.ToolCall)
635		return a.messages.Update(ctx, *assistantMsg)
636	case provider.EventToolUseDelta:
637		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
638		return a.messages.Update(ctx, *assistantMsg)
639	case provider.EventToolUseStop:
640		slog.Info("Finished tool call", "toolCall", event.ToolCall)
641		assistantMsg.FinishToolCall(event.ToolCall.ID)
642		return a.messages.Update(ctx, *assistantMsg)
643	case provider.EventError:
644		return event.Error
645	case provider.EventComplete:
646		assistantMsg.FinishThinking()
647		assistantMsg.SetToolCalls(event.Response.ToolCalls)
648		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
649		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
650			return fmt.Errorf("failed to update message: %w", err)
651		}
652		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
653	}
654
655	return nil
656}
657
658func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
659	sess, err := a.sessions.Get(ctx, sessionID)
660	if err != nil {
661		return fmt.Errorf("failed to get session: %w", err)
662	}
663
664	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
665		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
666		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
667		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
668
669	sess.Cost += cost
670	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
671	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
672
673	_, err = a.sessions.Save(ctx, sess)
674	if err != nil {
675		return fmt.Errorf("failed to save session: %w", err)
676	}
677	return nil
678}
679
680func (a *agent) Summarize(ctx context.Context, sessionID string) error {
681	if a.summarizeProvider == nil {
682		return fmt.Errorf("summarize provider not available")
683	}
684
685	// Check if session is busy
686	if a.IsSessionBusy(sessionID) {
687		return ErrSessionBusy
688	}
689
690	// Create a new context with cancellation
691	summarizeCtx, cancel := context.WithCancel(ctx)
692
693	// Store the cancel function in activeRequests to allow cancellation
694	a.activeRequests.Set(sessionID+"-summarize", cancel)
695
696	go func() {
697		defer a.activeRequests.Del(sessionID + "-summarize")
698		defer cancel()
699		event := AgentEvent{
700			Type:     AgentEventTypeSummarize,
701			Progress: "Starting summarization...",
702		}
703
704		a.Publish(pubsub.CreatedEvent, event)
705		// Get all messages from the session
706		msgs, err := a.messages.List(summarizeCtx, sessionID)
707		if err != nil {
708			event = AgentEvent{
709				Type:  AgentEventTypeError,
710				Error: fmt.Errorf("failed to list messages: %w", err),
711				Done:  true,
712			}
713			a.Publish(pubsub.CreatedEvent, event)
714			return
715		}
716		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
717
718		if len(msgs) == 0 {
719			event = AgentEvent{
720				Type:  AgentEventTypeError,
721				Error: fmt.Errorf("no messages to summarize"),
722				Done:  true,
723			}
724			a.Publish(pubsub.CreatedEvent, event)
725			return
726		}
727
728		event = AgentEvent{
729			Type:     AgentEventTypeSummarize,
730			Progress: "Analyzing conversation...",
731		}
732		a.Publish(pubsub.CreatedEvent, event)
733
734		// Add a system message to guide the summarization
735		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
736
737		// Create a new message with the summarize prompt
738		promptMsg := message.Message{
739			Role:  message.User,
740			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
741		}
742
743		// Append the prompt to the messages
744		msgsWithPrompt := append(msgs, promptMsg)
745
746		event = AgentEvent{
747			Type:     AgentEventTypeSummarize,
748			Progress: "Generating summary...",
749		}
750
751		a.Publish(pubsub.CreatedEvent, event)
752
753		// Send the messages to the summarize provider
754		response := a.summarizeProvider.StreamResponse(
755			summarizeCtx,
756			msgsWithPrompt,
757			nil,
758		)
759		var finalResponse *provider.ProviderResponse
760		for r := range response {
761			if r.Error != nil {
762				event = AgentEvent{
763					Type:  AgentEventTypeError,
764					Error: fmt.Errorf("failed to summarize: %w", err),
765					Done:  true,
766				}
767				a.Publish(pubsub.CreatedEvent, event)
768				return
769			}
770			finalResponse = r.Response
771		}
772
773		summary := strings.TrimSpace(finalResponse.Content)
774		if summary == "" {
775			event = AgentEvent{
776				Type:  AgentEventTypeError,
777				Error: fmt.Errorf("empty summary returned"),
778				Done:  true,
779			}
780			a.Publish(pubsub.CreatedEvent, event)
781			return
782		}
783		shell := shell.GetPersistentShell(config.Get().WorkingDir())
784		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
785		event = AgentEvent{
786			Type:     AgentEventTypeSummarize,
787			Progress: "Creating new session...",
788		}
789
790		a.Publish(pubsub.CreatedEvent, event)
791		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
792		if err != nil {
793			event = AgentEvent{
794				Type:  AgentEventTypeError,
795				Error: fmt.Errorf("failed to get session: %w", err),
796				Done:  true,
797			}
798
799			a.Publish(pubsub.CreatedEvent, event)
800			return
801		}
802		// Create a message in the new session with the summary
803		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
804			Role: message.Assistant,
805			Parts: []message.ContentPart{
806				message.TextContent{Text: summary},
807				message.Finish{
808					Reason: message.FinishReasonEndTurn,
809					Time:   time.Now().Unix(),
810				},
811			},
812			Model:    a.summarizeProvider.Model().ID,
813			Provider: a.summarizeProviderID,
814		})
815		if err != nil {
816			event = AgentEvent{
817				Type:  AgentEventTypeError,
818				Error: fmt.Errorf("failed to create summary message: %w", err),
819				Done:  true,
820			}
821
822			a.Publish(pubsub.CreatedEvent, event)
823			return
824		}
825		oldSession.SummaryMessageID = msg.ID
826		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
827		oldSession.PromptTokens = 0
828		model := a.summarizeProvider.Model()
829		usage := finalResponse.Usage
830		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
831			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
832			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
833			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
834		oldSession.Cost += cost
835		_, err = a.sessions.Save(summarizeCtx, oldSession)
836		if err != nil {
837			event = AgentEvent{
838				Type:  AgentEventTypeError,
839				Error: fmt.Errorf("failed to save session: %w", err),
840				Done:  true,
841			}
842			a.Publish(pubsub.CreatedEvent, event)
843		}
844
845		event = AgentEvent{
846			Type:      AgentEventTypeSummarize,
847			SessionID: oldSession.ID,
848			Progress:  "Summary complete",
849			Done:      true,
850		}
851		a.Publish(pubsub.CreatedEvent, event)
852		// Send final success event with the new session ID
853	}()
854
855	return nil
856}
857
858func (a *agent) CancelAll() {
859	if !a.IsBusy() {
860		return
861	}
862	for key := range a.activeRequests.Seq2() {
863		a.Cancel(key) // key is sessionID
864	}
865
866	timeout := time.After(5 * time.Second)
867	for a.IsBusy() {
868		select {
869		case <-timeout:
870			return
871		default:
872			time.Sleep(200 * time.Millisecond)
873		}
874	}
875}
876
877func (a *agent) UpdateModel() error {
878	cfg := config.Get()
879
880	// Get current provider configuration
881	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
882	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
883		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
884	}
885
886	// Check if provider has changed
887	if string(currentProviderCfg.ID) != a.providerID {
888		// Provider changed, need to recreate the main provider
889		model := cfg.GetModelByType(a.agentCfg.Model)
890		if model.ID == "" {
891			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
892		}
893
894		promptID := agentPromptMap[a.agentCfg.ID]
895		if promptID == "" {
896			promptID = prompt.PromptDefault
897		}
898
899		opts := []provider.ProviderClientOption{
900			provider.WithModel(a.agentCfg.Model),
901			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
902		}
903
904		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
905		if err != nil {
906			return fmt.Errorf("failed to create new provider: %w", err)
907		}
908
909		// Update the provider and provider ID
910		a.provider = newProvider
911		a.providerID = string(currentProviderCfg.ID)
912	}
913
914	// Check if small model provider has changed (affects title and summarize providers)
915	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
916	var smallModelProviderCfg config.ProviderConfig
917
918	for p := range cfg.Providers.Seq() {
919		if p.ID == smallModelCfg.Provider {
920			smallModelProviderCfg = p
921			break
922		}
923	}
924
925	if smallModelProviderCfg.ID == "" {
926		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
927	}
928
929	// Check if summarize provider has changed
930	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
931		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
932		if smallModel == nil {
933			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
934		}
935
936		// Recreate title provider
937		titleOpts := []provider.ProviderClientOption{
938			provider.WithModel(config.SelectedModelTypeSmall),
939			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
940			// We want the title to be short, so we limit the max tokens
941			provider.WithMaxTokens(40),
942		}
943		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
944		if err != nil {
945			return fmt.Errorf("failed to create new title provider: %w", err)
946		}
947
948		// Recreate summarize provider
949		summarizeOpts := []provider.ProviderClientOption{
950			provider.WithModel(config.SelectedModelTypeSmall),
951			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
952		}
953		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
954		if err != nil {
955			return fmt.Errorf("failed to create new summarize provider: %w", err)
956		}
957
958		// Update the providers and provider ID
959		a.titleProvider = newTitleProvider
960		a.summarizeProvider = newSummarizeProvider
961		a.summarizeProviderID = string(smallModelProviderCfg.ID)
962	}
963
964	return nil
965}