From 2eafa6e6aa471a9930a5b8c0209afa66030fb41a Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 22 Apr 2026 15:39:48 +0200 Subject: [PATCH] language_models: Remove unused language model token counting (#54177) Drop the `count_tokens` API and related implementations across providers, and remove the unused `tiktoken-rs` dependency. I was going to update the dependency becuase they finally released a fix we needed. But then I realized we only used this api in one place, the Rules library. And for most models it would have been wildly incorrect becuase we use tiktoken, i.e. OpenAI tokenizers, for almost every model, which is going to give incorrect results. Given that, I just removed these because the difference in how we get these has caused plenty of confusion in the past. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - N/A --- Cargo.lock | 33 +--- Cargo.toml | 1 - .../agent_ui/src/language_model_selector.rs | 10 +- crates/anthropic/Cargo.toml | 3 - crates/anthropic/src/anthropic.rs | 65 -------- crates/anthropic/src/completion.rs | 152 +----------------- .../cloud_llm_client/src/cloud_llm_client.rs | 12 -- crates/google_ai/Cargo.toml | 1 - crates/google_ai/src/completion.rs | 23 --- crates/google_ai/src/google_ai.rs | 46 ------ crates/language_model/src/fake_provider.rs | 4 - crates/language_model/src/language_model.rs | 6 - crates/language_models/Cargo.toml | 1 - .../language_models/src/provider/anthropic.rs | 51 +----- .../language_models/src/provider/bedrock.rs | 70 -------- .../src/provider/copilot_chat.rs | 40 ----- .../language_models/src/provider/deepseek.rs | 26 --- crates/language_models/src/provider/google.rs | 34 +--- .../language_models/src/provider/lmstudio.rs | 16 -- .../language_models/src/provider/mistral.rs | 26 --- crates/language_models/src/provider/ollama.rs | 17 -- .../language_models/src/provider/open_ai.rs | 13 +- .../src/provider/open_ai_compatible.rs | 21 --- .../src/provider/open_router.rs | 34 ---- .../language_models/src/provider/opencode.rs | 28 +--- crates/language_models/src/provider/vercel.rs | 57 +------ .../src/provider/vercel_ai_gateway.rs | 18 --- crates/language_models/src/provider/x_ai.rs | 11 -- crates/language_models_cloud/Cargo.toml | 1 - .../src/language_models_cloud.rs | 94 +---------- crates/open_ai/Cargo.toml | 1 - crates/open_ai/src/completion.rs | 92 +---------- crates/rules_library/src/rules_library.rs | 123 ++------------ crates/x_ai/Cargo.toml | 2 - crates/x_ai/src/completion.rs | 30 ---- crates/x_ai/src/x_ai.rs | 2 - 36 files changed, 34 insertions(+), 1130 deletions(-) delete mode 100644 crates/x_ai/src/completion.rs diff --git a/Cargo.lock b/Cargo.lock index d76f07aeef8e824ea78ca235d48d39ef7fe252e7..dd70dba47ea243095cc6f075c5a711e17e586d85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,7 +661,6 @@ dependencies = [ "serde_json", "strum 0.27.2", "thiserror 2.0.17", - "tiktoken-rs", ] [[package]] @@ -672,9 +671,9 @@ checksum = "34cd60c5e3152cef0a592f1b296f1cc93715d89d2551d85315828c3a09575ff4" [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "approx" @@ -2209,9 +2208,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" dependencies = [ "memchr", "regex-automata", @@ -7533,7 +7532,6 @@ dependencies = [ "serde", "serde_json", "strum 0.27.2", - "tiktoken-rs", ] [[package]] @@ -9574,7 +9572,6 @@ dependencies = [ "settings", "smol", "strum 0.27.2", - "tiktoken-rs", "tokio", "ui", "ui_input", @@ -9602,7 +9599,6 @@ dependencies = [ "serde_json", "smol", "thiserror 2.0.17", - "x_ai", ] [[package]] @@ -11766,7 +11762,6 @@ dependencies = [ "serde_json", "strum 0.27.2", "thiserror 2.0.17", - "tiktoken-rs", ] [[package]] @@ -14414,9 +14409,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -17855,20 +17850,6 @@ dependencies = [ "zune-jpeg 0.5.15", ] -[[package]] -name = "tiktoken-rs" -version = "0.9.1" -source = "git+https://github.com/zed-industries/tiktoken-rs?rev=2570c4387a8505fb8f1d3f3557454b474f1e8271#2570c4387a8505fb8f1d3f3557454b474f1e8271" -dependencies = [ - "anyhow", - "base64 0.22.1", - "bstr", - "fancy-regex 0.16.2", - "lazy_static", - "regex", - "rustc-hash 1.1.0", -] - [[package]] name = "time" version = "0.3.47" @@ -21920,11 +21901,9 @@ name = "x_ai" version = "0.1.0" dependencies = [ "anyhow", - "language_model_core", "schemars", "serde", "strum 0.27.2", - "tiktoken-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1d96615ebfe92b8a1728bfabca76f694a7e54b0e..448a4dd25c3b67f6f012194563cba996d3dc06bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -733,7 +733,6 @@ sysinfo = "0.37.0" take-until = "0.2.0" tempfile = "3.20.0" thiserror = "2.0.12" -tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "2570c4387a8505fb8f1d3f3557454b474f1e8271" } time = { version = "0.3", features = [ "macros", "parsing", diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7de58fd54ffd0d984b3a6079681f15f6a56507ae..4f870b7cfbd4800bb8f9c50d9f548af4a191a701 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -566,7 +566,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { mod tests { use super::*; use futures::{future::BoxFuture, stream::BoxStream}; - use gpui::{AsyncApp, TestAppContext, http_client}; + use gpui::{AsyncApp, TestAppContext}; use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, @@ -630,14 +630,6 @@ mod tests { 1000 } - fn count_tokens( - &self, - _: LanguageModelRequest, - _: &App, - ) -> BoxFuture<'static, http_client::Result> { - unimplemented!() - } - fn stream_completion( &self, _: LanguageModelRequest, diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 458f9bfae7da4736c4e54e42f08b5e3a926ed30a..3001b5801c067ec4053f34247985114cd8e8087c 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -28,6 +28,3 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -tiktoken-rs.workspace = true - - diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 488802d7db3fe169708c33710f7e583c46d41ae9..ac94daba71194ef925b87e7d73592aa9b05d97c2 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1000,71 +1000,6 @@ pub fn parse_prompt_too_long(message: &str) -> Option { .ok() } -/// Request body for the token counting API. -/// Similar to `Request` but without `max_tokens` since it's not needed for counting. -#[derive(Debug, Serialize)] -pub struct CountTokensRequest { - pub model: String, - pub messages: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub system: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tools: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub thinking: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, -} - -/// Response from the token counting API. -#[derive(Debug, Deserialize)] -pub struct CountTokensResponse { - pub input_tokens: u64, -} - -/// Count the number of tokens in a message without creating it. -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - let uri = format!("{api_url}/v1/messages/count_tokens"); - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Anthropic-Version", "2023-06-01") - .header("X-Api-Key", api_key.trim()) - .header("Content-Type", "application/json"); - - let serialized_request = - serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; - let http_request = request_builder - .body(AsyncBody::from(serialized_request)) - .map_err(AnthropicError::BuildRequestBody)?; - - let mut response = client - .send(http_request) - .await - .map_err(AnthropicError::HttpSend)?; - - let rate_limits = RateLimitInfo::from_headers(response.headers()); - - if response.status().is_success() { - let mut body = String::new(); - response - .body_mut() - .read_to_string(&mut body) - .await - .map_err(AnthropicError::ReadResponse)?; - - serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) - } else { - Err(handle_error_response(response, rate_limits).await) - } -} - // -- Conversions from/to `language_model_core` types -- impl From for Speed { diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs index a6175a4f7c24b3b724734b2edef48ef8acfaa159..16bb5012e583f44d34031d4bfe58e7680a90f289 100644 --- a/crates/anthropic/src/completion.rs +++ b/crates/anthropic/src/completion.rs @@ -11,9 +11,9 @@ use std::pin::Pin; use std::str::FromStr; use crate::{ - AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, - CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent, - StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage, + AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, Event, + ImageSource, Message, RequestContent, ResponseContent, StringOrContents, Thinking, Tool, + ToolChoice, ToolResultContent, ToolResultPart, Usage, }; fn to_anthropic_content(content: MessageContent) -> Option { @@ -92,152 +92,6 @@ fn to_anthropic_content(content: MessageContent) -> Option { } } -/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. -pub fn into_anthropic_count_tokens_request( - request: LanguageModelRequest, - model: String, - mode: AnthropicModelMode, -) -> CountTokensRequest { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => crate::Role::User, - Role::Assistant => crate::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - new_messages.push(Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - CountTokensRequest { - model, - messages: new_messages, - system: if system_message.is_empty() { - None - } else { - Some(StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => ToolChoice::Auto, - LanguageModelToolChoice::Any => ToolChoice::Any, - LanguageModelToolChoice::None => ToolChoice::None, - }), - } -} - -/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, -/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. -pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) -} - pub fn into_anthropic( request: LanguageModelRequest, model: String, diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index d5b3af394ea96b4269c13ab3259dd5a99ec6cf49..8d1dbeb4394cbbf52f3d47d94bbc2f7fcc75e5db 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -268,18 +268,6 @@ pub struct WebSearchResult { pub text: String, } -#[derive(Serialize, Deserialize)] -pub struct CountTokensBody { - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: serde_json::Value, -} - -#[derive(Serialize, Deserialize)] -pub struct CountTokensResponse { - pub tokens: usize, -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelId(pub Arc); diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index d91d28851997723835ba85be343a453918301c71..3848ed5f87514dfb70fb75d70e636c92ad5c4550 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -24,4 +24,3 @@ schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true -tiktoken-rs.workspace = true diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs index 3a15fdaa0187e52cb82dc8c71b5b861eb797f1a8..efbd1dc9ff731fbc12c5c01b697f8d944cd3e195 100644 --- a/crates/google_ai/src/completion.rs +++ b/crates/google_ai/src/completion.rs @@ -313,29 +313,6 @@ impl GoogleEventMapper { } } -/// Count tokens for a Google AI model using tiktoken. This is synchronous; -/// callers should spawn it on a background thread if needed. -pub fn count_google_tokens(request: LanguageModelRequest) -> Result { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) -} - fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { if let Some(prompt_token_count) = new.prompt_token_count { usage.prompt_token_count = Some(prompt_token_count); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 7917eb45c6292d05ede5267ba669a942348e575a..1461197bd972121685f2ff45d9cb1cec56492c84 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -64,38 +64,6 @@ pub async fn stream_generate_content( } } -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - validate_generate_content_request(&request.generate_content_request)?; - - let uri = format!( - "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", - model_id = &request.generate_content_request.model.model_id, - ); - - let request = serde_json::to_string(&request)?; - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(&uri) - .header("Content-Type", "application/json"); - let http_request = request_builder.body(AsyncBody::from(request))?; - - let mut response = client.send(http_request).await?; - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - anyhow::ensure!( - response.status().is_success(), - "error during countTokens, status code: {:?}, body: {}", - response.status(), - text - ); - Ok(serde_json::from_str::(&text)?) -} - pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { if request.model.is_empty() { bail!("Model must be specified"); @@ -123,8 +91,6 @@ pub enum Task { GenerateContent, #[serde(rename = "streamGenerateContent")] StreamGenerateContent, - #[serde(rename = "countTokens")] - CountTokens, #[serde(rename = "embedContent")] EmbedContent, #[serde(rename = "batchEmbedContents")] @@ -382,18 +348,6 @@ pub struct SafetyRating { pub probability: HarmProbability, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensRequest { - pub generate_content_request: GenerateContentRequest, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensResponse { - pub total_tokens: u64, -} - #[derive(Debug, Serialize, Deserialize)] pub struct FunctionCall { pub name: String, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index cee65c21e575e7c96579c271805386527a29d4da..4466a3f2762b033c869afda984d1aa453068f00e 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -299,10 +299,6 @@ impl LanguageModel for FakeLanguageModel { 1000000 } - fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result> { - futures::future::ready(Ok(0)).boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 60e8228fec52ffee763e19541f042ce47246dad2..4f7372777d7c1b1f2e4835426b30440f877eb097 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -121,12 +121,6 @@ pub trait LanguageModel: Send + Sync { None } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result>; - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 60670114529b07dca78202cc438ff5e243acaeee..f5828fa28d7064817374409ead1155a2c5809dca 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -57,7 +57,6 @@ serde_json.workspace = true settings.workspace = true smol.workspace = true strum.workspace = true -tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3d2b763a6e4a97a3a819505a5d81a6e9330fbb9c..af5e53300a785b2390e4ea8c6571d46447c4333a 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -22,10 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -pub use anthropic::completion::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, - into_anthropic_count_tokens_request, -}; +pub use anthropic::completion::{AnthropicEventMapper, into_anthropic}; pub use settings::AnthropicAvailableModel as AvailableModel; const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; @@ -378,52 +375,6 @@ impl LanguageModel for AnthropicModel { Some(self.model.max_output_tokens()) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let http_client = self.http_client.clone(); - let model_id = self.model.request_id().to_string(); - let mode = self.model.mode(); - - let (api_key, api_url) = self.state.read_with(cx, |state, cx| { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - ( - state.api_key_state.key(&api_url).map(|k| k.to_string()), - api_url.to_string(), - ) - }); - - let background = cx.background_executor().clone(); - async move { - // If no API key, fall back to tiktoken estimation - let Some(api_key) = api_key else { - return background - .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .await; - }; - - let count_request = - into_anthropic_count_tokens_request(request.clone(), model_id, mode); - - match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request) - .await - { - Ok(response) => Ok(response.input_tokens), - Err(err) => { - log::error!( - "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" - ); - background - .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .await - } - } - } - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 80c758769cd990c00f5942433143bf6fb2216b7c..1069ad80fc02493c54eb2cf22d9b85e98708ac55 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -706,14 +706,6 @@ impl LanguageModel for BedrockModel { Some(self.model.max_output_tokens()) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - get_bedrock_tokens(request, cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -1151,68 +1143,6 @@ pub fn into_bedrock( }) } -// TODO: just call the ConverseOutput.usage() method: -// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output -pub fn get_bedrock_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_executor() - .spawn(async move { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - use language_model::MessageContent; - - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - string_contents.push_str(&text); - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(&text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) - }) - .boxed() -} - pub fn map_to_language_model_completion_events( events: Pin>>>, ) -> impl Stream> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 0d7d03c8c754217c664238cfcb51134b36fb9ce5..ef9fbae1131a23d07d3e7934cd335c81603e710b 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -203,25 +203,6 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } } -fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - pub struct CopilotChatLanguageModel { model: CopilotChatModel, request_limiter: RateLimiter, @@ -318,27 +299,6 @@ impl LanguageModel for CopilotChatLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). - let tokenizer_model = match model.tokenizer() { - Some("o200k_base") => "gpt-4o", - Some("cl100k_base") => "gpt-4", - _ => "gpt-4o", - }; - - tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) - .map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index f3dccd5cc1a2e1a5ddfe2bc6b43901f2b549e532..dfc8521154e17abf04e91143cfbb0f8e79e9f1eb 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -293,32 +293,6 @@ impl LanguageModel for DeepSeekLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 92278839c6ff5119849f8881409928686f055331..87f2eeb26ab0f8f87da10671f5adc1c36b3426b8 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google}; +pub use google_ai::completion::{GoogleEventMapper, into_google}; use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; @@ -327,38 +327,6 @@ impl LanguageModel for GoogleLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model_id = self.model.request_id().to_string(); - let request = into_google(request, model_id, self.model.mode()); - let http_client = self.http_client.clone(); - let api_url = GoogleLanguageModelProvider::api_url(cx); - let api_key = self.state.read(cx).api_key_state.key(&api_url); - - async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - } - .into()); - }; - let response = google_ai::count_tokens( - http_client.as_ref(), - &api_url, - &api_key, - google_ai::CountTokensRequest { - generate_content_request: request, - }, - ) - .await?; - Ok(response.total_tokens) - } - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index a541da8cd8092d5d0fa43af1217c31833f10cdeb..f035e765f0737dc8f6dce0588fbbf69619902230 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -505,22 +505,6 @@ impl LanguageModel for LmStudioLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - _cx: &App, - ) -> BoxFuture<'static, Result> { - // Endpoint for this is coming soon. In the meantime, hacky estimation - let token_count = request - .messages - .iter() - .map(|msg| msg.string_contents().split_whitespace().count()) - .sum::(); - - let estimated_tokens = (token_count as f64 * 0.75) as u64; - async move { Ok(estimated_tokens) }.boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index fdb0fb7b3a7f510c8e55deefcea8e3b7f4d1eb86..cce5448b9938e3c642764fd645abbdbe8fa625e2 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -327,32 +327,6 @@ impl LanguageModel for MistralLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 49c326683a225bf73f604a584307ea1316a710c4..229b59e2bfded2473feebd970991cdacf2717471 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -493,23 +493,6 @@ impl LanguageModel for OllamaLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - _cx: &App, - ) -> BoxFuture<'static, Result> { - // There is no endpoint for this _yet_ in Ollama - // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 - let token_count = request - .messages - .iter() - .map(|msg| msg.string_contents().chars().count()) - .sum::() - / 4; - - async move { Ok(token_count as u64) }.boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index da341211855ca07bef526740a5a39260a4403982..f5ee65c8d85ff6d142620690c6bcbd0e0c387754 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -25,8 +25,7 @@ use ui_input::InputField; use util::ResultExt; pub use open_ai::completion::{ - OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens, - into_open_ai, into_open_ai_response, + OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response, }; const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; @@ -369,16 +368,6 @@ impl LanguageModel for OpenAiLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { count_open_ai_tokens(request, model) }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index b82e3b1ae3aa80373dbad3550e7bd896b8879f2b..5f7f6db3d36a45a6b41893421a59355188b7902b 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -360,27 +360,6 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.max_output_tokens } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let max_token_count = self.max_token_count(); - cx.background_spawn(async move { - let messages = super::open_ai::collect_tiktoken_messages(request); - let model = if max_token_count >= 100_000 { - // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index fba3a6938aecf1db80680e014e408e4d59c42ff7..6562d9de085229b4f8b80982b813c811382b45b1 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -372,14 +372,6 @@ impl LanguageModel for OpenRouterLanguageModel { self.model.supports_images.unwrap_or(false) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_open_router_tokens(request, self.model.clone(), cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -741,32 +733,6 @@ struct RawToolCall { thought_signature: Option, } -pub fn count_open_router_tokens( - request: LanguageModelRequest, - _model: open_router::Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models/src/provider/opencode.rs b/crates/language_models/src/provider/opencode.rs index 4754741715f39104f392b6871f68a1c04e3bdfce..4b0f8e5992a22c4bb5049e477419c9b9fa2d6f6a 100644 --- a/crates/language_models/src/provider/opencode.rs +++ b/crates/language_models/src/provider/opencode.rs @@ -8,7 +8,7 @@ use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var, + LanguageModelRequest, LanguageModelToolChoice, RateLimiter, env_var, }; use opencode::{ApiProtocol, OPENCODE_API_URL}; pub use settings::OpenCodeAvailableModel as AvailableModel; @@ -426,32 +426,6 @@ impl LanguageModel for OpenCodeLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index ce9870073ee9f399e9f02a5a093931c1a4304fdb..188cb6d0322d36d120f2faf70a3a4d3b33997512 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -8,7 +8,7 @@ use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var, + LanguageModelRequest, LanguageModelToolChoice, RateLimiter, env_var, }; use open_ai::ResponseStreamEvent; pub use settings::VercelAvailableModel as AvailableModel; @@ -18,7 +18,7 @@ use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use vercel::{Model, VERCEL_API_URL}; +use vercel::VERCEL_API_URL; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); @@ -295,14 +295,6 @@ impl LanguageModel for VercelLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_vercel_tokens(request, self.model.clone(), cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -335,51 +327,6 @@ impl LanguageModel for VercelLanguageModel { } } -pub fn count_vercel_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Map Vercel models to appropriate OpenAI models for token counting - // since Vercel uses OpenAI-compatible API - Model::VZeroOnePointFiveMedium => { - // Vercel v0 is similar to GPT-4o, so use gpt-4o for token counting - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) - } - } - .map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models/src/provider/vercel_ai_gateway.rs b/crates/language_models/src/provider/vercel_ai_gateway.rs index cf379e6edc1db181127cc284834b19c61143d692..789e8e35e8546a3a7493a4fbfe500f4baef04d85 100644 --- a/crates/language_models/src/provider/vercel_ai_gateway.rs +++ b/crates/language_models/src/provider/vercel_ai_gateway.rs @@ -422,24 +422,6 @@ impl LanguageModel for VercelAiGatewayLanguageModel { self.model.max_output_tokens } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let max_token_count = self.max_token_count(); - cx.background_spawn(async move { - let messages = crate::provider::open_ai::collect_tiktoken_messages(request); - let model = if max_token_count >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index b14cb82bf83ac773a58104207d1594b36994a91c..12f195417b5220afa9cd4e5b33211d40b39e14c3 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -20,7 +20,6 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; use x_ai::XAI_API_URL; -pub use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); @@ -316,16 +315,6 @@ impl LanguageModel for XAiLanguageModel { true } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { count_xai_tokens(request, model) }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models_cloud/Cargo.toml b/crates/language_models_cloud/Cargo.toml index b08acc5ecd5c2a718e936378c2dbfbc3d1c32df0..de82fdfa627829121183793de42ce620a72c4682 100644 --- a/crates/language_models_cloud/Cargo.toml +++ b/crates/language_models_cloud/Cargo.toml @@ -27,7 +27,6 @@ serde.workspace = true serde_json.workspace = true smol.workspace = true thiserror.workspace = true -x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models_cloud/src/language_models_cloud.rs b/crates/language_models_cloud/src/language_models_cloud.rs index 4e444def7b4df6295a5f12ebccb08802abdfca4d..adae72068c508ea853281755f258ff8e5432b405 100644 --- a/crates/language_models_cloud/src/language_models_cloud.rs +++ b/crates/language_models_cloud/src/language_models_cloud.rs @@ -3,9 +3,8 @@ use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, - CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, - OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, - ZED_VERSION_HEADER_NAME, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, @@ -13,7 +12,7 @@ use futures::{ stream::{self, BoxStream}, }; use google_ai::GoogleModelMode; -use gpui::{App, AppContext, AsyncApp, Context, Task}; +use gpui::{AppContext, AsyncApp, Context, Task}; use http_client::http::{HeaderMap, HeaderValue}; use http_client::{ AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode, @@ -40,15 +39,11 @@ use std::task::Poll; use std::time::Duration; use thiserror::Error; -use anthropic::completion::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, -}; +use anthropic::completion::{AnthropicEventMapper, into_anthropic}; use google_ai::completion::{GoogleEventMapper, into_google}; use open_ai::completion::{ - OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, - into_open_ai_response, + OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response, }; -use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; @@ -374,85 +369,6 @@ impl LanguageModel for CloudLanguageModel BoxFuture<'static, Result> { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => cx - .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .boxed(), - cloud_llm_client::LanguageModelProvider::OpenAi => { - let model = match open_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - cx.background_spawn(async move { count_open_ai_tokens(request, model) }) - .boxed() - } - cloud_llm_client::LanguageModelProvider::XAi => { - let model = match x_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - cx.background_spawn(async move { count_xai_tokens(request, model) }) - .boxed() - } - cloud_llm_client::LanguageModelProvider::Google => { - let http_client = self.http_client.clone(); - let token_provider = self.token_provider.clone(); - let model_id = self.model.id.to_string(); - let generate_content_request = - into_google(request, model_id.clone(), GoogleModelMode::Default); - let auth_context = token_provider.auth_context(cx); - async move { - let token = token_provider.acquire_token(auth_context).await?; - - let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, - model: model_id, - provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - generate_content_request, - })?, - }; - let request = http_client::Request::builder() - .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/count_tokens", &[])? - .as_ref(), - ) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body(serde_json::to_string(&request_body)?.into())?; - let mut response = http_client.send(request).await?; - let status = response.status(); - let headers = response.headers().clone(); - let mut response_body = String::new(); - response - .body_mut() - .read_to_string(&mut response_body) - .await?; - - if status.is_success() { - let response_body: CountTokensResponse = - serde_json::from_str(&response_body)?; - - Ok(response_body.tokens as u64) - } else { - Err(anyhow!(ApiError { - status, - body: response_body, - headers - })) - } - } - .boxed() - } - } - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 9a73e73196fa225691fa68e2ca839a19783bc3ca..5083e97c56014708d73ecaefd9463aae6189a8ac 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -28,7 +28,6 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -tiktoken-rs.workspace = true [dev-dependencies] pretty_assertions.workspace = true diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs index e37f57dd8b2efd8b1332c313bdb51713cdd142e6..3068f57f582db16182e4c4a78e10fd632bc2815f 100644 --- a/crates/open_ai/src/completion.rs +++ b/crates/open_ai/src/completion.rs @@ -18,7 +18,7 @@ use crate::responses::{ StreamEvent as ResponsesStreamEvent, }; use crate::{ - FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort, + FunctionContent, FunctionDefinition, ImageUrl, MessagePart, ReasoningEffort, ResponseStreamEvent, ToolCall, ToolCallContent, }; @@ -818,68 +818,6 @@ fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { } } -pub fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -/// Count tokens for an OpenAI model. This is synchronous; callers should spawn -/// it on a background thread if needed. -pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result { - let messages = collect_tiktoken_messages(request); - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::FourOmniMini - | Model::FourPointOneNano - | Model::O1 - | Model::O3 - | Model::O3Mini - | Model::Five - | Model::FiveCodex - | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), - // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer - Model::FivePointOne - | Model::FivePointTwo - | Model::FivePointTwoCodex - | Model::FivePointThreeCodex - | Model::FivePointFour - | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), - } - .map(|tokens| tokens as u64) -} - #[cfg(test)] mod tests { use crate::responses::{ @@ -929,34 +867,6 @@ mod tests { }) } - #[test] - fn tiktoken_rs_support() { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - reasoning_details: None, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in ::iter() { - let count = count_open_ai_tokens(request.clone(), model).unwrap(); - assert!(count > 0); - } - } - #[test] fn responses_stream_maps_text_and_usage() { let events = vec![ diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 425f7d2aa3d9e9259fe005a0e15dee10e4e4baf1..e5105081ca7af71520f6dae7f8f01b818d2b8877 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -8,9 +8,7 @@ use gpui::{ WindowOptions, actions, point, size, transparent_black, }; use language::{Buffer, LanguageRegistry, language_settings::SoftWrap}; -use language_model::{ - ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, -}; +use language_model::{ConfiguredModel, LanguageModelRegistry}; use picker::{Picker, PickerDelegate}; use platform_title_bar::PlatformTitleBar; use release_channel::ReleaseChannel; @@ -165,8 +163,6 @@ pub struct RulesLibrary { struct RuleEditor { title_editor: Entity, body_editor: Entity, - token_count: Option, - pending_token_count: Task>, next_title_and_body_to_save: Option<(String, Rope)>, pending_save: Option>>, _subscriptions: Vec, @@ -785,13 +781,10 @@ impl RulesLibrary { body_editor, next_title_and_body_to_save: None, pending_save: None, - token_count: None, - pending_token_count: Task::ready(None), _subscriptions, }, ); this.set_active_rule(Some(prompt_id), window, cx); - this.count_tokens(prompt_id, window, cx); } Err(error) => { // TODO: we should show the error in the UI. @@ -1019,7 +1012,6 @@ impl RulesLibrary { match event { EditorEvent::BufferEdited => { self.save_rule(prompt_id, window, cx); - self.count_tokens(prompt_id, window, cx); } EditorEvent::Blurred => { title_editor.update(cx, |title_editor, cx| { @@ -1049,7 +1041,6 @@ impl RulesLibrary { match event { EditorEvent::BufferEdited => { self.save_rule(prompt_id, window, cx); - self.count_tokens(prompt_id, window, cx); } EditorEvent::Blurred => { body_editor.update(cx, |body_editor, cx| { @@ -1068,59 +1059,6 @@ impl RulesLibrary { } } - fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context) { - let Some(ConfiguredModel { model, .. }) = - LanguageModelRegistry::read_global(cx).default_model() - else { - return; - }; - if let Some(rule) = self.rule_editors.get_mut(&prompt_id) { - let editor = &rule.body_editor.read(cx); - let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx); - let body = buffer.as_rope().clone(); - rule.pending_token_count = cx.spawn_in(window, async move |this, cx| { - async move { - const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1); - - cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; - let token_count = cx - .update(|_, cx| { - model.count_tokens( - LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::System, - content: vec![body.to_string().into()], - cache: false, - reasoning_details: None, - }], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }, - cx, - ) - })? - .await?; - - this.update(cx, |this, cx| { - let rule_editor = this.rule_editors.get_mut(&prompt_id).unwrap(); - rule_editor.token_count = Some(token_count); - cx.notify(); - }) - } - .log_err() - .await - }); - } - } - fn render_rule_list(&mut self, cx: &mut Context) -> impl IntoElement { v_flex() .id("rule-list") @@ -1293,8 +1231,6 @@ impl RulesLibrary { let rule_metadata = self.store.read(cx).metadata(prompt_id)?; let rule_editor = &self.rule_editors[&prompt_id]; let focus_handle = rule_editor.body_editor.focus_handle(cx); - let registry = LanguageModelRegistry::read_global(cx); - let model = registry.default_model().map(|default| default.model); let built_in = prompt_id.is_built_in(); Some( @@ -1318,52 +1254,17 @@ impl RulesLibrary { built_in, cx, )) - .child( - h_flex() - .h_full() - .flex_shrink_0() - .children(rule_editor.token_count.map(|token_count| { - let token_count: SharedString = - token_count.to_string().into(); - let label_token_count: SharedString = - token_count.to_string().into(); - - div() - .id("token_count") - .mr_1() - .flex_shrink_0() - .tooltip(move |_window, cx| { - Tooltip::with_meta( - "Token Estimation", - None, - format!( - "Model: {}", - model - .as_ref() - .map(|model| model.name().0) - .unwrap_or_default() - ), - cx, - ) - }) - .child( - Label::new(format!( - "{} tokens", - label_token_count - )) - .color(Color::Muted), - ) - })) - .map(|this| { - if built_in { - this.child(self.render_built_in_rule_controls()) - } else { - this.child(self.render_regular_rule_controls( - rule_metadata.default, - )) - } - }), - ), + .child(h_flex().h_full().flex_shrink_0().map(|this| { + if built_in { + this.child(self.render_built_in_rule_controls()) + } else { + this.child( + self.render_regular_rule_controls( + rule_metadata.default, + ), + ) + } + })), ) .child( div() diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml index 2d1c9d0ecebeb8a1e0965b0ac914603b41383f00..8ff020df8c1ccaf284157d8b46ddaa0e678b3cd7 100644 --- a/crates/x_ai/Cargo.toml +++ b/crates/x_ai/Cargo.toml @@ -17,8 +17,6 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true -language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true strum.workspace = true -tiktoken-rs.workspace = true diff --git a/crates/x_ai/src/completion.rs b/crates/x_ai/src/completion.rs deleted file mode 100644 index aad03d227eb82768c972283f7e1617ea7486f22f..0000000000000000000000000000000000000000 --- a/crates/x_ai/src/completion.rs +++ /dev/null @@ -1,30 +0,0 @@ -use anyhow::Result; -use language_model_core::{LanguageModelRequest, Role}; - -use crate::Model; - -/// Count tokens for an xAI model using tiktoken. This is synchronous; -/// callers should spawn it on a background thread if needed. -pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - let model_name = if model.max_token_count() >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) -} diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index bc49a3e2b37d6ac83c66a2fba3af83ea7a451576..afa7d62aa3c99127580779eb0f0f1563e3ab658d 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -1,5 +1,3 @@ -pub mod completion; - use anyhow::Result; use serde::{Deserialize, Serialize}; use strum::EnumIter;