1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/llm/models"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/llm/provider"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21)
22
23// Common errors
24var (
25 ErrRequestCancelled = errors.New("request cancelled by user")
26 ErrSessionBusy = errors.New("session is currently processing another request")
27)
28
29type AgentEventType string
30
31const (
32 AgentEventTypeError AgentEventType = "error"
33 AgentEventTypeResponse AgentEventType = "response"
34 AgentEventTypeSummarize AgentEventType = "summarize"
35)
36
37type AgentEvent struct {
38 Type AgentEventType
39 Message message.Message
40 Error error
41
42 // When summarizing
43 SessionID string
44 Progress string
45 Done bool
46}
47
48type Service interface {
49 pubsub.Suscriber[AgentEvent]
50 Model() models.Model
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 IsSessionBusy(sessionID string) bool
54 IsBusy() bool
55 Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
56 Summarize(ctx context.Context, sessionID string) error
57}
58
59type agent struct {
60 *pubsub.Broker[AgentEvent]
61 sessions session.Service
62 messages message.Service
63
64 tools []tools.BaseTool
65 provider provider.Provider
66
67 titleProvider provider.Provider
68 summarizeProvider provider.Provider
69
70 activeRequests sync.Map
71}
72
73func NewAgent(
74 agentName config.AgentName,
75 sessions session.Service,
76 messages message.Service,
77 agentTools []tools.BaseTool,
78) (Service, error) {
79 agentProvider, err := createAgentProvider(agentName)
80 if err != nil {
81 return nil, err
82 }
83 var titleProvider provider.Provider
84 // Only generate titles for the coder agent
85 if agentName == config.AgentCoder {
86 titleProvider, err = createAgentProvider(config.AgentTitle)
87 if err != nil {
88 return nil, err
89 }
90 }
91 var summarizeProvider provider.Provider
92 if agentName == config.AgentCoder {
93 summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
94 if err != nil {
95 return nil, err
96 }
97 }
98
99 agent := &agent{
100 Broker: pubsub.NewBroker[AgentEvent](),
101 provider: agentProvider,
102 messages: messages,
103 sessions: sessions,
104 tools: agentTools,
105 titleProvider: titleProvider,
106 summarizeProvider: summarizeProvider,
107 activeRequests: sync.Map{},
108 }
109
110 return agent, nil
111}
112
113func (a *agent) Model() models.Model {
114 return a.provider.Model()
115}
116
117func (a *agent) Cancel(sessionID string) {
118 // Cancel regular requests
119 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
120 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
121 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
122 cancel()
123 }
124 }
125
126 // Also check for summarize requests
127 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
128 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
129 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
130 cancel()
131 }
132 }
133}
134
135func (a *agent) IsBusy() bool {
136 busy := false
137 a.activeRequests.Range(func(key, value interface{}) bool {
138 if cancelFunc, ok := value.(context.CancelFunc); ok {
139 if cancelFunc != nil {
140 busy = true
141 return false // Stop iterating
142 }
143 }
144 return true // Continue iterating
145 })
146 return busy
147}
148
149func (a *agent) IsSessionBusy(sessionID string) bool {
150 _, busy := a.activeRequests.Load(sessionID)
151 return busy
152}
153
154func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
155 if content == "" {
156 return nil
157 }
158 if a.titleProvider == nil {
159 return nil
160 }
161 session, err := a.sessions.Get(ctx, sessionID)
162 if err != nil {
163 return err
164 }
165 parts := []message.ContentPart{message.TextContent{Text: content}}
166
167 // Use streaming approach like summarization
168 response := a.titleProvider.StreamResponse(
169 ctx,
170 []message.Message{
171 {
172 Role: message.User,
173 Parts: parts,
174 },
175 },
176 make([]tools.BaseTool, 0),
177 )
178
179 var finalResponse *provider.ProviderResponse
180 for r := range response {
181 if r.Error != nil {
182 return r.Error
183 }
184 finalResponse = r.Response
185 }
186
187 if finalResponse == nil {
188 return fmt.Errorf("no response received from title provider")
189 }
190
191 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
192 if title == "" {
193 return nil
194 }
195
196 session.Title = title
197 _, err = a.sessions.Save(ctx, session)
198 return err
199}
200
201func (a *agent) err(err error) AgentEvent {
202 return AgentEvent{
203 Type: AgentEventTypeError,
204 Error: err,
205 }
206}
207
208func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
209 if !a.provider.Model().SupportsAttachments && attachments != nil {
210 attachments = nil
211 }
212 events := make(chan AgentEvent)
213 if a.IsSessionBusy(sessionID) {
214 return nil, ErrSessionBusy
215 }
216
217 genCtx, cancel := context.WithCancel(ctx)
218
219 a.activeRequests.Store(sessionID, cancel)
220 go func() {
221 logging.Debug("Request started", "sessionID", sessionID)
222 defer logging.RecoverPanic("agent.Run", func() {
223 events <- a.err(fmt.Errorf("panic while running the agent"))
224 })
225 var attachmentParts []message.ContentPart
226 for _, attachment := range attachments {
227 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
228 }
229 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
230 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
231 logging.ErrorPersist(result.Error.Error())
232 }
233 logging.Debug("Request completed", "sessionID", sessionID)
234 a.activeRequests.Delete(sessionID)
235 cancel()
236 a.Publish(pubsub.CreatedEvent, result)
237 events <- result
238 close(events)
239 }()
240 return events, nil
241}
242
243func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
244 // List existing messages; if none, start title generation asynchronously.
245 msgs, err := a.messages.List(ctx, sessionID)
246 if err != nil {
247 return a.err(fmt.Errorf("failed to list messages: %w", err))
248 }
249 if len(msgs) == 0 {
250 go func() {
251 defer logging.RecoverPanic("agent.Run", func() {
252 logging.ErrorPersist("panic while generating title")
253 })
254 titleErr := a.generateTitle(context.Background(), sessionID, content)
255 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
256 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
257 }
258 }()
259 }
260 session, err := a.sessions.Get(ctx, sessionID)
261 if err != nil {
262 return a.err(fmt.Errorf("failed to get session: %w", err))
263 }
264 if session.SummaryMessageID != "" {
265 summaryMsgInex := -1
266 for i, msg := range msgs {
267 if msg.ID == session.SummaryMessageID {
268 summaryMsgInex = i
269 break
270 }
271 }
272 if summaryMsgInex != -1 {
273 msgs = msgs[summaryMsgInex:]
274 msgs[0].Role = message.User
275 }
276 }
277
278 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
279 if err != nil {
280 return a.err(fmt.Errorf("failed to create user message: %w", err))
281 }
282 // Append the new user message to the conversation history.
283 msgHistory := append(msgs, userMsg)
284
285 for {
286 // Check for cancellation before each iteration
287 select {
288 case <-ctx.Done():
289 return a.err(ctx.Err())
290 default:
291 // Continue processing
292 }
293 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
294 if err != nil {
295 if errors.Is(err, context.Canceled) {
296 agentMessage.AddFinish(message.FinishReasonCanceled)
297 a.messages.Update(context.Background(), agentMessage)
298 return a.err(ErrRequestCancelled)
299 }
300 return a.err(fmt.Errorf("failed to process events: %w", err))
301 }
302 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
303 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
304 // We are not done, we need to respond with the tool response
305 msgHistory = append(msgHistory, agentMessage, *toolResults)
306 continue
307 }
308 return AgentEvent{
309 Type: AgentEventTypeResponse,
310 Message: agentMessage,
311 Done: true,
312 }
313 }
314}
315
316func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
317 parts := []message.ContentPart{message.TextContent{Text: content}}
318 parts = append(parts, attachmentParts...)
319 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
320 Role: message.User,
321 Parts: parts,
322 })
323}
324
325func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
326 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
327
328 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
329 Role: message.Assistant,
330 Parts: []message.ContentPart{},
331 Model: a.provider.Model().ID,
332 })
333 if err != nil {
334 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
335 }
336
337 // Add the session and message ID into the context if needed by tools.
338 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
339 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
340
341 // Process each event in the stream.
342 for event := range eventChan {
343 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
344 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
345 return assistantMsg, nil, processErr
346 }
347 if ctx.Err() != nil {
348 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
349 return assistantMsg, nil, ctx.Err()
350 }
351 }
352
353 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
354 toolCalls := assistantMsg.ToolCalls()
355 for i, toolCall := range toolCalls {
356 select {
357 case <-ctx.Done():
358 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
359 // Make all future tool calls cancelled
360 for j := i; j < len(toolCalls); j++ {
361 toolResults[j] = message.ToolResult{
362 ToolCallID: toolCalls[j].ID,
363 Content: "Tool execution canceled by user",
364 IsError: true,
365 }
366 }
367 goto out
368 default:
369 // Continue processing
370 var tool tools.BaseTool
371 for _, availableTools := range a.tools {
372 if availableTools.Info().Name == toolCall.Name {
373 tool = availableTools
374 }
375 }
376
377 // Tool not found
378 if tool == nil {
379 toolResults[i] = message.ToolResult{
380 ToolCallID: toolCall.ID,
381 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
382 IsError: true,
383 }
384 continue
385 }
386 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
387 ID: toolCall.ID,
388 Name: toolCall.Name,
389 Input: toolCall.Input,
390 })
391 if toolErr != nil {
392 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
393 toolResults[i] = message.ToolResult{
394 ToolCallID: toolCall.ID,
395 Content: "Permission denied",
396 IsError: true,
397 }
398 for j := i + 1; j < len(toolCalls); j++ {
399 toolResults[j] = message.ToolResult{
400 ToolCallID: toolCalls[j].ID,
401 Content: "Tool execution canceled by user",
402 IsError: true,
403 }
404 }
405 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
406 break
407 }
408 }
409 toolResults[i] = message.ToolResult{
410 ToolCallID: toolCall.ID,
411 Content: toolResult.Content,
412 Metadata: toolResult.Metadata,
413 IsError: toolResult.IsError,
414 }
415 }
416 }
417out:
418 if len(toolResults) == 0 {
419 return assistantMsg, nil, nil
420 }
421 parts := make([]message.ContentPart, 0)
422 for _, tr := range toolResults {
423 parts = append(parts, tr)
424 }
425 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
426 Role: message.Tool,
427 Parts: parts,
428 })
429 if err != nil {
430 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
431 }
432
433 return assistantMsg, &msg, err
434}
435
436func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
437 msg.AddFinish(finishReson)
438 _ = a.messages.Update(ctx, *msg)
439}
440
441func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
442 select {
443 case <-ctx.Done():
444 return ctx.Err()
445 default:
446 // Continue processing.
447 }
448
449 switch event.Type {
450 case provider.EventThinkingDelta:
451 assistantMsg.AppendReasoningContent(event.Content)
452 return a.messages.Update(ctx, *assistantMsg)
453 case provider.EventContentDelta:
454 assistantMsg.AppendContent(event.Content)
455 return a.messages.Update(ctx, *assistantMsg)
456 case provider.EventToolUseStart:
457 logging.Info("Tool call started", "toolCall", event.ToolCall)
458 assistantMsg.AddToolCall(*event.ToolCall)
459 return a.messages.Update(ctx, *assistantMsg)
460 case provider.EventToolUseDelta:
461 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
462 return a.messages.Update(ctx, *assistantMsg)
463 case provider.EventToolUseStop:
464 logging.Info("Finished tool call", "toolCall", event.ToolCall)
465 assistantMsg.FinishToolCall(event.ToolCall.ID)
466 return a.messages.Update(ctx, *assistantMsg)
467 case provider.EventError:
468 if errors.Is(event.Error, context.Canceled) {
469 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
470 return context.Canceled
471 }
472 logging.ErrorPersist(event.Error.Error())
473 return event.Error
474 case provider.EventComplete:
475 assistantMsg.SetToolCalls(event.Response.ToolCalls)
476 assistantMsg.AddFinish(event.Response.FinishReason)
477 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
478 return fmt.Errorf("failed to update message: %w", err)
479 }
480 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
481 }
482
483 return nil
484}
485
486func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
487 sess, err := a.sessions.Get(ctx, sessionID)
488 if err != nil {
489 return fmt.Errorf("failed to get session: %w", err)
490 }
491
492 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
493 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
494 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
495 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
496
497 sess.Cost += cost
498 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
499 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
500
501 _, err = a.sessions.Save(ctx, sess)
502 if err != nil {
503 return fmt.Errorf("failed to save session: %w", err)
504 }
505 return nil
506}
507
508func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
509 if a.IsBusy() {
510 return models.Model{}, fmt.Errorf("cannot change model while processing requests")
511 }
512
513 if err := config.UpdateAgentModel(agentName, modelID); err != nil {
514 return models.Model{}, fmt.Errorf("failed to update config: %w", err)
515 }
516
517 provider, err := createAgentProvider(agentName)
518 if err != nil {
519 return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
520 }
521
522 a.provider = provider
523
524 return a.provider.Model(), nil
525}
526
527func (a *agent) Summarize(ctx context.Context, sessionID string) error {
528 if a.summarizeProvider == nil {
529 return fmt.Errorf("summarize provider not available")
530 }
531
532 // Check if session is busy
533 if a.IsSessionBusy(sessionID) {
534 return ErrSessionBusy
535 }
536
537 // Create a new context with cancellation
538 summarizeCtx, cancel := context.WithCancel(ctx)
539
540 // Store the cancel function in activeRequests to allow cancellation
541 a.activeRequests.Store(sessionID+"-summarize", cancel)
542
543 go func() {
544 defer a.activeRequests.Delete(sessionID + "-summarize")
545 defer cancel()
546 event := AgentEvent{
547 Type: AgentEventTypeSummarize,
548 Progress: "Starting summarization...",
549 }
550
551 a.Publish(pubsub.CreatedEvent, event)
552 // Get all messages from the session
553 msgs, err := a.messages.List(summarizeCtx, sessionID)
554 if err != nil {
555 event = AgentEvent{
556 Type: AgentEventTypeError,
557 Error: fmt.Errorf("failed to list messages: %w", err),
558 Done: true,
559 }
560 a.Publish(pubsub.CreatedEvent, event)
561 return
562 }
563
564 if len(msgs) == 0 {
565 event = AgentEvent{
566 Type: AgentEventTypeError,
567 Error: fmt.Errorf("no messages to summarize"),
568 Done: true,
569 }
570 a.Publish(pubsub.CreatedEvent, event)
571 return
572 }
573
574 event = AgentEvent{
575 Type: AgentEventTypeSummarize,
576 Progress: "Analyzing conversation...",
577 }
578 a.Publish(pubsub.CreatedEvent, event)
579
580 // Add a system message to guide the summarization
581 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
582
583 // Create a new message with the summarize prompt
584 promptMsg := message.Message{
585 Role: message.User,
586 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
587 }
588
589 // Append the prompt to the messages
590 msgsWithPrompt := append(msgs, promptMsg)
591
592 event = AgentEvent{
593 Type: AgentEventTypeSummarize,
594 Progress: "Generating summary...",
595 }
596
597 a.Publish(pubsub.CreatedEvent, event)
598
599 // Send the messages to the summarize provider
600 response := a.summarizeProvider.StreamResponse(
601 summarizeCtx,
602 msgsWithPrompt,
603 make([]tools.BaseTool, 0),
604 )
605 var finalResponse *provider.ProviderResponse
606 for r := range response {
607 if r.Error != nil {
608 event = AgentEvent{
609 Type: AgentEventTypeError,
610 Error: fmt.Errorf("failed to summarize: %w", err),
611 Done: true,
612 }
613 a.Publish(pubsub.CreatedEvent, event)
614 return
615 }
616 finalResponse = r.Response
617 }
618
619 summary := strings.TrimSpace(finalResponse.Content)
620 if summary == "" {
621 event = AgentEvent{
622 Type: AgentEventTypeError,
623 Error: fmt.Errorf("empty summary returned"),
624 Done: true,
625 }
626 a.Publish(pubsub.CreatedEvent, event)
627 return
628 }
629 event = AgentEvent{
630 Type: AgentEventTypeSummarize,
631 Progress: "Creating new session...",
632 }
633
634 a.Publish(pubsub.CreatedEvent, event)
635 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
636 if err != nil {
637 event = AgentEvent{
638 Type: AgentEventTypeError,
639 Error: fmt.Errorf("failed to get session: %w", err),
640 Done: true,
641 }
642
643 a.Publish(pubsub.CreatedEvent, event)
644 return
645 }
646 // Create a message in the new session with the summary
647 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
648 Role: message.Assistant,
649 Parts: []message.ContentPart{
650 message.TextContent{Text: summary},
651 message.Finish{
652 Reason: message.FinishReasonEndTurn,
653 Time: time.Now().Unix(),
654 },
655 },
656 Model: a.summarizeProvider.Model().ID,
657 })
658 if err != nil {
659 event = AgentEvent{
660 Type: AgentEventTypeError,
661 Error: fmt.Errorf("failed to create summary message: %w", err),
662 Done: true,
663 }
664
665 a.Publish(pubsub.CreatedEvent, event)
666 return
667 }
668 oldSession.SummaryMessageID = msg.ID
669 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
670 oldSession.PromptTokens = 0
671 model := a.summarizeProvider.Model()
672 usage := finalResponse.Usage
673 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
674 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
675 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
676 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
677 oldSession.Cost += cost
678 _, err = a.sessions.Save(summarizeCtx, oldSession)
679 if err != nil {
680 event = AgentEvent{
681 Type: AgentEventTypeError,
682 Error: fmt.Errorf("failed to save session: %w", err),
683 Done: true,
684 }
685 a.Publish(pubsub.CreatedEvent, event)
686 }
687
688 event = AgentEvent{
689 Type: AgentEventTypeSummarize,
690 SessionID: oldSession.ID,
691 Progress: "Summary complete",
692 Done: true,
693 }
694 a.Publish(pubsub.CreatedEvent, event)
695 // Send final success event with the new session ID
696 }()
697
698 return nil
699}
700
701func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
702 cfg := config.Get()
703 agentConfig, ok := cfg.Agents[agentName]
704 if !ok {
705 return nil, fmt.Errorf("agent %s not found", agentName)
706 }
707 model, ok := models.SupportedModels[agentConfig.Model]
708 if !ok {
709 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
710 }
711
712 providerCfg, ok := cfg.Providers[model.Provider]
713 if !ok {
714 return nil, fmt.Errorf("provider %s not supported", model.Provider)
715 }
716 if providerCfg.Disabled {
717 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
718 }
719 maxTokens := model.DefaultMaxTokens
720 if agentConfig.MaxTokens > 0 {
721 maxTokens = agentConfig.MaxTokens
722 }
723 opts := []provider.ProviderClientOption{
724 provider.WithAPIKey(providerCfg.APIKey),
725 provider.WithModel(model),
726 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
727 provider.WithMaxTokens(maxTokens),
728 }
729 if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
730 opts = append(
731 opts,
732 provider.WithOpenAIOptions(
733 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
734 ),
735 )
736 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
737 opts = append(
738 opts,
739 provider.WithAnthropicOptions(
740 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
741 ),
742 )
743 }
744 agentProvider, err := provider.NewProvider(
745 model.Provider,
746 opts...,
747 )
748 if err != nil {
749 return nil, fmt.Errorf("could not create provider: %v", err)
750 }
751
752 return agentProvider, nil
753}