1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
8 MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
9 RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37
38use futures::{
39 channel::oneshot,
40 future::{self, BoxFuture},
41 stream::FuturesUnordered,
42 FutureExt, SinkExt, StreamExt, TryStreamExt,
43};
44use prometheus::{register_int_gauge, IntGauge};
45use rpc::{
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
48 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 atomic::{AtomicBool, Ordering::SeqCst},
64 Arc, OnceLock,
65 },
66 time::{Duration, Instant},
67};
68use time::OffsetDateTime;
69use tokio::sync::{watch, Semaphore};
70use tower::ServiceBuilder;
71use tracing::{
72 field::{self},
73 info_span, instrument, Instrument,
74};
75use util::http::IsahcHttpClient;
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103struct StreamingResponse<R: RequestMessage> {
104 peer: Arc<Peer>,
105 receipt: Receipt<R>,
106}
107
108impl<R: RequestMessage> StreamingResponse<R> {
109 fn send(&self, payload: R::Response) -> Result<()> {
110 self.peer.respond(self.receipt, payload)?;
111 Ok(())
112 }
113}
114
115#[derive(Clone, Debug)]
116pub enum Principal {
117 User(User),
118 Impersonated { user: User, admin: User },
119 DevServer(dev_server::Model),
120}
121
122impl Principal {
123 fn update_span(&self, span: &tracing::Span) {
124 match &self {
125 Principal::User(user) => {
126 span.record("user_id", &user.id.0);
127 span.record("login", &user.github_login);
128 }
129 Principal::Impersonated { user, admin } => {
130 span.record("user_id", &user.id.0);
131 span.record("login", &user.github_login);
132 span.record("impersonator", &admin.github_login);
133 }
134 Principal::DevServer(dev_server) => {
135 span.record("dev_server_id", &dev_server.id.0);
136 }
137 }
138 }
139}
140
141#[derive(Clone)]
142struct Session {
143 principal: Principal,
144 connection_id: ConnectionId,
145 db: Arc<tokio::sync::Mutex<DbHandle>>,
146 peer: Arc<Peer>,
147 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
149 http_client: IsahcHttpClient,
150 rate_limiter: Arc<RateLimiter>,
151 _executor: Executor,
152}
153
154impl Session {
155 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.db.lock().await;
159 #[cfg(test)]
160 tokio::task::yield_now().await;
161 guard
162 }
163
164 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
165 #[cfg(test)]
166 tokio::task::yield_now().await;
167 let guard = self.connection_pool.lock();
168 ConnectionPoolGuard {
169 guard,
170 _not_send: PhantomData,
171 }
172 }
173
174 fn for_user(self) -> Option<UserSession> {
175 UserSession::new(self)
176 }
177
178 fn for_dev_server(self) -> Option<DevServerSession> {
179 DevServerSession::new(self)
180 }
181
182 fn user_id(&self) -> Option<UserId> {
183 match &self.principal {
184 Principal::User(user) => Some(user.id),
185 Principal::Impersonated { user, .. } => Some(user.id),
186 Principal::DevServer(_) => None,
187 }
188 }
189
190 fn dev_server_id(&self) -> Option<DevServerId> {
191 match &self.principal {
192 Principal::User(_) | Principal::Impersonated { .. } => None,
193 Principal::DevServer(dev_server) => Some(dev_server.id),
194 }
195 }
196
197 fn principal_id(&self) -> PrincipalId {
198 match &self.principal {
199 Principal::User(user) => PrincipalId::UserId(user.id),
200 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
201 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
202 }
203 }
204}
205
206impl Debug for Session {
207 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
208 let mut result = f.debug_struct("Session");
209 match &self.principal {
210 Principal::User(user) => {
211 result.field("user", &user.github_login);
212 }
213 Principal::Impersonated { user, admin } => {
214 result.field("user", &user.github_login);
215 result.field("impersonator", &admin.github_login);
216 }
217 Principal::DevServer(dev_server) => {
218 result.field("dev_server", &dev_server.id);
219 }
220 }
221 result.field("connection_id", &self.connection_id).finish()
222 }
223}
224
225struct UserSession(Session);
226
227impl UserSession {
228 pub fn new(s: Session) -> Option<Self> {
229 s.user_id().map(|_| UserSession(s))
230 }
231 pub fn user_id(&self) -> UserId {
232 self.0.user_id().unwrap()
233 }
234}
235
236impl Deref for UserSession {
237 type Target = Session;
238
239 fn deref(&self) -> &Self::Target {
240 &self.0
241 }
242}
243impl DerefMut for UserSession {
244 fn deref_mut(&mut self) -> &mut Self::Target {
245 &mut self.0
246 }
247}
248
249struct DevServerSession(Session);
250
251impl DevServerSession {
252 pub fn new(s: Session) -> Option<Self> {
253 s.dev_server_id().map(|_| DevServerSession(s))
254 }
255 pub fn dev_server_id(&self) -> DevServerId {
256 self.0.dev_server_id().unwrap()
257 }
258}
259
260impl Deref for DevServerSession {
261 type Target = Session;
262
263 fn deref(&self) -> &Self::Target {
264 &self.0
265 }
266}
267impl DerefMut for DevServerSession {
268 fn deref_mut(&mut self) -> &mut Self::Target {
269 &mut self.0
270 }
271}
272
273fn user_handler<M: RequestMessage, Fut>(
274 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
275) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
276where
277 Fut: Send + Future<Output = Result<()>>,
278{
279 let handler = Arc::new(handler);
280 move |message, response, session| {
281 let handler = handler.clone();
282 Box::pin(async move {
283 if let Some(user_session) = session.for_user() {
284 Ok(handler(message, response, user_session).await?)
285 } else {
286 Err(Error::Internal(anyhow!(
287 "must be a user to call {}",
288 M::NAME
289 )))
290 }
291 })
292 }
293}
294
295fn dev_server_handler<M: RequestMessage, Fut>(
296 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
297) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
298where
299 Fut: Send + Future<Output = Result<()>>,
300{
301 let handler = Arc::new(handler);
302 move |message, response, session| {
303 let handler = handler.clone();
304 Box::pin(async move {
305 if let Some(dev_server_session) = session.for_dev_server() {
306 Ok(handler(message, response, dev_server_session).await?)
307 } else {
308 Err(Error::Internal(anyhow!(
309 "must be a dev server to call {}",
310 M::NAME
311 )))
312 }
313 })
314 }
315}
316
317fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
318 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
319) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
320where
321 InnertRetFut: Send + Future<Output = Result<()>>,
322{
323 let handler = Arc::new(handler);
324 move |message, session| {
325 let handler = handler.clone();
326 Box::pin(async move {
327 if let Some(user_session) = session.for_user() {
328 Ok(handler(message, user_session).await?)
329 } else {
330 Err(Error::Internal(anyhow!(
331 "must be a user to call {}",
332 M::NAME
333 )))
334 }
335 })
336 }
337}
338
339struct DbHandle(Arc<Database>);
340
341impl Deref for DbHandle {
342 type Target = Database;
343
344 fn deref(&self) -> &Self::Target {
345 self.0.as_ref()
346 }
347}
348
349pub struct Server {
350 id: parking_lot::Mutex<ServerId>,
351 peer: Arc<Peer>,
352 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
353 app_state: Arc<AppState>,
354 handlers: HashMap<TypeId, MessageHandler>,
355 teardown: watch::Sender<bool>,
356}
357
358pub(crate) struct ConnectionPoolGuard<'a> {
359 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
360 _not_send: PhantomData<Rc<()>>,
361}
362
363#[derive(Serialize)]
364pub struct ServerSnapshot<'a> {
365 peer: &'a Peer,
366 #[serde(serialize_with = "serialize_deref")]
367 connection_pool: ConnectionPoolGuard<'a>,
368}
369
370pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
371where
372 S: Serializer,
373 T: Deref<Target = U>,
374 U: Serialize,
375{
376 Serialize::serialize(value.deref(), serializer)
377}
378
379impl Server {
380 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
381 let mut server = Self {
382 id: parking_lot::Mutex::new(id),
383 peer: Peer::new(id.0 as u32),
384 app_state: app_state.clone(),
385 connection_pool: Default::default(),
386 handlers: Default::default(),
387 teardown: watch::channel(false).0,
388 };
389
390 server
391 .add_request_handler(ping)
392 .add_request_handler(user_handler(create_room))
393 .add_request_handler(user_handler(join_room))
394 .add_request_handler(user_handler(rejoin_room))
395 .add_request_handler(user_handler(leave_room))
396 .add_request_handler(user_handler(set_room_participant_role))
397 .add_request_handler(user_handler(call))
398 .add_request_handler(user_handler(cancel_call))
399 .add_message_handler(user_message_handler(decline_call))
400 .add_request_handler(user_handler(update_participant_location))
401 .add_request_handler(user_handler(share_project))
402 .add_message_handler(unshare_project)
403 .add_request_handler(user_handler(join_project))
404 .add_request_handler(user_handler(join_hosted_project))
405 .add_request_handler(user_handler(rejoin_remote_projects))
406 .add_request_handler(user_handler(create_remote_project))
407 .add_request_handler(user_handler(create_dev_server))
408 .add_request_handler(dev_server_handler(share_remote_project))
409 .add_request_handler(dev_server_handler(shutdown_dev_server))
410 .add_request_handler(dev_server_handler(reconnect_dev_server))
411 .add_message_handler(user_message_handler(leave_project))
412 .add_request_handler(update_project)
413 .add_request_handler(update_worktree)
414 .add_message_handler(start_language_server)
415 .add_message_handler(update_language_server)
416 .add_message_handler(update_diagnostic_summary)
417 .add_message_handler(update_worktree_settings)
418 .add_request_handler(user_handler(
419 forward_read_only_project_request::<proto::GetHover>,
420 ))
421 .add_request_handler(user_handler(
422 forward_read_only_project_request::<proto::GetDefinition>,
423 ))
424 .add_request_handler(user_handler(
425 forward_read_only_project_request::<proto::GetTypeDefinition>,
426 ))
427 .add_request_handler(user_handler(
428 forward_read_only_project_request::<proto::GetReferences>,
429 ))
430 .add_request_handler(user_handler(
431 forward_read_only_project_request::<proto::SearchProject>,
432 ))
433 .add_request_handler(user_handler(
434 forward_read_only_project_request::<proto::GetDocumentHighlights>,
435 ))
436 .add_request_handler(user_handler(
437 forward_read_only_project_request::<proto::GetProjectSymbols>,
438 ))
439 .add_request_handler(user_handler(
440 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
441 ))
442 .add_request_handler(user_handler(
443 forward_read_only_project_request::<proto::OpenBufferById>,
444 ))
445 .add_request_handler(user_handler(
446 forward_read_only_project_request::<proto::SynchronizeBuffers>,
447 ))
448 .add_request_handler(user_handler(
449 forward_read_only_project_request::<proto::InlayHints>,
450 ))
451 .add_request_handler(user_handler(
452 forward_read_only_project_request::<proto::OpenBufferByPath>,
453 ))
454 .add_request_handler(user_handler(
455 forward_mutating_project_request::<proto::GetCompletions>,
456 ))
457 .add_request_handler(user_handler(
458 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
459 ))
460 .add_request_handler(user_handler(
461 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
462 ))
463 .add_request_handler(user_handler(
464 forward_mutating_project_request::<proto::GetCodeActions>,
465 ))
466 .add_request_handler(user_handler(
467 forward_mutating_project_request::<proto::ApplyCodeAction>,
468 ))
469 .add_request_handler(user_handler(
470 forward_mutating_project_request::<proto::PrepareRename>,
471 ))
472 .add_request_handler(user_handler(
473 forward_mutating_project_request::<proto::PerformRename>,
474 ))
475 .add_request_handler(user_handler(
476 forward_mutating_project_request::<proto::ReloadBuffers>,
477 ))
478 .add_request_handler(user_handler(
479 forward_mutating_project_request::<proto::FormatBuffers>,
480 ))
481 .add_request_handler(user_handler(
482 forward_mutating_project_request::<proto::CreateProjectEntry>,
483 ))
484 .add_request_handler(user_handler(
485 forward_mutating_project_request::<proto::RenameProjectEntry>,
486 ))
487 .add_request_handler(user_handler(
488 forward_mutating_project_request::<proto::CopyProjectEntry>,
489 ))
490 .add_request_handler(user_handler(
491 forward_mutating_project_request::<proto::DeleteProjectEntry>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::ExpandProjectEntry>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::OnTypeFormatting>,
498 ))
499 .add_request_handler(user_handler(
500 forward_mutating_project_request::<proto::SaveBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::BlameBuffer>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::MultiLspQuery>,
507 ))
508 .add_message_handler(create_buffer_for_peer)
509 .add_request_handler(update_buffer)
510 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
511 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
512 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
513 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
514 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
515 .add_request_handler(get_users)
516 .add_request_handler(user_handler(fuzzy_search_users))
517 .add_request_handler(user_handler(request_contact))
518 .add_request_handler(user_handler(remove_contact))
519 .add_request_handler(user_handler(respond_to_contact_request))
520 .add_request_handler(user_handler(create_channel))
521 .add_request_handler(user_handler(delete_channel))
522 .add_request_handler(user_handler(invite_channel_member))
523 .add_request_handler(user_handler(remove_channel_member))
524 .add_request_handler(user_handler(set_channel_member_role))
525 .add_request_handler(user_handler(set_channel_visibility))
526 .add_request_handler(user_handler(rename_channel))
527 .add_request_handler(user_handler(join_channel_buffer))
528 .add_request_handler(user_handler(leave_channel_buffer))
529 .add_message_handler(user_message_handler(update_channel_buffer))
530 .add_request_handler(user_handler(rejoin_channel_buffers))
531 .add_request_handler(user_handler(get_channel_members))
532 .add_request_handler(user_handler(respond_to_channel_invite))
533 .add_request_handler(user_handler(join_channel))
534 .add_request_handler(user_handler(join_channel_chat))
535 .add_message_handler(user_message_handler(leave_channel_chat))
536 .add_request_handler(user_handler(send_channel_message))
537 .add_request_handler(user_handler(remove_channel_message))
538 .add_request_handler(user_handler(update_channel_message))
539 .add_request_handler(user_handler(get_channel_messages))
540 .add_request_handler(user_handler(get_channel_messages_by_id))
541 .add_request_handler(user_handler(get_notifications))
542 .add_request_handler(user_handler(mark_notification_as_read))
543 .add_request_handler(user_handler(move_channel))
544 .add_request_handler(user_handler(follow))
545 .add_message_handler(user_message_handler(unfollow))
546 .add_message_handler(user_message_handler(update_followers))
547 .add_request_handler(user_handler(get_private_user_info))
548 .add_message_handler(user_message_handler(acknowledge_channel_message))
549 .add_message_handler(user_message_handler(acknowledge_buffer_version))
550 .add_streaming_request_handler({
551 let app_state = app_state.clone();
552 move |request, response, session| {
553 complete_with_language_model(
554 request,
555 response,
556 session,
557 app_state.config.openai_api_key.clone(),
558 app_state.config.google_ai_api_key.clone(),
559 app_state.config.anthropic_api_key.clone(),
560 )
561 }
562 })
563 .add_request_handler({
564 let app_state = app_state.clone();
565 user_handler(move |request, response, session| {
566 count_tokens_with_language_model(
567 request,
568 response,
569 session,
570 app_state.config.google_ai_api_key.clone(),
571 )
572 })
573 })
574 .add_request_handler({
575 user_handler(move |request, response, session| {
576 get_cached_embeddings(request, response, session)
577 })
578 })
579 .add_request_handler({
580 let app_state = app_state.clone();
581 user_handler(move |request, response, session| {
582 compute_embeddings(
583 request,
584 response,
585 session,
586 app_state.config.openai_api_key.clone(),
587 )
588 })
589 });
590
591 Arc::new(server)
592 }
593
594 pub async fn start(&self) -> Result<()> {
595 let server_id = *self.id.lock();
596 let app_state = self.app_state.clone();
597 let peer = self.peer.clone();
598 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
599 let pool = self.connection_pool.clone();
600 let live_kit_client = self.app_state.live_kit_client.clone();
601
602 let span = info_span!("start server");
603 self.app_state.executor.spawn_detached(
604 async move {
605 tracing::info!("waiting for cleanup timeout");
606 timeout.await;
607 tracing::info!("cleanup timeout expired, retrieving stale rooms");
608 if let Some((room_ids, channel_ids)) = app_state
609 .db
610 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
611 .await
612 .trace_err()
613 {
614 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
615 tracing::info!(
616 stale_channel_buffer_count = channel_ids.len(),
617 "retrieved stale channel buffers"
618 );
619
620 for channel_id in channel_ids {
621 if let Some(refreshed_channel_buffer) = app_state
622 .db
623 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
624 .await
625 .trace_err()
626 {
627 for connection_id in refreshed_channel_buffer.connection_ids {
628 peer.send(
629 connection_id,
630 proto::UpdateChannelBufferCollaborators {
631 channel_id: channel_id.to_proto(),
632 collaborators: refreshed_channel_buffer
633 .collaborators
634 .clone(),
635 },
636 )
637 .trace_err();
638 }
639 }
640 }
641
642 for room_id in room_ids {
643 let mut contacts_to_update = HashSet::default();
644 let mut canceled_calls_to_user_ids = Vec::new();
645 let mut live_kit_room = String::new();
646 let mut delete_live_kit_room = false;
647
648 if let Some(mut refreshed_room) = app_state
649 .db
650 .clear_stale_room_participants(room_id, server_id)
651 .await
652 .trace_err()
653 {
654 tracing::info!(
655 room_id = room_id.0,
656 new_participant_count = refreshed_room.room.participants.len(),
657 "refreshed room"
658 );
659 room_updated(&refreshed_room.room, &peer);
660 if let Some(channel) = refreshed_room.channel.as_ref() {
661 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
662 }
663 contacts_to_update
664 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
665 contacts_to_update
666 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
667 canceled_calls_to_user_ids =
668 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
669 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
670 delete_live_kit_room = refreshed_room.room.participants.is_empty();
671 }
672
673 {
674 let pool = pool.lock();
675 for canceled_user_id in canceled_calls_to_user_ids {
676 for connection_id in pool.user_connection_ids(canceled_user_id) {
677 peer.send(
678 connection_id,
679 proto::CallCanceled {
680 room_id: room_id.to_proto(),
681 },
682 )
683 .trace_err();
684 }
685 }
686 }
687
688 for user_id in contacts_to_update {
689 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
690 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
691 if let Some((busy, contacts)) = busy.zip(contacts) {
692 let pool = pool.lock();
693 let updated_contact = contact_for_user(user_id, busy, &pool);
694 for contact in contacts {
695 if let db::Contact::Accepted {
696 user_id: contact_user_id,
697 ..
698 } = contact
699 {
700 for contact_conn_id in
701 pool.user_connection_ids(contact_user_id)
702 {
703 peer.send(
704 contact_conn_id,
705 proto::UpdateContacts {
706 contacts: vec![updated_contact.clone()],
707 remove_contacts: Default::default(),
708 incoming_requests: Default::default(),
709 remove_incoming_requests: Default::default(),
710 outgoing_requests: Default::default(),
711 remove_outgoing_requests: Default::default(),
712 },
713 )
714 .trace_err();
715 }
716 }
717 }
718 }
719 }
720
721 if let Some(live_kit) = live_kit_client.as_ref() {
722 if delete_live_kit_room {
723 live_kit.delete_room(live_kit_room).await.trace_err();
724 }
725 }
726 }
727 }
728
729 app_state
730 .db
731 .delete_stale_servers(&app_state.config.zed_environment, server_id)
732 .await
733 .trace_err();
734 }
735 .instrument(span),
736 );
737 Ok(())
738 }
739
740 pub fn teardown(&self) {
741 self.peer.teardown();
742 self.connection_pool.lock().reset();
743 let _ = self.teardown.send(true);
744 }
745
746 #[cfg(test)]
747 pub fn reset(&self, id: ServerId) {
748 self.teardown();
749 *self.id.lock() = id;
750 self.peer.reset(id.0 as u32);
751 let _ = self.teardown.send(false);
752 }
753
754 #[cfg(test)]
755 pub fn id(&self) -> ServerId {
756 *self.id.lock()
757 }
758
759 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
760 where
761 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
762 Fut: 'static + Send + Future<Output = Result<()>>,
763 M: EnvelopedMessage,
764 {
765 let prev_handler = self.handlers.insert(
766 TypeId::of::<M>(),
767 Box::new(move |envelope, session| {
768 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
769 let received_at = envelope.received_at;
770 tracing::info!(
771 "message received"
772 );
773 let start_time = Instant::now();
774 let future = (handler)(*envelope, session);
775 async move {
776 let result = future.await;
777 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
778 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
779 let queue_duration_ms = total_duration_ms - processing_duration_ms;
780 let payload_type = M::NAME;
781 match result {
782 Err(error) => {
783 // todo!(), why isn't this logged inside the span?
784 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
785 }
786 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
787 }
788 }
789 .boxed()
790 }),
791 );
792 if prev_handler.is_some() {
793 panic!("registered a handler for the same message twice");
794 }
795 self
796 }
797
798 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
799 where
800 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
801 Fut: 'static + Send + Future<Output = Result<()>>,
802 M: EnvelopedMessage,
803 {
804 self.add_handler(move |envelope, session| handler(envelope.payload, session));
805 self
806 }
807
808 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
809 where
810 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
811 Fut: Send + Future<Output = Result<()>>,
812 M: RequestMessage,
813 {
814 let handler = Arc::new(handler);
815 self.add_handler(move |envelope, session| {
816 let receipt = envelope.receipt();
817 let handler = handler.clone();
818 async move {
819 let peer = session.peer.clone();
820 let responded = Arc::new(AtomicBool::default());
821 let response = Response {
822 peer: peer.clone(),
823 responded: responded.clone(),
824 receipt,
825 };
826 match (handler)(envelope.payload, response, session).await {
827 Ok(()) => {
828 if responded.load(std::sync::atomic::Ordering::SeqCst) {
829 Ok(())
830 } else {
831 Err(anyhow!("handler did not send a response"))?
832 }
833 }
834 Err(error) => {
835 let proto_err = match &error {
836 Error::Internal(err) => err.to_proto(),
837 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
838 };
839 peer.respond_with_error(receipt, proto_err)?;
840 Err(error)
841 }
842 }
843 }
844 })
845 }
846
847 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
848 where
849 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
850 Fut: Send + Future<Output = Result<()>>,
851 M: RequestMessage,
852 {
853 let handler = Arc::new(handler);
854 self.add_handler(move |envelope, session| {
855 let receipt = envelope.receipt();
856 let handler = handler.clone();
857 async move {
858 let peer = session.peer.clone();
859 let response = StreamingResponse {
860 peer: peer.clone(),
861 receipt,
862 };
863 match (handler)(envelope.payload, response, session).await {
864 Ok(()) => {
865 peer.end_stream(receipt)?;
866 Ok(())
867 }
868 Err(error) => {
869 let proto_err = match &error {
870 Error::Internal(err) => err.to_proto(),
871 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
872 };
873 peer.respond_with_error(receipt, proto_err)?;
874 Err(error)
875 }
876 }
877 }
878 })
879 }
880
881 #[allow(clippy::too_many_arguments)]
882 pub fn handle_connection(
883 self: &Arc<Self>,
884 connection: Connection,
885 address: String,
886 principal: Principal,
887 zed_version: ZedVersion,
888 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
889 executor: Executor,
890 ) -> impl Future<Output = ()> {
891 let this = self.clone();
892 let span = info_span!("handle connection", %address,
893 connection_id=field::Empty,
894 user_id=field::Empty,
895 login=field::Empty,
896 impersonator=field::Empty,
897 dev_server_id=field::Empty
898 );
899 principal.update_span(&span);
900
901 let mut teardown = self.teardown.subscribe();
902 async move {
903 if *teardown.borrow() {
904 tracing::error!("server is tearing down");
905 return
906 }
907 let (connection_id, handle_io, mut incoming_rx) = this
908 .peer
909 .add_connection(connection, {
910 let executor = executor.clone();
911 move |duration| executor.sleep(duration)
912 });
913 tracing::Span::current().record("connection_id", format!("{}", connection_id));
914 tracing::info!("connection opened");
915
916 let http_client = match IsahcHttpClient::new() {
917 Ok(http_client) => http_client,
918 Err(error) => {
919 tracing::error!(?error, "failed to create HTTP client");
920 return;
921 }
922 };
923
924 let session = Session {
925 principal: principal.clone(),
926 connection_id,
927 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
928 peer: this.peer.clone(),
929 connection_pool: this.connection_pool.clone(),
930 live_kit_client: this.app_state.live_kit_client.clone(),
931 http_client,
932 rate_limiter: this.app_state.rate_limiter.clone(),
933 _executor: executor.clone(),
934 };
935
936 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
937 tracing::error!(?error, "failed to send initial client update");
938 return;
939 }
940
941 let handle_io = handle_io.fuse();
942 futures::pin_mut!(handle_io);
943
944 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
945 // This prevents deadlocks when e.g., client A performs a request to client B and
946 // client B performs a request to client A. If both clients stop processing further
947 // messages until their respective request completes, they won't have a chance to
948 // respond to the other client's request and cause a deadlock.
949 //
950 // This arrangement ensures we will attempt to process earlier messages first, but fall
951 // back to processing messages arrived later in the spirit of making progress.
952 let mut foreground_message_handlers = FuturesUnordered::new();
953 let concurrent_handlers = Arc::new(Semaphore::new(256));
954 loop {
955 let next_message = async {
956 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
957 let message = incoming_rx.next().await;
958 (permit, message)
959 }.fuse();
960 futures::pin_mut!(next_message);
961 futures::select_biased! {
962 _ = teardown.changed().fuse() => return,
963 result = handle_io => {
964 if let Err(error) = result {
965 tracing::error!(?error, "error handling I/O");
966 }
967 break;
968 }
969 _ = foreground_message_handlers.next() => {}
970 next_message = next_message => {
971 let (permit, message) = next_message;
972 if let Some(message) = message {
973 let type_name = message.payload_type_name();
974 // note: we copy all the fields from the parent span so we can query them in the logs.
975 // (https://github.com/tokio-rs/tracing/issues/2670).
976 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
977 user_id=field::Empty,
978 login=field::Empty,
979 impersonator=field::Empty,
980 dev_server_id=field::Empty
981 );
982 principal.update_span(&span);
983 let span_enter = span.enter();
984 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
985 let is_background = message.is_background();
986 let handle_message = (handler)(message, session.clone());
987 drop(span_enter);
988
989 let handle_message = async move {
990 handle_message.await;
991 drop(permit);
992 }.instrument(span);
993 if is_background {
994 executor.spawn_detached(handle_message);
995 } else {
996 foreground_message_handlers.push(handle_message);
997 }
998 } else {
999 tracing::error!("no message handler");
1000 }
1001 } else {
1002 tracing::info!("connection closed");
1003 break;
1004 }
1005 }
1006 }
1007 }
1008
1009 drop(foreground_message_handlers);
1010 tracing::info!("signing out");
1011 if let Err(error) = connection_lost(session, teardown, executor).await {
1012 tracing::error!(?error, "error signing out");
1013 }
1014
1015 }.instrument(span)
1016 }
1017
1018 async fn send_initial_client_update(
1019 &self,
1020 connection_id: ConnectionId,
1021 principal: &Principal,
1022 zed_version: ZedVersion,
1023 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1024 session: &Session,
1025 ) -> Result<()> {
1026 self.peer.send(
1027 connection_id,
1028 proto::Hello {
1029 peer_id: Some(connection_id.into()),
1030 },
1031 )?;
1032 tracing::info!("sent hello message");
1033 if let Some(send_connection_id) = send_connection_id.take() {
1034 let _ = send_connection_id.send(connection_id);
1035 }
1036
1037 match principal {
1038 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1039 if !user.connected_once {
1040 self.peer.send(connection_id, proto::ShowContacts {})?;
1041 self.app_state
1042 .db
1043 .set_user_connected_once(user.id, true)
1044 .await?;
1045 }
1046
1047 let (contacts, channels_for_user, channel_invites) = future::try_join3(
1048 self.app_state.db.get_contacts(user.id),
1049 self.app_state.db.get_channels_for_user(user.id),
1050 self.app_state.db.get_channel_invites_for_user(user.id),
1051 )
1052 .await?;
1053
1054 {
1055 let mut pool = self.connection_pool.lock();
1056 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1057 for membership in &channels_for_user.channel_memberships {
1058 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1059 }
1060 self.peer.send(
1061 connection_id,
1062 build_initial_contacts_update(contacts, &pool),
1063 )?;
1064 self.peer.send(
1065 connection_id,
1066 build_update_user_channels(&channels_for_user),
1067 )?;
1068 self.peer.send(
1069 connection_id,
1070 build_channels_update(channels_for_user, channel_invites, &pool),
1071 )?;
1072 }
1073
1074 if let Some(incoming_call) =
1075 self.app_state.db.incoming_call_for_user(user.id).await?
1076 {
1077 self.peer.send(connection_id, incoming_call)?;
1078 }
1079
1080 update_user_contacts(user.id, &session).await?;
1081 }
1082 Principal::DevServer(dev_server) => {
1083 {
1084 let mut pool = self.connection_pool.lock();
1085 if pool.dev_server_connection_id(dev_server.id).is_some() {
1086 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1087 };
1088 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1089 }
1090 update_dev_server_status(dev_server, proto::DevServerStatus::Online, &session)
1091 .await;
1092 // todo!() allow only one connection.
1093
1094 let projects = self
1095 .app_state
1096 .db
1097 .get_remote_projects_for_dev_server(dev_server.id)
1098 .await?;
1099 self.peer
1100 .send(connection_id, proto::DevServerInstructions { projects })?;
1101 }
1102 }
1103
1104 Ok(())
1105 }
1106
1107 pub async fn invite_code_redeemed(
1108 self: &Arc<Self>,
1109 inviter_id: UserId,
1110 invitee_id: UserId,
1111 ) -> Result<()> {
1112 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1113 if let Some(code) = &user.invite_code {
1114 let pool = self.connection_pool.lock();
1115 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1116 for connection_id in pool.user_connection_ids(inviter_id) {
1117 self.peer.send(
1118 connection_id,
1119 proto::UpdateContacts {
1120 contacts: vec![invitee_contact.clone()],
1121 ..Default::default()
1122 },
1123 )?;
1124 self.peer.send(
1125 connection_id,
1126 proto::UpdateInviteInfo {
1127 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1128 count: user.invite_count as u32,
1129 },
1130 )?;
1131 }
1132 }
1133 }
1134 Ok(())
1135 }
1136
1137 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1138 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1139 if let Some(invite_code) = &user.invite_code {
1140 let pool = self.connection_pool.lock();
1141 for connection_id in pool.user_connection_ids(user_id) {
1142 self.peer.send(
1143 connection_id,
1144 proto::UpdateInviteInfo {
1145 url: format!(
1146 "{}{}",
1147 self.app_state.config.invite_link_prefix, invite_code
1148 ),
1149 count: user.invite_count as u32,
1150 },
1151 )?;
1152 }
1153 }
1154 }
1155 Ok(())
1156 }
1157
1158 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1159 ServerSnapshot {
1160 connection_pool: ConnectionPoolGuard {
1161 guard: self.connection_pool.lock(),
1162 _not_send: PhantomData,
1163 },
1164 peer: &self.peer,
1165 }
1166 }
1167}
1168
1169impl<'a> Deref for ConnectionPoolGuard<'a> {
1170 type Target = ConnectionPool;
1171
1172 fn deref(&self) -> &Self::Target {
1173 &self.guard
1174 }
1175}
1176
1177impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1178 fn deref_mut(&mut self) -> &mut Self::Target {
1179 &mut self.guard
1180 }
1181}
1182
1183impl<'a> Drop for ConnectionPoolGuard<'a> {
1184 fn drop(&mut self) {
1185 #[cfg(test)]
1186 self.check_invariants();
1187 }
1188}
1189
1190fn broadcast<F>(
1191 sender_id: Option<ConnectionId>,
1192 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1193 mut f: F,
1194) where
1195 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1196{
1197 for receiver_id in receiver_ids {
1198 if Some(receiver_id) != sender_id {
1199 if let Err(error) = f(receiver_id) {
1200 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1201 }
1202 }
1203 }
1204}
1205
1206pub struct ProtocolVersion(u32);
1207
1208impl Header for ProtocolVersion {
1209 fn name() -> &'static HeaderName {
1210 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1211 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1212 }
1213
1214 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1215 where
1216 Self: Sized,
1217 I: Iterator<Item = &'i axum::http::HeaderValue>,
1218 {
1219 let version = values
1220 .next()
1221 .ok_or_else(axum::headers::Error::invalid)?
1222 .to_str()
1223 .map_err(|_| axum::headers::Error::invalid())?
1224 .parse()
1225 .map_err(|_| axum::headers::Error::invalid())?;
1226 Ok(Self(version))
1227 }
1228
1229 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1230 values.extend([self.0.to_string().parse().unwrap()]);
1231 }
1232}
1233
1234pub struct AppVersionHeader(SemanticVersion);
1235impl Header for AppVersionHeader {
1236 fn name() -> &'static HeaderName {
1237 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1238 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1239 }
1240
1241 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1242 where
1243 Self: Sized,
1244 I: Iterator<Item = &'i axum::http::HeaderValue>,
1245 {
1246 let version = values
1247 .next()
1248 .ok_or_else(axum::headers::Error::invalid)?
1249 .to_str()
1250 .map_err(|_| axum::headers::Error::invalid())?
1251 .parse()
1252 .map_err(|_| axum::headers::Error::invalid())?;
1253 Ok(Self(version))
1254 }
1255
1256 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1257 values.extend([self.0.to_string().parse().unwrap()]);
1258 }
1259}
1260
1261pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1262 Router::new()
1263 .route("/rpc", get(handle_websocket_request))
1264 .layer(
1265 ServiceBuilder::new()
1266 .layer(Extension(server.app_state.clone()))
1267 .layer(middleware::from_fn(auth::validate_header)),
1268 )
1269 .route("/metrics", get(handle_metrics))
1270 .layer(Extension(server))
1271}
1272
1273pub async fn handle_websocket_request(
1274 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1275 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1276 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1277 Extension(server): Extension<Arc<Server>>,
1278 Extension(principal): Extension<Principal>,
1279 ws: WebSocketUpgrade,
1280) -> axum::response::Response {
1281 if protocol_version != rpc::PROTOCOL_VERSION {
1282 return (
1283 StatusCode::UPGRADE_REQUIRED,
1284 "client must be upgraded".to_string(),
1285 )
1286 .into_response();
1287 }
1288
1289 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1290 return (
1291 StatusCode::UPGRADE_REQUIRED,
1292 "no version header found".to_string(),
1293 )
1294 .into_response();
1295 };
1296
1297 if !version.can_collaborate() {
1298 return (
1299 StatusCode::UPGRADE_REQUIRED,
1300 "client must be upgraded".to_string(),
1301 )
1302 .into_response();
1303 }
1304
1305 let socket_address = socket_address.to_string();
1306 ws.on_upgrade(move |socket| {
1307 let socket = socket
1308 .map_ok(to_tungstenite_message)
1309 .err_into()
1310 .with(|message| async move { Ok(to_axum_message(message)) });
1311 let connection = Connection::new(Box::pin(socket));
1312 async move {
1313 server
1314 .handle_connection(
1315 connection,
1316 socket_address,
1317 principal,
1318 version,
1319 None,
1320 Executor::Production,
1321 )
1322 .await;
1323 }
1324 })
1325}
1326
1327pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1328 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1329 let connections_metric = CONNECTIONS_METRIC
1330 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1331
1332 let connections = server
1333 .connection_pool
1334 .lock()
1335 .connections()
1336 .filter(|connection| !connection.admin)
1337 .count();
1338 connections_metric.set(connections as _);
1339
1340 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1341 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1342 register_int_gauge!(
1343 "shared_projects",
1344 "number of open projects with one or more guests"
1345 )
1346 .unwrap()
1347 });
1348
1349 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1350 shared_projects_metric.set(shared_projects as _);
1351
1352 let encoder = prometheus::TextEncoder::new();
1353 let metric_families = prometheus::gather();
1354 let encoded_metrics = encoder
1355 .encode_to_string(&metric_families)
1356 .map_err(|err| anyhow!("{}", err))?;
1357 Ok(encoded_metrics)
1358}
1359
1360#[instrument(err, skip(executor))]
1361async fn connection_lost(
1362 session: Session,
1363 mut teardown: watch::Receiver<bool>,
1364 executor: Executor,
1365) -> Result<()> {
1366 session.peer.disconnect(session.connection_id);
1367 session
1368 .connection_pool()
1369 .await
1370 .remove_connection(session.connection_id)?;
1371
1372 session
1373 .db()
1374 .await
1375 .connection_lost(session.connection_id)
1376 .await
1377 .trace_err();
1378
1379 futures::select_biased! {
1380 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1381 match &session.principal {
1382 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1383 let session = session.for_user().unwrap();
1384
1385 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1386 leave_room_for_session(&session, session.connection_id).await.trace_err();
1387 leave_channel_buffers_for_session(&session)
1388 .await
1389 .trace_err();
1390
1391 if !session
1392 .connection_pool()
1393 .await
1394 .is_user_online(session.user_id())
1395 {
1396 let db = session.db().await;
1397 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1398 room_updated(&room, &session.peer);
1399 }
1400 }
1401
1402 update_user_contacts(session.user_id(), &session).await?;
1403 },
1404 Principal::DevServer(dev_server) => {
1405 lost_dev_server_connection(&session).await?;
1406 update_dev_server_status(&dev_server, proto::DevServerStatus::Offline, &session)
1407 .await;
1408 },
1409 }
1410 },
1411 _ = teardown.changed().fuse() => {}
1412 }
1413
1414 Ok(())
1415}
1416
1417/// Acknowledges a ping from a client, used to keep the connection alive.
1418async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1419 response.send(proto::Ack {})?;
1420 Ok(())
1421}
1422
1423/// Creates a new room for calling (outside of channels)
1424async fn create_room(
1425 _request: proto::CreateRoom,
1426 response: Response<proto::CreateRoom>,
1427 session: UserSession,
1428) -> Result<()> {
1429 let live_kit_room = nanoid::nanoid!(30);
1430
1431 let live_kit_connection_info = util::maybe!(async {
1432 let live_kit = session.live_kit_client.as_ref();
1433 let live_kit = live_kit?;
1434 let user_id = session.user_id().to_string();
1435
1436 let token = live_kit
1437 .room_token(&live_kit_room, &user_id.to_string())
1438 .trace_err()?;
1439
1440 Some(proto::LiveKitConnectionInfo {
1441 server_url: live_kit.url().into(),
1442 token,
1443 can_publish: true,
1444 })
1445 })
1446 .await;
1447
1448 let room = session
1449 .db()
1450 .await
1451 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1452 .await?;
1453
1454 response.send(proto::CreateRoomResponse {
1455 room: Some(room.clone()),
1456 live_kit_connection_info,
1457 })?;
1458
1459 update_user_contacts(session.user_id(), &session).await?;
1460 Ok(())
1461}
1462
1463/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1464async fn join_room(
1465 request: proto::JoinRoom,
1466 response: Response<proto::JoinRoom>,
1467 session: UserSession,
1468) -> Result<()> {
1469 let room_id = RoomId::from_proto(request.id);
1470
1471 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1472
1473 if let Some(channel_id) = channel_id {
1474 return join_channel_internal(channel_id, Box::new(response), session).await;
1475 }
1476
1477 let joined_room = {
1478 let room = session
1479 .db()
1480 .await
1481 .join_room(room_id, session.user_id(), session.connection_id)
1482 .await?;
1483 room_updated(&room.room, &session.peer);
1484 room.into_inner()
1485 };
1486
1487 for connection_id in session
1488 .connection_pool()
1489 .await
1490 .user_connection_ids(session.user_id())
1491 {
1492 session
1493 .peer
1494 .send(
1495 connection_id,
1496 proto::CallCanceled {
1497 room_id: room_id.to_proto(),
1498 },
1499 )
1500 .trace_err();
1501 }
1502
1503 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1504 if let Some(token) = live_kit
1505 .room_token(
1506 &joined_room.room.live_kit_room,
1507 &session.user_id().to_string(),
1508 )
1509 .trace_err()
1510 {
1511 Some(proto::LiveKitConnectionInfo {
1512 server_url: live_kit.url().into(),
1513 token,
1514 can_publish: true,
1515 })
1516 } else {
1517 None
1518 }
1519 } else {
1520 None
1521 };
1522
1523 response.send(proto::JoinRoomResponse {
1524 room: Some(joined_room.room),
1525 channel_id: None,
1526 live_kit_connection_info,
1527 })?;
1528
1529 update_user_contacts(session.user_id(), &session).await?;
1530 Ok(())
1531}
1532
1533/// Rejoin room is used to reconnect to a room after connection errors.
1534async fn rejoin_room(
1535 request: proto::RejoinRoom,
1536 response: Response<proto::RejoinRoom>,
1537 session: UserSession,
1538) -> Result<()> {
1539 let room;
1540 let channel;
1541 {
1542 let mut rejoined_room = session
1543 .db()
1544 .await
1545 .rejoin_room(request, session.user_id(), session.connection_id)
1546 .await?;
1547
1548 response.send(proto::RejoinRoomResponse {
1549 room: Some(rejoined_room.room.clone()),
1550 reshared_projects: rejoined_room
1551 .reshared_projects
1552 .iter()
1553 .map(|project| proto::ResharedProject {
1554 id: project.id.to_proto(),
1555 collaborators: project
1556 .collaborators
1557 .iter()
1558 .map(|collaborator| collaborator.to_proto())
1559 .collect(),
1560 })
1561 .collect(),
1562 rejoined_projects: rejoined_room
1563 .rejoined_projects
1564 .iter()
1565 .map(|rejoined_project| rejoined_project.to_proto())
1566 .collect(),
1567 })?;
1568 room_updated(&rejoined_room.room, &session.peer);
1569
1570 for project in &rejoined_room.reshared_projects {
1571 for collaborator in &project.collaborators {
1572 session
1573 .peer
1574 .send(
1575 collaborator.connection_id,
1576 proto::UpdateProjectCollaborator {
1577 project_id: project.id.to_proto(),
1578 old_peer_id: Some(project.old_connection_id.into()),
1579 new_peer_id: Some(session.connection_id.into()),
1580 },
1581 )
1582 .trace_err();
1583 }
1584
1585 broadcast(
1586 Some(session.connection_id),
1587 project
1588 .collaborators
1589 .iter()
1590 .map(|collaborator| collaborator.connection_id),
1591 |connection_id| {
1592 session.peer.forward_send(
1593 session.connection_id,
1594 connection_id,
1595 proto::UpdateProject {
1596 project_id: project.id.to_proto(),
1597 worktrees: project.worktrees.clone(),
1598 },
1599 )
1600 },
1601 );
1602 }
1603
1604 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1605
1606 let rejoined_room = rejoined_room.into_inner();
1607
1608 room = rejoined_room.room;
1609 channel = rejoined_room.channel;
1610 }
1611
1612 if let Some(channel) = channel {
1613 channel_updated(
1614 &channel,
1615 &room,
1616 &session.peer,
1617 &*session.connection_pool().await,
1618 );
1619 }
1620
1621 update_user_contacts(session.user_id(), &session).await?;
1622 Ok(())
1623}
1624
1625fn notify_rejoined_projects(
1626 rejoined_projects: &mut Vec<RejoinedProject>,
1627 session: &UserSession,
1628) -> Result<()> {
1629 for project in rejoined_projects.iter() {
1630 for collaborator in &project.collaborators {
1631 session
1632 .peer
1633 .send(
1634 collaborator.connection_id,
1635 proto::UpdateProjectCollaborator {
1636 project_id: project.id.to_proto(),
1637 old_peer_id: Some(project.old_connection_id.into()),
1638 new_peer_id: Some(session.connection_id.into()),
1639 },
1640 )
1641 .trace_err();
1642 }
1643 }
1644
1645 for project in rejoined_projects {
1646 for worktree in mem::take(&mut project.worktrees) {
1647 #[cfg(any(test, feature = "test-support"))]
1648 const MAX_CHUNK_SIZE: usize = 2;
1649 #[cfg(not(any(test, feature = "test-support")))]
1650 const MAX_CHUNK_SIZE: usize = 256;
1651
1652 // Stream this worktree's entries.
1653 let message = proto::UpdateWorktree {
1654 project_id: project.id.to_proto(),
1655 worktree_id: worktree.id,
1656 abs_path: worktree.abs_path.clone(),
1657 root_name: worktree.root_name,
1658 updated_entries: worktree.updated_entries,
1659 removed_entries: worktree.removed_entries,
1660 scan_id: worktree.scan_id,
1661 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1662 updated_repositories: worktree.updated_repositories,
1663 removed_repositories: worktree.removed_repositories,
1664 };
1665 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1666 session.peer.send(session.connection_id, update.clone())?;
1667 }
1668
1669 // Stream this worktree's diagnostics.
1670 for summary in worktree.diagnostic_summaries {
1671 session.peer.send(
1672 session.connection_id,
1673 proto::UpdateDiagnosticSummary {
1674 project_id: project.id.to_proto(),
1675 worktree_id: worktree.id,
1676 summary: Some(summary),
1677 },
1678 )?;
1679 }
1680
1681 for settings_file in worktree.settings_files {
1682 session.peer.send(
1683 session.connection_id,
1684 proto::UpdateWorktreeSettings {
1685 project_id: project.id.to_proto(),
1686 worktree_id: worktree.id,
1687 path: settings_file.path,
1688 content: Some(settings_file.content),
1689 },
1690 )?;
1691 }
1692 }
1693
1694 for language_server in &project.language_servers {
1695 session.peer.send(
1696 session.connection_id,
1697 proto::UpdateLanguageServer {
1698 project_id: project.id.to_proto(),
1699 language_server_id: language_server.id,
1700 variant: Some(
1701 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1702 proto::LspDiskBasedDiagnosticsUpdated {},
1703 ),
1704 ),
1705 },
1706 )?;
1707 }
1708 }
1709 Ok(())
1710}
1711
1712/// leave room disconnects from the room.
1713async fn leave_room(
1714 _: proto::LeaveRoom,
1715 response: Response<proto::LeaveRoom>,
1716 session: UserSession,
1717) -> Result<()> {
1718 leave_room_for_session(&session, session.connection_id).await?;
1719 response.send(proto::Ack {})?;
1720 Ok(())
1721}
1722
1723/// Updates the permissions of someone else in the room.
1724async fn set_room_participant_role(
1725 request: proto::SetRoomParticipantRole,
1726 response: Response<proto::SetRoomParticipantRole>,
1727 session: UserSession,
1728) -> Result<()> {
1729 let user_id = UserId::from_proto(request.user_id);
1730 let role = ChannelRole::from(request.role());
1731
1732 let (live_kit_room, can_publish) = {
1733 let room = session
1734 .db()
1735 .await
1736 .set_room_participant_role(
1737 session.user_id(),
1738 RoomId::from_proto(request.room_id),
1739 user_id,
1740 role,
1741 )
1742 .await?;
1743
1744 let live_kit_room = room.live_kit_room.clone();
1745 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1746 room_updated(&room, &session.peer);
1747 (live_kit_room, can_publish)
1748 };
1749
1750 if let Some(live_kit) = session.live_kit_client.as_ref() {
1751 live_kit
1752 .update_participant(
1753 live_kit_room.clone(),
1754 request.user_id.to_string(),
1755 live_kit_server::proto::ParticipantPermission {
1756 can_subscribe: true,
1757 can_publish,
1758 can_publish_data: can_publish,
1759 hidden: false,
1760 recorder: false,
1761 },
1762 )
1763 .await
1764 .trace_err();
1765 }
1766
1767 response.send(proto::Ack {})?;
1768 Ok(())
1769}
1770
1771/// Call someone else into the current room
1772async fn call(
1773 request: proto::Call,
1774 response: Response<proto::Call>,
1775 session: UserSession,
1776) -> Result<()> {
1777 let room_id = RoomId::from_proto(request.room_id);
1778 let calling_user_id = session.user_id();
1779 let calling_connection_id = session.connection_id;
1780 let called_user_id = UserId::from_proto(request.called_user_id);
1781 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1782 if !session
1783 .db()
1784 .await
1785 .has_contact(calling_user_id, called_user_id)
1786 .await?
1787 {
1788 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1789 }
1790
1791 let incoming_call = {
1792 let (room, incoming_call) = &mut *session
1793 .db()
1794 .await
1795 .call(
1796 room_id,
1797 calling_user_id,
1798 calling_connection_id,
1799 called_user_id,
1800 initial_project_id,
1801 )
1802 .await?;
1803 room_updated(&room, &session.peer);
1804 mem::take(incoming_call)
1805 };
1806 update_user_contacts(called_user_id, &session).await?;
1807
1808 let mut calls = session
1809 .connection_pool()
1810 .await
1811 .user_connection_ids(called_user_id)
1812 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1813 .collect::<FuturesUnordered<_>>();
1814
1815 while let Some(call_response) = calls.next().await {
1816 match call_response.as_ref() {
1817 Ok(_) => {
1818 response.send(proto::Ack {})?;
1819 return Ok(());
1820 }
1821 Err(_) => {
1822 call_response.trace_err();
1823 }
1824 }
1825 }
1826
1827 {
1828 let room = session
1829 .db()
1830 .await
1831 .call_failed(room_id, called_user_id)
1832 .await?;
1833 room_updated(&room, &session.peer);
1834 }
1835 update_user_contacts(called_user_id, &session).await?;
1836
1837 Err(anyhow!("failed to ring user"))?
1838}
1839
1840/// Cancel an outgoing call.
1841async fn cancel_call(
1842 request: proto::CancelCall,
1843 response: Response<proto::CancelCall>,
1844 session: UserSession,
1845) -> Result<()> {
1846 let called_user_id = UserId::from_proto(request.called_user_id);
1847 let room_id = RoomId::from_proto(request.room_id);
1848 {
1849 let room = session
1850 .db()
1851 .await
1852 .cancel_call(room_id, session.connection_id, called_user_id)
1853 .await?;
1854 room_updated(&room, &session.peer);
1855 }
1856
1857 for connection_id in session
1858 .connection_pool()
1859 .await
1860 .user_connection_ids(called_user_id)
1861 {
1862 session
1863 .peer
1864 .send(
1865 connection_id,
1866 proto::CallCanceled {
1867 room_id: room_id.to_proto(),
1868 },
1869 )
1870 .trace_err();
1871 }
1872 response.send(proto::Ack {})?;
1873
1874 update_user_contacts(called_user_id, &session).await?;
1875 Ok(())
1876}
1877
1878/// Decline an incoming call.
1879async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1880 let room_id = RoomId::from_proto(message.room_id);
1881 {
1882 let room = session
1883 .db()
1884 .await
1885 .decline_call(Some(room_id), session.user_id())
1886 .await?
1887 .ok_or_else(|| anyhow!("failed to decline call"))?;
1888 room_updated(&room, &session.peer);
1889 }
1890
1891 for connection_id in session
1892 .connection_pool()
1893 .await
1894 .user_connection_ids(session.user_id())
1895 {
1896 session
1897 .peer
1898 .send(
1899 connection_id,
1900 proto::CallCanceled {
1901 room_id: room_id.to_proto(),
1902 },
1903 )
1904 .trace_err();
1905 }
1906 update_user_contacts(session.user_id(), &session).await?;
1907 Ok(())
1908}
1909
1910/// Updates other participants in the room with your current location.
1911async fn update_participant_location(
1912 request: proto::UpdateParticipantLocation,
1913 response: Response<proto::UpdateParticipantLocation>,
1914 session: UserSession,
1915) -> Result<()> {
1916 let room_id = RoomId::from_proto(request.room_id);
1917 let location = request
1918 .location
1919 .ok_or_else(|| anyhow!("invalid location"))?;
1920
1921 let db = session.db().await;
1922 let room = db
1923 .update_room_participant_location(room_id, session.connection_id, location)
1924 .await?;
1925
1926 room_updated(&room, &session.peer);
1927 response.send(proto::Ack {})?;
1928 Ok(())
1929}
1930
1931/// Share a project into the room.
1932async fn share_project(
1933 request: proto::ShareProject,
1934 response: Response<proto::ShareProject>,
1935 session: UserSession,
1936) -> Result<()> {
1937 let (project_id, room) = &*session
1938 .db()
1939 .await
1940 .share_project(
1941 RoomId::from_proto(request.room_id),
1942 session.connection_id,
1943 &request.worktrees,
1944 )
1945 .await?;
1946 response.send(proto::ShareProjectResponse {
1947 project_id: project_id.to_proto(),
1948 })?;
1949 room_updated(&room, &session.peer);
1950
1951 Ok(())
1952}
1953
1954/// Unshare a project from the room.
1955async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1956 let project_id = ProjectId::from_proto(message.project_id);
1957 unshare_project_internal(project_id, &session).await
1958}
1959
1960async fn unshare_project_internal(project_id: ProjectId, session: &Session) -> Result<()> {
1961 let (room, guest_connection_ids) = &*session
1962 .db()
1963 .await
1964 .unshare_project(project_id, session.connection_id)
1965 .await?;
1966
1967 let message = proto::UnshareProject {
1968 project_id: project_id.to_proto(),
1969 };
1970
1971 broadcast(
1972 Some(session.connection_id),
1973 guest_connection_ids.iter().copied(),
1974 |conn_id| session.peer.send(conn_id, message.clone()),
1975 );
1976 if let Some(room) = room {
1977 room_updated(room, &session.peer);
1978 }
1979
1980 Ok(())
1981}
1982
1983/// Share a project into the room.
1984async fn share_remote_project(
1985 request: proto::ShareRemoteProject,
1986 response: Response<proto::ShareRemoteProject>,
1987 session: DevServerSession,
1988) -> Result<()> {
1989 let remote_project = session
1990 .db()
1991 .await
1992 .share_remote_project(
1993 RemoteProjectId::from_proto(request.remote_project_id),
1994 session.dev_server_id(),
1995 session.connection_id,
1996 &request.worktrees,
1997 )
1998 .await?;
1999 let Some(project_id) = remote_project.project_id else {
2000 return Err(anyhow!("failed to share remote project"))?;
2001 };
2002
2003 for (connection_id, _) in session
2004 .connection_pool()
2005 .await
2006 .channel_connection_ids(ChannelId::from_proto(remote_project.channel_id))
2007 {
2008 session
2009 .peer
2010 .send(
2011 connection_id,
2012 proto::UpdateChannels {
2013 remote_projects: vec![remote_project.clone()],
2014 ..Default::default()
2015 },
2016 )
2017 .trace_err();
2018 }
2019
2020 response.send(proto::ShareProjectResponse { project_id })?;
2021
2022 Ok(())
2023}
2024
2025/// Join someone elses shared project.
2026async fn join_project(
2027 request: proto::JoinProject,
2028 response: Response<proto::JoinProject>,
2029 session: UserSession,
2030) -> Result<()> {
2031 let project_id = ProjectId::from_proto(request.project_id);
2032
2033 tracing::info!(%project_id, "join project");
2034
2035 let db = session.db().await;
2036 let (project, replica_id) = &mut *db
2037 .join_project(project_id, session.connection_id, session.user_id())
2038 .await?;
2039 drop(db);
2040 tracing::info!(%project_id, "join remote project");
2041 join_project_internal(response, session, project, replica_id)
2042}
2043
2044trait JoinProjectInternalResponse {
2045 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2046}
2047impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2048 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2049 Response::<proto::JoinProject>::send(self, result)
2050 }
2051}
2052impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2053 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2054 Response::<proto::JoinHostedProject>::send(self, result)
2055 }
2056}
2057
2058fn join_project_internal(
2059 response: impl JoinProjectInternalResponse,
2060 session: UserSession,
2061 project: &mut Project,
2062 replica_id: &ReplicaId,
2063) -> Result<()> {
2064 let collaborators = project
2065 .collaborators
2066 .iter()
2067 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2068 .map(|collaborator| collaborator.to_proto())
2069 .collect::<Vec<_>>();
2070 let project_id = project.id;
2071 let guest_user_id = session.user_id();
2072
2073 let worktrees = project
2074 .worktrees
2075 .iter()
2076 .map(|(id, worktree)| proto::WorktreeMetadata {
2077 id: *id,
2078 root_name: worktree.root_name.clone(),
2079 visible: worktree.visible,
2080 abs_path: worktree.abs_path.clone(),
2081 })
2082 .collect::<Vec<_>>();
2083
2084 for collaborator in &collaborators {
2085 session
2086 .peer
2087 .send(
2088 collaborator.peer_id.unwrap().into(),
2089 proto::AddProjectCollaborator {
2090 project_id: project_id.to_proto(),
2091 collaborator: Some(proto::Collaborator {
2092 peer_id: Some(session.connection_id.into()),
2093 replica_id: replica_id.0 as u32,
2094 user_id: guest_user_id.to_proto(),
2095 }),
2096 },
2097 )
2098 .trace_err();
2099 }
2100
2101 // First, we send the metadata associated with each worktree.
2102 response.send(proto::JoinProjectResponse {
2103 project_id: project.id.0 as u64,
2104 worktrees: worktrees.clone(),
2105 replica_id: replica_id.0 as u32,
2106 collaborators: collaborators.clone(),
2107 language_servers: project.language_servers.clone(),
2108 role: project.role.into(), // todo
2109 })?;
2110
2111 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2112 #[cfg(any(test, feature = "test-support"))]
2113 const MAX_CHUNK_SIZE: usize = 2;
2114 #[cfg(not(any(test, feature = "test-support")))]
2115 const MAX_CHUNK_SIZE: usize = 256;
2116
2117 // Stream this worktree's entries.
2118 let message = proto::UpdateWorktree {
2119 project_id: project_id.to_proto(),
2120 worktree_id,
2121 abs_path: worktree.abs_path.clone(),
2122 root_name: worktree.root_name,
2123 updated_entries: worktree.entries,
2124 removed_entries: Default::default(),
2125 scan_id: worktree.scan_id,
2126 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2127 updated_repositories: worktree.repository_entries.into_values().collect(),
2128 removed_repositories: Default::default(),
2129 };
2130 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2131 session.peer.send(session.connection_id, update.clone())?;
2132 }
2133
2134 // Stream this worktree's diagnostics.
2135 for summary in worktree.diagnostic_summaries {
2136 session.peer.send(
2137 session.connection_id,
2138 proto::UpdateDiagnosticSummary {
2139 project_id: project_id.to_proto(),
2140 worktree_id: worktree.id,
2141 summary: Some(summary),
2142 },
2143 )?;
2144 }
2145
2146 for settings_file in worktree.settings_files {
2147 session.peer.send(
2148 session.connection_id,
2149 proto::UpdateWorktreeSettings {
2150 project_id: project_id.to_proto(),
2151 worktree_id: worktree.id,
2152 path: settings_file.path,
2153 content: Some(settings_file.content),
2154 },
2155 )?;
2156 }
2157 }
2158
2159 for language_server in &project.language_servers {
2160 session.peer.send(
2161 session.connection_id,
2162 proto::UpdateLanguageServer {
2163 project_id: project_id.to_proto(),
2164 language_server_id: language_server.id,
2165 variant: Some(
2166 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2167 proto::LspDiskBasedDiagnosticsUpdated {},
2168 ),
2169 ),
2170 },
2171 )?;
2172 }
2173
2174 Ok(())
2175}
2176
2177/// Leave someone elses shared project.
2178async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2179 let sender_id = session.connection_id;
2180 let project_id = ProjectId::from_proto(request.project_id);
2181 let db = session.db().await;
2182 if db.is_hosted_project(project_id).await? {
2183 let project = db.leave_hosted_project(project_id, sender_id).await?;
2184 project_left(&project, &session);
2185 return Ok(());
2186 }
2187
2188 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2189 tracing::info!(
2190 %project_id,
2191 host_user_id = ?project.host_user_id,
2192 host_connection_id = ?project.host_connection_id,
2193 "leave project"
2194 );
2195
2196 project_left(&project, &session);
2197 if let Some(room) = room {
2198 room_updated(&room, &session.peer);
2199 }
2200
2201 Ok(())
2202}
2203
2204async fn join_hosted_project(
2205 request: proto::JoinHostedProject,
2206 response: Response<proto::JoinHostedProject>,
2207 session: UserSession,
2208) -> Result<()> {
2209 let (mut project, replica_id) = session
2210 .db()
2211 .await
2212 .join_hosted_project(
2213 ProjectId(request.project_id as i32),
2214 session.user_id(),
2215 session.connection_id,
2216 )
2217 .await?;
2218
2219 join_project_internal(response, session, &mut project, &replica_id)
2220}
2221
2222async fn create_remote_project(
2223 request: proto::CreateRemoteProject,
2224 response: Response<proto::CreateRemoteProject>,
2225 session: UserSession,
2226) -> Result<()> {
2227 let (channel, remote_project) = session
2228 .db()
2229 .await
2230 .create_remote_project(
2231 ChannelId(request.channel_id as i32),
2232 DevServerId(request.dev_server_id as i32),
2233 &request.name,
2234 &request.path,
2235 session.user_id(),
2236 )
2237 .await?;
2238
2239 let projects = session
2240 .db()
2241 .await
2242 .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2243 .await?;
2244
2245 let update = proto::UpdateChannels {
2246 remote_projects: vec![remote_project.to_proto(None)],
2247 ..Default::default()
2248 };
2249 let connection_pool = session.connection_pool().await;
2250 for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2251 if role.can_see_all_descendants() {
2252 session.peer.send(connection_id, update.clone())?;
2253 }
2254 }
2255
2256 let dev_server_id = remote_project.dev_server_id;
2257 let dev_server_connection_id = connection_pool.dev_server_connection_id(dev_server_id);
2258 if let Some(dev_server_connection_id) = dev_server_connection_id {
2259 session.peer.send(
2260 dev_server_connection_id,
2261 proto::DevServerInstructions { projects },
2262 )?;
2263 }
2264
2265 response.send(proto::CreateRemoteProjectResponse {
2266 remote_project: Some(remote_project.to_proto(None)),
2267 })?;
2268 Ok(())
2269}
2270
2271async fn create_dev_server(
2272 request: proto::CreateDevServer,
2273 response: Response<proto::CreateDevServer>,
2274 session: UserSession,
2275) -> Result<()> {
2276 let access_token = auth::random_token();
2277 let hashed_access_token = auth::hash_access_token(&access_token);
2278
2279 let (channel, dev_server) = session
2280 .db()
2281 .await
2282 .create_dev_server(
2283 ChannelId(request.channel_id as i32),
2284 &request.name,
2285 &hashed_access_token,
2286 session.user_id(),
2287 )
2288 .await?;
2289
2290 let update = proto::UpdateChannels {
2291 dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2292 ..Default::default()
2293 };
2294 let connection_pool = session.connection_pool().await;
2295 for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2296 if role.can_see_channel(channel.visibility) {
2297 session.peer.send(connection_id, update.clone())?;
2298 }
2299 }
2300
2301 response.send(proto::CreateDevServerResponse {
2302 dev_server_id: dev_server.id.0 as u64,
2303 channel_id: request.channel_id,
2304 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2305 name: request.name.clone(),
2306 })?;
2307 Ok(())
2308}
2309
2310async fn rejoin_remote_projects(
2311 request: proto::RejoinRemoteProjects,
2312 response: Response<proto::RejoinRemoteProjects>,
2313 session: UserSession,
2314) -> Result<()> {
2315 let mut rejoined_projects = {
2316 let db = session.db().await;
2317 db.rejoin_remote_projects(
2318 &request.rejoined_projects,
2319 session.user_id(),
2320 session.0.connection_id,
2321 )
2322 .await?
2323 };
2324 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2325
2326 response.send(proto::RejoinRemoteProjectsResponse {
2327 rejoined_projects: rejoined_projects
2328 .into_iter()
2329 .map(|project| project.to_proto())
2330 .collect(),
2331 })
2332}
2333
2334async fn reconnect_dev_server(
2335 request: proto::ReconnectDevServer,
2336 response: Response<proto::ReconnectDevServer>,
2337 session: DevServerSession,
2338) -> Result<()> {
2339 let reshared_projects = {
2340 let db = session.db().await;
2341 db.reshare_remote_projects(
2342 &request.reshared_projects,
2343 session.dev_server_id(),
2344 session.0.connection_id,
2345 )
2346 .await?
2347 };
2348
2349 for project in &reshared_projects {
2350 for collaborator in &project.collaborators {
2351 session
2352 .peer
2353 .send(
2354 collaborator.connection_id,
2355 proto::UpdateProjectCollaborator {
2356 project_id: project.id.to_proto(),
2357 old_peer_id: Some(project.old_connection_id.into()),
2358 new_peer_id: Some(session.connection_id.into()),
2359 },
2360 )
2361 .trace_err();
2362 }
2363
2364 broadcast(
2365 Some(session.connection_id),
2366 project
2367 .collaborators
2368 .iter()
2369 .map(|collaborator| collaborator.connection_id),
2370 |connection_id| {
2371 session.peer.forward_send(
2372 session.connection_id,
2373 connection_id,
2374 proto::UpdateProject {
2375 project_id: project.id.to_proto(),
2376 worktrees: project.worktrees.clone(),
2377 },
2378 )
2379 },
2380 );
2381 }
2382
2383 response.send(proto::ReconnectDevServerResponse {
2384 reshared_projects: reshared_projects
2385 .iter()
2386 .map(|project| proto::ResharedProject {
2387 id: project.id.to_proto(),
2388 collaborators: project
2389 .collaborators
2390 .iter()
2391 .map(|collaborator| collaborator.to_proto())
2392 .collect(),
2393 })
2394 .collect(),
2395 })?;
2396
2397 Ok(())
2398}
2399
2400async fn shutdown_dev_server(
2401 _: proto::ShutdownDevServer,
2402 response: Response<proto::ShutdownDevServer>,
2403 session: DevServerSession,
2404) -> Result<()> {
2405 response.send(proto::Ack {})?;
2406 let (remote_projects, dev_server) = {
2407 let dev_server_id = session.dev_server_id();
2408 let db = session.db().await;
2409 let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2410 let dev_server = db.get_dev_server(dev_server_id).await?;
2411 (remote_projects, dev_server)
2412 };
2413
2414 for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2415 unshare_project_internal(ProjectId::from_proto(project_id), &session.0).await?;
2416 }
2417
2418 let update = proto::UpdateChannels {
2419 remote_projects,
2420 dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2421 ..Default::default()
2422 };
2423
2424 for (connection_id, _) in session
2425 .connection_pool()
2426 .await
2427 .channel_connection_ids(dev_server.channel_id)
2428 {
2429 session.peer.send(connection_id, update.clone()).trace_err();
2430 }
2431
2432 Ok(())
2433}
2434
2435/// Updates other participants with changes to the project
2436async fn update_project(
2437 request: proto::UpdateProject,
2438 response: Response<proto::UpdateProject>,
2439 session: Session,
2440) -> Result<()> {
2441 let project_id = ProjectId::from_proto(request.project_id);
2442 let (room, guest_connection_ids) = &*session
2443 .db()
2444 .await
2445 .update_project(project_id, session.connection_id, &request.worktrees)
2446 .await?;
2447 broadcast(
2448 Some(session.connection_id),
2449 guest_connection_ids.iter().copied(),
2450 |connection_id| {
2451 session
2452 .peer
2453 .forward_send(session.connection_id, connection_id, request.clone())
2454 },
2455 );
2456 if let Some(room) = room {
2457 room_updated(&room, &session.peer);
2458 }
2459 response.send(proto::Ack {})?;
2460
2461 Ok(())
2462}
2463
2464/// Updates other participants with changes to the worktree
2465async fn update_worktree(
2466 request: proto::UpdateWorktree,
2467 response: Response<proto::UpdateWorktree>,
2468 session: Session,
2469) -> Result<()> {
2470 let guest_connection_ids = session
2471 .db()
2472 .await
2473 .update_worktree(&request, session.connection_id)
2474 .await?;
2475
2476 broadcast(
2477 Some(session.connection_id),
2478 guest_connection_ids.iter().copied(),
2479 |connection_id| {
2480 session
2481 .peer
2482 .forward_send(session.connection_id, connection_id, request.clone())
2483 },
2484 );
2485 response.send(proto::Ack {})?;
2486 Ok(())
2487}
2488
2489/// Updates other participants with changes to the diagnostics
2490async fn update_diagnostic_summary(
2491 message: proto::UpdateDiagnosticSummary,
2492 session: Session,
2493) -> Result<()> {
2494 let guest_connection_ids = session
2495 .db()
2496 .await
2497 .update_diagnostic_summary(&message, session.connection_id)
2498 .await?;
2499
2500 broadcast(
2501 Some(session.connection_id),
2502 guest_connection_ids.iter().copied(),
2503 |connection_id| {
2504 session
2505 .peer
2506 .forward_send(session.connection_id, connection_id, message.clone())
2507 },
2508 );
2509
2510 Ok(())
2511}
2512
2513/// Updates other participants with changes to the worktree settings
2514async fn update_worktree_settings(
2515 message: proto::UpdateWorktreeSettings,
2516 session: Session,
2517) -> Result<()> {
2518 let guest_connection_ids = session
2519 .db()
2520 .await
2521 .update_worktree_settings(&message, session.connection_id)
2522 .await?;
2523
2524 broadcast(
2525 Some(session.connection_id),
2526 guest_connection_ids.iter().copied(),
2527 |connection_id| {
2528 session
2529 .peer
2530 .forward_send(session.connection_id, connection_id, message.clone())
2531 },
2532 );
2533
2534 Ok(())
2535}
2536
2537/// Notify other participants that a language server has started.
2538async fn start_language_server(
2539 request: proto::StartLanguageServer,
2540 session: Session,
2541) -> Result<()> {
2542 let guest_connection_ids = session
2543 .db()
2544 .await
2545 .start_language_server(&request, session.connection_id)
2546 .await?;
2547
2548 broadcast(
2549 Some(session.connection_id),
2550 guest_connection_ids.iter().copied(),
2551 |connection_id| {
2552 session
2553 .peer
2554 .forward_send(session.connection_id, connection_id, request.clone())
2555 },
2556 );
2557 Ok(())
2558}
2559
2560/// Notify other participants that a language server has changed.
2561async fn update_language_server(
2562 request: proto::UpdateLanguageServer,
2563 session: Session,
2564) -> Result<()> {
2565 let project_id = ProjectId::from_proto(request.project_id);
2566 let project_connection_ids = session
2567 .db()
2568 .await
2569 .project_connection_ids(project_id, session.connection_id, true)
2570 .await?;
2571 broadcast(
2572 Some(session.connection_id),
2573 project_connection_ids.iter().copied(),
2574 |connection_id| {
2575 session
2576 .peer
2577 .forward_send(session.connection_id, connection_id, request.clone())
2578 },
2579 );
2580 Ok(())
2581}
2582
2583/// forward a project request to the host. These requests should be read only
2584/// as guests are allowed to send them.
2585async fn forward_read_only_project_request<T>(
2586 request: T,
2587 response: Response<T>,
2588 session: UserSession,
2589) -> Result<()>
2590where
2591 T: EntityMessage + RequestMessage,
2592{
2593 let project_id = ProjectId::from_proto(request.remote_entity_id());
2594 let host_connection_id = session
2595 .db()
2596 .await
2597 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2598 .await?;
2599 let payload = session
2600 .peer
2601 .forward_request(session.connection_id, host_connection_id, request)
2602 .await?;
2603 response.send(payload)?;
2604 Ok(())
2605}
2606
2607/// forward a project request to the host. These requests are disallowed
2608/// for guests.
2609async fn forward_mutating_project_request<T>(
2610 request: T,
2611 response: Response<T>,
2612 session: UserSession,
2613) -> Result<()>
2614where
2615 T: EntityMessage + RequestMessage,
2616{
2617 let project_id = ProjectId::from_proto(request.remote_entity_id());
2618 let host_connection_id = session
2619 .db()
2620 .await
2621 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2622 .await?;
2623 let payload = session
2624 .peer
2625 .forward_request(session.connection_id, host_connection_id, request)
2626 .await?;
2627 response.send(payload)?;
2628 Ok(())
2629}
2630
2631/// Notify other participants that a new buffer has been created
2632async fn create_buffer_for_peer(
2633 request: proto::CreateBufferForPeer,
2634 session: Session,
2635) -> Result<()> {
2636 session
2637 .db()
2638 .await
2639 .check_user_is_project_host(
2640 ProjectId::from_proto(request.project_id),
2641 session.connection_id,
2642 )
2643 .await?;
2644 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2645 session
2646 .peer
2647 .forward_send(session.connection_id, peer_id.into(), request)?;
2648 Ok(())
2649}
2650
2651/// Notify other participants that a buffer has been updated. This is
2652/// allowed for guests as long as the update is limited to selections.
2653async fn update_buffer(
2654 request: proto::UpdateBuffer,
2655 response: Response<proto::UpdateBuffer>,
2656 session: Session,
2657) -> Result<()> {
2658 let project_id = ProjectId::from_proto(request.project_id);
2659 let mut capability = Capability::ReadOnly;
2660
2661 for op in request.operations.iter() {
2662 match op.variant {
2663 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2664 Some(_) => capability = Capability::ReadWrite,
2665 }
2666 }
2667
2668 let host = {
2669 let guard = session
2670 .db()
2671 .await
2672 .connections_for_buffer_update(
2673 project_id,
2674 session.principal_id(),
2675 session.connection_id,
2676 capability,
2677 )
2678 .await?;
2679
2680 let (host, guests) = &*guard;
2681
2682 broadcast(
2683 Some(session.connection_id),
2684 guests.clone(),
2685 |connection_id| {
2686 session
2687 .peer
2688 .forward_send(session.connection_id, connection_id, request.clone())
2689 },
2690 );
2691
2692 *host
2693 };
2694
2695 if host != session.connection_id {
2696 session
2697 .peer
2698 .forward_request(session.connection_id, host, request.clone())
2699 .await?;
2700 }
2701
2702 response.send(proto::Ack {})?;
2703 Ok(())
2704}
2705
2706/// Notify other participants that a project has been updated.
2707async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2708 request: T,
2709 session: Session,
2710) -> Result<()> {
2711 let project_id = ProjectId::from_proto(request.remote_entity_id());
2712 let project_connection_ids = session
2713 .db()
2714 .await
2715 .project_connection_ids(project_id, session.connection_id, false)
2716 .await?;
2717
2718 broadcast(
2719 Some(session.connection_id),
2720 project_connection_ids.iter().copied(),
2721 |connection_id| {
2722 session
2723 .peer
2724 .forward_send(session.connection_id, connection_id, request.clone())
2725 },
2726 );
2727 Ok(())
2728}
2729
2730/// Start following another user in a call.
2731async fn follow(
2732 request: proto::Follow,
2733 response: Response<proto::Follow>,
2734 session: UserSession,
2735) -> Result<()> {
2736 let room_id = RoomId::from_proto(request.room_id);
2737 let project_id = request.project_id.map(ProjectId::from_proto);
2738 let leader_id = request
2739 .leader_id
2740 .ok_or_else(|| anyhow!("invalid leader id"))?
2741 .into();
2742 let follower_id = session.connection_id;
2743
2744 session
2745 .db()
2746 .await
2747 .check_room_participants(room_id, leader_id, session.connection_id)
2748 .await?;
2749
2750 let response_payload = session
2751 .peer
2752 .forward_request(session.connection_id, leader_id, request)
2753 .await?;
2754 response.send(response_payload)?;
2755
2756 if let Some(project_id) = project_id {
2757 let room = session
2758 .db()
2759 .await
2760 .follow(room_id, project_id, leader_id, follower_id)
2761 .await?;
2762 room_updated(&room, &session.peer);
2763 }
2764
2765 Ok(())
2766}
2767
2768/// Stop following another user in a call.
2769async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2770 let room_id = RoomId::from_proto(request.room_id);
2771 let project_id = request.project_id.map(ProjectId::from_proto);
2772 let leader_id = request
2773 .leader_id
2774 .ok_or_else(|| anyhow!("invalid leader id"))?
2775 .into();
2776 let follower_id = session.connection_id;
2777
2778 session
2779 .db()
2780 .await
2781 .check_room_participants(room_id, leader_id, session.connection_id)
2782 .await?;
2783
2784 session
2785 .peer
2786 .forward_send(session.connection_id, leader_id, request)?;
2787
2788 if let Some(project_id) = project_id {
2789 let room = session
2790 .db()
2791 .await
2792 .unfollow(room_id, project_id, leader_id, follower_id)
2793 .await?;
2794 room_updated(&room, &session.peer);
2795 }
2796
2797 Ok(())
2798}
2799
2800/// Notify everyone following you of your current location.
2801async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2802 let room_id = RoomId::from_proto(request.room_id);
2803 let database = session.db.lock().await;
2804
2805 let connection_ids = if let Some(project_id) = request.project_id {
2806 let project_id = ProjectId::from_proto(project_id);
2807 database
2808 .project_connection_ids(project_id, session.connection_id, true)
2809 .await?
2810 } else {
2811 database
2812 .room_connection_ids(room_id, session.connection_id)
2813 .await?
2814 };
2815
2816 // For now, don't send view update messages back to that view's current leader.
2817 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2818 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2819 _ => None,
2820 });
2821
2822 for connection_id in connection_ids.iter().cloned() {
2823 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2824 session
2825 .peer
2826 .forward_send(session.connection_id, connection_id, request.clone())?;
2827 }
2828 }
2829 Ok(())
2830}
2831
2832/// Get public data about users.
2833async fn get_users(
2834 request: proto::GetUsers,
2835 response: Response<proto::GetUsers>,
2836 session: Session,
2837) -> Result<()> {
2838 let user_ids = request
2839 .user_ids
2840 .into_iter()
2841 .map(UserId::from_proto)
2842 .collect();
2843 let users = session
2844 .db()
2845 .await
2846 .get_users_by_ids(user_ids)
2847 .await?
2848 .into_iter()
2849 .map(|user| proto::User {
2850 id: user.id.to_proto(),
2851 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2852 github_login: user.github_login,
2853 })
2854 .collect();
2855 response.send(proto::UsersResponse { users })?;
2856 Ok(())
2857}
2858
2859/// Search for users (to invite) buy Github login
2860async fn fuzzy_search_users(
2861 request: proto::FuzzySearchUsers,
2862 response: Response<proto::FuzzySearchUsers>,
2863 session: UserSession,
2864) -> Result<()> {
2865 let query = request.query;
2866 let users = match query.len() {
2867 0 => vec![],
2868 1 | 2 => session
2869 .db()
2870 .await
2871 .get_user_by_github_login(&query)
2872 .await?
2873 .into_iter()
2874 .collect(),
2875 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2876 };
2877 let users = users
2878 .into_iter()
2879 .filter(|user| user.id != session.user_id())
2880 .map(|user| proto::User {
2881 id: user.id.to_proto(),
2882 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2883 github_login: user.github_login,
2884 })
2885 .collect();
2886 response.send(proto::UsersResponse { users })?;
2887 Ok(())
2888}
2889
2890/// Send a contact request to another user.
2891async fn request_contact(
2892 request: proto::RequestContact,
2893 response: Response<proto::RequestContact>,
2894 session: UserSession,
2895) -> Result<()> {
2896 let requester_id = session.user_id();
2897 let responder_id = UserId::from_proto(request.responder_id);
2898 if requester_id == responder_id {
2899 return Err(anyhow!("cannot add yourself as a contact"))?;
2900 }
2901
2902 let notifications = session
2903 .db()
2904 .await
2905 .send_contact_request(requester_id, responder_id)
2906 .await?;
2907
2908 // Update outgoing contact requests of requester
2909 let mut update = proto::UpdateContacts::default();
2910 update.outgoing_requests.push(responder_id.to_proto());
2911 for connection_id in session
2912 .connection_pool()
2913 .await
2914 .user_connection_ids(requester_id)
2915 {
2916 session.peer.send(connection_id, update.clone())?;
2917 }
2918
2919 // Update incoming contact requests of responder
2920 let mut update = proto::UpdateContacts::default();
2921 update
2922 .incoming_requests
2923 .push(proto::IncomingContactRequest {
2924 requester_id: requester_id.to_proto(),
2925 });
2926 let connection_pool = session.connection_pool().await;
2927 for connection_id in connection_pool.user_connection_ids(responder_id) {
2928 session.peer.send(connection_id, update.clone())?;
2929 }
2930
2931 send_notifications(&connection_pool, &session.peer, notifications);
2932
2933 response.send(proto::Ack {})?;
2934 Ok(())
2935}
2936
2937/// Accept or decline a contact request
2938async fn respond_to_contact_request(
2939 request: proto::RespondToContactRequest,
2940 response: Response<proto::RespondToContactRequest>,
2941 session: UserSession,
2942) -> Result<()> {
2943 let responder_id = session.user_id();
2944 let requester_id = UserId::from_proto(request.requester_id);
2945 let db = session.db().await;
2946 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2947 db.dismiss_contact_notification(responder_id, requester_id)
2948 .await?;
2949 } else {
2950 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2951
2952 let notifications = db
2953 .respond_to_contact_request(responder_id, requester_id, accept)
2954 .await?;
2955 let requester_busy = db.is_user_busy(requester_id).await?;
2956 let responder_busy = db.is_user_busy(responder_id).await?;
2957
2958 let pool = session.connection_pool().await;
2959 // Update responder with new contact
2960 let mut update = proto::UpdateContacts::default();
2961 if accept {
2962 update
2963 .contacts
2964 .push(contact_for_user(requester_id, requester_busy, &pool));
2965 }
2966 update
2967 .remove_incoming_requests
2968 .push(requester_id.to_proto());
2969 for connection_id in pool.user_connection_ids(responder_id) {
2970 session.peer.send(connection_id, update.clone())?;
2971 }
2972
2973 // Update requester with new contact
2974 let mut update = proto::UpdateContacts::default();
2975 if accept {
2976 update
2977 .contacts
2978 .push(contact_for_user(responder_id, responder_busy, &pool));
2979 }
2980 update
2981 .remove_outgoing_requests
2982 .push(responder_id.to_proto());
2983
2984 for connection_id in pool.user_connection_ids(requester_id) {
2985 session.peer.send(connection_id, update.clone())?;
2986 }
2987
2988 send_notifications(&pool, &session.peer, notifications);
2989 }
2990
2991 response.send(proto::Ack {})?;
2992 Ok(())
2993}
2994
2995/// Remove a contact.
2996async fn remove_contact(
2997 request: proto::RemoveContact,
2998 response: Response<proto::RemoveContact>,
2999 session: UserSession,
3000) -> Result<()> {
3001 let requester_id = session.user_id();
3002 let responder_id = UserId::from_proto(request.user_id);
3003 let db = session.db().await;
3004 let (contact_accepted, deleted_notification_id) =
3005 db.remove_contact(requester_id, responder_id).await?;
3006
3007 let pool = session.connection_pool().await;
3008 // Update outgoing contact requests of requester
3009 let mut update = proto::UpdateContacts::default();
3010 if contact_accepted {
3011 update.remove_contacts.push(responder_id.to_proto());
3012 } else {
3013 update
3014 .remove_outgoing_requests
3015 .push(responder_id.to_proto());
3016 }
3017 for connection_id in pool.user_connection_ids(requester_id) {
3018 session.peer.send(connection_id, update.clone())?;
3019 }
3020
3021 // Update incoming contact requests of responder
3022 let mut update = proto::UpdateContacts::default();
3023 if contact_accepted {
3024 update.remove_contacts.push(requester_id.to_proto());
3025 } else {
3026 update
3027 .remove_incoming_requests
3028 .push(requester_id.to_proto());
3029 }
3030 for connection_id in pool.user_connection_ids(responder_id) {
3031 session.peer.send(connection_id, update.clone())?;
3032 if let Some(notification_id) = deleted_notification_id {
3033 session.peer.send(
3034 connection_id,
3035 proto::DeleteNotification {
3036 notification_id: notification_id.to_proto(),
3037 },
3038 )?;
3039 }
3040 }
3041
3042 response.send(proto::Ack {})?;
3043 Ok(())
3044}
3045
3046/// Creates a new channel.
3047async fn create_channel(
3048 request: proto::CreateChannel,
3049 response: Response<proto::CreateChannel>,
3050 session: UserSession,
3051) -> Result<()> {
3052 let db = session.db().await;
3053
3054 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3055 let (channel, membership) = db
3056 .create_channel(&request.name, parent_id, session.user_id())
3057 .await?;
3058
3059 let root_id = channel.root_id();
3060 let channel = Channel::from_model(channel);
3061
3062 response.send(proto::CreateChannelResponse {
3063 channel: Some(channel.to_proto()),
3064 parent_id: request.parent_id,
3065 })?;
3066
3067 let mut connection_pool = session.connection_pool().await;
3068 if let Some(membership) = membership {
3069 connection_pool.subscribe_to_channel(
3070 membership.user_id,
3071 membership.channel_id,
3072 membership.role,
3073 );
3074 let update = proto::UpdateUserChannels {
3075 channel_memberships: vec![proto::ChannelMembership {
3076 channel_id: membership.channel_id.to_proto(),
3077 role: membership.role.into(),
3078 }],
3079 ..Default::default()
3080 };
3081 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3082 session.peer.send(connection_id, update.clone())?;
3083 }
3084 }
3085
3086 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3087 if !role.can_see_channel(channel.visibility) {
3088 continue;
3089 }
3090
3091 let update = proto::UpdateChannels {
3092 channels: vec![channel.to_proto()],
3093 ..Default::default()
3094 };
3095 session.peer.send(connection_id, update.clone())?;
3096 }
3097
3098 Ok(())
3099}
3100
3101/// Delete a channel
3102async fn delete_channel(
3103 request: proto::DeleteChannel,
3104 response: Response<proto::DeleteChannel>,
3105 session: UserSession,
3106) -> Result<()> {
3107 let db = session.db().await;
3108
3109 let channel_id = request.channel_id;
3110 let (root_channel, removed_channels) = db
3111 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3112 .await?;
3113 response.send(proto::Ack {})?;
3114
3115 // Notify members of removed channels
3116 let mut update = proto::UpdateChannels::default();
3117 update
3118 .delete_channels
3119 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3120
3121 let connection_pool = session.connection_pool().await;
3122 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3123 session.peer.send(connection_id, update.clone())?;
3124 }
3125
3126 Ok(())
3127}
3128
3129/// Invite someone to join a channel.
3130async fn invite_channel_member(
3131 request: proto::InviteChannelMember,
3132 response: Response<proto::InviteChannelMember>,
3133 session: UserSession,
3134) -> Result<()> {
3135 let db = session.db().await;
3136 let channel_id = ChannelId::from_proto(request.channel_id);
3137 let invitee_id = UserId::from_proto(request.user_id);
3138 let InviteMemberResult {
3139 channel,
3140 notifications,
3141 } = db
3142 .invite_channel_member(
3143 channel_id,
3144 invitee_id,
3145 session.user_id(),
3146 request.role().into(),
3147 )
3148 .await?;
3149
3150 let update = proto::UpdateChannels {
3151 channel_invitations: vec![channel.to_proto()],
3152 ..Default::default()
3153 };
3154
3155 let connection_pool = session.connection_pool().await;
3156 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3157 session.peer.send(connection_id, update.clone())?;
3158 }
3159
3160 send_notifications(&connection_pool, &session.peer, notifications);
3161
3162 response.send(proto::Ack {})?;
3163 Ok(())
3164}
3165
3166/// remove someone from a channel
3167async fn remove_channel_member(
3168 request: proto::RemoveChannelMember,
3169 response: Response<proto::RemoveChannelMember>,
3170 session: UserSession,
3171) -> Result<()> {
3172 let db = session.db().await;
3173 let channel_id = ChannelId::from_proto(request.channel_id);
3174 let member_id = UserId::from_proto(request.user_id);
3175
3176 let RemoveChannelMemberResult {
3177 membership_update,
3178 notification_id,
3179 } = db
3180 .remove_channel_member(channel_id, member_id, session.user_id())
3181 .await?;
3182
3183 let mut connection_pool = session.connection_pool().await;
3184 notify_membership_updated(
3185 &mut connection_pool,
3186 membership_update,
3187 member_id,
3188 &session.peer,
3189 );
3190 for connection_id in connection_pool.user_connection_ids(member_id) {
3191 if let Some(notification_id) = notification_id {
3192 session
3193 .peer
3194 .send(
3195 connection_id,
3196 proto::DeleteNotification {
3197 notification_id: notification_id.to_proto(),
3198 },
3199 )
3200 .trace_err();
3201 }
3202 }
3203
3204 response.send(proto::Ack {})?;
3205 Ok(())
3206}
3207
3208/// Toggle the channel between public and private.
3209/// Care is taken to maintain the invariant that public channels only descend from public channels,
3210/// (though members-only channels can appear at any point in the hierarchy).
3211async fn set_channel_visibility(
3212 request: proto::SetChannelVisibility,
3213 response: Response<proto::SetChannelVisibility>,
3214 session: UserSession,
3215) -> Result<()> {
3216 let db = session.db().await;
3217 let channel_id = ChannelId::from_proto(request.channel_id);
3218 let visibility = request.visibility().into();
3219
3220 let channel_model = db
3221 .set_channel_visibility(channel_id, visibility, session.user_id())
3222 .await?;
3223 let root_id = channel_model.root_id();
3224 let channel = Channel::from_model(channel_model);
3225
3226 let mut connection_pool = session.connection_pool().await;
3227 for (user_id, role) in connection_pool
3228 .channel_user_ids(root_id)
3229 .collect::<Vec<_>>()
3230 .into_iter()
3231 {
3232 let update = if role.can_see_channel(channel.visibility) {
3233 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3234 proto::UpdateChannels {
3235 channels: vec![channel.to_proto()],
3236 ..Default::default()
3237 }
3238 } else {
3239 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3240 proto::UpdateChannels {
3241 delete_channels: vec![channel.id.to_proto()],
3242 ..Default::default()
3243 }
3244 };
3245
3246 for connection_id in connection_pool.user_connection_ids(user_id) {
3247 session.peer.send(connection_id, update.clone())?;
3248 }
3249 }
3250
3251 response.send(proto::Ack {})?;
3252 Ok(())
3253}
3254
3255/// Alter the role for a user in the channel.
3256async fn set_channel_member_role(
3257 request: proto::SetChannelMemberRole,
3258 response: Response<proto::SetChannelMemberRole>,
3259 session: UserSession,
3260) -> Result<()> {
3261 let db = session.db().await;
3262 let channel_id = ChannelId::from_proto(request.channel_id);
3263 let member_id = UserId::from_proto(request.user_id);
3264 let result = db
3265 .set_channel_member_role(
3266 channel_id,
3267 session.user_id(),
3268 member_id,
3269 request.role().into(),
3270 )
3271 .await?;
3272
3273 match result {
3274 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3275 let mut connection_pool = session.connection_pool().await;
3276 notify_membership_updated(
3277 &mut connection_pool,
3278 membership_update,
3279 member_id,
3280 &session.peer,
3281 )
3282 }
3283 db::SetMemberRoleResult::InviteUpdated(channel) => {
3284 let update = proto::UpdateChannels {
3285 channel_invitations: vec![channel.to_proto()],
3286 ..Default::default()
3287 };
3288
3289 for connection_id in session
3290 .connection_pool()
3291 .await
3292 .user_connection_ids(member_id)
3293 {
3294 session.peer.send(connection_id, update.clone())?;
3295 }
3296 }
3297 }
3298
3299 response.send(proto::Ack {})?;
3300 Ok(())
3301}
3302
3303/// Change the name of a channel
3304async fn rename_channel(
3305 request: proto::RenameChannel,
3306 response: Response<proto::RenameChannel>,
3307 session: UserSession,
3308) -> Result<()> {
3309 let db = session.db().await;
3310 let channel_id = ChannelId::from_proto(request.channel_id);
3311 let channel_model = db
3312 .rename_channel(channel_id, session.user_id(), &request.name)
3313 .await?;
3314 let root_id = channel_model.root_id();
3315 let channel = Channel::from_model(channel_model);
3316
3317 response.send(proto::RenameChannelResponse {
3318 channel: Some(channel.to_proto()),
3319 })?;
3320
3321 let connection_pool = session.connection_pool().await;
3322 let update = proto::UpdateChannels {
3323 channels: vec![channel.to_proto()],
3324 ..Default::default()
3325 };
3326 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3327 if role.can_see_channel(channel.visibility) {
3328 session.peer.send(connection_id, update.clone())?;
3329 }
3330 }
3331
3332 Ok(())
3333}
3334
3335/// Move a channel to a new parent.
3336async fn move_channel(
3337 request: proto::MoveChannel,
3338 response: Response<proto::MoveChannel>,
3339 session: UserSession,
3340) -> Result<()> {
3341 let channel_id = ChannelId::from_proto(request.channel_id);
3342 let to = ChannelId::from_proto(request.to);
3343
3344 let (root_id, channels) = session
3345 .db()
3346 .await
3347 .move_channel(channel_id, to, session.user_id())
3348 .await?;
3349
3350 let connection_pool = session.connection_pool().await;
3351 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352 let channels = channels
3353 .iter()
3354 .filter_map(|channel| {
3355 if role.can_see_channel(channel.visibility) {
3356 Some(channel.to_proto())
3357 } else {
3358 None
3359 }
3360 })
3361 .collect::<Vec<_>>();
3362 if channels.is_empty() {
3363 continue;
3364 }
3365
3366 let update = proto::UpdateChannels {
3367 channels,
3368 ..Default::default()
3369 };
3370
3371 session.peer.send(connection_id, update.clone())?;
3372 }
3373
3374 response.send(Ack {})?;
3375 Ok(())
3376}
3377
3378/// Get the list of channel members
3379async fn get_channel_members(
3380 request: proto::GetChannelMembers,
3381 response: Response<proto::GetChannelMembers>,
3382 session: UserSession,
3383) -> Result<()> {
3384 let db = session.db().await;
3385 let channel_id = ChannelId::from_proto(request.channel_id);
3386 let members = db
3387 .get_channel_participant_details(channel_id, session.user_id())
3388 .await?;
3389 response.send(proto::GetChannelMembersResponse { members })?;
3390 Ok(())
3391}
3392
3393/// Accept or decline a channel invitation.
3394async fn respond_to_channel_invite(
3395 request: proto::RespondToChannelInvite,
3396 response: Response<proto::RespondToChannelInvite>,
3397 session: UserSession,
3398) -> Result<()> {
3399 let db = session.db().await;
3400 let channel_id = ChannelId::from_proto(request.channel_id);
3401 let RespondToChannelInvite {
3402 membership_update,
3403 notifications,
3404 } = db
3405 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3406 .await?;
3407
3408 let mut connection_pool = session.connection_pool().await;
3409 if let Some(membership_update) = membership_update {
3410 notify_membership_updated(
3411 &mut connection_pool,
3412 membership_update,
3413 session.user_id(),
3414 &session.peer,
3415 );
3416 } else {
3417 let update = proto::UpdateChannels {
3418 remove_channel_invitations: vec![channel_id.to_proto()],
3419 ..Default::default()
3420 };
3421
3422 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3423 session.peer.send(connection_id, update.clone())?;
3424 }
3425 };
3426
3427 send_notifications(&connection_pool, &session.peer, notifications);
3428
3429 response.send(proto::Ack {})?;
3430
3431 Ok(())
3432}
3433
3434/// Join the channels' room
3435async fn join_channel(
3436 request: proto::JoinChannel,
3437 response: Response<proto::JoinChannel>,
3438 session: UserSession,
3439) -> Result<()> {
3440 let channel_id = ChannelId::from_proto(request.channel_id);
3441 join_channel_internal(channel_id, Box::new(response), session).await
3442}
3443
3444trait JoinChannelInternalResponse {
3445 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3446}
3447impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3448 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3449 Response::<proto::JoinChannel>::send(self, result)
3450 }
3451}
3452impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3453 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3454 Response::<proto::JoinRoom>::send(self, result)
3455 }
3456}
3457
3458async fn join_channel_internal(
3459 channel_id: ChannelId,
3460 response: Box<impl JoinChannelInternalResponse>,
3461 session: UserSession,
3462) -> Result<()> {
3463 let joined_room = {
3464 let mut db = session.db().await;
3465 // If zed quits without leaving the room, and the user re-opens zed before the
3466 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3467 // room they were in.
3468 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3469 tracing::info!(
3470 stale_connection_id = %connection,
3471 "cleaning up stale connection",
3472 );
3473 drop(db);
3474 leave_room_for_session(&session, connection).await?;
3475 db = session.db().await;
3476 }
3477
3478 let (joined_room, membership_updated, role) = db
3479 .join_channel(channel_id, session.user_id(), session.connection_id)
3480 .await?;
3481
3482 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3483 let (can_publish, token) = if role == ChannelRole::Guest {
3484 (
3485 false,
3486 live_kit
3487 .guest_token(
3488 &joined_room.room.live_kit_room,
3489 &session.user_id().to_string(),
3490 )
3491 .trace_err()?,
3492 )
3493 } else {
3494 (
3495 true,
3496 live_kit
3497 .room_token(
3498 &joined_room.room.live_kit_room,
3499 &session.user_id().to_string(),
3500 )
3501 .trace_err()?,
3502 )
3503 };
3504
3505 Some(LiveKitConnectionInfo {
3506 server_url: live_kit.url().into(),
3507 token,
3508 can_publish,
3509 })
3510 });
3511
3512 response.send(proto::JoinRoomResponse {
3513 room: Some(joined_room.room.clone()),
3514 channel_id: joined_room
3515 .channel
3516 .as_ref()
3517 .map(|channel| channel.id.to_proto()),
3518 live_kit_connection_info,
3519 })?;
3520
3521 let mut connection_pool = session.connection_pool().await;
3522 if let Some(membership_updated) = membership_updated {
3523 notify_membership_updated(
3524 &mut connection_pool,
3525 membership_updated,
3526 session.user_id(),
3527 &session.peer,
3528 );
3529 }
3530
3531 room_updated(&joined_room.room, &session.peer);
3532
3533 joined_room
3534 };
3535
3536 channel_updated(
3537 &joined_room
3538 .channel
3539 .ok_or_else(|| anyhow!("channel not returned"))?,
3540 &joined_room.room,
3541 &session.peer,
3542 &*session.connection_pool().await,
3543 );
3544
3545 update_user_contacts(session.user_id(), &session).await?;
3546 Ok(())
3547}
3548
3549/// Start editing the channel notes
3550async fn join_channel_buffer(
3551 request: proto::JoinChannelBuffer,
3552 response: Response<proto::JoinChannelBuffer>,
3553 session: UserSession,
3554) -> Result<()> {
3555 let db = session.db().await;
3556 let channel_id = ChannelId::from_proto(request.channel_id);
3557
3558 let open_response = db
3559 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3560 .await?;
3561
3562 let collaborators = open_response.collaborators.clone();
3563 response.send(open_response)?;
3564
3565 let update = UpdateChannelBufferCollaborators {
3566 channel_id: channel_id.to_proto(),
3567 collaborators: collaborators.clone(),
3568 };
3569 channel_buffer_updated(
3570 session.connection_id,
3571 collaborators
3572 .iter()
3573 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3574 &update,
3575 &session.peer,
3576 );
3577
3578 Ok(())
3579}
3580
3581/// Edit the channel notes
3582async fn update_channel_buffer(
3583 request: proto::UpdateChannelBuffer,
3584 session: UserSession,
3585) -> Result<()> {
3586 let db = session.db().await;
3587 let channel_id = ChannelId::from_proto(request.channel_id);
3588
3589 let (collaborators, non_collaborators, epoch, version) = db
3590 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3591 .await?;
3592
3593 channel_buffer_updated(
3594 session.connection_id,
3595 collaborators,
3596 &proto::UpdateChannelBuffer {
3597 channel_id: channel_id.to_proto(),
3598 operations: request.operations,
3599 },
3600 &session.peer,
3601 );
3602
3603 let pool = &*session.connection_pool().await;
3604
3605 broadcast(
3606 None,
3607 non_collaborators
3608 .iter()
3609 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3610 |peer_id| {
3611 session.peer.send(
3612 peer_id,
3613 proto::UpdateChannels {
3614 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3615 channel_id: channel_id.to_proto(),
3616 epoch: epoch as u64,
3617 version: version.clone(),
3618 }],
3619 ..Default::default()
3620 },
3621 )
3622 },
3623 );
3624
3625 Ok(())
3626}
3627
3628/// Rejoin the channel notes after a connection blip
3629async fn rejoin_channel_buffers(
3630 request: proto::RejoinChannelBuffers,
3631 response: Response<proto::RejoinChannelBuffers>,
3632 session: UserSession,
3633) -> Result<()> {
3634 let db = session.db().await;
3635 let buffers = db
3636 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3637 .await?;
3638
3639 for rejoined_buffer in &buffers {
3640 let collaborators_to_notify = rejoined_buffer
3641 .buffer
3642 .collaborators
3643 .iter()
3644 .filter_map(|c| Some(c.peer_id?.into()));
3645 channel_buffer_updated(
3646 session.connection_id,
3647 collaborators_to_notify,
3648 &proto::UpdateChannelBufferCollaborators {
3649 channel_id: rejoined_buffer.buffer.channel_id,
3650 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3651 },
3652 &session.peer,
3653 );
3654 }
3655
3656 response.send(proto::RejoinChannelBuffersResponse {
3657 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3658 })?;
3659
3660 Ok(())
3661}
3662
3663/// Stop editing the channel notes
3664async fn leave_channel_buffer(
3665 request: proto::LeaveChannelBuffer,
3666 response: Response<proto::LeaveChannelBuffer>,
3667 session: UserSession,
3668) -> Result<()> {
3669 let db = session.db().await;
3670 let channel_id = ChannelId::from_proto(request.channel_id);
3671
3672 let left_buffer = db
3673 .leave_channel_buffer(channel_id, session.connection_id)
3674 .await?;
3675
3676 response.send(Ack {})?;
3677
3678 channel_buffer_updated(
3679 session.connection_id,
3680 left_buffer.connections,
3681 &proto::UpdateChannelBufferCollaborators {
3682 channel_id: channel_id.to_proto(),
3683 collaborators: left_buffer.collaborators,
3684 },
3685 &session.peer,
3686 );
3687
3688 Ok(())
3689}
3690
3691fn channel_buffer_updated<T: EnvelopedMessage>(
3692 sender_id: ConnectionId,
3693 collaborators: impl IntoIterator<Item = ConnectionId>,
3694 message: &T,
3695 peer: &Peer,
3696) {
3697 broadcast(Some(sender_id), collaborators, |peer_id| {
3698 peer.send(peer_id, message.clone())
3699 });
3700}
3701
3702fn send_notifications(
3703 connection_pool: &ConnectionPool,
3704 peer: &Peer,
3705 notifications: db::NotificationBatch,
3706) {
3707 for (user_id, notification) in notifications {
3708 for connection_id in connection_pool.user_connection_ids(user_id) {
3709 if let Err(error) = peer.send(
3710 connection_id,
3711 proto::AddNotification {
3712 notification: Some(notification.clone()),
3713 },
3714 ) {
3715 tracing::error!(
3716 "failed to send notification to {:?} {}",
3717 connection_id,
3718 error
3719 );
3720 }
3721 }
3722 }
3723}
3724
3725/// Send a message to the channel
3726async fn send_channel_message(
3727 request: proto::SendChannelMessage,
3728 response: Response<proto::SendChannelMessage>,
3729 session: UserSession,
3730) -> Result<()> {
3731 // Validate the message body.
3732 let body = request.body.trim().to_string();
3733 if body.len() > MAX_MESSAGE_LEN {
3734 return Err(anyhow!("message is too long"))?;
3735 }
3736 if body.is_empty() {
3737 return Err(anyhow!("message can't be blank"))?;
3738 }
3739
3740 // TODO: adjust mentions if body is trimmed
3741
3742 let timestamp = OffsetDateTime::now_utc();
3743 let nonce = request
3744 .nonce
3745 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3746
3747 let channel_id = ChannelId::from_proto(request.channel_id);
3748 let CreatedChannelMessage {
3749 message_id,
3750 participant_connection_ids,
3751 channel_members,
3752 notifications,
3753 } = session
3754 .db()
3755 .await
3756 .create_channel_message(
3757 channel_id,
3758 session.user_id(),
3759 &body,
3760 &request.mentions,
3761 timestamp,
3762 nonce.clone().into(),
3763 match request.reply_to_message_id {
3764 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3765 None => None,
3766 },
3767 )
3768 .await?;
3769
3770 let message = proto::ChannelMessage {
3771 sender_id: session.user_id().to_proto(),
3772 id: message_id.to_proto(),
3773 body,
3774 mentions: request.mentions,
3775 timestamp: timestamp.unix_timestamp() as u64,
3776 nonce: Some(nonce),
3777 reply_to_message_id: request.reply_to_message_id,
3778 edited_at: None,
3779 };
3780 broadcast(
3781 Some(session.connection_id),
3782 participant_connection_ids,
3783 |connection| {
3784 session.peer.send(
3785 connection,
3786 proto::ChannelMessageSent {
3787 channel_id: channel_id.to_proto(),
3788 message: Some(message.clone()),
3789 },
3790 )
3791 },
3792 );
3793 response.send(proto::SendChannelMessageResponse {
3794 message: Some(message),
3795 })?;
3796
3797 let pool = &*session.connection_pool().await;
3798 broadcast(
3799 None,
3800 channel_members
3801 .iter()
3802 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3803 |peer_id| {
3804 session.peer.send(
3805 peer_id,
3806 proto::UpdateChannels {
3807 latest_channel_message_ids: vec![proto::ChannelMessageId {
3808 channel_id: channel_id.to_proto(),
3809 message_id: message_id.to_proto(),
3810 }],
3811 ..Default::default()
3812 },
3813 )
3814 },
3815 );
3816 send_notifications(pool, &session.peer, notifications);
3817
3818 Ok(())
3819}
3820
3821/// Delete a channel message
3822async fn remove_channel_message(
3823 request: proto::RemoveChannelMessage,
3824 response: Response<proto::RemoveChannelMessage>,
3825 session: UserSession,
3826) -> Result<()> {
3827 let channel_id = ChannelId::from_proto(request.channel_id);
3828 let message_id = MessageId::from_proto(request.message_id);
3829 let (connection_ids, existing_notification_ids) = session
3830 .db()
3831 .await
3832 .remove_channel_message(channel_id, message_id, session.user_id())
3833 .await?;
3834
3835 broadcast(
3836 Some(session.connection_id),
3837 connection_ids,
3838 move |connection| {
3839 session.peer.send(connection, request.clone())?;
3840
3841 for notification_id in &existing_notification_ids {
3842 session.peer.send(
3843 connection,
3844 proto::DeleteNotification {
3845 notification_id: (*notification_id).to_proto(),
3846 },
3847 )?;
3848 }
3849
3850 Ok(())
3851 },
3852 );
3853 response.send(proto::Ack {})?;
3854 Ok(())
3855}
3856
3857async fn update_channel_message(
3858 request: proto::UpdateChannelMessage,
3859 response: Response<proto::UpdateChannelMessage>,
3860 session: UserSession,
3861) -> Result<()> {
3862 let channel_id = ChannelId::from_proto(request.channel_id);
3863 let message_id = MessageId::from_proto(request.message_id);
3864 let updated_at = OffsetDateTime::now_utc();
3865 let UpdatedChannelMessage {
3866 message_id,
3867 participant_connection_ids,
3868 notifications,
3869 reply_to_message_id,
3870 timestamp,
3871 deleted_mention_notification_ids,
3872 updated_mention_notifications,
3873 } = session
3874 .db()
3875 .await
3876 .update_channel_message(
3877 channel_id,
3878 message_id,
3879 session.user_id(),
3880 request.body.as_str(),
3881 &request.mentions,
3882 updated_at,
3883 )
3884 .await?;
3885
3886 let nonce = request
3887 .nonce
3888 .clone()
3889 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3890
3891 let message = proto::ChannelMessage {
3892 sender_id: session.user_id().to_proto(),
3893 id: message_id.to_proto(),
3894 body: request.body.clone(),
3895 mentions: request.mentions.clone(),
3896 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3897 nonce: Some(nonce),
3898 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3899 edited_at: Some(updated_at.unix_timestamp() as u64),
3900 };
3901
3902 response.send(proto::Ack {})?;
3903
3904 let pool = &*session.connection_pool().await;
3905 broadcast(
3906 Some(session.connection_id),
3907 participant_connection_ids,
3908 |connection| {
3909 session.peer.send(
3910 connection,
3911 proto::ChannelMessageUpdate {
3912 channel_id: channel_id.to_proto(),
3913 message: Some(message.clone()),
3914 },
3915 )?;
3916
3917 for notification_id in &deleted_mention_notification_ids {
3918 session.peer.send(
3919 connection,
3920 proto::DeleteNotification {
3921 notification_id: (*notification_id).to_proto(),
3922 },
3923 )?;
3924 }
3925
3926 for notification in &updated_mention_notifications {
3927 session.peer.send(
3928 connection,
3929 proto::UpdateNotification {
3930 notification: Some(notification.clone()),
3931 },
3932 )?;
3933 }
3934
3935 Ok(())
3936 },
3937 );
3938
3939 send_notifications(pool, &session.peer, notifications);
3940
3941 Ok(())
3942}
3943
3944/// Mark a channel message as read
3945async fn acknowledge_channel_message(
3946 request: proto::AckChannelMessage,
3947 session: UserSession,
3948) -> Result<()> {
3949 let channel_id = ChannelId::from_proto(request.channel_id);
3950 let message_id = MessageId::from_proto(request.message_id);
3951 let notifications = session
3952 .db()
3953 .await
3954 .observe_channel_message(channel_id, session.user_id(), message_id)
3955 .await?;
3956 send_notifications(
3957 &*session.connection_pool().await,
3958 &session.peer,
3959 notifications,
3960 );
3961 Ok(())
3962}
3963
3964/// Mark a buffer version as synced
3965async fn acknowledge_buffer_version(
3966 request: proto::AckBufferOperation,
3967 session: UserSession,
3968) -> Result<()> {
3969 let buffer_id = BufferId::from_proto(request.buffer_id);
3970 session
3971 .db()
3972 .await
3973 .observe_buffer_version(
3974 buffer_id,
3975 session.user_id(),
3976 request.epoch as i32,
3977 &request.version,
3978 )
3979 .await?;
3980 Ok(())
3981}
3982
3983struct CompleteWithLanguageModelRateLimit;
3984
3985impl RateLimit for CompleteWithLanguageModelRateLimit {
3986 fn capacity() -> usize {
3987 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3988 .ok()
3989 .and_then(|v| v.parse().ok())
3990 .unwrap_or(120) // Picked arbitrarily
3991 }
3992
3993 fn refill_duration() -> chrono::Duration {
3994 chrono::Duration::hours(1)
3995 }
3996
3997 fn db_name() -> &'static str {
3998 "complete-with-language-model"
3999 }
4000}
4001
4002async fn complete_with_language_model(
4003 request: proto::CompleteWithLanguageModel,
4004 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4005 session: Session,
4006 open_ai_api_key: Option<Arc<str>>,
4007 google_ai_api_key: Option<Arc<str>>,
4008 anthropic_api_key: Option<Arc<str>>,
4009) -> Result<()> {
4010 let Some(session) = session.for_user() else {
4011 return Err(anyhow!("user not found"))?;
4012 };
4013 authorize_access_to_language_models(&session).await?;
4014 session
4015 .rate_limiter
4016 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4017 .await?;
4018
4019 if request.model.starts_with("gpt") {
4020 let api_key =
4021 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4022 complete_with_open_ai(request, response, session, api_key).await?;
4023 } else if request.model.starts_with("gemini") {
4024 let api_key = google_ai_api_key
4025 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4026 complete_with_google_ai(request, response, session, api_key).await?;
4027 } else if request.model.starts_with("claude") {
4028 let api_key = anthropic_api_key
4029 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4030 complete_with_anthropic(request, response, session, api_key).await?;
4031 }
4032
4033 Ok(())
4034}
4035
4036async fn complete_with_open_ai(
4037 request: proto::CompleteWithLanguageModel,
4038 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4039 session: UserSession,
4040 api_key: Arc<str>,
4041) -> Result<()> {
4042 let mut completion_stream = open_ai::stream_completion(
4043 &session.http_client,
4044 OPEN_AI_API_URL,
4045 &api_key,
4046 crate::ai::language_model_request_to_open_ai(request)?,
4047 )
4048 .await
4049 .context("open_ai::stream_completion request failed")?;
4050
4051 while let Some(event) = completion_stream.next().await {
4052 let event = event?;
4053 response.send(proto::LanguageModelResponse {
4054 choices: event
4055 .choices
4056 .into_iter()
4057 .map(|choice| proto::LanguageModelChoiceDelta {
4058 index: choice.index,
4059 delta: Some(proto::LanguageModelResponseMessage {
4060 role: choice.delta.role.map(|role| match role {
4061 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4062 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4063 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4064 } as i32),
4065 content: choice.delta.content,
4066 }),
4067 finish_reason: choice.finish_reason,
4068 })
4069 .collect(),
4070 })?;
4071 }
4072
4073 Ok(())
4074}
4075
4076async fn complete_with_google_ai(
4077 request: proto::CompleteWithLanguageModel,
4078 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4079 session: UserSession,
4080 api_key: Arc<str>,
4081) -> Result<()> {
4082 let mut stream = google_ai::stream_generate_content(
4083 &session.http_client,
4084 google_ai::API_URL,
4085 api_key.as_ref(),
4086 crate::ai::language_model_request_to_google_ai(request)?,
4087 )
4088 .await
4089 .context("google_ai::stream_generate_content request failed")?;
4090
4091 while let Some(event) = stream.next().await {
4092 let event = event?;
4093 response.send(proto::LanguageModelResponse {
4094 choices: event
4095 .candidates
4096 .unwrap_or_default()
4097 .into_iter()
4098 .map(|candidate| proto::LanguageModelChoiceDelta {
4099 index: candidate.index as u32,
4100 delta: Some(proto::LanguageModelResponseMessage {
4101 role: Some(match candidate.content.role {
4102 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4103 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4104 } as i32),
4105 content: Some(
4106 candidate
4107 .content
4108 .parts
4109 .into_iter()
4110 .filter_map(|part| match part {
4111 google_ai::Part::TextPart(part) => Some(part.text),
4112 google_ai::Part::InlineDataPart(_) => None,
4113 })
4114 .collect(),
4115 ),
4116 }),
4117 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4118 })
4119 .collect(),
4120 })?;
4121 }
4122
4123 Ok(())
4124}
4125
4126async fn complete_with_anthropic(
4127 request: proto::CompleteWithLanguageModel,
4128 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4129 session: UserSession,
4130 api_key: Arc<str>,
4131) -> Result<()> {
4132 let model = anthropic::Model::from_id(&request.model)?;
4133
4134 let mut system_message = String::new();
4135 let messages = request
4136 .messages
4137 .into_iter()
4138 .filter_map(|message| match message.role() {
4139 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4140 role: anthropic::Role::User,
4141 content: message.content,
4142 }),
4143 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4144 role: anthropic::Role::Assistant,
4145 content: message.content,
4146 }),
4147 // Anthropic's API breaks system instructions out as a separate field rather
4148 // than having a system message role.
4149 LanguageModelRole::LanguageModelSystem => {
4150 if !system_message.is_empty() {
4151 system_message.push_str("\n\n");
4152 }
4153 system_message.push_str(&message.content);
4154
4155 None
4156 }
4157 })
4158 .collect();
4159
4160 let mut stream = anthropic::stream_completion(
4161 &session.http_client,
4162 "https://api.anthropic.com",
4163 &api_key,
4164 anthropic::Request {
4165 model,
4166 messages,
4167 stream: true,
4168 system: system_message,
4169 max_tokens: 4092,
4170 },
4171 )
4172 .await?;
4173
4174 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4175
4176 while let Some(event) = stream.next().await {
4177 let event = event?;
4178
4179 match event {
4180 anthropic::ResponseEvent::MessageStart { message } => {
4181 if let Some(role) = message.role {
4182 if role == "assistant" {
4183 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4184 } else if role == "user" {
4185 current_role = proto::LanguageModelRole::LanguageModelUser;
4186 }
4187 }
4188 }
4189 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4190 match content_block {
4191 anthropic::ContentBlock::Text { text } => {
4192 if !text.is_empty() {
4193 response.send(proto::LanguageModelResponse {
4194 choices: vec![proto::LanguageModelChoiceDelta {
4195 index: 0,
4196 delta: Some(proto::LanguageModelResponseMessage {
4197 role: Some(current_role as i32),
4198 content: Some(text),
4199 }),
4200 finish_reason: None,
4201 }],
4202 })?;
4203 }
4204 }
4205 }
4206 }
4207 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4208 anthropic::TextDelta::TextDelta { text } => {
4209 response.send(proto::LanguageModelResponse {
4210 choices: vec![proto::LanguageModelChoiceDelta {
4211 index: 0,
4212 delta: Some(proto::LanguageModelResponseMessage {
4213 role: Some(current_role as i32),
4214 content: Some(text),
4215 }),
4216 finish_reason: None,
4217 }],
4218 })?;
4219 }
4220 },
4221 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4222 if let Some(stop_reason) = delta.stop_reason {
4223 response.send(proto::LanguageModelResponse {
4224 choices: vec![proto::LanguageModelChoiceDelta {
4225 index: 0,
4226 delta: None,
4227 finish_reason: Some(stop_reason),
4228 }],
4229 })?;
4230 }
4231 }
4232 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4233 anthropic::ResponseEvent::MessageStop {} => {}
4234 anthropic::ResponseEvent::Ping {} => {}
4235 }
4236 }
4237
4238 Ok(())
4239}
4240
4241struct CountTokensWithLanguageModelRateLimit;
4242
4243impl RateLimit for CountTokensWithLanguageModelRateLimit {
4244 fn capacity() -> usize {
4245 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4246 .ok()
4247 .and_then(|v| v.parse().ok())
4248 .unwrap_or(600) // Picked arbitrarily
4249 }
4250
4251 fn refill_duration() -> chrono::Duration {
4252 chrono::Duration::hours(1)
4253 }
4254
4255 fn db_name() -> &'static str {
4256 "count-tokens-with-language-model"
4257 }
4258}
4259
4260async fn count_tokens_with_language_model(
4261 request: proto::CountTokensWithLanguageModel,
4262 response: Response<proto::CountTokensWithLanguageModel>,
4263 session: UserSession,
4264 google_ai_api_key: Option<Arc<str>>,
4265) -> Result<()> {
4266 authorize_access_to_language_models(&session).await?;
4267
4268 if !request.model.starts_with("gemini") {
4269 return Err(anyhow!(
4270 "counting tokens for model: {:?} is not supported",
4271 request.model
4272 ))?;
4273 }
4274
4275 session
4276 .rate_limiter
4277 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4278 .await?;
4279
4280 let api_key = google_ai_api_key
4281 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4282 let tokens_response = google_ai::count_tokens(
4283 &session.http_client,
4284 google_ai::API_URL,
4285 &api_key,
4286 crate::ai::count_tokens_request_to_google_ai(request)?,
4287 )
4288 .await?;
4289 response.send(proto::CountTokensResponse {
4290 token_count: tokens_response.total_tokens as u32,
4291 })?;
4292 Ok(())
4293}
4294
4295struct ComputeEmbeddingsRateLimit;
4296
4297impl RateLimit for ComputeEmbeddingsRateLimit {
4298 fn capacity() -> usize {
4299 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4300 .ok()
4301 .and_then(|v| v.parse().ok())
4302 .unwrap_or(120) // Picked arbitrarily
4303 }
4304
4305 fn refill_duration() -> chrono::Duration {
4306 chrono::Duration::hours(1)
4307 }
4308
4309 fn db_name() -> &'static str {
4310 "compute-embeddings"
4311 }
4312}
4313
4314async fn compute_embeddings(
4315 request: proto::ComputeEmbeddings,
4316 response: Response<proto::ComputeEmbeddings>,
4317 session: UserSession,
4318 api_key: Option<Arc<str>>,
4319) -> Result<()> {
4320 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4321 authorize_access_to_language_models(&session).await?;
4322
4323 session
4324 .rate_limiter
4325 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4326 .await?;
4327
4328 let embeddings = match request.model.as_str() {
4329 "openai/text-embedding-3-small" => {
4330 open_ai::embed(
4331 &session.http_client,
4332 OPEN_AI_API_URL,
4333 &api_key,
4334 OpenAiEmbeddingModel::TextEmbedding3Small,
4335 request.texts.iter().map(|text| text.as_str()),
4336 )
4337 .await?
4338 }
4339 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4340 };
4341
4342 let embeddings = request
4343 .texts
4344 .iter()
4345 .map(|text| {
4346 let mut hasher = sha2::Sha256::new();
4347 hasher.update(text.as_bytes());
4348 let result = hasher.finalize();
4349 result.to_vec()
4350 })
4351 .zip(
4352 embeddings
4353 .data
4354 .into_iter()
4355 .map(|embedding| embedding.embedding),
4356 )
4357 .collect::<HashMap<_, _>>();
4358
4359 let db = session.db().await;
4360 db.save_embeddings(&request.model, &embeddings)
4361 .await
4362 .context("failed to save embeddings")
4363 .trace_err();
4364
4365 response.send(proto::ComputeEmbeddingsResponse {
4366 embeddings: embeddings
4367 .into_iter()
4368 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4369 .collect(),
4370 })?;
4371 Ok(())
4372}
4373
4374struct GetCachedEmbeddingsRateLimit;
4375
4376impl RateLimit for GetCachedEmbeddingsRateLimit {
4377 fn capacity() -> usize {
4378 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4379 .ok()
4380 .and_then(|v| v.parse().ok())
4381 .unwrap_or(120) // Picked arbitrarily
4382 }
4383
4384 fn refill_duration() -> chrono::Duration {
4385 chrono::Duration::hours(1)
4386 }
4387
4388 fn db_name() -> &'static str {
4389 "get-cached-embeddings"
4390 }
4391}
4392
4393async fn get_cached_embeddings(
4394 request: proto::GetCachedEmbeddings,
4395 response: Response<proto::GetCachedEmbeddings>,
4396 session: UserSession,
4397) -> Result<()> {
4398 authorize_access_to_language_models(&session).await?;
4399
4400 session
4401 .rate_limiter
4402 .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4403 .await?;
4404
4405 let db = session.db().await;
4406 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4407
4408 response.send(proto::GetCachedEmbeddingsResponse {
4409 embeddings: embeddings
4410 .into_iter()
4411 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4412 .collect(),
4413 })?;
4414 Ok(())
4415}
4416
4417async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4418 let db = session.db().await;
4419 let flags = db.get_user_flags(session.user_id()).await?;
4420 if flags.iter().any(|flag| flag == "language-models") {
4421 Ok(())
4422 } else {
4423 Err(anyhow!("permission denied"))?
4424 }
4425}
4426
4427/// Start receiving chat updates for a channel
4428async fn join_channel_chat(
4429 request: proto::JoinChannelChat,
4430 response: Response<proto::JoinChannelChat>,
4431 session: UserSession,
4432) -> Result<()> {
4433 let channel_id = ChannelId::from_proto(request.channel_id);
4434
4435 let db = session.db().await;
4436 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4437 .await?;
4438 let messages = db
4439 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4440 .await?;
4441 response.send(proto::JoinChannelChatResponse {
4442 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4443 messages,
4444 })?;
4445 Ok(())
4446}
4447
4448/// Stop receiving chat updates for a channel
4449async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4450 let channel_id = ChannelId::from_proto(request.channel_id);
4451 session
4452 .db()
4453 .await
4454 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4455 .await?;
4456 Ok(())
4457}
4458
4459/// Retrieve the chat history for a channel
4460async fn get_channel_messages(
4461 request: proto::GetChannelMessages,
4462 response: Response<proto::GetChannelMessages>,
4463 session: UserSession,
4464) -> Result<()> {
4465 let channel_id = ChannelId::from_proto(request.channel_id);
4466 let messages = session
4467 .db()
4468 .await
4469 .get_channel_messages(
4470 channel_id,
4471 session.user_id(),
4472 MESSAGE_COUNT_PER_PAGE,
4473 Some(MessageId::from_proto(request.before_message_id)),
4474 )
4475 .await?;
4476 response.send(proto::GetChannelMessagesResponse {
4477 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4478 messages,
4479 })?;
4480 Ok(())
4481}
4482
4483/// Retrieve specific chat messages
4484async fn get_channel_messages_by_id(
4485 request: proto::GetChannelMessagesById,
4486 response: Response<proto::GetChannelMessagesById>,
4487 session: UserSession,
4488) -> Result<()> {
4489 let message_ids = request
4490 .message_ids
4491 .iter()
4492 .map(|id| MessageId::from_proto(*id))
4493 .collect::<Vec<_>>();
4494 let messages = session
4495 .db()
4496 .await
4497 .get_channel_messages_by_id(session.user_id(), &message_ids)
4498 .await?;
4499 response.send(proto::GetChannelMessagesResponse {
4500 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4501 messages,
4502 })?;
4503 Ok(())
4504}
4505
4506/// Retrieve the current users notifications
4507async fn get_notifications(
4508 request: proto::GetNotifications,
4509 response: Response<proto::GetNotifications>,
4510 session: UserSession,
4511) -> Result<()> {
4512 let notifications = session
4513 .db()
4514 .await
4515 .get_notifications(
4516 session.user_id(),
4517 NOTIFICATION_COUNT_PER_PAGE,
4518 request
4519 .before_id
4520 .map(|id| db::NotificationId::from_proto(id)),
4521 )
4522 .await?;
4523 response.send(proto::GetNotificationsResponse {
4524 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4525 notifications,
4526 })?;
4527 Ok(())
4528}
4529
4530/// Mark notifications as read
4531async fn mark_notification_as_read(
4532 request: proto::MarkNotificationRead,
4533 response: Response<proto::MarkNotificationRead>,
4534 session: UserSession,
4535) -> Result<()> {
4536 let database = &session.db().await;
4537 let notifications = database
4538 .mark_notification_as_read_by_id(
4539 session.user_id(),
4540 NotificationId::from_proto(request.notification_id),
4541 )
4542 .await?;
4543 send_notifications(
4544 &*session.connection_pool().await,
4545 &session.peer,
4546 notifications,
4547 );
4548 response.send(proto::Ack {})?;
4549 Ok(())
4550}
4551
4552/// Get the current users information
4553async fn get_private_user_info(
4554 _request: proto::GetPrivateUserInfo,
4555 response: Response<proto::GetPrivateUserInfo>,
4556 session: UserSession,
4557) -> Result<()> {
4558 let db = session.db().await;
4559
4560 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4561 let user = db
4562 .get_user_by_id(session.user_id())
4563 .await?
4564 .ok_or_else(|| anyhow!("user not found"))?;
4565 let flags = db.get_user_flags(session.user_id()).await?;
4566
4567 response.send(proto::GetPrivateUserInfoResponse {
4568 metrics_id,
4569 staff: user.admin,
4570 flags,
4571 })?;
4572 Ok(())
4573}
4574
4575fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4576 match message {
4577 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4578 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4579 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4580 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4581 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4582 code: frame.code.into(),
4583 reason: frame.reason,
4584 })),
4585 }
4586}
4587
4588fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4589 match message {
4590 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4591 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4592 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4593 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4594 AxumMessage::Close(frame) => {
4595 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4596 code: frame.code.into(),
4597 reason: frame.reason,
4598 }))
4599 }
4600 }
4601}
4602
4603fn notify_membership_updated(
4604 connection_pool: &mut ConnectionPool,
4605 result: MembershipUpdated,
4606 user_id: UserId,
4607 peer: &Peer,
4608) {
4609 for membership in &result.new_channels.channel_memberships {
4610 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4611 }
4612 for channel_id in &result.removed_channels {
4613 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4614 }
4615
4616 let user_channels_update = proto::UpdateUserChannels {
4617 channel_memberships: result
4618 .new_channels
4619 .channel_memberships
4620 .iter()
4621 .map(|cm| proto::ChannelMembership {
4622 channel_id: cm.channel_id.to_proto(),
4623 role: cm.role.into(),
4624 })
4625 .collect(),
4626 ..Default::default()
4627 };
4628
4629 let mut update = build_channels_update(result.new_channels, vec![], connection_pool);
4630 update.delete_channels = result
4631 .removed_channels
4632 .into_iter()
4633 .map(|id| id.to_proto())
4634 .collect();
4635 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4636
4637 for connection_id in connection_pool.user_connection_ids(user_id) {
4638 peer.send(connection_id, user_channels_update.clone())
4639 .trace_err();
4640 peer.send(connection_id, update.clone()).trace_err();
4641 }
4642}
4643
4644fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4645 proto::UpdateUserChannels {
4646 channel_memberships: channels
4647 .channel_memberships
4648 .iter()
4649 .map(|m| proto::ChannelMembership {
4650 channel_id: m.channel_id.to_proto(),
4651 role: m.role.into(),
4652 })
4653 .collect(),
4654 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4655 observed_channel_message_id: channels.observed_channel_messages.clone(),
4656 }
4657}
4658
4659fn build_channels_update(
4660 channels: ChannelsForUser,
4661 channel_invites: Vec<db::Channel>,
4662 pool: &ConnectionPool,
4663) -> proto::UpdateChannels {
4664 let mut update = proto::UpdateChannels::default();
4665
4666 for channel in channels.channels {
4667 update.channels.push(channel.to_proto());
4668 }
4669
4670 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4671 update.latest_channel_message_ids = channels.latest_channel_messages;
4672
4673 for (channel_id, participants) in channels.channel_participants {
4674 update
4675 .channel_participants
4676 .push(proto::ChannelParticipants {
4677 channel_id: channel_id.to_proto(),
4678 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4679 });
4680 }
4681
4682 for channel in channel_invites {
4683 update.channel_invitations.push(channel.to_proto());
4684 }
4685
4686 update.hosted_projects = channels.hosted_projects;
4687 update.dev_servers = channels
4688 .dev_servers
4689 .into_iter()
4690 .map(|dev_server| dev_server.to_proto(pool.dev_server_status(dev_server.id)))
4691 .collect();
4692 update.remote_projects = channels.remote_projects;
4693
4694 update
4695}
4696
4697fn build_initial_contacts_update(
4698 contacts: Vec<db::Contact>,
4699 pool: &ConnectionPool,
4700) -> proto::UpdateContacts {
4701 let mut update = proto::UpdateContacts::default();
4702
4703 for contact in contacts {
4704 match contact {
4705 db::Contact::Accepted { user_id, busy } => {
4706 update.contacts.push(contact_for_user(user_id, busy, &pool));
4707 }
4708 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4709 db::Contact::Incoming { user_id } => {
4710 update
4711 .incoming_requests
4712 .push(proto::IncomingContactRequest {
4713 requester_id: user_id.to_proto(),
4714 })
4715 }
4716 }
4717 }
4718
4719 update
4720}
4721
4722fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4723 proto::Contact {
4724 user_id: user_id.to_proto(),
4725 online: pool.is_user_online(user_id),
4726 busy,
4727 }
4728}
4729
4730fn room_updated(room: &proto::Room, peer: &Peer) {
4731 broadcast(
4732 None,
4733 room.participants
4734 .iter()
4735 .filter_map(|participant| Some(participant.peer_id?.into())),
4736 |peer_id| {
4737 peer.send(
4738 peer_id,
4739 proto::RoomUpdated {
4740 room: Some(room.clone()),
4741 },
4742 )
4743 },
4744 );
4745}
4746
4747fn channel_updated(
4748 channel: &db::channel::Model,
4749 room: &proto::Room,
4750 peer: &Peer,
4751 pool: &ConnectionPool,
4752) {
4753 let participants = room
4754 .participants
4755 .iter()
4756 .map(|p| p.user_id)
4757 .collect::<Vec<_>>();
4758
4759 broadcast(
4760 None,
4761 pool.channel_connection_ids(channel.root_id())
4762 .filter_map(|(channel_id, role)| {
4763 role.can_see_channel(channel.visibility).then(|| channel_id)
4764 }),
4765 |peer_id| {
4766 peer.send(
4767 peer_id,
4768 proto::UpdateChannels {
4769 channel_participants: vec![proto::ChannelParticipants {
4770 channel_id: channel.id.to_proto(),
4771 participant_user_ids: participants.clone(),
4772 }],
4773 ..Default::default()
4774 },
4775 )
4776 },
4777 );
4778}
4779
4780async fn update_dev_server_status(
4781 dev_server: &dev_server::Model,
4782 status: proto::DevServerStatus,
4783 session: &Session,
4784) {
4785 let pool = session.connection_pool().await;
4786 let connections = pool.channel_connection_ids(dev_server.channel_id);
4787 for (connection_id, _) in connections {
4788 session
4789 .peer
4790 .send(
4791 connection_id,
4792 proto::UpdateChannels {
4793 dev_servers: vec![dev_server.to_proto(status)],
4794 ..Default::default()
4795 },
4796 )
4797 .trace_err();
4798 }
4799}
4800
4801async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4802 let db = session.db().await;
4803
4804 let contacts = db.get_contacts(user_id).await?;
4805 let busy = db.is_user_busy(user_id).await?;
4806
4807 let pool = session.connection_pool().await;
4808 let updated_contact = contact_for_user(user_id, busy, &pool);
4809 for contact in contacts {
4810 if let db::Contact::Accepted {
4811 user_id: contact_user_id,
4812 ..
4813 } = contact
4814 {
4815 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4816 session
4817 .peer
4818 .send(
4819 contact_conn_id,
4820 proto::UpdateContacts {
4821 contacts: vec![updated_contact.clone()],
4822 remove_contacts: Default::default(),
4823 incoming_requests: Default::default(),
4824 remove_incoming_requests: Default::default(),
4825 outgoing_requests: Default::default(),
4826 remove_outgoing_requests: Default::default(),
4827 },
4828 )
4829 .trace_err();
4830 }
4831 }
4832 }
4833 Ok(())
4834}
4835
4836async fn lost_dev_server_connection(session: &Session) -> Result<()> {
4837 log::info!("lost dev server connection, unsharing projects");
4838 let project_ids = session
4839 .db()
4840 .await
4841 .get_stale_dev_server_projects(session.connection_id)
4842 .await?;
4843
4844 for project_id in project_ids {
4845 // not unshare re-checks the connection ids match, so we get away with no transaction
4846 unshare_project_internal(project_id, &session).await?;
4847 }
4848
4849 Ok(())
4850}
4851
4852async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4853 let mut contacts_to_update = HashSet::default();
4854
4855 let room_id;
4856 let canceled_calls_to_user_ids;
4857 let live_kit_room;
4858 let delete_live_kit_room;
4859 let room;
4860 let channel;
4861
4862 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4863 contacts_to_update.insert(session.user_id());
4864
4865 for project in left_room.left_projects.values() {
4866 project_left(project, session);
4867 }
4868
4869 room_id = RoomId::from_proto(left_room.room.id);
4870 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4871 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4872 delete_live_kit_room = left_room.deleted;
4873 room = mem::take(&mut left_room.room);
4874 channel = mem::take(&mut left_room.channel);
4875
4876 room_updated(&room, &session.peer);
4877 } else {
4878 return Ok(());
4879 }
4880
4881 if let Some(channel) = channel {
4882 channel_updated(
4883 &channel,
4884 &room,
4885 &session.peer,
4886 &*session.connection_pool().await,
4887 );
4888 }
4889
4890 {
4891 let pool = session.connection_pool().await;
4892 for canceled_user_id in canceled_calls_to_user_ids {
4893 for connection_id in pool.user_connection_ids(canceled_user_id) {
4894 session
4895 .peer
4896 .send(
4897 connection_id,
4898 proto::CallCanceled {
4899 room_id: room_id.to_proto(),
4900 },
4901 )
4902 .trace_err();
4903 }
4904 contacts_to_update.insert(canceled_user_id);
4905 }
4906 }
4907
4908 for contact_user_id in contacts_to_update {
4909 update_user_contacts(contact_user_id, &session).await?;
4910 }
4911
4912 if let Some(live_kit) = session.live_kit_client.as_ref() {
4913 live_kit
4914 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4915 .await
4916 .trace_err();
4917
4918 if delete_live_kit_room {
4919 live_kit.delete_room(live_kit_room).await.trace_err();
4920 }
4921 }
4922
4923 Ok(())
4924}
4925
4926async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4927 let left_channel_buffers = session
4928 .db()
4929 .await
4930 .leave_channel_buffers(session.connection_id)
4931 .await?;
4932
4933 for left_buffer in left_channel_buffers {
4934 channel_buffer_updated(
4935 session.connection_id,
4936 left_buffer.connections,
4937 &proto::UpdateChannelBufferCollaborators {
4938 channel_id: left_buffer.channel_id.to_proto(),
4939 collaborators: left_buffer.collaborators,
4940 },
4941 &session.peer,
4942 );
4943 }
4944
4945 Ok(())
4946}
4947
4948fn project_left(project: &db::LeftProject, session: &UserSession) {
4949 for connection_id in &project.connection_ids {
4950 if project.host_user_id == Some(session.user_id()) {
4951 session
4952 .peer
4953 .send(
4954 *connection_id,
4955 proto::UnshareProject {
4956 project_id: project.id.to_proto(),
4957 },
4958 )
4959 .trace_err();
4960 } else {
4961 session
4962 .peer
4963 .send(
4964 *connection_id,
4965 proto::RemoveProjectCollaborator {
4966 project_id: project.id.to_proto(),
4967 peer_id: Some(session.connection_id.into()),
4968 },
4969 )
4970 .trace_err();
4971 }
4972 }
4973}
4974
4975pub trait ResultExt {
4976 type Ok;
4977
4978 fn trace_err(self) -> Option<Self::Ok>;
4979}
4980
4981impl<T, E> ResultExt for Result<T, E>
4982where
4983 E: std::fmt::Debug,
4984{
4985 type Ok = T;
4986
4987 #[track_caller]
4988 fn trace_err(self) -> Option<T> {
4989 match self {
4990 Ok(value) => Some(value),
4991 Err(error) => {
4992 tracing::error!("{:?}", error);
4993 None
4994 }
4995 }
4996 }
4997}