agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	configv2 "github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() configv2.Model
 53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	Update(model configv2.PreferredModel) (configv2.Model, error)
 59	Summarize(ctx context.Context, sessionID string) error
 60}
 61
 62type agent struct {
 63	*pubsub.Broker[AgentEvent]
 64	agentCfg configv2.Agent
 65	sessions session.Service
 66	messages message.Service
 67
 68	tools      []tools.BaseTool
 69	provider   provider.Provider
 70	providerID string
 71
 72	titleProvider       provider.Provider
 73	summarizeProvider   provider.Provider
 74	summarizeProviderID string
 75
 76	activeRequests sync.Map
 77}
 78
 79var agentPromptMap = map[configv2.AgentID]prompt.PromptID{
 80	configv2.AgentCoder: prompt.PromptCoder,
 81	configv2.AgentTask:  prompt.PromptTask,
 82}
 83
 84func NewAgent(
 85	agentCfg configv2.Agent,
 86	// These services are needed in the tools
 87	permissions permission.Service,
 88	sessions session.Service,
 89	messages message.Service,
 90	history history.Service,
 91	lspClients map[string]*lsp.Client,
 92) (Service, error) {
 93	ctx := context.Background()
 94	cfg := configv2.Get()
 95	otherTools := GetMcpTools(ctx, permissions)
 96	if len(lspClients) > 0 {
 97		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 98	}
 99
100	allTools := []tools.BaseTool{
101		tools.NewBashTool(permissions),
102		tools.NewEditTool(lspClients, permissions, history),
103		tools.NewFetchTool(permissions),
104		tools.NewGlobTool(),
105		tools.NewGrepTool(),
106		tools.NewLsTool(),
107		tools.NewSourcegraphTool(),
108		tools.NewViewTool(lspClients),
109		tools.NewWriteTool(lspClients, permissions, history),
110	}
111
112	if agentCfg.ID == configv2.AgentCoder {
113		taskAgentCfg := configv2.Get().Agents[configv2.AgentTask]
114		if taskAgentCfg.ID == "" {
115			return nil, fmt.Errorf("task agent not found in config")
116		}
117		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118		if err != nil {
119			return nil, fmt.Errorf("failed to create task agent: %w", err)
120		}
121
122		allTools = append(
123			allTools,
124			NewAgentTool(
125				taskAgent,
126				sessions,
127				messages,
128			),
129		)
130	}
131
132	allTools = append(allTools, otherTools...)
133	var providerCfg configv2.ProviderConfig
134	for _, p := range cfg.Providers {
135		if p.ID == agentCfg.Provider {
136			providerCfg = p
137			break
138		}
139	}
140	if providerCfg.ID == "" {
141		return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider)
142	}
143
144	var model configv2.Model
145	for _, m := range providerCfg.Models {
146		if m.ID == agentCfg.Model {
147			model = m
148			break
149		}
150	}
151	if model.ID == "" {
152		return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider)
153	}
154
155	promptID := agentPromptMap[agentCfg.ID]
156	if promptID == "" {
157		promptID = prompt.PromptDefault
158	}
159	opts := []provider.ProviderClientOption{
160		provider.WithModel(model),
161		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
162		provider.WithMaxTokens(model.DefaultMaxTokens),
163	}
164	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
165	if err != nil {
166		return nil, err
167	}
168
169	smallModelCfg := cfg.Models.Small
170	var smallModel configv2.Model
171
172	var smallModelProviderCfg configv2.ProviderConfig
173	if smallModelCfg.Provider == providerCfg.ID {
174		smallModelProviderCfg = providerCfg
175	} else {
176		for _, p := range cfg.Providers {
177			if p.ID == smallModelCfg.Provider {
178				smallModelProviderCfg = p
179				break
180			}
181		}
182		if smallModelProviderCfg.ID == "" {
183			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
184		}
185	}
186	for _, m := range smallModelProviderCfg.Models {
187		if m.ID == smallModelCfg.ModelID {
188			smallModel = m
189			break
190		}
191	}
192	if smallModel.ID == "" {
193		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
194	}
195
196	titleOpts := []provider.ProviderClientOption{
197		provider.WithModel(smallModel),
198		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
199		provider.WithMaxTokens(40),
200	}
201	titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
202	if err != nil {
203		return nil, err
204	}
205	summarizeOpts := []provider.ProviderClientOption{
206		provider.WithModel(smallModel),
207		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
208		provider.WithMaxTokens(smallModel.DefaultMaxTokens),
209	}
210	summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
211	if err != nil {
212		return nil, err
213	}
214
215	agentTools := []tools.BaseTool{}
216	if agentCfg.AllowedTools == nil {
217		agentTools = allTools
218	} else {
219		for _, tool := range allTools {
220			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221				agentTools = append(agentTools, tool)
222			}
223		}
224	}
225
226	agent := &agent{
227		Broker:              pubsub.NewBroker[AgentEvent](),
228		agentCfg:            agentCfg,
229		provider:            agentProvider,
230		providerID:          string(providerCfg.ID),
231		messages:            messages,
232		sessions:            sessions,
233		tools:               agentTools,
234		titleProvider:       titleProvider,
235		summarizeProvider:   summarizeProvider,
236		summarizeProviderID: string(smallModelProviderCfg.ID),
237		activeRequests:      sync.Map{},
238	}
239
240	return agent, nil
241}
242
243func (a *agent) Model() configv2.Model {
244	return a.provider.Model()
245}
246
247func (a *agent) Cancel(sessionID string) {
248	// Cancel regular requests
249	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
250		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
251			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
252			cancel()
253		}
254	}
255
256	// Also check for summarize requests
257	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
258		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
259			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
260			cancel()
261		}
262	}
263}
264
265func (a *agent) IsBusy() bool {
266	busy := false
267	a.activeRequests.Range(func(key, value any) bool {
268		if cancelFunc, ok := value.(context.CancelFunc); ok {
269			if cancelFunc != nil {
270				busy = true
271				return false // Stop iterating
272			}
273		}
274		return true // Continue iterating
275	})
276	return busy
277}
278
279func (a *agent) IsSessionBusy(sessionID string) bool {
280	_, busy := a.activeRequests.Load(sessionID)
281	return busy
282}
283
284func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
285	if content == "" {
286		return nil
287	}
288	if a.titleProvider == nil {
289		return nil
290	}
291	session, err := a.sessions.Get(ctx, sessionID)
292	if err != nil {
293		return err
294	}
295	parts := []message.ContentPart{message.TextContent{Text: content}}
296
297	// Use streaming approach like summarization
298	response := a.titleProvider.StreamResponse(
299		ctx,
300		[]message.Message{
301			{
302				Role:  message.User,
303				Parts: parts,
304			},
305		},
306		make([]tools.BaseTool, 0),
307	)
308
309	var finalResponse *provider.ProviderResponse
310	for r := range response {
311		if r.Error != nil {
312			return r.Error
313		}
314		finalResponse = r.Response
315	}
316
317	if finalResponse == nil {
318		return fmt.Errorf("no response received from title provider")
319	}
320
321	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
322	if title == "" {
323		return nil
324	}
325
326	session.Title = title
327	_, err = a.sessions.Save(ctx, session)
328	return err
329}
330
331func (a *agent) err(err error) AgentEvent {
332	return AgentEvent{
333		Type:  AgentEventTypeError,
334		Error: err,
335	}
336}
337
338func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
339	if !a.provider.Model().SupportsImages && attachments != nil {
340		attachments = nil
341	}
342	events := make(chan AgentEvent)
343	if a.IsSessionBusy(sessionID) {
344		return nil, ErrSessionBusy
345	}
346
347	genCtx, cancel := context.WithCancel(ctx)
348
349	a.activeRequests.Store(sessionID, cancel)
350	go func() {
351		logging.Debug("Request started", "sessionID", sessionID)
352		defer logging.RecoverPanic("agent.Run", func() {
353			events <- a.err(fmt.Errorf("panic while running the agent"))
354		})
355		var attachmentParts []message.ContentPart
356		for _, attachment := range attachments {
357			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
358		}
359		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
360		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
361			logging.ErrorPersist(result.Error.Error())
362		}
363		logging.Debug("Request completed", "sessionID", sessionID)
364		a.activeRequests.Delete(sessionID)
365		cancel()
366		a.Publish(pubsub.CreatedEvent, result)
367		events <- result
368		close(events)
369	}()
370	return events, nil
371}
372
373func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
374	// List existing messages; if none, start title generation asynchronously.
375	msgs, err := a.messages.List(ctx, sessionID)
376	if err != nil {
377		return a.err(fmt.Errorf("failed to list messages: %w", err))
378	}
379	if len(msgs) == 0 {
380		go func() {
381			defer logging.RecoverPanic("agent.Run", func() {
382				logging.ErrorPersist("panic while generating title")
383			})
384			titleErr := a.generateTitle(context.Background(), sessionID, content)
385			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
386				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
387			}
388		}()
389	}
390	session, err := a.sessions.Get(ctx, sessionID)
391	if err != nil {
392		return a.err(fmt.Errorf("failed to get session: %w", err))
393	}
394	if session.SummaryMessageID != "" {
395		summaryMsgInex := -1
396		for i, msg := range msgs {
397			if msg.ID == session.SummaryMessageID {
398				summaryMsgInex = i
399				break
400			}
401		}
402		if summaryMsgInex != -1 {
403			msgs = msgs[summaryMsgInex:]
404			msgs[0].Role = message.User
405		}
406	}
407
408	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
409	if err != nil {
410		return a.err(fmt.Errorf("failed to create user message: %w", err))
411	}
412	// Append the new user message to the conversation history.
413	msgHistory := append(msgs, userMsg)
414
415	for {
416		// Check for cancellation before each iteration
417		select {
418		case <-ctx.Done():
419			return a.err(ctx.Err())
420		default:
421			// Continue processing
422		}
423		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
424		if err != nil {
425			if errors.Is(err, context.Canceled) {
426				agentMessage.AddFinish(message.FinishReasonCanceled)
427				a.messages.Update(context.Background(), agentMessage)
428				return a.err(ErrRequestCancelled)
429			}
430			return a.err(fmt.Errorf("failed to process events: %w", err))
431		}
432		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
433		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
434			// We are not done, we need to respond with the tool response
435			msgHistory = append(msgHistory, agentMessage, *toolResults)
436			continue
437		}
438		return AgentEvent{
439			Type:    AgentEventTypeResponse,
440			Message: agentMessage,
441			Done:    true,
442		}
443	}
444}
445
446func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
447	parts := []message.ContentPart{message.TextContent{Text: content}}
448	parts = append(parts, attachmentParts...)
449	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
450		Role:  message.User,
451		Parts: parts,
452	})
453}
454
455func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
456	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459		Role:     message.Assistant,
460		Parts:    []message.ContentPart{},
461		Model:    a.provider.Model().ID,
462		Provider: a.providerID,
463	})
464	if err != nil {
465		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466	}
467
468	// Add the session and message ID into the context if needed by tools.
469	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
471
472	// Process each event in the stream.
473	for event := range eventChan {
474		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
475			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
476			return assistantMsg, nil, processErr
477		}
478		if ctx.Err() != nil {
479			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
480			return assistantMsg, nil, ctx.Err()
481		}
482	}
483
484	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
485	toolCalls := assistantMsg.ToolCalls()
486	for i, toolCall := range toolCalls {
487		select {
488		case <-ctx.Done():
489			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
490			// Make all future tool calls cancelled
491			for j := i; j < len(toolCalls); j++ {
492				toolResults[j] = message.ToolResult{
493					ToolCallID: toolCalls[j].ID,
494					Content:    "Tool execution canceled by user",
495					IsError:    true,
496				}
497			}
498			goto out
499		default:
500			// Continue processing
501			var tool tools.BaseTool
502			for _, availableTools := range a.tools {
503				if availableTools.Info().Name == toolCall.Name {
504					tool = availableTools
505				}
506			}
507
508			// Tool not found
509			if tool == nil {
510				toolResults[i] = message.ToolResult{
511					ToolCallID: toolCall.ID,
512					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
513					IsError:    true,
514				}
515				continue
516			}
517			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
518				ID:    toolCall.ID,
519				Name:  toolCall.Name,
520				Input: toolCall.Input,
521			})
522			if toolErr != nil {
523				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
524					toolResults[i] = message.ToolResult{
525						ToolCallID: toolCall.ID,
526						Content:    "Permission denied",
527						IsError:    true,
528					}
529					for j := i + 1; j < len(toolCalls); j++ {
530						toolResults[j] = message.ToolResult{
531							ToolCallID: toolCalls[j].ID,
532							Content:    "Tool execution canceled by user",
533							IsError:    true,
534						}
535					}
536					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
537					break
538				}
539			}
540			toolResults[i] = message.ToolResult{
541				ToolCallID: toolCall.ID,
542				Content:    toolResult.Content,
543				Metadata:   toolResult.Metadata,
544				IsError:    toolResult.IsError,
545			}
546		}
547	}
548out:
549	if len(toolResults) == 0 {
550		return assistantMsg, nil, nil
551	}
552	parts := make([]message.ContentPart, 0)
553	for _, tr := range toolResults {
554		parts = append(parts, tr)
555	}
556	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
557		Role:     message.Tool,
558		Parts:    parts,
559		Provider: a.providerID,
560	})
561	if err != nil {
562		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
563	}
564
565	return assistantMsg, &msg, err
566}
567
568func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
569	msg.AddFinish(finishReson)
570	_ = a.messages.Update(ctx, *msg)
571}
572
573func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
574	select {
575	case <-ctx.Done():
576		return ctx.Err()
577	default:
578		// Continue processing.
579	}
580
581	switch event.Type {
582	case provider.EventThinkingDelta:
583		assistantMsg.AppendReasoningContent(event.Content)
584		return a.messages.Update(ctx, *assistantMsg)
585	case provider.EventContentDelta:
586		assistantMsg.AppendContent(event.Content)
587		return a.messages.Update(ctx, *assistantMsg)
588	case provider.EventToolUseStart:
589		logging.Info("Tool call started", "toolCall", event.ToolCall)
590		assistantMsg.AddToolCall(*event.ToolCall)
591		return a.messages.Update(ctx, *assistantMsg)
592	case provider.EventToolUseDelta:
593		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
594		return a.messages.Update(ctx, *assistantMsg)
595	case provider.EventToolUseStop:
596		logging.Info("Finished tool call", "toolCall", event.ToolCall)
597		assistantMsg.FinishToolCall(event.ToolCall.ID)
598		return a.messages.Update(ctx, *assistantMsg)
599	case provider.EventError:
600		if errors.Is(event.Error, context.Canceled) {
601			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
602			return context.Canceled
603		}
604		logging.ErrorPersist(event.Error.Error())
605		return event.Error
606	case provider.EventComplete:
607		assistantMsg.SetToolCalls(event.Response.ToolCalls)
608		assistantMsg.AddFinish(event.Response.FinishReason)
609		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
610			return fmt.Errorf("failed to update message: %w", err)
611		}
612		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
613	}
614
615	return nil
616}
617
618func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error {
619	sess, err := a.sessions.Get(ctx, sessionID)
620	if err != nil {
621		return fmt.Errorf("failed to get session: %w", err)
622	}
623
624	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
625		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
626		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
627		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
628
629	sess.Cost += cost
630	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
631	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
632
633	_, err = a.sessions.Save(ctx, sess)
634	if err != nil {
635		return fmt.Errorf("failed to save session: %w", err)
636	}
637	return nil
638}
639
640func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) {
641	if a.IsBusy() {
642		return configv2.Model{}, fmt.Errorf("cannot change model while processing requests")
643	}
644
645	cfg := configv2.Get()
646	var providerCfg configv2.ProviderConfig
647	for _, p := range cfg.Providers {
648		if p.ID == modelCfg.Provider {
649			providerCfg = p
650			break
651		}
652	}
653	if providerCfg.ID == "" {
654		return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider)
655	}
656
657	var model configv2.Model
658	for _, m := range providerCfg.Models {
659		if m.ID == modelCfg.ModelID {
660			model = m
661			break
662		}
663	}
664	if model.ID == "" {
665		return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider)
666	}
667
668	promptID := agentPromptMap[a.agentCfg.ID]
669	if promptID == "" {
670		promptID = prompt.PromptDefault
671	}
672	opts := []provider.ProviderClientOption{
673		provider.WithModel(model),
674		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
675		provider.WithMaxTokens(model.DefaultMaxTokens),
676	}
677	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
678	if err != nil {
679		return configv2.Model{}, err
680	}
681	a.provider = agentProvider
682
683	return a.provider.Model(), nil
684}
685
686func (a *agent) Summarize(ctx context.Context, sessionID string) error {
687	if a.summarizeProvider == nil {
688		return fmt.Errorf("summarize provider not available")
689	}
690
691	// Check if session is busy
692	if a.IsSessionBusy(sessionID) {
693		return ErrSessionBusy
694	}
695
696	// Create a new context with cancellation
697	summarizeCtx, cancel := context.WithCancel(ctx)
698
699	// Store the cancel function in activeRequests to allow cancellation
700	a.activeRequests.Store(sessionID+"-summarize", cancel)
701
702	go func() {
703		defer a.activeRequests.Delete(sessionID + "-summarize")
704		defer cancel()
705		event := AgentEvent{
706			Type:     AgentEventTypeSummarize,
707			Progress: "Starting summarization...",
708		}
709
710		a.Publish(pubsub.CreatedEvent, event)
711		// Get all messages from the session
712		msgs, err := a.messages.List(summarizeCtx, sessionID)
713		if err != nil {
714			event = AgentEvent{
715				Type:  AgentEventTypeError,
716				Error: fmt.Errorf("failed to list messages: %w", err),
717				Done:  true,
718			}
719			a.Publish(pubsub.CreatedEvent, event)
720			return
721		}
722
723		if len(msgs) == 0 {
724			event = AgentEvent{
725				Type:  AgentEventTypeError,
726				Error: fmt.Errorf("no messages to summarize"),
727				Done:  true,
728			}
729			a.Publish(pubsub.CreatedEvent, event)
730			return
731		}
732
733		event = AgentEvent{
734			Type:     AgentEventTypeSummarize,
735			Progress: "Analyzing conversation...",
736		}
737		a.Publish(pubsub.CreatedEvent, event)
738
739		// Add a system message to guide the summarization
740		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
741
742		// Create a new message with the summarize prompt
743		promptMsg := message.Message{
744			Role:  message.User,
745			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
746		}
747
748		// Append the prompt to the messages
749		msgsWithPrompt := append(msgs, promptMsg)
750
751		event = AgentEvent{
752			Type:     AgentEventTypeSummarize,
753			Progress: "Generating summary...",
754		}
755
756		a.Publish(pubsub.CreatedEvent, event)
757
758		// Send the messages to the summarize provider
759		response := a.summarizeProvider.StreamResponse(
760			summarizeCtx,
761			msgsWithPrompt,
762			make([]tools.BaseTool, 0),
763		)
764		var finalResponse *provider.ProviderResponse
765		for r := range response {
766			if r.Error != nil {
767				event = AgentEvent{
768					Type:  AgentEventTypeError,
769					Error: fmt.Errorf("failed to summarize: %w", err),
770					Done:  true,
771				}
772				a.Publish(pubsub.CreatedEvent, event)
773				return
774			}
775			finalResponse = r.Response
776		}
777
778		summary := strings.TrimSpace(finalResponse.Content)
779		if summary == "" {
780			event = AgentEvent{
781				Type:  AgentEventTypeError,
782				Error: fmt.Errorf("empty summary returned"),
783				Done:  true,
784			}
785			a.Publish(pubsub.CreatedEvent, event)
786			return
787		}
788		event = AgentEvent{
789			Type:     AgentEventTypeSummarize,
790			Progress: "Creating new session...",
791		}
792
793		a.Publish(pubsub.CreatedEvent, event)
794		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
795		if err != nil {
796			event = AgentEvent{
797				Type:  AgentEventTypeError,
798				Error: fmt.Errorf("failed to get session: %w", err),
799				Done:  true,
800			}
801
802			a.Publish(pubsub.CreatedEvent, event)
803			return
804		}
805		// Create a message in the new session with the summary
806		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
807			Role: message.Assistant,
808			Parts: []message.ContentPart{
809				message.TextContent{Text: summary},
810				message.Finish{
811					Reason: message.FinishReasonEndTurn,
812					Time:   time.Now().Unix(),
813				},
814			},
815			Model:    a.summarizeProvider.Model().ID,
816			Provider: a.summarizeProviderID,
817		})
818		if err != nil {
819			event = AgentEvent{
820				Type:  AgentEventTypeError,
821				Error: fmt.Errorf("failed to create summary message: %w", err),
822				Done:  true,
823			}
824
825			a.Publish(pubsub.CreatedEvent, event)
826			return
827		}
828		oldSession.SummaryMessageID = msg.ID
829		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
830		oldSession.PromptTokens = 0
831		model := a.summarizeProvider.Model()
832		usage := finalResponse.Usage
833		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
834			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
835			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
836			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
837		oldSession.Cost += cost
838		_, err = a.sessions.Save(summarizeCtx, oldSession)
839		if err != nil {
840			event = AgentEvent{
841				Type:  AgentEventTypeError,
842				Error: fmt.Errorf("failed to save session: %w", err),
843				Done:  true,
844			}
845			a.Publish(pubsub.CreatedEvent, event)
846		}
847
848		event = AgentEvent{
849			Type:      AgentEventTypeSummarize,
850			SessionID: oldSession.ID,
851			Progress:  "Summary complete",
852			Done:      true,
853		}
854		a.Publish(pubsub.CreatedEvent, event)
855		// Send final success event with the new session ID
856	}()
857
858	return nil
859}
860
861func (a *agent) CancelAll() {
862	a.activeRequests.Range(func(key, value any) bool {
863		a.Cancel(key.(string)) // key is sessionID
864		return true
865	})
866}