1use super::{
2 auth,
3 db::{ChannelId, UserId},
4 AppState,
5};
6use anyhow::anyhow;
7use async_std::{sync::RwLock, task};
8use async_tungstenite::{
9 tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
10 WebSocketStream,
11};
12use futures::{future::BoxFuture, FutureExt};
13use postage::prelude::Stream as _;
14use sha1::{Digest as _, Sha1};
15use std::{
16 any::TypeId,
17 collections::{HashMap, HashSet},
18 future::Future,
19 mem,
20 sync::Arc,
21 time::Instant,
22};
23use surf::StatusCode;
24use tide::log;
25use tide::{
26 http::headers::{HeaderName, CONNECTION, UPGRADE},
27 Request, Response,
28};
29use time::OffsetDateTime;
30use zrpc::{
31 auth::random_token,
32 proto::{self, AnyTypedEnvelope, EnvelopedMessage},
33 ConnectionId, Peer, TypedEnvelope,
34};
35
36type ReplicaId = u16;
37
38type MessageHandler = Box<
39 dyn Send
40 + Sync
41 + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
42>;
43
44pub struct Server {
45 peer: Arc<Peer>,
46 state: RwLock<ServerState>,
47 app_state: Arc<AppState>,
48 handlers: HashMap<TypeId, MessageHandler>,
49}
50
51#[derive(Default)]
52struct ServerState {
53 connections: HashMap<ConnectionId, Connection>,
54 pub worktrees: HashMap<u64, Worktree>,
55 channels: HashMap<ChannelId, Channel>,
56 next_worktree_id: u64,
57}
58
59struct Connection {
60 user_id: UserId,
61 worktrees: HashSet<u64>,
62 channels: HashSet<ChannelId>,
63}
64
65struct Worktree {
66 host_connection_id: Option<ConnectionId>,
67 guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
68 active_replica_ids: HashSet<ReplicaId>,
69 access_token: String,
70 root_name: String,
71 entries: HashMap<u64, proto::Entry>,
72}
73
74#[derive(Default)]
75struct Channel {
76 connection_ids: HashSet<ConnectionId>,
77}
78
79impl Server {
80 pub fn new(app_state: Arc<AppState>, peer: Arc<Peer>) -> Arc<Self> {
81 let mut server = Server {
82 peer,
83 app_state,
84 state: Default::default(),
85 handlers: Default::default(),
86 };
87
88 server
89 .add_handler(Server::share_worktree)
90 .add_handler(Server::join_worktree)
91 .add_handler(Server::update_worktree)
92 .add_handler(Server::close_worktree)
93 .add_handler(Server::open_buffer)
94 .add_handler(Server::close_buffer)
95 .add_handler(Server::update_buffer)
96 .add_handler(Server::buffer_saved)
97 .add_handler(Server::save_buffer)
98 .add_handler(Server::get_channels)
99 .add_handler(Server::get_users)
100 .add_handler(Server::join_channel)
101 .add_handler(Server::send_channel_message);
102
103 Arc::new(server)
104 }
105
106 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
107 where
108 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
109 Fut: 'static + Send + Future<Output = tide::Result<()>>,
110 M: EnvelopedMessage,
111 {
112 let prev_handler = self.handlers.insert(
113 TypeId::of::<M>(),
114 Box::new(move |server, envelope| {
115 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
116 (handler)(server, *envelope).boxed()
117 }),
118 );
119 if prev_handler.is_some() {
120 panic!("registered a handler for the same message twice");
121 }
122 self
123 }
124
125 pub fn handle_connection<Conn>(
126 self: &Arc<Self>,
127 connection: Conn,
128 addr: String,
129 user_id: UserId,
130 ) -> impl Future<Output = ()>
131 where
132 Conn: 'static
133 + futures::Sink<WebSocketMessage, Error = WebSocketError>
134 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
135 + Send
136 + Unpin,
137 {
138 let this = self.clone();
139 async move {
140 let (connection_id, handle_io, mut incoming_rx) =
141 this.peer.add_connection(connection).await;
142 this.add_connection(connection_id, user_id).await;
143
144 let handle_io = handle_io.fuse();
145 futures::pin_mut!(handle_io);
146 loop {
147 let next_message = incoming_rx.recv().fuse();
148 futures::pin_mut!(next_message);
149 futures::select_biased! {
150 message = next_message => {
151 if let Some(message) = message {
152 let start_time = Instant::now();
153 log::info!("RPC message received: {}", message.payload_type_name());
154 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
155 if let Err(err) = (handler)(this.clone(), message).await {
156 log::error!("error handling message: {:?}", err);
157 } else {
158 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
159 }
160 } else {
161 log::warn!("unhandled message: {}", message.payload_type_name());
162 }
163 } else {
164 log::info!("rpc connection closed {:?}", addr);
165 break;
166 }
167 }
168 handle_io = handle_io => {
169 if let Err(err) = handle_io {
170 log::error!("error handling rpc connection {:?} - {:?}", addr, err);
171 }
172 break;
173 }
174 }
175 }
176
177 if let Err(err) = this.sign_out(connection_id).await {
178 log::error!("error signing out connection {:?} - {:?}", addr, err);
179 }
180 }
181 }
182
183 async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
184 self.peer.disconnect(connection_id).await;
185 let worktree_ids = self.remove_connection(connection_id).await;
186 for worktree_id in worktree_ids {
187 let state = self.state.read().await;
188 if let Some(worktree) = state.worktrees.get(&worktree_id) {
189 broadcast(connection_id, worktree.connection_ids(), |conn_id| {
190 self.peer.send(
191 conn_id,
192 proto::RemovePeer {
193 worktree_id,
194 peer_id: connection_id.0,
195 },
196 )
197 })
198 .await?;
199 }
200 }
201 Ok(())
202 }
203
204 // Add a new connection associated with a given user.
205 async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
206 self.state.write().await.connections.insert(
207 connection_id,
208 Connection {
209 user_id,
210 worktrees: Default::default(),
211 channels: Default::default(),
212 },
213 );
214 }
215
216 // Remove the given connection and its association with any worktrees.
217 async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
218 let mut worktree_ids = Vec::new();
219 let mut state = self.state.write().await;
220 if let Some(connection) = state.connections.remove(&connection_id) {
221 for channel_id in connection.channels {
222 if let Some(channel) = state.channels.get_mut(&channel_id) {
223 channel.connection_ids.remove(&connection_id);
224 }
225 }
226 for worktree_id in connection.worktrees {
227 if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
228 if worktree.host_connection_id == Some(connection_id) {
229 worktree_ids.push(worktree_id);
230 } else if let Some(replica_id) =
231 worktree.guest_connection_ids.remove(&connection_id)
232 {
233 worktree.active_replica_ids.remove(&replica_id);
234 worktree_ids.push(worktree_id);
235 }
236 }
237 }
238 }
239 worktree_ids
240 }
241
242 async fn share_worktree(
243 self: Arc<Server>,
244 mut request: TypedEnvelope<proto::ShareWorktree>,
245 ) -> tide::Result<()> {
246 let mut state = self.state.write().await;
247 let worktree_id = state.next_worktree_id;
248 state.next_worktree_id += 1;
249 let access_token = random_token();
250 let worktree = request
251 .payload
252 .worktree
253 .as_mut()
254 .ok_or_else(|| anyhow!("missing worktree"))?;
255 let entries = mem::take(&mut worktree.entries)
256 .into_iter()
257 .map(|entry| (entry.id, entry))
258 .collect();
259 state.worktrees.insert(
260 worktree_id,
261 Worktree {
262 host_connection_id: Some(request.sender_id),
263 guest_connection_ids: Default::default(),
264 active_replica_ids: Default::default(),
265 access_token: access_token.clone(),
266 root_name: mem::take(&mut worktree.root_name),
267 entries,
268 },
269 );
270
271 self.peer
272 .respond(
273 request.receipt(),
274 proto::ShareWorktreeResponse {
275 worktree_id,
276 access_token,
277 },
278 )
279 .await?;
280 Ok(())
281 }
282
283 async fn join_worktree(
284 self: Arc<Server>,
285 request: TypedEnvelope<proto::OpenWorktree>,
286 ) -> tide::Result<()> {
287 let worktree_id = request.payload.worktree_id;
288 let access_token = &request.payload.access_token;
289
290 let mut state = self.state.write().await;
291 if let Some((peer_replica_id, worktree)) =
292 state.join_worktree(request.sender_id, worktree_id, access_token)
293 {
294 let mut peers = Vec::new();
295 if let Some(host_connection_id) = worktree.host_connection_id {
296 peers.push(proto::Peer {
297 peer_id: host_connection_id.0,
298 replica_id: 0,
299 });
300 }
301 for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
302 if *peer_conn_id != request.sender_id {
303 peers.push(proto::Peer {
304 peer_id: peer_conn_id.0,
305 replica_id: *peer_replica_id as u32,
306 });
307 }
308 }
309
310 broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
311 self.peer.send(
312 conn_id,
313 proto::AddPeer {
314 worktree_id,
315 peer: Some(proto::Peer {
316 peer_id: request.sender_id.0,
317 replica_id: peer_replica_id as u32,
318 }),
319 },
320 )
321 })
322 .await?;
323 self.peer
324 .respond(
325 request.receipt(),
326 proto::OpenWorktreeResponse {
327 worktree_id,
328 worktree: Some(proto::Worktree {
329 root_name: worktree.root_name.clone(),
330 entries: worktree.entries.values().cloned().collect(),
331 }),
332 replica_id: peer_replica_id as u32,
333 peers,
334 },
335 )
336 .await?;
337 } else {
338 self.peer
339 .respond(
340 request.receipt(),
341 proto::OpenWorktreeResponse {
342 worktree_id,
343 worktree: None,
344 replica_id: 0,
345 peers: Vec::new(),
346 },
347 )
348 .await?;
349 }
350
351 Ok(())
352 }
353
354 async fn update_worktree(
355 self: Arc<Server>,
356 request: TypedEnvelope<proto::UpdateWorktree>,
357 ) -> tide::Result<()> {
358 {
359 let mut state = self.state.write().await;
360 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
361 for entry_id in &request.payload.removed_entries {
362 worktree.entries.remove(&entry_id);
363 }
364
365 for entry in &request.payload.updated_entries {
366 worktree.entries.insert(entry.id, entry.clone());
367 }
368 }
369
370 self.broadcast_in_worktree(request.payload.worktree_id, &request)
371 .await?;
372 Ok(())
373 }
374
375 async fn close_worktree(
376 self: Arc<Server>,
377 request: TypedEnvelope<proto::CloseWorktree>,
378 ) -> tide::Result<()> {
379 let connection_ids;
380 {
381 let mut state = self.state.write().await;
382 let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
383 connection_ids = worktree.connection_ids();
384 if worktree.host_connection_id == Some(request.sender_id) {
385 worktree.host_connection_id = None;
386 } else if let Some(replica_id) =
387 worktree.guest_connection_ids.remove(&request.sender_id)
388 {
389 worktree.active_replica_ids.remove(&replica_id);
390 }
391 }
392
393 broadcast(request.sender_id, connection_ids, |conn_id| {
394 self.peer.send(
395 conn_id,
396 proto::RemovePeer {
397 worktree_id: request.payload.worktree_id,
398 peer_id: request.sender_id.0,
399 },
400 )
401 })
402 .await?;
403
404 Ok(())
405 }
406
407 async fn open_buffer(
408 self: Arc<Server>,
409 request: TypedEnvelope<proto::OpenBuffer>,
410 ) -> tide::Result<()> {
411 let receipt = request.receipt();
412 let worktree_id = request.payload.worktree_id;
413 let host_connection_id = self
414 .state
415 .read()
416 .await
417 .read_worktree(worktree_id, request.sender_id)?
418 .host_connection_id()?;
419
420 let response = self
421 .peer
422 .forward_request(request.sender_id, host_connection_id, request.payload)
423 .await?;
424 self.peer.respond(receipt, response).await?;
425 Ok(())
426 }
427
428 async fn close_buffer(
429 self: Arc<Server>,
430 request: TypedEnvelope<proto::CloseBuffer>,
431 ) -> tide::Result<()> {
432 let host_connection_id = self
433 .state
434 .read()
435 .await
436 .read_worktree(request.payload.worktree_id, request.sender_id)?
437 .host_connection_id()?;
438
439 self.peer
440 .forward_send(request.sender_id, host_connection_id, request.payload)
441 .await?;
442
443 Ok(())
444 }
445
446 async fn save_buffer(
447 self: Arc<Server>,
448 request: TypedEnvelope<proto::SaveBuffer>,
449 ) -> tide::Result<()> {
450 let host;
451 let guests;
452 {
453 let state = self.state.read().await;
454 let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
455 host = worktree.host_connection_id()?;
456 guests = worktree
457 .guest_connection_ids
458 .keys()
459 .copied()
460 .collect::<Vec<_>>();
461 }
462
463 let sender = request.sender_id;
464 let receipt = request.receipt();
465 let response = self
466 .peer
467 .forward_request(sender, host, request.payload.clone())
468 .await?;
469
470 broadcast(host, guests, |conn_id| {
471 let response = response.clone();
472 let peer = &self.peer;
473 async move {
474 if conn_id == sender {
475 peer.respond(receipt, response).await
476 } else {
477 peer.forward_send(host, conn_id, response).await
478 }
479 }
480 })
481 .await?;
482
483 Ok(())
484 }
485
486 async fn update_buffer(
487 self: Arc<Server>,
488 request: TypedEnvelope<proto::UpdateBuffer>,
489 ) -> tide::Result<()> {
490 self.broadcast_in_worktree(request.payload.worktree_id, &request)
491 .await
492 }
493
494 async fn buffer_saved(
495 self: Arc<Server>,
496 request: TypedEnvelope<proto::BufferSaved>,
497 ) -> tide::Result<()> {
498 self.broadcast_in_worktree(request.payload.worktree_id, &request)
499 .await
500 }
501
502 async fn get_channels(
503 self: Arc<Server>,
504 request: TypedEnvelope<proto::GetChannels>,
505 ) -> tide::Result<()> {
506 let user_id = self
507 .state
508 .read()
509 .await
510 .user_id_for_connection(request.sender_id)?;
511 let channels = self.app_state.db.get_channels_for_user(user_id).await?;
512 self.peer
513 .respond(
514 request.receipt(),
515 proto::GetChannelsResponse {
516 channels: channels
517 .into_iter()
518 .map(|chan| proto::Channel {
519 id: chan.id.to_proto(),
520 name: chan.name,
521 })
522 .collect(),
523 },
524 )
525 .await?;
526 Ok(())
527 }
528
529 async fn get_users(
530 self: Arc<Server>,
531 request: TypedEnvelope<proto::GetUsers>,
532 ) -> tide::Result<()> {
533 let user_id = self
534 .state
535 .read()
536 .await
537 .user_id_for_connection(request.sender_id)?;
538 let receipt = request.receipt();
539 let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
540 let users = self
541 .app_state
542 .db
543 .get_users_by_ids(user_id, user_ids)
544 .await?
545 .into_iter()
546 .map(|user| proto::User {
547 id: user.id.to_proto(),
548 github_login: user.github_login,
549 avatar_url: String::new(),
550 })
551 .collect();
552 self.peer
553 .respond(receipt, proto::GetUsersResponse { users })
554 .await?;
555 Ok(())
556 }
557
558 async fn join_channel(
559 self: Arc<Self>,
560 request: TypedEnvelope<proto::JoinChannel>,
561 ) -> tide::Result<()> {
562 let user_id = self
563 .state
564 .read()
565 .await
566 .user_id_for_connection(request.sender_id)?;
567 let channel_id = ChannelId::from_proto(request.payload.channel_id);
568 if !self
569 .app_state
570 .db
571 .can_user_access_channel(user_id, channel_id)
572 .await?
573 {
574 Err(anyhow!("access denied"))?;
575 }
576
577 self.state
578 .write()
579 .await
580 .join_channel(request.sender_id, channel_id);
581 let messages = self
582 .app_state
583 .db
584 .get_recent_channel_messages(channel_id, 50)
585 .await?
586 .into_iter()
587 .map(|msg| proto::ChannelMessage {
588 id: msg.id.to_proto(),
589 body: msg.body,
590 timestamp: msg.sent_at.unix_timestamp() as u64,
591 sender_id: msg.sender_id.to_proto(),
592 })
593 .collect();
594 self.peer
595 .respond(request.receipt(), proto::JoinChannelResponse { messages })
596 .await?;
597 Ok(())
598 }
599
600 async fn send_channel_message(
601 self: Arc<Self>,
602 request: TypedEnvelope<proto::SendChannelMessage>,
603 ) -> tide::Result<()> {
604 let channel_id = ChannelId::from_proto(request.payload.channel_id);
605 let user_id;
606 let connection_ids;
607 {
608 let state = self.state.read().await;
609 user_id = state.user_id_for_connection(request.sender_id)?;
610 if let Some(channel) = state.channels.get(&channel_id) {
611 connection_ids = channel.connection_ids();
612 } else {
613 return Ok(());
614 }
615 }
616
617 let timestamp = OffsetDateTime::now_utc();
618 let message_id = self
619 .app_state
620 .db
621 .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
622 .await?
623 .to_proto();
624 let receipt = request.receipt();
625 let message = proto::ChannelMessageSent {
626 channel_id: channel_id.to_proto(),
627 message: Some(proto::ChannelMessage {
628 sender_id: user_id.to_proto(),
629 id: message_id,
630 body: request.payload.body,
631 timestamp: timestamp.unix_timestamp() as u64,
632 }),
633 };
634 broadcast(request.sender_id, connection_ids, |conn_id| {
635 self.peer.send(conn_id, message.clone())
636 })
637 .await?;
638 self.peer
639 .respond(
640 receipt,
641 proto::SendChannelMessageResponse {
642 message_id,
643 timestamp: timestamp.unix_timestamp() as u64,
644 },
645 )
646 .await?;
647 Ok(())
648 }
649
650 async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
651 &self,
652 worktree_id: u64,
653 message: &TypedEnvelope<T>,
654 ) -> tide::Result<()> {
655 let connection_ids = self
656 .state
657 .read()
658 .await
659 .read_worktree(worktree_id, message.sender_id)?
660 .connection_ids();
661
662 broadcast(message.sender_id, connection_ids, |conn_id| {
663 self.peer
664 .forward_send(message.sender_id, conn_id, message.payload.clone())
665 })
666 .await?;
667
668 Ok(())
669 }
670}
671
672pub async fn broadcast<F, T>(
673 sender_id: ConnectionId,
674 receiver_ids: Vec<ConnectionId>,
675 mut f: F,
676) -> anyhow::Result<()>
677where
678 F: FnMut(ConnectionId) -> T,
679 T: Future<Output = anyhow::Result<()>>,
680{
681 let futures = receiver_ids
682 .into_iter()
683 .filter(|id| *id != sender_id)
684 .map(|id| f(id));
685 futures::future::try_join_all(futures).await?;
686 Ok(())
687}
688
689impl ServerState {
690 fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
691 if let Some(connection) = self.connections.get_mut(&connection_id) {
692 connection.channels.insert(channel_id);
693 self.channels
694 .entry(channel_id)
695 .or_default()
696 .connection_ids
697 .insert(connection_id);
698 }
699 }
700
701 fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
702 Ok(self
703 .connections
704 .get(&connection_id)
705 .ok_or_else(|| anyhow!("unknown connection"))?
706 .user_id)
707 }
708
709 // Add the given connection as a guest of the given worktree
710 fn join_worktree(
711 &mut self,
712 connection_id: ConnectionId,
713 worktree_id: u64,
714 access_token: &str,
715 ) -> Option<(ReplicaId, &Worktree)> {
716 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
717 if access_token == worktree.access_token {
718 if let Some(connection) = self.connections.get_mut(&connection_id) {
719 connection.worktrees.insert(worktree_id);
720 }
721
722 let mut replica_id = 1;
723 while worktree.active_replica_ids.contains(&replica_id) {
724 replica_id += 1;
725 }
726 worktree.active_replica_ids.insert(replica_id);
727 worktree
728 .guest_connection_ids
729 .insert(connection_id, replica_id);
730 Some((replica_id, worktree))
731 } else {
732 None
733 }
734 } else {
735 None
736 }
737 }
738
739 fn read_worktree(
740 &self,
741 worktree_id: u64,
742 connection_id: ConnectionId,
743 ) -> tide::Result<&Worktree> {
744 let worktree = self
745 .worktrees
746 .get(&worktree_id)
747 .ok_or_else(|| anyhow!("worktree not found"))?;
748
749 if worktree.host_connection_id == Some(connection_id)
750 || worktree.guest_connection_ids.contains_key(&connection_id)
751 {
752 Ok(worktree)
753 } else {
754 Err(anyhow!(
755 "{} is not a member of worktree {}",
756 connection_id,
757 worktree_id
758 ))?
759 }
760 }
761
762 fn write_worktree(
763 &mut self,
764 worktree_id: u64,
765 connection_id: ConnectionId,
766 ) -> tide::Result<&mut Worktree> {
767 let worktree = self
768 .worktrees
769 .get_mut(&worktree_id)
770 .ok_or_else(|| anyhow!("worktree not found"))?;
771
772 if worktree.host_connection_id == Some(connection_id)
773 || worktree.guest_connection_ids.contains_key(&connection_id)
774 {
775 Ok(worktree)
776 } else {
777 Err(anyhow!(
778 "{} is not a member of worktree {}",
779 connection_id,
780 worktree_id
781 ))?
782 }
783 }
784}
785
786impl Worktree {
787 pub fn connection_ids(&self) -> Vec<ConnectionId> {
788 self.guest_connection_ids
789 .keys()
790 .copied()
791 .chain(self.host_connection_id)
792 .collect()
793 }
794
795 fn host_connection_id(&self) -> tide::Result<ConnectionId> {
796 Ok(self
797 .host_connection_id
798 .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
799 }
800}
801
802impl Channel {
803 fn connection_ids(&self) -> Vec<ConnectionId> {
804 self.connection_ids.iter().copied().collect()
805 }
806}
807
808pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
809 let server = Server::new(app.state().clone(), rpc.clone());
810 app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
811 let user_id = request.ext::<UserId>().copied();
812 let server = server.clone();
813 async move {
814 const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
815
816 let connection_upgrade = header_contains_ignore_case(&request, CONNECTION, "upgrade");
817 let upgrade_to_websocket = header_contains_ignore_case(&request, UPGRADE, "websocket");
818 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
819
820 if !upgrade_requested {
821 return Ok(Response::new(StatusCode::UpgradeRequired));
822 }
823
824 let header = match request.header("Sec-Websocket-Key") {
825 Some(h) => h.as_str(),
826 None => return Err(anyhow!("expected sec-websocket-key"))?,
827 };
828
829 let mut response = Response::new(StatusCode::SwitchingProtocols);
830 response.insert_header(UPGRADE, "websocket");
831 response.insert_header(CONNECTION, "Upgrade");
832 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
833 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
834 response.insert_header("Sec-Websocket-Version", "13");
835
836 let http_res: &mut tide::http::Response = response.as_mut();
837 let upgrade_receiver = http_res.recv_upgrade().await;
838 let addr = request.remote().unwrap_or("unknown").to_string();
839 let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
840 task::spawn(async move {
841 if let Some(stream) = upgrade_receiver.await {
842 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
843 server.handle_connection(stream, addr, user_id).await;
844 }
845 });
846
847 Ok(response)
848 }
849 });
850}
851
852fn header_contains_ignore_case<T>(
853 request: &tide::Request<T>,
854 header_name: HeaderName,
855 value: &str,
856) -> bool {
857 request
858 .header(header_name)
859 .map(|h| {
860 h.as_str()
861 .split(',')
862 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
863 })
864 .unwrap_or(false)
865}