fix(agent): queued messages issues

Kujtim Hoxha created

Change summary

internal/agent/agent.go | 34 ++++++++++++++++++----------------
1 file changed, 18 insertions(+), 16 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -183,21 +183,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 		FrequencyPenalty: call.FrequencyPenalty,
 		// Before each step create the new assistant message
 		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
-			var assistantMsg message.Message
-			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
-				Role:     message.Assistant,
-				Parts:    []message.ContentPart{},
-				Model:    a.largeModel.ModelCfg.Model,
-				Provider: a.largeModel.ModelCfg.Provider,
-			})
-			if err != nil {
-				return callContext, prepared, err
-			}
-
-			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
-
-			currentAssistant = &assistantMsg
-
 			prepared.Messages = options.Messages
 			// reset all cached items
 			for i := range prepared.Messages {
@@ -229,6 +214,19 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
 				}
 			}
+
+			var assistantMsg message.Message
+			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
+				Role:     message.Assistant,
+				Parts:    []message.ContentPart{},
+				Model:    a.largeModel.ModelCfg.Model,
+				Provider: a.largeModel.ModelCfg.Provider,
+			})
+			if err != nil {
+				return callContext, prepared, err
+			}
+			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
+			currentAssistant = &assistantMsg
 			return callContext, prepared, err
 		},
 		OnReasoningDelta: func(id string, text string) error {
@@ -432,6 +430,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 		}
 	}
 
+	// release active request before processing queued messages
+	a.activeRequests.Del(call.SessionID)
+	cancel()
+
 	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 	if !ok || len(queuedMessages) == 0 {
 		return result, err
@@ -439,7 +441,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 	// there are queued messages restart the loop
 	firstQueuedMessage := queuedMessages[0]
 	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
-	return a.Run(genCtx, firstQueuedMessage)
+	return a.Run(ctx, firstQueuedMessage)
 }
 
 func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {