@@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
-use futures::{Stream, StreamExt as _};
+use futures::{FutureExt, Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
@@ -475,7 +475,11 @@ async fn predict_edits(
.replace("<excerpt>", ¶ms.input_excerpt);
let request_start = std::time::Instant::now();
- let mut response = fireworks::complete(
+ let timeout = state
+ .executor
+ .sleep(std::time::Duration::from_secs(2))
+ .fuse();
+ let response = fireworks::complete(
&state.http_client,
api_url,
api_key,
@@ -490,41 +494,72 @@ async fn predict_edits(
rewrite_speculation: Some(true),
},
)
- .await?;
- let duration = request_start.elapsed();
-
- let choice = response
- .completion
- .choices
- .pop()
- .context("no output from completion response")?;
-
- state.executor.spawn_detached({
- let kinesis_client = state.kinesis_client.clone();
- let kinesis_stream = state.config.kinesis_stream.clone();
- let model = model.clone();
- async move {
- SnowflakeRow::new(
- "Fireworks Completion Requested",
- claims.metrics_id,
- claims.is_staff,
- claims.system_id.clone(),
- json!({
- "model": model.to_string(),
- "headers": response.headers,
- "usage": response.completion.usage,
- "duration": duration.as_secs_f64(),
- }),
- )
- .write(&kinesis_client, &kinesis_stream)
- .await
- .log_err();
- }
- });
+ .fuse();
+ futures::pin_mut!(timeout);
+ futures::pin_mut!(response);
+
+ futures::select! {
+ _ = timeout => {
+ state.executor.spawn_detached({
+ let kinesis_client = state.kinesis_client.clone();
+ let kinesis_stream = state.config.kinesis_stream.clone();
+ let model = model.clone();
+ async move {
+ SnowflakeRow::new(
+ "Fireworks Completion Timeout",
+ claims.metrics_id,
+ claims.is_staff,
+ claims.system_id.clone(),
+ json!({
+ "model": model.to_string(),
+ "prompt": prompt,
+ }),
+ )
+ .write(&kinesis_client, &kinesis_stream)
+ .await
+ .log_err();
+ }
+ });
+ Err(anyhow!("request timed out"))?
+ },
+ response = response => {
+ let duration = request_start.elapsed();
+
+ let mut response = response?;
+ let choice = response
+ .completion
+ .choices
+ .pop()
+ .context("no output from completion response")?;
+
+ state.executor.spawn_detached({
+ let kinesis_client = state.kinesis_client.clone();
+ let kinesis_stream = state.config.kinesis_stream.clone();
+ let model = model.clone();
+ async move {
+ SnowflakeRow::new(
+ "Fireworks Completion Requested",
+ claims.metrics_id,
+ claims.is_staff,
+ claims.system_id.clone(),
+ json!({
+ "model": model.to_string(),
+ "headers": response.headers,
+ "usage": response.completion.usage,
+ "duration": duration.as_secs_f64(),
+ }),
+ )
+ .write(&kinesis_client, &kinesis_stream)
+ .await
+ .log_err();
+ }
+ });
- Ok(Json(PredictEditsResponse {
- output_excerpt: choice.text,
- }))
+ Ok(Json(PredictEditsResponse {
+ output_excerpt: choice.text,
+ }))
+ },
+ }
}
/// The maximum monthly spending an individual user can reach on the free tier