1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
8 MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
9 RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35
36use futures::{
37 channel::oneshot,
38 future::{self, BoxFuture},
39 stream::FuturesUnordered,
40 FutureExt, SinkExt, StreamExt, TryStreamExt,
41};
42use prometheus::{register_int_gauge, IntGauge};
43use rpc::{
44 proto::{
45 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
46 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
47 },
48 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
49};
50use semantic_version::SemanticVersion;
51use serde::{Serialize, Serializer};
52use std::{
53 any::TypeId,
54 future::Future,
55 marker::PhantomData,
56 mem,
57 net::SocketAddr,
58 ops::{Deref, DerefMut},
59 rc::Rc,
60 sync::{
61 atomic::{AtomicBool, Ordering::SeqCst},
62 Arc, OnceLock,
63 },
64 time::{Duration, Instant},
65};
66use time::OffsetDateTime;
67use tokio::sync::{watch, Semaphore};
68use tower::ServiceBuilder;
69use tracing::{
70 field::{self},
71 info_span, instrument, Instrument,
72};
73use util::http::IsahcHttpClient;
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const MESSAGE_COUNT_PER_PAGE: usize = 100;
81const MAX_MESSAGE_LEN: usize = 1024;
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83
84type MessageHandler =
85 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
86
87struct Response<R> {
88 peer: Arc<Peer>,
89 receipt: Receipt<R>,
90 responded: Arc<AtomicBool>,
91}
92
93impl<R: RequestMessage> Response<R> {
94 fn send(self, payload: R::Response) -> Result<()> {
95 self.responded.store(true, SeqCst);
96 self.peer.respond(self.receipt, payload)?;
97 Ok(())
98 }
99}
100
101struct StreamingResponse<R: RequestMessage> {
102 peer: Arc<Peer>,
103 receipt: Receipt<R>,
104}
105
106impl<R: RequestMessage> StreamingResponse<R> {
107 fn send(&self, payload: R::Response) -> Result<()> {
108 self.peer.respond(self.receipt, payload)?;
109 Ok(())
110 }
111}
112
113#[derive(Clone, Debug)]
114pub enum Principal {
115 User(User),
116 Impersonated { user: User, admin: User },
117 DevServer(dev_server::Model),
118}
119
120impl Principal {
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", &user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", &user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 Principal::DevServer(dev_server) => {
133 span.record("dev_server_id", &dev_server.id.0);
134 }
135 }
136 }
137}
138
139#[derive(Clone)]
140struct Session {
141 principal: Principal,
142 connection_id: ConnectionId,
143 db: Arc<tokio::sync::Mutex<DbHandle>>,
144 peer: Arc<Peer>,
145 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
146 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
147 http_client: IsahcHttpClient,
148 rate_limiter: Arc<RateLimiter>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn for_user(self) -> Option<UserSession> {
173 UserSession::new(self)
174 }
175
176 fn for_dev_server(self) -> Option<DevServerSession> {
177 DevServerSession::new(self)
178 }
179
180 fn user_id(&self) -> Option<UserId> {
181 match &self.principal {
182 Principal::User(user) => Some(user.id),
183 Principal::Impersonated { user, .. } => Some(user.id),
184 Principal::DevServer(_) => None,
185 }
186 }
187
188 fn dev_server_id(&self) -> Option<DevServerId> {
189 match &self.principal {
190 Principal::User(_) | Principal::Impersonated { .. } => None,
191 Principal::DevServer(dev_server) => Some(dev_server.id),
192 }
193 }
194
195 fn principal_id(&self) -> PrincipalId {
196 match &self.principal {
197 Principal::User(user) => PrincipalId::UserId(user.id),
198 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
199 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
200 }
201 }
202}
203
204impl Debug for Session {
205 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
206 let mut result = f.debug_struct("Session");
207 match &self.principal {
208 Principal::User(user) => {
209 result.field("user", &user.github_login);
210 }
211 Principal::Impersonated { user, admin } => {
212 result.field("user", &user.github_login);
213 result.field("impersonator", &admin.github_login);
214 }
215 Principal::DevServer(dev_server) => {
216 result.field("dev_server", &dev_server.id);
217 }
218 }
219 result.field("connection_id", &self.connection_id).finish()
220 }
221}
222
223struct UserSession(Session);
224
225impl UserSession {
226 pub fn new(s: Session) -> Option<Self> {
227 s.user_id().map(|_| UserSession(s))
228 }
229 pub fn user_id(&self) -> UserId {
230 self.0.user_id().unwrap()
231 }
232}
233
234impl Deref for UserSession {
235 type Target = Session;
236
237 fn deref(&self) -> &Self::Target {
238 &self.0
239 }
240}
241impl DerefMut for UserSession {
242 fn deref_mut(&mut self) -> &mut Self::Target {
243 &mut self.0
244 }
245}
246
247struct DevServerSession(Session);
248
249impl DevServerSession {
250 pub fn new(s: Session) -> Option<Self> {
251 s.dev_server_id().map(|_| DevServerSession(s))
252 }
253 pub fn dev_server_id(&self) -> DevServerId {
254 self.0.dev_server_id().unwrap()
255 }
256}
257
258impl Deref for DevServerSession {
259 type Target = Session;
260
261 fn deref(&self) -> &Self::Target {
262 &self.0
263 }
264}
265impl DerefMut for DevServerSession {
266 fn deref_mut(&mut self) -> &mut Self::Target {
267 &mut self.0
268 }
269}
270
271fn user_handler<M: RequestMessage, Fut>(
272 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
273) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
274where
275 Fut: Send + Future<Output = Result<()>>,
276{
277 let handler = Arc::new(handler);
278 move |message, response, session| {
279 let handler = handler.clone();
280 Box::pin(async move {
281 if let Some(user_session) = session.for_user() {
282 Ok(handler(message, response, user_session).await?)
283 } else {
284 Err(Error::Internal(anyhow!(
285 "must be a user to call {}",
286 M::NAME
287 )))
288 }
289 })
290 }
291}
292
293fn dev_server_handler<M: RequestMessage, Fut>(
294 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
295) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
296where
297 Fut: Send + Future<Output = Result<()>>,
298{
299 let handler = Arc::new(handler);
300 move |message, response, session| {
301 let handler = handler.clone();
302 Box::pin(async move {
303 if let Some(dev_server_session) = session.for_dev_server() {
304 Ok(handler(message, response, dev_server_session).await?)
305 } else {
306 Err(Error::Internal(anyhow!(
307 "must be a dev server to call {}",
308 M::NAME
309 )))
310 }
311 })
312 }
313}
314
315fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
316 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
317) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
318where
319 InnertRetFut: Send + Future<Output = Result<()>>,
320{
321 let handler = Arc::new(handler);
322 move |message, session| {
323 let handler = handler.clone();
324 Box::pin(async move {
325 if let Some(user_session) = session.for_user() {
326 Ok(handler(message, user_session).await?)
327 } else {
328 Err(Error::Internal(anyhow!(
329 "must be a user to call {}",
330 M::NAME
331 )))
332 }
333 })
334 }
335}
336
337struct DbHandle(Arc<Database>);
338
339impl Deref for DbHandle {
340 type Target = Database;
341
342 fn deref(&self) -> &Self::Target {
343 self.0.as_ref()
344 }
345}
346
347pub struct Server {
348 id: parking_lot::Mutex<ServerId>,
349 peer: Arc<Peer>,
350 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
351 app_state: Arc<AppState>,
352 handlers: HashMap<TypeId, MessageHandler>,
353 teardown: watch::Sender<bool>,
354}
355
356pub(crate) struct ConnectionPoolGuard<'a> {
357 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
358 _not_send: PhantomData<Rc<()>>,
359}
360
361#[derive(Serialize)]
362pub struct ServerSnapshot<'a> {
363 peer: &'a Peer,
364 #[serde(serialize_with = "serialize_deref")]
365 connection_pool: ConnectionPoolGuard<'a>,
366}
367
368pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
369where
370 S: Serializer,
371 T: Deref<Target = U>,
372 U: Serialize,
373{
374 Serialize::serialize(value.deref(), serializer)
375}
376
377impl Server {
378 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
379 let mut server = Self {
380 id: parking_lot::Mutex::new(id),
381 peer: Peer::new(id.0 as u32),
382 app_state: app_state.clone(),
383 connection_pool: Default::default(),
384 handlers: Default::default(),
385 teardown: watch::channel(false).0,
386 };
387
388 server
389 .add_request_handler(ping)
390 .add_request_handler(user_handler(create_room))
391 .add_request_handler(user_handler(join_room))
392 .add_request_handler(user_handler(rejoin_room))
393 .add_request_handler(user_handler(leave_room))
394 .add_request_handler(user_handler(set_room_participant_role))
395 .add_request_handler(user_handler(call))
396 .add_request_handler(user_handler(cancel_call))
397 .add_message_handler(user_message_handler(decline_call))
398 .add_request_handler(user_handler(update_participant_location))
399 .add_request_handler(user_handler(share_project))
400 .add_message_handler(unshare_project)
401 .add_request_handler(user_handler(join_project))
402 .add_request_handler(user_handler(join_hosted_project))
403 .add_request_handler(user_handler(rejoin_remote_projects))
404 .add_request_handler(user_handler(create_remote_project))
405 .add_request_handler(user_handler(create_dev_server))
406 .add_request_handler(dev_server_handler(share_remote_project))
407 .add_request_handler(dev_server_handler(shutdown_dev_server))
408 .add_request_handler(dev_server_handler(reconnect_dev_server))
409 .add_message_handler(user_message_handler(leave_project))
410 .add_request_handler(update_project)
411 .add_request_handler(update_worktree)
412 .add_message_handler(start_language_server)
413 .add_message_handler(update_language_server)
414 .add_message_handler(update_diagnostic_summary)
415 .add_message_handler(update_worktree_settings)
416 .add_request_handler(user_handler(
417 forward_read_only_project_request::<proto::GetHover>,
418 ))
419 .add_request_handler(user_handler(
420 forward_read_only_project_request::<proto::GetDefinition>,
421 ))
422 .add_request_handler(user_handler(
423 forward_read_only_project_request::<proto::GetTypeDefinition>,
424 ))
425 .add_request_handler(user_handler(
426 forward_read_only_project_request::<proto::GetReferences>,
427 ))
428 .add_request_handler(user_handler(
429 forward_read_only_project_request::<proto::SearchProject>,
430 ))
431 .add_request_handler(user_handler(
432 forward_read_only_project_request::<proto::GetDocumentHighlights>,
433 ))
434 .add_request_handler(user_handler(
435 forward_read_only_project_request::<proto::GetProjectSymbols>,
436 ))
437 .add_request_handler(user_handler(
438 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
439 ))
440 .add_request_handler(user_handler(
441 forward_read_only_project_request::<proto::OpenBufferById>,
442 ))
443 .add_request_handler(user_handler(
444 forward_read_only_project_request::<proto::SynchronizeBuffers>,
445 ))
446 .add_request_handler(user_handler(
447 forward_read_only_project_request::<proto::InlayHints>,
448 ))
449 .add_request_handler(user_handler(
450 forward_read_only_project_request::<proto::OpenBufferByPath>,
451 ))
452 .add_request_handler(user_handler(
453 forward_mutating_project_request::<proto::GetCompletions>,
454 ))
455 .add_request_handler(user_handler(
456 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
457 ))
458 .add_request_handler(user_handler(
459 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
460 ))
461 .add_request_handler(user_handler(
462 forward_mutating_project_request::<proto::GetCodeActions>,
463 ))
464 .add_request_handler(user_handler(
465 forward_mutating_project_request::<proto::ApplyCodeAction>,
466 ))
467 .add_request_handler(user_handler(
468 forward_mutating_project_request::<proto::PrepareRename>,
469 ))
470 .add_request_handler(user_handler(
471 forward_mutating_project_request::<proto::PerformRename>,
472 ))
473 .add_request_handler(user_handler(
474 forward_mutating_project_request::<proto::ReloadBuffers>,
475 ))
476 .add_request_handler(user_handler(
477 forward_mutating_project_request::<proto::FormatBuffers>,
478 ))
479 .add_request_handler(user_handler(
480 forward_mutating_project_request::<proto::CreateProjectEntry>,
481 ))
482 .add_request_handler(user_handler(
483 forward_mutating_project_request::<proto::RenameProjectEntry>,
484 ))
485 .add_request_handler(user_handler(
486 forward_mutating_project_request::<proto::CopyProjectEntry>,
487 ))
488 .add_request_handler(user_handler(
489 forward_mutating_project_request::<proto::DeleteProjectEntry>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::ExpandProjectEntry>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::OnTypeFormatting>,
496 ))
497 .add_request_handler(user_handler(
498 forward_mutating_project_request::<proto::SaveBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::BlameBuffer>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::MultiLspQuery>,
505 ))
506 .add_message_handler(create_buffer_for_peer)
507 .add_request_handler(update_buffer)
508 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
509 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
510 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
511 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
512 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
513 .add_request_handler(get_users)
514 .add_request_handler(user_handler(fuzzy_search_users))
515 .add_request_handler(user_handler(request_contact))
516 .add_request_handler(user_handler(remove_contact))
517 .add_request_handler(user_handler(respond_to_contact_request))
518 .add_request_handler(user_handler(create_channel))
519 .add_request_handler(user_handler(delete_channel))
520 .add_request_handler(user_handler(invite_channel_member))
521 .add_request_handler(user_handler(remove_channel_member))
522 .add_request_handler(user_handler(set_channel_member_role))
523 .add_request_handler(user_handler(set_channel_visibility))
524 .add_request_handler(user_handler(rename_channel))
525 .add_request_handler(user_handler(join_channel_buffer))
526 .add_request_handler(user_handler(leave_channel_buffer))
527 .add_message_handler(user_message_handler(update_channel_buffer))
528 .add_request_handler(user_handler(rejoin_channel_buffers))
529 .add_request_handler(user_handler(get_channel_members))
530 .add_request_handler(user_handler(respond_to_channel_invite))
531 .add_request_handler(user_handler(join_channel))
532 .add_request_handler(user_handler(join_channel_chat))
533 .add_message_handler(user_message_handler(leave_channel_chat))
534 .add_request_handler(user_handler(send_channel_message))
535 .add_request_handler(user_handler(remove_channel_message))
536 .add_request_handler(user_handler(update_channel_message))
537 .add_request_handler(user_handler(get_channel_messages))
538 .add_request_handler(user_handler(get_channel_messages_by_id))
539 .add_request_handler(user_handler(get_notifications))
540 .add_request_handler(user_handler(mark_notification_as_read))
541 .add_request_handler(user_handler(move_channel))
542 .add_request_handler(user_handler(follow))
543 .add_message_handler(user_message_handler(unfollow))
544 .add_message_handler(user_message_handler(update_followers))
545 .add_request_handler(user_handler(get_private_user_info))
546 .add_message_handler(user_message_handler(acknowledge_channel_message))
547 .add_message_handler(user_message_handler(acknowledge_buffer_version))
548 .add_streaming_request_handler({
549 let app_state = app_state.clone();
550 move |request, response, session| {
551 complete_with_language_model(
552 request,
553 response,
554 session,
555 app_state.config.openai_api_key.clone(),
556 app_state.config.google_ai_api_key.clone(),
557 app_state.config.anthropic_api_key.clone(),
558 )
559 }
560 })
561 .add_request_handler({
562 let app_state = app_state.clone();
563 user_handler(move |request, response, session| {
564 count_tokens_with_language_model(
565 request,
566 response,
567 session,
568 app_state.config.google_ai_api_key.clone(),
569 )
570 })
571 });
572
573 Arc::new(server)
574 }
575
576 pub async fn start(&self) -> Result<()> {
577 let server_id = *self.id.lock();
578 let app_state = self.app_state.clone();
579 let peer = self.peer.clone();
580 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
581 let pool = self.connection_pool.clone();
582 let live_kit_client = self.app_state.live_kit_client.clone();
583
584 let span = info_span!("start server");
585 self.app_state.executor.spawn_detached(
586 async move {
587 tracing::info!("waiting for cleanup timeout");
588 timeout.await;
589 tracing::info!("cleanup timeout expired, retrieving stale rooms");
590 if let Some((room_ids, channel_ids)) = app_state
591 .db
592 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
593 .await
594 .trace_err()
595 {
596 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
597 tracing::info!(
598 stale_channel_buffer_count = channel_ids.len(),
599 "retrieved stale channel buffers"
600 );
601
602 for channel_id in channel_ids {
603 if let Some(refreshed_channel_buffer) = app_state
604 .db
605 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
606 .await
607 .trace_err()
608 {
609 for connection_id in refreshed_channel_buffer.connection_ids {
610 peer.send(
611 connection_id,
612 proto::UpdateChannelBufferCollaborators {
613 channel_id: channel_id.to_proto(),
614 collaborators: refreshed_channel_buffer
615 .collaborators
616 .clone(),
617 },
618 )
619 .trace_err();
620 }
621 }
622 }
623
624 for room_id in room_ids {
625 let mut contacts_to_update = HashSet::default();
626 let mut canceled_calls_to_user_ids = Vec::new();
627 let mut live_kit_room = String::new();
628 let mut delete_live_kit_room = false;
629
630 if let Some(mut refreshed_room) = app_state
631 .db
632 .clear_stale_room_participants(room_id, server_id)
633 .await
634 .trace_err()
635 {
636 tracing::info!(
637 room_id = room_id.0,
638 new_participant_count = refreshed_room.room.participants.len(),
639 "refreshed room"
640 );
641 room_updated(&refreshed_room.room, &peer);
642 if let Some(channel) = refreshed_room.channel.as_ref() {
643 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
644 }
645 contacts_to_update
646 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
647 contacts_to_update
648 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
649 canceled_calls_to_user_ids =
650 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
651 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
652 delete_live_kit_room = refreshed_room.room.participants.is_empty();
653 }
654
655 {
656 let pool = pool.lock();
657 for canceled_user_id in canceled_calls_to_user_ids {
658 for connection_id in pool.user_connection_ids(canceled_user_id) {
659 peer.send(
660 connection_id,
661 proto::CallCanceled {
662 room_id: room_id.to_proto(),
663 },
664 )
665 .trace_err();
666 }
667 }
668 }
669
670 for user_id in contacts_to_update {
671 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
672 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
673 if let Some((busy, contacts)) = busy.zip(contacts) {
674 let pool = pool.lock();
675 let updated_contact = contact_for_user(user_id, busy, &pool);
676 for contact in contacts {
677 if let db::Contact::Accepted {
678 user_id: contact_user_id,
679 ..
680 } = contact
681 {
682 for contact_conn_id in
683 pool.user_connection_ids(contact_user_id)
684 {
685 peer.send(
686 contact_conn_id,
687 proto::UpdateContacts {
688 contacts: vec![updated_contact.clone()],
689 remove_contacts: Default::default(),
690 incoming_requests: Default::default(),
691 remove_incoming_requests: Default::default(),
692 outgoing_requests: Default::default(),
693 remove_outgoing_requests: Default::default(),
694 },
695 )
696 .trace_err();
697 }
698 }
699 }
700 }
701 }
702
703 if let Some(live_kit) = live_kit_client.as_ref() {
704 if delete_live_kit_room {
705 live_kit.delete_room(live_kit_room).await.trace_err();
706 }
707 }
708 }
709 }
710
711 app_state
712 .db
713 .delete_stale_servers(&app_state.config.zed_environment, server_id)
714 .await
715 .trace_err();
716 }
717 .instrument(span),
718 );
719 Ok(())
720 }
721
722 pub fn teardown(&self) {
723 self.peer.teardown();
724 self.connection_pool.lock().reset();
725 let _ = self.teardown.send(true);
726 }
727
728 #[cfg(test)]
729 pub fn reset(&self, id: ServerId) {
730 self.teardown();
731 *self.id.lock() = id;
732 self.peer.reset(id.0 as u32);
733 let _ = self.teardown.send(false);
734 }
735
736 #[cfg(test)]
737 pub fn id(&self) -> ServerId {
738 *self.id.lock()
739 }
740
741 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
742 where
743 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
744 Fut: 'static + Send + Future<Output = Result<()>>,
745 M: EnvelopedMessage,
746 {
747 let prev_handler = self.handlers.insert(
748 TypeId::of::<M>(),
749 Box::new(move |envelope, session| {
750 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
751 let received_at = envelope.received_at;
752 tracing::info!(
753 "message received"
754 );
755 let start_time = Instant::now();
756 let future = (handler)(*envelope, session);
757 async move {
758 let result = future.await;
759 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
760 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
761 let queue_duration_ms = total_duration_ms - processing_duration_ms;
762 let payload_type = M::NAME;
763 match result {
764 Err(error) => {
765 // todo!(), why isn't this logged inside the span?
766 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
767 }
768 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
769 }
770 }
771 .boxed()
772 }),
773 );
774 if prev_handler.is_some() {
775 panic!("registered a handler for the same message twice");
776 }
777 self
778 }
779
780 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
781 where
782 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
783 Fut: 'static + Send + Future<Output = Result<()>>,
784 M: EnvelopedMessage,
785 {
786 self.add_handler(move |envelope, session| handler(envelope.payload, session));
787 self
788 }
789
790 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
791 where
792 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
793 Fut: Send + Future<Output = Result<()>>,
794 M: RequestMessage,
795 {
796 let handler = Arc::new(handler);
797 self.add_handler(move |envelope, session| {
798 let receipt = envelope.receipt();
799 let handler = handler.clone();
800 async move {
801 let peer = session.peer.clone();
802 let responded = Arc::new(AtomicBool::default());
803 let response = Response {
804 peer: peer.clone(),
805 responded: responded.clone(),
806 receipt,
807 };
808 match (handler)(envelope.payload, response, session).await {
809 Ok(()) => {
810 if responded.load(std::sync::atomic::Ordering::SeqCst) {
811 Ok(())
812 } else {
813 Err(anyhow!("handler did not send a response"))?
814 }
815 }
816 Err(error) => {
817 let proto_err = match &error {
818 Error::Internal(err) => err.to_proto(),
819 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
820 };
821 peer.respond_with_error(receipt, proto_err)?;
822 Err(error)
823 }
824 }
825 }
826 })
827 }
828
829 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
830 where
831 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
832 Fut: Send + Future<Output = Result<()>>,
833 M: RequestMessage,
834 {
835 let handler = Arc::new(handler);
836 self.add_handler(move |envelope, session| {
837 let receipt = envelope.receipt();
838 let handler = handler.clone();
839 async move {
840 let peer = session.peer.clone();
841 let response = StreamingResponse {
842 peer: peer.clone(),
843 receipt,
844 };
845 match (handler)(envelope.payload, response, session).await {
846 Ok(()) => {
847 peer.end_stream(receipt)?;
848 Ok(())
849 }
850 Err(error) => {
851 let proto_err = match &error {
852 Error::Internal(err) => err.to_proto(),
853 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
854 };
855 peer.respond_with_error(receipt, proto_err)?;
856 Err(error)
857 }
858 }
859 }
860 })
861 }
862
863 #[allow(clippy::too_many_arguments)]
864 pub fn handle_connection(
865 self: &Arc<Self>,
866 connection: Connection,
867 address: String,
868 principal: Principal,
869 zed_version: ZedVersion,
870 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
871 executor: Executor,
872 ) -> impl Future<Output = ()> {
873 let this = self.clone();
874 let span = info_span!("handle connection", %address,
875 connection_id=field::Empty,
876 user_id=field::Empty,
877 login=field::Empty,
878 impersonator=field::Empty,
879 dev_server_id=field::Empty
880 );
881 principal.update_span(&span);
882
883 let mut teardown = self.teardown.subscribe();
884 async move {
885 if *teardown.borrow() {
886 tracing::error!("server is tearing down");
887 return
888 }
889 let (connection_id, handle_io, mut incoming_rx) = this
890 .peer
891 .add_connection(connection, {
892 let executor = executor.clone();
893 move |duration| executor.sleep(duration)
894 });
895 tracing::Span::current().record("connection_id", format!("{}", connection_id));
896 tracing::info!("connection opened");
897
898 let http_client = match IsahcHttpClient::new() {
899 Ok(http_client) => http_client,
900 Err(error) => {
901 tracing::error!(?error, "failed to create HTTP client");
902 return;
903 }
904 };
905
906 let session = Session {
907 principal: principal.clone(),
908 connection_id,
909 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
910 peer: this.peer.clone(),
911 connection_pool: this.connection_pool.clone(),
912 live_kit_client: this.app_state.live_kit_client.clone(),
913 http_client,
914 rate_limiter: this.app_state.rate_limiter.clone(),
915 _executor: executor.clone(),
916 };
917
918 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
919 tracing::error!(?error, "failed to send initial client update");
920 return;
921 }
922
923 let handle_io = handle_io.fuse();
924 futures::pin_mut!(handle_io);
925
926 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
927 // This prevents deadlocks when e.g., client A performs a request to client B and
928 // client B performs a request to client A. If both clients stop processing further
929 // messages until their respective request completes, they won't have a chance to
930 // respond to the other client's request and cause a deadlock.
931 //
932 // This arrangement ensures we will attempt to process earlier messages first, but fall
933 // back to processing messages arrived later in the spirit of making progress.
934 let mut foreground_message_handlers = FuturesUnordered::new();
935 let concurrent_handlers = Arc::new(Semaphore::new(256));
936 loop {
937 let next_message = async {
938 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
939 let message = incoming_rx.next().await;
940 (permit, message)
941 }.fuse();
942 futures::pin_mut!(next_message);
943 futures::select_biased! {
944 _ = teardown.changed().fuse() => return,
945 result = handle_io => {
946 if let Err(error) = result {
947 tracing::error!(?error, "error handling I/O");
948 }
949 break;
950 }
951 _ = foreground_message_handlers.next() => {}
952 next_message = next_message => {
953 let (permit, message) = next_message;
954 if let Some(message) = message {
955 let type_name = message.payload_type_name();
956 // note: we copy all the fields from the parent span so we can query them in the logs.
957 // (https://github.com/tokio-rs/tracing/issues/2670).
958 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
959 user_id=field::Empty,
960 login=field::Empty,
961 impersonator=field::Empty,
962 dev_server_id=field::Empty
963 );
964 principal.update_span(&span);
965 let span_enter = span.enter();
966 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
967 let is_background = message.is_background();
968 let handle_message = (handler)(message, session.clone());
969 drop(span_enter);
970
971 let handle_message = async move {
972 handle_message.await;
973 drop(permit);
974 }.instrument(span);
975 if is_background {
976 executor.spawn_detached(handle_message);
977 } else {
978 foreground_message_handlers.push(handle_message);
979 }
980 } else {
981 tracing::error!("no message handler");
982 }
983 } else {
984 tracing::info!("connection closed");
985 break;
986 }
987 }
988 }
989 }
990
991 drop(foreground_message_handlers);
992 tracing::info!("signing out");
993 if let Err(error) = connection_lost(session, teardown, executor).await {
994 tracing::error!(?error, "error signing out");
995 }
996
997 }.instrument(span)
998 }
999
1000 async fn send_initial_client_update(
1001 &self,
1002 connection_id: ConnectionId,
1003 principal: &Principal,
1004 zed_version: ZedVersion,
1005 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1006 session: &Session,
1007 ) -> Result<()> {
1008 self.peer.send(
1009 connection_id,
1010 proto::Hello {
1011 peer_id: Some(connection_id.into()),
1012 },
1013 )?;
1014 tracing::info!("sent hello message");
1015 if let Some(send_connection_id) = send_connection_id.take() {
1016 let _ = send_connection_id.send(connection_id);
1017 }
1018
1019 match principal {
1020 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1021 if !user.connected_once {
1022 self.peer.send(connection_id, proto::ShowContacts {})?;
1023 self.app_state
1024 .db
1025 .set_user_connected_once(user.id, true)
1026 .await?;
1027 }
1028
1029 let (contacts, channels_for_user, channel_invites) = future::try_join3(
1030 self.app_state.db.get_contacts(user.id),
1031 self.app_state.db.get_channels_for_user(user.id),
1032 self.app_state.db.get_channel_invites_for_user(user.id),
1033 )
1034 .await?;
1035
1036 {
1037 let mut pool = self.connection_pool.lock();
1038 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1039 for membership in &channels_for_user.channel_memberships {
1040 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1041 }
1042 self.peer.send(
1043 connection_id,
1044 build_initial_contacts_update(contacts, &pool),
1045 )?;
1046 self.peer.send(
1047 connection_id,
1048 build_update_user_channels(&channels_for_user),
1049 )?;
1050 self.peer.send(
1051 connection_id,
1052 build_channels_update(channels_for_user, channel_invites, &pool),
1053 )?;
1054 }
1055
1056 if let Some(incoming_call) =
1057 self.app_state.db.incoming_call_for_user(user.id).await?
1058 {
1059 self.peer.send(connection_id, incoming_call)?;
1060 }
1061
1062 update_user_contacts(user.id, &session).await?;
1063 }
1064 Principal::DevServer(dev_server) => {
1065 {
1066 let mut pool = self.connection_pool.lock();
1067 if pool.dev_server_connection_id(dev_server.id).is_some() {
1068 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1069 };
1070 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1071 }
1072 update_dev_server_status(dev_server, proto::DevServerStatus::Online, &session)
1073 .await;
1074 // todo!() allow only one connection.
1075
1076 let projects = self
1077 .app_state
1078 .db
1079 .get_remote_projects_for_dev_server(dev_server.id)
1080 .await?;
1081 self.peer
1082 .send(connection_id, proto::DevServerInstructions { projects })?;
1083 }
1084 }
1085
1086 Ok(())
1087 }
1088
1089 pub async fn invite_code_redeemed(
1090 self: &Arc<Self>,
1091 inviter_id: UserId,
1092 invitee_id: UserId,
1093 ) -> Result<()> {
1094 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1095 if let Some(code) = &user.invite_code {
1096 let pool = self.connection_pool.lock();
1097 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1098 for connection_id in pool.user_connection_ids(inviter_id) {
1099 self.peer.send(
1100 connection_id,
1101 proto::UpdateContacts {
1102 contacts: vec![invitee_contact.clone()],
1103 ..Default::default()
1104 },
1105 )?;
1106 self.peer.send(
1107 connection_id,
1108 proto::UpdateInviteInfo {
1109 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1110 count: user.invite_count as u32,
1111 },
1112 )?;
1113 }
1114 }
1115 }
1116 Ok(())
1117 }
1118
1119 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1120 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1121 if let Some(invite_code) = &user.invite_code {
1122 let pool = self.connection_pool.lock();
1123 for connection_id in pool.user_connection_ids(user_id) {
1124 self.peer.send(
1125 connection_id,
1126 proto::UpdateInviteInfo {
1127 url: format!(
1128 "{}{}",
1129 self.app_state.config.invite_link_prefix, invite_code
1130 ),
1131 count: user.invite_count as u32,
1132 },
1133 )?;
1134 }
1135 }
1136 }
1137 Ok(())
1138 }
1139
1140 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1141 ServerSnapshot {
1142 connection_pool: ConnectionPoolGuard {
1143 guard: self.connection_pool.lock(),
1144 _not_send: PhantomData,
1145 },
1146 peer: &self.peer,
1147 }
1148 }
1149}
1150
1151impl<'a> Deref for ConnectionPoolGuard<'a> {
1152 type Target = ConnectionPool;
1153
1154 fn deref(&self) -> &Self::Target {
1155 &self.guard
1156 }
1157}
1158
1159impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1160 fn deref_mut(&mut self) -> &mut Self::Target {
1161 &mut self.guard
1162 }
1163}
1164
1165impl<'a> Drop for ConnectionPoolGuard<'a> {
1166 fn drop(&mut self) {
1167 #[cfg(test)]
1168 self.check_invariants();
1169 }
1170}
1171
1172fn broadcast<F>(
1173 sender_id: Option<ConnectionId>,
1174 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1175 mut f: F,
1176) where
1177 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1178{
1179 for receiver_id in receiver_ids {
1180 if Some(receiver_id) != sender_id {
1181 if let Err(error) = f(receiver_id) {
1182 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1183 }
1184 }
1185 }
1186}
1187
1188pub struct ProtocolVersion(u32);
1189
1190impl Header for ProtocolVersion {
1191 fn name() -> &'static HeaderName {
1192 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1193 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1194 }
1195
1196 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1197 where
1198 Self: Sized,
1199 I: Iterator<Item = &'i axum::http::HeaderValue>,
1200 {
1201 let version = values
1202 .next()
1203 .ok_or_else(axum::headers::Error::invalid)?
1204 .to_str()
1205 .map_err(|_| axum::headers::Error::invalid())?
1206 .parse()
1207 .map_err(|_| axum::headers::Error::invalid())?;
1208 Ok(Self(version))
1209 }
1210
1211 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1212 values.extend([self.0.to_string().parse().unwrap()]);
1213 }
1214}
1215
1216pub struct AppVersionHeader(SemanticVersion);
1217impl Header for AppVersionHeader {
1218 fn name() -> &'static HeaderName {
1219 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1220 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1221 }
1222
1223 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1224 where
1225 Self: Sized,
1226 I: Iterator<Item = &'i axum::http::HeaderValue>,
1227 {
1228 let version = values
1229 .next()
1230 .ok_or_else(axum::headers::Error::invalid)?
1231 .to_str()
1232 .map_err(|_| axum::headers::Error::invalid())?
1233 .parse()
1234 .map_err(|_| axum::headers::Error::invalid())?;
1235 Ok(Self(version))
1236 }
1237
1238 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1239 values.extend([self.0.to_string().parse().unwrap()]);
1240 }
1241}
1242
1243pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1244 Router::new()
1245 .route("/rpc", get(handle_websocket_request))
1246 .layer(
1247 ServiceBuilder::new()
1248 .layer(Extension(server.app_state.clone()))
1249 .layer(middleware::from_fn(auth::validate_header)),
1250 )
1251 .route("/metrics", get(handle_metrics))
1252 .layer(Extension(server))
1253}
1254
1255pub async fn handle_websocket_request(
1256 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1257 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1258 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1259 Extension(server): Extension<Arc<Server>>,
1260 Extension(principal): Extension<Principal>,
1261 ws: WebSocketUpgrade,
1262) -> axum::response::Response {
1263 if protocol_version != rpc::PROTOCOL_VERSION {
1264 return (
1265 StatusCode::UPGRADE_REQUIRED,
1266 "client must be upgraded".to_string(),
1267 )
1268 .into_response();
1269 }
1270
1271 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1272 return (
1273 StatusCode::UPGRADE_REQUIRED,
1274 "no version header found".to_string(),
1275 )
1276 .into_response();
1277 };
1278
1279 if !version.can_collaborate() {
1280 return (
1281 StatusCode::UPGRADE_REQUIRED,
1282 "client must be upgraded".to_string(),
1283 )
1284 .into_response();
1285 }
1286
1287 let socket_address = socket_address.to_string();
1288 ws.on_upgrade(move |socket| {
1289 let socket = socket
1290 .map_ok(to_tungstenite_message)
1291 .err_into()
1292 .with(|message| async move { Ok(to_axum_message(message)) });
1293 let connection = Connection::new(Box::pin(socket));
1294 async move {
1295 server
1296 .handle_connection(
1297 connection,
1298 socket_address,
1299 principal,
1300 version,
1301 None,
1302 Executor::Production,
1303 )
1304 .await;
1305 }
1306 })
1307}
1308
1309pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1310 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1311 let connections_metric = CONNECTIONS_METRIC
1312 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1313
1314 let connections = server
1315 .connection_pool
1316 .lock()
1317 .connections()
1318 .filter(|connection| !connection.admin)
1319 .count();
1320 connections_metric.set(connections as _);
1321
1322 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1323 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1324 register_int_gauge!(
1325 "shared_projects",
1326 "number of open projects with one or more guests"
1327 )
1328 .unwrap()
1329 });
1330
1331 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1332 shared_projects_metric.set(shared_projects as _);
1333
1334 let encoder = prometheus::TextEncoder::new();
1335 let metric_families = prometheus::gather();
1336 let encoded_metrics = encoder
1337 .encode_to_string(&metric_families)
1338 .map_err(|err| anyhow!("{}", err))?;
1339 Ok(encoded_metrics)
1340}
1341
1342#[instrument(err, skip(executor))]
1343async fn connection_lost(
1344 session: Session,
1345 mut teardown: watch::Receiver<bool>,
1346 executor: Executor,
1347) -> Result<()> {
1348 session.peer.disconnect(session.connection_id);
1349 session
1350 .connection_pool()
1351 .await
1352 .remove_connection(session.connection_id)?;
1353
1354 session
1355 .db()
1356 .await
1357 .connection_lost(session.connection_id)
1358 .await
1359 .trace_err();
1360
1361 futures::select_biased! {
1362 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1363 match &session.principal {
1364 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1365 let session = session.for_user().unwrap();
1366
1367 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1368 leave_room_for_session(&session, session.connection_id).await.trace_err();
1369 leave_channel_buffers_for_session(&session)
1370 .await
1371 .trace_err();
1372
1373 if !session
1374 .connection_pool()
1375 .await
1376 .is_user_online(session.user_id())
1377 {
1378 let db = session.db().await;
1379 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1380 room_updated(&room, &session.peer);
1381 }
1382 }
1383
1384 update_user_contacts(session.user_id(), &session).await?;
1385 },
1386 Principal::DevServer(dev_server) => {
1387 lost_dev_server_connection(&session).await?;
1388 update_dev_server_status(&dev_server, proto::DevServerStatus::Offline, &session)
1389 .await;
1390 },
1391 }
1392 },
1393 _ = teardown.changed().fuse() => {}
1394 }
1395
1396 Ok(())
1397}
1398
1399/// Acknowledges a ping from a client, used to keep the connection alive.
1400async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1401 response.send(proto::Ack {})?;
1402 Ok(())
1403}
1404
1405/// Creates a new room for calling (outside of channels)
1406async fn create_room(
1407 _request: proto::CreateRoom,
1408 response: Response<proto::CreateRoom>,
1409 session: UserSession,
1410) -> Result<()> {
1411 let live_kit_room = nanoid::nanoid!(30);
1412
1413 let live_kit_connection_info = util::maybe!(async {
1414 let live_kit = session.live_kit_client.as_ref();
1415 let live_kit = live_kit?;
1416 let user_id = session.user_id().to_string();
1417
1418 let token = live_kit
1419 .room_token(&live_kit_room, &user_id.to_string())
1420 .trace_err()?;
1421
1422 Some(proto::LiveKitConnectionInfo {
1423 server_url: live_kit.url().into(),
1424 token,
1425 can_publish: true,
1426 })
1427 })
1428 .await;
1429
1430 let room = session
1431 .db()
1432 .await
1433 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1434 .await?;
1435
1436 response.send(proto::CreateRoomResponse {
1437 room: Some(room.clone()),
1438 live_kit_connection_info,
1439 })?;
1440
1441 update_user_contacts(session.user_id(), &session).await?;
1442 Ok(())
1443}
1444
1445/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1446async fn join_room(
1447 request: proto::JoinRoom,
1448 response: Response<proto::JoinRoom>,
1449 session: UserSession,
1450) -> Result<()> {
1451 let room_id = RoomId::from_proto(request.id);
1452
1453 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1454
1455 if let Some(channel_id) = channel_id {
1456 return join_channel_internal(channel_id, Box::new(response), session).await;
1457 }
1458
1459 let joined_room = {
1460 let room = session
1461 .db()
1462 .await
1463 .join_room(room_id, session.user_id(), session.connection_id)
1464 .await?;
1465 room_updated(&room.room, &session.peer);
1466 room.into_inner()
1467 };
1468
1469 for connection_id in session
1470 .connection_pool()
1471 .await
1472 .user_connection_ids(session.user_id())
1473 {
1474 session
1475 .peer
1476 .send(
1477 connection_id,
1478 proto::CallCanceled {
1479 room_id: room_id.to_proto(),
1480 },
1481 )
1482 .trace_err();
1483 }
1484
1485 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1486 if let Some(token) = live_kit
1487 .room_token(
1488 &joined_room.room.live_kit_room,
1489 &session.user_id().to_string(),
1490 )
1491 .trace_err()
1492 {
1493 Some(proto::LiveKitConnectionInfo {
1494 server_url: live_kit.url().into(),
1495 token,
1496 can_publish: true,
1497 })
1498 } else {
1499 None
1500 }
1501 } else {
1502 None
1503 };
1504
1505 response.send(proto::JoinRoomResponse {
1506 room: Some(joined_room.room),
1507 channel_id: None,
1508 live_kit_connection_info,
1509 })?;
1510
1511 update_user_contacts(session.user_id(), &session).await?;
1512 Ok(())
1513}
1514
1515/// Rejoin room is used to reconnect to a room after connection errors.
1516async fn rejoin_room(
1517 request: proto::RejoinRoom,
1518 response: Response<proto::RejoinRoom>,
1519 session: UserSession,
1520) -> Result<()> {
1521 let room;
1522 let channel;
1523 {
1524 let mut rejoined_room = session
1525 .db()
1526 .await
1527 .rejoin_room(request, session.user_id(), session.connection_id)
1528 .await?;
1529
1530 response.send(proto::RejoinRoomResponse {
1531 room: Some(rejoined_room.room.clone()),
1532 reshared_projects: rejoined_room
1533 .reshared_projects
1534 .iter()
1535 .map(|project| proto::ResharedProject {
1536 id: project.id.to_proto(),
1537 collaborators: project
1538 .collaborators
1539 .iter()
1540 .map(|collaborator| collaborator.to_proto())
1541 .collect(),
1542 })
1543 .collect(),
1544 rejoined_projects: rejoined_room
1545 .rejoined_projects
1546 .iter()
1547 .map(|rejoined_project| rejoined_project.to_proto())
1548 .collect(),
1549 })?;
1550 room_updated(&rejoined_room.room, &session.peer);
1551
1552 for project in &rejoined_room.reshared_projects {
1553 for collaborator in &project.collaborators {
1554 session
1555 .peer
1556 .send(
1557 collaborator.connection_id,
1558 proto::UpdateProjectCollaborator {
1559 project_id: project.id.to_proto(),
1560 old_peer_id: Some(project.old_connection_id.into()),
1561 new_peer_id: Some(session.connection_id.into()),
1562 },
1563 )
1564 .trace_err();
1565 }
1566
1567 broadcast(
1568 Some(session.connection_id),
1569 project
1570 .collaborators
1571 .iter()
1572 .map(|collaborator| collaborator.connection_id),
1573 |connection_id| {
1574 session.peer.forward_send(
1575 session.connection_id,
1576 connection_id,
1577 proto::UpdateProject {
1578 project_id: project.id.to_proto(),
1579 worktrees: project.worktrees.clone(),
1580 },
1581 )
1582 },
1583 );
1584 }
1585
1586 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1587
1588 let rejoined_room = rejoined_room.into_inner();
1589
1590 room = rejoined_room.room;
1591 channel = rejoined_room.channel;
1592 }
1593
1594 if let Some(channel) = channel {
1595 channel_updated(
1596 &channel,
1597 &room,
1598 &session.peer,
1599 &*session.connection_pool().await,
1600 );
1601 }
1602
1603 update_user_contacts(session.user_id(), &session).await?;
1604 Ok(())
1605}
1606
1607fn notify_rejoined_projects(
1608 rejoined_projects: &mut Vec<RejoinedProject>,
1609 session: &UserSession,
1610) -> Result<()> {
1611 for project in rejoined_projects.iter() {
1612 for collaborator in &project.collaborators {
1613 session
1614 .peer
1615 .send(
1616 collaborator.connection_id,
1617 proto::UpdateProjectCollaborator {
1618 project_id: project.id.to_proto(),
1619 old_peer_id: Some(project.old_connection_id.into()),
1620 new_peer_id: Some(session.connection_id.into()),
1621 },
1622 )
1623 .trace_err();
1624 }
1625 }
1626
1627 for project in rejoined_projects {
1628 for worktree in mem::take(&mut project.worktrees) {
1629 #[cfg(any(test, feature = "test-support"))]
1630 const MAX_CHUNK_SIZE: usize = 2;
1631 #[cfg(not(any(test, feature = "test-support")))]
1632 const MAX_CHUNK_SIZE: usize = 256;
1633
1634 // Stream this worktree's entries.
1635 let message = proto::UpdateWorktree {
1636 project_id: project.id.to_proto(),
1637 worktree_id: worktree.id,
1638 abs_path: worktree.abs_path.clone(),
1639 root_name: worktree.root_name,
1640 updated_entries: worktree.updated_entries,
1641 removed_entries: worktree.removed_entries,
1642 scan_id: worktree.scan_id,
1643 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1644 updated_repositories: worktree.updated_repositories,
1645 removed_repositories: worktree.removed_repositories,
1646 };
1647 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1648 session.peer.send(session.connection_id, update.clone())?;
1649 }
1650
1651 // Stream this worktree's diagnostics.
1652 for summary in worktree.diagnostic_summaries {
1653 session.peer.send(
1654 session.connection_id,
1655 proto::UpdateDiagnosticSummary {
1656 project_id: project.id.to_proto(),
1657 worktree_id: worktree.id,
1658 summary: Some(summary),
1659 },
1660 )?;
1661 }
1662
1663 for settings_file in worktree.settings_files {
1664 session.peer.send(
1665 session.connection_id,
1666 proto::UpdateWorktreeSettings {
1667 project_id: project.id.to_proto(),
1668 worktree_id: worktree.id,
1669 path: settings_file.path,
1670 content: Some(settings_file.content),
1671 },
1672 )?;
1673 }
1674 }
1675
1676 for language_server in &project.language_servers {
1677 session.peer.send(
1678 session.connection_id,
1679 proto::UpdateLanguageServer {
1680 project_id: project.id.to_proto(),
1681 language_server_id: language_server.id,
1682 variant: Some(
1683 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1684 proto::LspDiskBasedDiagnosticsUpdated {},
1685 ),
1686 ),
1687 },
1688 )?;
1689 }
1690 }
1691 Ok(())
1692}
1693
1694/// leave room disconnects from the room.
1695async fn leave_room(
1696 _: proto::LeaveRoom,
1697 response: Response<proto::LeaveRoom>,
1698 session: UserSession,
1699) -> Result<()> {
1700 leave_room_for_session(&session, session.connection_id).await?;
1701 response.send(proto::Ack {})?;
1702 Ok(())
1703}
1704
1705/// Updates the permissions of someone else in the room.
1706async fn set_room_participant_role(
1707 request: proto::SetRoomParticipantRole,
1708 response: Response<proto::SetRoomParticipantRole>,
1709 session: UserSession,
1710) -> Result<()> {
1711 let user_id = UserId::from_proto(request.user_id);
1712 let role = ChannelRole::from(request.role());
1713
1714 let (live_kit_room, can_publish) = {
1715 let room = session
1716 .db()
1717 .await
1718 .set_room_participant_role(
1719 session.user_id(),
1720 RoomId::from_proto(request.room_id),
1721 user_id,
1722 role,
1723 )
1724 .await?;
1725
1726 let live_kit_room = room.live_kit_room.clone();
1727 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1728 room_updated(&room, &session.peer);
1729 (live_kit_room, can_publish)
1730 };
1731
1732 if let Some(live_kit) = session.live_kit_client.as_ref() {
1733 live_kit
1734 .update_participant(
1735 live_kit_room.clone(),
1736 request.user_id.to_string(),
1737 live_kit_server::proto::ParticipantPermission {
1738 can_subscribe: true,
1739 can_publish,
1740 can_publish_data: can_publish,
1741 hidden: false,
1742 recorder: false,
1743 },
1744 )
1745 .await
1746 .trace_err();
1747 }
1748
1749 response.send(proto::Ack {})?;
1750 Ok(())
1751}
1752
1753/// Call someone else into the current room
1754async fn call(
1755 request: proto::Call,
1756 response: Response<proto::Call>,
1757 session: UserSession,
1758) -> Result<()> {
1759 let room_id = RoomId::from_proto(request.room_id);
1760 let calling_user_id = session.user_id();
1761 let calling_connection_id = session.connection_id;
1762 let called_user_id = UserId::from_proto(request.called_user_id);
1763 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1764 if !session
1765 .db()
1766 .await
1767 .has_contact(calling_user_id, called_user_id)
1768 .await?
1769 {
1770 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1771 }
1772
1773 let incoming_call = {
1774 let (room, incoming_call) = &mut *session
1775 .db()
1776 .await
1777 .call(
1778 room_id,
1779 calling_user_id,
1780 calling_connection_id,
1781 called_user_id,
1782 initial_project_id,
1783 )
1784 .await?;
1785 room_updated(&room, &session.peer);
1786 mem::take(incoming_call)
1787 };
1788 update_user_contacts(called_user_id, &session).await?;
1789
1790 let mut calls = session
1791 .connection_pool()
1792 .await
1793 .user_connection_ids(called_user_id)
1794 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1795 .collect::<FuturesUnordered<_>>();
1796
1797 while let Some(call_response) = calls.next().await {
1798 match call_response.as_ref() {
1799 Ok(_) => {
1800 response.send(proto::Ack {})?;
1801 return Ok(());
1802 }
1803 Err(_) => {
1804 call_response.trace_err();
1805 }
1806 }
1807 }
1808
1809 {
1810 let room = session
1811 .db()
1812 .await
1813 .call_failed(room_id, called_user_id)
1814 .await?;
1815 room_updated(&room, &session.peer);
1816 }
1817 update_user_contacts(called_user_id, &session).await?;
1818
1819 Err(anyhow!("failed to ring user"))?
1820}
1821
1822/// Cancel an outgoing call.
1823async fn cancel_call(
1824 request: proto::CancelCall,
1825 response: Response<proto::CancelCall>,
1826 session: UserSession,
1827) -> Result<()> {
1828 let called_user_id = UserId::from_proto(request.called_user_id);
1829 let room_id = RoomId::from_proto(request.room_id);
1830 {
1831 let room = session
1832 .db()
1833 .await
1834 .cancel_call(room_id, session.connection_id, called_user_id)
1835 .await?;
1836 room_updated(&room, &session.peer);
1837 }
1838
1839 for connection_id in session
1840 .connection_pool()
1841 .await
1842 .user_connection_ids(called_user_id)
1843 {
1844 session
1845 .peer
1846 .send(
1847 connection_id,
1848 proto::CallCanceled {
1849 room_id: room_id.to_proto(),
1850 },
1851 )
1852 .trace_err();
1853 }
1854 response.send(proto::Ack {})?;
1855
1856 update_user_contacts(called_user_id, &session).await?;
1857 Ok(())
1858}
1859
1860/// Decline an incoming call.
1861async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1862 let room_id = RoomId::from_proto(message.room_id);
1863 {
1864 let room = session
1865 .db()
1866 .await
1867 .decline_call(Some(room_id), session.user_id())
1868 .await?
1869 .ok_or_else(|| anyhow!("failed to decline call"))?;
1870 room_updated(&room, &session.peer);
1871 }
1872
1873 for connection_id in session
1874 .connection_pool()
1875 .await
1876 .user_connection_ids(session.user_id())
1877 {
1878 session
1879 .peer
1880 .send(
1881 connection_id,
1882 proto::CallCanceled {
1883 room_id: room_id.to_proto(),
1884 },
1885 )
1886 .trace_err();
1887 }
1888 update_user_contacts(session.user_id(), &session).await?;
1889 Ok(())
1890}
1891
1892/// Updates other participants in the room with your current location.
1893async fn update_participant_location(
1894 request: proto::UpdateParticipantLocation,
1895 response: Response<proto::UpdateParticipantLocation>,
1896 session: UserSession,
1897) -> Result<()> {
1898 let room_id = RoomId::from_proto(request.room_id);
1899 let location = request
1900 .location
1901 .ok_or_else(|| anyhow!("invalid location"))?;
1902
1903 let db = session.db().await;
1904 let room = db
1905 .update_room_participant_location(room_id, session.connection_id, location)
1906 .await?;
1907
1908 room_updated(&room, &session.peer);
1909 response.send(proto::Ack {})?;
1910 Ok(())
1911}
1912
1913/// Share a project into the room.
1914async fn share_project(
1915 request: proto::ShareProject,
1916 response: Response<proto::ShareProject>,
1917 session: UserSession,
1918) -> Result<()> {
1919 let (project_id, room) = &*session
1920 .db()
1921 .await
1922 .share_project(
1923 RoomId::from_proto(request.room_id),
1924 session.connection_id,
1925 &request.worktrees,
1926 )
1927 .await?;
1928 response.send(proto::ShareProjectResponse {
1929 project_id: project_id.to_proto(),
1930 })?;
1931 room_updated(&room, &session.peer);
1932
1933 Ok(())
1934}
1935
1936/// Unshare a project from the room.
1937async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1938 let project_id = ProjectId::from_proto(message.project_id);
1939 unshare_project_internal(project_id, &session).await
1940}
1941
1942async fn unshare_project_internal(project_id: ProjectId, session: &Session) -> Result<()> {
1943 let (room, guest_connection_ids) = &*session
1944 .db()
1945 .await
1946 .unshare_project(project_id, session.connection_id)
1947 .await?;
1948
1949 let message = proto::UnshareProject {
1950 project_id: project_id.to_proto(),
1951 };
1952
1953 broadcast(
1954 Some(session.connection_id),
1955 guest_connection_ids.iter().copied(),
1956 |conn_id| session.peer.send(conn_id, message.clone()),
1957 );
1958 if let Some(room) = room {
1959 room_updated(room, &session.peer);
1960 }
1961
1962 Ok(())
1963}
1964
1965/// Share a project into the room.
1966async fn share_remote_project(
1967 request: proto::ShareRemoteProject,
1968 response: Response<proto::ShareRemoteProject>,
1969 session: DevServerSession,
1970) -> Result<()> {
1971 let remote_project = session
1972 .db()
1973 .await
1974 .share_remote_project(
1975 RemoteProjectId::from_proto(request.remote_project_id),
1976 session.dev_server_id(),
1977 session.connection_id,
1978 &request.worktrees,
1979 )
1980 .await?;
1981 let Some(project_id) = remote_project.project_id else {
1982 return Err(anyhow!("failed to share remote project"))?;
1983 };
1984
1985 for (connection_id, _) in session
1986 .connection_pool()
1987 .await
1988 .channel_connection_ids(ChannelId::from_proto(remote_project.channel_id))
1989 {
1990 session
1991 .peer
1992 .send(
1993 connection_id,
1994 proto::UpdateChannels {
1995 remote_projects: vec![remote_project.clone()],
1996 ..Default::default()
1997 },
1998 )
1999 .trace_err();
2000 }
2001
2002 response.send(proto::ShareProjectResponse { project_id })?;
2003
2004 Ok(())
2005}
2006
2007/// Join someone elses shared project.
2008async fn join_project(
2009 request: proto::JoinProject,
2010 response: Response<proto::JoinProject>,
2011 session: UserSession,
2012) -> Result<()> {
2013 let project_id = ProjectId::from_proto(request.project_id);
2014
2015 tracing::info!(%project_id, "join project");
2016
2017 let db = session.db().await;
2018 let (project, replica_id) = &mut *db
2019 .join_project(project_id, session.connection_id, session.user_id())
2020 .await?;
2021 drop(db);
2022 tracing::info!(%project_id, "join remote project");
2023 join_project_internal(response, session, project, replica_id)
2024}
2025
2026trait JoinProjectInternalResponse {
2027 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2028}
2029impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2030 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2031 Response::<proto::JoinProject>::send(self, result)
2032 }
2033}
2034impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2035 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2036 Response::<proto::JoinHostedProject>::send(self, result)
2037 }
2038}
2039
2040fn join_project_internal(
2041 response: impl JoinProjectInternalResponse,
2042 session: UserSession,
2043 project: &mut Project,
2044 replica_id: &ReplicaId,
2045) -> Result<()> {
2046 let collaborators = project
2047 .collaborators
2048 .iter()
2049 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2050 .map(|collaborator| collaborator.to_proto())
2051 .collect::<Vec<_>>();
2052 let project_id = project.id;
2053 let guest_user_id = session.user_id();
2054
2055 let worktrees = project
2056 .worktrees
2057 .iter()
2058 .map(|(id, worktree)| proto::WorktreeMetadata {
2059 id: *id,
2060 root_name: worktree.root_name.clone(),
2061 visible: worktree.visible,
2062 abs_path: worktree.abs_path.clone(),
2063 })
2064 .collect::<Vec<_>>();
2065
2066 for collaborator in &collaborators {
2067 session
2068 .peer
2069 .send(
2070 collaborator.peer_id.unwrap().into(),
2071 proto::AddProjectCollaborator {
2072 project_id: project_id.to_proto(),
2073 collaborator: Some(proto::Collaborator {
2074 peer_id: Some(session.connection_id.into()),
2075 replica_id: replica_id.0 as u32,
2076 user_id: guest_user_id.to_proto(),
2077 }),
2078 },
2079 )
2080 .trace_err();
2081 }
2082
2083 // First, we send the metadata associated with each worktree.
2084 response.send(proto::JoinProjectResponse {
2085 project_id: project.id.0 as u64,
2086 worktrees: worktrees.clone(),
2087 replica_id: replica_id.0 as u32,
2088 collaborators: collaborators.clone(),
2089 language_servers: project.language_servers.clone(),
2090 role: project.role.into(), // todo
2091 })?;
2092
2093 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2094 #[cfg(any(test, feature = "test-support"))]
2095 const MAX_CHUNK_SIZE: usize = 2;
2096 #[cfg(not(any(test, feature = "test-support")))]
2097 const MAX_CHUNK_SIZE: usize = 256;
2098
2099 // Stream this worktree's entries.
2100 let message = proto::UpdateWorktree {
2101 project_id: project_id.to_proto(),
2102 worktree_id,
2103 abs_path: worktree.abs_path.clone(),
2104 root_name: worktree.root_name,
2105 updated_entries: worktree.entries,
2106 removed_entries: Default::default(),
2107 scan_id: worktree.scan_id,
2108 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2109 updated_repositories: worktree.repository_entries.into_values().collect(),
2110 removed_repositories: Default::default(),
2111 };
2112 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2113 session.peer.send(session.connection_id, update.clone())?;
2114 }
2115
2116 // Stream this worktree's diagnostics.
2117 for summary in worktree.diagnostic_summaries {
2118 session.peer.send(
2119 session.connection_id,
2120 proto::UpdateDiagnosticSummary {
2121 project_id: project_id.to_proto(),
2122 worktree_id: worktree.id,
2123 summary: Some(summary),
2124 },
2125 )?;
2126 }
2127
2128 for settings_file in worktree.settings_files {
2129 session.peer.send(
2130 session.connection_id,
2131 proto::UpdateWorktreeSettings {
2132 project_id: project_id.to_proto(),
2133 worktree_id: worktree.id,
2134 path: settings_file.path,
2135 content: Some(settings_file.content),
2136 },
2137 )?;
2138 }
2139 }
2140
2141 for language_server in &project.language_servers {
2142 session.peer.send(
2143 session.connection_id,
2144 proto::UpdateLanguageServer {
2145 project_id: project_id.to_proto(),
2146 language_server_id: language_server.id,
2147 variant: Some(
2148 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2149 proto::LspDiskBasedDiagnosticsUpdated {},
2150 ),
2151 ),
2152 },
2153 )?;
2154 }
2155
2156 Ok(())
2157}
2158
2159/// Leave someone elses shared project.
2160async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2161 let sender_id = session.connection_id;
2162 let project_id = ProjectId::from_proto(request.project_id);
2163 let db = session.db().await;
2164 if db.is_hosted_project(project_id).await? {
2165 let project = db.leave_hosted_project(project_id, sender_id).await?;
2166 project_left(&project, &session);
2167 return Ok(());
2168 }
2169
2170 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2171 tracing::info!(
2172 %project_id,
2173 host_user_id = ?project.host_user_id,
2174 host_connection_id = ?project.host_connection_id,
2175 "leave project"
2176 );
2177
2178 project_left(&project, &session);
2179 if let Some(room) = room {
2180 room_updated(&room, &session.peer);
2181 }
2182
2183 Ok(())
2184}
2185
2186async fn join_hosted_project(
2187 request: proto::JoinHostedProject,
2188 response: Response<proto::JoinHostedProject>,
2189 session: UserSession,
2190) -> Result<()> {
2191 let (mut project, replica_id) = session
2192 .db()
2193 .await
2194 .join_hosted_project(
2195 ProjectId(request.project_id as i32),
2196 session.user_id(),
2197 session.connection_id,
2198 )
2199 .await?;
2200
2201 join_project_internal(response, session, &mut project, &replica_id)
2202}
2203
2204async fn create_remote_project(
2205 request: proto::CreateRemoteProject,
2206 response: Response<proto::CreateRemoteProject>,
2207 session: UserSession,
2208) -> Result<()> {
2209 let (channel, remote_project) = session
2210 .db()
2211 .await
2212 .create_remote_project(
2213 ChannelId(request.channel_id as i32),
2214 DevServerId(request.dev_server_id as i32),
2215 &request.name,
2216 &request.path,
2217 session.user_id(),
2218 )
2219 .await?;
2220
2221 let projects = session
2222 .db()
2223 .await
2224 .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2225 .await?;
2226
2227 let update = proto::UpdateChannels {
2228 remote_projects: vec![remote_project.to_proto(None)],
2229 ..Default::default()
2230 };
2231 let connection_pool = session.connection_pool().await;
2232 for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2233 if role.can_see_all_descendants() {
2234 session.peer.send(connection_id, update.clone())?;
2235 }
2236 }
2237
2238 let dev_server_id = remote_project.dev_server_id;
2239 let dev_server_connection_id = connection_pool.dev_server_connection_id(dev_server_id);
2240 if let Some(dev_server_connection_id) = dev_server_connection_id {
2241 session.peer.send(
2242 dev_server_connection_id,
2243 proto::DevServerInstructions { projects },
2244 )?;
2245 }
2246
2247 response.send(proto::CreateRemoteProjectResponse {
2248 remote_project: Some(remote_project.to_proto(None)),
2249 })?;
2250 Ok(())
2251}
2252
2253async fn create_dev_server(
2254 request: proto::CreateDevServer,
2255 response: Response<proto::CreateDevServer>,
2256 session: UserSession,
2257) -> Result<()> {
2258 let access_token = auth::random_token();
2259 let hashed_access_token = auth::hash_access_token(&access_token);
2260
2261 let (channel, dev_server) = session
2262 .db()
2263 .await
2264 .create_dev_server(
2265 ChannelId(request.channel_id as i32),
2266 &request.name,
2267 &hashed_access_token,
2268 session.user_id(),
2269 )
2270 .await?;
2271
2272 let update = proto::UpdateChannels {
2273 dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2274 ..Default::default()
2275 };
2276 let connection_pool = session.connection_pool().await;
2277 for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2278 if role.can_see_channel(channel.visibility) {
2279 session.peer.send(connection_id, update.clone())?;
2280 }
2281 }
2282
2283 response.send(proto::CreateDevServerResponse {
2284 dev_server_id: dev_server.id.0 as u64,
2285 channel_id: request.channel_id,
2286 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2287 name: request.name.clone(),
2288 })?;
2289 Ok(())
2290}
2291
2292async fn rejoin_remote_projects(
2293 request: proto::RejoinRemoteProjects,
2294 response: Response<proto::RejoinRemoteProjects>,
2295 session: UserSession,
2296) -> Result<()> {
2297 let mut rejoined_projects = {
2298 let db = session.db().await;
2299 db.rejoin_remote_projects(
2300 &request.rejoined_projects,
2301 session.user_id(),
2302 session.0.connection_id,
2303 )
2304 .await?
2305 };
2306 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2307
2308 response.send(proto::RejoinRemoteProjectsResponse {
2309 rejoined_projects: rejoined_projects
2310 .into_iter()
2311 .map(|project| project.to_proto())
2312 .collect(),
2313 })
2314}
2315
2316async fn reconnect_dev_server(
2317 request: proto::ReconnectDevServer,
2318 response: Response<proto::ReconnectDevServer>,
2319 session: DevServerSession,
2320) -> Result<()> {
2321 let reshared_projects = {
2322 let db = session.db().await;
2323 db.reshare_remote_projects(
2324 &request.reshared_projects,
2325 session.dev_server_id(),
2326 session.0.connection_id,
2327 )
2328 .await?
2329 };
2330
2331 for project in &reshared_projects {
2332 for collaborator in &project.collaborators {
2333 session
2334 .peer
2335 .send(
2336 collaborator.connection_id,
2337 proto::UpdateProjectCollaborator {
2338 project_id: project.id.to_proto(),
2339 old_peer_id: Some(project.old_connection_id.into()),
2340 new_peer_id: Some(session.connection_id.into()),
2341 },
2342 )
2343 .trace_err();
2344 }
2345
2346 broadcast(
2347 Some(session.connection_id),
2348 project
2349 .collaborators
2350 .iter()
2351 .map(|collaborator| collaborator.connection_id),
2352 |connection_id| {
2353 session.peer.forward_send(
2354 session.connection_id,
2355 connection_id,
2356 proto::UpdateProject {
2357 project_id: project.id.to_proto(),
2358 worktrees: project.worktrees.clone(),
2359 },
2360 )
2361 },
2362 );
2363 }
2364
2365 response.send(proto::ReconnectDevServerResponse {
2366 reshared_projects: reshared_projects
2367 .iter()
2368 .map(|project| proto::ResharedProject {
2369 id: project.id.to_proto(),
2370 collaborators: project
2371 .collaborators
2372 .iter()
2373 .map(|collaborator| collaborator.to_proto())
2374 .collect(),
2375 })
2376 .collect(),
2377 })?;
2378
2379 Ok(())
2380}
2381
2382async fn shutdown_dev_server(
2383 _: proto::ShutdownDevServer,
2384 response: Response<proto::ShutdownDevServer>,
2385 session: DevServerSession,
2386) -> Result<()> {
2387 response.send(proto::Ack {})?;
2388 let (remote_projects, dev_server) = {
2389 let dev_server_id = session.dev_server_id();
2390 let db = session.db().await;
2391 let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2392 let dev_server = db.get_dev_server(dev_server_id).await?;
2393 (remote_projects, dev_server)
2394 };
2395
2396 for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2397 unshare_project_internal(ProjectId::from_proto(project_id), &session.0).await?;
2398 }
2399
2400 let update = proto::UpdateChannels {
2401 remote_projects,
2402 dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2403 ..Default::default()
2404 };
2405
2406 for (connection_id, _) in session
2407 .connection_pool()
2408 .await
2409 .channel_connection_ids(dev_server.channel_id)
2410 {
2411 session.peer.send(connection_id, update.clone()).trace_err();
2412 }
2413
2414 Ok(())
2415}
2416
2417/// Updates other participants with changes to the project
2418async fn update_project(
2419 request: proto::UpdateProject,
2420 response: Response<proto::UpdateProject>,
2421 session: Session,
2422) -> Result<()> {
2423 let project_id = ProjectId::from_proto(request.project_id);
2424 let (room, guest_connection_ids) = &*session
2425 .db()
2426 .await
2427 .update_project(project_id, session.connection_id, &request.worktrees)
2428 .await?;
2429 broadcast(
2430 Some(session.connection_id),
2431 guest_connection_ids.iter().copied(),
2432 |connection_id| {
2433 session
2434 .peer
2435 .forward_send(session.connection_id, connection_id, request.clone())
2436 },
2437 );
2438 if let Some(room) = room {
2439 room_updated(&room, &session.peer);
2440 }
2441 response.send(proto::Ack {})?;
2442
2443 Ok(())
2444}
2445
2446/// Updates other participants with changes to the worktree
2447async fn update_worktree(
2448 request: proto::UpdateWorktree,
2449 response: Response<proto::UpdateWorktree>,
2450 session: Session,
2451) -> Result<()> {
2452 let guest_connection_ids = session
2453 .db()
2454 .await
2455 .update_worktree(&request, session.connection_id)
2456 .await?;
2457
2458 broadcast(
2459 Some(session.connection_id),
2460 guest_connection_ids.iter().copied(),
2461 |connection_id| {
2462 session
2463 .peer
2464 .forward_send(session.connection_id, connection_id, request.clone())
2465 },
2466 );
2467 response.send(proto::Ack {})?;
2468 Ok(())
2469}
2470
2471/// Updates other participants with changes to the diagnostics
2472async fn update_diagnostic_summary(
2473 message: proto::UpdateDiagnosticSummary,
2474 session: Session,
2475) -> Result<()> {
2476 let guest_connection_ids = session
2477 .db()
2478 .await
2479 .update_diagnostic_summary(&message, session.connection_id)
2480 .await?;
2481
2482 broadcast(
2483 Some(session.connection_id),
2484 guest_connection_ids.iter().copied(),
2485 |connection_id| {
2486 session
2487 .peer
2488 .forward_send(session.connection_id, connection_id, message.clone())
2489 },
2490 );
2491
2492 Ok(())
2493}
2494
2495/// Updates other participants with changes to the worktree settings
2496async fn update_worktree_settings(
2497 message: proto::UpdateWorktreeSettings,
2498 session: Session,
2499) -> Result<()> {
2500 let guest_connection_ids = session
2501 .db()
2502 .await
2503 .update_worktree_settings(&message, session.connection_id)
2504 .await?;
2505
2506 broadcast(
2507 Some(session.connection_id),
2508 guest_connection_ids.iter().copied(),
2509 |connection_id| {
2510 session
2511 .peer
2512 .forward_send(session.connection_id, connection_id, message.clone())
2513 },
2514 );
2515
2516 Ok(())
2517}
2518
2519/// Notify other participants that a language server has started.
2520async fn start_language_server(
2521 request: proto::StartLanguageServer,
2522 session: Session,
2523) -> Result<()> {
2524 let guest_connection_ids = session
2525 .db()
2526 .await
2527 .start_language_server(&request, session.connection_id)
2528 .await?;
2529
2530 broadcast(
2531 Some(session.connection_id),
2532 guest_connection_ids.iter().copied(),
2533 |connection_id| {
2534 session
2535 .peer
2536 .forward_send(session.connection_id, connection_id, request.clone())
2537 },
2538 );
2539 Ok(())
2540}
2541
2542/// Notify other participants that a language server has changed.
2543async fn update_language_server(
2544 request: proto::UpdateLanguageServer,
2545 session: Session,
2546) -> Result<()> {
2547 let project_id = ProjectId::from_proto(request.project_id);
2548 let project_connection_ids = session
2549 .db()
2550 .await
2551 .project_connection_ids(project_id, session.connection_id, true)
2552 .await?;
2553 broadcast(
2554 Some(session.connection_id),
2555 project_connection_ids.iter().copied(),
2556 |connection_id| {
2557 session
2558 .peer
2559 .forward_send(session.connection_id, connection_id, request.clone())
2560 },
2561 );
2562 Ok(())
2563}
2564
2565/// forward a project request to the host. These requests should be read only
2566/// as guests are allowed to send them.
2567async fn forward_read_only_project_request<T>(
2568 request: T,
2569 response: Response<T>,
2570 session: UserSession,
2571) -> Result<()>
2572where
2573 T: EntityMessage + RequestMessage,
2574{
2575 let project_id = ProjectId::from_proto(request.remote_entity_id());
2576 let host_connection_id = session
2577 .db()
2578 .await
2579 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2580 .await?;
2581 let payload = session
2582 .peer
2583 .forward_request(session.connection_id, host_connection_id, request)
2584 .await?;
2585 response.send(payload)?;
2586 Ok(())
2587}
2588
2589/// forward a project request to the host. These requests are disallowed
2590/// for guests.
2591async fn forward_mutating_project_request<T>(
2592 request: T,
2593 response: Response<T>,
2594 session: UserSession,
2595) -> Result<()>
2596where
2597 T: EntityMessage + RequestMessage,
2598{
2599 let project_id = ProjectId::from_proto(request.remote_entity_id());
2600 let host_connection_id = session
2601 .db()
2602 .await
2603 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2604 .await?;
2605 let payload = session
2606 .peer
2607 .forward_request(session.connection_id, host_connection_id, request)
2608 .await?;
2609 response.send(payload)?;
2610 Ok(())
2611}
2612
2613/// Notify other participants that a new buffer has been created
2614async fn create_buffer_for_peer(
2615 request: proto::CreateBufferForPeer,
2616 session: Session,
2617) -> Result<()> {
2618 session
2619 .db()
2620 .await
2621 .check_user_is_project_host(
2622 ProjectId::from_proto(request.project_id),
2623 session.connection_id,
2624 )
2625 .await?;
2626 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2627 session
2628 .peer
2629 .forward_send(session.connection_id, peer_id.into(), request)?;
2630 Ok(())
2631}
2632
2633/// Notify other participants that a buffer has been updated. This is
2634/// allowed for guests as long as the update is limited to selections.
2635async fn update_buffer(
2636 request: proto::UpdateBuffer,
2637 response: Response<proto::UpdateBuffer>,
2638 session: Session,
2639) -> Result<()> {
2640 let project_id = ProjectId::from_proto(request.project_id);
2641 let mut capability = Capability::ReadOnly;
2642
2643 for op in request.operations.iter() {
2644 match op.variant {
2645 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2646 Some(_) => capability = Capability::ReadWrite,
2647 }
2648 }
2649
2650 let host = {
2651 let guard = session
2652 .db()
2653 .await
2654 .connections_for_buffer_update(
2655 project_id,
2656 session.principal_id(),
2657 session.connection_id,
2658 capability,
2659 )
2660 .await?;
2661
2662 let (host, guests) = &*guard;
2663
2664 broadcast(
2665 Some(session.connection_id),
2666 guests.clone(),
2667 |connection_id| {
2668 session
2669 .peer
2670 .forward_send(session.connection_id, connection_id, request.clone())
2671 },
2672 );
2673
2674 *host
2675 };
2676
2677 if host != session.connection_id {
2678 session
2679 .peer
2680 .forward_request(session.connection_id, host, request.clone())
2681 .await?;
2682 }
2683
2684 response.send(proto::Ack {})?;
2685 Ok(())
2686}
2687
2688/// Notify other participants that a project has been updated.
2689async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2690 request: T,
2691 session: Session,
2692) -> Result<()> {
2693 let project_id = ProjectId::from_proto(request.remote_entity_id());
2694 let project_connection_ids = session
2695 .db()
2696 .await
2697 .project_connection_ids(project_id, session.connection_id, false)
2698 .await?;
2699
2700 broadcast(
2701 Some(session.connection_id),
2702 project_connection_ids.iter().copied(),
2703 |connection_id| {
2704 session
2705 .peer
2706 .forward_send(session.connection_id, connection_id, request.clone())
2707 },
2708 );
2709 Ok(())
2710}
2711
2712/// Start following another user in a call.
2713async fn follow(
2714 request: proto::Follow,
2715 response: Response<proto::Follow>,
2716 session: UserSession,
2717) -> Result<()> {
2718 let room_id = RoomId::from_proto(request.room_id);
2719 let project_id = request.project_id.map(ProjectId::from_proto);
2720 let leader_id = request
2721 .leader_id
2722 .ok_or_else(|| anyhow!("invalid leader id"))?
2723 .into();
2724 let follower_id = session.connection_id;
2725
2726 session
2727 .db()
2728 .await
2729 .check_room_participants(room_id, leader_id, session.connection_id)
2730 .await?;
2731
2732 let response_payload = session
2733 .peer
2734 .forward_request(session.connection_id, leader_id, request)
2735 .await?;
2736 response.send(response_payload)?;
2737
2738 if let Some(project_id) = project_id {
2739 let room = session
2740 .db()
2741 .await
2742 .follow(room_id, project_id, leader_id, follower_id)
2743 .await?;
2744 room_updated(&room, &session.peer);
2745 }
2746
2747 Ok(())
2748}
2749
2750/// Stop following another user in a call.
2751async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2752 let room_id = RoomId::from_proto(request.room_id);
2753 let project_id = request.project_id.map(ProjectId::from_proto);
2754 let leader_id = request
2755 .leader_id
2756 .ok_or_else(|| anyhow!("invalid leader id"))?
2757 .into();
2758 let follower_id = session.connection_id;
2759
2760 session
2761 .db()
2762 .await
2763 .check_room_participants(room_id, leader_id, session.connection_id)
2764 .await?;
2765
2766 session
2767 .peer
2768 .forward_send(session.connection_id, leader_id, request)?;
2769
2770 if let Some(project_id) = project_id {
2771 let room = session
2772 .db()
2773 .await
2774 .unfollow(room_id, project_id, leader_id, follower_id)
2775 .await?;
2776 room_updated(&room, &session.peer);
2777 }
2778
2779 Ok(())
2780}
2781
2782/// Notify everyone following you of your current location.
2783async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2784 let room_id = RoomId::from_proto(request.room_id);
2785 let database = session.db.lock().await;
2786
2787 let connection_ids = if let Some(project_id) = request.project_id {
2788 let project_id = ProjectId::from_proto(project_id);
2789 database
2790 .project_connection_ids(project_id, session.connection_id, true)
2791 .await?
2792 } else {
2793 database
2794 .room_connection_ids(room_id, session.connection_id)
2795 .await?
2796 };
2797
2798 // For now, don't send view update messages back to that view's current leader.
2799 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2800 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2801 _ => None,
2802 });
2803
2804 for connection_id in connection_ids.iter().cloned() {
2805 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2806 session
2807 .peer
2808 .forward_send(session.connection_id, connection_id, request.clone())?;
2809 }
2810 }
2811 Ok(())
2812}
2813
2814/// Get public data about users.
2815async fn get_users(
2816 request: proto::GetUsers,
2817 response: Response<proto::GetUsers>,
2818 session: Session,
2819) -> Result<()> {
2820 let user_ids = request
2821 .user_ids
2822 .into_iter()
2823 .map(UserId::from_proto)
2824 .collect();
2825 let users = session
2826 .db()
2827 .await
2828 .get_users_by_ids(user_ids)
2829 .await?
2830 .into_iter()
2831 .map(|user| proto::User {
2832 id: user.id.to_proto(),
2833 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2834 github_login: user.github_login,
2835 })
2836 .collect();
2837 response.send(proto::UsersResponse { users })?;
2838 Ok(())
2839}
2840
2841/// Search for users (to invite) buy Github login
2842async fn fuzzy_search_users(
2843 request: proto::FuzzySearchUsers,
2844 response: Response<proto::FuzzySearchUsers>,
2845 session: UserSession,
2846) -> Result<()> {
2847 let query = request.query;
2848 let users = match query.len() {
2849 0 => vec![],
2850 1 | 2 => session
2851 .db()
2852 .await
2853 .get_user_by_github_login(&query)
2854 .await?
2855 .into_iter()
2856 .collect(),
2857 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2858 };
2859 let users = users
2860 .into_iter()
2861 .filter(|user| user.id != session.user_id())
2862 .map(|user| proto::User {
2863 id: user.id.to_proto(),
2864 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2865 github_login: user.github_login,
2866 })
2867 .collect();
2868 response.send(proto::UsersResponse { users })?;
2869 Ok(())
2870}
2871
2872/// Send a contact request to another user.
2873async fn request_contact(
2874 request: proto::RequestContact,
2875 response: Response<proto::RequestContact>,
2876 session: UserSession,
2877) -> Result<()> {
2878 let requester_id = session.user_id();
2879 let responder_id = UserId::from_proto(request.responder_id);
2880 if requester_id == responder_id {
2881 return Err(anyhow!("cannot add yourself as a contact"))?;
2882 }
2883
2884 let notifications = session
2885 .db()
2886 .await
2887 .send_contact_request(requester_id, responder_id)
2888 .await?;
2889
2890 // Update outgoing contact requests of requester
2891 let mut update = proto::UpdateContacts::default();
2892 update.outgoing_requests.push(responder_id.to_proto());
2893 for connection_id in session
2894 .connection_pool()
2895 .await
2896 .user_connection_ids(requester_id)
2897 {
2898 session.peer.send(connection_id, update.clone())?;
2899 }
2900
2901 // Update incoming contact requests of responder
2902 let mut update = proto::UpdateContacts::default();
2903 update
2904 .incoming_requests
2905 .push(proto::IncomingContactRequest {
2906 requester_id: requester_id.to_proto(),
2907 });
2908 let connection_pool = session.connection_pool().await;
2909 for connection_id in connection_pool.user_connection_ids(responder_id) {
2910 session.peer.send(connection_id, update.clone())?;
2911 }
2912
2913 send_notifications(&connection_pool, &session.peer, notifications);
2914
2915 response.send(proto::Ack {})?;
2916 Ok(())
2917}
2918
2919/// Accept or decline a contact request
2920async fn respond_to_contact_request(
2921 request: proto::RespondToContactRequest,
2922 response: Response<proto::RespondToContactRequest>,
2923 session: UserSession,
2924) -> Result<()> {
2925 let responder_id = session.user_id();
2926 let requester_id = UserId::from_proto(request.requester_id);
2927 let db = session.db().await;
2928 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2929 db.dismiss_contact_notification(responder_id, requester_id)
2930 .await?;
2931 } else {
2932 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2933
2934 let notifications = db
2935 .respond_to_contact_request(responder_id, requester_id, accept)
2936 .await?;
2937 let requester_busy = db.is_user_busy(requester_id).await?;
2938 let responder_busy = db.is_user_busy(responder_id).await?;
2939
2940 let pool = session.connection_pool().await;
2941 // Update responder with new contact
2942 let mut update = proto::UpdateContacts::default();
2943 if accept {
2944 update
2945 .contacts
2946 .push(contact_for_user(requester_id, requester_busy, &pool));
2947 }
2948 update
2949 .remove_incoming_requests
2950 .push(requester_id.to_proto());
2951 for connection_id in pool.user_connection_ids(responder_id) {
2952 session.peer.send(connection_id, update.clone())?;
2953 }
2954
2955 // Update requester with new contact
2956 let mut update = proto::UpdateContacts::default();
2957 if accept {
2958 update
2959 .contacts
2960 .push(contact_for_user(responder_id, responder_busy, &pool));
2961 }
2962 update
2963 .remove_outgoing_requests
2964 .push(responder_id.to_proto());
2965
2966 for connection_id in pool.user_connection_ids(requester_id) {
2967 session.peer.send(connection_id, update.clone())?;
2968 }
2969
2970 send_notifications(&pool, &session.peer, notifications);
2971 }
2972
2973 response.send(proto::Ack {})?;
2974 Ok(())
2975}
2976
2977/// Remove a contact.
2978async fn remove_contact(
2979 request: proto::RemoveContact,
2980 response: Response<proto::RemoveContact>,
2981 session: UserSession,
2982) -> Result<()> {
2983 let requester_id = session.user_id();
2984 let responder_id = UserId::from_proto(request.user_id);
2985 let db = session.db().await;
2986 let (contact_accepted, deleted_notification_id) =
2987 db.remove_contact(requester_id, responder_id).await?;
2988
2989 let pool = session.connection_pool().await;
2990 // Update outgoing contact requests of requester
2991 let mut update = proto::UpdateContacts::default();
2992 if contact_accepted {
2993 update.remove_contacts.push(responder_id.to_proto());
2994 } else {
2995 update
2996 .remove_outgoing_requests
2997 .push(responder_id.to_proto());
2998 }
2999 for connection_id in pool.user_connection_ids(requester_id) {
3000 session.peer.send(connection_id, update.clone())?;
3001 }
3002
3003 // Update incoming contact requests of responder
3004 let mut update = proto::UpdateContacts::default();
3005 if contact_accepted {
3006 update.remove_contacts.push(requester_id.to_proto());
3007 } else {
3008 update
3009 .remove_incoming_requests
3010 .push(requester_id.to_proto());
3011 }
3012 for connection_id in pool.user_connection_ids(responder_id) {
3013 session.peer.send(connection_id, update.clone())?;
3014 if let Some(notification_id) = deleted_notification_id {
3015 session.peer.send(
3016 connection_id,
3017 proto::DeleteNotification {
3018 notification_id: notification_id.to_proto(),
3019 },
3020 )?;
3021 }
3022 }
3023
3024 response.send(proto::Ack {})?;
3025 Ok(())
3026}
3027
3028/// Creates a new channel.
3029async fn create_channel(
3030 request: proto::CreateChannel,
3031 response: Response<proto::CreateChannel>,
3032 session: UserSession,
3033) -> Result<()> {
3034 let db = session.db().await;
3035
3036 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3037 let (channel, membership) = db
3038 .create_channel(&request.name, parent_id, session.user_id())
3039 .await?;
3040
3041 let root_id = channel.root_id();
3042 let channel = Channel::from_model(channel);
3043
3044 response.send(proto::CreateChannelResponse {
3045 channel: Some(channel.to_proto()),
3046 parent_id: request.parent_id,
3047 })?;
3048
3049 let mut connection_pool = session.connection_pool().await;
3050 if let Some(membership) = membership {
3051 connection_pool.subscribe_to_channel(
3052 membership.user_id,
3053 membership.channel_id,
3054 membership.role,
3055 );
3056 let update = proto::UpdateUserChannels {
3057 channel_memberships: vec![proto::ChannelMembership {
3058 channel_id: membership.channel_id.to_proto(),
3059 role: membership.role.into(),
3060 }],
3061 ..Default::default()
3062 };
3063 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3064 session.peer.send(connection_id, update.clone())?;
3065 }
3066 }
3067
3068 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3069 if !role.can_see_channel(channel.visibility) {
3070 continue;
3071 }
3072
3073 let update = proto::UpdateChannels {
3074 channels: vec![channel.to_proto()],
3075 ..Default::default()
3076 };
3077 session.peer.send(connection_id, update.clone())?;
3078 }
3079
3080 Ok(())
3081}
3082
3083/// Delete a channel
3084async fn delete_channel(
3085 request: proto::DeleteChannel,
3086 response: Response<proto::DeleteChannel>,
3087 session: UserSession,
3088) -> Result<()> {
3089 let db = session.db().await;
3090
3091 let channel_id = request.channel_id;
3092 let (root_channel, removed_channels) = db
3093 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3094 .await?;
3095 response.send(proto::Ack {})?;
3096
3097 // Notify members of removed channels
3098 let mut update = proto::UpdateChannels::default();
3099 update
3100 .delete_channels
3101 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3102
3103 let connection_pool = session.connection_pool().await;
3104 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3105 session.peer.send(connection_id, update.clone())?;
3106 }
3107
3108 Ok(())
3109}
3110
3111/// Invite someone to join a channel.
3112async fn invite_channel_member(
3113 request: proto::InviteChannelMember,
3114 response: Response<proto::InviteChannelMember>,
3115 session: UserSession,
3116) -> Result<()> {
3117 let db = session.db().await;
3118 let channel_id = ChannelId::from_proto(request.channel_id);
3119 let invitee_id = UserId::from_proto(request.user_id);
3120 let InviteMemberResult {
3121 channel,
3122 notifications,
3123 } = db
3124 .invite_channel_member(
3125 channel_id,
3126 invitee_id,
3127 session.user_id(),
3128 request.role().into(),
3129 )
3130 .await?;
3131
3132 let update = proto::UpdateChannels {
3133 channel_invitations: vec![channel.to_proto()],
3134 ..Default::default()
3135 };
3136
3137 let connection_pool = session.connection_pool().await;
3138 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3139 session.peer.send(connection_id, update.clone())?;
3140 }
3141
3142 send_notifications(&connection_pool, &session.peer, notifications);
3143
3144 response.send(proto::Ack {})?;
3145 Ok(())
3146}
3147
3148/// remove someone from a channel
3149async fn remove_channel_member(
3150 request: proto::RemoveChannelMember,
3151 response: Response<proto::RemoveChannelMember>,
3152 session: UserSession,
3153) -> Result<()> {
3154 let db = session.db().await;
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let member_id = UserId::from_proto(request.user_id);
3157
3158 let RemoveChannelMemberResult {
3159 membership_update,
3160 notification_id,
3161 } = db
3162 .remove_channel_member(channel_id, member_id, session.user_id())
3163 .await?;
3164
3165 let mut connection_pool = session.connection_pool().await;
3166 notify_membership_updated(
3167 &mut connection_pool,
3168 membership_update,
3169 member_id,
3170 &session.peer,
3171 );
3172 for connection_id in connection_pool.user_connection_ids(member_id) {
3173 if let Some(notification_id) = notification_id {
3174 session
3175 .peer
3176 .send(
3177 connection_id,
3178 proto::DeleteNotification {
3179 notification_id: notification_id.to_proto(),
3180 },
3181 )
3182 .trace_err();
3183 }
3184 }
3185
3186 response.send(proto::Ack {})?;
3187 Ok(())
3188}
3189
3190/// Toggle the channel between public and private.
3191/// Care is taken to maintain the invariant that public channels only descend from public channels,
3192/// (though members-only channels can appear at any point in the hierarchy).
3193async fn set_channel_visibility(
3194 request: proto::SetChannelVisibility,
3195 response: Response<proto::SetChannelVisibility>,
3196 session: UserSession,
3197) -> Result<()> {
3198 let db = session.db().await;
3199 let channel_id = ChannelId::from_proto(request.channel_id);
3200 let visibility = request.visibility().into();
3201
3202 let channel_model = db
3203 .set_channel_visibility(channel_id, visibility, session.user_id())
3204 .await?;
3205 let root_id = channel_model.root_id();
3206 let channel = Channel::from_model(channel_model);
3207
3208 let mut connection_pool = session.connection_pool().await;
3209 for (user_id, role) in connection_pool
3210 .channel_user_ids(root_id)
3211 .collect::<Vec<_>>()
3212 .into_iter()
3213 {
3214 let update = if role.can_see_channel(channel.visibility) {
3215 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3216 proto::UpdateChannels {
3217 channels: vec![channel.to_proto()],
3218 ..Default::default()
3219 }
3220 } else {
3221 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3222 proto::UpdateChannels {
3223 delete_channels: vec![channel.id.to_proto()],
3224 ..Default::default()
3225 }
3226 };
3227
3228 for connection_id in connection_pool.user_connection_ids(user_id) {
3229 session.peer.send(connection_id, update.clone())?;
3230 }
3231 }
3232
3233 response.send(proto::Ack {})?;
3234 Ok(())
3235}
3236
3237/// Alter the role for a user in the channel.
3238async fn set_channel_member_role(
3239 request: proto::SetChannelMemberRole,
3240 response: Response<proto::SetChannelMemberRole>,
3241 session: UserSession,
3242) -> Result<()> {
3243 let db = session.db().await;
3244 let channel_id = ChannelId::from_proto(request.channel_id);
3245 let member_id = UserId::from_proto(request.user_id);
3246 let result = db
3247 .set_channel_member_role(
3248 channel_id,
3249 session.user_id(),
3250 member_id,
3251 request.role().into(),
3252 )
3253 .await?;
3254
3255 match result {
3256 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3257 let mut connection_pool = session.connection_pool().await;
3258 notify_membership_updated(
3259 &mut connection_pool,
3260 membership_update,
3261 member_id,
3262 &session.peer,
3263 )
3264 }
3265 db::SetMemberRoleResult::InviteUpdated(channel) => {
3266 let update = proto::UpdateChannels {
3267 channel_invitations: vec![channel.to_proto()],
3268 ..Default::default()
3269 };
3270
3271 for connection_id in session
3272 .connection_pool()
3273 .await
3274 .user_connection_ids(member_id)
3275 {
3276 session.peer.send(connection_id, update.clone())?;
3277 }
3278 }
3279 }
3280
3281 response.send(proto::Ack {})?;
3282 Ok(())
3283}
3284
3285/// Change the name of a channel
3286async fn rename_channel(
3287 request: proto::RenameChannel,
3288 response: Response<proto::RenameChannel>,
3289 session: UserSession,
3290) -> Result<()> {
3291 let db = session.db().await;
3292 let channel_id = ChannelId::from_proto(request.channel_id);
3293 let channel_model = db
3294 .rename_channel(channel_id, session.user_id(), &request.name)
3295 .await?;
3296 let root_id = channel_model.root_id();
3297 let channel = Channel::from_model(channel_model);
3298
3299 response.send(proto::RenameChannelResponse {
3300 channel: Some(channel.to_proto()),
3301 })?;
3302
3303 let connection_pool = session.connection_pool().await;
3304 let update = proto::UpdateChannels {
3305 channels: vec![channel.to_proto()],
3306 ..Default::default()
3307 };
3308 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3309 if role.can_see_channel(channel.visibility) {
3310 session.peer.send(connection_id, update.clone())?;
3311 }
3312 }
3313
3314 Ok(())
3315}
3316
3317/// Move a channel to a new parent.
3318async fn move_channel(
3319 request: proto::MoveChannel,
3320 response: Response<proto::MoveChannel>,
3321 session: UserSession,
3322) -> Result<()> {
3323 let channel_id = ChannelId::from_proto(request.channel_id);
3324 let to = ChannelId::from_proto(request.to);
3325
3326 let (root_id, channels) = session
3327 .db()
3328 .await
3329 .move_channel(channel_id, to, session.user_id())
3330 .await?;
3331
3332 let connection_pool = session.connection_pool().await;
3333 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3334 let channels = channels
3335 .iter()
3336 .filter_map(|channel| {
3337 if role.can_see_channel(channel.visibility) {
3338 Some(channel.to_proto())
3339 } else {
3340 None
3341 }
3342 })
3343 .collect::<Vec<_>>();
3344 if channels.is_empty() {
3345 continue;
3346 }
3347
3348 let update = proto::UpdateChannels {
3349 channels,
3350 ..Default::default()
3351 };
3352
3353 session.peer.send(connection_id, update.clone())?;
3354 }
3355
3356 response.send(Ack {})?;
3357 Ok(())
3358}
3359
3360/// Get the list of channel members
3361async fn get_channel_members(
3362 request: proto::GetChannelMembers,
3363 response: Response<proto::GetChannelMembers>,
3364 session: UserSession,
3365) -> Result<()> {
3366 let db = session.db().await;
3367 let channel_id = ChannelId::from_proto(request.channel_id);
3368 let members = db
3369 .get_channel_participant_details(channel_id, session.user_id())
3370 .await?;
3371 response.send(proto::GetChannelMembersResponse { members })?;
3372 Ok(())
3373}
3374
3375/// Accept or decline a channel invitation.
3376async fn respond_to_channel_invite(
3377 request: proto::RespondToChannelInvite,
3378 response: Response<proto::RespondToChannelInvite>,
3379 session: UserSession,
3380) -> Result<()> {
3381 let db = session.db().await;
3382 let channel_id = ChannelId::from_proto(request.channel_id);
3383 let RespondToChannelInvite {
3384 membership_update,
3385 notifications,
3386 } = db
3387 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3388 .await?;
3389
3390 let mut connection_pool = session.connection_pool().await;
3391 if let Some(membership_update) = membership_update {
3392 notify_membership_updated(
3393 &mut connection_pool,
3394 membership_update,
3395 session.user_id(),
3396 &session.peer,
3397 );
3398 } else {
3399 let update = proto::UpdateChannels {
3400 remove_channel_invitations: vec![channel_id.to_proto()],
3401 ..Default::default()
3402 };
3403
3404 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3405 session.peer.send(connection_id, update.clone())?;
3406 }
3407 };
3408
3409 send_notifications(&connection_pool, &session.peer, notifications);
3410
3411 response.send(proto::Ack {})?;
3412
3413 Ok(())
3414}
3415
3416/// Join the channels' room
3417async fn join_channel(
3418 request: proto::JoinChannel,
3419 response: Response<proto::JoinChannel>,
3420 session: UserSession,
3421) -> Result<()> {
3422 let channel_id = ChannelId::from_proto(request.channel_id);
3423 join_channel_internal(channel_id, Box::new(response), session).await
3424}
3425
3426trait JoinChannelInternalResponse {
3427 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3428}
3429impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3430 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3431 Response::<proto::JoinChannel>::send(self, result)
3432 }
3433}
3434impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3435 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3436 Response::<proto::JoinRoom>::send(self, result)
3437 }
3438}
3439
3440async fn join_channel_internal(
3441 channel_id: ChannelId,
3442 response: Box<impl JoinChannelInternalResponse>,
3443 session: UserSession,
3444) -> Result<()> {
3445 let joined_room = {
3446 let mut db = session.db().await;
3447 // If zed quits without leaving the room, and the user re-opens zed before the
3448 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3449 // room they were in.
3450 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3451 tracing::info!(
3452 stale_connection_id = %connection,
3453 "cleaning up stale connection",
3454 );
3455 drop(db);
3456 leave_room_for_session(&session, connection).await?;
3457 db = session.db().await;
3458 }
3459
3460 let (joined_room, membership_updated, role) = db
3461 .join_channel(channel_id, session.user_id(), session.connection_id)
3462 .await?;
3463
3464 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3465 let (can_publish, token) = if role == ChannelRole::Guest {
3466 (
3467 false,
3468 live_kit
3469 .guest_token(
3470 &joined_room.room.live_kit_room,
3471 &session.user_id().to_string(),
3472 )
3473 .trace_err()?,
3474 )
3475 } else {
3476 (
3477 true,
3478 live_kit
3479 .room_token(
3480 &joined_room.room.live_kit_room,
3481 &session.user_id().to_string(),
3482 )
3483 .trace_err()?,
3484 )
3485 };
3486
3487 Some(LiveKitConnectionInfo {
3488 server_url: live_kit.url().into(),
3489 token,
3490 can_publish,
3491 })
3492 });
3493
3494 response.send(proto::JoinRoomResponse {
3495 room: Some(joined_room.room.clone()),
3496 channel_id: joined_room
3497 .channel
3498 .as_ref()
3499 .map(|channel| channel.id.to_proto()),
3500 live_kit_connection_info,
3501 })?;
3502
3503 let mut connection_pool = session.connection_pool().await;
3504 if let Some(membership_updated) = membership_updated {
3505 notify_membership_updated(
3506 &mut connection_pool,
3507 membership_updated,
3508 session.user_id(),
3509 &session.peer,
3510 );
3511 }
3512
3513 room_updated(&joined_room.room, &session.peer);
3514
3515 joined_room
3516 };
3517
3518 channel_updated(
3519 &joined_room
3520 .channel
3521 .ok_or_else(|| anyhow!("channel not returned"))?,
3522 &joined_room.room,
3523 &session.peer,
3524 &*session.connection_pool().await,
3525 );
3526
3527 update_user_contacts(session.user_id(), &session).await?;
3528 Ok(())
3529}
3530
3531/// Start editing the channel notes
3532async fn join_channel_buffer(
3533 request: proto::JoinChannelBuffer,
3534 response: Response<proto::JoinChannelBuffer>,
3535 session: UserSession,
3536) -> Result<()> {
3537 let db = session.db().await;
3538 let channel_id = ChannelId::from_proto(request.channel_id);
3539
3540 let open_response = db
3541 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3542 .await?;
3543
3544 let collaborators = open_response.collaborators.clone();
3545 response.send(open_response)?;
3546
3547 let update = UpdateChannelBufferCollaborators {
3548 channel_id: channel_id.to_proto(),
3549 collaborators: collaborators.clone(),
3550 };
3551 channel_buffer_updated(
3552 session.connection_id,
3553 collaborators
3554 .iter()
3555 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3556 &update,
3557 &session.peer,
3558 );
3559
3560 Ok(())
3561}
3562
3563/// Edit the channel notes
3564async fn update_channel_buffer(
3565 request: proto::UpdateChannelBuffer,
3566 session: UserSession,
3567) -> Result<()> {
3568 let db = session.db().await;
3569 let channel_id = ChannelId::from_proto(request.channel_id);
3570
3571 let (collaborators, non_collaborators, epoch, version) = db
3572 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3573 .await?;
3574
3575 channel_buffer_updated(
3576 session.connection_id,
3577 collaborators,
3578 &proto::UpdateChannelBuffer {
3579 channel_id: channel_id.to_proto(),
3580 operations: request.operations,
3581 },
3582 &session.peer,
3583 );
3584
3585 let pool = &*session.connection_pool().await;
3586
3587 broadcast(
3588 None,
3589 non_collaborators
3590 .iter()
3591 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3592 |peer_id| {
3593 session.peer.send(
3594 peer_id,
3595 proto::UpdateChannels {
3596 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3597 channel_id: channel_id.to_proto(),
3598 epoch: epoch as u64,
3599 version: version.clone(),
3600 }],
3601 ..Default::default()
3602 },
3603 )
3604 },
3605 );
3606
3607 Ok(())
3608}
3609
3610/// Rejoin the channel notes after a connection blip
3611async fn rejoin_channel_buffers(
3612 request: proto::RejoinChannelBuffers,
3613 response: Response<proto::RejoinChannelBuffers>,
3614 session: UserSession,
3615) -> Result<()> {
3616 let db = session.db().await;
3617 let buffers = db
3618 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3619 .await?;
3620
3621 for rejoined_buffer in &buffers {
3622 let collaborators_to_notify = rejoined_buffer
3623 .buffer
3624 .collaborators
3625 .iter()
3626 .filter_map(|c| Some(c.peer_id?.into()));
3627 channel_buffer_updated(
3628 session.connection_id,
3629 collaborators_to_notify,
3630 &proto::UpdateChannelBufferCollaborators {
3631 channel_id: rejoined_buffer.buffer.channel_id,
3632 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3633 },
3634 &session.peer,
3635 );
3636 }
3637
3638 response.send(proto::RejoinChannelBuffersResponse {
3639 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3640 })?;
3641
3642 Ok(())
3643}
3644
3645/// Stop editing the channel notes
3646async fn leave_channel_buffer(
3647 request: proto::LeaveChannelBuffer,
3648 response: Response<proto::LeaveChannelBuffer>,
3649 session: UserSession,
3650) -> Result<()> {
3651 let db = session.db().await;
3652 let channel_id = ChannelId::from_proto(request.channel_id);
3653
3654 let left_buffer = db
3655 .leave_channel_buffer(channel_id, session.connection_id)
3656 .await?;
3657
3658 response.send(Ack {})?;
3659
3660 channel_buffer_updated(
3661 session.connection_id,
3662 left_buffer.connections,
3663 &proto::UpdateChannelBufferCollaborators {
3664 channel_id: channel_id.to_proto(),
3665 collaborators: left_buffer.collaborators,
3666 },
3667 &session.peer,
3668 );
3669
3670 Ok(())
3671}
3672
3673fn channel_buffer_updated<T: EnvelopedMessage>(
3674 sender_id: ConnectionId,
3675 collaborators: impl IntoIterator<Item = ConnectionId>,
3676 message: &T,
3677 peer: &Peer,
3678) {
3679 broadcast(Some(sender_id), collaborators, |peer_id| {
3680 peer.send(peer_id, message.clone())
3681 });
3682}
3683
3684fn send_notifications(
3685 connection_pool: &ConnectionPool,
3686 peer: &Peer,
3687 notifications: db::NotificationBatch,
3688) {
3689 for (user_id, notification) in notifications {
3690 for connection_id in connection_pool.user_connection_ids(user_id) {
3691 if let Err(error) = peer.send(
3692 connection_id,
3693 proto::AddNotification {
3694 notification: Some(notification.clone()),
3695 },
3696 ) {
3697 tracing::error!(
3698 "failed to send notification to {:?} {}",
3699 connection_id,
3700 error
3701 );
3702 }
3703 }
3704 }
3705}
3706
3707/// Send a message to the channel
3708async fn send_channel_message(
3709 request: proto::SendChannelMessage,
3710 response: Response<proto::SendChannelMessage>,
3711 session: UserSession,
3712) -> Result<()> {
3713 // Validate the message body.
3714 let body = request.body.trim().to_string();
3715 if body.len() > MAX_MESSAGE_LEN {
3716 return Err(anyhow!("message is too long"))?;
3717 }
3718 if body.is_empty() {
3719 return Err(anyhow!("message can't be blank"))?;
3720 }
3721
3722 // TODO: adjust mentions if body is trimmed
3723
3724 let timestamp = OffsetDateTime::now_utc();
3725 let nonce = request
3726 .nonce
3727 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3728
3729 let channel_id = ChannelId::from_proto(request.channel_id);
3730 let CreatedChannelMessage {
3731 message_id,
3732 participant_connection_ids,
3733 channel_members,
3734 notifications,
3735 } = session
3736 .db()
3737 .await
3738 .create_channel_message(
3739 channel_id,
3740 session.user_id(),
3741 &body,
3742 &request.mentions,
3743 timestamp,
3744 nonce.clone().into(),
3745 match request.reply_to_message_id {
3746 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3747 None => None,
3748 },
3749 )
3750 .await?;
3751
3752 let message = proto::ChannelMessage {
3753 sender_id: session.user_id().to_proto(),
3754 id: message_id.to_proto(),
3755 body,
3756 mentions: request.mentions,
3757 timestamp: timestamp.unix_timestamp() as u64,
3758 nonce: Some(nonce),
3759 reply_to_message_id: request.reply_to_message_id,
3760 edited_at: None,
3761 };
3762 broadcast(
3763 Some(session.connection_id),
3764 participant_connection_ids,
3765 |connection| {
3766 session.peer.send(
3767 connection,
3768 proto::ChannelMessageSent {
3769 channel_id: channel_id.to_proto(),
3770 message: Some(message.clone()),
3771 },
3772 )
3773 },
3774 );
3775 response.send(proto::SendChannelMessageResponse {
3776 message: Some(message),
3777 })?;
3778
3779 let pool = &*session.connection_pool().await;
3780 broadcast(
3781 None,
3782 channel_members
3783 .iter()
3784 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3785 |peer_id| {
3786 session.peer.send(
3787 peer_id,
3788 proto::UpdateChannels {
3789 latest_channel_message_ids: vec![proto::ChannelMessageId {
3790 channel_id: channel_id.to_proto(),
3791 message_id: message_id.to_proto(),
3792 }],
3793 ..Default::default()
3794 },
3795 )
3796 },
3797 );
3798 send_notifications(pool, &session.peer, notifications);
3799
3800 Ok(())
3801}
3802
3803/// Delete a channel message
3804async fn remove_channel_message(
3805 request: proto::RemoveChannelMessage,
3806 response: Response<proto::RemoveChannelMessage>,
3807 session: UserSession,
3808) -> Result<()> {
3809 let channel_id = ChannelId::from_proto(request.channel_id);
3810 let message_id = MessageId::from_proto(request.message_id);
3811 let (connection_ids, existing_notification_ids) = session
3812 .db()
3813 .await
3814 .remove_channel_message(channel_id, message_id, session.user_id())
3815 .await?;
3816
3817 broadcast(
3818 Some(session.connection_id),
3819 connection_ids,
3820 move |connection| {
3821 session.peer.send(connection, request.clone())?;
3822
3823 for notification_id in &existing_notification_ids {
3824 session.peer.send(
3825 connection,
3826 proto::DeleteNotification {
3827 notification_id: (*notification_id).to_proto(),
3828 },
3829 )?;
3830 }
3831
3832 Ok(())
3833 },
3834 );
3835 response.send(proto::Ack {})?;
3836 Ok(())
3837}
3838
3839async fn update_channel_message(
3840 request: proto::UpdateChannelMessage,
3841 response: Response<proto::UpdateChannelMessage>,
3842 session: UserSession,
3843) -> Result<()> {
3844 let channel_id = ChannelId::from_proto(request.channel_id);
3845 let message_id = MessageId::from_proto(request.message_id);
3846 let updated_at = OffsetDateTime::now_utc();
3847 let UpdatedChannelMessage {
3848 message_id,
3849 participant_connection_ids,
3850 notifications,
3851 reply_to_message_id,
3852 timestamp,
3853 deleted_mention_notification_ids,
3854 updated_mention_notifications,
3855 } = session
3856 .db()
3857 .await
3858 .update_channel_message(
3859 channel_id,
3860 message_id,
3861 session.user_id(),
3862 request.body.as_str(),
3863 &request.mentions,
3864 updated_at,
3865 )
3866 .await?;
3867
3868 let nonce = request
3869 .nonce
3870 .clone()
3871 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3872
3873 let message = proto::ChannelMessage {
3874 sender_id: session.user_id().to_proto(),
3875 id: message_id.to_proto(),
3876 body: request.body.clone(),
3877 mentions: request.mentions.clone(),
3878 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3879 nonce: Some(nonce),
3880 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3881 edited_at: Some(updated_at.unix_timestamp() as u64),
3882 };
3883
3884 response.send(proto::Ack {})?;
3885
3886 let pool = &*session.connection_pool().await;
3887 broadcast(
3888 Some(session.connection_id),
3889 participant_connection_ids,
3890 |connection| {
3891 session.peer.send(
3892 connection,
3893 proto::ChannelMessageUpdate {
3894 channel_id: channel_id.to_proto(),
3895 message: Some(message.clone()),
3896 },
3897 )?;
3898
3899 for notification_id in &deleted_mention_notification_ids {
3900 session.peer.send(
3901 connection,
3902 proto::DeleteNotification {
3903 notification_id: (*notification_id).to_proto(),
3904 },
3905 )?;
3906 }
3907
3908 for notification in &updated_mention_notifications {
3909 session.peer.send(
3910 connection,
3911 proto::UpdateNotification {
3912 notification: Some(notification.clone()),
3913 },
3914 )?;
3915 }
3916
3917 Ok(())
3918 },
3919 );
3920
3921 send_notifications(pool, &session.peer, notifications);
3922
3923 Ok(())
3924}
3925
3926/// Mark a channel message as read
3927async fn acknowledge_channel_message(
3928 request: proto::AckChannelMessage,
3929 session: UserSession,
3930) -> Result<()> {
3931 let channel_id = ChannelId::from_proto(request.channel_id);
3932 let message_id = MessageId::from_proto(request.message_id);
3933 let notifications = session
3934 .db()
3935 .await
3936 .observe_channel_message(channel_id, session.user_id(), message_id)
3937 .await?;
3938 send_notifications(
3939 &*session.connection_pool().await,
3940 &session.peer,
3941 notifications,
3942 );
3943 Ok(())
3944}
3945
3946/// Mark a buffer version as synced
3947async fn acknowledge_buffer_version(
3948 request: proto::AckBufferOperation,
3949 session: UserSession,
3950) -> Result<()> {
3951 let buffer_id = BufferId::from_proto(request.buffer_id);
3952 session
3953 .db()
3954 .await
3955 .observe_buffer_version(
3956 buffer_id,
3957 session.user_id(),
3958 request.epoch as i32,
3959 &request.version,
3960 )
3961 .await?;
3962 Ok(())
3963}
3964
3965struct CompleteWithLanguageModelRateLimit;
3966
3967impl RateLimit for CompleteWithLanguageModelRateLimit {
3968 fn capacity() -> usize {
3969 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3970 .ok()
3971 .and_then(|v| v.parse().ok())
3972 .unwrap_or(120) // Picked arbitrarily
3973 }
3974
3975 fn refill_duration() -> chrono::Duration {
3976 chrono::Duration::hours(1)
3977 }
3978
3979 fn db_name() -> &'static str {
3980 "complete-with-language-model"
3981 }
3982}
3983
3984async fn complete_with_language_model(
3985 request: proto::CompleteWithLanguageModel,
3986 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3987 session: Session,
3988 open_ai_api_key: Option<Arc<str>>,
3989 google_ai_api_key: Option<Arc<str>>,
3990 anthropic_api_key: Option<Arc<str>>,
3991) -> Result<()> {
3992 let Some(session) = session.for_user() else {
3993 return Err(anyhow!("user not found"))?;
3994 };
3995 authorize_access_to_language_models(&session).await?;
3996 session
3997 .rate_limiter
3998 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3999 .await?;
4000
4001 if request.model.starts_with("gpt") {
4002 let api_key =
4003 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4004 complete_with_open_ai(request, response, session, api_key).await?;
4005 } else if request.model.starts_with("gemini") {
4006 let api_key = google_ai_api_key
4007 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4008 complete_with_google_ai(request, response, session, api_key).await?;
4009 } else if request.model.starts_with("claude") {
4010 let api_key = anthropic_api_key
4011 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4012 complete_with_anthropic(request, response, session, api_key).await?;
4013 }
4014
4015 Ok(())
4016}
4017
4018async fn complete_with_open_ai(
4019 request: proto::CompleteWithLanguageModel,
4020 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4021 session: UserSession,
4022 api_key: Arc<str>,
4023) -> Result<()> {
4024 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
4025
4026 let mut completion_stream = open_ai::stream_completion(
4027 &session.http_client,
4028 OPEN_AI_API_URL,
4029 &api_key,
4030 crate::ai::language_model_request_to_open_ai(request)?,
4031 )
4032 .await
4033 .context("open_ai::stream_completion request failed")?;
4034
4035 while let Some(event) = completion_stream.next().await {
4036 let event = event?;
4037 response.send(proto::LanguageModelResponse {
4038 choices: event
4039 .choices
4040 .into_iter()
4041 .map(|choice| proto::LanguageModelChoiceDelta {
4042 index: choice.index,
4043 delta: Some(proto::LanguageModelResponseMessage {
4044 role: choice.delta.role.map(|role| match role {
4045 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4046 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4047 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4048 } as i32),
4049 content: choice.delta.content,
4050 }),
4051 finish_reason: choice.finish_reason,
4052 })
4053 .collect(),
4054 })?;
4055 }
4056
4057 Ok(())
4058}
4059
4060async fn complete_with_google_ai(
4061 request: proto::CompleteWithLanguageModel,
4062 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4063 session: UserSession,
4064 api_key: Arc<str>,
4065) -> Result<()> {
4066 let mut stream = google_ai::stream_generate_content(
4067 &session.http_client,
4068 google_ai::API_URL,
4069 api_key.as_ref(),
4070 crate::ai::language_model_request_to_google_ai(request)?,
4071 )
4072 .await
4073 .context("google_ai::stream_generate_content request failed")?;
4074
4075 while let Some(event) = stream.next().await {
4076 let event = event?;
4077 response.send(proto::LanguageModelResponse {
4078 choices: event
4079 .candidates
4080 .unwrap_or_default()
4081 .into_iter()
4082 .map(|candidate| proto::LanguageModelChoiceDelta {
4083 index: candidate.index as u32,
4084 delta: Some(proto::LanguageModelResponseMessage {
4085 role: Some(match candidate.content.role {
4086 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4087 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4088 } as i32),
4089 content: Some(
4090 candidate
4091 .content
4092 .parts
4093 .into_iter()
4094 .filter_map(|part| match part {
4095 google_ai::Part::TextPart(part) => Some(part.text),
4096 google_ai::Part::InlineDataPart(_) => None,
4097 })
4098 .collect(),
4099 ),
4100 }),
4101 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4102 })
4103 .collect(),
4104 })?;
4105 }
4106
4107 Ok(())
4108}
4109
4110async fn complete_with_anthropic(
4111 request: proto::CompleteWithLanguageModel,
4112 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4113 session: UserSession,
4114 api_key: Arc<str>,
4115) -> Result<()> {
4116 let model = anthropic::Model::from_id(&request.model)?;
4117
4118 let mut system_message = String::new();
4119 let messages = request
4120 .messages
4121 .into_iter()
4122 .filter_map(|message| match message.role() {
4123 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4124 role: anthropic::Role::User,
4125 content: message.content,
4126 }),
4127 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4128 role: anthropic::Role::Assistant,
4129 content: message.content,
4130 }),
4131 // Anthropic's API breaks system instructions out as a separate field rather
4132 // than having a system message role.
4133 LanguageModelRole::LanguageModelSystem => {
4134 if !system_message.is_empty() {
4135 system_message.push_str("\n\n");
4136 }
4137 system_message.push_str(&message.content);
4138
4139 None
4140 }
4141 })
4142 .collect();
4143
4144 let mut stream = anthropic::stream_completion(
4145 &session.http_client,
4146 "https://api.anthropic.com",
4147 &api_key,
4148 anthropic::Request {
4149 model,
4150 messages,
4151 stream: true,
4152 system: system_message,
4153 max_tokens: 4092,
4154 },
4155 )
4156 .await?;
4157
4158 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4159
4160 while let Some(event) = stream.next().await {
4161 let event = event?;
4162
4163 match event {
4164 anthropic::ResponseEvent::MessageStart { message } => {
4165 if let Some(role) = message.role {
4166 if role == "assistant" {
4167 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4168 } else if role == "user" {
4169 current_role = proto::LanguageModelRole::LanguageModelUser;
4170 }
4171 }
4172 }
4173 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4174 match content_block {
4175 anthropic::ContentBlock::Text { text } => {
4176 if !text.is_empty() {
4177 response.send(proto::LanguageModelResponse {
4178 choices: vec![proto::LanguageModelChoiceDelta {
4179 index: 0,
4180 delta: Some(proto::LanguageModelResponseMessage {
4181 role: Some(current_role as i32),
4182 content: Some(text),
4183 }),
4184 finish_reason: None,
4185 }],
4186 })?;
4187 }
4188 }
4189 }
4190 }
4191 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4192 anthropic::TextDelta::TextDelta { text } => {
4193 response.send(proto::LanguageModelResponse {
4194 choices: vec![proto::LanguageModelChoiceDelta {
4195 index: 0,
4196 delta: Some(proto::LanguageModelResponseMessage {
4197 role: Some(current_role as i32),
4198 content: Some(text),
4199 }),
4200 finish_reason: None,
4201 }],
4202 })?;
4203 }
4204 },
4205 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4206 if let Some(stop_reason) = delta.stop_reason {
4207 response.send(proto::LanguageModelResponse {
4208 choices: vec![proto::LanguageModelChoiceDelta {
4209 index: 0,
4210 delta: None,
4211 finish_reason: Some(stop_reason),
4212 }],
4213 })?;
4214 }
4215 }
4216 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4217 anthropic::ResponseEvent::MessageStop {} => {}
4218 anthropic::ResponseEvent::Ping {} => {}
4219 }
4220 }
4221
4222 Ok(())
4223}
4224
4225struct CountTokensWithLanguageModelRateLimit;
4226
4227impl RateLimit for CountTokensWithLanguageModelRateLimit {
4228 fn capacity() -> usize {
4229 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4230 .ok()
4231 .and_then(|v| v.parse().ok())
4232 .unwrap_or(600) // Picked arbitrarily
4233 }
4234
4235 fn refill_duration() -> chrono::Duration {
4236 chrono::Duration::hours(1)
4237 }
4238
4239 fn db_name() -> &'static str {
4240 "count-tokens-with-language-model"
4241 }
4242}
4243
4244async fn count_tokens_with_language_model(
4245 request: proto::CountTokensWithLanguageModel,
4246 response: Response<proto::CountTokensWithLanguageModel>,
4247 session: UserSession,
4248 google_ai_api_key: Option<Arc<str>>,
4249) -> Result<()> {
4250 authorize_access_to_language_models(&session).await?;
4251
4252 if !request.model.starts_with("gemini") {
4253 return Err(anyhow!(
4254 "counting tokens for model: {:?} is not supported",
4255 request.model
4256 ))?;
4257 }
4258
4259 session
4260 .rate_limiter
4261 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4262 .await?;
4263
4264 let api_key = google_ai_api_key
4265 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4266 let tokens_response = google_ai::count_tokens(
4267 &session.http_client,
4268 google_ai::API_URL,
4269 &api_key,
4270 crate::ai::count_tokens_request_to_google_ai(request)?,
4271 )
4272 .await?;
4273 response.send(proto::CountTokensResponse {
4274 token_count: tokens_response.total_tokens as u32,
4275 })?;
4276 Ok(())
4277}
4278
4279async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4280 let db = session.db().await;
4281 let flags = db.get_user_flags(session.user_id()).await?;
4282 if flags.iter().any(|flag| flag == "language-models") {
4283 Ok(())
4284 } else {
4285 Err(anyhow!("permission denied"))?
4286 }
4287}
4288
4289/// Start receiving chat updates for a channel
4290async fn join_channel_chat(
4291 request: proto::JoinChannelChat,
4292 response: Response<proto::JoinChannelChat>,
4293 session: UserSession,
4294) -> Result<()> {
4295 let channel_id = ChannelId::from_proto(request.channel_id);
4296
4297 let db = session.db().await;
4298 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4299 .await?;
4300 let messages = db
4301 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4302 .await?;
4303 response.send(proto::JoinChannelChatResponse {
4304 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4305 messages,
4306 })?;
4307 Ok(())
4308}
4309
4310/// Stop receiving chat updates for a channel
4311async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4312 let channel_id = ChannelId::from_proto(request.channel_id);
4313 session
4314 .db()
4315 .await
4316 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4317 .await?;
4318 Ok(())
4319}
4320
4321/// Retrieve the chat history for a channel
4322async fn get_channel_messages(
4323 request: proto::GetChannelMessages,
4324 response: Response<proto::GetChannelMessages>,
4325 session: UserSession,
4326) -> Result<()> {
4327 let channel_id = ChannelId::from_proto(request.channel_id);
4328 let messages = session
4329 .db()
4330 .await
4331 .get_channel_messages(
4332 channel_id,
4333 session.user_id(),
4334 MESSAGE_COUNT_PER_PAGE,
4335 Some(MessageId::from_proto(request.before_message_id)),
4336 )
4337 .await?;
4338 response.send(proto::GetChannelMessagesResponse {
4339 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4340 messages,
4341 })?;
4342 Ok(())
4343}
4344
4345/// Retrieve specific chat messages
4346async fn get_channel_messages_by_id(
4347 request: proto::GetChannelMessagesById,
4348 response: Response<proto::GetChannelMessagesById>,
4349 session: UserSession,
4350) -> Result<()> {
4351 let message_ids = request
4352 .message_ids
4353 .iter()
4354 .map(|id| MessageId::from_proto(*id))
4355 .collect::<Vec<_>>();
4356 let messages = session
4357 .db()
4358 .await
4359 .get_channel_messages_by_id(session.user_id(), &message_ids)
4360 .await?;
4361 response.send(proto::GetChannelMessagesResponse {
4362 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4363 messages,
4364 })?;
4365 Ok(())
4366}
4367
4368/// Retrieve the current users notifications
4369async fn get_notifications(
4370 request: proto::GetNotifications,
4371 response: Response<proto::GetNotifications>,
4372 session: UserSession,
4373) -> Result<()> {
4374 let notifications = session
4375 .db()
4376 .await
4377 .get_notifications(
4378 session.user_id(),
4379 NOTIFICATION_COUNT_PER_PAGE,
4380 request
4381 .before_id
4382 .map(|id| db::NotificationId::from_proto(id)),
4383 )
4384 .await?;
4385 response.send(proto::GetNotificationsResponse {
4386 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4387 notifications,
4388 })?;
4389 Ok(())
4390}
4391
4392/// Mark notifications as read
4393async fn mark_notification_as_read(
4394 request: proto::MarkNotificationRead,
4395 response: Response<proto::MarkNotificationRead>,
4396 session: UserSession,
4397) -> Result<()> {
4398 let database = &session.db().await;
4399 let notifications = database
4400 .mark_notification_as_read_by_id(
4401 session.user_id(),
4402 NotificationId::from_proto(request.notification_id),
4403 )
4404 .await?;
4405 send_notifications(
4406 &*session.connection_pool().await,
4407 &session.peer,
4408 notifications,
4409 );
4410 response.send(proto::Ack {})?;
4411 Ok(())
4412}
4413
4414/// Get the current users information
4415async fn get_private_user_info(
4416 _request: proto::GetPrivateUserInfo,
4417 response: Response<proto::GetPrivateUserInfo>,
4418 session: UserSession,
4419) -> Result<()> {
4420 let db = session.db().await;
4421
4422 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4423 let user = db
4424 .get_user_by_id(session.user_id())
4425 .await?
4426 .ok_or_else(|| anyhow!("user not found"))?;
4427 let flags = db.get_user_flags(session.user_id()).await?;
4428
4429 response.send(proto::GetPrivateUserInfoResponse {
4430 metrics_id,
4431 staff: user.admin,
4432 flags,
4433 })?;
4434 Ok(())
4435}
4436
4437fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4438 match message {
4439 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4440 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4441 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4442 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4443 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4444 code: frame.code.into(),
4445 reason: frame.reason,
4446 })),
4447 }
4448}
4449
4450fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4451 match message {
4452 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4453 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4454 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4455 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4456 AxumMessage::Close(frame) => {
4457 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4458 code: frame.code.into(),
4459 reason: frame.reason,
4460 }))
4461 }
4462 }
4463}
4464
4465fn notify_membership_updated(
4466 connection_pool: &mut ConnectionPool,
4467 result: MembershipUpdated,
4468 user_id: UserId,
4469 peer: &Peer,
4470) {
4471 for membership in &result.new_channels.channel_memberships {
4472 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4473 }
4474 for channel_id in &result.removed_channels {
4475 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4476 }
4477
4478 let user_channels_update = proto::UpdateUserChannels {
4479 channel_memberships: result
4480 .new_channels
4481 .channel_memberships
4482 .iter()
4483 .map(|cm| proto::ChannelMembership {
4484 channel_id: cm.channel_id.to_proto(),
4485 role: cm.role.into(),
4486 })
4487 .collect(),
4488 ..Default::default()
4489 };
4490
4491 let mut update = build_channels_update(result.new_channels, vec![], connection_pool);
4492 update.delete_channels = result
4493 .removed_channels
4494 .into_iter()
4495 .map(|id| id.to_proto())
4496 .collect();
4497 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4498
4499 for connection_id in connection_pool.user_connection_ids(user_id) {
4500 peer.send(connection_id, user_channels_update.clone())
4501 .trace_err();
4502 peer.send(connection_id, update.clone()).trace_err();
4503 }
4504}
4505
4506fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4507 proto::UpdateUserChannels {
4508 channel_memberships: channels
4509 .channel_memberships
4510 .iter()
4511 .map(|m| proto::ChannelMembership {
4512 channel_id: m.channel_id.to_proto(),
4513 role: m.role.into(),
4514 })
4515 .collect(),
4516 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4517 observed_channel_message_id: channels.observed_channel_messages.clone(),
4518 }
4519}
4520
4521fn build_channels_update(
4522 channels: ChannelsForUser,
4523 channel_invites: Vec<db::Channel>,
4524 pool: &ConnectionPool,
4525) -> proto::UpdateChannels {
4526 let mut update = proto::UpdateChannels::default();
4527
4528 for channel in channels.channels {
4529 update.channels.push(channel.to_proto());
4530 }
4531
4532 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4533 update.latest_channel_message_ids = channels.latest_channel_messages;
4534
4535 for (channel_id, participants) in channels.channel_participants {
4536 update
4537 .channel_participants
4538 .push(proto::ChannelParticipants {
4539 channel_id: channel_id.to_proto(),
4540 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4541 });
4542 }
4543
4544 for channel in channel_invites {
4545 update.channel_invitations.push(channel.to_proto());
4546 }
4547
4548 update.hosted_projects = channels.hosted_projects;
4549 update.dev_servers = channels
4550 .dev_servers
4551 .into_iter()
4552 .map(|dev_server| dev_server.to_proto(pool.dev_server_status(dev_server.id)))
4553 .collect();
4554 update.remote_projects = channels.remote_projects;
4555
4556 update
4557}
4558
4559fn build_initial_contacts_update(
4560 contacts: Vec<db::Contact>,
4561 pool: &ConnectionPool,
4562) -> proto::UpdateContacts {
4563 let mut update = proto::UpdateContacts::default();
4564
4565 for contact in contacts {
4566 match contact {
4567 db::Contact::Accepted { user_id, busy } => {
4568 update.contacts.push(contact_for_user(user_id, busy, &pool));
4569 }
4570 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4571 db::Contact::Incoming { user_id } => {
4572 update
4573 .incoming_requests
4574 .push(proto::IncomingContactRequest {
4575 requester_id: user_id.to_proto(),
4576 })
4577 }
4578 }
4579 }
4580
4581 update
4582}
4583
4584fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4585 proto::Contact {
4586 user_id: user_id.to_proto(),
4587 online: pool.is_user_online(user_id),
4588 busy,
4589 }
4590}
4591
4592fn room_updated(room: &proto::Room, peer: &Peer) {
4593 broadcast(
4594 None,
4595 room.participants
4596 .iter()
4597 .filter_map(|participant| Some(participant.peer_id?.into())),
4598 |peer_id| {
4599 peer.send(
4600 peer_id,
4601 proto::RoomUpdated {
4602 room: Some(room.clone()),
4603 },
4604 )
4605 },
4606 );
4607}
4608
4609fn channel_updated(
4610 channel: &db::channel::Model,
4611 room: &proto::Room,
4612 peer: &Peer,
4613 pool: &ConnectionPool,
4614) {
4615 let participants = room
4616 .participants
4617 .iter()
4618 .map(|p| p.user_id)
4619 .collect::<Vec<_>>();
4620
4621 broadcast(
4622 None,
4623 pool.channel_connection_ids(channel.root_id())
4624 .filter_map(|(channel_id, role)| {
4625 role.can_see_channel(channel.visibility).then(|| channel_id)
4626 }),
4627 |peer_id| {
4628 peer.send(
4629 peer_id,
4630 proto::UpdateChannels {
4631 channel_participants: vec![proto::ChannelParticipants {
4632 channel_id: channel.id.to_proto(),
4633 participant_user_ids: participants.clone(),
4634 }],
4635 ..Default::default()
4636 },
4637 )
4638 },
4639 );
4640}
4641
4642async fn update_dev_server_status(
4643 dev_server: &dev_server::Model,
4644 status: proto::DevServerStatus,
4645 session: &Session,
4646) {
4647 let pool = session.connection_pool().await;
4648 let connections = pool.channel_connection_ids(dev_server.channel_id);
4649 for (connection_id, _) in connections {
4650 session
4651 .peer
4652 .send(
4653 connection_id,
4654 proto::UpdateChannels {
4655 dev_servers: vec![dev_server.to_proto(status)],
4656 ..Default::default()
4657 },
4658 )
4659 .trace_err();
4660 }
4661}
4662
4663async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4664 let db = session.db().await;
4665
4666 let contacts = db.get_contacts(user_id).await?;
4667 let busy = db.is_user_busy(user_id).await?;
4668
4669 let pool = session.connection_pool().await;
4670 let updated_contact = contact_for_user(user_id, busy, &pool);
4671 for contact in contacts {
4672 if let db::Contact::Accepted {
4673 user_id: contact_user_id,
4674 ..
4675 } = contact
4676 {
4677 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4678 session
4679 .peer
4680 .send(
4681 contact_conn_id,
4682 proto::UpdateContacts {
4683 contacts: vec![updated_contact.clone()],
4684 remove_contacts: Default::default(),
4685 incoming_requests: Default::default(),
4686 remove_incoming_requests: Default::default(),
4687 outgoing_requests: Default::default(),
4688 remove_outgoing_requests: Default::default(),
4689 },
4690 )
4691 .trace_err();
4692 }
4693 }
4694 }
4695 Ok(())
4696}
4697
4698async fn lost_dev_server_connection(session: &Session) -> Result<()> {
4699 log::info!("lost dev server connection, unsharing projects");
4700 let project_ids = session
4701 .db()
4702 .await
4703 .get_stale_dev_server_projects(session.connection_id)
4704 .await?;
4705
4706 for project_id in project_ids {
4707 // not unshare re-checks the connection ids match, so we get away with no transaction
4708 unshare_project_internal(project_id, &session).await?;
4709 }
4710
4711 Ok(())
4712}
4713
4714async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4715 let mut contacts_to_update = HashSet::default();
4716
4717 let room_id;
4718 let canceled_calls_to_user_ids;
4719 let live_kit_room;
4720 let delete_live_kit_room;
4721 let room;
4722 let channel;
4723
4724 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4725 contacts_to_update.insert(session.user_id());
4726
4727 for project in left_room.left_projects.values() {
4728 project_left(project, session);
4729 }
4730
4731 room_id = RoomId::from_proto(left_room.room.id);
4732 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4733 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4734 delete_live_kit_room = left_room.deleted;
4735 room = mem::take(&mut left_room.room);
4736 channel = mem::take(&mut left_room.channel);
4737
4738 room_updated(&room, &session.peer);
4739 } else {
4740 return Ok(());
4741 }
4742
4743 if let Some(channel) = channel {
4744 channel_updated(
4745 &channel,
4746 &room,
4747 &session.peer,
4748 &*session.connection_pool().await,
4749 );
4750 }
4751
4752 {
4753 let pool = session.connection_pool().await;
4754 for canceled_user_id in canceled_calls_to_user_ids {
4755 for connection_id in pool.user_connection_ids(canceled_user_id) {
4756 session
4757 .peer
4758 .send(
4759 connection_id,
4760 proto::CallCanceled {
4761 room_id: room_id.to_proto(),
4762 },
4763 )
4764 .trace_err();
4765 }
4766 contacts_to_update.insert(canceled_user_id);
4767 }
4768 }
4769
4770 for contact_user_id in contacts_to_update {
4771 update_user_contacts(contact_user_id, &session).await?;
4772 }
4773
4774 if let Some(live_kit) = session.live_kit_client.as_ref() {
4775 live_kit
4776 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4777 .await
4778 .trace_err();
4779
4780 if delete_live_kit_room {
4781 live_kit.delete_room(live_kit_room).await.trace_err();
4782 }
4783 }
4784
4785 Ok(())
4786}
4787
4788async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4789 let left_channel_buffers = session
4790 .db()
4791 .await
4792 .leave_channel_buffers(session.connection_id)
4793 .await?;
4794
4795 for left_buffer in left_channel_buffers {
4796 channel_buffer_updated(
4797 session.connection_id,
4798 left_buffer.connections,
4799 &proto::UpdateChannelBufferCollaborators {
4800 channel_id: left_buffer.channel_id.to_proto(),
4801 collaborators: left_buffer.collaborators,
4802 },
4803 &session.peer,
4804 );
4805 }
4806
4807 Ok(())
4808}
4809
4810fn project_left(project: &db::LeftProject, session: &UserSession) {
4811 for connection_id in &project.connection_ids {
4812 if project.host_user_id == Some(session.user_id()) {
4813 session
4814 .peer
4815 .send(
4816 *connection_id,
4817 proto::UnshareProject {
4818 project_id: project.id.to_proto(),
4819 },
4820 )
4821 .trace_err();
4822 } else {
4823 session
4824 .peer
4825 .send(
4826 *connection_id,
4827 proto::RemoveProjectCollaborator {
4828 project_id: project.id.to_proto(),
4829 peer_id: Some(session.connection_id.into()),
4830 },
4831 )
4832 .trace_err();
4833 }
4834 }
4835}
4836
4837pub trait ResultExt {
4838 type Ok;
4839
4840 fn trace_err(self) -> Option<Self::Ok>;
4841}
4842
4843impl<T, E> ResultExt for Result<T, E>
4844where
4845 E: std::fmt::Debug,
4846{
4847 type Ok = T;
4848
4849 #[track_caller]
4850 fn trace_err(self) -> Option<T> {
4851 match self {
4852 Ok(value) => Some(value),
4853 Err(error) => {
4854 tracing::error!("{:?}", error);
4855 None
4856 }
4857 }
4858 }
4859}