cloud_api_client: Send the organization ID in LLM token requests (#50517)

Tom Houlé created

This is already expected on the cloud side. This lets us know under
which organization the user is logged in when requesting an llm_api
token.

Closes CLO-337

Release Notes:

- N/A

Change summary

Cargo.lock                                              |   1 
crates/cloud_api_client/src/cloud_api_client.rs         |  10 
crates/cloud_api_types/src/cloud_api_types.rs           |   6 
crates/edit_prediction/src/edit_prediction.rs           |  67 +++++-
crates/edit_prediction/src/zeta.rs                      |  13 +
crates/http_client/src/async_body.rs                    |  14 +
crates/http_client/src/http_client.rs                   |   2 
crates/language_model/src/model/cloud_model.rs          |  28 ++
crates/language_models/src/provider/cloud.rs            | 104 ++++++++--
crates/web_search_providers/Cargo.toml                  |   1 
crates/web_search_providers/src/cloud.rs                |  36 ++
crates/web_search_providers/src/web_search_providers.rs |  22 +
crates/zed/src/main.rs                                  |   2 
crates/zed/src/zed.rs                                   |   2 
14 files changed, 247 insertions(+), 61 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -19851,6 +19851,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "client",
+ "cloud_api_types",
  "cloud_llm_client",
  "futures 0.3.31",
  "gpui",

crates/cloud_api_client/src/cloud_api_client.rs 🔗

@@ -9,7 +9,9 @@ use futures::AsyncReadExt as _;
 use gpui::{App, Task};
 use gpui_tokio::Tokio;
 use http_client::http::request;
-use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode};
+use http_client::{
+    AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode,
+};
 use parking_lot::RwLock;
 use thiserror::Error;
 use yawc::WebSocket;
@@ -141,6 +143,7 @@ impl CloudApiClient {
     pub async fn create_llm_token(
         &self,
         system_id: Option<String>,
+        organization_id: Option<OrganizationId>,
     ) -> Result<CreateLlmTokenResponse, ClientApiError> {
         let request_builder = Request::builder()
             .method(Method::POST)
@@ -153,7 +156,10 @@ impl CloudApiClient {
                 builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id)
             });
 
-        let request = self.build_request(request_builder, AsyncBody::default())?;
+        let request = self.build_request(
+            request_builder,
+            Json(CreateLlmTokenBody { organization_id }),
+        )?;
 
         let mut response = self.http_client.send(request).await?;
 

crates/cloud_api_types/src/cloud_api_types.rs 🔗

@@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse {
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub struct LlmToken(pub String);
 
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+pub struct CreateLlmTokenBody {
+    #[serde(default)]
+    pub organization_id: Option<OrganizationId>,
+}
+
 #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 pub struct CreateLlmTokenResponse {
     pub token: LlmToken,

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1,7 +1,7 @@
 use anyhow::Result;
 use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
-use cloud_api_types::SubmitEditPredictionFeedbackBody;
+use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
 use cloud_llm_client::predict_edits_v3::{
     PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
 };
@@ -143,7 +143,7 @@ pub struct EditPredictionStore {
     pub sweep_ai: SweepAi,
     pub mercury: Mercury,
     data_collection_choice: DataCollectionChoice,
-    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
+    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
     settled_predictions_tx: mpsc::UnboundedSender<Instant>,
     shown_predictions: VecDeque<EditPrediction>,
     rated_predictions: HashSet<EditPredictionId>,
@@ -151,6 +151,11 @@ pub struct EditPredictionStore {
     settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 }
 
+pub(crate) struct EditPredictionRejectionPayload {
+    rejection: EditPredictionRejection,
+    organization_id: Option<OrganizationId>,
+}
+
 #[derive(Copy, Clone, PartialEq, Eq)]
 pub enum EditPredictionModel {
     Zeta,
@@ -719,8 +724,13 @@ impl EditPredictionStore {
                 |this, _listener, _event, cx| {
                     let client = this.client.clone();
                     let llm_token = this.llm_token.clone();
+                    let organization_id = this
+                        .user_store
+                        .read(cx)
+                        .current_organization()
+                        .map(|organization| organization.id.clone());
                     cx.spawn(async move |_this, _cx| {
-                        llm_token.refresh(&client).await?;
+                        llm_token.refresh(&client, organization_id).await?;
                         anyhow::Ok(())
                     })
                     .detach_and_log_err(cx);
@@ -781,11 +791,17 @@ impl EditPredictionStore {
         let client = self.client.clone();
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
+        let organization_id = self
+            .user_store
+            .read(cx)
+            .current_organization()
+            .map(|organization| organization.id.clone());
+
         cx.spawn(async move |this, cx| {
             let experiments = cx
                 .background_spawn(async move {
                     let http_client = client.http_client();
-                    let token = llm_token.acquire(&client).await?;
+                    let token = llm_token.acquire(&client, organization_id).await?;
                     let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
                     let request = http_client::Request::builder()
                         .method(Method::GET)
@@ -1424,7 +1440,7 @@ impl EditPredictionStore {
     }
 
     async fn handle_rejected_predictions(
-        rx: UnboundedReceiver<EditPredictionRejection>,
+        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: Version,
@@ -1433,7 +1449,11 @@ impl EditPredictionStore {
         let mut rx = std::pin::pin!(rx.peekable());
         let mut batched = Vec::new();
 
-        while let Some(rejection) = rx.next().await {
+        while let Some(EditPredictionRejectionPayload {
+            rejection,
+            organization_id,
+        }) = rx.next().await
+        {
             batched.push(rejection);
 
             if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
@@ -1471,6 +1491,7 @@ impl EditPredictionStore {
                 },
                 client.clone(),
                 llm_token.clone(),
+                organization_id,
                 app_version.clone(),
                 true,
             )
@@ -1676,13 +1697,23 @@ impl EditPredictionStore {
                     all_language_settings(None, cx).edit_predictions.provider,
                     EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
                 );
+
                 if is_cloud {
+                    let organization_id = self
+                        .user_store
+                        .read(cx)
+                        .current_organization()
+                        .map(|organization| organization.id.clone());
+
                     self.reject_predictions_tx
-                        .unbounded_send(EditPredictionRejection {
-                            request_id: prediction_id.to_string(),
-                            reason,
-                            was_shown,
-                            model_version,
+                        .unbounded_send(EditPredictionRejectionPayload {
+                            rejection: EditPredictionRejection {
+                                request_id: prediction_id.to_string(),
+                                reason,
+                                was_shown,
+                                model_version,
+                            },
+                            organization_id,
                         })
                         .log_err();
                 }
@@ -2337,6 +2368,7 @@ impl EditPredictionStore {
         client: Arc<Client>,
         custom_url: Option<Arc<Url>>,
         llm_token: LlmApiToken,
+        organization_id: Option<OrganizationId>,
         app_version: Version,
     ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
         let url = if let Some(custom_url) = custom_url {
@@ -2356,6 +2388,7 @@ impl EditPredictionStore {
             },
             client,
             llm_token,
+            organization_id,
             app_version,
             true,
         )
@@ -2366,6 +2399,7 @@ impl EditPredictionStore {
         input: ZetaPromptInput,
         client: Arc<Client>,
         llm_token: LlmApiToken,
+        organization_id: Option<OrganizationId>,
         app_version: Version,
         trigger: PredictEditsRequestTrigger,
     ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
@@ -2388,6 +2422,7 @@ impl EditPredictionStore {
             },
             client,
             llm_token,
+            organization_id,
             app_version,
             true,
         )
@@ -2441,6 +2476,7 @@ impl EditPredictionStore {
         build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
         client: Arc<Client>,
         llm_token: LlmApiToken,
+        organization_id: Option<OrganizationId>,
         app_version: Version,
         require_auth: bool,
     ) -> Result<(Res, Option<EditPredictionUsage>)>
@@ -2450,9 +2486,12 @@ impl EditPredictionStore {
         let http_client = client.http_client();
 
         let mut token = if require_auth {
-            Some(llm_token.acquire(&client).await?)
+            Some(llm_token.acquire(&client, organization_id.clone()).await?)
         } else {
-            llm_token.acquire(&client).await.ok()
+            llm_token
+                .acquire(&client, organization_id.clone())
+                .await
+                .ok()
         };
         let mut did_retry = false;
 
@@ -2494,7 +2533,7 @@ impl EditPredictionStore {
                 return Ok((serde_json::from_slice(&body)?, usage));
             } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
                 did_retry = true;
-                token = Some(llm_token.refresh(&client).await?);
+                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;

crates/edit_prediction/src/zeta.rs 🔗

@@ -66,6 +66,11 @@ pub fn request_prediction_with_zeta(
 
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
+    let organization_id = store
+        .user_store
+        .read(cx)
+        .current_organization()
+        .map(|organization| organization.id.clone());
     let app_version = AppVersion::global(cx);
 
     let request_task = cx.background_spawn({
@@ -201,6 +206,7 @@ pub fn request_prediction_with_zeta(
                         client,
                         None,
                         llm_token,
+                        organization_id,
                         app_version,
                     )
                     .await?;
@@ -219,6 +225,7 @@ pub fn request_prediction_with_zeta(
                         prompt_input.clone(),
                         client,
                         llm_token,
+                        organization_id,
                         app_version,
                         trigger,
                     )
@@ -430,6 +437,11 @@ pub(crate) fn edit_prediction_accepted(
     let require_auth = custom_accept_url.is_none();
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
+    let organization_id = store
+        .user_store
+        .read(cx)
+        .current_organization()
+        .map(|organization| organization.id.clone());
     let app_version = AppVersion::global(cx);
 
     cx.background_spawn(async move {
@@ -454,6 +466,7 @@ pub(crate) fn edit_prediction_accepted(
             },
             client,
             llm_token,
+            organization_id,
             app_version,
             require_auth,
         )

crates/http_client/src/async_body.rs 🔗

@@ -7,6 +7,7 @@ use std::{
 use bytes::Bytes;
 use futures::AsyncRead;
 use http_body::{Body, Frame};
+use serde::Serialize;
 
 /// Based on the implementation of AsyncBody in
 /// <https://github.com/sagebind/isahc/blob/5c533f1ef4d6bdf1fd291b5103c22110f41d0bf0/src/body/mod.rs>.
@@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody {
     }
 }
 
+/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`.
+pub struct Json<T: Serialize>(pub T);
+
+impl<T: Serialize> From<Json<T>> for AsyncBody {
+    fn from(json: Json<T>) -> Self {
+        Self::from_bytes(
+            serde_json::to_vec(&json.0)
+                .expect("failed to serialize JSON")
+                .into(),
+        )
+    }
+}
+
 impl<T: Into<Self>> From<Option<T>> for AsyncBody {
     fn from(body: Option<T>) -> Self {
         match body {

crates/http_client/src/http_client.rs 🔗

@@ -5,7 +5,7 @@ pub mod github;
 pub mod github_download;
 
 pub use anyhow::{Result, anyhow};
-pub use async_body::{AsyncBody, Inner};
+pub use async_body::{AsyncBody, Inner, Json};
 use derive_more::Deref;
 use http::HeaderValue;
 pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};

crates/language_model/src/model/cloud_model.rs 🔗

@@ -4,6 +4,7 @@ use std::sync::Arc;
 use anyhow::{Context as _, Result};
 use client::Client;
 use cloud_api_client::ClientApiError;
+use cloud_api_types::OrganizationId;
 use cloud_api_types::websocket_protocol::MessageToClient;
 use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
 use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
@@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError {
 pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
 impl LlmApiToken {
-    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+    pub async fn acquire(
+        &self,
+        client: &Arc<Client>,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
         let lock = self.0.upgradable_read().await;
         if let Some(token) = lock.as_ref() {
             Ok(token.to_string())
         } else {
-            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
+            Self::fetch(
+                RwLockUpgradableReadGuard::upgrade(lock).await,
+                client,
+                organization_id,
+            )
+            .await
         }
     }
 
-    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
-        Self::fetch(self.0.write().await, client).await
+    pub async fn refresh(
+        &self,
+        client: &Arc<Client>,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        Self::fetch(self.0.write().await, client, organization_id).await
     }
 
     async fn fetch(
         mut lock: RwLockWriteGuard<'_, Option<String>>,
         client: &Arc<Client>,
+        organization_id: Option<OrganizationId>,
     ) -> Result<String> {
         let system_id = client
             .telemetry()
             .system_id()
             .map(|system_id| system_id.to_string());
 
-        let result = client.cloud_client().create_llm_token(system_id).await;
+        let result = client
+            .cloud_client()
+            .create_llm_token(system_id, organization_id)
+            .await;
         match result {
             Ok(response) => {
                 *lock = Some(response.token.0.clone());

crates/language_models/src/provider/cloud.rs 🔗

@@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode;
 use anyhow::{Context as _, Result, anyhow};
 use chrono::{DateTime, Utc};
 use client::{Client, UserStore, zed_urls};
-use cloud_api_types::Plan;
+use cloud_api_types::{OrganizationId, Plan};
 use cloud_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
     CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
@@ -122,15 +122,25 @@ impl State {
             recommended_models: Vec::new(),
             _fetch_models_task: cx.spawn(async move |this, cx| {
                 maybe!(async move {
-                    let (client, llm_api_token) = this
-                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
+                    let (client, llm_api_token, organization_id) =
+                        this.read_with(cx, |this, cx| {
+                            (
+                                client.clone(),
+                                this.llm_api_token.clone(),
+                                this.user_store
+                                    .read(cx)
+                                    .current_organization()
+                                    .map(|o| o.id.clone()),
+                            )
+                        })?;
 
                     while current_user.borrow().is_none() {
                         current_user.next().await;
                     }
 
                     let response =
-                        Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
+                        Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id)
+                            .await?;
                     this.update(cx, |this, cx| this.update_models(response, cx))?;
                     anyhow::Ok(())
                 })
@@ -146,9 +156,17 @@ impl State {
                 move |this, _listener, _event, cx| {
                     let client = this.client.clone();
                     let llm_api_token = this.llm_api_token.clone();
+                    let organization_id = this
+                        .user_store
+                        .read(cx)
+                        .current_organization()
+                        .map(|o| o.id.clone());
                     cx.spawn(async move |this, cx| {
-                        llm_api_token.refresh(&client).await?;
-                        let response = Self::fetch_models(client, llm_api_token).await?;
+                        llm_api_token
+                            .refresh(&client, organization_id.clone())
+                            .await?;
+                        let response =
+                            Self::fetch_models(client, llm_api_token, organization_id).await?;
                         this.update(cx, |this, cx| {
                             this.update_models(response, cx);
                         })
@@ -209,9 +227,10 @@ impl State {
     async fn fetch_models(
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
+        organization_id: Option<OrganizationId>,
     ) -> Result<ListModelsResponse> {
         let http_client = &client.http_client();
-        let token = llm_api_token.acquire(&client).await?;
+        let token = llm_api_token.acquire(&client, organization_id).await?;
 
         let request = http_client::Request::builder()
             .method(Method::GET)
@@ -273,11 +292,13 @@ impl CloudLanguageModelProvider {
         &self,
         model: Arc<cloud_llm_client::LanguageModel>,
         llm_api_token: LlmApiToken,
+        user_store: Entity<UserStore>,
     ) -> Arc<dyn LanguageModel> {
         Arc::new(CloudLanguageModel {
             id: LanguageModelId(SharedString::from(model.id.0.clone())),
             model,
             llm_api_token,
+            user_store,
             client: self.client.clone(),
             request_limiter: RateLimiter::new(4),
         })
@@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
     }
 
     fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
-        let default_model = self.state.read(cx).default_model.clone()?;
-        let llm_api_token = self.state.read(cx).llm_api_token.clone();
-        Some(self.create_language_model(default_model, llm_api_token))
+        let state = self.state.read(cx);
+        let default_model = state.default_model.clone()?;
+        let llm_api_token = state.llm_api_token.clone();
+        let user_store = state.user_store.clone();
+        Some(self.create_language_model(default_model, llm_api_token, user_store))
     }
 
     fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
-        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
-        let llm_api_token = self.state.read(cx).llm_api_token.clone();
-        Some(self.create_language_model(default_fast_model, llm_api_token))
+        let state = self.state.read(cx);
+        let default_fast_model = state.default_fast_model.clone()?;
+        let llm_api_token = state.llm_api_token.clone();
+        let user_store = state.user_store.clone();
+        Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
     }
 
     fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
-        let llm_api_token = self.state.read(cx).llm_api_token.clone();
-        self.state
-            .read(cx)
+        let state = self.state.read(cx);
+        let llm_api_token = state.llm_api_token.clone();
+        let user_store = state.user_store.clone();
+        state
             .recommended_models
             .iter()
             .cloned()
-            .map(|model| self.create_language_model(model, llm_api_token.clone()))
+            .map(|model| {
+                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+            })
             .collect()
     }
 
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
-        let llm_api_token = self.state.read(cx).llm_api_token.clone();
-        self.state
-            .read(cx)
+        let state = self.state.read(cx);
+        let llm_api_token = state.llm_api_token.clone();
+        let user_store = state.user_store.clone();
+        state
             .models
             .iter()
             .cloned()
-            .map(|model| self.create_language_model(model, llm_api_token.clone()))
+            .map(|model| {
+                self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+            })
             .collect()
     }
 
@@ -367,6 +398,7 @@ pub struct CloudLanguageModel {
     id: LanguageModelId,
     model: Arc<cloud_llm_client::LanguageModel>,
     llm_api_token: LlmApiToken,
+    user_store: Entity<UserStore>,
     client: Arc<Client>,
     request_limiter: RateLimiter,
 }
@@ -380,12 +412,15 @@ impl CloudLanguageModel {
     async fn perform_llm_completion(
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
+        organization_id: Option<OrganizationId>,
         app_version: Option<Version>,
         body: CompletionBody,
     ) -> Result<PerformLlmCompletionResponse> {
         let http_client = &client.http_client();
 
-        let mut token = llm_api_token.acquire(&client).await?;
+        let mut token = llm_api_token
+            .acquire(&client, organization_id.clone())
+            .await?;
         let mut refreshed_token = false;
 
         loop {
@@ -416,7 +451,9 @@ impl CloudLanguageModel {
             }
 
             if !refreshed_token && response.needs_llm_token_refresh() {
-                token = llm_api_token.refresh(&client).await?;
+                token = llm_api_token
+                    .refresh(&client, organization_id.clone())
+                    .await?;
                 refreshed_token = true;
                 continue;
             }
@@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel {
             cloud_llm_client::LanguageModelProvider::Google => {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
+                let organization_id = self
+                    .user_store
+                    .read(cx)
+                    .current_organization()
+                    .map(|o| o.id.clone());
                 let model_id = self.model.id.to_string();
                 let generate_content_request =
                     into_google(request, model_id.clone(), GoogleModelMode::Default);
                 async move {
                     let http_client = &client.http_client();
-                    let token = llm_api_token.acquire(&client).await?;
+                    let token = llm_api_token.acquire(&client, organization_id).await?;
 
                     let request_body = CountTokensBody {
                         provider: cloud_llm_client::LanguageModelProvider::Google,
@@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel {
         let prompt_id = request.prompt_id.clone();
         let intent = request.intent;
         let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
+        let user_store = self.user_store.clone();
+        let organization_id = cx.update(|cx| {
+            user_store
+                .read(cx)
+                .current_organization()
+                .map(|o| o.id.clone())
+        });
         let thinking_allowed = request.thinking_allowed;
         let enable_thinking = thinking_allowed && self.model.supports_thinking;
         let provider_name = provider_name(&self.model.provider);
@@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel {
 
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
+                let organization_id = organization_id.clone();
                 let future = self.request_limiter.stream(async move {
                     let PerformLlmCompletionResponse {
                         response,
@@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel {
                     } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
+                        organization_id,
                         app_version,
                         CompletionBody {
                             thread_id,
@@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel {
             cloud_llm_client::LanguageModelProvider::OpenAi => {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
+                let organization_id = organization_id.clone();
                 let effort = request
                     .thinking_effort
                     .as_ref()
@@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel {
                     } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
+                        organization_id,
                         app_version,
                         CompletionBody {
                             thread_id,
@@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel {
                     None,
                 );
                 let llm_api_token = self.llm_api_token.clone();
+                let organization_id = organization_id.clone();
                 let future = self.request_limiter.stream(async move {
                     let PerformLlmCompletionResponse {
                         response,
@@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel {
                     } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
+                        organization_id,
                         app_version,
                         CompletionBody {
                             thread_id,
@@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel {
                     } = Self::perform_llm_completion(
                         client.clone(),
                         llm_api_token,
+                        organization_id,
                         app_version,
                         CompletionBody {
                             thread_id,

crates/web_search_providers/Cargo.toml 🔗

@@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
 [dependencies]
 anyhow.workspace = true
 client.workspace = true
+cloud_api_types.workspace = true
 cloud_llm_client.workspace = true
 futures.workspace = true
 gpui.workspace = true

crates/web_search_providers/src/cloud.rs 🔗

@@ -1,7 +1,8 @@
 use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
-use client::Client;
+use client::{Client, UserStore};
+use cloud_api_types::OrganizationId;
 use cloud_llm_client::{WebSearchBody, WebSearchResponse};
 use futures::AsyncReadExt as _;
 use gpui::{App, AppContext, Context, Entity, Subscription, Task};
@@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider {
 }
 
 impl CloudWebSearchProvider {
-    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
-        let state = cx.new(|cx| State::new(client, cx));
+    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
+        let state = cx.new(|cx| State::new(client, user_store, cx));
 
         Self { state }
     }
@@ -23,24 +24,31 @@ impl CloudWebSearchProvider {
 
 pub struct State {
     client: Arc<Client>,
+    user_store: Entity<UserStore>,
     llm_api_token: LlmApiToken,
     _llm_token_subscription: Subscription,
 }
 
 impl State {
-    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 
         Self {
             client,
+            user_store,
             llm_api_token: LlmApiToken::default(),
             _llm_token_subscription: cx.subscribe(
                 &refresh_llm_token_listener,
                 |this, _, _event, cx| {
                     let client = this.client.clone();
                     let llm_api_token = this.llm_api_token.clone();
+                    let organization_id = this
+                        .user_store
+                        .read(cx)
+                        .current_organization()
+                        .map(|o| o.id.clone());
                     cx.spawn(async move |_this, _cx| {
-                        llm_api_token.refresh(&client).await?;
+                        llm_api_token.refresh(&client, organization_id).await?;
                         anyhow::Ok(())
                     })
                     .detach_and_log_err(cx);
@@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider {
         let state = self.state.read(cx);
         let client = state.client.clone();
         let llm_api_token = state.llm_api_token.clone();
+        let organization_id = state
+            .user_store
+            .read(cx)
+            .current_organization()
+            .map(|o| o.id.clone());
         let body = WebSearchBody { query };
-        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
+        cx.background_spawn(async move {
+            perform_web_search(client, llm_api_token, organization_id, body).await
+        })
     }
 }
 
 async fn perform_web_search(
     client: Arc<Client>,
     llm_api_token: LlmApiToken,
+    organization_id: Option<OrganizationId>,
     body: WebSearchBody,
 ) -> Result<WebSearchResponse> {
     const MAX_RETRIES: usize = 3;
 
     let http_client = &client.http_client();
     let mut retries_remaining = MAX_RETRIES;
-    let mut token = llm_api_token.acquire(&client).await?;
+    let mut token = llm_api_token
+        .acquire(&client, organization_id.clone())
+        .await?;
 
     loop {
         if retries_remaining == 0 {
@@ -100,7 +118,9 @@ async fn perform_web_search(
             response.body_mut().read_to_string(&mut body).await?;
             return Ok(serde_json::from_str(&body)?);
         } else if response.needs_llm_token_refresh() {
-            token = llm_api_token.refresh(&client).await?;
+            token = llm_api_token
+                .refresh(&client, organization_id.clone())
+                .await?;
             retries_remaining -= 1;
         } else {
             // For now we will only retry if the LLM token is expired,

crates/web_search_providers/src/web_search_providers.rs 🔗

@@ -1,26 +1,28 @@
 mod cloud;
 
-use client::Client;
+use client::{Client, UserStore};
 use gpui::{App, Context, Entity};
 use language_model::LanguageModelRegistry;
 use std::sync::Arc;
 use web_search::{WebSearchProviderId, WebSearchRegistry};
 
-pub fn init(client: Arc<Client>, cx: &mut App) {
+pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
     let registry = WebSearchRegistry::global(cx);
     registry.update(cx, |registry, cx| {
-        register_web_search_providers(registry, client, cx);
+        register_web_search_providers(registry, client, user_store, cx);
     });
 }
 
 fn register_web_search_providers(
     registry: &mut WebSearchRegistry,
     client: Arc<Client>,
+    user_store: Entity<UserStore>,
     cx: &mut Context<WebSearchRegistry>,
 ) {
     register_zed_web_search_provider(
         registry,
         client.clone(),
+        user_store.clone(),
         &LanguageModelRegistry::global(cx),
         cx,
     );
@@ -29,7 +31,13 @@ fn register_web_search_providers(
         &LanguageModelRegistry::global(cx),
         move |this, registry, event, cx| {
             if let language_model::Event::DefaultModelChanged = event {
-                register_zed_web_search_provider(this, client.clone(), &registry, cx)
+                register_zed_web_search_provider(
+                    this,
+                    client.clone(),
+                    user_store.clone(),
+                    &registry,
+                    cx,
+                )
             }
         },
     )
@@ -39,6 +47,7 @@ fn register_web_search_providers(
 fn register_zed_web_search_provider(
     registry: &mut WebSearchRegistry,
     client: Arc<Client>,
+    user_store: Entity<UserStore>,
     language_model_registry: &Entity<LanguageModelRegistry>,
     cx: &mut Context<WebSearchRegistry>,
 ) {
@@ -47,7 +56,10 @@ fn register_zed_web_search_provider(
         .default_model()
         .is_some_and(|default| default.is_provided_by_zed());
     if using_zed_provider {
-        registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
+        registry.register_provider(
+            cloud::CloudWebSearchProvider::new(client, user_store, cx),
+            cx,
+        )
     } else {
         registry.unregister_provider(WebSearchProviderId(
             cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),

crates/zed/src/main.rs 🔗

@@ -645,7 +645,7 @@ fn main() {
         zed::remote_debug::init(cx);
         edit_prediction_ui::init(cx);
         web_search::init(cx);
-        web_search_providers::init(app_state.client.clone(), cx);
+        web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
         snippet_provider::init(cx);
         edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx);
         let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);

crates/zed/src/zed.rs 🔗

@@ -5021,7 +5021,7 @@ mod tests {
             language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             web_search::init(cx);
             git_graph::init(cx);
-            web_search_providers::init(app_state.client.clone(), cx);
+            web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
             let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
             project::AgentRegistryStore::init_global(
                 cx,