diff --git a/Cargo.lock b/Cargo.lock index 4e4d86b947be1f68d03b225d4a62747659c99bf8..b1ff28fcf52e118830e2100d35a6cdbca6f6f013 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19851,6 +19851,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_api_types", "cloud_llm_client", "futures 0.3.31", "gpui", diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index f485e2d20c619715ea342fccd2a5cec0ecaa6f4e..13d67838b216f4990f15ec22c1701aa7aef9dbf2 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -9,7 +9,9 @@ use futures::AsyncReadExt as _; use gpui::{App, Task}; use gpui_tokio::Tokio; use http_client::http::request; -use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode}; +use http_client::{ + AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode, +}; use parking_lot::RwLock; use thiserror::Error; use yawc::WebSocket; @@ -141,6 +143,7 @@ impl CloudApiClient { pub async fn create_llm_token( &self, system_id: Option, + organization_id: Option, ) -> Result { let request_builder = Request::builder() .method(Method::POST) @@ -153,7 +156,10 @@ impl CloudApiClient { builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id) }); - let request = self.build_request(request_builder, AsyncBody::default())?; + let request = self.build_request( + request_builder, + Json(CreateLlmTokenBody { organization_id }), + )?; let mut response = self.http_client.send(request).await?; diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs index 2d457fc6630d5b32f049e67a6a460047e925973a..42d3442bfc016f5cb1a39ba421ccdfe386bcbc65 100644 --- a/crates/cloud_api_types/src/cloud_api_types.rs +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse { #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct LlmToken(pub String); +#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenBody { + #[serde(default)] + pub organization_id: Option, +} + #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct CreateLlmTokenResponse { pub token: LlmToken, diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 33c3ea1e56648c73682e06f685f91f54344200d6..6b2019aa30030b0852f74bc851e2012feac4f0e2 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,7 +1,7 @@ use anyhow::Result; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_api_types::SubmitEditPredictionFeedbackBody; +use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, }; @@ -143,7 +143,7 @@ pub struct EditPredictionStore { pub sweep_ai: SweepAi, pub mercury: Mercury, data_collection_choice: DataCollectionChoice, - reject_predictions_tx: mpsc::UnboundedSender, + reject_predictions_tx: mpsc::UnboundedSender, settled_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, rated_predictions: HashSet, @@ -151,6 +151,11 @@ pub struct EditPredictionStore { settled_event_callback: Option>, } +pub(crate) struct EditPredictionRejectionPayload { + rejection: EditPredictionRejection, + organization_id: Option, +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum EditPredictionModel { Zeta, @@ -719,8 +724,13 @@ impl EditPredictionStore { |this, _listener, _event, cx| { let client = this.client.clone(); let llm_token = this.llm_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); cx.spawn(async move |_this, _cx| { - llm_token.refresh(&client).await?; + llm_token.refresh(&client, organization_id).await?; anyhow::Ok(()) }) .detach_and_log_err(cx); @@ -781,11 +791,17 @@ impl EditPredictionStore { let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + cx.spawn(async move |this, cx| { let experiments = cx .background_spawn(async move { let http_client = client.http_client(); - let token = llm_token.acquire(&client).await?; + let token = llm_token.acquire(&client, organization_id).await?; let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?; let request = http_client::Request::builder() .method(Method::GET) @@ -1424,7 +1440,7 @@ impl EditPredictionStore { } async fn handle_rejected_predictions( - rx: UnboundedReceiver, + rx: UnboundedReceiver, client: Arc, llm_token: LlmApiToken, app_version: Version, @@ -1433,7 +1449,11 @@ impl EditPredictionStore { let mut rx = std::pin::pin!(rx.peekable()); let mut batched = Vec::new(); - while let Some(rejection) = rx.next().await { + while let Some(EditPredictionRejectionPayload { + rejection, + organization_id, + }) = rx.next().await + { batched.push(rejection); if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { @@ -1471,6 +1491,7 @@ impl EditPredictionStore { }, client.clone(), llm_token.clone(), + organization_id, app_version.clone(), true, ) @@ -1676,13 +1697,23 @@ impl EditPredictionStore { all_language_settings(None, cx).edit_predictions.provider, EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi ); + if is_cloud { + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + self.reject_predictions_tx - .unbounded_send(EditPredictionRejection { - request_id: prediction_id.to_string(), - reason, - was_shown, - model_version, + .unbounded_send(EditPredictionRejectionPayload { + rejection: EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + model_version, + }, + organization_id, }) .log_err(); } @@ -2337,6 +2368,7 @@ impl EditPredictionStore { client: Arc, custom_url: Option>, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, ) -> Result<(RawCompletionResponse, Option)> { let url = if let Some(custom_url) = custom_url { @@ -2356,6 +2388,7 @@ impl EditPredictionStore { }, client, llm_token, + organization_id, app_version, true, ) @@ -2366,6 +2399,7 @@ impl EditPredictionStore { input: ZetaPromptInput, client: Arc, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, trigger: PredictEditsRequestTrigger, ) -> Result<(PredictEditsV3Response, Option)> { @@ -2388,6 +2422,7 @@ impl EditPredictionStore { }, client, llm_token, + organization_id, app_version, true, ) @@ -2441,6 +2476,7 @@ impl EditPredictionStore { build: impl Fn(http_client::http::request::Builder) -> Result>, client: Arc, llm_token: LlmApiToken, + organization_id: Option, app_version: Version, require_auth: bool, ) -> Result<(Res, Option)> @@ -2450,9 +2486,12 @@ impl EditPredictionStore { let http_client = client.http_client(); let mut token = if require_auth { - Some(llm_token.acquire(&client).await?) + Some(llm_token.acquire(&client, organization_id.clone()).await?) } else { - llm_token.acquire(&client).await.ok() + llm_token + .acquire(&client, organization_id.clone()) + .await + .ok() }; let mut did_retry = false; @@ -2494,7 +2533,7 @@ impl EditPredictionStore { return Ok((serde_json::from_slice(&body)?, usage)); } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() { did_retry = true; - token = Some(llm_token.refresh(&client).await?); + token = Some(llm_token.refresh(&client, organization_id.clone()).await?); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index f038d2a4ca1929faee2a02391534539b5b63e2d0..8c158c074bf926d2cee9b77cec65b28c4317a22a 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -66,6 +66,11 @@ pub fn request_prediction_with_zeta( let client = store.client.clone(); let llm_token = store.llm_token.clone(); + let organization_id = store + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); let app_version = AppVersion::global(cx); let request_task = cx.background_spawn({ @@ -201,6 +206,7 @@ pub fn request_prediction_with_zeta( client, None, llm_token, + organization_id, app_version, ) .await?; @@ -219,6 +225,7 @@ pub fn request_prediction_with_zeta( prompt_input.clone(), client, llm_token, + organization_id, app_version, trigger, ) @@ -430,6 +437,11 @@ pub(crate) fn edit_prediction_accepted( let require_auth = custom_accept_url.is_none(); let client = store.client.clone(); let llm_token = store.llm_token.clone(); + let organization_id = store + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); let app_version = AppVersion::global(cx); cx.background_spawn(async move { @@ -454,6 +466,7 @@ pub(crate) fn edit_prediction_accepted( }, client, llm_token, + organization_id, app_version, require_auth, ) diff --git a/crates/http_client/src/async_body.rs b/crates/http_client/src/async_body.rs index 8fb49f218568ea36078d772a7225229f31a916c4..a59a7339db1e4449b875e2c539e98c86b4279365 100644 --- a/crates/http_client/src/async_body.rs +++ b/crates/http_client/src/async_body.rs @@ -7,6 +7,7 @@ use std::{ use bytes::Bytes; use futures::AsyncRead; use http_body::{Body, Frame}; +use serde::Serialize; /// Based on the implementation of AsyncBody in /// . @@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody { } } +/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`. +pub struct Json(pub T); + +impl From> for AsyncBody { + fn from(json: Json) -> Self { + Self::from_bytes( + serde_json::to_vec(&json.0) + .expect("failed to serialize JSON") + .into(), + ) + } +} + impl> From> for AsyncBody { fn from(body: Option) -> Self { match body { diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 5cf25a8277872ba3c6d502565e8057623b267d42..bbbe3b1a832332bd6bee693b4c0b916b4f4c182a 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -5,7 +5,7 @@ pub mod github; pub mod github_download; pub use anyhow::{Result, anyhow}; -pub use async_body::{AsyncBody, Inner}; +pub use async_body::{AsyncBody, Inner, Json}; use derive_more::Deref; use http::HeaderValue; pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder}; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 18e099b4d6fc62867bf35fbd1d4573093af44744..b2af80a3c295cab1cf40a330eb8d84f94a137eb7 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; use cloud_api_client::ClientApiError; +use cloud_api_types::OrganizationId; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; @@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError { pub struct LlmApiToken(Arc>>); impl LlmApiToken { - pub async fn acquire(&self, client: &Arc) -> Result { + pub async fn acquire( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { let lock = self.0.upgradable_read().await; if let Some(token) = lock.as_ref() { Ok(token.to_string()) } else { - Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await + Self::fetch( + RwLockUpgradableReadGuard::upgrade(lock).await, + client, + organization_id, + ) + .await } } - pub async fn refresh(&self, client: &Arc) -> Result { - Self::fetch(self.0.write().await, client).await + pub async fn refresh( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { + Self::fetch(self.0.write().await, client, organization_id).await } async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, + organization_id: Option, ) -> Result { let system_id = client .telemetry() .system_id() .map(|system_id| system_id.to_string()); - let result = client.cloud_client().create_llm_token(system_id).await; + let result = client + .cloud_client() + .create_llm_token(system_id, organization_id) + .await; match result { Ok(response) => { *lock = Some(response.token.0.clone()); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 19009013bf84ad9751e9ed0de2d3338b279a258e..b84b19b038905ba9f3d9a0637c770acc95687976 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, UserStore, zed_urls}; -use cloud_api_types::Plan; +use cloud_api_types::{OrganizationId, Plan}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, @@ -122,15 +122,25 @@ impl State { recommended_models: Vec::new(), _fetch_models_task: cx.spawn(async move |this, cx| { maybe!(async move { - let (client, llm_api_token) = this - .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; + let (client, llm_api_token, organization_id) = + this.read_with(cx, |this, cx| { + ( + client.clone(), + this.llm_api_token.clone(), + this.user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()), + ) + })?; while current_user.borrow().is_none() { current_user.next().await; } let response = - Self::fetch_models(client.clone(), llm_api_token.clone()).await?; + Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id) + .await?; this.update(cx, |this, cx| this.update_models(response, cx))?; anyhow::Ok(()) }) @@ -146,9 +156,17 @@ impl State { move |this, _listener, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); cx.spawn(async move |this, cx| { - llm_api_token.refresh(&client).await?; - let response = Self::fetch_models(client, llm_api_token).await?; + llm_api_token + .refresh(&client, organization_id.clone()) + .await?; + let response = + Self::fetch_models(client, llm_api_token, organization_id).await?; this.update(cx, |this, cx| { this.update_models(response, cx); }) @@ -209,9 +227,10 @@ impl State { async fn fetch_models( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, ) -> Result { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client).await?; + let token = llm_api_token.acquire(&client, organization_id).await?; let request = http_client::Request::builder() .method(Method::GET) @@ -273,11 +292,13 @@ impl CloudLanguageModelProvider { &self, model: Arc, llm_api_token: LlmApiToken, + user_store: Entity, ) -> Arc { Arc::new(CloudLanguageModel { id: LanguageModelId(SharedString::from(model.id.0.clone())), model, llm_api_token, + user_store, client: self.client.clone(), request_limiter: RateLimiter::new(4), }) @@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } fn default_model(&self, cx: &App) -> Option> { - let default_model = self.state.read(cx).default_model.clone()?; - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - Some(self.create_language_model(default_model, llm_api_token)) + let state = self.state.read(cx); + let default_model = state.default_model.clone()?; + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + Some(self.create_language_model(default_model, llm_api_token, user_store)) } fn default_fast_model(&self, cx: &App) -> Option> { - let default_fast_model = self.state.read(cx).default_fast_model.clone()?; - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - Some(self.create_language_model(default_fast_model, llm_api_token)) + let state = self.state.read(cx); + let default_fast_model = state.default_fast_model.clone()?; + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + Some(self.create_language_model(default_fast_model, llm_api_token, user_store)) } fn recommended_models(&self, cx: &App) -> Vec> { - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - self.state - .read(cx) + let state = self.state.read(cx); + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + state .recommended_models .iter() .cloned() - .map(|model| self.create_language_model(model, llm_api_token.clone())) + .map(|model| { + self.create_language_model(model, llm_api_token.clone(), user_store.clone()) + }) .collect() } fn provided_models(&self, cx: &App) -> Vec> { - let llm_api_token = self.state.read(cx).llm_api_token.clone(); - self.state - .read(cx) + let state = self.state.read(cx); + let llm_api_token = state.llm_api_token.clone(); + let user_store = state.user_store.clone(); + state .models .iter() .cloned() - .map(|model| self.create_language_model(model, llm_api_token.clone())) + .map(|model| { + self.create_language_model(model, llm_api_token.clone(), user_store.clone()) + }) .collect() } @@ -367,6 +398,7 @@ pub struct CloudLanguageModel { id: LanguageModelId, model: Arc, llm_api_token: LlmApiToken, + user_store: Entity, client: Arc, request_limiter: RateLimiter, } @@ -380,12 +412,15 @@ impl CloudLanguageModel { async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, app_version: Option, body: CompletionBody, ) -> Result { let http_client = &client.http_client(); - let mut token = llm_api_token.acquire(&client).await?; + let mut token = llm_api_token + .acquire(&client, organization_id.clone()) + .await?; let mut refreshed_token = false; loop { @@ -416,7 +451,9 @@ impl CloudLanguageModel { } if !refreshed_token && response.needs_llm_token_refresh() { - token = llm_api_token.refresh(&client).await?; + token = llm_api_token + .refresh(&client, organization_id.clone()) + .await?; refreshed_token = true; continue; } @@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel { cloud_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); let model_id = self.model.id.to_string(); let generate_content_request = into_google(request, model_id.clone(), GoogleModelMode::Default); async move { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client).await?; + let token = llm_api_token.acquire(&client, organization_id).await?; let request_body = CountTokensBody { provider: cloud_llm_client::LanguageModelProvider::Google, @@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel { let prompt_id = request.prompt_id.clone(); let intent = request.intent; let app_version = Some(cx.update(|cx| AppVersion::global(cx))); + let user_store = self.user_store.clone(); + let organization_id = cx.update(|cx| { + user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()) + }); let thinking_allowed = request.thinking_allowed; let enable_thinking = thinking_allowed && self.model.supports_thinking; let provider_name = provider_name(&self.model.provider); @@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, @@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel { cloud_llm_client::LanguageModelProvider::OpenAi => { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let effort = request .thinking_effort .as_ref() @@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel { None, ); let llm_api_token = self.llm_api_token.clone(); + let organization_id = organization_id.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { response, @@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, @@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel { } = Self::perform_llm_completion( client.clone(), llm_api_token, + organization_id, app_version, CompletionBody { thread_id, diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index ecdca5883ff541459e94170986df3b7f16036c5a..ff264edcb150063237c633de746b2f6b9f6f250c 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 2f3ccdbb52a884471250ad458e8b7922437cb9ae..c8bc89953f2b2d3ec62bac07e80f2737522824f7 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; -use client::Client; +use client::{Client, UserStore}; +use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; @@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider { } impl CloudWebSearchProvider { - pub fn new(client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State::new(client, cx)); + pub fn new(client: Arc, user_store: Entity, cx: &mut App) -> Self { + let state = cx.new(|cx| State::new(client, user_store, cx)); Self { state } } @@ -23,24 +24,31 @@ impl CloudWebSearchProvider { pub struct State { client: Arc, + user_store: Entity, llm_api_token: LlmApiToken, _llm_token_subscription: Subscription, } impl State { - pub fn new(client: Arc, cx: &mut Context) -> Self { + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); Self { client, + user_store, llm_api_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, |this, _, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); + let organization_id = this + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); cx.spawn(async move |_this, _cx| { - llm_api_token.refresh(&client).await?; + llm_api_token.refresh(&client, organization_id).await?; anyhow::Ok(()) }) .detach_and_log_err(cx); @@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider { let state = self.state.read(cx); let client = state.client.clone(); let llm_api_token = state.llm_api_token.clone(); + let organization_id = state + .user_store + .read(cx) + .current_organization() + .map(|o| o.id.clone()); let body = WebSearchBody { query }; - cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await }) + cx.background_spawn(async move { + perform_web_search(client, llm_api_token, organization_id, body).await + }) } } async fn perform_web_search( client: Arc, llm_api_token: LlmApiToken, + organization_id: Option, body: WebSearchBody, ) -> Result { const MAX_RETRIES: usize = 3; let http_client = &client.http_client(); let mut retries_remaining = MAX_RETRIES; - let mut token = llm_api_token.acquire(&client).await?; + let mut token = llm_api_token + .acquire(&client, organization_id.clone()) + .await?; loop { if retries_remaining == 0 { @@ -100,7 +118,9 @@ async fn perform_web_search( response.body_mut().read_to_string(&mut body).await?; return Ok(serde_json::from_str(&body)?); } else if response.needs_llm_token_refresh() { - token = llm_api_token.refresh(&client).await?; + token = llm_api_token + .refresh(&client, organization_id.clone()) + .await?; retries_remaining -= 1; } else { // For now we will only retry if the LLM token is expired, diff --git a/crates/web_search_providers/src/web_search_providers.rs b/crates/web_search_providers/src/web_search_providers.rs index 8ab0aee47a414c4cc669ab05e727a827d17c2844..509632429fb167cd489cd4253ceae0ce479b10a8 100644 --- a/crates/web_search_providers/src/web_search_providers.rs +++ b/crates/web_search_providers/src/web_search_providers.rs @@ -1,26 +1,28 @@ mod cloud; -use client::Client; +use client::{Client, UserStore}; use gpui::{App, Context, Entity}; use language_model::LanguageModelRegistry; use std::sync::Arc; use web_search::{WebSearchProviderId, WebSearchRegistry}; -pub fn init(client: Arc, cx: &mut App) { +pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let registry = WebSearchRegistry::global(cx); registry.update(cx, |registry, cx| { - register_web_search_providers(registry, client, cx); + register_web_search_providers(registry, client, user_store, cx); }); } fn register_web_search_providers( registry: &mut WebSearchRegistry, client: Arc, + user_store: Entity, cx: &mut Context, ) { register_zed_web_search_provider( registry, client.clone(), + user_store.clone(), &LanguageModelRegistry::global(cx), cx, ); @@ -29,7 +31,13 @@ fn register_web_search_providers( &LanguageModelRegistry::global(cx), move |this, registry, event, cx| { if let language_model::Event::DefaultModelChanged = event { - register_zed_web_search_provider(this, client.clone(), ®istry, cx) + register_zed_web_search_provider( + this, + client.clone(), + user_store.clone(), + ®istry, + cx, + ) } }, ) @@ -39,6 +47,7 @@ fn register_web_search_providers( fn register_zed_web_search_provider( registry: &mut WebSearchRegistry, client: Arc, + user_store: Entity, language_model_registry: &Entity, cx: &mut Context, ) { @@ -47,7 +56,10 @@ fn register_zed_web_search_provider( .default_model() .is_some_and(|default| default.is_provided_by_zed()); if using_zed_provider { - registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx) + registry.register_provider( + cloud::CloudWebSearchProvider::new(client, user_store, cx), + cx, + ) } else { registry.unregister_provider(WebSearchProviderId( cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(), diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 38238d8af519c0506ab451bccaa1abe3a893e4c9..a3379a6017b7e3b7c26e2a98346e4926e90e0999 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -645,7 +645,7 @@ fn main() { zed::remote_debug::init(cx); edit_prediction_ui::init(cx); web_search::init(cx); - web_search_providers::init(app_state.client.clone(), cx); + web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx); snippet_provider::init(cx); edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 17832bdd1833cabb42af2195f9d9aab1a6bf3fab..20629785c7172241f49a0e7a69f9dcc1953f6a95 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -5021,7 +5021,7 @@ mod tests { language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); git_graph::init(cx); - web_search_providers::init(app_state.client.clone(), cx); + web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); project::AgentRegistryStore::init_global( cx,