Cargo.lock 🔗
@@ -19851,6 +19851,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
+ "cloud_api_types",
"cloud_llm_client",
"futures 0.3.31",
"gpui",
Tom Houlé created
This is already expected on the cloud side. This lets us know under
which organization the user is logged in when requesting an llm_api
token.
Closes CLO-337
Release Notes:
- N/A
Cargo.lock | 1
crates/cloud_api_client/src/cloud_api_client.rs | 10
crates/cloud_api_types/src/cloud_api_types.rs | 6
crates/edit_prediction/src/edit_prediction.rs | 67 +++++-
crates/edit_prediction/src/zeta.rs | 13 +
crates/http_client/src/async_body.rs | 14 +
crates/http_client/src/http_client.rs | 2
crates/language_model/src/model/cloud_model.rs | 28 ++
crates/language_models/src/provider/cloud.rs | 104 ++++++++--
crates/web_search_providers/Cargo.toml | 1
crates/web_search_providers/src/cloud.rs | 36 ++
crates/web_search_providers/src/web_search_providers.rs | 22 +
crates/zed/src/main.rs | 2
crates/zed/src/zed.rs | 2
14 files changed, 247 insertions(+), 61 deletions(-)
@@ -19851,6 +19851,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
+ "cloud_api_types",
"cloud_llm_client",
"futures 0.3.31",
"gpui",
@@ -9,7 +9,9 @@ use futures::AsyncReadExt as _;
use gpui::{App, Task};
use gpui_tokio::Tokio;
use http_client::http::request;
-use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode};
+use http_client::{
+ AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode,
+};
use parking_lot::RwLock;
use thiserror::Error;
use yawc::WebSocket;
@@ -141,6 +143,7 @@ impl CloudApiClient {
pub async fn create_llm_token(
&self,
system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
) -> Result<CreateLlmTokenResponse, ClientApiError> {
let request_builder = Request::builder()
.method(Method::POST)
@@ -153,7 +156,10 @@ impl CloudApiClient {
builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id)
});
- let request = self.build_request(request_builder, AsyncBody::default())?;
+ let request = self.build_request(
+ request_builder,
+ Json(CreateLlmTokenBody { organization_id }),
+ )?;
let mut response = self.http_client.send(request).await?;
@@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse {
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LlmToken(pub String);
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+pub struct CreateLlmTokenBody {
+ #[serde(default)]
+ pub organization_id: Option<OrganizationId>,
+}
+
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct CreateLlmTokenResponse {
pub token: LlmToken,
@@ -1,7 +1,7 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_api_types::SubmitEditPredictionFeedbackBody;
+use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
};
@@ -143,7 +143,7 @@ pub struct EditPredictionStore {
pub sweep_ai: SweepAi,
pub mercury: Mercury,
data_collection_choice: DataCollectionChoice,
- reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
+ reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
settled_predictions_tx: mpsc::UnboundedSender<Instant>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
@@ -151,6 +151,11 @@ pub struct EditPredictionStore {
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
}
+pub(crate) struct EditPredictionRejectionPayload {
+ rejection: EditPredictionRejection,
+ organization_id: Option<OrganizationId>,
+}
+
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum EditPredictionModel {
Zeta,
@@ -719,8 +724,13 @@ impl EditPredictionStore {
|this, _listener, _event, cx| {
let client = this.client.clone();
let llm_token = this.llm_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client).await?;
+ llm_token.refresh(&client, organization_id).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
@@ -781,11 +791,17 @@ impl EditPredictionStore {
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+
cx.spawn(async move |this, cx| {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
- let token = llm_token.acquire(&client).await?;
+ let token = llm_token.acquire(&client, organization_id).await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -1424,7 +1440,7 @@ impl EditPredictionStore {
}
async fn handle_rejected_predictions(
- rx: UnboundedReceiver<EditPredictionRejection>,
+ rx: UnboundedReceiver<EditPredictionRejectionPayload>,
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
@@ -1433,7 +1449,11 @@ impl EditPredictionStore {
let mut rx = std::pin::pin!(rx.peekable());
let mut batched = Vec::new();
- while let Some(rejection) = rx.next().await {
+ while let Some(EditPredictionRejectionPayload {
+ rejection,
+ organization_id,
+ }) = rx.next().await
+ {
batched.push(rejection);
if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
@@ -1471,6 +1491,7 @@ impl EditPredictionStore {
},
client.clone(),
llm_token.clone(),
+ organization_id,
app_version.clone(),
true,
)
@@ -1676,13 +1697,23 @@ impl EditPredictionStore {
all_language_settings(None, cx).edit_predictions.provider,
EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
);
+
if is_cloud {
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+
self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- model_version,
+ .unbounded_send(EditPredictionRejectionPayload {
+ rejection: EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ model_version,
+ },
+ organization_id,
})
.log_err();
}
@@ -2337,6 +2368,7 @@ impl EditPredictionStore {
client: Arc<Client>,
custom_url: Option<Arc<Url>>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
let url = if let Some(custom_url) = custom_url {
@@ -2356,6 +2388,7 @@ impl EditPredictionStore {
},
client,
llm_token,
+ organization_id,
app_version,
true,
)
@@ -2366,6 +2399,7 @@ impl EditPredictionStore {
input: ZetaPromptInput,
client: Arc<Client>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
trigger: PredictEditsRequestTrigger,
) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
@@ -2388,6 +2422,7 @@ impl EditPredictionStore {
},
client,
llm_token,
+ organization_id,
app_version,
true,
)
@@ -2441,6 +2476,7 @@ impl EditPredictionStore {
build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
client: Arc<Client>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
require_auth: bool,
) -> Result<(Res, Option<EditPredictionUsage>)>
@@ -2450,9 +2486,12 @@ impl EditPredictionStore {
let http_client = client.http_client();
let mut token = if require_auth {
- Some(llm_token.acquire(&client).await?)
+ Some(llm_token.acquire(&client, organization_id.clone()).await?)
} else {
- llm_token.acquire(&client).await.ok()
+ llm_token
+ .acquire(&client, organization_id.clone())
+ .await
+ .ok()
};
let mut did_retry = false;
@@ -2494,7 +2533,7 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
- token = Some(llm_token.refresh(&client).await?);
+ token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -66,6 +66,11 @@ pub fn request_prediction_with_zeta(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
+ let organization_id = store
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
let app_version = AppVersion::global(cx);
let request_task = cx.background_spawn({
@@ -201,6 +206,7 @@ pub fn request_prediction_with_zeta(
client,
None,
llm_token,
+ organization_id,
app_version,
)
.await?;
@@ -219,6 +225,7 @@ pub fn request_prediction_with_zeta(
prompt_input.clone(),
client,
llm_token,
+ organization_id,
app_version,
trigger,
)
@@ -430,6 +437,11 @@ pub(crate) fn edit_prediction_accepted(
let require_auth = custom_accept_url.is_none();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
+ let organization_id = store
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
let app_version = AppVersion::global(cx);
cx.background_spawn(async move {
@@ -454,6 +466,7 @@ pub(crate) fn edit_prediction_accepted(
},
client,
llm_token,
+ organization_id,
app_version,
require_auth,
)
@@ -7,6 +7,7 @@ use std::{
use bytes::Bytes;
use futures::AsyncRead;
use http_body::{Body, Frame};
+use serde::Serialize;
/// Based on the implementation of AsyncBody in
/// <https://github.com/sagebind/isahc/blob/5c533f1ef4d6bdf1fd291b5103c22110f41d0bf0/src/body/mod.rs>.
@@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody {
}
}
+/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`.
+pub struct Json<T: Serialize>(pub T);
+
+impl<T: Serialize> From<Json<T>> for AsyncBody {
+ fn from(json: Json<T>) -> Self {
+ Self::from_bytes(
+ serde_json::to_vec(&json.0)
+ .expect("failed to serialize JSON")
+ .into(),
+ )
+ }
+}
+
impl<T: Into<Self>> From<Option<T>> for AsyncBody {
fn from(body: Option<T>) -> Self {
match body {
@@ -5,7 +5,7 @@ pub mod github;
pub mod github_download;
pub use anyhow::{Result, anyhow};
-pub use async_body::{AsyncBody, Inner};
+pub use async_body::{AsyncBody, Inner, Json};
use derive_more::Deref;
use http::HeaderValue;
pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
@@ -4,6 +4,7 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
use cloud_api_client::ClientApiError;
+use cloud_api_types::OrganizationId;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
@@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
- pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+ pub async fn acquire(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
- Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
+ Self::fetch(
+ RwLockUpgradableReadGuard::upgrade(lock).await,
+ client,
+ organization_id,
+ )
+ .await
}
}
- pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
- Self::fetch(self.0.write().await, client).await
+ pub async fn refresh(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ Self::fetch(self.0.write().await, client, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = client
.telemetry()
.system_id()
.map(|system_id| system_id.to_string());
- let result = client.cloud_client().create_llm_token(system_id).await;
+ let result = client
+ .cloud_client()
+ .create_llm_token(system_id, organization_id)
+ .await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{Client, UserStore, zed_urls};
-use cloud_api_types::Plan;
+use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
@@ -122,15 +122,25 @@ impl State {
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
- let (client, llm_api_token) = this
- .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
+ let (client, llm_api_token, organization_id) =
+ this.read_with(cx, |this, cx| {
+ (
+ client.clone(),
+ this.llm_api_token.clone(),
+ this.user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone()),
+ )
+ })?;
while current_user.borrow().is_none() {
current_user.next().await;
}
let response =
- Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
+ Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id)
+ .await?;
this.update(cx, |this, cx| this.update_models(response, cx))?;
anyhow::Ok(())
})
@@ -146,9 +156,17 @@ impl State {
move |this, _listener, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
cx.spawn(async move |this, cx| {
- llm_api_token.refresh(&client).await?;
- let response = Self::fetch_models(client, llm_api_token).await?;
+ llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
+ let response =
+ Self::fetch_models(client, llm_api_token, organization_id).await?;
this.update(cx, |this, cx| {
this.update_models(response, cx);
})
@@ -209,9 +227,10 @@ impl State {
async fn fetch_models(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client).await?;
+ let token = llm_api_token.acquire(&client, organization_id).await?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -273,11 +292,13 @@ impl CloudLanguageModelProvider {
&self,
model: Arc<cloud_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
+ user_store: Entity<UserStore>,
) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel {
id: LanguageModelId(SharedString::from(model.id.0.clone())),
model,
llm_api_token,
+ user_store,
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
})
@@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let default_model = self.state.read(cx).default_model.clone()?;
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- Some(self.create_language_model(default_model, llm_api_token))
+ let state = self.state.read(cx);
+ let default_model = state.default_model.clone()?;
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ Some(self.create_language_model(default_model, llm_api_token, user_store))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- Some(self.create_language_model(default_fast_model, llm_api_token))
+ let state = self.state.read(cx);
+ let default_fast_model = state.default_fast_model.clone()?;
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
}
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- self.state
- .read(cx)
+ let state = self.state.read(cx);
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ state
.recommended_models
.iter()
.cloned()
- .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .map(|model| {
+ self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+ })
.collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- self.state
- .read(cx)
+ let state = self.state.read(cx);
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ state
.models
.iter()
.cloned()
- .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .map(|model| {
+ self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+ })
.collect()
}
@@ -367,6 +398,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId,
model: Arc<cloud_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
+ user_store: Entity<UserStore>,
client: Arc<Client>,
request_limiter: RateLimiter,
}
@@ -380,12 +412,15 @@ impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Option<Version>,
body: CompletionBody,
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
- let mut token = llm_api_token.acquire(&client).await?;
+ let mut token = llm_api_token
+ .acquire(&client, organization_id.clone())
+ .await?;
let mut refreshed_token = false;
loop {
@@ -416,7 +451,9 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
- token = llm_api_token.refresh(&client).await?;
+ token = llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
refreshed_token = true;
continue;
}
@@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel {
cloud_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
let model_id = self.model.id.to_string();
let generate_content_request =
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client).await?;
+ let token = llm_api_token.acquire(&client, organization_id).await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
@@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel {
let prompt_id = request.prompt_id.clone();
let intent = request.intent;
let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
+ let user_store = self.user_store.clone();
+ let organization_id = cx.update(|cx| {
+ user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone())
+ });
let thinking_allowed = request.thinking_allowed;
let enable_thinking = thinking_allowed && self.model.supports_thinking;
let provider_name = provider_name(&self.model.provider);
@@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
@@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel {
cloud_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let effort = request
.thinking_effort
.as_ref()
@@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel {
None,
);
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
@@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
+cloud_api_types.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -1,7 +1,8 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
-use client::Client;
+use client::{Client, UserStore};
+use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
@@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider {
}
impl CloudWebSearchProvider {
- pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State::new(client, cx));
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State::new(client, user_store, cx));
Self { state }
}
@@ -23,24 +24,31 @@ impl CloudWebSearchProvider {
pub struct State {
client: Arc<Client>,
+ user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
_llm_token_subscription: Subscription,
}
impl State {
- pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client,
+ user_store,
llm_api_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
cx.spawn(async move |_this, _cx| {
- llm_api_token.refresh(&client).await?;
+ llm_api_token.refresh(&client, organization_id).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
@@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider {
let state = self.state.read(cx);
let client = state.client.clone();
let llm_api_token = state.llm_api_token.clone();
+ let organization_id = state
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
let body = WebSearchBody { query };
- cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
+ cx.background_spawn(async move {
+ perform_web_search(client, llm_api_token, organization_id, body).await
+ })
}
}
async fn perform_web_search(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
const MAX_RETRIES: usize = 3;
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
- let mut token = llm_api_token.acquire(&client).await?;
+ let mut token = llm_api_token
+ .acquire(&client, organization_id.clone())
+ .await?;
loop {
if retries_remaining == 0 {
@@ -100,7 +118,9 @@ async fn perform_web_search(
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
- token = llm_api_token.refresh(&client).await?;
+ token = llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
retries_remaining -= 1;
} else {
// For now we will only retry if the LLM token is expired,
@@ -1,26 +1,28 @@
mod cloud;
-use client::Client;
+use client::{Client, UserStore};
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use std::sync::Arc;
use web_search::{WebSearchProviderId, WebSearchRegistry};
-pub fn init(client: Arc<Client>, cx: &mut App) {
+pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_web_search_providers(registry, client, cx);
+ register_web_search_providers(registry, client, user_store, cx);
});
}
fn register_web_search_providers(
registry: &mut WebSearchRegistry,
client: Arc<Client>,
+ user_store: Entity<UserStore>,
cx: &mut Context<WebSearchRegistry>,
) {
register_zed_web_search_provider(
registry,
client.clone(),
+ user_store.clone(),
&LanguageModelRegistry::global(cx),
cx,
);
@@ -29,7 +31,13 @@ fn register_web_search_providers(
&LanguageModelRegistry::global(cx),
move |this, registry, event, cx| {
if let language_model::Event::DefaultModelChanged = event {
- register_zed_web_search_provider(this, client.clone(), ®istry, cx)
+ register_zed_web_search_provider(
+ this,
+ client.clone(),
+ user_store.clone(),
+ ®istry,
+ cx,
+ )
}
},
)
@@ -39,6 +47,7 @@ fn register_web_search_providers(
fn register_zed_web_search_provider(
registry: &mut WebSearchRegistry,
client: Arc<Client>,
+ user_store: Entity<UserStore>,
language_model_registry: &Entity<LanguageModelRegistry>,
cx: &mut Context<WebSearchRegistry>,
) {
@@ -47,7 +56,10 @@ fn register_zed_web_search_provider(
.default_model()
.is_some_and(|default| default.is_provided_by_zed());
if using_zed_provider {
- registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
+ registry.register_provider(
+ cloud::CloudWebSearchProvider::new(client, user_store, cx),
+ cx,
+ )
} else {
registry.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
@@ -645,7 +645,7 @@ fn main() {
zed::remote_debug::init(cx);
edit_prediction_ui::init(cx);
web_search::init(cx);
- web_search_providers::init(app_state.client.clone(), cx);
+ web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
snippet_provider::init(cx);
edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
@@ -5021,7 +5021,7 @@ mod tests {
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);
- web_search_providers::init(app_state.client.clone(), cx);
+ web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
project::AgentRegistryStore::init_global(
cx,