1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59}
60
61type agent struct {
62 *pubsub.Broker[AgentEvent]
63 agentCfg config.Agent
64 sessions session.Service
65 messages message.Service
66
67 tools []tools.BaseTool
68 provider provider.Provider
69 providerID string
70
71 titleProvider provider.Provider
72 summarizeProvider provider.Provider
73 summarizeProviderID string
74
75 activeRequests sync.Map
76}
77
78var agentPromptMap = map[config.AgentID]prompt.PromptID{
79 config.AgentCoder: prompt.PromptCoder,
80 config.AgentTask: prompt.PromptTask,
81}
82
83func NewAgent(
84 agentCfg config.Agent,
85 // These services are needed in the tools
86 permissions permission.Service,
87 sessions session.Service,
88 messages message.Service,
89 history history.Service,
90 lspClients map[string]*lsp.Client,
91) (Service, error) {
92 ctx := context.Background()
93 cfg := config.Get()
94 otherTools := GetMcpTools(ctx, permissions)
95 if len(lspClients) > 0 {
96 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
97 }
98
99 allTools := []tools.BaseTool{
100 tools.NewBashTool(permissions),
101 tools.NewEditTool(lspClients, permissions, history),
102 tools.NewFetchTool(permissions),
103 tools.NewGlobTool(),
104 tools.NewGrepTool(),
105 tools.NewLsTool(),
106 tools.NewSourcegraphTool(),
107 tools.NewViewTool(lspClients),
108 tools.NewWriteTool(lspClients, permissions, history),
109 }
110
111 if agentCfg.ID == config.AgentCoder {
112 taskAgentCfg := config.Get().Agents[config.AgentTask]
113 if taskAgentCfg.ID == "" {
114 return nil, fmt.Errorf("task agent not found in config")
115 }
116 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
117 if err != nil {
118 return nil, fmt.Errorf("failed to create task agent: %w", err)
119 }
120
121 allTools = append(
122 allTools,
123 NewAgentTool(
124 taskAgent,
125 sessions,
126 messages,
127 ),
128 )
129 }
130
131 allTools = append(allTools, otherTools...)
132 providerCfg := config.GetAgentProvider(agentCfg.ID)
133 if providerCfg.ID == "" {
134 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
135 }
136 model := config.GetAgentModel(agentCfg.ID)
137
138 if model.ID == "" {
139 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
140 }
141
142 promptID := agentPromptMap[agentCfg.ID]
143 if promptID == "" {
144 promptID = prompt.PromptDefault
145 }
146 opts := []provider.ProviderClientOption{
147 provider.WithModel(agentCfg.Model),
148 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
149 provider.WithMaxTokens(model.DefaultMaxTokens),
150 }
151 agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
152 if err != nil {
153 return nil, err
154 }
155
156 smallModelCfg := cfg.Models.Small
157 var smallModel config.Model
158
159 var smallModelProviderCfg config.ProviderConfig
160 if smallModelCfg.Provider == providerCfg.ID {
161 smallModelProviderCfg = providerCfg
162 } else {
163 for _, p := range cfg.Providers {
164 if p.ID == smallModelCfg.Provider {
165 smallModelProviderCfg = p
166 break
167 }
168 }
169 if smallModelProviderCfg.ID == "" {
170 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171 }
172 }
173 for _, m := range smallModelProviderCfg.Models {
174 if m.ID == smallModelCfg.ModelID {
175 smallModel = m
176 break
177 }
178 }
179 if smallModel.ID == "" {
180 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181 }
182
183 titleOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SmallModel),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186 provider.WithMaxTokens(40),
187 }
188 titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
189 if err != nil {
190 return nil, err
191 }
192 summarizeOpts := []provider.ProviderClientOption{
193 provider.WithModel(config.SmallModel),
194 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195 provider.WithMaxTokens(smallModel.DefaultMaxTokens),
196 }
197 summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
198 if err != nil {
199 return nil, err
200 }
201
202 agentTools := []tools.BaseTool{}
203 if agentCfg.AllowedTools == nil {
204 agentTools = allTools
205 } else {
206 for _, tool := range allTools {
207 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208 agentTools = append(agentTools, tool)
209 }
210 }
211 }
212
213 agent := &agent{
214 Broker: pubsub.NewBroker[AgentEvent](),
215 agentCfg: agentCfg,
216 provider: agentProvider,
217 providerID: string(providerCfg.ID),
218 messages: messages,
219 sessions: sessions,
220 tools: agentTools,
221 titleProvider: titleProvider,
222 summarizeProvider: summarizeProvider,
223 summarizeProviderID: string(smallModelProviderCfg.ID),
224 activeRequests: sync.Map{},
225 }
226
227 return agent, nil
228}
229
230func (a *agent) Model() config.Model {
231 return config.GetAgentModel(a.agentCfg.ID)
232}
233
234func (a *agent) Cancel(sessionID string) {
235 // Cancel regular requests
236 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
237 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
238 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
239 cancel()
240 }
241 }
242
243 // Also check for summarize requests
244 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
245 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
246 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
247 cancel()
248 }
249 }
250}
251
252func (a *agent) IsBusy() bool {
253 busy := false
254 a.activeRequests.Range(func(key, value any) bool {
255 if cancelFunc, ok := value.(context.CancelFunc); ok {
256 if cancelFunc != nil {
257 busy = true
258 return false // Stop iterating
259 }
260 }
261 return true // Continue iterating
262 })
263 return busy
264}
265
266func (a *agent) IsSessionBusy(sessionID string) bool {
267 _, busy := a.activeRequests.Load(sessionID)
268 return busy
269}
270
271func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
272 if content == "" {
273 return nil
274 }
275 if a.titleProvider == nil {
276 return nil
277 }
278 session, err := a.sessions.Get(ctx, sessionID)
279 if err != nil {
280 return err
281 }
282 parts := []message.ContentPart{message.TextContent{Text: content}}
283
284 // Use streaming approach like summarization
285 response := a.titleProvider.StreamResponse(
286 ctx,
287 []message.Message{
288 {
289 Role: message.User,
290 Parts: parts,
291 },
292 },
293 make([]tools.BaseTool, 0),
294 )
295
296 var finalResponse *provider.ProviderResponse
297 for r := range response {
298 if r.Error != nil {
299 return r.Error
300 }
301 finalResponse = r.Response
302 }
303
304 if finalResponse == nil {
305 return fmt.Errorf("no response received from title provider")
306 }
307
308 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
309 if title == "" {
310 return nil
311 }
312
313 session.Title = title
314 _, err = a.sessions.Save(ctx, session)
315 return err
316}
317
318func (a *agent) err(err error) AgentEvent {
319 return AgentEvent{
320 Type: AgentEventTypeError,
321 Error: err,
322 }
323}
324
325func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
326 if !a.Model().SupportsImages && attachments != nil {
327 attachments = nil
328 }
329 events := make(chan AgentEvent)
330 if a.IsSessionBusy(sessionID) {
331 return nil, ErrSessionBusy
332 }
333
334 genCtx, cancel := context.WithCancel(ctx)
335
336 a.activeRequests.Store(sessionID, cancel)
337 go func() {
338 logging.Debug("Request started", "sessionID", sessionID)
339 defer logging.RecoverPanic("agent.Run", func() {
340 events <- a.err(fmt.Errorf("panic while running the agent"))
341 })
342 var attachmentParts []message.ContentPart
343 for _, attachment := range attachments {
344 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
345 }
346 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
347 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
348 logging.ErrorPersist(result.Error.Error())
349 }
350 logging.Debug("Request completed", "sessionID", sessionID)
351 a.activeRequests.Delete(sessionID)
352 cancel()
353 a.Publish(pubsub.CreatedEvent, result)
354 events <- result
355 close(events)
356 }()
357 return events, nil
358}
359
360func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
361 // List existing messages; if none, start title generation asynchronously.
362 msgs, err := a.messages.List(ctx, sessionID)
363 if err != nil {
364 return a.err(fmt.Errorf("failed to list messages: %w", err))
365 }
366 if len(msgs) == 0 {
367 go func() {
368 defer logging.RecoverPanic("agent.Run", func() {
369 logging.ErrorPersist("panic while generating title")
370 })
371 titleErr := a.generateTitle(context.Background(), sessionID, content)
372 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
374 }
375 }()
376 }
377 session, err := a.sessions.Get(ctx, sessionID)
378 if err != nil {
379 return a.err(fmt.Errorf("failed to get session: %w", err))
380 }
381 if session.SummaryMessageID != "" {
382 summaryMsgInex := -1
383 for i, msg := range msgs {
384 if msg.ID == session.SummaryMessageID {
385 summaryMsgInex = i
386 break
387 }
388 }
389 if summaryMsgInex != -1 {
390 msgs = msgs[summaryMsgInex:]
391 msgs[0].Role = message.User
392 }
393 }
394
395 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to create user message: %w", err))
398 }
399 // Append the new user message to the conversation history.
400 msgHistory := append(msgs, userMsg)
401
402 for {
403 // Check for cancellation before each iteration
404 select {
405 case <-ctx.Done():
406 return a.err(ctx.Err())
407 default:
408 // Continue processing
409 }
410 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411 if err != nil {
412 if errors.Is(err, context.Canceled) {
413 agentMessage.AddFinish(message.FinishReasonCanceled)
414 a.messages.Update(context.Background(), agentMessage)
415 return a.err(ErrRequestCancelled)
416 }
417 return a.err(fmt.Errorf("failed to process events: %w", err))
418 }
419 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
421 // We are not done, we need to respond with the tool response
422 msgHistory = append(msgHistory, agentMessage, *toolResults)
423 continue
424 }
425 return AgentEvent{
426 Type: AgentEventTypeResponse,
427 Message: agentMessage,
428 Done: true,
429 }
430 }
431}
432
433func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
434 parts := []message.ContentPart{message.TextContent{Text: content}}
435 parts = append(parts, attachmentParts...)
436 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
437 Role: message.User,
438 Parts: parts,
439 })
440}
441
442func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
443 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
444
445 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
446 Role: message.Assistant,
447 Parts: []message.ContentPart{},
448 Model: a.Model().ID,
449 Provider: a.providerID,
450 })
451 if err != nil {
452 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
453 }
454
455 // Add the session and message ID into the context if needed by tools.
456 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
457 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
458
459 // Process each event in the stream.
460 for event := range eventChan {
461 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
462 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
463 return assistantMsg, nil, processErr
464 }
465 if ctx.Err() != nil {
466 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
467 return assistantMsg, nil, ctx.Err()
468 }
469 }
470
471 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
472 toolCalls := assistantMsg.ToolCalls()
473 for i, toolCall := range toolCalls {
474 select {
475 case <-ctx.Done():
476 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
477 // Make all future tool calls cancelled
478 for j := i; j < len(toolCalls); j++ {
479 toolResults[j] = message.ToolResult{
480 ToolCallID: toolCalls[j].ID,
481 Content: "Tool execution canceled by user",
482 IsError: true,
483 }
484 }
485 goto out
486 default:
487 // Continue processing
488 var tool tools.BaseTool
489 for _, availableTools := range a.tools {
490 if availableTools.Info().Name == toolCall.Name {
491 tool = availableTools
492 }
493 }
494
495 // Tool not found
496 if tool == nil {
497 toolResults[i] = message.ToolResult{
498 ToolCallID: toolCall.ID,
499 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
500 IsError: true,
501 }
502 continue
503 }
504 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
505 ID: toolCall.ID,
506 Name: toolCall.Name,
507 Input: toolCall.Input,
508 })
509 if toolErr != nil {
510 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
511 toolResults[i] = message.ToolResult{
512 ToolCallID: toolCall.ID,
513 Content: "Permission denied",
514 IsError: true,
515 }
516 for j := i + 1; j < len(toolCalls); j++ {
517 toolResults[j] = message.ToolResult{
518 ToolCallID: toolCalls[j].ID,
519 Content: "Tool execution canceled by user",
520 IsError: true,
521 }
522 }
523 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
524 break
525 }
526 }
527 toolResults[i] = message.ToolResult{
528 ToolCallID: toolCall.ID,
529 Content: toolResult.Content,
530 Metadata: toolResult.Metadata,
531 IsError: toolResult.IsError,
532 }
533 }
534 }
535out:
536 if len(toolResults) == 0 {
537 return assistantMsg, nil, nil
538 }
539 parts := make([]message.ContentPart, 0)
540 for _, tr := range toolResults {
541 parts = append(parts, tr)
542 }
543 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: parts,
546 Provider: a.providerID,
547 })
548 if err != nil {
549 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
550 }
551
552 return assistantMsg, &msg, err
553}
554
555func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
556 msg.AddFinish(finishReson)
557 _ = a.messages.Update(ctx, *msg)
558}
559
560func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
561 select {
562 case <-ctx.Done():
563 return ctx.Err()
564 default:
565 // Continue processing.
566 }
567
568 switch event.Type {
569 case provider.EventThinkingDelta:
570 assistantMsg.AppendReasoningContent(event.Content)
571 return a.messages.Update(ctx, *assistantMsg)
572 case provider.EventContentDelta:
573 assistantMsg.AppendContent(event.Content)
574 return a.messages.Update(ctx, *assistantMsg)
575 case provider.EventToolUseStart:
576 logging.Info("Tool call started", "toolCall", event.ToolCall)
577 assistantMsg.AddToolCall(*event.ToolCall)
578 return a.messages.Update(ctx, *assistantMsg)
579 case provider.EventToolUseDelta:
580 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
581 return a.messages.Update(ctx, *assistantMsg)
582 case provider.EventToolUseStop:
583 logging.Info("Finished tool call", "toolCall", event.ToolCall)
584 assistantMsg.FinishToolCall(event.ToolCall.ID)
585 return a.messages.Update(ctx, *assistantMsg)
586 case provider.EventError:
587 if errors.Is(event.Error, context.Canceled) {
588 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
589 return context.Canceled
590 }
591 logging.ErrorPersist(event.Error.Error())
592 return event.Error
593 case provider.EventComplete:
594 assistantMsg.SetToolCalls(event.Response.ToolCalls)
595 assistantMsg.AddFinish(event.Response.FinishReason)
596 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
597 return fmt.Errorf("failed to update message: %w", err)
598 }
599 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
600 }
601
602 return nil
603}
604
605func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
606 sess, err := a.sessions.Get(ctx, sessionID)
607 if err != nil {
608 return fmt.Errorf("failed to get session: %w", err)
609 }
610
611 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
612 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
613 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
614 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
615
616 sess.Cost += cost
617 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
618 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
619
620 _, err = a.sessions.Save(ctx, sess)
621 if err != nil {
622 return fmt.Errorf("failed to save session: %w", err)
623 }
624 return nil
625}
626
627func (a *agent) Summarize(ctx context.Context, sessionID string) error {
628 if a.summarizeProvider == nil {
629 return fmt.Errorf("summarize provider not available")
630 }
631
632 // Check if session is busy
633 if a.IsSessionBusy(sessionID) {
634 return ErrSessionBusy
635 }
636
637 // Create a new context with cancellation
638 summarizeCtx, cancel := context.WithCancel(ctx)
639
640 // Store the cancel function in activeRequests to allow cancellation
641 a.activeRequests.Store(sessionID+"-summarize", cancel)
642
643 go func() {
644 defer a.activeRequests.Delete(sessionID + "-summarize")
645 defer cancel()
646 event := AgentEvent{
647 Type: AgentEventTypeSummarize,
648 Progress: "Starting summarization...",
649 }
650
651 a.Publish(pubsub.CreatedEvent, event)
652 // Get all messages from the session
653 msgs, err := a.messages.List(summarizeCtx, sessionID)
654 if err != nil {
655 event = AgentEvent{
656 Type: AgentEventTypeError,
657 Error: fmt.Errorf("failed to list messages: %w", err),
658 Done: true,
659 }
660 a.Publish(pubsub.CreatedEvent, event)
661 return
662 }
663
664 if len(msgs) == 0 {
665 event = AgentEvent{
666 Type: AgentEventTypeError,
667 Error: fmt.Errorf("no messages to summarize"),
668 Done: true,
669 }
670 a.Publish(pubsub.CreatedEvent, event)
671 return
672 }
673
674 event = AgentEvent{
675 Type: AgentEventTypeSummarize,
676 Progress: "Analyzing conversation...",
677 }
678 a.Publish(pubsub.CreatedEvent, event)
679
680 // Add a system message to guide the summarization
681 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
682
683 // Create a new message with the summarize prompt
684 promptMsg := message.Message{
685 Role: message.User,
686 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
687 }
688
689 // Append the prompt to the messages
690 msgsWithPrompt := append(msgs, promptMsg)
691
692 event = AgentEvent{
693 Type: AgentEventTypeSummarize,
694 Progress: "Generating summary...",
695 }
696
697 a.Publish(pubsub.CreatedEvent, event)
698
699 // Send the messages to the summarize provider
700 response := a.summarizeProvider.StreamResponse(
701 summarizeCtx,
702 msgsWithPrompt,
703 make([]tools.BaseTool, 0),
704 )
705 var finalResponse *provider.ProviderResponse
706 for r := range response {
707 if r.Error != nil {
708 event = AgentEvent{
709 Type: AgentEventTypeError,
710 Error: fmt.Errorf("failed to summarize: %w", err),
711 Done: true,
712 }
713 a.Publish(pubsub.CreatedEvent, event)
714 return
715 }
716 finalResponse = r.Response
717 }
718
719 summary := strings.TrimSpace(finalResponse.Content)
720 if summary == "" {
721 event = AgentEvent{
722 Type: AgentEventTypeError,
723 Error: fmt.Errorf("empty summary returned"),
724 Done: true,
725 }
726 a.Publish(pubsub.CreatedEvent, event)
727 return
728 }
729 event = AgentEvent{
730 Type: AgentEventTypeSummarize,
731 Progress: "Creating new session...",
732 }
733
734 a.Publish(pubsub.CreatedEvent, event)
735 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
736 if err != nil {
737 event = AgentEvent{
738 Type: AgentEventTypeError,
739 Error: fmt.Errorf("failed to get session: %w", err),
740 Done: true,
741 }
742
743 a.Publish(pubsub.CreatedEvent, event)
744 return
745 }
746 // Create a message in the new session with the summary
747 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
748 Role: message.Assistant,
749 Parts: []message.ContentPart{
750 message.TextContent{Text: summary},
751 message.Finish{
752 Reason: message.FinishReasonEndTurn,
753 Time: time.Now().Unix(),
754 },
755 },
756 Model: a.summarizeProvider.Model().ID,
757 Provider: a.summarizeProviderID,
758 })
759 if err != nil {
760 event = AgentEvent{
761 Type: AgentEventTypeError,
762 Error: fmt.Errorf("failed to create summary message: %w", err),
763 Done: true,
764 }
765
766 a.Publish(pubsub.CreatedEvent, event)
767 return
768 }
769 oldSession.SummaryMessageID = msg.ID
770 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
771 oldSession.PromptTokens = 0
772 model := a.summarizeProvider.Model()
773 usage := finalResponse.Usage
774 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
775 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
776 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
777 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
778 oldSession.Cost += cost
779 _, err = a.sessions.Save(summarizeCtx, oldSession)
780 if err != nil {
781 event = AgentEvent{
782 Type: AgentEventTypeError,
783 Error: fmt.Errorf("failed to save session: %w", err),
784 Done: true,
785 }
786 a.Publish(pubsub.CreatedEvent, event)
787 }
788
789 event = AgentEvent{
790 Type: AgentEventTypeSummarize,
791 SessionID: oldSession.ID,
792 Progress: "Summary complete",
793 Done: true,
794 }
795 a.Publish(pubsub.CreatedEvent, event)
796 // Send final success event with the new session ID
797 }()
798
799 return nil
800}
801
802func (a *agent) CancelAll() {
803 a.activeRequests.Range(func(key, value any) bool {
804 a.Cancel(key.(string)) // key is sessionID
805 return true
806 })
807}