1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMCPTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewDownloadTool(permissions, cwd),
106 tools.NewEditTool(lspClients, permissions, history, cwd),
107 tools.NewFetchTool(permissions, cwd),
108 tools.NewGlobTool(cwd),
109 tools.NewGrepTool(cwd),
110 tools.NewLsTool(cwd),
111 tools.NewSourcegraphTool(),
112 tools.NewViewTool(lspClients, cwd),
113 tools.NewWriteTool(lspClients, permissions, history, cwd),
114 }
115
116 if agentCfg.ID == "coder" {
117 taskAgentCfg := config.Get().Agents["task"]
118 if taskAgentCfg.ID == "" {
119 return nil, fmt.Errorf("task agent not found in config")
120 }
121 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
122 if err != nil {
123 return nil, fmt.Errorf("failed to create task agent: %w", err)
124 }
125
126 allTools = append(
127 allTools,
128 NewAgentTool(
129 taskAgent,
130 sessions,
131 messages,
132 ),
133 )
134 }
135
136 allTools = append(allTools, otherTools...)
137 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
138 if providerCfg == nil {
139 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
140 }
141 model := config.Get().GetModelByType(agentCfg.Model)
142
143 if model == nil {
144 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
145 }
146
147 promptID := agentPromptMap[agentCfg.ID]
148 if promptID == "" {
149 promptID = prompt.PromptDefault
150 }
151 opts := []provider.ProviderClientOption{
152 provider.WithModel(agentCfg.Model),
153 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
154 }
155 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
156 if err != nil {
157 return nil, err
158 }
159
160 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
161 var smallModelProviderCfg *config.ProviderConfig
162 if smallModelCfg.Provider == providerCfg.ID {
163 smallModelProviderCfg = providerCfg
164 } else {
165 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
166
167 if smallModelProviderCfg.ID == "" {
168 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
169 }
170 }
171 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
172 if smallModel.ID == "" {
173 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
174 }
175
176 titleOpts := []provider.ProviderClientOption{
177 provider.WithModel(config.SelectedModelTypeSmall),
178 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
179 }
180 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
181 if err != nil {
182 return nil, err
183 }
184 summarizeOpts := []provider.ProviderClientOption{
185 provider.WithModel(config.SelectedModelTypeSmall),
186 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
187 }
188 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
189 if err != nil {
190 return nil, err
191 }
192
193 agentTools := []tools.BaseTool{}
194 if agentCfg.AllowedTools == nil {
195 agentTools = allTools
196 } else {
197 for _, tool := range allTools {
198 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
199 agentTools = append(agentTools, tool)
200 }
201 }
202 }
203
204 agent := &agent{
205 Broker: pubsub.NewBroker[AgentEvent](),
206 agentCfg: agentCfg,
207 provider: agentProvider,
208 providerID: string(providerCfg.ID),
209 messages: messages,
210 sessions: sessions,
211 tools: agentTools,
212 titleProvider: titleProvider,
213 summarizeProvider: summarizeProvider,
214 summarizeProviderID: string(smallModelProviderCfg.ID),
215 activeRequests: sync.Map{},
216 }
217
218 return agent, nil
219}
220
221func (a *agent) Model() fur.Model {
222 return *config.Get().GetModelByType(a.agentCfg.Model)
223}
224
225func (a *agent) Cancel(sessionID string) {
226 // Cancel regular requests
227 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
228 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
229 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
230 cancel()
231 }
232 }
233
234 // Also check for summarize requests
235 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
236 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
237 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
238 cancel()
239 }
240 }
241}
242
243func (a *agent) IsBusy() bool {
244 busy := false
245 a.activeRequests.Range(func(key, value any) bool {
246 if cancelFunc, ok := value.(context.CancelFunc); ok {
247 if cancelFunc != nil {
248 busy = true
249 return false
250 }
251 }
252 return true
253 })
254 return busy
255}
256
257func (a *agent) IsSessionBusy(sessionID string) bool {
258 _, busy := a.activeRequests.Load(sessionID)
259 return busy
260}
261
262func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
263 if content == "" {
264 return nil
265 }
266 if a.titleProvider == nil {
267 return nil
268 }
269 session, err := a.sessions.Get(ctx, sessionID)
270 if err != nil {
271 return err
272 }
273 parts := []message.ContentPart{message.TextContent{
274 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
275 }}
276
277 // Use streaming approach like summarization
278 response := a.titleProvider.StreamResponse(
279 ctx,
280 []message.Message{
281 {
282 Role: message.User,
283 Parts: parts,
284 },
285 },
286 make([]tools.BaseTool, 0),
287 )
288
289 var finalResponse *provider.ProviderResponse
290 for r := range response {
291 if r.Error != nil {
292 return r.Error
293 }
294 finalResponse = r.Response
295 }
296
297 if finalResponse == nil {
298 return fmt.Errorf("no response received from title provider")
299 }
300
301 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
302 if title == "" {
303 return nil
304 }
305
306 session.Title = title
307 _, err = a.sessions.Save(ctx, session)
308 return err
309}
310
311func (a *agent) err(err error) AgentEvent {
312 return AgentEvent{
313 Type: AgentEventTypeError,
314 Error: err,
315 }
316}
317
318func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
319 if !a.Model().SupportsImages && attachments != nil {
320 attachments = nil
321 }
322 events := make(chan AgentEvent)
323 if a.IsSessionBusy(sessionID) {
324 return nil, ErrSessionBusy
325 }
326
327 genCtx, cancel := context.WithCancel(ctx)
328
329 a.activeRequests.Store(sessionID, cancel)
330 go func() {
331 slog.Debug("Request started", "sessionID", sessionID)
332 defer log.RecoverPanic("agent.Run", func() {
333 events <- a.err(fmt.Errorf("panic while running the agent"))
334 })
335 var attachmentParts []message.ContentPart
336 for _, attachment := range attachments {
337 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
338 }
339 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
340 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
341 slog.Error(result.Error.Error())
342 }
343 slog.Debug("Request completed", "sessionID", sessionID)
344 a.activeRequests.Delete(sessionID)
345 cancel()
346 a.Publish(pubsub.CreatedEvent, result)
347 events <- result
348 close(events)
349 }()
350 return events, nil
351}
352
353func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
354 cfg := config.Get()
355 // List existing messages; if none, start title generation asynchronously.
356 msgs, err := a.messages.List(ctx, sessionID)
357 if err != nil {
358 return a.err(fmt.Errorf("failed to list messages: %w", err))
359 }
360 if len(msgs) == 0 {
361 go func() {
362 defer log.RecoverPanic("agent.Run", func() {
363 slog.Error("panic while generating title")
364 })
365 titleErr := a.generateTitle(context.Background(), sessionID, content)
366 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
367 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
368 }
369 }()
370 }
371 session, err := a.sessions.Get(ctx, sessionID)
372 if err != nil {
373 return a.err(fmt.Errorf("failed to get session: %w", err))
374 }
375 if session.SummaryMessageID != "" {
376 summaryMsgInex := -1
377 for i, msg := range msgs {
378 if msg.ID == session.SummaryMessageID {
379 summaryMsgInex = i
380 break
381 }
382 }
383 if summaryMsgInex != -1 {
384 msgs = msgs[summaryMsgInex:]
385 msgs[0].Role = message.User
386 }
387 }
388
389 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
390 if err != nil {
391 return a.err(fmt.Errorf("failed to create user message: %w", err))
392 }
393 // Append the new user message to the conversation history.
394 msgHistory := append(msgs, userMsg)
395
396 for {
397 // Check for cancellation before each iteration
398 select {
399 case <-ctx.Done():
400 return a.err(ctx.Err())
401 default:
402 // Continue processing
403 }
404 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
405 if err != nil {
406 if errors.Is(err, context.Canceled) {
407 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
408 a.messages.Update(context.Background(), agentMessage)
409 return a.err(ErrRequestCancelled)
410 }
411 return a.err(fmt.Errorf("failed to process events: %w", err))
412 }
413 if cfg.Options.Debug {
414 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
415 }
416 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
417 // We are not done, we need to respond with the tool response
418 msgHistory = append(msgHistory, agentMessage, *toolResults)
419 continue
420 }
421 return AgentEvent{
422 Type: AgentEventTypeResponse,
423 Message: agentMessage,
424 Done: true,
425 }
426 }
427}
428
429func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
430 parts := []message.ContentPart{message.TextContent{Text: content}}
431 parts = append(parts, attachmentParts...)
432 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
433 Role: message.User,
434 Parts: parts,
435 })
436}
437
438func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
439 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
440 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
441
442 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
443 Role: message.Assistant,
444 Parts: []message.ContentPart{},
445 Model: a.Model().ID,
446 Provider: a.providerID,
447 })
448 if err != nil {
449 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
450 }
451
452 // Add the session and message ID into the context if needed by tools.
453 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
454
455 // Process each event in the stream.
456 for event := range eventChan {
457 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
458 if errors.Is(processErr, context.Canceled) {
459 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
460 } else {
461 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
462 }
463 return assistantMsg, nil, processErr
464 }
465 if ctx.Err() != nil {
466 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
467 return assistantMsg, nil, ctx.Err()
468 }
469 }
470
471 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
472 toolCalls := assistantMsg.ToolCalls()
473 for i, toolCall := range toolCalls {
474 select {
475 case <-ctx.Done():
476 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
477 // Make all future tool calls cancelled
478 for j := i; j < len(toolCalls); j++ {
479 toolResults[j] = message.ToolResult{
480 ToolCallID: toolCalls[j].ID,
481 Content: "Tool execution canceled by user",
482 IsError: true,
483 }
484 }
485 goto out
486 default:
487 // Continue processing
488 var tool tools.BaseTool
489 for _, availableTool := range a.tools {
490 if availableTool.Info().Name == toolCall.Name {
491 tool = availableTool
492 break
493 }
494 }
495
496 // Tool not found
497 if tool == nil {
498 toolResults[i] = message.ToolResult{
499 ToolCallID: toolCall.ID,
500 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
501 IsError: true,
502 }
503 continue
504 }
505
506 // Run tool in goroutine to allow cancellation
507 type toolExecResult struct {
508 response tools.ToolResponse
509 err error
510 }
511 resultChan := make(chan toolExecResult, 1)
512
513 go func() {
514 response, err := tool.Run(ctx, tools.ToolCall{
515 ID: toolCall.ID,
516 Name: toolCall.Name,
517 Input: toolCall.Input,
518 })
519 resultChan <- toolExecResult{response: response, err: err}
520 }()
521
522 var toolResponse tools.ToolResponse
523 var toolErr error
524
525 select {
526 case <-ctx.Done():
527 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
528 // Mark remaining tool calls as cancelled
529 for j := i; j < len(toolCalls); j++ {
530 toolResults[j] = message.ToolResult{
531 ToolCallID: toolCalls[j].ID,
532 Content: "Tool execution canceled by user",
533 IsError: true,
534 }
535 }
536 goto out
537 case result := <-resultChan:
538 toolResponse = result.response
539 toolErr = result.err
540 }
541
542 if toolErr != nil {
543 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
544 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
545 toolResults[i] = message.ToolResult{
546 ToolCallID: toolCall.ID,
547 Content: "Permission denied",
548 IsError: true,
549 }
550 for j := i + 1; j < len(toolCalls); j++ {
551 toolResults[j] = message.ToolResult{
552 ToolCallID: toolCalls[j].ID,
553 Content: "Tool execution canceled by user",
554 IsError: true,
555 }
556 }
557 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
558 break
559 }
560 }
561 toolResults[i] = message.ToolResult{
562 ToolCallID: toolCall.ID,
563 Content: toolResponse.Content,
564 Metadata: toolResponse.Metadata,
565 IsError: toolResponse.IsError,
566 }
567 }
568 }
569out:
570 if len(toolResults) == 0 {
571 return assistantMsg, nil, nil
572 }
573 parts := make([]message.ContentPart, 0)
574 for _, tr := range toolResults {
575 parts = append(parts, tr)
576 }
577 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
578 Role: message.Tool,
579 Parts: parts,
580 Provider: a.providerID,
581 })
582 if err != nil {
583 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
584 }
585
586 return assistantMsg, &msg, err
587}
588
589func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
590 msg.AddFinish(finishReason, message, details)
591 _ = a.messages.Update(ctx, *msg)
592}
593
594func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
595 select {
596 case <-ctx.Done():
597 return ctx.Err()
598 default:
599 // Continue processing.
600 }
601
602 switch event.Type {
603 case provider.EventThinkingDelta:
604 assistantMsg.AppendReasoningContent(event.Thinking)
605 return a.messages.Update(ctx, *assistantMsg)
606 case provider.EventSignatureDelta:
607 assistantMsg.AppendReasoningSignature(event.Signature)
608 return a.messages.Update(ctx, *assistantMsg)
609 case provider.EventContentDelta:
610 assistantMsg.FinishThinking()
611 assistantMsg.AppendContent(event.Content)
612 return a.messages.Update(ctx, *assistantMsg)
613 case provider.EventToolUseStart:
614 assistantMsg.FinishThinking()
615 slog.Info("Tool call started", "toolCall", event.ToolCall)
616 assistantMsg.AddToolCall(*event.ToolCall)
617 return a.messages.Update(ctx, *assistantMsg)
618 case provider.EventToolUseDelta:
619 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventToolUseStop:
622 slog.Info("Finished tool call", "toolCall", event.ToolCall)
623 assistantMsg.FinishToolCall(event.ToolCall.ID)
624 return a.messages.Update(ctx, *assistantMsg)
625 case provider.EventError:
626 return event.Error
627 case provider.EventComplete:
628 assistantMsg.FinishThinking()
629 assistantMsg.SetToolCalls(event.Response.ToolCalls)
630 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
631 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
632 return fmt.Errorf("failed to update message: %w", err)
633 }
634 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
635 }
636
637 return nil
638}
639
640func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
641 sess, err := a.sessions.Get(ctx, sessionID)
642 if err != nil {
643 return fmt.Errorf("failed to get session: %w", err)
644 }
645
646 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
647 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
648 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
649 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
650
651 sess.Cost += cost
652 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
653 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
654
655 _, err = a.sessions.Save(ctx, sess)
656 if err != nil {
657 return fmt.Errorf("failed to save session: %w", err)
658 }
659 return nil
660}
661
662func (a *agent) Summarize(ctx context.Context, sessionID string) error {
663 if a.summarizeProvider == nil {
664 return fmt.Errorf("summarize provider not available")
665 }
666
667 // Check if session is busy
668 if a.IsSessionBusy(sessionID) {
669 return ErrSessionBusy
670 }
671
672 // Create a new context with cancellation
673 summarizeCtx, cancel := context.WithCancel(ctx)
674
675 // Store the cancel function in activeRequests to allow cancellation
676 a.activeRequests.Store(sessionID+"-summarize", cancel)
677
678 go func() {
679 defer a.activeRequests.Delete(sessionID + "-summarize")
680 defer cancel()
681 event := AgentEvent{
682 Type: AgentEventTypeSummarize,
683 Progress: "Starting summarization...",
684 }
685
686 a.Publish(pubsub.CreatedEvent, event)
687 // Get all messages from the session
688 msgs, err := a.messages.List(summarizeCtx, sessionID)
689 if err != nil {
690 event = AgentEvent{
691 Type: AgentEventTypeError,
692 Error: fmt.Errorf("failed to list messages: %w", err),
693 Done: true,
694 }
695 a.Publish(pubsub.CreatedEvent, event)
696 return
697 }
698 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
699
700 if len(msgs) == 0 {
701 event = AgentEvent{
702 Type: AgentEventTypeError,
703 Error: fmt.Errorf("no messages to summarize"),
704 Done: true,
705 }
706 a.Publish(pubsub.CreatedEvent, event)
707 return
708 }
709
710 event = AgentEvent{
711 Type: AgentEventTypeSummarize,
712 Progress: "Analyzing conversation...",
713 }
714 a.Publish(pubsub.CreatedEvent, event)
715
716 // Add a system message to guide the summarization
717 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
718
719 // Create a new message with the summarize prompt
720 promptMsg := message.Message{
721 Role: message.User,
722 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
723 }
724
725 // Append the prompt to the messages
726 msgsWithPrompt := append(msgs, promptMsg)
727
728 event = AgentEvent{
729 Type: AgentEventTypeSummarize,
730 Progress: "Generating summary...",
731 }
732
733 a.Publish(pubsub.CreatedEvent, event)
734
735 // Send the messages to the summarize provider
736 response := a.summarizeProvider.StreamResponse(
737 summarizeCtx,
738 msgsWithPrompt,
739 make([]tools.BaseTool, 0),
740 )
741 var finalResponse *provider.ProviderResponse
742 for r := range response {
743 if r.Error != nil {
744 event = AgentEvent{
745 Type: AgentEventTypeError,
746 Error: fmt.Errorf("failed to summarize: %w", err),
747 Done: true,
748 }
749 a.Publish(pubsub.CreatedEvent, event)
750 return
751 }
752 finalResponse = r.Response
753 }
754
755 summary := strings.TrimSpace(finalResponse.Content)
756 if summary == "" {
757 event = AgentEvent{
758 Type: AgentEventTypeError,
759 Error: fmt.Errorf("empty summary returned"),
760 Done: true,
761 }
762 a.Publish(pubsub.CreatedEvent, event)
763 return
764 }
765 event = AgentEvent{
766 Type: AgentEventTypeSummarize,
767 Progress: "Creating new session...",
768 }
769
770 a.Publish(pubsub.CreatedEvent, event)
771 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
772 if err != nil {
773 event = AgentEvent{
774 Type: AgentEventTypeError,
775 Error: fmt.Errorf("failed to get session: %w", err),
776 Done: true,
777 }
778
779 a.Publish(pubsub.CreatedEvent, event)
780 return
781 }
782 // Create a message in the new session with the summary
783 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
784 Role: message.Assistant,
785 Parts: []message.ContentPart{
786 message.TextContent{Text: summary},
787 message.Finish{
788 Reason: message.FinishReasonEndTurn,
789 Time: time.Now().Unix(),
790 },
791 },
792 Model: a.summarizeProvider.Model().ID,
793 Provider: a.summarizeProviderID,
794 })
795 if err != nil {
796 event = AgentEvent{
797 Type: AgentEventTypeError,
798 Error: fmt.Errorf("failed to create summary message: %w", err),
799 Done: true,
800 }
801
802 a.Publish(pubsub.CreatedEvent, event)
803 return
804 }
805 oldSession.SummaryMessageID = msg.ID
806 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
807 oldSession.PromptTokens = 0
808 model := a.summarizeProvider.Model()
809 usage := finalResponse.Usage
810 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
811 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
812 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
813 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
814 oldSession.Cost += cost
815 _, err = a.sessions.Save(summarizeCtx, oldSession)
816 if err != nil {
817 event = AgentEvent{
818 Type: AgentEventTypeError,
819 Error: fmt.Errorf("failed to save session: %w", err),
820 Done: true,
821 }
822 a.Publish(pubsub.CreatedEvent, event)
823 }
824
825 event = AgentEvent{
826 Type: AgentEventTypeSummarize,
827 SessionID: oldSession.ID,
828 Progress: "Summary complete",
829 Done: true,
830 }
831 a.Publish(pubsub.CreatedEvent, event)
832 // Send final success event with the new session ID
833 }()
834
835 return nil
836}
837
838func (a *agent) CancelAll() {
839 if !a.IsBusy() {
840 return
841 }
842 a.activeRequests.Range(func(key, value any) bool {
843 a.Cancel(key.(string)) // key is sessionID
844 return true
845 })
846
847 timeout := time.After(5 * time.Second)
848 for a.IsBusy() {
849 select {
850 case <-timeout:
851 return
852 default:
853 time.Sleep(200 * time.Millisecond)
854 }
855 }
856}
857
858func (a *agent) UpdateModel() error {
859 cfg := config.Get()
860
861 // Get current provider configuration
862 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
863 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
864 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
865 }
866
867 // Check if provider has changed
868 if string(currentProviderCfg.ID) != a.providerID {
869 // Provider changed, need to recreate the main provider
870 model := cfg.GetModelByType(a.agentCfg.Model)
871 if model.ID == "" {
872 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
873 }
874
875 promptID := agentPromptMap[a.agentCfg.ID]
876 if promptID == "" {
877 promptID = prompt.PromptDefault
878 }
879
880 opts := []provider.ProviderClientOption{
881 provider.WithModel(a.agentCfg.Model),
882 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
883 }
884
885 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
886 if err != nil {
887 return fmt.Errorf("failed to create new provider: %w", err)
888 }
889
890 // Update the provider and provider ID
891 a.provider = newProvider
892 a.providerID = string(currentProviderCfg.ID)
893 }
894
895 // Check if small model provider has changed (affects title and summarize providers)
896 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
897 var smallModelProviderCfg config.ProviderConfig
898
899 for _, p := range cfg.Providers {
900 if p.ID == smallModelCfg.Provider {
901 smallModelProviderCfg = p
902 break
903 }
904 }
905
906 if smallModelProviderCfg.ID == "" {
907 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
908 }
909
910 // Check if summarize provider has changed
911 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
912 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
913 if smallModel == nil {
914 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
915 }
916
917 // Recreate title provider
918 titleOpts := []provider.ProviderClientOption{
919 provider.WithModel(config.SelectedModelTypeSmall),
920 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
921 // We want the title to be short, so we limit the max tokens
922 provider.WithMaxTokens(40),
923 }
924 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
925 if err != nil {
926 return fmt.Errorf("failed to create new title provider: %w", err)
927 }
928
929 // Recreate summarize provider
930 summarizeOpts := []provider.ProviderClientOption{
931 provider.WithModel(config.SelectedModelTypeSmall),
932 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
933 }
934 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
935 if err != nil {
936 return fmt.Errorf("failed to create new summarize provider: %w", err)
937 }
938
939 // Update the providers and provider ID
940 a.titleProvider = newTitleProvider
941 a.summarizeProvider = newSummarizeProvider
942 a.summarizeProviderID = string(smallModelProviderCfg.ID)
943 }
944
945 return nil
946}