Detailed changes
@@ -62,7 +62,7 @@ jobs:
- name: Run unit evals
shell: bash -euxo pipefail {0}
- run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)' --test-threads 1
+ run: cargo nextest run --workspace --no-fail-fast --features eval --no-capture -E 'test(::eval_)'
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
@@ -705,6 +705,7 @@ dependencies = [
"serde_json",
"settings",
"smallvec",
+ "smol",
"streaming_diff",
"strsim",
"task",
@@ -386,8 +386,10 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(&model, user_prompt, cx)?;
- cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
- .boxed_local()
+ cx.spawn(async move |_, cx| {
+ Ok(model.stream_completion_text(request.await, &cx).await?)
+ })
+ .boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())
@@ -1563,6 +1563,9 @@ impl Thread {
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
+ Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
+ return Err(err.into());
+ }
};
match event {
@@ -1,4 +1,5 @@
use std::str::FromStr;
+use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
@@ -406,6 +407,7 @@ impl RateLimit {
/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
#[derive(Debug)]
pub struct RateLimitInfo {
+ pub retry_after: Option<Duration>,
pub requests: Option<RateLimit>,
pub tokens: Option<RateLimit>,
pub input_tokens: Option<RateLimit>,
@@ -417,10 +419,11 @@ impl RateLimitInfo {
// Check if any rate limit headers exist
let has_rate_limit_headers = headers
.keys()
- .any(|k| k.as_str().starts_with("anthropic-ratelimit-"));
+ .any(|k| k == "retry-after" || k.as_str().starts_with("anthropic-ratelimit-"));
if !has_rate_limit_headers {
return Self {
+ retry_after: None,
requests: None,
tokens: None,
input_tokens: None,
@@ -429,6 +432,11 @@ impl RateLimitInfo {
}
Self {
+ retry_after: headers
+ .get("retry-after")
+ .and_then(|v| v.to_str().ok())
+ .and_then(|v| v.parse::<u64>().ok())
+ .map(Duration::from_secs),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@@ -481,8 +489,8 @@ pub async fn stream_completion_with_rate_limit_info(
.send(request)
.await
.context("failed to send request to Anthropic")?;
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
- let rate_limits = RateLimitInfo::from_headers(response.headers());
let reader = BufReader::new(response.into_body());
let stream = reader
.lines()
@@ -500,6 +508,8 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
+ } else if let Some(retry_after) = rate_limits.retry_after {
+ Err(AnthropicError::RateLimit(retry_after))
} else {
let mut body = Vec::new();
response
@@ -769,6 +779,8 @@ pub struct MessageDelta {
#[derive(Error, Debug)]
pub enum AnthropicError {
+ #[error("rate limit exceeded, retry after {0:?}")]
+ RateLimit(Duration),
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
ApiError(ApiError),
#[error("{0}")]
@@ -682,11 +682,12 @@ mod tests {
_: &AsyncApp,
) -> BoxFuture<
'static,
- http_client::Result<
+ Result<
BoxStream<
'static,
- http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
+ LanguageModelCompletionError,
>,
> {
unimplemented!()
@@ -80,6 +80,7 @@ rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
+smol.workspace = true
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true
theme.workspace = true
@@ -11,7 +11,7 @@ use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
-use gpui::{AppContext, TestAppContext};
+use gpui::{AppContext, TestAppContext, Timer};
use indoc::{formatdoc, indoc};
use language_model::{
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
@@ -1255,9 +1255,12 @@ impl EvalAssertion {
}],
..Default::default()
};
- let mut response = judge
- .stream_completion_text(request, &cx.to_async())
- .await?;
+ let mut response = retry_on_rate_limit(async || {
+ Ok(judge
+ .stream_completion_text(request.clone(), &cx.to_async())
+ .await?)
+ })
+ .await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
@@ -1308,10 +1311,17 @@ fn eval(
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
+ let semaphore = Arc::new(smol::lock::Semaphore::new(32));
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
- executor.spawn(async move { run_eval(eval, tx) }).detach();
+ let semaphore = semaphore.clone();
+ executor
+ .spawn(async move {
+ let _guard = semaphore.acquire().await;
+ run_eval(eval, tx)
+ })
+ .detach();
}
drop(tx);
@@ -1577,21 +1587,31 @@ impl EditAgentTest {
if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
}
- let (edit_output, _) = self.agent.edit(
- buffer.clone(),
- eval.edit_file_input.display_description,
- &conversation,
- &mut cx.to_async(),
- );
- edit_output.await?
+ retry_on_rate_limit(async || {
+ self.agent
+ .edit(
+ buffer.clone(),
+ eval.edit_file_input.display_description.clone(),
+ &conversation,
+ &mut cx.to_async(),
+ )
+ .0
+ .await
+ })
+ .await?
} else {
- let (edit_output, _) = self.agent.overwrite(
- buffer.clone(),
- eval.edit_file_input.display_description,
- &conversation,
- &mut cx.to_async(),
- );
- edit_output.await?
+ retry_on_rate_limit(async || {
+ self.agent
+ .overwrite(
+ buffer.clone(),
+ eval.edit_file_input.display_description.clone(),
+ &conversation,
+ &mut cx.to_async(),
+ )
+ .0
+ .await
+ })
+ .await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
@@ -1613,6 +1633,26 @@ impl EditAgentTest {
}
}
+async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) -> Result<R> {
+ loop {
+ match request().await {
+ Ok(result) => return Ok(result),
+ Err(err) => match err.downcast::<LanguageModelCompletionError>() {
+ Ok(err) => match err {
+ LanguageModelCompletionError::RateLimit(duration) => {
+ // Wait until after we are allowed to try again
+ eprintln!("Rate limit exceeded. Waiting for {duration:?}...",);
+ Timer::after(duration).await;
+ continue;
+ }
+ _ => return Err(err.into()),
+ },
+ Err(err) => return Err(err),
+ },
+ }
+ }
+}
+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionOutcome {
score: usize,
@@ -185,6 +185,7 @@ impl LanguageModel for FakeLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let (tx, rx) = mpsc::unbounded();
@@ -22,6 +22,7 @@ use std::fmt;
use std::ops::{Add, Sub};
use std::str::FromStr as _;
use std::sync::Arc;
+use std::time::Duration;
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::{
@@ -74,6 +75,8 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
+ #[error("rate limit exceeded, retry after {0:?}")]
+ RateLimit(Duration),
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
@@ -270,6 +273,7 @@ pub trait LanguageModel: Send + Sync {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
>;
@@ -277,7 +281,7 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
+ ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
let future = self.stream_completion(request, cx);
async move {
@@ -1,4 +1,3 @@
-use anyhow::Result;
use futures::Stream;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{
@@ -8,6 +7,8 @@ use std::{
task::{Context, Poll},
};
+use crate::LanguageModelCompletionError;
+
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
@@ -36,9 +37,12 @@ impl RateLimiter {
}
}
- pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
+ pub fn run<'a, Fut, T>(
+ &self,
+ future: Fut,
+ ) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
where
- Fut: 'a + Future<Output = Result<T>>,
+ Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
{
let guard = self.semaphore.acquire_arc();
async move {
@@ -52,9 +56,12 @@ impl RateLimiter {
pub fn stream<'a, Fut, T>(
&self,
future: Fut,
- ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
+ ) -> impl 'a
+ + Future<
+ Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
+ >
where
- Fut: 'a + Future<Output = Result<T>>,
+ Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
@@ -387,22 +387,34 @@ impl AnthropicModel {
&self,
request: anthropic::Request,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
- {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
+ LanguageModelCompletionError,
+ >,
+ > {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
async move {
let api_key = api_key.context("Missing Anthropic API Key")?;
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- request.await.context("failed to stream completion")
+ request.await.map_err(|err| match err {
+ AnthropicError::RateLimit(duration) => {
+ LanguageModelCompletionError::RateLimit(duration)
+ }
+ err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
+ LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
+ }
+ })
}
.boxed()
}
@@ -473,6 +485,7 @@ impl LanguageModel for AnthropicModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let request = into_anthropic(
@@ -484,12 +497,7 @@ impl LanguageModel for AnthropicModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
- let response = request
- .await
- .map_err(|err| match err.downcast::<AnthropicError>() {
- Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
- Err(err) => anyhow!(err),
- })?;
+ let response = request.await?;
Ok(AnthropicEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -527,6 +527,7 @@ impl LanguageModel for BedrockModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
@@ -539,16 +540,13 @@ impl LanguageModel for BedrockModel {
.or(settings_region)
.unwrap_or(String::from("us-east-1"))
}) else {
- return async move {
- anyhow::bail!("App State Dropped");
- }
- .boxed();
+ return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
};
let model_id = match self.model.cross_region_inference_id(®ion) {
Ok(s) => s,
Err(e) => {
- return async move { Err(e) }.boxed();
+ return async move { Err(e.into()) }.boxed();
}
};
@@ -560,7 +558,7 @@ impl LanguageModel for BedrockModel {
self.model.mode(),
) {
Ok(request) => request,
- Err(err) => return futures::future::ready(Err(err)).boxed(),
+ Err(err) => return futures::future::ready(Err(err.into())).boxed(),
};
let owned_handle = self.handler.clone();
@@ -807,6 +807,7 @@ impl LanguageModel for CloudLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let thread_id = request.thread_id.clone();
@@ -848,7 +849,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
- provider_request: serde_json::to_value(&request)?,
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
},
)
.await
@@ -884,7 +886,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
- Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
};
let request = into_open_ai(request, &model, None);
let llm_api_token = self.llm_api_token.clone();
@@ -905,7 +907,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
- provider_request: serde_json::to_value(&request)?,
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
},
)
.await?;
@@ -944,7 +947,8 @@ impl LanguageModel for CloudLanguageModel {
mode,
provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.model_id.clone(),
- provider_request: serde_json::to_value(&request)?,
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
},
)
.await?;
@@ -265,13 +265,15 @@ impl LanguageModel for CopilotChatLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
- return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
+ return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into()))
+ .boxed();
}
// Copilot Chat has a restriction that the final message must be from the user.
@@ -279,13 +281,13 @@ impl LanguageModel for CopilotChatLanguageModel {
// and provide a more helpful error message.
if !matches!(message.role, Role::User) {
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
- return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
+ return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed();
}
}
let copilot_request = match into_copilot_chat(&self.model, request) {
Ok(request) => request,
- Err(err) => return futures::future::ready(Err(err)).boxed(),
+ Err(err) => return futures::future::ready(Err(err.into())).boxed(),
};
let is_streaming = copilot_request.stream;
@@ -348,6 +348,7 @@ impl LanguageModel for DeepSeekLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let request = into_deepseek(request, &self.model, self.max_output_tokens());
@@ -409,6 +409,7 @@ impl LanguageModel for GoogleLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
+ LanguageModelCompletionError,
>,
> {
let request = into_google(
@@ -420,6 +420,7 @@ impl LanguageModel for LmStudioLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let request = self.to_lmstudio_request(request);
@@ -364,6 +364,7 @@ impl LanguageModel for MistralLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let request = into_mistral(
@@ -406,6 +406,7 @@ impl LanguageModel for OllamaLanguageModel {
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
>,
> {
let request = self.to_ollama_request(request);
@@ -415,7 +416,7 @@ impl LanguageModel for OllamaLanguageModel {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
let future = self.request_limiter.stream(async move {
@@ -339,6 +339,7 @@ impl LanguageModel for OpenAiLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
+ LanguageModelCompletionError,
>,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
@@ -367,6 +367,7 @@ impl LanguageModel for OpenRouterLanguageModel {
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
+ LanguageModelCompletionError,
>,
> {
let request = into_open_router(request, &self.model, self.max_output_tokens());