1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg config.Agent
65 sessions session.Service
66 messages message.Service
67
68 tools []tools.BaseTool
69 provider provider.Provider
70 providerID string
71
72 titleProvider provider.Provider
73 summarizeProvider provider.Provider
74 summarizeProviderID string
75
76 activeRequests sync.Map
77}
78
79var agentPromptMap = map[config.AgentID]prompt.PromptID{
80 config.AgentCoder: prompt.PromptCoder,
81 config.AgentTask: prompt.PromptTask,
82}
83
84func NewAgent(
85 agentCfg config.Agent,
86 // These services are needed in the tools
87 permissions permission.Service,
88 sessions session.Service,
89 messages message.Service,
90 history history.Service,
91 lspClients map[string]*lsp.Client,
92) (Service, error) {
93 ctx := context.Background()
94 cfg := config.Get()
95 otherTools := GetMcpTools(ctx, permissions)
96 if len(lspClients) > 0 {
97 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
98 }
99
100 allTools := []tools.BaseTool{
101 tools.NewBashTool(permissions),
102 tools.NewEditTool(lspClients, permissions, history),
103 tools.NewFetchTool(permissions),
104 tools.NewGlobTool(),
105 tools.NewGrepTool(),
106 tools.NewLsTool(),
107 tools.NewSourcegraphTool(),
108 tools.NewViewTool(lspClients),
109 tools.NewWriteTool(lspClients, permissions, history),
110 }
111
112 if agentCfg.ID == config.AgentCoder {
113 taskAgentCfg := config.Get().Agents[config.AgentTask]
114 if taskAgentCfg.ID == "" {
115 return nil, fmt.Errorf("task agent not found in config")
116 }
117 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118 if err != nil {
119 return nil, fmt.Errorf("failed to create task agent: %w", err)
120 }
121
122 allTools = append(
123 allTools,
124 NewAgentTool(
125 taskAgent,
126 sessions,
127 messages,
128 ),
129 )
130 }
131
132 allTools = append(allTools, otherTools...)
133 providerCfg := config.GetAgentProvider(agentCfg.ID)
134 if providerCfg.ID == "" {
135 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136 }
137 model := config.GetAgentModel(agentCfg.ID)
138
139 if model.ID == "" {
140 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141 }
142
143 promptID := agentPromptMap[agentCfg.ID]
144 if promptID == "" {
145 promptID = prompt.PromptDefault
146 }
147 opts := []provider.ProviderClientOption{
148 provider.WithModel(agentCfg.Model),
149 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150 provider.WithMaxTokens(model.DefaultMaxTokens),
151 }
152 agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
153 if err != nil {
154 return nil, err
155 }
156
157 smallModelCfg := cfg.Models.Small
158 var smallModel config.Model
159
160 var smallModelProviderCfg config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 for _, p := range cfg.Providers {
165 if p.ID == smallModelCfg.Provider {
166 smallModelProviderCfg = p
167 break
168 }
169 }
170 if smallModelProviderCfg.ID == "" {
171 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172 }
173 }
174 for _, m := range smallModelProviderCfg.Models {
175 if m.ID == smallModelCfg.ModelID {
176 smallModel = m
177 break
178 }
179 }
180 if smallModel.ID == "" {
181 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182 }
183
184 titleOpts := []provider.ProviderClientOption{
185 provider.WithModel(config.SmallModel),
186 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187 provider.WithMaxTokens(40),
188 }
189 titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
190 if err != nil {
191 return nil, err
192 }
193 summarizeOpts := []provider.ProviderClientOption{
194 provider.WithModel(config.SmallModel),
195 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196 provider.WithMaxTokens(smallModel.DefaultMaxTokens),
197 }
198 summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
199 if err != nil {
200 return nil, err
201 }
202
203 agentTools := []tools.BaseTool{}
204 if agentCfg.AllowedTools == nil {
205 agentTools = allTools
206 } else {
207 for _, tool := range allTools {
208 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209 agentTools = append(agentTools, tool)
210 }
211 }
212 }
213
214 agent := &agent{
215 Broker: pubsub.NewBroker[AgentEvent](),
216 agentCfg: agentCfg,
217 provider: agentProvider,
218 providerID: string(providerCfg.ID),
219 messages: messages,
220 sessions: sessions,
221 tools: agentTools,
222 titleProvider: titleProvider,
223 summarizeProvider: summarizeProvider,
224 summarizeProviderID: string(smallModelProviderCfg.ID),
225 activeRequests: sync.Map{},
226 }
227
228 return agent, nil
229}
230
231func (a *agent) Model() config.Model {
232 return config.GetAgentModel(a.agentCfg.ID)
233}
234
235func (a *agent) Cancel(sessionID string) {
236 // Cancel regular requests
237 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
238 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
239 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
240 cancel()
241 }
242 }
243
244 // Also check for summarize requests
245 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
246 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
247 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
248 cancel()
249 }
250 }
251}
252
253func (a *agent) IsBusy() bool {
254 busy := false
255 a.activeRequests.Range(func(key, value any) bool {
256 if cancelFunc, ok := value.(context.CancelFunc); ok {
257 if cancelFunc != nil {
258 busy = true
259 return false // Stop iterating
260 }
261 }
262 return true // Continue iterating
263 })
264 return busy
265}
266
267func (a *agent) IsSessionBusy(sessionID string) bool {
268 _, busy := a.activeRequests.Load(sessionID)
269 return busy
270}
271
272func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
273 if content == "" {
274 return nil
275 }
276 if a.titleProvider == nil {
277 return nil
278 }
279 session, err := a.sessions.Get(ctx, sessionID)
280 if err != nil {
281 return err
282 }
283 parts := []message.ContentPart{message.TextContent{Text: content}}
284
285 // Use streaming approach like summarization
286 response := a.titleProvider.StreamResponse(
287 ctx,
288 []message.Message{
289 {
290 Role: message.User,
291 Parts: parts,
292 },
293 },
294 make([]tools.BaseTool, 0),
295 )
296
297 var finalResponse *provider.ProviderResponse
298 for r := range response {
299 if r.Error != nil {
300 return r.Error
301 }
302 finalResponse = r.Response
303 }
304
305 if finalResponse == nil {
306 return fmt.Errorf("no response received from title provider")
307 }
308
309 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
310 if title == "" {
311 return nil
312 }
313
314 session.Title = title
315 _, err = a.sessions.Save(ctx, session)
316 return err
317}
318
319func (a *agent) err(err error) AgentEvent {
320 return AgentEvent{
321 Type: AgentEventTypeError,
322 Error: err,
323 }
324}
325
326func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
327 if !a.Model().SupportsImages && attachments != nil {
328 attachments = nil
329 }
330 events := make(chan AgentEvent)
331 if a.IsSessionBusy(sessionID) {
332 return nil, ErrSessionBusy
333 }
334
335 genCtx, cancel := context.WithCancel(ctx)
336
337 a.activeRequests.Store(sessionID, cancel)
338 go func() {
339 logging.Debug("Request started", "sessionID", sessionID)
340 defer logging.RecoverPanic("agent.Run", func() {
341 events <- a.err(fmt.Errorf("panic while running the agent"))
342 })
343 var attachmentParts []message.ContentPart
344 for _, attachment := range attachments {
345 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
346 }
347 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
348 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
349 logging.ErrorPersist(result.Error.Error())
350 }
351 logging.Debug("Request completed", "sessionID", sessionID)
352 a.activeRequests.Delete(sessionID)
353 cancel()
354 a.Publish(pubsub.CreatedEvent, result)
355 events <- result
356 close(events)
357 }()
358 return events, nil
359}
360
361func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
362 // List existing messages; if none, start title generation asynchronously.
363 msgs, err := a.messages.List(ctx, sessionID)
364 if err != nil {
365 return a.err(fmt.Errorf("failed to list messages: %w", err))
366 }
367 if len(msgs) == 0 {
368 go func() {
369 defer logging.RecoverPanic("agent.Run", func() {
370 logging.ErrorPersist("panic while generating title")
371 })
372 titleErr := a.generateTitle(context.Background(), sessionID, content)
373 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
374 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
375 }
376 }()
377 }
378 session, err := a.sessions.Get(ctx, sessionID)
379 if err != nil {
380 return a.err(fmt.Errorf("failed to get session: %w", err))
381 }
382 if session.SummaryMessageID != "" {
383 summaryMsgInex := -1
384 for i, msg := range msgs {
385 if msg.ID == session.SummaryMessageID {
386 summaryMsgInex = i
387 break
388 }
389 }
390 if summaryMsgInex != -1 {
391 msgs = msgs[summaryMsgInex:]
392 msgs[0].Role = message.User
393 }
394 }
395
396 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
397 if err != nil {
398 return a.err(fmt.Errorf("failed to create user message: %w", err))
399 }
400 // Append the new user message to the conversation history.
401 msgHistory := append(msgs, userMsg)
402
403 for {
404 // Check for cancellation before each iteration
405 select {
406 case <-ctx.Done():
407 return a.err(ctx.Err())
408 default:
409 // Continue processing
410 }
411 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
412 if err != nil {
413 if errors.Is(err, context.Canceled) {
414 agentMessage.AddFinish(message.FinishReasonCanceled)
415 a.messages.Update(context.Background(), agentMessage)
416 return a.err(ErrRequestCancelled)
417 }
418 return a.err(fmt.Errorf("failed to process events: %w", err))
419 }
420 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422 // We are not done, we need to respond with the tool response
423 msgHistory = append(msgHistory, agentMessage, *toolResults)
424 continue
425 }
426 return AgentEvent{
427 Type: AgentEventTypeResponse,
428 Message: agentMessage,
429 Done: true,
430 }
431 }
432}
433
434func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
435 parts := []message.ContentPart{message.TextContent{Text: content}}
436 parts = append(parts, attachmentParts...)
437 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
438 Role: message.User,
439 Parts: parts,
440 })
441}
442
443func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
444 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
445
446 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
447 Role: message.Assistant,
448 Parts: []message.ContentPart{},
449 Model: a.Model().ID,
450 Provider: a.providerID,
451 })
452 if err != nil {
453 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
454 }
455
456 // Add the session and message ID into the context if needed by tools.
457 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
458 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
459
460 // Process each event in the stream.
461 for event := range eventChan {
462 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
463 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
464 return assistantMsg, nil, processErr
465 }
466 if ctx.Err() != nil {
467 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
468 return assistantMsg, nil, ctx.Err()
469 }
470 }
471
472 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473 toolCalls := assistantMsg.ToolCalls()
474 for i, toolCall := range toolCalls {
475 select {
476 case <-ctx.Done():
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
478 // Make all future tool calls cancelled
479 for j := i; j < len(toolCalls); j++ {
480 toolResults[j] = message.ToolResult{
481 ToolCallID: toolCalls[j].ID,
482 Content: "Tool execution canceled by user",
483 IsError: true,
484 }
485 }
486 goto out
487 default:
488 // Continue processing
489 var tool tools.BaseTool
490 for _, availableTools := range a.tools {
491 if availableTools.Info().Name == toolCall.Name {
492 tool = availableTools
493 }
494 }
495
496 // Tool not found
497 if tool == nil {
498 toolResults[i] = message.ToolResult{
499 ToolCallID: toolCall.ID,
500 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
501 IsError: true,
502 }
503 continue
504 }
505 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
506 ID: toolCall.ID,
507 Name: toolCall.Name,
508 Input: toolCall.Input,
509 })
510 if toolErr != nil {
511 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
512 toolResults[i] = message.ToolResult{
513 ToolCallID: toolCall.ID,
514 Content: "Permission denied",
515 IsError: true,
516 }
517 for j := i + 1; j < len(toolCalls); j++ {
518 toolResults[j] = message.ToolResult{
519 ToolCallID: toolCalls[j].ID,
520 Content: "Tool execution canceled by user",
521 IsError: true,
522 }
523 }
524 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
525 break
526 }
527 }
528 toolResults[i] = message.ToolResult{
529 ToolCallID: toolCall.ID,
530 Content: toolResult.Content,
531 Metadata: toolResult.Metadata,
532 IsError: toolResult.IsError,
533 }
534 }
535 }
536out:
537 if len(toolResults) == 0 {
538 return assistantMsg, nil, nil
539 }
540 parts := make([]message.ContentPart, 0)
541 for _, tr := range toolResults {
542 parts = append(parts, tr)
543 }
544 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
545 Role: message.Tool,
546 Parts: parts,
547 Provider: a.providerID,
548 })
549 if err != nil {
550 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
551 }
552
553 return assistantMsg, &msg, err
554}
555
556func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
557 msg.AddFinish(finishReson)
558 _ = a.messages.Update(ctx, *msg)
559}
560
561func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
562 select {
563 case <-ctx.Done():
564 return ctx.Err()
565 default:
566 // Continue processing.
567 }
568
569 switch event.Type {
570 case provider.EventThinkingDelta:
571 assistantMsg.AppendReasoningContent(event.Content)
572 return a.messages.Update(ctx, *assistantMsg)
573 case provider.EventContentDelta:
574 assistantMsg.AppendContent(event.Content)
575 return a.messages.Update(ctx, *assistantMsg)
576 case provider.EventToolUseStart:
577 logging.Info("Tool call started", "toolCall", event.ToolCall)
578 assistantMsg.AddToolCall(*event.ToolCall)
579 return a.messages.Update(ctx, *assistantMsg)
580 case provider.EventToolUseDelta:
581 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
582 return a.messages.Update(ctx, *assistantMsg)
583 case provider.EventToolUseStop:
584 logging.Info("Finished tool call", "toolCall", event.ToolCall)
585 assistantMsg.FinishToolCall(event.ToolCall.ID)
586 return a.messages.Update(ctx, *assistantMsg)
587 case provider.EventError:
588 if errors.Is(event.Error, context.Canceled) {
589 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
590 return context.Canceled
591 }
592 logging.ErrorPersist(event.Error.Error())
593 return event.Error
594 case provider.EventComplete:
595 assistantMsg.SetToolCalls(event.Response.ToolCalls)
596 assistantMsg.AddFinish(event.Response.FinishReason)
597 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
598 return fmt.Errorf("failed to update message: %w", err)
599 }
600 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
601 }
602
603 return nil
604}
605
606func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
607 sess, err := a.sessions.Get(ctx, sessionID)
608 if err != nil {
609 return fmt.Errorf("failed to get session: %w", err)
610 }
611
612 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
613 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
614 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
615 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
616
617 sess.Cost += cost
618 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
619 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
620
621 _, err = a.sessions.Save(ctx, sess)
622 if err != nil {
623 return fmt.Errorf("failed to save session: %w", err)
624 }
625 return nil
626}
627
628func (a *agent) Summarize(ctx context.Context, sessionID string) error {
629 if a.summarizeProvider == nil {
630 return fmt.Errorf("summarize provider not available")
631 }
632
633 // Check if session is busy
634 if a.IsSessionBusy(sessionID) {
635 return ErrSessionBusy
636 }
637
638 // Create a new context with cancellation
639 summarizeCtx, cancel := context.WithCancel(ctx)
640
641 // Store the cancel function in activeRequests to allow cancellation
642 a.activeRequests.Store(sessionID+"-summarize", cancel)
643
644 go func() {
645 defer a.activeRequests.Delete(sessionID + "-summarize")
646 defer cancel()
647 event := AgentEvent{
648 Type: AgentEventTypeSummarize,
649 Progress: "Starting summarization...",
650 }
651
652 a.Publish(pubsub.CreatedEvent, event)
653 // Get all messages from the session
654 msgs, err := a.messages.List(summarizeCtx, sessionID)
655 if err != nil {
656 event = AgentEvent{
657 Type: AgentEventTypeError,
658 Error: fmt.Errorf("failed to list messages: %w", err),
659 Done: true,
660 }
661 a.Publish(pubsub.CreatedEvent, event)
662 return
663 }
664
665 if len(msgs) == 0 {
666 event = AgentEvent{
667 Type: AgentEventTypeError,
668 Error: fmt.Errorf("no messages to summarize"),
669 Done: true,
670 }
671 a.Publish(pubsub.CreatedEvent, event)
672 return
673 }
674
675 event = AgentEvent{
676 Type: AgentEventTypeSummarize,
677 Progress: "Analyzing conversation...",
678 }
679 a.Publish(pubsub.CreatedEvent, event)
680
681 // Add a system message to guide the summarization
682 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
683
684 // Create a new message with the summarize prompt
685 promptMsg := message.Message{
686 Role: message.User,
687 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
688 }
689
690 // Append the prompt to the messages
691 msgsWithPrompt := append(msgs, promptMsg)
692
693 event = AgentEvent{
694 Type: AgentEventTypeSummarize,
695 Progress: "Generating summary...",
696 }
697
698 a.Publish(pubsub.CreatedEvent, event)
699
700 // Send the messages to the summarize provider
701 response := a.summarizeProvider.StreamResponse(
702 summarizeCtx,
703 msgsWithPrompt,
704 make([]tools.BaseTool, 0),
705 )
706 var finalResponse *provider.ProviderResponse
707 for r := range response {
708 if r.Error != nil {
709 event = AgentEvent{
710 Type: AgentEventTypeError,
711 Error: fmt.Errorf("failed to summarize: %w", err),
712 Done: true,
713 }
714 a.Publish(pubsub.CreatedEvent, event)
715 return
716 }
717 finalResponse = r.Response
718 }
719
720 summary := strings.TrimSpace(finalResponse.Content)
721 if summary == "" {
722 event = AgentEvent{
723 Type: AgentEventTypeError,
724 Error: fmt.Errorf("empty summary returned"),
725 Done: true,
726 }
727 a.Publish(pubsub.CreatedEvent, event)
728 return
729 }
730 event = AgentEvent{
731 Type: AgentEventTypeSummarize,
732 Progress: "Creating new session...",
733 }
734
735 a.Publish(pubsub.CreatedEvent, event)
736 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
737 if err != nil {
738 event = AgentEvent{
739 Type: AgentEventTypeError,
740 Error: fmt.Errorf("failed to get session: %w", err),
741 Done: true,
742 }
743
744 a.Publish(pubsub.CreatedEvent, event)
745 return
746 }
747 // Create a message in the new session with the summary
748 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
749 Role: message.Assistant,
750 Parts: []message.ContentPart{
751 message.TextContent{Text: summary},
752 message.Finish{
753 Reason: message.FinishReasonEndTurn,
754 Time: time.Now().Unix(),
755 },
756 },
757 Model: a.summarizeProvider.Model().ID,
758 Provider: a.summarizeProviderID,
759 })
760 if err != nil {
761 event = AgentEvent{
762 Type: AgentEventTypeError,
763 Error: fmt.Errorf("failed to create summary message: %w", err),
764 Done: true,
765 }
766
767 a.Publish(pubsub.CreatedEvent, event)
768 return
769 }
770 oldSession.SummaryMessageID = msg.ID
771 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
772 oldSession.PromptTokens = 0
773 model := a.summarizeProvider.Model()
774 usage := finalResponse.Usage
775 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
776 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
777 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
778 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
779 oldSession.Cost += cost
780 _, err = a.sessions.Save(summarizeCtx, oldSession)
781 if err != nil {
782 event = AgentEvent{
783 Type: AgentEventTypeError,
784 Error: fmt.Errorf("failed to save session: %w", err),
785 Done: true,
786 }
787 a.Publish(pubsub.CreatedEvent, event)
788 }
789
790 event = AgentEvent{
791 Type: AgentEventTypeSummarize,
792 SessionID: oldSession.ID,
793 Progress: "Summary complete",
794 Done: true,
795 }
796 a.Publish(pubsub.CreatedEvent, event)
797 // Send final success event with the new session ID
798 }()
799
800 return nil
801}
802
803func (a *agent) CancelAll() {
804 a.activeRequests.Range(func(key, value any) bool {
805 a.Cancel(key.(string)) // key is sessionID
806 return true
807 })
808}
809
810func (a *agent) UpdateModel() error {
811 cfg := config.Get()
812
813 // Get current provider configuration
814 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
815 if currentProviderCfg.ID == "" {
816 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
817 }
818
819 // Check if provider has changed
820 if string(currentProviderCfg.ID) != a.providerID {
821 // Provider changed, need to recreate the main provider
822 model := config.GetAgentModel(a.agentCfg.ID)
823 if model.ID == "" {
824 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
825 }
826
827 promptID := agentPromptMap[a.agentCfg.ID]
828 if promptID == "" {
829 promptID = prompt.PromptDefault
830 }
831
832 opts := []provider.ProviderClientOption{
833 provider.WithModel(a.agentCfg.Model),
834 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
835 provider.WithMaxTokens(model.DefaultMaxTokens),
836 }
837
838 newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
839 if err != nil {
840 return fmt.Errorf("failed to create new provider: %w", err)
841 }
842
843 // Update the provider and provider ID
844 a.provider = newProvider
845 a.providerID = string(currentProviderCfg.ID)
846 }
847
848 // Check if small model provider has changed (affects title and summarize providers)
849 smallModelCfg := cfg.Models.Small
850 var smallModelProviderCfg config.ProviderConfig
851
852 for _, p := range cfg.Providers {
853 if p.ID == smallModelCfg.Provider {
854 smallModelProviderCfg = p
855 break
856 }
857 }
858
859 if smallModelProviderCfg.ID == "" {
860 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
861 }
862
863 // Check if summarize provider has changed
864 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
865 var smallModel config.Model
866 for _, m := range smallModelProviderCfg.Models {
867 if m.ID == smallModelCfg.ModelID {
868 smallModel = m
869 break
870 }
871 }
872 if smallModel.ID == "" {
873 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
874 }
875
876 // Recreate title provider
877 titleOpts := []provider.ProviderClientOption{
878 provider.WithModel(config.SmallModel),
879 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
880 provider.WithMaxTokens(40),
881 }
882 newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
883 if err != nil {
884 return fmt.Errorf("failed to create new title provider: %w", err)
885 }
886
887 // Recreate summarize provider
888 summarizeOpts := []provider.ProviderClientOption{
889 provider.WithModel(config.SmallModel),
890 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
891 provider.WithMaxTokens(smallModel.DefaultMaxTokens),
892 }
893 newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
894 if err != nil {
895 return fmt.Errorf("failed to create new summarize provider: %w", err)
896 }
897
898 // Update the providers and provider ID
899 a.titleProvider = newTitleProvider
900 a.summarizeProvider = newSummarizeProvider
901 a.summarizeProviderID = string(smallModelProviderCfg.ID)
902 }
903
904 return nil
905}