agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/opencode/internal/config"
 11	"github.com/kujtimiihoxha/opencode/internal/llm/models"
 12	"github.com/kujtimiihoxha/opencode/internal/llm/prompt"
 13	"github.com/kujtimiihoxha/opencode/internal/llm/provider"
 14	"github.com/kujtimiihoxha/opencode/internal/llm/tools"
 15	"github.com/kujtimiihoxha/opencode/internal/logging"
 16	"github.com/kujtimiihoxha/opencode/internal/message"
 17	"github.com/kujtimiihoxha/opencode/internal/permission"
 18	"github.com/kujtimiihoxha/opencode/internal/session"
 19)
 20
 21// Common errors
 22var (
 23	ErrRequestCancelled = errors.New("request cancelled by user")
 24	ErrSessionBusy      = errors.New("session is currently processing another request")
 25)
 26
 27type AgentEvent struct {
 28	message message.Message
 29	err     error
 30}
 31
 32func (e *AgentEvent) Err() error {
 33	return e.err
 34}
 35
 36func (e *AgentEvent) Response() message.Message {
 37	return e.message
 38}
 39
 40type Service interface {
 41	Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
 42	Cancel(sessionID string)
 43	IsSessionBusy(sessionID string) bool
 44	IsBusy() bool
 45}
 46
 47type agent struct {
 48	sessions session.Service
 49	messages message.Service
 50
 51	tools    []tools.BaseTool
 52	provider provider.Provider
 53
 54	titleProvider provider.Provider
 55
 56	activeRequests sync.Map
 57}
 58
 59func NewAgent(
 60	agentName config.AgentName,
 61	sessions session.Service,
 62	messages message.Service,
 63	agentTools []tools.BaseTool,
 64) (Service, error) {
 65	agentProvider, err := createAgentProvider(agentName)
 66	if err != nil {
 67		return nil, err
 68	}
 69	var titleProvider provider.Provider
 70	// Only generate titles for the coder agent
 71	if agentName == config.AgentCoder {
 72		titleProvider, err = createAgentProvider(config.AgentTitle)
 73		if err != nil {
 74			return nil, err
 75		}
 76	}
 77
 78	agent := &agent{
 79		provider:       agentProvider,
 80		messages:       messages,
 81		sessions:       sessions,
 82		tools:          agentTools,
 83		titleProvider:  titleProvider,
 84		activeRequests: sync.Map{},
 85	}
 86
 87	return agent, nil
 88}
 89
 90func (a *agent) Cancel(sessionID string) {
 91	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 92		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 93			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 94			cancel()
 95		}
 96	}
 97}
 98
 99func (a *agent) IsBusy() bool {
100	busy := false
101	a.activeRequests.Range(func(key, value interface{}) bool {
102		if cancelFunc, ok := value.(context.CancelFunc); ok {
103			if cancelFunc != nil {
104				busy = true
105				return false // Stop iterating
106			}
107		}
108		return true // Continue iterating
109	})
110	return busy
111}
112
113func (a *agent) IsSessionBusy(sessionID string) bool {
114	_, busy := a.activeRequests.Load(sessionID)
115	return busy
116}
117
118func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
119	if a.titleProvider == nil {
120		return nil
121	}
122	session, err := a.sessions.Get(ctx, sessionID)
123	if err != nil {
124		return err
125	}
126	response, err := a.titleProvider.SendMessages(
127		ctx,
128		[]message.Message{
129			{
130				Role: message.User,
131				Parts: []message.ContentPart{
132					message.TextContent{
133						Text: content,
134					},
135				},
136			},
137		},
138		make([]tools.BaseTool, 0),
139	)
140	if err != nil {
141		return err
142	}
143
144	title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
145	if title == "" {
146		return nil
147	}
148
149	session.Title = title
150	_, err = a.sessions.Save(ctx, session)
151	return err
152}
153
154func (a *agent) err(err error) AgentEvent {
155	return AgentEvent{
156		err: err,
157	}
158}
159
160func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
161	events := make(chan AgentEvent)
162	if a.IsSessionBusy(sessionID) {
163		return nil, ErrSessionBusy
164	}
165
166	genCtx, cancel := context.WithCancel(ctx)
167
168	a.activeRequests.Store(sessionID, cancel)
169	go func() {
170		logging.Debug("Request started", "sessionID", sessionID)
171		defer logging.RecoverPanic("agent.Run", func() {
172			events <- a.err(fmt.Errorf("panic while running the agent"))
173		})
174
175		result := a.processGeneration(genCtx, sessionID, content)
176		if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
177			logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
178		}
179		logging.Debug("Request completed", "sessionID", sessionID)
180		a.activeRequests.Delete(sessionID)
181		cancel()
182		events <- result
183		close(events)
184	}()
185	return events, nil
186}
187
188func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
189	// List existing messages; if none, start title generation asynchronously.
190	msgs, err := a.messages.List(ctx, sessionID)
191	if err != nil {
192		return a.err(fmt.Errorf("failed to list messages: %w", err))
193	}
194	if len(msgs) == 0 {
195		go func() {
196			defer logging.RecoverPanic("agent.Run", func() {
197				logging.ErrorPersist("panic while generating title")
198			})
199			titleErr := a.generateTitle(context.Background(), sessionID, content)
200			if titleErr != nil {
201				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
202			}
203		}()
204	}
205
206	userMsg, err := a.createUserMessage(ctx, sessionID, content)
207	if err != nil {
208		return a.err(fmt.Errorf("failed to create user message: %w", err))
209	}
210
211	// Append the new user message to the conversation history.
212	msgHistory := append(msgs, userMsg)
213	for {
214		// Check for cancellation before each iteration
215		select {
216		case <-ctx.Done():
217			return a.err(ctx.Err())
218		default:
219			// Continue processing
220		}
221		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
222		if err != nil {
223			if errors.Is(err, context.Canceled) {
224				return a.err(ErrRequestCancelled)
225			}
226			return a.err(fmt.Errorf("failed to process events: %w", err))
227		}
228		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
229		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
230			// We are not done, we need to respond with the tool response
231			msgHistory = append(msgHistory, agentMessage, *toolResults)
232			continue
233		}
234		return AgentEvent{
235			message: agentMessage,
236		}
237	}
238}
239
240func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
241	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
242		Role: message.User,
243		Parts: []message.ContentPart{
244			message.TextContent{Text: content},
245		},
246	})
247}
248
249func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
250	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
251
252	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
253		Role:  message.Assistant,
254		Parts: []message.ContentPart{},
255		Model: a.provider.Model().ID,
256	})
257	if err != nil {
258		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
259	}
260
261	// Add the session and message ID into the context if needed by tools.
262	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
263	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
264
265	// Process each event in the stream.
266	for event := range eventChan {
267		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
268			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
269			return assistantMsg, nil, processErr
270		}
271		if ctx.Err() != nil {
272			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
273			return assistantMsg, nil, ctx.Err()
274		}
275	}
276
277	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
278	toolCalls := assistantMsg.ToolCalls()
279	for i, toolCall := range toolCalls {
280		select {
281		case <-ctx.Done():
282			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
283			// Make all future tool calls cancelled
284			for j := i; j < len(toolCalls); j++ {
285				toolResults[j] = message.ToolResult{
286					ToolCallID: toolCalls[j].ID,
287					Content:    "Tool execution canceled by user",
288					IsError:    true,
289				}
290			}
291			goto out
292		default:
293			// Continue processing
294			var tool tools.BaseTool
295			for _, availableTools := range a.tools {
296				if availableTools.Info().Name == toolCall.Name {
297					tool = availableTools
298				}
299			}
300
301			// Tool not found
302			if tool == nil {
303				toolResults[i] = message.ToolResult{
304					ToolCallID: toolCall.ID,
305					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
306					IsError:    true,
307				}
308				continue
309			}
310
311			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
312				ID:    toolCall.ID,
313				Name:  toolCall.Name,
314				Input: toolCall.Input,
315			})
316			if toolErr != nil {
317				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
318					toolResults[i] = message.ToolResult{
319						ToolCallID: toolCall.ID,
320						Content:    "Permission denied",
321						IsError:    true,
322					}
323					for j := i + 1; j < len(toolCalls); j++ {
324						toolResults[j] = message.ToolResult{
325							ToolCallID: toolCalls[j].ID,
326							Content:    "Tool execution canceled by user",
327							IsError:    true,
328						}
329					}
330					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
331					break
332				}
333			}
334			toolResults[i] = message.ToolResult{
335				ToolCallID: toolCall.ID,
336				Content:    toolResult.Content,
337				Metadata:   toolResult.Metadata,
338				IsError:    toolResult.IsError,
339			}
340		}
341	}
342out:
343	if len(toolResults) == 0 {
344		return assistantMsg, nil, nil
345	}
346	parts := make([]message.ContentPart, 0)
347	for _, tr := range toolResults {
348		parts = append(parts, tr)
349	}
350	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
351		Role:  message.Tool,
352		Parts: parts,
353	})
354	if err != nil {
355		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
356	}
357
358	return assistantMsg, &msg, err
359}
360
361func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
362	msg.AddFinish(finishReson)
363	_ = a.messages.Update(ctx, *msg)
364}
365
366func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
367	select {
368	case <-ctx.Done():
369		return ctx.Err()
370	default:
371		// Continue processing.
372	}
373
374	switch event.Type {
375	case provider.EventThinkingDelta:
376		assistantMsg.AppendReasoningContent(event.Content)
377		return a.messages.Update(ctx, *assistantMsg)
378	case provider.EventContentDelta:
379		assistantMsg.AppendContent(event.Content)
380		return a.messages.Update(ctx, *assistantMsg)
381	case provider.EventError:
382		if errors.Is(event.Error, context.Canceled) {
383			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
384			return context.Canceled
385		}
386		logging.ErrorPersist(event.Error.Error())
387		return event.Error
388	case provider.EventComplete:
389		assistantMsg.SetToolCalls(event.Response.ToolCalls)
390		assistantMsg.AddFinish(event.Response.FinishReason)
391		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
392			return fmt.Errorf("failed to update message: %w", err)
393		}
394		return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
395	}
396
397	return nil
398}
399
400func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
401	sess, err := a.sessions.Get(ctx, sessionID)
402	if err != nil {
403		return fmt.Errorf("failed to get session: %w", err)
404	}
405
406	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
407		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
408		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
409		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
410
411	sess.Cost += cost
412	sess.CompletionTokens += usage.OutputTokens
413	sess.PromptTokens += usage.InputTokens
414
415	_, err = a.sessions.Save(ctx, sess)
416	if err != nil {
417		return fmt.Errorf("failed to save session: %w", err)
418	}
419	return nil
420}
421
422func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
423	cfg := config.Get()
424	agentConfig, ok := cfg.Agents[agentName]
425	if !ok {
426		return nil, fmt.Errorf("agent %s not found", agentName)
427	}
428	model, ok := models.SupportedModels[agentConfig.Model]
429	if !ok {
430		return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
431	}
432
433	providerCfg, ok := cfg.Providers[model.Provider]
434	if !ok {
435		return nil, fmt.Errorf("provider %s not supported", model.Provider)
436	}
437	if providerCfg.Disabled {
438		return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
439	}
440	maxTokens := model.DefaultMaxTokens
441	if agentConfig.MaxTokens > 0 {
442		maxTokens = agentConfig.MaxTokens
443	}
444	opts := []provider.ProviderClientOption{
445		provider.WithAPIKey(providerCfg.APIKey),
446		provider.WithModel(model),
447		provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
448		provider.WithMaxTokens(maxTokens),
449	}
450	if model.Provider == models.ProviderOpenAI && model.CanReason {
451		opts = append(
452			opts,
453			provider.WithOpenAIOptions(
454				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
455			),
456		)
457	}
458	agentProvider, err := provider.NewProvider(
459		model.Provider,
460		opts...,
461	)
462	if err != nil {
463		return nil, fmt.Errorf("could not create provider: %v", err)
464	}
465
466	return agentProvider, nil
467}