user.rs

   1use super::{Client, Status, TypedEnvelope, proto};
   2use anyhow::{Context as _, Result, anyhow};
   3use chrono::{DateTime, Utc};
   4use cloud_api_client::websocket_protocol::MessageToClient;
   5use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
   6use cloud_llm_client::{
   7    EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
   8    MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
   9};
  10use collections::{HashMap, HashSet, hash_map::Entry};
  11use derive_more::Deref;
  12use feature_flags::FeatureFlagAppExt;
  13use futures::{Future, StreamExt, channel::mpsc};
  14use gpui::{
  15    App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity,
  16};
  17use http_client::http::{HeaderMap, HeaderValue};
  18use postage::{sink::Sink, watch};
  19use rpc::proto::{RequestMessage, UsersResponse};
  20use std::{
  21    str::FromStr as _,
  22    sync::{Arc, Weak},
  23};
  24use text::ReplicaId;
  25use util::{ResultExt, TryFutureExt as _};
  26
  27pub type UserId = u64;
  28
  29#[derive(
  30    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  31)]
  32pub struct ChannelId(pub u64);
  33
  34impl std::fmt::Display for ChannelId {
  35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  36        self.0.fmt(f)
  37    }
  38}
  39
  40#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
  41pub struct ProjectId(pub u64);
  42
  43impl ProjectId {
  44    pub fn to_proto(&self) -> u64 {
  45        self.0
  46    }
  47}
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  51)]
  52pub struct DevServerProjectId(pub u64);
  53
  54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  55pub struct ParticipantIndex(pub u32);
  56
  57#[derive(Default, Debug)]
  58pub struct User {
  59    pub id: UserId,
  60    pub github_login: SharedString,
  61    pub avatar_uri: SharedUri,
  62    pub name: Option<String>,
  63}
  64
  65#[derive(Clone, Debug, PartialEq, Eq)]
  66pub struct Collaborator {
  67    pub peer_id: proto::PeerId,
  68    pub replica_id: ReplicaId,
  69    pub user_id: UserId,
  70    pub is_host: bool,
  71    pub committer_name: Option<String>,
  72    pub committer_email: Option<String>,
  73}
  74
  75impl PartialOrd for User {
  76    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
  77        Some(self.cmp(other))
  78    }
  79}
  80
  81impl Ord for User {
  82    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
  83        self.github_login.cmp(&other.github_login)
  84    }
  85}
  86
  87impl PartialEq for User {
  88    fn eq(&self, other: &Self) -> bool {
  89        self.id == other.id && self.github_login == other.github_login
  90    }
  91}
  92
  93impl Eq for User {}
  94
  95#[derive(Debug, PartialEq)]
  96pub struct Contact {
  97    pub user: Arc<User>,
  98    pub online: bool,
  99    pub busy: bool,
 100}
 101
 102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 103pub enum ContactRequestStatus {
 104    None,
 105    RequestSent,
 106    RequestReceived,
 107    RequestAccepted,
 108}
 109
 110pub struct UserStore {
 111    users: HashMap<u64, Arc<User>>,
 112    by_github_login: HashMap<SharedString, u64>,
 113    participant_indices: HashMap<u64, ParticipantIndex>,
 114    update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
 115    model_request_usage: Option<ModelRequestUsage>,
 116    edit_prediction_usage: Option<EditPredictionUsage>,
 117    plan_info: Option<PlanInfo>,
 118    current_user: watch::Receiver<Option<Arc<User>>>,
 119    accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
 120    contacts: Vec<Arc<Contact>>,
 121    incoming_contact_requests: Vec<Arc<User>>,
 122    outgoing_contact_requests: Vec<Arc<User>>,
 123    pending_contact_requests: HashMap<u64, usize>,
 124    invite_info: Option<InviteInfo>,
 125    client: Weak<Client>,
 126    _maintain_contacts: Task<()>,
 127    _maintain_current_user: Task<Result<()>>,
 128    weak_self: WeakEntity<Self>,
 129}
 130
 131#[derive(Clone)]
 132pub struct InviteInfo {
 133    pub count: u32,
 134    pub url: Arc<str>,
 135}
 136
 137pub enum Event {
 138    Contact {
 139        user: Arc<User>,
 140        kind: ContactEventKind,
 141    },
 142    ShowContacts,
 143    ParticipantIndicesChanged,
 144    PrivateUserInfoUpdated,
 145    PlanUpdated,
 146}
 147
 148#[derive(Clone, Copy)]
 149pub enum ContactEventKind {
 150    Requested,
 151    Accepted,
 152    Cancelled,
 153}
 154
 155impl EventEmitter<Event> for UserStore {}
 156
 157enum UpdateContacts {
 158    Update(proto::UpdateContacts),
 159    Wait(postage::barrier::Sender),
 160    Clear(postage::barrier::Sender),
 161}
 162
 163#[derive(Debug, Clone, Copy, Deref)]
 164pub struct ModelRequestUsage(pub RequestUsage);
 165
 166#[derive(Debug, Clone, Copy, Deref)]
 167pub struct EditPredictionUsage(pub RequestUsage);
 168
 169#[derive(Debug, Clone, Copy)]
 170pub struct RequestUsage {
 171    pub limit: UsageLimit,
 172    pub amount: i32,
 173}
 174
 175impl UserStore {
 176    pub fn new(client: Arc<Client>, cx: &Context<Self>) -> Self {
 177        let (mut current_user_tx, current_user_rx) = watch::channel();
 178        let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
 179        let rpc_subscriptions = vec![
 180            client.add_message_handler(cx.weak_entity(), Self::handle_update_plan),
 181            client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
 182            client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
 183            client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
 184        ];
 185
 186        client.add_message_to_client_handler({
 187            let this = cx.weak_entity();
 188            move |message, cx| Self::handle_message_to_client(this.clone(), message, cx)
 189        });
 190
 191        Self {
 192            users: Default::default(),
 193            by_github_login: Default::default(),
 194            current_user: current_user_rx,
 195            plan_info: None,
 196            model_request_usage: None,
 197            edit_prediction_usage: None,
 198            accepted_tos_at: None,
 199            contacts: Default::default(),
 200            incoming_contact_requests: Default::default(),
 201            participant_indices: Default::default(),
 202            outgoing_contact_requests: Default::default(),
 203            invite_info: None,
 204            client: Arc::downgrade(&client),
 205            update_contacts_tx,
 206            _maintain_contacts: cx.spawn(async move |this, cx| {
 207                let _subscriptions = rpc_subscriptions;
 208                while let Some(message) = update_contacts_rx.next().await {
 209                    if let Ok(task) = this.update(cx, |this, cx| this.update_contacts(message, cx))
 210                    {
 211                        task.log_err().await;
 212                    } else {
 213                        break;
 214                    }
 215                }
 216            }),
 217            _maintain_current_user: cx.spawn(async move |this, cx| {
 218                let mut status = client.status();
 219                let weak = Arc::downgrade(&client);
 220                drop(client);
 221                while let Some(status) = status.next().await {
 222                    // if the client is dropped, the app is shutting down.
 223                    let Some(client) = weak.upgrade() else {
 224                        return Ok(());
 225                    };
 226                    match status {
 227                        Status::Authenticated | Status::Connected { .. } => {
 228                            if let Some(user_id) = client.user_id() {
 229                                let response = client
 230                                    .cloud_client()
 231                                    .get_authenticated_user()
 232                                    .await
 233                                    .log_err();
 234
 235                                let current_user_and_response = if let Some(response) = response {
 236                                    let user = Arc::new(User {
 237                                        id: user_id,
 238                                        github_login: response.user.github_login.clone().into(),
 239                                        avatar_uri: response.user.avatar_url.clone().into(),
 240                                        name: response.user.name.clone(),
 241                                    });
 242
 243                                    Some((user, response))
 244                                } else {
 245                                    None
 246                                };
 247                                current_user_tx
 248                                    .send(
 249                                        current_user_and_response
 250                                            .as_ref()
 251                                            .map(|(user, _)| user.clone()),
 252                                    )
 253                                    .await
 254                                    .ok();
 255
 256                                cx.update(|cx| {
 257                                    if let Some((user, response)) = current_user_and_response {
 258                                        this.update(cx, |this, cx| {
 259                                            this.by_github_login
 260                                                .insert(user.github_login.clone(), user_id);
 261                                            this.users.insert(user_id, user);
 262                                            this.update_authenticated_user(response, cx)
 263                                        })
 264                                    } else {
 265                                        anyhow::Ok(())
 266                                    }
 267                                })??;
 268
 269                                this.update(cx, |_, cx| cx.notify())?;
 270                            }
 271                        }
 272                        Status::SignedOut => {
 273                            current_user_tx.send(None).await.ok();
 274                            this.update(cx, |this, cx| {
 275                                this.accepted_tos_at = None;
 276                                cx.emit(Event::PrivateUserInfoUpdated);
 277                                cx.notify();
 278                                this.clear_contacts()
 279                            })?
 280                            .await;
 281                        }
 282                        Status::ConnectionLost => {
 283                            this.update(cx, |this, cx| {
 284                                cx.notify();
 285                                this.clear_contacts()
 286                            })?
 287                            .await;
 288                        }
 289                        _ => {}
 290                    }
 291                }
 292                Ok(())
 293            }),
 294            pending_contact_requests: Default::default(),
 295            weak_self: cx.weak_entity(),
 296        }
 297    }
 298
 299    #[cfg(feature = "test-support")]
 300    pub fn clear_cache(&mut self) {
 301        self.users.clear();
 302        self.by_github_login.clear();
 303    }
 304
 305    async fn handle_update_invite_info(
 306        this: Entity<Self>,
 307        message: TypedEnvelope<proto::UpdateInviteInfo>,
 308        mut cx: AsyncApp,
 309    ) -> Result<()> {
 310        this.update(&mut cx, |this, cx| {
 311            this.invite_info = Some(InviteInfo {
 312                url: Arc::from(message.payload.url),
 313                count: message.payload.count,
 314            });
 315            cx.notify();
 316        })?;
 317        Ok(())
 318    }
 319
 320    async fn handle_show_contacts(
 321        this: Entity<Self>,
 322        _: TypedEnvelope<proto::ShowContacts>,
 323        mut cx: AsyncApp,
 324    ) -> Result<()> {
 325        this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts))?;
 326        Ok(())
 327    }
 328
 329    pub fn invite_info(&self) -> Option<&InviteInfo> {
 330        self.invite_info.as_ref()
 331    }
 332
 333    async fn handle_update_contacts(
 334        this: Entity<Self>,
 335        message: TypedEnvelope<proto::UpdateContacts>,
 336        mut cx: AsyncApp,
 337    ) -> Result<()> {
 338        this.read_with(&mut cx, |this, _| {
 339            this.update_contacts_tx
 340                .unbounded_send(UpdateContacts::Update(message.payload))
 341                .unwrap();
 342        })?;
 343        Ok(())
 344    }
 345
 346    async fn handle_update_plan(
 347        this: Entity<Self>,
 348        _message: TypedEnvelope<proto::UpdateUserPlan>,
 349        mut cx: AsyncApp,
 350    ) -> Result<()> {
 351        let client = this
 352            .read_with(&cx, |this, _| this.client.upgrade())?
 353            .context("client was dropped")?;
 354
 355        let response = client
 356            .cloud_client()
 357            .get_authenticated_user()
 358            .await
 359            .context("failed to fetch authenticated user")?;
 360
 361        this.update(&mut cx, |this, cx| {
 362            this.update_authenticated_user(response, cx);
 363        })
 364    }
 365
 366    fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
 367        match message {
 368            UpdateContacts::Wait(barrier) => {
 369                drop(barrier);
 370                Task::ready(Ok(()))
 371            }
 372            UpdateContacts::Clear(barrier) => {
 373                self.contacts.clear();
 374                self.incoming_contact_requests.clear();
 375                self.outgoing_contact_requests.clear();
 376                drop(barrier);
 377                Task::ready(Ok(()))
 378            }
 379            UpdateContacts::Update(message) => {
 380                let mut user_ids = HashSet::default();
 381                for contact in &message.contacts {
 382                    user_ids.insert(contact.user_id);
 383                }
 384                user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
 385                user_ids.extend(message.outgoing_requests.iter());
 386
 387                let load_users = self.get_users(user_ids.into_iter().collect(), cx);
 388                cx.spawn(async move |this, cx| {
 389                    load_users.await?;
 390
 391                    // Users are fetched in parallel above and cached in call to get_users
 392                    // No need to parallelize here
 393                    let mut updated_contacts = Vec::new();
 394                    let this = this.upgrade().context("can't upgrade user store handle")?;
 395                    for contact in message.contacts {
 396                        updated_contacts
 397                            .push(Arc::new(Contact::from_proto(contact, &this, cx).await?));
 398                    }
 399
 400                    let mut incoming_requests = Vec::new();
 401                    for request in message.incoming_requests {
 402                        incoming_requests.push({
 403                            this.update(cx, |this, cx| this.get_user(request.requester_id, cx))?
 404                                .await?
 405                        });
 406                    }
 407
 408                    let mut outgoing_requests = Vec::new();
 409                    for requested_user_id in message.outgoing_requests {
 410                        outgoing_requests.push(
 411                            this.update(cx, |this, cx| this.get_user(requested_user_id, cx))?
 412                                .await?,
 413                        );
 414                    }
 415
 416                    let removed_contacts =
 417                        HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
 418                    let removed_incoming_requests =
 419                        HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
 420                    let removed_outgoing_requests =
 421                        HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
 422
 423                    this.update(cx, |this, cx| {
 424                        // Remove contacts
 425                        this.contacts
 426                            .retain(|contact| !removed_contacts.contains(&contact.user.id));
 427                        // Update existing contacts and insert new ones
 428                        for updated_contact in updated_contacts {
 429                            match this.contacts.binary_search_by_key(
 430                                &&updated_contact.user.github_login,
 431                                |contact| &contact.user.github_login,
 432                            ) {
 433                                Ok(ix) => this.contacts[ix] = updated_contact,
 434                                Err(ix) => this.contacts.insert(ix, updated_contact),
 435                            }
 436                        }
 437
 438                        // Remove incoming contact requests
 439                        this.incoming_contact_requests.retain(|user| {
 440                            if removed_incoming_requests.contains(&user.id) {
 441                                cx.emit(Event::Contact {
 442                                    user: user.clone(),
 443                                    kind: ContactEventKind::Cancelled,
 444                                });
 445                                false
 446                            } else {
 447                                true
 448                            }
 449                        });
 450                        // Update existing incoming requests and insert new ones
 451                        for user in incoming_requests {
 452                            match this
 453                                .incoming_contact_requests
 454                                .binary_search_by_key(&&user.github_login, |contact| {
 455                                    &contact.github_login
 456                                }) {
 457                                Ok(ix) => this.incoming_contact_requests[ix] = user,
 458                                Err(ix) => this.incoming_contact_requests.insert(ix, user),
 459                            }
 460                        }
 461
 462                        // Remove outgoing contact requests
 463                        this.outgoing_contact_requests
 464                            .retain(|user| !removed_outgoing_requests.contains(&user.id));
 465                        // Update existing incoming requests and insert new ones
 466                        for request in outgoing_requests {
 467                            match this
 468                                .outgoing_contact_requests
 469                                .binary_search_by_key(&&request.github_login, |contact| {
 470                                    &contact.github_login
 471                                }) {
 472                                Ok(ix) => this.outgoing_contact_requests[ix] = request,
 473                                Err(ix) => this.outgoing_contact_requests.insert(ix, request),
 474                            }
 475                        }
 476
 477                        cx.notify();
 478                    })?;
 479
 480                    Ok(())
 481                })
 482            }
 483        }
 484    }
 485
 486    pub fn contacts(&self) -> &[Arc<Contact>] {
 487        &self.contacts
 488    }
 489
 490    pub fn has_contact(&self, user: &Arc<User>) -> bool {
 491        self.contacts
 492            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
 493            .is_ok()
 494    }
 495
 496    pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
 497        &self.incoming_contact_requests
 498    }
 499
 500    pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
 501        &self.outgoing_contact_requests
 502    }
 503
 504    pub fn is_contact_request_pending(&self, user: &User) -> bool {
 505        self.pending_contact_requests.contains_key(&user.id)
 506    }
 507
 508    pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
 509        if self
 510            .contacts
 511            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
 512            .is_ok()
 513        {
 514            ContactRequestStatus::RequestAccepted
 515        } else if self
 516            .outgoing_contact_requests
 517            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
 518            .is_ok()
 519        {
 520            ContactRequestStatus::RequestSent
 521        } else if self
 522            .incoming_contact_requests
 523            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
 524            .is_ok()
 525        {
 526            ContactRequestStatus::RequestReceived
 527        } else {
 528            ContactRequestStatus::None
 529        }
 530    }
 531
 532    pub fn request_contact(
 533        &mut self,
 534        responder_id: u64,
 535        cx: &mut Context<Self>,
 536    ) -> Task<Result<()>> {
 537        self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
 538    }
 539
 540    pub fn remove_contact(&mut self, user_id: u64, cx: &mut Context<Self>) -> Task<Result<()>> {
 541        self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
 542    }
 543
 544    pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
 545        self.incoming_contact_requests
 546            .iter()
 547            .any(|user| user.id == user_id)
 548    }
 549
 550    pub fn respond_to_contact_request(
 551        &mut self,
 552        requester_id: u64,
 553        accept: bool,
 554        cx: &mut Context<Self>,
 555    ) -> Task<Result<()>> {
 556        self.perform_contact_request(
 557            requester_id,
 558            proto::RespondToContactRequest {
 559                requester_id,
 560                response: if accept {
 561                    proto::ContactRequestResponse::Accept
 562                } else {
 563                    proto::ContactRequestResponse::Decline
 564                } as i32,
 565            },
 566            cx,
 567        )
 568    }
 569
 570    pub fn dismiss_contact_request(
 571        &self,
 572        requester_id: u64,
 573        cx: &Context<Self>,
 574    ) -> Task<Result<()>> {
 575        let client = self.client.upgrade();
 576        cx.spawn(async move |_, _| {
 577            client
 578                .context("can't upgrade client reference")?
 579                .request(proto::RespondToContactRequest {
 580                    requester_id,
 581                    response: proto::ContactRequestResponse::Dismiss as i32,
 582                })
 583                .await?;
 584            Ok(())
 585        })
 586    }
 587
 588    fn perform_contact_request<T: RequestMessage>(
 589        &mut self,
 590        user_id: u64,
 591        request: T,
 592        cx: &mut Context<Self>,
 593    ) -> Task<Result<()>> {
 594        let client = self.client.upgrade();
 595        *self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
 596        cx.notify();
 597
 598        cx.spawn(async move |this, cx| {
 599            let response = client
 600                .context("can't upgrade client reference")?
 601                .request(request)
 602                .await;
 603            this.update(cx, |this, cx| {
 604                if let Entry::Occupied(mut request_count) =
 605                    this.pending_contact_requests.entry(user_id)
 606                {
 607                    *request_count.get_mut() -= 1;
 608                    if *request_count.get() == 0 {
 609                        request_count.remove();
 610                    }
 611                }
 612                cx.notify();
 613            })?;
 614            response?;
 615            Ok(())
 616        })
 617    }
 618
 619    pub fn clear_contacts(&self) -> impl Future<Output = ()> + use<> {
 620        let (tx, mut rx) = postage::barrier::channel();
 621        self.update_contacts_tx
 622            .unbounded_send(UpdateContacts::Clear(tx))
 623            .unwrap();
 624        async move {
 625            rx.next().await;
 626        }
 627    }
 628
 629    pub fn contact_updates_done(&self) -> impl Future<Output = ()> {
 630        let (tx, mut rx) = postage::barrier::channel();
 631        self.update_contacts_tx
 632            .unbounded_send(UpdateContacts::Wait(tx))
 633            .unwrap();
 634        async move {
 635            rx.next().await;
 636        }
 637    }
 638
 639    pub fn get_users(
 640        &self,
 641        user_ids: Vec<u64>,
 642        cx: &Context<Self>,
 643    ) -> Task<Result<Vec<Arc<User>>>> {
 644        let mut user_ids_to_fetch = user_ids.clone();
 645        user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
 646
 647        cx.spawn(async move |this, cx| {
 648            if !user_ids_to_fetch.is_empty() {
 649                this.update(cx, |this, cx| {
 650                    this.load_users(
 651                        proto::GetUsers {
 652                            user_ids: user_ids_to_fetch,
 653                        },
 654                        cx,
 655                    )
 656                })?
 657                .await?;
 658            }
 659
 660            this.read_with(cx, |this, _| {
 661                user_ids
 662                    .iter()
 663                    .map(|user_id| {
 664                        this.users
 665                            .get(user_id)
 666                            .cloned()
 667                            .with_context(|| format!("user {user_id} not found"))
 668                    })
 669                    .collect()
 670            })?
 671        })
 672    }
 673
 674    pub fn fuzzy_search_users(
 675        &self,
 676        query: String,
 677        cx: &Context<Self>,
 678    ) -> Task<Result<Vec<Arc<User>>>> {
 679        self.load_users(proto::FuzzySearchUsers { query }, cx)
 680    }
 681
 682    pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
 683        self.users.get(&user_id).cloned()
 684    }
 685
 686    pub fn get_user_optimistic(&self, user_id: u64, cx: &Context<Self>) -> Option<Arc<User>> {
 687        if let Some(user) = self.users.get(&user_id).cloned() {
 688            return Some(user);
 689        }
 690
 691        self.get_user(user_id, cx).detach_and_log_err(cx);
 692        None
 693    }
 694
 695    pub fn get_user(&self, user_id: u64, cx: &Context<Self>) -> Task<Result<Arc<User>>> {
 696        if let Some(user) = self.users.get(&user_id).cloned() {
 697            return Task::ready(Ok(user));
 698        }
 699
 700        let load_users = self.get_users(vec![user_id], cx);
 701        cx.spawn(async move |this, cx| {
 702            load_users.await?;
 703            this.read_with(cx, |this, _| {
 704                this.users
 705                    .get(&user_id)
 706                    .cloned()
 707                    .context("server responded with no users")
 708            })?
 709        })
 710    }
 711
 712    pub fn cached_user_by_github_login(&self, github_login: &str) -> Option<Arc<User>> {
 713        self.by_github_login
 714            .get(github_login)
 715            .and_then(|id| self.users.get(id).cloned())
 716    }
 717
 718    pub fn current_user(&self) -> Option<Arc<User>> {
 719        self.current_user.borrow().clone()
 720    }
 721
 722    pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
 723        #[cfg(debug_assertions)]
 724        if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
 725            return match plan.as_str() {
 726                "free" => Some(cloud_llm_client::Plan::ZedFree),
 727                "trial" => Some(cloud_llm_client::Plan::ZedProTrial),
 728                "pro" => Some(cloud_llm_client::Plan::ZedPro),
 729                _ => {
 730                    panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
 731                }
 732            };
 733        }
 734
 735        self.plan_info.as_ref().map(|info| info.plan)
 736    }
 737
 738    pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
 739        self.plan_info
 740            .as_ref()
 741            .and_then(|plan| plan.subscription_period)
 742            .map(|subscription_period| {
 743                (
 744                    subscription_period.started_at.0,
 745                    subscription_period.ended_at.0,
 746                )
 747            })
 748    }
 749
 750    pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
 751        self.plan_info
 752            .as_ref()
 753            .and_then(|plan| plan.trial_started_at)
 754            .map(|trial_started_at| trial_started_at.0)
 755    }
 756
 757    /// Returns whether the user's account is too new to use the service.
 758    pub fn account_too_young(&self) -> bool {
 759        self.plan_info
 760            .as_ref()
 761            .map(|plan| plan.is_account_too_young)
 762            .unwrap_or_default()
 763    }
 764
 765    /// Returns whether the current user has overdue invoices and usage should be blocked.
 766    pub fn has_overdue_invoices(&self) -> bool {
 767        self.plan_info
 768            .as_ref()
 769            .map(|plan| plan.has_overdue_invoices)
 770            .unwrap_or_default()
 771    }
 772
 773    pub fn is_usage_based_billing_enabled(&self) -> bool {
 774        self.plan_info
 775            .as_ref()
 776            .map(|plan| plan.is_usage_based_billing_enabled)
 777            .unwrap_or_default()
 778    }
 779
 780    pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
 781        self.model_request_usage
 782    }
 783
 784    pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
 785        self.model_request_usage = Some(usage);
 786        cx.notify();
 787    }
 788
 789    pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
 790        self.edit_prediction_usage
 791    }
 792
 793    pub fn update_edit_prediction_usage(
 794        &mut self,
 795        usage: EditPredictionUsage,
 796        cx: &mut Context<Self>,
 797    ) {
 798        self.edit_prediction_usage = Some(usage);
 799        cx.notify();
 800    }
 801
 802    fn update_authenticated_user(
 803        &mut self,
 804        response: GetAuthenticatedUserResponse,
 805        cx: &mut Context<Self>,
 806    ) {
 807        let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
 808        cx.update_flags(staff, response.feature_flags);
 809        if let Some(client) = self.client.upgrade() {
 810            client
 811                .telemetry
 812                .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
 813        }
 814
 815        let accepted_tos_at = {
 816            #[cfg(debug_assertions)]
 817            if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
 818                None
 819            } else {
 820                response.user.accepted_tos_at
 821            }
 822
 823            #[cfg(not(debug_assertions))]
 824            response.user.accepted_tos_at
 825        };
 826
 827        self.accepted_tos_at = Some(accepted_tos_at);
 828        self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
 829            limit: response.plan.usage.model_requests.limit,
 830            amount: response.plan.usage.model_requests.used as i32,
 831        }));
 832        self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
 833            limit: response.plan.usage.edit_predictions.limit,
 834            amount: response.plan.usage.edit_predictions.used as i32,
 835        }));
 836        self.plan_info = Some(response.plan);
 837        cx.emit(Event::PrivateUserInfoUpdated);
 838    }
 839
 840    fn handle_message_to_client(this: WeakEntity<Self>, message: &MessageToClient, cx: &App) {
 841        cx.spawn(async move |cx| {
 842            match message {
 843                MessageToClient::UserUpdated => {
 844                    let cloud_client = cx
 845                        .update(|cx| {
 846                            this.read_with(cx, |this, _cx| {
 847                                this.client.upgrade().map(|client| client.cloud_client())
 848                            })
 849                        })??
 850                        .ok_or(anyhow::anyhow!("Failed to get Cloud client"))?;
 851
 852                    let response = cloud_client.get_authenticated_user().await?;
 853                    cx.update(|cx| {
 854                        this.update(cx, |this, cx| {
 855                            this.update_authenticated_user(response, cx);
 856                        })
 857                    })??;
 858                }
 859            }
 860
 861            anyhow::Ok(())
 862        })
 863        .detach_and_log_err(cx);
 864    }
 865
 866    pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
 867        self.current_user.clone()
 868    }
 869
 870    pub fn has_accepted_terms_of_service(&self) -> bool {
 871        self.accepted_tos_at
 872            .map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
 873    }
 874
 875    pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
 876        if self.current_user().is_none() {
 877            return Task::ready(Err(anyhow!("no current user")));
 878        };
 879
 880        let client = self.client.clone();
 881        cx.spawn(async move |this, cx| -> anyhow::Result<()> {
 882            let client = client.upgrade().context("client not found")?;
 883            let response = client
 884                .cloud_client()
 885                .accept_terms_of_service()
 886                .await
 887                .context("error accepting tos")?;
 888            this.update(cx, |this, cx| {
 889                this.accepted_tos_at = Some(response.user.accepted_tos_at);
 890                cx.emit(Event::PrivateUserInfoUpdated);
 891            })?;
 892            Ok(())
 893        })
 894    }
 895
 896    fn load_users(
 897        &self,
 898        request: impl RequestMessage<Response = UsersResponse>,
 899        cx: &Context<Self>,
 900    ) -> Task<Result<Vec<Arc<User>>>> {
 901        let client = self.client.clone();
 902        cx.spawn(async move |this, cx| {
 903            if let Some(rpc) = client.upgrade() {
 904                let response = rpc.request(request).await.context("error loading users")?;
 905                let users = response.users;
 906
 907                this.update(cx, |this, _| this.insert(users))
 908            } else {
 909                Ok(Vec::new())
 910            }
 911        })
 912    }
 913
 914    pub fn insert(&mut self, users: Vec<proto::User>) -> Vec<Arc<User>> {
 915        let mut ret = Vec::with_capacity(users.len());
 916        for user in users {
 917            let user = User::new(user);
 918            if let Some(old) = self.users.insert(user.id, user.clone()) {
 919                if old.github_login != user.github_login {
 920                    self.by_github_login.remove(&old.github_login);
 921                }
 922            }
 923            self.by_github_login
 924                .insert(user.github_login.clone(), user.id);
 925            ret.push(user)
 926        }
 927        ret
 928    }
 929
 930    pub fn set_participant_indices(
 931        &mut self,
 932        participant_indices: HashMap<u64, ParticipantIndex>,
 933        cx: &mut Context<Self>,
 934    ) {
 935        if participant_indices != self.participant_indices {
 936            self.participant_indices = participant_indices;
 937            cx.emit(Event::ParticipantIndicesChanged);
 938        }
 939    }
 940
 941    pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
 942        &self.participant_indices
 943    }
 944
 945    pub fn participant_names(
 946        &self,
 947        user_ids: impl Iterator<Item = u64>,
 948        cx: &App,
 949    ) -> HashMap<u64, SharedString> {
 950        let mut ret = HashMap::default();
 951        let mut missing_user_ids = Vec::new();
 952        for id in user_ids {
 953            if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) {
 954                ret.insert(id, github_login);
 955            } else {
 956                missing_user_ids.push(id)
 957            }
 958        }
 959        if !missing_user_ids.is_empty() {
 960            let this = self.weak_self.clone();
 961            cx.spawn(async move |cx| {
 962                this.update(cx, |this, cx| this.get_users(missing_user_ids, cx))?
 963                    .await
 964            })
 965            .detach_and_log_err(cx);
 966        }
 967        ret
 968    }
 969}
 970
 971impl User {
 972    fn new(message: proto::User) -> Arc<Self> {
 973        Arc::new(User {
 974            id: message.id,
 975            github_login: message.github_login.into(),
 976            avatar_uri: message.avatar_url.into(),
 977            name: message.name,
 978        })
 979    }
 980}
 981
 982impl Contact {
 983    async fn from_proto(
 984        contact: proto::Contact,
 985        user_store: &Entity<UserStore>,
 986        cx: &mut AsyncApp,
 987    ) -> Result<Self> {
 988        let user = user_store
 989            .update(cx, |user_store, cx| {
 990                user_store.get_user(contact.user_id, cx)
 991            })?
 992            .await?;
 993        Ok(Self {
 994            user,
 995            online: contact.online,
 996            busy: contact.busy,
 997        })
 998    }
 999}
1000
1001impl Collaborator {
1002    pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
1003        Ok(Self {
1004            peer_id: message.peer_id.context("invalid peer id")?,
1005            replica_id: message.replica_id as ReplicaId,
1006            user_id: message.user_id as UserId,
1007            is_host: message.is_host,
1008            committer_name: message.committer_name,
1009            committer_email: message.committer_email,
1010        })
1011    }
1012}
1013
1014impl RequestUsage {
1015    pub fn over_limit(&self) -> bool {
1016        match self.limit {
1017            UsageLimit::Limited(limit) => self.amount >= limit,
1018            UsageLimit::Unlimited => false,
1019        }
1020    }
1021
1022    pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
1023        let limit = match limit.variant? {
1024            proto::usage_limit::Variant::Limited(limited) => {
1025                UsageLimit::Limited(limited.limit as i32)
1026            }
1027            proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
1028        };
1029        Some(RequestUsage {
1030            limit,
1031            amount: amount as i32,
1032        })
1033    }
1034
1035    fn from_headers(
1036        limit_name: &str,
1037        amount_name: &str,
1038        headers: &HeaderMap<HeaderValue>,
1039    ) -> Result<Self> {
1040        let limit = headers
1041            .get(limit_name)
1042            .with_context(|| format!("missing {limit_name:?} header"))?;
1043        let limit = UsageLimit::from_str(limit.to_str()?)?;
1044
1045        let amount = headers
1046            .get(amount_name)
1047            .with_context(|| format!("missing {amount_name:?} header"))?;
1048        let amount = amount.to_str()?.parse::<i32>()?;
1049
1050        Ok(Self { limit, amount })
1051    }
1052}
1053
1054impl ModelRequestUsage {
1055    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
1056        Ok(Self(RequestUsage::from_headers(
1057            MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME,
1058            MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
1059            headers,
1060        )?))
1061    }
1062}
1063
1064impl EditPredictionUsage {
1065    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
1066        Ok(Self(RequestUsage::from_headers(
1067            EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
1068            EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME,
1069            headers,
1070        )?))
1071    }
1072}