user.rs

   1use super::{Client, Status, TypedEnvelope, proto};
   2use anyhow::{Context as _, Result};
   3use chrono::{DateTime, Utc};
   4use cloud_api_client::websocket_protocol::MessageToClient;
   5use cloud_api_client::{
   6    GetAuthenticatedUserResponse, KnownOrUnknown, Organization, OrganizationId, Plan, PlanInfo,
   7};
   8use cloud_llm_client::{
   9    EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
  10};
  11use collections::{HashMap, HashSet, hash_map::Entry};
  12use db::kvp::KeyValueStore;
  13use derive_more::Deref;
  14use feature_flags::FeatureFlagAppExt;
  15use futures::{Future, StreamExt, channel::mpsc};
  16use gpui::{
  17    App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity,
  18};
  19use http_client::http::{HeaderMap, HeaderValue};
  20use postage::{sink::Sink, watch};
  21use rpc::proto::{RequestMessage, UsersResponse};
  22use std::{
  23    str::FromStr as _,
  24    sync::{Arc, Weak},
  25};
  26use text::ReplicaId;
  27use util::{ResultExt, TryFutureExt as _};
  28
  29const CURRENT_ORGANIZATION_ID_KEY: &str = "current_organization_id";
  30
  31pub type UserId = u64;
  32
  33#[derive(
  34    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
  35)]
  36pub struct ChannelId(pub u64);
  37
  38impl std::fmt::Display for ChannelId {
  39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  40        self.0.fmt(f)
  41    }
  42}
  43
  44#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
  45pub struct ProjectId(pub u64);
  46
  47impl ProjectId {
  48    pub fn to_proto(self) -> u64 {
  49        self.0
  50    }
  51}
  52
  53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  54pub struct ParticipantIndex(pub u32);
  55
  56#[derive(Default, Debug)]
  57pub struct User {
  58    pub id: UserId,
  59    pub github_login: SharedString,
  60    pub avatar_uri: SharedUri,
  61    pub name: Option<String>,
  62}
  63
  64#[derive(Clone, Debug, PartialEq, Eq)]
  65pub struct Collaborator {
  66    pub peer_id: proto::PeerId,
  67    pub replica_id: ReplicaId,
  68    pub user_id: UserId,
  69    pub is_host: bool,
  70    pub committer_name: Option<String>,
  71    pub committer_email: Option<String>,
  72}
  73
  74impl PartialOrd for User {
  75    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
  76        Some(self.cmp(other))
  77    }
  78}
  79
  80impl Ord for User {
  81    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
  82        self.github_login.cmp(&other.github_login)
  83    }
  84}
  85
  86impl PartialEq for User {
  87    fn eq(&self, other: &Self) -> bool {
  88        self.id == other.id && self.github_login == other.github_login
  89    }
  90}
  91
  92impl Eq for User {}
  93
  94#[derive(Debug, PartialEq)]
  95pub struct Contact {
  96    pub user: Arc<User>,
  97    pub online: bool,
  98    pub busy: bool,
  99}
 100
 101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 102pub enum ContactRequestStatus {
 103    None,
 104    RequestSent,
 105    RequestReceived,
 106    RequestAccepted,
 107}
 108
 109pub struct UserStore {
 110    users: HashMap<u64, Arc<User>>,
 111    by_github_login: HashMap<SharedString, u64>,
 112    participant_indices: HashMap<u64, ParticipantIndex>,
 113    update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
 114    edit_prediction_usage: Option<EditPredictionUsage>,
 115    plan_info: Option<PlanInfo>,
 116    current_user: watch::Receiver<Option<Arc<User>>>,
 117    current_organization: Option<Arc<Organization>>,
 118    organizations: Vec<Arc<Organization>>,
 119    plans_by_organization: HashMap<OrganizationId, Plan>,
 120    contacts: Vec<Arc<Contact>>,
 121    incoming_contact_requests: Vec<Arc<User>>,
 122    outgoing_contact_requests: Vec<Arc<User>>,
 123    pending_contact_requests: HashMap<u64, usize>,
 124    client: Weak<Client>,
 125    _maintain_contacts: Task<()>,
 126    _maintain_current_user: Task<Result<()>>,
 127    _handle_sign_out: Task<()>,
 128    weak_self: WeakEntity<Self>,
 129}
 130
 131#[derive(Clone)]
 132pub struct InviteInfo {
 133    pub count: u32,
 134    pub url: Arc<str>,
 135}
 136
 137pub enum Event {
 138    Contact {
 139        user: Arc<User>,
 140        kind: ContactEventKind,
 141    },
 142    ShowContacts,
 143    ParticipantIndicesChanged,
 144    PrivateUserInfoUpdated,
 145    PlanUpdated,
 146    OrganizationChanged,
 147}
 148
 149#[derive(Clone, Copy)]
 150pub enum ContactEventKind {
 151    Requested,
 152    Accepted,
 153    Cancelled,
 154}
 155
 156impl EventEmitter<Event> for UserStore {}
 157
 158enum UpdateContacts {
 159    Update(proto::UpdateContacts),
 160    Wait(postage::barrier::Sender),
 161    Clear(postage::barrier::Sender),
 162}
 163
 164#[derive(Debug, Clone, Copy, Deref)]
 165pub struct EditPredictionUsage(pub RequestUsage);
 166
 167#[derive(Debug, Clone, Copy)]
 168pub struct RequestUsage {
 169    pub limit: UsageLimit,
 170    pub amount: i32,
 171}
 172
 173impl UserStore {
 174    pub fn new(client: Arc<Client>, cx: &Context<Self>) -> Self {
 175        let (mut current_user_tx, current_user_rx) = watch::channel();
 176        let (sign_out_tx, mut sign_out_rx) = mpsc::unbounded();
 177        let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
 178        let rpc_subscriptions = vec![
 179            client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
 180            client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
 181        ];
 182
 183        client.sign_out_tx.lock().replace(sign_out_tx);
 184        client.add_message_to_client_handler({
 185            let this = cx.weak_entity();
 186            move |message, cx| Self::handle_message_to_client(this.clone(), message, cx)
 187        });
 188
 189        Self {
 190            users: Default::default(),
 191            by_github_login: Default::default(),
 192            current_user: current_user_rx,
 193            current_organization: None,
 194            organizations: Vec::new(),
 195            plans_by_organization: HashMap::default(),
 196            plan_info: None,
 197            edit_prediction_usage: None,
 198            contacts: Default::default(),
 199            incoming_contact_requests: Default::default(),
 200            participant_indices: Default::default(),
 201            outgoing_contact_requests: Default::default(),
 202            client: Arc::downgrade(&client),
 203            update_contacts_tx,
 204            _maintain_contacts: cx.spawn(async move |this, cx| {
 205                let _subscriptions = rpc_subscriptions;
 206                while let Some(message) = update_contacts_rx.next().await {
 207                    if let Ok(task) = this.update(cx, |this, cx| this.update_contacts(message, cx))
 208                    {
 209                        task.log_err().await;
 210                    } else {
 211                        break;
 212                    }
 213                }
 214            }),
 215            _maintain_current_user: cx.spawn(async move |this, cx| {
 216                let mut status = client.status();
 217                let weak = Arc::downgrade(&client);
 218                drop(client);
 219                while let Some(status) = status.next().await {
 220                    // if the client is dropped, the app is shutting down.
 221                    let Some(client) = weak.upgrade() else {
 222                        return Ok(());
 223                    };
 224                    match status {
 225                        Status::Authenticated
 226                        | Status::Reauthenticated
 227                        | Status::Connected { .. } => {
 228                            if let Some(user_id) = client.user_id() {
 229                                let response = client
 230                                    .cloud_client()
 231                                    .get_authenticated_user()
 232                                    .await
 233                                    .log_err();
 234
 235                                let current_user_and_response = if let Some(response) = response {
 236                                    let user = Arc::new(User {
 237                                        id: user_id,
 238                                        github_login: response.user.github_login.clone().into(),
 239                                        avatar_uri: response.user.avatar_url.clone().into(),
 240                                        name: response.user.name.clone(),
 241                                    });
 242
 243                                    Some((user, response))
 244                                } else {
 245                                    None
 246                                };
 247                                current_user_tx
 248                                    .send(
 249                                        current_user_and_response
 250                                            .as_ref()
 251                                            .map(|(user, _)| user.clone()),
 252                                    )
 253                                    .await
 254                                    .ok();
 255
 256                                cx.update(|cx| {
 257                                    if let Some((user, response)) = current_user_and_response {
 258                                        this.update(cx, |this, cx| {
 259                                            this.by_github_login
 260                                                .insert(user.github_login.clone(), user_id);
 261                                            this.users.insert(user_id, user);
 262                                            this.update_authenticated_user(response, cx)
 263                                        })
 264                                    } else {
 265                                        anyhow::Ok(())
 266                                    }
 267                                })?;
 268
 269                                this.update(cx, |_, cx| cx.notify())?;
 270                            }
 271                        }
 272                        Status::SignedOut => {
 273                            current_user_tx.send(None).await.ok();
 274                            this.update(cx, |this, cx| {
 275                                this.clear_organizations();
 276                                this.clear_plan_and_usage();
 277                                cx.emit(Event::PrivateUserInfoUpdated);
 278                                cx.notify();
 279                                this.clear_contacts()
 280                            })?
 281                            .await;
 282                        }
 283                        Status::ConnectionLost => {
 284                            this.update(cx, |this, cx| {
 285                                cx.notify();
 286                                this.clear_contacts()
 287                            })?
 288                            .await;
 289                        }
 290                        _ => {}
 291                    }
 292                }
 293                Ok(())
 294            }),
 295            _handle_sign_out: cx.spawn(async move |this, cx| {
 296                while let Some(()) = sign_out_rx.next().await {
 297                    let Some(client) = this
 298                        .read_with(cx, |this, _cx| this.client.upgrade())
 299                        .ok()
 300                        .flatten()
 301                    else {
 302                        break;
 303                    };
 304
 305                    client.sign_out(cx).await;
 306                }
 307            }),
 308            pending_contact_requests: Default::default(),
 309            weak_self: cx.weak_entity(),
 310        }
 311    }
 312
 313    #[cfg(feature = "test-support")]
 314    pub fn clear_cache(&mut self) {
 315        self.users.clear();
 316        self.by_github_login.clear();
 317    }
 318
 319    async fn handle_show_contacts(
 320        this: Entity<Self>,
 321        _: TypedEnvelope<proto::ShowContacts>,
 322        mut cx: AsyncApp,
 323    ) -> Result<()> {
 324        this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts));
 325        Ok(())
 326    }
 327
 328    async fn handle_update_contacts(
 329        this: Entity<Self>,
 330        message: TypedEnvelope<proto::UpdateContacts>,
 331        cx: AsyncApp,
 332    ) -> Result<()> {
 333        this.read_with(&cx, |this, _| {
 334            this.update_contacts_tx
 335                .unbounded_send(UpdateContacts::Update(message.payload))
 336                .unwrap();
 337        });
 338        Ok(())
 339    }
 340
 341    fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
 342        match message {
 343            UpdateContacts::Wait(barrier) => {
 344                drop(barrier);
 345                Task::ready(Ok(()))
 346            }
 347            UpdateContacts::Clear(barrier) => {
 348                self.contacts.clear();
 349                self.incoming_contact_requests.clear();
 350                self.outgoing_contact_requests.clear();
 351                drop(barrier);
 352                Task::ready(Ok(()))
 353            }
 354            UpdateContacts::Update(message) => {
 355                let mut user_ids = HashSet::default();
 356                for contact in &message.contacts {
 357                    user_ids.insert(contact.user_id);
 358                }
 359                user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
 360                user_ids.extend(message.outgoing_requests.iter());
 361
 362                let load_users = self.get_users(user_ids.into_iter().collect(), cx);
 363                cx.spawn(async move |this, cx| {
 364                    load_users.await?;
 365
 366                    // Users are fetched in parallel above and cached in call to get_users
 367                    // No need to parallelize here
 368                    let mut updated_contacts = Vec::new();
 369                    let this = this.upgrade().context("can't upgrade user store handle")?;
 370                    for contact in message.contacts {
 371                        updated_contacts
 372                            .push(Arc::new(Contact::from_proto(contact, &this, cx).await?));
 373                    }
 374
 375                    let mut incoming_requests = Vec::new();
 376                    for request in message.incoming_requests {
 377                        incoming_requests.push({
 378                            this.update(cx, |this, cx| this.get_user(request.requester_id, cx))
 379                                .await?
 380                        });
 381                    }
 382
 383                    let mut outgoing_requests = Vec::new();
 384                    for requested_user_id in message.outgoing_requests {
 385                        outgoing_requests.push(
 386                            this.update(cx, |this, cx| this.get_user(requested_user_id, cx))
 387                                .await?,
 388                        );
 389                    }
 390
 391                    let removed_contacts =
 392                        HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
 393                    let removed_incoming_requests =
 394                        HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
 395                    let removed_outgoing_requests =
 396                        HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
 397
 398                    this.update(cx, |this, cx| {
 399                        // Remove contacts
 400                        this.contacts
 401                            .retain(|contact| !removed_contacts.contains(&contact.user.id));
 402                        // Update existing contacts and insert new ones
 403                        for updated_contact in updated_contacts {
 404                            match this.contacts.binary_search_by_key(
 405                                &&updated_contact.user.github_login,
 406                                |contact| &contact.user.github_login,
 407                            ) {
 408                                Ok(ix) => this.contacts[ix] = updated_contact,
 409                                Err(ix) => this.contacts.insert(ix, updated_contact),
 410                            }
 411                        }
 412
 413                        // Remove incoming contact requests
 414                        this.incoming_contact_requests.retain(|user| {
 415                            if removed_incoming_requests.contains(&user.id) {
 416                                cx.emit(Event::Contact {
 417                                    user: user.clone(),
 418                                    kind: ContactEventKind::Cancelled,
 419                                });
 420                                false
 421                            } else {
 422                                true
 423                            }
 424                        });
 425                        // Update existing incoming requests and insert new ones
 426                        for user in incoming_requests {
 427                            match this
 428                                .incoming_contact_requests
 429                                .binary_search_by_key(&&user.github_login, |contact| {
 430                                    &contact.github_login
 431                                }) {
 432                                Ok(ix) => this.incoming_contact_requests[ix] = user,
 433                                Err(ix) => this.incoming_contact_requests.insert(ix, user),
 434                            }
 435                        }
 436
 437                        // Remove outgoing contact requests
 438                        this.outgoing_contact_requests
 439                            .retain(|user| !removed_outgoing_requests.contains(&user.id));
 440                        // Update existing incoming requests and insert new ones
 441                        for request in outgoing_requests {
 442                            match this
 443                                .outgoing_contact_requests
 444                                .binary_search_by_key(&&request.github_login, |contact| {
 445                                    &contact.github_login
 446                                }) {
 447                                Ok(ix) => this.outgoing_contact_requests[ix] = request,
 448                                Err(ix) => this.outgoing_contact_requests.insert(ix, request),
 449                            }
 450                        }
 451
 452                        cx.notify();
 453                    });
 454
 455                    Ok(())
 456                })
 457            }
 458        }
 459    }
 460
 461    pub fn contacts(&self) -> &[Arc<Contact>] {
 462        &self.contacts
 463    }
 464
 465    pub fn has_contact(&self, user: &Arc<User>) -> bool {
 466        self.contacts
 467            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
 468            .is_ok()
 469    }
 470
 471    pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
 472        &self.incoming_contact_requests
 473    }
 474
 475    pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
 476        &self.outgoing_contact_requests
 477    }
 478
 479    pub fn is_contact_request_pending(&self, user: &User) -> bool {
 480        self.pending_contact_requests.contains_key(&user.id)
 481    }
 482
 483    pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
 484        if self
 485            .contacts
 486            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
 487            .is_ok()
 488        {
 489            ContactRequestStatus::RequestAccepted
 490        } else if self
 491            .outgoing_contact_requests
 492            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
 493            .is_ok()
 494        {
 495            ContactRequestStatus::RequestSent
 496        } else if self
 497            .incoming_contact_requests
 498            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
 499            .is_ok()
 500        {
 501            ContactRequestStatus::RequestReceived
 502        } else {
 503            ContactRequestStatus::None
 504        }
 505    }
 506
 507    pub fn request_contact(
 508        &mut self,
 509        responder_id: u64,
 510        cx: &mut Context<Self>,
 511    ) -> Task<Result<()>> {
 512        self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
 513    }
 514
 515    pub fn remove_contact(&mut self, user_id: u64, cx: &mut Context<Self>) -> Task<Result<()>> {
 516        self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
 517    }
 518
 519    pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
 520        self.incoming_contact_requests
 521            .iter()
 522            .any(|user| user.id == user_id)
 523    }
 524
 525    pub fn respond_to_contact_request(
 526        &mut self,
 527        requester_id: u64,
 528        accept: bool,
 529        cx: &mut Context<Self>,
 530    ) -> Task<Result<()>> {
 531        self.perform_contact_request(
 532            requester_id,
 533            proto::RespondToContactRequest {
 534                requester_id,
 535                response: if accept {
 536                    proto::ContactRequestResponse::Accept
 537                } else {
 538                    proto::ContactRequestResponse::Decline
 539                } as i32,
 540            },
 541            cx,
 542        )
 543    }
 544
 545    pub fn dismiss_contact_request(
 546        &self,
 547        requester_id: u64,
 548        cx: &Context<Self>,
 549    ) -> Task<Result<()>> {
 550        let client = self.client.upgrade();
 551        cx.spawn(async move |_, _| {
 552            client
 553                .context("can't upgrade client reference")?
 554                .request(proto::RespondToContactRequest {
 555                    requester_id,
 556                    response: proto::ContactRequestResponse::Dismiss as i32,
 557                })
 558                .await?;
 559            Ok(())
 560        })
 561    }
 562
 563    fn perform_contact_request<T: RequestMessage>(
 564        &mut self,
 565        user_id: u64,
 566        request: T,
 567        cx: &mut Context<Self>,
 568    ) -> Task<Result<()>> {
 569        let client = self.client.upgrade();
 570        *self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
 571        cx.notify();
 572
 573        cx.spawn(async move |this, cx| {
 574            let response = client
 575                .context("can't upgrade client reference")?
 576                .request(request)
 577                .await;
 578            this.update(cx, |this, cx| {
 579                if let Entry::Occupied(mut request_count) =
 580                    this.pending_contact_requests.entry(user_id)
 581                {
 582                    *request_count.get_mut() -= 1;
 583                    if *request_count.get() == 0 {
 584                        request_count.remove();
 585                    }
 586                }
 587                cx.notify();
 588            })?;
 589            response?;
 590            Ok(())
 591        })
 592    }
 593
 594    pub fn clear_contacts(&self) -> impl Future<Output = ()> + use<> {
 595        let (tx, mut rx) = postage::barrier::channel();
 596        self.update_contacts_tx
 597            .unbounded_send(UpdateContacts::Clear(tx))
 598            .unwrap();
 599        async move {
 600            rx.next().await;
 601        }
 602    }
 603
 604    pub fn contact_updates_done(&self) -> impl Future<Output = ()> {
 605        let (tx, mut rx) = postage::barrier::channel();
 606        self.update_contacts_tx
 607            .unbounded_send(UpdateContacts::Wait(tx))
 608            .unwrap();
 609        async move {
 610            rx.next().await;
 611        }
 612    }
 613
 614    pub fn get_users(
 615        &self,
 616        user_ids: Vec<u64>,
 617        cx: &Context<Self>,
 618    ) -> Task<Result<Vec<Arc<User>>>> {
 619        let mut user_ids_to_fetch = user_ids.clone();
 620        user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
 621
 622        cx.spawn(async move |this, cx| {
 623            if !user_ids_to_fetch.is_empty() {
 624                this.update(cx, |this, cx| {
 625                    this.load_users(
 626                        proto::GetUsers {
 627                            user_ids: user_ids_to_fetch,
 628                        },
 629                        cx,
 630                    )
 631                })?
 632                .await?;
 633            }
 634
 635            this.read_with(cx, |this, _| {
 636                user_ids
 637                    .iter()
 638                    .map(|user_id| {
 639                        this.users
 640                            .get(user_id)
 641                            .cloned()
 642                            .with_context(|| format!("user {user_id} not found"))
 643                    })
 644                    .collect()
 645            })?
 646        })
 647    }
 648
 649    pub fn fuzzy_search_users(
 650        &self,
 651        query: String,
 652        cx: &Context<Self>,
 653    ) -> Task<Result<Vec<Arc<User>>>> {
 654        self.load_users(proto::FuzzySearchUsers { query }, cx)
 655    }
 656
 657    pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
 658        self.users.get(&user_id).cloned()
 659    }
 660
 661    pub fn get_user_optimistic(&self, user_id: u64, cx: &Context<Self>) -> Option<Arc<User>> {
 662        if let Some(user) = self.users.get(&user_id).cloned() {
 663            return Some(user);
 664        }
 665
 666        self.get_user(user_id, cx).detach_and_log_err(cx);
 667        None
 668    }
 669
 670    pub fn get_user(&self, user_id: u64, cx: &Context<Self>) -> Task<Result<Arc<User>>> {
 671        if let Some(user) = self.users.get(&user_id).cloned() {
 672            return Task::ready(Ok(user));
 673        }
 674
 675        let load_users = self.get_users(vec![user_id], cx);
 676        cx.spawn(async move |this, cx| {
 677            load_users.await?;
 678            this.read_with(cx, |this, _| {
 679                this.users
 680                    .get(&user_id)
 681                    .cloned()
 682                    .context("server responded with no users")
 683            })?
 684        })
 685    }
 686
 687    pub fn cached_user_by_github_login(&self, github_login: &str) -> Option<Arc<User>> {
 688        self.by_github_login
 689            .get(github_login)
 690            .and_then(|id| self.users.get(id).cloned())
 691    }
 692
 693    pub fn current_user(&self) -> Option<Arc<User>> {
 694        self.current_user.borrow().clone()
 695    }
 696
 697    pub fn current_organization(&self) -> Option<Arc<Organization>> {
 698        self.current_organization.clone()
 699    }
 700
 701    pub fn set_current_organization(
 702        &mut self,
 703        organization: Arc<Organization>,
 704        cx: &mut Context<Self>,
 705    ) {
 706        let is_same_organization = self
 707            .current_organization
 708            .as_ref()
 709            .is_some_and(|current| current.id == organization.id);
 710
 711        if !is_same_organization {
 712            let organization_id = organization.id.0.to_string();
 713            self.current_organization.replace(organization);
 714            cx.emit(Event::OrganizationChanged);
 715            cx.notify();
 716
 717            let kvp = KeyValueStore::global(cx);
 718            db::write_and_log(cx, move || async move {
 719                kvp.write_kvp(CURRENT_ORGANIZATION_ID_KEY.into(), organization_id)
 720                    .await
 721            });
 722        }
 723    }
 724
 725    pub fn organizations(&self) -> &Vec<Arc<Organization>> {
 726        &self.organizations
 727    }
 728
 729    pub fn plan_for_organization(&self, organization_id: &OrganizationId) -> Option<Plan> {
 730        self.plans_by_organization.get(organization_id).copied()
 731    }
 732
 733    pub fn plan(&self) -> Option<Plan> {
 734        #[cfg(debug_assertions)]
 735        if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
 736            use cloud_api_client::Plan;
 737
 738            return match plan.as_str() {
 739                "free" => Some(Plan::ZedFree),
 740                "trial" => Some(Plan::ZedProTrial),
 741                "pro" => Some(Plan::ZedPro),
 742                _ => {
 743                    panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
 744                }
 745            };
 746        }
 747
 748        if let Some(organization) = &self.current_organization
 749            && let Some(plan) = self.plan_for_organization(&organization.id)
 750        {
 751            return Some(plan);
 752        }
 753
 754        self.plan_info.as_ref().map(|info| info.plan())
 755    }
 756
 757    pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
 758        self.plan_info
 759            .as_ref()
 760            .and_then(|plan| plan.subscription_period)
 761            .map(|subscription_period| {
 762                (
 763                    subscription_period.started_at.0,
 764                    subscription_period.ended_at.0,
 765                )
 766            })
 767    }
 768
 769    pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
 770        self.plan_info
 771            .as_ref()
 772            .and_then(|plan| plan.trial_started_at)
 773            .map(|trial_started_at| trial_started_at.0)
 774    }
 775
 776    /// Returns whether the user's account is too new to use the service.
 777    pub fn account_too_young(&self) -> bool {
 778        self.plan_info
 779            .as_ref()
 780            .map(|plan| plan.is_account_too_young)
 781            .unwrap_or_default()
 782    }
 783
 784    /// Returns whether the current user has overdue invoices and usage should be blocked.
 785    pub fn has_overdue_invoices(&self) -> bool {
 786        self.plan_info
 787            .as_ref()
 788            .map(|plan| plan.has_overdue_invoices)
 789            .unwrap_or_default()
 790    }
 791
 792    pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
 793        self.edit_prediction_usage
 794    }
 795
 796    pub fn update_edit_prediction_usage(
 797        &mut self,
 798        usage: EditPredictionUsage,
 799        cx: &mut Context<Self>,
 800    ) {
 801        self.edit_prediction_usage = Some(usage);
 802        cx.notify();
 803    }
 804
 805    pub fn clear_organizations(&mut self) {
 806        self.organizations.clear();
 807        self.current_organization = None;
 808    }
 809
 810    pub fn clear_plan_and_usage(&mut self) {
 811        self.plan_info = None;
 812        self.edit_prediction_usage = None;
 813    }
 814
 815    fn update_authenticated_user(
 816        &mut self,
 817        response: GetAuthenticatedUserResponse,
 818        cx: &mut Context<Self>,
 819    ) {
 820        let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
 821        cx.update_flags(staff, response.feature_flags);
 822        if let Some(client) = self.client.upgrade() {
 823            client
 824                .telemetry
 825                .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
 826        }
 827
 828        self.organizations = response.organizations.into_iter().map(Arc::new).collect();
 829        let persisted_org_id = KeyValueStore::global(cx)
 830            .read_kvp(CURRENT_ORGANIZATION_ID_KEY)
 831            .log_err()
 832            .flatten()
 833            .map(|id| OrganizationId(Arc::from(id)));
 834
 835        self.current_organization = persisted_org_id
 836            .and_then(|persisted_id| {
 837                self.organizations
 838                    .iter()
 839                    .find(|org| org.id == persisted_id)
 840                    .cloned()
 841            })
 842            .or_else(|| {
 843                response
 844                    .default_organization_id
 845                    .and_then(|default_organization_id| {
 846                        self.organizations
 847                            .iter()
 848                            .find(|organization| organization.id == default_organization_id)
 849                            .cloned()
 850                    })
 851            })
 852            .or_else(|| self.organizations.first().cloned());
 853        self.plans_by_organization = response
 854            .plans_by_organization
 855            .into_iter()
 856            .map(|(organization_id, plan)| {
 857                let plan = match plan {
 858                    KnownOrUnknown::Known(plan) => plan,
 859                    KnownOrUnknown::Unknown(_) => {
 860                        // If we get a plan that we don't recognize, fall back to the Free plan.
 861                        Plan::ZedFree
 862                    }
 863                };
 864
 865                (organization_id, plan)
 866            })
 867            .collect();
 868
 869        self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
 870            limit: response.plan.usage.edit_predictions.limit,
 871            amount: response.plan.usage.edit_predictions.used as i32,
 872        }));
 873        self.plan_info = Some(response.plan);
 874        cx.emit(Event::PrivateUserInfoUpdated);
 875    }
 876
 877    fn handle_message_to_client(this: WeakEntity<Self>, message: &MessageToClient, cx: &App) {
 878        cx.spawn(async move |cx| {
 879            match message {
 880                MessageToClient::UserUpdated => {
 881                    let cloud_client = cx
 882                        .update(|cx| {
 883                            this.read_with(cx, |this, _cx| {
 884                                this.client.upgrade().map(|client| client.cloud_client())
 885                            })
 886                        })?
 887                        .ok_or(anyhow::anyhow!("Failed to get Cloud client"))?;
 888
 889                    let response = cloud_client.get_authenticated_user().await?;
 890                    cx.update(|cx| {
 891                        this.update(cx, |this, cx| {
 892                            this.update_authenticated_user(response, cx);
 893                        })
 894                    })?;
 895                }
 896            }
 897
 898            anyhow::Ok(())
 899        })
 900        .detach_and_log_err(cx);
 901    }
 902
 903    pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
 904        self.current_user.clone()
 905    }
 906
 907    fn load_users(
 908        &self,
 909        request: impl RequestMessage<Response = UsersResponse>,
 910        cx: &Context<Self>,
 911    ) -> Task<Result<Vec<Arc<User>>>> {
 912        let client = self.client.clone();
 913        cx.spawn(async move |this, cx| {
 914            if let Some(rpc) = client.upgrade() {
 915                let response = rpc.request(request).await.context("error loading users")?;
 916                let users = response.users;
 917
 918                this.update(cx, |this, _| this.insert(users))
 919            } else {
 920                Ok(Vec::new())
 921            }
 922        })
 923    }
 924
 925    pub fn insert(&mut self, users: Vec<proto::User>) -> Vec<Arc<User>> {
 926        let mut ret = Vec::with_capacity(users.len());
 927        for user in users {
 928            let user = User::new(user);
 929            if let Some(old) = self.users.insert(user.id, user.clone())
 930                && old.github_login != user.github_login
 931            {
 932                self.by_github_login.remove(&old.github_login);
 933            }
 934            self.by_github_login
 935                .insert(user.github_login.clone(), user.id);
 936            ret.push(user)
 937        }
 938        ret
 939    }
 940
 941    pub fn set_participant_indices(
 942        &mut self,
 943        participant_indices: HashMap<u64, ParticipantIndex>,
 944        cx: &mut Context<Self>,
 945    ) {
 946        if participant_indices != self.participant_indices {
 947            self.participant_indices = participant_indices;
 948            cx.emit(Event::ParticipantIndicesChanged);
 949        }
 950    }
 951
 952    pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
 953        &self.participant_indices
 954    }
 955
 956    pub fn participant_names(
 957        &self,
 958        user_ids: impl Iterator<Item = u64>,
 959        cx: &App,
 960    ) -> HashMap<u64, SharedString> {
 961        let mut ret = HashMap::default();
 962        let mut missing_user_ids = Vec::new();
 963        for id in user_ids {
 964            if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) {
 965                ret.insert(id, github_login);
 966            } else {
 967                missing_user_ids.push(id)
 968            }
 969        }
 970        if !missing_user_ids.is_empty() {
 971            let this = self.weak_self.clone();
 972            cx.spawn(async move |cx| {
 973                this.update(cx, |this, cx| this.get_users(missing_user_ids, cx))?
 974                    .await
 975            })
 976            .detach_and_log_err(cx);
 977        }
 978        ret
 979    }
 980}
 981
 982impl User {
 983    fn new(message: proto::User) -> Arc<Self> {
 984        Arc::new(User {
 985            id: message.id,
 986            github_login: message.github_login.into(),
 987            avatar_uri: message.avatar_url.into(),
 988            name: message.name,
 989        })
 990    }
 991}
 992
 993impl Contact {
 994    async fn from_proto(
 995        contact: proto::Contact,
 996        user_store: &Entity<UserStore>,
 997        cx: &mut AsyncApp,
 998    ) -> Result<Self> {
 999        let user = user_store
1000            .update(cx, |user_store, cx| {
1001                user_store.get_user(contact.user_id, cx)
1002            })
1003            .await?;
1004        Ok(Self {
1005            user,
1006            online: contact.online,
1007            busy: contact.busy,
1008        })
1009    }
1010}
1011
1012impl Collaborator {
1013    pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
1014        Ok(Self {
1015            peer_id: message.peer_id.context("invalid peer id")?,
1016            replica_id: ReplicaId::new(message.replica_id as u16),
1017            user_id: message.user_id as UserId,
1018            is_host: message.is_host,
1019            committer_name: message.committer_name,
1020            committer_email: message.committer_email,
1021        })
1022    }
1023}
1024
1025impl RequestUsage {
1026    pub fn over_limit(&self) -> bool {
1027        match self.limit {
1028            UsageLimit::Limited(limit) => self.amount >= limit,
1029            UsageLimit::Unlimited => false,
1030        }
1031    }
1032
1033    fn from_headers(
1034        limit_name: &str,
1035        amount_name: &str,
1036        headers: &HeaderMap<HeaderValue>,
1037    ) -> Result<Self> {
1038        let limit = headers
1039            .get(limit_name)
1040            .with_context(|| format!("missing {limit_name:?} header"))?;
1041        let limit = UsageLimit::from_str(limit.to_str()?)?;
1042
1043        let amount = headers
1044            .get(amount_name)
1045            .with_context(|| format!("missing {amount_name:?} header"))?;
1046        let amount = amount.to_str()?.parse::<i32>()?;
1047
1048        Ok(Self { limit, amount })
1049    }
1050}
1051
1052impl EditPredictionUsage {
1053    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
1054        Ok(Self(RequestUsage::from_headers(
1055            EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
1056            EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME,
1057            headers,
1058        )?))
1059    }
1060}