agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	Summarize(ctx context.Context, sessionID string) error
 59	UpdateModel() error
 60}
 61
 62type agent struct {
 63	*pubsub.Broker[AgentEvent]
 64	agentCfg config.Agent
 65	sessions session.Service
 66	messages message.Service
 67
 68	tools      []tools.BaseTool
 69	provider   provider.Provider
 70	providerID string
 71
 72	titleProvider       provider.Provider
 73	summarizeProvider   provider.Provider
 74	summarizeProviderID string
 75
 76	activeRequests sync.Map
 77}
 78
 79var agentPromptMap = map[config.AgentID]prompt.PromptID{
 80	config.AgentCoder: prompt.PromptCoder,
 81	config.AgentTask:  prompt.PromptTask,
 82}
 83
 84func NewAgent(
 85	agentCfg config.Agent,
 86	// These services are needed in the tools
 87	permissions permission.Service,
 88	sessions session.Service,
 89	messages message.Service,
 90	history history.Service,
 91	lspClients map[string]*lsp.Client,
 92) (Service, error) {
 93	ctx := context.Background()
 94	cfg := config.Get()
 95	otherTools := GetMcpTools(ctx, permissions)
 96	if len(lspClients) > 0 {
 97		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 98	}
 99
100	allTools := []tools.BaseTool{
101		tools.NewBashTool(permissions),
102		tools.NewEditTool(lspClients, permissions, history),
103		tools.NewFetchTool(permissions),
104		tools.NewGlobTool(),
105		tools.NewGrepTool(),
106		tools.NewLsTool(),
107		tools.NewSourcegraphTool(),
108		tools.NewViewTool(lspClients),
109		tools.NewWriteTool(lspClients, permissions, history),
110	}
111
112	if agentCfg.ID == config.AgentCoder {
113		taskAgentCfg := config.Get().Agents[config.AgentTask]
114		if taskAgentCfg.ID == "" {
115			return nil, fmt.Errorf("task agent not found in config")
116		}
117		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118		if err != nil {
119			return nil, fmt.Errorf("failed to create task agent: %w", err)
120		}
121
122		allTools = append(
123			allTools,
124			NewAgentTool(
125				taskAgent,
126				sessions,
127				messages,
128			),
129		)
130	}
131
132	allTools = append(allTools, otherTools...)
133	providerCfg := config.GetAgentProvider(agentCfg.ID)
134	if providerCfg.ID == "" {
135		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136	}
137	model := config.GetAgentModel(agentCfg.ID)
138
139	if model.ID == "" {
140		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141	}
142
143	promptID := agentPromptMap[agentCfg.ID]
144	if promptID == "" {
145		promptID = prompt.PromptDefault
146	}
147	opts := []provider.ProviderClientOption{
148		provider.WithModel(agentCfg.Model),
149		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150	}
151	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
152	if err != nil {
153		return nil, err
154	}
155
156	smallModelCfg := cfg.Models.Small
157	var smallModel config.Model
158
159	var smallModelProviderCfg config.ProviderConfig
160	if smallModelCfg.Provider == providerCfg.ID {
161		smallModelProviderCfg = providerCfg
162	} else {
163		for _, p := range cfg.Providers {
164			if p.ID == smallModelCfg.Provider {
165				smallModelProviderCfg = p
166				break
167			}
168		}
169		if smallModelProviderCfg.ID == "" {
170			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171		}
172	}
173	for _, m := range smallModelProviderCfg.Models {
174		if m.ID == smallModelCfg.ModelID {
175			smallModel = m
176			break
177		}
178	}
179	if smallModel.ID == "" {
180		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181	}
182
183	titleOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SmallModel),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186	}
187	titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
188	if err != nil {
189		return nil, err
190	}
191	summarizeOpts := []provider.ProviderClientOption{
192		provider.WithModel(config.SmallModel),
193		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
194	}
195	summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
196	if err != nil {
197		return nil, err
198	}
199
200	agentTools := []tools.BaseTool{}
201	if agentCfg.AllowedTools == nil {
202		agentTools = allTools
203	} else {
204		for _, tool := range allTools {
205			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
206				agentTools = append(agentTools, tool)
207			}
208		}
209	}
210
211	agent := &agent{
212		Broker:              pubsub.NewBroker[AgentEvent](),
213		agentCfg:            agentCfg,
214		provider:            agentProvider,
215		providerID:          string(providerCfg.ID),
216		messages:            messages,
217		sessions:            sessions,
218		tools:               agentTools,
219		titleProvider:       titleProvider,
220		summarizeProvider:   summarizeProvider,
221		summarizeProviderID: string(smallModelProviderCfg.ID),
222		activeRequests:      sync.Map{},
223	}
224
225	return agent, nil
226}
227
228func (a *agent) Model() config.Model {
229	return config.GetAgentModel(a.agentCfg.ID)
230}
231
232func (a *agent) Cancel(sessionID string) {
233	// Cancel regular requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240
241	// Also check for summarize requests
242	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
243		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
244			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
245			cancel()
246		}
247	}
248}
249
250func (a *agent) IsBusy() bool {
251	busy := false
252	a.activeRequests.Range(func(key, value any) bool {
253		if cancelFunc, ok := value.(context.CancelFunc); ok {
254			if cancelFunc != nil {
255				busy = true
256				return false // Stop iterating
257			}
258		}
259		return true // Continue iterating
260	})
261	return busy
262}
263
264func (a *agent) IsSessionBusy(sessionID string) bool {
265	_, busy := a.activeRequests.Load(sessionID)
266	return busy
267}
268
269func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
270	if content == "" {
271		return nil
272	}
273	if a.titleProvider == nil {
274		return nil
275	}
276	session, err := a.sessions.Get(ctx, sessionID)
277	if err != nil {
278		return err
279	}
280	parts := []message.ContentPart{message.TextContent{Text: content}}
281
282	// Use streaming approach like summarization
283	response := a.titleProvider.StreamResponse(
284		ctx,
285		[]message.Message{
286			{
287				Role:  message.User,
288				Parts: parts,
289			},
290		},
291		make([]tools.BaseTool, 0),
292	)
293
294	var finalResponse *provider.ProviderResponse
295	for r := range response {
296		if r.Error != nil {
297			return r.Error
298		}
299		finalResponse = r.Response
300	}
301
302	if finalResponse == nil {
303		return fmt.Errorf("no response received from title provider")
304	}
305
306	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307	if title == "" {
308		return nil
309	}
310
311	session.Title = title
312	_, err = a.sessions.Save(ctx, session)
313	return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317	return AgentEvent{
318		Type:  AgentEventTypeError,
319		Error: err,
320	}
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324	if !a.Model().SupportsImages && attachments != nil {
325		attachments = nil
326	}
327	events := make(chan AgentEvent)
328	if a.IsSessionBusy(sessionID) {
329		return nil, ErrSessionBusy
330	}
331
332	genCtx, cancel := context.WithCancel(ctx)
333
334	a.activeRequests.Store(sessionID, cancel)
335	go func() {
336		logging.Debug("Request started", "sessionID", sessionID)
337		defer logging.RecoverPanic("agent.Run", func() {
338			events <- a.err(fmt.Errorf("panic while running the agent"))
339		})
340		var attachmentParts []message.ContentPart
341		for _, attachment := range attachments {
342			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343		}
344		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346			logging.ErrorPersist(result.Error.Error())
347		}
348		logging.Debug("Request completed", "sessionID", sessionID)
349		a.activeRequests.Delete(sessionID)
350		cancel()
351		a.Publish(pubsub.CreatedEvent, result)
352		events <- result
353		close(events)
354	}()
355	return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359	// List existing messages; if none, start title generation asynchronously.
360	msgs, err := a.messages.List(ctx, sessionID)
361	if err != nil {
362		return a.err(fmt.Errorf("failed to list messages: %w", err))
363	}
364	if len(msgs) == 0 {
365		go func() {
366			defer logging.RecoverPanic("agent.Run", func() {
367				logging.ErrorPersist("panic while generating title")
368			})
369			titleErr := a.generateTitle(context.Background(), sessionID, content)
370			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
371				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
372			}
373		}()
374	}
375	session, err := a.sessions.Get(ctx, sessionID)
376	if err != nil {
377		return a.err(fmt.Errorf("failed to get session: %w", err))
378	}
379	if session.SummaryMessageID != "" {
380		summaryMsgInex := -1
381		for i, msg := range msgs {
382			if msg.ID == session.SummaryMessageID {
383				summaryMsgInex = i
384				break
385			}
386		}
387		if summaryMsgInex != -1 {
388			msgs = msgs[summaryMsgInex:]
389			msgs[0].Role = message.User
390		}
391	}
392
393	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
394	if err != nil {
395		return a.err(fmt.Errorf("failed to create user message: %w", err))
396	}
397	// Append the new user message to the conversation history.
398	msgHistory := append(msgs, userMsg)
399
400	for {
401		// Check for cancellation before each iteration
402		select {
403		case <-ctx.Done():
404			return a.err(ctx.Err())
405		default:
406			// Continue processing
407		}
408		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
409		if err != nil {
410			if errors.Is(err, context.Canceled) {
411				agentMessage.AddFinish(message.FinishReasonCanceled)
412				a.messages.Update(context.Background(), agentMessage)
413				return a.err(ErrRequestCancelled)
414			}
415			return a.err(fmt.Errorf("failed to process events: %w", err))
416		}
417		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
418		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
419			// We are not done, we need to respond with the tool response
420			msgHistory = append(msgHistory, agentMessage, *toolResults)
421			continue
422		}
423		return AgentEvent{
424			Type:    AgentEventTypeResponse,
425			Message: agentMessage,
426			Done:    true,
427		}
428	}
429}
430
431func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
432	parts := []message.ContentPart{message.TextContent{Text: content}}
433	parts = append(parts, attachmentParts...)
434	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
435		Role:  message.User,
436		Parts: parts,
437	})
438}
439
440func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
441	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
442
443	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:     message.Assistant,
445		Parts:    []message.ContentPart{},
446		Model:    a.Model().ID,
447		Provider: a.providerID,
448	})
449	if err != nil {
450		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451	}
452
453	// Add the session and message ID into the context if needed by tools.
454	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456
457	// Process each event in the stream.
458	for event := range eventChan {
459		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
460			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
461			return assistantMsg, nil, processErr
462		}
463		if ctx.Err() != nil {
464			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
465			return assistantMsg, nil, ctx.Err()
466		}
467	}
468
469	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
470	toolCalls := assistantMsg.ToolCalls()
471	for i, toolCall := range toolCalls {
472		select {
473		case <-ctx.Done():
474			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
475			// Make all future tool calls cancelled
476			for j := i; j < len(toolCalls); j++ {
477				toolResults[j] = message.ToolResult{
478					ToolCallID: toolCalls[j].ID,
479					Content:    "Tool execution canceled by user",
480					IsError:    true,
481				}
482			}
483			goto out
484		default:
485			// Continue processing
486			var tool tools.BaseTool
487			for _, availableTools := range a.tools {
488				if availableTools.Info().Name == toolCall.Name {
489					tool = availableTools
490				}
491			}
492
493			// Tool not found
494			if tool == nil {
495				toolResults[i] = message.ToolResult{
496					ToolCallID: toolCall.ID,
497					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
498					IsError:    true,
499				}
500				continue
501			}
502			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
503				ID:    toolCall.ID,
504				Name:  toolCall.Name,
505				Input: toolCall.Input,
506			})
507			if toolErr != nil {
508				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
509					toolResults[i] = message.ToolResult{
510						ToolCallID: toolCall.ID,
511						Content:    "Permission denied",
512						IsError:    true,
513					}
514					for j := i + 1; j < len(toolCalls); j++ {
515						toolResults[j] = message.ToolResult{
516							ToolCallID: toolCalls[j].ID,
517							Content:    "Tool execution canceled by user",
518							IsError:    true,
519						}
520					}
521					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
522					break
523				}
524			}
525			toolResults[i] = message.ToolResult{
526				ToolCallID: toolCall.ID,
527				Content:    toolResult.Content,
528				Metadata:   toolResult.Metadata,
529				IsError:    toolResult.IsError,
530			}
531		}
532	}
533out:
534	if len(toolResults) == 0 {
535		return assistantMsg, nil, nil
536	}
537	parts := make([]message.ContentPart, 0)
538	for _, tr := range toolResults {
539		parts = append(parts, tr)
540	}
541	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
542		Role:     message.Tool,
543		Parts:    parts,
544		Provider: a.providerID,
545	})
546	if err != nil {
547		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
548	}
549
550	return assistantMsg, &msg, err
551}
552
553func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
554	msg.AddFinish(finishReson)
555	_ = a.messages.Update(ctx, *msg)
556}
557
558func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
559	select {
560	case <-ctx.Done():
561		return ctx.Err()
562	default:
563		// Continue processing.
564	}
565
566	switch event.Type {
567	case provider.EventThinkingDelta:
568		assistantMsg.AppendReasoningContent(event.Content)
569		return a.messages.Update(ctx, *assistantMsg)
570	case provider.EventContentDelta:
571		assistantMsg.AppendContent(event.Content)
572		return a.messages.Update(ctx, *assistantMsg)
573	case provider.EventToolUseStart:
574		logging.Info("Tool call started", "toolCall", event.ToolCall)
575		assistantMsg.AddToolCall(*event.ToolCall)
576		return a.messages.Update(ctx, *assistantMsg)
577	case provider.EventToolUseDelta:
578		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
579		return a.messages.Update(ctx, *assistantMsg)
580	case provider.EventToolUseStop:
581		logging.Info("Finished tool call", "toolCall", event.ToolCall)
582		assistantMsg.FinishToolCall(event.ToolCall.ID)
583		return a.messages.Update(ctx, *assistantMsg)
584	case provider.EventError:
585		if errors.Is(event.Error, context.Canceled) {
586			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
587			return context.Canceled
588		}
589		logging.ErrorPersist(event.Error.Error())
590		return event.Error
591	case provider.EventComplete:
592		assistantMsg.SetToolCalls(event.Response.ToolCalls)
593		assistantMsg.AddFinish(event.Response.FinishReason)
594		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
595			return fmt.Errorf("failed to update message: %w", err)
596		}
597		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
598	}
599
600	return nil
601}
602
603func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
604	sess, err := a.sessions.Get(ctx, sessionID)
605	if err != nil {
606		return fmt.Errorf("failed to get session: %w", err)
607	}
608
609	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
610		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
611		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
612		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
613
614	sess.Cost += cost
615	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
616	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
617
618	_, err = a.sessions.Save(ctx, sess)
619	if err != nil {
620		return fmt.Errorf("failed to save session: %w", err)
621	}
622	return nil
623}
624
625func (a *agent) Summarize(ctx context.Context, sessionID string) error {
626	if a.summarizeProvider == nil {
627		return fmt.Errorf("summarize provider not available")
628	}
629
630	// Check if session is busy
631	if a.IsSessionBusy(sessionID) {
632		return ErrSessionBusy
633	}
634
635	// Create a new context with cancellation
636	summarizeCtx, cancel := context.WithCancel(ctx)
637
638	// Store the cancel function in activeRequests to allow cancellation
639	a.activeRequests.Store(sessionID+"-summarize", cancel)
640
641	go func() {
642		defer a.activeRequests.Delete(sessionID + "-summarize")
643		defer cancel()
644		event := AgentEvent{
645			Type:     AgentEventTypeSummarize,
646			Progress: "Starting summarization...",
647		}
648
649		a.Publish(pubsub.CreatedEvent, event)
650		// Get all messages from the session
651		msgs, err := a.messages.List(summarizeCtx, sessionID)
652		if err != nil {
653			event = AgentEvent{
654				Type:  AgentEventTypeError,
655				Error: fmt.Errorf("failed to list messages: %w", err),
656				Done:  true,
657			}
658			a.Publish(pubsub.CreatedEvent, event)
659			return
660		}
661
662		if len(msgs) == 0 {
663			event = AgentEvent{
664				Type:  AgentEventTypeError,
665				Error: fmt.Errorf("no messages to summarize"),
666				Done:  true,
667			}
668			a.Publish(pubsub.CreatedEvent, event)
669			return
670		}
671
672		event = AgentEvent{
673			Type:     AgentEventTypeSummarize,
674			Progress: "Analyzing conversation...",
675		}
676		a.Publish(pubsub.CreatedEvent, event)
677
678		// Add a system message to guide the summarization
679		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
680
681		// Create a new message with the summarize prompt
682		promptMsg := message.Message{
683			Role:  message.User,
684			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
685		}
686
687		// Append the prompt to the messages
688		msgsWithPrompt := append(msgs, promptMsg)
689
690		event = AgentEvent{
691			Type:     AgentEventTypeSummarize,
692			Progress: "Generating summary...",
693		}
694
695		a.Publish(pubsub.CreatedEvent, event)
696
697		// Send the messages to the summarize provider
698		response := a.summarizeProvider.StreamResponse(
699			summarizeCtx,
700			msgsWithPrompt,
701			make([]tools.BaseTool, 0),
702		)
703		var finalResponse *provider.ProviderResponse
704		for r := range response {
705			if r.Error != nil {
706				event = AgentEvent{
707					Type:  AgentEventTypeError,
708					Error: fmt.Errorf("failed to summarize: %w", err),
709					Done:  true,
710				}
711				a.Publish(pubsub.CreatedEvent, event)
712				return
713			}
714			finalResponse = r.Response
715		}
716
717		summary := strings.TrimSpace(finalResponse.Content)
718		if summary == "" {
719			event = AgentEvent{
720				Type:  AgentEventTypeError,
721				Error: fmt.Errorf("empty summary returned"),
722				Done:  true,
723			}
724			a.Publish(pubsub.CreatedEvent, event)
725			return
726		}
727		event = AgentEvent{
728			Type:     AgentEventTypeSummarize,
729			Progress: "Creating new session...",
730		}
731
732		a.Publish(pubsub.CreatedEvent, event)
733		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
734		if err != nil {
735			event = AgentEvent{
736				Type:  AgentEventTypeError,
737				Error: fmt.Errorf("failed to get session: %w", err),
738				Done:  true,
739			}
740
741			a.Publish(pubsub.CreatedEvent, event)
742			return
743		}
744		// Create a message in the new session with the summary
745		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
746			Role: message.Assistant,
747			Parts: []message.ContentPart{
748				message.TextContent{Text: summary},
749				message.Finish{
750					Reason: message.FinishReasonEndTurn,
751					Time:   time.Now().Unix(),
752				},
753			},
754			Model:    a.summarizeProvider.Model().ID,
755			Provider: a.summarizeProviderID,
756		})
757		if err != nil {
758			event = AgentEvent{
759				Type:  AgentEventTypeError,
760				Error: fmt.Errorf("failed to create summary message: %w", err),
761				Done:  true,
762			}
763
764			a.Publish(pubsub.CreatedEvent, event)
765			return
766		}
767		oldSession.SummaryMessageID = msg.ID
768		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
769		oldSession.PromptTokens = 0
770		model := a.summarizeProvider.Model()
771		usage := finalResponse.Usage
772		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
773			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
774			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
775			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
776		oldSession.Cost += cost
777		_, err = a.sessions.Save(summarizeCtx, oldSession)
778		if err != nil {
779			event = AgentEvent{
780				Type:  AgentEventTypeError,
781				Error: fmt.Errorf("failed to save session: %w", err),
782				Done:  true,
783			}
784			a.Publish(pubsub.CreatedEvent, event)
785		}
786
787		event = AgentEvent{
788			Type:      AgentEventTypeSummarize,
789			SessionID: oldSession.ID,
790			Progress:  "Summary complete",
791			Done:      true,
792		}
793		a.Publish(pubsub.CreatedEvent, event)
794		// Send final success event with the new session ID
795	}()
796
797	return nil
798}
799
800func (a *agent) CancelAll() {
801	a.activeRequests.Range(func(key, value any) bool {
802		a.Cancel(key.(string)) // key is sessionID
803		return true
804	})
805}
806
807func (a *agent) UpdateModel() error {
808	cfg := config.Get()
809
810	// Get current provider configuration
811	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
812	if currentProviderCfg.ID == "" {
813		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
814	}
815
816	// Check if provider has changed
817	if string(currentProviderCfg.ID) != a.providerID {
818		// Provider changed, need to recreate the main provider
819		model := config.GetAgentModel(a.agentCfg.ID)
820		if model.ID == "" {
821			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
822		}
823
824		promptID := agentPromptMap[a.agentCfg.ID]
825		if promptID == "" {
826			promptID = prompt.PromptDefault
827		}
828
829		opts := []provider.ProviderClientOption{
830			provider.WithModel(a.agentCfg.Model),
831			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
832		}
833
834		newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
835		if err != nil {
836			return fmt.Errorf("failed to create new provider: %w", err)
837		}
838
839		// Update the provider and provider ID
840		a.provider = newProvider
841		a.providerID = string(currentProviderCfg.ID)
842	}
843
844	// Check if small model provider has changed (affects title and summarize providers)
845	smallModelCfg := cfg.Models.Small
846	var smallModelProviderCfg config.ProviderConfig
847
848	for _, p := range cfg.Providers {
849		if p.ID == smallModelCfg.Provider {
850			smallModelProviderCfg = p
851			break
852		}
853	}
854
855	if smallModelProviderCfg.ID == "" {
856		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
857	}
858
859	// Check if summarize provider has changed
860	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
861		var smallModel config.Model
862		for _, m := range smallModelProviderCfg.Models {
863			if m.ID == smallModelCfg.ModelID {
864				smallModel = m
865				break
866			}
867		}
868		if smallModel.ID == "" {
869			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
870		}
871
872		// Recreate title provider
873		titleOpts := []provider.ProviderClientOption{
874			provider.WithModel(config.SmallModel),
875			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
876		}
877		newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
878		if err != nil {
879			return fmt.Errorf("failed to create new title provider: %w", err)
880		}
881
882		// Recreate summarize provider
883		summarizeOpts := []provider.ProviderClientOption{
884			provider.WithModel(config.SmallModel),
885			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
886		}
887		newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
888		if err != nil {
889			return fmt.Errorf("failed to create new summarize provider: %w", err)
890		}
891
892		// Update the providers and provider ID
893		a.titleProvider = newTitleProvider
894		a.summarizeProvider = newSummarizeProvider
895		a.summarizeProviderID = string(smallModelProviderCfg.ID)
896	}
897
898	return nil
899}