1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162
163 summarizeOpts := []provider.ProviderClientOption{
164 provider.WithModel(config.SelectedModelTypeLarge),
165 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
166 }
167 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
168 if err != nil {
169 return nil, err
170 }
171
172 toolFn := func() []tools.BaseTool {
173 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
174 defer func() {
175 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
176 }()
177
178 cwd := cfg.WorkingDir()
179 allTools := []tools.BaseTool{
180 tools.NewBashTool(permissions, cwd),
181 tools.NewDownloadTool(permissions, cwd),
182 tools.NewEditTool(lspClients, permissions, history, cwd),
183 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
184 tools.NewFetchTool(permissions, cwd),
185 tools.NewGlobTool(cwd),
186 tools.NewGrepTool(cwd),
187 tools.NewLsTool(permissions, cwd),
188 tools.NewSourcegraphTool(),
189 tools.NewViewTool(lspClients, permissions, cwd),
190 tools.NewWriteTool(lspClients, permissions, history, cwd),
191 }
192
193 mcpToolsOnce.Do(func() {
194 mcpTools = doGetMCPTools(ctx, permissions, cfg)
195 })
196 allTools = append(allTools, mcpTools...)
197
198 if len(lspClients) > 0 {
199 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
200 }
201
202 if agentTool != nil {
203 allTools = append(allTools, agentTool)
204 }
205
206 if agentCfg.AllowedTools == nil {
207 return allTools
208 }
209
210 var filteredTools []tools.BaseTool
211 for _, tool := range allTools {
212 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
213 filteredTools = append(filteredTools, tool)
214 }
215 }
216 return filteredTools
217 }
218
219 return &agent{
220 Broker: pubsub.NewBroker[AgentEvent](),
221 agentCfg: agentCfg,
222 provider: agentProvider,
223 providerID: string(providerCfg.ID),
224 messages: messages,
225 sessions: sessions,
226 titleProvider: titleProvider,
227 summarizeProvider: summarizeProvider,
228 summarizeProviderID: string(providerCfg.ID),
229 activeRequests: csync.NewMap[string, context.CancelFunc](),
230 tools: csync.NewLazySlice(toolFn),
231 }, nil
232}
233
234func (a *agent) Model() catwalk.Model {
235 return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
241 slog.Info("Request cancellation initiated", "session_id", sessionID)
242 cancel()
243 }
244
245 // Also check for summarize requests
246 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
247 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250}
251
252func (a *agent) IsBusy() bool {
253 var busy bool
254 for cancelFunc := range a.activeRequests.Seq() {
255 if cancelFunc != nil {
256 busy = true
257 break
258 }
259 }
260 return busy
261}
262
263func (a *agent) IsSessionBusy(sessionID string) bool {
264 _, busy := a.activeRequests.Get(sessionID)
265 return busy
266}
267
268func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
269 if content == "" {
270 return nil
271 }
272 if a.titleProvider == nil {
273 return nil
274 }
275 session, err := a.sessions.Get(ctx, sessionID)
276 if err != nil {
277 return err
278 }
279 parts := []message.ContentPart{message.TextContent{
280 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
281 }}
282
283 // Use streaming approach like summarization
284 response := a.titleProvider.StreamResponse(
285 ctx,
286 []message.Message{
287 {
288 Role: message.User,
289 Parts: parts,
290 },
291 },
292 nil,
293 )
294
295 var finalResponse *provider.ProviderResponse
296 for r := range response {
297 if r.Error != nil {
298 return r.Error
299 }
300 finalResponse = r.Response
301 }
302
303 if finalResponse == nil {
304 return fmt.Errorf("no response received from title provider")
305 }
306
307 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
308 if title == "" {
309 return nil
310 }
311
312 session.Title = title
313 _, err = a.sessions.Save(ctx, session)
314 return err
315}
316
317func (a *agent) err(err error) AgentEvent {
318 return AgentEvent{
319 Type: AgentEventTypeError,
320 Error: err,
321 }
322}
323
324func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
325 if !a.Model().SupportsImages && attachments != nil {
326 attachments = nil
327 }
328 events := make(chan AgentEvent)
329 if a.IsSessionBusy(sessionID) {
330 return nil, ErrSessionBusy
331 }
332
333 genCtx, cancel := context.WithCancel(ctx)
334
335 a.activeRequests.Set(sessionID, cancel)
336 go func() {
337 slog.Debug("Request started", "sessionID", sessionID)
338 defer log.RecoverPanic("agent.Run", func() {
339 events <- a.err(fmt.Errorf("panic while running the agent"))
340 })
341 var attachmentParts []message.ContentPart
342 for _, attachment := range attachments {
343 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
344 }
345 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
346 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
347 slog.Error(result.Error.Error())
348 }
349 slog.Debug("Request completed", "sessionID", sessionID)
350 a.activeRequests.Del(sessionID)
351 cancel()
352 a.Publish(pubsub.CreatedEvent, result)
353 events <- result
354 close(events)
355 }()
356 return events, nil
357}
358
359func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
360 cfg := config.Get()
361 // List existing messages; if none, start title generation asynchronously.
362 msgs, err := a.messages.List(ctx, sessionID)
363 if err != nil {
364 return a.err(fmt.Errorf("failed to list messages: %w", err))
365 }
366 if len(msgs) == 0 {
367 go func() {
368 defer log.RecoverPanic("agent.Run", func() {
369 slog.Error("panic while generating title")
370 })
371 titleErr := a.generateTitle(context.Background(), sessionID, content)
372 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373 slog.Error("failed to generate title", "error", titleErr)
374 }
375 }()
376 }
377 session, err := a.sessions.Get(ctx, sessionID)
378 if err != nil {
379 return a.err(fmt.Errorf("failed to get session: %w", err))
380 }
381 if session.SummaryMessageID != "" {
382 summaryMsgInex := -1
383 for i, msg := range msgs {
384 if msg.ID == session.SummaryMessageID {
385 summaryMsgInex = i
386 break
387 }
388 }
389 if summaryMsgInex != -1 {
390 msgs = msgs[summaryMsgInex:]
391 msgs[0].Role = message.User
392 }
393 }
394
395 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to create user message: %w", err))
398 }
399 // Append the new user message to the conversation history.
400 msgHistory := append(msgs, userMsg)
401
402 for {
403 // Check for cancellation before each iteration
404 select {
405 case <-ctx.Done():
406 return a.err(ctx.Err())
407 default:
408 // Continue processing
409 }
410 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411 if err != nil {
412 if errors.Is(err, context.Canceled) {
413 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414 a.messages.Update(context.Background(), agentMessage)
415 return a.err(ErrRequestCancelled)
416 }
417 return a.err(fmt.Errorf("failed to process events: %w", err))
418 }
419 if cfg.Options.Debug {
420 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421 }
422 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
423 // We are not done, we need to respond with the tool response
424 msgHistory = append(msgHistory, agentMessage, *toolResults)
425 continue
426 }
427 if agentMessage.FinishReason() == "" {
428 // Kujtim: could not track down where this is happening but this means its cancelled
429 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
430 _ = a.messages.Update(context.Background(), agentMessage)
431 return a.err(ErrRequestCancelled)
432 }
433 return AgentEvent{
434 Type: AgentEventTypeResponse,
435 Message: agentMessage,
436 Done: true,
437 }
438 }
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442 parts := []message.ContentPart{message.TextContent{Text: content}}
443 parts = append(parts, attachmentParts...)
444 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445 Role: message.User,
446 Parts: parts,
447 })
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
453
454 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
455 Role: message.Assistant,
456 Parts: []message.ContentPart{},
457 Model: a.Model().ID,
458 Provider: a.providerID,
459 })
460 if err != nil {
461 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
462 }
463
464 // Add the session and message ID into the context if needed by tools.
465 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
466
467 // Process each event in the stream.
468 for event := range eventChan {
469 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
470 if errors.Is(processErr, context.Canceled) {
471 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
472 } else {
473 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
474 }
475 return assistantMsg, nil, processErr
476 }
477 if ctx.Err() != nil {
478 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
479 return assistantMsg, nil, ctx.Err()
480 }
481 }
482
483 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484 toolCalls := assistantMsg.ToolCalls()
485 for i, toolCall := range toolCalls {
486 select {
487 case <-ctx.Done():
488 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
489 // Make all future tool calls cancelled
490 for j := i; j < len(toolCalls); j++ {
491 toolResults[j] = message.ToolResult{
492 ToolCallID: toolCalls[j].ID,
493 Content: "Tool execution canceled by user",
494 IsError: true,
495 }
496 }
497 goto out
498 default:
499 // Continue processing
500 var tool tools.BaseTool
501 for availableTool := range a.tools.Seq() {
502 if availableTool.Info().Name == toolCall.Name {
503 tool = availableTool
504 break
505 }
506 }
507
508 // Tool not found
509 if tool == nil {
510 toolResults[i] = message.ToolResult{
511 ToolCallID: toolCall.ID,
512 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
513 IsError: true,
514 }
515 continue
516 }
517
518 // Run tool in goroutine to allow cancellation
519 type toolExecResult struct {
520 response tools.ToolResponse
521 err error
522 }
523 resultChan := make(chan toolExecResult, 1)
524
525 go func() {
526 response, err := tool.Run(ctx, tools.ToolCall{
527 ID: toolCall.ID,
528 Name: toolCall.Name,
529 Input: toolCall.Input,
530 })
531 resultChan <- toolExecResult{response: response, err: err}
532 }()
533
534 var toolResponse tools.ToolResponse
535 var toolErr error
536
537 select {
538 case <-ctx.Done():
539 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
540 // Mark remaining tool calls as cancelled
541 for j := i; j < len(toolCalls); j++ {
542 toolResults[j] = message.ToolResult{
543 ToolCallID: toolCalls[j].ID,
544 Content: "Tool execution canceled by user",
545 IsError: true,
546 }
547 }
548 goto out
549 case result := <-resultChan:
550 toolResponse = result.response
551 toolErr = result.err
552 }
553
554 if toolErr != nil {
555 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
556 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
557 toolResults[i] = message.ToolResult{
558 ToolCallID: toolCall.ID,
559 Content: "Permission denied",
560 IsError: true,
561 }
562 for j := i + 1; j < len(toolCalls); j++ {
563 toolResults[j] = message.ToolResult{
564 ToolCallID: toolCalls[j].ID,
565 Content: "Tool execution canceled by user",
566 IsError: true,
567 }
568 }
569 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
570 break
571 }
572 }
573 toolResults[i] = message.ToolResult{
574 ToolCallID: toolCall.ID,
575 Content: toolResponse.Content,
576 Metadata: toolResponse.Metadata,
577 IsError: toolResponse.IsError,
578 }
579 }
580 }
581out:
582 if len(toolResults) == 0 {
583 return assistantMsg, nil, nil
584 }
585 parts := make([]message.ContentPart, 0)
586 for _, tr := range toolResults {
587 parts = append(parts, tr)
588 }
589 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
590 Role: message.Tool,
591 Parts: parts,
592 Provider: a.providerID,
593 })
594 if err != nil {
595 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
596 }
597
598 return assistantMsg, &msg, err
599}
600
601func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
602 msg.AddFinish(finishReason, message, details)
603 _ = a.messages.Update(ctx, *msg)
604}
605
606func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
607 select {
608 case <-ctx.Done():
609 return ctx.Err()
610 default:
611 // Continue processing.
612 }
613
614 switch event.Type {
615 case provider.EventThinkingDelta:
616 assistantMsg.AppendReasoningContent(event.Thinking)
617 return a.messages.Update(ctx, *assistantMsg)
618 case provider.EventSignatureDelta:
619 assistantMsg.AppendReasoningSignature(event.Signature)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventContentDelta:
622 assistantMsg.FinishThinking()
623 assistantMsg.AppendContent(event.Content)
624 return a.messages.Update(ctx, *assistantMsg)
625 case provider.EventToolUseStart:
626 assistantMsg.FinishThinking()
627 slog.Info("Tool call started", "toolCall", event.ToolCall)
628 assistantMsg.AddToolCall(*event.ToolCall)
629 return a.messages.Update(ctx, *assistantMsg)
630 case provider.EventToolUseDelta:
631 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
632 return a.messages.Update(ctx, *assistantMsg)
633 case provider.EventToolUseStop:
634 slog.Info("Finished tool call", "toolCall", event.ToolCall)
635 assistantMsg.FinishToolCall(event.ToolCall.ID)
636 return a.messages.Update(ctx, *assistantMsg)
637 case provider.EventError:
638 return event.Error
639 case provider.EventComplete:
640 assistantMsg.FinishThinking()
641 assistantMsg.SetToolCalls(event.Response.ToolCalls)
642 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
643 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
644 return fmt.Errorf("failed to update message: %w", err)
645 }
646 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
647 }
648
649 return nil
650}
651
652func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
653 sess, err := a.sessions.Get(ctx, sessionID)
654 if err != nil {
655 return fmt.Errorf("failed to get session: %w", err)
656 }
657
658 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
659 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
660 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
661 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
662
663 sess.Cost += cost
664 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
665 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
666
667 _, err = a.sessions.Save(ctx, sess)
668 if err != nil {
669 return fmt.Errorf("failed to save session: %w", err)
670 }
671 return nil
672}
673
674func (a *agent) Summarize(ctx context.Context, sessionID string) error {
675 if a.summarizeProvider == nil {
676 return fmt.Errorf("summarize provider not available")
677 }
678
679 // Check if session is busy
680 if a.IsSessionBusy(sessionID) {
681 return ErrSessionBusy
682 }
683
684 // Create a new context with cancellation
685 summarizeCtx, cancel := context.WithCancel(ctx)
686
687 // Store the cancel function in activeRequests to allow cancellation
688 a.activeRequests.Set(sessionID+"-summarize", cancel)
689
690 go func() {
691 defer a.activeRequests.Del(sessionID + "-summarize")
692 defer cancel()
693 event := AgentEvent{
694 Type: AgentEventTypeSummarize,
695 Progress: "Starting summarization...",
696 }
697
698 a.Publish(pubsub.CreatedEvent, event)
699 // Get all messages from the session
700 msgs, err := a.messages.List(summarizeCtx, sessionID)
701 if err != nil {
702 event = AgentEvent{
703 Type: AgentEventTypeError,
704 Error: fmt.Errorf("failed to list messages: %w", err),
705 Done: true,
706 }
707 a.Publish(pubsub.CreatedEvent, event)
708 return
709 }
710 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
711
712 if len(msgs) == 0 {
713 event = AgentEvent{
714 Type: AgentEventTypeError,
715 Error: fmt.Errorf("no messages to summarize"),
716 Done: true,
717 }
718 a.Publish(pubsub.CreatedEvent, event)
719 return
720 }
721
722 event = AgentEvent{
723 Type: AgentEventTypeSummarize,
724 Progress: "Analyzing conversation...",
725 }
726 a.Publish(pubsub.CreatedEvent, event)
727
728 // Add a system message to guide the summarization
729 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
730
731 // Create a new message with the summarize prompt
732 promptMsg := message.Message{
733 Role: message.User,
734 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
735 }
736
737 // Append the prompt to the messages
738 msgsWithPrompt := append(msgs, promptMsg)
739
740 event = AgentEvent{
741 Type: AgentEventTypeSummarize,
742 Progress: "Generating summary...",
743 }
744
745 a.Publish(pubsub.CreatedEvent, event)
746
747 // Send the messages to the summarize provider
748 response := a.summarizeProvider.StreamResponse(
749 summarizeCtx,
750 msgsWithPrompt,
751 nil,
752 )
753 var finalResponse *provider.ProviderResponse
754 for r := range response {
755 if r.Error != nil {
756 event = AgentEvent{
757 Type: AgentEventTypeError,
758 Error: fmt.Errorf("failed to summarize: %w", err),
759 Done: true,
760 }
761 a.Publish(pubsub.CreatedEvent, event)
762 return
763 }
764 finalResponse = r.Response
765 }
766
767 summary := strings.TrimSpace(finalResponse.Content)
768 if summary == "" {
769 event = AgentEvent{
770 Type: AgentEventTypeError,
771 Error: fmt.Errorf("empty summary returned"),
772 Done: true,
773 }
774 a.Publish(pubsub.CreatedEvent, event)
775 return
776 }
777 shell := shell.GetPersistentShell(config.Get().WorkingDir())
778 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
779 event = AgentEvent{
780 Type: AgentEventTypeSummarize,
781 Progress: "Creating new session...",
782 }
783
784 a.Publish(pubsub.CreatedEvent, event)
785 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
786 if err != nil {
787 event = AgentEvent{
788 Type: AgentEventTypeError,
789 Error: fmt.Errorf("failed to get session: %w", err),
790 Done: true,
791 }
792
793 a.Publish(pubsub.CreatedEvent, event)
794 return
795 }
796 // Create a message in the new session with the summary
797 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
798 Role: message.Assistant,
799 Parts: []message.ContentPart{
800 message.TextContent{Text: summary},
801 message.Finish{
802 Reason: message.FinishReasonEndTurn,
803 Time: time.Now().Unix(),
804 },
805 },
806 Model: a.summarizeProvider.Model().ID,
807 Provider: a.summarizeProviderID,
808 })
809 if err != nil {
810 event = AgentEvent{
811 Type: AgentEventTypeError,
812 Error: fmt.Errorf("failed to create summary message: %w", err),
813 Done: true,
814 }
815
816 a.Publish(pubsub.CreatedEvent, event)
817 return
818 }
819 oldSession.SummaryMessageID = msg.ID
820 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
821 oldSession.PromptTokens = 0
822 model := a.summarizeProvider.Model()
823 usage := finalResponse.Usage
824 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
825 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
826 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
827 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
828 oldSession.Cost += cost
829 _, err = a.sessions.Save(summarizeCtx, oldSession)
830 if err != nil {
831 event = AgentEvent{
832 Type: AgentEventTypeError,
833 Error: fmt.Errorf("failed to save session: %w", err),
834 Done: true,
835 }
836 a.Publish(pubsub.CreatedEvent, event)
837 }
838
839 event = AgentEvent{
840 Type: AgentEventTypeSummarize,
841 SessionID: oldSession.ID,
842 Progress: "Summary complete",
843 Done: true,
844 }
845 a.Publish(pubsub.CreatedEvent, event)
846 // Send final success event with the new session ID
847 }()
848
849 return nil
850}
851
852func (a *agent) CancelAll() {
853 if !a.IsBusy() {
854 return
855 }
856 for key := range a.activeRequests.Seq2() {
857 a.Cancel(key) // key is sessionID
858 }
859
860 timeout := time.After(5 * time.Second)
861 for a.IsBusy() {
862 select {
863 case <-timeout:
864 return
865 default:
866 time.Sleep(200 * time.Millisecond)
867 }
868 }
869}
870
871func (a *agent) UpdateModel() error {
872 cfg := config.Get()
873
874 // Get current provider configuration
875 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
876 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
877 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
878 }
879
880 // Check if provider has changed
881 if string(currentProviderCfg.ID) != a.providerID {
882 // Provider changed, need to recreate the main provider
883 model := cfg.GetModelByType(a.agentCfg.Model)
884 if model.ID == "" {
885 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
886 }
887
888 promptID := agentPromptMap[a.agentCfg.ID]
889 if promptID == "" {
890 promptID = prompt.PromptDefault
891 }
892
893 opts := []provider.ProviderClientOption{
894 provider.WithModel(a.agentCfg.Model),
895 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
896 }
897
898 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
899 if err != nil {
900 return fmt.Errorf("failed to create new provider: %w", err)
901 }
902
903 // Update the provider and provider ID
904 a.provider = newProvider
905 a.providerID = string(currentProviderCfg.ID)
906 }
907
908 // Check if providers have changed for title (small) and summarize (large)
909 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
910 var smallModelProviderCfg config.ProviderConfig
911 for p := range cfg.Providers.Seq() {
912 if p.ID == smallModelCfg.Provider {
913 smallModelProviderCfg = p
914 break
915 }
916 }
917 if smallModelProviderCfg.ID == "" {
918 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
919 }
920
921 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
922 var largeModelProviderCfg config.ProviderConfig
923 for p := range cfg.Providers.Seq() {
924 if p.ID == largeModelCfg.Provider {
925 largeModelProviderCfg = p
926 break
927 }
928 }
929 if largeModelProviderCfg.ID == "" {
930 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
931 }
932
933 // Recreate title provider
934 titleOpts := []provider.ProviderClientOption{
935 provider.WithModel(config.SelectedModelTypeSmall),
936 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
937 provider.WithMaxTokens(40),
938 }
939 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
940 if err != nil {
941 return fmt.Errorf("failed to create new title provider: %w", err)
942 }
943 a.titleProvider = newTitleProvider
944
945 // Recreate summarize provider if provider changed (now large model)
946 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
947 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
948 if largeModel == nil {
949 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
950 }
951 summarizeOpts := []provider.ProviderClientOption{
952 provider.WithModel(config.SelectedModelTypeLarge),
953 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
954 }
955 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
956 if err != nil {
957 return fmt.Errorf("failed to create new summarize provider: %w", err)
958 }
959 a.summarizeProvider = newSummarizeProvider
960 a.summarizeProviderID = string(largeModelProviderCfg.ID)
961 }
962
963 return nil
964}