1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request cancelled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMcpTools(ctx, permissions)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 allTools := []tools.BaseTool{
103 tools.NewBashTool(permissions),
104 tools.NewEditTool(lspClients, permissions, history),
105 tools.NewFetchTool(permissions),
106 tools.NewGlobTool(),
107 tools.NewGrepTool(),
108 tools.NewLsTool(),
109 tools.NewSourcegraphTool(),
110 tools.NewViewTool(lspClients),
111 tools.NewWriteTool(lspClients, permissions, history),
112 }
113
114 if agentCfg.ID == "coder" {
115 taskAgentCfg := config.Get().Agents["task"]
116 if taskAgentCfg.ID == "" {
117 return nil, fmt.Errorf("task agent not found in config")
118 }
119 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120 if err != nil {
121 return nil, fmt.Errorf("failed to create task agent: %w", err)
122 }
123
124 allTools = append(
125 allTools,
126 NewAgentTool(
127 taskAgent,
128 sessions,
129 messages,
130 ),
131 )
132 }
133
134 allTools = append(allTools, otherTools...)
135 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
136 if providerCfg == nil {
137 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138 }
139 model := config.Get().GetModelByType(agentCfg.Model)
140
141 if model == nil {
142 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143 }
144
145 promptID := agentPromptMap[agentCfg.ID]
146 if promptID == "" {
147 promptID = prompt.PromptDefault
148 }
149 opts := []provider.ProviderClientOption{
150 provider.WithModel(agentCfg.Model),
151 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152 }
153 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
154 if err != nil {
155 return nil, err
156 }
157
158 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
159 var smallModelProviderCfg *config.ProviderConfig
160 if smallModelCfg.Provider == providerCfg.ID {
161 smallModelProviderCfg = providerCfg
162 } else {
163 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
164
165 if smallModelProviderCfg.ID == "" {
166 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
167 }
168 }
169 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
170 if smallModel.ID == "" {
171 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
172 }
173
174 titleOpts := []provider.ProviderClientOption{
175 provider.WithModel(config.SelectedModelTypeSmall),
176 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
177 }
178 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
179 if err != nil {
180 return nil, err
181 }
182 summarizeOpts := []provider.ProviderClientOption{
183 provider.WithModel(config.SelectedModelTypeSmall),
184 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
185 }
186 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
187 if err != nil {
188 return nil, err
189 }
190
191 agentTools := []tools.BaseTool{}
192 if agentCfg.AllowedTools == nil {
193 agentTools = allTools
194 } else {
195 for _, tool := range allTools {
196 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
197 agentTools = append(agentTools, tool)
198 }
199 }
200 }
201
202 agent := &agent{
203 Broker: pubsub.NewBroker[AgentEvent](),
204 agentCfg: agentCfg,
205 provider: agentProvider,
206 providerID: string(providerCfg.ID),
207 messages: messages,
208 sessions: sessions,
209 tools: agentTools,
210 titleProvider: titleProvider,
211 summarizeProvider: summarizeProvider,
212 summarizeProviderID: string(smallModelProviderCfg.ID),
213 activeRequests: sync.Map{},
214 }
215
216 return agent, nil
217}
218
219func (a *agent) Model() fur.Model {
220 return *config.Get().GetModelByType(a.agentCfg.Model)
221}
222
223func (a *agent) Cancel(sessionID string) {
224 // Cancel regular requests
225 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
226 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
227 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
228 cancel()
229 }
230 }
231
232 // Also check for summarize requests
233 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
234 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
235 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
236 cancel()
237 }
238 }
239}
240
241func (a *agent) IsBusy() bool {
242 busy := false
243 a.activeRequests.Range(func(key, value any) bool {
244 if cancelFunc, ok := value.(context.CancelFunc); ok {
245 if cancelFunc != nil {
246 busy = true
247 return false
248 }
249 }
250 return true
251 })
252 return busy
253}
254
255func (a *agent) IsSessionBusy(sessionID string) bool {
256 _, busy := a.activeRequests.Load(sessionID)
257 return busy
258}
259
260func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
261 if content == "" {
262 return nil
263 }
264 if a.titleProvider == nil {
265 return nil
266 }
267 session, err := a.sessions.Get(ctx, sessionID)
268 if err != nil {
269 return err
270 }
271 parts := []message.ContentPart{message.TextContent{
272 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
273 }}
274
275 // Use streaming approach like summarization
276 response := a.titleProvider.StreamResponse(
277 ctx,
278 []message.Message{
279 {
280 Role: message.User,
281 Parts: parts,
282 },
283 },
284 make([]tools.BaseTool, 0),
285 )
286
287 var finalResponse *provider.ProviderResponse
288 for r := range response {
289 if r.Error != nil {
290 return r.Error
291 }
292 finalResponse = r.Response
293 }
294
295 if finalResponse == nil {
296 return fmt.Errorf("no response received from title provider")
297 }
298
299 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
300 if title == "" {
301 return nil
302 }
303
304 session.Title = title
305 _, err = a.sessions.Save(ctx, session)
306 return err
307}
308
309func (a *agent) err(err error) AgentEvent {
310 return AgentEvent{
311 Type: AgentEventTypeError,
312 Error: err,
313 }
314}
315
316func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
317 if !a.Model().SupportsImages && attachments != nil {
318 attachments = nil
319 }
320 events := make(chan AgentEvent)
321 if a.IsSessionBusy(sessionID) {
322 return nil, ErrSessionBusy
323 }
324
325 genCtx, cancel := context.WithCancel(ctx)
326
327 a.activeRequests.Store(sessionID, cancel)
328 go func() {
329 slog.Debug("Request started", "sessionID", sessionID)
330 defer log.RecoverPanic("agent.Run", func() {
331 events <- a.err(fmt.Errorf("panic while running the agent"))
332 })
333 var attachmentParts []message.ContentPart
334 for _, attachment := range attachments {
335 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
336 }
337 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
338 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
339 slog.Error(result.Error.Error())
340 }
341 slog.Debug("Request completed", "sessionID", sessionID)
342 a.activeRequests.Delete(sessionID)
343 cancel()
344 a.Publish(pubsub.CreatedEvent, result)
345 events <- result
346 close(events)
347 }()
348 return events, nil
349}
350
351func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
352 cfg := config.Get()
353 // List existing messages; if none, start title generation asynchronously.
354 msgs, err := a.messages.List(ctx, sessionID)
355 if err != nil {
356 return a.err(fmt.Errorf("failed to list messages: %w", err))
357 }
358 if len(msgs) == 0 {
359 go func() {
360 defer log.RecoverPanic("agent.Run", func() {
361 slog.Error("panic while generating title")
362 })
363 titleErr := a.generateTitle(context.Background(), sessionID, content)
364 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
365 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
366 }
367 }()
368 }
369 session, err := a.sessions.Get(ctx, sessionID)
370 if err != nil {
371 return a.err(fmt.Errorf("failed to get session: %w", err))
372 }
373 if session.SummaryMessageID != "" {
374 summaryMsgInex := -1
375 for i, msg := range msgs {
376 if msg.ID == session.SummaryMessageID {
377 summaryMsgInex = i
378 break
379 }
380 }
381 if summaryMsgInex != -1 {
382 msgs = msgs[summaryMsgInex:]
383 msgs[0].Role = message.User
384 }
385 }
386
387 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
388 if err != nil {
389 return a.err(fmt.Errorf("failed to create user message: %w", err))
390 }
391 // Append the new user message to the conversation history.
392 msgHistory := append(msgs, userMsg)
393
394 for {
395 // Check for cancellation before each iteration
396 select {
397 case <-ctx.Done():
398 return a.err(ctx.Err())
399 default:
400 // Continue processing
401 }
402 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
403 if err != nil {
404 if errors.Is(err, context.Canceled) {
405 agentMessage.AddFinish(message.FinishReasonCanceled)
406 a.messages.Update(context.Background(), agentMessage)
407 return a.err(ErrRequestCancelled)
408 }
409 return a.err(fmt.Errorf("failed to process events: %w", err))
410 }
411 if cfg.Options.Debug {
412 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
413 }
414 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
415 // We are not done, we need to respond with the tool response
416 msgHistory = append(msgHistory, agentMessage, *toolResults)
417 continue
418 }
419 return AgentEvent{
420 Type: AgentEventTypeResponse,
421 Message: agentMessage,
422 Done: true,
423 }
424 }
425}
426
427func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
428 parts := []message.ContentPart{message.TextContent{Text: content}}
429 parts = append(parts, attachmentParts...)
430 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
431 Role: message.User,
432 Parts: parts,
433 })
434}
435
436func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
437 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
438 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
439
440 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441 Role: message.Assistant,
442 Parts: []message.ContentPart{},
443 Model: a.Model().ID,
444 Provider: a.providerID,
445 })
446 if err != nil {
447 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
448 }
449
450 // Add the session and message ID into the context if needed by tools.
451 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
452
453 // Process each event in the stream.
454 for event := range eventChan {
455 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
456 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
457 return assistantMsg, nil, processErr
458 }
459 if ctx.Err() != nil {
460 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
461 return assistantMsg, nil, ctx.Err()
462 }
463 }
464
465 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
466 toolCalls := assistantMsg.ToolCalls()
467 for i, toolCall := range toolCalls {
468 select {
469 case <-ctx.Done():
470 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
471 // Make all future tool calls cancelled
472 for j := i; j < len(toolCalls); j++ {
473 toolResults[j] = message.ToolResult{
474 ToolCallID: toolCalls[j].ID,
475 Content: "Tool execution canceled by user",
476 IsError: true,
477 }
478 }
479 goto out
480 default:
481 // Continue processing
482 var tool tools.BaseTool
483 for _, availableTool := range a.tools {
484 if availableTool.Info().Name == toolCall.Name {
485 tool = availableTool
486 break
487 }
488 }
489
490 // Tool not found
491 if tool == nil {
492 toolResults[i] = message.ToolResult{
493 ToolCallID: toolCall.ID,
494 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
495 IsError: true,
496 }
497 continue
498 }
499 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
500 ID: toolCall.ID,
501 Name: toolCall.Name,
502 Input: toolCall.Input,
503 })
504 if toolErr != nil {
505 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
506 toolResults[i] = message.ToolResult{
507 ToolCallID: toolCall.ID,
508 Content: "Permission denied",
509 IsError: true,
510 }
511 for j := i + 1; j < len(toolCalls); j++ {
512 toolResults[j] = message.ToolResult{
513 ToolCallID: toolCalls[j].ID,
514 Content: "Tool execution canceled by user",
515 IsError: true,
516 }
517 }
518 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
519 break
520 }
521 }
522 toolResults[i] = message.ToolResult{
523 ToolCallID: toolCall.ID,
524 Content: toolResult.Content,
525 Metadata: toolResult.Metadata,
526 IsError: toolResult.IsError,
527 }
528 }
529 }
530out:
531 if len(toolResults) == 0 {
532 return assistantMsg, nil, nil
533 }
534 parts := make([]message.ContentPart, 0)
535 for _, tr := range toolResults {
536 parts = append(parts, tr)
537 }
538 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
539 Role: message.Tool,
540 Parts: parts,
541 Provider: a.providerID,
542 })
543 if err != nil {
544 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
545 }
546
547 return assistantMsg, &msg, err
548}
549
550func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
551 msg.AddFinish(finishReson)
552 _ = a.messages.Update(ctx, *msg)
553}
554
555func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
556 select {
557 case <-ctx.Done():
558 return ctx.Err()
559 default:
560 // Continue processing.
561 }
562
563 switch event.Type {
564 case provider.EventThinkingDelta:
565 assistantMsg.AppendReasoningContent(event.Content)
566 return a.messages.Update(ctx, *assistantMsg)
567 case provider.EventContentDelta:
568 assistantMsg.AppendContent(event.Content)
569 return a.messages.Update(ctx, *assistantMsg)
570 case provider.EventToolUseStart:
571 slog.Info("Tool call started", "toolCall", event.ToolCall)
572 assistantMsg.AddToolCall(*event.ToolCall)
573 return a.messages.Update(ctx, *assistantMsg)
574 case provider.EventToolUseDelta:
575 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
576 return a.messages.Update(ctx, *assistantMsg)
577 case provider.EventToolUseStop:
578 slog.Info("Finished tool call", "toolCall", event.ToolCall)
579 assistantMsg.FinishToolCall(event.ToolCall.ID)
580 return a.messages.Update(ctx, *assistantMsg)
581 case provider.EventError:
582 if errors.Is(event.Error, context.Canceled) {
583 slog.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
584 return context.Canceled
585 }
586 slog.Error(event.Error.Error())
587 return event.Error
588 case provider.EventComplete:
589 assistantMsg.SetToolCalls(event.Response.ToolCalls)
590 assistantMsg.AddFinish(event.Response.FinishReason)
591 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
592 return fmt.Errorf("failed to update message: %w", err)
593 }
594 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
595 }
596
597 return nil
598}
599
600func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
601 sess, err := a.sessions.Get(ctx, sessionID)
602 if err != nil {
603 return fmt.Errorf("failed to get session: %w", err)
604 }
605
606 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
607 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
608 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
609 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
610
611 sess.Cost += cost
612 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
613 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
614
615 _, err = a.sessions.Save(ctx, sess)
616 if err != nil {
617 return fmt.Errorf("failed to save session: %w", err)
618 }
619 return nil
620}
621
622func (a *agent) Summarize(ctx context.Context, sessionID string) error {
623 if a.summarizeProvider == nil {
624 return fmt.Errorf("summarize provider not available")
625 }
626
627 // Check if session is busy
628 if a.IsSessionBusy(sessionID) {
629 return ErrSessionBusy
630 }
631
632 // Create a new context with cancellation
633 summarizeCtx, cancel := context.WithCancel(ctx)
634
635 // Store the cancel function in activeRequests to allow cancellation
636 a.activeRequests.Store(sessionID+"-summarize", cancel)
637
638 go func() {
639 defer a.activeRequests.Delete(sessionID + "-summarize")
640 defer cancel()
641 event := AgentEvent{
642 Type: AgentEventTypeSummarize,
643 Progress: "Starting summarization...",
644 }
645
646 a.Publish(pubsub.CreatedEvent, event)
647 // Get all messages from the session
648 msgs, err := a.messages.List(summarizeCtx, sessionID)
649 if err != nil {
650 event = AgentEvent{
651 Type: AgentEventTypeError,
652 Error: fmt.Errorf("failed to list messages: %w", err),
653 Done: true,
654 }
655 a.Publish(pubsub.CreatedEvent, event)
656 return
657 }
658 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
659
660 if len(msgs) == 0 {
661 event = AgentEvent{
662 Type: AgentEventTypeError,
663 Error: fmt.Errorf("no messages to summarize"),
664 Done: true,
665 }
666 a.Publish(pubsub.CreatedEvent, event)
667 return
668 }
669
670 event = AgentEvent{
671 Type: AgentEventTypeSummarize,
672 Progress: "Analyzing conversation...",
673 }
674 a.Publish(pubsub.CreatedEvent, event)
675
676 // Add a system message to guide the summarization
677 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
678
679 // Create a new message with the summarize prompt
680 promptMsg := message.Message{
681 Role: message.User,
682 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
683 }
684
685 // Append the prompt to the messages
686 msgsWithPrompt := append(msgs, promptMsg)
687
688 event = AgentEvent{
689 Type: AgentEventTypeSummarize,
690 Progress: "Generating summary...",
691 }
692
693 a.Publish(pubsub.CreatedEvent, event)
694
695 // Send the messages to the summarize provider
696 response := a.summarizeProvider.StreamResponse(
697 summarizeCtx,
698 msgsWithPrompt,
699 make([]tools.BaseTool, 0),
700 )
701 var finalResponse *provider.ProviderResponse
702 for r := range response {
703 if r.Error != nil {
704 event = AgentEvent{
705 Type: AgentEventTypeError,
706 Error: fmt.Errorf("failed to summarize: %w", err),
707 Done: true,
708 }
709 a.Publish(pubsub.CreatedEvent, event)
710 return
711 }
712 finalResponse = r.Response
713 }
714
715 summary := strings.TrimSpace(finalResponse.Content)
716 if summary == "" {
717 event = AgentEvent{
718 Type: AgentEventTypeError,
719 Error: fmt.Errorf("empty summary returned"),
720 Done: true,
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723 return
724 }
725 event = AgentEvent{
726 Type: AgentEventTypeSummarize,
727 Progress: "Creating new session...",
728 }
729
730 a.Publish(pubsub.CreatedEvent, event)
731 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
732 if err != nil {
733 event = AgentEvent{
734 Type: AgentEventTypeError,
735 Error: fmt.Errorf("failed to get session: %w", err),
736 Done: true,
737 }
738
739 a.Publish(pubsub.CreatedEvent, event)
740 return
741 }
742 // Create a message in the new session with the summary
743 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
744 Role: message.Assistant,
745 Parts: []message.ContentPart{
746 message.TextContent{Text: summary},
747 message.Finish{
748 Reason: message.FinishReasonEndTurn,
749 Time: time.Now().Unix(),
750 },
751 },
752 Model: a.summarizeProvider.Model().ID,
753 Provider: a.summarizeProviderID,
754 })
755 if err != nil {
756 event = AgentEvent{
757 Type: AgentEventTypeError,
758 Error: fmt.Errorf("failed to create summary message: %w", err),
759 Done: true,
760 }
761
762 a.Publish(pubsub.CreatedEvent, event)
763 return
764 }
765 oldSession.SummaryMessageID = msg.ID
766 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
767 oldSession.PromptTokens = 0
768 model := a.summarizeProvider.Model()
769 usage := finalResponse.Usage
770 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
771 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
772 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
773 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
774 oldSession.Cost += cost
775 _, err = a.sessions.Save(summarizeCtx, oldSession)
776 if err != nil {
777 event = AgentEvent{
778 Type: AgentEventTypeError,
779 Error: fmt.Errorf("failed to save session: %w", err),
780 Done: true,
781 }
782 a.Publish(pubsub.CreatedEvent, event)
783 }
784
785 event = AgentEvent{
786 Type: AgentEventTypeSummarize,
787 SessionID: oldSession.ID,
788 Progress: "Summary complete",
789 Done: true,
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792 // Send final success event with the new session ID
793 }()
794
795 return nil
796}
797
798func (a *agent) CancelAll() {
799 a.activeRequests.Range(func(key, value any) bool {
800 a.Cancel(key.(string)) // key is sessionID
801 return true
802 })
803}
804
805func (a *agent) UpdateModel() error {
806 cfg := config.Get()
807
808 // Get current provider configuration
809 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
810 if currentProviderCfg.ID == "" {
811 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
812 }
813
814 // Check if provider has changed
815 if string(currentProviderCfg.ID) != a.providerID {
816 // Provider changed, need to recreate the main provider
817 model := cfg.GetModelByType(a.agentCfg.Model)
818 if model.ID == "" {
819 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
820 }
821
822 promptID := agentPromptMap[a.agentCfg.ID]
823 if promptID == "" {
824 promptID = prompt.PromptDefault
825 }
826
827 opts := []provider.ProviderClientOption{
828 provider.WithModel(a.agentCfg.Model),
829 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
830 }
831
832 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
833 if err != nil {
834 return fmt.Errorf("failed to create new provider: %w", err)
835 }
836
837 // Update the provider and provider ID
838 a.provider = newProvider
839 a.providerID = string(currentProviderCfg.ID)
840 }
841
842 // Check if small model provider has changed (affects title and summarize providers)
843 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
844 var smallModelProviderCfg config.ProviderConfig
845
846 for _, p := range cfg.Providers {
847 if p.ID == smallModelCfg.Provider {
848 smallModelProviderCfg = p
849 break
850 }
851 }
852
853 if smallModelProviderCfg.ID == "" {
854 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
855 }
856
857 // Check if summarize provider has changed
858 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
859 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
860 if smallModel == nil {
861 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
862 }
863
864 // Recreate title provider
865 titleOpts := []provider.ProviderClientOption{
866 provider.WithModel(config.SelectedModelTypeSmall),
867 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
868 // We want the title to be short, so we limit the max tokens
869 provider.WithMaxTokens(40),
870 }
871 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
872 if err != nil {
873 return fmt.Errorf("failed to create new title provider: %w", err)
874 }
875
876 // Recreate summarize provider
877 summarizeOpts := []provider.ProviderClientOption{
878 provider.WithModel(config.SelectedModelTypeSmall),
879 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
880 }
881 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
882 if err != nil {
883 return fmt.Errorf("failed to create new summarize provider: %w", err)
884 }
885
886 // Update the providers and provider ID
887 a.titleProvider = newTitleProvider
888 a.summarizeProvider = newSummarizeProvider
889 a.summarizeProviderID = string(smallModelProviderCfg.ID)
890 }
891
892 return nil
893}