Detailed changes
@@ -462,7 +462,6 @@ dependencies = [
"language",
"language_model",
"language_model_selector",
- "language_models",
"log",
"lsp",
"markdown",
@@ -517,7 +516,6 @@ dependencies = [
"language",
"language_model",
"language_model_selector",
- "language_models",
"languages",
"log",
"multi_buffer",
@@ -7083,7 +7081,6 @@ dependencies = [
"smol",
"strum",
"theme",
- "thiserror 1.0.69",
"tiktoken-rs",
"ui",
"util",
@@ -17190,7 +17187,7 @@ dependencies = [
"indoc",
"inline_completion",
"language",
- "language_models",
+ "language_model",
"log",
"menu",
"migrator",
@@ -46,7 +46,6 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
-language_models.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
@@ -10,9 +10,9 @@ use gpui::{App, Context, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
- LanguageModelToolUseId, MessageContent, Role, StopReason,
+ LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
+ Role, StopReason,
};
-use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
use uuid::Uuid;
@@ -30,7 +30,6 @@ indexed_docs.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
-language_models.workspace = true
log.workspace = true
multi_buffer.workspace = true
open_ai.workspace = true
@@ -21,9 +21,9 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
report_assistant_event, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason,
+ LanguageModelRequestMessage, LanguageModelToolUseId, MaxMonthlySpendReachedError,
+ MessageContent, PaymentRequiredError, Role, StopReason,
};
-use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
@@ -9,6 +9,7 @@ mod telemetry;
pub mod fake_provider;
use anyhow::Result;
+use client::Client;
use futures::FutureExt;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -29,8 +30,9 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
-pub fn init(cx: &mut App) {
+pub fn init(client: Arc<Client>, cx: &mut App) {
registry::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), cx);
}
/// The availability of a [`LanguageModel`].
@@ -1,7 +1,17 @@
-use proto::Plan;
+use std::fmt;
+use std::sync::Arc;
+
+use anyhow::Result;
+use client::Client;
+use gpui::{
+ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
+};
+use proto::{Plan, TypedEnvelope};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
+use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use strum::EnumIter;
+use thiserror::Error;
use ui::IconName;
use crate::LanguageModelAvailability;
@@ -102,3 +112,92 @@ impl CloudModel {
}
}
}
+
+#[derive(Error, Debug)]
+pub struct PaymentRequiredError;
+
+impl fmt::Display for PaymentRequiredError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Payment required to use this language model. Please upgrade your account."
+ )
+ }
+}
+
+#[derive(Error, Debug)]
+pub struct MaxMonthlySpendReachedError;
+
+impl fmt::Display for MaxMonthlySpendReachedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Maximum spending limit reached for this month. For more usage, increase your spending limit."
+ )
+ }
+}
+
+#[derive(Clone, Default)]
+pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
+
+impl LlmApiToken {
+ pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+ let lock = self.0.upgradable_read().await;
+ if let Some(token) = lock.as_ref() {
+ Ok(token.to_string())
+ } else {
+ Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
+ }
+ }
+
+ pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
+ Self::fetch(self.0.write().await, client).await
+ }
+
+ async fn fetch<'a>(
+ mut lock: RwLockWriteGuard<'a, Option<String>>,
+ client: &Arc<Client>,
+ ) -> Result<String> {
+ let response = client.request(proto::GetLlmToken {}).await?;
+ *lock = Some(response.token.clone());
+ Ok(response.token.clone())
+ }
+}
+
+struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct RefreshLlmTokenEvent;
+
+pub struct RefreshLlmTokenListener {
+ _llm_token_subscription: client::Subscription,
+}
+
+impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+ pub fn register(client: Arc<Client>, cx: &mut App) {
+ let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
+ cx.set_global(GlobalRefreshLlmTokenListener(listener));
+ }
+
+ pub fn global(cx: &App) -> Entity<Self> {
+ GlobalRefreshLlmTokenListener::global(cx).0.clone()
+ }
+
+ fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+ Self {
+ _llm_token_subscription: client
+ .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
+ }
+ }
+
+ async fn handle_refresh_llm_token(
+ this: Entity<Self>,
+ _: TypedEnvelope<proto::RefreshLlmToken>,
+ mut cx: AsyncApp,
+ ) -> Result<()> {
+ this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
+ }
+}
@@ -41,7 +41,6 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
-thiserror.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true
@@ -11,8 +11,6 @@ mod settings;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
-pub use crate::provider::cloud::LlmApiToken;
-pub use crate::provider::cloud::RefreshLlmTokenListener;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@@ -37,8 +35,6 @@ fn register_language_model_providers(
) {
use feature_flags::FeatureFlagAppExt;
- RefreshLlmTokenListener::register(client.clone(), cx);
-
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,
@@ -10,10 +10,7 @@ use futures::{
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
TryStreamExt as _,
};
-use gpui::{
- AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal,
- Subscription, Task,
-};
+use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
@@ -22,24 +19,19 @@ use language_model::{
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
- LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
+ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
+ MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
};
-use proto::TypedEnvelope;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
-use smol::{
- io::{AsyncReadExt, BufReader},
- lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
-};
-use std::fmt;
+use smol::io::{AsyncReadExt, BufReader};
use std::{
future,
sync::{Arc, LazyLock},
};
use strum::IntoEnumIterator;
-use thiserror::Error;
use ui::{prelude::*, TintColor};
use crate::provider::anthropic::{
@@ -101,44 +93,6 @@ pub struct AvailableModel {
pub extra_beta_headers: Vec<String>,
}
-struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct RefreshLlmTokenEvent;
-
-pub struct RefreshLlmTokenListener {
- _llm_token_subscription: client::Subscription,
-}
-
-impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
- pub fn register(client: Arc<Client>, cx: &mut App) {
- let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
- cx.set_global(GlobalRefreshLlmTokenListener(listener));
- }
-
- pub fn global(cx: &App) -> Entity<Self> {
- GlobalRefreshLlmTokenListener::global(cx).0.clone()
- }
-
- fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
- Self {
- _llm_token_subscription: client
- .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
- }
- }
-
- async fn handle_refresh_llm_token(
- this: Entity<Self>,
- _: TypedEnvelope<proto::RefreshLlmToken>,
- mut cx: AsyncApp,
- ) -> Result<()> {
- this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
- }
-}
-
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
state: gpui::Entity<State>,
@@ -475,33 +429,6 @@ pub struct CloudLanguageModel {
request_limiter: RateLimiter,
}
-#[derive(Clone, Default)]
-pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
-
-#[derive(Error, Debug)]
-pub struct PaymentRequiredError;
-
-impl fmt::Display for PaymentRequiredError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(
- f,
- "Payment required to use this language model. Please upgrade your account."
- )
- }
-}
-
-#[derive(Error, Debug)]
-pub struct MaxMonthlySpendReachedError;
-
-impl fmt::Display for MaxMonthlySpendReachedError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(
- f,
- "Maximum spending limit reached for this month. For more usage, increase your spending limit."
- )
- }
-}
-
impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
@@ -847,30 +774,6 @@ fn response_lines<T: DeserializeOwned>(
)
}
-impl LlmApiToken {
- pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
- let lock = self.0.upgradable_read().await;
- if let Some(token) = lock.as_ref() {
- Ok(token.to_string())
- } else {
- Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
- }
- }
-
- pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
- Self::fetch(self.0.write().await, client).await
- }
-
- async fn fetch<'a>(
- mut lock: RwLockWriteGuard<'a, Option<String>>,
- client: &Arc<Client>,
- ) -> Result<String> {
- let response = client.request(proto::GetLlmToken {}).await?;
- *lock = Some(response.token.clone());
- Ok(response.token.clone())
- }
-}
-
struct ConfigurationView {
state: gpui::Entity<State>,
}
@@ -436,7 +436,7 @@ fn main() {
cx,
);
supermaven::init(app_state.client.clone(), cx);
- language_model::init(cx);
+ language_model::init(app_state.client.clone(), cx);
language_models::init(
app_state.user_store.clone(),
app_state.client.clone(),
@@ -4237,7 +4237,7 @@ mod tests {
cx,
);
image_viewer::init(cx);
- language_model::init(cx);
+ language_model::init(app_state.client.clone(), cx);
language_models::init(
app_state.user_store.clone(),
app_state.client.clone(),
@@ -33,7 +33,7 @@ http_client.workspace = true
indoc.workspace = true
inline_completion.workspace = true
language.workspace = true
-language_models.workspace = true
+language_model.workspace = true
log.workspace = true
menu.workspace = true
migrator.workspace = true
@@ -31,7 +31,7 @@ use input_excerpt::excerpt_for_cursor_position;
use language::{
text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint,
};
-use language_models::LlmApiToken;
+use language_model::{LlmApiToken, RefreshLlmTokenListener};
use postage::watch;
use project::Project;
use release_channel::AppVersion;
@@ -244,7 +244,7 @@ impl Zeta {
user_store: Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
- let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice);
@@ -1649,7 +1649,6 @@ mod tests {
use http_client::FakeHttpClient;
use indoc::indoc;
use language::Point;
- use language_models::RefreshLlmTokenListener;
use rpc::proto;
use settings::SettingsStore;