user.rs

  1use super::{http::HttpClient, proto, Client, Status, TypedEnvelope};
  2use anyhow::{anyhow, Context, Result};
  3use collections::{hash_map::Entry, HashMap, HashSet};
  4use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
  5use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
  6use postage::{sink::Sink, watch};
  7use rpc::proto::{RequestMessage, UsersResponse};
  8use std::sync::{Arc, Weak};
  9use util::TryFutureExt as _;
 10
 11#[derive(Default, Debug)]
 12pub struct User {
 13    pub id: u64,
 14    pub github_login: String,
 15    pub avatar: Option<Arc<ImageData>>,
 16}
 17
 18impl PartialOrd for User {
 19    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 20        Some(self.cmp(other))
 21    }
 22}
 23
 24impl Ord for User {
 25    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
 26        self.github_login.cmp(&other.github_login)
 27    }
 28}
 29
 30impl PartialEq for User {
 31    fn eq(&self, other: &Self) -> bool {
 32        self.id == other.id && self.github_login == other.github_login
 33    }
 34}
 35
 36impl Eq for User {}
 37
 38#[derive(Debug, PartialEq)]
 39pub struct Contact {
 40    pub user: Arc<User>,
 41    pub online: bool,
 42}
 43
 44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 45pub enum ContactRequestStatus {
 46    None,
 47    RequestSent,
 48    RequestReceived,
 49    RequestAccepted,
 50}
 51
 52pub struct UserStore {
 53    users: HashMap<u64, Arc<User>>,
 54    update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
 55    current_user: watch::Receiver<Option<Arc<User>>>,
 56    contacts: Vec<Arc<Contact>>,
 57    incoming_contact_requests: Vec<Arc<User>>,
 58    outgoing_contact_requests: Vec<Arc<User>>,
 59    pending_contact_requests: HashMap<u64, usize>,
 60    invite_info: Option<InviteInfo>,
 61    client: Weak<Client>,
 62    http: Arc<dyn HttpClient>,
 63    _maintain_contacts: Task<()>,
 64    _maintain_current_user: Task<()>,
 65}
 66
 67#[derive(Clone)]
 68pub struct InviteInfo {
 69    pub count: u32,
 70    pub url: Arc<str>,
 71}
 72
 73pub enum Event {
 74    Contact {
 75        user: Arc<User>,
 76        kind: ContactEventKind,
 77    },
 78    ShowContacts,
 79}
 80
 81#[derive(Clone, Copy)]
 82pub enum ContactEventKind {
 83    Requested,
 84    Accepted,
 85    Cancelled,
 86}
 87
 88impl Entity for UserStore {
 89    type Event = Event;
 90}
 91
 92enum UpdateContacts {
 93    Update(proto::UpdateContacts),
 94    Wait(postage::barrier::Sender),
 95    Clear(postage::barrier::Sender),
 96}
 97
 98impl UserStore {
 99    pub fn new(
100        client: Arc<Client>,
101        http: Arc<dyn HttpClient>,
102        cx: &mut ModelContext<Self>,
103    ) -> Self {
104        let (mut current_user_tx, current_user_rx) = watch::channel();
105        let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
106        let rpc_subscriptions = vec![
107            client.add_message_handler(cx.handle(), Self::handle_update_contacts),
108            client.add_message_handler(cx.handle(), Self::handle_update_invite_info),
109            client.add_message_handler(cx.handle(), Self::handle_show_contacts),
110        ];
111        Self {
112            users: Default::default(),
113            current_user: current_user_rx,
114            contacts: Default::default(),
115            incoming_contact_requests: Default::default(),
116            outgoing_contact_requests: Default::default(),
117            invite_info: None,
118            client: Arc::downgrade(&client),
119            update_contacts_tx,
120            http,
121            _maintain_contacts: cx.spawn_weak(|this, mut cx| async move {
122                let _subscriptions = rpc_subscriptions;
123                while let Some(message) = update_contacts_rx.next().await {
124                    if let Some(this) = this.upgrade(&cx) {
125                        this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
126                            .log_err()
127                            .await;
128                    }
129                }
130            }),
131            _maintain_current_user: cx.spawn_weak(|this, mut cx| async move {
132                let mut status = client.status();
133                while let Some(status) = status.next().await {
134                    match status {
135                        Status::Connected { .. } => {
136                            if let Some((this, user_id)) = this.upgrade(&cx).zip(client.user_id()) {
137                                let user = this
138                                    .update(&mut cx, |this, cx| this.get_user(user_id, cx))
139                                    .log_err()
140                                    .await;
141                                current_user_tx.send(user).await.ok();
142                            }
143                        }
144                        Status::SignedOut => {
145                            current_user_tx.send(None).await.ok();
146                            if let Some(this) = this.upgrade(&cx) {
147                                this.update(&mut cx, |this, _| this.clear_contacts()).await;
148                            }
149                        }
150                        Status::ConnectionLost => {
151                            if let Some(this) = this.upgrade(&cx) {
152                                this.update(&mut cx, |this, _| this.clear_contacts()).await;
153                            }
154                        }
155                        _ => {}
156                    }
157                }
158            }),
159            pending_contact_requests: Default::default(),
160        }
161    }
162
163    async fn handle_update_invite_info(
164        this: ModelHandle<Self>,
165        message: TypedEnvelope<proto::UpdateInviteInfo>,
166        _: Arc<Client>,
167        mut cx: AsyncAppContext,
168    ) -> Result<()> {
169        this.update(&mut cx, |this, cx| {
170            this.invite_info = Some(InviteInfo {
171                url: Arc::from(message.payload.url),
172                count: message.payload.count,
173            });
174            cx.notify();
175        });
176        Ok(())
177    }
178
179    async fn handle_show_contacts(
180        this: ModelHandle<Self>,
181        _: TypedEnvelope<proto::ShowContacts>,
182        _: Arc<Client>,
183        mut cx: AsyncAppContext,
184    ) -> Result<()> {
185        this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts));
186        Ok(())
187    }
188
189    pub fn invite_info(&self) -> Option<&InviteInfo> {
190        self.invite_info.as_ref()
191    }
192
193    async fn handle_update_contacts(
194        this: ModelHandle<Self>,
195        message: TypedEnvelope<proto::UpdateContacts>,
196        _: Arc<Client>,
197        mut cx: AsyncAppContext,
198    ) -> Result<()> {
199        this.update(&mut cx, |this, _| {
200            this.update_contacts_tx
201                .unbounded_send(UpdateContacts::Update(message.payload))
202                .unwrap();
203        });
204        Ok(())
205    }
206
207    fn update_contacts(
208        &mut self,
209        message: UpdateContacts,
210        cx: &mut ModelContext<Self>,
211    ) -> Task<Result<()>> {
212        match message {
213            UpdateContacts::Wait(barrier) => {
214                drop(barrier);
215                Task::ready(Ok(()))
216            }
217            UpdateContacts::Clear(barrier) => {
218                self.contacts.clear();
219                self.incoming_contact_requests.clear();
220                self.outgoing_contact_requests.clear();
221                drop(barrier);
222                Task::ready(Ok(()))
223            }
224            UpdateContacts::Update(message) => {
225                let mut user_ids = HashSet::default();
226                for contact in &message.contacts {
227                    user_ids.insert(contact.user_id);
228                }
229                user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
230                user_ids.extend(message.outgoing_requests.iter());
231
232                let load_users = self.get_users(user_ids.into_iter().collect(), cx);
233                cx.spawn(|this, mut cx| async move {
234                    load_users.await?;
235
236                    // Users are fetched in parallel above and cached in call to get_users
237                    // No need to paralellize here
238                    let mut updated_contacts = Vec::new();
239                    for contact in message.contacts {
240                        let should_notify = contact.should_notify;
241                        updated_contacts.push((
242                            Arc::new(Contact::from_proto(contact, &this, &mut cx).await?),
243                            should_notify,
244                        ));
245                    }
246
247                    let mut incoming_requests = Vec::new();
248                    for request in message.incoming_requests {
249                        incoming_requests.push({
250                            let user = this
251                                .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx))
252                                .await?;
253                            (user, request.should_notify)
254                        });
255                    }
256
257                    let mut outgoing_requests = Vec::new();
258                    for requested_user_id in message.outgoing_requests {
259                        outgoing_requests.push(
260                            this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))
261                                .await?,
262                        );
263                    }
264
265                    let removed_contacts =
266                        HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
267                    let removed_incoming_requests =
268                        HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
269                    let removed_outgoing_requests =
270                        HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
271
272                    this.update(&mut cx, |this, cx| {
273                        // Remove contacts
274                        this.contacts
275                            .retain(|contact| !removed_contacts.contains(&contact.user.id));
276                        // Update existing contacts and insert new ones
277                        for (updated_contact, should_notify) in updated_contacts {
278                            if should_notify {
279                                cx.emit(Event::Contact {
280                                    user: updated_contact.user.clone(),
281                                    kind: ContactEventKind::Accepted,
282                                });
283                            }
284                            match this.contacts.binary_search_by_key(
285                                &&updated_contact.user.github_login,
286                                |contact| &contact.user.github_login,
287                            ) {
288                                Ok(ix) => this.contacts[ix] = updated_contact,
289                                Err(ix) => this.contacts.insert(ix, updated_contact),
290                            }
291                        }
292
293                        // Remove incoming contact requests
294                        this.incoming_contact_requests.retain(|user| {
295                            if removed_incoming_requests.contains(&user.id) {
296                                cx.emit(Event::Contact {
297                                    user: user.clone(),
298                                    kind: ContactEventKind::Cancelled,
299                                });
300                                false
301                            } else {
302                                true
303                            }
304                        });
305                        // Update existing incoming requests and insert new ones
306                        for (user, should_notify) in incoming_requests {
307                            if should_notify {
308                                cx.emit(Event::Contact {
309                                    user: user.clone(),
310                                    kind: ContactEventKind::Requested,
311                                });
312                            }
313
314                            match this
315                                .incoming_contact_requests
316                                .binary_search_by_key(&&user.github_login, |contact| {
317                                    &contact.github_login
318                                }) {
319                                Ok(ix) => this.incoming_contact_requests[ix] = user,
320                                Err(ix) => this.incoming_contact_requests.insert(ix, user),
321                            }
322                        }
323
324                        // Remove outgoing contact requests
325                        this.outgoing_contact_requests
326                            .retain(|user| !removed_outgoing_requests.contains(&user.id));
327                        // Update existing incoming requests and insert new ones
328                        for request in outgoing_requests {
329                            match this
330                                .outgoing_contact_requests
331                                .binary_search_by_key(&&request.github_login, |contact| {
332                                    &contact.github_login
333                                }) {
334                                Ok(ix) => this.outgoing_contact_requests[ix] = request,
335                                Err(ix) => this.outgoing_contact_requests.insert(ix, request),
336                            }
337                        }
338
339                        cx.notify();
340                    });
341
342                    Ok(())
343                })
344            }
345        }
346    }
347
348    pub fn contacts(&self) -> &[Arc<Contact>] {
349        &self.contacts
350    }
351
352    pub fn has_contact(&self, user: &Arc<User>) -> bool {
353        self.contacts
354            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
355            .is_ok()
356    }
357
358    pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
359        &self.incoming_contact_requests
360    }
361
362    pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
363        &self.outgoing_contact_requests
364    }
365
366    pub fn is_contact_request_pending(&self, user: &User) -> bool {
367        self.pending_contact_requests.contains_key(&user.id)
368    }
369
370    pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
371        if self
372            .contacts
373            .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
374            .is_ok()
375        {
376            ContactRequestStatus::RequestAccepted
377        } else if self
378            .outgoing_contact_requests
379            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
380            .is_ok()
381        {
382            ContactRequestStatus::RequestSent
383        } else if self
384            .incoming_contact_requests
385            .binary_search_by_key(&&user.github_login, |user| &user.github_login)
386            .is_ok()
387        {
388            ContactRequestStatus::RequestReceived
389        } else {
390            ContactRequestStatus::None
391        }
392    }
393
394    pub fn request_contact(
395        &mut self,
396        responder_id: u64,
397        cx: &mut ModelContext<Self>,
398    ) -> Task<Result<()>> {
399        self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
400    }
401
402    pub fn remove_contact(
403        &mut self,
404        user_id: u64,
405        cx: &mut ModelContext<Self>,
406    ) -> Task<Result<()>> {
407        self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
408    }
409
410    pub fn respond_to_contact_request(
411        &mut self,
412        requester_id: u64,
413        accept: bool,
414        cx: &mut ModelContext<Self>,
415    ) -> Task<Result<()>> {
416        self.perform_contact_request(
417            requester_id,
418            proto::RespondToContactRequest {
419                requester_id,
420                response: if accept {
421                    proto::ContactRequestResponse::Accept
422                } else {
423                    proto::ContactRequestResponse::Decline
424                } as i32,
425            },
426            cx,
427        )
428    }
429
430    pub fn dismiss_contact_request(
431        &mut self,
432        requester_id: u64,
433        cx: &mut ModelContext<Self>,
434    ) -> Task<Result<()>> {
435        let client = self.client.upgrade();
436        cx.spawn_weak(|_, _| async move {
437            client
438                .ok_or_else(|| anyhow!("can't upgrade client reference"))?
439                .request(proto::RespondToContactRequest {
440                    requester_id,
441                    response: proto::ContactRequestResponse::Dismiss as i32,
442                })
443                .await?;
444            Ok(())
445        })
446    }
447
448    fn perform_contact_request<T: RequestMessage>(
449        &mut self,
450        user_id: u64,
451        request: T,
452        cx: &mut ModelContext<Self>,
453    ) -> Task<Result<()>> {
454        let client = self.client.upgrade();
455        *self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
456        cx.notify();
457
458        cx.spawn(|this, mut cx| async move {
459            let response = client
460                .ok_or_else(|| anyhow!("can't upgrade client reference"))?
461                .request(request)
462                .await;
463            this.update(&mut cx, |this, cx| {
464                if let Entry::Occupied(mut request_count) =
465                    this.pending_contact_requests.entry(user_id)
466                {
467                    *request_count.get_mut() -= 1;
468                    if *request_count.get() == 0 {
469                        request_count.remove();
470                    }
471                }
472                cx.notify();
473            });
474            response?;
475            Ok(())
476        })
477    }
478
479    pub fn clear_contacts(&mut self) -> impl Future<Output = ()> {
480        let (tx, mut rx) = postage::barrier::channel();
481        self.update_contacts_tx
482            .unbounded_send(UpdateContacts::Clear(tx))
483            .unwrap();
484        async move {
485            rx.next().await;
486        }
487    }
488
489    pub fn contact_updates_done(&mut self) -> impl Future<Output = ()> {
490        let (tx, mut rx) = postage::barrier::channel();
491        self.update_contacts_tx
492            .unbounded_send(UpdateContacts::Wait(tx))
493            .unwrap();
494        async move {
495            rx.next().await;
496        }
497    }
498
499    pub fn get_users(
500        &mut self,
501        user_ids: Vec<u64>,
502        cx: &mut ModelContext<Self>,
503    ) -> Task<Result<Vec<Arc<User>>>> {
504        let mut user_ids_to_fetch = user_ids.clone();
505        user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
506
507        cx.spawn(|this, mut cx| async move {
508            if !user_ids_to_fetch.is_empty() {
509                this.update(&mut cx, |this, cx| {
510                    this.load_users(
511                        proto::GetUsers {
512                            user_ids: user_ids_to_fetch,
513                        },
514                        cx,
515                    )
516                })
517                .await?;
518            }
519
520            this.read_with(&cx, |this, _| {
521                user_ids
522                    .iter()
523                    .map(|user_id| {
524                        this.users
525                            .get(user_id)
526                            .cloned()
527                            .ok_or_else(|| anyhow!("user {} not found", user_id))
528                    })
529                    .collect()
530            })
531        })
532    }
533
534    pub fn fuzzy_search_users(
535        &mut self,
536        query: String,
537        cx: &mut ModelContext<Self>,
538    ) -> Task<Result<Vec<Arc<User>>>> {
539        self.load_users(proto::FuzzySearchUsers { query }, cx)
540    }
541
542    pub fn get_user(
543        &mut self,
544        user_id: u64,
545        cx: &mut ModelContext<Self>,
546    ) -> Task<Result<Arc<User>>> {
547        if let Some(user) = self.users.get(&user_id).cloned() {
548            return cx.foreground().spawn(async move { Ok(user) });
549        }
550
551        let load_users = self.get_users(vec![user_id], cx);
552        cx.spawn(|this, mut cx| async move {
553            load_users.await?;
554            this.update(&mut cx, |this, _| {
555                this.users
556                    .get(&user_id)
557                    .cloned()
558                    .ok_or_else(|| anyhow!("server responded with no users"))
559            })
560        })
561    }
562
563    pub fn current_user(&self) -> Option<Arc<User>> {
564        self.current_user.borrow().clone()
565    }
566
567    pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
568        self.current_user.clone()
569    }
570
571    fn load_users(
572        &mut self,
573        request: impl RequestMessage<Response = UsersResponse>,
574        cx: &mut ModelContext<Self>,
575    ) -> Task<Result<Vec<Arc<User>>>> {
576        let client = self.client.clone();
577        let http = self.http.clone();
578        cx.spawn_weak(|this, mut cx| async move {
579            if let Some(rpc) = client.upgrade() {
580                let response = rpc.request(request).await.context("error loading users")?;
581                let users = future::join_all(
582                    response
583                        .users
584                        .into_iter()
585                        .map(|user| User::new(user, http.as_ref())),
586                )
587                .await;
588
589                if let Some(this) = this.upgrade(&cx) {
590                    this.update(&mut cx, |this, _| {
591                        for user in &users {
592                            this.users.insert(user.id, user.clone());
593                        }
594                    });
595                }
596                Ok(users)
597            } else {
598                Ok(Vec::new())
599            }
600        })
601    }
602}
603
604impl User {
605    async fn new(message: proto::User, http: &dyn HttpClient) -> Arc<Self> {
606        Arc::new(User {
607            id: message.id,
608            github_login: message.github_login,
609            avatar: fetch_avatar(http, &message.avatar_url).warn_on_err().await,
610        })
611    }
612}
613
614impl Contact {
615    async fn from_proto(
616        contact: proto::Contact,
617        user_store: &ModelHandle<UserStore>,
618        cx: &mut AsyncAppContext,
619    ) -> Result<Self> {
620        let user = user_store
621            .update(cx, |user_store, cx| {
622                user_store.get_user(contact.user_id, cx)
623            })
624            .await?;
625        Ok(Self {
626            user,
627            online: contact.online,
628        })
629    }
630}
631
632async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result<Arc<ImageData>> {
633    let mut response = http
634        .get(url, Default::default(), true)
635        .await
636        .map_err(|e| anyhow!("failed to send user avatar request: {}", e))?;
637
638    if !response.status().is_success() {
639        return Err(anyhow!("avatar request failed {:?}", response.status()));
640    }
641
642    let mut body = Vec::new();
643    response
644        .body_mut()
645        .read_to_end(&mut body)
646        .await
647        .map_err(|e| anyhow!("failed to read user avatar response body: {}", e))?;
648    let format = image::guess_format(&body)?;
649    let image = image::load_from_memory_with_format(&body, format)?.into_bgra8();
650    Ok(ImageData::new(image))
651}