1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 EffectiveMaxTokens() int64
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[config.AgentID]prompt.PromptID{
81 config.AgentCoder: prompt.PromptCoder,
82 config.AgentTask: prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(lspClients) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 }
100
101 allTools := []tools.BaseTool{
102 tools.NewBashTool(permissions),
103 tools.NewEditTool(lspClients, permissions, history),
104 tools.NewFetchTool(permissions),
105 tools.NewGlobTool(),
106 tools.NewGrepTool(),
107 tools.NewLsTool(),
108 tools.NewSourcegraphTool(),
109 tools.NewViewTool(lspClients),
110 tools.NewVSCodeDiffTool(permissions),
111 tools.NewWriteTool(lspClients, permissions, history),
112 }
113
114 if agentCfg.ID == config.AgentCoder {
115 taskAgentCfg := config.Get().Agents[config.AgentTask]
116 if taskAgentCfg.ID == "" {
117 return nil, fmt.Errorf("task agent not found in config")
118 }
119 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120 if err != nil {
121 return nil, fmt.Errorf("failed to create task agent: %w", err)
122 }
123
124 allTools = append(
125 allTools,
126 NewAgentTool(
127 taskAgent,
128 sessions,
129 messages,
130 ),
131 )
132 }
133
134 allTools = append(allTools, otherTools...)
135 providerCfg := config.GetAgentProvider(agentCfg.ID)
136 if providerCfg.ID == "" {
137 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138 }
139 model := config.GetAgentModel(agentCfg.ID)
140
141 if model.ID == "" {
142 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143 }
144
145 promptID := agentPromptMap[agentCfg.ID]
146 if promptID == "" {
147 promptID = prompt.PromptDefault
148 }
149 opts := []provider.ProviderClientOption{
150 provider.WithModel(agentCfg.Model),
151 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152 }
153 agentProvider, err := provider.NewProvider(providerCfg, opts...)
154 if err != nil {
155 return nil, err
156 }
157
158 smallModelCfg := cfg.Models.Small
159 var smallModel config.Model
160
161 var smallModelProviderCfg config.ProviderConfig
162 if smallModelCfg.Provider == providerCfg.ID {
163 smallModelProviderCfg = providerCfg
164 } else {
165 for _, p := range cfg.Providers {
166 if p.ID == smallModelCfg.Provider {
167 smallModelProviderCfg = p
168 break
169 }
170 }
171 if smallModelProviderCfg.ID == "" {
172 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
173 }
174 }
175 for _, m := range smallModelProviderCfg.Models {
176 if m.ID == smallModelCfg.ModelID {
177 smallModel = m
178 break
179 }
180 }
181 if smallModel.ID == "" {
182 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
183 }
184
185 titleOpts := []provider.ProviderClientOption{
186 provider.WithModel(config.SmallModel),
187 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
188 }
189 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
190 if err != nil {
191 return nil, err
192 }
193 summarizeOpts := []provider.ProviderClientOption{
194 provider.WithModel(config.SmallModel),
195 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196 }
197 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
198 if err != nil {
199 return nil, err
200 }
201
202 agentTools := []tools.BaseTool{}
203 if agentCfg.AllowedTools == nil {
204 agentTools = allTools
205 } else {
206 for _, tool := range allTools {
207 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208 agentTools = append(agentTools, tool)
209 }
210 }
211 }
212
213 agent := &agent{
214 Broker: pubsub.NewBroker[AgentEvent](),
215 agentCfg: agentCfg,
216 provider: agentProvider,
217 providerID: string(providerCfg.ID),
218 messages: messages,
219 sessions: sessions,
220 tools: agentTools,
221 titleProvider: titleProvider,
222 summarizeProvider: summarizeProvider,
223 summarizeProviderID: string(smallModelProviderCfg.ID),
224 activeRequests: sync.Map{},
225 }
226
227 return agent, nil
228}
229
230func (a *agent) Model() config.Model {
231 return config.GetAgentModel(a.agentCfg.ID)
232}
233
234func (a *agent) EffectiveMaxTokens() int64 {
235 return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243 cancel()
244 }
245 }
246
247 // Also check for summarize requests
248 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251 cancel()
252 }
253 }
254}
255
256func (a *agent) IsBusy() bool {
257 busy := false
258 a.activeRequests.Range(func(key, value any) bool {
259 if cancelFunc, ok := value.(context.CancelFunc); ok {
260 if cancelFunc != nil {
261 busy = true
262 return false
263 }
264 }
265 return true
266 })
267 return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271 _, busy := a.activeRequests.Load(sessionID)
272 return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276 if content == "" {
277 return nil
278 }
279 if a.titleProvider == nil {
280 return nil
281 }
282 session, err := a.sessions.Get(ctx, sessionID)
283 if err != nil {
284 return err
285 }
286 parts := []message.ContentPart{message.TextContent{
287 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288 }}
289
290 // Use streaming approach like summarization
291 response := a.titleProvider.StreamResponse(
292 ctx,
293 []message.Message{
294 {
295 Role: message.User,
296 Parts: parts,
297 },
298 },
299 make([]tools.BaseTool, 0),
300 )
301
302 var finalResponse *provider.ProviderResponse
303 for r := range response {
304 if r.Error != nil {
305 return r.Error
306 }
307 finalResponse = r.Response
308 }
309
310 if finalResponse == nil {
311 return fmt.Errorf("no response received from title provider")
312 }
313
314 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315 if title == "" {
316 return nil
317 }
318
319 session.Title = title
320 _, err = a.sessions.Save(ctx, session)
321 return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325 return AgentEvent{
326 Type: AgentEventTypeError,
327 Error: err,
328 }
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332 if !a.Model().SupportsImages && attachments != nil {
333 attachments = nil
334 }
335 events := make(chan AgentEvent)
336 if a.IsSessionBusy(sessionID) {
337 return nil, ErrSessionBusy
338 }
339
340 genCtx, cancel := context.WithCancel(ctx)
341
342 a.activeRequests.Store(sessionID, cancel)
343 go func() {
344 logging.Debug("Request started", "sessionID", sessionID)
345 defer logging.RecoverPanic("agent.Run", func() {
346 events <- a.err(fmt.Errorf("panic while running the agent"))
347 })
348 var attachmentParts []message.ContentPart
349 for _, attachment := range attachments {
350 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351 }
352 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354 logging.ErrorPersist(result.Error.Error())
355 }
356 logging.Debug("Request completed", "sessionID", sessionID)
357 a.activeRequests.Delete(sessionID)
358 cancel()
359 a.Publish(pubsub.CreatedEvent, result)
360 events <- result
361 close(events)
362 }()
363 return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367 cfg := config.Get()
368 // List existing messages; if none, start title generation asynchronously.
369 msgs, err := a.messages.List(ctx, sessionID)
370 if err != nil {
371 return a.err(fmt.Errorf("failed to list messages: %w", err))
372 }
373 if len(msgs) == 0 {
374 go func() {
375 defer logging.RecoverPanic("agent.Run", func() {
376 logging.ErrorPersist("panic while generating title")
377 })
378 titleErr := a.generateTitle(context.Background(), sessionID, content)
379 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
380 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
381 }
382 }()
383 }
384 session, err := a.sessions.Get(ctx, sessionID)
385 if err != nil {
386 return a.err(fmt.Errorf("failed to get session: %w", err))
387 }
388 if session.SummaryMessageID != "" {
389 summaryMsgInex := -1
390 for i, msg := range msgs {
391 if msg.ID == session.SummaryMessageID {
392 summaryMsgInex = i
393 break
394 }
395 }
396 if summaryMsgInex != -1 {
397 msgs = msgs[summaryMsgInex:]
398 msgs[0].Role = message.User
399 }
400 }
401
402 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to create user message: %w", err))
405 }
406 // Append the new user message to the conversation history.
407 msgHistory := append(msgs, userMsg)
408
409 for {
410 // Check for cancellation before each iteration
411 select {
412 case <-ctx.Done():
413 return a.err(ctx.Err())
414 default:
415 // Continue processing
416 }
417 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
418 if err != nil {
419 if errors.Is(err, context.Canceled) {
420 agentMessage.AddFinish(message.FinishReasonCanceled)
421 a.messages.Update(context.Background(), agentMessage)
422 return a.err(ErrRequestCancelled)
423 }
424 return a.err(fmt.Errorf("failed to process events: %w", err))
425 }
426 if cfg.Options.Debug {
427 seqId := (len(msgHistory) + 1) / 2
428 toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
429 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
430 } else {
431 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
432 }
433 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
434 // We are not done, we need to respond with the tool response
435 msgHistory = append(msgHistory, agentMessage, *toolResults)
436 continue
437 }
438 return AgentEvent{
439 Type: AgentEventTypeResponse,
440 Message: agentMessage,
441 Done: true,
442 }
443 }
444}
445
446func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
447 parts := []message.ContentPart{message.TextContent{Text: content}}
448 parts = append(parts, attachmentParts...)
449 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
450 Role: message.User,
451 Parts: parts,
452 })
453}
454
455func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
456 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
457 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
458
459 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
460 Role: message.Assistant,
461 Parts: []message.ContentPart{},
462 Model: a.Model().ID,
463 Provider: a.providerID,
464 })
465 if err != nil {
466 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
467 }
468
469 // Add the session and message ID into the context if needed by tools.
470 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
471
472 // Process each event in the stream.
473 for event := range eventChan {
474 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
475 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
476 return assistantMsg, nil, processErr
477 }
478 if ctx.Err() != nil {
479 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
480 return assistantMsg, nil, ctx.Err()
481 }
482 }
483
484 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
485 toolCalls := assistantMsg.ToolCalls()
486 for i, toolCall := range toolCalls {
487 select {
488 case <-ctx.Done():
489 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
490 // Make all future tool calls cancelled
491 for j := i; j < len(toolCalls); j++ {
492 toolResults[j] = message.ToolResult{
493 ToolCallID: toolCalls[j].ID,
494 Content: "Tool execution canceled by user",
495 IsError: true,
496 }
497 }
498 goto out
499 default:
500 // Continue processing
501 var tool tools.BaseTool
502 for _, availableTool := range a.tools {
503 if availableTool.Info().Name == toolCall.Name {
504 tool = availableTool
505 break
506 }
507 }
508
509 // Tool not found
510 if tool == nil {
511 toolResults[i] = message.ToolResult{
512 ToolCallID: toolCall.ID,
513 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
514 IsError: true,
515 }
516 continue
517 }
518 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
519 ID: toolCall.ID,
520 Name: toolCall.Name,
521 Input: toolCall.Input,
522 })
523 if toolErr != nil {
524 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
525 toolResults[i] = message.ToolResult{
526 ToolCallID: toolCall.ID,
527 Content: "Permission denied",
528 IsError: true,
529 }
530 for j := i + 1; j < len(toolCalls); j++ {
531 toolResults[j] = message.ToolResult{
532 ToolCallID: toolCalls[j].ID,
533 Content: "Tool execution canceled by user",
534 IsError: true,
535 }
536 }
537 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
538 break
539 }
540 }
541 toolResults[i] = message.ToolResult{
542 ToolCallID: toolCall.ID,
543 Content: toolResult.Content,
544 Metadata: toolResult.Metadata,
545 IsError: toolResult.IsError,
546 }
547 }
548 }
549out:
550 if len(toolResults) == 0 {
551 return assistantMsg, nil, nil
552 }
553 parts := make([]message.ContentPart, 0)
554 for _, tr := range toolResults {
555 parts = append(parts, tr)
556 }
557 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
558 Role: message.Tool,
559 Parts: parts,
560 Provider: a.providerID,
561 })
562 if err != nil {
563 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
564 }
565
566 return assistantMsg, &msg, err
567}
568
569func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
570 msg.AddFinish(finishReson)
571 _ = a.messages.Update(ctx, *msg)
572}
573
574func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
575 select {
576 case <-ctx.Done():
577 return ctx.Err()
578 default:
579 // Continue processing.
580 }
581
582 switch event.Type {
583 case provider.EventThinkingDelta:
584 assistantMsg.AppendReasoningContent(event.Content)
585 return a.messages.Update(ctx, *assistantMsg)
586 case provider.EventContentDelta:
587 assistantMsg.AppendContent(event.Content)
588 return a.messages.Update(ctx, *assistantMsg)
589 case provider.EventToolUseStart:
590 logging.Info("Tool call started", "toolCall", event.ToolCall)
591 assistantMsg.AddToolCall(*event.ToolCall)
592 return a.messages.Update(ctx, *assistantMsg)
593 case provider.EventToolUseDelta:
594 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
595 return a.messages.Update(ctx, *assistantMsg)
596 case provider.EventToolUseStop:
597 logging.Info("Finished tool call", "toolCall", event.ToolCall)
598 assistantMsg.FinishToolCall(event.ToolCall.ID)
599 return a.messages.Update(ctx, *assistantMsg)
600 case provider.EventError:
601 if errors.Is(event.Error, context.Canceled) {
602 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
603 return context.Canceled
604 }
605 logging.ErrorPersist(event.Error.Error())
606 return event.Error
607 case provider.EventComplete:
608 assistantMsg.SetToolCalls(event.Response.ToolCalls)
609 assistantMsg.AddFinish(event.Response.FinishReason)
610 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
611 return fmt.Errorf("failed to update message: %w", err)
612 }
613 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
614 }
615
616 return nil
617}
618
619func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
620 sess, err := a.sessions.Get(ctx, sessionID)
621 if err != nil {
622 return fmt.Errorf("failed to get session: %w", err)
623 }
624
625 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
626 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
627 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
628 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
629
630 sess.Cost += cost
631 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
632 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
633
634 _, err = a.sessions.Save(ctx, sess)
635 if err != nil {
636 return fmt.Errorf("failed to save session: %w", err)
637 }
638 return nil
639}
640
641func (a *agent) Summarize(ctx context.Context, sessionID string) error {
642 if a.summarizeProvider == nil {
643 return fmt.Errorf("summarize provider not available")
644 }
645
646 // Check if session is busy
647 if a.IsSessionBusy(sessionID) {
648 return ErrSessionBusy
649 }
650
651 // Create a new context with cancellation
652 summarizeCtx, cancel := context.WithCancel(ctx)
653
654 // Store the cancel function in activeRequests to allow cancellation
655 a.activeRequests.Store(sessionID+"-summarize", cancel)
656
657 go func() {
658 defer a.activeRequests.Delete(sessionID + "-summarize")
659 defer cancel()
660 event := AgentEvent{
661 Type: AgentEventTypeSummarize,
662 Progress: "Starting summarization...",
663 }
664
665 a.Publish(pubsub.CreatedEvent, event)
666 // Get all messages from the session
667 msgs, err := a.messages.List(summarizeCtx, sessionID)
668 if err != nil {
669 event = AgentEvent{
670 Type: AgentEventTypeError,
671 Error: fmt.Errorf("failed to list messages: %w", err),
672 Done: true,
673 }
674 a.Publish(pubsub.CreatedEvent, event)
675 return
676 }
677 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
678
679 if len(msgs) == 0 {
680 event = AgentEvent{
681 Type: AgentEventTypeError,
682 Error: fmt.Errorf("no messages to summarize"),
683 Done: true,
684 }
685 a.Publish(pubsub.CreatedEvent, event)
686 return
687 }
688
689 event = AgentEvent{
690 Type: AgentEventTypeSummarize,
691 Progress: "Analyzing conversation...",
692 }
693 a.Publish(pubsub.CreatedEvent, event)
694
695 // Add a system message to guide the summarization
696 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
697
698 // Create a new message with the summarize prompt
699 promptMsg := message.Message{
700 Role: message.User,
701 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
702 }
703
704 // Append the prompt to the messages
705 msgsWithPrompt := append(msgs, promptMsg)
706
707 event = AgentEvent{
708 Type: AgentEventTypeSummarize,
709 Progress: "Generating summary...",
710 }
711
712 a.Publish(pubsub.CreatedEvent, event)
713
714 // Send the messages to the summarize provider
715 response := a.summarizeProvider.StreamResponse(
716 summarizeCtx,
717 msgsWithPrompt,
718 make([]tools.BaseTool, 0),
719 )
720 var finalResponse *provider.ProviderResponse
721 for r := range response {
722 if r.Error != nil {
723 event = AgentEvent{
724 Type: AgentEventTypeError,
725 Error: fmt.Errorf("failed to summarize: %w", err),
726 Done: true,
727 }
728 a.Publish(pubsub.CreatedEvent, event)
729 return
730 }
731 finalResponse = r.Response
732 }
733
734 summary := strings.TrimSpace(finalResponse.Content)
735 if summary == "" {
736 event = AgentEvent{
737 Type: AgentEventTypeError,
738 Error: fmt.Errorf("empty summary returned"),
739 Done: true,
740 }
741 a.Publish(pubsub.CreatedEvent, event)
742 return
743 }
744 event = AgentEvent{
745 Type: AgentEventTypeSummarize,
746 Progress: "Creating new session...",
747 }
748
749 a.Publish(pubsub.CreatedEvent, event)
750 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
751 if err != nil {
752 event = AgentEvent{
753 Type: AgentEventTypeError,
754 Error: fmt.Errorf("failed to get session: %w", err),
755 Done: true,
756 }
757
758 a.Publish(pubsub.CreatedEvent, event)
759 return
760 }
761 // Create a message in the new session with the summary
762 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
763 Role: message.Assistant,
764 Parts: []message.ContentPart{
765 message.TextContent{Text: summary},
766 message.Finish{
767 Reason: message.FinishReasonEndTurn,
768 Time: time.Now().Unix(),
769 },
770 },
771 Model: a.summarizeProvider.Model().ID,
772 Provider: a.summarizeProviderID,
773 })
774 if err != nil {
775 event = AgentEvent{
776 Type: AgentEventTypeError,
777 Error: fmt.Errorf("failed to create summary message: %w", err),
778 Done: true,
779 }
780
781 a.Publish(pubsub.CreatedEvent, event)
782 return
783 }
784 oldSession.SummaryMessageID = msg.ID
785 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
786 oldSession.PromptTokens = 0
787 model := a.summarizeProvider.Model()
788 usage := finalResponse.Usage
789 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
790 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
791 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
792 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
793 oldSession.Cost += cost
794 _, err = a.sessions.Save(summarizeCtx, oldSession)
795 if err != nil {
796 event = AgentEvent{
797 Type: AgentEventTypeError,
798 Error: fmt.Errorf("failed to save session: %w", err),
799 Done: true,
800 }
801 a.Publish(pubsub.CreatedEvent, event)
802 }
803
804 event = AgentEvent{
805 Type: AgentEventTypeSummarize,
806 SessionID: oldSession.ID,
807 Progress: "Summary complete",
808 Done: true,
809 }
810 a.Publish(pubsub.CreatedEvent, event)
811 // Send final success event with the new session ID
812 }()
813
814 return nil
815}
816
817func (a *agent) CancelAll() {
818 a.activeRequests.Range(func(key, value any) bool {
819 a.Cancel(key.(string)) // key is sessionID
820 return true
821 })
822}
823
824func (a *agent) UpdateModel() error {
825 cfg := config.Get()
826
827 // Get current provider configuration
828 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
829 if currentProviderCfg.ID == "" {
830 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
831 }
832
833 // Check if provider has changed
834 if string(currentProviderCfg.ID) != a.providerID {
835 // Provider changed, need to recreate the main provider
836 model := config.GetAgentModel(a.agentCfg.ID)
837 if model.ID == "" {
838 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
839 }
840
841 promptID := agentPromptMap[a.agentCfg.ID]
842 if promptID == "" {
843 promptID = prompt.PromptDefault
844 }
845
846 opts := []provider.ProviderClientOption{
847 provider.WithModel(a.agentCfg.Model),
848 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
849 }
850
851 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
852 if err != nil {
853 return fmt.Errorf("failed to create new provider: %w", err)
854 }
855
856 // Update the provider and provider ID
857 a.provider = newProvider
858 a.providerID = string(currentProviderCfg.ID)
859 }
860
861 // Check if small model provider has changed (affects title and summarize providers)
862 smallModelCfg := cfg.Models.Small
863 var smallModelProviderCfg config.ProviderConfig
864
865 for _, p := range cfg.Providers {
866 if p.ID == smallModelCfg.Provider {
867 smallModelProviderCfg = p
868 break
869 }
870 }
871
872 if smallModelProviderCfg.ID == "" {
873 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
874 }
875
876 // Check if summarize provider has changed
877 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
878 var smallModel config.Model
879 for _, m := range smallModelProviderCfg.Models {
880 if m.ID == smallModelCfg.ModelID {
881 smallModel = m
882 break
883 }
884 }
885 if smallModel.ID == "" {
886 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
887 }
888
889 // Recreate title provider
890 titleOpts := []provider.ProviderClientOption{
891 provider.WithModel(config.SmallModel),
892 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
893 // We want the title to be short, so we limit the max tokens
894 provider.WithMaxTokens(40),
895 }
896 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
897 if err != nil {
898 return fmt.Errorf("failed to create new title provider: %w", err)
899 }
900
901 // Recreate summarize provider
902 summarizeOpts := []provider.ProviderClientOption{
903 provider.WithModel(config.SmallModel),
904 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
905 }
906 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
907 if err != nil {
908 return fmt.Errorf("failed to create new summarize provider: %w", err)
909 }
910
911 // Update the providers and provider ID
912 a.titleProvider = newTitleProvider
913 a.summarizeProvider = newSummarizeProvider
914 a.summarizeProviderID = string(smallModelProviderCfg.ID)
915 }
916
917 return nil
918}