diff --git a/Cargo.lock b/Cargo.lock index d76f07aeef8e824ea78ca235d48d39ef7fe252e7..dd70dba47ea243095cc6f075c5a711e17e586d85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,7 +661,6 @@ dependencies = [ "serde_json", "strum 0.27.2", "thiserror 2.0.17", - "tiktoken-rs", ] [[package]] @@ -672,9 +671,9 @@ checksum = "34cd60c5e3152cef0a592f1b296f1cc93715d89d2551d85315828c3a09575ff4" [[package]] name = "anyhow" -version = "1.0.100" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "approx" @@ -2209,9 +2208,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" dependencies = [ "memchr", "regex-automata", @@ -7533,7 +7532,6 @@ dependencies = [ "serde", "serde_json", "strum 0.27.2", - "tiktoken-rs", ] [[package]] @@ -9574,7 +9572,6 @@ dependencies = [ "settings", "smol", "strum 0.27.2", - "tiktoken-rs", "tokio", "ui", "ui_input", @@ -9602,7 +9599,6 @@ dependencies = [ "serde_json", "smol", "thiserror 2.0.17", - "x_ai", ] [[package]] @@ -11766,7 +11762,6 @@ dependencies = [ "serde_json", "strum 0.27.2", "thiserror 2.0.17", - "tiktoken-rs", ] [[package]] @@ -14414,9 +14409,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.2" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -17855,20 +17850,6 @@ dependencies = [ "zune-jpeg 0.5.15", ] -[[package]] -name = "tiktoken-rs" -version = "0.9.1" -source = "git+https://github.com/zed-industries/tiktoken-rs?rev=2570c4387a8505fb8f1d3f3557454b474f1e8271#2570c4387a8505fb8f1d3f3557454b474f1e8271" -dependencies = [ - "anyhow", - "base64 0.22.1", - "bstr", - "fancy-regex 0.16.2", - "lazy_static", - "regex", - "rustc-hash 1.1.0", -] - [[package]] name = "time" version = "0.3.47" @@ -21920,11 +21901,9 @@ name = "x_ai" version = "0.1.0" dependencies = [ "anyhow", - "language_model_core", "schemars", "serde", "strum 0.27.2", - "tiktoken-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1d96615ebfe92b8a1728bfabca76f694a7e54b0e..448a4dd25c3b67f6f012194563cba996d3dc06bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -733,7 +733,6 @@ sysinfo = "0.37.0" take-until = "0.2.0" tempfile = "3.20.0" thiserror = "2.0.12" -tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "2570c4387a8505fb8f1d3f3557454b474f1e8271" } time = { version = "0.3", features = [ "macros", "parsing", diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7de58fd54ffd0d984b3a6079681f15f6a56507ae..4f870b7cfbd4800bb8f9c50d9f548af4a191a701 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -566,7 +566,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { mod tests { use super::*; use futures::{future::BoxFuture, stream::BoxStream}; - use gpui::{AsyncApp, TestAppContext, http_client}; + use gpui::{AsyncApp, TestAppContext}; use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, @@ -630,14 +630,6 @@ mod tests { 1000 } - fn count_tokens( - &self, - _: LanguageModelRequest, - _: &App, - ) -> BoxFuture<'static, http_client::Result> { - unimplemented!() - } - fn stream_completion( &self, _: LanguageModelRequest, diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 458f9bfae7da4736c4e54e42f08b5e3a926ed30a..3001b5801c067ec4053f34247985114cd8e8087c 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -28,6 +28,3 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -tiktoken-rs.workspace = true - - diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 488802d7db3fe169708c33710f7e583c46d41ae9..ac94daba71194ef925b87e7d73592aa9b05d97c2 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1000,71 +1000,6 @@ pub fn parse_prompt_too_long(message: &str) -> Option { .ok() } -/// Request body for the token counting API. -/// Similar to `Request` but without `max_tokens` since it's not needed for counting. -#[derive(Debug, Serialize)] -pub struct CountTokensRequest { - pub model: String, - pub messages: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub system: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tools: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub thinking: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, -} - -/// Response from the token counting API. -#[derive(Debug, Deserialize)] -pub struct CountTokensResponse { - pub input_tokens: u64, -} - -/// Count the number of tokens in a message without creating it. -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - let uri = format!("{api_url}/v1/messages/count_tokens"); - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Anthropic-Version", "2023-06-01") - .header("X-Api-Key", api_key.trim()) - .header("Content-Type", "application/json"); - - let serialized_request = - serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; - let http_request = request_builder - .body(AsyncBody::from(serialized_request)) - .map_err(AnthropicError::BuildRequestBody)?; - - let mut response = client - .send(http_request) - .await - .map_err(AnthropicError::HttpSend)?; - - let rate_limits = RateLimitInfo::from_headers(response.headers()); - - if response.status().is_success() { - let mut body = String::new(); - response - .body_mut() - .read_to_string(&mut body) - .await - .map_err(AnthropicError::ReadResponse)?; - - serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) - } else { - Err(handle_error_response(response, rate_limits).await) - } -} - // -- Conversions from/to `language_model_core` types -- impl From for Speed { diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs index a6175a4f7c24b3b724734b2edef48ef8acfaa159..16bb5012e583f44d34031d4bfe58e7680a90f289 100644 --- a/crates/anthropic/src/completion.rs +++ b/crates/anthropic/src/completion.rs @@ -11,9 +11,9 @@ use std::pin::Pin; use std::str::FromStr; use crate::{ - AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, - CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent, - StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage, + AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, Event, + ImageSource, Message, RequestContent, ResponseContent, StringOrContents, Thinking, Tool, + ToolChoice, ToolResultContent, ToolResultPart, Usage, }; fn to_anthropic_content(content: MessageContent) -> Option { @@ -92,152 +92,6 @@ fn to_anthropic_content(content: MessageContent) -> Option { } } -/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. -pub fn into_anthropic_count_tokens_request( - request: LanguageModelRequest, - model: String, - mode: AnthropicModelMode, -) -> CountTokensRequest { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => crate::Role::User, - Role::Assistant => crate::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - new_messages.push(Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - CountTokensRequest { - model, - messages: new_messages, - system: if system_message.is_empty() { - None - } else { - Some(StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => ToolChoice::Auto, - LanguageModelToolChoice::Any => ToolChoice::Any, - LanguageModelToolChoice::None => ToolChoice::None, - }), - } -} - -/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, -/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. -pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) -} - pub fn into_anthropic( request: LanguageModelRequest, model: String, diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index d5b3af394ea96b4269c13ab3259dd5a99ec6cf49..8d1dbeb4394cbbf52f3d47d94bbc2f7fcc75e5db 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -268,18 +268,6 @@ pub struct WebSearchResult { pub text: String, } -#[derive(Serialize, Deserialize)] -pub struct CountTokensBody { - pub provider: LanguageModelProvider, - pub model: String, - pub provider_request: serde_json::Value, -} - -#[derive(Serialize, Deserialize)] -pub struct CountTokensResponse { - pub tokens: usize, -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelId(pub Arc); diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index d91d28851997723835ba85be343a453918301c71..3848ed5f87514dfb70fb75d70e636c92ad5c4550 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -24,4 +24,3 @@ schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true -tiktoken-rs.workspace = true diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs index 3a15fdaa0187e52cb82dc8c71b5b861eb797f1a8..efbd1dc9ff731fbc12c5c01b697f8d944cd3e195 100644 --- a/crates/google_ai/src/completion.rs +++ b/crates/google_ai/src/completion.rs @@ -313,29 +313,6 @@ impl GoogleEventMapper { } } -/// Count tokens for a Google AI model using tiktoken. This is synchronous; -/// callers should spawn it on a background thread if needed. -pub fn count_google_tokens(request: LanguageModelRequest) -> Result { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) -} - fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { if let Some(prompt_token_count) = new.prompt_token_count { usage.prompt_token_count = Some(prompt_token_count); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 7917eb45c6292d05ede5267ba669a942348e575a..1461197bd972121685f2ff45d9cb1cec56492c84 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -64,38 +64,6 @@ pub async fn stream_generate_content( } } -pub async fn count_tokens( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: CountTokensRequest, -) -> Result { - validate_generate_content_request(&request.generate_content_request)?; - - let uri = format!( - "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}", - model_id = &request.generate_content_request.model.model_id, - ); - - let request = serde_json::to_string(&request)?; - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(&uri) - .header("Content-Type", "application/json"); - let http_request = request_builder.body(AsyncBody::from(request))?; - - let mut response = client.send(http_request).await?; - let mut text = String::new(); - response.body_mut().read_to_string(&mut text).await?; - anyhow::ensure!( - response.status().is_success(), - "error during countTokens, status code: {:?}, body: {}", - response.status(), - text - ); - Ok(serde_json::from_str::(&text)?) -} - pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> { if request.model.is_empty() { bail!("Model must be specified"); @@ -123,8 +91,6 @@ pub enum Task { GenerateContent, #[serde(rename = "streamGenerateContent")] StreamGenerateContent, - #[serde(rename = "countTokens")] - CountTokens, #[serde(rename = "embedContent")] EmbedContent, #[serde(rename = "batchEmbedContents")] @@ -382,18 +348,6 @@ pub struct SafetyRating { pub probability: HarmProbability, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensRequest { - pub generate_content_request: GenerateContentRequest, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensResponse { - pub total_tokens: u64, -} - #[derive(Debug, Serialize, Deserialize)] pub struct FunctionCall { pub name: String, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index cee65c21e575e7c96579c271805386527a29d4da..4466a3f2762b033c869afda984d1aa453068f00e 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -299,10 +299,6 @@ impl LanguageModel for FakeLanguageModel { 1000000 } - fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result> { - futures::future::ready(Ok(0)).boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 60e8228fec52ffee763e19541f042ce47246dad2..4f7372777d7c1b1f2e4835426b30440f877eb097 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -121,12 +121,6 @@ pub trait LanguageModel: Send + Sync { None } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result>; - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 60670114529b07dca78202cc438ff5e243acaeee..f5828fa28d7064817374409ead1155a2c5809dca 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -57,7 +57,6 @@ serde_json.workspace = true settings.workspace = true smol.workspace = true strum.workspace = true -tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true ui_input.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3d2b763a6e4a97a3a819505a5d81a6e9330fbb9c..af5e53300a785b2390e4ea8c6571d46447c4333a 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -22,10 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -pub use anthropic::completion::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, - into_anthropic_count_tokens_request, -}; +pub use anthropic::completion::{AnthropicEventMapper, into_anthropic}; pub use settings::AnthropicAvailableModel as AvailableModel; const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; @@ -378,52 +375,6 @@ impl LanguageModel for AnthropicModel { Some(self.model.max_output_tokens()) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let http_client = self.http_client.clone(); - let model_id = self.model.request_id().to_string(); - let mode = self.model.mode(); - - let (api_key, api_url) = self.state.read_with(cx, |state, cx| { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - ( - state.api_key_state.key(&api_url).map(|k| k.to_string()), - api_url.to_string(), - ) - }); - - let background = cx.background_executor().clone(); - async move { - // If no API key, fall back to tiktoken estimation - let Some(api_key) = api_key else { - return background - .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .await; - }; - - let count_request = - into_anthropic_count_tokens_request(request.clone(), model_id, mode); - - match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request) - .await - { - Ok(response) => Ok(response.input_tokens), - Err(err) => { - log::error!( - "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" - ); - background - .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .await - } - } - } - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 80c758769cd990c00f5942433143bf6fb2216b7c..1069ad80fc02493c54eb2cf22d9b85e98708ac55 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -706,14 +706,6 @@ impl LanguageModel for BedrockModel { Some(self.model.max_output_tokens()) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - get_bedrock_tokens(request, cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -1151,68 +1143,6 @@ pub fn into_bedrock( }) } -// TODO: just call the ConverseOutput.usage() method: -// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output -pub fn get_bedrock_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_executor() - .spawn(async move { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - use language_model::MessageContent; - - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - string_contents.push_str(&text); - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(&text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) - }) - .boxed() -} - pub fn map_to_language_model_completion_events( events: Pin>>>, ) -> impl Stream> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 0d7d03c8c754217c664238cfcb51134b36fb9ce5..ef9fbae1131a23d07d3e7934cd335c81603e710b 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -203,25 +203,6 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } } -fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - pub struct CopilotChatLanguageModel { model: CopilotChatModel, request_limiter: RateLimiter, @@ -318,27 +299,6 @@ impl LanguageModel for CopilotChatLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). - let tokenizer_model = match model.tokenizer() { - Some("o200k_base") => "gpt-4o", - Some("cl100k_base") => "gpt-4", - _ => "gpt-4o", - }; - - tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) - .map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index f3dccd5cc1a2e1a5ddfe2bc6b43901f2b549e532..dfc8521154e17abf04e91143cfbb0f8e79e9f1eb 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -293,32 +293,6 @@ impl LanguageModel for DeepSeekLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 92278839c6ff5119849f8881409928686f055331..87f2eeb26ab0f8f87da10671f5adc1c36b3426b8 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google}; +pub use google_ai::completion::{GoogleEventMapper, into_google}; use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; @@ -327,38 +327,6 @@ impl LanguageModel for GoogleLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model_id = self.model.request_id().to_string(); - let request = into_google(request, model_id, self.model.mode()); - let http_client = self.http_client.clone(); - let api_url = GoogleLanguageModelProvider::api_url(cx); - let api_key = self.state.read(cx).api_key_state.key(&api_url); - - async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - } - .into()); - }; - let response = google_ai::count_tokens( - http_client.as_ref(), - &api_url, - &api_key, - google_ai::CountTokensRequest { - generate_content_request: request, - }, - ) - .await?; - Ok(response.total_tokens) - } - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index a541da8cd8092d5d0fa43af1217c31833f10cdeb..f035e765f0737dc8f6dce0588fbbf69619902230 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -505,22 +505,6 @@ impl LanguageModel for LmStudioLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - _cx: &App, - ) -> BoxFuture<'static, Result> { - // Endpoint for this is coming soon. In the meantime, hacky estimation - let token_count = request - .messages - .iter() - .map(|msg| msg.string_contents().split_whitespace().count()) - .sum::(); - - let estimated_tokens = (token_count as f64 * 0.75) as u64; - async move { Ok(estimated_tokens) }.boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index fdb0fb7b3a7f510c8e55deefcea8e3b7f4d1eb86..cce5448b9938e3c642764fd645abbdbe8fa625e2 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -327,32 +327,6 @@ impl LanguageModel for MistralLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 49c326683a225bf73f604a584307ea1316a710c4..229b59e2bfded2473feebd970991cdacf2717471 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -493,23 +493,6 @@ impl LanguageModel for OllamaLanguageModel { self.model.max_token_count() } - fn count_tokens( - &self, - request: LanguageModelRequest, - _cx: &App, - ) -> BoxFuture<'static, Result> { - // There is no endpoint for this _yet_ in Ollama - // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 - let token_count = request - .messages - .iter() - .map(|msg| msg.string_contents().chars().count()) - .sum::() - / 4; - - async move { Ok(token_count as u64) }.boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index da341211855ca07bef526740a5a39260a4403982..f5ee65c8d85ff6d142620690c6bcbd0e0c387754 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -25,8 +25,7 @@ use ui_input::InputField; use util::ResultExt; pub use open_ai::completion::{ - OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens, - into_open_ai, into_open_ai_response, + OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response, }; const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; @@ -369,16 +368,6 @@ impl LanguageModel for OpenAiLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { count_open_ai_tokens(request, model) }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index b82e3b1ae3aa80373dbad3550e7bd896b8879f2b..5f7f6db3d36a45a6b41893421a59355188b7902b 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -360,27 +360,6 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.max_output_tokens } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let max_token_count = self.max_token_count(); - cx.background_spawn(async move { - let messages = super::open_ai::collect_tiktoken_messages(request); - let model = if max_token_count >= 100_000 { - // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index fba3a6938aecf1db80680e014e408e4d59c42ff7..6562d9de085229b4f8b80982b813c811382b45b1 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -372,14 +372,6 @@ impl LanguageModel for OpenRouterLanguageModel { self.model.supports_images.unwrap_or(false) } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_open_router_tokens(request, self.model.clone(), cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -741,32 +733,6 @@ struct RawToolCall { thought_signature: Option, } -pub fn count_open_router_tokens( - request: LanguageModelRequest, - _model: open_router::Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models/src/provider/opencode.rs b/crates/language_models/src/provider/opencode.rs index 4754741715f39104f392b6871f68a1c04e3bdfce..4b0f8e5992a22c4bb5049e477419c9b9fa2d6f6a 100644 --- a/crates/language_models/src/provider/opencode.rs +++ b/crates/language_models/src/provider/opencode.rs @@ -8,7 +8,7 @@ use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var, + LanguageModelRequest, LanguageModelToolChoice, RateLimiter, env_var, }; use opencode::{ApiProtocol, OPENCODE_API_URL}; pub use settings::OpenCodeAvailableModel as AvailableModel; @@ -426,32 +426,6 @@ impl LanguageModel for OpenCodeLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index ce9870073ee9f399e9f02a5a093931c1a4304fdb..188cb6d0322d36d120f2faf70a3a4d3b33997512 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -8,7 +8,7 @@ use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var, + LanguageModelRequest, LanguageModelToolChoice, RateLimiter, env_var, }; use open_ai::ResponseStreamEvent; pub use settings::VercelAvailableModel as AvailableModel; @@ -18,7 +18,7 @@ use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use vercel::{Model, VERCEL_API_URL}; +use vercel::VERCEL_API_URL; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); @@ -295,14 +295,6 @@ impl LanguageModel for VercelLanguageModel { self.model.max_output_tokens() } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_vercel_tokens(request, self.model.clone(), cx) - } - fn stream_completion( &self, request: LanguageModelRequest, @@ -335,51 +327,6 @@ impl LanguageModel for VercelLanguageModel { } } -pub fn count_vercel_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Map Vercel models to appropriate OpenAI models for token counting - // since Vercel uses OpenAI-compatible API - Model::VZeroOnePointFiveMedium => { - // Vercel v0 is similar to GPT-4o, so use gpt-4o for token counting - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) - } - } - .map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models/src/provider/vercel_ai_gateway.rs b/crates/language_models/src/provider/vercel_ai_gateway.rs index cf379e6edc1db181127cc284834b19c61143d692..789e8e35e8546a3a7493a4fbfe500f4baef04d85 100644 --- a/crates/language_models/src/provider/vercel_ai_gateway.rs +++ b/crates/language_models/src/provider/vercel_ai_gateway.rs @@ -422,24 +422,6 @@ impl LanguageModel for VercelAiGatewayLanguageModel { self.model.max_output_tokens } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let max_token_count = self.max_token_count(); - cx.background_spawn(async move { - let messages = crate::provider::open_ai::collect_tiktoken_messages(request); - let model = if max_token_count >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) - }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index b14cb82bf83ac773a58104207d1594b36994a91c..12f195417b5220afa9cd4e5b33211d40b39e14c3 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -20,7 +20,6 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; use x_ai::XAI_API_URL; -pub use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); @@ -316,16 +315,6 @@ impl LanguageModel for XAiLanguageModel { true } - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { count_xai_tokens(request, model) }) - .boxed() - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/language_models_cloud/Cargo.toml b/crates/language_models_cloud/Cargo.toml index b08acc5ecd5c2a718e936378c2dbfbc3d1c32df0..de82fdfa627829121183793de42ce620a72c4682 100644 --- a/crates/language_models_cloud/Cargo.toml +++ b/crates/language_models_cloud/Cargo.toml @@ -27,7 +27,6 @@ serde.workspace = true serde_json.workspace = true smol.workspace = true thiserror.workspace = true -x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models_cloud/src/language_models_cloud.rs b/crates/language_models_cloud/src/language_models_cloud.rs index 4e444def7b4df6295a5f12ebccb08802abdfca4d..adae72068c508ea853281755f258ff8e5432b405 100644 --- a/crates/language_models_cloud/src/language_models_cloud.rs +++ b/crates/language_models_cloud/src/language_models_cloud.rs @@ -3,9 +3,8 @@ use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, - CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, - OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, - ZED_VERSION_HEADER_NAME, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, @@ -13,7 +12,7 @@ use futures::{ stream::{self, BoxStream}, }; use google_ai::GoogleModelMode; -use gpui::{App, AppContext, AsyncApp, Context, Task}; +use gpui::{AppContext, AsyncApp, Context, Task}; use http_client::http::{HeaderMap, HeaderValue}; use http_client::{ AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode, @@ -40,15 +39,11 @@ use std::task::Poll; use std::time::Duration; use thiserror::Error; -use anthropic::completion::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, -}; +use anthropic::completion::{AnthropicEventMapper, into_anthropic}; use google_ai::completion::{GoogleEventMapper, into_google}; use open_ai::completion::{ - OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, - into_open_ai_response, + OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response, }; -use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; @@ -374,85 +369,6 @@ impl LanguageModel for CloudLanguageModel BoxFuture<'static, Result> { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => cx - .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .boxed(), - cloud_llm_client::LanguageModelProvider::OpenAi => { - let model = match open_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - cx.background_spawn(async move { count_open_ai_tokens(request, model) }) - .boxed() - } - cloud_llm_client::LanguageModelProvider::XAi => { - let model = match x_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - cx.background_spawn(async move { count_xai_tokens(request, model) }) - .boxed() - } - cloud_llm_client::LanguageModelProvider::Google => { - let http_client = self.http_client.clone(); - let token_provider = self.token_provider.clone(); - let model_id = self.model.id.to_string(); - let generate_content_request = - into_google(request, model_id.clone(), GoogleModelMode::Default); - let auth_context = token_provider.auth_context(cx); - async move { - let token = token_provider.acquire_token(auth_context).await?; - - let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, - model: model_id, - provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - generate_content_request, - })?, - }; - let request = http_client::Request::builder() - .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/count_tokens", &[])? - .as_ref(), - ) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body(serde_json::to_string(&request_body)?.into())?; - let mut response = http_client.send(request).await?; - let status = response.status(); - let headers = response.headers().clone(); - let mut response_body = String::new(); - response - .body_mut() - .read_to_string(&mut response_body) - .await?; - - if status.is_success() { - let response_body: CountTokensResponse = - serde_json::from_str(&response_body)?; - - Ok(response_body.tokens as u64) - } else { - Err(anyhow!(ApiError { - status, - body: response_body, - headers - })) - } - } - .boxed() - } - } - } - fn stream_completion( &self, request: LanguageModelRequest, diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 9a73e73196fa225691fa68e2ca839a19783bc3ca..5083e97c56014708d73ecaefd9463aae6189a8ac 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -28,7 +28,6 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -tiktoken-rs.workspace = true [dev-dependencies] pretty_assertions.workspace = true diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs index e37f57dd8b2efd8b1332c313bdb51713cdd142e6..3068f57f582db16182e4c4a78e10fd632bc2815f 100644 --- a/crates/open_ai/src/completion.rs +++ b/crates/open_ai/src/completion.rs @@ -18,7 +18,7 @@ use crate::responses::{ StreamEvent as ResponsesStreamEvent, }; use crate::{ - FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort, + FunctionContent, FunctionDefinition, ImageUrl, MessagePart, ReasoningEffort, ResponseStreamEvent, ToolCall, ToolCallContent, }; @@ -818,68 +818,6 @@ fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { } } -pub fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -/// Count tokens for an OpenAI model. This is synchronous; callers should spawn -/// it on a background thread if needed. -pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result { - let messages = collect_tiktoken_messages(request); - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::FourOmniMini - | Model::FourPointOneNano - | Model::O1 - | Model::O3 - | Model::O3Mini - | Model::Five - | Model::FiveCodex - | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), - // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer - Model::FivePointOne - | Model::FivePointTwo - | Model::FivePointTwoCodex - | Model::FivePointThreeCodex - | Model::FivePointFour - | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), - } - .map(|tokens| tokens as u64) -} - #[cfg(test)] mod tests { use crate::responses::{ @@ -929,34 +867,6 @@ mod tests { }) } - #[test] - fn tiktoken_rs_support() { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - reasoning_details: None, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in ::iter() { - let count = count_open_ai_tokens(request.clone(), model).unwrap(); - assert!(count > 0); - } - } - #[test] fn responses_stream_maps_text_and_usage() { let events = vec![ diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 425f7d2aa3d9e9259fe005a0e15dee10e4e4baf1..e5105081ca7af71520f6dae7f8f01b818d2b8877 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -8,9 +8,7 @@ use gpui::{ WindowOptions, actions, point, size, transparent_black, }; use language::{Buffer, LanguageRegistry, language_settings::SoftWrap}; -use language_model::{ - ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, -}; +use language_model::{ConfiguredModel, LanguageModelRegistry}; use picker::{Picker, PickerDelegate}; use platform_title_bar::PlatformTitleBar; use release_channel::ReleaseChannel; @@ -165,8 +163,6 @@ pub struct RulesLibrary { struct RuleEditor { title_editor: Entity, body_editor: Entity, - token_count: Option, - pending_token_count: Task>, next_title_and_body_to_save: Option<(String, Rope)>, pending_save: Option>>, _subscriptions: Vec, @@ -785,13 +781,10 @@ impl RulesLibrary { body_editor, next_title_and_body_to_save: None, pending_save: None, - token_count: None, - pending_token_count: Task::ready(None), _subscriptions, }, ); this.set_active_rule(Some(prompt_id), window, cx); - this.count_tokens(prompt_id, window, cx); } Err(error) => { // TODO: we should show the error in the UI. @@ -1019,7 +1012,6 @@ impl RulesLibrary { match event { EditorEvent::BufferEdited => { self.save_rule(prompt_id, window, cx); - self.count_tokens(prompt_id, window, cx); } EditorEvent::Blurred => { title_editor.update(cx, |title_editor, cx| { @@ -1049,7 +1041,6 @@ impl RulesLibrary { match event { EditorEvent::BufferEdited => { self.save_rule(prompt_id, window, cx); - self.count_tokens(prompt_id, window, cx); } EditorEvent::Blurred => { body_editor.update(cx, |body_editor, cx| { @@ -1068,59 +1059,6 @@ impl RulesLibrary { } } - fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context) { - let Some(ConfiguredModel { model, .. }) = - LanguageModelRegistry::read_global(cx).default_model() - else { - return; - }; - if let Some(rule) = self.rule_editors.get_mut(&prompt_id) { - let editor = &rule.body_editor.read(cx); - let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx); - let body = buffer.as_rope().clone(); - rule.pending_token_count = cx.spawn_in(window, async move |this, cx| { - async move { - const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1); - - cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; - let token_count = cx - .update(|_, cx| { - model.count_tokens( - LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::System, - content: vec![body.to_string().into()], - cache: false, - reasoning_details: None, - }], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }, - cx, - ) - })? - .await?; - - this.update(cx, |this, cx| { - let rule_editor = this.rule_editors.get_mut(&prompt_id).unwrap(); - rule_editor.token_count = Some(token_count); - cx.notify(); - }) - } - .log_err() - .await - }); - } - } - fn render_rule_list(&mut self, cx: &mut Context) -> impl IntoElement { v_flex() .id("rule-list") @@ -1293,8 +1231,6 @@ impl RulesLibrary { let rule_metadata = self.store.read(cx).metadata(prompt_id)?; let rule_editor = &self.rule_editors[&prompt_id]; let focus_handle = rule_editor.body_editor.focus_handle(cx); - let registry = LanguageModelRegistry::read_global(cx); - let model = registry.default_model().map(|default| default.model); let built_in = prompt_id.is_built_in(); Some( @@ -1318,52 +1254,17 @@ impl RulesLibrary { built_in, cx, )) - .child( - h_flex() - .h_full() - .flex_shrink_0() - .children(rule_editor.token_count.map(|token_count| { - let token_count: SharedString = - token_count.to_string().into(); - let label_token_count: SharedString = - token_count.to_string().into(); - - div() - .id("token_count") - .mr_1() - .flex_shrink_0() - .tooltip(move |_window, cx| { - Tooltip::with_meta( - "Token Estimation", - None, - format!( - "Model: {}", - model - .as_ref() - .map(|model| model.name().0) - .unwrap_or_default() - ), - cx, - ) - }) - .child( - Label::new(format!( - "{} tokens", - label_token_count - )) - .color(Color::Muted), - ) - })) - .map(|this| { - if built_in { - this.child(self.render_built_in_rule_controls()) - } else { - this.child(self.render_regular_rule_controls( - rule_metadata.default, - )) - } - }), - ), + .child(h_flex().h_full().flex_shrink_0().map(|this| { + if built_in { + this.child(self.render_built_in_rule_controls()) + } else { + this.child( + self.render_regular_rule_controls( + rule_metadata.default, + ), + ) + } + })), ) .child( div() diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml index 2d1c9d0ecebeb8a1e0965b0ac914603b41383f00..8ff020df8c1ccaf284157d8b46ddaa0e678b3cd7 100644 --- a/crates/x_ai/Cargo.toml +++ b/crates/x_ai/Cargo.toml @@ -17,8 +17,6 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true -language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true strum.workspace = true -tiktoken-rs.workspace = true diff --git a/crates/x_ai/src/completion.rs b/crates/x_ai/src/completion.rs deleted file mode 100644 index aad03d227eb82768c972283f7e1617ea7486f22f..0000000000000000000000000000000000000000 --- a/crates/x_ai/src/completion.rs +++ /dev/null @@ -1,30 +0,0 @@ -use anyhow::Result; -use language_model_core::{LanguageModelRequest, Role}; - -use crate::Model; - -/// Count tokens for an xAI model using tiktoken. This is synchronous; -/// callers should spawn it on a background thread if needed. -pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - let model_name = if model.max_token_count() >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) -} diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index bc49a3e2b37d6ac83c66a2fba3af83ea7a451576..afa7d62aa3c99127580779eb0f0f1563e3ab658d 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -1,5 +1,3 @@ -pub mod completion; - use anyhow::Result; use serde::{Deserialize, Serialize}; use strum::EnumIter;