From 7e801dccb0e7d2296120dbf277f63ecfecd223d3 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Fri, 20 Jun 2025 15:28:48 -0600 Subject: [PATCH] agent: Fix issues with usage display sometimes showing initially fetched usage (#33125) Having `Thread::last_usage` as an override of the initially fetched usage could cause the initial usage to be displayed when the current thread is empty or in text threads. Fix is to just store last usage info in `UserStore` and not have these overrides Release Notes: - Agent: Fixed request usage display to always include the most recently known usage - there were some cases where it would show the initially requested usage. --- Cargo.lock | 3 + crates/agent/src/agent_panel.rs | 26 +--- crates/agent/src/debug.rs | 10 +- crates/agent/src/message_editor.rs | 25 +-- crates/agent/src/thread.rs | 39 ++--- crates/agent/src/ui/preview/usage_callouts.rs | 27 ++-- crates/client/Cargo.toml | 2 + crates/client/src/user.rs | 142 +++++++++++++++--- crates/inline_completion/Cargo.toml | 3 +- .../src/inline_completion.rs | 40 +---- crates/language_model/src/language_model.rs | 35 +---- crates/language_models/src/provider/cloud.rs | 10 +- crates/zeta/src/zeta.rs | 39 ++--- 13 files changed, 189 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d3a623f7ef6ba32097e8d6f8e256d5eb4a830637..bfa9b2396e778ac3c1f8a7c229efdebc0b837e7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2821,6 +2821,7 @@ dependencies = [ "cocoa 0.26.0", "collections", "credentials_provider", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -2859,6 +2860,7 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -8159,6 +8161,7 @@ name = "inline_completion" version = "0.1.0" dependencies = [ "anyhow", + "client", "gpui", "language", "project", diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index ee76045d2114688e4b0c131ce21921dbcd7b398d..10c2db37c40e7b24402523adc38f9e3135cf20ae 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; @@ -45,7 +44,7 @@ use ui::{ Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::{ CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace, @@ -1682,24 +1681,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); let is_empty = active_thread.is_empty(); let editor_empty = self.message_editor.read(cx).is_editor_fully_empty(cx); - let last_usage = active_thread.thread().read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }); + let usage = user_store.model_request_usage(); let account_url = zed_urls::account_url(cx); @@ -1820,7 +1802,7 @@ impl AgentPanel { .action("Add Custom Server…", Box::new(AddContextServer)) .separator(); - if let Some(usage) = last_usage { + if let Some(usage) = usage { menu = menu .header_with_link("Prompt Usage", "Manage", account_url.clone()) .custom_entry( diff --git a/crates/agent/src/debug.rs b/crates/agent/src/debug.rs index 7bd52e5a96561bc293477a18c86fabf84c81ca00..ff6538dc85a45f0072b805b033952da78255f8b7 100644 --- a/crates/agent/src/debug.rs +++ b/crates/agent/src/debug.rs @@ -1,7 +1,7 @@ #![allow(unused, dead_code)] +use client::{ModelRequestUsage, RequestUsage}; use gpui::Global; -use language_model::RequestUsage; use std::ops::{Deref, DerefMut}; use ui::prelude::*; use zed_llm_client::{Plan, UsageLimit}; @@ -17,7 +17,7 @@ pub struct DebugAccountState { pub enabled: bool, pub trial_expired: bool, pub plan: Plan, - pub custom_prompt_usage: RequestUsage, + pub custom_prompt_usage: ModelRequestUsage, pub usage_based_billing_enabled: bool, pub monthly_spending_cap: i32, pub custom_edit_prediction_usage: UsageLimit, @@ -43,7 +43,7 @@ impl DebugAccountState { self } - pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: RequestUsage) -> &mut Self { + pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: ModelRequestUsage) -> &mut Self { self.custom_prompt_usage = custom_prompt_usage; self } @@ -76,10 +76,10 @@ impl Default for DebugAccountState { enabled: false, trial_expired: false, plan: Plan::ZedFree, - custom_prompt_usage: RequestUsage { + custom_prompt_usage: ModelRequestUsage(RequestUsage { limit: UsageLimit::Unlimited, amount: 0, - }, + }), usage_based_billing_enabled: false, // $50.00 monthly_spending_cap: 5000, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index c8d127aa28e8a97a9cedeeca082306c8bddc373d..ec0a01e8af01f29041a6f4a116fca4761bdc2a8e 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::{Buffer, Language, Point}; use language_model::{ - ConfiguredModel, LanguageModelRequestMessage, MessageContent, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, }; use multi_buffer; use project::Project; @@ -42,7 +41,7 @@ use theme::ThemeSettings; use ui::{ Callout, Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_llm_client::CompletionIntent; @@ -1257,24 +1256,8 @@ impl MessageEditor { Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, }) .unwrap_or(zed_llm_client::Plan::ZedFree); - let usage = self.thread.read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - })?; + + let usage = user_store.model_request_usage()?; Some( div() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index edc0ea1152883ef16b6b33f273216724912957b7..1a6b9604b597f26f1c3b76e02c9e4f39ef5ceef1 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -7,6 +7,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; +use client::{ModelRequestUsage, RequestUsage}; use collections::HashMap; use editor::display_map::CreaseMetadata; use feature_flags::{self, FeatureFlagAppExt}; @@ -22,8 +23,8 @@ use language_model::{ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, + ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, + TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -38,7 +39,7 @@ use ui::Window; use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; +use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use crate::ThreadStore; use crate::agent_profile::AgentProfile; @@ -350,7 +351,6 @@ pub struct Thread { request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, - last_usage: Option, tool_use_limit_reached: bool, feedback: Option, message_feedback: HashMap, @@ -443,7 +443,6 @@ impl Thread { request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), @@ -568,7 +567,6 @@ impl Thread { request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: serialized.tool_use_limit_reached, feedback: None, message_feedback: HashMap::default(), @@ -875,10 +873,6 @@ impl Thread { .unwrap_or(false) } - pub fn last_usage(&self) -> Option { - self.last_usage - } - pub fn tool_use_limit_reached(&self) -> bool { self.tool_use_limit_reached } @@ -1658,9 +1652,7 @@ impl Thread { CompletionRequestStatus::UsageUpdated { amount, limit } => { - let usage = RequestUsage { limit, amount: amount as i32 }; - - thread.last_usage = Some(usage); + thread.update_model_request_usage(amount as u32, limit, cx); } CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; @@ -1871,11 +1863,8 @@ impl Thread { LanguageModelCompletionEvent::StatusUpdate( CompletionRequestStatus::UsageUpdated { amount, limit }, ) => { - this.update(cx, |thread, _cx| { - thread.last_usage = Some(RequestUsage { - limit, - amount: amount as i32, - }); + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount as u32, limit, cx); })?; continue; } @@ -2757,6 +2746,20 @@ impl Thread { } } + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { + self.project.update(cx, |project, cx| { + project.user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }) + }); + } + pub fn deny_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, diff --git a/crates/agent/src/ui/preview/usage_callouts.rs b/crates/agent/src/ui/preview/usage_callouts.rs index 62e29094612186280c2be4696feedfe03df35589..45af41395b52afc8655c7cdd748a3228868b2d0f 100644 --- a/crates/agent/src/ui/preview/usage_callouts.rs +++ b/crates/agent/src/ui/preview/usage_callouts.rs @@ -1,18 +1,17 @@ -use client::zed_urls; +use client::{ModelRequestUsage, RequestUsage, zed_urls}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use language_model::RequestUsage; use ui::{Callout, prelude::*}; use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { plan: Plan, - usage: RequestUsage, + usage: ModelRequestUsage, } impl UsageCallout { - pub fn new(plan: Plan, usage: RequestUsage) -> Self { + pub fn new(plan: Plan, usage: ModelRequestUsage) -> Self { Self { plan, usage } } } @@ -128,10 +127,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 45, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -139,10 +138,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 50, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -156,10 +155,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 135, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -167,10 +166,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 150, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -184,10 +183,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedPro, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(500), amount: 500, // 100% of limit - }, + }), ) .into_any_element(), ), diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 0d65b7ef215e79f529db5727fa3308e4773b062f..b741f515fd1681a721048d831d01db8ec0f889e6 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -24,6 +24,7 @@ chrono = { workspace = true, features = ["serde"] } clock.workspace = true collections.workspace = true credentials_provider.workspace = true +derive_more.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true @@ -57,6 +58,7 @@ worktree.workspace = true telemetry.workspace = true tokio.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 1be8d71e85cba51d95e750ae8b871bdcedefe279..61e3064eb496b59910ce8ab25797b9b4b4848201 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -2,16 +2,25 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet, hash_map::Entry}; +use derive_more::Deref; use feature_flags::FeatureFlagAppExt; use futures::{Future, StreamExt, channel::mpsc}; use gpui::{ App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity, }; +use http_client::http::{HeaderMap, HeaderValue}; use postage::{sink::Sink, watch}; use rpc::proto::{RequestMessage, UsersResponse}; -use std::sync::{Arc, Weak}; +use std::{ + str::FromStr as _, + sync::{Arc, Weak}, +}; use text::ReplicaId; use util::{TryFutureExt as _, maybe}; +use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub type UserId = u64; @@ -104,10 +113,8 @@ pub struct UserStore { current_plan: Option, subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, - model_request_usage_amount: Option, - model_request_usage_limit: Option, - edit_predictions_usage_amount: Option, - edit_predictions_usage_limit: Option, + model_request_usage: Option, + edit_prediction_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -155,6 +162,18 @@ enum UpdateContacts { Clear(postage::barrier::Sender), } +#[derive(Debug, Clone, Copy, Deref)] +pub struct ModelRequestUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy, Deref)] +pub struct EditPredictionUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + impl UserStore { pub fn new(client: Arc, cx: &Context) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); @@ -172,10 +191,8 @@ impl UserStore { current_plan: None, subscription_period: None, trial_started_at: None, - model_request_usage_amount: None, - model_request_usage_limit: None, - edit_predictions_usage_amount: None, - edit_predictions_usage_limit: None, + model_request_usage: None, + edit_prediction_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -356,10 +373,19 @@ impl UserStore { this.has_overdue_invoices = message.payload.has_overdue_invoices; if let Some(usage) = message.payload.usage { - this.model_request_usage_amount = Some(usage.model_requests_usage_amount); - this.model_request_usage_limit = usage.model_requests_usage_limit; - this.edit_predictions_usage_amount = Some(usage.edit_predictions_usage_amount); - this.edit_predictions_usage_limit = usage.edit_predictions_usage_limit; + // limits are always present even though they are wrapped in Option + this.model_request_usage = usage + .model_requests_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(ModelRequestUsage); + this.edit_prediction_usage = usage + .edit_predictions_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(EditPredictionUsage); } cx.notify(); @@ -367,6 +393,20 @@ impl UserStore { Ok(()) } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -739,20 +779,12 @@ impl UserStore { self.is_usage_based_billing_enabled } - pub fn model_request_usage_amount(&self) -> Option { - self.model_request_usage_amount - } - - pub fn model_request_usage_limit(&self) -> Option { - self.model_request_usage_limit.clone() - } - - pub fn edit_predictions_usage_amount(&self) -> Option { - self.edit_predictions_usage_amount + pub fn model_request_usage(&self) -> Option { + self.model_request_usage } - pub fn edit_predictions_usage_limit(&self) -> Option { - self.edit_predictions_usage_limit.clone() + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage } pub fn watch_current_user(&self) -> watch::Receiver>> { @@ -917,3 +949,63 @@ impl Collaborator { }) } } + +impl RequestUsage { + pub fn over_limit(&self) -> bool { + match self.limit { + UsageLimit::Limited(limit) => self.amount >= limit, + UsageLimit::Unlimited => false, + } + } + + pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { + let limit = match limit.variant? { + proto::usage_limit::Variant::Limited(limited) => { + UsageLimit::Limited(limited.limit as i32) + } + proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, + }; + Some(RequestUsage { + limit, + amount: amount as i32, + }) + } + + fn from_headers( + limit_name: &str, + amount_name: &str, + headers: &HeaderMap, + ) -> Result { + let limit = headers + .get(limit_name) + .with_context(|| format!("missing {limit_name:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(amount_name) + .with_context(|| format!("missing {amount_name:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + +impl ModelRequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} + +impl EditPredictionUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/inline_completion/Cargo.toml index 0094385e16d150d002679efecafafd934d739760..3a90875def1a8ce491765c24c18f432807292dc9 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/inline_completion/Cargo.toml @@ -12,9 +12,8 @@ workspace = true path = "src/inline_completion.rs" [dependencies] -anyhow.workspace = true +client.workspace = true gpui.workspace = true language.workspace = true project.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/inline_completion/src/inline_completion.rs index 7acfea72b2610a514b8fd274f1ea03bc46d581ce..c8f35bf16a116294edb5d1d2f5359733828e6995 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/inline_completion/src/inline_completion.rs @@ -1,14 +1,9 @@ use std::ops::Range; -use std::str::FromStr as _; -use anyhow::{Context as _, Result}; -use gpui::http_client::http::{HeaderMap, HeaderValue}; +use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; use language::Buffer; use project::Project; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; // TODO: Find a better home for `Direction`. // @@ -59,39 +54,6 @@ impl DataCollectionState { } } -#[derive(Debug, Clone, Copy)] -pub struct EditPredictionUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl EditPredictionUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } - - pub fn over_limit(&self) -> bool { - match self.limit { - UsageLimit::Limited(limit) => self.amount >= limit, - UsageLimit::Unlimited => false, - } - } -} - pub trait EditPredictionProvider: 'static + Sized { fn name() -> &'static str; fn display_name() -> &'static str; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c411593213b7921e2a01b839b478b56978d953d5..900d7f6f39e9c9fa4a17dff6db0c31c70c53c0b4 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,27 +8,22 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Context as _, Result}; +use anyhow::Result; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; -use std::str::FromStr as _; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{ - CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, - MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use zed_llm_client::CompletionRequestStatus; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -106,32 +101,6 @@ pub enum StopReason { Refusal, } -#[derive(Debug, Clone, Copy)] -pub struct RequestUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl RequestUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 59a5537ae9bf3dea26786cb608b92975baec982c..1062d732a42d0d7fdd15e99d15a50b72826ed03c 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,6 +1,6 @@ use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Context as _, Result, anyhow}; -use client::{Client, UserStore, zed_urls}; +use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -14,7 +14,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -530,7 +530,7 @@ pub struct CloudLanguageModel { struct PerformLlmCompletionResponse { response: Response, - usage: Option, + usage: Option, tool_use_limit_reached: bool, includes_status_messages: bool, } @@ -581,7 +581,7 @@ impl CloudLanguageModel { let usage = if includes_status_messages { None } else { - RequestUsage::from_headers(response.headers()).ok() + ModelRequestUsage::from_headers(response.headers()).ok() }; return Ok(PerformLlmCompletionResponse { @@ -1002,7 +1002,7 @@ where } fn usage_updated_event( - usage: Option, + usage: Option, ) -> impl Stream>> { futures::stream::iter(usage.map(|usage| { Ok(CloudCompletionEvent::Status( diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 23ce320ee9a77f42856cba5ca32723cb7ea18eaa..4d643c9db08da49144d26f28a416524fd5a3ceab 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -9,14 +9,14 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; pub use init::*; -use inline_completion::{DataCollectionState, EditPredictionUsage}; +use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, UserStore}; +use client::{Client, EditPredictionUsage, UserStore}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -48,7 +48,7 @@ use std::{ }; use telemetry_events::InlineCompletionRating; use thiserror::Error; -use util::{ResultExt, maybe}; +use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; @@ -188,7 +188,6 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - last_usage: Option, /// Whether the terms of service have been accepted. tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. @@ -234,25 +233,7 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.last_usage.or_else(|| { - let user_store = self.user_store.read(cx); - maybe!({ - let amount = user_store.edit_predictions_usage_amount()?; - let limit = user_store.edit_predictions_usage_limit()?.variant?; - - Some(EditPredictionUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }) + self.user_store.read(cx).edit_prediction_usage() } fn new( @@ -287,7 +268,6 @@ impl Zeta { .detach_and_log_err(cx); }, ), - last_usage: None, tos_accepted: user_store .read(cx) .current_user_has_accepted_terms() @@ -533,8 +513,10 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); if let Some(usage) = usage { - this.update(cx, |this, _cx| { - this.last_usage = Some(usage); + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); }) .ok(); } @@ -874,8 +856,9 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.last_usage = Some(usage); - cx.notify(); + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); })?; }