agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	EffectiveMaxTokens() int64
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[config.AgentID]prompt.PromptID{
 81	config.AgentCoder: prompt.PromptCoder,
 82	config.AgentTask:  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(lspClients) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99	}
100
101	allTools := []tools.BaseTool{
102		tools.NewBashTool(permissions),
103		tools.NewEditTool(lspClients, permissions, history),
104		tools.NewFetchTool(permissions),
105		tools.NewGlobTool(),
106		tools.NewGrepTool(),
107		tools.NewLsTool(),
108		tools.NewSourcegraphTool(),
109		tools.NewViewTool(lspClients),
110		tools.NewWriteTool(lspClients, permissions, history),
111	}
112
113	if agentCfg.ID == config.AgentCoder {
114		taskAgentCfg := config.Get().Agents[config.AgentTask]
115		if taskAgentCfg.ID == "" {
116			return nil, fmt.Errorf("task agent not found in config")
117		}
118		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119		if err != nil {
120			return nil, fmt.Errorf("failed to create task agent: %w", err)
121		}
122
123		allTools = append(
124			allTools,
125			NewAgentTool(
126				taskAgent,
127				sessions,
128				messages,
129			),
130		)
131	}
132
133	allTools = append(allTools, otherTools...)
134	providerCfg := config.GetAgentProvider(agentCfg.ID)
135	if providerCfg.ID == "" {
136		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137	}
138	model := config.GetAgentModel(agentCfg.ID)
139
140	if model.ID == "" {
141		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142	}
143
144	promptID := agentPromptMap[agentCfg.ID]
145	if promptID == "" {
146		promptID = prompt.PromptDefault
147	}
148	opts := []provider.ProviderClientOption{
149		provider.WithModel(agentCfg.Model),
150		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151	}
152	agentProvider, err := provider.NewProvider(providerCfg, opts...)
153	if err != nil {
154		return nil, err
155	}
156
157	smallModelCfg := cfg.Models.Small
158	var smallModel config.Model
159
160	var smallModelProviderCfg config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		for _, p := range cfg.Providers {
165			if p.ID == smallModelCfg.Provider {
166				smallModelProviderCfg = p
167				break
168			}
169		}
170		if smallModelProviderCfg.ID == "" {
171			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172		}
173	}
174	for _, m := range smallModelProviderCfg.Models {
175		if m.ID == smallModelCfg.ModelID {
176			smallModel = m
177			break
178		}
179	}
180	if smallModel.ID == "" {
181		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182	}
183
184	titleOpts := []provider.ProviderClientOption{
185		provider.WithModel(config.SmallModel),
186		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187	}
188	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189	if err != nil {
190		return nil, err
191	}
192	summarizeOpts := []provider.ProviderClientOption{
193		provider.WithModel(config.SmallModel),
194		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195	}
196	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197	if err != nil {
198		return nil, err
199	}
200
201	agentTools := []tools.BaseTool{}
202	if agentCfg.AllowedTools == nil {
203		agentTools = allTools
204	} else {
205		for _, tool := range allTools {
206			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207				agentTools = append(agentTools, tool)
208			}
209		}
210	}
211
212	agent := &agent{
213		Broker:              pubsub.NewBroker[AgentEvent](),
214		agentCfg:            agentCfg,
215		provider:            agentProvider,
216		providerID:          string(providerCfg.ID),
217		messages:            messages,
218		sessions:            sessions,
219		tools:               agentTools,
220		titleProvider:       titleProvider,
221		summarizeProvider:   summarizeProvider,
222		summarizeProviderID: string(smallModelProviderCfg.ID),
223		activeRequests:      sync.Map{},
224	}
225
226	return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230	return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242			cancel()
243		}
244	}
245
246	// Also check for summarize requests
247	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250			cancel()
251		}
252	}
253}
254
255func (a *agent) IsBusy() bool {
256	busy := false
257	a.activeRequests.Range(func(key, value any) bool {
258		if cancelFunc, ok := value.(context.CancelFunc); ok {
259			if cancelFunc != nil {
260				busy = true
261				return false
262			}
263		}
264		return true
265	})
266	return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270	_, busy := a.activeRequests.Load(sessionID)
271	return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275	if content == "" {
276		return nil
277	}
278	if a.titleProvider == nil {
279		return nil
280	}
281	session, err := a.sessions.Get(ctx, sessionID)
282	if err != nil {
283		return err
284	}
285	parts := []message.ContentPart{message.TextContent{
286		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287	}}
288
289	// Use streaming approach like summarization
290	response := a.titleProvider.StreamResponse(
291		ctx,
292		[]message.Message{
293			{
294				Role:  message.User,
295				Parts: parts,
296			},
297		},
298		make([]tools.BaseTool, 0),
299	)
300
301	var finalResponse *provider.ProviderResponse
302	for r := range response {
303		if r.Error != nil {
304			return r.Error
305		}
306		finalResponse = r.Response
307	}
308
309	if finalResponse == nil {
310		return fmt.Errorf("no response received from title provider")
311	}
312
313	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314	if title == "" {
315		return nil
316	}
317
318	session.Title = title
319	_, err = a.sessions.Save(ctx, session)
320	return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324	return AgentEvent{
325		Type:  AgentEventTypeError,
326		Error: err,
327	}
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331	if !a.Model().SupportsImages && attachments != nil {
332		attachments = nil
333	}
334	events := make(chan AgentEvent)
335	if a.IsSessionBusy(sessionID) {
336		return nil, ErrSessionBusy
337	}
338
339	genCtx, cancel := context.WithCancel(ctx)
340
341	a.activeRequests.Store(sessionID, cancel)
342	go func() {
343		logging.Debug("Request started", "sessionID", sessionID)
344		defer logging.RecoverPanic("agent.Run", func() {
345			events <- a.err(fmt.Errorf("panic while running the agent"))
346		})
347		var attachmentParts []message.ContentPart
348		for _, attachment := range attachments {
349			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350		}
351		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353			logging.ErrorPersist(result.Error.Error())
354		}
355		logging.Debug("Request completed", "sessionID", sessionID)
356		a.activeRequests.Delete(sessionID)
357		cancel()
358		a.Publish(pubsub.CreatedEvent, result)
359		events <- result
360		close(events)
361	}()
362	return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366	// List existing messages; if none, start title generation asynchronously.
367	msgs, err := a.messages.List(ctx, sessionID)
368	if err != nil {
369		return a.err(fmt.Errorf("failed to list messages: %w", err))
370	}
371	if len(msgs) == 0 {
372		go func() {
373			defer logging.RecoverPanic("agent.Run", func() {
374				logging.ErrorPersist("panic while generating title")
375			})
376			titleErr := a.generateTitle(context.Background(), sessionID, content)
377			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
378				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
379			}
380		}()
381	}
382	session, err := a.sessions.Get(ctx, sessionID)
383	if err != nil {
384		return a.err(fmt.Errorf("failed to get session: %w", err))
385	}
386	if session.SummaryMessageID != "" {
387		summaryMsgInex := -1
388		for i, msg := range msgs {
389			if msg.ID == session.SummaryMessageID {
390				summaryMsgInex = i
391				break
392			}
393		}
394		if summaryMsgInex != -1 {
395			msgs = msgs[summaryMsgInex:]
396			msgs[0].Role = message.User
397		}
398	}
399
400	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
401	if err != nil {
402		return a.err(fmt.Errorf("failed to create user message: %w", err))
403	}
404	// Append the new user message to the conversation history.
405	msgHistory := append(msgs, userMsg)
406
407	for {
408		// Check for cancellation before each iteration
409		select {
410		case <-ctx.Done():
411			return a.err(ctx.Err())
412		default:
413			// Continue processing
414		}
415		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
416		if err != nil {
417			if errors.Is(err, context.Canceled) {
418				agentMessage.AddFinish(message.FinishReasonCanceled)
419				a.messages.Update(context.Background(), agentMessage)
420				return a.err(ErrRequestCancelled)
421			}
422			return a.err(fmt.Errorf("failed to process events: %w", err))
423		}
424		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
425		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
426			// We are not done, we need to respond with the tool response
427			msgHistory = append(msgHistory, agentMessage, *toolResults)
428			continue
429		}
430		return AgentEvent{
431			Type:    AgentEventTypeResponse,
432			Message: agentMessage,
433			Done:    true,
434		}
435	}
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439	parts := []message.ContentPart{message.TextContent{Text: content}}
440	parts = append(parts, attachmentParts...)
441	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:  message.User,
443		Parts: parts,
444	})
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
449
450	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451		Role:     message.Assistant,
452		Parts:    []message.ContentPart{},
453		Model:    a.Model().ID,
454		Provider: a.providerID,
455	})
456	if err != nil {
457		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458	}
459
460	// Add the session and message ID into the context if needed by tools.
461	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
463
464	// Process each event in the stream.
465	for event := range eventChan {
466		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
468			return assistantMsg, nil, processErr
469		}
470		if ctx.Err() != nil {
471			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
472			return assistantMsg, nil, ctx.Err()
473		}
474	}
475
476	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
477	toolCalls := assistantMsg.ToolCalls()
478	for i, toolCall := range toolCalls {
479		select {
480		case <-ctx.Done():
481			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
482			// Make all future tool calls cancelled
483			for j := i; j < len(toolCalls); j++ {
484				toolResults[j] = message.ToolResult{
485					ToolCallID: toolCalls[j].ID,
486					Content:    "Tool execution canceled by user",
487					IsError:    true,
488				}
489			}
490			goto out
491		default:
492			// Continue processing
493			var tool tools.BaseTool
494			for _, availableTools := range a.tools {
495				if availableTools.Info().Name == toolCall.Name {
496					tool = availableTools
497				}
498			}
499
500			// Tool not found
501			if tool == nil {
502				toolResults[i] = message.ToolResult{
503					ToolCallID: toolCall.ID,
504					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
505					IsError:    true,
506				}
507				continue
508			}
509			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
510				ID:    toolCall.ID,
511				Name:  toolCall.Name,
512				Input: toolCall.Input,
513			})
514			if toolErr != nil {
515				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
516					toolResults[i] = message.ToolResult{
517						ToolCallID: toolCall.ID,
518						Content:    "Permission denied",
519						IsError:    true,
520					}
521					for j := i + 1; j < len(toolCalls); j++ {
522						toolResults[j] = message.ToolResult{
523							ToolCallID: toolCalls[j].ID,
524							Content:    "Tool execution canceled by user",
525							IsError:    true,
526						}
527					}
528					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
529					break
530				}
531			}
532			toolResults[i] = message.ToolResult{
533				ToolCallID: toolCall.ID,
534				Content:    toolResult.Content,
535				Metadata:   toolResult.Metadata,
536				IsError:    toolResult.IsError,
537			}
538		}
539	}
540out:
541	if len(toolResults) == 0 {
542		return assistantMsg, nil, nil
543	}
544	parts := make([]message.ContentPart, 0)
545	for _, tr := range toolResults {
546		parts = append(parts, tr)
547	}
548	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
549		Role:     message.Tool,
550		Parts:    parts,
551		Provider: a.providerID,
552	})
553	if err != nil {
554		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
555	}
556
557	return assistantMsg, &msg, err
558}
559
560func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
561	msg.AddFinish(finishReson)
562	_ = a.messages.Update(ctx, *msg)
563}
564
565func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
566	select {
567	case <-ctx.Done():
568		return ctx.Err()
569	default:
570		// Continue processing.
571	}
572
573	switch event.Type {
574	case provider.EventThinkingDelta:
575		assistantMsg.AppendReasoningContent(event.Content)
576		return a.messages.Update(ctx, *assistantMsg)
577	case provider.EventContentDelta:
578		assistantMsg.AppendContent(event.Content)
579		return a.messages.Update(ctx, *assistantMsg)
580	case provider.EventToolUseStart:
581		logging.Info("Tool call started", "toolCall", event.ToolCall)
582		assistantMsg.AddToolCall(*event.ToolCall)
583		return a.messages.Update(ctx, *assistantMsg)
584	case provider.EventToolUseDelta:
585		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
586		return a.messages.Update(ctx, *assistantMsg)
587	case provider.EventToolUseStop:
588		logging.Info("Finished tool call", "toolCall", event.ToolCall)
589		assistantMsg.FinishToolCall(event.ToolCall.ID)
590		return a.messages.Update(ctx, *assistantMsg)
591	case provider.EventError:
592		if errors.Is(event.Error, context.Canceled) {
593			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
594			return context.Canceled
595		}
596		logging.ErrorPersist(event.Error.Error())
597		return event.Error
598	case provider.EventComplete:
599		assistantMsg.SetToolCalls(event.Response.ToolCalls)
600		assistantMsg.AddFinish(event.Response.FinishReason)
601		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
602			return fmt.Errorf("failed to update message: %w", err)
603		}
604		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
605	}
606
607	return nil
608}
609
610func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
611	sess, err := a.sessions.Get(ctx, sessionID)
612	if err != nil {
613		return fmt.Errorf("failed to get session: %w", err)
614	}
615
616	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
617		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
618		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
619		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
620
621	sess.Cost += cost
622	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
623	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
624
625	_, err = a.sessions.Save(ctx, sess)
626	if err != nil {
627		return fmt.Errorf("failed to save session: %w", err)
628	}
629	return nil
630}
631
632func (a *agent) Summarize(ctx context.Context, sessionID string) error {
633	if a.summarizeProvider == nil {
634		return fmt.Errorf("summarize provider not available")
635	}
636
637	// Check if session is busy
638	if a.IsSessionBusy(sessionID) {
639		return ErrSessionBusy
640	}
641
642	// Create a new context with cancellation
643	summarizeCtx, cancel := context.WithCancel(ctx)
644
645	// Store the cancel function in activeRequests to allow cancellation
646	a.activeRequests.Store(sessionID+"-summarize", cancel)
647
648	go func() {
649		defer a.activeRequests.Delete(sessionID + "-summarize")
650		defer cancel()
651		event := AgentEvent{
652			Type:     AgentEventTypeSummarize,
653			Progress: "Starting summarization...",
654		}
655
656		a.Publish(pubsub.CreatedEvent, event)
657		// Get all messages from the session
658		msgs, err := a.messages.List(summarizeCtx, sessionID)
659		if err != nil {
660			event = AgentEvent{
661				Type:  AgentEventTypeError,
662				Error: fmt.Errorf("failed to list messages: %w", err),
663				Done:  true,
664			}
665			a.Publish(pubsub.CreatedEvent, event)
666			return
667		}
668
669		if len(msgs) == 0 {
670			event = AgentEvent{
671				Type:  AgentEventTypeError,
672				Error: fmt.Errorf("no messages to summarize"),
673				Done:  true,
674			}
675			a.Publish(pubsub.CreatedEvent, event)
676			return
677		}
678
679		event = AgentEvent{
680			Type:     AgentEventTypeSummarize,
681			Progress: "Analyzing conversation...",
682		}
683		a.Publish(pubsub.CreatedEvent, event)
684
685		// Add a system message to guide the summarization
686		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
687
688		// Create a new message with the summarize prompt
689		promptMsg := message.Message{
690			Role:  message.User,
691			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
692		}
693
694		// Append the prompt to the messages
695		msgsWithPrompt := append(msgs, promptMsg)
696
697		event = AgentEvent{
698			Type:     AgentEventTypeSummarize,
699			Progress: "Generating summary...",
700		}
701
702		a.Publish(pubsub.CreatedEvent, event)
703
704		// Send the messages to the summarize provider
705		response := a.summarizeProvider.StreamResponse(
706			summarizeCtx,
707			msgsWithPrompt,
708			make([]tools.BaseTool, 0),
709		)
710		var finalResponse *provider.ProviderResponse
711		for r := range response {
712			if r.Error != nil {
713				event = AgentEvent{
714					Type:  AgentEventTypeError,
715					Error: fmt.Errorf("failed to summarize: %w", err),
716					Done:  true,
717				}
718				a.Publish(pubsub.CreatedEvent, event)
719				return
720			}
721			finalResponse = r.Response
722		}
723
724		summary := strings.TrimSpace(finalResponse.Content)
725		if summary == "" {
726			event = AgentEvent{
727				Type:  AgentEventTypeError,
728				Error: fmt.Errorf("empty summary returned"),
729				Done:  true,
730			}
731			a.Publish(pubsub.CreatedEvent, event)
732			return
733		}
734		event = AgentEvent{
735			Type:     AgentEventTypeSummarize,
736			Progress: "Creating new session...",
737		}
738
739		a.Publish(pubsub.CreatedEvent, event)
740		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
741		if err != nil {
742			event = AgentEvent{
743				Type:  AgentEventTypeError,
744				Error: fmt.Errorf("failed to get session: %w", err),
745				Done:  true,
746			}
747
748			a.Publish(pubsub.CreatedEvent, event)
749			return
750		}
751		// Create a message in the new session with the summary
752		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
753			Role: message.Assistant,
754			Parts: []message.ContentPart{
755				message.TextContent{Text: summary},
756				message.Finish{
757					Reason: message.FinishReasonEndTurn,
758					Time:   time.Now().Unix(),
759				},
760			},
761			Model:    a.summarizeProvider.Model().ID,
762			Provider: a.summarizeProviderID,
763		})
764		if err != nil {
765			event = AgentEvent{
766				Type:  AgentEventTypeError,
767				Error: fmt.Errorf("failed to create summary message: %w", err),
768				Done:  true,
769			}
770
771			a.Publish(pubsub.CreatedEvent, event)
772			return
773		}
774		oldSession.SummaryMessageID = msg.ID
775		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
776		oldSession.PromptTokens = 0
777		model := a.summarizeProvider.Model()
778		usage := finalResponse.Usage
779		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
780			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
781			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
782			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
783		oldSession.Cost += cost
784		_, err = a.sessions.Save(summarizeCtx, oldSession)
785		if err != nil {
786			event = AgentEvent{
787				Type:  AgentEventTypeError,
788				Error: fmt.Errorf("failed to save session: %w", err),
789				Done:  true,
790			}
791			a.Publish(pubsub.CreatedEvent, event)
792		}
793
794		event = AgentEvent{
795			Type:      AgentEventTypeSummarize,
796			SessionID: oldSession.ID,
797			Progress:  "Summary complete",
798			Done:      true,
799		}
800		a.Publish(pubsub.CreatedEvent, event)
801		// Send final success event with the new session ID
802	}()
803
804	return nil
805}
806
807func (a *agent) CancelAll() {
808	a.activeRequests.Range(func(key, value any) bool {
809		a.Cancel(key.(string)) // key is sessionID
810		return true
811	})
812}
813
814func (a *agent) UpdateModel() error {
815	cfg := config.Get()
816
817	// Get current provider configuration
818	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
819	if currentProviderCfg.ID == "" {
820		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
821	}
822
823	// Check if provider has changed
824	if string(currentProviderCfg.ID) != a.providerID {
825		// Provider changed, need to recreate the main provider
826		model := config.GetAgentModel(a.agentCfg.ID)
827		if model.ID == "" {
828			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
829		}
830
831		promptID := agentPromptMap[a.agentCfg.ID]
832		if promptID == "" {
833			promptID = prompt.PromptDefault
834		}
835
836		opts := []provider.ProviderClientOption{
837			provider.WithModel(a.agentCfg.Model),
838			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
839		}
840
841		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
842		if err != nil {
843			return fmt.Errorf("failed to create new provider: %w", err)
844		}
845
846		// Update the provider and provider ID
847		a.provider = newProvider
848		a.providerID = string(currentProviderCfg.ID)
849	}
850
851	// Check if small model provider has changed (affects title and summarize providers)
852	smallModelCfg := cfg.Models.Small
853	var smallModelProviderCfg config.ProviderConfig
854
855	for _, p := range cfg.Providers {
856		if p.ID == smallModelCfg.Provider {
857			smallModelProviderCfg = p
858			break
859		}
860	}
861
862	if smallModelProviderCfg.ID == "" {
863		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
864	}
865
866	// Check if summarize provider has changed
867	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
868		var smallModel config.Model
869		for _, m := range smallModelProviderCfg.Models {
870			if m.ID == smallModelCfg.ModelID {
871				smallModel = m
872				break
873			}
874		}
875		if smallModel.ID == "" {
876			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
877		}
878
879		// Recreate title provider
880		titleOpts := []provider.ProviderClientOption{
881			provider.WithModel(config.SmallModel),
882			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
883			// We want the title to be short, so we limit the max tokens
884			provider.WithMaxTokens(40),
885		}
886		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
887		if err != nil {
888			return fmt.Errorf("failed to create new title provider: %w", err)
889		}
890
891		// Recreate summarize provider
892		summarizeOpts := []provider.ProviderClientOption{
893			provider.WithModel(config.SmallModel),
894			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
895		}
896		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
897		if err != nil {
898			return fmt.Errorf("failed to create new summarize provider: %w", err)
899		}
900
901		// Update the providers and provider ID
902		a.titleProvider = newTitleProvider
903		a.summarizeProvider = newSummarizeProvider
904		a.summarizeProviderID = string(smallModelProviderCfg.ID)
905	}
906
907	return nil
908}