Detailed changes
@@ -2809,3 +2809,181 @@ fn setup_context_server(
cx.run_until_parked();
mcp_tool_calls_rx
}
+
+#[gpui::test]
+async fn test_tokens_before_message(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ // First message
+ let message_1_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_1_id.clone(), ["First message"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Before any response, tokens_before_message should return None for first message
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.tokens_before_message(&message_1_id),
+ None,
+ "First message should have no tokens before it"
+ );
+ });
+
+ // Complete first message with usage
+ fake_model.send_last_completion_stream_text_chunk("Response 1");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 100,
+ output_tokens: 50,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // First message still has no tokens before it
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.tokens_before_message(&message_1_id),
+ None,
+ "First message should still have no tokens before it after response"
+ );
+ });
+
+ // Second message
+ let message_2_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_2_id.clone(), ["Second message"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Second message should have first message's input tokens before it
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.tokens_before_message(&message_2_id),
+ Some(100),
+ "Second message should have 100 tokens before it (from first request)"
+ );
+ });
+
+ // Complete second message
+ fake_model.send_last_completion_stream_text_chunk("Response 2");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 250, // Total for this request (includes previous context)
+ output_tokens: 75,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Third message
+ let message_3_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_3_id.clone(), ["Third message"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Third message should have second message's input tokens (250) before it
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.tokens_before_message(&message_3_id),
+ Some(250),
+ "Third message should have 250 tokens before it (from second request)"
+ );
+ // Second message should still have 100
+ assert_eq!(
+ thread.tokens_before_message(&message_2_id),
+ Some(100),
+ "Second message should still have 100 tokens before it"
+ );
+ // First message still has none
+ assert_eq!(
+ thread.tokens_before_message(&message_1_id),
+ None,
+ "First message should still have no tokens before it"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ // Set up three messages with responses
+ let message_1_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_1_id.clone(), ["Message 1"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Response 1");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 100,
+ output_tokens: 50,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let message_2_id = UserMessageId::new();
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(message_2_id.clone(), ["Message 2"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Response 2");
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 250,
+ output_tokens: 75,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Verify initial state
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
+ });
+
+ // Truncate at message 2 (removes message 2 and everything after)
+ thread
+ .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ // After truncation, message_2_id no longer exists, so lookup should return None
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.tokens_before_message(&message_2_id),
+ None,
+ "After truncation, message 2 no longer exists"
+ );
+ // Message 1 still exists but has no tokens before it
+ assert_eq!(
+ thread.tokens_before_message(&message_1_id),
+ None,
+ "First message still has no tokens before it"
+ );
+ });
+}
@@ -1095,6 +1095,28 @@ impl Thread {
})
}
+ /// Get the total input token count as of the message before the given message.
+ ///
+ /// Returns `None` if:
+ /// - `target_id` is the first message (no previous message)
+ /// - The previous message hasn't received a response yet (no usage data)
+ /// - `target_id` is not found in the messages
+ pub fn tokens_before_message(&self, target_id: &UserMessageId) -> Option<u64> {
+ let mut previous_user_message_id: Option<&UserMessageId> = None;
+
+ for message in &self.messages {
+ if let Message::User(user_msg) = message {
+ if &user_msg.id == target_id {
+ let prev_id = previous_user_message_id?;
+ let usage = self.request_token_usage.get(prev_id)?;
+ return Some(usage.input_tokens);
+ }
+ previous_user_message_id = Some(&user_msg.id);
+ }
+ }
+ None
+ }
+
/// Look up the active profile and resolve its preferred model if one is configured.
fn resolve_profile_model(
profile_id: &AgentProfileId,
@@ -1052,6 +1052,71 @@ pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
.ok()
}
+/// Request body for the token counting API.
+/// Similar to `Request` but without `max_tokens` since it's not needed for counting.
+#[derive(Debug, Serialize)]
+pub struct CountTokensRequest {
+ pub model: String,
+ pub messages: Vec<Message>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub system: Option<StringOrContents>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec<Tool>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub thinking: Option<Thinking>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option<ToolChoice>,
+}
+
+/// Response from the token counting API.
+#[derive(Debug, Deserialize)]
+pub struct CountTokensResponse {
+ pub input_tokens: u64,
+}
+
+/// Count the number of tokens in a message without creating it.
+pub async fn count_tokens(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: CountTokensRequest,
+) -> Result<CountTokensResponse, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/count_tokens");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim())
+ .header("Content-Type", "application/json");
+
+ let serialized_request =
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+ let http_request = request_builder
+ .body(AsyncBody::from(serialized_request))
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(handle_error_response(response, rate_limits).await)
+ }
+}
+
#[test]
fn test_match_window_exceeded() {
let error = ApiError {
@@ -1,6 +1,6 @@
use anthropic::{
- ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent,
- ToolResultContent, ToolResultPart, Usage,
+ ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event,
+ ResponseContent, ToolResultContent, ToolResultPart, Usage,
};
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
@@ -219,68 +219,215 @@ pub struct AnthropicModel {
request_limiter: RateLimiter,
}
-pub fn count_anthropic_tokens(
+/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
+pub fn into_anthropic_count_tokens_request(
request: LanguageModelRequest,
- cx: &App,
-) -> BoxFuture<'static, Result<u64>> {
- cx.background_spawn(async move {
- let messages = request.messages;
- let mut tokens_from_images = 0;
- let mut string_messages = Vec::with_capacity(messages.len());
-
- for message in messages {
- use language_model::MessageContent;
-
- let mut string_contents = String::new();
-
- for content in message.content {
- match content {
- MessageContent::Text(text) => {
- string_contents.push_str(&text);
- }
- MessageContent::Thinking { .. } => {
- // Thinking blocks are not included in the input token count.
- }
- MessageContent::RedactedThinking(_) => {
- // Thinking blocks are not included in the input token count.
- }
- MessageContent::Image(image) => {
- tokens_from_images += image.estimate_tokens();
- }
- MessageContent::ToolUse(_tool_use) => {
- // TODO: Estimate token usage from tool uses.
- }
- MessageContent::ToolResult(tool_result) => match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- string_contents.push_str(text);
+ model: String,
+ mode: AnthropicModelMode,
+) -> CountTokensRequest {
+ let mut new_messages: Vec<anthropic::Message> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in request.messages {
+ if message.contents_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ let anthropic_message_content: Vec<anthropic::RequestContent> = message
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ MessageContent::Text(text) => {
+ let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
+ text.trim_end().to_string()
+ } else {
+ text
+ };
+ if !text.is_empty() {
+ Some(anthropic::RequestContent::Text {
+ text,
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::Thinking {
+ text: thinking,
+ signature,
+ } => {
+ if !thinking.is_empty() {
+ Some(anthropic::RequestContent::Thinking {
+ thinking,
+ signature: signature.unwrap_or_default(),
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::RedactedThinking(data) => {
+ if !data.is_empty() {
+ Some(anthropic::RequestContent::RedactedThinking { data })
+ } else {
+ None
+ }
}
- LanguageModelToolResultContent::Image(image) => {
- tokens_from_images += image.estimate_tokens();
+ MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ cache_control: None,
+ }),
+ MessageContent::ToolUse(tool_use) => {
+ Some(anthropic::RequestContent::ToolUse {
+ id: tool_use.id.to_string(),
+ name: tool_use.name.to_string(),
+ input: tool_use.input,
+ cache_control: None,
+ })
+ }
+ MessageContent::ToolResult(tool_result) => {
+ Some(anthropic::RequestContent::ToolResult {
+ tool_use_id: tool_result.tool_use_id.to_string(),
+ is_error: tool_result.is_error,
+ content: match tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ ToolResultContent::Plain(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ ToolResultContent::Multipart(vec![ToolResultPart::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ }])
+ }
+ },
+ cache_control: None,
+ })
}
- },
+ })
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => anthropic::Role::User,
+ Role::Assistant => anthropic::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
+ if let Some(last_message) = new_messages.last_mut()
+ && last_message.role == anthropic_role
+ {
+ last_message.content.extend(anthropic_message_content);
+ continue;
}
- }
- if !string_contents.is_empty() {
- string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(string_contents),
- name: None,
- function_call: None,
+ new_messages.push(anthropic::Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
});
}
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.string_contents());
+ }
+ }
+ }
+
+ CountTokensRequest {
+ model,
+ messages: new_messages,
+ system: if system_message.is_empty() {
+ None
+ } else {
+ Some(anthropic::StringOrContents::String(system_message))
+ },
+ thinking: if request.thinking_allowed
+ && let AnthropicModelMode::Thinking { budget_tokens } = mode
+ {
+ Some(anthropic::Thinking::Enabled { budget_tokens })
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| anthropic::Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
+ LanguageModelToolChoice::None => anthropic::ToolChoice::None,
+ }),
+ }
+}
+
+/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
+/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
+pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
+ let messages = request.messages;
+ let mut tokens_from_images = 0;
+ let mut string_messages = Vec::with_capacity(messages.len());
+
+ for message in messages {
+ let mut string_contents = String::new();
+
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => {
+ string_contents.push_str(&text);
+ }
+ MessageContent::Thinking { .. } => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::RedactedThinking(_) => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ MessageContent::ToolUse(_tool_use) => {
+ // TODO: Estimate token usage from tool uses.
+ }
+ MessageContent::ToolResult(tool_result) => match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ string_contents.push_str(text);
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ },
+ }
+ }
+
+ if !string_contents.is_empty() {
+ string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(string_contents),
+ name: None,
+ function_call: None,
+ });
}
+ }
- // Tiktoken doesn't yet support these models, so we manually use the
- // same tokenizer as GPT-4.
- tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
- .map(|tokens| (tokens + tokens_from_images) as u64)
- })
- .boxed()
+ // Tiktoken doesn't yet support these models, so we manually use the
+ // same tokenizer as GPT-4.
+ tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
+ .map(|tokens| (tokens + tokens_from_images) as u64)
}
impl AnthropicModel {
@@ -386,7 +533,40 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- count_anthropic_tokens(request, cx)
+ let http_client = self.http_client.clone();
+ let model_id = self.model.request_id().to_string();
+ let mode = self.model.mode();
+
+ let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ (
+ state.api_key_state.key(&api_url).map(|k| k.to_string()),
+ api_url.to_string(),
+ )
+ });
+
+ async move {
+ // If no API key, fall back to tiktoken estimation
+ let Some(api_key) = api_key else {
+ return count_anthropic_tokens_with_tiktoken(request);
+ };
+
+ let count_request =
+ into_anthropic_count_tokens_request(request.clone(), model_id, mode);
+
+ match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request)
+ .await
+ {
+ Ok(response) => Ok(response.input_tokens),
+ Err(err) => {
+ log::error!(
+ "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
+ );
+ count_anthropic_tokens_with_tiktoken(request)
+ }
+ }
+ }
+ .boxed()
}
fn stream_completion(
@@ -42,7 +42,9 @@ use thiserror::Error;
use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
-use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
+use crate::provider::anthropic::{
+ AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
+};
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
use crate::provider::x_ai::count_xai_tokens;
@@ -667,9 +669,9 @@ impl LanguageModel for CloudLanguageModel {
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
match self.model.provider {
- cloud_llm_client::LanguageModelProvider::Anthropic => {
- count_anthropic_tokens(request, cx)
- }
+ cloud_llm_client::LanguageModelProvider::Anthropic => cx
+ .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
+ .boxed(),
cloud_llm_client::LanguageModelProvider::OpenAi => {
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,