agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	tools *csync.LazySlice[tools.BaseTool]
 72
 73	provider   provider.Provider
 74	providerID string
 75
 76	titleProvider       provider.Provider
 77	summarizeProvider   provider.Provider
 78	summarizeProviderID string
 79
 80	activeRequests *csync.Map[string, context.CancelFunc]
 81}
 82
 83var agentPromptMap = map[string]prompt.PromptID{
 84	"coder": prompt.PromptCoder,
 85	"task":  prompt.PromptTask,
 86}
 87
 88func NewAgent(
 89	agentCfg config.Agent,
 90	// These services are needed in the tools
 91	permissions permission.Service,
 92	sessions session.Service,
 93	messages message.Service,
 94	history history.Service,
 95	lspClients map[string]*lsp.Client,
 96) (Service, error) {
 97	ctx := context.Background()
 98	cfg := config.Get()
 99
100	var agentTool tools.BaseTool
101	if agentCfg.ID == "coder" {
102		taskAgentCfg := config.Get().Agents["task"]
103		if taskAgentCfg.ID == "" {
104			return nil, fmt.Errorf("task agent not found in config")
105		}
106		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107		if err != nil {
108			return nil, fmt.Errorf("failed to create task agent: %w", err)
109		}
110
111		agentTool = NewAgentTool(taskAgent, sessions, messages)
112	}
113
114	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115	if providerCfg == nil {
116		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117	}
118	model := config.Get().GetModelByType(agentCfg.Model)
119
120	if model == nil {
121		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122	}
123
124	promptID := agentPromptMap[agentCfg.ID]
125	if promptID == "" {
126		promptID = prompt.PromptDefault
127	}
128	opts := []provider.ProviderClientOption{
129		provider.WithModel(agentCfg.Model),
130		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131	}
132	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133	if err != nil {
134		return nil, err
135	}
136
137	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138	var smallModelProviderCfg *config.ProviderConfig
139	if smallModelCfg.Provider == providerCfg.ID {
140		smallModelProviderCfg = providerCfg
141	} else {
142		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144		if smallModelProviderCfg.ID == "" {
145			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146		}
147	}
148	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149	if smallModel.ID == "" {
150		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151	}
152
153	titleOpts := []provider.ProviderClientOption{
154		provider.WithModel(config.SelectedModelTypeSmall),
155		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156	}
157	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158	if err != nil {
159		return nil, err
160	}
161	summarizeOpts := []provider.ProviderClientOption{
162		provider.WithModel(config.SelectedModelTypeSmall),
163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164	}
165	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166	if err != nil {
167		return nil, err
168	}
169
170	toolFn := func() []tools.BaseTool {
171		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172		defer func() {
173			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174		}()
175
176		cwd := cfg.WorkingDir()
177		allTools := []tools.BaseTool{
178			tools.NewBashTool(permissions, cwd),
179			tools.NewDownloadTool(permissions, cwd),
180			tools.NewEditTool(lspClients, permissions, history, cwd),
181			tools.NewFetchTool(permissions, cwd),
182			tools.NewGlobTool(cwd),
183			tools.NewGrepTool(cwd),
184			tools.NewLsTool(permissions, cwd),
185			tools.NewSourcegraphTool(),
186			tools.NewViewTool(lspClients, permissions, cwd),
187			tools.NewWriteTool(lspClients, permissions, history, cwd),
188		}
189
190		mcpTools := GetMCPTools(ctx, permissions, cfg)
191		allTools = append(allTools, mcpTools...)
192
193		if len(lspClients) > 0 {
194			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
195		}
196
197		if agentTool != nil {
198			allTools = append(allTools, agentTool)
199		}
200
201		if agentCfg.AllowedTools == nil {
202			return allTools
203		}
204
205		var filteredTools []tools.BaseTool
206		for _, tool := range allTools {
207			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208				filteredTools = append(filteredTools, tool)
209			}
210		}
211		return filteredTools
212	}
213
214	return &agent{
215		Broker:              pubsub.NewBroker[AgentEvent](),
216		agentCfg:            agentCfg,
217		provider:            agentProvider,
218		providerID:          string(providerCfg.ID),
219		messages:            messages,
220		sessions:            sessions,
221		titleProvider:       titleProvider,
222		summarizeProvider:   summarizeProvider,
223		summarizeProviderID: string(smallModelProviderCfg.ID),
224		activeRequests:      csync.NewMap[string, context.CancelFunc](),
225		tools:               csync.NewLazySlice(toolFn),
226	}, nil
227}
228
229func (a *agent) Model() catwalk.Model {
230	return *config.Get().GetModelByType(a.agentCfg.Model)
231}
232
233func (a *agent) Cancel(sessionID string) {
234	// Cancel regular requests
235	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
236		slog.Info("Request cancellation initiated", "session_id", sessionID)
237		cancel()
238	}
239
240	// Also check for summarize requests
241	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
242		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
243		cancel()
244	}
245}
246
247func (a *agent) IsBusy() bool {
248	var busy bool
249	for cancelFunc := range a.activeRequests.Seq() {
250		if cancelFunc != nil {
251			busy = true
252			break
253		}
254	}
255	return busy
256}
257
258func (a *agent) IsSessionBusy(sessionID string) bool {
259	_, busy := a.activeRequests.Get(sessionID)
260	return busy
261}
262
263func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
264	if content == "" {
265		return nil
266	}
267	if a.titleProvider == nil {
268		return nil
269	}
270	session, err := a.sessions.Get(ctx, sessionID)
271	if err != nil {
272		return err
273	}
274	parts := []message.ContentPart{message.TextContent{
275		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
276	}}
277
278	// Use streaming approach like summarization
279	response := a.titleProvider.StreamResponse(
280		ctx,
281		[]message.Message{
282			{
283				Role:  message.User,
284				Parts: parts,
285			},
286		},
287		nil,
288	)
289
290	var finalResponse *provider.ProviderResponse
291	for r := range response {
292		if r.Error != nil {
293			return r.Error
294		}
295		finalResponse = r.Response
296	}
297
298	if finalResponse == nil {
299		return fmt.Errorf("no response received from title provider")
300	}
301
302	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
303	if title == "" {
304		return nil
305	}
306
307	session.Title = title
308	_, err = a.sessions.Save(ctx, session)
309	return err
310}
311
312func (a *agent) err(err error) AgentEvent {
313	return AgentEvent{
314		Type:  AgentEventTypeError,
315		Error: err,
316	}
317}
318
319func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
320	if !a.Model().SupportsImages && attachments != nil {
321		attachments = nil
322	}
323	events := make(chan AgentEvent)
324	if a.IsSessionBusy(sessionID) {
325		return nil, ErrSessionBusy
326	}
327
328	genCtx, cancel := context.WithCancel(ctx)
329
330	a.activeRequests.Set(sessionID, cancel)
331	go func() {
332		slog.Debug("Request started", "sessionID", sessionID)
333		defer log.RecoverPanic("agent.Run", func() {
334			events <- a.err(fmt.Errorf("panic while running the agent"))
335		})
336		var attachmentParts []message.ContentPart
337		for _, attachment := range attachments {
338			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
339		}
340		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
341		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
342			slog.Error(result.Error.Error())
343		}
344		slog.Debug("Request completed", "sessionID", sessionID)
345		a.activeRequests.Del(sessionID)
346		cancel()
347		a.Publish(pubsub.CreatedEvent, result)
348		events <- result
349		close(events)
350	}()
351	return events, nil
352}
353
354func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
355	cfg := config.Get()
356	// List existing messages; if none, start title generation asynchronously.
357	msgs, err := a.messages.List(ctx, sessionID)
358	if err != nil {
359		return a.err(fmt.Errorf("failed to list messages: %w", err))
360	}
361	if len(msgs) == 0 {
362		go func() {
363			defer log.RecoverPanic("agent.Run", func() {
364				slog.Error("panic while generating title")
365			})
366			titleErr := a.generateTitle(context.Background(), sessionID, content)
367			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
368				slog.Error("failed to generate title", "error", titleErr)
369			}
370		}()
371	}
372	session, err := a.sessions.Get(ctx, sessionID)
373	if err != nil {
374		return a.err(fmt.Errorf("failed to get session: %w", err))
375	}
376	if session.SummaryMessageID != "" {
377		summaryMsgInex := -1
378		for i, msg := range msgs {
379			if msg.ID == session.SummaryMessageID {
380				summaryMsgInex = i
381				break
382			}
383		}
384		if summaryMsgInex != -1 {
385			msgs = msgs[summaryMsgInex:]
386			msgs[0].Role = message.User
387		}
388	}
389
390	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
391	if err != nil {
392		return a.err(fmt.Errorf("failed to create user message: %w", err))
393	}
394	// Append the new user message to the conversation history.
395	msgHistory := append(msgs, userMsg)
396
397	for {
398		// Check for cancellation before each iteration
399		select {
400		case <-ctx.Done():
401			return a.err(ctx.Err())
402		default:
403			// Continue processing
404		}
405		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
406		if err != nil {
407			if errors.Is(err, context.Canceled) {
408				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
409				a.messages.Update(context.Background(), agentMessage)
410				return a.err(ErrRequestCancelled)
411			}
412			return a.err(fmt.Errorf("failed to process events: %w", err))
413		}
414		if cfg.Options.Debug {
415			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416		}
417		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418			// We are not done, we need to respond with the tool response
419			msgHistory = append(msgHistory, agentMessage, *toolResults)
420			continue
421		}
422		return AgentEvent{
423			Type:    AgentEventTypeResponse,
424			Message: agentMessage,
425			Done:    true,
426		}
427	}
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431	parts := []message.ContentPart{message.TextContent{Text: content}}
432	parts = append(parts, attachmentParts...)
433	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434		Role:  message.User,
435		Parts: parts,
436	})
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
442
443	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:     message.Assistant,
445		Parts:    []message.ContentPart{},
446		Model:    a.Model().ID,
447		Provider: a.providerID,
448	})
449	if err != nil {
450		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451	}
452
453	// Add the session and message ID into the context if needed by tools.
454	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456	// Process each event in the stream.
457	for event := range eventChan {
458		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459			if errors.Is(processErr, context.Canceled) {
460				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
461			} else {
462				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
463			}
464			return assistantMsg, nil, processErr
465		}
466		if ctx.Err() != nil {
467			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468			return assistantMsg, nil, ctx.Err()
469		}
470	}
471
472	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473	toolCalls := assistantMsg.ToolCalls()
474	for i, toolCall := range toolCalls {
475		select {
476		case <-ctx.Done():
477			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478			// Make all future tool calls cancelled
479			for j := i; j < len(toolCalls); j++ {
480				toolResults[j] = message.ToolResult{
481					ToolCallID: toolCalls[j].ID,
482					Content:    "Tool execution canceled by user",
483					IsError:    true,
484				}
485			}
486			goto out
487		default:
488			// Continue processing
489			var tool tools.BaseTool
490			for availableTool := range a.tools.Seq() {
491				if availableTool.Info().Name == toolCall.Name {
492					tool = availableTool
493					break
494				}
495			}
496
497			// Tool not found
498			if tool == nil {
499				toolResults[i] = message.ToolResult{
500					ToolCallID: toolCall.ID,
501					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
502					IsError:    true,
503				}
504				continue
505			}
506
507			// Run tool in goroutine to allow cancellation
508			type toolExecResult struct {
509				response tools.ToolResponse
510				err      error
511			}
512			resultChan := make(chan toolExecResult, 1)
513
514			go func() {
515				response, err := tool.Run(ctx, tools.ToolCall{
516					ID:    toolCall.ID,
517					Name:  toolCall.Name,
518					Input: toolCall.Input,
519				})
520				resultChan <- toolExecResult{response: response, err: err}
521			}()
522
523			var toolResponse tools.ToolResponse
524			var toolErr error
525
526			select {
527			case <-ctx.Done():
528				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
529				// Mark remaining tool calls as cancelled
530				for j := i; j < len(toolCalls); j++ {
531					toolResults[j] = message.ToolResult{
532						ToolCallID: toolCalls[j].ID,
533						Content:    "Tool execution canceled by user",
534						IsError:    true,
535					}
536				}
537				goto out
538			case result := <-resultChan:
539				toolResponse = result.response
540				toolErr = result.err
541			}
542
543			if toolErr != nil {
544				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
545				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
546					toolResults[i] = message.ToolResult{
547						ToolCallID: toolCall.ID,
548						Content:    "Permission denied",
549						IsError:    true,
550					}
551					for j := i + 1; j < len(toolCalls); j++ {
552						toolResults[j] = message.ToolResult{
553							ToolCallID: toolCalls[j].ID,
554							Content:    "Tool execution canceled by user",
555							IsError:    true,
556						}
557					}
558					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
559					break
560				}
561			}
562			toolResults[i] = message.ToolResult{
563				ToolCallID: toolCall.ID,
564				Content:    toolResponse.Content,
565				Metadata:   toolResponse.Metadata,
566				IsError:    toolResponse.IsError,
567			}
568		}
569	}
570out:
571	if len(toolResults) == 0 {
572		return assistantMsg, nil, nil
573	}
574	parts := make([]message.ContentPart, 0)
575	for _, tr := range toolResults {
576		parts = append(parts, tr)
577	}
578	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
579		Role:     message.Tool,
580		Parts:    parts,
581		Provider: a.providerID,
582	})
583	if err != nil {
584		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
585	}
586
587	return assistantMsg, &msg, err
588}
589
590func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
591	msg.AddFinish(finishReason, message, details)
592	_ = a.messages.Update(ctx, *msg)
593}
594
595func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
596	select {
597	case <-ctx.Done():
598		return ctx.Err()
599	default:
600		// Continue processing.
601	}
602
603	switch event.Type {
604	case provider.EventThinkingDelta:
605		assistantMsg.AppendReasoningContent(event.Thinking)
606		return a.messages.Update(ctx, *assistantMsg)
607	case provider.EventSignatureDelta:
608		assistantMsg.AppendReasoningSignature(event.Signature)
609		return a.messages.Update(ctx, *assistantMsg)
610	case provider.EventContentDelta:
611		assistantMsg.FinishThinking()
612		assistantMsg.AppendContent(event.Content)
613		return a.messages.Update(ctx, *assistantMsg)
614	case provider.EventToolUseStart:
615		assistantMsg.FinishThinking()
616		slog.Info("Tool call started", "toolCall", event.ToolCall)
617		assistantMsg.AddToolCall(*event.ToolCall)
618		return a.messages.Update(ctx, *assistantMsg)
619	case provider.EventToolUseDelta:
620		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
621		return a.messages.Update(ctx, *assistantMsg)
622	case provider.EventToolUseStop:
623		slog.Info("Finished tool call", "toolCall", event.ToolCall)
624		assistantMsg.FinishToolCall(event.ToolCall.ID)
625		return a.messages.Update(ctx, *assistantMsg)
626	case provider.EventError:
627		return event.Error
628	case provider.EventComplete:
629		assistantMsg.FinishThinking()
630		assistantMsg.SetToolCalls(event.Response.ToolCalls)
631		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
632		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
633			return fmt.Errorf("failed to update message: %w", err)
634		}
635		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
636	}
637
638	return nil
639}
640
641func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
642	sess, err := a.sessions.Get(ctx, sessionID)
643	if err != nil {
644		return fmt.Errorf("failed to get session: %w", err)
645	}
646
647	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
648		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
649		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
650		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
651
652	sess.Cost += cost
653	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
654	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
655
656	_, err = a.sessions.Save(ctx, sess)
657	if err != nil {
658		return fmt.Errorf("failed to save session: %w", err)
659	}
660	return nil
661}
662
663func (a *agent) Summarize(ctx context.Context, sessionID string) error {
664	if a.summarizeProvider == nil {
665		return fmt.Errorf("summarize provider not available")
666	}
667
668	// Check if session is busy
669	if a.IsSessionBusy(sessionID) {
670		return ErrSessionBusy
671	}
672
673	// Create a new context with cancellation
674	summarizeCtx, cancel := context.WithCancel(ctx)
675
676	// Store the cancel function in activeRequests to allow cancellation
677	a.activeRequests.Set(sessionID+"-summarize", cancel)
678
679	go func() {
680		defer a.activeRequests.Del(sessionID + "-summarize")
681		defer cancel()
682		event := AgentEvent{
683			Type:     AgentEventTypeSummarize,
684			Progress: "Starting summarization...",
685		}
686
687		a.Publish(pubsub.CreatedEvent, event)
688		// Get all messages from the session
689		msgs, err := a.messages.List(summarizeCtx, sessionID)
690		if err != nil {
691			event = AgentEvent{
692				Type:  AgentEventTypeError,
693				Error: fmt.Errorf("failed to list messages: %w", err),
694				Done:  true,
695			}
696			a.Publish(pubsub.CreatedEvent, event)
697			return
698		}
699		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
700
701		if len(msgs) == 0 {
702			event = AgentEvent{
703				Type:  AgentEventTypeError,
704				Error: fmt.Errorf("no messages to summarize"),
705				Done:  true,
706			}
707			a.Publish(pubsub.CreatedEvent, event)
708			return
709		}
710
711		event = AgentEvent{
712			Type:     AgentEventTypeSummarize,
713			Progress: "Analyzing conversation...",
714		}
715		a.Publish(pubsub.CreatedEvent, event)
716
717		// Add a system message to guide the summarization
718		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
719
720		// Create a new message with the summarize prompt
721		promptMsg := message.Message{
722			Role:  message.User,
723			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
724		}
725
726		// Append the prompt to the messages
727		msgsWithPrompt := append(msgs, promptMsg)
728
729		event = AgentEvent{
730			Type:     AgentEventTypeSummarize,
731			Progress: "Generating summary...",
732		}
733
734		a.Publish(pubsub.CreatedEvent, event)
735
736		// Send the messages to the summarize provider
737		response := a.summarizeProvider.StreamResponse(
738			summarizeCtx,
739			msgsWithPrompt,
740			nil,
741		)
742		var finalResponse *provider.ProviderResponse
743		for r := range response {
744			if r.Error != nil {
745				event = AgentEvent{
746					Type:  AgentEventTypeError,
747					Error: fmt.Errorf("failed to summarize: %w", err),
748					Done:  true,
749				}
750				a.Publish(pubsub.CreatedEvent, event)
751				return
752			}
753			finalResponse = r.Response
754		}
755
756		summary := strings.TrimSpace(finalResponse.Content)
757		if summary == "" {
758			event = AgentEvent{
759				Type:  AgentEventTypeError,
760				Error: fmt.Errorf("empty summary returned"),
761				Done:  true,
762			}
763			a.Publish(pubsub.CreatedEvent, event)
764			return
765		}
766		shell := shell.GetPersistentShell(config.Get().WorkingDir())
767		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
768		event = AgentEvent{
769			Type:     AgentEventTypeSummarize,
770			Progress: "Creating new session...",
771		}
772
773		a.Publish(pubsub.CreatedEvent, event)
774		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
775		if err != nil {
776			event = AgentEvent{
777				Type:  AgentEventTypeError,
778				Error: fmt.Errorf("failed to get session: %w", err),
779				Done:  true,
780			}
781
782			a.Publish(pubsub.CreatedEvent, event)
783			return
784		}
785		// Create a message in the new session with the summary
786		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
787			Role: message.Assistant,
788			Parts: []message.ContentPart{
789				message.TextContent{Text: summary},
790				message.Finish{
791					Reason: message.FinishReasonEndTurn,
792					Time:   time.Now().Unix(),
793				},
794			},
795			Model:    a.summarizeProvider.Model().ID,
796			Provider: a.summarizeProviderID,
797		})
798		if err != nil {
799			event = AgentEvent{
800				Type:  AgentEventTypeError,
801				Error: fmt.Errorf("failed to create summary message: %w", err),
802				Done:  true,
803			}
804
805			a.Publish(pubsub.CreatedEvent, event)
806			return
807		}
808		oldSession.SummaryMessageID = msg.ID
809		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
810		oldSession.PromptTokens = 0
811		model := a.summarizeProvider.Model()
812		usage := finalResponse.Usage
813		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
814			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
815			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
816			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
817		oldSession.Cost += cost
818		_, err = a.sessions.Save(summarizeCtx, oldSession)
819		if err != nil {
820			event = AgentEvent{
821				Type:  AgentEventTypeError,
822				Error: fmt.Errorf("failed to save session: %w", err),
823				Done:  true,
824			}
825			a.Publish(pubsub.CreatedEvent, event)
826		}
827
828		event = AgentEvent{
829			Type:      AgentEventTypeSummarize,
830			SessionID: oldSession.ID,
831			Progress:  "Summary complete",
832			Done:      true,
833		}
834		a.Publish(pubsub.CreatedEvent, event)
835		// Send final success event with the new session ID
836	}()
837
838	return nil
839}
840
841func (a *agent) CancelAll() {
842	if !a.IsBusy() {
843		return
844	}
845	for key := range a.activeRequests.Seq2() {
846		a.Cancel(key) // key is sessionID
847	}
848
849	timeout := time.After(5 * time.Second)
850	for a.IsBusy() {
851		select {
852		case <-timeout:
853			return
854		default:
855			time.Sleep(200 * time.Millisecond)
856		}
857	}
858}
859
860func (a *agent) UpdateModel() error {
861	cfg := config.Get()
862
863	// Get current provider configuration
864	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
865	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
866		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
867	}
868
869	// Check if provider has changed
870	if string(currentProviderCfg.ID) != a.providerID {
871		// Provider changed, need to recreate the main provider
872		model := cfg.GetModelByType(a.agentCfg.Model)
873		if model.ID == "" {
874			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
875		}
876
877		promptID := agentPromptMap[a.agentCfg.ID]
878		if promptID == "" {
879			promptID = prompt.PromptDefault
880		}
881
882		opts := []provider.ProviderClientOption{
883			provider.WithModel(a.agentCfg.Model),
884			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
885		}
886
887		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
888		if err != nil {
889			return fmt.Errorf("failed to create new provider: %w", err)
890		}
891
892		// Update the provider and provider ID
893		a.provider = newProvider
894		a.providerID = string(currentProviderCfg.ID)
895	}
896
897	// Check if small model provider has changed (affects title and summarize providers)
898	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
899	var smallModelProviderCfg config.ProviderConfig
900
901	for p := range cfg.Providers.Seq() {
902		if p.ID == smallModelCfg.Provider {
903			smallModelProviderCfg = p
904			break
905		}
906	}
907
908	if smallModelProviderCfg.ID == "" {
909		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
910	}
911
912	// Check if summarize provider has changed
913	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
914		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
915		if smallModel == nil {
916			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
917		}
918
919		// Recreate title provider
920		titleOpts := []provider.ProviderClientOption{
921			provider.WithModel(config.SelectedModelTypeSmall),
922			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
923			// We want the title to be short, so we limit the max tokens
924			provider.WithMaxTokens(40),
925		}
926		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
927		if err != nil {
928			return fmt.Errorf("failed to create new title provider: %w", err)
929		}
930
931		// Recreate summarize provider
932		summarizeOpts := []provider.ProviderClientOption{
933			provider.WithModel(config.SelectedModelTypeSmall),
934			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
935		}
936		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
937		if err != nil {
938			return fmt.Errorf("failed to create new summarize provider: %w", err)
939		}
940
941		// Update the providers and provider ID
942		a.titleProvider = newTitleProvider
943		a.summarizeProvider = newSummarizeProvider
944		a.summarizeProviderID = string(smallModelProviderCfg.ID)
945	}
946
947	return nil
948}