@@ -555,6 +555,9 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
// Process each event in the stream.
+ timer := time.NewTimer(streamChunkTimeout)
+ defer timer.Stop()
+
loop:
for {
select {
@@ -562,6 +565,9 @@ loop:
if !ok {
break loop
}
+ // Reset the timeout timer since we received a chunk
+ timer.Reset(streamChunkTimeout)
+
if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
if errors.Is(processErr, context.Canceled) {
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
@@ -570,7 +576,7 @@ loop:
}
return assistantMsg, nil, processErr
}
- case <-time.After(streamChunkTimeout):
+ case <-timer.C:
a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
case <-ctx.Done():