eval: retry in more scenarios

Ben Brandt created

Change summary

crates/assistant_tools/src/edit_agent/evals.rs | 101 ++++++++++++-------
1 file changed, 61 insertions(+), 40 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_agent/evals.rs 🔗

@@ -1663,47 +1663,68 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
         attempt += 1;
         match request().await {
             Ok(result) => return Ok(result),
-            Err(err) => match err.downcast::<LanguageModelCompletionError>() {
-                Ok(err) => match &err {
-                    LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
-                    | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
-                        let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
-                        // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
-                        let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
-                        eprintln!(
-                            "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
-                        );
-                        Timer::after(retry_after + jitter).await;
-                        continue;
-                    }
-                    LanguageModelCompletionError::UpstreamProviderError {
-                        status,
-                        retry_after,
-                        ..
-                    } => {
-                        // Only retry for specific status codes
-                        let should_retry = matches!(
-                            *status,
-                            StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE
-                        ) || status.as_u16() == 529;
-
-                        if !should_retry {
-                            return Err(err.into());
-                        }
+            Err(err) => {
+                if attempt > 20 {
+                    return Err(err);
+                }
 
-                        // Use server-provided retry_after if available, otherwise use default
-                        let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
-                        let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
-                        eprintln!(
-                            "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
-                        );
-                        Timer::after(retry_after + jitter).await;
-                        continue;
-                    }
-                    _ => return Err(err.into()),
-                },
-                Err(err) => return Err(err),
-            },
+                match err.downcast::<LanguageModelCompletionError>() {
+                    Ok(err) => match &err {
+                        LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
+                        | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
+                            let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
+                            // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
+                            let jitter =
+                                retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
+                            eprintln!(
+                                "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
+                            );
+                            Timer::after(retry_after + jitter).await;
+                            continue;
+                        }
+                        LanguageModelCompletionError::UpstreamProviderError {
+                            status,
+                            retry_after,
+                            ..
+                        } => {
+                            // Only retry for specific status codes
+                            let should_retry = matches!(
+                                *status,
+                                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE
+                            ) || status.as_u16() == 529;
+
+                            if !should_retry {
+                                return Err(err.into());
+                            }
+
+                            // Use server-provided retry_after if available, otherwise use default
+                            let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
+                            let jitter =
+                                retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
+                            eprintln!(
+                                "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
+                            );
+                            Timer::after(retry_after + jitter).await;
+                            continue;
+                        }
+                        LanguageModelCompletionError::ApiInternalServerError { .. }
+                        | LanguageModelCompletionError::ApiReadResponseError { .. }
+                        | LanguageModelCompletionError::DeserializeResponse { .. }
+                        | LanguageModelCompletionError::HttpSend { .. } => {
+                            let retry_after = Duration::from_secs(attempt);
+                            let jitter =
+                                retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
+                            eprintln!(
+                                "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
+                            );
+                            Timer::after(retry_after + jitter).await;
+                            continue;
+                        }
+                        _ => return Err(err.into()),
+                    },
+                    Err(err) => return Err(err),
+                }
+            }
         }
     }
 }