1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 tools *csync.LazySlice[tools.BaseTool]
72
73 provider provider.Provider
74 providerID string
75
76 titleProvider provider.Provider
77 summarizeProvider provider.Provider
78 summarizeProviderID string
79
80 activeRequests *csync.Map[string, context.CancelFunc]
81}
82
83var agentPromptMap = map[string]prompt.PromptID{
84 "coder": prompt.PromptCoder,
85 "task": prompt.PromptTask,
86}
87
88func NewAgent(
89 agentCfg config.Agent,
90 // These services are needed in the tools
91 permissions permission.Service,
92 sessions session.Service,
93 messages message.Service,
94 history history.Service,
95 lspClients map[string]*lsp.Client,
96) (Service, error) {
97 ctx := context.Background()
98 cfg := config.Get()
99
100 var agentTool tools.BaseTool
101 if agentCfg.ID == "coder" {
102 taskAgentCfg := config.Get().Agents["task"]
103 if taskAgentCfg.ID == "" {
104 return nil, fmt.Errorf("task agent not found in config")
105 }
106 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107 if err != nil {
108 return nil, fmt.Errorf("failed to create task agent: %w", err)
109 }
110
111 agentTool = NewAgentTool(taskAgent, sessions, messages)
112 }
113
114 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115 if providerCfg == nil {
116 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117 }
118 model := config.Get().GetModelByType(agentCfg.Model)
119
120 if model == nil {
121 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122 }
123
124 promptID := agentPromptMap[agentCfg.ID]
125 if promptID == "" {
126 promptID = prompt.PromptDefault
127 }
128 opts := []provider.ProviderClientOption{
129 provider.WithModel(agentCfg.Model),
130 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131 }
132 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133 if err != nil {
134 return nil, err
135 }
136
137 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138 var smallModelProviderCfg *config.ProviderConfig
139 if smallModelCfg.Provider == providerCfg.ID {
140 smallModelProviderCfg = providerCfg
141 } else {
142 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144 if smallModelProviderCfg.ID == "" {
145 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146 }
147 }
148 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149 if smallModel.ID == "" {
150 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151 }
152
153 titleOpts := []provider.ProviderClientOption{
154 provider.WithModel(config.SelectedModelTypeSmall),
155 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156 }
157 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158 if err != nil {
159 return nil, err
160 }
161 summarizeOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164 }
165 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 toolFn := func() []tools.BaseTool {
171 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172 defer func() {
173 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174 }()
175
176 cwd := cfg.WorkingDir()
177 allTools := []tools.BaseTool{
178 tools.NewBashTool(permissions, cwd),
179 tools.NewDownloadTool(permissions, cwd),
180 tools.NewEditTool(lspClients, permissions, history, cwd),
181 tools.NewFetchTool(permissions, cwd),
182 tools.NewGlobTool(cwd),
183 tools.NewGrepTool(cwd),
184 tools.NewLsTool(permissions, cwd),
185 tools.NewSourcegraphTool(),
186 tools.NewViewTool(lspClients, permissions, cwd),
187 tools.NewWriteTool(lspClients, permissions, history, cwd),
188 }
189
190 mcpTools := GetMCPTools(ctx, permissions, cfg)
191 allTools = append(allTools, mcpTools...)
192
193 if len(lspClients) > 0 {
194 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
195 }
196
197 if agentTool != nil {
198 allTools = append(allTools, agentTool)
199 }
200
201 if agentCfg.AllowedTools == nil {
202 return allTools
203 }
204
205 var filteredTools []tools.BaseTool
206 for _, tool := range allTools {
207 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208 filteredTools = append(filteredTools, tool)
209 }
210 }
211 return filteredTools
212 }
213
214 return &agent{
215 Broker: pubsub.NewBroker[AgentEvent](),
216 agentCfg: agentCfg,
217 provider: agentProvider,
218 providerID: string(providerCfg.ID),
219 messages: messages,
220 sessions: sessions,
221 titleProvider: titleProvider,
222 summarizeProvider: summarizeProvider,
223 summarizeProviderID: string(smallModelProviderCfg.ID),
224 activeRequests: csync.NewMap[string, context.CancelFunc](),
225 tools: csync.NewLazySlice(toolFn),
226 }, nil
227}
228
229func (a *agent) Model() catwalk.Model {
230 return *config.Get().GetModelByType(a.agentCfg.Model)
231}
232
233func (a *agent) Cancel(sessionID string) {
234 // Cancel regular requests
235 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
236 slog.Info("Request cancellation initiated", "session_id", sessionID)
237 cancel()
238 }
239
240 // Also check for summarize requests
241 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
242 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
243 cancel()
244 }
245}
246
247func (a *agent) IsBusy() bool {
248 var busy bool
249 for cancelFunc := range a.activeRequests.Seq() {
250 if cancelFunc != nil {
251 busy = true
252 break
253 }
254 }
255 return busy
256}
257
258func (a *agent) IsSessionBusy(sessionID string) bool {
259 _, busy := a.activeRequests.Get(sessionID)
260 return busy
261}
262
263func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
264 if content == "" {
265 return nil
266 }
267 if a.titleProvider == nil {
268 return nil
269 }
270 session, err := a.sessions.Get(ctx, sessionID)
271 if err != nil {
272 return err
273 }
274 parts := []message.ContentPart{message.TextContent{
275 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
276 }}
277
278 // Use streaming approach like summarization
279 response := a.titleProvider.StreamResponse(
280 ctx,
281 []message.Message{
282 {
283 Role: message.User,
284 Parts: parts,
285 },
286 },
287 nil,
288 )
289
290 var finalResponse *provider.ProviderResponse
291 for r := range response {
292 if r.Error != nil {
293 return r.Error
294 }
295 finalResponse = r.Response
296 }
297
298 if finalResponse == nil {
299 return fmt.Errorf("no response received from title provider")
300 }
301
302 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
303 if title == "" {
304 return nil
305 }
306
307 session.Title = title
308 _, err = a.sessions.Save(ctx, session)
309 return err
310}
311
312func (a *agent) err(err error) AgentEvent {
313 return AgentEvent{
314 Type: AgentEventTypeError,
315 Error: err,
316 }
317}
318
319func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
320 if !a.Model().SupportsImages && attachments != nil {
321 attachments = nil
322 }
323 events := make(chan AgentEvent)
324 if a.IsSessionBusy(sessionID) {
325 return nil, ErrSessionBusy
326 }
327
328 genCtx, cancel := context.WithCancel(ctx)
329
330 a.activeRequests.Set(sessionID, cancel)
331 go func() {
332 slog.Debug("Request started", "sessionID", sessionID)
333 defer log.RecoverPanic("agent.Run", func() {
334 events <- a.err(fmt.Errorf("panic while running the agent"))
335 })
336 var attachmentParts []message.ContentPart
337 for _, attachment := range attachments {
338 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
339 }
340 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
341 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
342 slog.Error(result.Error.Error())
343 }
344 slog.Debug("Request completed", "sessionID", sessionID)
345 a.activeRequests.Del(sessionID)
346 cancel()
347 a.Publish(pubsub.CreatedEvent, result)
348 events <- result
349 close(events)
350 }()
351 return events, nil
352}
353
354func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
355 cfg := config.Get()
356 // List existing messages; if none, start title generation asynchronously.
357 msgs, err := a.messages.List(ctx, sessionID)
358 if err != nil {
359 return a.err(fmt.Errorf("failed to list messages: %w", err))
360 }
361 if len(msgs) == 0 {
362 go func() {
363 defer log.RecoverPanic("agent.Run", func() {
364 slog.Error("panic while generating title")
365 })
366 titleErr := a.generateTitle(context.Background(), sessionID, content)
367 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
368 slog.Error("failed to generate title", "error", titleErr)
369 }
370 }()
371 }
372 session, err := a.sessions.Get(ctx, sessionID)
373 if err != nil {
374 return a.err(fmt.Errorf("failed to get session: %w", err))
375 }
376 if session.SummaryMessageID != "" {
377 summaryMsgInex := -1
378 for i, msg := range msgs {
379 if msg.ID == session.SummaryMessageID {
380 summaryMsgInex = i
381 break
382 }
383 }
384 if summaryMsgInex != -1 {
385 msgs = msgs[summaryMsgInex:]
386 msgs[0].Role = message.User
387 }
388 }
389
390 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
391 if err != nil {
392 return a.err(fmt.Errorf("failed to create user message: %w", err))
393 }
394 // Append the new user message to the conversation history.
395 msgHistory := append(msgs, userMsg)
396
397 for {
398 // Check for cancellation before each iteration
399 select {
400 case <-ctx.Done():
401 return a.err(ctx.Err())
402 default:
403 // Continue processing
404 }
405 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
406 if err != nil {
407 if errors.Is(err, context.Canceled) {
408 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
409 a.messages.Update(context.Background(), agentMessage)
410 return a.err(ErrRequestCancelled)
411 }
412 return a.err(fmt.Errorf("failed to process events: %w", err))
413 }
414 if cfg.Options.Debug {
415 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416 }
417 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418 // We are not done, we need to respond with the tool response
419 msgHistory = append(msgHistory, agentMessage, *toolResults)
420 continue
421 }
422 return AgentEvent{
423 Type: AgentEventTypeResponse,
424 Message: agentMessage,
425 Done: true,
426 }
427 }
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431 parts := []message.ContentPart{message.TextContent{Text: content}}
432 parts = append(parts, attachmentParts...)
433 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434 Role: message.User,
435 Parts: parts,
436 })
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
442
443 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.Assistant,
445 Parts: []message.ContentPart{},
446 Model: a.Model().ID,
447 Provider: a.providerID,
448 })
449 if err != nil {
450 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451 }
452
453 // Add the session and message ID into the context if needed by tools.
454 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456 // Process each event in the stream.
457 for event := range eventChan {
458 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459 if errors.Is(processErr, context.Canceled) {
460 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
461 } else {
462 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
463 }
464 return assistantMsg, nil, processErr
465 }
466 if ctx.Err() != nil {
467 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468 return assistantMsg, nil, ctx.Err()
469 }
470 }
471
472 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473 toolCalls := assistantMsg.ToolCalls()
474 for i, toolCall := range toolCalls {
475 select {
476 case <-ctx.Done():
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478 // Make all future tool calls cancelled
479 for j := i; j < len(toolCalls); j++ {
480 toolResults[j] = message.ToolResult{
481 ToolCallID: toolCalls[j].ID,
482 Content: "Tool execution canceled by user",
483 IsError: true,
484 }
485 }
486 goto out
487 default:
488 // Continue processing
489 var tool tools.BaseTool
490 for availableTool := range a.tools.Seq() {
491 if availableTool.Info().Name == toolCall.Name {
492 tool = availableTool
493 break
494 }
495 }
496
497 // Tool not found
498 if tool == nil {
499 toolResults[i] = message.ToolResult{
500 ToolCallID: toolCall.ID,
501 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
502 IsError: true,
503 }
504 continue
505 }
506
507 // Run tool in goroutine to allow cancellation
508 type toolExecResult struct {
509 response tools.ToolResponse
510 err error
511 }
512 resultChan := make(chan toolExecResult, 1)
513
514 go func() {
515 response, err := tool.Run(ctx, tools.ToolCall{
516 ID: toolCall.ID,
517 Name: toolCall.Name,
518 Input: toolCall.Input,
519 })
520 resultChan <- toolExecResult{response: response, err: err}
521 }()
522
523 var toolResponse tools.ToolResponse
524 var toolErr error
525
526 select {
527 case <-ctx.Done():
528 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
529 // Mark remaining tool calls as cancelled
530 for j := i; j < len(toolCalls); j++ {
531 toolResults[j] = message.ToolResult{
532 ToolCallID: toolCalls[j].ID,
533 Content: "Tool execution canceled by user",
534 IsError: true,
535 }
536 }
537 goto out
538 case result := <-resultChan:
539 toolResponse = result.response
540 toolErr = result.err
541 }
542
543 if toolErr != nil {
544 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
545 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
546 toolResults[i] = message.ToolResult{
547 ToolCallID: toolCall.ID,
548 Content: "Permission denied",
549 IsError: true,
550 }
551 for j := i + 1; j < len(toolCalls); j++ {
552 toolResults[j] = message.ToolResult{
553 ToolCallID: toolCalls[j].ID,
554 Content: "Tool execution canceled by user",
555 IsError: true,
556 }
557 }
558 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
559 break
560 }
561 }
562 toolResults[i] = message.ToolResult{
563 ToolCallID: toolCall.ID,
564 Content: toolResponse.Content,
565 Metadata: toolResponse.Metadata,
566 IsError: toolResponse.IsError,
567 }
568 }
569 }
570out:
571 if len(toolResults) == 0 {
572 return assistantMsg, nil, nil
573 }
574 parts := make([]message.ContentPart, 0)
575 for _, tr := range toolResults {
576 parts = append(parts, tr)
577 }
578 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
579 Role: message.Tool,
580 Parts: parts,
581 Provider: a.providerID,
582 })
583 if err != nil {
584 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
585 }
586
587 return assistantMsg, &msg, err
588}
589
590func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
591 msg.AddFinish(finishReason, message, details)
592 _ = a.messages.Update(ctx, *msg)
593}
594
595func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
596 select {
597 case <-ctx.Done():
598 return ctx.Err()
599 default:
600 // Continue processing.
601 }
602
603 switch event.Type {
604 case provider.EventThinkingDelta:
605 assistantMsg.AppendReasoningContent(event.Thinking)
606 return a.messages.Update(ctx, *assistantMsg)
607 case provider.EventSignatureDelta:
608 assistantMsg.AppendReasoningSignature(event.Signature)
609 return a.messages.Update(ctx, *assistantMsg)
610 case provider.EventContentDelta:
611 assistantMsg.FinishThinking()
612 assistantMsg.AppendContent(event.Content)
613 return a.messages.Update(ctx, *assistantMsg)
614 case provider.EventToolUseStart:
615 assistantMsg.FinishThinking()
616 slog.Info("Tool call started", "toolCall", event.ToolCall)
617 assistantMsg.AddToolCall(*event.ToolCall)
618 return a.messages.Update(ctx, *assistantMsg)
619 case provider.EventToolUseDelta:
620 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
621 return a.messages.Update(ctx, *assistantMsg)
622 case provider.EventToolUseStop:
623 slog.Info("Finished tool call", "toolCall", event.ToolCall)
624 assistantMsg.FinishToolCall(event.ToolCall.ID)
625 return a.messages.Update(ctx, *assistantMsg)
626 case provider.EventError:
627 return event.Error
628 case provider.EventComplete:
629 assistantMsg.FinishThinking()
630 assistantMsg.SetToolCalls(event.Response.ToolCalls)
631 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
632 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
633 return fmt.Errorf("failed to update message: %w", err)
634 }
635 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
636 }
637
638 return nil
639}
640
641func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
642 sess, err := a.sessions.Get(ctx, sessionID)
643 if err != nil {
644 return fmt.Errorf("failed to get session: %w", err)
645 }
646
647 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
648 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
649 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
650 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
651
652 sess.Cost += cost
653 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
654 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
655
656 _, err = a.sessions.Save(ctx, sess)
657 if err != nil {
658 return fmt.Errorf("failed to save session: %w", err)
659 }
660 return nil
661}
662
663func (a *agent) Summarize(ctx context.Context, sessionID string) error {
664 if a.summarizeProvider == nil {
665 return fmt.Errorf("summarize provider not available")
666 }
667
668 // Check if session is busy
669 if a.IsSessionBusy(sessionID) {
670 return ErrSessionBusy
671 }
672
673 // Create a new context with cancellation
674 summarizeCtx, cancel := context.WithCancel(ctx)
675
676 // Store the cancel function in activeRequests to allow cancellation
677 a.activeRequests.Set(sessionID+"-summarize", cancel)
678
679 go func() {
680 defer a.activeRequests.Del(sessionID + "-summarize")
681 defer cancel()
682 event := AgentEvent{
683 Type: AgentEventTypeSummarize,
684 Progress: "Starting summarization...",
685 }
686
687 a.Publish(pubsub.CreatedEvent, event)
688 // Get all messages from the session
689 msgs, err := a.messages.List(summarizeCtx, sessionID)
690 if err != nil {
691 event = AgentEvent{
692 Type: AgentEventTypeError,
693 Error: fmt.Errorf("failed to list messages: %w", err),
694 Done: true,
695 }
696 a.Publish(pubsub.CreatedEvent, event)
697 return
698 }
699 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
700
701 if len(msgs) == 0 {
702 event = AgentEvent{
703 Type: AgentEventTypeError,
704 Error: fmt.Errorf("no messages to summarize"),
705 Done: true,
706 }
707 a.Publish(pubsub.CreatedEvent, event)
708 return
709 }
710
711 event = AgentEvent{
712 Type: AgentEventTypeSummarize,
713 Progress: "Analyzing conversation...",
714 }
715 a.Publish(pubsub.CreatedEvent, event)
716
717 // Add a system message to guide the summarization
718 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
719
720 // Create a new message with the summarize prompt
721 promptMsg := message.Message{
722 Role: message.User,
723 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
724 }
725
726 // Append the prompt to the messages
727 msgsWithPrompt := append(msgs, promptMsg)
728
729 event = AgentEvent{
730 Type: AgentEventTypeSummarize,
731 Progress: "Generating summary...",
732 }
733
734 a.Publish(pubsub.CreatedEvent, event)
735
736 // Send the messages to the summarize provider
737 response := a.summarizeProvider.StreamResponse(
738 summarizeCtx,
739 msgsWithPrompt,
740 nil,
741 )
742 var finalResponse *provider.ProviderResponse
743 for r := range response {
744 if r.Error != nil {
745 event = AgentEvent{
746 Type: AgentEventTypeError,
747 Error: fmt.Errorf("failed to summarize: %w", err),
748 Done: true,
749 }
750 a.Publish(pubsub.CreatedEvent, event)
751 return
752 }
753 finalResponse = r.Response
754 }
755
756 summary := strings.TrimSpace(finalResponse.Content)
757 if summary == "" {
758 event = AgentEvent{
759 Type: AgentEventTypeError,
760 Error: fmt.Errorf("empty summary returned"),
761 Done: true,
762 }
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 shell := shell.GetPersistentShell(config.Get().WorkingDir())
767 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
768 event = AgentEvent{
769 Type: AgentEventTypeSummarize,
770 Progress: "Creating new session...",
771 }
772
773 a.Publish(pubsub.CreatedEvent, event)
774 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
775 if err != nil {
776 event = AgentEvent{
777 Type: AgentEventTypeError,
778 Error: fmt.Errorf("failed to get session: %w", err),
779 Done: true,
780 }
781
782 a.Publish(pubsub.CreatedEvent, event)
783 return
784 }
785 // Create a message in the new session with the summary
786 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
787 Role: message.Assistant,
788 Parts: []message.ContentPart{
789 message.TextContent{Text: summary},
790 message.Finish{
791 Reason: message.FinishReasonEndTurn,
792 Time: time.Now().Unix(),
793 },
794 },
795 Model: a.summarizeProvider.Model().ID,
796 Provider: a.summarizeProviderID,
797 })
798 if err != nil {
799 event = AgentEvent{
800 Type: AgentEventTypeError,
801 Error: fmt.Errorf("failed to create summary message: %w", err),
802 Done: true,
803 }
804
805 a.Publish(pubsub.CreatedEvent, event)
806 return
807 }
808 oldSession.SummaryMessageID = msg.ID
809 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
810 oldSession.PromptTokens = 0
811 model := a.summarizeProvider.Model()
812 usage := finalResponse.Usage
813 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
814 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
815 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
816 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
817 oldSession.Cost += cost
818 _, err = a.sessions.Save(summarizeCtx, oldSession)
819 if err != nil {
820 event = AgentEvent{
821 Type: AgentEventTypeError,
822 Error: fmt.Errorf("failed to save session: %w", err),
823 Done: true,
824 }
825 a.Publish(pubsub.CreatedEvent, event)
826 }
827
828 event = AgentEvent{
829 Type: AgentEventTypeSummarize,
830 SessionID: oldSession.ID,
831 Progress: "Summary complete",
832 Done: true,
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835 // Send final success event with the new session ID
836 }()
837
838 return nil
839}
840
841func (a *agent) CancelAll() {
842 if !a.IsBusy() {
843 return
844 }
845 for key := range a.activeRequests.Seq2() {
846 a.Cancel(key) // key is sessionID
847 }
848
849 timeout := time.After(5 * time.Second)
850 for a.IsBusy() {
851 select {
852 case <-timeout:
853 return
854 default:
855 time.Sleep(200 * time.Millisecond)
856 }
857 }
858}
859
860func (a *agent) UpdateModel() error {
861 cfg := config.Get()
862
863 // Get current provider configuration
864 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
865 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
866 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
867 }
868
869 // Check if provider has changed
870 if string(currentProviderCfg.ID) != a.providerID {
871 // Provider changed, need to recreate the main provider
872 model := cfg.GetModelByType(a.agentCfg.Model)
873 if model.ID == "" {
874 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
875 }
876
877 promptID := agentPromptMap[a.agentCfg.ID]
878 if promptID == "" {
879 promptID = prompt.PromptDefault
880 }
881
882 opts := []provider.ProviderClientOption{
883 provider.WithModel(a.agentCfg.Model),
884 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
885 }
886
887 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
888 if err != nil {
889 return fmt.Errorf("failed to create new provider: %w", err)
890 }
891
892 // Update the provider and provider ID
893 a.provider = newProvider
894 a.providerID = string(currentProviderCfg.ID)
895 }
896
897 // Check if small model provider has changed (affects title and summarize providers)
898 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
899 var smallModelProviderCfg config.ProviderConfig
900
901 for p := range cfg.Providers.Seq() {
902 if p.ID == smallModelCfg.Provider {
903 smallModelProviderCfg = p
904 break
905 }
906 }
907
908 if smallModelProviderCfg.ID == "" {
909 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
910 }
911
912 // Check if summarize provider has changed
913 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
914 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
915 if smallModel == nil {
916 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
917 }
918
919 // Recreate title provider
920 titleOpts := []provider.ProviderClientOption{
921 provider.WithModel(config.SelectedModelTypeSmall),
922 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
923 // We want the title to be short, so we limit the max tokens
924 provider.WithMaxTokens(40),
925 }
926 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
927 if err != nil {
928 return fmt.Errorf("failed to create new title provider: %w", err)
929 }
930
931 // Recreate summarize provider
932 summarizeOpts := []provider.ProviderClientOption{
933 provider.WithModel(config.SelectedModelTypeSmall),
934 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
935 }
936 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
937 if err != nil {
938 return fmt.Errorf("failed to create new summarize provider: %w", err)
939 }
940
941 // Update the providers and provider ID
942 a.titleProvider = newTitleProvider
943 a.summarizeProvider = newSummarizeProvider
944 a.summarizeProviderID = string(smallModelProviderCfg.ID)
945 }
946
947 return nil
948}