cloud_api_client.rs

  1mod websocket;
  2
  3use std::sync::Arc;
  4
  5use anyhow::{Context, Result, anyhow};
  6use cloud_api_types::websocket_protocol::{PROTOCOL_VERSION, PROTOCOL_VERSION_HEADER_NAME};
  7pub use cloud_api_types::*;
  8use futures::AsyncReadExt as _;
  9use gpui::{App, Task};
 10use gpui_tokio::Tokio;
 11use http_client::http::request;
 12use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode};
 13use parking_lot::RwLock;
 14use yawc::WebSocket;
 15
 16use crate::websocket::Connection;
 17
 18struct Credentials {
 19    user_id: u32,
 20    access_token: String,
 21}
 22
 23pub struct CloudApiClient {
 24    credentials: RwLock<Option<Credentials>>,
 25    http_client: Arc<HttpClientWithUrl>,
 26}
 27
 28impl CloudApiClient {
 29    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 30        Self {
 31            credentials: RwLock::new(None),
 32            http_client,
 33        }
 34    }
 35
 36    pub fn has_credentials(&self) -> bool {
 37        self.credentials.read().is_some()
 38    }
 39
 40    pub fn set_credentials(&self, user_id: u32, access_token: String) {
 41        *self.credentials.write() = Some(Credentials {
 42            user_id,
 43            access_token,
 44        });
 45    }
 46
 47    pub fn clear_credentials(&self) {
 48        *self.credentials.write() = None;
 49    }
 50
 51    fn build_request(
 52        &self,
 53        req: request::Builder,
 54        body: impl Into<AsyncBody>,
 55    ) -> Result<Request<AsyncBody>> {
 56        let credentials = self.credentials.read();
 57        let credentials = credentials.as_ref().context("no credentials provided")?;
 58        build_request(req, body, credentials)
 59    }
 60
 61    pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
 62        let request = self.build_request(
 63            Request::builder().method(Method::GET).uri(
 64                self.http_client
 65                    .build_zed_cloud_url("/client/users/me", &[])?
 66                    .as_ref(),
 67            ),
 68            AsyncBody::default(),
 69        )?;
 70
 71        let mut response = self.http_client.send(request).await?;
 72
 73        if !response.status().is_success() {
 74            let mut body = String::new();
 75            response.body_mut().read_to_string(&mut body).await?;
 76
 77            anyhow::bail!(
 78                "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
 79                response.status()
 80            )
 81        }
 82
 83        let mut body = String::new();
 84        response.body_mut().read_to_string(&mut body).await?;
 85
 86        Ok(serde_json::from_str(&body)?)
 87    }
 88
 89    pub fn connect(&self, cx: &App) -> Result<Task<Result<Connection>>> {
 90        let mut connect_url = self
 91            .http_client
 92            .build_zed_cloud_url("/client/users/connect", &[])?;
 93        connect_url
 94            .set_scheme(match connect_url.scheme() {
 95                "https" => "wss",
 96                "http" => "ws",
 97                scheme => Err(anyhow!("invalid URL scheme: {scheme}"))?,
 98            })
 99            .map_err(|_| anyhow!("failed to set URL scheme"))?;
100
101        let credentials = self.credentials.read();
102        let credentials = credentials.as_ref().context("no credentials provided")?;
103        let authorization_header = format!("{} {}", credentials.user_id, credentials.access_token);
104
105        Ok(cx.spawn(async move |cx| {
106            let handle = cx
107                .update(|cx| Tokio::handle(cx))
108                .ok()
109                .context("failed to get Tokio handle")?;
110            let _guard = handle.enter();
111
112            let ws = WebSocket::connect(connect_url)
113                .with_request(
114                    request::Builder::new()
115                        .header("Authorization", authorization_header)
116                        .header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION.to_string()),
117                )
118                .await?;
119
120            Ok(Connection::new(ws))
121        }))
122    }
123
124    pub async fn accept_terms_of_service(&self) -> Result<AcceptTermsOfServiceResponse> {
125        let request = self.build_request(
126            Request::builder().method(Method::POST).uri(
127                self.http_client
128                    .build_zed_cloud_url("/client/terms_of_service/accept", &[])?
129                    .as_ref(),
130            ),
131            AsyncBody::default(),
132        )?;
133
134        let mut response = self.http_client.send(request).await?;
135
136        if !response.status().is_success() {
137            let mut body = String::new();
138            response.body_mut().read_to_string(&mut body).await?;
139
140            anyhow::bail!(
141                "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}",
142                response.status()
143            )
144        }
145
146        let mut body = String::new();
147        response.body_mut().read_to_string(&mut body).await?;
148
149        Ok(serde_json::from_str(&body)?)
150    }
151
152    pub async fn create_llm_token(
153        &self,
154        system_id: Option<String>,
155    ) -> Result<CreateLlmTokenResponse> {
156        let mut request_builder = Request::builder().method(Method::POST).uri(
157            self.http_client
158                .build_zed_cloud_url("/client/llm_tokens", &[])?
159                .as_ref(),
160        );
161
162        if let Some(system_id) = system_id {
163            request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id);
164        }
165
166        let request = self.build_request(request_builder, AsyncBody::default())?;
167
168        let mut response = self.http_client.send(request).await?;
169
170        if !response.status().is_success() {
171            let mut body = String::new();
172            response.body_mut().read_to_string(&mut body).await?;
173
174            anyhow::bail!(
175                "Failed to create LLM token.\nStatus: {:?}\nBody: {body}",
176                response.status()
177            )
178        }
179
180        let mut body = String::new();
181        response.body_mut().read_to_string(&mut body).await?;
182
183        Ok(serde_json::from_str(&body)?)
184    }
185
186    pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result<bool> {
187        let request = build_request(
188            Request::builder().method(Method::GET).uri(
189                self.http_client
190                    .build_zed_cloud_url("/client/users/me", &[])?
191                    .as_ref(),
192            ),
193            AsyncBody::default(),
194            &Credentials {
195                user_id,
196                access_token: access_token.into(),
197            },
198        )?;
199
200        let mut response = self.http_client.send(request).await?;
201
202        if response.status().is_success() {
203            Ok(true)
204        } else {
205            let mut body = String::new();
206            response.body_mut().read_to_string(&mut body).await?;
207            if response.status() == StatusCode::UNAUTHORIZED {
208                Ok(false)
209            } else {
210                Err(anyhow!(
211                    "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
212                    response.status()
213                ))
214            }
215        }
216    }
217}
218
219fn build_request(
220    req: request::Builder,
221    body: impl Into<AsyncBody>,
222    credentials: &Credentials,
223) -> Result<Request<AsyncBody>> {
224    Ok(req
225        .header("Content-Type", "application/json")
226        .header(
227            "Authorization",
228            format!("{} {}", credentials.user_id, credentials.access_token),
229        )
230        .body(body.into())?)
231}