agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/opencode-ai/opencode/internal/config"
 11	"github.com/opencode-ai/opencode/internal/llm/models"
 12	"github.com/opencode-ai/opencode/internal/llm/prompt"
 13	"github.com/opencode-ai/opencode/internal/llm/provider"
 14	"github.com/opencode-ai/opencode/internal/llm/tools"
 15	"github.com/opencode-ai/opencode/internal/logging"
 16	"github.com/opencode-ai/opencode/internal/message"
 17	"github.com/opencode-ai/opencode/internal/permission"
 18	"github.com/opencode-ai/opencode/internal/session"
 19)
 20
 21// Common errors
 22var (
 23	ErrRequestCancelled = errors.New("request cancelled by user")
 24	ErrSessionBusy      = errors.New("session is currently processing another request")
 25)
 26
 27type AgentEvent struct {
 28	message message.Message
 29	err     error
 30}
 31
 32func (e *AgentEvent) Err() error {
 33	return e.err
 34}
 35
 36func (e *AgentEvent) Response() message.Message {
 37	return e.message
 38}
 39
 40type Service interface {
 41	Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
 42	Cancel(sessionID string)
 43	IsSessionBusy(sessionID string) bool
 44	IsBusy() bool
 45	Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
 46}
 47
 48type agent struct {
 49	sessions session.Service
 50	messages message.Service
 51
 52	tools    []tools.BaseTool
 53	provider provider.Provider
 54
 55	titleProvider provider.Provider
 56
 57	activeRequests sync.Map
 58}
 59
 60func NewAgent(
 61	agentName config.AgentName,
 62	sessions session.Service,
 63	messages message.Service,
 64	agentTools []tools.BaseTool,
 65) (Service, error) {
 66	agentProvider, err := createAgentProvider(agentName)
 67	if err != nil {
 68		return nil, err
 69	}
 70	var titleProvider provider.Provider
 71	// Only generate titles for the coder agent
 72	if agentName == config.AgentCoder {
 73		titleProvider, err = createAgentProvider(config.AgentTitle)
 74		if err != nil {
 75			return nil, err
 76		}
 77	}
 78
 79	agent := &agent{
 80		provider:       agentProvider,
 81		messages:       messages,
 82		sessions:       sessions,
 83		tools:          agentTools,
 84		titleProvider:  titleProvider,
 85		activeRequests: sync.Map{},
 86	}
 87
 88	return agent, nil
 89}
 90
 91func (a *agent) Cancel(sessionID string) {
 92	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 93		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 94			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 95			cancel()
 96		}
 97	}
 98}
 99
100func (a *agent) IsBusy() bool {
101	busy := false
102	a.activeRequests.Range(func(key, value interface{}) bool {
103		if cancelFunc, ok := value.(context.CancelFunc); ok {
104			if cancelFunc != nil {
105				busy = true
106				return false // Stop iterating
107			}
108		}
109		return true // Continue iterating
110	})
111	return busy
112}
113
114func (a *agent) IsSessionBusy(sessionID string) bool {
115	_, busy := a.activeRequests.Load(sessionID)
116	return busy
117}
118
119func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
120	if a.titleProvider == nil {
121		return nil
122	}
123	session, err := a.sessions.Get(ctx, sessionID)
124	if err != nil {
125		return err
126	}
127	response, err := a.titleProvider.SendMessages(
128		ctx,
129		[]message.Message{
130			{
131				Role: message.User,
132				Parts: []message.ContentPart{
133					message.TextContent{
134						Text: content,
135					},
136				},
137			},
138		},
139		make([]tools.BaseTool, 0),
140	)
141	if err != nil {
142		return err
143	}
144
145	title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
146	if title == "" {
147		return nil
148	}
149
150	session.Title = title
151	_, err = a.sessions.Save(ctx, session)
152	return err
153}
154
155func (a *agent) err(err error) AgentEvent {
156	return AgentEvent{
157		err: err,
158	}
159}
160
161func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
162	events := make(chan AgentEvent)
163	if a.IsSessionBusy(sessionID) {
164		return nil, ErrSessionBusy
165	}
166
167	genCtx, cancel := context.WithCancel(ctx)
168
169	a.activeRequests.Store(sessionID, cancel)
170	go func() {
171		logging.Debug("Request started", "sessionID", sessionID)
172		defer logging.RecoverPanic("agent.Run", func() {
173			events <- a.err(fmt.Errorf("panic while running the agent"))
174		})
175
176		result := a.processGeneration(genCtx, sessionID, content)
177		if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
178			logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
179		}
180		logging.Debug("Request completed", "sessionID", sessionID)
181		a.activeRequests.Delete(sessionID)
182		cancel()
183		events <- result
184		close(events)
185	}()
186	return events, nil
187}
188
189func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
190	// List existing messages; if none, start title generation asynchronously.
191	msgs, err := a.messages.List(ctx, sessionID)
192	if err != nil {
193		return a.err(fmt.Errorf("failed to list messages: %w", err))
194	}
195	if len(msgs) == 0 {
196		go func() {
197			defer logging.RecoverPanic("agent.Run", func() {
198				logging.ErrorPersist("panic while generating title")
199			})
200			titleErr := a.generateTitle(context.Background(), sessionID, content)
201			if titleErr != nil {
202				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
203			}
204		}()
205	}
206
207	userMsg, err := a.createUserMessage(ctx, sessionID, content)
208	if err != nil {
209		return a.err(fmt.Errorf("failed to create user message: %w", err))
210	}
211
212	// Append the new user message to the conversation history.
213	msgHistory := append(msgs, userMsg)
214	for {
215		// Check for cancellation before each iteration
216		select {
217		case <-ctx.Done():
218			return a.err(ctx.Err())
219		default:
220			// Continue processing
221		}
222		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
223		if err != nil {
224			if errors.Is(err, context.Canceled) {
225				agentMessage.AddFinish(message.FinishReasonCanceled)
226				a.messages.Update(context.Background(), agentMessage)
227				return a.err(ErrRequestCancelled)
228			}
229			return a.err(fmt.Errorf("failed to process events: %w", err))
230		}
231		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
232		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
233			// We are not done, we need to respond with the tool response
234			msgHistory = append(msgHistory, agentMessage, *toolResults)
235			continue
236		}
237		return AgentEvent{
238			message: agentMessage,
239		}
240	}
241}
242
243func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
244	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
245		Role: message.User,
246		Parts: []message.ContentPart{
247			message.TextContent{Text: content},
248		},
249	})
250}
251
252func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
253	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
254
255	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
256		Role:  message.Assistant,
257		Parts: []message.ContentPart{},
258		Model: a.provider.Model().ID,
259	})
260	if err != nil {
261		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
262	}
263
264	// Add the session and message ID into the context if needed by tools.
265	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
266	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
267
268	// Process each event in the stream.
269	for event := range eventChan {
270		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
271			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
272			return assistantMsg, nil, processErr
273		}
274		if ctx.Err() != nil {
275			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
276			return assistantMsg, nil, ctx.Err()
277		}
278	}
279
280	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
281	toolCalls := assistantMsg.ToolCalls()
282	for i, toolCall := range toolCalls {
283		select {
284		case <-ctx.Done():
285			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
286			// Make all future tool calls cancelled
287			for j := i; j < len(toolCalls); j++ {
288				toolResults[j] = message.ToolResult{
289					ToolCallID: toolCalls[j].ID,
290					Content:    "Tool execution canceled by user",
291					IsError:    true,
292				}
293			}
294			goto out
295		default:
296			// Continue processing
297			var tool tools.BaseTool
298			for _, availableTools := range a.tools {
299				if availableTools.Info().Name == toolCall.Name {
300					tool = availableTools
301				}
302			}
303
304			// Tool not found
305			if tool == nil {
306				toolResults[i] = message.ToolResult{
307					ToolCallID: toolCall.ID,
308					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
309					IsError:    true,
310				}
311				continue
312			}
313
314			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
315				ID:    toolCall.ID,
316				Name:  toolCall.Name,
317				Input: toolCall.Input,
318			})
319			if toolErr != nil {
320				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
321					toolResults[i] = message.ToolResult{
322						ToolCallID: toolCall.ID,
323						Content:    "Permission denied",
324						IsError:    true,
325					}
326					for j := i + 1; j < len(toolCalls); j++ {
327						toolResults[j] = message.ToolResult{
328							ToolCallID: toolCalls[j].ID,
329							Content:    "Tool execution canceled by user",
330							IsError:    true,
331						}
332					}
333					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
334					break
335				}
336			}
337			toolResults[i] = message.ToolResult{
338				ToolCallID: toolCall.ID,
339				Content:    toolResult.Content,
340				Metadata:   toolResult.Metadata,
341				IsError:    toolResult.IsError,
342			}
343		}
344	}
345out:
346	if len(toolResults) == 0 {
347		return assistantMsg, nil, nil
348	}
349	parts := make([]message.ContentPart, 0)
350	for _, tr := range toolResults {
351		parts = append(parts, tr)
352	}
353	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
354		Role:  message.Tool,
355		Parts: parts,
356	})
357	if err != nil {
358		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
359	}
360
361	return assistantMsg, &msg, err
362}
363
364func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
365	msg.AddFinish(finishReson)
366	_ = a.messages.Update(ctx, *msg)
367}
368
369func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
370	select {
371	case <-ctx.Done():
372		return ctx.Err()
373	default:
374		// Continue processing.
375	}
376
377	switch event.Type {
378	case provider.EventThinkingDelta:
379		assistantMsg.AppendReasoningContent(event.Content)
380		return a.messages.Update(ctx, *assistantMsg)
381	case provider.EventContentDelta:
382		assistantMsg.AppendContent(event.Content)
383		return a.messages.Update(ctx, *assistantMsg)
384	case provider.EventToolUseStart:
385		assistantMsg.AddToolCall(*event.ToolCall)
386		return a.messages.Update(ctx, *assistantMsg)
387	// TODO: see how to handle this
388	// case provider.EventToolUseDelta:
389	// 	tm := time.Unix(assistantMsg.UpdatedAt, 0)
390	// 	assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
391	// 	if time.Since(tm) > 1000*time.Millisecond {
392	// 		err := a.messages.Update(ctx, *assistantMsg)
393	// 		assistantMsg.UpdatedAt = time.Now().Unix()
394	// 		return err
395	// 	}
396	case provider.EventToolUseStop:
397		assistantMsg.FinishToolCall(event.ToolCall.ID)
398		return a.messages.Update(ctx, *assistantMsg)
399	case provider.EventError:
400		if errors.Is(event.Error, context.Canceled) {
401			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
402			return context.Canceled
403		}
404		logging.ErrorPersist(event.Error.Error())
405		return event.Error
406	case provider.EventComplete:
407		assistantMsg.SetToolCalls(event.Response.ToolCalls)
408		assistantMsg.AddFinish(event.Response.FinishReason)
409		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
410			return fmt.Errorf("failed to update message: %w", err)
411		}
412		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
413	}
414
415	return nil
416}
417
418func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
419	sess, err := a.sessions.Get(ctx, sessionID)
420	if err != nil {
421		return fmt.Errorf("failed to get session: %w", err)
422	}
423
424	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
425		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
426		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
427		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
428
429	sess.Cost += cost
430	sess.CompletionTokens += usage.OutputTokens
431	sess.PromptTokens += usage.InputTokens
432
433	_, err = a.sessions.Save(ctx, sess)
434	if err != nil {
435		return fmt.Errorf("failed to save session: %w", err)
436	}
437	return nil
438}
439
440func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
441	if a.IsBusy() {
442		return models.Model{}, fmt.Errorf("cannot change model while processing requests")
443	}
444
445	if err := config.UpdateAgentModel(agentName, modelID); err != nil {
446		return models.Model{}, fmt.Errorf("failed to update config: %w", err)
447	}
448
449	provider, err := createAgentProvider(agentName)
450	if err != nil {
451		return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
452	}
453
454	a.provider = provider
455
456	return a.provider.Model(), nil
457}
458
459func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
460	cfg := config.Get()
461	agentConfig, ok := cfg.Agents[agentName]
462	if !ok {
463		return nil, fmt.Errorf("agent %s not found", agentName)
464	}
465	model, ok := models.SupportedModels[agentConfig.Model]
466	if !ok {
467		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
468	}
469
470	providerCfg, ok := cfg.Providers[model.Provider]
471	if !ok {
472		return nil, fmt.Errorf("provider %s not supported", model.Provider)
473	}
474	if providerCfg.Disabled {
475		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
476	}
477	maxTokens := model.DefaultMaxTokens
478	if agentConfig.MaxTokens > 0 {
479		maxTokens = agentConfig.MaxTokens
480	}
481	opts := []provider.ProviderClientOption{
482		provider.WithAPIKey(providerCfg.APIKey),
483		provider.WithModel(model),
484		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
485		provider.WithMaxTokens(maxTokens),
486	}
487	if model.Provider == models.ProviderOpenAI && model.CanReason {
488		opts = append(
489			opts,
490			provider.WithOpenAIOptions(
491				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
492			),
493		)
494	} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
495		opts = append(
496			opts,
497			provider.WithAnthropicOptions(
498				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
499			),
500		)
501	}
502	agentProvider, err := provider.NewProvider(
503		model.Provider,
504		opts...,
505	)
506	if err != nil {
507		return nil, fmt.Errorf("could not create provider: %v", err)
508	}
509
510	return agentProvider, nil
511}