agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request canceled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMCPTools(ctx, permissions, cfg)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	cwd := cfg.WorkingDir()
103	allTools := []tools.BaseTool{
104		tools.NewBashTool(permissions, cwd),
105		tools.NewEditTool(lspClients, permissions, history, cwd),
106		tools.NewFetchTool(permissions, cwd),
107		tools.NewGlobTool(cwd),
108		tools.NewGrepTool(cwd),
109		tools.NewLsTool(cwd),
110		tools.NewSourcegraphTool(),
111		tools.NewViewTool(lspClients, cwd),
112		tools.NewWriteTool(lspClients, permissions, history, cwd),
113	}
114
115	if agentCfg.ID == "coder" {
116		taskAgentCfg := config.Get().Agents["task"]
117		if taskAgentCfg.ID == "" {
118			return nil, fmt.Errorf("task agent not found in config")
119		}
120		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121		if err != nil {
122			return nil, fmt.Errorf("failed to create task agent: %w", err)
123		}
124
125		allTools = append(
126			allTools,
127			NewAgentTool(
128				taskAgent,
129				sessions,
130				messages,
131			),
132		)
133	}
134
135	allTools = append(allTools, otherTools...)
136	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137	if providerCfg == nil {
138		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139	}
140	model := config.Get().GetModelByType(agentCfg.Model)
141
142	if model == nil {
143		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144	}
145
146	promptID := agentPromptMap[agentCfg.ID]
147	if promptID == "" {
148		promptID = prompt.PromptDefault
149	}
150	opts := []provider.ProviderClientOption{
151		provider.WithModel(agentCfg.Model),
152		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153	}
154	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155	if err != nil {
156		return nil, err
157	}
158
159	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160	var smallModelProviderCfg *config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166		if smallModelProviderCfg.ID == "" {
167			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168		}
169	}
170	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171	if smallModel.ID == "" {
172		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173	}
174
175	titleOpts := []provider.ProviderClientOption{
176		provider.WithModel(config.SelectedModelTypeSmall),
177		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178	}
179	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180	if err != nil {
181		return nil, err
182	}
183	summarizeOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SelectedModelTypeSmall),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186	}
187	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188	if err != nil {
189		return nil, err
190	}
191
192	agentTools := []tools.BaseTool{}
193	if agentCfg.AllowedTools == nil {
194		agentTools = allTools
195	} else {
196		for _, tool := range allTools {
197			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198				agentTools = append(agentTools, tool)
199			}
200		}
201	}
202
203	agent := &agent{
204		Broker:              pubsub.NewBroker[AgentEvent](),
205		agentCfg:            agentCfg,
206		provider:            agentProvider,
207		providerID:          string(providerCfg.ID),
208		messages:            messages,
209		sessions:            sessions,
210		tools:               agentTools,
211		titleProvider:       titleProvider,
212		summarizeProvider:   summarizeProvider,
213		summarizeProviderID: string(smallModelProviderCfg.ID),
214		activeRequests:      sync.Map{},
215	}
216
217	return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221	return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225	// Cancel regular requests
226	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229			cancel()
230		}
231	}
232
233	// Also check for summarize requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240}
241
242func (a *agent) IsBusy() bool {
243	busy := false
244	a.activeRequests.Range(func(key, value any) bool {
245		if cancelFunc, ok := value.(context.CancelFunc); ok {
246			if cancelFunc != nil {
247				busy = true
248				return false
249			}
250		}
251		return true
252	})
253	return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257	_, busy := a.activeRequests.Load(sessionID)
258	return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262	if content == "" {
263		return nil
264	}
265	if a.titleProvider == nil {
266		return nil
267	}
268	session, err := a.sessions.Get(ctx, sessionID)
269	if err != nil {
270		return err
271	}
272	parts := []message.ContentPart{message.TextContent{
273		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274	}}
275
276	// Use streaming approach like summarization
277	response := a.titleProvider.StreamResponse(
278		ctx,
279		[]message.Message{
280			{
281				Role:  message.User,
282				Parts: parts,
283			},
284		},
285		make([]tools.BaseTool, 0),
286	)
287
288	var finalResponse *provider.ProviderResponse
289	for r := range response {
290		if r.Error != nil {
291			return r.Error
292		}
293		finalResponse = r.Response
294	}
295
296	if finalResponse == nil {
297		return fmt.Errorf("no response received from title provider")
298	}
299
300	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301	if title == "" {
302		return nil
303	}
304
305	session.Title = title
306	_, err = a.sessions.Save(ctx, session)
307	return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311	return AgentEvent{
312		Type:  AgentEventTypeError,
313		Error: err,
314	}
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318	if !a.Model().SupportsImages && attachments != nil {
319		attachments = nil
320	}
321	events := make(chan AgentEvent)
322	if a.IsSessionBusy(sessionID) {
323		return nil, ErrSessionBusy
324	}
325
326	genCtx, cancel := context.WithCancel(ctx)
327
328	a.activeRequests.Store(sessionID, cancel)
329	go func() {
330		slog.Debug("Request started", "sessionID", sessionID)
331		defer log.RecoverPanic("agent.Run", func() {
332			events <- a.err(fmt.Errorf("panic while running the agent"))
333		})
334		var attachmentParts []message.ContentPart
335		for _, attachment := range attachments {
336			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337		}
338		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340			slog.Error(result.Error.Error())
341		}
342		slog.Debug("Request completed", "sessionID", sessionID)
343		a.activeRequests.Delete(sessionID)
344		cancel()
345		a.Publish(pubsub.CreatedEvent, result)
346		events <- result
347		close(events)
348	}()
349	return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353	cfg := config.Get()
354	// List existing messages; if none, start title generation asynchronously.
355	msgs, err := a.messages.List(ctx, sessionID)
356	if err != nil {
357		return a.err(fmt.Errorf("failed to list messages: %w", err))
358	}
359	if len(msgs) == 0 {
360		go func() {
361			defer log.RecoverPanic("agent.Run", func() {
362				slog.Error("panic while generating title")
363			})
364			titleErr := a.generateTitle(context.Background(), sessionID, content)
365			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367			}
368		}()
369	}
370	session, err := a.sessions.Get(ctx, sessionID)
371	if err != nil {
372		return a.err(fmt.Errorf("failed to get session: %w", err))
373	}
374	if session.SummaryMessageID != "" {
375		summaryMsgInex := -1
376		for i, msg := range msgs {
377			if msg.ID == session.SummaryMessageID {
378				summaryMsgInex = i
379				break
380			}
381		}
382		if summaryMsgInex != -1 {
383			msgs = msgs[summaryMsgInex:]
384			msgs[0].Role = message.User
385		}
386	}
387
388	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389	if err != nil {
390		return a.err(fmt.Errorf("failed to create user message: %w", err))
391	}
392	// Append the new user message to the conversation history.
393	msgHistory := append(msgs, userMsg)
394
395	for {
396		// Check for cancellation before each iteration
397		select {
398		case <-ctx.Done():
399			return a.err(ctx.Err())
400		default:
401			// Continue processing
402		}
403		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404		if err != nil {
405			if errors.Is(err, context.Canceled) {
406				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407				a.messages.Update(context.Background(), agentMessage)
408				return a.err(ErrRequestCancelled)
409			}
410			return a.err(fmt.Errorf("failed to process events: %w", err))
411		}
412		if cfg.Options.Debug {
413			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414		}
415		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416			// We are not done, we need to respond with the tool response
417			msgHistory = append(msgHistory, agentMessage, *toolResults)
418			continue
419		}
420		return AgentEvent{
421			Type:    AgentEventTypeResponse,
422			Message: agentMessage,
423			Done:    true,
424		}
425	}
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429	parts := []message.ContentPart{message.TextContent{Text: content}}
430	parts = append(parts, attachmentParts...)
431	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432		Role:  message.User,
433		Parts: parts,
434	})
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:     message.Assistant,
443		Parts:    []message.ContentPart{},
444		Model:    a.Model().ID,
445		Provider: a.providerID,
446	})
447	if err != nil {
448		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449	}
450
451	// Add the session and message ID into the context if needed by tools.
452	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454	// Process each event in the stream.
455	for event := range eventChan {
456		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457			if errors.Is(processErr, context.Canceled) {
458				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459			} else {
460				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461			}
462			return assistantMsg, nil, processErr
463		}
464		if ctx.Err() != nil {
465			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466			return assistantMsg, nil, ctx.Err()
467		}
468	}
469
470	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471	toolCalls := assistantMsg.ToolCalls()
472	for i, toolCall := range toolCalls {
473		select {
474		case <-ctx.Done():
475			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			// Make all future tool calls cancelled
477			for j := i; j < len(toolCalls); j++ {
478				toolResults[j] = message.ToolResult{
479					ToolCallID: toolCalls[j].ID,
480					Content:    "Tool execution canceled by user",
481					IsError:    true,
482				}
483			}
484			goto out
485		default:
486			// Continue processing
487			var tool tools.BaseTool
488			for _, availableTool := range a.tools {
489				if availableTool.Info().Name == toolCall.Name {
490					tool = availableTool
491					break
492				}
493			}
494
495			// Tool not found
496			if tool == nil {
497				toolResults[i] = message.ToolResult{
498					ToolCallID: toolCall.ID,
499					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
500					IsError:    true,
501				}
502				continue
503			}
504
505			// Run tool in goroutine to allow cancellation
506			type toolExecResult struct {
507				response tools.ToolResponse
508				err      error
509			}
510			resultChan := make(chan toolExecResult, 1)
511
512			go func() {
513				response, err := tool.Run(ctx, tools.ToolCall{
514					ID:    toolCall.ID,
515					Name:  toolCall.Name,
516					Input: toolCall.Input,
517				})
518				resultChan <- toolExecResult{response: response, err: err}
519			}()
520
521			var toolResponse tools.ToolResponse
522			var toolErr error
523
524			select {
525			case <-ctx.Done():
526				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
527				// Mark remaining tool calls as cancelled
528				for j := i; j < len(toolCalls); j++ {
529					toolResults[j] = message.ToolResult{
530						ToolCallID: toolCalls[j].ID,
531						Content:    "Tool execution canceled by user",
532						IsError:    true,
533					}
534				}
535				goto out
536			case result := <-resultChan:
537				toolResponse = result.response
538				toolErr = result.err
539			}
540
541			if toolErr != nil {
542				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
543				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
544					toolResults[i] = message.ToolResult{
545						ToolCallID: toolCall.ID,
546						Content:    "Permission denied",
547						IsError:    true,
548					}
549					for j := i + 1; j < len(toolCalls); j++ {
550						toolResults[j] = message.ToolResult{
551							ToolCallID: toolCalls[j].ID,
552							Content:    "Tool execution canceled by user",
553							IsError:    true,
554						}
555					}
556					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
557					break
558				}
559			}
560			toolResults[i] = message.ToolResult{
561				ToolCallID: toolCall.ID,
562				Content:    toolResponse.Content,
563				Metadata:   toolResponse.Metadata,
564				IsError:    toolResponse.IsError,
565			}
566		}
567	}
568out:
569	if len(toolResults) == 0 {
570		return assistantMsg, nil, nil
571	}
572	parts := make([]message.ContentPart, 0)
573	for _, tr := range toolResults {
574		parts = append(parts, tr)
575	}
576	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
577		Role:     message.Tool,
578		Parts:    parts,
579		Provider: a.providerID,
580	})
581	if err != nil {
582		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
583	}
584
585	return assistantMsg, &msg, err
586}
587
588func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
589	msg.AddFinish(finishReason, message, details)
590	_ = a.messages.Update(ctx, *msg)
591}
592
593func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
594	select {
595	case <-ctx.Done():
596		return ctx.Err()
597	default:
598		// Continue processing.
599	}
600
601	switch event.Type {
602	case provider.EventThinkingDelta:
603		assistantMsg.AppendReasoningContent(event.Content)
604		return a.messages.Update(ctx, *assistantMsg)
605	case provider.EventContentDelta:
606		assistantMsg.AppendContent(event.Content)
607		return a.messages.Update(ctx, *assistantMsg)
608	case provider.EventToolUseStart:
609		slog.Info("Tool call started", "toolCall", event.ToolCall)
610		assistantMsg.AddToolCall(*event.ToolCall)
611		return a.messages.Update(ctx, *assistantMsg)
612	case provider.EventToolUseDelta:
613		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
614		return a.messages.Update(ctx, *assistantMsg)
615	case provider.EventToolUseStop:
616		slog.Info("Finished tool call", "toolCall", event.ToolCall)
617		assistantMsg.FinishToolCall(event.ToolCall.ID)
618		return a.messages.Update(ctx, *assistantMsg)
619	case provider.EventError:
620		return event.Error
621	case provider.EventComplete:
622		assistantMsg.SetToolCalls(event.Response.ToolCalls)
623		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
624		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
625			return fmt.Errorf("failed to update message: %w", err)
626		}
627		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
628	}
629
630	return nil
631}
632
633func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
634	sess, err := a.sessions.Get(ctx, sessionID)
635	if err != nil {
636		return fmt.Errorf("failed to get session: %w", err)
637	}
638
639	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
640		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
641		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
642		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
643
644	sess.Cost += cost
645	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
646	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
647
648	_, err = a.sessions.Save(ctx, sess)
649	if err != nil {
650		return fmt.Errorf("failed to save session: %w", err)
651	}
652	return nil
653}
654
655func (a *agent) Summarize(ctx context.Context, sessionID string) error {
656	if a.summarizeProvider == nil {
657		return fmt.Errorf("summarize provider not available")
658	}
659
660	// Check if session is busy
661	if a.IsSessionBusy(sessionID) {
662		return ErrSessionBusy
663	}
664
665	// Create a new context with cancellation
666	summarizeCtx, cancel := context.WithCancel(ctx)
667
668	// Store the cancel function in activeRequests to allow cancellation
669	a.activeRequests.Store(sessionID+"-summarize", cancel)
670
671	go func() {
672		defer a.activeRequests.Delete(sessionID + "-summarize")
673		defer cancel()
674		event := AgentEvent{
675			Type:     AgentEventTypeSummarize,
676			Progress: "Starting summarization...",
677		}
678
679		a.Publish(pubsub.CreatedEvent, event)
680		// Get all messages from the session
681		msgs, err := a.messages.List(summarizeCtx, sessionID)
682		if err != nil {
683			event = AgentEvent{
684				Type:  AgentEventTypeError,
685				Error: fmt.Errorf("failed to list messages: %w", err),
686				Done:  true,
687			}
688			a.Publish(pubsub.CreatedEvent, event)
689			return
690		}
691		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
692
693		if len(msgs) == 0 {
694			event = AgentEvent{
695				Type:  AgentEventTypeError,
696				Error: fmt.Errorf("no messages to summarize"),
697				Done:  true,
698			}
699			a.Publish(pubsub.CreatedEvent, event)
700			return
701		}
702
703		event = AgentEvent{
704			Type:     AgentEventTypeSummarize,
705			Progress: "Analyzing conversation...",
706		}
707		a.Publish(pubsub.CreatedEvent, event)
708
709		// Add a system message to guide the summarization
710		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
711
712		// Create a new message with the summarize prompt
713		promptMsg := message.Message{
714			Role:  message.User,
715			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
716		}
717
718		// Append the prompt to the messages
719		msgsWithPrompt := append(msgs, promptMsg)
720
721		event = AgentEvent{
722			Type:     AgentEventTypeSummarize,
723			Progress: "Generating summary...",
724		}
725
726		a.Publish(pubsub.CreatedEvent, event)
727
728		// Send the messages to the summarize provider
729		response := a.summarizeProvider.StreamResponse(
730			summarizeCtx,
731			msgsWithPrompt,
732			make([]tools.BaseTool, 0),
733		)
734		var finalResponse *provider.ProviderResponse
735		for r := range response {
736			if r.Error != nil {
737				event = AgentEvent{
738					Type:  AgentEventTypeError,
739					Error: fmt.Errorf("failed to summarize: %w", err),
740					Done:  true,
741				}
742				a.Publish(pubsub.CreatedEvent, event)
743				return
744			}
745			finalResponse = r.Response
746		}
747
748		summary := strings.TrimSpace(finalResponse.Content)
749		if summary == "" {
750			event = AgentEvent{
751				Type:  AgentEventTypeError,
752				Error: fmt.Errorf("empty summary returned"),
753				Done:  true,
754			}
755			a.Publish(pubsub.CreatedEvent, event)
756			return
757		}
758		event = AgentEvent{
759			Type:     AgentEventTypeSummarize,
760			Progress: "Creating new session...",
761		}
762
763		a.Publish(pubsub.CreatedEvent, event)
764		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
765		if err != nil {
766			event = AgentEvent{
767				Type:  AgentEventTypeError,
768				Error: fmt.Errorf("failed to get session: %w", err),
769				Done:  true,
770			}
771
772			a.Publish(pubsub.CreatedEvent, event)
773			return
774		}
775		// Create a message in the new session with the summary
776		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
777			Role: message.Assistant,
778			Parts: []message.ContentPart{
779				message.TextContent{Text: summary},
780				message.Finish{
781					Reason: message.FinishReasonEndTurn,
782					Time:   time.Now().Unix(),
783				},
784			},
785			Model:    a.summarizeProvider.Model().ID,
786			Provider: a.summarizeProviderID,
787		})
788		if err != nil {
789			event = AgentEvent{
790				Type:  AgentEventTypeError,
791				Error: fmt.Errorf("failed to create summary message: %w", err),
792				Done:  true,
793			}
794
795			a.Publish(pubsub.CreatedEvent, event)
796			return
797		}
798		oldSession.SummaryMessageID = msg.ID
799		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
800		oldSession.PromptTokens = 0
801		model := a.summarizeProvider.Model()
802		usage := finalResponse.Usage
803		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
804			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
805			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
806			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
807		oldSession.Cost += cost
808		_, err = a.sessions.Save(summarizeCtx, oldSession)
809		if err != nil {
810			event = AgentEvent{
811				Type:  AgentEventTypeError,
812				Error: fmt.Errorf("failed to save session: %w", err),
813				Done:  true,
814			}
815			a.Publish(pubsub.CreatedEvent, event)
816		}
817
818		event = AgentEvent{
819			Type:      AgentEventTypeSummarize,
820			SessionID: oldSession.ID,
821			Progress:  "Summary complete",
822			Done:      true,
823		}
824		a.Publish(pubsub.CreatedEvent, event)
825		// Send final success event with the new session ID
826	}()
827
828	return nil
829}
830
831func (a *agent) CancelAll() {
832	if !a.IsBusy() {
833		return
834	}
835	a.activeRequests.Range(func(key, value any) bool {
836		a.Cancel(key.(string)) // key is sessionID
837		return true
838	})
839
840	timeout := time.After(5 * time.Second)
841	for a.IsBusy() {
842		select {
843		case <-timeout:
844			return
845		default:
846			time.Sleep(200 * time.Millisecond)
847		}
848	}
849}
850
851func (a *agent) UpdateModel() error {
852	cfg := config.Get()
853
854	// Get current provider configuration
855	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
856	if currentProviderCfg.ID == "" {
857		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
858	}
859
860	// Check if provider has changed
861	if string(currentProviderCfg.ID) != a.providerID {
862		// Provider changed, need to recreate the main provider
863		model := cfg.GetModelByType(a.agentCfg.Model)
864		if model.ID == "" {
865			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
866		}
867
868		promptID := agentPromptMap[a.agentCfg.ID]
869		if promptID == "" {
870			promptID = prompt.PromptDefault
871		}
872
873		opts := []provider.ProviderClientOption{
874			provider.WithModel(a.agentCfg.Model),
875			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
876		}
877
878		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
879		if err != nil {
880			return fmt.Errorf("failed to create new provider: %w", err)
881		}
882
883		// Update the provider and provider ID
884		a.provider = newProvider
885		a.providerID = string(currentProviderCfg.ID)
886	}
887
888	// Check if small model provider has changed (affects title and summarize providers)
889	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
890	var smallModelProviderCfg config.ProviderConfig
891
892	for _, p := range cfg.Providers {
893		if p.ID == smallModelCfg.Provider {
894			smallModelProviderCfg = p
895			break
896		}
897	}
898
899	if smallModelProviderCfg.ID == "" {
900		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
901	}
902
903	// Check if summarize provider has changed
904	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
905		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
906		if smallModel == nil {
907			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
908		}
909
910		// Recreate title provider
911		titleOpts := []provider.ProviderClientOption{
912			provider.WithModel(config.SelectedModelTypeSmall),
913			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
914			// We want the title to be short, so we limit the max tokens
915			provider.WithMaxTokens(40),
916		}
917		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
918		if err != nil {
919			return fmt.Errorf("failed to create new title provider: %w", err)
920		}
921
922		// Recreate summarize provider
923		summarizeOpts := []provider.ProviderClientOption{
924			provider.WithModel(config.SelectedModelTypeSmall),
925			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
926		}
927		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
928		if err != nil {
929			return fmt.Errorf("failed to create new summarize provider: %w", err)
930		}
931
932		// Update the providers and provider ID
933		a.titleProvider = newTitleProvider
934		a.summarizeProvider = newSummarizeProvider
935		a.summarizeProviderID = string(smallModelProviderCfg.ID)
936	}
937
938	return nil
939}