@@ -355,6 +355,7 @@ pub struct Thread {
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
+ tool_use_limit_reached: bool,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
@@ -417,6 +418,7 @@ impl Thread {
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
+ tool_use_limit_reached: false,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -524,6 +526,7 @@ impl Thread {
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
+ tool_use_limit_reached: false,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -814,6 +817,10 @@ impl Thread {
.unwrap_or(false)
}
+ pub fn tool_use_limit_reached(&self) -> bool {
+ self.tool_use_limit_reached
+ }
+
/// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool {
// If the only pending tool uses left are the ones with errors, then
@@ -1331,6 +1338,8 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
+ self.tool_use_limit_reached = false;
+
let pending_completion_id = post_inc(&mut self.completion_count);
let mut request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new()))
@@ -1506,17 +1515,27 @@ impl Thread {
});
}
}
- LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
+ LanguageModelCompletionEvent::QueueUpdate(status) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
- completion.queue_state = match queue_event {
- language_model::QueueState::Queued { position } => {
- QueueState::Queued { position }
+ let queue_state = match status {
+ language_model::CompletionRequestStatus::Queued {
+ position,
+ } => Some(QueueState::Queued { position }),
+ language_model::CompletionRequestStatus::Started => {
+ Some(QueueState::Started)
}
- language_model::QueueState::Started => QueueState::Started,
+ language_model::CompletionRequestStatus::ToolUseLimitReached => {
+ thread.tool_use_limit_reached = true;
+ None
+ }
+ };
+
+ if let Some(queue_state) = queue_state {
+ completion.queue_state = queue_state;
}
}
}
@@ -9,11 +9,12 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
- AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
- ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
+ AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
+ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
+ LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
+ LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
+ ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -38,6 +39,7 @@ use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+ TOOL_USE_LIMIT_REACHED_HEADER_NAME,
};
use crate::AllLanguageModelSettings;
@@ -511,6 +513,13 @@ pub struct CloudLanguageModel {
request_limiter: RateLimiter,
}
+struct PerformLlmCompletionResponse {
+ response: Response<AsyncBody>,
+ usage: Option<RequestUsage>,
+ tool_use_limit_reached: bool,
+ includes_queue_events: bool,
+}
+
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
@@ -518,7 +527,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: CompletionBody,
- ) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
+ ) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
@@ -545,9 +554,18 @@ impl CloudLanguageModel {
.headers()
.get("x-zed-server-supports-queueing")
.is_some();
+ let tool_use_limit_reached = response
+ .headers()
+ .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
+ .is_some();
let usage = RequestUsage::from_headers(response.headers()).ok();
- return Ok((response, usage, includes_queue_events));
+ return Ok(PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_queue_events,
+ tool_use_limit_reached,
+ });
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_queue_events,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = AnthropicEventMapper::new();
Ok((
map_cloud_completion_events(
- Box::pin(response_lines(response, includes_queue_events)),
+ Box::pin(
+ response_lines(response, includes_queue_events)
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
move |event| mapper.map_event(event),
),
usage,
@@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_queue_events,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = OpenAiEventMapper::new();
Ok((
map_cloud_completion_events(
- Box::pin(response_lines(response, includes_queue_events)),
+ Box::pin(
+ response_lines(response, includes_queue_events)
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
move |event| mapper.map_event(event),
),
usage,
@@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage, includes_queue_events) = Self::perform_llm_completion(
+ let PerformLlmCompletionResponse {
+ response,
+ usage,
+ includes_queue_events,
+ tool_use_limit_reached,
+ } = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
+
let mut mapper = GoogleEventMapper::new();
Ok((
map_cloud_completion_events(
- Box::pin(response_lines(response, includes_queue_events)),
+ Box::pin(
+ response_lines(response, includes_queue_events)
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ ),
move |event| mapper.map_event(event),
),
usage,
@@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
- Queue(QueueState),
+ System(CompletionRequestStatus),
Event(T),
}
@@ -925,7 +968,7 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
}
- Ok(CloudCompletionEvent::Queue(event)) => {
+ Ok(CloudCompletionEvent::System(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
@@ -934,6 +977,16 @@ where
.boxed()
}
+fn tool_use_limit_reached_event<T>(
+ tool_use_limit_reached: bool,
+) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+ futures::stream::iter(tool_use_limit_reached.then(|| {
+ Ok(CloudCompletionEvent::System(
+ CompletionRequestStatus::ToolUseLimitReached,
+ ))
+ }))
+}
+
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
includes_queue_events: bool,