agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	EffectiveMaxTokens() int64
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[config.AgentID]prompt.PromptID{
 81	config.AgentCoder: prompt.PromptCoder,
 82	config.AgentTask:  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(lspClients) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99	}
100
101	allTools := []tools.BaseTool{
102		tools.NewBashTool(permissions),
103		tools.NewEditTool(lspClients, permissions, history),
104		tools.NewFetchTool(permissions),
105		tools.NewGlobTool(),
106		tools.NewGrepTool(),
107		tools.NewLsTool(),
108		tools.NewSourcegraphTool(),
109		tools.NewViewTool(lspClients),
110		tools.NewWriteTool(lspClients, permissions, history),
111	}
112
113	if agentCfg.ID == config.AgentCoder {
114		taskAgentCfg := config.Get().Agents[config.AgentTask]
115		if taskAgentCfg.ID == "" {
116			return nil, fmt.Errorf("task agent not found in config")
117		}
118		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119		if err != nil {
120			return nil, fmt.Errorf("failed to create task agent: %w", err)
121		}
122
123		allTools = append(
124			allTools,
125			NewAgentTool(
126				taskAgent,
127				sessions,
128				messages,
129			),
130		)
131	}
132
133	allTools = append(allTools, otherTools...)
134	providerCfg := config.GetAgentProvider(agentCfg.ID)
135	if providerCfg.ID == "" {
136		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137	}
138	model := config.GetAgentModel(agentCfg.ID)
139
140	if model.ID == "" {
141		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142	}
143
144	promptID := agentPromptMap[agentCfg.ID]
145	if promptID == "" {
146		promptID = prompt.PromptDefault
147	}
148	opts := []provider.ProviderClientOption{
149		provider.WithModel(agentCfg.Model),
150		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151	}
152	agentProvider, err := provider.NewProvider(providerCfg, opts...)
153	if err != nil {
154		return nil, err
155	}
156
157	smallModelCfg := cfg.Models.Small
158	var smallModel config.Model
159
160	var smallModelProviderCfg config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		for _, p := range cfg.Providers {
165			if p.ID == smallModelCfg.Provider {
166				smallModelProviderCfg = p
167				break
168			}
169		}
170		if smallModelProviderCfg.ID == "" {
171			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172		}
173	}
174	for _, m := range smallModelProviderCfg.Models {
175		if m.ID == smallModelCfg.ModelID {
176			smallModel = m
177			break
178		}
179	}
180	if smallModel.ID == "" {
181		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182	}
183
184	titleOpts := []provider.ProviderClientOption{
185		provider.WithModel(config.SmallModel),
186		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187	}
188	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189	if err != nil {
190		return nil, err
191	}
192	summarizeOpts := []provider.ProviderClientOption{
193		provider.WithModel(config.SmallModel),
194		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195	}
196	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197	if err != nil {
198		return nil, err
199	}
200
201	agentTools := []tools.BaseTool{}
202	if agentCfg.AllowedTools == nil {
203		agentTools = allTools
204	} else {
205		for _, tool := range allTools {
206			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207				agentTools = append(agentTools, tool)
208			}
209		}
210	}
211
212	agent := &agent{
213		Broker:              pubsub.NewBroker[AgentEvent](),
214		agentCfg:            agentCfg,
215		provider:            agentProvider,
216		providerID:          string(providerCfg.ID),
217		messages:            messages,
218		sessions:            sessions,
219		tools:               agentTools,
220		titleProvider:       titleProvider,
221		summarizeProvider:   summarizeProvider,
222		summarizeProviderID: string(smallModelProviderCfg.ID),
223		activeRequests:      sync.Map{},
224	}
225
226	return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230	return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242			cancel()
243		}
244	}
245
246	// Also check for summarize requests
247	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250			cancel()
251		}
252	}
253}
254
255func (a *agent) IsBusy() bool {
256	busy := false
257	a.activeRequests.Range(func(key, value any) bool {
258		if cancelFunc, ok := value.(context.CancelFunc); ok {
259			if cancelFunc != nil {
260				busy = true
261				return false
262			}
263		}
264		return true
265	})
266	return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270	_, busy := a.activeRequests.Load(sessionID)
271	return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275	if content == "" {
276		return nil
277	}
278	if a.titleProvider == nil {
279		return nil
280	}
281	session, err := a.sessions.Get(ctx, sessionID)
282	if err != nil {
283		return err
284	}
285	parts := []message.ContentPart{message.TextContent{
286		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287	}}
288
289	// Use streaming approach like summarization
290	response := a.titleProvider.StreamResponse(
291		ctx,
292		[]message.Message{
293			{
294				Role:  message.User,
295				Parts: parts,
296			},
297		},
298		make([]tools.BaseTool, 0),
299	)
300
301	var finalResponse *provider.ProviderResponse
302	for r := range response {
303		if r.Error != nil {
304			return r.Error
305		}
306		finalResponse = r.Response
307	}
308
309	if finalResponse == nil {
310		return fmt.Errorf("no response received from title provider")
311	}
312
313	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314	if title == "" {
315		return nil
316	}
317
318	session.Title = title
319	_, err = a.sessions.Save(ctx, session)
320	return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324	return AgentEvent{
325		Type:  AgentEventTypeError,
326		Error: err,
327	}
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331	if !a.Model().SupportsImages && attachments != nil {
332		attachments = nil
333	}
334	events := make(chan AgentEvent)
335	if a.IsSessionBusy(sessionID) {
336		return nil, ErrSessionBusy
337	}
338
339	genCtx, cancel := context.WithCancel(ctx)
340
341	a.activeRequests.Store(sessionID, cancel)
342	go func() {
343		logging.Debug("Request started", "sessionID", sessionID)
344		defer logging.RecoverPanic("agent.Run", func() {
345			events <- a.err(fmt.Errorf("panic while running the agent"))
346		})
347		var attachmentParts []message.ContentPart
348		for _, attachment := range attachments {
349			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350		}
351		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353			logging.ErrorPersist(result.Error.Error())
354		}
355		logging.Debug("Request completed", "sessionID", sessionID)
356		a.activeRequests.Delete(sessionID)
357		cancel()
358		a.Publish(pubsub.CreatedEvent, result)
359		events <- result
360		close(events)
361	}()
362	return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366	cfg := config.Get()
367	// List existing messages; if none, start title generation asynchronously.
368	msgs, err := a.messages.List(ctx, sessionID)
369	if err != nil {
370		return a.err(fmt.Errorf("failed to list messages: %w", err))
371	}
372	if len(msgs) == 0 {
373		go func() {
374			defer logging.RecoverPanic("agent.Run", func() {
375				logging.ErrorPersist("panic while generating title")
376			})
377			titleErr := a.generateTitle(context.Background(), sessionID, content)
378			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380			}
381		}()
382	}
383	session, err := a.sessions.Get(ctx, sessionID)
384	if err != nil {
385		return a.err(fmt.Errorf("failed to get session: %w", err))
386	}
387	if session.SummaryMessageID != "" {
388		summaryMsgInex := -1
389		for i, msg := range msgs {
390			if msg.ID == session.SummaryMessageID {
391				summaryMsgInex = i
392				break
393			}
394		}
395		if summaryMsgInex != -1 {
396			msgs = msgs[summaryMsgInex:]
397			msgs[0].Role = message.User
398		}
399	}
400
401	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402	if err != nil {
403		return a.err(fmt.Errorf("failed to create user message: %w", err))
404	}
405	// Append the new user message to the conversation history.
406	msgHistory := append(msgs, userMsg)
407
408	for {
409		// Check for cancellation before each iteration
410		select {
411		case <-ctx.Done():
412			return a.err(ctx.Err())
413		default:
414			// Continue processing
415		}
416		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417		if err != nil {
418			if errors.Is(err, context.Canceled) {
419				agentMessage.AddFinish(message.FinishReasonCanceled)
420				a.messages.Update(context.Background(), agentMessage)
421				return a.err(ErrRequestCancelled)
422			}
423			return a.err(fmt.Errorf("failed to process events: %w", err))
424		}
425		if cfg.Options.Debug {
426			seqId := (len(msgHistory) + 1) / 2
427			toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
428			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
429		} else {
430			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
431		}
432		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
433			// We are not done, we need to respond with the tool response
434			msgHistory = append(msgHistory, agentMessage, *toolResults)
435			continue
436		}
437		return AgentEvent{
438			Type:    AgentEventTypeResponse,
439			Message: agentMessage,
440			Done:    true,
441		}
442	}
443}
444
445func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
446	parts := []message.ContentPart{message.TextContent{Text: content}}
447	parts = append(parts, attachmentParts...)
448	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
449		Role:  message.User,
450		Parts: parts,
451	})
452}
453
454func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
455	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459		Role:     message.Assistant,
460		Parts:    []message.ContentPart{},
461		Model:    a.Model().ID,
462		Provider: a.providerID,
463	})
464	if err != nil {
465		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466	}
467
468	// Add the session and message ID into the context if needed by tools.
469	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471	// Process each event in the stream.
472	for event := range eventChan {
473		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
475			return assistantMsg, nil, processErr
476		}
477		if ctx.Err() != nil {
478			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
479			return assistantMsg, nil, ctx.Err()
480		}
481	}
482
483	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484	toolCalls := assistantMsg.ToolCalls()
485	for i, toolCall := range toolCalls {
486		select {
487		case <-ctx.Done():
488			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
489			// Make all future tool calls cancelled
490			for j := i; j < len(toolCalls); j++ {
491				toolResults[j] = message.ToolResult{
492					ToolCallID: toolCalls[j].ID,
493					Content:    "Tool execution canceled by user",
494					IsError:    true,
495				}
496			}
497			goto out
498		default:
499			// Continue processing
500			var tool tools.BaseTool
501			for _, availableTool := range a.tools {
502				if availableTool.Info().Name == toolCall.Name {
503					tool = availableTool
504					break
505				}
506				// Monkey patch for Copilot Sonnet-4 tool repetition obfuscation
507				// if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) &&
508				// 	strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) {
509				// 	tool = availableTool
510				// 	break
511				// }
512			}
513
514			// Tool not found
515			if tool == nil {
516				toolResults[i] = message.ToolResult{
517					ToolCallID: toolCall.ID,
518					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
519					IsError:    true,
520				}
521				continue
522			}
523			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
524				ID:    toolCall.ID,
525				Name:  toolCall.Name,
526				Input: toolCall.Input,
527			})
528			if toolErr != nil {
529				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
530					toolResults[i] = message.ToolResult{
531						ToolCallID: toolCall.ID,
532						Content:    "Permission denied",
533						IsError:    true,
534					}
535					for j := i + 1; j < len(toolCalls); j++ {
536						toolResults[j] = message.ToolResult{
537							ToolCallID: toolCalls[j].ID,
538							Content:    "Tool execution canceled by user",
539							IsError:    true,
540						}
541					}
542					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
543					break
544				}
545			}
546			toolResults[i] = message.ToolResult{
547				ToolCallID: toolCall.ID,
548				Content:    toolResult.Content,
549				Metadata:   toolResult.Metadata,
550				IsError:    toolResult.IsError,
551			}
552		}
553	}
554out:
555	if len(toolResults) == 0 {
556		return assistantMsg, nil, nil
557	}
558	parts := make([]message.ContentPart, 0)
559	for _, tr := range toolResults {
560		parts = append(parts, tr)
561	}
562	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
563		Role:     message.Tool,
564		Parts:    parts,
565		Provider: a.providerID,
566	})
567	if err != nil {
568		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
569	}
570
571	return assistantMsg, &msg, err
572}
573
574func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
575	msg.AddFinish(finishReson)
576	_ = a.messages.Update(ctx, *msg)
577}
578
579func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
580	select {
581	case <-ctx.Done():
582		return ctx.Err()
583	default:
584		// Continue processing.
585	}
586
587	switch event.Type {
588	case provider.EventThinkingDelta:
589		assistantMsg.AppendReasoningContent(event.Content)
590		return a.messages.Update(ctx, *assistantMsg)
591	case provider.EventContentDelta:
592		assistantMsg.AppendContent(event.Content)
593		return a.messages.Update(ctx, *assistantMsg)
594	case provider.EventToolUseStart:
595		logging.Info("Tool call started", "toolCall", event.ToolCall)
596		assistantMsg.AddToolCall(*event.ToolCall)
597		return a.messages.Update(ctx, *assistantMsg)
598	case provider.EventToolUseDelta:
599		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
600		return a.messages.Update(ctx, *assistantMsg)
601	case provider.EventToolUseStop:
602		logging.Info("Finished tool call", "toolCall", event.ToolCall)
603		assistantMsg.FinishToolCall(event.ToolCall.ID)
604		return a.messages.Update(ctx, *assistantMsg)
605	case provider.EventError:
606		if errors.Is(event.Error, context.Canceled) {
607			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
608			return context.Canceled
609		}
610		logging.ErrorPersist(event.Error.Error())
611		return event.Error
612	case provider.EventComplete:
613		assistantMsg.SetToolCalls(event.Response.ToolCalls)
614		assistantMsg.AddFinish(event.Response.FinishReason)
615		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
616			return fmt.Errorf("failed to update message: %w", err)
617		}
618		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
619	}
620
621	return nil
622}
623
624func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
625	sess, err := a.sessions.Get(ctx, sessionID)
626	if err != nil {
627		return fmt.Errorf("failed to get session: %w", err)
628	}
629
630	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
631		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
632		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
633		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
634
635	sess.Cost += cost
636	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
637	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
638
639	_, err = a.sessions.Save(ctx, sess)
640	if err != nil {
641		return fmt.Errorf("failed to save session: %w", err)
642	}
643	return nil
644}
645
646func (a *agent) Summarize(ctx context.Context, sessionID string) error {
647	if a.summarizeProvider == nil {
648		return fmt.Errorf("summarize provider not available")
649	}
650
651	// Check if session is busy
652	if a.IsSessionBusy(sessionID) {
653		return ErrSessionBusy
654	}
655
656	// Create a new context with cancellation
657	summarizeCtx, cancel := context.WithCancel(ctx)
658
659	// Store the cancel function in activeRequests to allow cancellation
660	a.activeRequests.Store(sessionID+"-summarize", cancel)
661
662	go func() {
663		defer a.activeRequests.Delete(sessionID + "-summarize")
664		defer cancel()
665		event := AgentEvent{
666			Type:     AgentEventTypeSummarize,
667			Progress: "Starting summarization...",
668		}
669
670		a.Publish(pubsub.CreatedEvent, event)
671		// Get all messages from the session
672		msgs, err := a.messages.List(summarizeCtx, sessionID)
673		if err != nil {
674			event = AgentEvent{
675				Type:  AgentEventTypeError,
676				Error: fmt.Errorf("failed to list messages: %w", err),
677				Done:  true,
678			}
679			a.Publish(pubsub.CreatedEvent, event)
680			return
681		}
682		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
683
684		if len(msgs) == 0 {
685			event = AgentEvent{
686				Type:  AgentEventTypeError,
687				Error: fmt.Errorf("no messages to summarize"),
688				Done:  true,
689			}
690			a.Publish(pubsub.CreatedEvent, event)
691			return
692		}
693
694		event = AgentEvent{
695			Type:     AgentEventTypeSummarize,
696			Progress: "Analyzing conversation...",
697		}
698		a.Publish(pubsub.CreatedEvent, event)
699
700		// Add a system message to guide the summarization
701		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
702
703		// Create a new message with the summarize prompt
704		promptMsg := message.Message{
705			Role:  message.User,
706			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
707		}
708
709		// Append the prompt to the messages
710		msgsWithPrompt := append(msgs, promptMsg)
711
712		event = AgentEvent{
713			Type:     AgentEventTypeSummarize,
714			Progress: "Generating summary...",
715		}
716
717		a.Publish(pubsub.CreatedEvent, event)
718
719		// Send the messages to the summarize provider
720		response := a.summarizeProvider.StreamResponse(
721			summarizeCtx,
722			msgsWithPrompt,
723			make([]tools.BaseTool, 0),
724		)
725		var finalResponse *provider.ProviderResponse
726		for r := range response {
727			if r.Error != nil {
728				event = AgentEvent{
729					Type:  AgentEventTypeError,
730					Error: fmt.Errorf("failed to summarize: %w", err),
731					Done:  true,
732				}
733				a.Publish(pubsub.CreatedEvent, event)
734				return
735			}
736			finalResponse = r.Response
737		}
738
739		summary := strings.TrimSpace(finalResponse.Content)
740		if summary == "" {
741			event = AgentEvent{
742				Type:  AgentEventTypeError,
743				Error: fmt.Errorf("empty summary returned"),
744				Done:  true,
745			}
746			a.Publish(pubsub.CreatedEvent, event)
747			return
748		}
749		event = AgentEvent{
750			Type:     AgentEventTypeSummarize,
751			Progress: "Creating new session...",
752		}
753
754		a.Publish(pubsub.CreatedEvent, event)
755		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
756		if err != nil {
757			event = AgentEvent{
758				Type:  AgentEventTypeError,
759				Error: fmt.Errorf("failed to get session: %w", err),
760				Done:  true,
761			}
762
763			a.Publish(pubsub.CreatedEvent, event)
764			return
765		}
766		// Create a message in the new session with the summary
767		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
768			Role: message.Assistant,
769			Parts: []message.ContentPart{
770				message.TextContent{Text: summary},
771				message.Finish{
772					Reason: message.FinishReasonEndTurn,
773					Time:   time.Now().Unix(),
774				},
775			},
776			Model:    a.summarizeProvider.Model().ID,
777			Provider: a.summarizeProviderID,
778		})
779		if err != nil {
780			event = AgentEvent{
781				Type:  AgentEventTypeError,
782				Error: fmt.Errorf("failed to create summary message: %w", err),
783				Done:  true,
784			}
785
786			a.Publish(pubsub.CreatedEvent, event)
787			return
788		}
789		oldSession.SummaryMessageID = msg.ID
790		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
791		oldSession.PromptTokens = 0
792		model := a.summarizeProvider.Model()
793		usage := finalResponse.Usage
794		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
795			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
796			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
797			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
798		oldSession.Cost += cost
799		_, err = a.sessions.Save(summarizeCtx, oldSession)
800		if err != nil {
801			event = AgentEvent{
802				Type:  AgentEventTypeError,
803				Error: fmt.Errorf("failed to save session: %w", err),
804				Done:  true,
805			}
806			a.Publish(pubsub.CreatedEvent, event)
807		}
808
809		event = AgentEvent{
810			Type:      AgentEventTypeSummarize,
811			SessionID: oldSession.ID,
812			Progress:  "Summary complete",
813			Done:      true,
814		}
815		a.Publish(pubsub.CreatedEvent, event)
816		// Send final success event with the new session ID
817	}()
818
819	return nil
820}
821
822func (a *agent) CancelAll() {
823	a.activeRequests.Range(func(key, value any) bool {
824		a.Cancel(key.(string)) // key is sessionID
825		return true
826	})
827}
828
829func (a *agent) UpdateModel() error {
830	cfg := config.Get()
831
832	// Get current provider configuration
833	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
834	if currentProviderCfg.ID == "" {
835		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
836	}
837
838	// Check if provider has changed
839	if string(currentProviderCfg.ID) != a.providerID {
840		// Provider changed, need to recreate the main provider
841		model := config.GetAgentModel(a.agentCfg.ID)
842		if model.ID == "" {
843			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
844		}
845
846		promptID := agentPromptMap[a.agentCfg.ID]
847		if promptID == "" {
848			promptID = prompt.PromptDefault
849		}
850
851		opts := []provider.ProviderClientOption{
852			provider.WithModel(a.agentCfg.Model),
853			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
854		}
855
856		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
857		if err != nil {
858			return fmt.Errorf("failed to create new provider: %w", err)
859		}
860
861		// Update the provider and provider ID
862		a.provider = newProvider
863		a.providerID = string(currentProviderCfg.ID)
864	}
865
866	// Check if small model provider has changed (affects title and summarize providers)
867	smallModelCfg := cfg.Models.Small
868	var smallModelProviderCfg config.ProviderConfig
869
870	for _, p := range cfg.Providers {
871		if p.ID == smallModelCfg.Provider {
872			smallModelProviderCfg = p
873			break
874		}
875	}
876
877	if smallModelProviderCfg.ID == "" {
878		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
879	}
880
881	// Check if summarize provider has changed
882	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
883		var smallModel config.Model
884		for _, m := range smallModelProviderCfg.Models {
885			if m.ID == smallModelCfg.ModelID {
886				smallModel = m
887				break
888			}
889		}
890		if smallModel.ID == "" {
891			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
892		}
893
894		// Recreate title provider
895		titleOpts := []provider.ProviderClientOption{
896			provider.WithModel(config.SmallModel),
897			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
898			// We want the title to be short, so we limit the max tokens
899			provider.WithMaxTokens(40),
900		}
901		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
902		if err != nil {
903			return fmt.Errorf("failed to create new title provider: %w", err)
904		}
905
906		// Recreate summarize provider
907		summarizeOpts := []provider.ProviderClientOption{
908			provider.WithModel(config.SmallModel),
909			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
910		}
911		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
912		if err != nil {
913			return fmt.Errorf("failed to create new summarize provider: %w", err)
914		}
915
916		// Update the providers and provider ID
917		a.titleProvider = newTitleProvider
918		a.summarizeProvider = newSummarizeProvider
919		a.summarizeProviderID = string(smallModelProviderCfg.ID)
920	}
921
922	return nil
923}