1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg config.Agent
65 sessions session.Service
66 messages message.Service
67
68 tools []tools.BaseTool
69 provider provider.Provider
70 providerID string
71
72 titleProvider provider.Provider
73 summarizeProvider provider.Provider
74 summarizeProviderID string
75
76 activeRequests sync.Map
77}
78
79var agentPromptMap = map[config.AgentID]prompt.PromptID{
80 config.AgentCoder: prompt.PromptCoder,
81 config.AgentTask: prompt.PromptTask,
82}
83
84func NewAgent(
85 agentCfg config.Agent,
86 // These services are needed in the tools
87 permissions permission.Service,
88 sessions session.Service,
89 messages message.Service,
90 history history.Service,
91 lspClients map[string]*lsp.Client,
92) (Service, error) {
93 ctx := context.Background()
94 cfg := config.Get()
95 otherTools := GetMcpTools(ctx, permissions)
96 if len(lspClients) > 0 {
97 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
98 }
99
100 allTools := []tools.BaseTool{
101 tools.NewBashTool(permissions),
102 tools.NewEditTool(lspClients, permissions, history),
103 tools.NewFetchTool(permissions),
104 tools.NewGlobTool(),
105 tools.NewGrepTool(),
106 tools.NewLsTool(),
107 tools.NewSourcegraphTool(),
108 tools.NewViewTool(lspClients),
109 tools.NewWriteTool(lspClients, permissions, history),
110 }
111
112 if agentCfg.ID == config.AgentCoder {
113 taskAgentCfg := config.Get().Agents[config.AgentTask]
114 if taskAgentCfg.ID == "" {
115 return nil, fmt.Errorf("task agent not found in config")
116 }
117 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118 if err != nil {
119 return nil, fmt.Errorf("failed to create task agent: %w", err)
120 }
121
122 allTools = append(
123 allTools,
124 NewAgentTool(
125 taskAgent,
126 sessions,
127 messages,
128 ),
129 )
130 }
131
132 allTools = append(allTools, otherTools...)
133 providerCfg := config.GetAgentProvider(agentCfg.ID)
134 if providerCfg.ID == "" {
135 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136 }
137 model := config.GetAgentModel(agentCfg.ID)
138
139 if model.ID == "" {
140 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141 }
142
143 promptID := agentPromptMap[agentCfg.ID]
144 if promptID == "" {
145 promptID = prompt.PromptDefault
146 }
147 opts := []provider.ProviderClientOption{
148 provider.WithModel(agentCfg.Model),
149 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150 }
151 agentProvider, err := provider.NewProvider(providerCfg, opts...)
152 if err != nil {
153 return nil, err
154 }
155
156 smallModelCfg := cfg.Models.Small
157 var smallModel config.Model
158
159 var smallModelProviderCfg config.ProviderConfig
160 if smallModelCfg.Provider == providerCfg.ID {
161 smallModelProviderCfg = providerCfg
162 } else {
163 for _, p := range cfg.Providers {
164 if p.ID == smallModelCfg.Provider {
165 smallModelProviderCfg = p
166 break
167 }
168 }
169 if smallModelProviderCfg.ID == "" {
170 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171 }
172 }
173 for _, m := range smallModelProviderCfg.Models {
174 if m.ID == smallModelCfg.ModelID {
175 smallModel = m
176 break
177 }
178 }
179 if smallModel.ID == "" {
180 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181 }
182
183 titleOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SmallModel),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186 }
187 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
188 if err != nil {
189 return nil, err
190 }
191 summarizeOpts := []provider.ProviderClientOption{
192 provider.WithModel(config.SmallModel),
193 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
194 }
195 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
196 if err != nil {
197 return nil, err
198 }
199
200 agentTools := []tools.BaseTool{}
201 if agentCfg.AllowedTools == nil {
202 agentTools = allTools
203 } else {
204 for _, tool := range allTools {
205 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
206 agentTools = append(agentTools, tool)
207 }
208 }
209 }
210
211 agent := &agent{
212 Broker: pubsub.NewBroker[AgentEvent](),
213 agentCfg: agentCfg,
214 provider: agentProvider,
215 providerID: string(providerCfg.ID),
216 messages: messages,
217 sessions: sessions,
218 tools: agentTools,
219 titleProvider: titleProvider,
220 summarizeProvider: summarizeProvider,
221 summarizeProviderID: string(smallModelProviderCfg.ID),
222 activeRequests: sync.Map{},
223 }
224
225 return agent, nil
226}
227
228func (a *agent) Model() config.Model {
229 return config.GetAgentModel(a.agentCfg.ID)
230}
231
232func (a *agent) Cancel(sessionID string) {
233 // Cancel regular requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240
241 // Also check for summarize requests
242 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
243 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
244 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
245 cancel()
246 }
247 }
248}
249
250func (a *agent) IsBusy() bool {
251 busy := false
252 a.activeRequests.Range(func(key, value any) bool {
253 if cancelFunc, ok := value.(context.CancelFunc); ok {
254 if cancelFunc != nil {
255 busy = true
256 return false // Stop iterating
257 }
258 }
259 return true // Continue iterating
260 })
261 return busy
262}
263
264func (a *agent) IsSessionBusy(sessionID string) bool {
265 _, busy := a.activeRequests.Load(sessionID)
266 return busy
267}
268
269func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
270 if content == "" {
271 return nil
272 }
273 if a.titleProvider == nil {
274 return nil
275 }
276 session, err := a.sessions.Get(ctx, sessionID)
277 if err != nil {
278 return err
279 }
280 parts := []message.ContentPart{message.TextContent{
281 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
282 }}
283
284 // Use streaming approach like summarization
285 response := a.titleProvider.StreamResponse(
286 ctx,
287 []message.Message{
288 {
289 Role: message.User,
290 Parts: parts,
291 },
292 },
293 make([]tools.BaseTool, 0),
294 )
295
296 var finalResponse *provider.ProviderResponse
297 for r := range response {
298 if r.Error != nil {
299 return r.Error
300 }
301 finalResponse = r.Response
302 }
303
304 if finalResponse == nil {
305 return fmt.Errorf("no response received from title provider")
306 }
307
308 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
309 if title == "" {
310 return nil
311 }
312
313 session.Title = title
314 _, err = a.sessions.Save(ctx, session)
315 return err
316}
317
318func (a *agent) err(err error) AgentEvent {
319 return AgentEvent{
320 Type: AgentEventTypeError,
321 Error: err,
322 }
323}
324
325func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
326 if !a.Model().SupportsImages && attachments != nil {
327 attachments = nil
328 }
329 events := make(chan AgentEvent)
330 if a.IsSessionBusy(sessionID) {
331 return nil, ErrSessionBusy
332 }
333
334 genCtx, cancel := context.WithCancel(ctx)
335
336 a.activeRequests.Store(sessionID, cancel)
337 go func() {
338 logging.Debug("Request started", "sessionID", sessionID)
339 defer logging.RecoverPanic("agent.Run", func() {
340 events <- a.err(fmt.Errorf("panic while running the agent"))
341 })
342 var attachmentParts []message.ContentPart
343 for _, attachment := range attachments {
344 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
345 }
346 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
347 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
348 logging.ErrorPersist(result.Error.Error())
349 }
350 logging.Debug("Request completed", "sessionID", sessionID)
351 a.activeRequests.Delete(sessionID)
352 cancel()
353 a.Publish(pubsub.CreatedEvent, result)
354 events <- result
355 close(events)
356 }()
357 return events, nil
358}
359
360func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
361 // List existing messages; if none, start title generation asynchronously.
362 msgs, err := a.messages.List(ctx, sessionID)
363 if err != nil {
364 return a.err(fmt.Errorf("failed to list messages: %w", err))
365 }
366 if len(msgs) == 0 {
367 go func() {
368 defer logging.RecoverPanic("agent.Run", func() {
369 logging.ErrorPersist("panic while generating title")
370 })
371 titleErr := a.generateTitle(context.Background(), sessionID, content)
372 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
374 }
375 }()
376 }
377 session, err := a.sessions.Get(ctx, sessionID)
378 if err != nil {
379 return a.err(fmt.Errorf("failed to get session: %w", err))
380 }
381 if session.SummaryMessageID != "" {
382 summaryMsgInex := -1
383 for i, msg := range msgs {
384 if msg.ID == session.SummaryMessageID {
385 summaryMsgInex = i
386 break
387 }
388 }
389 if summaryMsgInex != -1 {
390 msgs = msgs[summaryMsgInex:]
391 msgs[0].Role = message.User
392 }
393 }
394
395 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to create user message: %w", err))
398 }
399 // Append the new user message to the conversation history.
400 msgHistory := append(msgs, userMsg)
401
402 for {
403 // Check for cancellation before each iteration
404 select {
405 case <-ctx.Done():
406 return a.err(ctx.Err())
407 default:
408 // Continue processing
409 }
410 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411 if err != nil {
412 if errors.Is(err, context.Canceled) {
413 agentMessage.AddFinish(message.FinishReasonCanceled)
414 a.messages.Update(context.Background(), agentMessage)
415 return a.err(ErrRequestCancelled)
416 }
417 return a.err(fmt.Errorf("failed to process events: %w", err))
418 }
419 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
421 // We are not done, we need to respond with the tool response
422 msgHistory = append(msgHistory, agentMessage, *toolResults)
423 continue
424 }
425 return AgentEvent{
426 Type: AgentEventTypeResponse,
427 Message: agentMessage,
428 Done: true,
429 }
430 }
431}
432
433func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
434 parts := []message.ContentPart{message.TextContent{Text: content}}
435 parts = append(parts, attachmentParts...)
436 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
437 Role: message.User,
438 Parts: parts,
439 })
440}
441
442func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
443 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
444
445 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
446 Role: message.Assistant,
447 Parts: []message.ContentPart{},
448 Model: a.Model().ID,
449 Provider: a.providerID,
450 })
451 if err != nil {
452 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
453 }
454
455 // Add the session and message ID into the context if needed by tools.
456 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
457 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
458
459 // Process each event in the stream.
460 for event := range eventChan {
461 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
462 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
463 return assistantMsg, nil, processErr
464 }
465 if ctx.Err() != nil {
466 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
467 return assistantMsg, nil, ctx.Err()
468 }
469 }
470
471 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
472 toolCalls := assistantMsg.ToolCalls()
473 for i, toolCall := range toolCalls {
474 select {
475 case <-ctx.Done():
476 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
477 // Make all future tool calls cancelled
478 for j := i; j < len(toolCalls); j++ {
479 toolResults[j] = message.ToolResult{
480 ToolCallID: toolCalls[j].ID,
481 Content: "Tool execution canceled by user",
482 IsError: true,
483 }
484 }
485 goto out
486 default:
487 // Continue processing
488 var tool tools.BaseTool
489 for _, availableTools := range a.tools {
490 if availableTools.Info().Name == toolCall.Name {
491 tool = availableTools
492 }
493 }
494
495 // Tool not found
496 if tool == nil {
497 toolResults[i] = message.ToolResult{
498 ToolCallID: toolCall.ID,
499 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
500 IsError: true,
501 }
502 continue
503 }
504 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
505 ID: toolCall.ID,
506 Name: toolCall.Name,
507 Input: toolCall.Input,
508 })
509 if toolErr != nil {
510 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
511 toolResults[i] = message.ToolResult{
512 ToolCallID: toolCall.ID,
513 Content: "Permission denied",
514 IsError: true,
515 }
516 for j := i + 1; j < len(toolCalls); j++ {
517 toolResults[j] = message.ToolResult{
518 ToolCallID: toolCalls[j].ID,
519 Content: "Tool execution canceled by user",
520 IsError: true,
521 }
522 }
523 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
524 break
525 }
526 }
527 toolResults[i] = message.ToolResult{
528 ToolCallID: toolCall.ID,
529 Content: toolResult.Content,
530 Metadata: toolResult.Metadata,
531 IsError: toolResult.IsError,
532 }
533 }
534 }
535out:
536 if len(toolResults) == 0 {
537 return assistantMsg, nil, nil
538 }
539 parts := make([]message.ContentPart, 0)
540 for _, tr := range toolResults {
541 parts = append(parts, tr)
542 }
543 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: parts,
546 Provider: a.providerID,
547 })
548 if err != nil {
549 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
550 }
551
552 return assistantMsg, &msg, err
553}
554
555func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
556 msg.AddFinish(finishReson)
557 _ = a.messages.Update(ctx, *msg)
558}
559
560func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
561 select {
562 case <-ctx.Done():
563 return ctx.Err()
564 default:
565 // Continue processing.
566 }
567
568 switch event.Type {
569 case provider.EventThinkingDelta:
570 assistantMsg.AppendReasoningContent(event.Content)
571 return a.messages.Update(ctx, *assistantMsg)
572 case provider.EventContentDelta:
573 assistantMsg.AppendContent(event.Content)
574 return a.messages.Update(ctx, *assistantMsg)
575 case provider.EventToolUseStart:
576 logging.Info("Tool call started", "toolCall", event.ToolCall)
577 assistantMsg.AddToolCall(*event.ToolCall)
578 return a.messages.Update(ctx, *assistantMsg)
579 case provider.EventToolUseDelta:
580 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
581 return a.messages.Update(ctx, *assistantMsg)
582 case provider.EventToolUseStop:
583 logging.Info("Finished tool call", "toolCall", event.ToolCall)
584 assistantMsg.FinishToolCall(event.ToolCall.ID)
585 return a.messages.Update(ctx, *assistantMsg)
586 case provider.EventError:
587 if errors.Is(event.Error, context.Canceled) {
588 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
589 return context.Canceled
590 }
591 logging.ErrorPersist(event.Error.Error())
592 return event.Error
593 case provider.EventComplete:
594 assistantMsg.SetToolCalls(event.Response.ToolCalls)
595 assistantMsg.AddFinish(event.Response.FinishReason)
596 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
597 return fmt.Errorf("failed to update message: %w", err)
598 }
599 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
600 }
601
602 return nil
603}
604
605func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
606 sess, err := a.sessions.Get(ctx, sessionID)
607 if err != nil {
608 return fmt.Errorf("failed to get session: %w", err)
609 }
610
611 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
612 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
613 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
614 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
615
616 sess.Cost += cost
617 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
618 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
619
620 _, err = a.sessions.Save(ctx, sess)
621 if err != nil {
622 return fmt.Errorf("failed to save session: %w", err)
623 }
624 return nil
625}
626
627func (a *agent) Summarize(ctx context.Context, sessionID string) error {
628 if a.summarizeProvider == nil {
629 return fmt.Errorf("summarize provider not available")
630 }
631
632 // Check if session is busy
633 if a.IsSessionBusy(sessionID) {
634 return ErrSessionBusy
635 }
636
637 // Create a new context with cancellation
638 summarizeCtx, cancel := context.WithCancel(ctx)
639
640 // Store the cancel function in activeRequests to allow cancellation
641 a.activeRequests.Store(sessionID+"-summarize", cancel)
642
643 go func() {
644 defer a.activeRequests.Delete(sessionID + "-summarize")
645 defer cancel()
646 event := AgentEvent{
647 Type: AgentEventTypeSummarize,
648 Progress: "Starting summarization...",
649 }
650
651 a.Publish(pubsub.CreatedEvent, event)
652 // Get all messages from the session
653 msgs, err := a.messages.List(summarizeCtx, sessionID)
654 if err != nil {
655 event = AgentEvent{
656 Type: AgentEventTypeError,
657 Error: fmt.Errorf("failed to list messages: %w", err),
658 Done: true,
659 }
660 a.Publish(pubsub.CreatedEvent, event)
661 return
662 }
663
664 if len(msgs) == 0 {
665 event = AgentEvent{
666 Type: AgentEventTypeError,
667 Error: fmt.Errorf("no messages to summarize"),
668 Done: true,
669 }
670 a.Publish(pubsub.CreatedEvent, event)
671 return
672 }
673
674 event = AgentEvent{
675 Type: AgentEventTypeSummarize,
676 Progress: "Analyzing conversation...",
677 }
678 a.Publish(pubsub.CreatedEvent, event)
679
680 // Add a system message to guide the summarization
681 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
682
683 // Create a new message with the summarize prompt
684 promptMsg := message.Message{
685 Role: message.User,
686 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
687 }
688
689 // Append the prompt to the messages
690 msgsWithPrompt := append(msgs, promptMsg)
691
692 event = AgentEvent{
693 Type: AgentEventTypeSummarize,
694 Progress: "Generating summary...",
695 }
696
697 a.Publish(pubsub.CreatedEvent, event)
698
699 // Send the messages to the summarize provider
700 response := a.summarizeProvider.StreamResponse(
701 summarizeCtx,
702 msgsWithPrompt,
703 make([]tools.BaseTool, 0),
704 )
705 var finalResponse *provider.ProviderResponse
706 for r := range response {
707 if r.Error != nil {
708 event = AgentEvent{
709 Type: AgentEventTypeError,
710 Error: fmt.Errorf("failed to summarize: %w", err),
711 Done: true,
712 }
713 a.Publish(pubsub.CreatedEvent, event)
714 return
715 }
716 finalResponse = r.Response
717 }
718
719 summary := strings.TrimSpace(finalResponse.Content)
720 if summary == "" {
721 event = AgentEvent{
722 Type: AgentEventTypeError,
723 Error: fmt.Errorf("empty summary returned"),
724 Done: true,
725 }
726 a.Publish(pubsub.CreatedEvent, event)
727 return
728 }
729 event = AgentEvent{
730 Type: AgentEventTypeSummarize,
731 Progress: "Creating new session...",
732 }
733
734 a.Publish(pubsub.CreatedEvent, event)
735 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
736 if err != nil {
737 event = AgentEvent{
738 Type: AgentEventTypeError,
739 Error: fmt.Errorf("failed to get session: %w", err),
740 Done: true,
741 }
742
743 a.Publish(pubsub.CreatedEvent, event)
744 return
745 }
746 // Create a message in the new session with the summary
747 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
748 Role: message.Assistant,
749 Parts: []message.ContentPart{
750 message.TextContent{Text: summary},
751 message.Finish{
752 Reason: message.FinishReasonEndTurn,
753 Time: time.Now().Unix(),
754 },
755 },
756 Model: a.summarizeProvider.Model().ID,
757 Provider: a.summarizeProviderID,
758 })
759 if err != nil {
760 event = AgentEvent{
761 Type: AgentEventTypeError,
762 Error: fmt.Errorf("failed to create summary message: %w", err),
763 Done: true,
764 }
765
766 a.Publish(pubsub.CreatedEvent, event)
767 return
768 }
769 oldSession.SummaryMessageID = msg.ID
770 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
771 oldSession.PromptTokens = 0
772 model := a.summarizeProvider.Model()
773 usage := finalResponse.Usage
774 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
775 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
776 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
777 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
778 oldSession.Cost += cost
779 _, err = a.sessions.Save(summarizeCtx, oldSession)
780 if err != nil {
781 event = AgentEvent{
782 Type: AgentEventTypeError,
783 Error: fmt.Errorf("failed to save session: %w", err),
784 Done: true,
785 }
786 a.Publish(pubsub.CreatedEvent, event)
787 }
788
789 event = AgentEvent{
790 Type: AgentEventTypeSummarize,
791 SessionID: oldSession.ID,
792 Progress: "Summary complete",
793 Done: true,
794 }
795 a.Publish(pubsub.CreatedEvent, event)
796 // Send final success event with the new session ID
797 }()
798
799 return nil
800}
801
802func (a *agent) CancelAll() {
803 a.activeRequests.Range(func(key, value any) bool {
804 a.Cancel(key.(string)) // key is sessionID
805 return true
806 })
807}
808
809func (a *agent) UpdateModel() error {
810 cfg := config.Get()
811
812 // Get current provider configuration
813 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
814 if currentProviderCfg.ID == "" {
815 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
816 }
817
818 // Check if provider has changed
819 if string(currentProviderCfg.ID) != a.providerID {
820 // Provider changed, need to recreate the main provider
821 model := config.GetAgentModel(a.agentCfg.ID)
822 if model.ID == "" {
823 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
824 }
825
826 promptID := agentPromptMap[a.agentCfg.ID]
827 if promptID == "" {
828 promptID = prompt.PromptDefault
829 }
830
831 opts := []provider.ProviderClientOption{
832 provider.WithModel(a.agentCfg.Model),
833 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
834 }
835
836 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
837 if err != nil {
838 return fmt.Errorf("failed to create new provider: %w", err)
839 }
840
841 // Update the provider and provider ID
842 a.provider = newProvider
843 a.providerID = string(currentProviderCfg.ID)
844 }
845
846 // Check if small model provider has changed (affects title and summarize providers)
847 smallModelCfg := cfg.Models.Small
848 var smallModelProviderCfg config.ProviderConfig
849
850 for _, p := range cfg.Providers {
851 if p.ID == smallModelCfg.Provider {
852 smallModelProviderCfg = p
853 break
854 }
855 }
856
857 if smallModelProviderCfg.ID == "" {
858 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
859 }
860
861 // Check if summarize provider has changed
862 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
863 var smallModel config.Model
864 for _, m := range smallModelProviderCfg.Models {
865 if m.ID == smallModelCfg.ModelID {
866 smallModel = m
867 break
868 }
869 }
870 if smallModel.ID == "" {
871 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
872 }
873
874 // Recreate title provider
875 titleOpts := []provider.ProviderClientOption{
876 provider.WithModel(config.SmallModel),
877 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
878 // We want the title to be short, so we limit the max tokens
879 provider.WithMaxTokens(40),
880 }
881 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
882 if err != nil {
883 return fmt.Errorf("failed to create new title provider: %w", err)
884 }
885
886 // Recreate summarize provider
887 summarizeOpts := []provider.ProviderClientOption{
888 provider.WithModel(config.SmallModel),
889 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
890 }
891 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
892 if err != nil {
893 return fmt.Errorf("failed to create new summarize provider: %w", err)
894 }
895
896 // Update the providers and provider ID
897 a.titleProvider = newTitleProvider
898 a.summarizeProvider = newSummarizeProvider
899 a.summarizeProviderID = string(smallModelProviderCfg.ID)
900 }
901
902 return nil
903}