diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 5a581c5db80a4c4f527efc8b1711fbf16c8097f8..45028902e467fe67945ddf444c9ae417dcaed654 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -2809,3 +2809,181 @@ fn setup_context_server( cx.run_until_parked(); mcp_tool_calls_rx } + +#[gpui::test] +async fn test_tokens_before_message(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // First message + let message_1_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(message_1_id.clone(), ["First message"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Before any response, tokens_before_message should return None for first message + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.tokens_before_message(&message_1_id), + None, + "First message should have no tokens before it" + ); + }); + + // Complete first message with usage + fake_model.send_last_completion_stream_text_chunk("Response 1"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // First message still has no tokens before it + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.tokens_before_message(&message_1_id), + None, + "First message should still have no tokens before it after response" + ); + }); + + // Second message + let message_2_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(message_2_id.clone(), ["Second message"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Second message should have first message's input tokens before it + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.tokens_before_message(&message_2_id), + Some(100), + "Second message should have 100 tokens before it (from first request)" + ); + }); + + // Complete second message + fake_model.send_last_completion_stream_text_chunk("Response 2"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 250, // Total for this request (includes previous context) + output_tokens: 75, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Third message + let message_3_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(message_3_id.clone(), ["Third message"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Third message should have second message's input tokens (250) before it + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.tokens_before_message(&message_3_id), + Some(250), + "Third message should have 250 tokens before it (from second request)" + ); + // Second message should still have 100 + assert_eq!( + thread.tokens_before_message(&message_2_id), + Some(100), + "Second message should still have 100 tokens before it" + ); + // First message still has none + assert_eq!( + thread.tokens_before_message(&message_1_id), + None, + "First message should still have no tokens before it" + ); + }); +} + +#[gpui::test] +async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Set up three messages with responses + let message_1_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(message_1_id.clone(), ["Message 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Response 1"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let message_2_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(message_2_id.clone(), ["Message 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Response 2"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 250, + output_tokens: 75, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify initial state + thread.read_with(cx, |thread, _| { + assert_eq!(thread.tokens_before_message(&message_2_id), Some(100)); + }); + + // Truncate at message 2 (removes message 2 and everything after) + thread + .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx)) + .unwrap(); + cx.run_until_parked(); + + // After truncation, message_2_id no longer exists, so lookup should return None + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.tokens_before_message(&message_2_id), + None, + "After truncation, message 2 no longer exists" + ); + // Message 1 still exists but has no tokens before it + assert_eq!( + thread.tokens_before_message(&message_1_id), + None, + "First message still has no tokens before it" + ); + }); +} diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index bb22470b9e7db934f949a13b86fd13f9dc58beed..f8f46af5fe2bbea5888ded6e24495afee71680dd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1095,6 +1095,28 @@ impl Thread { }) } + /// Get the total input token count as of the message before the given message. + /// + /// Returns `None` if: + /// - `target_id` is the first message (no previous message) + /// - The previous message hasn't received a response yet (no usage data) + /// - `target_id` is not found in the messages + pub fn tokens_before_message(&self, target_id: &UserMessageId) -> Option { + let mut previous_user_message_id: Option<&UserMessageId> = None; + + for message in &self.messages { + if let Message::User(user_msg) = message { + if &user_msg.id == target_id { + let prev_id = previous_user_message_id?; + let usage = self.request_token_usage.get(prev_id)?; + return Some(usage.input_tokens); + } + previous_user_message_id = Some(&user_msg.id); + } + } + None + } + /// Look up the active profile and resolve its preferred model if one is configured. fn resolve_profile_model( profile_id: &AgentProfileId, diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index e976b7f5dc36905d2a32b4cdc04869f3267705fe..f0dde3eedea657ea2d2ebe9ede457e329bd8b9a5 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1052,6 +1052,71 @@ pub fn parse_prompt_too_long(message: &str) -> Option { .ok() } +/// Request body for the token counting API. +/// Similar to `Request` but without `max_tokens` since it's not needed for counting. +#[derive(Debug, Serialize)] +pub struct CountTokensRequest { + pub model: String, + pub messages: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thinking: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} + +/// Response from the token counting API. +#[derive(Debug, Deserialize)] +pub struct CountTokensResponse { + pub input_tokens: u64, +} + +/// Count the number of tokens in a message without creating it. +pub async fn count_tokens( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CountTokensRequest, +) -> Result { + let uri = format!("{api_url}/v1/messages/count_tokens"); + + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()) + .header("Content-Type", "application/json"); + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let http_request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(handle_error_response(response, rate_limits).await) + } +} + #[test] fn test_match_window_exceeded() { let error = ApiError { diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 25ba7615dc23e2561648e173588be6d93c28e295..d8c972399c33922386bfba4236e1369d03d338dc 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,6 +1,6 @@ use anthropic::{ - ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, - ToolResultContent, ToolResultPart, Usage, + ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event, + ResponseContent, ToolResultContent, ToolResultPart, Usage, }; use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; @@ -219,68 +219,215 @@ pub struct AnthropicModel { request_limiter: RateLimiter, } -pub fn count_anthropic_tokens( +/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. +pub fn into_anthropic_count_tokens_request( request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - use language_model::MessageContent; - - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); + model: String, + mode: AnthropicModelMode, +) -> CountTokensRequest { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(anthropic::RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if !thinking.is_empty() { + Some(anthropic::RequestContent::Thinking { + thinking, + signature: signature.unwrap_or_default(), + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(anthropic::RequestContent::RedactedThinking { data }) + } else { + None + } } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); + MessageContent::Image(image) => Some(anthropic::RequestContent::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }) } - }, + }) + .collect(); + let anthropic_role = match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; } - } - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, + new_messages.push(anthropic::Message { + role: anthropic_role, + content: anthropic_message_content, }); } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + CountTokensRequest { + model, + messages: new_messages, + system: if system_message.is_empty() { + None + } else { + Some(anthropic::StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed + && let AnthropicModelMode::Thinking { budget_tokens } = mode + { + Some(anthropic::Thinking::Enabled { budget_tokens }) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, + LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, + LanguageModelToolChoice::None => anthropic::ToolChoice::None, + }), + } +} + +/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, +/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. +pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); } + } - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) - }) - .boxed() + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) } impl AnthropicModel { @@ -386,7 +533,40 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_anthropic_tokens(request, cx) + let http_client = self.http_client.clone(); + let model_id = self.model.request_id().to_string(); + let mode = self.model.mode(); + + let (api_key, api_url) = self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + ( + state.api_key_state.key(&api_url).map(|k| k.to_string()), + api_url.to_string(), + ) + }); + + async move { + // If no API key, fall back to tiktoken estimation + let Some(api_key) = api_key else { + return count_anthropic_tokens_with_tiktoken(request); + }; + + let count_request = + into_anthropic_count_tokens_request(request.clone(), model_id, mode); + + match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request) + .await + { + Ok(response) => Ok(response.input_tokens), + Err(err) => { + log::error!( + "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" + ); + count_anthropic_tokens_with_tiktoken(request) + } + } + } + .boxed() } fn stream_completion( diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 508a77d38abcf2143170382e945ab6ce31f3a623..def1cef84d3166d08dcc7638ca5a29cabbd149c5 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -42,7 +42,9 @@ use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; -use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; +use crate::provider::anthropic::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, +}; use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; use crate::provider::x_ai::count_xai_tokens; @@ -667,9 +669,9 @@ impl LanguageModel for CloudLanguageModel { cx: &App, ) -> BoxFuture<'static, Result> { match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - count_anthropic_tokens(request, cx) - } + cloud_llm_client::LanguageModelProvider::Anthropic => cx + .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .boxed(), cloud_llm_client::LanguageModelProvider::OpenAi => { let model = match open_ai::Model::from_id(&self.model.id.0) { Ok(model) => model,