1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 tools *csync.LazySlice[tools.BaseTool]
72
73 provider provider.Provider
74 providerID string
75
76 titleProvider provider.Provider
77 summarizeProvider provider.Provider
78 summarizeProviderID string
79
80 activeRequests *csync.Map[string, context.CancelFunc]
81}
82
83var agentPromptMap = map[string]prompt.PromptID{
84 "coder": prompt.PromptCoder,
85 "task": prompt.PromptTask,
86}
87
88func NewAgent(
89 agentCfg config.Agent,
90 // These services are needed in the tools
91 permissions permission.Service,
92 sessions session.Service,
93 messages message.Service,
94 history history.Service,
95 lspClients map[string]*lsp.Client,
96) (Service, error) {
97 ctx := context.Background()
98 cfg := config.Get()
99
100 var agentTool tools.BaseTool
101 if agentCfg.ID == "coder" {
102 taskAgentCfg := config.Get().Agents["task"]
103 if taskAgentCfg.ID == "" {
104 return nil, fmt.Errorf("task agent not found in config")
105 }
106 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107 if err != nil {
108 return nil, fmt.Errorf("failed to create task agent: %w", err)
109 }
110
111 agentTool = NewAgentTool(taskAgent, sessions, messages)
112 }
113
114 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115 if providerCfg == nil {
116 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117 }
118 model := config.Get().GetModelByType(agentCfg.Model)
119
120 if model == nil {
121 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122 }
123
124 promptID := agentPromptMap[agentCfg.ID]
125 if promptID == "" {
126 promptID = prompt.PromptDefault
127 }
128 opts := []provider.ProviderClientOption{
129 provider.WithModel(agentCfg.Model),
130 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131 }
132 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133 if err != nil {
134 return nil, err
135 }
136
137 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138 var smallModelProviderCfg *config.ProviderConfig
139 if smallModelCfg.Provider == providerCfg.ID {
140 smallModelProviderCfg = providerCfg
141 } else {
142 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144 if smallModelProviderCfg.ID == "" {
145 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146 }
147 }
148 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149 if smallModel.ID == "" {
150 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151 }
152
153 titleOpts := []provider.ProviderClientOption{
154 provider.WithModel(config.SelectedModelTypeSmall),
155 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156 }
157 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158 if err != nil {
159 return nil, err
160 }
161 summarizeOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164 }
165 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 toolFn := func() []tools.BaseTool {
171 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172 defer func() {
173 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174 }()
175
176 cwd := cfg.WorkingDir()
177 allTools := []tools.BaseTool{
178 tools.NewBashTool(permissions, cwd),
179 tools.NewDownloadTool(permissions, cwd),
180 tools.NewEditTool(lspClients, permissions, history, cwd),
181 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
182 tools.NewFetchTool(permissions, cwd),
183 tools.NewGlobTool(cwd),
184 tools.NewGrepTool(cwd),
185 tools.NewLsTool(permissions, cwd),
186 tools.NewSourcegraphTool(),
187 tools.NewViewTool(lspClients, permissions, cwd),
188 tools.NewWriteTool(lspClients, permissions, history, cwd),
189 }
190
191 mcpTools := GetMCPTools(ctx, permissions, cfg)
192 allTools = append(allTools, mcpTools...)
193
194 if len(lspClients) > 0 {
195 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196 }
197
198 if agentTool != nil {
199 allTools = append(allTools, agentTool)
200 }
201
202 if agentCfg.AllowedTools == nil {
203 return allTools
204 }
205
206 var filteredTools []tools.BaseTool
207 for _, tool := range allTools {
208 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209 filteredTools = append(filteredTools, tool)
210 }
211 }
212 return filteredTools
213 }
214
215 return &agent{
216 Broker: pubsub.NewBroker[AgentEvent](),
217 agentCfg: agentCfg,
218 provider: agentProvider,
219 providerID: string(providerCfg.ID),
220 messages: messages,
221 sessions: sessions,
222 titleProvider: titleProvider,
223 summarizeProvider: summarizeProvider,
224 summarizeProviderID: string(smallModelProviderCfg.ID),
225 activeRequests: csync.NewMap[string, context.CancelFunc](),
226 tools: csync.NewLazySlice(toolFn),
227 }, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231 return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235 // Cancel regular requests
236 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
237 slog.Info("Request cancellation initiated", "session_id", sessionID)
238 cancel()
239 }
240
241 // Also check for summarize requests
242 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
243 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
244 cancel()
245 }
246}
247
248func (a *agent) IsBusy() bool {
249 var busy bool
250 for cancelFunc := range a.activeRequests.Seq() {
251 if cancelFunc != nil {
252 busy = true
253 break
254 }
255 }
256 return busy
257}
258
259func (a *agent) IsSessionBusy(sessionID string) bool {
260 _, busy := a.activeRequests.Get(sessionID)
261 return busy
262}
263
264func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
265 if content == "" {
266 return nil
267 }
268 if a.titleProvider == nil {
269 return nil
270 }
271 session, err := a.sessions.Get(ctx, sessionID)
272 if err != nil {
273 return err
274 }
275 parts := []message.ContentPart{message.TextContent{
276 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
277 }}
278
279 // Use streaming approach like summarization
280 response := a.titleProvider.StreamResponse(
281 ctx,
282 []message.Message{
283 {
284 Role: message.User,
285 Parts: parts,
286 },
287 },
288 nil,
289 )
290
291 var finalResponse *provider.ProviderResponse
292 for r := range response {
293 if r.Error != nil {
294 return r.Error
295 }
296 finalResponse = r.Response
297 }
298
299 if finalResponse == nil {
300 return fmt.Errorf("no response received from title provider")
301 }
302
303 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
304 if title == "" {
305 return nil
306 }
307
308 session.Title = title
309 _, err = a.sessions.Save(ctx, session)
310 return err
311}
312
313func (a *agent) err(err error) AgentEvent {
314 return AgentEvent{
315 Type: AgentEventTypeError,
316 Error: err,
317 }
318}
319
320func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
321 if !a.Model().SupportsImages && attachments != nil {
322 attachments = nil
323 }
324 events := make(chan AgentEvent)
325 if a.IsSessionBusy(sessionID) {
326 return nil, ErrSessionBusy
327 }
328
329 genCtx, cancel := context.WithCancel(ctx)
330
331 a.activeRequests.Set(sessionID, cancel)
332 go func() {
333 slog.Debug("Request started", "sessionID", sessionID)
334 defer log.RecoverPanic("agent.Run", func() {
335 events <- a.err(fmt.Errorf("panic while running the agent"))
336 })
337 var attachmentParts []message.ContentPart
338 for _, attachment := range attachments {
339 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
340 }
341 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
342 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
343 slog.Error(result.Error.Error())
344 }
345 slog.Debug("Request completed", "sessionID", sessionID)
346 a.activeRequests.Del(sessionID)
347 cancel()
348 a.Publish(pubsub.CreatedEvent, result)
349 events <- result
350 close(events)
351 }()
352 return events, nil
353}
354
355func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
356 cfg := config.Get()
357 // List existing messages; if none, start title generation asynchronously.
358 msgs, err := a.messages.List(ctx, sessionID)
359 if err != nil {
360 return a.err(fmt.Errorf("failed to list messages: %w", err))
361 }
362 if len(msgs) == 0 {
363 go func() {
364 defer log.RecoverPanic("agent.Run", func() {
365 slog.Error("panic while generating title")
366 })
367 titleErr := a.generateTitle(context.Background(), sessionID, content)
368 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
369 slog.Error("failed to generate title", "error", titleErr)
370 }
371 }()
372 }
373 session, err := a.sessions.Get(ctx, sessionID)
374 if err != nil {
375 return a.err(fmt.Errorf("failed to get session: %w", err))
376 }
377 if session.SummaryMessageID != "" {
378 summaryMsgInex := -1
379 for i, msg := range msgs {
380 if msg.ID == session.SummaryMessageID {
381 summaryMsgInex = i
382 break
383 }
384 }
385 if summaryMsgInex != -1 {
386 msgs = msgs[summaryMsgInex:]
387 msgs[0].Role = message.User
388 }
389 }
390
391 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
392 if err != nil {
393 return a.err(fmt.Errorf("failed to create user message: %w", err))
394 }
395 // Append the new user message to the conversation history.
396 msgHistory := append(msgs, userMsg)
397
398 for {
399 // Check for cancellation before each iteration
400 select {
401 case <-ctx.Done():
402 return a.err(ctx.Err())
403 default:
404 // Continue processing
405 }
406 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
407 if err != nil {
408 if errors.Is(err, context.Canceled) {
409 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
410 a.messages.Update(context.Background(), agentMessage)
411 return a.err(ErrRequestCancelled)
412 }
413 return a.err(fmt.Errorf("failed to process events: %w", err))
414 }
415 if cfg.Options.Debug {
416 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
417 }
418 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
419 // We are not done, we need to respond with the tool response
420 msgHistory = append(msgHistory, agentMessage, *toolResults)
421 continue
422 }
423 if agentMessage.FinishReason() == "" {
424 // Kujtim: could not track down where this is happening but this means its cancelled
425 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
426 _ = a.messages.Update(context.Background(), agentMessage)
427 return a.err(ErrRequestCancelled)
428 }
429 return AgentEvent{
430 Type: AgentEventTypeResponse,
431 Message: agentMessage,
432 Done: true,
433 }
434 }
435}
436
437func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
438 parts := []message.ContentPart{message.TextContent{Text: content}}
439 parts = append(parts, attachmentParts...)
440 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441 Role: message.User,
442 Parts: parts,
443 })
444}
445
446func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
447 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
448 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
449
450 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451 Role: message.Assistant,
452 Parts: []message.ContentPart{},
453 Model: a.Model().ID,
454 Provider: a.providerID,
455 })
456 if err != nil {
457 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458 }
459
460 // Add the session and message ID into the context if needed by tools.
461 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462
463 // Process each event in the stream.
464 for event := range eventChan {
465 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
466 if errors.Is(processErr, context.Canceled) {
467 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468 } else {
469 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
470 }
471 return assistantMsg, nil, processErr
472 }
473 if ctx.Err() != nil {
474 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475 return assistantMsg, nil, ctx.Err()
476 }
477 }
478
479 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
480 toolCalls := assistantMsg.ToolCalls()
481 for i, toolCall := range toolCalls {
482 select {
483 case <-ctx.Done():
484 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485 // Make all future tool calls cancelled
486 for j := i; j < len(toolCalls); j++ {
487 toolResults[j] = message.ToolResult{
488 ToolCallID: toolCalls[j].ID,
489 Content: "Tool execution canceled by user",
490 IsError: true,
491 }
492 }
493 goto out
494 default:
495 // Continue processing
496 var tool tools.BaseTool
497 for availableTool := range a.tools.Seq() {
498 if availableTool.Info().Name == toolCall.Name {
499 tool = availableTool
500 break
501 }
502 }
503
504 // Tool not found
505 if tool == nil {
506 toolResults[i] = message.ToolResult{
507 ToolCallID: toolCall.ID,
508 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
509 IsError: true,
510 }
511 continue
512 }
513
514 // Run tool in goroutine to allow cancellation
515 type toolExecResult struct {
516 response tools.ToolResponse
517 err error
518 }
519 resultChan := make(chan toolExecResult, 1)
520
521 go func() {
522 response, err := tool.Run(ctx, tools.ToolCall{
523 ID: toolCall.ID,
524 Name: toolCall.Name,
525 Input: toolCall.Input,
526 })
527 resultChan <- toolExecResult{response: response, err: err}
528 }()
529
530 var toolResponse tools.ToolResponse
531 var toolErr error
532
533 select {
534 case <-ctx.Done():
535 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
536 // Mark remaining tool calls as cancelled
537 for j := i; j < len(toolCalls); j++ {
538 toolResults[j] = message.ToolResult{
539 ToolCallID: toolCalls[j].ID,
540 Content: "Tool execution canceled by user",
541 IsError: true,
542 }
543 }
544 goto out
545 case result := <-resultChan:
546 toolResponse = result.response
547 toolErr = result.err
548 }
549
550 if toolErr != nil {
551 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
552 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
553 toolResults[i] = message.ToolResult{
554 ToolCallID: toolCall.ID,
555 Content: "Permission denied",
556 IsError: true,
557 }
558 for j := i + 1; j < len(toolCalls); j++ {
559 toolResults[j] = message.ToolResult{
560 ToolCallID: toolCalls[j].ID,
561 Content: "Tool execution canceled by user",
562 IsError: true,
563 }
564 }
565 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
566 break
567 }
568 }
569 toolResults[i] = message.ToolResult{
570 ToolCallID: toolCall.ID,
571 Content: toolResponse.Content,
572 Metadata: toolResponse.Metadata,
573 IsError: toolResponse.IsError,
574 }
575 }
576 }
577out:
578 if len(toolResults) == 0 {
579 return assistantMsg, nil, nil
580 }
581 parts := make([]message.ContentPart, 0)
582 for _, tr := range toolResults {
583 parts = append(parts, tr)
584 }
585 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
586 Role: message.Tool,
587 Parts: parts,
588 Provider: a.providerID,
589 })
590 if err != nil {
591 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
592 }
593
594 return assistantMsg, &msg, err
595}
596
597func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
598 msg.AddFinish(finishReason, message, details)
599 _ = a.messages.Update(ctx, *msg)
600}
601
602func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
603 select {
604 case <-ctx.Done():
605 return ctx.Err()
606 default:
607 // Continue processing.
608 }
609
610 switch event.Type {
611 case provider.EventThinkingDelta:
612 assistantMsg.AppendReasoningContent(event.Thinking)
613 return a.messages.Update(ctx, *assistantMsg)
614 case provider.EventSignatureDelta:
615 assistantMsg.AppendReasoningSignature(event.Signature)
616 return a.messages.Update(ctx, *assistantMsg)
617 case provider.EventContentDelta:
618 assistantMsg.FinishThinking()
619 assistantMsg.AppendContent(event.Content)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventToolUseStart:
622 assistantMsg.FinishThinking()
623 slog.Info("Tool call started", "toolCall", event.ToolCall)
624 assistantMsg.AddToolCall(*event.ToolCall)
625 return a.messages.Update(ctx, *assistantMsg)
626 case provider.EventToolUseDelta:
627 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
628 return a.messages.Update(ctx, *assistantMsg)
629 case provider.EventToolUseStop:
630 slog.Info("Finished tool call", "toolCall", event.ToolCall)
631 assistantMsg.FinishToolCall(event.ToolCall.ID)
632 return a.messages.Update(ctx, *assistantMsg)
633 case provider.EventError:
634 return event.Error
635 case provider.EventComplete:
636 assistantMsg.FinishThinking()
637 assistantMsg.SetToolCalls(event.Response.ToolCalls)
638 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
639 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
640 return fmt.Errorf("failed to update message: %w", err)
641 }
642 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
643 }
644
645 return nil
646}
647
648func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
649 sess, err := a.sessions.Get(ctx, sessionID)
650 if err != nil {
651 return fmt.Errorf("failed to get session: %w", err)
652 }
653
654 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
655 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
656 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
657 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
658
659 sess.Cost += cost
660 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
661 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
662
663 _, err = a.sessions.Save(ctx, sess)
664 if err != nil {
665 return fmt.Errorf("failed to save session: %w", err)
666 }
667 return nil
668}
669
670func (a *agent) Summarize(ctx context.Context, sessionID string) error {
671 if a.summarizeProvider == nil {
672 return fmt.Errorf("summarize provider not available")
673 }
674
675 // Check if session is busy
676 if a.IsSessionBusy(sessionID) {
677 return ErrSessionBusy
678 }
679
680 // Create a new context with cancellation
681 summarizeCtx, cancel := context.WithCancel(ctx)
682
683 // Store the cancel function in activeRequests to allow cancellation
684 a.activeRequests.Set(sessionID+"-summarize", cancel)
685
686 go func() {
687 defer a.activeRequests.Del(sessionID + "-summarize")
688 defer cancel()
689 event := AgentEvent{
690 Type: AgentEventTypeSummarize,
691 Progress: "Starting summarization...",
692 }
693
694 a.Publish(pubsub.CreatedEvent, event)
695 // Get all messages from the session
696 msgs, err := a.messages.List(summarizeCtx, sessionID)
697 if err != nil {
698 event = AgentEvent{
699 Type: AgentEventTypeError,
700 Error: fmt.Errorf("failed to list messages: %w", err),
701 Done: true,
702 }
703 a.Publish(pubsub.CreatedEvent, event)
704 return
705 }
706 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
707
708 if len(msgs) == 0 {
709 event = AgentEvent{
710 Type: AgentEventTypeError,
711 Error: fmt.Errorf("no messages to summarize"),
712 Done: true,
713 }
714 a.Publish(pubsub.CreatedEvent, event)
715 return
716 }
717
718 event = AgentEvent{
719 Type: AgentEventTypeSummarize,
720 Progress: "Analyzing conversation...",
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723
724 // Add a system message to guide the summarization
725 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
726
727 // Create a new message with the summarize prompt
728 promptMsg := message.Message{
729 Role: message.User,
730 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
731 }
732
733 // Append the prompt to the messages
734 msgsWithPrompt := append(msgs, promptMsg)
735
736 event = AgentEvent{
737 Type: AgentEventTypeSummarize,
738 Progress: "Generating summary...",
739 }
740
741 a.Publish(pubsub.CreatedEvent, event)
742
743 // Send the messages to the summarize provider
744 response := a.summarizeProvider.StreamResponse(
745 summarizeCtx,
746 msgsWithPrompt,
747 nil,
748 )
749 var finalResponse *provider.ProviderResponse
750 for r := range response {
751 if r.Error != nil {
752 event = AgentEvent{
753 Type: AgentEventTypeError,
754 Error: fmt.Errorf("failed to summarize: %w", err),
755 Done: true,
756 }
757 a.Publish(pubsub.CreatedEvent, event)
758 return
759 }
760 finalResponse = r.Response
761 }
762
763 summary := strings.TrimSpace(finalResponse.Content)
764 if summary == "" {
765 event = AgentEvent{
766 Type: AgentEventTypeError,
767 Error: fmt.Errorf("empty summary returned"),
768 Done: true,
769 }
770 a.Publish(pubsub.CreatedEvent, event)
771 return
772 }
773 shell := shell.GetPersistentShell(config.Get().WorkingDir())
774 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
775 event = AgentEvent{
776 Type: AgentEventTypeSummarize,
777 Progress: "Creating new session...",
778 }
779
780 a.Publish(pubsub.CreatedEvent, event)
781 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
782 if err != nil {
783 event = AgentEvent{
784 Type: AgentEventTypeError,
785 Error: fmt.Errorf("failed to get session: %w", err),
786 Done: true,
787 }
788
789 a.Publish(pubsub.CreatedEvent, event)
790 return
791 }
792 // Create a message in the new session with the summary
793 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
794 Role: message.Assistant,
795 Parts: []message.ContentPart{
796 message.TextContent{Text: summary},
797 message.Finish{
798 Reason: message.FinishReasonEndTurn,
799 Time: time.Now().Unix(),
800 },
801 },
802 Model: a.summarizeProvider.Model().ID,
803 Provider: a.summarizeProviderID,
804 })
805 if err != nil {
806 event = AgentEvent{
807 Type: AgentEventTypeError,
808 Error: fmt.Errorf("failed to create summary message: %w", err),
809 Done: true,
810 }
811
812 a.Publish(pubsub.CreatedEvent, event)
813 return
814 }
815 oldSession.SummaryMessageID = msg.ID
816 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
817 oldSession.PromptTokens = 0
818 model := a.summarizeProvider.Model()
819 usage := finalResponse.Usage
820 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
821 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
822 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
823 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
824 oldSession.Cost += cost
825 _, err = a.sessions.Save(summarizeCtx, oldSession)
826 if err != nil {
827 event = AgentEvent{
828 Type: AgentEventTypeError,
829 Error: fmt.Errorf("failed to save session: %w", err),
830 Done: true,
831 }
832 a.Publish(pubsub.CreatedEvent, event)
833 }
834
835 event = AgentEvent{
836 Type: AgentEventTypeSummarize,
837 SessionID: oldSession.ID,
838 Progress: "Summary complete",
839 Done: true,
840 }
841 a.Publish(pubsub.CreatedEvent, event)
842 // Send final success event with the new session ID
843 }()
844
845 return nil
846}
847
848func (a *agent) CancelAll() {
849 if !a.IsBusy() {
850 return
851 }
852 for key := range a.activeRequests.Seq2() {
853 a.Cancel(key) // key is sessionID
854 }
855
856 timeout := time.After(5 * time.Second)
857 for a.IsBusy() {
858 select {
859 case <-timeout:
860 return
861 default:
862 time.Sleep(200 * time.Millisecond)
863 }
864 }
865}
866
867func (a *agent) UpdateModel() error {
868 cfg := config.Get()
869
870 // Get current provider configuration
871 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
872 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
873 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
874 }
875
876 // Check if provider has changed
877 if string(currentProviderCfg.ID) != a.providerID {
878 // Provider changed, need to recreate the main provider
879 model := cfg.GetModelByType(a.agentCfg.Model)
880 if model.ID == "" {
881 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
882 }
883
884 promptID := agentPromptMap[a.agentCfg.ID]
885 if promptID == "" {
886 promptID = prompt.PromptDefault
887 }
888
889 opts := []provider.ProviderClientOption{
890 provider.WithModel(a.agentCfg.Model),
891 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
892 }
893
894 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
895 if err != nil {
896 return fmt.Errorf("failed to create new provider: %w", err)
897 }
898
899 // Update the provider and provider ID
900 a.provider = newProvider
901 a.providerID = string(currentProviderCfg.ID)
902 }
903
904 // Check if small model provider has changed (affects title and summarize providers)
905 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
906 var smallModelProviderCfg config.ProviderConfig
907
908 for p := range cfg.Providers.Seq() {
909 if p.ID == smallModelCfg.Provider {
910 smallModelProviderCfg = p
911 break
912 }
913 }
914
915 if smallModelProviderCfg.ID == "" {
916 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
917 }
918
919 // Check if summarize provider has changed
920 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
921 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
922 if smallModel == nil {
923 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
924 }
925
926 // Recreate title provider
927 titleOpts := []provider.ProviderClientOption{
928 provider.WithModel(config.SelectedModelTypeSmall),
929 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
930 // We want the title to be short, so we limit the max tokens
931 provider.WithMaxTokens(40),
932 }
933 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
934 if err != nil {
935 return fmt.Errorf("failed to create new title provider: %w", err)
936 }
937
938 // Recreate summarize provider
939 summarizeOpts := []provider.ProviderClientOption{
940 provider.WithModel(config.SelectedModelTypeSmall),
941 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
942 }
943 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
944 if err != nil {
945 return fmt.Errorf("failed to create new summarize provider: %w", err)
946 }
947
948 // Update the providers and provider ID
949 a.titleProvider = newTitleProvider
950 a.summarizeProvider = newSummarizeProvider
951 a.summarizeProviderID = string(smallModelProviderCfg.ID)
952 }
953
954 return nil
955}