1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/llm/models"
13 "github.com/charmbracelet/crush/internal/llm/prompt"
14 "github.com/charmbracelet/crush/internal/llm/provider"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/session"
21)
22
23// Common errors
24var (
25 ErrRequestCancelled = errors.New("request cancelled by user")
26 ErrSessionBusy = errors.New("session is currently processing another request")
27)
28
29type AgentEventType string
30
31const (
32 AgentEventTypeError AgentEventType = "error"
33 AgentEventTypeResponse AgentEventType = "response"
34 AgentEventTypeSummarize AgentEventType = "summarize"
35)
36
37type AgentEvent struct {
38 Type AgentEventType
39 Message message.Message
40 Error error
41
42 // When summarizing
43 SessionID string
44 Progress string
45 Done bool
46}
47
48type Service interface {
49 pubsub.Suscriber[AgentEvent]
50 Model() models.Model
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
57 Summarize(ctx context.Context, sessionID string) error
58}
59
60type agent struct {
61 *pubsub.Broker[AgentEvent]
62 sessions session.Service
63 messages message.Service
64
65 tools []tools.BaseTool
66 provider provider.Provider
67
68 titleProvider provider.Provider
69 summarizeProvider provider.Provider
70
71 activeRequests sync.Map
72}
73
74func NewAgent(
75 agentName config.AgentName,
76 sessions session.Service,
77 messages message.Service,
78 agentTools []tools.BaseTool,
79) (Service, error) {
80 agentProvider, err := createAgentProvider(agentName)
81 if err != nil {
82 return nil, err
83 }
84 var titleProvider provider.Provider
85 // Only generate titles for the coder agent
86 if agentName == config.AgentCoder {
87 titleProvider, err = createAgentProvider(config.AgentTitle)
88 if err != nil {
89 return nil, err
90 }
91 }
92 var summarizeProvider provider.Provider
93 if agentName == config.AgentCoder {
94 summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
95 if err != nil {
96 return nil, err
97 }
98 }
99
100 agent := &agent{
101 Broker: pubsub.NewBroker[AgentEvent](),
102 provider: agentProvider,
103 messages: messages,
104 sessions: sessions,
105 tools: agentTools,
106 titleProvider: titleProvider,
107 summarizeProvider: summarizeProvider,
108 activeRequests: sync.Map{},
109 }
110
111 return agent, nil
112}
113
114func (a *agent) Model() models.Model {
115 return a.provider.Model()
116}
117
118func (a *agent) Cancel(sessionID string) {
119 // Cancel regular requests
120 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
121 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
122 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
123 cancel()
124 }
125 }
126
127 // Also check for summarize requests
128 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
129 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
130 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
131 cancel()
132 }
133 }
134}
135
136func (a *agent) IsBusy() bool {
137 busy := false
138 a.activeRequests.Range(func(key, value any) bool {
139 if cancelFunc, ok := value.(context.CancelFunc); ok {
140 if cancelFunc != nil {
141 busy = true
142 return false // Stop iterating
143 }
144 }
145 return true // Continue iterating
146 })
147 return busy
148}
149
150func (a *agent) IsSessionBusy(sessionID string) bool {
151 _, busy := a.activeRequests.Load(sessionID)
152 return busy
153}
154
155func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
156 if content == "" {
157 return nil
158 }
159 if a.titleProvider == nil {
160 return nil
161 }
162 session, err := a.sessions.Get(ctx, sessionID)
163 if err != nil {
164 return err
165 }
166 parts := []message.ContentPart{message.TextContent{Text: content}}
167
168 // Use streaming approach like summarization
169 response := a.titleProvider.StreamResponse(
170 ctx,
171 []message.Message{
172 {
173 Role: message.User,
174 Parts: parts,
175 },
176 },
177 make([]tools.BaseTool, 0),
178 )
179
180 var finalResponse *provider.ProviderResponse
181 for r := range response {
182 if r.Error != nil {
183 return r.Error
184 }
185 finalResponse = r.Response
186 }
187
188 if finalResponse == nil {
189 return fmt.Errorf("no response received from title provider")
190 }
191
192 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
193 if title == "" {
194 return nil
195 }
196
197 session.Title = title
198 _, err = a.sessions.Save(ctx, session)
199 return err
200}
201
202func (a *agent) err(err error) AgentEvent {
203 return AgentEvent{
204 Type: AgentEventTypeError,
205 Error: err,
206 }
207}
208
209func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
210 if !a.provider.Model().SupportsAttachments && attachments != nil {
211 attachments = nil
212 }
213 events := make(chan AgentEvent)
214 if a.IsSessionBusy(sessionID) {
215 return nil, ErrSessionBusy
216 }
217
218 genCtx, cancel := context.WithCancel(ctx)
219
220 a.activeRequests.Store(sessionID, cancel)
221 go func() {
222 logging.Debug("Request started", "sessionID", sessionID)
223 defer logging.RecoverPanic("agent.Run", func() {
224 events <- a.err(fmt.Errorf("panic while running the agent"))
225 })
226 var attachmentParts []message.ContentPart
227 for _, attachment := range attachments {
228 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
229 }
230 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
231 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
232 logging.ErrorPersist(result.Error.Error())
233 }
234 logging.Debug("Request completed", "sessionID", sessionID)
235 a.activeRequests.Delete(sessionID)
236 cancel()
237 a.Publish(pubsub.CreatedEvent, result)
238 events <- result
239 close(events)
240 }()
241 return events, nil
242}
243
244func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
245 // List existing messages; if none, start title generation asynchronously.
246 msgs, err := a.messages.List(ctx, sessionID)
247 if err != nil {
248 return a.err(fmt.Errorf("failed to list messages: %w", err))
249 }
250 if len(msgs) == 0 {
251 go func() {
252 defer logging.RecoverPanic("agent.Run", func() {
253 logging.ErrorPersist("panic while generating title")
254 })
255 titleErr := a.generateTitle(context.Background(), sessionID, content)
256 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
257 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
258 }
259 }()
260 }
261 session, err := a.sessions.Get(ctx, sessionID)
262 if err != nil {
263 return a.err(fmt.Errorf("failed to get session: %w", err))
264 }
265 if session.SummaryMessageID != "" {
266 summaryMsgInex := -1
267 for i, msg := range msgs {
268 if msg.ID == session.SummaryMessageID {
269 summaryMsgInex = i
270 break
271 }
272 }
273 if summaryMsgInex != -1 {
274 msgs = msgs[summaryMsgInex:]
275 msgs[0].Role = message.User
276 }
277 }
278
279 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
280 if err != nil {
281 return a.err(fmt.Errorf("failed to create user message: %w", err))
282 }
283 // Append the new user message to the conversation history.
284 msgHistory := append(msgs, userMsg)
285
286 for {
287 // Check for cancellation before each iteration
288 select {
289 case <-ctx.Done():
290 return a.err(ctx.Err())
291 default:
292 // Continue processing
293 }
294 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
295 if err != nil {
296 if errors.Is(err, context.Canceled) {
297 agentMessage.AddFinish(message.FinishReasonCanceled)
298 a.messages.Update(context.Background(), agentMessage)
299 return a.err(ErrRequestCancelled)
300 }
301 return a.err(fmt.Errorf("failed to process events: %w", err))
302 }
303 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
304 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
305 // We are not done, we need to respond with the tool response
306 msgHistory = append(msgHistory, agentMessage, *toolResults)
307 continue
308 }
309 return AgentEvent{
310 Type: AgentEventTypeResponse,
311 Message: agentMessage,
312 Done: true,
313 }
314 }
315}
316
317func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
318 parts := []message.ContentPart{message.TextContent{Text: content}}
319 parts = append(parts, attachmentParts...)
320 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
321 Role: message.User,
322 Parts: parts,
323 })
324}
325
326func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
327 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
328
329 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
330 Role: message.Assistant,
331 Parts: []message.ContentPart{},
332 Model: a.provider.Model().ID,
333 })
334 if err != nil {
335 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
336 }
337
338 // Add the session and message ID into the context if needed by tools.
339 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
340 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
341
342 // Process each event in the stream.
343 for event := range eventChan {
344 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
345 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
346 return assistantMsg, nil, processErr
347 }
348 if ctx.Err() != nil {
349 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
350 return assistantMsg, nil, ctx.Err()
351 }
352 }
353
354 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
355 toolCalls := assistantMsg.ToolCalls()
356 for i, toolCall := range toolCalls {
357 select {
358 case <-ctx.Done():
359 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
360 // Make all future tool calls cancelled
361 for j := i; j < len(toolCalls); j++ {
362 toolResults[j] = message.ToolResult{
363 ToolCallID: toolCalls[j].ID,
364 Content: "Tool execution canceled by user",
365 IsError: true,
366 }
367 }
368 goto out
369 default:
370 // Continue processing
371 var tool tools.BaseTool
372 for _, availableTools := range a.tools {
373 if availableTools.Info().Name == toolCall.Name {
374 tool = availableTools
375 }
376 }
377
378 // Tool not found
379 if tool == nil {
380 toolResults[i] = message.ToolResult{
381 ToolCallID: toolCall.ID,
382 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
383 IsError: true,
384 }
385 continue
386 }
387 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
388 ID: toolCall.ID,
389 Name: toolCall.Name,
390 Input: toolCall.Input,
391 })
392 if toolErr != nil {
393 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
394 toolResults[i] = message.ToolResult{
395 ToolCallID: toolCall.ID,
396 Content: "Permission denied",
397 IsError: true,
398 }
399 for j := i + 1; j < len(toolCalls); j++ {
400 toolResults[j] = message.ToolResult{
401 ToolCallID: toolCalls[j].ID,
402 Content: "Tool execution canceled by user",
403 IsError: true,
404 }
405 }
406 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
407 break
408 }
409 }
410 toolResults[i] = message.ToolResult{
411 ToolCallID: toolCall.ID,
412 Content: toolResult.Content,
413 Metadata: toolResult.Metadata,
414 IsError: toolResult.IsError,
415 }
416 }
417 }
418out:
419 if len(toolResults) == 0 {
420 return assistantMsg, nil, nil
421 }
422 parts := make([]message.ContentPart, 0)
423 for _, tr := range toolResults {
424 parts = append(parts, tr)
425 }
426 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
427 Role: message.Tool,
428 Parts: parts,
429 })
430 if err != nil {
431 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
432 }
433
434 return assistantMsg, &msg, err
435}
436
437func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
438 msg.AddFinish(finishReson)
439 _ = a.messages.Update(ctx, *msg)
440}
441
442func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
443 select {
444 case <-ctx.Done():
445 return ctx.Err()
446 default:
447 // Continue processing.
448 }
449
450 switch event.Type {
451 case provider.EventThinkingDelta:
452 assistantMsg.AppendReasoningContent(event.Content)
453 return a.messages.Update(ctx, *assistantMsg)
454 case provider.EventContentDelta:
455 assistantMsg.AppendContent(event.Content)
456 return a.messages.Update(ctx, *assistantMsg)
457 case provider.EventToolUseStart:
458 logging.Info("Tool call started", "toolCall", event.ToolCall)
459 assistantMsg.AddToolCall(*event.ToolCall)
460 return a.messages.Update(ctx, *assistantMsg)
461 case provider.EventToolUseDelta:
462 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
463 return a.messages.Update(ctx, *assistantMsg)
464 case provider.EventToolUseStop:
465 logging.Info("Finished tool call", "toolCall", event.ToolCall)
466 assistantMsg.FinishToolCall(event.ToolCall.ID)
467 return a.messages.Update(ctx, *assistantMsg)
468 case provider.EventError:
469 if errors.Is(event.Error, context.Canceled) {
470 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
471 return context.Canceled
472 }
473 logging.ErrorPersist(event.Error.Error())
474 return event.Error
475 case provider.EventComplete:
476 assistantMsg.SetToolCalls(event.Response.ToolCalls)
477 assistantMsg.AddFinish(event.Response.FinishReason)
478 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
479 return fmt.Errorf("failed to update message: %w", err)
480 }
481 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
482 }
483
484 return nil
485}
486
487func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
488 sess, err := a.sessions.Get(ctx, sessionID)
489 if err != nil {
490 return fmt.Errorf("failed to get session: %w", err)
491 }
492
493 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
494 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
495 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
496 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
497
498 sess.Cost += cost
499 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
500 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
501
502 _, err = a.sessions.Save(ctx, sess)
503 if err != nil {
504 return fmt.Errorf("failed to save session: %w", err)
505 }
506 return nil
507}
508
509func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
510 if a.IsBusy() {
511 return models.Model{}, fmt.Errorf("cannot change model while processing requests")
512 }
513
514 if err := config.UpdateAgentModel(agentName, modelID); err != nil {
515 return models.Model{}, fmt.Errorf("failed to update config: %w", err)
516 }
517
518 provider, err := createAgentProvider(agentName)
519 if err != nil {
520 return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
521 }
522
523 a.provider = provider
524
525 return a.provider.Model(), nil
526}
527
528func (a *agent) Summarize(ctx context.Context, sessionID string) error {
529 if a.summarizeProvider == nil {
530 return fmt.Errorf("summarize provider not available")
531 }
532
533 // Check if session is busy
534 if a.IsSessionBusy(sessionID) {
535 return ErrSessionBusy
536 }
537
538 // Create a new context with cancellation
539 summarizeCtx, cancel := context.WithCancel(ctx)
540
541 // Store the cancel function in activeRequests to allow cancellation
542 a.activeRequests.Store(sessionID+"-summarize", cancel)
543
544 go func() {
545 defer a.activeRequests.Delete(sessionID + "-summarize")
546 defer cancel()
547 event := AgentEvent{
548 Type: AgentEventTypeSummarize,
549 Progress: "Starting summarization...",
550 }
551
552 a.Publish(pubsub.CreatedEvent, event)
553 // Get all messages from the session
554 msgs, err := a.messages.List(summarizeCtx, sessionID)
555 if err != nil {
556 event = AgentEvent{
557 Type: AgentEventTypeError,
558 Error: fmt.Errorf("failed to list messages: %w", err),
559 Done: true,
560 }
561 a.Publish(pubsub.CreatedEvent, event)
562 return
563 }
564
565 if len(msgs) == 0 {
566 event = AgentEvent{
567 Type: AgentEventTypeError,
568 Error: fmt.Errorf("no messages to summarize"),
569 Done: true,
570 }
571 a.Publish(pubsub.CreatedEvent, event)
572 return
573 }
574
575 event = AgentEvent{
576 Type: AgentEventTypeSummarize,
577 Progress: "Analyzing conversation...",
578 }
579 a.Publish(pubsub.CreatedEvent, event)
580
581 // Add a system message to guide the summarization
582 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
583
584 // Create a new message with the summarize prompt
585 promptMsg := message.Message{
586 Role: message.User,
587 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
588 }
589
590 // Append the prompt to the messages
591 msgsWithPrompt := append(msgs, promptMsg)
592
593 event = AgentEvent{
594 Type: AgentEventTypeSummarize,
595 Progress: "Generating summary...",
596 }
597
598 a.Publish(pubsub.CreatedEvent, event)
599
600 // Send the messages to the summarize provider
601 response := a.summarizeProvider.StreamResponse(
602 summarizeCtx,
603 msgsWithPrompt,
604 make([]tools.BaseTool, 0),
605 )
606 var finalResponse *provider.ProviderResponse
607 for r := range response {
608 if r.Error != nil {
609 event = AgentEvent{
610 Type: AgentEventTypeError,
611 Error: fmt.Errorf("failed to summarize: %w", err),
612 Done: true,
613 }
614 a.Publish(pubsub.CreatedEvent, event)
615 return
616 }
617 finalResponse = r.Response
618 }
619
620 summary := strings.TrimSpace(finalResponse.Content)
621 if summary == "" {
622 event = AgentEvent{
623 Type: AgentEventTypeError,
624 Error: fmt.Errorf("empty summary returned"),
625 Done: true,
626 }
627 a.Publish(pubsub.CreatedEvent, event)
628 return
629 }
630 event = AgentEvent{
631 Type: AgentEventTypeSummarize,
632 Progress: "Creating new session...",
633 }
634
635 a.Publish(pubsub.CreatedEvent, event)
636 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
637 if err != nil {
638 event = AgentEvent{
639 Type: AgentEventTypeError,
640 Error: fmt.Errorf("failed to get session: %w", err),
641 Done: true,
642 }
643
644 a.Publish(pubsub.CreatedEvent, event)
645 return
646 }
647 // Create a message in the new session with the summary
648 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
649 Role: message.Assistant,
650 Parts: []message.ContentPart{
651 message.TextContent{Text: summary},
652 message.Finish{
653 Reason: message.FinishReasonEndTurn,
654 Time: time.Now().Unix(),
655 },
656 },
657 Model: a.summarizeProvider.Model().ID,
658 })
659 if err != nil {
660 event = AgentEvent{
661 Type: AgentEventTypeError,
662 Error: fmt.Errorf("failed to create summary message: %w", err),
663 Done: true,
664 }
665
666 a.Publish(pubsub.CreatedEvent, event)
667 return
668 }
669 oldSession.SummaryMessageID = msg.ID
670 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
671 oldSession.PromptTokens = 0
672 model := a.summarizeProvider.Model()
673 usage := finalResponse.Usage
674 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
675 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
676 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
677 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
678 oldSession.Cost += cost
679 _, err = a.sessions.Save(summarizeCtx, oldSession)
680 if err != nil {
681 event = AgentEvent{
682 Type: AgentEventTypeError,
683 Error: fmt.Errorf("failed to save session: %w", err),
684 Done: true,
685 }
686 a.Publish(pubsub.CreatedEvent, event)
687 }
688
689 event = AgentEvent{
690 Type: AgentEventTypeSummarize,
691 SessionID: oldSession.ID,
692 Progress: "Summary complete",
693 Done: true,
694 }
695 a.Publish(pubsub.CreatedEvent, event)
696 // Send final success event with the new session ID
697 }()
698
699 return nil
700}
701
702func (a *agent) CancelAll() {
703 a.activeRequests.Range(func(key, value any) bool {
704 a.Cancel(key.(string)) // key is sessionID
705 return true
706 })
707}
708
709func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
710 cfg := config.Get()
711 agentConfig, ok := cfg.Agents[agentName]
712 if !ok {
713 return nil, fmt.Errorf("agent %s not found", agentName)
714 }
715 model, ok := models.SupportedModels[agentConfig.Model]
716 if !ok {
717 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
718 }
719
720 providerCfg, ok := cfg.Providers[model.Provider]
721 if !ok {
722 return nil, fmt.Errorf("provider %s not supported", model.Provider)
723 }
724 if providerCfg.Disabled {
725 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
726 }
727 maxTokens := model.DefaultMaxTokens
728 if agentConfig.MaxTokens > 0 {
729 maxTokens = agentConfig.MaxTokens
730 }
731 opts := []provider.ProviderClientOption{
732 provider.WithAPIKey(providerCfg.APIKey),
733 provider.WithModel(model),
734 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
735 provider.WithMaxTokens(maxTokens),
736 }
737 if (model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal) && model.CanReason {
738 opts = append(
739 opts,
740 provider.WithOpenAIOptions(
741 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
742 ),
743 )
744 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
745 opts = append(
746 opts,
747 provider.WithAnthropicOptions(
748 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
749 ),
750 )
751 }
752 agentProvider, err := provider.NewProvider(
753 model.Provider,
754 opts...,
755 )
756 if err != nil {
757 return nil, fmt.Errorf("could not create provider: %v", err)
758 }
759
760 return agentProvider, nil
761}