agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/opencode-ai/opencode/internal/config"
 11	"github.com/opencode-ai/opencode/internal/llm/models"
 12	"github.com/opencode-ai/opencode/internal/llm/prompt"
 13	"github.com/opencode-ai/opencode/internal/llm/provider"
 14	"github.com/opencode-ai/opencode/internal/llm/tools"
 15	"github.com/opencode-ai/opencode/internal/logging"
 16	"github.com/opencode-ai/opencode/internal/message"
 17	"github.com/opencode-ai/opencode/internal/permission"
 18	"github.com/opencode-ai/opencode/internal/session"
 19)
 20
 21// Common errors
 22var (
 23	ErrRequestCancelled = errors.New("request cancelled by user")
 24	ErrSessionBusy      = errors.New("session is currently processing another request")
 25)
 26
 27type AgentEvent struct {
 28	message message.Message
 29	err     error
 30}
 31
 32func (e *AgentEvent) Err() error {
 33	return e.err
 34}
 35
 36func (e *AgentEvent) Response() message.Message {
 37	return e.message
 38}
 39
 40type Service interface {
 41	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 42	Cancel(sessionID string)
 43	IsSessionBusy(sessionID string) bool
 44	IsBusy() bool
 45	Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
 46}
 47
 48type agent struct {
 49	sessions session.Service
 50	messages message.Service
 51
 52	tools    []tools.BaseTool
 53	provider provider.Provider
 54
 55	titleProvider provider.Provider
 56
 57	activeRequests sync.Map
 58}
 59
 60func NewAgent(
 61	agentName config.AgentName,
 62	sessions session.Service,
 63	messages message.Service,
 64	agentTools []tools.BaseTool,
 65) (Service, error) {
 66	agentProvider, err := createAgentProvider(agentName)
 67	if err != nil {
 68		return nil, err
 69	}
 70	var titleProvider provider.Provider
 71	// Only generate titles for the coder agent
 72	if agentName == config.AgentCoder {
 73		titleProvider, err = createAgentProvider(config.AgentTitle)
 74		if err != nil {
 75			return nil, err
 76		}
 77	}
 78
 79	agent := &agent{
 80		provider:       agentProvider,
 81		messages:       messages,
 82		sessions:       sessions,
 83		tools:          agentTools,
 84		titleProvider:  titleProvider,
 85		activeRequests: sync.Map{},
 86	}
 87
 88	return agent, nil
 89}
 90
 91func (a *agent) Cancel(sessionID string) {
 92	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 93		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 94			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 95			cancel()
 96		}
 97	}
 98}
 99
100func (a *agent) IsBusy() bool {
101	busy := false
102	a.activeRequests.Range(func(key, value interface{}) bool {
103		if cancelFunc, ok := value.(context.CancelFunc); ok {
104			if cancelFunc != nil {
105				busy = true
106				return false // Stop iterating
107			}
108		}
109		return true // Continue iterating
110	})
111	return busy
112}
113
114func (a *agent) IsSessionBusy(sessionID string) bool {
115	_, busy := a.activeRequests.Load(sessionID)
116	return busy
117}
118
119func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
120	if content == "" {
121		return nil
122	}
123	if a.titleProvider == nil {
124		return nil
125	}
126	session, err := a.sessions.Get(ctx, sessionID)
127	if err != nil {
128		return err
129	}
130	parts := []message.ContentPart{message.TextContent{Text: content}}
131	response, err := a.titleProvider.SendMessages(
132		ctx,
133		[]message.Message{
134			{
135				Role:  message.User,
136				Parts: parts,
137			},
138		},
139		make([]tools.BaseTool, 0),
140	)
141	if err != nil {
142		return err
143	}
144
145	title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
146	if title == "" {
147		return nil
148	}
149
150	session.Title = title
151	_, err = a.sessions.Save(ctx, session)
152	return err
153}
154
155func (a *agent) err(err error) AgentEvent {
156	return AgentEvent{
157		err: err,
158	}
159}
160
161func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
162	if !a.provider.Model().SupportsAttachments && attachments != nil {
163		attachments = nil
164	}
165	events := make(chan AgentEvent)
166	if a.IsSessionBusy(sessionID) {
167		return nil, ErrSessionBusy
168	}
169
170	genCtx, cancel := context.WithCancel(ctx)
171
172	a.activeRequests.Store(sessionID, cancel)
173	go func() {
174		logging.Debug("Request started", "sessionID", sessionID)
175		defer logging.RecoverPanic("agent.Run", func() {
176			events <- a.err(fmt.Errorf("panic while running the agent"))
177		})
178		var attachmentParts []message.ContentPart
179		for _, attachment := range attachments {
180			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
181		}
182		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
183		if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
184			logging.ErrorPersist(result.Err().Error())
185		}
186		logging.Debug("Request completed", "sessionID", sessionID)
187		a.activeRequests.Delete(sessionID)
188		cancel()
189		events <- result
190		close(events)
191	}()
192	return events, nil
193}
194
195func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
196	// List existing messages; if none, start title generation asynchronously.
197	msgs, err := a.messages.List(ctx, sessionID)
198	if err != nil {
199		return a.err(fmt.Errorf("failed to list messages: %w", err))
200	}
201	if len(msgs) == 0 {
202		go func() {
203			defer logging.RecoverPanic("agent.Run", func() {
204				logging.ErrorPersist("panic while generating title")
205			})
206			titleErr := a.generateTitle(context.Background(), sessionID, content)
207			if titleErr != nil {
208				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
209			}
210		}()
211	}
212
213	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
214	if err != nil {
215		return a.err(fmt.Errorf("failed to create user message: %w", err))
216	}
217	// Append the new user message to the conversation history.
218	msgHistory := append(msgs, userMsg)
219
220	for {
221		// Check for cancellation before each iteration
222		select {
223		case <-ctx.Done():
224			return a.err(ctx.Err())
225		default:
226			// Continue processing
227		}
228		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
229		if err != nil {
230			if errors.Is(err, context.Canceled) {
231				agentMessage.AddFinish(message.FinishReasonCanceled)
232				a.messages.Update(context.Background(), agentMessage)
233				return a.err(ErrRequestCancelled)
234			}
235			return a.err(fmt.Errorf("failed to process events: %w", err))
236		}
237		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
238		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
239			// We are not done, we need to respond with the tool response
240			msgHistory = append(msgHistory, agentMessage, *toolResults)
241			continue
242		}
243		return AgentEvent{
244			message: agentMessage,
245		}
246	}
247}
248
249func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
250	parts := []message.ContentPart{message.TextContent{Text: content}}
251	parts = append(parts, attachmentParts...)
252	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
253		Role:  message.User,
254		Parts: parts,
255	})
256}
257
258func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
259	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
260
261	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
262		Role:  message.Assistant,
263		Parts: []message.ContentPart{},
264		Model: a.provider.Model().ID,
265	})
266	if err != nil {
267		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
268	}
269
270	// Add the session and message ID into the context if needed by tools.
271	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
272	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
273
274	// Process each event in the stream.
275	for event := range eventChan {
276		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
277			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
278			return assistantMsg, nil, processErr
279		}
280		if ctx.Err() != nil {
281			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
282			return assistantMsg, nil, ctx.Err()
283		}
284	}
285
286	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
287	toolCalls := assistantMsg.ToolCalls()
288	for i, toolCall := range toolCalls {
289		select {
290		case <-ctx.Done():
291			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
292			// Make all future tool calls cancelled
293			for j := i; j < len(toolCalls); j++ {
294				toolResults[j] = message.ToolResult{
295					ToolCallID: toolCalls[j].ID,
296					Content:    "Tool execution canceled by user",
297					IsError:    true,
298				}
299			}
300			goto out
301		default:
302			// Continue processing
303			var tool tools.BaseTool
304			for _, availableTools := range a.tools {
305				if availableTools.Info().Name == toolCall.Name {
306					tool = availableTools
307				}
308			}
309
310			// Tool not found
311			if tool == nil {
312				toolResults[i] = message.ToolResult{
313					ToolCallID: toolCall.ID,
314					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
315					IsError:    true,
316				}
317				continue
318			}
319			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
320				ID:    toolCall.ID,
321				Name:  toolCall.Name,
322				Input: toolCall.Input,
323			})
324			if toolErr != nil {
325				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
326					toolResults[i] = message.ToolResult{
327						ToolCallID: toolCall.ID,
328						Content:    "Permission denied",
329						IsError:    true,
330					}
331					for j := i + 1; j < len(toolCalls); j++ {
332						toolResults[j] = message.ToolResult{
333							ToolCallID: toolCalls[j].ID,
334							Content:    "Tool execution canceled by user",
335							IsError:    true,
336						}
337					}
338					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
339					break
340				}
341			}
342			toolResults[i] = message.ToolResult{
343				ToolCallID: toolCall.ID,
344				Content:    toolResult.Content,
345				Metadata:   toolResult.Metadata,
346				IsError:    toolResult.IsError,
347			}
348		}
349	}
350out:
351	if len(toolResults) == 0 {
352		return assistantMsg, nil, nil
353	}
354	parts := make([]message.ContentPart, 0)
355	for _, tr := range toolResults {
356		parts = append(parts, tr)
357	}
358	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
359		Role:  message.Tool,
360		Parts: parts,
361	})
362	if err != nil {
363		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
364	}
365
366	return assistantMsg, &msg, err
367}
368
369func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
370	msg.AddFinish(finishReson)
371	_ = a.messages.Update(ctx, *msg)
372}
373
374func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
375	select {
376	case <-ctx.Done():
377		return ctx.Err()
378	default:
379		// Continue processing.
380	}
381
382	switch event.Type {
383	case provider.EventThinkingDelta:
384		assistantMsg.AppendReasoningContent(event.Content)
385		return a.messages.Update(ctx, *assistantMsg)
386	case provider.EventContentDelta:
387		assistantMsg.AppendContent(event.Content)
388		return a.messages.Update(ctx, *assistantMsg)
389	case provider.EventToolUseStart:
390		assistantMsg.AddToolCall(*event.ToolCall)
391		return a.messages.Update(ctx, *assistantMsg)
392	// TODO: see how to handle this
393	// case provider.EventToolUseDelta:
394	// 	tm := time.Unix(assistantMsg.UpdatedAt, 0)
395	// 	assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
396	// 	if time.Since(tm) > 1000*time.Millisecond {
397	// 		err := a.messages.Update(ctx, *assistantMsg)
398	// 		assistantMsg.UpdatedAt = time.Now().Unix()
399	// 		return err
400	// 	}
401	case provider.EventToolUseStop:
402		assistantMsg.FinishToolCall(event.ToolCall.ID)
403		return a.messages.Update(ctx, *assistantMsg)
404	case provider.EventError:
405		if errors.Is(event.Error, context.Canceled) {
406			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
407			return context.Canceled
408		}
409		logging.ErrorPersist(event.Error.Error())
410		return event.Error
411	case provider.EventComplete:
412		assistantMsg.SetToolCalls(event.Response.ToolCalls)
413		assistantMsg.AddFinish(event.Response.FinishReason)
414		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
415			return fmt.Errorf("failed to update message: %w", err)
416		}
417		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
418	}
419
420	return nil
421}
422
423func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
424	sess, err := a.sessions.Get(ctx, sessionID)
425	if err != nil {
426		return fmt.Errorf("failed to get session: %w", err)
427	}
428
429	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
430		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
431		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
432		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
433
434	sess.Cost += cost
435	sess.CompletionTokens += usage.OutputTokens
436	sess.PromptTokens += usage.InputTokens
437
438	_, err = a.sessions.Save(ctx, sess)
439	if err != nil {
440		return fmt.Errorf("failed to save session: %w", err)
441	}
442	return nil
443}
444
445func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
446	if a.IsBusy() {
447		return models.Model{}, fmt.Errorf("cannot change model while processing requests")
448	}
449
450	if err := config.UpdateAgentModel(agentName, modelID); err != nil {
451		return models.Model{}, fmt.Errorf("failed to update config: %w", err)
452	}
453
454	provider, err := createAgentProvider(agentName)
455	if err != nil {
456		return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
457	}
458
459	a.provider = provider
460
461	return a.provider.Model(), nil
462}
463
464func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
465	cfg := config.Get()
466	agentConfig, ok := cfg.Agents[agentName]
467	if !ok {
468		return nil, fmt.Errorf("agent %s not found", agentName)
469	}
470	model, ok := models.SupportedModels[agentConfig.Model]
471	if !ok {
472		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
473	}
474
475	providerCfg, ok := cfg.Providers[model.Provider]
476	if !ok {
477		return nil, fmt.Errorf("provider %s not supported", model.Provider)
478	}
479	if providerCfg.Disabled {
480		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
481	}
482	maxTokens := model.DefaultMaxTokens
483	if agentConfig.MaxTokens > 0 {
484		maxTokens = agentConfig.MaxTokens
485	}
486	opts := []provider.ProviderClientOption{
487		provider.WithAPIKey(providerCfg.APIKey),
488		provider.WithModel(model),
489		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
490		provider.WithMaxTokens(maxTokens),
491	}
492	if model.Provider == models.ProviderOpenAI && model.CanReason {
493		opts = append(
494			opts,
495			provider.WithOpenAIOptions(
496				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
497			),
498		)
499	} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
500		opts = append(
501			opts,
502			provider.WithAnthropicOptions(
503				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
504			),
505		)
506	}
507	agentProvider, err := provider.NewProvider(
508		model.Provider,
509		opts...,
510	)
511	if err != nil {
512		return nil, fmt.Errorf("could not create provider: %v", err)
513	}
514
515	return agentProvider, nil
516}