rpc.rs

  1use super::{
  2    auth,
  3    db::{ChannelId, UserId},
  4    AppState,
  5};
  6use anyhow::anyhow;
  7use async_std::{sync::RwLock, task};
  8use async_tungstenite::{
  9    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
 10    WebSocketStream,
 11};
 12use futures::{future::BoxFuture, FutureExt};
 13use postage::prelude::Stream as _;
 14use sha1::{Digest as _, Sha1};
 15use std::{
 16    any::TypeId,
 17    collections::{HashMap, HashSet},
 18    future::Future,
 19    mem,
 20    sync::Arc,
 21    time::Instant,
 22};
 23use surf::StatusCode;
 24use tide::log;
 25use tide::{
 26    http::headers::{HeaderName, CONNECTION, UPGRADE},
 27    Request, Response,
 28};
 29use time::OffsetDateTime;
 30use zrpc::{
 31    auth::random_token,
 32    proto::{self, AnyTypedEnvelope, EnvelopedMessage},
 33    ConnectionId, Peer, TypedEnvelope,
 34};
 35
 36type ReplicaId = u16;
 37
 38type MessageHandler = Box<
 39    dyn Send
 40        + Sync
 41        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
 42>;
 43
 44pub struct Server {
 45    peer: Arc<Peer>,
 46    state: RwLock<ServerState>,
 47    app_state: Arc<AppState>,
 48    handlers: HashMap<TypeId, MessageHandler>,
 49}
 50
 51#[derive(Default)]
 52struct ServerState {
 53    connections: HashMap<ConnectionId, Connection>,
 54    pub worktrees: HashMap<u64, Worktree>,
 55    channels: HashMap<ChannelId, Channel>,
 56    next_worktree_id: u64,
 57}
 58
 59struct Connection {
 60    user_id: UserId,
 61    worktrees: HashSet<u64>,
 62    channels: HashSet<ChannelId>,
 63}
 64
 65struct Worktree {
 66    host_connection_id: Option<ConnectionId>,
 67    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 68    active_replica_ids: HashSet<ReplicaId>,
 69    access_token: String,
 70    root_name: String,
 71    entries: HashMap<u64, proto::Entry>,
 72}
 73
 74#[derive(Default)]
 75struct Channel {
 76    connection_ids: HashSet<ConnectionId>,
 77}
 78
 79impl Server {
 80    pub fn new(app_state: Arc<AppState>, peer: Arc<Peer>) -> Arc<Self> {
 81        let mut server = Server {
 82            peer,
 83            app_state,
 84            state: Default::default(),
 85            handlers: Default::default(),
 86        };
 87
 88        server
 89            .add_handler(Server::share_worktree)
 90            .add_handler(Server::join_worktree)
 91            .add_handler(Server::update_worktree)
 92            .add_handler(Server::close_worktree)
 93            .add_handler(Server::open_buffer)
 94            .add_handler(Server::close_buffer)
 95            .add_handler(Server::update_buffer)
 96            .add_handler(Server::buffer_saved)
 97            .add_handler(Server::save_buffer)
 98            .add_handler(Server::get_channels)
 99            .add_handler(Server::get_users)
100            .add_handler(Server::join_channel)
101            .add_handler(Server::send_channel_message);
102
103        Arc::new(server)
104    }
105
106    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
107    where
108        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
109        Fut: 'static + Send + Future<Output = tide::Result<()>>,
110        M: EnvelopedMessage,
111    {
112        let prev_handler = self.handlers.insert(
113            TypeId::of::<M>(),
114            Box::new(move |server, envelope| {
115                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
116                (handler)(server, *envelope).boxed()
117            }),
118        );
119        if prev_handler.is_some() {
120            panic!("registered a handler for the same message twice");
121        }
122        self
123    }
124
125    pub fn handle_connection<Conn>(
126        self: &Arc<Self>,
127        connection: Conn,
128        addr: String,
129        user_id: UserId,
130    ) -> impl Future<Output = ()>
131    where
132        Conn: 'static
133            + futures::Sink<WebSocketMessage, Error = WebSocketError>
134            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
135            + Send
136            + Unpin,
137    {
138        let this = self.clone();
139        async move {
140            let (connection_id, handle_io, mut incoming_rx) =
141                this.peer.add_connection(connection).await;
142            this.add_connection(connection_id, user_id).await;
143
144            let handle_io = handle_io.fuse();
145            futures::pin_mut!(handle_io);
146            loop {
147                let next_message = incoming_rx.recv().fuse();
148                futures::pin_mut!(next_message);
149                futures::select_biased! {
150                    message = next_message => {
151                        if let Some(message) = message {
152                            let start_time = Instant::now();
153                            log::info!("RPC message received: {}", message.payload_type_name());
154                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
155                                if let Err(err) = (handler)(this.clone(), message).await {
156                                    log::error!("error handling message: {:?}", err);
157                                } else {
158                                    log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
159                                }
160                            } else {
161                                log::warn!("unhandled message: {}", message.payload_type_name());
162                            }
163                        } else {
164                            log::info!("rpc connection closed {:?}", addr);
165                            break;
166                        }
167                    }
168                    handle_io = handle_io => {
169                        if let Err(err) = handle_io {
170                            log::error!("error handling rpc connection {:?} - {:?}", addr, err);
171                        }
172                        break;
173                    }
174                }
175            }
176
177            if let Err(err) = this.sign_out(connection_id).await {
178                log::error!("error signing out connection {:?} - {:?}", addr, err);
179            }
180        }
181    }
182
183    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
184        self.peer.disconnect(connection_id).await;
185        let worktree_ids = self.remove_connection(connection_id).await;
186        for worktree_id in worktree_ids {
187            let state = self.state.read().await;
188            if let Some(worktree) = state.worktrees.get(&worktree_id) {
189                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
190                    self.peer.send(
191                        conn_id,
192                        proto::RemovePeer {
193                            worktree_id,
194                            peer_id: connection_id.0,
195                        },
196                    )
197                })
198                .await?;
199            }
200        }
201        Ok(())
202    }
203
204    // Add a new connection associated with a given user.
205    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
206        self.state.write().await.connections.insert(
207            connection_id,
208            Connection {
209                user_id,
210                worktrees: Default::default(),
211                channels: Default::default(),
212            },
213        );
214    }
215
216    // Remove the given connection and its association with any worktrees.
217    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
218        let mut worktree_ids = Vec::new();
219        let mut state = self.state.write().await;
220        if let Some(connection) = state.connections.remove(&connection_id) {
221            for channel_id in connection.channels {
222                if let Some(channel) = state.channels.get_mut(&channel_id) {
223                    channel.connection_ids.remove(&connection_id);
224                }
225            }
226            for worktree_id in connection.worktrees {
227                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
228                    if worktree.host_connection_id == Some(connection_id) {
229                        worktree_ids.push(worktree_id);
230                    } else if let Some(replica_id) =
231                        worktree.guest_connection_ids.remove(&connection_id)
232                    {
233                        worktree.active_replica_ids.remove(&replica_id);
234                        worktree_ids.push(worktree_id);
235                    }
236                }
237            }
238        }
239        worktree_ids
240    }
241
242    async fn share_worktree(
243        self: Arc<Server>,
244        mut request: TypedEnvelope<proto::ShareWorktree>,
245    ) -> tide::Result<()> {
246        let mut state = self.state.write().await;
247        let worktree_id = state.next_worktree_id;
248        state.next_worktree_id += 1;
249        let access_token = random_token();
250        let worktree = request
251            .payload
252            .worktree
253            .as_mut()
254            .ok_or_else(|| anyhow!("missing worktree"))?;
255        let entries = mem::take(&mut worktree.entries)
256            .into_iter()
257            .map(|entry| (entry.id, entry))
258            .collect();
259        state.worktrees.insert(
260            worktree_id,
261            Worktree {
262                host_connection_id: Some(request.sender_id),
263                guest_connection_ids: Default::default(),
264                active_replica_ids: Default::default(),
265                access_token: access_token.clone(),
266                root_name: mem::take(&mut worktree.root_name),
267                entries,
268            },
269        );
270
271        self.peer
272            .respond(
273                request.receipt(),
274                proto::ShareWorktreeResponse {
275                    worktree_id,
276                    access_token,
277                },
278            )
279            .await?;
280        Ok(())
281    }
282
283    async fn join_worktree(
284        self: Arc<Server>,
285        request: TypedEnvelope<proto::OpenWorktree>,
286    ) -> tide::Result<()> {
287        let worktree_id = request.payload.worktree_id;
288        let access_token = &request.payload.access_token;
289
290        let mut state = self.state.write().await;
291        if let Some((peer_replica_id, worktree)) =
292            state.join_worktree(request.sender_id, worktree_id, access_token)
293        {
294            let mut peers = Vec::new();
295            if let Some(host_connection_id) = worktree.host_connection_id {
296                peers.push(proto::Peer {
297                    peer_id: host_connection_id.0,
298                    replica_id: 0,
299                });
300            }
301            for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
302                if *peer_conn_id != request.sender_id {
303                    peers.push(proto::Peer {
304                        peer_id: peer_conn_id.0,
305                        replica_id: *peer_replica_id as u32,
306                    });
307                }
308            }
309
310            broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
311                self.peer.send(
312                    conn_id,
313                    proto::AddPeer {
314                        worktree_id,
315                        peer: Some(proto::Peer {
316                            peer_id: request.sender_id.0,
317                            replica_id: peer_replica_id as u32,
318                        }),
319                    },
320                )
321            })
322            .await?;
323            self.peer
324                .respond(
325                    request.receipt(),
326                    proto::OpenWorktreeResponse {
327                        worktree_id,
328                        worktree: Some(proto::Worktree {
329                            root_name: worktree.root_name.clone(),
330                            entries: worktree.entries.values().cloned().collect(),
331                        }),
332                        replica_id: peer_replica_id as u32,
333                        peers,
334                    },
335                )
336                .await?;
337        } else {
338            self.peer
339                .respond(
340                    request.receipt(),
341                    proto::OpenWorktreeResponse {
342                        worktree_id,
343                        worktree: None,
344                        replica_id: 0,
345                        peers: Vec::new(),
346                    },
347                )
348                .await?;
349        }
350
351        Ok(())
352    }
353
354    async fn update_worktree(
355        self: Arc<Server>,
356        request: TypedEnvelope<proto::UpdateWorktree>,
357    ) -> tide::Result<()> {
358        {
359            let mut state = self.state.write().await;
360            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
361            for entry_id in &request.payload.removed_entries {
362                worktree.entries.remove(&entry_id);
363            }
364
365            for entry in &request.payload.updated_entries {
366                worktree.entries.insert(entry.id, entry.clone());
367            }
368        }
369
370        self.broadcast_in_worktree(request.payload.worktree_id, &request)
371            .await?;
372        Ok(())
373    }
374
375    async fn close_worktree(
376        self: Arc<Server>,
377        request: TypedEnvelope<proto::CloseWorktree>,
378    ) -> tide::Result<()> {
379        let connection_ids;
380        {
381            let mut state = self.state.write().await;
382            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
383            connection_ids = worktree.connection_ids();
384            if worktree.host_connection_id == Some(request.sender_id) {
385                worktree.host_connection_id = None;
386            } else if let Some(replica_id) =
387                worktree.guest_connection_ids.remove(&request.sender_id)
388            {
389                worktree.active_replica_ids.remove(&replica_id);
390            }
391        }
392
393        broadcast(request.sender_id, connection_ids, |conn_id| {
394            self.peer.send(
395                conn_id,
396                proto::RemovePeer {
397                    worktree_id: request.payload.worktree_id,
398                    peer_id: request.sender_id.0,
399                },
400            )
401        })
402        .await?;
403
404        Ok(())
405    }
406
407    async fn open_buffer(
408        self: Arc<Server>,
409        request: TypedEnvelope<proto::OpenBuffer>,
410    ) -> tide::Result<()> {
411        let receipt = request.receipt();
412        let worktree_id = request.payload.worktree_id;
413        let host_connection_id = self
414            .state
415            .read()
416            .await
417            .read_worktree(worktree_id, request.sender_id)?
418            .host_connection_id()?;
419
420        let response = self
421            .peer
422            .forward_request(request.sender_id, host_connection_id, request.payload)
423            .await?;
424        self.peer.respond(receipt, response).await?;
425        Ok(())
426    }
427
428    async fn close_buffer(
429        self: Arc<Server>,
430        request: TypedEnvelope<proto::CloseBuffer>,
431    ) -> tide::Result<()> {
432        let host_connection_id = self
433            .state
434            .read()
435            .await
436            .read_worktree(request.payload.worktree_id, request.sender_id)?
437            .host_connection_id()?;
438
439        self.peer
440            .forward_send(request.sender_id, host_connection_id, request.payload)
441            .await?;
442
443        Ok(())
444    }
445
446    async fn save_buffer(
447        self: Arc<Server>,
448        request: TypedEnvelope<proto::SaveBuffer>,
449    ) -> tide::Result<()> {
450        let host;
451        let guests;
452        {
453            let state = self.state.read().await;
454            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
455            host = worktree.host_connection_id()?;
456            guests = worktree
457                .guest_connection_ids
458                .keys()
459                .copied()
460                .collect::<Vec<_>>();
461        }
462
463        let sender = request.sender_id;
464        let receipt = request.receipt();
465        let response = self
466            .peer
467            .forward_request(sender, host, request.payload.clone())
468            .await?;
469
470        broadcast(host, guests, |conn_id| {
471            let response = response.clone();
472            let peer = &self.peer;
473            async move {
474                if conn_id == sender {
475                    peer.respond(receipt, response).await
476                } else {
477                    peer.forward_send(host, conn_id, response).await
478                }
479            }
480        })
481        .await?;
482
483        Ok(())
484    }
485
486    async fn update_buffer(
487        self: Arc<Server>,
488        request: TypedEnvelope<proto::UpdateBuffer>,
489    ) -> tide::Result<()> {
490        self.broadcast_in_worktree(request.payload.worktree_id, &request)
491            .await
492    }
493
494    async fn buffer_saved(
495        self: Arc<Server>,
496        request: TypedEnvelope<proto::BufferSaved>,
497    ) -> tide::Result<()> {
498        self.broadcast_in_worktree(request.payload.worktree_id, &request)
499            .await
500    }
501
502    async fn get_channels(
503        self: Arc<Server>,
504        request: TypedEnvelope<proto::GetChannels>,
505    ) -> tide::Result<()> {
506        let user_id = self
507            .state
508            .read()
509            .await
510            .user_id_for_connection(request.sender_id)?;
511        let channels = self.app_state.db.get_channels_for_user(user_id).await?;
512        self.peer
513            .respond(
514                request.receipt(),
515                proto::GetChannelsResponse {
516                    channels: channels
517                        .into_iter()
518                        .map(|chan| proto::Channel {
519                            id: chan.id.to_proto(),
520                            name: chan.name,
521                        })
522                        .collect(),
523                },
524            )
525            .await?;
526        Ok(())
527    }
528
529    async fn get_users(
530        self: Arc<Server>,
531        request: TypedEnvelope<proto::GetUsers>,
532    ) -> tide::Result<()> {
533        let user_id = self
534            .state
535            .read()
536            .await
537            .user_id_for_connection(request.sender_id)?;
538        let receipt = request.receipt();
539        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
540        let users = self
541            .app_state
542            .db
543            .get_users_by_ids(user_id, user_ids)
544            .await?
545            .into_iter()
546            .map(|user| proto::User {
547                id: user.id.to_proto(),
548                github_login: user.github_login,
549                avatar_url: String::new(),
550            })
551            .collect();
552        self.peer
553            .respond(receipt, proto::GetUsersResponse { users })
554            .await?;
555        Ok(())
556    }
557
558    async fn join_channel(
559        self: Arc<Self>,
560        request: TypedEnvelope<proto::JoinChannel>,
561    ) -> tide::Result<()> {
562        let user_id = self
563            .state
564            .read()
565            .await
566            .user_id_for_connection(request.sender_id)?;
567        let channel_id = ChannelId::from_proto(request.payload.channel_id);
568        if !self
569            .app_state
570            .db
571            .can_user_access_channel(user_id, channel_id)
572            .await?
573        {
574            Err(anyhow!("access denied"))?;
575        }
576
577        self.state
578            .write()
579            .await
580            .join_channel(request.sender_id, channel_id);
581        let messages = self
582            .app_state
583            .db
584            .get_recent_channel_messages(channel_id, 50)
585            .await?
586            .into_iter()
587            .map(|msg| proto::ChannelMessage {
588                id: msg.id.to_proto(),
589                body: msg.body,
590                timestamp: msg.sent_at.unix_timestamp() as u64,
591                sender_id: msg.sender_id.to_proto(),
592            })
593            .collect();
594        self.peer
595            .respond(request.receipt(), proto::JoinChannelResponse { messages })
596            .await?;
597        Ok(())
598    }
599
600    async fn send_channel_message(
601        self: Arc<Self>,
602        request: TypedEnvelope<proto::SendChannelMessage>,
603    ) -> tide::Result<()> {
604        let channel_id = ChannelId::from_proto(request.payload.channel_id);
605        let user_id;
606        let connection_ids;
607        {
608            let state = self.state.read().await;
609            user_id = state.user_id_for_connection(request.sender_id)?;
610            if let Some(channel) = state.channels.get(&channel_id) {
611                connection_ids = channel.connection_ids();
612            } else {
613                return Ok(());
614            }
615        }
616
617        let timestamp = OffsetDateTime::now_utc();
618        let message_id = self
619            .app_state
620            .db
621            .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
622            .await?
623            .to_proto();
624        let receipt = request.receipt();
625        let message = proto::ChannelMessageSent {
626            channel_id: channel_id.to_proto(),
627            message: Some(proto::ChannelMessage {
628                sender_id: user_id.to_proto(),
629                id: message_id,
630                body: request.payload.body,
631                timestamp: timestamp.unix_timestamp() as u64,
632            }),
633        };
634        broadcast(request.sender_id, connection_ids, |conn_id| {
635            self.peer.send(conn_id, message.clone())
636        })
637        .await?;
638        self.peer
639            .respond(
640                receipt,
641                proto::SendChannelMessageResponse {
642                    message_id,
643                    timestamp: timestamp.unix_timestamp() as u64,
644                },
645            )
646            .await?;
647        Ok(())
648    }
649
650    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
651        &self,
652        worktree_id: u64,
653        message: &TypedEnvelope<T>,
654    ) -> tide::Result<()> {
655        let connection_ids = self
656            .state
657            .read()
658            .await
659            .read_worktree(worktree_id, message.sender_id)?
660            .connection_ids();
661
662        broadcast(message.sender_id, connection_ids, |conn_id| {
663            self.peer
664                .forward_send(message.sender_id, conn_id, message.payload.clone())
665        })
666        .await?;
667
668        Ok(())
669    }
670}
671
672pub async fn broadcast<F, T>(
673    sender_id: ConnectionId,
674    receiver_ids: Vec<ConnectionId>,
675    mut f: F,
676) -> anyhow::Result<()>
677where
678    F: FnMut(ConnectionId) -> T,
679    T: Future<Output = anyhow::Result<()>>,
680{
681    let futures = receiver_ids
682        .into_iter()
683        .filter(|id| *id != sender_id)
684        .map(|id| f(id));
685    futures::future::try_join_all(futures).await?;
686    Ok(())
687}
688
689impl ServerState {
690    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
691        if let Some(connection) = self.connections.get_mut(&connection_id) {
692            connection.channels.insert(channel_id);
693            self.channels
694                .entry(channel_id)
695                .or_default()
696                .connection_ids
697                .insert(connection_id);
698        }
699    }
700
701    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
702        Ok(self
703            .connections
704            .get(&connection_id)
705            .ok_or_else(|| anyhow!("unknown connection"))?
706            .user_id)
707    }
708
709    // Add the given connection as a guest of the given worktree
710    fn join_worktree(
711        &mut self,
712        connection_id: ConnectionId,
713        worktree_id: u64,
714        access_token: &str,
715    ) -> Option<(ReplicaId, &Worktree)> {
716        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
717            if access_token == worktree.access_token {
718                if let Some(connection) = self.connections.get_mut(&connection_id) {
719                    connection.worktrees.insert(worktree_id);
720                }
721
722                let mut replica_id = 1;
723                while worktree.active_replica_ids.contains(&replica_id) {
724                    replica_id += 1;
725                }
726                worktree.active_replica_ids.insert(replica_id);
727                worktree
728                    .guest_connection_ids
729                    .insert(connection_id, replica_id);
730                Some((replica_id, worktree))
731            } else {
732                None
733            }
734        } else {
735            None
736        }
737    }
738
739    fn read_worktree(
740        &self,
741        worktree_id: u64,
742        connection_id: ConnectionId,
743    ) -> tide::Result<&Worktree> {
744        let worktree = self
745            .worktrees
746            .get(&worktree_id)
747            .ok_or_else(|| anyhow!("worktree not found"))?;
748
749        if worktree.host_connection_id == Some(connection_id)
750            || worktree.guest_connection_ids.contains_key(&connection_id)
751        {
752            Ok(worktree)
753        } else {
754            Err(anyhow!(
755                "{} is not a member of worktree {}",
756                connection_id,
757                worktree_id
758            ))?
759        }
760    }
761
762    fn write_worktree(
763        &mut self,
764        worktree_id: u64,
765        connection_id: ConnectionId,
766    ) -> tide::Result<&mut Worktree> {
767        let worktree = self
768            .worktrees
769            .get_mut(&worktree_id)
770            .ok_or_else(|| anyhow!("worktree not found"))?;
771
772        if worktree.host_connection_id == Some(connection_id)
773            || worktree.guest_connection_ids.contains_key(&connection_id)
774        {
775            Ok(worktree)
776        } else {
777            Err(anyhow!(
778                "{} is not a member of worktree {}",
779                connection_id,
780                worktree_id
781            ))?
782        }
783    }
784}
785
786impl Worktree {
787    pub fn connection_ids(&self) -> Vec<ConnectionId> {
788        self.guest_connection_ids
789            .keys()
790            .copied()
791            .chain(self.host_connection_id)
792            .collect()
793    }
794
795    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
796        Ok(self
797            .host_connection_id
798            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
799    }
800}
801
802impl Channel {
803    fn connection_ids(&self) -> Vec<ConnectionId> {
804        self.connection_ids.iter().copied().collect()
805    }
806}
807
808pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
809    let server = Server::new(app.state().clone(), rpc.clone());
810    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
811        let user_id = request.ext::<UserId>().copied();
812        let server = server.clone();
813        async move {
814            const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
815
816            let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
817            let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
818            let upgrade_requested = connection_upgrade && upgrade_to_websocket;
819
820            if !upgrade_requested {
821                return Ok(Response::new(StatusCode::UpgradeRequired));
822            }
823
824            let header = match request.header("Sec-Websocket-Key") {
825                Some(h) => h.as_str(),
826                None => return Err(anyhow!("expected sec-websocket-key"))?,
827            };
828
829            let mut response = Response::new(StatusCode::SwitchingProtocols);
830            response.insert_header(UPGRADE, "websocket");
831            response.insert_header(CONNECTION, "Upgrade");
832            let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
833            response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
834            response.insert_header("Sec-Websocket-Version", "13");
835
836            let http_res: &mut tide::http::Response = response.as_mut();
837            let upgrade_receiver = http_res.recv_upgrade().await;
838            let addr = request.remote().unwrap_or("unknown").to_string();
839            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
840            task::spawn(async move {
841                if let Some(stream) = upgrade_receiver.await {
842                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
843                    server.handle_connection(stream, addr, user_id).await;
844                }
845            });
846
847            Ok(response)
848        }
849    });
850}
851
852fn header_contains_ignore_case<T>(
853    request: &tide::Request<T>,
854    header_name: HeaderName,
855    value: &str,
856) -> bool {
857    request
858        .header(header_name)
859        .map(|h| {
860            h.as_str()
861                .split(',')
862                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
863        })
864        .unwrap_or(false)
865}