From d93141bdedd4967637dd49bad224a0c8162560f1 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 17 Apr 2025 16:11:07 -0400 Subject: [PATCH] agent: Extract usage information from response headers (#29002) This PR updates the Agent to extract the usage information from the response headers, if they are present. For now we just log the information, but we'll be using this soon to populate some UI. Release Notes: - N/A --- Cargo.lock | 2 + crates/agent/Cargo.toml | 1 + crates/agent/src/thread.rs | 13 +++- crates/language_model/Cargo.toml | 1 + crates/language_model/src/language_model.rs | 45 ++++++++++++- crates/language_model/src/rate_limiter.rs | 30 +++++++++ crates/language_models/src/provider/cloud.rs | 71 ++++++++++++++------ 7 files changed, 141 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5fa56f45beedabac8936cf86e96d309ed4d94e9..6866438debbea6019179c17a0ec04b43fbde6e78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,7 @@ dependencies = [ "workspace", "workspace-hack", "zed_actions", + "zed_llm_client", ] [[package]] @@ -7654,6 +7655,7 @@ dependencies = [ "thiserror 2.0.12", "util", "workspace-hack", + "zed_llm_client", ] [[package]] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ae184a1f386f3df5e63828909d8fb9bc3595fadb..cd8b9af0ee53a214b1f5049beb6010518a563aa5 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -90,6 +90,7 @@ uuid.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zed_llm_client.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 62c43b877e3c247bbb362a7558873bfdfc5005be..94882f8cbd68648ed63660cfb7f7d00d44e81c2d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -31,6 +31,7 @@ use settings::Settings; use thiserror::Error; use util::{ResultExt as _, TryFutureExt as _, post_inc}; use uuid::Uuid; +use zed_llm_client::UsageLimit; use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::thread_store::{ @@ -1070,14 +1071,22 @@ impl Thread { ) { let pending_completion_id = post_inc(&mut self.completion_count); let task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion(request, &cx); + let stream_completion_future = model.stream_completion_with_usage(request, &cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { - let mut events = stream.await?; + let (mut events, usage) = stream_completion_future.await?; let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); + if let Some(usage) = usage { + let limit = match usage.limit { + UsageLimit::Limited(limit) => limit.to_string(), + UsageLimit::Unlimited => "unlimited".to_string(), + }; + log::info!("model request usage: {} / {}", usage.amount, limit); + } + while let Some(event) = events.next().await { let event = event?; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4580d9f701da5edf000b411125d048c5fe3ec702..c468ff82973e4bada9d3835c0d78938fe66c8e95 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -40,6 +40,7 @@ telemetry_events.workspace = true thiserror.workspace = true util.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 35bf5d60940e8ce6dbe68cda15db8ef96d7b5aa2..88115c43fb67152a382a2855b399ef5cdf53d428 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,11 +8,12 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::Result; +use anyhow::{Result, anyhow}; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use proto::Plan; @@ -20,9 +21,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; +use std::str::FromStr as _; use std::sync::Arc; use thiserror::Error; use util::serde::is_default; +use zed_llm_client::{ + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -83,6 +88,28 @@ pub enum StopReason { ToolUse, } +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + +impl RequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + let limit = headers + .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) + .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] @@ -214,6 +241,22 @@ pub trait LanguageModel: Send + Sync { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>>; + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { + self.stream_completion(request, cx) + .map(|result| result.map(|stream| (stream, None))) + .boxed() + } + fn stream_completion_text( &self, request: LanguageModelRequest, diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index a48d34488b8de7ea0ede49e63c95f868923fa23c..7383dd56c93e6e6795989c3566eeb37e0c02f50c 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -8,6 +8,8 @@ use std::{ task::{Context, Poll}, }; +use crate::RequestUsage; + #[derive(Clone)] pub struct RateLimiter { semaphore: Arc, @@ -67,4 +69,32 @@ impl RateLimiter { }) } } + + pub fn stream_with_usage<'a, Fut, T>( + &self, + future: Fut, + ) -> impl 'a + + Future< + Output = Result<( + impl Stream + use, + Option, + )>, + > + where + Fut: 'a + Future)>>, + T: Stream, + { + let guard = self.semaphore.acquire_arc(); + async move { + let guard = guard.await; + let (inner, usage) = future.await?; + Ok(( + RateLimitGuard { + inner, + _guard: guard, + }, + usage, + )) + } + } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 8286d0a1f2fa3e9c07aeea38a11f96200ef9a2a8..25a048537be2017abf04f457d9dedea897206246 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -13,7 +13,7 @@ use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -518,7 +518,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result> { + ) -> Result<(Response, Option)> { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -540,7 +540,9 @@ impl CloudLanguageModel { let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { - return Ok(response); + let usage = RequestUsage::from_headers(response.headers()).ok(); + + return Ok((response, usage)); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -708,8 +710,24 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - _cx: &AsyncApp, + cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { + self.stream_completion_with_usage(request, cx) + .map(|result| result.map(|(stream, _)| stream)) + .boxed() + } + + fn stream_completion_with_usage( + &self, + request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result<( + BoxStream<'static, Result>, + Option, + )>, + > { match &self.model { CloudModel::Anthropic(model) => { let request = into_anthropic( @@ -721,8 +739,8 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -748,20 +766,25 @@ impl LanguageModel for CloudLanguageModel { Err(err) => anyhow!(err), })?; - Ok( + Ok(( crate::provider::anthropic::map_to_language_model_completion_events( Box::pin(response_lines(response).map_err(AnthropicError::Other)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -771,20 +794,25 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::open_ai::map_to_language_model_completion_events( Box::pin(response_lines(response)), ), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( + let future = self.request_limiter.stream_with_usage(async move { + let (response, usage) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -794,13 +822,18 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok( + Ok(( crate::provider::google::map_to_language_model_completion_events(Box::pin( response_lines(response), )), - ) + usage, + )) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + let (stream, usage) = future.await?; + Ok((stream.boxed(), usage)) + } + .boxed() } } }