@@ -18536,9 +18536,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
-version = "0.7.0"
+version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4"
+checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
dependencies = [
"anyhow",
"serde",
@@ -606,7 +606,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
-zed_llm_client = "0.7.0"
+zed_llm_client = "0.7.1"
zstd = "0.11"
metal = "0.29"
@@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{TintColor, prelude::*};
use zed_llm_client::{
- CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
- MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
- SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+ CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
+ EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
+ MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
};
use crate::AllLanguageModelSettings;
@@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
match self.model.clone() {
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
- CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
+ CloudModel::Google(model) => {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let request = into_google(request, model.id().into());
+ async move {
+ let http_client = &client.http_client();
+ let token = llm_api_token.acquire(&client).await?;
+
+ let request_builder = http_client::Request::builder().method(Method::POST);
+ let request_builder =
+ if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
+ request_builder.uri(completions_url)
+ } else {
+ request_builder.uri(
+ http_client
+ .build_zed_llm_url("/count_tokens", &[])?
+ .as_ref(),
+ )
+ };
+ let request_body = CountTokensBody {
+ provider: zed_llm_client::LanguageModelProvider::Google,
+ model: model.id().into(),
+ provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
+ contents: request.contents,
+ })?,
+ };
+ let request = request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .body(serde_json::to_string(&request_body)?.into())?;
+ let mut response = http_client.send(request).await?;
+ let status = response.status();
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ if status.is_success() {
+ let response_body: CountTokensResponse =
+ serde_json::from_str(&response_body)?;
+
+ Ok(response_body.tokens)
+ } else {
+ Err(anyhow!(ApiError {
+ status,
+ body: response_body
+ }))
+ }
+ }
+ .boxed()
+ }
}
}