agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request canceled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMcpTools(ctx, permissions, cfg)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	cwd := cfg.WorkingDir()
103	allTools := []tools.BaseTool{
104		tools.NewBashTool(permissions, cwd),
105		tools.NewEditTool(lspClients, permissions, history, cwd),
106		tools.NewFetchTool(permissions, cwd),
107		tools.NewGlobTool(cwd),
108		tools.NewGrepTool(cwd),
109		tools.NewLsTool(cwd),
110		tools.NewSourcegraphTool(),
111		tools.NewViewTool(lspClients, cwd),
112		tools.NewWriteTool(lspClients, permissions, history, cwd),
113	}
114
115	if agentCfg.ID == "coder" {
116		taskAgentCfg := config.Get().Agents["task"]
117		if taskAgentCfg.ID == "" {
118			return nil, fmt.Errorf("task agent not found in config")
119		}
120		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121		if err != nil {
122			return nil, fmt.Errorf("failed to create task agent: %w", err)
123		}
124
125		allTools = append(
126			allTools,
127			NewAgentTool(
128				taskAgent,
129				sessions,
130				messages,
131			),
132		)
133	}
134
135	allTools = append(allTools, otherTools...)
136	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137	if providerCfg == nil {
138		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139	}
140	model := config.Get().GetModelByType(agentCfg.Model)
141
142	if model == nil {
143		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144	}
145
146	promptID := agentPromptMap[agentCfg.ID]
147	if promptID == "" {
148		promptID = prompt.PromptDefault
149	}
150	opts := []provider.ProviderClientOption{
151		provider.WithModel(agentCfg.Model),
152		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153	}
154	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155	if err != nil {
156		return nil, err
157	}
158
159	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160	var smallModelProviderCfg *config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166		if smallModelProviderCfg.ID == "" {
167			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168		}
169	}
170	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171	if smallModel.ID == "" {
172		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173	}
174
175	titleOpts := []provider.ProviderClientOption{
176		provider.WithModel(config.SelectedModelTypeSmall),
177		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178	}
179	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180	if err != nil {
181		return nil, err
182	}
183	summarizeOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SelectedModelTypeSmall),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186	}
187	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188	if err != nil {
189		return nil, err
190	}
191
192	agentTools := []tools.BaseTool{}
193	if agentCfg.AllowedTools == nil {
194		agentTools = allTools
195	} else {
196		for _, tool := range allTools {
197			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198				agentTools = append(agentTools, tool)
199			}
200		}
201	}
202
203	agent := &agent{
204		Broker:              pubsub.NewBroker[AgentEvent](),
205		agentCfg:            agentCfg,
206		provider:            agentProvider,
207		providerID:          string(providerCfg.ID),
208		messages:            messages,
209		sessions:            sessions,
210		tools:               agentTools,
211		titleProvider:       titleProvider,
212		summarizeProvider:   summarizeProvider,
213		summarizeProviderID: string(smallModelProviderCfg.ID),
214		activeRequests:      sync.Map{},
215	}
216
217	return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221	return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225	// Cancel regular requests
226	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229			cancel()
230		}
231	}
232
233	// Also check for summarize requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240}
241
242func (a *agent) IsBusy() bool {
243	busy := false
244	a.activeRequests.Range(func(key, value any) bool {
245		if cancelFunc, ok := value.(context.CancelFunc); ok {
246			if cancelFunc != nil {
247				busy = true
248				return false
249			}
250		}
251		return true
252	})
253	return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257	_, busy := a.activeRequests.Load(sessionID)
258	return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262	if content == "" {
263		return nil
264	}
265	if a.titleProvider == nil {
266		return nil
267	}
268	session, err := a.sessions.Get(ctx, sessionID)
269	if err != nil {
270		return err
271	}
272	parts := []message.ContentPart{message.TextContent{
273		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274	}}
275
276	// Use streaming approach like summarization
277	response := a.titleProvider.StreamResponse(
278		ctx,
279		[]message.Message{
280			{
281				Role:  message.User,
282				Parts: parts,
283			},
284		},
285		make([]tools.BaseTool, 0),
286	)
287
288	var finalResponse *provider.ProviderResponse
289	for r := range response {
290		if r.Error != nil {
291			return r.Error
292		}
293		finalResponse = r.Response
294	}
295
296	if finalResponse == nil {
297		return fmt.Errorf("no response received from title provider")
298	}
299
300	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301	if title == "" {
302		return nil
303	}
304
305	session.Title = title
306	_, err = a.sessions.Save(ctx, session)
307	return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311	return AgentEvent{
312		Type:  AgentEventTypeError,
313		Error: err,
314	}
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318	if !a.Model().SupportsImages && attachments != nil {
319		attachments = nil
320	}
321	events := make(chan AgentEvent)
322	if a.IsSessionBusy(sessionID) {
323		return nil, ErrSessionBusy
324	}
325
326	genCtx, cancel := context.WithCancel(ctx)
327
328	a.activeRequests.Store(sessionID, cancel)
329	go func() {
330		slog.Debug("Request started", "sessionID", sessionID)
331		defer log.RecoverPanic("agent.Run", func() {
332			events <- a.err(fmt.Errorf("panic while running the agent"))
333		})
334		var attachmentParts []message.ContentPart
335		for _, attachment := range attachments {
336			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337		}
338		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340			slog.Error(result.Error.Error())
341		}
342		slog.Debug("Request completed", "sessionID", sessionID)
343		a.activeRequests.Delete(sessionID)
344		cancel()
345		a.Publish(pubsub.CreatedEvent, result)
346		events <- result
347		close(events)
348	}()
349	return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353	cfg := config.Get()
354	// List existing messages; if none, start title generation asynchronously.
355	msgs, err := a.messages.List(ctx, sessionID)
356	if err != nil {
357		return a.err(fmt.Errorf("failed to list messages: %w", err))
358	}
359	if len(msgs) == 0 {
360		go func() {
361			defer log.RecoverPanic("agent.Run", func() {
362				slog.Error("panic while generating title")
363			})
364			titleErr := a.generateTitle(context.Background(), sessionID, content)
365			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367			}
368		}()
369	}
370	session, err := a.sessions.Get(ctx, sessionID)
371	if err != nil {
372		return a.err(fmt.Errorf("failed to get session: %w", err))
373	}
374	if session.SummaryMessageID != "" {
375		summaryMsgInex := -1
376		for i, msg := range msgs {
377			if msg.ID == session.SummaryMessageID {
378				summaryMsgInex = i
379				break
380			}
381		}
382		if summaryMsgInex != -1 {
383			msgs = msgs[summaryMsgInex:]
384			msgs[0].Role = message.User
385		}
386	}
387
388	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389	if err != nil {
390		return a.err(fmt.Errorf("failed to create user message: %w", err))
391	}
392	// Append the new user message to the conversation history.
393	msgHistory := append(msgs, userMsg)
394
395	for {
396		// Check for cancellation before each iteration
397		select {
398		case <-ctx.Done():
399			return a.err(ctx.Err())
400		default:
401			// Continue processing
402		}
403		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404		if err != nil {
405			if errors.Is(err, context.Canceled) {
406				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407				a.messages.Update(context.Background(), agentMessage)
408				return a.err(ErrRequestCancelled)
409			}
410			return a.err(fmt.Errorf("failed to process events: %w", err))
411		}
412		if cfg.Options.Debug {
413			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414		}
415		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416			// We are not done, we need to respond with the tool response
417			msgHistory = append(msgHistory, agentMessage, *toolResults)
418			continue
419		}
420		return AgentEvent{
421			Type:    AgentEventTypeResponse,
422			Message: agentMessage,
423			Done:    true,
424		}
425	}
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429	parts := []message.ContentPart{message.TextContent{Text: content}}
430	parts = append(parts, attachmentParts...)
431	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432		Role:  message.User,
433		Parts: parts,
434	})
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:     message.Assistant,
443		Parts:    []message.ContentPart{},
444		Model:    a.Model().ID,
445		Provider: a.providerID,
446	})
447	if err != nil {
448		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449	}
450
451	// Add the session and message ID into the context if needed by tools.
452	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454	// Process each event in the stream.
455	for event := range eventChan {
456		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457			if errors.Is(processErr, context.Canceled) {
458				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459			} else {
460				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461			}
462			return assistantMsg, nil, processErr
463		}
464		if ctx.Err() != nil {
465			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466			return assistantMsg, nil, ctx.Err()
467		}
468	}
469
470	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471	toolCalls := assistantMsg.ToolCalls()
472	for i, toolCall := range toolCalls {
473		select {
474		case <-ctx.Done():
475			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			// Make all future tool calls cancelled
477			for j := i; j < len(toolCalls); j++ {
478				toolResults[j] = message.ToolResult{
479					ToolCallID: toolCalls[j].ID,
480					Content:    "Tool execution canceled by user",
481					IsError:    true,
482				}
483			}
484			goto out
485		default:
486			// Continue processing
487			var tool tools.BaseTool
488			for _, availableTool := range a.tools {
489				if availableTool.Info().Name == toolCall.Name {
490					tool = availableTool
491					break
492				}
493			}
494
495			// Tool not found
496			if tool == nil {
497				toolResults[i] = message.ToolResult{
498					ToolCallID: toolCall.ID,
499					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
500					IsError:    true,
501				}
502				continue
503			}
504			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
505				ID:    toolCall.ID,
506				Name:  toolCall.Name,
507				Input: toolCall.Input,
508			})
509			if toolErr != nil {
510				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
511					toolResults[i] = message.ToolResult{
512						ToolCallID: toolCall.ID,
513						Content:    "Permission denied",
514						IsError:    true,
515					}
516					for j := i + 1; j < len(toolCalls); j++ {
517						toolResults[j] = message.ToolResult{
518							ToolCallID: toolCalls[j].ID,
519							Content:    "Tool execution canceled by user",
520							IsError:    true,
521						}
522					}
523					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
524					break
525				}
526			}
527			toolResults[i] = message.ToolResult{
528				ToolCallID: toolCall.ID,
529				Content:    toolResult.Content,
530				Metadata:   toolResult.Metadata,
531				IsError:    toolResult.IsError,
532			}
533		}
534	}
535out:
536	if len(toolResults) == 0 {
537		return assistantMsg, nil, nil
538	}
539	parts := make([]message.ContentPart, 0)
540	for _, tr := range toolResults {
541		parts = append(parts, tr)
542	}
543	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
544		Role:     message.Tool,
545		Parts:    parts,
546		Provider: a.providerID,
547	})
548	if err != nil {
549		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
550	}
551
552	return assistantMsg, &msg, err
553}
554
555func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason, message, details string) {
556	msg.AddFinish(finishReson, message, details)
557	_ = a.messages.Update(ctx, *msg)
558}
559
560func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
561	select {
562	case <-ctx.Done():
563		return ctx.Err()
564	default:
565		// Continue processing.
566	}
567
568	switch event.Type {
569	case provider.EventThinkingDelta:
570		assistantMsg.AppendReasoningContent(event.Content)
571		return a.messages.Update(ctx, *assistantMsg)
572	case provider.EventContentDelta:
573		assistantMsg.AppendContent(event.Content)
574		return a.messages.Update(ctx, *assistantMsg)
575	case provider.EventToolUseStart:
576		slog.Info("Tool call started", "toolCall", event.ToolCall)
577		assistantMsg.AddToolCall(*event.ToolCall)
578		return a.messages.Update(ctx, *assistantMsg)
579	case provider.EventToolUseDelta:
580		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
581		return a.messages.Update(ctx, *assistantMsg)
582	case provider.EventToolUseStop:
583		slog.Info("Finished tool call", "toolCall", event.ToolCall)
584		assistantMsg.FinishToolCall(event.ToolCall.ID)
585		return a.messages.Update(ctx, *assistantMsg)
586	case provider.EventError:
587		return event.Error
588	case provider.EventComplete:
589		assistantMsg.SetToolCalls(event.Response.ToolCalls)
590		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
591		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
592			return fmt.Errorf("failed to update message: %w", err)
593		}
594		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
595	}
596
597	return nil
598}
599
600func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
601	sess, err := a.sessions.Get(ctx, sessionID)
602	if err != nil {
603		return fmt.Errorf("failed to get session: %w", err)
604	}
605
606	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
607		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
608		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
609		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
610
611	sess.Cost += cost
612	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
613	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
614
615	_, err = a.sessions.Save(ctx, sess)
616	if err != nil {
617		return fmt.Errorf("failed to save session: %w", err)
618	}
619	return nil
620}
621
622func (a *agent) Summarize(ctx context.Context, sessionID string) error {
623	if a.summarizeProvider == nil {
624		return fmt.Errorf("summarize provider not available")
625	}
626
627	// Check if session is busy
628	if a.IsSessionBusy(sessionID) {
629		return ErrSessionBusy
630	}
631
632	// Create a new context with cancellation
633	summarizeCtx, cancel := context.WithCancel(ctx)
634
635	// Store the cancel function in activeRequests to allow cancellation
636	a.activeRequests.Store(sessionID+"-summarize", cancel)
637
638	go func() {
639		defer a.activeRequests.Delete(sessionID + "-summarize")
640		defer cancel()
641		event := AgentEvent{
642			Type:     AgentEventTypeSummarize,
643			Progress: "Starting summarization...",
644		}
645
646		a.Publish(pubsub.CreatedEvent, event)
647		// Get all messages from the session
648		msgs, err := a.messages.List(summarizeCtx, sessionID)
649		if err != nil {
650			event = AgentEvent{
651				Type:  AgentEventTypeError,
652				Error: fmt.Errorf("failed to list messages: %w", err),
653				Done:  true,
654			}
655			a.Publish(pubsub.CreatedEvent, event)
656			return
657		}
658		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
659
660		if len(msgs) == 0 {
661			event = AgentEvent{
662				Type:  AgentEventTypeError,
663				Error: fmt.Errorf("no messages to summarize"),
664				Done:  true,
665			}
666			a.Publish(pubsub.CreatedEvent, event)
667			return
668		}
669
670		event = AgentEvent{
671			Type:     AgentEventTypeSummarize,
672			Progress: "Analyzing conversation...",
673		}
674		a.Publish(pubsub.CreatedEvent, event)
675
676		// Add a system message to guide the summarization
677		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
678
679		// Create a new message with the summarize prompt
680		promptMsg := message.Message{
681			Role:  message.User,
682			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
683		}
684
685		// Append the prompt to the messages
686		msgsWithPrompt := append(msgs, promptMsg)
687
688		event = AgentEvent{
689			Type:     AgentEventTypeSummarize,
690			Progress: "Generating summary...",
691		}
692
693		a.Publish(pubsub.CreatedEvent, event)
694
695		// Send the messages to the summarize provider
696		response := a.summarizeProvider.StreamResponse(
697			summarizeCtx,
698			msgsWithPrompt,
699			make([]tools.BaseTool, 0),
700		)
701		var finalResponse *provider.ProviderResponse
702		for r := range response {
703			if r.Error != nil {
704				event = AgentEvent{
705					Type:  AgentEventTypeError,
706					Error: fmt.Errorf("failed to summarize: %w", err),
707					Done:  true,
708				}
709				a.Publish(pubsub.CreatedEvent, event)
710				return
711			}
712			finalResponse = r.Response
713		}
714
715		summary := strings.TrimSpace(finalResponse.Content)
716		if summary == "" {
717			event = AgentEvent{
718				Type:  AgentEventTypeError,
719				Error: fmt.Errorf("empty summary returned"),
720				Done:  true,
721			}
722			a.Publish(pubsub.CreatedEvent, event)
723			return
724		}
725		event = AgentEvent{
726			Type:     AgentEventTypeSummarize,
727			Progress: "Creating new session...",
728		}
729
730		a.Publish(pubsub.CreatedEvent, event)
731		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
732		if err != nil {
733			event = AgentEvent{
734				Type:  AgentEventTypeError,
735				Error: fmt.Errorf("failed to get session: %w", err),
736				Done:  true,
737			}
738
739			a.Publish(pubsub.CreatedEvent, event)
740			return
741		}
742		// Create a message in the new session with the summary
743		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
744			Role: message.Assistant,
745			Parts: []message.ContentPart{
746				message.TextContent{Text: summary},
747				message.Finish{
748					Reason: message.FinishReasonEndTurn,
749					Time:   time.Now().Unix(),
750				},
751			},
752			Model:    a.summarizeProvider.Model().ID,
753			Provider: a.summarizeProviderID,
754		})
755		if err != nil {
756			event = AgentEvent{
757				Type:  AgentEventTypeError,
758				Error: fmt.Errorf("failed to create summary message: %w", err),
759				Done:  true,
760			}
761
762			a.Publish(pubsub.CreatedEvent, event)
763			return
764		}
765		oldSession.SummaryMessageID = msg.ID
766		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
767		oldSession.PromptTokens = 0
768		model := a.summarizeProvider.Model()
769		usage := finalResponse.Usage
770		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
771			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
772			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
773			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
774		oldSession.Cost += cost
775		_, err = a.sessions.Save(summarizeCtx, oldSession)
776		if err != nil {
777			event = AgentEvent{
778				Type:  AgentEventTypeError,
779				Error: fmt.Errorf("failed to save session: %w", err),
780				Done:  true,
781			}
782			a.Publish(pubsub.CreatedEvent, event)
783		}
784
785		event = AgentEvent{
786			Type:      AgentEventTypeSummarize,
787			SessionID: oldSession.ID,
788			Progress:  "Summary complete",
789			Done:      true,
790		}
791		a.Publish(pubsub.CreatedEvent, event)
792		// Send final success event with the new session ID
793	}()
794
795	return nil
796}
797
798func (a *agent) CancelAll() {
799	a.activeRequests.Range(func(key, value any) bool {
800		a.Cancel(key.(string)) // key is sessionID
801		return true
802	})
803
804	timeout := time.After(5 * time.Second)
805	for a.IsBusy() {
806		select {
807		case <-timeout:
808			return
809		default:
810			time.Sleep(200 * time.Millisecond)
811		}
812	}
813}
814
815func (a *agent) UpdateModel() error {
816	cfg := config.Get()
817
818	// Get current provider configuration
819	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
820	if currentProviderCfg.ID == "" {
821		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
822	}
823
824	// Check if provider has changed
825	if string(currentProviderCfg.ID) != a.providerID {
826		// Provider changed, need to recreate the main provider
827		model := cfg.GetModelByType(a.agentCfg.Model)
828		if model.ID == "" {
829			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
830		}
831
832		promptID := agentPromptMap[a.agentCfg.ID]
833		if promptID == "" {
834			promptID = prompt.PromptDefault
835		}
836
837		opts := []provider.ProviderClientOption{
838			provider.WithModel(a.agentCfg.Model),
839			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
840		}
841
842		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
843		if err != nil {
844			return fmt.Errorf("failed to create new provider: %w", err)
845		}
846
847		// Update the provider and provider ID
848		a.provider = newProvider
849		a.providerID = string(currentProviderCfg.ID)
850	}
851
852	// Check if small model provider has changed (affects title and summarize providers)
853	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
854	var smallModelProviderCfg config.ProviderConfig
855
856	for _, p := range cfg.Providers {
857		if p.ID == smallModelCfg.Provider {
858			smallModelProviderCfg = p
859			break
860		}
861	}
862
863	if smallModelProviderCfg.ID == "" {
864		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
865	}
866
867	// Check if summarize provider has changed
868	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
869		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
870		if smallModel == nil {
871			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
872		}
873
874		// Recreate title provider
875		titleOpts := []provider.ProviderClientOption{
876			provider.WithModel(config.SelectedModelTypeSmall),
877			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
878			// We want the title to be short, so we limit the max tokens
879			provider.WithMaxTokens(40),
880		}
881		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
882		if err != nil {
883			return fmt.Errorf("failed to create new title provider: %w", err)
884		}
885
886		// Recreate summarize provider
887		summarizeOpts := []provider.ProviderClientOption{
888			provider.WithModel(config.SelectedModelTypeSmall),
889			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
890		}
891		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
892		if err != nil {
893			return fmt.Errorf("failed to create new summarize provider: %w", err)
894		}
895
896		// Update the providers and provider ID
897		a.titleProvider = newTitleProvider
898		a.summarizeProvider = newSummarizeProvider
899		a.summarizeProviderID = string(smallModelProviderCfg.ID)
900	}
901
902	return nil
903}