Add configurable low-speed timeout for OpenAI provider (#11668)

Marshall Bowers created

This PR adds a setting to allow configuring the low-speed timeout for
the Assistant when using the OpenAI provider.

The `low_speed_timeout_in_seconds` accepts a number of seconds that the
HTTP client can go below a minimum speed limit (currently set to 100
bytes/second) before it times out.

```json
{
  "assistant": {
    "version": "1",
    "provider": { "name": "openai", "low_speed_timeout_in_seconds": 60 }
  },
}
```

This should help the case where the `openai` provider is being used with
a local model that requires higher timeouts.

Issue: https://github.com/zed-industries/zed/issues/9913

Release Notes:

- Added a `low_speed_timeout_in_seconds` setting to the Assistant's
OpenAI provider
([#9913](https://github.com/zed-industries/zed/issues/9913)).

Change summary

Cargo.lock                                          |  1 
crates/assistant/src/assistant_settings.rs          | 16 ++++++++--
crates/assistant/src/completion_provider.rs         | 13 ++++++++
crates/assistant/src/completion_provider/open_ai.rs | 22 +++++++++++++-
crates/collab/src/rpc.rs                            |  1 
crates/open_ai/Cargo.toml                           |  1 
crates/open_ai/src/open_ai.rs                       | 14 +++++++--
7 files changed, 59 insertions(+), 9 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6826,6 +6826,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "futures 0.3.28",
+ "isahc",
  "schemars",
  "serde",
  "serde_json",

crates/assistant/src/assistant_settings.rs 🔗

@@ -153,6 +153,8 @@ pub enum AssistantProvider {
         default_model: OpenAiModel,
         #[serde(default = "open_ai_url")]
         api_url: String,
+        #[serde(default)]
+        low_speed_timeout_in_seconds: Option<u64>,
     },
 }
 
@@ -222,12 +224,14 @@ impl AssistantSettingsContent {
                     Some(AssistantProvider::OpenAi {
                         default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
                         api_url: open_ai_api_url.clone(),
+                        low_speed_timeout_in_seconds: None,
                     })
                 } else {
                     settings.default_open_ai_model.clone().map(|open_ai_model| {
                         AssistantProvider::OpenAi {
                             default_model: open_ai_model,
                             api_url: open_ai_url(),
+                            low_speed_timeout_in_seconds: None,
                         }
                     })
                 },
@@ -364,14 +368,17 @@ impl Settings for AssistantSettings {
                         AssistantProvider::OpenAi {
                             default_model,
                             api_url,
+                            low_speed_timeout_in_seconds,
                         },
                         AssistantProvider::OpenAi {
                             default_model: default_model_override,
                             api_url: api_url_override,
+                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
                         },
                     ) => {
                         *default_model = default_model_override;
                         *api_url = api_url_override;
+                        *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
                     }
                     (merged, provider_override) => {
                         *merged = provider_override;
@@ -408,7 +415,8 @@ mod tests {
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
                 default_model: OpenAiModel::FourTurbo,
-                api_url: open_ai_url()
+                api_url: open_ai_url(),
+                low_speed_timeout_in_seconds: None,
             }
         );
 
@@ -429,7 +437,8 @@ mod tests {
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
                 default_model: OpenAiModel::FourTurbo,
-                api_url: "test-url".into()
+                api_url: "test-url".into(),
+                low_speed_timeout_in_seconds: None,
             }
         );
         cx.update_global::<SettingsStore, _>(|store, cx| {
@@ -448,7 +457,8 @@ mod tests {
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::OpenAi {
                 default_model: OpenAiModel::Four,
-                api_url: open_ai_url()
+                api_url: open_ai_url(),
+                low_speed_timeout_in_seconds: None,
             }
         );
 

crates/assistant/src/completion_provider.rs 🔗

@@ -18,6 +18,7 @@ use futures::{future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 use settings::{Settings, SettingsStore};
 use std::sync::Arc;
+use std::time::Duration;
 
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     let mut settings_version = 0;
@@ -33,10 +34,12 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
         AssistantProvider::OpenAi {
             default_model,
             api_url,
+            low_speed_timeout_in_seconds,
         } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
             default_model.clone(),
             api_url.clone(),
             client.http_client(),
+            low_speed_timeout_in_seconds.map(Duration::from_secs),
             settings_version,
         )),
     };
@@ -51,9 +54,15 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                     AssistantProvider::OpenAi {
                         default_model,
                         api_url,
+                        low_speed_timeout_in_seconds,
                     },
                 ) => {
-                    provider.update(default_model.clone(), api_url.clone(), settings_version);
+                    provider.update(
+                        default_model.clone(),
+                        api_url.clone(),
+                        low_speed_timeout_in_seconds.map(Duration::from_secs),
+                        settings_version,
+                    );
                 }
                 (
                     CompletionProvider::ZedDotDev(provider),
@@ -74,12 +83,14 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                     AssistantProvider::OpenAi {
                         default_model,
                         api_url,
+                        low_speed_timeout_in_seconds,
                     },
                 ) => {
                     *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
                         default_model.clone(),
                         api_url.clone(),
                         client.http_client(),
+                        low_speed_timeout_in_seconds.map(Duration::from_secs),
                         settings_version,
                     ));
                 }

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -7,6 +7,7 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
 use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 use settings::Settings;
+use std::time::Duration;
 use std::{env, sync::Arc};
 use theme::ThemeSettings;
 use ui::prelude::*;
@@ -17,6 +18,7 @@ pub struct OpenAiCompletionProvider {
     api_url: String,
     default_model: OpenAiModel,
     http_client: Arc<dyn HttpClient>,
+    low_speed_timeout: Option<Duration>,
     settings_version: usize,
 }
 
@@ -25,6 +27,7 @@ impl OpenAiCompletionProvider {
         default_model: OpenAiModel,
         api_url: String,
         http_client: Arc<dyn HttpClient>,
+        low_speed_timeout: Option<Duration>,
         settings_version: usize,
     ) -> Self {
         Self {
@@ -32,13 +35,21 @@ impl OpenAiCompletionProvider {
             api_url,
             default_model,
             http_client,
+            low_speed_timeout,
             settings_version,
         }
     }
 
-    pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
+    pub fn update(
+        &mut self,
+        default_model: OpenAiModel,
+        api_url: String,
+        low_speed_timeout: Option<Duration>,
+        settings_version: usize,
+    ) {
         self.default_model = default_model;
         self.api_url = api_url;
+        self.low_speed_timeout = low_speed_timeout;
         self.settings_version = settings_version;
     }
 
@@ -112,9 +123,16 @@ impl OpenAiCompletionProvider {
         let http_client = self.http_client.clone();
         let api_key = self.api_key.clone();
         let api_url = self.api_url.clone();
+        let low_speed_timeout = self.low_speed_timeout;
         async move {
             let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
-            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = stream_completion(
+                http_client.as_ref(),
+                &api_url,
+                &api_key,
+                request,
+                low_speed_timeout,
+            );
             let response = request.await?;
             let stream = response
                 .filter_map(|response| async move {

crates/collab/src/rpc.rs 🔗

@@ -4344,6 +4344,7 @@ async fn complete_with_open_ai(
         OPEN_AI_API_URL,
         &api_key,
         crate::ai::language_model_request_to_open_ai(request)?,
+        None,
     )
     .await
     .context("open_ai::stream_completion request failed within collab")?;

crates/open_ai/Cargo.toml 🔗

@@ -15,6 +15,7 @@ schemars = ["dep:schemars"]
 [dependencies]
 anyhow.workspace = true
 futures.workspace = true
+isahc.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true

crates/open_ai/src/open_ai.rs 🔗

@@ -1,7 +1,9 @@
 use anyhow::{anyhow, Context, Result};
 use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
 use serde_json::{Map, Value};
+use std::time::Duration;
 use std::{convert::TryFrom, future::Future};
 use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 
@@ -206,14 +208,20 @@ pub async fn stream_completion(
     api_url: &str,
     api_key: &str,
     request: Request,
+    low_speed_timeout: Option<Duration>,
 ) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
     let uri = format!("{api_url}/chat/completions");
-    let request = HttpRequest::builder()
+    let mut request_builder = HttpRequest::builder()
         .method(Method::POST)
         .uri(uri)
         .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
-        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+        .header("Authorization", format!("Bearer {}", api_key));
+
+    if let Some(low_speed_timeout) = low_speed_timeout {
+        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+    };
+
+    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
     let mut response = client.send(request).await?;
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());