Cargo.lock 🔗
@@ -2998,7 +2998,6 @@ dependencies = [
"git",
"git_hosting_providers",
"git_ui",
- "google_ai",
"gpui",
"gpui_tokio",
"hex",
Marshall Bowers created
This PR removes the `CountLanguageModelTokens` RPC message from collab.
We were only using this for Google AI models through the Zed provider
(which is only available to Zed staff).
For now we're returning `0`, but will bring back soon.
Release Notes:
- N/A
Cargo.lock | 1
crates/collab/Cargo.toml | 1
crates/collab/src/rpc.rs | 111 ---------------------
crates/language_models/src/provider/cloud.rs | 19 ---
crates/proto/proto/ai.proto | 16 ---
crates/proto/proto/zed.proto | 4
crates/proto/src/proto.rs | 3
7 files changed, 4 insertions(+), 151 deletions(-)
@@ -2998,7 +2998,6 @@ dependencies = [
"git",
"git_hosting_providers",
"git_ui",
- "google_ai",
"gpui",
"gpui_tokio",
"hex",
@@ -34,7 +34,6 @@ dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
-google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true
@@ -3,7 +3,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
- AppState, Config, Error, RateLimit, Result, auth,
+ AppState, Error, Result, auth,
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@@ -33,7 +33,6 @@ use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
-use http_client::HttpClient;
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
@@ -132,7 +131,6 @@ struct Session {
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
- http_client: Arc<dyn HttpClient>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@@ -425,17 +423,7 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
- .add_message_handler(update_context)
- .add_request_handler({
- let app_state = app_state.clone();
- move |request, response, session| {
- let app_state = app_state.clone();
- async move {
- count_language_model_tokens(request, response, session, &app_state.config)
- .await
- }
- }
- });
+ .add_message_handler(update_context);
Arc::new(server)
}
@@ -764,7 +752,6 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
app_state: this.app_state.clone(),
- http_client,
geoip_country_code,
system_id,
_executor: executor.clone(),
@@ -3683,100 +3670,6 @@ async fn acknowledge_buffer_version(
Ok(())
}
-async fn count_language_model_tokens(
- request: proto::CountLanguageModelTokens,
- response: Response<proto::CountLanguageModelTokens>,
- session: Session,
- config: &Config,
-) -> Result<()> {
- authorize_access_to_legacy_llm_endpoints(&session).await?;
-
- let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
- proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
- proto::Plan::Free | proto::Plan::ZedProTrial => {
- Box::new(FreeCountLanguageModelTokensRateLimit)
- }
- };
-
- session
- .app_state
- .rate_limiter
- .check(&*rate_limit, session.user_id())
- .await?;
-
- let result = match proto::LanguageModelProvider::from_i32(request.provider) {
- Some(proto::LanguageModelProvider::Google) => {
- let api_key = config
- .google_ai_api_key
- .as_ref()
- .context("no Google AI API key configured on the server")?;
- google_ai::count_tokens(
- session.http_client.as_ref(),
- google_ai::API_URL,
- api_key,
- serde_json::from_str(&request.request)?,
- )
- .await?
- }
- _ => return Err(anyhow!("unsupported provider"))?,
- };
-
- response.send(proto::CountLanguageModelTokensResponse {
- token_count: result.total_tokens as u32,
- })?;
-
- Ok(())
-}
-
-struct ZedProCountLanguageModelTokensRateLimit;
-
-impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(600) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "zed-pro:count-language-model-tokens"
- }
-}
-
-struct FreeCountLanguageModelTokensRateLimit;
-
-impl RateLimit for FreeCountLanguageModelTokensRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(600 / 10) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "free:count-language-model-tokens"
- }
-}
-
-/// This is leftover from before the LLM service.
-///
-/// The endpoints protected by this check will be moved there eventually.
-async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
- if session.is_staff() {
- Ok(())
- } else {
- Err(anyhow!("permission denied"))?
- }
-}
-
/// Get a Supermaven API key for the user
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,
@@ -686,24 +686,7 @@ impl LanguageModel for CloudLanguageModel {
match self.model.clone() {
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
- CloudModel::Google(model) => {
- let client = self.client.clone();
- let request = into_google(request, model.id().into());
- let request = google_ai::CountTokensRequest {
- contents: request.contents,
- };
- async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request(proto::CountLanguageModelTokens {
- provider: proto::LanguageModelProvider::Google as i32,
- request,
- })
- .await?;
- Ok(response.token_count as usize)
- }
- .boxed()
- }
+ CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
}
}
@@ -172,19 +172,3 @@ enum LanguageModelRole {
LanguageModelSystem = 2;
reserved 3;
}
-
-message CountLanguageModelTokens {
- LanguageModelProvider provider = 1;
- string request = 2;
-}
-
-message CountLanguageModelTokensResponse {
- uint32 token_count = 1;
-}
-
-enum LanguageModelProvider {
- Anthropic = 0;
- OpenAI = 1;
- Google = 2;
- Zed = 3;
-}
@@ -206,9 +206,6 @@ message Envelope {
GetImplementation get_implementation = 162;
GetImplementationResponse get_implementation_response = 163;
- CountLanguageModelTokens count_language_model_tokens = 230;
- CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
-
UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171;
@@ -397,6 +394,7 @@ message Envelope {
reserved 205 to 206;
reserved 221;
reserved 224 to 229;
+ reserved 230 to 231;
reserved 246;
reserved 270;
reserved 247 to 254;
@@ -50,8 +50,6 @@ messages!(
(CloseBuffer, Foreground),
(Commit, Background),
(CopyProjectEntry, Foreground),
- (CountLanguageModelTokens, Background),
- (CountLanguageModelTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@@ -374,7 +372,6 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
- (CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack),
(RefreshCodeLens, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),