1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 configv2 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() configv2.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Update(model configv2.PreferredModel) (configv2.Model, error)
59 Summarize(ctx context.Context, sessionID string) error
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg configv2.Agent
65 sessions session.Service
66 messages message.Service
67
68 tools []tools.BaseTool
69 provider provider.Provider
70 providerID string
71
72 titleProvider provider.Provider
73 summarizeProvider provider.Provider
74 summarizeProviderID string
75
76 activeRequests sync.Map
77}
78
79var agentPromptMap = map[configv2.AgentID]prompt.PromptID{
80 configv2.AgentCoder: prompt.PromptCoder,
81 configv2.AgentTask: prompt.PromptTask,
82}
83
84func NewAgent(
85 agentCfg configv2.Agent,
86 // These services are needed in the tools
87 permissions permission.Service,
88 sessions session.Service,
89 messages message.Service,
90 history history.Service,
91 lspClients map[string]*lsp.Client,
92) (Service, error) {
93 ctx := context.Background()
94 cfg := configv2.Get()
95 otherTools := GetMcpTools(ctx, permissions)
96 if len(lspClients) > 0 {
97 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
98 }
99
100 allTools := []tools.BaseTool{
101 tools.NewBashTool(permissions),
102 tools.NewEditTool(lspClients, permissions, history),
103 tools.NewFetchTool(permissions),
104 tools.NewGlobTool(),
105 tools.NewGrepTool(),
106 tools.NewLsTool(),
107 tools.NewSourcegraphTool(),
108 tools.NewViewTool(lspClients),
109 tools.NewWriteTool(lspClients, permissions, history),
110 }
111
112 if agentCfg.ID == configv2.AgentCoder {
113 taskAgentCfg := configv2.Get().Agents[configv2.AgentTask]
114 if taskAgentCfg.ID == "" {
115 return nil, fmt.Errorf("task agent not found in config")
116 }
117 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118 if err != nil {
119 return nil, fmt.Errorf("failed to create task agent: %w", err)
120 }
121
122 allTools = append(
123 allTools,
124 NewAgentTool(
125 taskAgent,
126 sessions,
127 messages,
128 ),
129 )
130 }
131
132 allTools = append(allTools, otherTools...)
133 var providerCfg configv2.ProviderConfig
134 for _, p := range cfg.Providers {
135 if p.ID == agentCfg.Provider {
136 providerCfg = p
137 break
138 }
139 }
140 if providerCfg.ID == "" {
141 return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider)
142 }
143
144 var model configv2.Model
145 for _, m := range providerCfg.Models {
146 if m.ID == agentCfg.Model {
147 model = m
148 break
149 }
150 }
151 if model.ID == "" {
152 return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider)
153 }
154
155 promptID := agentPromptMap[agentCfg.ID]
156 if promptID == "" {
157 promptID = prompt.PromptDefault
158 }
159 opts := []provider.ProviderClientOption{
160 provider.WithModel(model),
161 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
162 provider.WithMaxTokens(model.DefaultMaxTokens),
163 }
164 agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
165 if err != nil {
166 return nil, err
167 }
168
169 smallModelCfg := cfg.Models.Small
170 var smallModel configv2.Model
171
172 var smallModelProviderCfg configv2.ProviderConfig
173 if smallModelCfg.Provider == providerCfg.ID {
174 smallModelProviderCfg = providerCfg
175 } else {
176 for _, p := range cfg.Providers {
177 if p.ID == smallModelCfg.Provider {
178 smallModelProviderCfg = p
179 break
180 }
181 }
182 if smallModelProviderCfg.ID == "" {
183 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
184 }
185 }
186 for _, m := range smallModelProviderCfg.Models {
187 if m.ID == smallModelCfg.ModelID {
188 smallModel = m
189 break
190 }
191 }
192 if smallModel.ID == "" {
193 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
194 }
195
196 titleOpts := []provider.ProviderClientOption{
197 provider.WithModel(smallModel),
198 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
199 provider.WithMaxTokens(40),
200 }
201 titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
202 if err != nil {
203 return nil, err
204 }
205 summarizeOpts := []provider.ProviderClientOption{
206 provider.WithModel(smallModel),
207 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
208 provider.WithMaxTokens(smallModel.DefaultMaxTokens),
209 }
210 summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
211 if err != nil {
212 return nil, err
213 }
214
215 agentTools := []tools.BaseTool{}
216 if agentCfg.AllowedTools == nil {
217 agentTools = allTools
218 } else {
219 for _, tool := range allTools {
220 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221 agentTools = append(agentTools, tool)
222 }
223 }
224 }
225
226 agent := &agent{
227 Broker: pubsub.NewBroker[AgentEvent](),
228 agentCfg: agentCfg,
229 provider: agentProvider,
230 providerID: string(providerCfg.ID),
231 messages: messages,
232 sessions: sessions,
233 tools: agentTools,
234 titleProvider: titleProvider,
235 summarizeProvider: summarizeProvider,
236 summarizeProviderID: string(smallModelProviderCfg.ID),
237 activeRequests: sync.Map{},
238 }
239
240 return agent, nil
241}
242
243func (a *agent) Model() configv2.Model {
244 return a.provider.Model()
245}
246
247func (a *agent) Cancel(sessionID string) {
248 // Cancel regular requests
249 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
250 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
251 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
252 cancel()
253 }
254 }
255
256 // Also check for summarize requests
257 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
258 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
259 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
260 cancel()
261 }
262 }
263}
264
265func (a *agent) IsBusy() bool {
266 busy := false
267 a.activeRequests.Range(func(key, value any) bool {
268 if cancelFunc, ok := value.(context.CancelFunc); ok {
269 if cancelFunc != nil {
270 busy = true
271 return false // Stop iterating
272 }
273 }
274 return true // Continue iterating
275 })
276 return busy
277}
278
279func (a *agent) IsSessionBusy(sessionID string) bool {
280 _, busy := a.activeRequests.Load(sessionID)
281 return busy
282}
283
284func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
285 if content == "" {
286 return nil
287 }
288 if a.titleProvider == nil {
289 return nil
290 }
291 session, err := a.sessions.Get(ctx, sessionID)
292 if err != nil {
293 return err
294 }
295 parts := []message.ContentPart{message.TextContent{Text: content}}
296
297 // Use streaming approach like summarization
298 response := a.titleProvider.StreamResponse(
299 ctx,
300 []message.Message{
301 {
302 Role: message.User,
303 Parts: parts,
304 },
305 },
306 make([]tools.BaseTool, 0),
307 )
308
309 var finalResponse *provider.ProviderResponse
310 for r := range response {
311 if r.Error != nil {
312 return r.Error
313 }
314 finalResponse = r.Response
315 }
316
317 if finalResponse == nil {
318 return fmt.Errorf("no response received from title provider")
319 }
320
321 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
322 if title == "" {
323 return nil
324 }
325
326 session.Title = title
327 _, err = a.sessions.Save(ctx, session)
328 return err
329}
330
331func (a *agent) err(err error) AgentEvent {
332 return AgentEvent{
333 Type: AgentEventTypeError,
334 Error: err,
335 }
336}
337
338func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
339 if !a.provider.Model().SupportsImages && attachments != nil {
340 attachments = nil
341 }
342 events := make(chan AgentEvent)
343 if a.IsSessionBusy(sessionID) {
344 return nil, ErrSessionBusy
345 }
346
347 genCtx, cancel := context.WithCancel(ctx)
348
349 a.activeRequests.Store(sessionID, cancel)
350 go func() {
351 logging.Debug("Request started", "sessionID", sessionID)
352 defer logging.RecoverPanic("agent.Run", func() {
353 events <- a.err(fmt.Errorf("panic while running the agent"))
354 })
355 var attachmentParts []message.ContentPart
356 for _, attachment := range attachments {
357 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
358 }
359 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
360 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
361 logging.ErrorPersist(result.Error.Error())
362 }
363 logging.Debug("Request completed", "sessionID", sessionID)
364 a.activeRequests.Delete(sessionID)
365 cancel()
366 a.Publish(pubsub.CreatedEvent, result)
367 events <- result
368 close(events)
369 }()
370 return events, nil
371}
372
373func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
374 // List existing messages; if none, start title generation asynchronously.
375 msgs, err := a.messages.List(ctx, sessionID)
376 if err != nil {
377 return a.err(fmt.Errorf("failed to list messages: %w", err))
378 }
379 if len(msgs) == 0 {
380 go func() {
381 defer logging.RecoverPanic("agent.Run", func() {
382 logging.ErrorPersist("panic while generating title")
383 })
384 titleErr := a.generateTitle(context.Background(), sessionID, content)
385 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
386 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
387 }
388 }()
389 }
390 session, err := a.sessions.Get(ctx, sessionID)
391 if err != nil {
392 return a.err(fmt.Errorf("failed to get session: %w", err))
393 }
394 if session.SummaryMessageID != "" {
395 summaryMsgInex := -1
396 for i, msg := range msgs {
397 if msg.ID == session.SummaryMessageID {
398 summaryMsgInex = i
399 break
400 }
401 }
402 if summaryMsgInex != -1 {
403 msgs = msgs[summaryMsgInex:]
404 msgs[0].Role = message.User
405 }
406 }
407
408 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
409 if err != nil {
410 return a.err(fmt.Errorf("failed to create user message: %w", err))
411 }
412 // Append the new user message to the conversation history.
413 msgHistory := append(msgs, userMsg)
414
415 for {
416 // Check for cancellation before each iteration
417 select {
418 case <-ctx.Done():
419 return a.err(ctx.Err())
420 default:
421 // Continue processing
422 }
423 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
424 if err != nil {
425 if errors.Is(err, context.Canceled) {
426 agentMessage.AddFinish(message.FinishReasonCanceled)
427 a.messages.Update(context.Background(), agentMessage)
428 return a.err(ErrRequestCancelled)
429 }
430 return a.err(fmt.Errorf("failed to process events: %w", err))
431 }
432 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
433 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
434 // We are not done, we need to respond with the tool response
435 msgHistory = append(msgHistory, agentMessage, *toolResults)
436 continue
437 }
438 return AgentEvent{
439 Type: AgentEventTypeResponse,
440 Message: agentMessage,
441 Done: true,
442 }
443 }
444}
445
446func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
447 parts := []message.ContentPart{message.TextContent{Text: content}}
448 parts = append(parts, attachmentParts...)
449 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
450 Role: message.User,
451 Parts: parts,
452 })
453}
454
455func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
456 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459 Role: message.Assistant,
460 Parts: []message.ContentPart{},
461 Model: a.provider.Model().ID,
462 Provider: a.providerID,
463 })
464 if err != nil {
465 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466 }
467
468 // Add the session and message ID into the context if needed by tools.
469 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
471
472 // Process each event in the stream.
473 for event := range eventChan {
474 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
475 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
476 return assistantMsg, nil, processErr
477 }
478 if ctx.Err() != nil {
479 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
480 return assistantMsg, nil, ctx.Err()
481 }
482 }
483
484 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
485 toolCalls := assistantMsg.ToolCalls()
486 for i, toolCall := range toolCalls {
487 select {
488 case <-ctx.Done():
489 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
490 // Make all future tool calls cancelled
491 for j := i; j < len(toolCalls); j++ {
492 toolResults[j] = message.ToolResult{
493 ToolCallID: toolCalls[j].ID,
494 Content: "Tool execution canceled by user",
495 IsError: true,
496 }
497 }
498 goto out
499 default:
500 // Continue processing
501 var tool tools.BaseTool
502 for _, availableTools := range a.tools {
503 if availableTools.Info().Name == toolCall.Name {
504 tool = availableTools
505 }
506 }
507
508 // Tool not found
509 if tool == nil {
510 toolResults[i] = message.ToolResult{
511 ToolCallID: toolCall.ID,
512 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
513 IsError: true,
514 }
515 continue
516 }
517 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
518 ID: toolCall.ID,
519 Name: toolCall.Name,
520 Input: toolCall.Input,
521 })
522 if toolErr != nil {
523 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
524 toolResults[i] = message.ToolResult{
525 ToolCallID: toolCall.ID,
526 Content: "Permission denied",
527 IsError: true,
528 }
529 for j := i + 1; j < len(toolCalls); j++ {
530 toolResults[j] = message.ToolResult{
531 ToolCallID: toolCalls[j].ID,
532 Content: "Tool execution canceled by user",
533 IsError: true,
534 }
535 }
536 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
537 break
538 }
539 }
540 toolResults[i] = message.ToolResult{
541 ToolCallID: toolCall.ID,
542 Content: toolResult.Content,
543 Metadata: toolResult.Metadata,
544 IsError: toolResult.IsError,
545 }
546 }
547 }
548out:
549 if len(toolResults) == 0 {
550 return assistantMsg, nil, nil
551 }
552 parts := make([]message.ContentPart, 0)
553 for _, tr := range toolResults {
554 parts = append(parts, tr)
555 }
556 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
557 Role: message.Tool,
558 Parts: parts,
559 Provider: a.providerID,
560 })
561 if err != nil {
562 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
563 }
564
565 return assistantMsg, &msg, err
566}
567
568func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
569 msg.AddFinish(finishReson)
570 _ = a.messages.Update(ctx, *msg)
571}
572
573func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
574 select {
575 case <-ctx.Done():
576 return ctx.Err()
577 default:
578 // Continue processing.
579 }
580
581 switch event.Type {
582 case provider.EventThinkingDelta:
583 assistantMsg.AppendReasoningContent(event.Content)
584 return a.messages.Update(ctx, *assistantMsg)
585 case provider.EventContentDelta:
586 assistantMsg.AppendContent(event.Content)
587 return a.messages.Update(ctx, *assistantMsg)
588 case provider.EventToolUseStart:
589 logging.Info("Tool call started", "toolCall", event.ToolCall)
590 assistantMsg.AddToolCall(*event.ToolCall)
591 return a.messages.Update(ctx, *assistantMsg)
592 case provider.EventToolUseDelta:
593 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
594 return a.messages.Update(ctx, *assistantMsg)
595 case provider.EventToolUseStop:
596 logging.Info("Finished tool call", "toolCall", event.ToolCall)
597 assistantMsg.FinishToolCall(event.ToolCall.ID)
598 return a.messages.Update(ctx, *assistantMsg)
599 case provider.EventError:
600 if errors.Is(event.Error, context.Canceled) {
601 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
602 return context.Canceled
603 }
604 logging.ErrorPersist(event.Error.Error())
605 return event.Error
606 case provider.EventComplete:
607 assistantMsg.SetToolCalls(event.Response.ToolCalls)
608 assistantMsg.AddFinish(event.Response.FinishReason)
609 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
610 return fmt.Errorf("failed to update message: %w", err)
611 }
612 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
613 }
614
615 return nil
616}
617
618func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error {
619 sess, err := a.sessions.Get(ctx, sessionID)
620 if err != nil {
621 return fmt.Errorf("failed to get session: %w", err)
622 }
623
624 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
625 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
626 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
627 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
628
629 sess.Cost += cost
630 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
631 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
632
633 _, err = a.sessions.Save(ctx, sess)
634 if err != nil {
635 return fmt.Errorf("failed to save session: %w", err)
636 }
637 return nil
638}
639
640func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) {
641 if a.IsBusy() {
642 return configv2.Model{}, fmt.Errorf("cannot change model while processing requests")
643 }
644
645 cfg := configv2.Get()
646 var providerCfg configv2.ProviderConfig
647 for _, p := range cfg.Providers {
648 if p.ID == modelCfg.Provider {
649 providerCfg = p
650 break
651 }
652 }
653 if providerCfg.ID == "" {
654 return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider)
655 }
656
657 var model configv2.Model
658 for _, m := range providerCfg.Models {
659 if m.ID == modelCfg.ModelID {
660 model = m
661 break
662 }
663 }
664 if model.ID == "" {
665 return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider)
666 }
667
668 promptID := agentPromptMap[a.agentCfg.ID]
669 if promptID == "" {
670 promptID = prompt.PromptDefault
671 }
672 opts := []provider.ProviderClientOption{
673 provider.WithModel(model),
674 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
675 provider.WithMaxTokens(model.DefaultMaxTokens),
676 }
677 agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
678 if err != nil {
679 return configv2.Model{}, err
680 }
681 a.provider = agentProvider
682
683 return a.provider.Model(), nil
684}
685
686func (a *agent) Summarize(ctx context.Context, sessionID string) error {
687 if a.summarizeProvider == nil {
688 return fmt.Errorf("summarize provider not available")
689 }
690
691 // Check if session is busy
692 if a.IsSessionBusy(sessionID) {
693 return ErrSessionBusy
694 }
695
696 // Create a new context with cancellation
697 summarizeCtx, cancel := context.WithCancel(ctx)
698
699 // Store the cancel function in activeRequests to allow cancellation
700 a.activeRequests.Store(sessionID+"-summarize", cancel)
701
702 go func() {
703 defer a.activeRequests.Delete(sessionID + "-summarize")
704 defer cancel()
705 event := AgentEvent{
706 Type: AgentEventTypeSummarize,
707 Progress: "Starting summarization...",
708 }
709
710 a.Publish(pubsub.CreatedEvent, event)
711 // Get all messages from the session
712 msgs, err := a.messages.List(summarizeCtx, sessionID)
713 if err != nil {
714 event = AgentEvent{
715 Type: AgentEventTypeError,
716 Error: fmt.Errorf("failed to list messages: %w", err),
717 Done: true,
718 }
719 a.Publish(pubsub.CreatedEvent, event)
720 return
721 }
722
723 if len(msgs) == 0 {
724 event = AgentEvent{
725 Type: AgentEventTypeError,
726 Error: fmt.Errorf("no messages to summarize"),
727 Done: true,
728 }
729 a.Publish(pubsub.CreatedEvent, event)
730 return
731 }
732
733 event = AgentEvent{
734 Type: AgentEventTypeSummarize,
735 Progress: "Analyzing conversation...",
736 }
737 a.Publish(pubsub.CreatedEvent, event)
738
739 // Add a system message to guide the summarization
740 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
741
742 // Create a new message with the summarize prompt
743 promptMsg := message.Message{
744 Role: message.User,
745 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
746 }
747
748 // Append the prompt to the messages
749 msgsWithPrompt := append(msgs, promptMsg)
750
751 event = AgentEvent{
752 Type: AgentEventTypeSummarize,
753 Progress: "Generating summary...",
754 }
755
756 a.Publish(pubsub.CreatedEvent, event)
757
758 // Send the messages to the summarize provider
759 response := a.summarizeProvider.StreamResponse(
760 summarizeCtx,
761 msgsWithPrompt,
762 make([]tools.BaseTool, 0),
763 )
764 var finalResponse *provider.ProviderResponse
765 for r := range response {
766 if r.Error != nil {
767 event = AgentEvent{
768 Type: AgentEventTypeError,
769 Error: fmt.Errorf("failed to summarize: %w", err),
770 Done: true,
771 }
772 a.Publish(pubsub.CreatedEvent, event)
773 return
774 }
775 finalResponse = r.Response
776 }
777
778 summary := strings.TrimSpace(finalResponse.Content)
779 if summary == "" {
780 event = AgentEvent{
781 Type: AgentEventTypeError,
782 Error: fmt.Errorf("empty summary returned"),
783 Done: true,
784 }
785 a.Publish(pubsub.CreatedEvent, event)
786 return
787 }
788 event = AgentEvent{
789 Type: AgentEventTypeSummarize,
790 Progress: "Creating new session...",
791 }
792
793 a.Publish(pubsub.CreatedEvent, event)
794 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
795 if err != nil {
796 event = AgentEvent{
797 Type: AgentEventTypeError,
798 Error: fmt.Errorf("failed to get session: %w", err),
799 Done: true,
800 }
801
802 a.Publish(pubsub.CreatedEvent, event)
803 return
804 }
805 // Create a message in the new session with the summary
806 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
807 Role: message.Assistant,
808 Parts: []message.ContentPart{
809 message.TextContent{Text: summary},
810 message.Finish{
811 Reason: message.FinishReasonEndTurn,
812 Time: time.Now().Unix(),
813 },
814 },
815 Model: a.summarizeProvider.Model().ID,
816 Provider: a.summarizeProviderID,
817 })
818 if err != nil {
819 event = AgentEvent{
820 Type: AgentEventTypeError,
821 Error: fmt.Errorf("failed to create summary message: %w", err),
822 Done: true,
823 }
824
825 a.Publish(pubsub.CreatedEvent, event)
826 return
827 }
828 oldSession.SummaryMessageID = msg.ID
829 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
830 oldSession.PromptTokens = 0
831 model := a.summarizeProvider.Model()
832 usage := finalResponse.Usage
833 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
834 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
835 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
836 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
837 oldSession.Cost += cost
838 _, err = a.sessions.Save(summarizeCtx, oldSession)
839 if err != nil {
840 event = AgentEvent{
841 Type: AgentEventTypeError,
842 Error: fmt.Errorf("failed to save session: %w", err),
843 Done: true,
844 }
845 a.Publish(pubsub.CreatedEvent, event)
846 }
847
848 event = AgentEvent{
849 Type: AgentEventTypeSummarize,
850 SessionID: oldSession.ID,
851 Progress: "Summary complete",
852 Done: true,
853 }
854 a.Publish(pubsub.CreatedEvent, event)
855 // Send final success event with the new session ID
856 }()
857
858 return nil
859}
860
861func (a *agent) CancelAll() {
862 a.activeRequests.Range(func(key, value any) bool {
863 a.Cancel(key.(string)) // key is sessionID
864 return true
865 })
866}