1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/opencode/internal/config"
11 "github.com/kujtimiihoxha/opencode/internal/llm/models"
12 "github.com/kujtimiihoxha/opencode/internal/llm/prompt"
13 "github.com/kujtimiihoxha/opencode/internal/llm/provider"
14 "github.com/kujtimiihoxha/opencode/internal/llm/tools"
15 "github.com/kujtimiihoxha/opencode/internal/logging"
16 "github.com/kujtimiihoxha/opencode/internal/message"
17 "github.com/kujtimiihoxha/opencode/internal/permission"
18 "github.com/kujtimiihoxha/opencode/internal/session"
19)
20
21// Common errors
22var (
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27type AgentEvent struct {
28 message message.Message
29 err error
30}
31
32func (e *AgentEvent) Err() error {
33 return e.err
34}
35
36func (e *AgentEvent) Response() message.Message {
37 return e.message
38}
39
40type Service interface {
41 Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
42 Cancel(sessionID string)
43 IsSessionBusy(sessionID string) bool
44 IsBusy() bool
45}
46
47type agent struct {
48 sessions session.Service
49 messages message.Service
50
51 tools []tools.BaseTool
52 provider provider.Provider
53
54 titleProvider provider.Provider
55
56 activeRequests sync.Map
57}
58
59func NewAgent(
60 agentName config.AgentName,
61 sessions session.Service,
62 messages message.Service,
63 agentTools []tools.BaseTool,
64) (Service, error) {
65 agentProvider, err := createAgentProvider(agentName)
66 if err != nil {
67 return nil, err
68 }
69 var titleProvider provider.Provider
70 // Only generate titles for the coder agent
71 if agentName == config.AgentCoder {
72 titleProvider, err = createAgentProvider(config.AgentTitle)
73 if err != nil {
74 return nil, err
75 }
76 }
77
78 agent := &agent{
79 provider: agentProvider,
80 messages: messages,
81 sessions: sessions,
82 tools: agentTools,
83 titleProvider: titleProvider,
84 activeRequests: sync.Map{},
85 }
86
87 return agent, nil
88}
89
90func (a *agent) Cancel(sessionID string) {
91 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
92 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
93 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
94 cancel()
95 }
96 }
97}
98
99func (a *agent) IsBusy() bool {
100 busy := false
101 a.activeRequests.Range(func(key, value interface{}) bool {
102 if cancelFunc, ok := value.(context.CancelFunc); ok {
103 if cancelFunc != nil {
104 busy = true
105 return false // Stop iterating
106 }
107 }
108 return true // Continue iterating
109 })
110 return busy
111}
112
113func (a *agent) IsSessionBusy(sessionID string) bool {
114 _, busy := a.activeRequests.Load(sessionID)
115 return busy
116}
117
118func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
119 if a.titleProvider == nil {
120 return nil
121 }
122 session, err := a.sessions.Get(ctx, sessionID)
123 if err != nil {
124 return err
125 }
126 response, err := a.titleProvider.SendMessages(
127 ctx,
128 []message.Message{
129 {
130 Role: message.User,
131 Parts: []message.ContentPart{
132 message.TextContent{
133 Text: content,
134 },
135 },
136 },
137 },
138 make([]tools.BaseTool, 0),
139 )
140 if err != nil {
141 return err
142 }
143
144 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
145 if title == "" {
146 return nil
147 }
148
149 session.Title = title
150 _, err = a.sessions.Save(ctx, session)
151 return err
152}
153
154func (a *agent) err(err error) AgentEvent {
155 return AgentEvent{
156 err: err,
157 }
158}
159
160func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
161 events := make(chan AgentEvent)
162 if a.IsSessionBusy(sessionID) {
163 return nil, ErrSessionBusy
164 }
165
166 genCtx, cancel := context.WithCancel(ctx)
167
168 a.activeRequests.Store(sessionID, cancel)
169 go func() {
170 logging.Debug("Request started", "sessionID", sessionID)
171 defer logging.RecoverPanic("agent.Run", func() {
172 events <- a.err(fmt.Errorf("panic while running the agent"))
173 })
174
175 result := a.processGeneration(genCtx, sessionID, content)
176 if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
177 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
178 }
179 logging.Debug("Request completed", "sessionID", sessionID)
180 a.activeRequests.Delete(sessionID)
181 cancel()
182 events <- result
183 close(events)
184 }()
185 return events, nil
186}
187
188func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
189 // List existing messages; if none, start title generation asynchronously.
190 msgs, err := a.messages.List(ctx, sessionID)
191 if err != nil {
192 return a.err(fmt.Errorf("failed to list messages: %w", err))
193 }
194 if len(msgs) == 0 {
195 go func() {
196 defer logging.RecoverPanic("agent.Run", func() {
197 logging.ErrorPersist("panic while generating title")
198 })
199 titleErr := a.generateTitle(context.Background(), sessionID, content)
200 if titleErr != nil {
201 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
202 }
203 }()
204 }
205
206 userMsg, err := a.createUserMessage(ctx, sessionID, content)
207 if err != nil {
208 return a.err(fmt.Errorf("failed to create user message: %w", err))
209 }
210
211 // Append the new user message to the conversation history.
212 msgHistory := append(msgs, userMsg)
213 for {
214 // Check for cancellation before each iteration
215 select {
216 case <-ctx.Done():
217 return a.err(ctx.Err())
218 default:
219 // Continue processing
220 }
221 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
222 if err != nil {
223 if errors.Is(err, context.Canceled) {
224 return a.err(ErrRequestCancelled)
225 }
226 return a.err(fmt.Errorf("failed to process events: %w", err))
227 }
228 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
229 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
230 // We are not done, we need to respond with the tool response
231 msgHistory = append(msgHistory, agentMessage, *toolResults)
232 continue
233 }
234 return AgentEvent{
235 message: agentMessage,
236 }
237 }
238}
239
240func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
241 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
242 Role: message.User,
243 Parts: []message.ContentPart{
244 message.TextContent{Text: content},
245 },
246 })
247}
248
249func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
250 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
251
252 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
253 Role: message.Assistant,
254 Parts: []message.ContentPart{},
255 Model: a.provider.Model().ID,
256 })
257 if err != nil {
258 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
259 }
260
261 // Add the session and message ID into the context if needed by tools.
262 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
263 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
264
265 // Process each event in the stream.
266 for event := range eventChan {
267 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
268 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
269 return assistantMsg, nil, processErr
270 }
271 if ctx.Err() != nil {
272 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
273 return assistantMsg, nil, ctx.Err()
274 }
275 }
276
277 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
278 toolCalls := assistantMsg.ToolCalls()
279 for i, toolCall := range toolCalls {
280 select {
281 case <-ctx.Done():
282 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
283 // Make all future tool calls cancelled
284 for j := i; j < len(toolCalls); j++ {
285 toolResults[j] = message.ToolResult{
286 ToolCallID: toolCalls[j].ID,
287 Content: "Tool execution canceled by user",
288 IsError: true,
289 }
290 }
291 goto out
292 default:
293 // Continue processing
294 var tool tools.BaseTool
295 for _, availableTools := range a.tools {
296 if availableTools.Info().Name == toolCall.Name {
297 tool = availableTools
298 }
299 }
300
301 // Tool not found
302 if tool == nil {
303 toolResults[i] = message.ToolResult{
304 ToolCallID: toolCall.ID,
305 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
306 IsError: true,
307 }
308 continue
309 }
310
311 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
312 ID: toolCall.ID,
313 Name: toolCall.Name,
314 Input: toolCall.Input,
315 })
316 if toolErr != nil {
317 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
318 toolResults[i] = message.ToolResult{
319 ToolCallID: toolCall.ID,
320 Content: "Permission denied",
321 IsError: true,
322 }
323 for j := i + 1; j < len(toolCalls); j++ {
324 toolResults[j] = message.ToolResult{
325 ToolCallID: toolCalls[j].ID,
326 Content: "Tool execution canceled by user",
327 IsError: true,
328 }
329 }
330 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
331 break
332 }
333 }
334 toolResults[i] = message.ToolResult{
335 ToolCallID: toolCall.ID,
336 Content: toolResult.Content,
337 Metadata: toolResult.Metadata,
338 IsError: toolResult.IsError,
339 }
340 }
341 }
342out:
343 if len(toolResults) == 0 {
344 return assistantMsg, nil, nil
345 }
346 parts := make([]message.ContentPart, 0)
347 for _, tr := range toolResults {
348 parts = append(parts, tr)
349 }
350 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
351 Role: message.Tool,
352 Parts: parts,
353 })
354 if err != nil {
355 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
356 }
357
358 return assistantMsg, &msg, err
359}
360
361func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
362 msg.AddFinish(finishReson)
363 _ = a.messages.Update(ctx, *msg)
364}
365
366func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
367 select {
368 case <-ctx.Done():
369 return ctx.Err()
370 default:
371 // Continue processing.
372 }
373
374 switch event.Type {
375 case provider.EventThinkingDelta:
376 assistantMsg.AppendReasoningContent(event.Content)
377 return a.messages.Update(ctx, *assistantMsg)
378 case provider.EventContentDelta:
379 assistantMsg.AppendContent(event.Content)
380 return a.messages.Update(ctx, *assistantMsg)
381 case provider.EventError:
382 if errors.Is(event.Error, context.Canceled) {
383 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
384 return context.Canceled
385 }
386 logging.ErrorPersist(event.Error.Error())
387 return event.Error
388 case provider.EventComplete:
389 assistantMsg.SetToolCalls(event.Response.ToolCalls)
390 assistantMsg.AddFinish(event.Response.FinishReason)
391 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
392 return fmt.Errorf("failed to update message: %w", err)
393 }
394 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
395 }
396
397 return nil
398}
399
400func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
401 sess, err := a.sessions.Get(ctx, sessionID)
402 if err != nil {
403 return fmt.Errorf("failed to get session: %w", err)
404 }
405
406 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
407 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
408 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
409 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
410
411 sess.Cost += cost
412 sess.CompletionTokens += usage.OutputTokens
413 sess.PromptTokens += usage.InputTokens
414
415 _, err = a.sessions.Save(ctx, sess)
416 if err != nil {
417 return fmt.Errorf("failed to save session: %w", err)
418 }
419 return nil
420}
421
422func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
423 cfg := config.Get()
424 agentConfig, ok := cfg.Agents[agentName]
425 if !ok {
426 return nil, fmt.Errorf("agent %s not found", agentName)
427 }
428 model, ok := models.SupportedModels[agentConfig.Model]
429 if !ok {
430 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
431 }
432
433 providerCfg, ok := cfg.Providers[model.Provider]
434 if !ok {
435 return nil, fmt.Errorf("provider %s not supported", model.Provider)
436 }
437 if providerCfg.Disabled {
438 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
439 }
440 maxTokens := model.DefaultMaxTokens
441 if agentConfig.MaxTokens > 0 {
442 maxTokens = agentConfig.MaxTokens
443 }
444 opts := []provider.ProviderClientOption{
445 provider.WithAPIKey(providerCfg.APIKey),
446 provider.WithModel(model),
447 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
448 provider.WithMaxTokens(maxTokens),
449 }
450 if model.Provider == models.ProviderOpenAI && model.CanReason {
451 opts = append(
452 opts,
453 provider.WithOpenAIOptions(
454 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
455 ),
456 )
457 }
458 agentProvider, err := provider.NewProvider(
459 model.Provider,
460 opts...,
461 )
462 if err != nil {
463 return nil, fmt.Errorf("could not create provider: %v", err)
464 }
465
466 return agentProvider, nil
467}