agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request cancelled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMcpTools(ctx, permissions)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	allTools := []tools.BaseTool{
103		tools.NewBashTool(permissions),
104		tools.NewEditTool(lspClients, permissions, history),
105		tools.NewFetchTool(permissions),
106		tools.NewGlobTool(),
107		tools.NewGrepTool(),
108		tools.NewLsTool(),
109		tools.NewSourcegraphTool(),
110		tools.NewViewTool(lspClients),
111		tools.NewWriteTool(lspClients, permissions, history),
112	}
113
114	if agentCfg.ID == "coder" {
115		taskAgentCfg := config.Get().Agents["task"]
116		if taskAgentCfg.ID == "" {
117			return nil, fmt.Errorf("task agent not found in config")
118		}
119		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120		if err != nil {
121			return nil, fmt.Errorf("failed to create task agent: %w", err)
122		}
123
124		allTools = append(
125			allTools,
126			NewAgentTool(
127				taskAgent,
128				sessions,
129				messages,
130			),
131		)
132	}
133
134	allTools = append(allTools, otherTools...)
135	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
136	if providerCfg == nil {
137		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138	}
139	model := config.Get().GetModelByType(agentCfg.Model)
140
141	if model == nil {
142		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143	}
144
145	promptID := agentPromptMap[agentCfg.ID]
146	if promptID == "" {
147		promptID = prompt.PromptDefault
148	}
149	opts := []provider.ProviderClientOption{
150		provider.WithModel(agentCfg.Model),
151		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152	}
153	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
154	if err != nil {
155		return nil, err
156	}
157
158	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
159	var smallModelProviderCfg *config.ProviderConfig
160	if smallModelCfg.Provider == providerCfg.ID {
161		smallModelProviderCfg = providerCfg
162	} else {
163		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
164
165		if smallModelProviderCfg.ID == "" {
166			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
167		}
168	}
169	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
170	if smallModel.ID == "" {
171		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
172	}
173
174	titleOpts := []provider.ProviderClientOption{
175		provider.WithModel(config.SelectedModelTypeSmall),
176		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
177	}
178	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
179	if err != nil {
180		return nil, err
181	}
182	summarizeOpts := []provider.ProviderClientOption{
183		provider.WithModel(config.SelectedModelTypeSmall),
184		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
185	}
186	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
187	if err != nil {
188		return nil, err
189	}
190
191	agentTools := []tools.BaseTool{}
192	if agentCfg.AllowedTools == nil {
193		agentTools = allTools
194	} else {
195		for _, tool := range allTools {
196			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
197				agentTools = append(agentTools, tool)
198			}
199		}
200	}
201
202	agent := &agent{
203		Broker:              pubsub.NewBroker[AgentEvent](),
204		agentCfg:            agentCfg,
205		provider:            agentProvider,
206		providerID:          string(providerCfg.ID),
207		messages:            messages,
208		sessions:            sessions,
209		tools:               agentTools,
210		titleProvider:       titleProvider,
211		summarizeProvider:   summarizeProvider,
212		summarizeProviderID: string(smallModelProviderCfg.ID),
213		activeRequests:      sync.Map{},
214	}
215
216	return agent, nil
217}
218
219func (a *agent) Model() fur.Model {
220	return *config.Get().GetModelByType(a.agentCfg.Model)
221}
222
223func (a *agent) Cancel(sessionID string) {
224	// Cancel regular requests
225	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
226		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
227			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
228			cancel()
229		}
230	}
231
232	// Also check for summarize requests
233	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
234		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
235			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
236			cancel()
237		}
238	}
239}
240
241func (a *agent) IsBusy() bool {
242	busy := false
243	a.activeRequests.Range(func(key, value any) bool {
244		if cancelFunc, ok := value.(context.CancelFunc); ok {
245			if cancelFunc != nil {
246				busy = true
247				return false
248			}
249		}
250		return true
251	})
252	return busy
253}
254
255func (a *agent) IsSessionBusy(sessionID string) bool {
256	_, busy := a.activeRequests.Load(sessionID)
257	return busy
258}
259
260func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
261	if content == "" {
262		return nil
263	}
264	if a.titleProvider == nil {
265		return nil
266	}
267	session, err := a.sessions.Get(ctx, sessionID)
268	if err != nil {
269		return err
270	}
271	parts := []message.ContentPart{message.TextContent{
272		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
273	}}
274
275	// Use streaming approach like summarization
276	response := a.titleProvider.StreamResponse(
277		ctx,
278		[]message.Message{
279			{
280				Role:  message.User,
281				Parts: parts,
282			},
283		},
284		make([]tools.BaseTool, 0),
285	)
286
287	var finalResponse *provider.ProviderResponse
288	for r := range response {
289		if r.Error != nil {
290			return r.Error
291		}
292		finalResponse = r.Response
293	}
294
295	if finalResponse == nil {
296		return fmt.Errorf("no response received from title provider")
297	}
298
299	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
300	if title == "" {
301		return nil
302	}
303
304	session.Title = title
305	_, err = a.sessions.Save(ctx, session)
306	return err
307}
308
309func (a *agent) err(err error) AgentEvent {
310	return AgentEvent{
311		Type:  AgentEventTypeError,
312		Error: err,
313	}
314}
315
316func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
317	if !a.Model().SupportsImages && attachments != nil {
318		attachments = nil
319	}
320	events := make(chan AgentEvent)
321	if a.IsSessionBusy(sessionID) {
322		return nil, ErrSessionBusy
323	}
324
325	genCtx, cancel := context.WithCancel(ctx)
326
327	a.activeRequests.Store(sessionID, cancel)
328	go func() {
329		slog.Debug("Request started", "sessionID", sessionID)
330		defer log.RecoverPanic("agent.Run", func() {
331			events <- a.err(fmt.Errorf("panic while running the agent"))
332		})
333		var attachmentParts []message.ContentPart
334		for _, attachment := range attachments {
335			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
336		}
337		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
338		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
339			slog.Error(result.Error.Error())
340		}
341		slog.Debug("Request completed", "sessionID", sessionID)
342		a.activeRequests.Delete(sessionID)
343		cancel()
344		a.Publish(pubsub.CreatedEvent, result)
345		events <- result
346		close(events)
347	}()
348	return events, nil
349}
350
351func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
352	cfg := config.Get()
353	// List existing messages; if none, start title generation asynchronously.
354	msgs, err := a.messages.List(ctx, sessionID)
355	if err != nil {
356		return a.err(fmt.Errorf("failed to list messages: %w", err))
357	}
358	if len(msgs) == 0 {
359		go func() {
360			defer log.RecoverPanic("agent.Run", func() {
361				slog.Error("panic while generating title")
362			})
363			titleErr := a.generateTitle(context.Background(), sessionID, content)
364			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
365				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
366			}
367		}()
368	}
369	session, err := a.sessions.Get(ctx, sessionID)
370	if err != nil {
371		return a.err(fmt.Errorf("failed to get session: %w", err))
372	}
373	if session.SummaryMessageID != "" {
374		summaryMsgInex := -1
375		for i, msg := range msgs {
376			if msg.ID == session.SummaryMessageID {
377				summaryMsgInex = i
378				break
379			}
380		}
381		if summaryMsgInex != -1 {
382			msgs = msgs[summaryMsgInex:]
383			msgs[0].Role = message.User
384		}
385	}
386
387	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
388	if err != nil {
389		return a.err(fmt.Errorf("failed to create user message: %w", err))
390	}
391	// Append the new user message to the conversation history.
392	msgHistory := append(msgs, userMsg)
393
394	for {
395		// Check for cancellation before each iteration
396		select {
397		case <-ctx.Done():
398			return a.err(ctx.Err())
399		default:
400			// Continue processing
401		}
402		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
403		if err != nil {
404			if errors.Is(err, context.Canceled) {
405				agentMessage.AddFinish(message.FinishReasonCanceled)
406				a.messages.Update(context.Background(), agentMessage)
407				return a.err(ErrRequestCancelled)
408			}
409			return a.err(fmt.Errorf("failed to process events: %w", err))
410		}
411		if cfg.Options.Debug {
412			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
413		}
414		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
415			// We are not done, we need to respond with the tool response
416			msgHistory = append(msgHistory, agentMessage, *toolResults)
417			continue
418		}
419		return AgentEvent{
420			Type:    AgentEventTypeResponse,
421			Message: agentMessage,
422			Done:    true,
423		}
424	}
425}
426
427func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
428	parts := []message.ContentPart{message.TextContent{Text: content}}
429	parts = append(parts, attachmentParts...)
430	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
431		Role:  message.User,
432		Parts: parts,
433	})
434}
435
436func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
437	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
438	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
439
440	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441		Role:     message.Assistant,
442		Parts:    []message.ContentPart{},
443		Model:    a.Model().ID,
444		Provider: a.providerID,
445	})
446	if err != nil {
447		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
448	}
449
450	// Add the session and message ID into the context if needed by tools.
451	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
452
453	// Process each event in the stream.
454	for event := range eventChan {
455		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
456			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
457			return assistantMsg, nil, processErr
458		}
459		if ctx.Err() != nil {
460			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
461			return assistantMsg, nil, ctx.Err()
462		}
463	}
464
465	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
466	toolCalls := assistantMsg.ToolCalls()
467	for i, toolCall := range toolCalls {
468		select {
469		case <-ctx.Done():
470			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
471			// Make all future tool calls cancelled
472			for j := i; j < len(toolCalls); j++ {
473				toolResults[j] = message.ToolResult{
474					ToolCallID: toolCalls[j].ID,
475					Content:    "Tool execution canceled by user",
476					IsError:    true,
477				}
478			}
479			goto out
480		default:
481			// Continue processing
482			var tool tools.BaseTool
483			for _, availableTool := range a.tools {
484				if availableTool.Info().Name == toolCall.Name {
485					tool = availableTool
486					break
487				}
488			}
489
490			// Tool not found
491			if tool == nil {
492				toolResults[i] = message.ToolResult{
493					ToolCallID: toolCall.ID,
494					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
495					IsError:    true,
496				}
497				continue
498			}
499			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
500				ID:    toolCall.ID,
501				Name:  toolCall.Name,
502				Input: toolCall.Input,
503			})
504			if toolErr != nil {
505				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
506					toolResults[i] = message.ToolResult{
507						ToolCallID: toolCall.ID,
508						Content:    "Permission denied",
509						IsError:    true,
510					}
511					for j := i + 1; j < len(toolCalls); j++ {
512						toolResults[j] = message.ToolResult{
513							ToolCallID: toolCalls[j].ID,
514							Content:    "Tool execution canceled by user",
515							IsError:    true,
516						}
517					}
518					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
519					break
520				}
521			}
522			toolResults[i] = message.ToolResult{
523				ToolCallID: toolCall.ID,
524				Content:    toolResult.Content,
525				Metadata:   toolResult.Metadata,
526				IsError:    toolResult.IsError,
527			}
528		}
529	}
530out:
531	if len(toolResults) == 0 {
532		return assistantMsg, nil, nil
533	}
534	parts := make([]message.ContentPart, 0)
535	for _, tr := range toolResults {
536		parts = append(parts, tr)
537	}
538	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
539		Role:     message.Tool,
540		Parts:    parts,
541		Provider: a.providerID,
542	})
543	if err != nil {
544		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
545	}
546
547	return assistantMsg, &msg, err
548}
549
550func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
551	msg.AddFinish(finishReson)
552	_ = a.messages.Update(ctx, *msg)
553}
554
555func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
556	select {
557	case <-ctx.Done():
558		return ctx.Err()
559	default:
560		// Continue processing.
561	}
562
563	switch event.Type {
564	case provider.EventThinkingDelta:
565		assistantMsg.AppendReasoningContent(event.Content)
566		return a.messages.Update(ctx, *assistantMsg)
567	case provider.EventContentDelta:
568		assistantMsg.AppendContent(event.Content)
569		return a.messages.Update(ctx, *assistantMsg)
570	case provider.EventToolUseStart:
571		slog.Info("Tool call started", "toolCall", event.ToolCall)
572		assistantMsg.AddToolCall(*event.ToolCall)
573		return a.messages.Update(ctx, *assistantMsg)
574	case provider.EventToolUseDelta:
575		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
576		return a.messages.Update(ctx, *assistantMsg)
577	case provider.EventToolUseStop:
578		slog.Info("Finished tool call", "toolCall", event.ToolCall)
579		assistantMsg.FinishToolCall(event.ToolCall.ID)
580		return a.messages.Update(ctx, *assistantMsg)
581	case provider.EventError:
582		if errors.Is(event.Error, context.Canceled) {
583			slog.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
584			return context.Canceled
585		}
586		slog.Error(event.Error.Error())
587		return event.Error
588	case provider.EventComplete:
589		assistantMsg.SetToolCalls(event.Response.ToolCalls)
590		assistantMsg.AddFinish(event.Response.FinishReason)
591		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
592			return fmt.Errorf("failed to update message: %w", err)
593		}
594		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
595	}
596
597	return nil
598}
599
600func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
601	sess, err := a.sessions.Get(ctx, sessionID)
602	if err != nil {
603		return fmt.Errorf("failed to get session: %w", err)
604	}
605
606	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
607		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
608		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
609		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
610
611	sess.Cost += cost
612	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
613	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
614
615	_, err = a.sessions.Save(ctx, sess)
616	if err != nil {
617		return fmt.Errorf("failed to save session: %w", err)
618	}
619	return nil
620}
621
622func (a *agent) Summarize(ctx context.Context, sessionID string) error {
623	if a.summarizeProvider == nil {
624		return fmt.Errorf("summarize provider not available")
625	}
626
627	// Check if session is busy
628	if a.IsSessionBusy(sessionID) {
629		return ErrSessionBusy
630	}
631
632	// Create a new context with cancellation
633	summarizeCtx, cancel := context.WithCancel(ctx)
634
635	// Store the cancel function in activeRequests to allow cancellation
636	a.activeRequests.Store(sessionID+"-summarize", cancel)
637
638	go func() {
639		defer a.activeRequests.Delete(sessionID + "-summarize")
640		defer cancel()
641		event := AgentEvent{
642			Type:     AgentEventTypeSummarize,
643			Progress: "Starting summarization...",
644		}
645
646		a.Publish(pubsub.CreatedEvent, event)
647		// Get all messages from the session
648		msgs, err := a.messages.List(summarizeCtx, sessionID)
649		if err != nil {
650			event = AgentEvent{
651				Type:  AgentEventTypeError,
652				Error: fmt.Errorf("failed to list messages: %w", err),
653				Done:  true,
654			}
655			a.Publish(pubsub.CreatedEvent, event)
656			return
657		}
658		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
659
660		if len(msgs) == 0 {
661			event = AgentEvent{
662				Type:  AgentEventTypeError,
663				Error: fmt.Errorf("no messages to summarize"),
664				Done:  true,
665			}
666			a.Publish(pubsub.CreatedEvent, event)
667			return
668		}
669
670		event = AgentEvent{
671			Type:     AgentEventTypeSummarize,
672			Progress: "Analyzing conversation...",
673		}
674		a.Publish(pubsub.CreatedEvent, event)
675
676		// Add a system message to guide the summarization
677		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
678
679		// Create a new message with the summarize prompt
680		promptMsg := message.Message{
681			Role:  message.User,
682			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
683		}
684
685		// Append the prompt to the messages
686		msgsWithPrompt := append(msgs, promptMsg)
687
688		event = AgentEvent{
689			Type:     AgentEventTypeSummarize,
690			Progress: "Generating summary...",
691		}
692
693		a.Publish(pubsub.CreatedEvent, event)
694
695		// Send the messages to the summarize provider
696		response := a.summarizeProvider.StreamResponse(
697			summarizeCtx,
698			msgsWithPrompt,
699			make([]tools.BaseTool, 0),
700		)
701		var finalResponse *provider.ProviderResponse
702		for r := range response {
703			if r.Error != nil {
704				event = AgentEvent{
705					Type:  AgentEventTypeError,
706					Error: fmt.Errorf("failed to summarize: %w", err),
707					Done:  true,
708				}
709				a.Publish(pubsub.CreatedEvent, event)
710				return
711			}
712			finalResponse = r.Response
713		}
714
715		summary := strings.TrimSpace(finalResponse.Content)
716		if summary == "" {
717			event = AgentEvent{
718				Type:  AgentEventTypeError,
719				Error: fmt.Errorf("empty summary returned"),
720				Done:  true,
721			}
722			a.Publish(pubsub.CreatedEvent, event)
723			return
724		}
725		event = AgentEvent{
726			Type:     AgentEventTypeSummarize,
727			Progress: "Creating new session...",
728		}
729
730		a.Publish(pubsub.CreatedEvent, event)
731		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
732		if err != nil {
733			event = AgentEvent{
734				Type:  AgentEventTypeError,
735				Error: fmt.Errorf("failed to get session: %w", err),
736				Done:  true,
737			}
738
739			a.Publish(pubsub.CreatedEvent, event)
740			return
741		}
742		// Create a message in the new session with the summary
743		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
744			Role: message.Assistant,
745			Parts: []message.ContentPart{
746				message.TextContent{Text: summary},
747				message.Finish{
748					Reason: message.FinishReasonEndTurn,
749					Time:   time.Now().Unix(),
750				},
751			},
752			Model:    a.summarizeProvider.Model().ID,
753			Provider: a.summarizeProviderID,
754		})
755		if err != nil {
756			event = AgentEvent{
757				Type:  AgentEventTypeError,
758				Error: fmt.Errorf("failed to create summary message: %w", err),
759				Done:  true,
760			}
761
762			a.Publish(pubsub.CreatedEvent, event)
763			return
764		}
765		oldSession.SummaryMessageID = msg.ID
766		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
767		oldSession.PromptTokens = 0
768		model := a.summarizeProvider.Model()
769		usage := finalResponse.Usage
770		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
771			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
772			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
773			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
774		oldSession.Cost += cost
775		_, err = a.sessions.Save(summarizeCtx, oldSession)
776		if err != nil {
777			event = AgentEvent{
778				Type:  AgentEventTypeError,
779				Error: fmt.Errorf("failed to save session: %w", err),
780				Done:  true,
781			}
782			a.Publish(pubsub.CreatedEvent, event)
783		}
784
785		event = AgentEvent{
786			Type:      AgentEventTypeSummarize,
787			SessionID: oldSession.ID,
788			Progress:  "Summary complete",
789			Done:      true,
790		}
791		a.Publish(pubsub.CreatedEvent, event)
792		// Send final success event with the new session ID
793	}()
794
795	return nil
796}
797
798func (a *agent) CancelAll() {
799	a.activeRequests.Range(func(key, value any) bool {
800		a.Cancel(key.(string)) // key is sessionID
801		return true
802	})
803}
804
805func (a *agent) UpdateModel() error {
806	cfg := config.Get()
807
808	// Get current provider configuration
809	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
810	if currentProviderCfg.ID == "" {
811		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
812	}
813
814	// Check if provider has changed
815	if string(currentProviderCfg.ID) != a.providerID {
816		// Provider changed, need to recreate the main provider
817		model := cfg.GetModelByType(a.agentCfg.Model)
818		if model.ID == "" {
819			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
820		}
821
822		promptID := agentPromptMap[a.agentCfg.ID]
823		if promptID == "" {
824			promptID = prompt.PromptDefault
825		}
826
827		opts := []provider.ProviderClientOption{
828			provider.WithModel(a.agentCfg.Model),
829			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
830		}
831
832		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
833		if err != nil {
834			return fmt.Errorf("failed to create new provider: %w", err)
835		}
836
837		// Update the provider and provider ID
838		a.provider = newProvider
839		a.providerID = string(currentProviderCfg.ID)
840	}
841
842	// Check if small model provider has changed (affects title and summarize providers)
843	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
844	var smallModelProviderCfg config.ProviderConfig
845
846	for _, p := range cfg.Providers {
847		if p.ID == smallModelCfg.Provider {
848			smallModelProviderCfg = p
849			break
850		}
851	}
852
853	if smallModelProviderCfg.ID == "" {
854		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
855	}
856
857	// Check if summarize provider has changed
858	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
859		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
860		if smallModel == nil {
861			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
862		}
863
864		// Recreate title provider
865		titleOpts := []provider.ProviderClientOption{
866			provider.WithModel(config.SelectedModelTypeSmall),
867			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
868			// We want the title to be short, so we limit the max tokens
869			provider.WithMaxTokens(40),
870		}
871		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
872		if err != nil {
873			return fmt.Errorf("failed to create new title provider: %w", err)
874		}
875
876		// Recreate summarize provider
877		summarizeOpts := []provider.ProviderClientOption{
878			provider.WithModel(config.SelectedModelTypeSmall),
879			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
880		}
881		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
882		if err != nil {
883			return fmt.Errorf("failed to create new summarize provider: %w", err)
884		}
885
886		// Update the providers and provider ID
887		a.titleProvider = newTitleProvider
888		a.summarizeProvider = newSummarizeProvider
889		a.summarizeProviderID = string(smallModelProviderCfg.ID)
890	}
891
892	return nil
893}