1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMcpTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewEditTool(lspClients, permissions, history, cwd),
106 tools.NewFetchTool(permissions, cwd),
107 tools.NewGlobTool(cwd),
108 tools.NewGrepTool(cwd),
109 tools.NewLsTool(cwd),
110 tools.NewSourcegraphTool(),
111 tools.NewViewTool(lspClients, cwd),
112 tools.NewWriteTool(lspClients, permissions, history, cwd),
113 }
114
115 if agentCfg.ID == "coder" {
116 taskAgentCfg := config.Get().Agents["task"]
117 if taskAgentCfg.ID == "" {
118 return nil, fmt.Errorf("task agent not found in config")
119 }
120 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121 if err != nil {
122 return nil, fmt.Errorf("failed to create task agent: %w", err)
123 }
124
125 allTools = append(
126 allTools,
127 NewAgentTool(
128 taskAgent,
129 sessions,
130 messages,
131 ),
132 )
133 }
134
135 allTools = append(allTools, otherTools...)
136 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137 if providerCfg == nil {
138 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139 }
140 model := config.Get().GetModelByType(agentCfg.Model)
141
142 if model == nil {
143 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144 }
145
146 promptID := agentPromptMap[agentCfg.ID]
147 if promptID == "" {
148 promptID = prompt.PromptDefault
149 }
150 opts := []provider.ProviderClientOption{
151 provider.WithModel(agentCfg.Model),
152 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153 }
154 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155 if err != nil {
156 return nil, err
157 }
158
159 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160 var smallModelProviderCfg *config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166 if smallModelProviderCfg.ID == "" {
167 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168 }
169 }
170 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171 if smallModel.ID == "" {
172 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173 }
174
175 titleOpts := []provider.ProviderClientOption{
176 provider.WithModel(config.SelectedModelTypeSmall),
177 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178 }
179 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180 if err != nil {
181 return nil, err
182 }
183 summarizeOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SelectedModelTypeSmall),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186 }
187 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188 if err != nil {
189 return nil, err
190 }
191
192 agentTools := []tools.BaseTool{}
193 if agentCfg.AllowedTools == nil {
194 agentTools = allTools
195 } else {
196 for _, tool := range allTools {
197 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198 agentTools = append(agentTools, tool)
199 }
200 }
201 }
202
203 agent := &agent{
204 Broker: pubsub.NewBroker[AgentEvent](),
205 agentCfg: agentCfg,
206 provider: agentProvider,
207 providerID: string(providerCfg.ID),
208 messages: messages,
209 sessions: sessions,
210 tools: agentTools,
211 titleProvider: titleProvider,
212 summarizeProvider: summarizeProvider,
213 summarizeProviderID: string(smallModelProviderCfg.ID),
214 activeRequests: sync.Map{},
215 }
216
217 return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221 return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225 // Cancel regular requests
226 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229 cancel()
230 }
231 }
232
233 // Also check for summarize requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240}
241
242func (a *agent) IsBusy() bool {
243 busy := false
244 a.activeRequests.Range(func(key, value any) bool {
245 if cancelFunc, ok := value.(context.CancelFunc); ok {
246 if cancelFunc != nil {
247 busy = true
248 return false
249 }
250 }
251 return true
252 })
253 return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257 _, busy := a.activeRequests.Load(sessionID)
258 return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262 if content == "" {
263 return nil
264 }
265 if a.titleProvider == nil {
266 return nil
267 }
268 session, err := a.sessions.Get(ctx, sessionID)
269 if err != nil {
270 return err
271 }
272 parts := []message.ContentPart{message.TextContent{
273 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274 }}
275
276 // Use streaming approach like summarization
277 response := a.titleProvider.StreamResponse(
278 ctx,
279 []message.Message{
280 {
281 Role: message.User,
282 Parts: parts,
283 },
284 },
285 make([]tools.BaseTool, 0),
286 )
287
288 var finalResponse *provider.ProviderResponse
289 for r := range response {
290 if r.Error != nil {
291 return r.Error
292 }
293 finalResponse = r.Response
294 }
295
296 if finalResponse == nil {
297 return fmt.Errorf("no response received from title provider")
298 }
299
300 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301 if title == "" {
302 return nil
303 }
304
305 session.Title = title
306 _, err = a.sessions.Save(ctx, session)
307 return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311 return AgentEvent{
312 Type: AgentEventTypeError,
313 Error: err,
314 }
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318 if !a.Model().SupportsImages && attachments != nil {
319 attachments = nil
320 }
321 events := make(chan AgentEvent)
322 if a.IsSessionBusy(sessionID) {
323 return nil, ErrSessionBusy
324 }
325
326 genCtx, cancel := context.WithCancel(ctx)
327
328 a.activeRequests.Store(sessionID, cancel)
329 go func() {
330 slog.Debug("Request started", "sessionID", sessionID)
331 defer log.RecoverPanic("agent.Run", func() {
332 events <- a.err(fmt.Errorf("panic while running the agent"))
333 })
334 var attachmentParts []message.ContentPart
335 for _, attachment := range attachments {
336 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337 }
338 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340 slog.Error(result.Error.Error())
341 }
342 slog.Debug("Request completed", "sessionID", sessionID)
343 a.activeRequests.Delete(sessionID)
344 cancel()
345 a.Publish(pubsub.CreatedEvent, result)
346 events <- result
347 close(events)
348 }()
349 return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353 cfg := config.Get()
354 // List existing messages; if none, start title generation asynchronously.
355 msgs, err := a.messages.List(ctx, sessionID)
356 if err != nil {
357 return a.err(fmt.Errorf("failed to list messages: %w", err))
358 }
359 if len(msgs) == 0 {
360 go func() {
361 defer log.RecoverPanic("agent.Run", func() {
362 slog.Error("panic while generating title")
363 })
364 titleErr := a.generateTitle(context.Background(), sessionID, content)
365 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367 }
368 }()
369 }
370 session, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return a.err(fmt.Errorf("failed to get session: %w", err))
373 }
374 if session.SummaryMessageID != "" {
375 summaryMsgInex := -1
376 for i, msg := range msgs {
377 if msg.ID == session.SummaryMessageID {
378 summaryMsgInex = i
379 break
380 }
381 }
382 if summaryMsgInex != -1 {
383 msgs = msgs[summaryMsgInex:]
384 msgs[0].Role = message.User
385 }
386 }
387
388 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to create user message: %w", err))
391 }
392 // Append the new user message to the conversation history.
393 msgHistory := append(msgs, userMsg)
394
395 for {
396 // Check for cancellation before each iteration
397 select {
398 case <-ctx.Done():
399 return a.err(ctx.Err())
400 default:
401 // Continue processing
402 }
403 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404 if err != nil {
405 if errors.Is(err, context.Canceled) {
406 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407 a.messages.Update(context.Background(), agentMessage)
408 return a.err(ErrRequestCancelled)
409 }
410 return a.err(fmt.Errorf("failed to process events: %w", err))
411 }
412 if cfg.Options.Debug {
413 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414 }
415 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416 // We are not done, we need to respond with the tool response
417 msgHistory = append(msgHistory, agentMessage, *toolResults)
418 continue
419 }
420 return AgentEvent{
421 Type: AgentEventTypeResponse,
422 Message: agentMessage,
423 Done: true,
424 }
425 }
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429 parts := []message.ContentPart{message.TextContent{Text: content}}
430 parts = append(parts, attachmentParts...)
431 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432 Role: message.User,
433 Parts: parts,
434 })
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.Assistant,
443 Parts: []message.ContentPart{},
444 Model: a.Model().ID,
445 Provider: a.providerID,
446 })
447 if err != nil {
448 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449 }
450
451 // Add the session and message ID into the context if needed by tools.
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process each event in the stream.
455 for event := range eventChan {
456 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457 if errors.Is(processErr, context.Canceled) {
458 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459 } else {
460 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461 }
462 return assistantMsg, nil, processErr
463 }
464 if ctx.Err() != nil {
465 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466 return assistantMsg, nil, ctx.Err()
467 }
468 }
469
470 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471 toolCalls := assistantMsg.ToolCalls()
472 for i, toolCall := range toolCalls {
473 select {
474 case <-ctx.Done():
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 // Make all future tool calls cancelled
477 for j := i; j < len(toolCalls); j++ {
478 toolResults[j] = message.ToolResult{
479 ToolCallID: toolCalls[j].ID,
480 Content: "Tool execution canceled by user",
481 IsError: true,
482 }
483 }
484 goto out
485 default:
486 // Continue processing
487 var tool tools.BaseTool
488 for _, availableTool := range a.tools {
489 if availableTool.Info().Name == toolCall.Name {
490 tool = availableTool
491 break
492 }
493 }
494
495 // Tool not found
496 if tool == nil {
497 toolResults[i] = message.ToolResult{
498 ToolCallID: toolCall.ID,
499 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
500 IsError: true,
501 }
502 continue
503 }
504 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
505 ID: toolCall.ID,
506 Name: toolCall.Name,
507 Input: toolCall.Input,
508 })
509 if toolErr != nil {
510 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
511 toolResults[i] = message.ToolResult{
512 ToolCallID: toolCall.ID,
513 Content: "Permission denied",
514 IsError: true,
515 }
516 for j := i + 1; j < len(toolCalls); j++ {
517 toolResults[j] = message.ToolResult{
518 ToolCallID: toolCalls[j].ID,
519 Content: "Tool execution canceled by user",
520 IsError: true,
521 }
522 }
523 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
524 break
525 }
526 }
527 toolResults[i] = message.ToolResult{
528 ToolCallID: toolCall.ID,
529 Content: toolResult.Content,
530 Metadata: toolResult.Metadata,
531 IsError: toolResult.IsError,
532 }
533 }
534 }
535out:
536 if len(toolResults) == 0 {
537 return assistantMsg, nil, nil
538 }
539 parts := make([]message.ContentPart, 0)
540 for _, tr := range toolResults {
541 parts = append(parts, tr)
542 }
543 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
544 Role: message.Tool,
545 Parts: parts,
546 Provider: a.providerID,
547 })
548 if err != nil {
549 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
550 }
551
552 return assistantMsg, &msg, err
553}
554
555func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason, message, details string) {
556 msg.AddFinish(finishReson, message, details)
557 _ = a.messages.Update(ctx, *msg)
558}
559
560func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
561 select {
562 case <-ctx.Done():
563 return ctx.Err()
564 default:
565 // Continue processing.
566 }
567
568 switch event.Type {
569 case provider.EventThinkingDelta:
570 assistantMsg.AppendReasoningContent(event.Content)
571 return a.messages.Update(ctx, *assistantMsg)
572 case provider.EventContentDelta:
573 assistantMsg.AppendContent(event.Content)
574 return a.messages.Update(ctx, *assistantMsg)
575 case provider.EventToolUseStart:
576 slog.Info("Tool call started", "toolCall", event.ToolCall)
577 assistantMsg.AddToolCall(*event.ToolCall)
578 return a.messages.Update(ctx, *assistantMsg)
579 case provider.EventToolUseDelta:
580 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
581 return a.messages.Update(ctx, *assistantMsg)
582 case provider.EventToolUseStop:
583 slog.Info("Finished tool call", "toolCall", event.ToolCall)
584 assistantMsg.FinishToolCall(event.ToolCall.ID)
585 return a.messages.Update(ctx, *assistantMsg)
586 case provider.EventError:
587 return event.Error
588 case provider.EventComplete:
589 assistantMsg.SetToolCalls(event.Response.ToolCalls)
590 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
591 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
592 return fmt.Errorf("failed to update message: %w", err)
593 }
594 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
595 }
596
597 return nil
598}
599
600func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
601 sess, err := a.sessions.Get(ctx, sessionID)
602 if err != nil {
603 return fmt.Errorf("failed to get session: %w", err)
604 }
605
606 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
607 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
608 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
609 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
610
611 sess.Cost += cost
612 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
613 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
614
615 _, err = a.sessions.Save(ctx, sess)
616 if err != nil {
617 return fmt.Errorf("failed to save session: %w", err)
618 }
619 return nil
620}
621
622func (a *agent) Summarize(ctx context.Context, sessionID string) error {
623 if a.summarizeProvider == nil {
624 return fmt.Errorf("summarize provider not available")
625 }
626
627 // Check if session is busy
628 if a.IsSessionBusy(sessionID) {
629 return ErrSessionBusy
630 }
631
632 // Create a new context with cancellation
633 summarizeCtx, cancel := context.WithCancel(ctx)
634
635 // Store the cancel function in activeRequests to allow cancellation
636 a.activeRequests.Store(sessionID+"-summarize", cancel)
637
638 go func() {
639 defer a.activeRequests.Delete(sessionID + "-summarize")
640 defer cancel()
641 event := AgentEvent{
642 Type: AgentEventTypeSummarize,
643 Progress: "Starting summarization...",
644 }
645
646 a.Publish(pubsub.CreatedEvent, event)
647 // Get all messages from the session
648 msgs, err := a.messages.List(summarizeCtx, sessionID)
649 if err != nil {
650 event = AgentEvent{
651 Type: AgentEventTypeError,
652 Error: fmt.Errorf("failed to list messages: %w", err),
653 Done: true,
654 }
655 a.Publish(pubsub.CreatedEvent, event)
656 return
657 }
658 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
659
660 if len(msgs) == 0 {
661 event = AgentEvent{
662 Type: AgentEventTypeError,
663 Error: fmt.Errorf("no messages to summarize"),
664 Done: true,
665 }
666 a.Publish(pubsub.CreatedEvent, event)
667 return
668 }
669
670 event = AgentEvent{
671 Type: AgentEventTypeSummarize,
672 Progress: "Analyzing conversation...",
673 }
674 a.Publish(pubsub.CreatedEvent, event)
675
676 // Add a system message to guide the summarization
677 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
678
679 // Create a new message with the summarize prompt
680 promptMsg := message.Message{
681 Role: message.User,
682 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
683 }
684
685 // Append the prompt to the messages
686 msgsWithPrompt := append(msgs, promptMsg)
687
688 event = AgentEvent{
689 Type: AgentEventTypeSummarize,
690 Progress: "Generating summary...",
691 }
692
693 a.Publish(pubsub.CreatedEvent, event)
694
695 // Send the messages to the summarize provider
696 response := a.summarizeProvider.StreamResponse(
697 summarizeCtx,
698 msgsWithPrompt,
699 make([]tools.BaseTool, 0),
700 )
701 var finalResponse *provider.ProviderResponse
702 for r := range response {
703 if r.Error != nil {
704 event = AgentEvent{
705 Type: AgentEventTypeError,
706 Error: fmt.Errorf("failed to summarize: %w", err),
707 Done: true,
708 }
709 a.Publish(pubsub.CreatedEvent, event)
710 return
711 }
712 finalResponse = r.Response
713 }
714
715 summary := strings.TrimSpace(finalResponse.Content)
716 if summary == "" {
717 event = AgentEvent{
718 Type: AgentEventTypeError,
719 Error: fmt.Errorf("empty summary returned"),
720 Done: true,
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723 return
724 }
725 event = AgentEvent{
726 Type: AgentEventTypeSummarize,
727 Progress: "Creating new session...",
728 }
729
730 a.Publish(pubsub.CreatedEvent, event)
731 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
732 if err != nil {
733 event = AgentEvent{
734 Type: AgentEventTypeError,
735 Error: fmt.Errorf("failed to get session: %w", err),
736 Done: true,
737 }
738
739 a.Publish(pubsub.CreatedEvent, event)
740 return
741 }
742 // Create a message in the new session with the summary
743 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
744 Role: message.Assistant,
745 Parts: []message.ContentPart{
746 message.TextContent{Text: summary},
747 message.Finish{
748 Reason: message.FinishReasonEndTurn,
749 Time: time.Now().Unix(),
750 },
751 },
752 Model: a.summarizeProvider.Model().ID,
753 Provider: a.summarizeProviderID,
754 })
755 if err != nil {
756 event = AgentEvent{
757 Type: AgentEventTypeError,
758 Error: fmt.Errorf("failed to create summary message: %w", err),
759 Done: true,
760 }
761
762 a.Publish(pubsub.CreatedEvent, event)
763 return
764 }
765 oldSession.SummaryMessageID = msg.ID
766 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
767 oldSession.PromptTokens = 0
768 model := a.summarizeProvider.Model()
769 usage := finalResponse.Usage
770 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
771 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
772 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
773 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
774 oldSession.Cost += cost
775 _, err = a.sessions.Save(summarizeCtx, oldSession)
776 if err != nil {
777 event = AgentEvent{
778 Type: AgentEventTypeError,
779 Error: fmt.Errorf("failed to save session: %w", err),
780 Done: true,
781 }
782 a.Publish(pubsub.CreatedEvent, event)
783 }
784
785 event = AgentEvent{
786 Type: AgentEventTypeSummarize,
787 SessionID: oldSession.ID,
788 Progress: "Summary complete",
789 Done: true,
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792 // Send final success event with the new session ID
793 }()
794
795 return nil
796}
797
798func (a *agent) CancelAll() {
799 a.activeRequests.Range(func(key, value any) bool {
800 a.Cancel(key.(string)) // key is sessionID
801 return true
802 })
803
804 timeout := time.After(5 * time.Second)
805 for a.IsBusy() {
806 select {
807 case <-timeout:
808 return
809 default:
810 time.Sleep(200 * time.Millisecond)
811 }
812 }
813}
814
815func (a *agent) UpdateModel() error {
816 cfg := config.Get()
817
818 // Get current provider configuration
819 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
820 if currentProviderCfg.ID == "" {
821 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
822 }
823
824 // Check if provider has changed
825 if string(currentProviderCfg.ID) != a.providerID {
826 // Provider changed, need to recreate the main provider
827 model := cfg.GetModelByType(a.agentCfg.Model)
828 if model.ID == "" {
829 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
830 }
831
832 promptID := agentPromptMap[a.agentCfg.ID]
833 if promptID == "" {
834 promptID = prompt.PromptDefault
835 }
836
837 opts := []provider.ProviderClientOption{
838 provider.WithModel(a.agentCfg.Model),
839 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
840 }
841
842 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
843 if err != nil {
844 return fmt.Errorf("failed to create new provider: %w", err)
845 }
846
847 // Update the provider and provider ID
848 a.provider = newProvider
849 a.providerID = string(currentProviderCfg.ID)
850 }
851
852 // Check if small model provider has changed (affects title and summarize providers)
853 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
854 var smallModelProviderCfg config.ProviderConfig
855
856 for _, p := range cfg.Providers {
857 if p.ID == smallModelCfg.Provider {
858 smallModelProviderCfg = p
859 break
860 }
861 }
862
863 if smallModelProviderCfg.ID == "" {
864 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
865 }
866
867 // Check if summarize provider has changed
868 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
869 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
870 if smallModel == nil {
871 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
872 }
873
874 // Recreate title provider
875 titleOpts := []provider.ProviderClientOption{
876 provider.WithModel(config.SelectedModelTypeSmall),
877 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
878 // We want the title to be short, so we limit the max tokens
879 provider.WithMaxTokens(40),
880 }
881 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
882 if err != nil {
883 return fmt.Errorf("failed to create new title provider: %w", err)
884 }
885
886 // Recreate summarize provider
887 summarizeOpts := []provider.ProviderClientOption{
888 provider.WithModel(config.SelectedModelTypeSmall),
889 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
890 }
891 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
892 if err != nil {
893 return fmt.Errorf("failed to create new summarize provider: %w", err)
894 }
895
896 // Update the providers and provider ID
897 a.titleProvider = newTitleProvider
898 a.summarizeProvider = newSummarizeProvider
899 a.summarizeProviderID = string(smallModelProviderCfg.ID)
900 }
901
902 return nil
903}