assistant: Fix Google AI provider not respecting `low_speed_timeout_in_seconds` (#17423)

Bennet Bo Fenner created

Release Notes:

- Fixed an issue when using Google Gemini models, where the setting
`low_speed_timeout_in_seconds` was not respected

Change summary

Cargo.lock                                   |  1 
crates/collab/src/llm.rs                     |  1 
crates/collab/src/rpc.rs                     |  1 
crates/google_ai/Cargo.toml                  |  1 
crates/google_ai/src/google_ai.rs            | 32 +++++++++++++++++++--
crates/language_model/src/provider/google.rs | 26 ++++++++++++-----
6 files changed, 50 insertions(+), 12 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4928,6 +4928,7 @@ dependencies = [
  "anyhow",
  "futures 0.3.30",
  "http_client",
+ "isahc",
  "schemars",
  "serde",
  "serde_json",

crates/collab/src/llm.rs 🔗

@@ -380,6 +380,7 @@ async fn perform_completion(
                 google_ai::API_URL,
                 api_key,
                 serde_json::from_str(&params.provider_request.get())?,
+                None,
             )
             .await?;
 

crates/collab/src/rpc.rs 🔗

@@ -4540,6 +4540,7 @@ async fn count_language_model_tokens(
                 google_ai::API_URL,
                 api_key,
                 serde_json::from_str(&request.request)?,
+                None,
             )
             .await?
         }

crates/google_ai/Cargo.toml 🔗

@@ -18,6 +18,7 @@ schemars = ["dep:schemars"]
 anyhow.workspace = true
 futures.workspace = true
 http_client.workspace = true
+isahc.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true

crates/google_ai/src/google_ai.rs 🔗

@@ -2,8 +2,10 @@ mod supported_countries;
 
 use anyhow::{anyhow, Result};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
-use http_client::HttpClient;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
+use std::time::Duration;
 
 pub use supported_countries::*;
 
@@ -14,6 +16,7 @@ pub async fn stream_generate_content(
     api_url: &str,
     api_key: &str,
     mut request: GenerateContentRequest,
+    low_speed_timeout: Option<Duration>,
 ) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
     let uri = format!(
         "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
@@ -21,8 +24,17 @@ pub async fn stream_generate_content(
     );
     request.model.clear();
 
-    let request = serde_json::to_string(&request)?;
-    let mut response = client.post_json(&uri, request.into()).await?;
+    let mut request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json");
+
+    if let Some(low_speed_timeout) = low_speed_timeout {
+        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+    };
+
+    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let mut response = client.send(request).await?;
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());
         Ok(reader
@@ -59,13 +71,25 @@ pub async fn count_tokens(
     api_url: &str,
     api_key: &str,
     request: CountTokensRequest,
+    low_speed_timeout: Option<Duration>,
 ) -> Result<CountTokensResponse> {
     let uri = format!(
         "{}/v1beta/models/gemini-pro:countTokens?key={}",
         api_url, api_key
     );
     let request = serde_json::to_string(&request)?;
-    let mut response = client.post_json(&uri, request.into()).await?;
+
+    let mut request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(&uri)
+        .header("Content-Type", "application/json");
+
+    if let Some(low_speed_timeout) = low_speed_timeout {
+        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+    }
+
+    let http_request = request_builder.body(AsyncBody::from(request))?;
+    let mut response = client.send(http_request).await?;
     let mut text = String::new();
     response.body_mut().read_to_string(&mut text).await?;
     if response.status().is_success() {

crates/language_model/src/provider/google.rs 🔗

@@ -257,10 +257,10 @@ impl LanguageModel for GoogleLanguageModel {
         let request = request.into_google(self.model.id().to_string());
         let http_client = self.http_client.clone();
         let api_key = self.state.read(cx).api_key.clone();
-        let api_url = AllLanguageModelSettings::get_global(cx)
-            .google
-            .api_url
-            .clone();
+
+        let settings = &AllLanguageModelSettings::get_global(cx).google;
+        let api_url = settings.api_url.clone();
+        let low_speed_timeout = settings.low_speed_timeout;
 
         async move {
             let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
@@ -271,6 +271,7 @@ impl LanguageModel for GoogleLanguageModel {
                 google_ai::CountTokensRequest {
                     contents: request.contents,
                 },
+                low_speed_timeout,
             )
             .await?;
             Ok(response.total_tokens)
@@ -289,17 +290,26 @@ impl LanguageModel for GoogleLanguageModel {
         let request = request.into_google(self.model.id().to_string());
 
         let http_client = self.http_client.clone();
-        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
+        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
             let settings = &AllLanguageModelSettings::get_global(cx).google;
-            (state.api_key.clone(), settings.api_url.clone())
+            (
+                state.api_key.clone(),
+                settings.api_url.clone(),
+                settings.low_speed_timeout,
+            )
         }) else {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
         };
 
         let future = self.rate_limiter.stream(async move {
             let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
-            let response =
-                stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
+            let response = stream_generate_content(
+                http_client.as_ref(),
+                &api_url,
+                &api_key,
+                request,
+                low_speed_timeout,
+            );
             let events = response.await?;
             Ok(google_ai::extract_text_from_events(events).boxed())
         });