1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpTools := GetMCPTools(ctx, permissions, cfg)
193 allTools = append(allTools, mcpTools...)
194
195 if len(lspClients) > 0 {
196 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
197 }
198
199 if agentTool != nil {
200 allTools = append(allTools, agentTool)
201 }
202
203 if agentCfg.AllowedTools == nil {
204 return allTools
205 }
206
207 var filteredTools []tools.BaseTool
208 for _, tool := range allTools {
209 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
210 filteredTools = append(filteredTools, tool)
211 }
212 }
213 return filteredTools
214 }
215
216 return &agent{
217 Broker: pubsub.NewBroker[AgentEvent](),
218 agentCfg: agentCfg,
219 provider: agentProvider,
220 providerID: string(providerCfg.ID),
221 messages: messages,
222 sessions: sessions,
223 titleProvider: titleProvider,
224 summarizeProvider: summarizeProvider,
225 summarizeProviderID: string(smallModelProviderCfg.ID),
226 activeRequests: csync.NewMap[string, context.CancelFunc](),
227 tools: csync.NewLazySlice(toolFn),
228 }, nil
229}
230
231func (a *agent) Model() catwalk.Model {
232 return *config.Get().GetModelByType(a.agentCfg.Model)
233}
234
235func (a *agent) Cancel(sessionID string) {
236 // Cancel regular requests
237 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
238 slog.Info("Request cancellation initiated", "session_id", sessionID)
239 cancel()
240 }
241
242 // Also check for summarize requests
243 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
244 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
245 cancel()
246 }
247}
248
249func (a *agent) IsBusy() bool {
250 var busy bool
251 for cancelFunc := range a.activeRequests.Seq() {
252 if cancelFunc != nil {
253 busy = true
254 break
255 }
256 }
257 return busy
258}
259
260func (a *agent) IsSessionBusy(sessionID string) bool {
261 _, busy := a.activeRequests.Get(sessionID)
262 return busy
263}
264
265func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
266 if content == "" {
267 return nil
268 }
269 if a.titleProvider == nil {
270 return nil
271 }
272 session, err := a.sessions.Get(ctx, sessionID)
273 if err != nil {
274 return err
275 }
276 parts := []message.ContentPart{message.TextContent{
277 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
278 }}
279
280 // Use streaming approach like summarization
281 response := a.titleProvider.StreamResponse(
282 ctx,
283 []message.Message{
284 {
285 Role: message.User,
286 Parts: parts,
287 },
288 },
289 nil,
290 )
291
292 var finalResponse *provider.ProviderResponse
293 for r := range response {
294 if r.Error != nil {
295 return r.Error
296 }
297 finalResponse = r.Response
298 }
299
300 if finalResponse == nil {
301 return fmt.Errorf("no response received from title provider")
302 }
303
304 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
305 if title == "" {
306 return nil
307 }
308
309 session.Title = title
310 _, err = a.sessions.Save(ctx, session)
311 return err
312}
313
314func (a *agent) err(err error) AgentEvent {
315 return AgentEvent{
316 Type: AgentEventTypeError,
317 Error: err,
318 }
319}
320
321func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
322 if !a.Model().SupportsImages && attachments != nil {
323 attachments = nil
324 }
325 events := make(chan AgentEvent)
326 if a.IsSessionBusy(sessionID) {
327 return nil, ErrSessionBusy
328 }
329
330 genCtx, cancel := context.WithCancel(ctx)
331
332 a.activeRequests.Set(sessionID, cancel)
333 go func() {
334 slog.Debug("Request started", "sessionID", sessionID)
335 defer log.RecoverPanic("agent.Run", func() {
336 events <- a.err(fmt.Errorf("panic while running the agent"))
337 })
338 var attachmentParts []message.ContentPart
339 for _, attachment := range attachments {
340 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
341 }
342 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
343 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
344 slog.Error(result.Error.Error())
345 }
346 slog.Debug("Request completed", "sessionID", sessionID)
347 a.activeRequests.Del(sessionID)
348 cancel()
349 a.Publish(pubsub.CreatedEvent, result)
350 events <- result
351 close(events)
352 }()
353 return events, nil
354}
355
356func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
357 cfg := config.Get()
358 // List existing messages; if none, start title generation asynchronously.
359 msgs, err := a.messages.List(ctx, sessionID)
360 if err != nil {
361 return a.err(fmt.Errorf("failed to list messages: %w", err))
362 }
363 if len(msgs) == 0 {
364 go func() {
365 defer log.RecoverPanic("agent.Run", func() {
366 slog.Error("panic while generating title")
367 })
368 titleErr := a.generateTitle(context.Background(), sessionID, content)
369 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
370 slog.Error("failed to generate title", "error", titleErr)
371 }
372 }()
373 }
374 session, err := a.sessions.Get(ctx, sessionID)
375 if err != nil {
376 return a.err(fmt.Errorf("failed to get session: %w", err))
377 }
378 if session.SummaryMessageID != "" {
379 summaryMsgInex := -1
380 for i, msg := range msgs {
381 if msg.ID == session.SummaryMessageID {
382 summaryMsgInex = i
383 break
384 }
385 }
386 if summaryMsgInex != -1 {
387 msgs = msgs[summaryMsgInex:]
388 msgs[0].Role = message.User
389 }
390 }
391
392 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
393 if err != nil {
394 return a.err(fmt.Errorf("failed to create user message: %w", err))
395 }
396 // Append the new user message to the conversation history.
397 msgHistory := append(msgs, userMsg)
398
399 for {
400 // Check for cancellation before each iteration
401 select {
402 case <-ctx.Done():
403 return a.err(ctx.Err())
404 default:
405 // Continue processing
406 }
407 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
408 if err != nil {
409 if errors.Is(err, context.Canceled) {
410 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
411 a.messages.Update(context.Background(), agentMessage)
412 return a.err(ErrRequestCancelled)
413 }
414 return a.err(fmt.Errorf("failed to process events: %w", err))
415 }
416 if cfg.Options.Debug {
417 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
418 }
419 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
420 // We are not done, we need to respond with the tool response
421 msgHistory = append(msgHistory, agentMessage, *toolResults)
422 continue
423 }
424 if agentMessage.FinishReason() == "" {
425 // Kujtim: could not track down where this is happening but this means its cancelled
426 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
427 _ = a.messages.Update(context.Background(), agentMessage)
428 return a.err(ErrRequestCancelled)
429 }
430 return AgentEvent{
431 Type: AgentEventTypeResponse,
432 Message: agentMessage,
433 Done: true,
434 }
435 }
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439 parts := []message.ContentPart{message.TextContent{Text: content}}
440 parts = append(parts, attachmentParts...)
441 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.User,
443 Parts: parts,
444 })
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
449 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
450
451 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452 Role: message.Assistant,
453 Parts: []message.ContentPart{},
454 Model: a.Model().ID,
455 Provider: a.providerID,
456 })
457 if err != nil {
458 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459 }
460
461 // Add the session and message ID into the context if needed by tools.
462 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463
464 // Process each event in the stream.
465 for event := range eventChan {
466 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467 if errors.Is(processErr, context.Canceled) {
468 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469 } else {
470 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
471 }
472 return assistantMsg, nil, processErr
473 }
474 if ctx.Err() != nil {
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 return assistantMsg, nil, ctx.Err()
477 }
478 }
479
480 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
481 toolCalls := assistantMsg.ToolCalls()
482 for i, toolCall := range toolCalls {
483 select {
484 case <-ctx.Done():
485 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
486 // Make all future tool calls cancelled
487 for j := i; j < len(toolCalls); j++ {
488 toolResults[j] = message.ToolResult{
489 ToolCallID: toolCalls[j].ID,
490 Content: "Tool execution canceled by user",
491 IsError: true,
492 }
493 }
494 goto out
495 default:
496 // Continue processing
497 var tool tools.BaseTool
498 for availableTool := range a.tools.Seq() {
499 if availableTool.Info().Name == toolCall.Name {
500 tool = availableTool
501 break
502 }
503 }
504
505 // Tool not found
506 if tool == nil {
507 toolResults[i] = message.ToolResult{
508 ToolCallID: toolCall.ID,
509 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
510 IsError: true,
511 }
512 continue
513 }
514
515 // Run tool in goroutine to allow cancellation
516 type toolExecResult struct {
517 response tools.ToolResponse
518 err error
519 }
520 resultChan := make(chan toolExecResult, 1)
521
522 go func() {
523 response, err := tool.Run(ctx, tools.ToolCall{
524 ID: toolCall.ID,
525 Name: toolCall.Name,
526 Input: toolCall.Input,
527 })
528 resultChan <- toolExecResult{response: response, err: err}
529 }()
530
531 var toolResponse tools.ToolResponse
532 var toolErr error
533
534 select {
535 case <-ctx.Done():
536 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
537 // Mark remaining tool calls as cancelled
538 for j := i; j < len(toolCalls); j++ {
539 toolResults[j] = message.ToolResult{
540 ToolCallID: toolCalls[j].ID,
541 Content: "Tool execution canceled by user",
542 IsError: true,
543 }
544 }
545 goto out
546 case result := <-resultChan:
547 toolResponse = result.response
548 toolErr = result.err
549 }
550
551 if toolErr != nil {
552 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
553 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
554 toolResults[i] = message.ToolResult{
555 ToolCallID: toolCall.ID,
556 Content: "Permission denied",
557 IsError: true,
558 }
559 for j := i + 1; j < len(toolCalls); j++ {
560 toolResults[j] = message.ToolResult{
561 ToolCallID: toolCalls[j].ID,
562 Content: "Tool execution canceled by user",
563 IsError: true,
564 }
565 }
566 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
567 break
568 }
569 }
570 toolResults[i] = message.ToolResult{
571 ToolCallID: toolCall.ID,
572 Content: toolResponse.Content,
573 Metadata: toolResponse.Metadata,
574 IsError: toolResponse.IsError,
575 }
576 }
577 }
578out:
579 if len(toolResults) == 0 {
580 return assistantMsg, nil, nil
581 }
582 parts := make([]message.ContentPart, 0)
583 for _, tr := range toolResults {
584 parts = append(parts, tr)
585 }
586 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
587 Role: message.Tool,
588 Parts: parts,
589 Provider: a.providerID,
590 })
591 if err != nil {
592 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
593 }
594
595 return assistantMsg, &msg, err
596}
597
598func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
599 msg.AddFinish(finishReason, message, details)
600 _ = a.messages.Update(ctx, *msg)
601}
602
603func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
604 select {
605 case <-ctx.Done():
606 return ctx.Err()
607 default:
608 // Continue processing.
609 }
610
611 switch event.Type {
612 case provider.EventThinkingDelta:
613 assistantMsg.AppendReasoningContent(event.Thinking)
614 return a.messages.Update(ctx, *assistantMsg)
615 case provider.EventSignatureDelta:
616 assistantMsg.AppendReasoningSignature(event.Signature)
617 return a.messages.Update(ctx, *assistantMsg)
618 case provider.EventContentDelta:
619 assistantMsg.FinishThinking()
620 assistantMsg.AppendContent(event.Content)
621 return a.messages.Update(ctx, *assistantMsg)
622 case provider.EventToolUseStart:
623 assistantMsg.FinishThinking()
624 slog.Info("Tool call started", "toolCall", event.ToolCall)
625 assistantMsg.AddToolCall(*event.ToolCall)
626 return a.messages.Update(ctx, *assistantMsg)
627 case provider.EventToolUseDelta:
628 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
629 return a.messages.Update(ctx, *assistantMsg)
630 case provider.EventToolUseStop:
631 slog.Info("Finished tool call", "toolCall", event.ToolCall)
632 assistantMsg.FinishToolCall(event.ToolCall.ID)
633 return a.messages.Update(ctx, *assistantMsg)
634 case provider.EventError:
635 return event.Error
636 case provider.EventComplete:
637 assistantMsg.FinishThinking()
638 assistantMsg.SetToolCalls(event.Response.ToolCalls)
639 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
640 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
641 return fmt.Errorf("failed to update message: %w", err)
642 }
643 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
644 }
645
646 return nil
647}
648
649func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
650 sess, err := a.sessions.Get(ctx, sessionID)
651 if err != nil {
652 return fmt.Errorf("failed to get session: %w", err)
653 }
654
655 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
656 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
657 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
658 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
659
660 sess.Cost += cost
661 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
662 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
663
664 _, err = a.sessions.Save(ctx, sess)
665 if err != nil {
666 return fmt.Errorf("failed to save session: %w", err)
667 }
668 return nil
669}
670
671func (a *agent) Summarize(ctx context.Context, sessionID string) error {
672 if a.summarizeProvider == nil {
673 return fmt.Errorf("summarize provider not available")
674 }
675
676 // Check if session is busy
677 if a.IsSessionBusy(sessionID) {
678 return ErrSessionBusy
679 }
680
681 // Create a new context with cancellation
682 summarizeCtx, cancel := context.WithCancel(ctx)
683
684 // Store the cancel function in activeRequests to allow cancellation
685 a.activeRequests.Set(sessionID+"-summarize", cancel)
686
687 go func() {
688 defer a.activeRequests.Del(sessionID + "-summarize")
689 defer cancel()
690 event := AgentEvent{
691 Type: AgentEventTypeSummarize,
692 Progress: "Starting summarization...",
693 }
694
695 a.Publish(pubsub.CreatedEvent, event)
696 // Get all messages from the session
697 msgs, err := a.messages.List(summarizeCtx, sessionID)
698 if err != nil {
699 event = AgentEvent{
700 Type: AgentEventTypeError,
701 Error: fmt.Errorf("failed to list messages: %w", err),
702 Done: true,
703 }
704 a.Publish(pubsub.CreatedEvent, event)
705 return
706 }
707 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
708
709 if len(msgs) == 0 {
710 event = AgentEvent{
711 Type: AgentEventTypeError,
712 Error: fmt.Errorf("no messages to summarize"),
713 Done: true,
714 }
715 a.Publish(pubsub.CreatedEvent, event)
716 return
717 }
718
719 event = AgentEvent{
720 Type: AgentEventTypeSummarize,
721 Progress: "Analyzing conversation...",
722 }
723 a.Publish(pubsub.CreatedEvent, event)
724
725 // Add a system message to guide the summarization
726 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
727
728 // Create a new message with the summarize prompt
729 promptMsg := message.Message{
730 Role: message.User,
731 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
732 }
733
734 // Append the prompt to the messages
735 msgsWithPrompt := append(msgs, promptMsg)
736
737 event = AgentEvent{
738 Type: AgentEventTypeSummarize,
739 Progress: "Generating summary...",
740 }
741
742 a.Publish(pubsub.CreatedEvent, event)
743
744 // Send the messages to the summarize provider
745 response := a.summarizeProvider.StreamResponse(
746 summarizeCtx,
747 msgsWithPrompt,
748 nil,
749 )
750 var finalResponse *provider.ProviderResponse
751 for r := range response {
752 if r.Error != nil {
753 event = AgentEvent{
754 Type: AgentEventTypeError,
755 Error: fmt.Errorf("failed to summarize: %w", err),
756 Done: true,
757 }
758 a.Publish(pubsub.CreatedEvent, event)
759 return
760 }
761 finalResponse = r.Response
762 }
763
764 summary := strings.TrimSpace(finalResponse.Content)
765 if summary == "" {
766 event = AgentEvent{
767 Type: AgentEventTypeError,
768 Error: fmt.Errorf("empty summary returned"),
769 Done: true,
770 }
771 a.Publish(pubsub.CreatedEvent, event)
772 return
773 }
774 shell := shell.GetPersistentShell(config.Get().WorkingDir())
775 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
776 event = AgentEvent{
777 Type: AgentEventTypeSummarize,
778 Progress: "Creating new session...",
779 }
780
781 a.Publish(pubsub.CreatedEvent, event)
782 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
783 if err != nil {
784 event = AgentEvent{
785 Type: AgentEventTypeError,
786 Error: fmt.Errorf("failed to get session: %w", err),
787 Done: true,
788 }
789
790 a.Publish(pubsub.CreatedEvent, event)
791 return
792 }
793 // Create a message in the new session with the summary
794 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
795 Role: message.Assistant,
796 Parts: []message.ContentPart{
797 message.TextContent{Text: summary},
798 message.Finish{
799 Reason: message.FinishReasonEndTurn,
800 Time: time.Now().Unix(),
801 },
802 },
803 Model: a.summarizeProvider.Model().ID,
804 Provider: a.summarizeProviderID,
805 })
806 if err != nil {
807 event = AgentEvent{
808 Type: AgentEventTypeError,
809 Error: fmt.Errorf("failed to create summary message: %w", err),
810 Done: true,
811 }
812
813 a.Publish(pubsub.CreatedEvent, event)
814 return
815 }
816 oldSession.SummaryMessageID = msg.ID
817 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
818 oldSession.PromptTokens = 0
819 model := a.summarizeProvider.Model()
820 usage := finalResponse.Usage
821 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
822 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
823 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
824 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
825 oldSession.Cost += cost
826 _, err = a.sessions.Save(summarizeCtx, oldSession)
827 if err != nil {
828 event = AgentEvent{
829 Type: AgentEventTypeError,
830 Error: fmt.Errorf("failed to save session: %w", err),
831 Done: true,
832 }
833 a.Publish(pubsub.CreatedEvent, event)
834 }
835
836 event = AgentEvent{
837 Type: AgentEventTypeSummarize,
838 SessionID: oldSession.ID,
839 Progress: "Summary complete",
840 Done: true,
841 }
842 a.Publish(pubsub.CreatedEvent, event)
843 // Send final success event with the new session ID
844 }()
845
846 return nil
847}
848
849func (a *agent) CancelAll() {
850 if !a.IsBusy() {
851 return
852 }
853 for key := range a.activeRequests.Seq2() {
854 a.Cancel(key) // key is sessionID
855 }
856
857 timeout := time.After(5 * time.Second)
858 for a.IsBusy() {
859 select {
860 case <-timeout:
861 return
862 default:
863 time.Sleep(200 * time.Millisecond)
864 }
865 }
866}
867
868func (a *agent) UpdateModel() error {
869 cfg := config.Get()
870
871 // Get current provider configuration
872 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
873 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
874 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
875 }
876
877 // Check if provider has changed
878 if string(currentProviderCfg.ID) != a.providerID {
879 // Provider changed, need to recreate the main provider
880 model := cfg.GetModelByType(a.agentCfg.Model)
881 if model.ID == "" {
882 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
883 }
884
885 promptID := agentPromptMap[a.agentCfg.ID]
886 if promptID == "" {
887 promptID = prompt.PromptDefault
888 }
889
890 opts := []provider.ProviderClientOption{
891 provider.WithModel(a.agentCfg.Model),
892 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
893 }
894
895 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
896 if err != nil {
897 return fmt.Errorf("failed to create new provider: %w", err)
898 }
899
900 // Update the provider and provider ID
901 a.provider = newProvider
902 a.providerID = string(currentProviderCfg.ID)
903 }
904
905 // Check if small model provider has changed (affects title and summarize providers)
906 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
907 var smallModelProviderCfg config.ProviderConfig
908
909 for p := range cfg.Providers.Seq() {
910 if p.ID == smallModelCfg.Provider {
911 smallModelProviderCfg = p
912 break
913 }
914 }
915
916 if smallModelProviderCfg.ID == "" {
917 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
918 }
919
920 // Check if summarize provider has changed
921 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
922 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
923 if smallModel == nil {
924 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
925 }
926
927 // Recreate title provider
928 titleOpts := []provider.ProviderClientOption{
929 provider.WithModel(config.SelectedModelTypeSmall),
930 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
931 // We want the title to be short, so we limit the max tokens
932 provider.WithMaxTokens(40),
933 }
934 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
935 if err != nil {
936 return fmt.Errorf("failed to create new title provider: %w", err)
937 }
938
939 // Recreate summarize provider
940 summarizeOpts := []provider.ProviderClientOption{
941 provider.WithModel(config.SelectedModelTypeSmall),
942 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
943 }
944 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
945 if err != nil {
946 return fmt.Errorf("failed to create new summarize provider: %w", err)
947 }
948
949 // Update the providers and provider ID
950 a.titleProvider = newTitleProvider
951 a.summarizeProvider = newSummarizeProvider
952 a.summarizeProviderID = string(smallModelProviderCfg.ID)
953 }
954
955 return nil
956}