agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183			tools.NewFetchTool(permissions, cwd),
184			tools.NewGlobTool(cwd),
185			tools.NewGrepTool(cwd),
186			tools.NewLsTool(permissions, cwd),
187			tools.NewSourcegraphTool(),
188			tools.NewViewTool(lspClients, permissions, cwd),
189			tools.NewWriteTool(lspClients, permissions, history, cwd),
190		}
191
192		mcpToolsOnce.Do(func() {
193			mcpTools = doGetMCPTools(ctx, permissions, cfg)
194		})
195		allTools = append(allTools, mcpTools...)
196
197		if len(lspClients) > 0 {
198			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199		}
200
201		if agentTool != nil {
202			allTools = append(allTools, agentTool)
203		}
204
205		if agentCfg.AllowedTools == nil {
206			return allTools
207		}
208
209		var filteredTools []tools.BaseTool
210		for _, tool := range allTools {
211			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212				filteredTools = append(filteredTools, tool)
213			}
214		}
215		return filteredTools
216	}
217
218	return &agent{
219		Broker:              pubsub.NewBroker[AgentEvent](),
220		agentCfg:            agentCfg,
221		provider:            agentProvider,
222		providerID:          string(providerCfg.ID),
223		messages:            messages,
224		sessions:            sessions,
225		titleProvider:       titleProvider,
226		summarizeProvider:   summarizeProvider,
227		summarizeProviderID: string(smallModelProviderCfg.ID),
228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
229		tools:               csync.NewLazySlice(toolFn),
230	}, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234	return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240		slog.Info("Request cancellation initiated", "session_id", sessionID)
241		cancel()
242	}
243
244	// Also check for summarize requests
245	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247		cancel()
248	}
249}
250
251func (a *agent) IsBusy() bool {
252	var busy bool
253	for cancelFunc := range a.activeRequests.Seq() {
254		if cancelFunc != nil {
255			busy = true
256			break
257		}
258	}
259	return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263	_, busy := a.activeRequests.Get(sessionID)
264	return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268	if content == "" {
269		return nil
270	}
271	if a.titleProvider == nil {
272		return nil
273	}
274	session, err := a.sessions.Get(ctx, sessionID)
275	if err != nil {
276		return err
277	}
278	parts := []message.ContentPart{message.TextContent{
279		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280	}}
281
282	// Use streaming approach like summarization
283	response := a.titleProvider.StreamResponse(
284		ctx,
285		[]message.Message{
286			{
287				Role:  message.User,
288				Parts: parts,
289			},
290		},
291		nil,
292	)
293
294	var finalResponse *provider.ProviderResponse
295	for r := range response {
296		if r.Error != nil {
297			return r.Error
298		}
299		finalResponse = r.Response
300	}
301
302	if finalResponse == nil {
303		return fmt.Errorf("no response received from title provider")
304	}
305
306	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307	if title == "" {
308		return nil
309	}
310
311	session.Title = title
312	_, err = a.sessions.Save(ctx, session)
313	return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317	return AgentEvent{
318		Type:  AgentEventTypeError,
319		Error: err,
320	}
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324	if !a.Model().SupportsImages && attachments != nil {
325		attachments = nil
326	}
327	events := make(chan AgentEvent)
328	if a.IsSessionBusy(sessionID) {
329		return nil, ErrSessionBusy
330	}
331
332	genCtx, cancel := context.WithCancel(ctx)
333
334	a.activeRequests.Set(sessionID, cancel)
335	go func() {
336		slog.Debug("Request started", "sessionID", sessionID)
337		defer log.RecoverPanic("agent.Run", func() {
338			events <- a.err(fmt.Errorf("panic while running the agent"))
339		})
340		var attachmentParts []message.ContentPart
341		for _, attachment := range attachments {
342			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343		}
344		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346			slog.Error(result.Error.Error())
347		}
348		slog.Debug("Request completed", "sessionID", sessionID)
349		a.activeRequests.Del(sessionID)
350		cancel()
351		a.Publish(pubsub.CreatedEvent, result)
352		events <- result
353		close(events)
354	}()
355	return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359	cfg := config.Get()
360	// List existing messages; if none, start title generation asynchronously.
361	msgs, err := a.messages.List(ctx, sessionID)
362	if err != nil {
363		return a.err(fmt.Errorf("failed to list messages: %w", err))
364	}
365	if len(msgs) == 0 {
366		go func() {
367			defer log.RecoverPanic("agent.Run", func() {
368				slog.Error("panic while generating title")
369			})
370			titleErr := a.generateTitle(context.Background(), sessionID, content)
371			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372				slog.Error("failed to generate title", "error", titleErr)
373			}
374		}()
375	}
376	session, err := a.sessions.Get(ctx, sessionID)
377	if err != nil {
378		return a.err(fmt.Errorf("failed to get session: %w", err))
379	}
380	if session.SummaryMessageID != "" {
381		summaryMsgInex := -1
382		for i, msg := range msgs {
383			if msg.ID == session.SummaryMessageID {
384				summaryMsgInex = i
385				break
386			}
387		}
388		if summaryMsgInex != -1 {
389			msgs = msgs[summaryMsgInex:]
390			msgs[0].Role = message.User
391		}
392	}
393
394	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
395	if err != nil {
396		return a.err(fmt.Errorf("failed to create user message: %w", err))
397	}
398	// Append the new user message to the conversation history.
399	msgHistory := append(msgs, userMsg)
400
401	for {
402		// Check for cancellation before each iteration
403		select {
404		case <-ctx.Done():
405			return a.err(ctx.Err())
406		default:
407			// Continue processing
408		}
409		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
410		if err != nil {
411			if errors.Is(err, context.Canceled) {
412				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
413				a.messages.Update(context.Background(), agentMessage)
414				return a.err(ErrRequestCancelled)
415			}
416			return a.err(fmt.Errorf("failed to process events: %w", err))
417		}
418		if cfg.Options.Debug {
419			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420		}
421		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422			// We are not done, we need to respond with the tool response
423			msgHistory = append(msgHistory, agentMessage, *toolResults)
424			continue
425		}
426		if agentMessage.FinishReason() == "" {
427			// Kujtim: could not track down where this is happening but this means its cancelled
428			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
429			_ = a.messages.Update(context.Background(), agentMessage)
430			return a.err(ErrRequestCancelled)
431		}
432		return AgentEvent{
433			Type:    AgentEventTypeResponse,
434			Message: agentMessage,
435			Done:    true,
436		}
437	}
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441	parts := []message.ContentPart{message.TextContent{Text: content}}
442	parts = append(parts, attachmentParts...)
443	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:  message.User,
445		Parts: parts,
446	})
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
452
453	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454		Role:     message.Assistant,
455		Parts:    []message.ContentPart{},
456		Model:    a.Model().ID,
457		Provider: a.providerID,
458	})
459	if err != nil {
460		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461	}
462
463	// Add the session and message ID into the context if needed by tools.
464	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466	// Process each event in the stream.
467	for event := range eventChan {
468		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469			if errors.Is(processErr, context.Canceled) {
470				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
471			} else {
472				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
473			}
474			return assistantMsg, nil, processErr
475		}
476		if ctx.Err() != nil {
477			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478			return assistantMsg, nil, ctx.Err()
479		}
480	}
481
482	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
483	toolCalls := assistantMsg.ToolCalls()
484	for i, toolCall := range toolCalls {
485		select {
486		case <-ctx.Done():
487			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
488			// Make all future tool calls cancelled
489			for j := i; j < len(toolCalls); j++ {
490				toolResults[j] = message.ToolResult{
491					ToolCallID: toolCalls[j].ID,
492					Content:    "Tool execution canceled by user",
493					IsError:    true,
494				}
495			}
496			goto out
497		default:
498			// Continue processing
499			var tool tools.BaseTool
500			for availableTool := range a.tools.Seq() {
501				if availableTool.Info().Name == toolCall.Name {
502					tool = availableTool
503					break
504				}
505			}
506
507			// Tool not found
508			if tool == nil {
509				toolResults[i] = message.ToolResult{
510					ToolCallID: toolCall.ID,
511					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
512					IsError:    true,
513				}
514				continue
515			}
516
517			// Run tool in goroutine to allow cancellation
518			type toolExecResult struct {
519				response tools.ToolResponse
520				err      error
521			}
522			resultChan := make(chan toolExecResult, 1)
523
524			go func() {
525				response, err := tool.Run(ctx, tools.ToolCall{
526					ID:    toolCall.ID,
527					Name:  toolCall.Name,
528					Input: toolCall.Input,
529				})
530				resultChan <- toolExecResult{response: response, err: err}
531			}()
532
533			var toolResponse tools.ToolResponse
534			var toolErr error
535
536			select {
537			case <-ctx.Done():
538				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
539				// Mark remaining tool calls as cancelled
540				for j := i; j < len(toolCalls); j++ {
541					toolResults[j] = message.ToolResult{
542						ToolCallID: toolCalls[j].ID,
543						Content:    "Tool execution canceled by user",
544						IsError:    true,
545					}
546				}
547				goto out
548			case result := <-resultChan:
549				toolResponse = result.response
550				toolErr = result.err
551			}
552
553			if toolErr != nil {
554				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
555				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
556					toolResults[i] = message.ToolResult{
557						ToolCallID: toolCall.ID,
558						Content:    "Permission denied",
559						IsError:    true,
560					}
561					for j := i + 1; j < len(toolCalls); j++ {
562						toolResults[j] = message.ToolResult{
563							ToolCallID: toolCalls[j].ID,
564							Content:    "Tool execution canceled by user",
565							IsError:    true,
566						}
567					}
568					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
569					break
570				}
571			}
572			toolResults[i] = message.ToolResult{
573				ToolCallID: toolCall.ID,
574				Content:    toolResponse.Content,
575				Metadata:   toolResponse.Metadata,
576				IsError:    toolResponse.IsError,
577			}
578		}
579	}
580out:
581	if len(toolResults) == 0 {
582		return assistantMsg, nil, nil
583	}
584	parts := make([]message.ContentPart, 0)
585	for _, tr := range toolResults {
586		parts = append(parts, tr)
587	}
588	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
589		Role:     message.Tool,
590		Parts:    parts,
591		Provider: a.providerID,
592	})
593	if err != nil {
594		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
595	}
596
597	return assistantMsg, &msg, err
598}
599
600func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
601	msg.AddFinish(finishReason, message, details)
602	_ = a.messages.Update(ctx, *msg)
603}
604
605func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
606	select {
607	case <-ctx.Done():
608		return ctx.Err()
609	default:
610		// Continue processing.
611	}
612
613	switch event.Type {
614	case provider.EventThinkingDelta:
615		assistantMsg.AppendReasoningContent(event.Thinking)
616		return a.messages.Update(ctx, *assistantMsg)
617	case provider.EventSignatureDelta:
618		assistantMsg.AppendReasoningSignature(event.Signature)
619		return a.messages.Update(ctx, *assistantMsg)
620	case provider.EventContentDelta:
621		assistantMsg.FinishThinking()
622		assistantMsg.AppendContent(event.Content)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventToolUseStart:
625		assistantMsg.FinishThinking()
626		slog.Info("Tool call started", "toolCall", event.ToolCall)
627		assistantMsg.AddToolCall(*event.ToolCall)
628		return a.messages.Update(ctx, *assistantMsg)
629	case provider.EventToolUseDelta:
630		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
631		return a.messages.Update(ctx, *assistantMsg)
632	case provider.EventToolUseStop:
633		slog.Info("Finished tool call", "toolCall", event.ToolCall)
634		assistantMsg.FinishToolCall(event.ToolCall.ID)
635		return a.messages.Update(ctx, *assistantMsg)
636	case provider.EventError:
637		assistantMsg.SetRetrying(false)
638		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
639			return fmt.Errorf("failed to update message: %w", err)
640		}
641		return event.Error
642	case provider.EventRetry:
643		errMsg := ""
644		if event.Error != nil {
645			errMsg = event.Error.Error()
646		}
647		assistantMsg.SetRetrying(false)
648		assistantMsg.AddRetry(errMsg, event.Retry)
649		return a.messages.Update(ctx, *assistantMsg)
650	case provider.EventRetrying:
651		assistantMsg.SetRetrying(true)
652		return a.messages.Update(ctx, *assistantMsg)
653
654	case provider.EventComplete:
655		assistantMsg.FinishThinking()
656		assistantMsg.SetToolCalls(event.Response.ToolCalls)
657		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
658		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
659			return fmt.Errorf("failed to update message: %w", err)
660		}
661		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
662	}
663
664	return nil
665}
666
667func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
668	sess, err := a.sessions.Get(ctx, sessionID)
669	if err != nil {
670		return fmt.Errorf("failed to get session: %w", err)
671	}
672
673	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
674		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
675		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
676		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
677
678	sess.Cost += cost
679	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
680	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
681
682	_, err = a.sessions.Save(ctx, sess)
683	if err != nil {
684		return fmt.Errorf("failed to save session: %w", err)
685	}
686	return nil
687}
688
689func (a *agent) Summarize(ctx context.Context, sessionID string) error {
690	if a.summarizeProvider == nil {
691		return fmt.Errorf("summarize provider not available")
692	}
693
694	// Check if session is busy
695	if a.IsSessionBusy(sessionID) {
696		return ErrSessionBusy
697	}
698
699	// Create a new context with cancellation
700	summarizeCtx, cancel := context.WithCancel(ctx)
701
702	// Store the cancel function in activeRequests to allow cancellation
703	a.activeRequests.Set(sessionID+"-summarize", cancel)
704
705	go func() {
706		defer a.activeRequests.Del(sessionID + "-summarize")
707		defer cancel()
708		event := AgentEvent{
709			Type:     AgentEventTypeSummarize,
710			Progress: "Starting summarization...",
711		}
712
713		a.Publish(pubsub.CreatedEvent, event)
714		// Get all messages from the session
715		msgs, err := a.messages.List(summarizeCtx, sessionID)
716		if err != nil {
717			event = AgentEvent{
718				Type:  AgentEventTypeError,
719				Error: fmt.Errorf("failed to list messages: %w", err),
720				Done:  true,
721			}
722			a.Publish(pubsub.CreatedEvent, event)
723			return
724		}
725		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
726
727		if len(msgs) == 0 {
728			event = AgentEvent{
729				Type:  AgentEventTypeError,
730				Error: fmt.Errorf("no messages to summarize"),
731				Done:  true,
732			}
733			a.Publish(pubsub.CreatedEvent, event)
734			return
735		}
736
737		event = AgentEvent{
738			Type:     AgentEventTypeSummarize,
739			Progress: "Analyzing conversation...",
740		}
741		a.Publish(pubsub.CreatedEvent, event)
742
743		// Add a system message to guide the summarization
744		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
745
746		// Create a new message with the summarize prompt
747		promptMsg := message.Message{
748			Role:  message.User,
749			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
750		}
751
752		// Append the prompt to the messages
753		msgsWithPrompt := append(msgs, promptMsg)
754
755		event = AgentEvent{
756			Type:     AgentEventTypeSummarize,
757			Progress: "Generating summary...",
758		}
759
760		a.Publish(pubsub.CreatedEvent, event)
761
762		// Send the messages to the summarize provider
763		response := a.summarizeProvider.StreamResponse(
764			summarizeCtx,
765			msgsWithPrompt,
766			nil,
767		)
768		var finalResponse *provider.ProviderResponse
769		for r := range response {
770			if r.Error != nil {
771				event = AgentEvent{
772					Type:  AgentEventTypeError,
773					Error: fmt.Errorf("failed to summarize: %w", err),
774					Done:  true,
775				}
776				a.Publish(pubsub.CreatedEvent, event)
777				return
778			}
779			finalResponse = r.Response
780		}
781
782		summary := strings.TrimSpace(finalResponse.Content)
783		if summary == "" {
784			event = AgentEvent{
785				Type:  AgentEventTypeError,
786				Error: fmt.Errorf("empty summary returned"),
787				Done:  true,
788			}
789			a.Publish(pubsub.CreatedEvent, event)
790			return
791		}
792		shell := shell.GetPersistentShell(config.Get().WorkingDir())
793		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
794		event = AgentEvent{
795			Type:     AgentEventTypeSummarize,
796			Progress: "Creating new session...",
797		}
798
799		a.Publish(pubsub.CreatedEvent, event)
800		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
801		if err != nil {
802			event = AgentEvent{
803				Type:  AgentEventTypeError,
804				Error: fmt.Errorf("failed to get session: %w", err),
805				Done:  true,
806			}
807
808			a.Publish(pubsub.CreatedEvent, event)
809			return
810		}
811		// Create a message in the new session with the summary
812		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
813			Role: message.Assistant,
814			Parts: []message.ContentPart{
815				message.TextContent{Text: summary},
816				message.Finish{
817					Reason: message.FinishReasonEndTurn,
818					Time:   time.Now().Unix(),
819				},
820			},
821			Model:    a.summarizeProvider.Model().ID,
822			Provider: a.summarizeProviderID,
823		})
824		if err != nil {
825			event = AgentEvent{
826				Type:  AgentEventTypeError,
827				Error: fmt.Errorf("failed to create summary message: %w", err),
828				Done:  true,
829			}
830
831			a.Publish(pubsub.CreatedEvent, event)
832			return
833		}
834		oldSession.SummaryMessageID = msg.ID
835		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
836		oldSession.PromptTokens = 0
837		model := a.summarizeProvider.Model()
838		usage := finalResponse.Usage
839		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
840			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
841			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
842			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
843		oldSession.Cost += cost
844		_, err = a.sessions.Save(summarizeCtx, oldSession)
845		if err != nil {
846			event = AgentEvent{
847				Type:  AgentEventTypeError,
848				Error: fmt.Errorf("failed to save session: %w", err),
849				Done:  true,
850			}
851			a.Publish(pubsub.CreatedEvent, event)
852		}
853
854		event = AgentEvent{
855			Type:      AgentEventTypeSummarize,
856			SessionID: oldSession.ID,
857			Progress:  "Summary complete",
858			Done:      true,
859		}
860		a.Publish(pubsub.CreatedEvent, event)
861		// Send final success event with the new session ID
862	}()
863
864	return nil
865}
866
867func (a *agent) CancelAll() {
868	if !a.IsBusy() {
869		return
870	}
871	for key := range a.activeRequests.Seq2() {
872		a.Cancel(key) // key is sessionID
873	}
874
875	timeout := time.After(5 * time.Second)
876	for a.IsBusy() {
877		select {
878		case <-timeout:
879			return
880		default:
881			time.Sleep(200 * time.Millisecond)
882		}
883	}
884}
885
886func (a *agent) UpdateModel() error {
887	cfg := config.Get()
888
889	// Get current provider configuration
890	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
891	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
892		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
893	}
894
895	// Check if provider has changed
896	if string(currentProviderCfg.ID) != a.providerID {
897		// Provider changed, need to recreate the main provider
898		model := cfg.GetModelByType(a.agentCfg.Model)
899		if model.ID == "" {
900			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
901		}
902
903		promptID := agentPromptMap[a.agentCfg.ID]
904		if promptID == "" {
905			promptID = prompt.PromptDefault
906		}
907
908		opts := []provider.ProviderClientOption{
909			provider.WithModel(a.agentCfg.Model),
910			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
911		}
912
913		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
914		if err != nil {
915			return fmt.Errorf("failed to create new provider: %w", err)
916		}
917
918		// Update the provider and provider ID
919		a.provider = newProvider
920		a.providerID = string(currentProviderCfg.ID)
921	}
922
923	// Check if small model provider has changed (affects title and summarize providers)
924	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
925	var smallModelProviderCfg config.ProviderConfig
926
927	for p := range cfg.Providers.Seq() {
928		if p.ID == smallModelCfg.Provider {
929			smallModelProviderCfg = p
930			break
931		}
932	}
933
934	if smallModelProviderCfg.ID == "" {
935		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
936	}
937
938	// Check if summarize provider has changed
939	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
940		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
941		if smallModel == nil {
942			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
943		}
944
945		// Recreate title provider
946		titleOpts := []provider.ProviderClientOption{
947			provider.WithModel(config.SelectedModelTypeSmall),
948			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
949			// We want the title to be short, so we limit the max tokens
950			provider.WithMaxTokens(40),
951		}
952		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
953		if err != nil {
954			return fmt.Errorf("failed to create new title provider: %w", err)
955		}
956
957		// Recreate summarize provider
958		summarizeOpts := []provider.ProviderClientOption{
959			provider.WithModel(config.SelectedModelTypeSmall),
960			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
961		}
962		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
963		if err != nil {
964			return fmt.Errorf("failed to create new summarize provider: %w", err)
965		}
966
967		// Update the providers and provider ID
968		a.titleProvider = newTitleProvider
969		a.summarizeProvider = newSummarizeProvider
970		a.summarizeProviderID = string(smallModelProviderCfg.ID)
971	}
972
973	return nil
974}