1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "strings"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/llm/provider"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/log"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/crush/internal/shell"
22)
23
24// Common errors
25var (
26 ErrRequestCancelled = errors.New("request canceled by user")
27 ErrSessionBusy = errors.New("session is currently processing another request")
28)
29
30type AgentEventType string
31
32const (
33 AgentEventTypeError AgentEventType = "error"
34 AgentEventTypeResponse AgentEventType = "response"
35 AgentEventTypeSummarize AgentEventType = "summarize"
36)
37
38type AgentEvent struct {
39 Type AgentEventType
40 Message message.Message
41 Error error
42
43 // When summarizing
44 SessionID string
45 Progress string
46 Done bool
47}
48
49type Service interface {
50 pubsub.Suscriber[AgentEvent]
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 Summarize(ctx context.Context, sessionID string) error
57 SetDebug(debug bool)
58 UpdateModels(large, small Model) error
59 // for now, not really sure how to handle this better
60 WithAgentTool() error
61
62 ModelConfig() Model
63 Model() *catwalk.Model
64 Provider() *provider.Config
65}
66
67type Model struct {
68 // The model id as used by the provider API.
69 // Required.
70 Model string `json:"model"`
71 // The model provider, same as the key/id used in the providers config.
72 // Required.
73 Provider string `json:"provider"`
74
75 // Only used by models that use the openai provider and need this set.
76 ReasoningEffort string `json:"reasoning_effort,omitempty"`
77
78 // Overrides the default model configuration.
79 MaxTokens int64 `json:"max_tokens,omitempty"`
80
81 // Used by anthropic models that can reason to indicate if the model should think.
82 Think bool `json:"think,omitempty"`
83}
84
85type agent struct {
86 *pubsub.Broker[AgentEvent]
87 ctx context.Context
88 cwd string
89 systemPrompt string
90 providers map[string]provider.Config
91
92 sessions session.Service
93 messages message.Service
94
95 toolsRegistry tools.Registry
96
97 large, small Model
98 provider provider.Provider
99 titleProvider provider.Provider
100 summarizeProvider provider.Provider
101
102 activeRequests *csync.Map[string, context.CancelFunc]
103
104 debug bool
105}
106
107func NewAgent(
108 ctx context.Context,
109 cwd string,
110 systemPrompt string,
111 toolsRegistry tools.Registry,
112 providers map[string]provider.Config,
113
114 smallModel Model,
115 largeModel Model,
116
117 sessions session.Service,
118 messages message.Service,
119) (Service, error) {
120 agent := &agent{
121 Broker: pubsub.NewBroker[AgentEvent](),
122 ctx: ctx,
123 providers: providers,
124 cwd: cwd,
125 systemPrompt: systemPrompt,
126 toolsRegistry: toolsRegistry,
127 small: smallModel,
128 large: largeModel,
129 messages: messages,
130 sessions: sessions,
131 activeRequests: csync.NewMap[string, context.CancelFunc](),
132 }
133
134 err := agent.setProviders()
135 return agent, err
136}
137
138func (a *agent) ModelConfig() Model {
139 return a.large
140}
141
142func (a *agent) Model() *catwalk.Model {
143 return a.provider.Model(a.large.Model)
144}
145
146func (a *agent) Provider() *provider.Config {
147 for _, provider := range a.providers {
148 if provider.ID == a.large.Provider {
149 return &provider
150 }
151 }
152 return nil
153}
154
155func (a *agent) Cancel(sessionID string) {
156 // Cancel regular requests
157 if cancel, exists := a.activeRequests.Take(sessionID); exists {
158 slog.Info("Request cancellation initiated", "session_id", sessionID)
159 cancel()
160 }
161
162 // Also check for summarize requests
163 if cancel, exists := a.activeRequests.Take(sessionID + "-summarize"); exists {
164 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
165 cancel()
166 }
167}
168
169func (a *agent) IsBusy() bool {
170 busy := false
171 for cancelFunc := range a.activeRequests.Seq() {
172 if cancelFunc != nil {
173 busy = true
174 break
175 }
176 }
177 return busy
178}
179
180func (a *agent) IsSessionBusy(sessionID string) bool {
181 _, busy := a.activeRequests.Get(sessionID)
182 return busy
183}
184
185func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
186 if content == "" {
187 return nil
188 }
189 if a.titleProvider == nil {
190 return nil
191 }
192 session, err := a.sessions.Get(ctx, sessionID)
193 if err != nil {
194 return err
195 }
196 parts := []message.ContentPart{message.TextContent{
197 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
198 }}
199
200 response := a.titleProvider.Stream(
201 ctx,
202 a.small.Model,
203 []message.Message{
204 {
205 Role: message.User,
206 Parts: parts,
207 },
208 },
209 nil,
210 )
211
212 var finalResponse *provider.ProviderResponse
213 for r := range response {
214 if r.Error != nil {
215 return r.Error
216 }
217 finalResponse = r.Response
218 }
219
220 if finalResponse == nil {
221 return fmt.Errorf("no response received from title provider")
222 }
223
224 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
225 if title == "" {
226 return nil
227 }
228
229 session.Title = title
230 _, err = a.sessions.Save(ctx, session)
231 return err
232}
233
234func (a *agent) err(err error) AgentEvent {
235 return AgentEvent{
236 Type: AgentEventTypeError,
237 Error: err,
238 }
239}
240
241func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
242 if !a.Model().SupportsImages && attachments != nil {
243 attachments = nil
244 }
245 events := make(chan AgentEvent)
246 if a.IsSessionBusy(sessionID) {
247 return nil, ErrSessionBusy
248 }
249
250 genCtx, cancel := context.WithCancel(ctx)
251
252 a.activeRequests.Set(sessionID, cancel)
253 go func() {
254 slog.Debug("Request started", "sessionID", sessionID)
255 defer log.RecoverPanic("agent.Run", func() {
256 events <- a.err(fmt.Errorf("panic while running the agent"))
257 })
258 var attachmentParts []message.ContentPart
259 for _, attachment := range attachments {
260 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
261 }
262 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
263 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
264 slog.Error(result.Error.Error())
265 }
266 slog.Debug("Request completed", "sessionID", sessionID)
267 a.activeRequests.Del(sessionID)
268 cancel()
269 a.Publish(pubsub.CreatedEvent, result)
270 events <- result
271 close(events)
272 }()
273 return events, nil
274}
275
276func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
277 // List existing messages; if none, start title generation asynchronously.
278 msgs, err := a.messages.List(ctx, sessionID)
279 if err != nil {
280 return a.err(fmt.Errorf("failed to list messages: %w", err))
281 }
282 if len(msgs) == 0 {
283 go func() {
284 defer log.RecoverPanic("agent.Run", func() {
285 slog.Error("panic while generating title")
286 })
287 titleErr := a.generateTitle(context.Background(), sessionID, content)
288 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
289 slog.Error("failed to generate title", "error", titleErr)
290 }
291 }()
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return a.err(fmt.Errorf("failed to get session: %w", err))
296 }
297 if session.SummaryMessageID != "" {
298 summaryMsgInex := -1
299 for i, msg := range msgs {
300 if msg.ID == session.SummaryMessageID {
301 summaryMsgInex = i
302 break
303 }
304 }
305 if summaryMsgInex != -1 {
306 msgs = msgs[summaryMsgInex:]
307 msgs[0].Role = message.User
308 }
309 }
310
311 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
312 if err != nil {
313 return a.err(fmt.Errorf("failed to create user message: %w", err))
314 }
315 // Append the new user message to the conversation history.
316 msgHistory := append(msgs, userMsg)
317
318 for {
319 // Check for cancellation before each iteration
320 select {
321 case <-ctx.Done():
322 return a.err(ctx.Err())
323 default:
324 // Continue processing
325 }
326 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
327 if err != nil {
328 if errors.Is(err, context.Canceled) {
329 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
330 a.messages.Update(context.Background(), agentMessage)
331 return a.err(ErrRequestCancelled)
332 }
333 return a.err(fmt.Errorf("failed to process events: %w", err))
334 }
335 if a.debug {
336 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
337 }
338 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
339 // We are not done, we need to respond with the tool response
340 msgHistory = append(msgHistory, agentMessage, *toolResults)
341 continue
342 }
343 return AgentEvent{
344 Type: AgentEventTypeResponse,
345 Message: agentMessage,
346 Done: true,
347 }
348 }
349}
350
351func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
352 parts := []message.ContentPart{message.TextContent{Text: content}}
353 parts = append(parts, attachmentParts...)
354 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
355 Role: message.User,
356 Parts: parts,
357 })
358}
359
360func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
361 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
362 eventChan := a.provider.Stream(ctx, a.large.Model, msgHistory, a.toolsRegistry.GetAllTools())
363
364 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
365 Role: message.Assistant,
366 Parts: []message.ContentPart{},
367 Model: a.large.Model,
368 Provider: a.large.Provider,
369 })
370 if err != nil {
371 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
372 }
373
374 // Add the session and message ID into the context if needed by tools.
375 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
376
377 // Process each event in the stream.
378 for event := range eventChan {
379 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
380 if errors.Is(processErr, context.Canceled) {
381 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
382 } else {
383 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
384 }
385 return assistantMsg, nil, processErr
386 }
387 if ctx.Err() != nil {
388 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
389 return assistantMsg, nil, ctx.Err()
390 }
391 }
392
393 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
394 toolCalls := assistantMsg.ToolCalls()
395 for i, toolCall := range toolCalls {
396 select {
397 case <-ctx.Done():
398 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
399 // Make all future tool calls cancelled
400 for j := i; j < len(toolCalls); j++ {
401 toolResults[j] = message.ToolResult{
402 ToolCallID: toolCalls[j].ID,
403 Content: "Tool execution canceled by user",
404 IsError: true,
405 }
406 }
407 goto out
408 default:
409 // Continue processing
410 var tool tools.BaseTool
411 for _, availableTool := range a.toolsRegistry.GetAllTools() {
412 if availableTool.Info().Name == toolCall.Name {
413 tool = availableTool
414 break
415 }
416 }
417
418 // Tool not found
419 if tool == nil {
420 toolResults[i] = message.ToolResult{
421 ToolCallID: toolCall.ID,
422 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
423 IsError: true,
424 }
425 continue
426 }
427
428 // Run tool in goroutine to allow cancellation
429 type toolExecResult struct {
430 response tools.ToolResponse
431 err error
432 }
433 resultChan := make(chan toolExecResult, 1)
434
435 go func() {
436 response, err := tool.Run(ctx, tools.ToolCall{
437 ID: toolCall.ID,
438 Name: toolCall.Name,
439 Input: toolCall.Input,
440 })
441 resultChan <- toolExecResult{response: response, err: err}
442 }()
443
444 var toolResponse tools.ToolResponse
445 var toolErr error
446
447 select {
448 case <-ctx.Done():
449 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
450 // Mark remaining tool calls as cancelled
451 for j := i; j < len(toolCalls); j++ {
452 toolResults[j] = message.ToolResult{
453 ToolCallID: toolCalls[j].ID,
454 Content: "Tool execution canceled by user",
455 IsError: true,
456 }
457 }
458 goto out
459 case result := <-resultChan:
460 toolResponse = result.response
461 toolErr = result.err
462 }
463
464 if toolErr != nil {
465 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
466 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
467 toolResults[i] = message.ToolResult{
468 ToolCallID: toolCall.ID,
469 Content: "Permission denied",
470 IsError: true,
471 }
472 for j := i + 1; j < len(toolCalls); j++ {
473 toolResults[j] = message.ToolResult{
474 ToolCallID: toolCalls[j].ID,
475 Content: "Tool execution canceled by user",
476 IsError: true,
477 }
478 }
479 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
480 break
481 }
482 }
483 toolResults[i] = message.ToolResult{
484 ToolCallID: toolCall.ID,
485 Content: toolResponse.Content,
486 Metadata: toolResponse.Metadata,
487 IsError: toolResponse.IsError,
488 }
489 }
490 }
491out:
492 if len(toolResults) == 0 {
493 return assistantMsg, nil, nil
494 }
495 parts := make([]message.ContentPart, 0)
496 for _, tr := range toolResults {
497 parts = append(parts, tr)
498 }
499 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
500 Role: message.Tool,
501 Parts: parts,
502 Model: a.large.Model,
503 Provider: a.large.Provider,
504 })
505 if err != nil {
506 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
507 }
508
509 return assistantMsg, &msg, err
510}
511
512func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
513 msg.AddFinish(finishReason, message, details)
514 _ = a.messages.Update(ctx, *msg)
515}
516
517func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
518 select {
519 case <-ctx.Done():
520 return ctx.Err()
521 default:
522 // Continue processing.
523 }
524
525 switch event.Type {
526 case provider.EventThinkingDelta:
527 assistantMsg.AppendReasoningContent(event.Thinking)
528 return a.messages.Update(ctx, *assistantMsg)
529 case provider.EventSignatureDelta:
530 assistantMsg.AppendReasoningSignature(event.Signature)
531 return a.messages.Update(ctx, *assistantMsg)
532 case provider.EventContentDelta:
533 assistantMsg.FinishThinking()
534 assistantMsg.AppendContent(event.Content)
535 return a.messages.Update(ctx, *assistantMsg)
536 case provider.EventToolUseStart:
537 assistantMsg.FinishThinking()
538 slog.Info("Tool call started", "toolCall", event.ToolCall)
539 assistantMsg.AddToolCall(*event.ToolCall)
540 return a.messages.Update(ctx, *assistantMsg)
541 case provider.EventToolUseDelta:
542 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
543 return a.messages.Update(ctx, *assistantMsg)
544 case provider.EventToolUseStop:
545 slog.Info("Finished tool call", "toolCall", event.ToolCall)
546 assistantMsg.FinishToolCall(event.ToolCall.ID)
547 return a.messages.Update(ctx, *assistantMsg)
548 case provider.EventError:
549 return event.Error
550 case provider.EventComplete:
551 assistantMsg.FinishThinking()
552 assistantMsg.SetToolCalls(event.Response.ToolCalls)
553 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
554 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
555 return fmt.Errorf("failed to update message: %w", err)
556 }
557 model := a.Model()
558 if model == nil {
559 return nil
560 }
561 return a.TrackUsage(ctx, sessionID, *model, event.Response.Usage)
562 }
563
564 return nil
565}
566
567func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
568 sess, err := a.sessions.Get(ctx, sessionID)
569 if err != nil {
570 return fmt.Errorf("failed to get session: %w", err)
571 }
572
573 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
574 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
575 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
576 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
577
578 sess.Cost += cost
579 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
580 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
581
582 _, err = a.sessions.Save(ctx, sess)
583 if err != nil {
584 return fmt.Errorf("failed to save session: %w", err)
585 }
586 return nil
587}
588
589func (a *agent) Summarize(ctx context.Context, sessionID string) error {
590 if a.summarizeProvider == nil {
591 return fmt.Errorf("summarize provider not available")
592 }
593
594 // Check if session is busy
595 if a.IsSessionBusy(sessionID) {
596 return ErrSessionBusy
597 }
598
599 // Create a new context with cancellation
600 summarizeCtx, cancel := context.WithCancel(ctx)
601
602 // Store the cancel function in activeRequests to allow cancellation
603 a.activeRequests.Set(sessionID+"-summarize", cancel)
604
605 go func() {
606 defer a.activeRequests.Del(sessionID + "-summarize")
607 defer cancel()
608 event := AgentEvent{
609 Type: AgentEventTypeSummarize,
610 Progress: "Starting summarization...",
611 }
612
613 a.Publish(pubsub.CreatedEvent, event)
614 // Get all messages from the session
615 msgs, err := a.messages.List(summarizeCtx, sessionID)
616 if err != nil {
617 event = AgentEvent{
618 Type: AgentEventTypeError,
619 Error: fmt.Errorf("failed to list messages: %w", err),
620 Done: true,
621 }
622 a.Publish(pubsub.CreatedEvent, event)
623 return
624 }
625 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
626
627 if len(msgs) == 0 {
628 event = AgentEvent{
629 Type: AgentEventTypeError,
630 Error: fmt.Errorf("no messages to summarize"),
631 Done: true,
632 }
633 a.Publish(pubsub.CreatedEvent, event)
634 return
635 }
636
637 event = AgentEvent{
638 Type: AgentEventTypeSummarize,
639 Progress: "Analyzing conversation...",
640 }
641 a.Publish(pubsub.CreatedEvent, event)
642
643 // Add a system message to guide the summarization
644 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
645
646 // Create a new message with the summarize prompt
647 promptMsg := message.Message{
648 Role: message.User,
649 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
650 }
651
652 // Append the prompt to the messages
653 msgsWithPrompt := append(msgs, promptMsg)
654
655 event = AgentEvent{
656 Type: AgentEventTypeSummarize,
657 Progress: "Generating summary...",
658 }
659
660 a.Publish(pubsub.CreatedEvent, event)
661
662 // Send the messages to the summarize provider
663 response := a.summarizeProvider.Stream(
664 summarizeCtx,
665 a.large.Model,
666 msgsWithPrompt,
667 nil,
668 )
669 var finalResponse *provider.ProviderResponse
670 for r := range response {
671 if r.Error != nil {
672 event = AgentEvent{
673 Type: AgentEventTypeError,
674 Error: fmt.Errorf("failed to summarize: %w", err),
675 Done: true,
676 }
677 a.Publish(pubsub.CreatedEvent, event)
678 return
679 }
680 finalResponse = r.Response
681 }
682
683 summary := strings.TrimSpace(finalResponse.Content)
684 if summary == "" {
685 event = AgentEvent{
686 Type: AgentEventTypeError,
687 Error: fmt.Errorf("empty summary returned"),
688 Done: true,
689 }
690 a.Publish(pubsub.CreatedEvent, event)
691 return
692 }
693 shell := shell.GetPersistentShell(a.cwd)
694 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
695 event = AgentEvent{
696 Type: AgentEventTypeSummarize,
697 Progress: "Creating new session...",
698 }
699
700 a.Publish(pubsub.CreatedEvent, event)
701 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
702 if err != nil {
703 event = AgentEvent{
704 Type: AgentEventTypeError,
705 Error: fmt.Errorf("failed to get session: %w", err),
706 Done: true,
707 }
708
709 a.Publish(pubsub.CreatedEvent, event)
710 return
711 }
712 // Create a message in the new session with the summary
713 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
714 Role: message.Assistant,
715 Parts: []message.ContentPart{
716 message.TextContent{Text: summary},
717 message.Finish{
718 Reason: message.FinishReasonEndTurn,
719 Time: time.Now().Unix(),
720 },
721 },
722 Model: a.large.Model,
723 Provider: a.large.Provider,
724 })
725 if err != nil {
726 event = AgentEvent{
727 Type: AgentEventTypeError,
728 Error: fmt.Errorf("failed to create summary message: %w", err),
729 Done: true,
730 }
731
732 a.Publish(pubsub.CreatedEvent, event)
733 return
734 }
735 oldSession.SummaryMessageID = msg.ID
736 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
737 oldSession.PromptTokens = 0
738 model := a.summarizeProvider.Model(a.large.Model)
739 usage := finalResponse.Usage
740 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
741 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
742 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
743 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
744 oldSession.Cost += cost
745 _, err = a.sessions.Save(summarizeCtx, oldSession)
746 if err != nil {
747 event = AgentEvent{
748 Type: AgentEventTypeError,
749 Error: fmt.Errorf("failed to save session: %w", err),
750 Done: true,
751 }
752 a.Publish(pubsub.CreatedEvent, event)
753 }
754
755 event = AgentEvent{
756 Type: AgentEventTypeSummarize,
757 SessionID: oldSession.ID,
758 Progress: "Summary complete",
759 Done: true,
760 }
761 a.Publish(pubsub.CreatedEvent, event)
762 // Send final success event with the new session ID
763 }()
764
765 return nil
766}
767
768func (a *agent) CancelAll() {
769 if !a.IsBusy() {
770 return
771 }
772 for key := range a.activeRequests.Seq2() {
773 a.Cancel(key) // key is sessionID
774 }
775
776 timeout := time.After(5 * time.Second)
777 for a.IsBusy() {
778 select {
779 case <-timeout:
780 return
781 default:
782 time.Sleep(200 * time.Millisecond)
783 }
784 }
785}
786
787func (a *agent) UpdateModels(small, large Model) error {
788 a.small = small
789 a.large = large
790 return a.setProviders()
791}
792
793func (a *agent) SetDebug(debug bool) {
794 a.debug = debug
795 if a.provider != nil {
796 a.provider.SetDebug(debug)
797 }
798 if a.titleProvider != nil {
799 a.titleProvider.SetDebug(debug)
800 }
801 if a.summarizeProvider != nil {
802 a.summarizeProvider.SetDebug(debug)
803 }
804}
805
806func (a *agent) setProviders() error {
807 opts := []provider.Option{
808 provider.WithSystemMessage(a.systemPrompt),
809 provider.WithThinking(a.large.Think),
810 }
811
812 if a.large.MaxTokens > 0 {
813 opts = append(opts, provider.WithMaxTokens(a.large.MaxTokens))
814 }
815 if a.large.ReasoningEffort != "" {
816 opts = append(opts, provider.WithReasoningEffort(a.large.ReasoningEffort))
817 }
818
819 providerCfg, ok := a.providers[a.large.Provider]
820 if !ok {
821 return fmt.Errorf("provider %s not found in config", a.large.Provider)
822 }
823 var err error
824 a.provider, err = provider.NewProvider(providerCfg, opts...)
825 if err != nil {
826 return fmt.Errorf("failed to create provider: %w", err)
827 }
828
829 titleOpts := []provider.Option{
830 provider.WithSystemMessage(prompt.TitlePrompt()),
831 provider.WithMaxTokens(40),
832 }
833
834 titleProviderCfg, ok := a.providers[a.small.Provider]
835 if !ok {
836 return fmt.Errorf("small model provider %s not found in config", a.small.Provider)
837 }
838
839 a.titleProvider, err = provider.NewProvider(titleProviderCfg, titleOpts...)
840 if err != nil {
841 return err
842 }
843 summarizeOpts := []provider.Option{
844 provider.WithSystemMessage(prompt.SummarizerPrompt()),
845 }
846 a.summarizeProvider, err = provider.NewProvider(providerCfg, summarizeOpts...)
847 if err != nil {
848 return err
849 }
850
851 if _, ok := a.toolsRegistry.GetTool(AgentToolName); ok {
852 // reset the agent tool
853 a.WithAgentTool()
854 }
855
856 a.SetDebug(a.debug)
857 return nil
858}
859
860func (a *agent) WithAgentTool() error {
861 agent, err := NewAgent(
862 a.ctx,
863 a.cwd,
864 prompt.TaskPrompt(a.cwd),
865 NewTaskTools(a.cwd),
866 a.providers,
867 a.small,
868 a.large,
869 a.sessions,
870 a.messages,
871 )
872 if err != nil {
873 return err
874 }
875
876 agentTool := NewAgentTool(
877 agent,
878 a.sessions,
879 a.messages,
880 )
881
882 a.toolsRegistry.SetTool(AgentToolName, agentTool)
883 return nil
884}