agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	tools *csync.LazySlice[tools.BaseTool]
 72
 73	provider   provider.Provider
 74	providerID string
 75
 76	titleProvider       provider.Provider
 77	summarizeProvider   provider.Provider
 78	summarizeProviderID string
 79
 80	activeRequests *csync.Map[string, context.CancelFunc]
 81}
 82
 83var agentPromptMap = map[string]prompt.PromptID{
 84	"coder": prompt.PromptCoder,
 85	"task":  prompt.PromptTask,
 86}
 87
 88func NewAgent(
 89	agentCfg config.Agent,
 90	// These services are needed in the tools
 91	permissions permission.Service,
 92	sessions session.Service,
 93	messages message.Service,
 94	history history.Service,
 95	lspClients map[string]*lsp.Client,
 96) (Service, error) {
 97	ctx := context.Background()
 98	cfg := config.Get()
 99
100	var agentTool tools.BaseTool
101	if agentCfg.ID == "coder" {
102		taskAgentCfg := config.Get().Agents["task"]
103		if taskAgentCfg.ID == "" {
104			return nil, fmt.Errorf("task agent not found in config")
105		}
106		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107		if err != nil {
108			return nil, fmt.Errorf("failed to create task agent: %w", err)
109		}
110
111		agentTool = NewAgentTool(taskAgent, sessions, messages)
112	}
113
114	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115	if providerCfg == nil {
116		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117	}
118	model := config.Get().GetModelByType(agentCfg.Model)
119
120	if model == nil {
121		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122	}
123
124	promptID := agentPromptMap[agentCfg.ID]
125	if promptID == "" {
126		promptID = prompt.PromptDefault
127	}
128	opts := []provider.ProviderClientOption{
129		provider.WithModel(agentCfg.Model),
130		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131	}
132	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133	if err != nil {
134		return nil, err
135	}
136
137	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138	var smallModelProviderCfg *config.ProviderConfig
139	if smallModelCfg.Provider == providerCfg.ID {
140		smallModelProviderCfg = providerCfg
141	} else {
142		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144		if smallModelProviderCfg.ID == "" {
145			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146		}
147	}
148	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149	if smallModel.ID == "" {
150		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151	}
152
153	titleOpts := []provider.ProviderClientOption{
154		provider.WithModel(config.SelectedModelTypeSmall),
155		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156	}
157	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158	if err != nil {
159		return nil, err
160	}
161	summarizeOpts := []provider.ProviderClientOption{
162		provider.WithModel(config.SelectedModelTypeSmall),
163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164	}
165	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166	if err != nil {
167		return nil, err
168	}
169
170	toolFn := func() []tools.BaseTool {
171		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172		defer func() {
173			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174		}()
175
176		cwd := cfg.WorkingDir()
177		allTools := []tools.BaseTool{
178			tools.NewBashTool(permissions, cwd),
179			tools.NewDownloadTool(permissions, cwd),
180			tools.NewEditTool(lspClients, permissions, history, cwd),
181			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
182			tools.NewFetchTool(permissions, cwd),
183			tools.NewGlobTool(cwd),
184			tools.NewGrepTool(cwd),
185			tools.NewLsTool(cwd),
186			tools.NewSourcegraphTool(),
187			tools.NewViewTool(lspClients, cwd),
188			tools.NewWriteTool(lspClients, permissions, history, cwd),
189		}
190
191		mcpTools := GetMCPTools(ctx, permissions, cfg)
192		allTools = append(allTools, mcpTools...)
193
194		if len(lspClients) > 0 {
195			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196		}
197
198		if agentTool != nil {
199			allTools = append(allTools, agentTool)
200		}
201
202		if agentCfg.AllowedTools == nil {
203			return allTools
204		}
205
206		var filteredTools []tools.BaseTool
207		for _, tool := range allTools {
208			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209				filteredTools = append(filteredTools, tool)
210			}
211		}
212		return filteredTools
213	}
214
215	return &agent{
216		Broker:              pubsub.NewBroker[AgentEvent](),
217		agentCfg:            agentCfg,
218		provider:            agentProvider,
219		providerID:          string(providerCfg.ID),
220		messages:            messages,
221		sessions:            sessions,
222		titleProvider:       titleProvider,
223		summarizeProvider:   summarizeProvider,
224		summarizeProviderID: string(smallModelProviderCfg.ID),
225		activeRequests:      csync.NewMap[string, context.CancelFunc](),
226		tools:               csync.NewLazySlice(toolFn),
227	}, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231	return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235	// Cancel regular requests
236	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
237		slog.Info("Request cancellation initiated", "session_id", sessionID)
238		cancel()
239	}
240
241	// Also check for summarize requests
242	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
243		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
244		cancel()
245	}
246}
247
248func (a *agent) IsBusy() bool {
249	var busy bool
250	for cancelFunc := range a.activeRequests.Seq() {
251		if cancelFunc != nil {
252			busy = true
253			break
254		}
255	}
256	return busy
257}
258
259func (a *agent) IsSessionBusy(sessionID string) bool {
260	_, busy := a.activeRequests.Get(sessionID)
261	return busy
262}
263
264func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
265	if content == "" {
266		return nil
267	}
268	if a.titleProvider == nil {
269		return nil
270	}
271	session, err := a.sessions.Get(ctx, sessionID)
272	if err != nil {
273		return err
274	}
275	parts := []message.ContentPart{message.TextContent{
276		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
277	}}
278
279	// Use streaming approach like summarization
280	response := a.titleProvider.StreamResponse(
281		ctx,
282		[]message.Message{
283			{
284				Role:  message.User,
285				Parts: parts,
286			},
287		},
288		nil,
289	)
290
291	var finalResponse *provider.ProviderResponse
292	for r := range response {
293		if r.Error != nil {
294			return r.Error
295		}
296		finalResponse = r.Response
297	}
298
299	if finalResponse == nil {
300		return fmt.Errorf("no response received from title provider")
301	}
302
303	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
304	if title == "" {
305		return nil
306	}
307
308	session.Title = title
309	_, err = a.sessions.Save(ctx, session)
310	return err
311}
312
313func (a *agent) err(err error) AgentEvent {
314	return AgentEvent{
315		Type:  AgentEventTypeError,
316		Error: err,
317	}
318}
319
320func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
321	if !a.Model().SupportsImages && attachments != nil {
322		attachments = nil
323	}
324	events := make(chan AgentEvent)
325	if a.IsSessionBusy(sessionID) {
326		return nil, ErrSessionBusy
327	}
328
329	genCtx, cancel := context.WithCancel(ctx)
330
331	a.activeRequests.Set(sessionID, cancel)
332	go func() {
333		slog.Debug("Request started", "sessionID", sessionID)
334		defer log.RecoverPanic("agent.Run", func() {
335			events <- a.err(fmt.Errorf("panic while running the agent"))
336		})
337		var attachmentParts []message.ContentPart
338		for _, attachment := range attachments {
339			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
340		}
341		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
342		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
343			slog.Error(result.Error.Error())
344		}
345		slog.Debug("Request completed", "sessionID", sessionID)
346		a.activeRequests.Del(sessionID)
347		cancel()
348		a.Publish(pubsub.CreatedEvent, result)
349		events <- result
350		close(events)
351	}()
352	return events, nil
353}
354
355func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
356	cfg := config.Get()
357	// List existing messages; if none, start title generation asynchronously.
358	msgs, err := a.messages.List(ctx, sessionID)
359	if err != nil {
360		return a.err(fmt.Errorf("failed to list messages: %w", err))
361	}
362	if len(msgs) == 0 {
363		go func() {
364			defer log.RecoverPanic("agent.Run", func() {
365				slog.Error("panic while generating title")
366			})
367			titleErr := a.generateTitle(context.Background(), sessionID, content)
368			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
369				slog.Error("failed to generate title", "error", titleErr)
370			}
371		}()
372	}
373	session, err := a.sessions.Get(ctx, sessionID)
374	if err != nil {
375		return a.err(fmt.Errorf("failed to get session: %w", err))
376	}
377	if session.SummaryMessageID != "" {
378		summaryMsgInex := -1
379		for i, msg := range msgs {
380			if msg.ID == session.SummaryMessageID {
381				summaryMsgInex = i
382				break
383			}
384		}
385		if summaryMsgInex != -1 {
386			msgs = msgs[summaryMsgInex:]
387			msgs[0].Role = message.User
388		}
389	}
390
391	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
392	if err != nil {
393		return a.err(fmt.Errorf("failed to create user message: %w", err))
394	}
395	// Append the new user message to the conversation history.
396	msgHistory := append(msgs, userMsg)
397
398	for {
399		// Check for cancellation before each iteration
400		select {
401		case <-ctx.Done():
402			return a.err(ctx.Err())
403		default:
404			// Continue processing
405		}
406		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
407		if err != nil {
408			if errors.Is(err, context.Canceled) {
409				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
410				a.messages.Update(context.Background(), agentMessage)
411				return a.err(ErrRequestCancelled)
412			}
413			return a.err(fmt.Errorf("failed to process events: %w", err))
414		}
415		if cfg.Options.Debug {
416			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
417		}
418		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
419			// We are not done, we need to respond with the tool response
420			msgHistory = append(msgHistory, agentMessage, *toolResults)
421			continue
422		}
423		return AgentEvent{
424			Type:    AgentEventTypeResponse,
425			Message: agentMessage,
426			Done:    true,
427		}
428	}
429}
430
431func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
432	parts := []message.ContentPart{message.TextContent{Text: content}}
433	parts = append(parts, attachmentParts...)
434	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
435		Role:  message.User,
436		Parts: parts,
437	})
438}
439
440func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
441	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
442	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
443
444	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445		Role:     message.Assistant,
446		Parts:    []message.ContentPart{},
447		Model:    a.Model().ID,
448		Provider: a.providerID,
449	})
450	if err != nil {
451		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
452	}
453
454	// Add the session and message ID into the context if needed by tools.
455	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
456
457	// Process each event in the stream.
458	for event := range eventChan {
459		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
460			if errors.Is(processErr, context.Canceled) {
461				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
462			} else {
463				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
464			}
465			return assistantMsg, nil, processErr
466		}
467		if ctx.Err() != nil {
468			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469			return assistantMsg, nil, ctx.Err()
470		}
471	}
472
473	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
474	toolCalls := assistantMsg.ToolCalls()
475	for i, toolCall := range toolCalls {
476		select {
477		case <-ctx.Done():
478			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
479			// Make all future tool calls cancelled
480			for j := i; j < len(toolCalls); j++ {
481				toolResults[j] = message.ToolResult{
482					ToolCallID: toolCalls[j].ID,
483					Content:    "Tool execution canceled by user",
484					IsError:    true,
485				}
486			}
487			goto out
488		default:
489			// Continue processing
490			var tool tools.BaseTool
491			for availableTool := range a.tools.Seq() {
492				if availableTool.Info().Name == toolCall.Name {
493					tool = availableTool
494					break
495				}
496			}
497
498			// Tool not found
499			if tool == nil {
500				toolResults[i] = message.ToolResult{
501					ToolCallID: toolCall.ID,
502					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
503					IsError:    true,
504				}
505				continue
506			}
507
508			// Run tool in goroutine to allow cancellation
509			type toolExecResult struct {
510				response tools.ToolResponse
511				err      error
512			}
513			resultChan := make(chan toolExecResult, 1)
514
515			go func() {
516				response, err := tool.Run(ctx, tools.ToolCall{
517					ID:    toolCall.ID,
518					Name:  toolCall.Name,
519					Input: toolCall.Input,
520				})
521				resultChan <- toolExecResult{response: response, err: err}
522			}()
523
524			var toolResponse tools.ToolResponse
525			var toolErr error
526
527			select {
528			case <-ctx.Done():
529				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
530				// Mark remaining tool calls as cancelled
531				for j := i; j < len(toolCalls); j++ {
532					toolResults[j] = message.ToolResult{
533						ToolCallID: toolCalls[j].ID,
534						Content:    "Tool execution canceled by user",
535						IsError:    true,
536					}
537				}
538				goto out
539			case result := <-resultChan:
540				toolResponse = result.response
541				toolErr = result.err
542			}
543
544			if toolErr != nil {
545				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
546				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
547					toolResults[i] = message.ToolResult{
548						ToolCallID: toolCall.ID,
549						Content:    "Permission denied",
550						IsError:    true,
551					}
552					for j := i + 1; j < len(toolCalls); j++ {
553						toolResults[j] = message.ToolResult{
554							ToolCallID: toolCalls[j].ID,
555							Content:    "Tool execution canceled by user",
556							IsError:    true,
557						}
558					}
559					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
560					break
561				}
562			}
563			toolResults[i] = message.ToolResult{
564				ToolCallID: toolCall.ID,
565				Content:    toolResponse.Content,
566				Metadata:   toolResponse.Metadata,
567				IsError:    toolResponse.IsError,
568			}
569		}
570	}
571out:
572	if len(toolResults) == 0 {
573		return assistantMsg, nil, nil
574	}
575	parts := make([]message.ContentPart, 0)
576	for _, tr := range toolResults {
577		parts = append(parts, tr)
578	}
579	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
580		Role:     message.Tool,
581		Parts:    parts,
582		Provider: a.providerID,
583	})
584	if err != nil {
585		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
586	}
587
588	return assistantMsg, &msg, err
589}
590
591func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
592	msg.AddFinish(finishReason, message, details)
593	_ = a.messages.Update(ctx, *msg)
594}
595
596func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
597	select {
598	case <-ctx.Done():
599		return ctx.Err()
600	default:
601		// Continue processing.
602	}
603
604	switch event.Type {
605	case provider.EventThinkingDelta:
606		assistantMsg.AppendReasoningContent(event.Thinking)
607		return a.messages.Update(ctx, *assistantMsg)
608	case provider.EventSignatureDelta:
609		assistantMsg.AppendReasoningSignature(event.Signature)
610		return a.messages.Update(ctx, *assistantMsg)
611	case provider.EventContentDelta:
612		assistantMsg.FinishThinking()
613		assistantMsg.AppendContent(event.Content)
614		return a.messages.Update(ctx, *assistantMsg)
615	case provider.EventToolUseStart:
616		assistantMsg.FinishThinking()
617		slog.Info("Tool call started", "toolCall", event.ToolCall)
618		assistantMsg.AddToolCall(*event.ToolCall)
619		return a.messages.Update(ctx, *assistantMsg)
620	case provider.EventToolUseDelta:
621		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
622		return a.messages.Update(ctx, *assistantMsg)
623	case provider.EventToolUseStop:
624		slog.Info("Finished tool call", "toolCall", event.ToolCall)
625		assistantMsg.FinishToolCall(event.ToolCall.ID)
626		return a.messages.Update(ctx, *assistantMsg)
627	case provider.EventError:
628		return event.Error
629	case provider.EventComplete:
630		assistantMsg.FinishThinking()
631		assistantMsg.SetToolCalls(event.Response.ToolCalls)
632		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
633		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
634			return fmt.Errorf("failed to update message: %w", err)
635		}
636		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
637	}
638
639	return nil
640}
641
642func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
643	sess, err := a.sessions.Get(ctx, sessionID)
644	if err != nil {
645		return fmt.Errorf("failed to get session: %w", err)
646	}
647
648	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
649		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
650		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
651		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
652
653	sess.Cost += cost
654	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
655	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
656
657	_, err = a.sessions.Save(ctx, sess)
658	if err != nil {
659		return fmt.Errorf("failed to save session: %w", err)
660	}
661	return nil
662}
663
664func (a *agent) Summarize(ctx context.Context, sessionID string) error {
665	if a.summarizeProvider == nil {
666		return fmt.Errorf("summarize provider not available")
667	}
668
669	// Check if session is busy
670	if a.IsSessionBusy(sessionID) {
671		return ErrSessionBusy
672	}
673
674	// Create a new context with cancellation
675	summarizeCtx, cancel := context.WithCancel(ctx)
676
677	// Store the cancel function in activeRequests to allow cancellation
678	a.activeRequests.Set(sessionID+"-summarize", cancel)
679
680	go func() {
681		defer a.activeRequests.Del(sessionID + "-summarize")
682		defer cancel()
683		event := AgentEvent{
684			Type:     AgentEventTypeSummarize,
685			Progress: "Starting summarization...",
686		}
687
688		a.Publish(pubsub.CreatedEvent, event)
689		// Get all messages from the session
690		msgs, err := a.messages.List(summarizeCtx, sessionID)
691		if err != nil {
692			event = AgentEvent{
693				Type:  AgentEventTypeError,
694				Error: fmt.Errorf("failed to list messages: %w", err),
695				Done:  true,
696			}
697			a.Publish(pubsub.CreatedEvent, event)
698			return
699		}
700		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
701
702		if len(msgs) == 0 {
703			event = AgentEvent{
704				Type:  AgentEventTypeError,
705				Error: fmt.Errorf("no messages to summarize"),
706				Done:  true,
707			}
708			a.Publish(pubsub.CreatedEvent, event)
709			return
710		}
711
712		event = AgentEvent{
713			Type:     AgentEventTypeSummarize,
714			Progress: "Analyzing conversation...",
715		}
716		a.Publish(pubsub.CreatedEvent, event)
717
718		// Add a system message to guide the summarization
719		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
720
721		// Create a new message with the summarize prompt
722		promptMsg := message.Message{
723			Role:  message.User,
724			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
725		}
726
727		// Append the prompt to the messages
728		msgsWithPrompt := append(msgs, promptMsg)
729
730		event = AgentEvent{
731			Type:     AgentEventTypeSummarize,
732			Progress: "Generating summary...",
733		}
734
735		a.Publish(pubsub.CreatedEvent, event)
736
737		// Send the messages to the summarize provider
738		response := a.summarizeProvider.StreamResponse(
739			summarizeCtx,
740			msgsWithPrompt,
741			nil,
742		)
743		var finalResponse *provider.ProviderResponse
744		for r := range response {
745			if r.Error != nil {
746				event = AgentEvent{
747					Type:  AgentEventTypeError,
748					Error: fmt.Errorf("failed to summarize: %w", err),
749					Done:  true,
750				}
751				a.Publish(pubsub.CreatedEvent, event)
752				return
753			}
754			finalResponse = r.Response
755		}
756
757		summary := strings.TrimSpace(finalResponse.Content)
758		if summary == "" {
759			event = AgentEvent{
760				Type:  AgentEventTypeError,
761				Error: fmt.Errorf("empty summary returned"),
762				Done:  true,
763			}
764			a.Publish(pubsub.CreatedEvent, event)
765			return
766		}
767		shell := shell.GetPersistentShell(config.Get().WorkingDir())
768		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
769		event = AgentEvent{
770			Type:     AgentEventTypeSummarize,
771			Progress: "Creating new session...",
772		}
773
774		a.Publish(pubsub.CreatedEvent, event)
775		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
776		if err != nil {
777			event = AgentEvent{
778				Type:  AgentEventTypeError,
779				Error: fmt.Errorf("failed to get session: %w", err),
780				Done:  true,
781			}
782
783			a.Publish(pubsub.CreatedEvent, event)
784			return
785		}
786		// Create a message in the new session with the summary
787		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
788			Role: message.Assistant,
789			Parts: []message.ContentPart{
790				message.TextContent{Text: summary},
791				message.Finish{
792					Reason: message.FinishReasonEndTurn,
793					Time:   time.Now().Unix(),
794				},
795			},
796			Model:    a.summarizeProvider.Model().ID,
797			Provider: a.summarizeProviderID,
798		})
799		if err != nil {
800			event = AgentEvent{
801				Type:  AgentEventTypeError,
802				Error: fmt.Errorf("failed to create summary message: %w", err),
803				Done:  true,
804			}
805
806			a.Publish(pubsub.CreatedEvent, event)
807			return
808		}
809		oldSession.SummaryMessageID = msg.ID
810		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
811		oldSession.PromptTokens = 0
812		model := a.summarizeProvider.Model()
813		usage := finalResponse.Usage
814		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
815			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
816			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
817			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
818		oldSession.Cost += cost
819		_, err = a.sessions.Save(summarizeCtx, oldSession)
820		if err != nil {
821			event = AgentEvent{
822				Type:  AgentEventTypeError,
823				Error: fmt.Errorf("failed to save session: %w", err),
824				Done:  true,
825			}
826			a.Publish(pubsub.CreatedEvent, event)
827		}
828
829		event = AgentEvent{
830			Type:      AgentEventTypeSummarize,
831			SessionID: oldSession.ID,
832			Progress:  "Summary complete",
833			Done:      true,
834		}
835		a.Publish(pubsub.CreatedEvent, event)
836		// Send final success event with the new session ID
837	}()
838
839	return nil
840}
841
842func (a *agent) CancelAll() {
843	if !a.IsBusy() {
844		return
845	}
846	for key := range a.activeRequests.Seq2() {
847		a.Cancel(key) // key is sessionID
848	}
849
850	timeout := time.After(5 * time.Second)
851	for a.IsBusy() {
852		select {
853		case <-timeout:
854			return
855		default:
856			time.Sleep(200 * time.Millisecond)
857		}
858	}
859}
860
861func (a *agent) UpdateModel() error {
862	cfg := config.Get()
863
864	// Get current provider configuration
865	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
866	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
867		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
868	}
869
870	// Check if provider has changed
871	if string(currentProviderCfg.ID) != a.providerID {
872		// Provider changed, need to recreate the main provider
873		model := cfg.GetModelByType(a.agentCfg.Model)
874		if model.ID == "" {
875			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
876		}
877
878		promptID := agentPromptMap[a.agentCfg.ID]
879		if promptID == "" {
880			promptID = prompt.PromptDefault
881		}
882
883		opts := []provider.ProviderClientOption{
884			provider.WithModel(a.agentCfg.Model),
885			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
886		}
887
888		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
889		if err != nil {
890			return fmt.Errorf("failed to create new provider: %w", err)
891		}
892
893		// Update the provider and provider ID
894		a.provider = newProvider
895		a.providerID = string(currentProviderCfg.ID)
896	}
897
898	// Check if small model provider has changed (affects title and summarize providers)
899	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
900	var smallModelProviderCfg config.ProviderConfig
901
902	for p := range cfg.Providers.Seq() {
903		if p.ID == smallModelCfg.Provider {
904			smallModelProviderCfg = p
905			break
906		}
907	}
908
909	if smallModelProviderCfg.ID == "" {
910		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
911	}
912
913	// Check if summarize provider has changed
914	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
915		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
916		if smallModel == nil {
917			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
918		}
919
920		// Recreate title provider
921		titleOpts := []provider.ProviderClientOption{
922			provider.WithModel(config.SelectedModelTypeSmall),
923			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
924			// We want the title to be short, so we limit the max tokens
925			provider.WithMaxTokens(40),
926		}
927		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
928		if err != nil {
929			return fmt.Errorf("failed to create new title provider: %w", err)
930		}
931
932		// Recreate summarize provider
933		summarizeOpts := []provider.ProviderClientOption{
934			provider.WithModel(config.SelectedModelTypeSmall),
935			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
936		}
937		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
938		if err != nil {
939			return fmt.Errorf("failed to create new summarize provider: %w", err)
940		}
941
942		// Update the providers and provider ID
943		a.titleProvider = newTitleProvider
944		a.summarizeProvider = newSummarizeProvider
945		a.summarizeProviderID = string(smallModelProviderCfg.ID)
946	}
947
948	return nil
949}