1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/opencode-ai/opencode/internal/config"
11 "github.com/opencode-ai/opencode/internal/llm/models"
12 "github.com/opencode-ai/opencode/internal/llm/prompt"
13 "github.com/opencode-ai/opencode/internal/llm/provider"
14 "github.com/opencode-ai/opencode/internal/llm/tools"
15 "github.com/opencode-ai/opencode/internal/logging"
16 "github.com/opencode-ai/opencode/internal/message"
17 "github.com/opencode-ai/opencode/internal/permission"
18 "github.com/opencode-ai/opencode/internal/session"
19)
20
21// Common errors
22var (
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27type AgentEvent struct {
28 message message.Message
29 err error
30}
31
32func (e *AgentEvent) Err() error {
33 return e.err
34}
35
36func (e *AgentEvent) Response() message.Message {
37 return e.message
38}
39
40type Service interface {
41 Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error)
42 Cancel(sessionID string)
43 IsSessionBusy(sessionID string) bool
44 IsBusy() bool
45}
46
47type agent struct {
48 sessions session.Service
49 messages message.Service
50
51 tools []tools.BaseTool
52 provider provider.Provider
53
54 titleProvider provider.Provider
55
56 activeRequests sync.Map
57}
58
59func NewAgent(
60 agentName config.AgentName,
61 sessions session.Service,
62 messages message.Service,
63 agentTools []tools.BaseTool,
64) (Service, error) {
65 agentProvider, err := createAgentProvider(agentName)
66 if err != nil {
67 return nil, err
68 }
69 var titleProvider provider.Provider
70 // Only generate titles for the coder agent
71 if agentName == config.AgentCoder {
72 titleProvider, err = createAgentProvider(config.AgentTitle)
73 if err != nil {
74 return nil, err
75 }
76 }
77
78 agent := &agent{
79 provider: agentProvider,
80 messages: messages,
81 sessions: sessions,
82 tools: agentTools,
83 titleProvider: titleProvider,
84 activeRequests: sync.Map{},
85 }
86
87 return agent, nil
88}
89
90func (a *agent) Cancel(sessionID string) {
91 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
92 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
93 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
94 cancel()
95 }
96 }
97}
98
99func (a *agent) IsBusy() bool {
100 busy := false
101 a.activeRequests.Range(func(key, value interface{}) bool {
102 if cancelFunc, ok := value.(context.CancelFunc); ok {
103 if cancelFunc != nil {
104 busy = true
105 return false // Stop iterating
106 }
107 }
108 return true // Continue iterating
109 })
110 return busy
111}
112
113func (a *agent) IsSessionBusy(sessionID string) bool {
114 _, busy := a.activeRequests.Load(sessionID)
115 return busy
116}
117
118func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
119 if a.titleProvider == nil {
120 return nil
121 }
122 session, err := a.sessions.Get(ctx, sessionID)
123 if err != nil {
124 return err
125 }
126 response, err := a.titleProvider.SendMessages(
127 ctx,
128 []message.Message{
129 {
130 Role: message.User,
131 Parts: []message.ContentPart{
132 message.TextContent{
133 Text: content,
134 },
135 },
136 },
137 },
138 make([]tools.BaseTool, 0),
139 )
140 if err != nil {
141 return err
142 }
143
144 title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
145 if title == "" {
146 return nil
147 }
148
149 session.Title = title
150 _, err = a.sessions.Save(ctx, session)
151 return err
152}
153
154func (a *agent) err(err error) AgentEvent {
155 return AgentEvent{
156 err: err,
157 }
158}
159
160func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) {
161 events := make(chan AgentEvent)
162 if a.IsSessionBusy(sessionID) {
163 return nil, ErrSessionBusy
164 }
165
166 genCtx, cancel := context.WithCancel(ctx)
167
168 a.activeRequests.Store(sessionID, cancel)
169 go func() {
170 logging.Debug("Request started", "sessionID", sessionID)
171 defer logging.RecoverPanic("agent.Run", func() {
172 events <- a.err(fmt.Errorf("panic while running the agent"))
173 })
174
175 result := a.processGeneration(genCtx, sessionID, content)
176 if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
177 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result))
178 }
179 logging.Debug("Request completed", "sessionID", sessionID)
180 a.activeRequests.Delete(sessionID)
181 cancel()
182 events <- result
183 close(events)
184 }()
185 return events, nil
186}
187
188func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent {
189 // List existing messages; if none, start title generation asynchronously.
190 msgs, err := a.messages.List(ctx, sessionID)
191 if err != nil {
192 return a.err(fmt.Errorf("failed to list messages: %w", err))
193 }
194 if len(msgs) == 0 {
195 go func() {
196 defer logging.RecoverPanic("agent.Run", func() {
197 logging.ErrorPersist("panic while generating title")
198 })
199 titleErr := a.generateTitle(context.Background(), sessionID, content)
200 if titleErr != nil {
201 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
202 }
203 }()
204 }
205
206 userMsg, err := a.createUserMessage(ctx, sessionID, content)
207 if err != nil {
208 return a.err(fmt.Errorf("failed to create user message: %w", err))
209 }
210
211 // Append the new user message to the conversation history.
212 msgHistory := append(msgs, userMsg)
213 for {
214 // Check for cancellation before each iteration
215 select {
216 case <-ctx.Done():
217 return a.err(ctx.Err())
218 default:
219 // Continue processing
220 }
221 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
222 if err != nil {
223 if errors.Is(err, context.Canceled) {
224 agentMessage.AddFinish(message.FinishReasonCanceled)
225 a.messages.Update(context.Background(), agentMessage)
226 return a.err(ErrRequestCancelled)
227 }
228 return a.err(fmt.Errorf("failed to process events: %w", err))
229 }
230 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
231 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
232 // We are not done, we need to respond with the tool response
233 msgHistory = append(msgHistory, agentMessage, *toolResults)
234 continue
235 }
236 return AgentEvent{
237 message: agentMessage,
238 }
239 }
240}
241
242func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) {
243 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
244 Role: message.User,
245 Parts: []message.ContentPart{
246 message.TextContent{Text: content},
247 },
248 })
249}
250
251func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
252 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
253
254 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
255 Role: message.Assistant,
256 Parts: []message.ContentPart{},
257 Model: a.provider.Model().ID,
258 })
259 if err != nil {
260 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
261 }
262
263 // Add the session and message ID into the context if needed by tools.
264 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
265 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
266
267 // Process each event in the stream.
268 for event := range eventChan {
269 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
270 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
271 return assistantMsg, nil, processErr
272 }
273 if ctx.Err() != nil {
274 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
275 return assistantMsg, nil, ctx.Err()
276 }
277 }
278
279 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
280 toolCalls := assistantMsg.ToolCalls()
281 for i, toolCall := range toolCalls {
282 select {
283 case <-ctx.Done():
284 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
285 // Make all future tool calls cancelled
286 for j := i; j < len(toolCalls); j++ {
287 toolResults[j] = message.ToolResult{
288 ToolCallID: toolCalls[j].ID,
289 Content: "Tool execution canceled by user",
290 IsError: true,
291 }
292 }
293 goto out
294 default:
295 // Continue processing
296 var tool tools.BaseTool
297 for _, availableTools := range a.tools {
298 if availableTools.Info().Name == toolCall.Name {
299 tool = availableTools
300 }
301 }
302
303 // Tool not found
304 if tool == nil {
305 toolResults[i] = message.ToolResult{
306 ToolCallID: toolCall.ID,
307 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
308 IsError: true,
309 }
310 continue
311 }
312
313 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
314 ID: toolCall.ID,
315 Name: toolCall.Name,
316 Input: toolCall.Input,
317 })
318 if toolErr != nil {
319 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
320 toolResults[i] = message.ToolResult{
321 ToolCallID: toolCall.ID,
322 Content: "Permission denied",
323 IsError: true,
324 }
325 for j := i + 1; j < len(toolCalls); j++ {
326 toolResults[j] = message.ToolResult{
327 ToolCallID: toolCalls[j].ID,
328 Content: "Tool execution canceled by user",
329 IsError: true,
330 }
331 }
332 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
333 break
334 }
335 }
336 toolResults[i] = message.ToolResult{
337 ToolCallID: toolCall.ID,
338 Content: toolResult.Content,
339 Metadata: toolResult.Metadata,
340 IsError: toolResult.IsError,
341 }
342 }
343 }
344out:
345 if len(toolResults) == 0 {
346 return assistantMsg, nil, nil
347 }
348 parts := make([]message.ContentPart, 0)
349 for _, tr := range toolResults {
350 parts = append(parts, tr)
351 }
352 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
353 Role: message.Tool,
354 Parts: parts,
355 })
356 if err != nil {
357 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
358 }
359
360 return assistantMsg, &msg, err
361}
362
363func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
364 msg.AddFinish(finishReson)
365 _ = a.messages.Update(ctx, *msg)
366}
367
368func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
369 select {
370 case <-ctx.Done():
371 return ctx.Err()
372 default:
373 // Continue processing.
374 }
375
376 switch event.Type {
377 case provider.EventThinkingDelta:
378 assistantMsg.AppendReasoningContent(event.Content)
379 return a.messages.Update(ctx, *assistantMsg)
380 case provider.EventContentDelta:
381 assistantMsg.AppendContent(event.Content)
382 return a.messages.Update(ctx, *assistantMsg)
383 case provider.EventToolUseStart:
384 assistantMsg.AddToolCall(*event.ToolCall)
385 return a.messages.Update(ctx, *assistantMsg)
386 // TODO: see how to handle this
387 // case provider.EventToolUseDelta:
388 // tm := time.Unix(assistantMsg.UpdatedAt, 0)
389 // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
390 // if time.Since(tm) > 1000*time.Millisecond {
391 // err := a.messages.Update(ctx, *assistantMsg)
392 // assistantMsg.UpdatedAt = time.Now().Unix()
393 // return err
394 // }
395 case provider.EventToolUseStop:
396 assistantMsg.FinishToolCall(event.ToolCall.ID)
397 return a.messages.Update(ctx, *assistantMsg)
398 case provider.EventError:
399 if errors.Is(event.Error, context.Canceled) {
400 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
401 return context.Canceled
402 }
403 logging.ErrorPersist(event.Error.Error())
404 return event.Error
405 case provider.EventComplete:
406 assistantMsg.SetToolCalls(event.Response.ToolCalls)
407 assistantMsg.AddFinish(event.Response.FinishReason)
408 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
409 return fmt.Errorf("failed to update message: %w", err)
410 }
411 return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
412 }
413
414 return nil
415}
416
417func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
418 sess, err := a.sessions.Get(ctx, sessionID)
419 if err != nil {
420 return fmt.Errorf("failed to get session: %w", err)
421 }
422
423 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
424 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
425 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
426 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
427
428 sess.Cost += cost
429 sess.CompletionTokens += usage.OutputTokens
430 sess.PromptTokens += usage.InputTokens
431
432 _, err = a.sessions.Save(ctx, sess)
433 if err != nil {
434 return fmt.Errorf("failed to save session: %w", err)
435 }
436 return nil
437}
438
439func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
440 cfg := config.Get()
441 agentConfig, ok := cfg.Agents[agentName]
442 if !ok {
443 return nil, fmt.Errorf("agent %s not found", agentName)
444 }
445 model, ok := models.SupportedModels[agentConfig.Model]
446 if !ok {
447 return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
448 }
449
450 providerCfg, ok := cfg.Providers[model.Provider]
451 if !ok {
452 return nil, fmt.Errorf("provider %s not supported", model.Provider)
453 }
454 if providerCfg.Disabled {
455 return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
456 }
457 maxTokens := model.DefaultMaxTokens
458 if agentConfig.MaxTokens > 0 {
459 maxTokens = agentConfig.MaxTokens
460 }
461 opts := []provider.ProviderClientOption{
462 provider.WithAPIKey(providerCfg.APIKey),
463 provider.WithModel(model),
464 provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
465 provider.WithMaxTokens(maxTokens),
466 }
467 if model.Provider == models.ProviderOpenAI && model.CanReason {
468 opts = append(
469 opts,
470 provider.WithOpenAIOptions(
471 provider.WithReasoningEffort(agentConfig.ReasoningEffort),
472 ),
473 )
474 } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
475 opts = append(
476 opts,
477 provider.WithAnthropicOptions(
478 provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
479 ),
480 )
481 }
482 agentProvider, err := provider.NewProvider(
483 model.Provider,
484 opts...,
485 )
486 if err != nil {
487 return nil, fmt.Errorf("could not create provider: %v", err)
488 }
489
490 return agentProvider, nil
491}