agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"strings"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/llm/provider"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/log"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/crush/internal/shell"
 22)
 23
 24// Common errors
 25var (
 26	ErrRequestCancelled = errors.New("request canceled by user")
 27	ErrSessionBusy      = errors.New("session is currently processing another request")
 28)
 29
 30type AgentEventType string
 31
 32const (
 33	AgentEventTypeError     AgentEventType = "error"
 34	AgentEventTypeResponse  AgentEventType = "response"
 35	AgentEventTypeSummarize AgentEventType = "summarize"
 36)
 37
 38type AgentEvent struct {
 39	Type    AgentEventType
 40	Message message.Message
 41	Error   error
 42
 43	// When summarizing
 44	SessionID string
 45	Progress  string
 46	Done      bool
 47}
 48
 49type Service interface {
 50	pubsub.Suscriber[AgentEvent]
 51	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	Summarize(ctx context.Context, sessionID string) error
 57	SetDebug(debug bool)
 58	UpdateModels(large, small Model) error
 59	// for now, not really sure how to handle this better
 60	WithAgentTool() error
 61
 62	ModelConfig() Model
 63	Model() *catwalk.Model
 64	Provider() *provider.Config
 65}
 66
 67type Model struct {
 68	// The model id as used by the provider API.
 69	// Required.
 70	Model string `json:"model"`
 71	// The model provider, same as the key/id used in the providers config.
 72	// Required.
 73	Provider string `json:"provider"`
 74
 75	// Only used by models that use the openai provider and need this set.
 76	ReasoningEffort string `json:"reasoning_effort,omitempty"`
 77
 78	// Overrides the default model configuration.
 79	MaxTokens int64 `json:"max_tokens,omitempty"`
 80
 81	// Used by anthropic models that can reason to indicate if the model should think.
 82	Think bool `json:"think,omitempty"`
 83}
 84
 85type agent struct {
 86	*pubsub.Broker[AgentEvent]
 87	ctx          context.Context
 88	cwd          string
 89	systemPrompt string
 90	providers    map[string]provider.Config
 91
 92	sessions session.Service
 93	messages message.Service
 94
 95	toolsRegistry tools.Registry
 96
 97	large, small      Model
 98	provider          provider.Provider
 99	titleProvider     provider.Provider
100	summarizeProvider provider.Provider
101
102	activeRequests *csync.Map[string, context.CancelFunc]
103
104	debug bool
105}
106
107func NewAgent(
108	ctx context.Context,
109	cwd string,
110	systemPrompt string,
111	toolsRegistry tools.Registry,
112	providers map[string]provider.Config,
113
114	smallModel Model,
115	largeModel Model,
116
117	sessions session.Service,
118	messages message.Service,
119) (Service, error) {
120	agent := &agent{
121		Broker:         pubsub.NewBroker[AgentEvent](),
122		ctx:            ctx,
123		providers:      providers,
124		cwd:            cwd,
125		systemPrompt:   systemPrompt,
126		toolsRegistry:  toolsRegistry,
127		small:          smallModel,
128		large:          largeModel,
129		messages:       messages,
130		sessions:       sessions,
131		activeRequests: csync.NewMap[string, context.CancelFunc](),
132	}
133
134	err := agent.setProviders()
135	return agent, err
136}
137
138func (a *agent) ModelConfig() Model {
139	return a.large
140}
141
142func (a *agent) Model() *catwalk.Model {
143	return a.provider.Model(a.large.Model)
144}
145
146func (a *agent) Provider() *provider.Config {
147	for _, provider := range a.providers {
148		if provider.ID == a.large.Provider {
149			return &provider
150		}
151	}
152	return nil
153}
154
155func (a *agent) Cancel(sessionID string) {
156	// Cancel regular requests
157	if cancel, exists := a.activeRequests.Take(sessionID); exists {
158		slog.Info("Request cancellation initiated", "session_id", sessionID)
159		cancel()
160	}
161
162	// Also check for summarize requests
163	if cancel, exists := a.activeRequests.Take(sessionID + "-summarize"); exists {
164		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
165		cancel()
166	}
167}
168
169func (a *agent) IsBusy() bool {
170	busy := false
171	for cancelFunc := range a.activeRequests.Seq() {
172		if cancelFunc != nil {
173			busy = true
174			break
175		}
176	}
177	return busy
178}
179
180func (a *agent) IsSessionBusy(sessionID string) bool {
181	_, busy := a.activeRequests.Get(sessionID)
182	return busy
183}
184
185func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
186	if content == "" {
187		return nil
188	}
189	if a.titleProvider == nil {
190		return nil
191	}
192	session, err := a.sessions.Get(ctx, sessionID)
193	if err != nil {
194		return err
195	}
196	parts := []message.ContentPart{message.TextContent{
197		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
198	}}
199
200	response := a.titleProvider.Stream(
201		ctx,
202		a.small.Model,
203		[]message.Message{
204			{
205				Role:  message.User,
206				Parts: parts,
207			},
208		},
209		nil,
210	)
211
212	var finalResponse *provider.ProviderResponse
213	for r := range response {
214		if r.Error != nil {
215			return r.Error
216		}
217		finalResponse = r.Response
218	}
219
220	if finalResponse == nil {
221		return fmt.Errorf("no response received from title provider")
222	}
223
224	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
225	if title == "" {
226		return nil
227	}
228
229	session.Title = title
230	_, err = a.sessions.Save(ctx, session)
231	return err
232}
233
234func (a *agent) err(err error) AgentEvent {
235	return AgentEvent{
236		Type:  AgentEventTypeError,
237		Error: err,
238	}
239}
240
241func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
242	if !a.Model().SupportsImages && attachments != nil {
243		attachments = nil
244	}
245	events := make(chan AgentEvent)
246	if a.IsSessionBusy(sessionID) {
247		return nil, ErrSessionBusy
248	}
249
250	genCtx, cancel := context.WithCancel(ctx)
251
252	a.activeRequests.Set(sessionID, cancel)
253	go func() {
254		slog.Debug("Request started", "sessionID", sessionID)
255		defer log.RecoverPanic("agent.Run", func() {
256			events <- a.err(fmt.Errorf("panic while running the agent"))
257		})
258		var attachmentParts []message.ContentPart
259		for _, attachment := range attachments {
260			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
261		}
262		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
263		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
264			slog.Error(result.Error.Error())
265		}
266		slog.Debug("Request completed", "sessionID", sessionID)
267		a.activeRequests.Del(sessionID)
268		cancel()
269		a.Publish(pubsub.CreatedEvent, result)
270		events <- result
271		close(events)
272	}()
273	return events, nil
274}
275
276func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
277	// List existing messages; if none, start title generation asynchronously.
278	msgs, err := a.messages.List(ctx, sessionID)
279	if err != nil {
280		return a.err(fmt.Errorf("failed to list messages: %w", err))
281	}
282	if len(msgs) == 0 {
283		go func() {
284			defer log.RecoverPanic("agent.Run", func() {
285				slog.Error("panic while generating title")
286			})
287			titleErr := a.generateTitle(context.Background(), sessionID, content)
288			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
289				slog.Error("failed to generate title", "error", titleErr)
290			}
291		}()
292	}
293	session, err := a.sessions.Get(ctx, sessionID)
294	if err != nil {
295		return a.err(fmt.Errorf("failed to get session: %w", err))
296	}
297	if session.SummaryMessageID != "" {
298		summaryMsgInex := -1
299		for i, msg := range msgs {
300			if msg.ID == session.SummaryMessageID {
301				summaryMsgInex = i
302				break
303			}
304		}
305		if summaryMsgInex != -1 {
306			msgs = msgs[summaryMsgInex:]
307			msgs[0].Role = message.User
308		}
309	}
310
311	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
312	if err != nil {
313		return a.err(fmt.Errorf("failed to create user message: %w", err))
314	}
315	// Append the new user message to the conversation history.
316	msgHistory := append(msgs, userMsg)
317
318	for {
319		// Check for cancellation before each iteration
320		select {
321		case <-ctx.Done():
322			return a.err(ctx.Err())
323		default:
324			// Continue processing
325		}
326		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
327		if err != nil {
328			if errors.Is(err, context.Canceled) {
329				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
330				a.messages.Update(context.Background(), agentMessage)
331				return a.err(ErrRequestCancelled)
332			}
333			return a.err(fmt.Errorf("failed to process events: %w", err))
334		}
335		if a.debug {
336			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
337		}
338		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
339			// We are not done, we need to respond with the tool response
340			msgHistory = append(msgHistory, agentMessage, *toolResults)
341			continue
342		}
343		return AgentEvent{
344			Type:    AgentEventTypeResponse,
345			Message: agentMessage,
346			Done:    true,
347		}
348	}
349}
350
351func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
352	parts := []message.ContentPart{message.TextContent{Text: content}}
353	parts = append(parts, attachmentParts...)
354	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
355		Role:  message.User,
356		Parts: parts,
357	})
358}
359
360func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
361	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
362	eventChan := a.provider.Stream(ctx, a.large.Model, msgHistory, a.toolsRegistry.GetAllTools())
363
364	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
365		Role:     message.Assistant,
366		Parts:    []message.ContentPart{},
367		Model:    a.large.Model,
368		Provider: a.large.Provider,
369	})
370	if err != nil {
371		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
372	}
373
374	// Add the session and message ID into the context if needed by tools.
375	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
376
377	// Process each event in the stream.
378	for event := range eventChan {
379		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
380			if errors.Is(processErr, context.Canceled) {
381				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
382			} else {
383				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
384			}
385			return assistantMsg, nil, processErr
386		}
387		if ctx.Err() != nil {
388			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
389			return assistantMsg, nil, ctx.Err()
390		}
391	}
392
393	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
394	toolCalls := assistantMsg.ToolCalls()
395	for i, toolCall := range toolCalls {
396		select {
397		case <-ctx.Done():
398			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
399			// Make all future tool calls cancelled
400			for j := i; j < len(toolCalls); j++ {
401				toolResults[j] = message.ToolResult{
402					ToolCallID: toolCalls[j].ID,
403					Content:    "Tool execution canceled by user",
404					IsError:    true,
405				}
406			}
407			goto out
408		default:
409			// Continue processing
410			var tool tools.BaseTool
411			for _, availableTool := range a.toolsRegistry.GetAllTools() {
412				if availableTool.Info().Name == toolCall.Name {
413					tool = availableTool
414					break
415				}
416			}
417
418			// Tool not found
419			if tool == nil {
420				toolResults[i] = message.ToolResult{
421					ToolCallID: toolCall.ID,
422					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
423					IsError:    true,
424				}
425				continue
426			}
427
428			// Run tool in goroutine to allow cancellation
429			type toolExecResult struct {
430				response tools.ToolResponse
431				err      error
432			}
433			resultChan := make(chan toolExecResult, 1)
434
435			go func() {
436				response, err := tool.Run(ctx, tools.ToolCall{
437					ID:    toolCall.ID,
438					Name:  toolCall.Name,
439					Input: toolCall.Input,
440				})
441				resultChan <- toolExecResult{response: response, err: err}
442			}()
443
444			var toolResponse tools.ToolResponse
445			var toolErr error
446
447			select {
448			case <-ctx.Done():
449				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
450				// Mark remaining tool calls as cancelled
451				for j := i; j < len(toolCalls); j++ {
452					toolResults[j] = message.ToolResult{
453						ToolCallID: toolCalls[j].ID,
454						Content:    "Tool execution canceled by user",
455						IsError:    true,
456					}
457				}
458				goto out
459			case result := <-resultChan:
460				toolResponse = result.response
461				toolErr = result.err
462			}
463
464			if toolErr != nil {
465				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
466				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
467					toolResults[i] = message.ToolResult{
468						ToolCallID: toolCall.ID,
469						Content:    "Permission denied",
470						IsError:    true,
471					}
472					for j := i + 1; j < len(toolCalls); j++ {
473						toolResults[j] = message.ToolResult{
474							ToolCallID: toolCalls[j].ID,
475							Content:    "Tool execution canceled by user",
476							IsError:    true,
477						}
478					}
479					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
480					break
481				}
482			}
483			toolResults[i] = message.ToolResult{
484				ToolCallID: toolCall.ID,
485				Content:    toolResponse.Content,
486				Metadata:   toolResponse.Metadata,
487				IsError:    toolResponse.IsError,
488			}
489		}
490	}
491out:
492	if len(toolResults) == 0 {
493		return assistantMsg, nil, nil
494	}
495	parts := make([]message.ContentPart, 0)
496	for _, tr := range toolResults {
497		parts = append(parts, tr)
498	}
499	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
500		Role:     message.Tool,
501		Parts:    parts,
502		Model:    a.large.Model,
503		Provider: a.large.Provider,
504	})
505	if err != nil {
506		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
507	}
508
509	return assistantMsg, &msg, err
510}
511
512func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
513	msg.AddFinish(finishReason, message, details)
514	_ = a.messages.Update(ctx, *msg)
515}
516
517func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
518	select {
519	case <-ctx.Done():
520		return ctx.Err()
521	default:
522		// Continue processing.
523	}
524
525	switch event.Type {
526	case provider.EventThinkingDelta:
527		assistantMsg.AppendReasoningContent(event.Thinking)
528		return a.messages.Update(ctx, *assistantMsg)
529	case provider.EventSignatureDelta:
530		assistantMsg.AppendReasoningSignature(event.Signature)
531		return a.messages.Update(ctx, *assistantMsg)
532	case provider.EventContentDelta:
533		assistantMsg.FinishThinking()
534		assistantMsg.AppendContent(event.Content)
535		return a.messages.Update(ctx, *assistantMsg)
536	case provider.EventToolUseStart:
537		assistantMsg.FinishThinking()
538		slog.Info("Tool call started", "toolCall", event.ToolCall)
539		assistantMsg.AddToolCall(*event.ToolCall)
540		return a.messages.Update(ctx, *assistantMsg)
541	case provider.EventToolUseDelta:
542		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
543		return a.messages.Update(ctx, *assistantMsg)
544	case provider.EventToolUseStop:
545		slog.Info("Finished tool call", "toolCall", event.ToolCall)
546		assistantMsg.FinishToolCall(event.ToolCall.ID)
547		return a.messages.Update(ctx, *assistantMsg)
548	case provider.EventError:
549		return event.Error
550	case provider.EventComplete:
551		assistantMsg.FinishThinking()
552		assistantMsg.SetToolCalls(event.Response.ToolCalls)
553		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
554		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
555			return fmt.Errorf("failed to update message: %w", err)
556		}
557		model := a.Model()
558		if model == nil {
559			return nil
560		}
561		return a.TrackUsage(ctx, sessionID, *model, event.Response.Usage)
562	}
563
564	return nil
565}
566
567func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
568	sess, err := a.sessions.Get(ctx, sessionID)
569	if err != nil {
570		return fmt.Errorf("failed to get session: %w", err)
571	}
572
573	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
574		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
575		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
576		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
577
578	sess.Cost += cost
579	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
580	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
581
582	_, err = a.sessions.Save(ctx, sess)
583	if err != nil {
584		return fmt.Errorf("failed to save session: %w", err)
585	}
586	return nil
587}
588
589func (a *agent) Summarize(ctx context.Context, sessionID string) error {
590	if a.summarizeProvider == nil {
591		return fmt.Errorf("summarize provider not available")
592	}
593
594	// Check if session is busy
595	if a.IsSessionBusy(sessionID) {
596		return ErrSessionBusy
597	}
598
599	// Create a new context with cancellation
600	summarizeCtx, cancel := context.WithCancel(ctx)
601
602	// Store the cancel function in activeRequests to allow cancellation
603	a.activeRequests.Set(sessionID+"-summarize", cancel)
604
605	go func() {
606		defer a.activeRequests.Del(sessionID + "-summarize")
607		defer cancel()
608		event := AgentEvent{
609			Type:     AgentEventTypeSummarize,
610			Progress: "Starting summarization...",
611		}
612
613		a.Publish(pubsub.CreatedEvent, event)
614		// Get all messages from the session
615		msgs, err := a.messages.List(summarizeCtx, sessionID)
616		if err != nil {
617			event = AgentEvent{
618				Type:  AgentEventTypeError,
619				Error: fmt.Errorf("failed to list messages: %w", err),
620				Done:  true,
621			}
622			a.Publish(pubsub.CreatedEvent, event)
623			return
624		}
625		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
626
627		if len(msgs) == 0 {
628			event = AgentEvent{
629				Type:  AgentEventTypeError,
630				Error: fmt.Errorf("no messages to summarize"),
631				Done:  true,
632			}
633			a.Publish(pubsub.CreatedEvent, event)
634			return
635		}
636
637		event = AgentEvent{
638			Type:     AgentEventTypeSummarize,
639			Progress: "Analyzing conversation...",
640		}
641		a.Publish(pubsub.CreatedEvent, event)
642
643		// Add a system message to guide the summarization
644		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
645
646		// Create a new message with the summarize prompt
647		promptMsg := message.Message{
648			Role:  message.User,
649			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
650		}
651
652		// Append the prompt to the messages
653		msgsWithPrompt := append(msgs, promptMsg)
654
655		event = AgentEvent{
656			Type:     AgentEventTypeSummarize,
657			Progress: "Generating summary...",
658		}
659
660		a.Publish(pubsub.CreatedEvent, event)
661
662		// Send the messages to the summarize provider
663		response := a.summarizeProvider.Stream(
664			summarizeCtx,
665			a.large.Model,
666			msgsWithPrompt,
667			nil,
668		)
669		var finalResponse *provider.ProviderResponse
670		for r := range response {
671			if r.Error != nil {
672				event = AgentEvent{
673					Type:  AgentEventTypeError,
674					Error: fmt.Errorf("failed to summarize: %w", err),
675					Done:  true,
676				}
677				a.Publish(pubsub.CreatedEvent, event)
678				return
679			}
680			finalResponse = r.Response
681		}
682
683		summary := strings.TrimSpace(finalResponse.Content)
684		if summary == "" {
685			event = AgentEvent{
686				Type:  AgentEventTypeError,
687				Error: fmt.Errorf("empty summary returned"),
688				Done:  true,
689			}
690			a.Publish(pubsub.CreatedEvent, event)
691			return
692		}
693		shell := shell.GetPersistentShell(a.cwd)
694		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
695		event = AgentEvent{
696			Type:     AgentEventTypeSummarize,
697			Progress: "Creating new session...",
698		}
699
700		a.Publish(pubsub.CreatedEvent, event)
701		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
702		if err != nil {
703			event = AgentEvent{
704				Type:  AgentEventTypeError,
705				Error: fmt.Errorf("failed to get session: %w", err),
706				Done:  true,
707			}
708
709			a.Publish(pubsub.CreatedEvent, event)
710			return
711		}
712		// Create a message in the new session with the summary
713		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
714			Role: message.Assistant,
715			Parts: []message.ContentPart{
716				message.TextContent{Text: summary},
717				message.Finish{
718					Reason: message.FinishReasonEndTurn,
719					Time:   time.Now().Unix(),
720				},
721			},
722			Model:    a.large.Model,
723			Provider: a.large.Provider,
724		})
725		if err != nil {
726			event = AgentEvent{
727				Type:  AgentEventTypeError,
728				Error: fmt.Errorf("failed to create summary message: %w", err),
729				Done:  true,
730			}
731
732			a.Publish(pubsub.CreatedEvent, event)
733			return
734		}
735		oldSession.SummaryMessageID = msg.ID
736		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
737		oldSession.PromptTokens = 0
738		model := a.summarizeProvider.Model(a.large.Model)
739		usage := finalResponse.Usage
740		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
741			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
742			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
743			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
744		oldSession.Cost += cost
745		_, err = a.sessions.Save(summarizeCtx, oldSession)
746		if err != nil {
747			event = AgentEvent{
748				Type:  AgentEventTypeError,
749				Error: fmt.Errorf("failed to save session: %w", err),
750				Done:  true,
751			}
752			a.Publish(pubsub.CreatedEvent, event)
753		}
754
755		event = AgentEvent{
756			Type:      AgentEventTypeSummarize,
757			SessionID: oldSession.ID,
758			Progress:  "Summary complete",
759			Done:      true,
760		}
761		a.Publish(pubsub.CreatedEvent, event)
762		// Send final success event with the new session ID
763	}()
764
765	return nil
766}
767
768func (a *agent) CancelAll() {
769	if !a.IsBusy() {
770		return
771	}
772	for key := range a.activeRequests.Seq2() {
773		a.Cancel(key) // key is sessionID
774	}
775
776	timeout := time.After(5 * time.Second)
777	for a.IsBusy() {
778		select {
779		case <-timeout:
780			return
781		default:
782			time.Sleep(200 * time.Millisecond)
783		}
784	}
785}
786
787func (a *agent) UpdateModels(small, large Model) error {
788	a.small = small
789	a.large = large
790	return a.setProviders()
791}
792
793func (a *agent) SetDebug(debug bool) {
794	a.debug = debug
795	if a.provider != nil {
796		a.provider.SetDebug(debug)
797	}
798	if a.titleProvider != nil {
799		a.titleProvider.SetDebug(debug)
800	}
801	if a.summarizeProvider != nil {
802		a.summarizeProvider.SetDebug(debug)
803	}
804}
805
806func (a *agent) setProviders() error {
807	opts := []provider.Option{
808		provider.WithSystemMessage(a.systemPrompt),
809		provider.WithThinking(a.large.Think),
810	}
811
812	if a.large.MaxTokens > 0 {
813		opts = append(opts, provider.WithMaxTokens(a.large.MaxTokens))
814	}
815	if a.large.ReasoningEffort != "" {
816		opts = append(opts, provider.WithReasoningEffort(a.large.ReasoningEffort))
817	}
818
819	providerCfg, ok := a.providers[a.large.Provider]
820	if !ok {
821		return fmt.Errorf("provider %s not found in config", a.large.Provider)
822	}
823	var err error
824	a.provider, err = provider.NewProvider(providerCfg, opts...)
825	if err != nil {
826		return fmt.Errorf("failed to create provider: %w", err)
827	}
828
829	titleOpts := []provider.Option{
830		provider.WithSystemMessage(prompt.TitlePrompt()),
831		provider.WithMaxTokens(40),
832	}
833
834	titleProviderCfg, ok := a.providers[a.small.Provider]
835	if !ok {
836		return fmt.Errorf("small model provider %s not found in config", a.small.Provider)
837	}
838
839	a.titleProvider, err = provider.NewProvider(titleProviderCfg, titleOpts...)
840	if err != nil {
841		return err
842	}
843	summarizeOpts := []provider.Option{
844		provider.WithSystemMessage(prompt.SummarizerPrompt()),
845	}
846	a.summarizeProvider, err = provider.NewProvider(providerCfg, summarizeOpts...)
847	if err != nil {
848		return err
849	}
850
851	if _, ok := a.toolsRegistry.GetTool(AgentToolName); ok {
852		// reset the agent tool
853		a.WithAgentTool()
854	}
855
856	a.SetDebug(a.debug)
857	return nil
858}
859
860func (a *agent) WithAgentTool() error {
861	agent, err := NewAgent(
862		a.ctx,
863		a.cwd,
864		prompt.TaskPrompt(a.cwd),
865		NewTaskTools(a.cwd),
866		a.providers,
867		a.small,
868		a.large,
869		a.sessions,
870		a.messages,
871	)
872	if err != nil {
873		return err
874	}
875
876	agentTool := NewAgentTool(
877		agent,
878		a.sessions,
879		a.messages,
880	)
881
882	a.toolsRegistry.SetTool(AgentToolName, agentTool)
883	return nil
884}