Reuse OpenAI low_speed_timeout setting for zed.dev provider (#18144)

jvmncs created

Release Notes:

- N/A

Change summary

Cargo.lock                                  |  1 +
crates/language_model/Cargo.toml            |  1 +
crates/language_model/src/provider/cloud.rs | 22 ++++++++++++++++++++--
crates/language_model/src/settings.rs       |  9 +++++++++
4 files changed, 31 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6285,6 +6285,7 @@ dependencies = [
  "http_client",
  "image",
  "inline_completion_button",
+ "isahc",
  "language",
  "log",
  "menu",

crates/language_model/Cargo.toml 🔗

@@ -32,6 +32,7 @@ futures.workspace = true
 google_ai = { workspace = true, features = ["schemars"] }
 gpui.workspace = true
 http_client.workspace = true
+isahc.workspace = true
 inline_completion_button.workspace = true
 log.workspace = true
 menu.workspace = true

crates/language_model/src/provider/cloud.rs 🔗

@@ -19,6 +19,7 @@ use gpui::{
     Subscription, Task,
 };
 use http_client::{AsyncBody, HttpClient, Method, Response};
+use isahc::config::Configurable;
 use schemars::JsonSchema;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::value::RawValue;
@@ -27,6 +28,7 @@ use smol::{
     io::{AsyncReadExt, BufReader},
     lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 };
+use std::time::Duration;
 use std::{
     future,
     sync::{Arc, LazyLock},
@@ -56,6 +58,7 @@ fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct ZedDotDevSettings {
     pub available_models: Vec<AvailableModel>,
+    pub low_speed_timeout: Option<Duration>,
 }
 
 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@@ -380,6 +383,7 @@ impl CloudLanguageModel {
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
         body: PerformCompletionParams,
+        low_speed_timeout: Option<Duration>,
     ) -> Result<Response<AsyncBody>> {
         let http_client = &client.http_client();
 
@@ -387,7 +391,11 @@ impl CloudLanguageModel {
         let mut did_retry = false;
 
         let response = loop {
-            let request = http_client::Request::builder()
+            let mut request_builder = http_client::Request::builder();
+            if let Some(low_speed_timeout) = low_speed_timeout {
+                request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+            };
+            let request = request_builder
                 .method(Method::POST)
                 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
                 .header("Content-Type", "application/json")
@@ -501,8 +509,11 @@ impl LanguageModel for CloudLanguageModel {
     fn stream_completion(
         &self,
         request: LanguageModelRequest,
-        _cx: &AsyncAppContext,
+        cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+        let openai_low_speed_timeout =
+            AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
+
         match &self.model {
             CloudModel::Anthropic(model) => {
                 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
@@ -519,6 +530,7 @@ impl LanguageModel for CloudLanguageModel {
                                 &request,
                             )?)?,
                         },
+                        None,
                     )
                     .await?;
                     Ok(map_to_language_model_completion_events(Box::pin(
@@ -542,6 +554,7 @@ impl LanguageModel for CloudLanguageModel {
                                 &request,
                             )?)?,
                         },
+                        openai_low_speed_timeout,
                     )
                     .await?;
                     Ok(open_ai::extract_text_from_events(response_lines(response)))
@@ -569,6 +582,7 @@ impl LanguageModel for CloudLanguageModel {
                                 &request,
                             )?)?,
                         },
+                        None,
                     )
                     .await?;
                     Ok(google_ai::extract_text_from_events(response_lines(
@@ -599,6 +613,7 @@ impl LanguageModel for CloudLanguageModel {
                                 &request,
                             )?)?,
                         },
+                        None,
                     )
                     .await?;
                     Ok(open_ai::extract_text_from_events(response_lines(response)))
@@ -650,6 +665,7 @@ impl LanguageModel for CloudLanguageModel {
                                     &request,
                                 )?)?,
                             },
+                            None,
                         )
                         .await?;
 
@@ -694,6 +710,7 @@ impl LanguageModel for CloudLanguageModel {
                                     &request,
                                 )?)?,
                             },
+                            None,
                         )
                         .await?;
 
@@ -741,6 +758,7 @@ impl LanguageModel for CloudLanguageModel {
                                     &request,
                                 )?)?,
                             },
+                            None,
                         )
                         .await?;
 

crates/language_model/src/settings.rs 🔗

@@ -231,6 +231,7 @@ pub struct GoogleSettingsContent {
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
 pub struct ZedDotDevSettingsContent {
     available_models: Option<Vec<cloud::AvailableModel>>,
+    pub low_speed_timeout_in_seconds: Option<u64>,
 }
 
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -333,6 +334,14 @@ impl settings::Settings for AllLanguageModelSettings {
                     .as_ref()
                     .and_then(|s| s.available_models.clone()),
             );
+            if let Some(low_speed_timeout_in_seconds) = value
+                .zed_dot_dev
+                .as_ref()
+                .and_then(|s| s.low_speed_timeout_in_seconds)
+            {
+                settings.zed_dot_dev.low_speed_timeout =
+                    Some(Duration::from_secs(low_speed_timeout_in_seconds));
+            }
 
             merge(
                 &mut settings.google.api_url,