1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/llm/models"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/llm/provider"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21)
22
23// Common errors
24var (
25 ErrRequestCancelled = errors.New("request cancelled by user")
26 ErrSessionBusy = errors.New("session is currently processing another request")
27)
28
29type AgentEventType string
30
31const (
32 AgentEventTypeError AgentEventType = "error"
33 AgentEventTypeResponse AgentEventType = "response"
34 AgentEventTypeSummarize AgentEventType = "summarize"
35)
36
37type AgentEvent struct {
38 Type AgentEventType
39 Message message.Message
40 Error error
41
42 // When summarizing
43 SessionID string
44 Progress string
45 Done bool
46}
47
48type Service interface {
49 pubsub.Suscriber[AgentEvent]
50 Model() models.Model
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 IsSessionBusy(sessionID string) bool
54 IsBusy() bool
55 Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
56 Summarize(ctx context.Context, sessionID string) error
57}
58
59type agent struct {
60 *pubsub.Broker[AgentEvent]
61 sessions session.Service
62 messages message.Service
63
64 tools []tools.BaseTool
65 provider provider.Provider
66
67 titleProvider provider.Provider
68 summarizeProvider provider.Provider
69
70 activeRequests sync.Map
71}
72
73func NewAgent(
74 agentName config.AgentName,
75 sessions session.Service,
76 messages message.Service,
77 agentTools []tools.BaseTool,
78) (Service, error) {
79 agentProvider, err := createAgentProvider(agentName)
80 if err != nil {
81 return nil, err
82 }
83 var titleProvider provider.Provider
84 // Only generate titles for the coder agent
85 if agentName == config.AgentCoder {
86 titleProvider, err = createAgentProvider(config.AgentTitle)
87 if err != nil {
88 return nil, err
89 }
90 }
91 var summarizeProvider provider.Provider
92 if agentName == config.AgentCoder {
93 summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
94 if err != nil {
95 return nil, err
96 }
97 }
98
99 agent := &agent{
100 Broker: pubsub.NewBroker[AgentEvent](),
101 provider: agentProvider,
102 messages: messages,
103 sessions: sessions,
104 tools: agentTools,
105 titleProvider: titleProvider,
106 summarizeProvider: summarizeProvider,
107 activeRequests: sync.Map{},
108 }
109
110 return agent, nil
111}
112
113func (a *agent) Model() models.Model {
114 return a.provider.Model()
115}
116
117func (a *agent) Cancel(sessionID string) {
118 // Cancel regular requests
119 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
120 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
121 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
122 cancel()
123 }
124 }
125
126 // Also check for summarize requests
127 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
128 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
129 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
130 cancel()
131 }
132 }
133}
134
135func (a *agent) IsBusy() bool {
136 busy := false
137 a.activeRequests.Range(func(key, value interface{}) bool {
138 if cancelFunc, ok := value.(context.CancelFunc); ok {
139 if cancelFunc != nil {
140 busy = true
141 return false // Stop iterating
142 }
143 }
144 return true // Continue iterating
145 })
146 return busy
147}
148
149func (a *agent) IsSessionBusy(sessionID string) bool {
150 _, busy := a.activeRequests.Load(sessionID)
151 return busy
152}
153
154func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
155 if content == "" {
156 return nil
157 }
158 if a.titleProvider == nil {
159 return nil
160 }
161 session, err := a.sessions.Get(ctx, sessionID)
162 if err != nil {
163 return err
164 }
165 parts := []message.ContentPart{message.TextContent{Text: content}}
166 response, err := a.titleProvider.SendMessages(
167 ctx,
168 []message.Message{
169 {
170 Role: message.User,
171 Parts: parts,
172 },
173 },
174 make([]tools.BaseTool, 0),
175 )
176 if err != nil {
177 return err
178 }
179
180 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
181 if title == "" {
182 return nil
183 }
184
185 session.Title = title
186 _, err = a.sessions.Save(ctx, session)
187 return err
188}
189
190func (a *agent) err(err error) AgentEvent {
191 return AgentEvent{
192 Type: AgentEventTypeError,
193 Error: err,
194 }
195}
196
197func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
198 if !a.provider.Model().SupportsAttachments && attachments != nil {
199 attachments = nil
200 }
201 events := make(chan AgentEvent)
202 if a.IsSessionBusy(sessionID) {
203 return nil, ErrSessionBusy
204 }
205
206 genCtx, cancel := context.WithCancel(ctx)
207
208 a.activeRequests.Store(sessionID, cancel)
209 go func() {
210 logging.Debug("Request started", "sessionID", sessionID)
211 defer logging.RecoverPanic("agent.Run", func() {
212 events <- a.err(fmt.Errorf("panic while running the agent"))
213 })
214 var attachmentParts []message.ContentPart
215 for _, attachment := range attachments {
216 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
217 }
218 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
219 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
220 logging.ErrorPersist(result.Error.Error())
221 }
222 logging.Debug("Request completed", "sessionID", sessionID)
223 a.activeRequests.Delete(sessionID)
224 cancel()
225 a.Publish(pubsub.CreatedEvent, result)
226 events <- result
227 close(events)
228 }()
229 return events, nil
230}
231
232func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
233 // List existing messages; if none, start title generation asynchronously.
234 msgs, err := a.messages.List(ctx, sessionID)
235 if err != nil {
236 return a.err(fmt.Errorf("failed to list messages: %w", err))
237 }
238 if len(msgs) == 0 {
239 go func() {
240 defer logging.RecoverPanic("agent.Run", func() {
241 logging.ErrorPersist("panic while generating title")
242 })
243 titleErr := a.generateTitle(context.Background(), sessionID, content)
244 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
245 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
246 }
247 }()
248 }
249 session, err := a.sessions.Get(ctx, sessionID)
250 if err != nil {
251 return a.err(fmt.Errorf("failed to get session: %w", err))
252 }
253 if session.SummaryMessageID != "" {
254 summaryMsgInex := -1
255 for i, msg := range msgs {
256 if msg.ID == session.SummaryMessageID {
257 summaryMsgInex = i
258 break
259 }
260 }
261 if summaryMsgInex != -1 {
262 msgs = msgs[summaryMsgInex:]
263 msgs[0].Role = message.User
264 }
265 }
266
267 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
268 if err != nil {
269 return a.err(fmt.Errorf("failed to create user message: %w", err))
270 }
271 // Append the new user message to the conversation history.
272 msgHistory := append(msgs, userMsg)
273
274 for {
275 // Check for cancellation before each iteration
276 select {
277 case <-ctx.Done():
278 return a.err(ctx.Err())
279 default:
280 // Continue processing
281 }
282 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
283 if err != nil {
284 if errors.Is(err, context.Canceled) {
285 agentMessage.AddFinish(message.FinishReasonCanceled)
286 a.messages.Update(context.Background(), agentMessage)
287 return a.err(ErrRequestCancelled)
288 }
289 return a.err(fmt.Errorf("failed to process events: %w", err))
290 }
291 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
292 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
293 // We are not done, we need to respond with the tool response
294 msgHistory = append(msgHistory, agentMessage, *toolResults)
295 continue
296 }
297 return AgentEvent{
298 Type: AgentEventTypeResponse,
299 Message: agentMessage,
300 Done: true,
301 }
302 }
303}
304
305func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
306 parts := []message.ContentPart{message.TextContent{Text: content}}
307 parts = append(parts, attachmentParts...)
308 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
309 Role: message.User,
310 Parts: parts,
311 })
312}
313
314func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
315 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
316
317 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
318 Role: message.Assistant,
319 Parts: []message.ContentPart{},
320 Model: a.provider.Model().ID,
321 })
322 if err != nil {
323 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
324 }
325
326 // Add the session and message ID into the context if needed by tools.
327 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
328 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
329
330 // Process each event in the stream.
331 for event := range eventChan {
332 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
333 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
334 return assistantMsg, nil, processErr
335 }
336 if ctx.Err() != nil {
337 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
338 return assistantMsg, nil, ctx.Err()
339 }
340 }
341
342 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
343 toolCalls := assistantMsg.ToolCalls()
344 for i, toolCall := range toolCalls {
345 select {
346 case <-ctx.Done():
347 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
348 // Make all future tool calls cancelled
349 for j := i; j < len(toolCalls); j++ {
350 toolResults[j] = message.ToolResult{
351 ToolCallID: toolCalls[j].ID,
352 Content: "Tool execution canceled by user",
353 IsError: true,
354 }
355 }
356 goto out
357 default:
358 // Continue processing
359 var tool tools.BaseTool
360 for _, availableTools := range a.tools {
361 if availableTools.Info().Name == toolCall.Name {
362 tool = availableTools
363 }
364 }
365
366 // Tool not found
367 if tool == nil {
368 toolResults[i] = message.ToolResult{
369 ToolCallID: toolCall.ID,
370 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
371 IsError: true,
372 }
373 continue
374 }
375 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
376 ID: toolCall.ID,
377 Name: toolCall.Name,
378 Input: toolCall.Input,
379 })
380 if toolErr != nil {
381 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
382 toolResults[i] = message.ToolResult{
383 ToolCallID: toolCall.ID,
384 Content: "Permission denied",
385 IsError: true,
386 }
387 for j := i + 1; j < len(toolCalls); j++ {
388 toolResults[j] = message.ToolResult{
389 ToolCallID: toolCalls[j].ID,
390 Content: "Tool execution canceled by user",
391 IsError: true,
392 }
393 }
394 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
395 break
396 }
397 }
398 toolResults[i] = message.ToolResult{
399 ToolCallID: toolCall.ID,
400 Content: toolResult.Content,
401 Metadata: toolResult.Metadata,
402 IsError: toolResult.IsError,
403 }
404 }
405 }
406out:
407 if len(toolResults) == 0 {
408 return assistantMsg, nil, nil
409 }
410 parts := make([]message.ContentPart, 0)
411 for _, tr := range toolResults {
412 parts = append(parts, tr)
413 }
414 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
415 Role: message.Tool,
416 Parts: parts,
417 })
418 if err != nil {
419 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
420 }
421
422 return assistantMsg, &msg, err
423}
424
425func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
426 msg.AddFinish(finishReson)
427 _ = a.messages.Update(ctx, *msg)
428}
429
430func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
431 select {
432 case <-ctx.Done():
433 return ctx.Err()
434 default:
435 // Continue processing.
436 }
437
438 switch event.Type {
439 case provider.EventThinkingDelta:
440 assistantMsg.AppendReasoningContent(event.Content)
441 return a.messages.Update(ctx, *assistantMsg)
442 case provider.EventContentDelta:
443 assistantMsg.AppendContent(event.Content)
444 return a.messages.Update(ctx, *assistantMsg)
445 case provider.EventToolUseStart:
446 logging.Info("Tool call started", "toolCall", event.ToolCall)
447 assistantMsg.AddToolCall(*event.ToolCall)
448 return a.messages.Update(ctx, *assistantMsg)
449 case provider.EventToolUseDelta:
450 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
451 return a.messages.Update(ctx, *assistantMsg)
452 case provider.EventToolUseStop:
453 logging.Info("Finished tool call", "toolCall", event.ToolCall)
454 assistantMsg.FinishToolCall(event.ToolCall.ID)
455 return a.messages.Update(ctx, *assistantMsg)
456 case provider.EventError:
457 if errors.Is(event.Error, context.Canceled) {
458 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
459 return context.Canceled
460 }
461 logging.ErrorPersist(event.Error.Error())
462 return event.Error
463 case provider.EventComplete:
464 assistantMsg.SetToolCalls(event.Response.ToolCalls)
465 assistantMsg.AddFinish(event.Response.FinishReason)
466 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
467 return fmt.Errorf("failed to update message: %w", err)
468 }
469 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
470 }
471
472 return nil
473}
474
475func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
476 sess, err := a.sessions.Get(ctx, sessionID)
477 if err != nil {
478 return fmt.Errorf("failed to get session: %w", err)
479 }
480
481 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
482 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
483 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
484 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
485
486 sess.Cost += cost
487 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
488 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
489
490 _, err = a.sessions.Save(ctx, sess)
491 if err != nil {
492 return fmt.Errorf("failed to save session: %w", err)
493 }
494 return nil
495}
496
497func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
498 if a.IsBusy() {
499 return models.Model{}, fmt.Errorf("cannot change model while processing requests")
500 }
501
502 if err := config.UpdateAgentModel(agentName, modelID); err != nil {
503 return models.Model{}, fmt.Errorf("failed to update config: %w", err)
504 }
505
506 provider, err := createAgentProvider(agentName)
507 if err != nil {
508 return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
509 }
510
511 a.provider = provider
512
513 return a.provider.Model(), nil
514}
515
516func (a *agent) Summarize(ctx context.Context, sessionID string) error {
517 if a.summarizeProvider == nil {
518 return fmt.Errorf("summarize provider not available")
519 }
520
521 // Check if session is busy
522 if a.IsSessionBusy(sessionID) {
523 return ErrSessionBusy
524 }
525
526 // Create a new context with cancellation
527 summarizeCtx, cancel := context.WithCancel(ctx)
528
529 // Store the cancel function in activeRequests to allow cancellation
530 a.activeRequests.Store(sessionID+"-summarize", cancel)
531
532 go func() {
533 defer a.activeRequests.Delete(sessionID + "-summarize")
534 defer cancel()
535 event := AgentEvent{
536 Type: AgentEventTypeSummarize,
537 Progress: "Starting summarization...",
538 }
539
540 a.Publish(pubsub.CreatedEvent, event)
541 // Get all messages from the session
542 msgs, err := a.messages.List(summarizeCtx, sessionID)
543 if err != nil {
544 event = AgentEvent{
545 Type: AgentEventTypeError,
546 Error: fmt.Errorf("failed to list messages: %w", err),
547 Done: true,
548 }
549 a.Publish(pubsub.CreatedEvent, event)
550 return
551 }
552
553 if len(msgs) == 0 {
554 event = AgentEvent{
555 Type: AgentEventTypeError,
556 Error: fmt.Errorf("no messages to summarize"),
557 Done: true,
558 }
559 a.Publish(pubsub.CreatedEvent, event)
560 return
561 }
562
563 event = AgentEvent{
564 Type: AgentEventTypeSummarize,
565 Progress: "Analyzing conversation...",
566 }
567 a.Publish(pubsub.CreatedEvent, event)
568
569 // Add a system message to guide the summarization
570 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
571
572 // Create a new message with the summarize prompt
573 promptMsg := message.Message{
574 Role: message.User,
575 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
576 }
577
578 // Append the prompt to the messages
579 msgsWithPrompt := append(msgs, promptMsg)
580
581 event = AgentEvent{
582 Type: AgentEventTypeSummarize,
583 Progress: "Generating summary...",
584 }
585
586 a.Publish(pubsub.CreatedEvent, event)
587
588 // Send the messages to the summarize provider
589 response := a.summarizeProvider.StreamResponse(
590 summarizeCtx,
591 msgsWithPrompt,
592 make([]tools.BaseTool, 0),
593 )
594 var finalResponse *provider.ProviderResponse
595 for r := range response {
596 if r.Error != nil {
597 event = AgentEvent{
598 Type: AgentEventTypeError,
599 Error: fmt.Errorf("failed to summarize: %w", err),
600 Done: true,
601 }
602 a.Publish(pubsub.CreatedEvent, event)
603 return
604 }
605 finalResponse = r.Response
606 }
607
608 summary := strings.TrimSpace(finalResponse.Content)
609 if summary == "" {
610 event = AgentEvent{
611 Type: AgentEventTypeError,
612 Error: fmt.Errorf("empty summary returned"),
613 Done: true,
614 }
615 a.Publish(pubsub.CreatedEvent, event)
616 return
617 }
618 event = AgentEvent{
619 Type: AgentEventTypeSummarize,
620 Progress: "Creating new session...",
621 }
622
623 a.Publish(pubsub.CreatedEvent, event)
624 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
625 if err != nil {
626 event = AgentEvent{
627 Type: AgentEventTypeError,
628 Error: fmt.Errorf("failed to get session: %w", err),
629 Done: true,
630 }
631
632 a.Publish(pubsub.CreatedEvent, event)
633 return
634 }
635 // Create a message in the new session with the summary
636 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
637 Role: message.Assistant,
638 Parts: []message.ContentPart{
639 message.TextContent{Text: summary},
640 message.Finish{
641 Reason: message.FinishReasonEndTurn,
642 Time: time.Now().Unix(),
643 },
644 },
645 Model: a.summarizeProvider.Model().ID,
646 })
647 if err != nil {
648 event = AgentEvent{
649 Type: AgentEventTypeError,
650 Error: fmt.Errorf("failed to create summary message: %w", err),
651 Done: true,
652 }
653
654 a.Publish(pubsub.CreatedEvent, event)
655 return
656 }
657 oldSession.SummaryMessageID = msg.ID
658 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
659 oldSession.PromptTokens = 0
660 model := a.summarizeProvider.Model()
661 usage := finalResponse.Usage
662 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
663 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
664 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
665 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
666 oldSession.Cost += cost
667 _, err = a.sessions.Save(summarizeCtx, oldSession)
668 if err != nil {
669 event = AgentEvent{
670 Type: AgentEventTypeError,
671 Error: fmt.Errorf("failed to save session: %w", err),
672 Done: true,
673 }
674 a.Publish(pubsub.CreatedEvent, event)
675 }
676
677 event = AgentEvent{
678 Type: AgentEventTypeSummarize,
679 SessionID: oldSession.ID,
680 Progress: "Summary complete",
681 Done: true,
682 }
683 a.Publish(pubsub.CreatedEvent, event)
684 // Send final success event with the new session ID
685 }()
686
687 return nil
688}
689
690func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
691 cfg := config.Get()
692 agentConfig, ok := cfg.Agents[agentName]
693 if !ok {
694 return nil, fmt.Errorf("agent %s not found", agentName)
695 }
696 model, ok := models.SupportedModels[agentConfig.Model]
697 if !ok {
698 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
699 }
700
701 providerCfg, ok := cfg.Providers[model.Provider]
702 if !ok {
703 return nil, fmt.Errorf("provider %s not supported", model.Provider)
704 }
705 if providerCfg.Disabled {
706 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
707 }
708 maxTokens := model.DefaultMaxTokens
709 if agentConfig.MaxTokens > 0 {
710 maxTokens = agentConfig.MaxTokens
711 }
712 opts := []provider.ProviderClientOption{
713 provider.WithAPIKey(providerCfg.APIKey),
714 provider.WithModel(model),
715 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
716 provider.WithMaxTokens(maxTokens),
717 }
718 if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
719 opts = append(
720 opts,
721 provider.WithOpenAIOptions(
722 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
723 ),
724 )
725 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
726 opts = append(
727 opts,
728 provider.WithAnthropicOptions(
729 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
730 ),
731 )
732 }
733 agentProvider, err := provider.NewProvider(
734 model.Provider,
735 opts...,
736 )
737 if err != nil {
738 return nil, fmt.Errorf("could not create provider: %v", err)
739 }
740
741 return agentProvider, nil
742}