1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
8 MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
9 RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37
38use futures::{
39 channel::oneshot,
40 future::{self, BoxFuture},
41 stream::FuturesUnordered,
42 FutureExt, SinkExt, StreamExt, TryStreamExt,
43};
44use prometheus::{register_int_gauge, IntGauge};
45use rpc::{
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
48 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 atomic::{AtomicBool, Ordering::SeqCst},
64 Arc, OnceLock,
65 },
66 time::{Duration, Instant},
67};
68use time::OffsetDateTime;
69use tokio::sync::{watch, Semaphore};
70use tower::ServiceBuilder;
71use tracing::{
72 field::{self},
73 info_span, instrument, Instrument,
74};
75use util::http::IsahcHttpClient;
76
77use self::connection_pool::VersionedMessage;
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105struct StreamingResponse<R: RequestMessage> {
106 peer: Arc<Peer>,
107 receipt: Receipt<R>,
108}
109
110impl<R: RequestMessage> StreamingResponse<R> {
111 fn send(&self, payload: R::Response) -> Result<()> {
112 self.peer.respond(self.receipt, payload)?;
113 Ok(())
114 }
115}
116
117#[derive(Clone, Debug)]
118pub enum Principal {
119 User(User),
120 Impersonated { user: User, admin: User },
121 DevServer(dev_server::Model),
122}
123
124impl Principal {
125 fn update_span(&self, span: &tracing::Span) {
126 match &self {
127 Principal::User(user) => {
128 span.record("user_id", &user.id.0);
129 span.record("login", &user.github_login);
130 }
131 Principal::Impersonated { user, admin } => {
132 span.record("user_id", &user.id.0);
133 span.record("login", &user.github_login);
134 span.record("impersonator", &admin.github_login);
135 }
136 Principal::DevServer(dev_server) => {
137 span.record("dev_server_id", &dev_server.id.0);
138 }
139 }
140 }
141}
142
143#[derive(Clone)]
144struct Session {
145 principal: Principal,
146 connection_id: ConnectionId,
147 db: Arc<tokio::sync::Mutex<DbHandle>>,
148 peer: Arc<Peer>,
149 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
150 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
151 http_client: IsahcHttpClient,
152 rate_limiter: Arc<RateLimiter>,
153 _executor: Executor,
154}
155
156impl Session {
157 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
158 #[cfg(test)]
159 tokio::task::yield_now().await;
160 let guard = self.db.lock().await;
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 guard
164 }
165
166 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
167 #[cfg(test)]
168 tokio::task::yield_now().await;
169 let guard = self.connection_pool.lock();
170 ConnectionPoolGuard {
171 guard,
172 _not_send: PhantomData,
173 }
174 }
175
176 fn for_user(self) -> Option<UserSession> {
177 UserSession::new(self)
178 }
179
180 fn for_dev_server(self) -> Option<DevServerSession> {
181 DevServerSession::new(self)
182 }
183
184 fn user_id(&self) -> Option<UserId> {
185 match &self.principal {
186 Principal::User(user) => Some(user.id),
187 Principal::Impersonated { user, .. } => Some(user.id),
188 Principal::DevServer(_) => None,
189 }
190 }
191
192 fn dev_server_id(&self) -> Option<DevServerId> {
193 match &self.principal {
194 Principal::User(_) | Principal::Impersonated { .. } => None,
195 Principal::DevServer(dev_server) => Some(dev_server.id),
196 }
197 }
198
199 fn principal_id(&self) -> PrincipalId {
200 match &self.principal {
201 Principal::User(user) => PrincipalId::UserId(user.id),
202 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
203 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
204 }
205 }
206}
207
208impl Debug for Session {
209 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
210 let mut result = f.debug_struct("Session");
211 match &self.principal {
212 Principal::User(user) => {
213 result.field("user", &user.github_login);
214 }
215 Principal::Impersonated { user, admin } => {
216 result.field("user", &user.github_login);
217 result.field("impersonator", &admin.github_login);
218 }
219 Principal::DevServer(dev_server) => {
220 result.field("dev_server", &dev_server.id);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct UserSession(Session);
228
229impl UserSession {
230 pub fn new(s: Session) -> Option<Self> {
231 s.user_id().map(|_| UserSession(s))
232 }
233 pub fn user_id(&self) -> UserId {
234 self.0.user_id().unwrap()
235 }
236}
237
238impl Deref for UserSession {
239 type Target = Session;
240
241 fn deref(&self) -> &Self::Target {
242 &self.0
243 }
244}
245impl DerefMut for UserSession {
246 fn deref_mut(&mut self) -> &mut Self::Target {
247 &mut self.0
248 }
249}
250
251struct DevServerSession(Session);
252
253impl DevServerSession {
254 pub fn new(s: Session) -> Option<Self> {
255 s.dev_server_id().map(|_| DevServerSession(s))
256 }
257 pub fn dev_server_id(&self) -> DevServerId {
258 self.0.dev_server_id().unwrap()
259 }
260
261 fn dev_server(&self) -> &dev_server::Model {
262 match &self.0.principal {
263 Principal::DevServer(dev_server) => dev_server,
264 _ => unreachable!(),
265 }
266 }
267}
268
269impl Deref for DevServerSession {
270 type Target = Session;
271
272 fn deref(&self) -> &Self::Target {
273 &self.0
274 }
275}
276impl DerefMut for DevServerSession {
277 fn deref_mut(&mut self) -> &mut Self::Target {
278 &mut self.0
279 }
280}
281
282fn user_handler<M: RequestMessage, Fut>(
283 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
284) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
285where
286 Fut: Send + Future<Output = Result<()>>,
287{
288 let handler = Arc::new(handler);
289 move |message, response, session| {
290 let handler = handler.clone();
291 Box::pin(async move {
292 if let Some(user_session) = session.for_user() {
293 Ok(handler(message, response, user_session).await?)
294 } else {
295 Err(Error::Internal(anyhow!(
296 "must be a user to call {}",
297 M::NAME
298 )))
299 }
300 })
301 }
302}
303
304fn dev_server_handler<M: RequestMessage, Fut>(
305 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
306) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
307where
308 Fut: Send + Future<Output = Result<()>>,
309{
310 let handler = Arc::new(handler);
311 move |message, response, session| {
312 let handler = handler.clone();
313 Box::pin(async move {
314 if let Some(dev_server_session) = session.for_dev_server() {
315 Ok(handler(message, response, dev_server_session).await?)
316 } else {
317 Err(Error::Internal(anyhow!(
318 "must be a dev server to call {}",
319 M::NAME
320 )))
321 }
322 })
323 }
324}
325
326fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
327 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
328) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
329where
330 InnertRetFut: Send + Future<Output = Result<()>>,
331{
332 let handler = Arc::new(handler);
333 move |message, session| {
334 let handler = handler.clone();
335 Box::pin(async move {
336 if let Some(user_session) = session.for_user() {
337 Ok(handler(message, user_session).await?)
338 } else {
339 Err(Error::Internal(anyhow!(
340 "must be a user to call {}",
341 M::NAME
342 )))
343 }
344 })
345 }
346}
347
348struct DbHandle(Arc<Database>);
349
350impl Deref for DbHandle {
351 type Target = Database;
352
353 fn deref(&self) -> &Self::Target {
354 self.0.as_ref()
355 }
356}
357
358pub struct Server {
359 id: parking_lot::Mutex<ServerId>,
360 peer: Arc<Peer>,
361 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
362 app_state: Arc<AppState>,
363 handlers: HashMap<TypeId, MessageHandler>,
364 teardown: watch::Sender<bool>,
365}
366
367pub(crate) struct ConnectionPoolGuard<'a> {
368 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
369 _not_send: PhantomData<Rc<()>>,
370}
371
372#[derive(Serialize)]
373pub struct ServerSnapshot<'a> {
374 peer: &'a Peer,
375 #[serde(serialize_with = "serialize_deref")]
376 connection_pool: ConnectionPoolGuard<'a>,
377}
378
379pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
380where
381 S: Serializer,
382 T: Deref<Target = U>,
383 U: Serialize,
384{
385 Serialize::serialize(value.deref(), serializer)
386}
387
388impl Server {
389 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
390 let mut server = Self {
391 id: parking_lot::Mutex::new(id),
392 peer: Peer::new(id.0 as u32),
393 app_state: app_state.clone(),
394 connection_pool: Default::default(),
395 handlers: Default::default(),
396 teardown: watch::channel(false).0,
397 };
398
399 server
400 .add_request_handler(ping)
401 .add_request_handler(user_handler(create_room))
402 .add_request_handler(user_handler(join_room))
403 .add_request_handler(user_handler(rejoin_room))
404 .add_request_handler(user_handler(leave_room))
405 .add_request_handler(user_handler(set_room_participant_role))
406 .add_request_handler(user_handler(call))
407 .add_request_handler(user_handler(cancel_call))
408 .add_message_handler(user_message_handler(decline_call))
409 .add_request_handler(user_handler(update_participant_location))
410 .add_request_handler(user_handler(share_project))
411 .add_message_handler(unshare_project)
412 .add_request_handler(user_handler(join_project))
413 .add_request_handler(user_handler(join_hosted_project))
414 .add_request_handler(user_handler(rejoin_remote_projects))
415 .add_request_handler(user_handler(create_remote_project))
416 .add_request_handler(user_handler(create_dev_server))
417 .add_request_handler(user_handler(delete_dev_server))
418 .add_request_handler(dev_server_handler(share_remote_project))
419 .add_request_handler(dev_server_handler(shutdown_dev_server))
420 .add_request_handler(dev_server_handler(reconnect_dev_server))
421 .add_message_handler(user_message_handler(leave_project))
422 .add_request_handler(update_project)
423 .add_request_handler(update_worktree)
424 .add_message_handler(start_language_server)
425 .add_message_handler(update_language_server)
426 .add_message_handler(update_diagnostic_summary)
427 .add_message_handler(update_worktree_settings)
428 .add_request_handler(user_handler(
429 forward_read_only_project_request::<proto::GetHover>,
430 ))
431 .add_request_handler(user_handler(
432 forward_read_only_project_request::<proto::GetDefinition>,
433 ))
434 .add_request_handler(user_handler(
435 forward_read_only_project_request::<proto::GetTypeDefinition>,
436 ))
437 .add_request_handler(user_handler(
438 forward_read_only_project_request::<proto::GetReferences>,
439 ))
440 .add_request_handler(user_handler(
441 forward_read_only_project_request::<proto::SearchProject>,
442 ))
443 .add_request_handler(user_handler(
444 forward_read_only_project_request::<proto::GetDocumentHighlights>,
445 ))
446 .add_request_handler(user_handler(
447 forward_read_only_project_request::<proto::GetProjectSymbols>,
448 ))
449 .add_request_handler(user_handler(
450 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
451 ))
452 .add_request_handler(user_handler(
453 forward_read_only_project_request::<proto::OpenBufferById>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::SynchronizeBuffers>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::InlayHints>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::OpenBufferByPath>,
463 ))
464 .add_request_handler(user_handler(
465 forward_mutating_project_request::<proto::GetCompletions>,
466 ))
467 .add_request_handler(user_handler(
468 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
469 ))
470 .add_request_handler(user_handler(
471 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
472 ))
473 .add_request_handler(user_handler(
474 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
475 ))
476 .add_request_handler(user_handler(
477 forward_mutating_project_request::<proto::GetCodeActions>,
478 ))
479 .add_request_handler(user_handler(
480 forward_mutating_project_request::<proto::ApplyCodeAction>,
481 ))
482 .add_request_handler(user_handler(
483 forward_mutating_project_request::<proto::PrepareRename>,
484 ))
485 .add_request_handler(user_handler(
486 forward_mutating_project_request::<proto::PerformRename>,
487 ))
488 .add_request_handler(user_handler(
489 forward_mutating_project_request::<proto::ReloadBuffers>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::FormatBuffers>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::CreateProjectEntry>,
496 ))
497 .add_request_handler(user_handler(
498 forward_mutating_project_request::<proto::RenameProjectEntry>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::CopyProjectEntry>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::DeleteProjectEntry>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ExpandProjectEntry>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::OnTypeFormatting>,
511 ))
512 .add_request_handler(user_handler(
513 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::BlameBuffer>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::MultiLspQuery>,
520 ))
521 .add_message_handler(create_buffer_for_peer)
522 .add_request_handler(update_buffer)
523 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
524 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
525 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
526 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
527 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
528 .add_request_handler(get_users)
529 .add_request_handler(user_handler(fuzzy_search_users))
530 .add_request_handler(user_handler(request_contact))
531 .add_request_handler(user_handler(remove_contact))
532 .add_request_handler(user_handler(respond_to_contact_request))
533 .add_request_handler(user_handler(create_channel))
534 .add_request_handler(user_handler(delete_channel))
535 .add_request_handler(user_handler(invite_channel_member))
536 .add_request_handler(user_handler(remove_channel_member))
537 .add_request_handler(user_handler(set_channel_member_role))
538 .add_request_handler(user_handler(set_channel_visibility))
539 .add_request_handler(user_handler(rename_channel))
540 .add_request_handler(user_handler(join_channel_buffer))
541 .add_request_handler(user_handler(leave_channel_buffer))
542 .add_message_handler(user_message_handler(update_channel_buffer))
543 .add_request_handler(user_handler(rejoin_channel_buffers))
544 .add_request_handler(user_handler(get_channel_members))
545 .add_request_handler(user_handler(respond_to_channel_invite))
546 .add_request_handler(user_handler(join_channel))
547 .add_request_handler(user_handler(join_channel_chat))
548 .add_message_handler(user_message_handler(leave_channel_chat))
549 .add_request_handler(user_handler(send_channel_message))
550 .add_request_handler(user_handler(remove_channel_message))
551 .add_request_handler(user_handler(update_channel_message))
552 .add_request_handler(user_handler(get_channel_messages))
553 .add_request_handler(user_handler(get_channel_messages_by_id))
554 .add_request_handler(user_handler(get_notifications))
555 .add_request_handler(user_handler(mark_notification_as_read))
556 .add_request_handler(user_handler(move_channel))
557 .add_request_handler(user_handler(follow))
558 .add_message_handler(user_message_handler(unfollow))
559 .add_message_handler(user_message_handler(update_followers))
560 .add_request_handler(user_handler(get_private_user_info))
561 .add_message_handler(user_message_handler(acknowledge_channel_message))
562 .add_message_handler(user_message_handler(acknowledge_buffer_version))
563 .add_streaming_request_handler({
564 let app_state = app_state.clone();
565 move |request, response, session| {
566 complete_with_language_model(
567 request,
568 response,
569 session,
570 app_state.config.openai_api_key.clone(),
571 app_state.config.google_ai_api_key.clone(),
572 app_state.config.anthropic_api_key.clone(),
573 )
574 }
575 })
576 .add_request_handler({
577 let app_state = app_state.clone();
578 user_handler(move |request, response, session| {
579 count_tokens_with_language_model(
580 request,
581 response,
582 session,
583 app_state.config.google_ai_api_key.clone(),
584 )
585 })
586 })
587 .add_request_handler({
588 user_handler(move |request, response, session| {
589 get_cached_embeddings(request, response, session)
590 })
591 })
592 .add_request_handler({
593 let app_state = app_state.clone();
594 user_handler(move |request, response, session| {
595 compute_embeddings(
596 request,
597 response,
598 session,
599 app_state.config.openai_api_key.clone(),
600 )
601 })
602 });
603
604 Arc::new(server)
605 }
606
607 pub async fn start(&self) -> Result<()> {
608 let server_id = *self.id.lock();
609 let app_state = self.app_state.clone();
610 let peer = self.peer.clone();
611 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
612 let pool = self.connection_pool.clone();
613 let live_kit_client = self.app_state.live_kit_client.clone();
614
615 let span = info_span!("start server");
616 self.app_state.executor.spawn_detached(
617 async move {
618 tracing::info!("waiting for cleanup timeout");
619 timeout.await;
620 tracing::info!("cleanup timeout expired, retrieving stale rooms");
621 if let Some((room_ids, channel_ids)) = app_state
622 .db
623 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
624 .await
625 .trace_err()
626 {
627 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
628 tracing::info!(
629 stale_channel_buffer_count = channel_ids.len(),
630 "retrieved stale channel buffers"
631 );
632
633 for channel_id in channel_ids {
634 if let Some(refreshed_channel_buffer) = app_state
635 .db
636 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
637 .await
638 .trace_err()
639 {
640 for connection_id in refreshed_channel_buffer.connection_ids {
641 peer.send(
642 connection_id,
643 proto::UpdateChannelBufferCollaborators {
644 channel_id: channel_id.to_proto(),
645 collaborators: refreshed_channel_buffer
646 .collaborators
647 .clone(),
648 },
649 )
650 .trace_err();
651 }
652 }
653 }
654
655 for room_id in room_ids {
656 let mut contacts_to_update = HashSet::default();
657 let mut canceled_calls_to_user_ids = Vec::new();
658 let mut live_kit_room = String::new();
659 let mut delete_live_kit_room = false;
660
661 if let Some(mut refreshed_room) = app_state
662 .db
663 .clear_stale_room_participants(room_id, server_id)
664 .await
665 .trace_err()
666 {
667 tracing::info!(
668 room_id = room_id.0,
669 new_participant_count = refreshed_room.room.participants.len(),
670 "refreshed room"
671 );
672 room_updated(&refreshed_room.room, &peer);
673 if let Some(channel) = refreshed_room.channel.as_ref() {
674 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
675 }
676 contacts_to_update
677 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
678 contacts_to_update
679 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
680 canceled_calls_to_user_ids =
681 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
682 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
683 delete_live_kit_room = refreshed_room.room.participants.is_empty();
684 }
685
686 {
687 let pool = pool.lock();
688 for canceled_user_id in canceled_calls_to_user_ids {
689 for connection_id in pool.user_connection_ids(canceled_user_id) {
690 peer.send(
691 connection_id,
692 proto::CallCanceled {
693 room_id: room_id.to_proto(),
694 },
695 )
696 .trace_err();
697 }
698 }
699 }
700
701 for user_id in contacts_to_update {
702 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
703 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
704 if let Some((busy, contacts)) = busy.zip(contacts) {
705 let pool = pool.lock();
706 let updated_contact = contact_for_user(user_id, busy, &pool);
707 for contact in contacts {
708 if let db::Contact::Accepted {
709 user_id: contact_user_id,
710 ..
711 } = contact
712 {
713 for contact_conn_id in
714 pool.user_connection_ids(contact_user_id)
715 {
716 peer.send(
717 contact_conn_id,
718 proto::UpdateContacts {
719 contacts: vec![updated_contact.clone()],
720 remove_contacts: Default::default(),
721 incoming_requests: Default::default(),
722 remove_incoming_requests: Default::default(),
723 outgoing_requests: Default::default(),
724 remove_outgoing_requests: Default::default(),
725 },
726 )
727 .trace_err();
728 }
729 }
730 }
731 }
732 }
733
734 if let Some(live_kit) = live_kit_client.as_ref() {
735 if delete_live_kit_room {
736 live_kit.delete_room(live_kit_room).await.trace_err();
737 }
738 }
739 }
740 }
741
742 app_state
743 .db
744 .delete_stale_servers(&app_state.config.zed_environment, server_id)
745 .await
746 .trace_err();
747 }
748 .instrument(span),
749 );
750 Ok(())
751 }
752
753 pub fn teardown(&self) {
754 self.peer.teardown();
755 self.connection_pool.lock().reset();
756 let _ = self.teardown.send(true);
757 }
758
759 #[cfg(test)]
760 pub fn reset(&self, id: ServerId) {
761 self.teardown();
762 *self.id.lock() = id;
763 self.peer.reset(id.0 as u32);
764 let _ = self.teardown.send(false);
765 }
766
767 #[cfg(test)]
768 pub fn id(&self) -> ServerId {
769 *self.id.lock()
770 }
771
772 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
773 where
774 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
775 Fut: 'static + Send + Future<Output = Result<()>>,
776 M: EnvelopedMessage,
777 {
778 let prev_handler = self.handlers.insert(
779 TypeId::of::<M>(),
780 Box::new(move |envelope, session| {
781 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
782 let received_at = envelope.received_at;
783 tracing::info!("message received");
784 let start_time = Instant::now();
785 let future = (handler)(*envelope, session);
786 async move {
787 let result = future.await;
788 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
789 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
790 let queue_duration_ms = total_duration_ms - processing_duration_ms;
791 let payload_type = M::NAME;
792
793 match result {
794 Err(error) => {
795 tracing::error!(
796 ?error,
797 total_duration_ms,
798 processing_duration_ms,
799 queue_duration_ms,
800 payload_type,
801 "error handling message"
802 )
803 }
804 Ok(()) => tracing::info!(
805 total_duration_ms,
806 processing_duration_ms,
807 queue_duration_ms,
808 "finished handling message"
809 ),
810 }
811 }
812 .boxed()
813 }),
814 );
815 if prev_handler.is_some() {
816 panic!("registered a handler for the same message twice");
817 }
818 self
819 }
820
821 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
822 where
823 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
824 Fut: 'static + Send + Future<Output = Result<()>>,
825 M: EnvelopedMessage,
826 {
827 self.add_handler(move |envelope, session| handler(envelope.payload, session));
828 self
829 }
830
831 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
832 where
833 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
834 Fut: Send + Future<Output = Result<()>>,
835 M: RequestMessage,
836 {
837 let handler = Arc::new(handler);
838 self.add_handler(move |envelope, session| {
839 let receipt = envelope.receipt();
840 let handler = handler.clone();
841 async move {
842 let peer = session.peer.clone();
843 let responded = Arc::new(AtomicBool::default());
844 let response = Response {
845 peer: peer.clone(),
846 responded: responded.clone(),
847 receipt,
848 };
849 match (handler)(envelope.payload, response, session).await {
850 Ok(()) => {
851 if responded.load(std::sync::atomic::Ordering::SeqCst) {
852 Ok(())
853 } else {
854 Err(anyhow!("handler did not send a response"))?
855 }
856 }
857 Err(error) => {
858 let proto_err = match &error {
859 Error::Internal(err) => err.to_proto(),
860 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
861 };
862 peer.respond_with_error(receipt, proto_err)?;
863 Err(error)
864 }
865 }
866 }
867 })
868 }
869
870 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
871 where
872 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
873 Fut: Send + Future<Output = Result<()>>,
874 M: RequestMessage,
875 {
876 let handler = Arc::new(handler);
877 self.add_handler(move |envelope, session| {
878 let receipt = envelope.receipt();
879 let handler = handler.clone();
880 async move {
881 let peer = session.peer.clone();
882 let response = StreamingResponse {
883 peer: peer.clone(),
884 receipt,
885 };
886 match (handler)(envelope.payload, response, session).await {
887 Ok(()) => {
888 peer.end_stream(receipt)?;
889 Ok(())
890 }
891 Err(error) => {
892 let proto_err = match &error {
893 Error::Internal(err) => err.to_proto(),
894 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
895 };
896 peer.respond_with_error(receipt, proto_err)?;
897 Err(error)
898 }
899 }
900 }
901 })
902 }
903
904 #[allow(clippy::too_many_arguments)]
905 pub fn handle_connection(
906 self: &Arc<Self>,
907 connection: Connection,
908 address: String,
909 principal: Principal,
910 zed_version: ZedVersion,
911 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
912 executor: Executor,
913 ) -> impl Future<Output = ()> {
914 let this = self.clone();
915 let span = info_span!("handle connection", %address,
916 connection_id=field::Empty,
917 user_id=field::Empty,
918 login=field::Empty,
919 impersonator=field::Empty,
920 dev_server_id=field::Empty
921 );
922 principal.update_span(&span);
923
924 let mut teardown = self.teardown.subscribe();
925 async move {
926 if *teardown.borrow() {
927 tracing::error!("server is tearing down");
928 return
929 }
930 let (connection_id, handle_io, mut incoming_rx) = this
931 .peer
932 .add_connection(connection, {
933 let executor = executor.clone();
934 move |duration| executor.sleep(duration)
935 });
936 tracing::Span::current().record("connection_id", format!("{}", connection_id));
937 tracing::info!("connection opened");
938
939 let http_client = match IsahcHttpClient::new() {
940 Ok(http_client) => http_client,
941 Err(error) => {
942 tracing::error!(?error, "failed to create HTTP client");
943 return;
944 }
945 };
946
947 let session = Session {
948 principal: principal.clone(),
949 connection_id,
950 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
951 peer: this.peer.clone(),
952 connection_pool: this.connection_pool.clone(),
953 live_kit_client: this.app_state.live_kit_client.clone(),
954 http_client,
955 rate_limiter: this.app_state.rate_limiter.clone(),
956 _executor: executor.clone(),
957 };
958
959 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
960 tracing::error!(?error, "failed to send initial client update");
961 return;
962 }
963
964 let handle_io = handle_io.fuse();
965 futures::pin_mut!(handle_io);
966
967 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
968 // This prevents deadlocks when e.g., client A performs a request to client B and
969 // client B performs a request to client A. If both clients stop processing further
970 // messages until their respective request completes, they won't have a chance to
971 // respond to the other client's request and cause a deadlock.
972 //
973 // This arrangement ensures we will attempt to process earlier messages first, but fall
974 // back to processing messages arrived later in the spirit of making progress.
975 let mut foreground_message_handlers = FuturesUnordered::new();
976 let concurrent_handlers = Arc::new(Semaphore::new(256));
977 loop {
978 let next_message = async {
979 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
980 let message = incoming_rx.next().await;
981 (permit, message)
982 }.fuse();
983 futures::pin_mut!(next_message);
984 futures::select_biased! {
985 _ = teardown.changed().fuse() => return,
986 result = handle_io => {
987 if let Err(error) = result {
988 tracing::error!(?error, "error handling I/O");
989 }
990 break;
991 }
992 _ = foreground_message_handlers.next() => {}
993 next_message = next_message => {
994 let (permit, message) = next_message;
995 if let Some(message) = message {
996 let type_name = message.payload_type_name();
997 // note: we copy all the fields from the parent span so we can query them in the logs.
998 // (https://github.com/tokio-rs/tracing/issues/2670).
999 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1000 user_id=field::Empty,
1001 login=field::Empty,
1002 impersonator=field::Empty,
1003 dev_server_id=field::Empty
1004 );
1005 principal.update_span(&span);
1006 let span_enter = span.enter();
1007 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1008 let is_background = message.is_background();
1009 let handle_message = (handler)(message, session.clone());
1010 drop(span_enter);
1011
1012 let handle_message = async move {
1013 handle_message.await;
1014 drop(permit);
1015 }.instrument(span);
1016 if is_background {
1017 executor.spawn_detached(handle_message);
1018 } else {
1019 foreground_message_handlers.push(handle_message);
1020 }
1021 } else {
1022 tracing::error!("no message handler");
1023 }
1024 } else {
1025 tracing::info!("connection closed");
1026 break;
1027 }
1028 }
1029 }
1030 }
1031
1032 drop(foreground_message_handlers);
1033 tracing::info!("signing out");
1034 if let Err(error) = connection_lost(session, teardown, executor).await {
1035 tracing::error!(?error, "error signing out");
1036 }
1037
1038 }.instrument(span)
1039 }
1040
1041 async fn send_initial_client_update(
1042 &self,
1043 connection_id: ConnectionId,
1044 principal: &Principal,
1045 zed_version: ZedVersion,
1046 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1047 session: &Session,
1048 ) -> Result<()> {
1049 self.peer.send(
1050 connection_id,
1051 proto::Hello {
1052 peer_id: Some(connection_id.into()),
1053 },
1054 )?;
1055 tracing::info!("sent hello message");
1056 if let Some(send_connection_id) = send_connection_id.take() {
1057 let _ = send_connection_id.send(connection_id);
1058 }
1059
1060 match principal {
1061 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1062 if !user.connected_once {
1063 self.peer.send(connection_id, proto::ShowContacts {})?;
1064 self.app_state
1065 .db
1066 .set_user_connected_once(user.id, true)
1067 .await?;
1068 }
1069
1070 let (contacts, channels_for_user, channel_invites, remote_projects) =
1071 future::try_join4(
1072 self.app_state.db.get_contacts(user.id),
1073 self.app_state.db.get_channels_for_user(user.id),
1074 self.app_state.db.get_channel_invites_for_user(user.id),
1075 self.app_state.db.remote_projects_update(user.id),
1076 )
1077 .await?;
1078
1079 {
1080 let mut pool = self.connection_pool.lock();
1081 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1082 for membership in &channels_for_user.channel_memberships {
1083 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1084 }
1085 self.peer.send(
1086 connection_id,
1087 build_initial_contacts_update(contacts, &pool),
1088 )?;
1089 self.peer.send(
1090 connection_id,
1091 build_update_user_channels(&channels_for_user),
1092 )?;
1093 self.peer.send(
1094 connection_id,
1095 build_channels_update(channels_for_user, channel_invites),
1096 )?;
1097 }
1098 send_remote_projects_update(user.id, remote_projects, session).await;
1099
1100 if let Some(incoming_call) =
1101 self.app_state.db.incoming_call_for_user(user.id).await?
1102 {
1103 self.peer.send(connection_id, incoming_call)?;
1104 }
1105
1106 update_user_contacts(user.id, &session).await?;
1107 }
1108 Principal::DevServer(dev_server) => {
1109 {
1110 let mut pool = self.connection_pool.lock();
1111 if pool.dev_server_connection_id(dev_server.id).is_some() {
1112 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1113 };
1114 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1115 }
1116
1117 let projects = self
1118 .app_state
1119 .db
1120 .get_remote_projects_for_dev_server(dev_server.id)
1121 .await?;
1122 self.peer
1123 .send(connection_id, proto::DevServerInstructions { projects })?;
1124
1125 let status = self
1126 .app_state
1127 .db
1128 .remote_projects_update(dev_server.user_id)
1129 .await?;
1130 send_remote_projects_update(dev_server.user_id, status, &session).await;
1131 }
1132 }
1133
1134 Ok(())
1135 }
1136
1137 pub async fn invite_code_redeemed(
1138 self: &Arc<Self>,
1139 inviter_id: UserId,
1140 invitee_id: UserId,
1141 ) -> Result<()> {
1142 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1143 if let Some(code) = &user.invite_code {
1144 let pool = self.connection_pool.lock();
1145 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1146 for connection_id in pool.user_connection_ids(inviter_id) {
1147 self.peer.send(
1148 connection_id,
1149 proto::UpdateContacts {
1150 contacts: vec![invitee_contact.clone()],
1151 ..Default::default()
1152 },
1153 )?;
1154 self.peer.send(
1155 connection_id,
1156 proto::UpdateInviteInfo {
1157 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1158 count: user.invite_count as u32,
1159 },
1160 )?;
1161 }
1162 }
1163 }
1164 Ok(())
1165 }
1166
1167 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1168 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1169 if let Some(invite_code) = &user.invite_code {
1170 let pool = self.connection_pool.lock();
1171 for connection_id in pool.user_connection_ids(user_id) {
1172 self.peer.send(
1173 connection_id,
1174 proto::UpdateInviteInfo {
1175 url: format!(
1176 "{}{}",
1177 self.app_state.config.invite_link_prefix, invite_code
1178 ),
1179 count: user.invite_count as u32,
1180 },
1181 )?;
1182 }
1183 }
1184 }
1185 Ok(())
1186 }
1187
1188 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1189 ServerSnapshot {
1190 connection_pool: ConnectionPoolGuard {
1191 guard: self.connection_pool.lock(),
1192 _not_send: PhantomData,
1193 },
1194 peer: &self.peer,
1195 }
1196 }
1197}
1198
1199impl<'a> Deref for ConnectionPoolGuard<'a> {
1200 type Target = ConnectionPool;
1201
1202 fn deref(&self) -> &Self::Target {
1203 &self.guard
1204 }
1205}
1206
1207impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1208 fn deref_mut(&mut self) -> &mut Self::Target {
1209 &mut self.guard
1210 }
1211}
1212
1213impl<'a> Drop for ConnectionPoolGuard<'a> {
1214 fn drop(&mut self) {
1215 #[cfg(test)]
1216 self.check_invariants();
1217 }
1218}
1219
1220fn broadcast<F>(
1221 sender_id: Option<ConnectionId>,
1222 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1223 mut f: F,
1224) where
1225 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1226{
1227 for receiver_id in receiver_ids {
1228 if Some(receiver_id) != sender_id {
1229 if let Err(error) = f(receiver_id) {
1230 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1231 }
1232 }
1233 }
1234}
1235
1236pub struct ProtocolVersion(u32);
1237
1238impl Header for ProtocolVersion {
1239 fn name() -> &'static HeaderName {
1240 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1241 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1242 }
1243
1244 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1245 where
1246 Self: Sized,
1247 I: Iterator<Item = &'i axum::http::HeaderValue>,
1248 {
1249 let version = values
1250 .next()
1251 .ok_or_else(axum::headers::Error::invalid)?
1252 .to_str()
1253 .map_err(|_| axum::headers::Error::invalid())?
1254 .parse()
1255 .map_err(|_| axum::headers::Error::invalid())?;
1256 Ok(Self(version))
1257 }
1258
1259 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1260 values.extend([self.0.to_string().parse().unwrap()]);
1261 }
1262}
1263
1264pub struct AppVersionHeader(SemanticVersion);
1265impl Header for AppVersionHeader {
1266 fn name() -> &'static HeaderName {
1267 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1268 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1269 }
1270
1271 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1272 where
1273 Self: Sized,
1274 I: Iterator<Item = &'i axum::http::HeaderValue>,
1275 {
1276 let version = values
1277 .next()
1278 .ok_or_else(axum::headers::Error::invalid)?
1279 .to_str()
1280 .map_err(|_| axum::headers::Error::invalid())?
1281 .parse()
1282 .map_err(|_| axum::headers::Error::invalid())?;
1283 Ok(Self(version))
1284 }
1285
1286 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1287 values.extend([self.0.to_string().parse().unwrap()]);
1288 }
1289}
1290
1291pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1292 Router::new()
1293 .route("/rpc", get(handle_websocket_request))
1294 .layer(
1295 ServiceBuilder::new()
1296 .layer(Extension(server.app_state.clone()))
1297 .layer(middleware::from_fn(auth::validate_header)),
1298 )
1299 .route("/metrics", get(handle_metrics))
1300 .layer(Extension(server))
1301}
1302
1303pub async fn handle_websocket_request(
1304 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1305 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1306 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1307 Extension(server): Extension<Arc<Server>>,
1308 Extension(principal): Extension<Principal>,
1309 ws: WebSocketUpgrade,
1310) -> axum::response::Response {
1311 if protocol_version != rpc::PROTOCOL_VERSION {
1312 return (
1313 StatusCode::UPGRADE_REQUIRED,
1314 "client must be upgraded".to_string(),
1315 )
1316 .into_response();
1317 }
1318
1319 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1320 return (
1321 StatusCode::UPGRADE_REQUIRED,
1322 "no version header found".to_string(),
1323 )
1324 .into_response();
1325 };
1326
1327 if !version.can_collaborate() {
1328 return (
1329 StatusCode::UPGRADE_REQUIRED,
1330 "client must be upgraded".to_string(),
1331 )
1332 .into_response();
1333 }
1334
1335 let socket_address = socket_address.to_string();
1336 ws.on_upgrade(move |socket| {
1337 let socket = socket
1338 .map_ok(to_tungstenite_message)
1339 .err_into()
1340 .with(|message| async move { Ok(to_axum_message(message)) });
1341 let connection = Connection::new(Box::pin(socket));
1342 async move {
1343 server
1344 .handle_connection(
1345 connection,
1346 socket_address,
1347 principal,
1348 version,
1349 None,
1350 Executor::Production,
1351 )
1352 .await;
1353 }
1354 })
1355}
1356
1357pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1358 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1359 let connections_metric = CONNECTIONS_METRIC
1360 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1361
1362 let connections = server
1363 .connection_pool
1364 .lock()
1365 .connections()
1366 .filter(|connection| !connection.admin)
1367 .count();
1368 connections_metric.set(connections as _);
1369
1370 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1371 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1372 register_int_gauge!(
1373 "shared_projects",
1374 "number of open projects with one or more guests"
1375 )
1376 .unwrap()
1377 });
1378
1379 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1380 shared_projects_metric.set(shared_projects as _);
1381
1382 let encoder = prometheus::TextEncoder::new();
1383 let metric_families = prometheus::gather();
1384 let encoded_metrics = encoder
1385 .encode_to_string(&metric_families)
1386 .map_err(|err| anyhow!("{}", err))?;
1387 Ok(encoded_metrics)
1388}
1389
1390#[instrument(err, skip(executor))]
1391async fn connection_lost(
1392 session: Session,
1393 mut teardown: watch::Receiver<bool>,
1394 executor: Executor,
1395) -> Result<()> {
1396 session.peer.disconnect(session.connection_id);
1397 session
1398 .connection_pool()
1399 .await
1400 .remove_connection(session.connection_id)?;
1401
1402 session
1403 .db()
1404 .await
1405 .connection_lost(session.connection_id)
1406 .await
1407 .trace_err();
1408
1409 futures::select_biased! {
1410 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1411 match &session.principal {
1412 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1413 let session = session.for_user().unwrap();
1414
1415 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1416 leave_room_for_session(&session, session.connection_id).await.trace_err();
1417 leave_channel_buffers_for_session(&session)
1418 .await
1419 .trace_err();
1420
1421 if !session
1422 .connection_pool()
1423 .await
1424 .is_user_online(session.user_id())
1425 {
1426 let db = session.db().await;
1427 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1428 room_updated(&room, &session.peer);
1429 }
1430 }
1431
1432 update_user_contacts(session.user_id(), &session).await?;
1433 },
1434 Principal::DevServer(_) => {
1435 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1436 },
1437 }
1438 },
1439 _ = teardown.changed().fuse() => {}
1440 }
1441
1442 Ok(())
1443}
1444
1445/// Acknowledges a ping from a client, used to keep the connection alive.
1446async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1447 response.send(proto::Ack {})?;
1448 Ok(())
1449}
1450
1451/// Creates a new room for calling (outside of channels)
1452async fn create_room(
1453 _request: proto::CreateRoom,
1454 response: Response<proto::CreateRoom>,
1455 session: UserSession,
1456) -> Result<()> {
1457 let live_kit_room = nanoid::nanoid!(30);
1458
1459 let live_kit_connection_info = util::maybe!(async {
1460 let live_kit = session.live_kit_client.as_ref();
1461 let live_kit = live_kit?;
1462 let user_id = session.user_id().to_string();
1463
1464 let token = live_kit
1465 .room_token(&live_kit_room, &user_id.to_string())
1466 .trace_err()?;
1467
1468 Some(proto::LiveKitConnectionInfo {
1469 server_url: live_kit.url().into(),
1470 token,
1471 can_publish: true,
1472 })
1473 })
1474 .await;
1475
1476 let room = session
1477 .db()
1478 .await
1479 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1480 .await?;
1481
1482 response.send(proto::CreateRoomResponse {
1483 room: Some(room.clone()),
1484 live_kit_connection_info,
1485 })?;
1486
1487 update_user_contacts(session.user_id(), &session).await?;
1488 Ok(())
1489}
1490
1491/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1492async fn join_room(
1493 request: proto::JoinRoom,
1494 response: Response<proto::JoinRoom>,
1495 session: UserSession,
1496) -> Result<()> {
1497 let room_id = RoomId::from_proto(request.id);
1498
1499 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1500
1501 if let Some(channel_id) = channel_id {
1502 return join_channel_internal(channel_id, Box::new(response), session).await;
1503 }
1504
1505 let joined_room = {
1506 let room = session
1507 .db()
1508 .await
1509 .join_room(room_id, session.user_id(), session.connection_id)
1510 .await?;
1511 room_updated(&room.room, &session.peer);
1512 room.into_inner()
1513 };
1514
1515 for connection_id in session
1516 .connection_pool()
1517 .await
1518 .user_connection_ids(session.user_id())
1519 {
1520 session
1521 .peer
1522 .send(
1523 connection_id,
1524 proto::CallCanceled {
1525 room_id: room_id.to_proto(),
1526 },
1527 )
1528 .trace_err();
1529 }
1530
1531 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1532 if let Some(token) = live_kit
1533 .room_token(
1534 &joined_room.room.live_kit_room,
1535 &session.user_id().to_string(),
1536 )
1537 .trace_err()
1538 {
1539 Some(proto::LiveKitConnectionInfo {
1540 server_url: live_kit.url().into(),
1541 token,
1542 can_publish: true,
1543 })
1544 } else {
1545 None
1546 }
1547 } else {
1548 None
1549 };
1550
1551 response.send(proto::JoinRoomResponse {
1552 room: Some(joined_room.room),
1553 channel_id: None,
1554 live_kit_connection_info,
1555 })?;
1556
1557 update_user_contacts(session.user_id(), &session).await?;
1558 Ok(())
1559}
1560
1561/// Rejoin room is used to reconnect to a room after connection errors.
1562async fn rejoin_room(
1563 request: proto::RejoinRoom,
1564 response: Response<proto::RejoinRoom>,
1565 session: UserSession,
1566) -> Result<()> {
1567 let room;
1568 let channel;
1569 {
1570 let mut rejoined_room = session
1571 .db()
1572 .await
1573 .rejoin_room(request, session.user_id(), session.connection_id)
1574 .await?;
1575
1576 response.send(proto::RejoinRoomResponse {
1577 room: Some(rejoined_room.room.clone()),
1578 reshared_projects: rejoined_room
1579 .reshared_projects
1580 .iter()
1581 .map(|project| proto::ResharedProject {
1582 id: project.id.to_proto(),
1583 collaborators: project
1584 .collaborators
1585 .iter()
1586 .map(|collaborator| collaborator.to_proto())
1587 .collect(),
1588 })
1589 .collect(),
1590 rejoined_projects: rejoined_room
1591 .rejoined_projects
1592 .iter()
1593 .map(|rejoined_project| rejoined_project.to_proto())
1594 .collect(),
1595 })?;
1596 room_updated(&rejoined_room.room, &session.peer);
1597
1598 for project in &rejoined_room.reshared_projects {
1599 for collaborator in &project.collaborators {
1600 session
1601 .peer
1602 .send(
1603 collaborator.connection_id,
1604 proto::UpdateProjectCollaborator {
1605 project_id: project.id.to_proto(),
1606 old_peer_id: Some(project.old_connection_id.into()),
1607 new_peer_id: Some(session.connection_id.into()),
1608 },
1609 )
1610 .trace_err();
1611 }
1612
1613 broadcast(
1614 Some(session.connection_id),
1615 project
1616 .collaborators
1617 .iter()
1618 .map(|collaborator| collaborator.connection_id),
1619 |connection_id| {
1620 session.peer.forward_send(
1621 session.connection_id,
1622 connection_id,
1623 proto::UpdateProject {
1624 project_id: project.id.to_proto(),
1625 worktrees: project.worktrees.clone(),
1626 },
1627 )
1628 },
1629 );
1630 }
1631
1632 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1633
1634 let rejoined_room = rejoined_room.into_inner();
1635
1636 room = rejoined_room.room;
1637 channel = rejoined_room.channel;
1638 }
1639
1640 if let Some(channel) = channel {
1641 channel_updated(
1642 &channel,
1643 &room,
1644 &session.peer,
1645 &*session.connection_pool().await,
1646 );
1647 }
1648
1649 update_user_contacts(session.user_id(), &session).await?;
1650 Ok(())
1651}
1652
1653fn notify_rejoined_projects(
1654 rejoined_projects: &mut Vec<RejoinedProject>,
1655 session: &UserSession,
1656) -> Result<()> {
1657 for project in rejoined_projects.iter() {
1658 for collaborator in &project.collaborators {
1659 session
1660 .peer
1661 .send(
1662 collaborator.connection_id,
1663 proto::UpdateProjectCollaborator {
1664 project_id: project.id.to_proto(),
1665 old_peer_id: Some(project.old_connection_id.into()),
1666 new_peer_id: Some(session.connection_id.into()),
1667 },
1668 )
1669 .trace_err();
1670 }
1671 }
1672
1673 for project in rejoined_projects {
1674 for worktree in mem::take(&mut project.worktrees) {
1675 #[cfg(any(test, feature = "test-support"))]
1676 const MAX_CHUNK_SIZE: usize = 2;
1677 #[cfg(not(any(test, feature = "test-support")))]
1678 const MAX_CHUNK_SIZE: usize = 256;
1679
1680 // Stream this worktree's entries.
1681 let message = proto::UpdateWorktree {
1682 project_id: project.id.to_proto(),
1683 worktree_id: worktree.id,
1684 abs_path: worktree.abs_path.clone(),
1685 root_name: worktree.root_name,
1686 updated_entries: worktree.updated_entries,
1687 removed_entries: worktree.removed_entries,
1688 scan_id: worktree.scan_id,
1689 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1690 updated_repositories: worktree.updated_repositories,
1691 removed_repositories: worktree.removed_repositories,
1692 };
1693 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1694 session.peer.send(session.connection_id, update.clone())?;
1695 }
1696
1697 // Stream this worktree's diagnostics.
1698 for summary in worktree.diagnostic_summaries {
1699 session.peer.send(
1700 session.connection_id,
1701 proto::UpdateDiagnosticSummary {
1702 project_id: project.id.to_proto(),
1703 worktree_id: worktree.id,
1704 summary: Some(summary),
1705 },
1706 )?;
1707 }
1708
1709 for settings_file in worktree.settings_files {
1710 session.peer.send(
1711 session.connection_id,
1712 proto::UpdateWorktreeSettings {
1713 project_id: project.id.to_proto(),
1714 worktree_id: worktree.id,
1715 path: settings_file.path,
1716 content: Some(settings_file.content),
1717 },
1718 )?;
1719 }
1720 }
1721
1722 for language_server in &project.language_servers {
1723 session.peer.send(
1724 session.connection_id,
1725 proto::UpdateLanguageServer {
1726 project_id: project.id.to_proto(),
1727 language_server_id: language_server.id,
1728 variant: Some(
1729 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1730 proto::LspDiskBasedDiagnosticsUpdated {},
1731 ),
1732 ),
1733 },
1734 )?;
1735 }
1736 }
1737 Ok(())
1738}
1739
1740/// leave room disconnects from the room.
1741async fn leave_room(
1742 _: proto::LeaveRoom,
1743 response: Response<proto::LeaveRoom>,
1744 session: UserSession,
1745) -> Result<()> {
1746 leave_room_for_session(&session, session.connection_id).await?;
1747 response.send(proto::Ack {})?;
1748 Ok(())
1749}
1750
1751/// Updates the permissions of someone else in the room.
1752async fn set_room_participant_role(
1753 request: proto::SetRoomParticipantRole,
1754 response: Response<proto::SetRoomParticipantRole>,
1755 session: UserSession,
1756) -> Result<()> {
1757 let user_id = UserId::from_proto(request.user_id);
1758 let role = ChannelRole::from(request.role());
1759
1760 let (live_kit_room, can_publish) = {
1761 let room = session
1762 .db()
1763 .await
1764 .set_room_participant_role(
1765 session.user_id(),
1766 RoomId::from_proto(request.room_id),
1767 user_id,
1768 role,
1769 )
1770 .await?;
1771
1772 let live_kit_room = room.live_kit_room.clone();
1773 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1774 room_updated(&room, &session.peer);
1775 (live_kit_room, can_publish)
1776 };
1777
1778 if let Some(live_kit) = session.live_kit_client.as_ref() {
1779 live_kit
1780 .update_participant(
1781 live_kit_room.clone(),
1782 request.user_id.to_string(),
1783 live_kit_server::proto::ParticipantPermission {
1784 can_subscribe: true,
1785 can_publish,
1786 can_publish_data: can_publish,
1787 hidden: false,
1788 recorder: false,
1789 },
1790 )
1791 .await
1792 .trace_err();
1793 }
1794
1795 response.send(proto::Ack {})?;
1796 Ok(())
1797}
1798
1799/// Call someone else into the current room
1800async fn call(
1801 request: proto::Call,
1802 response: Response<proto::Call>,
1803 session: UserSession,
1804) -> Result<()> {
1805 let room_id = RoomId::from_proto(request.room_id);
1806 let calling_user_id = session.user_id();
1807 let calling_connection_id = session.connection_id;
1808 let called_user_id = UserId::from_proto(request.called_user_id);
1809 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1810 if !session
1811 .db()
1812 .await
1813 .has_contact(calling_user_id, called_user_id)
1814 .await?
1815 {
1816 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1817 }
1818
1819 let incoming_call = {
1820 let (room, incoming_call) = &mut *session
1821 .db()
1822 .await
1823 .call(
1824 room_id,
1825 calling_user_id,
1826 calling_connection_id,
1827 called_user_id,
1828 initial_project_id,
1829 )
1830 .await?;
1831 room_updated(&room, &session.peer);
1832 mem::take(incoming_call)
1833 };
1834 update_user_contacts(called_user_id, &session).await?;
1835
1836 let mut calls = session
1837 .connection_pool()
1838 .await
1839 .user_connection_ids(called_user_id)
1840 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1841 .collect::<FuturesUnordered<_>>();
1842
1843 while let Some(call_response) = calls.next().await {
1844 match call_response.as_ref() {
1845 Ok(_) => {
1846 response.send(proto::Ack {})?;
1847 return Ok(());
1848 }
1849 Err(_) => {
1850 call_response.trace_err();
1851 }
1852 }
1853 }
1854
1855 {
1856 let room = session
1857 .db()
1858 .await
1859 .call_failed(room_id, called_user_id)
1860 .await?;
1861 room_updated(&room, &session.peer);
1862 }
1863 update_user_contacts(called_user_id, &session).await?;
1864
1865 Err(anyhow!("failed to ring user"))?
1866}
1867
1868/// Cancel an outgoing call.
1869async fn cancel_call(
1870 request: proto::CancelCall,
1871 response: Response<proto::CancelCall>,
1872 session: UserSession,
1873) -> Result<()> {
1874 let called_user_id = UserId::from_proto(request.called_user_id);
1875 let room_id = RoomId::from_proto(request.room_id);
1876 {
1877 let room = session
1878 .db()
1879 .await
1880 .cancel_call(room_id, session.connection_id, called_user_id)
1881 .await?;
1882 room_updated(&room, &session.peer);
1883 }
1884
1885 for connection_id in session
1886 .connection_pool()
1887 .await
1888 .user_connection_ids(called_user_id)
1889 {
1890 session
1891 .peer
1892 .send(
1893 connection_id,
1894 proto::CallCanceled {
1895 room_id: room_id.to_proto(),
1896 },
1897 )
1898 .trace_err();
1899 }
1900 response.send(proto::Ack {})?;
1901
1902 update_user_contacts(called_user_id, &session).await?;
1903 Ok(())
1904}
1905
1906/// Decline an incoming call.
1907async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1908 let room_id = RoomId::from_proto(message.room_id);
1909 {
1910 let room = session
1911 .db()
1912 .await
1913 .decline_call(Some(room_id), session.user_id())
1914 .await?
1915 .ok_or_else(|| anyhow!("failed to decline call"))?;
1916 room_updated(&room, &session.peer);
1917 }
1918
1919 for connection_id in session
1920 .connection_pool()
1921 .await
1922 .user_connection_ids(session.user_id())
1923 {
1924 session
1925 .peer
1926 .send(
1927 connection_id,
1928 proto::CallCanceled {
1929 room_id: room_id.to_proto(),
1930 },
1931 )
1932 .trace_err();
1933 }
1934 update_user_contacts(session.user_id(), &session).await?;
1935 Ok(())
1936}
1937
1938/// Updates other participants in the room with your current location.
1939async fn update_participant_location(
1940 request: proto::UpdateParticipantLocation,
1941 response: Response<proto::UpdateParticipantLocation>,
1942 session: UserSession,
1943) -> Result<()> {
1944 let room_id = RoomId::from_proto(request.room_id);
1945 let location = request
1946 .location
1947 .ok_or_else(|| anyhow!("invalid location"))?;
1948
1949 let db = session.db().await;
1950 let room = db
1951 .update_room_participant_location(room_id, session.connection_id, location)
1952 .await?;
1953
1954 room_updated(&room, &session.peer);
1955 response.send(proto::Ack {})?;
1956 Ok(())
1957}
1958
1959/// Share a project into the room.
1960async fn share_project(
1961 request: proto::ShareProject,
1962 response: Response<proto::ShareProject>,
1963 session: UserSession,
1964) -> Result<()> {
1965 let (project_id, room) = &*session
1966 .db()
1967 .await
1968 .share_project(
1969 RoomId::from_proto(request.room_id),
1970 session.connection_id,
1971 &request.worktrees,
1972 request
1973 .remote_project_id
1974 .map(|id| RemoteProjectId::from_proto(id)),
1975 )
1976 .await?;
1977 response.send(proto::ShareProjectResponse {
1978 project_id: project_id.to_proto(),
1979 })?;
1980 room_updated(&room, &session.peer);
1981
1982 Ok(())
1983}
1984
1985/// Unshare a project from the room.
1986async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1987 let project_id = ProjectId::from_proto(message.project_id);
1988 unshare_project_internal(
1989 project_id,
1990 session.connection_id,
1991 session.user_id(),
1992 &session,
1993 )
1994 .await
1995}
1996
1997async fn unshare_project_internal(
1998 project_id: ProjectId,
1999 connection_id: ConnectionId,
2000 user_id: Option<UserId>,
2001 session: &Session,
2002) -> Result<()> {
2003 let (room, guest_connection_ids) = &*session
2004 .db()
2005 .await
2006 .unshare_project(project_id, connection_id, user_id)
2007 .await?;
2008
2009 let message = proto::UnshareProject {
2010 project_id: project_id.to_proto(),
2011 };
2012
2013 broadcast(
2014 Some(connection_id),
2015 guest_connection_ids.iter().copied(),
2016 |conn_id| session.peer.send(conn_id, message.clone()),
2017 );
2018 if let Some(room) = room {
2019 room_updated(room, &session.peer);
2020 }
2021
2022 Ok(())
2023}
2024
2025/// DevServer makes a project available online
2026async fn share_remote_project(
2027 request: proto::ShareRemoteProject,
2028 response: Response<proto::ShareRemoteProject>,
2029 session: DevServerSession,
2030) -> Result<()> {
2031 let (remote_project, user_id, status) = session
2032 .db()
2033 .await
2034 .share_remote_project(
2035 RemoteProjectId::from_proto(request.remote_project_id),
2036 session.dev_server_id(),
2037 session.connection_id,
2038 &request.worktrees,
2039 )
2040 .await?;
2041 let Some(project_id) = remote_project.project_id else {
2042 return Err(anyhow!("failed to share remote project"))?;
2043 };
2044
2045 send_remote_projects_update(user_id, status, &session).await;
2046
2047 response.send(proto::ShareProjectResponse { project_id })?;
2048
2049 Ok(())
2050}
2051
2052/// Join someone elses shared project.
2053async fn join_project(
2054 request: proto::JoinProject,
2055 response: Response<proto::JoinProject>,
2056 session: UserSession,
2057) -> Result<()> {
2058 let project_id = ProjectId::from_proto(request.project_id);
2059
2060 tracing::info!(%project_id, "join project");
2061
2062 let db = session.db().await;
2063 let (project, replica_id) = &mut *db
2064 .join_project(project_id, session.connection_id, session.user_id())
2065 .await?;
2066 drop(db);
2067 tracing::info!(%project_id, "join remote project");
2068 join_project_internal(response, session, project, replica_id)
2069}
2070
2071trait JoinProjectInternalResponse {
2072 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2073}
2074impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2075 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2076 Response::<proto::JoinProject>::send(self, result)
2077 }
2078}
2079impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2080 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2081 Response::<proto::JoinHostedProject>::send(self, result)
2082 }
2083}
2084
2085fn join_project_internal(
2086 response: impl JoinProjectInternalResponse,
2087 session: UserSession,
2088 project: &mut Project,
2089 replica_id: &ReplicaId,
2090) -> Result<()> {
2091 let collaborators = project
2092 .collaborators
2093 .iter()
2094 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2095 .map(|collaborator| collaborator.to_proto())
2096 .collect::<Vec<_>>();
2097 let project_id = project.id;
2098 let guest_user_id = session.user_id();
2099
2100 let worktrees = project
2101 .worktrees
2102 .iter()
2103 .map(|(id, worktree)| proto::WorktreeMetadata {
2104 id: *id,
2105 root_name: worktree.root_name.clone(),
2106 visible: worktree.visible,
2107 abs_path: worktree.abs_path.clone(),
2108 })
2109 .collect::<Vec<_>>();
2110
2111 let add_project_collaborator = proto::AddProjectCollaborator {
2112 project_id: project_id.to_proto(),
2113 collaborator: Some(proto::Collaborator {
2114 peer_id: Some(session.connection_id.into()),
2115 replica_id: replica_id.0 as u32,
2116 user_id: guest_user_id.to_proto(),
2117 }),
2118 };
2119
2120 for collaborator in &collaborators {
2121 session
2122 .peer
2123 .send(
2124 collaborator.peer_id.unwrap().into(),
2125 add_project_collaborator.clone(),
2126 )
2127 .trace_err();
2128 }
2129
2130 // First, we send the metadata associated with each worktree.
2131 response.send(proto::JoinProjectResponse {
2132 project_id: project.id.0 as u64,
2133 worktrees: worktrees.clone(),
2134 replica_id: replica_id.0 as u32,
2135 collaborators: collaborators.clone(),
2136 language_servers: project.language_servers.clone(),
2137 role: project.role.into(),
2138 remote_project_id: project
2139 .remote_project_id
2140 .map(|remote_project_id| remote_project_id.0 as u64),
2141 })?;
2142
2143 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2144 #[cfg(any(test, feature = "test-support"))]
2145 const MAX_CHUNK_SIZE: usize = 2;
2146 #[cfg(not(any(test, feature = "test-support")))]
2147 const MAX_CHUNK_SIZE: usize = 256;
2148
2149 // Stream this worktree's entries.
2150 let message = proto::UpdateWorktree {
2151 project_id: project_id.to_proto(),
2152 worktree_id,
2153 abs_path: worktree.abs_path.clone(),
2154 root_name: worktree.root_name,
2155 updated_entries: worktree.entries,
2156 removed_entries: Default::default(),
2157 scan_id: worktree.scan_id,
2158 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2159 updated_repositories: worktree.repository_entries.into_values().collect(),
2160 removed_repositories: Default::default(),
2161 };
2162 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2163 session.peer.send(session.connection_id, update.clone())?;
2164 }
2165
2166 // Stream this worktree's diagnostics.
2167 for summary in worktree.diagnostic_summaries {
2168 session.peer.send(
2169 session.connection_id,
2170 proto::UpdateDiagnosticSummary {
2171 project_id: project_id.to_proto(),
2172 worktree_id: worktree.id,
2173 summary: Some(summary),
2174 },
2175 )?;
2176 }
2177
2178 for settings_file in worktree.settings_files {
2179 session.peer.send(
2180 session.connection_id,
2181 proto::UpdateWorktreeSettings {
2182 project_id: project_id.to_proto(),
2183 worktree_id: worktree.id,
2184 path: settings_file.path,
2185 content: Some(settings_file.content),
2186 },
2187 )?;
2188 }
2189 }
2190
2191 for language_server in &project.language_servers {
2192 session.peer.send(
2193 session.connection_id,
2194 proto::UpdateLanguageServer {
2195 project_id: project_id.to_proto(),
2196 language_server_id: language_server.id,
2197 variant: Some(
2198 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2199 proto::LspDiskBasedDiagnosticsUpdated {},
2200 ),
2201 ),
2202 },
2203 )?;
2204 }
2205
2206 Ok(())
2207}
2208
2209/// Leave someone elses shared project.
2210async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2211 let sender_id = session.connection_id;
2212 let project_id = ProjectId::from_proto(request.project_id);
2213 let db = session.db().await;
2214 if db.is_hosted_project(project_id).await? {
2215 let project = db.leave_hosted_project(project_id, sender_id).await?;
2216 project_left(&project, &session);
2217 return Ok(());
2218 }
2219
2220 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2221 tracing::info!(
2222 %project_id,
2223 "leave project"
2224 );
2225
2226 project_left(&project, &session);
2227 if let Some(room) = room {
2228 room_updated(&room, &session.peer);
2229 }
2230
2231 Ok(())
2232}
2233
2234async fn join_hosted_project(
2235 request: proto::JoinHostedProject,
2236 response: Response<proto::JoinHostedProject>,
2237 session: UserSession,
2238) -> Result<()> {
2239 let (mut project, replica_id) = session
2240 .db()
2241 .await
2242 .join_hosted_project(
2243 ProjectId(request.project_id as i32),
2244 session.user_id(),
2245 session.connection_id,
2246 )
2247 .await?;
2248
2249 join_project_internal(response, session, &mut project, &replica_id)
2250}
2251
2252async fn create_remote_project(
2253 request: proto::CreateRemoteProject,
2254 response: Response<proto::CreateRemoteProject>,
2255 session: UserSession,
2256) -> Result<()> {
2257 let dev_server_id = DevServerId(request.dev_server_id as i32);
2258 let dev_server_connection_id = session
2259 .connection_pool()
2260 .await
2261 .dev_server_connection_id(dev_server_id);
2262 let Some(dev_server_connection_id) = dev_server_connection_id else {
2263 Err(ErrorCode::DevServerOffline
2264 .message("Cannot create a remote project when the dev server is offline".to_string())
2265 .anyhow())?
2266 };
2267
2268 let path = request.path.clone();
2269 //Check that the path exists on the dev server
2270 session
2271 .peer
2272 .forward_request(
2273 session.connection_id,
2274 dev_server_connection_id,
2275 proto::ValidateRemoteProjectRequest { path: path.clone() },
2276 )
2277 .await?;
2278
2279 let (remote_project, update) = session
2280 .db()
2281 .await
2282 .create_remote_project(
2283 DevServerId(request.dev_server_id as i32),
2284 &request.path,
2285 session.user_id(),
2286 )
2287 .await?;
2288
2289 let projects = session
2290 .db()
2291 .await
2292 .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2293 .await?;
2294
2295 session.peer.send(
2296 dev_server_connection_id,
2297 proto::DevServerInstructions { projects },
2298 )?;
2299
2300 send_remote_projects_update(session.user_id(), update, &session).await;
2301
2302 response.send(proto::CreateRemoteProjectResponse {
2303 remote_project: Some(remote_project.to_proto(None)),
2304 })?;
2305 Ok(())
2306}
2307
2308async fn create_dev_server(
2309 request: proto::CreateDevServer,
2310 response: Response<proto::CreateDevServer>,
2311 session: UserSession,
2312) -> Result<()> {
2313 let access_token = auth::random_token();
2314 let hashed_access_token = auth::hash_access_token(&access_token);
2315
2316 let (dev_server, status) = session
2317 .db()
2318 .await
2319 .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2320 .await?;
2321
2322 send_remote_projects_update(session.user_id(), status, &session).await;
2323
2324 response.send(proto::CreateDevServerResponse {
2325 dev_server_id: dev_server.id.0 as u64,
2326 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2327 name: request.name.clone(),
2328 })?;
2329 Ok(())
2330}
2331
2332async fn delete_dev_server(
2333 request: proto::DeleteDevServer,
2334 response: Response<proto::DeleteDevServer>,
2335 session: UserSession,
2336) -> Result<()> {
2337 let dev_server_id = DevServerId(request.dev_server_id as i32);
2338 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2339 if dev_server.user_id != session.user_id() {
2340 return Err(anyhow!(ErrorCode::Forbidden))?;
2341 }
2342
2343 let connection_id = session
2344 .connection_pool()
2345 .await
2346 .dev_server_connection_id(dev_server_id);
2347 if let Some(connection_id) = connection_id {
2348 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2349 session
2350 .peer
2351 .send(connection_id, proto::ShutdownDevServer {})?;
2352 }
2353
2354 let status = session
2355 .db()
2356 .await
2357 .delete_dev_server(dev_server_id, session.user_id())
2358 .await?;
2359
2360 send_remote_projects_update(session.user_id(), status, &session).await;
2361
2362 response.send(proto::Ack {})?;
2363 Ok(())
2364}
2365
2366async fn rejoin_remote_projects(
2367 request: proto::RejoinRemoteProjects,
2368 response: Response<proto::RejoinRemoteProjects>,
2369 session: UserSession,
2370) -> Result<()> {
2371 let mut rejoined_projects = {
2372 let db = session.db().await;
2373 db.rejoin_remote_projects(
2374 &request.rejoined_projects,
2375 session.user_id(),
2376 session.0.connection_id,
2377 )
2378 .await?
2379 };
2380 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2381
2382 response.send(proto::RejoinRemoteProjectsResponse {
2383 rejoined_projects: rejoined_projects
2384 .into_iter()
2385 .map(|project| project.to_proto())
2386 .collect(),
2387 })
2388}
2389
2390async fn reconnect_dev_server(
2391 request: proto::ReconnectDevServer,
2392 response: Response<proto::ReconnectDevServer>,
2393 session: DevServerSession,
2394) -> Result<()> {
2395 let reshared_projects = {
2396 let db = session.db().await;
2397 db.reshare_remote_projects(
2398 &request.reshared_projects,
2399 session.dev_server_id(),
2400 session.0.connection_id,
2401 )
2402 .await?
2403 };
2404
2405 for project in &reshared_projects {
2406 for collaborator in &project.collaborators {
2407 session
2408 .peer
2409 .send(
2410 collaborator.connection_id,
2411 proto::UpdateProjectCollaborator {
2412 project_id: project.id.to_proto(),
2413 old_peer_id: Some(project.old_connection_id.into()),
2414 new_peer_id: Some(session.connection_id.into()),
2415 },
2416 )
2417 .trace_err();
2418 }
2419
2420 broadcast(
2421 Some(session.connection_id),
2422 project
2423 .collaborators
2424 .iter()
2425 .map(|collaborator| collaborator.connection_id),
2426 |connection_id| {
2427 session.peer.forward_send(
2428 session.connection_id,
2429 connection_id,
2430 proto::UpdateProject {
2431 project_id: project.id.to_proto(),
2432 worktrees: project.worktrees.clone(),
2433 },
2434 )
2435 },
2436 );
2437 }
2438
2439 response.send(proto::ReconnectDevServerResponse {
2440 reshared_projects: reshared_projects
2441 .iter()
2442 .map(|project| proto::ResharedProject {
2443 id: project.id.to_proto(),
2444 collaborators: project
2445 .collaborators
2446 .iter()
2447 .map(|collaborator| collaborator.to_proto())
2448 .collect(),
2449 })
2450 .collect(),
2451 })?;
2452
2453 Ok(())
2454}
2455
2456async fn shutdown_dev_server(
2457 _: proto::ShutdownDevServer,
2458 response: Response<proto::ShutdownDevServer>,
2459 session: DevServerSession,
2460) -> Result<()> {
2461 response.send(proto::Ack {})?;
2462 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2463}
2464
2465async fn shutdown_dev_server_internal(
2466 dev_server_id: DevServerId,
2467 connection_id: ConnectionId,
2468 session: &Session,
2469) -> Result<()> {
2470 let (remote_projects, dev_server) = {
2471 let db = session.db().await;
2472 let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2473 let dev_server = db.get_dev_server(dev_server_id).await?;
2474 (remote_projects, dev_server)
2475 };
2476
2477 for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2478 unshare_project_internal(
2479 ProjectId::from_proto(project_id),
2480 connection_id,
2481 None,
2482 session,
2483 )
2484 .await?;
2485 }
2486
2487 session
2488 .connection_pool()
2489 .await
2490 .set_dev_server_offline(dev_server_id);
2491
2492 let status = session
2493 .db()
2494 .await
2495 .remote_projects_update(dev_server.user_id)
2496 .await?;
2497 send_remote_projects_update(dev_server.user_id, status, &session).await;
2498
2499 Ok(())
2500}
2501
2502/// Updates other participants with changes to the project
2503async fn update_project(
2504 request: proto::UpdateProject,
2505 response: Response<proto::UpdateProject>,
2506 session: Session,
2507) -> Result<()> {
2508 let project_id = ProjectId::from_proto(request.project_id);
2509 let (room, guest_connection_ids) = &*session
2510 .db()
2511 .await
2512 .update_project(project_id, session.connection_id, &request.worktrees)
2513 .await?;
2514 broadcast(
2515 Some(session.connection_id),
2516 guest_connection_ids.iter().copied(),
2517 |connection_id| {
2518 session
2519 .peer
2520 .forward_send(session.connection_id, connection_id, request.clone())
2521 },
2522 );
2523 if let Some(room) = room {
2524 room_updated(&room, &session.peer);
2525 }
2526 response.send(proto::Ack {})?;
2527
2528 Ok(())
2529}
2530
2531/// Updates other participants with changes to the worktree
2532async fn update_worktree(
2533 request: proto::UpdateWorktree,
2534 response: Response<proto::UpdateWorktree>,
2535 session: Session,
2536) -> Result<()> {
2537 let guest_connection_ids = session
2538 .db()
2539 .await
2540 .update_worktree(&request, session.connection_id)
2541 .await?;
2542
2543 broadcast(
2544 Some(session.connection_id),
2545 guest_connection_ids.iter().copied(),
2546 |connection_id| {
2547 session
2548 .peer
2549 .forward_send(session.connection_id, connection_id, request.clone())
2550 },
2551 );
2552 response.send(proto::Ack {})?;
2553 Ok(())
2554}
2555
2556/// Updates other participants with changes to the diagnostics
2557async fn update_diagnostic_summary(
2558 message: proto::UpdateDiagnosticSummary,
2559 session: Session,
2560) -> Result<()> {
2561 let guest_connection_ids = session
2562 .db()
2563 .await
2564 .update_diagnostic_summary(&message, session.connection_id)
2565 .await?;
2566
2567 broadcast(
2568 Some(session.connection_id),
2569 guest_connection_ids.iter().copied(),
2570 |connection_id| {
2571 session
2572 .peer
2573 .forward_send(session.connection_id, connection_id, message.clone())
2574 },
2575 );
2576
2577 Ok(())
2578}
2579
2580/// Updates other participants with changes to the worktree settings
2581async fn update_worktree_settings(
2582 message: proto::UpdateWorktreeSettings,
2583 session: Session,
2584) -> Result<()> {
2585 let guest_connection_ids = session
2586 .db()
2587 .await
2588 .update_worktree_settings(&message, session.connection_id)
2589 .await?;
2590
2591 broadcast(
2592 Some(session.connection_id),
2593 guest_connection_ids.iter().copied(),
2594 |connection_id| {
2595 session
2596 .peer
2597 .forward_send(session.connection_id, connection_id, message.clone())
2598 },
2599 );
2600
2601 Ok(())
2602}
2603
2604/// Notify other participants that a language server has started.
2605async fn start_language_server(
2606 request: proto::StartLanguageServer,
2607 session: Session,
2608) -> Result<()> {
2609 let guest_connection_ids = session
2610 .db()
2611 .await
2612 .start_language_server(&request, session.connection_id)
2613 .await?;
2614
2615 broadcast(
2616 Some(session.connection_id),
2617 guest_connection_ids.iter().copied(),
2618 |connection_id| {
2619 session
2620 .peer
2621 .forward_send(session.connection_id, connection_id, request.clone())
2622 },
2623 );
2624 Ok(())
2625}
2626
2627/// Notify other participants that a language server has changed.
2628async fn update_language_server(
2629 request: proto::UpdateLanguageServer,
2630 session: Session,
2631) -> Result<()> {
2632 let project_id = ProjectId::from_proto(request.project_id);
2633 let project_connection_ids = session
2634 .db()
2635 .await
2636 .project_connection_ids(project_id, session.connection_id, true)
2637 .await?;
2638 broadcast(
2639 Some(session.connection_id),
2640 project_connection_ids.iter().copied(),
2641 |connection_id| {
2642 session
2643 .peer
2644 .forward_send(session.connection_id, connection_id, request.clone())
2645 },
2646 );
2647 Ok(())
2648}
2649
2650/// forward a project request to the host. These requests should be read only
2651/// as guests are allowed to send them.
2652async fn forward_read_only_project_request<T>(
2653 request: T,
2654 response: Response<T>,
2655 session: UserSession,
2656) -> Result<()>
2657where
2658 T: EntityMessage + RequestMessage,
2659{
2660 let project_id = ProjectId::from_proto(request.remote_entity_id());
2661 let host_connection_id = session
2662 .db()
2663 .await
2664 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2665 .await?;
2666 let payload = session
2667 .peer
2668 .forward_request(session.connection_id, host_connection_id, request)
2669 .await?;
2670 response.send(payload)?;
2671 Ok(())
2672}
2673
2674/// forward a project request to the host. These requests are disallowed
2675/// for guests.
2676async fn forward_mutating_project_request<T>(
2677 request: T,
2678 response: Response<T>,
2679 session: UserSession,
2680) -> Result<()>
2681where
2682 T: EntityMessage + RequestMessage,
2683{
2684 let project_id = ProjectId::from_proto(request.remote_entity_id());
2685
2686 let host_connection_id = session
2687 .db()
2688 .await
2689 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2690 .await?;
2691 let payload = session
2692 .peer
2693 .forward_request(session.connection_id, host_connection_id, request)
2694 .await?;
2695 response.send(payload)?;
2696 Ok(())
2697}
2698
2699/// forward a project request to the host. These requests are disallowed
2700/// for guests.
2701async fn forward_versioned_mutating_project_request<T>(
2702 request: T,
2703 response: Response<T>,
2704 session: UserSession,
2705) -> Result<()>
2706where
2707 T: EntityMessage + RequestMessage + VersionedMessage,
2708{
2709 let project_id = ProjectId::from_proto(request.remote_entity_id());
2710
2711 let host_connection_id = session
2712 .db()
2713 .await
2714 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2715 .await?;
2716 if let Some(host_version) = session
2717 .connection_pool()
2718 .await
2719 .connection(host_connection_id)
2720 .map(|c| c.zed_version)
2721 {
2722 if let Some(min_required_version) = request.required_host_version() {
2723 if min_required_version > host_version {
2724 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2725 .with_tag("required", &min_required_version.to_string())))?;
2726 }
2727 }
2728 }
2729
2730 let payload = session
2731 .peer
2732 .forward_request(session.connection_id, host_connection_id, request)
2733 .await?;
2734 response.send(payload)?;
2735 Ok(())
2736}
2737
2738/// Notify other participants that a new buffer has been created
2739async fn create_buffer_for_peer(
2740 request: proto::CreateBufferForPeer,
2741 session: Session,
2742) -> Result<()> {
2743 session
2744 .db()
2745 .await
2746 .check_user_is_project_host(
2747 ProjectId::from_proto(request.project_id),
2748 session.connection_id,
2749 )
2750 .await?;
2751 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2752 session
2753 .peer
2754 .forward_send(session.connection_id, peer_id.into(), request)?;
2755 Ok(())
2756}
2757
2758/// Notify other participants that a buffer has been updated. This is
2759/// allowed for guests as long as the update is limited to selections.
2760async fn update_buffer(
2761 request: proto::UpdateBuffer,
2762 response: Response<proto::UpdateBuffer>,
2763 session: Session,
2764) -> Result<()> {
2765 let project_id = ProjectId::from_proto(request.project_id);
2766 let mut capability = Capability::ReadOnly;
2767
2768 for op in request.operations.iter() {
2769 match op.variant {
2770 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2771 Some(_) => capability = Capability::ReadWrite,
2772 }
2773 }
2774
2775 let host = {
2776 let guard = session
2777 .db()
2778 .await
2779 .connections_for_buffer_update(
2780 project_id,
2781 session.principal_id(),
2782 session.connection_id,
2783 capability,
2784 )
2785 .await?;
2786
2787 let (host, guests) = &*guard;
2788
2789 broadcast(
2790 Some(session.connection_id),
2791 guests.clone(),
2792 |connection_id| {
2793 session
2794 .peer
2795 .forward_send(session.connection_id, connection_id, request.clone())
2796 },
2797 );
2798
2799 *host
2800 };
2801
2802 if host != session.connection_id {
2803 session
2804 .peer
2805 .forward_request(session.connection_id, host, request.clone())
2806 .await?;
2807 }
2808
2809 response.send(proto::Ack {})?;
2810 Ok(())
2811}
2812
2813/// Notify other participants that a project has been updated.
2814async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2815 request: T,
2816 session: Session,
2817) -> Result<()> {
2818 let project_id = ProjectId::from_proto(request.remote_entity_id());
2819 let project_connection_ids = session
2820 .db()
2821 .await
2822 .project_connection_ids(project_id, session.connection_id, false)
2823 .await?;
2824
2825 broadcast(
2826 Some(session.connection_id),
2827 project_connection_ids.iter().copied(),
2828 |connection_id| {
2829 session
2830 .peer
2831 .forward_send(session.connection_id, connection_id, request.clone())
2832 },
2833 );
2834 Ok(())
2835}
2836
2837/// Start following another user in a call.
2838async fn follow(
2839 request: proto::Follow,
2840 response: Response<proto::Follow>,
2841 session: UserSession,
2842) -> Result<()> {
2843 let room_id = RoomId::from_proto(request.room_id);
2844 let project_id = request.project_id.map(ProjectId::from_proto);
2845 let leader_id = request
2846 .leader_id
2847 .ok_or_else(|| anyhow!("invalid leader id"))?
2848 .into();
2849 let follower_id = session.connection_id;
2850
2851 session
2852 .db()
2853 .await
2854 .check_room_participants(room_id, leader_id, session.connection_id)
2855 .await?;
2856
2857 let response_payload = session
2858 .peer
2859 .forward_request(session.connection_id, leader_id, request)
2860 .await?;
2861 response.send(response_payload)?;
2862
2863 if let Some(project_id) = project_id {
2864 let room = session
2865 .db()
2866 .await
2867 .follow(room_id, project_id, leader_id, follower_id)
2868 .await?;
2869 room_updated(&room, &session.peer);
2870 }
2871
2872 Ok(())
2873}
2874
2875/// Stop following another user in a call.
2876async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2877 let room_id = RoomId::from_proto(request.room_id);
2878 let project_id = request.project_id.map(ProjectId::from_proto);
2879 let leader_id = request
2880 .leader_id
2881 .ok_or_else(|| anyhow!("invalid leader id"))?
2882 .into();
2883 let follower_id = session.connection_id;
2884
2885 session
2886 .db()
2887 .await
2888 .check_room_participants(room_id, leader_id, session.connection_id)
2889 .await?;
2890
2891 session
2892 .peer
2893 .forward_send(session.connection_id, leader_id, request)?;
2894
2895 if let Some(project_id) = project_id {
2896 let room = session
2897 .db()
2898 .await
2899 .unfollow(room_id, project_id, leader_id, follower_id)
2900 .await?;
2901 room_updated(&room, &session.peer);
2902 }
2903
2904 Ok(())
2905}
2906
2907/// Notify everyone following you of your current location.
2908async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2909 let room_id = RoomId::from_proto(request.room_id);
2910 let database = session.db.lock().await;
2911
2912 let connection_ids = if let Some(project_id) = request.project_id {
2913 let project_id = ProjectId::from_proto(project_id);
2914 database
2915 .project_connection_ids(project_id, session.connection_id, true)
2916 .await?
2917 } else {
2918 database
2919 .room_connection_ids(room_id, session.connection_id)
2920 .await?
2921 };
2922
2923 // For now, don't send view update messages back to that view's current leader.
2924 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2925 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2926 _ => None,
2927 });
2928
2929 for connection_id in connection_ids.iter().cloned() {
2930 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2931 session
2932 .peer
2933 .forward_send(session.connection_id, connection_id, request.clone())?;
2934 }
2935 }
2936 Ok(())
2937}
2938
2939/// Get public data about users.
2940async fn get_users(
2941 request: proto::GetUsers,
2942 response: Response<proto::GetUsers>,
2943 session: Session,
2944) -> Result<()> {
2945 let user_ids = request
2946 .user_ids
2947 .into_iter()
2948 .map(UserId::from_proto)
2949 .collect();
2950 let users = session
2951 .db()
2952 .await
2953 .get_users_by_ids(user_ids)
2954 .await?
2955 .into_iter()
2956 .map(|user| proto::User {
2957 id: user.id.to_proto(),
2958 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2959 github_login: user.github_login,
2960 })
2961 .collect();
2962 response.send(proto::UsersResponse { users })?;
2963 Ok(())
2964}
2965
2966/// Search for users (to invite) buy Github login
2967async fn fuzzy_search_users(
2968 request: proto::FuzzySearchUsers,
2969 response: Response<proto::FuzzySearchUsers>,
2970 session: UserSession,
2971) -> Result<()> {
2972 let query = request.query;
2973 let users = match query.len() {
2974 0 => vec![],
2975 1 | 2 => session
2976 .db()
2977 .await
2978 .get_user_by_github_login(&query)
2979 .await?
2980 .into_iter()
2981 .collect(),
2982 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2983 };
2984 let users = users
2985 .into_iter()
2986 .filter(|user| user.id != session.user_id())
2987 .map(|user| proto::User {
2988 id: user.id.to_proto(),
2989 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2990 github_login: user.github_login,
2991 })
2992 .collect();
2993 response.send(proto::UsersResponse { users })?;
2994 Ok(())
2995}
2996
2997/// Send a contact request to another user.
2998async fn request_contact(
2999 request: proto::RequestContact,
3000 response: Response<proto::RequestContact>,
3001 session: UserSession,
3002) -> Result<()> {
3003 let requester_id = session.user_id();
3004 let responder_id = UserId::from_proto(request.responder_id);
3005 if requester_id == responder_id {
3006 return Err(anyhow!("cannot add yourself as a contact"))?;
3007 }
3008
3009 let notifications = session
3010 .db()
3011 .await
3012 .send_contact_request(requester_id, responder_id)
3013 .await?;
3014
3015 // Update outgoing contact requests of requester
3016 let mut update = proto::UpdateContacts::default();
3017 update.outgoing_requests.push(responder_id.to_proto());
3018 for connection_id in session
3019 .connection_pool()
3020 .await
3021 .user_connection_ids(requester_id)
3022 {
3023 session.peer.send(connection_id, update.clone())?;
3024 }
3025
3026 // Update incoming contact requests of responder
3027 let mut update = proto::UpdateContacts::default();
3028 update
3029 .incoming_requests
3030 .push(proto::IncomingContactRequest {
3031 requester_id: requester_id.to_proto(),
3032 });
3033 let connection_pool = session.connection_pool().await;
3034 for connection_id in connection_pool.user_connection_ids(responder_id) {
3035 session.peer.send(connection_id, update.clone())?;
3036 }
3037
3038 send_notifications(&connection_pool, &session.peer, notifications);
3039
3040 response.send(proto::Ack {})?;
3041 Ok(())
3042}
3043
3044/// Accept or decline a contact request
3045async fn respond_to_contact_request(
3046 request: proto::RespondToContactRequest,
3047 response: Response<proto::RespondToContactRequest>,
3048 session: UserSession,
3049) -> Result<()> {
3050 let responder_id = session.user_id();
3051 let requester_id = UserId::from_proto(request.requester_id);
3052 let db = session.db().await;
3053 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3054 db.dismiss_contact_notification(responder_id, requester_id)
3055 .await?;
3056 } else {
3057 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3058
3059 let notifications = db
3060 .respond_to_contact_request(responder_id, requester_id, accept)
3061 .await?;
3062 let requester_busy = db.is_user_busy(requester_id).await?;
3063 let responder_busy = db.is_user_busy(responder_id).await?;
3064
3065 let pool = session.connection_pool().await;
3066 // Update responder with new contact
3067 let mut update = proto::UpdateContacts::default();
3068 if accept {
3069 update
3070 .contacts
3071 .push(contact_for_user(requester_id, requester_busy, &pool));
3072 }
3073 update
3074 .remove_incoming_requests
3075 .push(requester_id.to_proto());
3076 for connection_id in pool.user_connection_ids(responder_id) {
3077 session.peer.send(connection_id, update.clone())?;
3078 }
3079
3080 // Update requester with new contact
3081 let mut update = proto::UpdateContacts::default();
3082 if accept {
3083 update
3084 .contacts
3085 .push(contact_for_user(responder_id, responder_busy, &pool));
3086 }
3087 update
3088 .remove_outgoing_requests
3089 .push(responder_id.to_proto());
3090
3091 for connection_id in pool.user_connection_ids(requester_id) {
3092 session.peer.send(connection_id, update.clone())?;
3093 }
3094
3095 send_notifications(&pool, &session.peer, notifications);
3096 }
3097
3098 response.send(proto::Ack {})?;
3099 Ok(())
3100}
3101
3102/// Remove a contact.
3103async fn remove_contact(
3104 request: proto::RemoveContact,
3105 response: Response<proto::RemoveContact>,
3106 session: UserSession,
3107) -> Result<()> {
3108 let requester_id = session.user_id();
3109 let responder_id = UserId::from_proto(request.user_id);
3110 let db = session.db().await;
3111 let (contact_accepted, deleted_notification_id) =
3112 db.remove_contact(requester_id, responder_id).await?;
3113
3114 let pool = session.connection_pool().await;
3115 // Update outgoing contact requests of requester
3116 let mut update = proto::UpdateContacts::default();
3117 if contact_accepted {
3118 update.remove_contacts.push(responder_id.to_proto());
3119 } else {
3120 update
3121 .remove_outgoing_requests
3122 .push(responder_id.to_proto());
3123 }
3124 for connection_id in pool.user_connection_ids(requester_id) {
3125 session.peer.send(connection_id, update.clone())?;
3126 }
3127
3128 // Update incoming contact requests of responder
3129 let mut update = proto::UpdateContacts::default();
3130 if contact_accepted {
3131 update.remove_contacts.push(requester_id.to_proto());
3132 } else {
3133 update
3134 .remove_incoming_requests
3135 .push(requester_id.to_proto());
3136 }
3137 for connection_id in pool.user_connection_ids(responder_id) {
3138 session.peer.send(connection_id, update.clone())?;
3139 if let Some(notification_id) = deleted_notification_id {
3140 session.peer.send(
3141 connection_id,
3142 proto::DeleteNotification {
3143 notification_id: notification_id.to_proto(),
3144 },
3145 )?;
3146 }
3147 }
3148
3149 response.send(proto::Ack {})?;
3150 Ok(())
3151}
3152
3153/// Creates a new channel.
3154async fn create_channel(
3155 request: proto::CreateChannel,
3156 response: Response<proto::CreateChannel>,
3157 session: UserSession,
3158) -> Result<()> {
3159 let db = session.db().await;
3160
3161 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3162 let (channel, membership) = db
3163 .create_channel(&request.name, parent_id, session.user_id())
3164 .await?;
3165
3166 let root_id = channel.root_id();
3167 let channel = Channel::from_model(channel);
3168
3169 response.send(proto::CreateChannelResponse {
3170 channel: Some(channel.to_proto()),
3171 parent_id: request.parent_id,
3172 })?;
3173
3174 let mut connection_pool = session.connection_pool().await;
3175 if let Some(membership) = membership {
3176 connection_pool.subscribe_to_channel(
3177 membership.user_id,
3178 membership.channel_id,
3179 membership.role,
3180 );
3181 let update = proto::UpdateUserChannels {
3182 channel_memberships: vec![proto::ChannelMembership {
3183 channel_id: membership.channel_id.to_proto(),
3184 role: membership.role.into(),
3185 }],
3186 ..Default::default()
3187 };
3188 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3189 session.peer.send(connection_id, update.clone())?;
3190 }
3191 }
3192
3193 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3194 if !role.can_see_channel(channel.visibility) {
3195 continue;
3196 }
3197
3198 let update = proto::UpdateChannels {
3199 channels: vec![channel.to_proto()],
3200 ..Default::default()
3201 };
3202 session.peer.send(connection_id, update.clone())?;
3203 }
3204
3205 Ok(())
3206}
3207
3208/// Delete a channel
3209async fn delete_channel(
3210 request: proto::DeleteChannel,
3211 response: Response<proto::DeleteChannel>,
3212 session: UserSession,
3213) -> Result<()> {
3214 let db = session.db().await;
3215
3216 let channel_id = request.channel_id;
3217 let (root_channel, removed_channels) = db
3218 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3219 .await?;
3220 response.send(proto::Ack {})?;
3221
3222 // Notify members of removed channels
3223 let mut update = proto::UpdateChannels::default();
3224 update
3225 .delete_channels
3226 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3227
3228 let connection_pool = session.connection_pool().await;
3229 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3230 session.peer.send(connection_id, update.clone())?;
3231 }
3232
3233 Ok(())
3234}
3235
3236/// Invite someone to join a channel.
3237async fn invite_channel_member(
3238 request: proto::InviteChannelMember,
3239 response: Response<proto::InviteChannelMember>,
3240 session: UserSession,
3241) -> Result<()> {
3242 let db = session.db().await;
3243 let channel_id = ChannelId::from_proto(request.channel_id);
3244 let invitee_id = UserId::from_proto(request.user_id);
3245 let InviteMemberResult {
3246 channel,
3247 notifications,
3248 } = db
3249 .invite_channel_member(
3250 channel_id,
3251 invitee_id,
3252 session.user_id(),
3253 request.role().into(),
3254 )
3255 .await?;
3256
3257 let update = proto::UpdateChannels {
3258 channel_invitations: vec![channel.to_proto()],
3259 ..Default::default()
3260 };
3261
3262 let connection_pool = session.connection_pool().await;
3263 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3264 session.peer.send(connection_id, update.clone())?;
3265 }
3266
3267 send_notifications(&connection_pool, &session.peer, notifications);
3268
3269 response.send(proto::Ack {})?;
3270 Ok(())
3271}
3272
3273/// remove someone from a channel
3274async fn remove_channel_member(
3275 request: proto::RemoveChannelMember,
3276 response: Response<proto::RemoveChannelMember>,
3277 session: UserSession,
3278) -> Result<()> {
3279 let db = session.db().await;
3280 let channel_id = ChannelId::from_proto(request.channel_id);
3281 let member_id = UserId::from_proto(request.user_id);
3282
3283 let RemoveChannelMemberResult {
3284 membership_update,
3285 notification_id,
3286 } = db
3287 .remove_channel_member(channel_id, member_id, session.user_id())
3288 .await?;
3289
3290 let mut connection_pool = session.connection_pool().await;
3291 notify_membership_updated(
3292 &mut connection_pool,
3293 membership_update,
3294 member_id,
3295 &session.peer,
3296 );
3297 for connection_id in connection_pool.user_connection_ids(member_id) {
3298 if let Some(notification_id) = notification_id {
3299 session
3300 .peer
3301 .send(
3302 connection_id,
3303 proto::DeleteNotification {
3304 notification_id: notification_id.to_proto(),
3305 },
3306 )
3307 .trace_err();
3308 }
3309 }
3310
3311 response.send(proto::Ack {})?;
3312 Ok(())
3313}
3314
3315/// Toggle the channel between public and private.
3316/// Care is taken to maintain the invariant that public channels only descend from public channels,
3317/// (though members-only channels can appear at any point in the hierarchy).
3318async fn set_channel_visibility(
3319 request: proto::SetChannelVisibility,
3320 response: Response<proto::SetChannelVisibility>,
3321 session: UserSession,
3322) -> Result<()> {
3323 let db = session.db().await;
3324 let channel_id = ChannelId::from_proto(request.channel_id);
3325 let visibility = request.visibility().into();
3326
3327 let channel_model = db
3328 .set_channel_visibility(channel_id, visibility, session.user_id())
3329 .await?;
3330 let root_id = channel_model.root_id();
3331 let channel = Channel::from_model(channel_model);
3332
3333 let mut connection_pool = session.connection_pool().await;
3334 for (user_id, role) in connection_pool
3335 .channel_user_ids(root_id)
3336 .collect::<Vec<_>>()
3337 .into_iter()
3338 {
3339 let update = if role.can_see_channel(channel.visibility) {
3340 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3341 proto::UpdateChannels {
3342 channels: vec![channel.to_proto()],
3343 ..Default::default()
3344 }
3345 } else {
3346 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3347 proto::UpdateChannels {
3348 delete_channels: vec![channel.id.to_proto()],
3349 ..Default::default()
3350 }
3351 };
3352
3353 for connection_id in connection_pool.user_connection_ids(user_id) {
3354 session.peer.send(connection_id, update.clone())?;
3355 }
3356 }
3357
3358 response.send(proto::Ack {})?;
3359 Ok(())
3360}
3361
3362/// Alter the role for a user in the channel.
3363async fn set_channel_member_role(
3364 request: proto::SetChannelMemberRole,
3365 response: Response<proto::SetChannelMemberRole>,
3366 session: UserSession,
3367) -> Result<()> {
3368 let db = session.db().await;
3369 let channel_id = ChannelId::from_proto(request.channel_id);
3370 let member_id = UserId::from_proto(request.user_id);
3371 let result = db
3372 .set_channel_member_role(
3373 channel_id,
3374 session.user_id(),
3375 member_id,
3376 request.role().into(),
3377 )
3378 .await?;
3379
3380 match result {
3381 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3382 let mut connection_pool = session.connection_pool().await;
3383 notify_membership_updated(
3384 &mut connection_pool,
3385 membership_update,
3386 member_id,
3387 &session.peer,
3388 )
3389 }
3390 db::SetMemberRoleResult::InviteUpdated(channel) => {
3391 let update = proto::UpdateChannels {
3392 channel_invitations: vec![channel.to_proto()],
3393 ..Default::default()
3394 };
3395
3396 for connection_id in session
3397 .connection_pool()
3398 .await
3399 .user_connection_ids(member_id)
3400 {
3401 session.peer.send(connection_id, update.clone())?;
3402 }
3403 }
3404 }
3405
3406 response.send(proto::Ack {})?;
3407 Ok(())
3408}
3409
3410/// Change the name of a channel
3411async fn rename_channel(
3412 request: proto::RenameChannel,
3413 response: Response<proto::RenameChannel>,
3414 session: UserSession,
3415) -> Result<()> {
3416 let db = session.db().await;
3417 let channel_id = ChannelId::from_proto(request.channel_id);
3418 let channel_model = db
3419 .rename_channel(channel_id, session.user_id(), &request.name)
3420 .await?;
3421 let root_id = channel_model.root_id();
3422 let channel = Channel::from_model(channel_model);
3423
3424 response.send(proto::RenameChannelResponse {
3425 channel: Some(channel.to_proto()),
3426 })?;
3427
3428 let connection_pool = session.connection_pool().await;
3429 let update = proto::UpdateChannels {
3430 channels: vec![channel.to_proto()],
3431 ..Default::default()
3432 };
3433 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3434 if role.can_see_channel(channel.visibility) {
3435 session.peer.send(connection_id, update.clone())?;
3436 }
3437 }
3438
3439 Ok(())
3440}
3441
3442/// Move a channel to a new parent.
3443async fn move_channel(
3444 request: proto::MoveChannel,
3445 response: Response<proto::MoveChannel>,
3446 session: UserSession,
3447) -> Result<()> {
3448 let channel_id = ChannelId::from_proto(request.channel_id);
3449 let to = ChannelId::from_proto(request.to);
3450
3451 let (root_id, channels) = session
3452 .db()
3453 .await
3454 .move_channel(channel_id, to, session.user_id())
3455 .await?;
3456
3457 let connection_pool = session.connection_pool().await;
3458 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3459 let channels = channels
3460 .iter()
3461 .filter_map(|channel| {
3462 if role.can_see_channel(channel.visibility) {
3463 Some(channel.to_proto())
3464 } else {
3465 None
3466 }
3467 })
3468 .collect::<Vec<_>>();
3469 if channels.is_empty() {
3470 continue;
3471 }
3472
3473 let update = proto::UpdateChannels {
3474 channels,
3475 ..Default::default()
3476 };
3477
3478 session.peer.send(connection_id, update.clone())?;
3479 }
3480
3481 response.send(Ack {})?;
3482 Ok(())
3483}
3484
3485/// Get the list of channel members
3486async fn get_channel_members(
3487 request: proto::GetChannelMembers,
3488 response: Response<proto::GetChannelMembers>,
3489 session: UserSession,
3490) -> Result<()> {
3491 let db = session.db().await;
3492 let channel_id = ChannelId::from_proto(request.channel_id);
3493 let members = db
3494 .get_channel_participant_details(channel_id, session.user_id())
3495 .await?;
3496 response.send(proto::GetChannelMembersResponse { members })?;
3497 Ok(())
3498}
3499
3500/// Accept or decline a channel invitation.
3501async fn respond_to_channel_invite(
3502 request: proto::RespondToChannelInvite,
3503 response: Response<proto::RespondToChannelInvite>,
3504 session: UserSession,
3505) -> Result<()> {
3506 let db = session.db().await;
3507 let channel_id = ChannelId::from_proto(request.channel_id);
3508 let RespondToChannelInvite {
3509 membership_update,
3510 notifications,
3511 } = db
3512 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3513 .await?;
3514
3515 let mut connection_pool = session.connection_pool().await;
3516 if let Some(membership_update) = membership_update {
3517 notify_membership_updated(
3518 &mut connection_pool,
3519 membership_update,
3520 session.user_id(),
3521 &session.peer,
3522 );
3523 } else {
3524 let update = proto::UpdateChannels {
3525 remove_channel_invitations: vec![channel_id.to_proto()],
3526 ..Default::default()
3527 };
3528
3529 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3530 session.peer.send(connection_id, update.clone())?;
3531 }
3532 };
3533
3534 send_notifications(&connection_pool, &session.peer, notifications);
3535
3536 response.send(proto::Ack {})?;
3537
3538 Ok(())
3539}
3540
3541/// Join the channels' room
3542async fn join_channel(
3543 request: proto::JoinChannel,
3544 response: Response<proto::JoinChannel>,
3545 session: UserSession,
3546) -> Result<()> {
3547 let channel_id = ChannelId::from_proto(request.channel_id);
3548 join_channel_internal(channel_id, Box::new(response), session).await
3549}
3550
3551trait JoinChannelInternalResponse {
3552 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3553}
3554impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3555 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3556 Response::<proto::JoinChannel>::send(self, result)
3557 }
3558}
3559impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3560 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3561 Response::<proto::JoinRoom>::send(self, result)
3562 }
3563}
3564
3565async fn join_channel_internal(
3566 channel_id: ChannelId,
3567 response: Box<impl JoinChannelInternalResponse>,
3568 session: UserSession,
3569) -> Result<()> {
3570 let joined_room = {
3571 let mut db = session.db().await;
3572 // If zed quits without leaving the room, and the user re-opens zed before the
3573 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3574 // room they were in.
3575 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3576 tracing::info!(
3577 stale_connection_id = %connection,
3578 "cleaning up stale connection",
3579 );
3580 drop(db);
3581 leave_room_for_session(&session, connection).await?;
3582 db = session.db().await;
3583 }
3584
3585 let (joined_room, membership_updated, role) = db
3586 .join_channel(channel_id, session.user_id(), session.connection_id)
3587 .await?;
3588
3589 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3590 let (can_publish, token) = if role == ChannelRole::Guest {
3591 (
3592 false,
3593 live_kit
3594 .guest_token(
3595 &joined_room.room.live_kit_room,
3596 &session.user_id().to_string(),
3597 )
3598 .trace_err()?,
3599 )
3600 } else {
3601 (
3602 true,
3603 live_kit
3604 .room_token(
3605 &joined_room.room.live_kit_room,
3606 &session.user_id().to_string(),
3607 )
3608 .trace_err()?,
3609 )
3610 };
3611
3612 Some(LiveKitConnectionInfo {
3613 server_url: live_kit.url().into(),
3614 token,
3615 can_publish,
3616 })
3617 });
3618
3619 response.send(proto::JoinRoomResponse {
3620 room: Some(joined_room.room.clone()),
3621 channel_id: joined_room
3622 .channel
3623 .as_ref()
3624 .map(|channel| channel.id.to_proto()),
3625 live_kit_connection_info,
3626 })?;
3627
3628 let mut connection_pool = session.connection_pool().await;
3629 if let Some(membership_updated) = membership_updated {
3630 notify_membership_updated(
3631 &mut connection_pool,
3632 membership_updated,
3633 session.user_id(),
3634 &session.peer,
3635 );
3636 }
3637
3638 room_updated(&joined_room.room, &session.peer);
3639
3640 joined_room
3641 };
3642
3643 channel_updated(
3644 &joined_room
3645 .channel
3646 .ok_or_else(|| anyhow!("channel not returned"))?,
3647 &joined_room.room,
3648 &session.peer,
3649 &*session.connection_pool().await,
3650 );
3651
3652 update_user_contacts(session.user_id(), &session).await?;
3653 Ok(())
3654}
3655
3656/// Start editing the channel notes
3657async fn join_channel_buffer(
3658 request: proto::JoinChannelBuffer,
3659 response: Response<proto::JoinChannelBuffer>,
3660 session: UserSession,
3661) -> Result<()> {
3662 let db = session.db().await;
3663 let channel_id = ChannelId::from_proto(request.channel_id);
3664
3665 let open_response = db
3666 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3667 .await?;
3668
3669 let collaborators = open_response.collaborators.clone();
3670 response.send(open_response)?;
3671
3672 let update = UpdateChannelBufferCollaborators {
3673 channel_id: channel_id.to_proto(),
3674 collaborators: collaborators.clone(),
3675 };
3676 channel_buffer_updated(
3677 session.connection_id,
3678 collaborators
3679 .iter()
3680 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3681 &update,
3682 &session.peer,
3683 );
3684
3685 Ok(())
3686}
3687
3688/// Edit the channel notes
3689async fn update_channel_buffer(
3690 request: proto::UpdateChannelBuffer,
3691 session: UserSession,
3692) -> Result<()> {
3693 let db = session.db().await;
3694 let channel_id = ChannelId::from_proto(request.channel_id);
3695
3696 let (collaborators, non_collaborators, epoch, version) = db
3697 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3698 .await?;
3699
3700 channel_buffer_updated(
3701 session.connection_id,
3702 collaborators,
3703 &proto::UpdateChannelBuffer {
3704 channel_id: channel_id.to_proto(),
3705 operations: request.operations,
3706 },
3707 &session.peer,
3708 );
3709
3710 let pool = &*session.connection_pool().await;
3711
3712 broadcast(
3713 None,
3714 non_collaborators
3715 .iter()
3716 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3717 |peer_id| {
3718 session.peer.send(
3719 peer_id,
3720 proto::UpdateChannels {
3721 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3722 channel_id: channel_id.to_proto(),
3723 epoch: epoch as u64,
3724 version: version.clone(),
3725 }],
3726 ..Default::default()
3727 },
3728 )
3729 },
3730 );
3731
3732 Ok(())
3733}
3734
3735/// Rejoin the channel notes after a connection blip
3736async fn rejoin_channel_buffers(
3737 request: proto::RejoinChannelBuffers,
3738 response: Response<proto::RejoinChannelBuffers>,
3739 session: UserSession,
3740) -> Result<()> {
3741 let db = session.db().await;
3742 let buffers = db
3743 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3744 .await?;
3745
3746 for rejoined_buffer in &buffers {
3747 let collaborators_to_notify = rejoined_buffer
3748 .buffer
3749 .collaborators
3750 .iter()
3751 .filter_map(|c| Some(c.peer_id?.into()));
3752 channel_buffer_updated(
3753 session.connection_id,
3754 collaborators_to_notify,
3755 &proto::UpdateChannelBufferCollaborators {
3756 channel_id: rejoined_buffer.buffer.channel_id,
3757 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3758 },
3759 &session.peer,
3760 );
3761 }
3762
3763 response.send(proto::RejoinChannelBuffersResponse {
3764 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3765 })?;
3766
3767 Ok(())
3768}
3769
3770/// Stop editing the channel notes
3771async fn leave_channel_buffer(
3772 request: proto::LeaveChannelBuffer,
3773 response: Response<proto::LeaveChannelBuffer>,
3774 session: UserSession,
3775) -> Result<()> {
3776 let db = session.db().await;
3777 let channel_id = ChannelId::from_proto(request.channel_id);
3778
3779 let left_buffer = db
3780 .leave_channel_buffer(channel_id, session.connection_id)
3781 .await?;
3782
3783 response.send(Ack {})?;
3784
3785 channel_buffer_updated(
3786 session.connection_id,
3787 left_buffer.connections,
3788 &proto::UpdateChannelBufferCollaborators {
3789 channel_id: channel_id.to_proto(),
3790 collaborators: left_buffer.collaborators,
3791 },
3792 &session.peer,
3793 );
3794
3795 Ok(())
3796}
3797
3798fn channel_buffer_updated<T: EnvelopedMessage>(
3799 sender_id: ConnectionId,
3800 collaborators: impl IntoIterator<Item = ConnectionId>,
3801 message: &T,
3802 peer: &Peer,
3803) {
3804 broadcast(Some(sender_id), collaborators, |peer_id| {
3805 peer.send(peer_id, message.clone())
3806 });
3807}
3808
3809fn send_notifications(
3810 connection_pool: &ConnectionPool,
3811 peer: &Peer,
3812 notifications: db::NotificationBatch,
3813) {
3814 for (user_id, notification) in notifications {
3815 for connection_id in connection_pool.user_connection_ids(user_id) {
3816 if let Err(error) = peer.send(
3817 connection_id,
3818 proto::AddNotification {
3819 notification: Some(notification.clone()),
3820 },
3821 ) {
3822 tracing::error!(
3823 "failed to send notification to {:?} {}",
3824 connection_id,
3825 error
3826 );
3827 }
3828 }
3829 }
3830}
3831
3832/// Send a message to the channel
3833async fn send_channel_message(
3834 request: proto::SendChannelMessage,
3835 response: Response<proto::SendChannelMessage>,
3836 session: UserSession,
3837) -> Result<()> {
3838 // Validate the message body.
3839 let body = request.body.trim().to_string();
3840 if body.len() > MAX_MESSAGE_LEN {
3841 return Err(anyhow!("message is too long"))?;
3842 }
3843 if body.is_empty() {
3844 return Err(anyhow!("message can't be blank"))?;
3845 }
3846
3847 // TODO: adjust mentions if body is trimmed
3848
3849 let timestamp = OffsetDateTime::now_utc();
3850 let nonce = request
3851 .nonce
3852 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3853
3854 let channel_id = ChannelId::from_proto(request.channel_id);
3855 let CreatedChannelMessage {
3856 message_id,
3857 participant_connection_ids,
3858 channel_members,
3859 notifications,
3860 } = session
3861 .db()
3862 .await
3863 .create_channel_message(
3864 channel_id,
3865 session.user_id(),
3866 &body,
3867 &request.mentions,
3868 timestamp,
3869 nonce.clone().into(),
3870 match request.reply_to_message_id {
3871 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3872 None => None,
3873 },
3874 )
3875 .await?;
3876
3877 let message = proto::ChannelMessage {
3878 sender_id: session.user_id().to_proto(),
3879 id: message_id.to_proto(),
3880 body,
3881 mentions: request.mentions,
3882 timestamp: timestamp.unix_timestamp() as u64,
3883 nonce: Some(nonce),
3884 reply_to_message_id: request.reply_to_message_id,
3885 edited_at: None,
3886 };
3887 broadcast(
3888 Some(session.connection_id),
3889 participant_connection_ids,
3890 |connection| {
3891 session.peer.send(
3892 connection,
3893 proto::ChannelMessageSent {
3894 channel_id: channel_id.to_proto(),
3895 message: Some(message.clone()),
3896 },
3897 )
3898 },
3899 );
3900 response.send(proto::SendChannelMessageResponse {
3901 message: Some(message),
3902 })?;
3903
3904 let pool = &*session.connection_pool().await;
3905 broadcast(
3906 None,
3907 channel_members
3908 .iter()
3909 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3910 |peer_id| {
3911 session.peer.send(
3912 peer_id,
3913 proto::UpdateChannels {
3914 latest_channel_message_ids: vec![proto::ChannelMessageId {
3915 channel_id: channel_id.to_proto(),
3916 message_id: message_id.to_proto(),
3917 }],
3918 ..Default::default()
3919 },
3920 )
3921 },
3922 );
3923 send_notifications(pool, &session.peer, notifications);
3924
3925 Ok(())
3926}
3927
3928/// Delete a channel message
3929async fn remove_channel_message(
3930 request: proto::RemoveChannelMessage,
3931 response: Response<proto::RemoveChannelMessage>,
3932 session: UserSession,
3933) -> Result<()> {
3934 let channel_id = ChannelId::from_proto(request.channel_id);
3935 let message_id = MessageId::from_proto(request.message_id);
3936 let (connection_ids, existing_notification_ids) = session
3937 .db()
3938 .await
3939 .remove_channel_message(channel_id, message_id, session.user_id())
3940 .await?;
3941
3942 broadcast(
3943 Some(session.connection_id),
3944 connection_ids,
3945 move |connection| {
3946 session.peer.send(connection, request.clone())?;
3947
3948 for notification_id in &existing_notification_ids {
3949 session.peer.send(
3950 connection,
3951 proto::DeleteNotification {
3952 notification_id: (*notification_id).to_proto(),
3953 },
3954 )?;
3955 }
3956
3957 Ok(())
3958 },
3959 );
3960 response.send(proto::Ack {})?;
3961 Ok(())
3962}
3963
3964async fn update_channel_message(
3965 request: proto::UpdateChannelMessage,
3966 response: Response<proto::UpdateChannelMessage>,
3967 session: UserSession,
3968) -> Result<()> {
3969 let channel_id = ChannelId::from_proto(request.channel_id);
3970 let message_id = MessageId::from_proto(request.message_id);
3971 let updated_at = OffsetDateTime::now_utc();
3972 let UpdatedChannelMessage {
3973 message_id,
3974 participant_connection_ids,
3975 notifications,
3976 reply_to_message_id,
3977 timestamp,
3978 deleted_mention_notification_ids,
3979 updated_mention_notifications,
3980 } = session
3981 .db()
3982 .await
3983 .update_channel_message(
3984 channel_id,
3985 message_id,
3986 session.user_id(),
3987 request.body.as_str(),
3988 &request.mentions,
3989 updated_at,
3990 )
3991 .await?;
3992
3993 let nonce = request
3994 .nonce
3995 .clone()
3996 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3997
3998 let message = proto::ChannelMessage {
3999 sender_id: session.user_id().to_proto(),
4000 id: message_id.to_proto(),
4001 body: request.body.clone(),
4002 mentions: request.mentions.clone(),
4003 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4004 nonce: Some(nonce),
4005 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4006 edited_at: Some(updated_at.unix_timestamp() as u64),
4007 };
4008
4009 response.send(proto::Ack {})?;
4010
4011 let pool = &*session.connection_pool().await;
4012 broadcast(
4013 Some(session.connection_id),
4014 participant_connection_ids,
4015 |connection| {
4016 session.peer.send(
4017 connection,
4018 proto::ChannelMessageUpdate {
4019 channel_id: channel_id.to_proto(),
4020 message: Some(message.clone()),
4021 },
4022 )?;
4023
4024 for notification_id in &deleted_mention_notification_ids {
4025 session.peer.send(
4026 connection,
4027 proto::DeleteNotification {
4028 notification_id: (*notification_id).to_proto(),
4029 },
4030 )?;
4031 }
4032
4033 for notification in &updated_mention_notifications {
4034 session.peer.send(
4035 connection,
4036 proto::UpdateNotification {
4037 notification: Some(notification.clone()),
4038 },
4039 )?;
4040 }
4041
4042 Ok(())
4043 },
4044 );
4045
4046 send_notifications(pool, &session.peer, notifications);
4047
4048 Ok(())
4049}
4050
4051/// Mark a channel message as read
4052async fn acknowledge_channel_message(
4053 request: proto::AckChannelMessage,
4054 session: UserSession,
4055) -> Result<()> {
4056 let channel_id = ChannelId::from_proto(request.channel_id);
4057 let message_id = MessageId::from_proto(request.message_id);
4058 let notifications = session
4059 .db()
4060 .await
4061 .observe_channel_message(channel_id, session.user_id(), message_id)
4062 .await?;
4063 send_notifications(
4064 &*session.connection_pool().await,
4065 &session.peer,
4066 notifications,
4067 );
4068 Ok(())
4069}
4070
4071/// Mark a buffer version as synced
4072async fn acknowledge_buffer_version(
4073 request: proto::AckBufferOperation,
4074 session: UserSession,
4075) -> Result<()> {
4076 let buffer_id = BufferId::from_proto(request.buffer_id);
4077 session
4078 .db()
4079 .await
4080 .observe_buffer_version(
4081 buffer_id,
4082 session.user_id(),
4083 request.epoch as i32,
4084 &request.version,
4085 )
4086 .await?;
4087 Ok(())
4088}
4089
4090struct CompleteWithLanguageModelRateLimit;
4091
4092impl RateLimit for CompleteWithLanguageModelRateLimit {
4093 fn capacity() -> usize {
4094 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4095 .ok()
4096 .and_then(|v| v.parse().ok())
4097 .unwrap_or(120) // Picked arbitrarily
4098 }
4099
4100 fn refill_duration() -> chrono::Duration {
4101 chrono::Duration::hours(1)
4102 }
4103
4104 fn db_name() -> &'static str {
4105 "complete-with-language-model"
4106 }
4107}
4108
4109async fn complete_with_language_model(
4110 request: proto::CompleteWithLanguageModel,
4111 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4112 session: Session,
4113 open_ai_api_key: Option<Arc<str>>,
4114 google_ai_api_key: Option<Arc<str>>,
4115 anthropic_api_key: Option<Arc<str>>,
4116) -> Result<()> {
4117 let Some(session) = session.for_user() else {
4118 return Err(anyhow!("user not found"))?;
4119 };
4120 authorize_access_to_language_models(&session).await?;
4121 session
4122 .rate_limiter
4123 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4124 .await?;
4125
4126 if request.model.starts_with("gpt") {
4127 let api_key =
4128 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4129 complete_with_open_ai(request, response, session, api_key).await?;
4130 } else if request.model.starts_with("gemini") {
4131 let api_key = google_ai_api_key
4132 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4133 complete_with_google_ai(request, response, session, api_key).await?;
4134 } else if request.model.starts_with("claude") {
4135 let api_key = anthropic_api_key
4136 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4137 complete_with_anthropic(request, response, session, api_key).await?;
4138 }
4139
4140 Ok(())
4141}
4142
4143async fn complete_with_open_ai(
4144 request: proto::CompleteWithLanguageModel,
4145 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4146 session: UserSession,
4147 api_key: Arc<str>,
4148) -> Result<()> {
4149 let mut completion_stream = open_ai::stream_completion(
4150 &session.http_client,
4151 OPEN_AI_API_URL,
4152 &api_key,
4153 crate::ai::language_model_request_to_open_ai(request)?,
4154 )
4155 .await
4156 .context("open_ai::stream_completion request failed within collab")?;
4157
4158 while let Some(event) = completion_stream.next().await {
4159 let event = event?;
4160 response.send(proto::LanguageModelResponse {
4161 choices: event
4162 .choices
4163 .into_iter()
4164 .map(|choice| proto::LanguageModelChoiceDelta {
4165 index: choice.index,
4166 delta: Some(proto::LanguageModelResponseMessage {
4167 role: choice.delta.role.map(|role| match role {
4168 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4169 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4170 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4171 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4172 } as i32),
4173 content: choice.delta.content,
4174 tool_calls: choice
4175 .delta
4176 .tool_calls
4177 .into_iter()
4178 .map(|delta| proto::ToolCallDelta {
4179 index: delta.index as u32,
4180 id: delta.id,
4181 variant: match delta.function {
4182 Some(function) => {
4183 let name = function.name;
4184 let arguments = function.arguments;
4185
4186 Some(proto::tool_call_delta::Variant::Function(
4187 proto::tool_call_delta::FunctionCallDelta {
4188 name,
4189 arguments,
4190 },
4191 ))
4192 }
4193 None => None,
4194 },
4195 })
4196 .collect(),
4197 }),
4198 finish_reason: choice.finish_reason,
4199 })
4200 .collect(),
4201 })?;
4202 }
4203
4204 Ok(())
4205}
4206
4207async fn complete_with_google_ai(
4208 request: proto::CompleteWithLanguageModel,
4209 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4210 session: UserSession,
4211 api_key: Arc<str>,
4212) -> Result<()> {
4213 let mut stream = google_ai::stream_generate_content(
4214 &session.http_client,
4215 google_ai::API_URL,
4216 api_key.as_ref(),
4217 crate::ai::language_model_request_to_google_ai(request)?,
4218 )
4219 .await
4220 .context("google_ai::stream_generate_content request failed")?;
4221
4222 while let Some(event) = stream.next().await {
4223 let event = event?;
4224 response.send(proto::LanguageModelResponse {
4225 choices: event
4226 .candidates
4227 .unwrap_or_default()
4228 .into_iter()
4229 .map(|candidate| proto::LanguageModelChoiceDelta {
4230 index: candidate.index as u32,
4231 delta: Some(proto::LanguageModelResponseMessage {
4232 role: Some(match candidate.content.role {
4233 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4234 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4235 } as i32),
4236 content: Some(
4237 candidate
4238 .content
4239 .parts
4240 .into_iter()
4241 .filter_map(|part| match part {
4242 google_ai::Part::TextPart(part) => Some(part.text),
4243 google_ai::Part::InlineDataPart(_) => None,
4244 })
4245 .collect(),
4246 ),
4247 // Tool calls are not supported for Google
4248 tool_calls: Vec::new(),
4249 }),
4250 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4251 })
4252 .collect(),
4253 })?;
4254 }
4255
4256 Ok(())
4257}
4258
4259async fn complete_with_anthropic(
4260 request: proto::CompleteWithLanguageModel,
4261 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4262 session: UserSession,
4263 api_key: Arc<str>,
4264) -> Result<()> {
4265 let model = anthropic::Model::from_id(&request.model)?;
4266
4267 let mut system_message = String::new();
4268 let messages = request
4269 .messages
4270 .into_iter()
4271 .filter_map(|message| {
4272 match message.role() {
4273 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4274 role: anthropic::Role::User,
4275 content: message.content,
4276 }),
4277 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4278 role: anthropic::Role::Assistant,
4279 content: message.content,
4280 }),
4281 // Anthropic's API breaks system instructions out as a separate field rather
4282 // than having a system message role.
4283 LanguageModelRole::LanguageModelSystem => {
4284 if !system_message.is_empty() {
4285 system_message.push_str("\n\n");
4286 }
4287 system_message.push_str(&message.content);
4288
4289 None
4290 }
4291 // We don't yet support tool calls for Anthropic
4292 LanguageModelRole::LanguageModelTool => None,
4293 }
4294 })
4295 .collect();
4296
4297 let mut stream = anthropic::stream_completion(
4298 &session.http_client,
4299 "https://api.anthropic.com",
4300 &api_key,
4301 anthropic::Request {
4302 model,
4303 messages,
4304 stream: true,
4305 system: system_message,
4306 max_tokens: 4092,
4307 },
4308 )
4309 .await?;
4310
4311 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4312
4313 while let Some(event) = stream.next().await {
4314 let event = event?;
4315
4316 match event {
4317 anthropic::ResponseEvent::MessageStart { message } => {
4318 if let Some(role) = message.role {
4319 if role == "assistant" {
4320 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4321 } else if role == "user" {
4322 current_role = proto::LanguageModelRole::LanguageModelUser;
4323 }
4324 }
4325 }
4326 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4327 match content_block {
4328 anthropic::ContentBlock::Text { text } => {
4329 if !text.is_empty() {
4330 response.send(proto::LanguageModelResponse {
4331 choices: vec![proto::LanguageModelChoiceDelta {
4332 index: 0,
4333 delta: Some(proto::LanguageModelResponseMessage {
4334 role: Some(current_role as i32),
4335 content: Some(text),
4336 tool_calls: Vec::new(),
4337 }),
4338 finish_reason: None,
4339 }],
4340 })?;
4341 }
4342 }
4343 }
4344 }
4345 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4346 anthropic::TextDelta::TextDelta { text } => {
4347 response.send(proto::LanguageModelResponse {
4348 choices: vec![proto::LanguageModelChoiceDelta {
4349 index: 0,
4350 delta: Some(proto::LanguageModelResponseMessage {
4351 role: Some(current_role as i32),
4352 content: Some(text),
4353 tool_calls: Vec::new(),
4354 }),
4355 finish_reason: None,
4356 }],
4357 })?;
4358 }
4359 },
4360 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4361 if let Some(stop_reason) = delta.stop_reason {
4362 response.send(proto::LanguageModelResponse {
4363 choices: vec![proto::LanguageModelChoiceDelta {
4364 index: 0,
4365 delta: None,
4366 finish_reason: Some(stop_reason),
4367 }],
4368 })?;
4369 }
4370 }
4371 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4372 anthropic::ResponseEvent::MessageStop {} => {}
4373 anthropic::ResponseEvent::Ping {} => {}
4374 }
4375 }
4376
4377 Ok(())
4378}
4379
4380struct CountTokensWithLanguageModelRateLimit;
4381
4382impl RateLimit for CountTokensWithLanguageModelRateLimit {
4383 fn capacity() -> usize {
4384 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4385 .ok()
4386 .and_then(|v| v.parse().ok())
4387 .unwrap_or(600) // Picked arbitrarily
4388 }
4389
4390 fn refill_duration() -> chrono::Duration {
4391 chrono::Duration::hours(1)
4392 }
4393
4394 fn db_name() -> &'static str {
4395 "count-tokens-with-language-model"
4396 }
4397}
4398
4399async fn count_tokens_with_language_model(
4400 request: proto::CountTokensWithLanguageModel,
4401 response: Response<proto::CountTokensWithLanguageModel>,
4402 session: UserSession,
4403 google_ai_api_key: Option<Arc<str>>,
4404) -> Result<()> {
4405 authorize_access_to_language_models(&session).await?;
4406
4407 if !request.model.starts_with("gemini") {
4408 return Err(anyhow!(
4409 "counting tokens for model: {:?} is not supported",
4410 request.model
4411 ))?;
4412 }
4413
4414 session
4415 .rate_limiter
4416 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4417 .await?;
4418
4419 let api_key = google_ai_api_key
4420 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4421 let tokens_response = google_ai::count_tokens(
4422 &session.http_client,
4423 google_ai::API_URL,
4424 &api_key,
4425 crate::ai::count_tokens_request_to_google_ai(request)?,
4426 )
4427 .await?;
4428 response.send(proto::CountTokensResponse {
4429 token_count: tokens_response.total_tokens as u32,
4430 })?;
4431 Ok(())
4432}
4433
4434struct ComputeEmbeddingsRateLimit;
4435
4436impl RateLimit for ComputeEmbeddingsRateLimit {
4437 fn capacity() -> usize {
4438 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4439 .ok()
4440 .and_then(|v| v.parse().ok())
4441 .unwrap_or(120) // Picked arbitrarily
4442 }
4443
4444 fn refill_duration() -> chrono::Duration {
4445 chrono::Duration::hours(1)
4446 }
4447
4448 fn db_name() -> &'static str {
4449 "compute-embeddings"
4450 }
4451}
4452
4453async fn compute_embeddings(
4454 request: proto::ComputeEmbeddings,
4455 response: Response<proto::ComputeEmbeddings>,
4456 session: UserSession,
4457 api_key: Option<Arc<str>>,
4458) -> Result<()> {
4459 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4460 authorize_access_to_language_models(&session).await?;
4461
4462 session
4463 .rate_limiter
4464 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4465 .await?;
4466
4467 let embeddings = match request.model.as_str() {
4468 "openai/text-embedding-3-small" => {
4469 open_ai::embed(
4470 &session.http_client,
4471 OPEN_AI_API_URL,
4472 &api_key,
4473 OpenAiEmbeddingModel::TextEmbedding3Small,
4474 request.texts.iter().map(|text| text.as_str()),
4475 )
4476 .await?
4477 }
4478 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4479 };
4480
4481 let embeddings = request
4482 .texts
4483 .iter()
4484 .map(|text| {
4485 let mut hasher = sha2::Sha256::new();
4486 hasher.update(text.as_bytes());
4487 let result = hasher.finalize();
4488 result.to_vec()
4489 })
4490 .zip(
4491 embeddings
4492 .data
4493 .into_iter()
4494 .map(|embedding| embedding.embedding),
4495 )
4496 .collect::<HashMap<_, _>>();
4497
4498 let db = session.db().await;
4499 db.save_embeddings(&request.model, &embeddings)
4500 .await
4501 .context("failed to save embeddings")
4502 .trace_err();
4503
4504 response.send(proto::ComputeEmbeddingsResponse {
4505 embeddings: embeddings
4506 .into_iter()
4507 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4508 .collect(),
4509 })?;
4510 Ok(())
4511}
4512
4513struct GetCachedEmbeddingsRateLimit;
4514
4515impl RateLimit for GetCachedEmbeddingsRateLimit {
4516 fn capacity() -> usize {
4517 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4518 .ok()
4519 .and_then(|v| v.parse().ok())
4520 .unwrap_or(120) // Picked arbitrarily
4521 }
4522
4523 fn refill_duration() -> chrono::Duration {
4524 chrono::Duration::hours(1)
4525 }
4526
4527 fn db_name() -> &'static str {
4528 "get-cached-embeddings"
4529 }
4530}
4531
4532async fn get_cached_embeddings(
4533 request: proto::GetCachedEmbeddings,
4534 response: Response<proto::GetCachedEmbeddings>,
4535 session: UserSession,
4536) -> Result<()> {
4537 authorize_access_to_language_models(&session).await?;
4538
4539 session
4540 .rate_limiter
4541 .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4542 .await?;
4543
4544 let db = session.db().await;
4545 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4546
4547 response.send(proto::GetCachedEmbeddingsResponse {
4548 embeddings: embeddings
4549 .into_iter()
4550 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4551 .collect(),
4552 })?;
4553 Ok(())
4554}
4555
4556async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4557 let db = session.db().await;
4558 let flags = db.get_user_flags(session.user_id()).await?;
4559 if flags.iter().any(|flag| flag == "language-models") {
4560 Ok(())
4561 } else {
4562 Err(anyhow!("permission denied"))?
4563 }
4564}
4565
4566/// Start receiving chat updates for a channel
4567async fn join_channel_chat(
4568 request: proto::JoinChannelChat,
4569 response: Response<proto::JoinChannelChat>,
4570 session: UserSession,
4571) -> Result<()> {
4572 let channel_id = ChannelId::from_proto(request.channel_id);
4573
4574 let db = session.db().await;
4575 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4576 .await?;
4577 let messages = db
4578 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4579 .await?;
4580 response.send(proto::JoinChannelChatResponse {
4581 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4582 messages,
4583 })?;
4584 Ok(())
4585}
4586
4587/// Stop receiving chat updates for a channel
4588async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4589 let channel_id = ChannelId::from_proto(request.channel_id);
4590 session
4591 .db()
4592 .await
4593 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4594 .await?;
4595 Ok(())
4596}
4597
4598/// Retrieve the chat history for a channel
4599async fn get_channel_messages(
4600 request: proto::GetChannelMessages,
4601 response: Response<proto::GetChannelMessages>,
4602 session: UserSession,
4603) -> Result<()> {
4604 let channel_id = ChannelId::from_proto(request.channel_id);
4605 let messages = session
4606 .db()
4607 .await
4608 .get_channel_messages(
4609 channel_id,
4610 session.user_id(),
4611 MESSAGE_COUNT_PER_PAGE,
4612 Some(MessageId::from_proto(request.before_message_id)),
4613 )
4614 .await?;
4615 response.send(proto::GetChannelMessagesResponse {
4616 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4617 messages,
4618 })?;
4619 Ok(())
4620}
4621
4622/// Retrieve specific chat messages
4623async fn get_channel_messages_by_id(
4624 request: proto::GetChannelMessagesById,
4625 response: Response<proto::GetChannelMessagesById>,
4626 session: UserSession,
4627) -> Result<()> {
4628 let message_ids = request
4629 .message_ids
4630 .iter()
4631 .map(|id| MessageId::from_proto(*id))
4632 .collect::<Vec<_>>();
4633 let messages = session
4634 .db()
4635 .await
4636 .get_channel_messages_by_id(session.user_id(), &message_ids)
4637 .await?;
4638 response.send(proto::GetChannelMessagesResponse {
4639 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4640 messages,
4641 })?;
4642 Ok(())
4643}
4644
4645/// Retrieve the current users notifications
4646async fn get_notifications(
4647 request: proto::GetNotifications,
4648 response: Response<proto::GetNotifications>,
4649 session: UserSession,
4650) -> Result<()> {
4651 let notifications = session
4652 .db()
4653 .await
4654 .get_notifications(
4655 session.user_id(),
4656 NOTIFICATION_COUNT_PER_PAGE,
4657 request
4658 .before_id
4659 .map(|id| db::NotificationId::from_proto(id)),
4660 )
4661 .await?;
4662 response.send(proto::GetNotificationsResponse {
4663 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4664 notifications,
4665 })?;
4666 Ok(())
4667}
4668
4669/// Mark notifications as read
4670async fn mark_notification_as_read(
4671 request: proto::MarkNotificationRead,
4672 response: Response<proto::MarkNotificationRead>,
4673 session: UserSession,
4674) -> Result<()> {
4675 let database = &session.db().await;
4676 let notifications = database
4677 .mark_notification_as_read_by_id(
4678 session.user_id(),
4679 NotificationId::from_proto(request.notification_id),
4680 )
4681 .await?;
4682 send_notifications(
4683 &*session.connection_pool().await,
4684 &session.peer,
4685 notifications,
4686 );
4687 response.send(proto::Ack {})?;
4688 Ok(())
4689}
4690
4691/// Get the current users information
4692async fn get_private_user_info(
4693 _request: proto::GetPrivateUserInfo,
4694 response: Response<proto::GetPrivateUserInfo>,
4695 session: UserSession,
4696) -> Result<()> {
4697 let db = session.db().await;
4698
4699 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4700 let user = db
4701 .get_user_by_id(session.user_id())
4702 .await?
4703 .ok_or_else(|| anyhow!("user not found"))?;
4704 let flags = db.get_user_flags(session.user_id()).await?;
4705
4706 response.send(proto::GetPrivateUserInfoResponse {
4707 metrics_id,
4708 staff: user.admin,
4709 flags,
4710 })?;
4711 Ok(())
4712}
4713
4714fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4715 match message {
4716 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4717 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4718 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4719 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4720 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4721 code: frame.code.into(),
4722 reason: frame.reason,
4723 })),
4724 }
4725}
4726
4727fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4728 match message {
4729 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4730 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4731 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4732 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4733 AxumMessage::Close(frame) => {
4734 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4735 code: frame.code.into(),
4736 reason: frame.reason,
4737 }))
4738 }
4739 }
4740}
4741
4742fn notify_membership_updated(
4743 connection_pool: &mut ConnectionPool,
4744 result: MembershipUpdated,
4745 user_id: UserId,
4746 peer: &Peer,
4747) {
4748 for membership in &result.new_channels.channel_memberships {
4749 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4750 }
4751 for channel_id in &result.removed_channels {
4752 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4753 }
4754
4755 let user_channels_update = proto::UpdateUserChannels {
4756 channel_memberships: result
4757 .new_channels
4758 .channel_memberships
4759 .iter()
4760 .map(|cm| proto::ChannelMembership {
4761 channel_id: cm.channel_id.to_proto(),
4762 role: cm.role.into(),
4763 })
4764 .collect(),
4765 ..Default::default()
4766 };
4767
4768 let mut update = build_channels_update(result.new_channels, vec![]);
4769 update.delete_channels = result
4770 .removed_channels
4771 .into_iter()
4772 .map(|id| id.to_proto())
4773 .collect();
4774 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4775
4776 for connection_id in connection_pool.user_connection_ids(user_id) {
4777 peer.send(connection_id, user_channels_update.clone())
4778 .trace_err();
4779 peer.send(connection_id, update.clone()).trace_err();
4780 }
4781}
4782
4783fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4784 proto::UpdateUserChannels {
4785 channel_memberships: channels
4786 .channel_memberships
4787 .iter()
4788 .map(|m| proto::ChannelMembership {
4789 channel_id: m.channel_id.to_proto(),
4790 role: m.role.into(),
4791 })
4792 .collect(),
4793 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4794 observed_channel_message_id: channels.observed_channel_messages.clone(),
4795 }
4796}
4797
4798fn build_channels_update(
4799 channels: ChannelsForUser,
4800 channel_invites: Vec<db::Channel>,
4801) -> proto::UpdateChannels {
4802 let mut update = proto::UpdateChannels::default();
4803
4804 for channel in channels.channels {
4805 update.channels.push(channel.to_proto());
4806 }
4807
4808 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4809 update.latest_channel_message_ids = channels.latest_channel_messages;
4810
4811 for (channel_id, participants) in channels.channel_participants {
4812 update
4813 .channel_participants
4814 .push(proto::ChannelParticipants {
4815 channel_id: channel_id.to_proto(),
4816 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4817 });
4818 }
4819
4820 for channel in channel_invites {
4821 update.channel_invitations.push(channel.to_proto());
4822 }
4823
4824 update.hosted_projects = channels.hosted_projects;
4825 update
4826}
4827
4828fn build_initial_contacts_update(
4829 contacts: Vec<db::Contact>,
4830 pool: &ConnectionPool,
4831) -> proto::UpdateContacts {
4832 let mut update = proto::UpdateContacts::default();
4833
4834 for contact in contacts {
4835 match contact {
4836 db::Contact::Accepted { user_id, busy } => {
4837 update.contacts.push(contact_for_user(user_id, busy, &pool));
4838 }
4839 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4840 db::Contact::Incoming { user_id } => {
4841 update
4842 .incoming_requests
4843 .push(proto::IncomingContactRequest {
4844 requester_id: user_id.to_proto(),
4845 })
4846 }
4847 }
4848 }
4849
4850 update
4851}
4852
4853fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4854 proto::Contact {
4855 user_id: user_id.to_proto(),
4856 online: pool.is_user_online(user_id),
4857 busy,
4858 }
4859}
4860
4861fn room_updated(room: &proto::Room, peer: &Peer) {
4862 broadcast(
4863 None,
4864 room.participants
4865 .iter()
4866 .filter_map(|participant| Some(participant.peer_id?.into())),
4867 |peer_id| {
4868 peer.send(
4869 peer_id,
4870 proto::RoomUpdated {
4871 room: Some(room.clone()),
4872 },
4873 )
4874 },
4875 );
4876}
4877
4878fn channel_updated(
4879 channel: &db::channel::Model,
4880 room: &proto::Room,
4881 peer: &Peer,
4882 pool: &ConnectionPool,
4883) {
4884 let participants = room
4885 .participants
4886 .iter()
4887 .map(|p| p.user_id)
4888 .collect::<Vec<_>>();
4889
4890 broadcast(
4891 None,
4892 pool.channel_connection_ids(channel.root_id())
4893 .filter_map(|(channel_id, role)| {
4894 role.can_see_channel(channel.visibility).then(|| channel_id)
4895 }),
4896 |peer_id| {
4897 peer.send(
4898 peer_id,
4899 proto::UpdateChannels {
4900 channel_participants: vec![proto::ChannelParticipants {
4901 channel_id: channel.id.to_proto(),
4902 participant_user_ids: participants.clone(),
4903 }],
4904 ..Default::default()
4905 },
4906 )
4907 },
4908 );
4909}
4910
4911async fn send_remote_projects_update(
4912 user_id: UserId,
4913 mut status: proto::RemoteProjectsUpdate,
4914 session: &Session,
4915) {
4916 let pool = session.connection_pool().await;
4917 for dev_server in &mut status.dev_servers {
4918 dev_server.status =
4919 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4920 }
4921 let connections = pool.user_connection_ids(user_id);
4922 for connection_id in connections {
4923 session.peer.send(connection_id, status.clone()).trace_err();
4924 }
4925}
4926
4927async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4928 let db = session.db().await;
4929
4930 let contacts = db.get_contacts(user_id).await?;
4931 let busy = db.is_user_busy(user_id).await?;
4932
4933 let pool = session.connection_pool().await;
4934 let updated_contact = contact_for_user(user_id, busy, &pool);
4935 for contact in contacts {
4936 if let db::Contact::Accepted {
4937 user_id: contact_user_id,
4938 ..
4939 } = contact
4940 {
4941 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4942 session
4943 .peer
4944 .send(
4945 contact_conn_id,
4946 proto::UpdateContacts {
4947 contacts: vec![updated_contact.clone()],
4948 remove_contacts: Default::default(),
4949 incoming_requests: Default::default(),
4950 remove_incoming_requests: Default::default(),
4951 outgoing_requests: Default::default(),
4952 remove_outgoing_requests: Default::default(),
4953 },
4954 )
4955 .trace_err();
4956 }
4957 }
4958 }
4959 Ok(())
4960}
4961
4962async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
4963 log::info!("lost dev server connection, unsharing projects");
4964 let project_ids = session
4965 .db()
4966 .await
4967 .get_stale_dev_server_projects(session.connection_id)
4968 .await?;
4969
4970 for project_id in project_ids {
4971 // not unshare re-checks the connection ids match, so we get away with no transaction
4972 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
4973 }
4974
4975 let user_id = session.dev_server().user_id;
4976 let update = session.db().await.remote_projects_update(user_id).await?;
4977
4978 send_remote_projects_update(user_id, update, session).await;
4979
4980 Ok(())
4981}
4982
4983async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4984 let mut contacts_to_update = HashSet::default();
4985
4986 let room_id;
4987 let canceled_calls_to_user_ids;
4988 let live_kit_room;
4989 let delete_live_kit_room;
4990 let room;
4991 let channel;
4992
4993 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4994 contacts_to_update.insert(session.user_id());
4995
4996 for project in left_room.left_projects.values() {
4997 project_left(project, session);
4998 }
4999
5000 room_id = RoomId::from_proto(left_room.room.id);
5001 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5002 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5003 delete_live_kit_room = left_room.deleted;
5004 room = mem::take(&mut left_room.room);
5005 channel = mem::take(&mut left_room.channel);
5006
5007 room_updated(&room, &session.peer);
5008 } else {
5009 return Ok(());
5010 }
5011
5012 if let Some(channel) = channel {
5013 channel_updated(
5014 &channel,
5015 &room,
5016 &session.peer,
5017 &*session.connection_pool().await,
5018 );
5019 }
5020
5021 {
5022 let pool = session.connection_pool().await;
5023 for canceled_user_id in canceled_calls_to_user_ids {
5024 for connection_id in pool.user_connection_ids(canceled_user_id) {
5025 session
5026 .peer
5027 .send(
5028 connection_id,
5029 proto::CallCanceled {
5030 room_id: room_id.to_proto(),
5031 },
5032 )
5033 .trace_err();
5034 }
5035 contacts_to_update.insert(canceled_user_id);
5036 }
5037 }
5038
5039 for contact_user_id in contacts_to_update {
5040 update_user_contacts(contact_user_id, &session).await?;
5041 }
5042
5043 if let Some(live_kit) = session.live_kit_client.as_ref() {
5044 live_kit
5045 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5046 .await
5047 .trace_err();
5048
5049 if delete_live_kit_room {
5050 live_kit.delete_room(live_kit_room).await.trace_err();
5051 }
5052 }
5053
5054 Ok(())
5055}
5056
5057async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5058 let left_channel_buffers = session
5059 .db()
5060 .await
5061 .leave_channel_buffers(session.connection_id)
5062 .await?;
5063
5064 for left_buffer in left_channel_buffers {
5065 channel_buffer_updated(
5066 session.connection_id,
5067 left_buffer.connections,
5068 &proto::UpdateChannelBufferCollaborators {
5069 channel_id: left_buffer.channel_id.to_proto(),
5070 collaborators: left_buffer.collaborators,
5071 },
5072 &session.peer,
5073 );
5074 }
5075
5076 Ok(())
5077}
5078
5079fn project_left(project: &db::LeftProject, session: &UserSession) {
5080 for connection_id in &project.connection_ids {
5081 if project.should_unshare {
5082 session
5083 .peer
5084 .send(
5085 *connection_id,
5086 proto::UnshareProject {
5087 project_id: project.id.to_proto(),
5088 },
5089 )
5090 .trace_err();
5091 } else {
5092 session
5093 .peer
5094 .send(
5095 *connection_id,
5096 proto::RemoveProjectCollaborator {
5097 project_id: project.id.to_proto(),
5098 peer_id: Some(session.connection_id.into()),
5099 },
5100 )
5101 .trace_err();
5102 }
5103 }
5104}
5105
5106pub trait ResultExt {
5107 type Ok;
5108
5109 fn trace_err(self) -> Option<Self::Ok>;
5110}
5111
5112impl<T, E> ResultExt for Result<T, E>
5113where
5114 E: std::fmt::Debug,
5115{
5116 type Ok = T;
5117
5118 #[track_caller]
5119 fn trace_err(self) -> Option<T> {
5120 match self {
5121 Ok(value) => Some(value),
5122 Err(error) => {
5123 tracing::error!("{:?}", error);
5124 None
5125 }
5126 }
5127 }
5128}