Timeout if completion takes longer than 2s (#23215)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs | 107 +++++++++++++++++++++++++++--------------
1 file changed, 71 insertions(+), 36 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, Utc};
 use collections::HashMap;
 use db::TokenUsage;
 use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
-use futures::{Stream, StreamExt as _};
+use futures::{FutureExt, Stream, StreamExt as _};
 use reqwest_client::ReqwestClient;
 use rpc::{
     proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
@@ -475,7 +475,11 @@ async fn predict_edits(
         .replace("<excerpt>", &params.input_excerpt);
 
     let request_start = std::time::Instant::now();
-    let mut response = fireworks::complete(
+    let timeout = state
+        .executor
+        .sleep(std::time::Duration::from_secs(2))
+        .fuse();
+    let response = fireworks::complete(
         &state.http_client,
         api_url,
         api_key,
@@ -490,41 +494,72 @@ async fn predict_edits(
             rewrite_speculation: Some(true),
         },
     )
-    .await?;
-    let duration = request_start.elapsed();
-
-    let choice = response
-        .completion
-        .choices
-        .pop()
-        .context("no output from completion response")?;
-
-    state.executor.spawn_detached({
-        let kinesis_client = state.kinesis_client.clone();
-        let kinesis_stream = state.config.kinesis_stream.clone();
-        let model = model.clone();
-        async move {
-            SnowflakeRow::new(
-                "Fireworks Completion Requested",
-                claims.metrics_id,
-                claims.is_staff,
-                claims.system_id.clone(),
-                json!({
-                    "model": model.to_string(),
-                    "headers": response.headers,
-                    "usage": response.completion.usage,
-                    "duration": duration.as_secs_f64(),
-                }),
-            )
-            .write(&kinesis_client, &kinesis_stream)
-            .await
-            .log_err();
-        }
-    });
+    .fuse();
+    futures::pin_mut!(timeout);
+    futures::pin_mut!(response);
+
+    futures::select! {
+        _ = timeout => {
+            state.executor.spawn_detached({
+                let kinesis_client = state.kinesis_client.clone();
+                let kinesis_stream = state.config.kinesis_stream.clone();
+                let model = model.clone();
+                async move {
+                    SnowflakeRow::new(
+                        "Fireworks Completion Timeout",
+                        claims.metrics_id,
+                        claims.is_staff,
+                        claims.system_id.clone(),
+                        json!({
+                            "model": model.to_string(),
+                            "prompt": prompt,
+                        }),
+                    )
+                    .write(&kinesis_client, &kinesis_stream)
+                    .await
+                    .log_err();
+                }
+            });
+            Err(anyhow!("request timed out"))?
+        },
+        response = response => {
+            let duration = request_start.elapsed();
+
+            let mut response = response?;
+            let choice = response
+                .completion
+                .choices
+                .pop()
+                .context("no output from completion response")?;
+
+            state.executor.spawn_detached({
+                let kinesis_client = state.kinesis_client.clone();
+                let kinesis_stream = state.config.kinesis_stream.clone();
+                let model = model.clone();
+                async move {
+                    SnowflakeRow::new(
+                        "Fireworks Completion Requested",
+                        claims.metrics_id,
+                        claims.is_staff,
+                        claims.system_id.clone(),
+                        json!({
+                            "model": model.to_string(),
+                            "headers": response.headers,
+                            "usage": response.completion.usage,
+                            "duration": duration.as_secs_f64(),
+                        }),
+                    )
+                    .write(&kinesis_client, &kinesis_stream)
+                    .await
+                    .log_err();
+                }
+            });
 
-    Ok(Json(PredictEditsResponse {
-        output_excerpt: choice.text,
-    }))
+            Ok(Json(PredictEditsResponse {
+                output_excerpt: choice.text,
+            }))
+        },
+    }
 }
 
 /// The maximum monthly spending an individual user can reach on the free tier