Detailed changes
@@ -125,6 +125,7 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
+ "zed_llm_client",
]
[[package]]
@@ -7654,6 +7655,7 @@ dependencies = [
"thiserror 2.0.12",
"util",
"workspace-hack",
+ "zed_llm_client",
]
[[package]]
@@ -90,6 +90,7 @@ uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
+zed_llm_client.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }
@@ -31,6 +31,7 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
+use zed_llm_client::UsageLimit;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
@@ -1070,14 +1071,22 @@ impl Thread {
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn(async move |thread, cx| {
- let stream = model.stream_completion(request, &cx);
+ let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
- let mut events = stream.await?;
+ let (mut events, usage) = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
+ if let Some(usage) = usage {
+ let limit = match usage.limit {
+ UsageLimit::Limited(limit) => limit.to_string(),
+ UsageLimit::Unlimited => "unlimited".to_string(),
+ };
+ log::info!("model request usage: {} / {}", usage.amount, limit);
+ }
+
while let Some(event) = events.next().await {
let event = event?;
@@ -40,6 +40,7 @@ telemetry_events.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true
+zed_llm_client.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
@@ -8,11 +8,12 @@ mod telemetry;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
-use anyhow::Result;
+use anyhow::{Result, anyhow};
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
+use http_client::http::{HeaderMap, HeaderValue};
use icons::IconName;
use parking_lot::Mutex;
use proto::Plan;
@@ -20,9 +21,13 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt;
use std::ops::{Add, Sub};
+use std::str::FromStr as _;
use std::sync::Arc;
use thiserror::Error;
use util::serde::is_default;
+use zed_llm_client::{
+ MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
+};
pub use crate::model::*;
pub use crate::rate_limiter::*;
@@ -83,6 +88,28 @@ pub enum StopReason {
ToolUse,
}
+#[derive(Debug, Clone, Copy)]
+pub struct RequestUsage {
+ pub limit: UsageLimit,
+ pub amount: i32,
+}
+
+impl RequestUsage {
+ pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
+ let limit = headers
+ .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
+ .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
+ let limit = UsageLimit::from_str(limit.to_str()?)?;
+
+ let amount = headers
+ .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
+ .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
+ let amount = amount.to_str()?.parse::<i32>()?;
+
+ Ok(Self { limit, amount })
+ }
+}
+
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
@@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
+ fn stream_completion_with_usage(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<(
+ BoxStream<'static, Result<LanguageModelCompletionEvent>>,
+ Option<RequestUsage>,
+ )>,
+ > {
+ self.stream_completion(request, cx)
+ .map(|result| result.map(|stream| (stream, None)))
+ .boxed()
+ }
+
fn stream_completion_text(
&self,
request: LanguageModelRequest,
@@ -8,6 +8,8 @@ use std::{
task::{Context, Poll},
};
+use crate::RequestUsage;
+
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
@@ -67,4 +69,32 @@ impl RateLimiter {
})
}
}
+
+ pub fn stream_with_usage<'a, Fut, T>(
+ &self,
+ future: Fut,
+ ) -> impl 'a
+ + Future<
+ Output = Result<(
+ impl Stream<Item = T::Item> + use<Fut, T>,
+ Option<RequestUsage>,
+ )>,
+ >
+ where
+ Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
+ T: Stream,
+ {
+ let guard = self.semaphore.acquire_arc();
+ async move {
+ let guard = guard.await;
+ let (inner, usage) = future.await?;
+ Ok((
+ RateLimitGuard {
+ inner,
+ _guard: guard,
+ },
+ usage,
+ ))
+ }
+ }
}
@@ -13,7 +13,7 @@ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
- LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
+ LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
@@ -518,7 +518,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: CompletionBody,
- ) -> Result<Response<AsyncBody>> {
+ ) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
@@ -540,7 +540,9 @@ impl CloudLanguageModel {
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
- return Ok(response);
+ let usage = RequestUsage::from_headers(response.headers()).ok();
+
+ return Ok((response, usage));
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
- _cx: &AsyncApp,
+ cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ self.stream_completion_with_usage(request, cx)
+ .map(|result| result.map(|(stream, _)| stream))
+ .boxed()
+ }
+
+ fn stream_completion_with_usage(
+ &self,
+ request: LanguageModelRequest,
+ _cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<(
+ BoxStream<'static, Result<LanguageModelCompletionEvent>>,
+ Option<RequestUsage>,
+ )>,
+ > {
match &self.model {
CloudModel::Anthropic(model) => {
let request = into_anthropic(
@@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel {
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
+ let future = self.request_limiter.stream_with_usage(async move {
+ let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel {
Err(err) => anyhow!(err),
})?;
- Ok(
+ Ok((
crate::provider::anthropic::map_to_language_model_completion_events(
Box::pin(response_lines(response).map_err(AnthropicError::Other)),
),
- )
+ usage,
+ ))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ let (stream, usage) = future.await?;
+ Ok((stream.boxed(), usage))
+ }
+ .boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
+ let future = self.request_limiter.stream_with_usage(async move {
+ let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- Ok(
+ Ok((
crate::provider::open_ai::map_to_language_model_completion_events(
Box::pin(response_lines(response)),
),
- )
+ usage,
+ ))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ let (stream, usage) = future.await?;
+ Ok((stream.boxed(), usage))
+ }
+ .boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
+ let future = self.request_limiter.stream_with_usage(async move {
+ let (response, usage) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- Ok(
+ Ok((
crate::provider::google::map_to_language_model_completion_events(Box::pin(
response_lines(response),
)),
- )
+ usage,
+ ))
});
- async move { Ok(future.await?.boxed()) }.boxed()
+ async move {
+ let (stream, usage) = future.await?;
+ Ok((stream.boxed(), usage))
+ }
+ .boxed()
}
}
}