From def342e35c9795a505663a84b597f5b708b2b832 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 24 Feb 2025 17:46:45 -0500 Subject: [PATCH] Remove dependents of `language_models` (#25511) This PR removes the dependents of the `language_models` crate. The following types have been moved from `language_models` to `language_model` to facilitate this: - `LlmApiToken` - `RefreshLlmTokenListener` - `MaxMonthlySpendReachedError` - `PaymentRequiredError` With this change only `zed` now depends on `language_models`. Release Notes: - N/A --- Cargo.lock | 5 +- crates/assistant2/Cargo.toml | 1 - crates/assistant2/src/thread.rs | 4 +- crates/assistant_context_editor/Cargo.toml | 1 - .../assistant_context_editor/src/context.rs | 4 +- crates/language_model/src/language_model.rs | 4 +- .../language_model/src/model/cloud_model.rs | 101 ++++++++++++++++- crates/language_models/Cargo.toml | 1 - crates/language_models/src/language_models.rs | 4 - crates/language_models/src/provider/cloud.rs | 105 +----------------- crates/zed/src/main.rs | 2 +- crates/zed/src/zed.rs | 2 +- crates/zeta/Cargo.toml | 2 +- crates/zeta/src/zeta.rs | 5 +- 14 files changed, 117 insertions(+), 124 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a932282cfff5a170cb492e33f64df4af4543188b..9b959de3362eb54bcd86f278fe0dea8da541639c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -462,7 +462,6 @@ dependencies = [ "language", "language_model", "language_model_selector", - "language_models", "log", "lsp", "markdown", @@ -517,7 +516,6 @@ dependencies = [ "language", "language_model", "language_model_selector", - "language_models", "languages", "log", "multi_buffer", @@ -7083,7 +7081,6 @@ dependencies = [ "smol", "strum", "theme", - "thiserror 1.0.69", "tiktoken-rs", "ui", "util", @@ -17190,7 +17187,7 @@ dependencies = [ "indoc", "inline_completion", "language", - "language_models", + "language_model", "log", "menu", "migrator", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 9a74a5e2fec07bc54c76a7671f5a72315fb02af4..13116c2ab5f0221a3f858388d70925b4a7df5168 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -46,7 +46,6 @@ itertools.workspace = true language.workspace = true language_model.workspace = true language_model_selector.workspace = true -language_models.workspace = true log.workspace = true lsp.workspace = true markdown.workspace = true diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 9ccb0664807bf546649e3a2bcef1f86b2b173a5b..7eeb13f8726e42f8239dd8086c04d2e83a9df0c3 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -10,9 +10,9 @@ use gpui::{App, Context, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, Role, StopReason, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, + Role, StopReason, }; -use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; diff --git a/crates/assistant_context_editor/Cargo.toml b/crates/assistant_context_editor/Cargo.toml index aebadc4ca9d0296b62b605bf73dddd129ef35638..0feb0543bf85f088f8d8a434b8d3cb7877eea945 100644 --- a/crates/assistant_context_editor/Cargo.toml +++ b/crates/assistant_context_editor/Cargo.toml @@ -30,7 +30,6 @@ indexed_docs.workspace = true language.workspace = true language_model.workspace = true language_model_selector.workspace = true -language_models.workspace = true log.workspace = true multi_buffer.workspace = true open_ai.workspace = true diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 0ebdb9049a94bb16a05ce430e1bce569345395a6..d6447572bc7e037f5457f9faa0573111df31909d 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -21,9 +21,9 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P use language_model::{ report_assistant_event, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason, + LanguageModelRequestMessage, LanguageModelToolUseId, MaxMonthlySpendReachedError, + MessageContent, PaymentRequiredError, Role, StopReason, }; -use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 964d3b84e7a1b245487c95c39549d064e70d7b6f..72ff92142d978b46cdcb82c220ebdd0de248a589 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -9,6 +9,7 @@ mod telemetry; pub mod fake_provider; use anyhow::Result; +use client::Client; use futures::FutureExt; use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -29,8 +30,9 @@ pub use crate::telemetry::*; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; -pub fn init(cx: &mut App) { +pub fn init(client: Arc, cx: &mut App) { registry::init(cx); + RefreshLlmTokenListener::register(client.clone(), cx); } /// The availability of a [`LanguageModel`]. diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 3d935b7d9d9b50a641bb78f7c86a8ffa76fb2c1d..2a09ab5c4219bca4609c8a229b8fd34a3000de65 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,7 +1,17 @@ -use proto::Plan; +use std::fmt; +use std::sync::Arc; + +use anyhow::Result; +use client::Client; +use gpui::{ + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, +}; +use proto::{Plan, TypedEnvelope}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use strum::EnumIter; +use thiserror::Error; use ui::IconName; use crate::LanguageModelAvailability; @@ -102,3 +112,92 @@ impl CloudModel { } } } + +#[derive(Error, Debug)] +pub struct PaymentRequiredError; + +impl fmt::Display for PaymentRequiredError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Payment required to use this language model. Please upgrade your account." + ) + } +} + +#[derive(Error, Debug)] +pub struct MaxMonthlySpendReachedError; + +impl fmt::Display for MaxMonthlySpendReachedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Maximum spending limit reached for this month. For more usage, increase your spending limit." + ) + } +} + +#[derive(Clone, Default)] +pub struct LlmApiToken(Arc>>); + +impl LlmApiToken { + pub async fn acquire(&self, client: &Arc) -> Result { + let lock = self.0.upgradable_read().await; + if let Some(token) = lock.as_ref() { + Ok(token.to_string()) + } else { + Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await + } + } + + pub async fn refresh(&self, client: &Arc) -> Result { + Self::fetch(self.0.write().await, client).await + } + + async fn fetch<'a>( + mut lock: RwLockWriteGuard<'a, Option>, + client: &Arc, + ) -> Result { + let response = client.request(proto::GetLlmToken {}).await?; + *lock = Some(response.token.clone()); + Ok(response.token.clone()) + } +} + +struct GlobalRefreshLlmTokenListener(Entity); + +impl Global for GlobalRefreshLlmTokenListener {} + +pub struct RefreshLlmTokenEvent; + +pub struct RefreshLlmTokenListener { + _llm_token_subscription: client::Subscription, +} + +impl EventEmitter for RefreshLlmTokenListener {} + +impl RefreshLlmTokenListener { + pub fn register(client: Arc, cx: &mut App) { + let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx)); + cx.set_global(GlobalRefreshLlmTokenListener(listener)); + } + + pub fn global(cx: &App) -> Entity { + GlobalRefreshLlmTokenListener::global(cx).0.clone() + } + + fn new(client: Arc, cx: &mut Context) -> Self { + Self { + _llm_token_subscription: client + .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), + } + } + + async fn handle_refresh_llm_token( + this: Entity, + _: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result<()> { + this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index cc64a2b5c0eeb21036edd36cc16565983a44d6b5..7cec7fb4a073b63a466bbe1e134f74e1871749c4 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -41,7 +41,6 @@ settings.workspace = true smol.workspace = true strum.workspace = true theme.workspace = true -thiserror.workspace = true tiktoken-rs.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 6a963875e6eaf158d0b03822b115e021516e816b..6ed0e30959afbde4d3f6e932423fdd89e866de8b 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -11,8 +11,6 @@ mod settings; use crate::provider::anthropic::AnthropicLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; -pub use crate::provider::cloud::LlmApiToken; -pub use crate::provider::cloud::RefreshLlmTokenListener; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider; @@ -37,8 +35,6 @@ fn register_language_model_providers( ) { use feature_flags::FeatureFlagAppExt; - RefreshLlmTokenListener::register(client.clone(), cx); - registry.register_provider( AnthropicLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index cbdf1785e01c95a3efeff5020aa7814a8ca74fcb..9c9401532a8d3b83900a4ce8d9221538ac7b4928 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -10,10 +10,7 @@ use futures::{ future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, }; -use gpui::{ - AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal, - Subscription, Task, -}; +use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, @@ -22,24 +19,19 @@ use language_model::{ ZED_CLOUD_PROVIDER_ID, }; use language_model::{ - LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, + LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, + MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener, }; -use proto::TypedEnvelope; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; -use smol::{ - io::{AsyncReadExt, BufReader}, - lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, -}; -use std::fmt; +use smol::io::{AsyncReadExt, BufReader}; use std::{ future, sync::{Arc, LazyLock}, }; use strum::IntoEnumIterator; -use thiserror::Error; use ui::{prelude::*, TintColor}; use crate::provider::anthropic::{ @@ -101,44 +93,6 @@ pub struct AvailableModel { pub extra_beta_headers: Vec, } -struct GlobalRefreshLlmTokenListener(Entity); - -impl Global for GlobalRefreshLlmTokenListener {} - -pub struct RefreshLlmTokenEvent; - -pub struct RefreshLlmTokenListener { - _llm_token_subscription: client::Subscription, -} - -impl EventEmitter for RefreshLlmTokenListener {} - -impl RefreshLlmTokenListener { - pub fn register(client: Arc, cx: &mut App) { - let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx)); - cx.set_global(GlobalRefreshLlmTokenListener(listener)); - } - - pub fn global(cx: &App) -> Entity { - GlobalRefreshLlmTokenListener::global(cx).0.clone() - } - - fn new(client: Arc, cx: &mut Context) -> Self { - Self { - _llm_token_subscription: client - .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), - } - } - - async fn handle_refresh_llm_token( - this: Entity, - _: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) - } -} - pub struct CloudLanguageModelProvider { client: Arc, state: gpui::Entity, @@ -475,33 +429,6 @@ pub struct CloudLanguageModel { request_limiter: RateLimiter, } -#[derive(Clone, Default)] -pub struct LlmApiToken(Arc>>); - -#[derive(Error, Debug)] -pub struct PaymentRequiredError; - -impl fmt::Display for PaymentRequiredError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Payment required to use this language model. Please upgrade your account." - ) - } -} - -#[derive(Error, Debug)] -pub struct MaxMonthlySpendReachedError; - -impl fmt::Display for MaxMonthlySpendReachedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Maximum spending limit reached for this month. For more usage, increase your spending limit." - ) - } -} - impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, @@ -847,30 +774,6 @@ fn response_lines( ) } -impl LlmApiToken { - pub async fn acquire(&self, client: &Arc) -> Result { - let lock = self.0.upgradable_read().await; - if let Some(token) = lock.as_ref() { - Ok(token.to_string()) - } else { - Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await - } - } - - pub async fn refresh(&self, client: &Arc) -> Result { - Self::fetch(self.0.write().await, client).await - } - - async fn fetch<'a>( - mut lock: RwLockWriteGuard<'a, Option>, - client: &Arc, - ) -> Result { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) - } -} - struct ConfigurationView { state: gpui::Entity, } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 0f537cdee1afce855681678962036e22910c5fd8..b25052f8399a5c9827d88108a20119f6b133fd94 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -436,7 +436,7 @@ fn main() { cx, ); supermaven::init(app_state.client.clone(), cx); - language_model::init(cx); + language_model::init(app_state.client.clone(), cx); language_models::init( app_state.user_store.clone(), app_state.client.clone(), diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 6169fa75d018d05682c9b2943ed37bb6ae5e094e..8ee2a6cdb8a5027ce0231c6d795db5d5d993201b 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4237,7 +4237,7 @@ mod tests { cx, ); image_viewer::init(cx); - language_model::init(cx); + language_model::init(app_state.client.clone(), cx); language_models::init( app_state.user_store.clone(), app_state.client.clone(), diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 515624962a6b94667f06b64c803b7ecdb9e22a13..6621417b1885f6674e1cd04298173c0edbba2c68 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -33,7 +33,7 @@ http_client.workspace = true indoc.workspace = true inline_completion.workspace = true language.workspace = true -language_models.workspace = true +language_model.workspace = true log.workspace = true menu.workspace = true migrator.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 82fd0e9991d19fa7bcef859589d2834acc064397..e6cda8e7e5f26e2056a0ae4b732e00a3558e0a62 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -31,7 +31,7 @@ use input_excerpt::excerpt_for_cursor_position; use language::{ text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, }; -use language_models::LlmApiToken; +use language_model::{LlmApiToken, RefreshLlmTokenListener}; use postage::watch; use project::Project; use release_channel::AppVersion; @@ -244,7 +244,7 @@ impl Zeta { user_store: Entity, cx: &mut Context, ) -> Self { - let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx); + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = cx.new(|_| data_collection_choice); @@ -1649,7 +1649,6 @@ mod tests { use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use language_models::RefreshLlmTokenListener; use rpc::proto; use settings::SettingsStore;