1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162
163 summarizeOpts := []provider.ProviderClientOption{
164 provider.WithModel(config.SelectedModelTypeLarge),
165 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
166 }
167 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
168 if err != nil {
169 return nil, err
170 }
171
172 toolFn := func() []tools.BaseTool {
173 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
174 defer func() {
175 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
176 }()
177
178 cwd := cfg.WorkingDir()
179 allTools := []tools.BaseTool{
180 tools.NewBashTool(permissions, cwd),
181 tools.NewDownloadTool(permissions, cwd),
182 tools.NewEditTool(lspClients, permissions, history, cwd),
183 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
184 tools.NewFetchTool(permissions, cwd),
185 tools.NewGlobTool(cwd),
186 tools.NewGrepTool(cwd),
187 tools.NewLsTool(permissions, cwd),
188 tools.NewSourcegraphTool(),
189 tools.NewViewTool(lspClients, permissions, cwd),
190 tools.NewWriteTool(lspClients, permissions, history, cwd),
191 }
192
193 mcpToolsOnce.Do(func() {
194 mcpTools = doGetMCPTools(ctx, permissions, cfg)
195 })
196 allTools = append(allTools, mcpTools...)
197
198 if len(lspClients) > 0 {
199 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
200 }
201
202 if agentTool != nil {
203 allTools = append(allTools, agentTool)
204 }
205
206 if agentCfg.AllowedTools == nil {
207 return allTools
208 }
209
210 var filteredTools []tools.BaseTool
211 for _, tool := range allTools {
212 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
213 filteredTools = append(filteredTools, tool)
214 }
215 }
216 return filteredTools
217 }
218
219 return &agent{
220 Broker: pubsub.NewBroker[AgentEvent](),
221 agentCfg: agentCfg,
222 provider: agentProvider,
223 providerID: string(providerCfg.ID),
224 messages: messages,
225 sessions: sessions,
226 titleProvider: titleProvider,
227 summarizeProvider: summarizeProvider,
228 summarizeProviderID: string(providerCfg.ID),
229 activeRequests: csync.NewMap[string, context.CancelFunc](),
230 tools: csync.NewLazySlice(toolFn),
231 }, nil
232}
233
234func (a *agent) Model() catwalk.Model {
235 return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
241 slog.Info("Request cancellation initiated", "session_id", sessionID)
242 cancel()
243 }
244
245 // Also check for summarize requests
246 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
247 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250}
251
252func (a *agent) IsBusy() bool {
253 var busy bool
254 for cancelFunc := range a.activeRequests.Seq() {
255 if cancelFunc != nil {
256 busy = true
257 break
258 }
259 }
260 return busy
261}
262
263func (a *agent) IsSessionBusy(sessionID string) bool {
264 _, busy := a.activeRequests.Get(sessionID)
265 return busy
266}
267
268func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
269 if content == "" {
270 return nil
271 }
272 if a.titleProvider == nil {
273 return nil
274 }
275 session, err := a.sessions.Get(ctx, sessionID)
276 if err != nil {
277 return err
278 }
279 parts := []message.ContentPart{message.TextContent{
280 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
281 }}
282
283 // Use streaming approach like summarization
284 response := a.titleProvider.StreamResponse(
285 ctx,
286 []message.Message{
287 {
288 Role: message.User,
289 Parts: parts,
290 },
291 },
292 nil,
293 )
294
295 var finalResponse *provider.ProviderResponse
296 for r := range response {
297 if r.Error != nil {
298 return r.Error
299 }
300 finalResponse = r.Response
301 }
302
303 if finalResponse == nil {
304 return fmt.Errorf("no response received from title provider")
305 }
306
307 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
308 if title == "" {
309 return nil
310 }
311
312 session.Title = title
313 _, err = a.sessions.Save(ctx, session)
314 return err
315}
316
317func (a *agent) err(err error) AgentEvent {
318 return AgentEvent{
319 Type: AgentEventTypeError,
320 Error: err,
321 }
322}
323
324func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
325 if !a.Model().SupportsImages && attachments != nil {
326 attachments = nil
327 }
328 events := make(chan AgentEvent)
329 if a.IsSessionBusy(sessionID) {
330 return nil, ErrSessionBusy
331 }
332
333 genCtx, cancel := context.WithCancel(ctx)
334
335 a.activeRequests.Set(sessionID, cancel)
336 go func() {
337 slog.Debug("Request started", "sessionID", sessionID)
338 defer log.RecoverPanic("agent.Run", func() {
339 events <- a.err(fmt.Errorf("panic while running the agent"))
340 })
341 var attachmentParts []message.ContentPart
342 for _, attachment := range attachments {
343 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
344 }
345 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
346 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
347 slog.Error(result.Error.Error())
348 }
349 slog.Debug("Request completed", "sessionID", sessionID)
350 a.activeRequests.Del(sessionID)
351 cancel()
352 a.Publish(pubsub.CreatedEvent, result)
353 events <- result
354 close(events)
355 }()
356 return events, nil
357}
358
359func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
360 cfg := config.Get()
361 // List existing messages; if none, start title generation asynchronously.
362 msgs, err := a.messages.List(ctx, sessionID)
363 if err != nil {
364 return a.err(fmt.Errorf("failed to list messages: %w", err))
365 }
366 if len(msgs) == 0 {
367 go func() {
368 defer log.RecoverPanic("agent.Run", func() {
369 slog.Error("panic while generating title")
370 })
371 titleErr := a.generateTitle(context.Background(), sessionID, content)
372 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373 slog.Error("failed to generate title", "error", titleErr)
374 }
375 }()
376 }
377 session, err := a.sessions.Get(ctx, sessionID)
378 if err != nil {
379 return a.err(fmt.Errorf("failed to get session: %w", err))
380 }
381 if session.SummaryMessageID != "" {
382 summaryMsgInex := -1
383 for i, msg := range msgs {
384 if msg.ID == session.SummaryMessageID {
385 summaryMsgInex = i
386 break
387 }
388 }
389 if summaryMsgInex != -1 {
390 msgs = msgs[summaryMsgInex:]
391 msgs[0].Role = message.User
392 }
393 }
394
395 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to create user message: %w", err))
398 }
399 // Append the new user message to the conversation history.
400 msgHistory := append(msgs, userMsg)
401
402 for {
403 // Check for cancellation before each iteration
404 select {
405 case <-ctx.Done():
406 return a.err(ctx.Err())
407 default:
408 // Continue processing
409 }
410 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411 if err != nil {
412 if errors.Is(err, context.Canceled) {
413 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414 a.messages.Update(context.Background(), agentMessage)
415 return a.err(ErrRequestCancelled)
416 }
417 return a.err(fmt.Errorf("failed to process events: %w", err))
418 }
419 if cfg.Options.Debug {
420 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421 }
422 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
423 // We are not done, we need to respond with the tool response
424 msgHistory = append(msgHistory, agentMessage, *toolResults)
425 continue
426 }
427 if agentMessage.FinishReason() == "" {
428 // Kujtim: could not track down where this is happening but this means its cancelled
429 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
430 _ = a.messages.Update(context.Background(), agentMessage)
431 return a.err(ErrRequestCancelled)
432 }
433 return AgentEvent{
434 Type: AgentEventTypeResponse,
435 Message: agentMessage,
436 Done: true,
437 }
438 }
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442 parts := []message.ContentPart{message.TextContent{Text: content}}
443 parts = append(parts, attachmentParts...)
444 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445 Role: message.User,
446 Parts: parts,
447 })
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452
453 // Create the assistant message first so the spinner shows immediately
454 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
455 Role: message.Assistant,
456 Parts: []message.ContentPart{},
457 Model: a.Model().ID,
458 Provider: a.providerID,
459 })
460 if err != nil {
461 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
462 }
463
464 // Now collect tools (which may block on MCP initialization)
465 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
466
467 // Add the session and message ID into the context if needed by tools.
468 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
469
470 // Process each event in the stream.
471 for event := range eventChan {
472 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
473 if errors.Is(processErr, context.Canceled) {
474 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475 } else {
476 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
477 }
478 return assistantMsg, nil, processErr
479 }
480 if ctx.Err() != nil {
481 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
482 return assistantMsg, nil, ctx.Err()
483 }
484 }
485
486 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
487 toolCalls := assistantMsg.ToolCalls()
488 for i, toolCall := range toolCalls {
489 select {
490 case <-ctx.Done():
491 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
492 // Make all future tool calls cancelled
493 for j := i; j < len(toolCalls); j++ {
494 toolResults[j] = message.ToolResult{
495 ToolCallID: toolCalls[j].ID,
496 Content: "Tool execution canceled by user",
497 IsError: true,
498 }
499 }
500 goto out
501 default:
502 // Continue processing
503 var tool tools.BaseTool
504 for availableTool := range a.tools.Seq() {
505 if availableTool.Info().Name == toolCall.Name {
506 tool = availableTool
507 break
508 }
509 }
510
511 // Tool not found
512 if tool == nil {
513 toolResults[i] = message.ToolResult{
514 ToolCallID: toolCall.ID,
515 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
516 IsError: true,
517 }
518 continue
519 }
520
521 // Run tool in goroutine to allow cancellation
522 type toolExecResult struct {
523 response tools.ToolResponse
524 err error
525 }
526 resultChan := make(chan toolExecResult, 1)
527
528 go func() {
529 response, err := tool.Run(ctx, tools.ToolCall{
530 ID: toolCall.ID,
531 Name: toolCall.Name,
532 Input: toolCall.Input,
533 })
534 resultChan <- toolExecResult{response: response, err: err}
535 }()
536
537 var toolResponse tools.ToolResponse
538 var toolErr error
539
540 select {
541 case <-ctx.Done():
542 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543 // Mark remaining tool calls as cancelled
544 for j := i; j < len(toolCalls); j++ {
545 toolResults[j] = message.ToolResult{
546 ToolCallID: toolCalls[j].ID,
547 Content: "Tool execution canceled by user",
548 IsError: true,
549 }
550 }
551 goto out
552 case result := <-resultChan:
553 toolResponse = result.response
554 toolErr = result.err
555 }
556
557 if toolErr != nil {
558 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
559 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
560 toolResults[i] = message.ToolResult{
561 ToolCallID: toolCall.ID,
562 Content: "Permission denied",
563 IsError: true,
564 }
565 for j := i + 1; j < len(toolCalls); j++ {
566 toolResults[j] = message.ToolResult{
567 ToolCallID: toolCalls[j].ID,
568 Content: "Tool execution canceled by user",
569 IsError: true,
570 }
571 }
572 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
573 break
574 }
575 }
576 toolResults[i] = message.ToolResult{
577 ToolCallID: toolCall.ID,
578 Content: toolResponse.Content,
579 Metadata: toolResponse.Metadata,
580 IsError: toolResponse.IsError,
581 }
582 }
583 }
584out:
585 if len(toolResults) == 0 {
586 return assistantMsg, nil, nil
587 }
588 parts := make([]message.ContentPart, 0)
589 for _, tr := range toolResults {
590 parts = append(parts, tr)
591 }
592 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
593 Role: message.Tool,
594 Parts: parts,
595 Provider: a.providerID,
596 })
597 if err != nil {
598 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
599 }
600
601 return assistantMsg, &msg, err
602}
603
604func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
605 msg.AddFinish(finishReason, message, details)
606 _ = a.messages.Update(ctx, *msg)
607}
608
609func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
610 select {
611 case <-ctx.Done():
612 return ctx.Err()
613 default:
614 // Continue processing.
615 }
616
617 switch event.Type {
618 case provider.EventThinkingDelta:
619 assistantMsg.AppendReasoningContent(event.Thinking)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventSignatureDelta:
622 assistantMsg.AppendReasoningSignature(event.Signature)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventContentDelta:
625 assistantMsg.FinishThinking()
626 assistantMsg.AppendContent(event.Content)
627 return a.messages.Update(ctx, *assistantMsg)
628 case provider.EventToolUseStart:
629 assistantMsg.FinishThinking()
630 slog.Info("Tool call started", "toolCall", event.ToolCall)
631 assistantMsg.AddToolCall(*event.ToolCall)
632 return a.messages.Update(ctx, *assistantMsg)
633 case provider.EventToolUseDelta:
634 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
635 return a.messages.Update(ctx, *assistantMsg)
636 case provider.EventToolUseStop:
637 slog.Info("Finished tool call", "toolCall", event.ToolCall)
638 assistantMsg.FinishToolCall(event.ToolCall.ID)
639 return a.messages.Update(ctx, *assistantMsg)
640 case provider.EventError:
641 return event.Error
642 case provider.EventComplete:
643 assistantMsg.FinishThinking()
644 assistantMsg.SetToolCalls(event.Response.ToolCalls)
645 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
646 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
647 return fmt.Errorf("failed to update message: %w", err)
648 }
649 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
650 }
651
652 return nil
653}
654
655func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
656 sess, err := a.sessions.Get(ctx, sessionID)
657 if err != nil {
658 return fmt.Errorf("failed to get session: %w", err)
659 }
660
661 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
662 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
663 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
664 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
665
666 sess.Cost += cost
667 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
668 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
669
670 _, err = a.sessions.Save(ctx, sess)
671 if err != nil {
672 return fmt.Errorf("failed to save session: %w", err)
673 }
674 return nil
675}
676
677func (a *agent) Summarize(ctx context.Context, sessionID string) error {
678 if a.summarizeProvider == nil {
679 return fmt.Errorf("summarize provider not available")
680 }
681
682 // Check if session is busy
683 if a.IsSessionBusy(sessionID) {
684 return ErrSessionBusy
685 }
686
687 // Create a new context with cancellation
688 summarizeCtx, cancel := context.WithCancel(ctx)
689
690 // Store the cancel function in activeRequests to allow cancellation
691 a.activeRequests.Set(sessionID+"-summarize", cancel)
692
693 go func() {
694 defer a.activeRequests.Del(sessionID + "-summarize")
695 defer cancel()
696 event := AgentEvent{
697 Type: AgentEventTypeSummarize,
698 Progress: "Starting summarization...",
699 }
700
701 a.Publish(pubsub.CreatedEvent, event)
702 // Get all messages from the session
703 msgs, err := a.messages.List(summarizeCtx, sessionID)
704 if err != nil {
705 event = AgentEvent{
706 Type: AgentEventTypeError,
707 Error: fmt.Errorf("failed to list messages: %w", err),
708 Done: true,
709 }
710 a.Publish(pubsub.CreatedEvent, event)
711 return
712 }
713 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
714
715 if len(msgs) == 0 {
716 event = AgentEvent{
717 Type: AgentEventTypeError,
718 Error: fmt.Errorf("no messages to summarize"),
719 Done: true,
720 }
721 a.Publish(pubsub.CreatedEvent, event)
722 return
723 }
724
725 event = AgentEvent{
726 Type: AgentEventTypeSummarize,
727 Progress: "Analyzing conversation...",
728 }
729 a.Publish(pubsub.CreatedEvent, event)
730
731 // Add a system message to guide the summarization
732 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
733
734 // Create a new message with the summarize prompt
735 promptMsg := message.Message{
736 Role: message.User,
737 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
738 }
739
740 // Append the prompt to the messages
741 msgsWithPrompt := append(msgs, promptMsg)
742
743 event = AgentEvent{
744 Type: AgentEventTypeSummarize,
745 Progress: "Generating summary...",
746 }
747
748 a.Publish(pubsub.CreatedEvent, event)
749
750 // Send the messages to the summarize provider
751 response := a.summarizeProvider.StreamResponse(
752 summarizeCtx,
753 msgsWithPrompt,
754 nil,
755 )
756 var finalResponse *provider.ProviderResponse
757 for r := range response {
758 if r.Error != nil {
759 event = AgentEvent{
760 Type: AgentEventTypeError,
761 Error: fmt.Errorf("failed to summarize: %w", err),
762 Done: true,
763 }
764 a.Publish(pubsub.CreatedEvent, event)
765 return
766 }
767 finalResponse = r.Response
768 }
769
770 summary := strings.TrimSpace(finalResponse.Content)
771 if summary == "" {
772 event = AgentEvent{
773 Type: AgentEventTypeError,
774 Error: fmt.Errorf("empty summary returned"),
775 Done: true,
776 }
777 a.Publish(pubsub.CreatedEvent, event)
778 return
779 }
780 shell := shell.GetPersistentShell(config.Get().WorkingDir())
781 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
782 event = AgentEvent{
783 Type: AgentEventTypeSummarize,
784 Progress: "Creating new session...",
785 }
786
787 a.Publish(pubsub.CreatedEvent, event)
788 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
789 if err != nil {
790 event = AgentEvent{
791 Type: AgentEventTypeError,
792 Error: fmt.Errorf("failed to get session: %w", err),
793 Done: true,
794 }
795
796 a.Publish(pubsub.CreatedEvent, event)
797 return
798 }
799 // Create a message in the new session with the summary
800 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
801 Role: message.Assistant,
802 Parts: []message.ContentPart{
803 message.TextContent{Text: summary},
804 message.Finish{
805 Reason: message.FinishReasonEndTurn,
806 Time: time.Now().Unix(),
807 },
808 },
809 Model: a.summarizeProvider.Model().ID,
810 Provider: a.summarizeProviderID,
811 })
812 if err != nil {
813 event = AgentEvent{
814 Type: AgentEventTypeError,
815 Error: fmt.Errorf("failed to create summary message: %w", err),
816 Done: true,
817 }
818
819 a.Publish(pubsub.CreatedEvent, event)
820 return
821 }
822 oldSession.SummaryMessageID = msg.ID
823 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
824 oldSession.PromptTokens = 0
825 model := a.summarizeProvider.Model()
826 usage := finalResponse.Usage
827 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
828 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
829 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
830 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
831 oldSession.Cost += cost
832 _, err = a.sessions.Save(summarizeCtx, oldSession)
833 if err != nil {
834 event = AgentEvent{
835 Type: AgentEventTypeError,
836 Error: fmt.Errorf("failed to save session: %w", err),
837 Done: true,
838 }
839 a.Publish(pubsub.CreatedEvent, event)
840 }
841
842 event = AgentEvent{
843 Type: AgentEventTypeSummarize,
844 SessionID: oldSession.ID,
845 Progress: "Summary complete",
846 Done: true,
847 }
848 a.Publish(pubsub.CreatedEvent, event)
849 // Send final success event with the new session ID
850 }()
851
852 return nil
853}
854
855func (a *agent) CancelAll() {
856 if !a.IsBusy() {
857 return
858 }
859 for key := range a.activeRequests.Seq2() {
860 a.Cancel(key) // key is sessionID
861 }
862
863 timeout := time.After(5 * time.Second)
864 for a.IsBusy() {
865 select {
866 case <-timeout:
867 return
868 default:
869 time.Sleep(200 * time.Millisecond)
870 }
871 }
872}
873
874func (a *agent) UpdateModel() error {
875 cfg := config.Get()
876
877 // Get current provider configuration
878 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
879 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
880 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
881 }
882
883 // Check if provider has changed
884 if string(currentProviderCfg.ID) != a.providerID {
885 // Provider changed, need to recreate the main provider
886 model := cfg.GetModelByType(a.agentCfg.Model)
887 if model.ID == "" {
888 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
889 }
890
891 promptID := agentPromptMap[a.agentCfg.ID]
892 if promptID == "" {
893 promptID = prompt.PromptDefault
894 }
895
896 opts := []provider.ProviderClientOption{
897 provider.WithModel(a.agentCfg.Model),
898 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
899 }
900
901 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
902 if err != nil {
903 return fmt.Errorf("failed to create new provider: %w", err)
904 }
905
906 // Update the provider and provider ID
907 a.provider = newProvider
908 a.providerID = string(currentProviderCfg.ID)
909 }
910
911 // Check if providers have changed for title (small) and summarize (large)
912 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
913 var smallModelProviderCfg config.ProviderConfig
914 for p := range cfg.Providers.Seq() {
915 if p.ID == smallModelCfg.Provider {
916 smallModelProviderCfg = p
917 break
918 }
919 }
920 if smallModelProviderCfg.ID == "" {
921 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
922 }
923
924 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
925 var largeModelProviderCfg config.ProviderConfig
926 for p := range cfg.Providers.Seq() {
927 if p.ID == largeModelCfg.Provider {
928 largeModelProviderCfg = p
929 break
930 }
931 }
932 if largeModelProviderCfg.ID == "" {
933 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
934 }
935
936 // Recreate title provider
937 titleOpts := []provider.ProviderClientOption{
938 provider.WithModel(config.SelectedModelTypeSmall),
939 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
940 provider.WithMaxTokens(40),
941 }
942 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
943 if err != nil {
944 return fmt.Errorf("failed to create new title provider: %w", err)
945 }
946 a.titleProvider = newTitleProvider
947
948 // Recreate summarize provider if provider changed (now large model)
949 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
950 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
951 if largeModel == nil {
952 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
953 }
954 summarizeOpts := []provider.ProviderClientOption{
955 provider.WithModel(config.SelectedModelTypeLarge),
956 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
957 }
958 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
959 if err != nil {
960 return fmt.Errorf("failed to create new summarize provider: %w", err)
961 }
962 a.summarizeProvider = newSummarizeProvider
963 a.summarizeProviderID = string(largeModelProviderCfg.ID)
964 }
965
966 return nil
967}