Cargo.lock 🔗
@@ -6826,6 +6826,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
+ "isahc",
"schemars",
"serde",
"serde_json",
Marshall Bowers created
This PR adds a setting to allow configuring the low-speed timeout for
the Assistant when using the OpenAI provider.
The `low_speed_timeout_in_seconds` accepts a number of seconds that the
HTTP client can go below a minimum speed limit (currently set to 100
bytes/second) before it times out.
```json
{
"assistant": {
"version": "1",
"provider": { "name": "openai", "low_speed_timeout_in_seconds": 60 }
},
}
```
This should help the case where the `openai` provider is being used with
a local model that requires higher timeouts.
Issue: https://github.com/zed-industries/zed/issues/9913
Release Notes:
- Added a `low_speed_timeout_in_seconds` setting to the Assistant's
OpenAI provider
([#9913](https://github.com/zed-industries/zed/issues/9913)).
Cargo.lock | 1
crates/assistant/src/assistant_settings.rs | 16 ++++++++--
crates/assistant/src/completion_provider.rs | 13 ++++++++
crates/assistant/src/completion_provider/open_ai.rs | 22 +++++++++++++-
crates/collab/src/rpc.rs | 1
crates/open_ai/Cargo.toml | 1
crates/open_ai/src/open_ai.rs | 14 +++++++--
7 files changed, 59 insertions(+), 9 deletions(-)
@@ -6826,6 +6826,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
+ "isahc",
"schemars",
"serde",
"serde_json",
@@ -153,6 +153,8 @@ pub enum AssistantProvider {
default_model: OpenAiModel,
#[serde(default = "open_ai_url")]
api_url: String,
+ #[serde(default)]
+ low_speed_timeout_in_seconds: Option<u64>,
},
}
@@ -222,12 +224,14 @@ impl AssistantSettingsContent {
Some(AssistantProvider::OpenAi {
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
api_url: open_ai_api_url.clone(),
+ low_speed_timeout_in_seconds: None,
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProvider::OpenAi {
default_model: open_ai_model,
api_url: open_ai_url(),
+ low_speed_timeout_in_seconds: None,
}
})
},
@@ -364,14 +368,17 @@ impl Settings for AssistantSettings {
AssistantProvider::OpenAi {
default_model,
api_url,
+ low_speed_timeout_in_seconds,
},
AssistantProvider::OpenAi {
default_model: default_model_override,
api_url: api_url_override,
+ low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
*default_model = default_model_override;
*api_url = api_url_override;
+ *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
}
(merged, provider_override) => {
*merged = provider_override;
@@ -408,7 +415,8 @@ mod tests {
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
- api_url: open_ai_url()
+ api_url: open_ai_url(),
+ low_speed_timeout_in_seconds: None,
}
);
@@ -429,7 +437,8 @@ mod tests {
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
- api_url: "test-url".into()
+ api_url: "test-url".into(),
+ low_speed_timeout_in_seconds: None,
}
);
cx.update_global::<SettingsStore, _>(|store, cx| {
@@ -448,7 +457,8 @@ mod tests {
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::Four,
- api_url: open_ai_url()
+ api_url: open_ai_url(),
+ low_speed_timeout_in_seconds: None,
}
);
@@ -18,6 +18,7 @@ use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
+use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
@@ -33,10 +34,12 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi {
default_model,
api_url,
+ low_speed_timeout_in_seconds,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
};
@@ -51,9 +54,15 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi {
default_model,
api_url,
+ low_speed_timeout_in_seconds,
},
) => {
- provider.update(default_model.clone(), api_url.clone(), settings_version);
+ provider.update(
+ default_model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ );
}
(
CompletionProvider::ZedDotDev(provider),
@@ -74,12 +83,14 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi {
default_model,
api_url,
+ low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
));
}
@@ -7,6 +7,7 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
+use std::time::Duration;
use std::{env, sync::Arc};
use theme::ThemeSettings;
use ui::prelude::*;
@@ -17,6 +18,7 @@ pub struct OpenAiCompletionProvider {
api_url: String,
default_model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
settings_version: usize,
}
@@ -25,6 +27,7 @@ impl OpenAiCompletionProvider {
default_model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
@@ -32,13 +35,21 @@ impl OpenAiCompletionProvider {
api_url,
default_model,
http_client,
+ low_speed_timeout,
settings_version,
}
}
- pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
+ pub fn update(
+ &mut self,
+ default_model: OpenAiModel,
+ api_url: String,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) {
self.default_model = default_model;
self.api_url = api_url;
+ self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
@@ -112,9 +123,16 @@ impl OpenAiCompletionProvider {
let http_client = self.http_client.clone();
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
+ let low_speed_timeout = self.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
- let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let request = stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ low_speed_timeout,
+ );
let response = request.await?;
let stream = response
.filter_map(|response| async move {
@@ -4344,6 +4344,7 @@ async fn complete_with_open_ai(
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
+ None,
)
.await
.context("open_ai::stream_completion request failed within collab")?;
@@ -15,6 +15,7 @@ schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
futures.workspace = true
+isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
@@ -1,7 +1,9 @@
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
+use std::time::Duration;
use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -206,14 +208,20 @@ pub async fn stream_completion(
api_url: &str,
api_key: &str,
request: Request,
+ low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions");
- let request = HttpRequest::builder()
+ let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ .header("Authorization", format!("Bearer {}", api_key));
+
+ if let Some(low_speed_timeout) = low_speed_timeout {
+ request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+ };
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());