1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 AppState, Error, Result, auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14};
15use anyhow::{Context as _, anyhow, bail};
16use async_tungstenite::tungstenite::{
17 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
18};
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use reqwest_client::ReqwestClient;
37use rpc::proto::split_repository_update;
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use time::OffsetDateTime;
69use tokio::sync::{MutexGuard, Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone, Debug)]
104pub enum Principal {
105 User(User),
106 Impersonated { user: User, admin: User },
107}
108
109impl Principal {
110 fn update_span(&self, span: &tracing::Span) {
111 match &self {
112 Principal::User(user) => {
113 span.record("user_id", user.id.0);
114 span.record("login", &user.github_login);
115 }
116 Principal::Impersonated { user, admin } => {
117 span.record("user_id", user.id.0);
118 span.record("login", &user.github_login);
119 span.record("impersonator", &admin.github_login);
120 }
121 }
122 }
123}
124
125#[derive(Clone)]
126struct Session {
127 principal: Principal,
128 connection_id: ConnectionId,
129 db: Arc<tokio::sync::Mutex<DbHandle>>,
130 peer: Arc<Peer>,
131 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
132 app_state: Arc<AppState>,
133 supermaven_client: Option<Arc<SupermavenAdminApi>>,
134 /// The GeoIP country code for the user.
135 #[allow(unused)]
136 geoip_country_code: Option<String>,
137 system_id: Option<String>,
138 _executor: Executor,
139}
140
141impl Session {
142 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
143 #[cfg(test)]
144 tokio::task::yield_now().await;
145 let guard = self.db.lock().await;
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 guard
149 }
150
151 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.connection_pool.lock();
155 ConnectionPoolGuard {
156 guard,
157 _not_send: PhantomData,
158 }
159 }
160
161 fn is_staff(&self) -> bool {
162 match &self.principal {
163 Principal::User(user) => user.admin,
164 Principal::Impersonated { .. } => true,
165 }
166 }
167
168 pub async fn has_llm_subscription(
169 &self,
170 db: &MutexGuard<'_, DbHandle>,
171 ) -> anyhow::Result<bool> {
172 if self.is_staff() {
173 return Ok(true);
174 }
175
176 let user_id = self.user_id();
177
178 Ok(db.has_active_billing_subscription(user_id).await?)
179 }
180
181 pub async fn current_plan(
182 &self,
183 _db: &MutexGuard<'_, DbHandle>,
184 ) -> anyhow::Result<proto::Plan> {
185 if self.is_staff() {
186 Ok(proto::Plan::ZedPro)
187 } else {
188 Ok(proto::Plan::Free)
189 }
190 }
191
192 fn user_id(&self) -> UserId {
193 match &self.principal {
194 Principal::User(user) => user.id,
195 Principal::Impersonated { user, .. } => user.id,
196 }
197 }
198
199 pub fn email(&self) -> Option<String> {
200 match &self.principal {
201 Principal::User(user) => user.email_address.clone(),
202 Principal::Impersonated { user, .. } => user.email_address.clone(),
203 }
204 }
205}
206
207impl Debug for Session {
208 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
209 let mut result = f.debug_struct("Session");
210 match &self.principal {
211 Principal::User(user) => {
212 result.field("user", &user.github_login);
213 }
214 Principal::Impersonated { user, admin } => {
215 result.field("user", &user.github_login);
216 result.field("impersonator", &admin.github_login);
217 }
218 }
219 result.field("connection_id", &self.connection_id).finish()
220 }
221}
222
223struct DbHandle(Arc<Database>);
224
225impl Deref for DbHandle {
226 type Target = Database;
227
228 fn deref(&self) -> &Self::Target {
229 self.0.as_ref()
230 }
231}
232
233pub struct Server {
234 id: parking_lot::Mutex<ServerId>,
235 peer: Arc<Peer>,
236 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
237 app_state: Arc<AppState>,
238 handlers: HashMap<TypeId, MessageHandler>,
239 teardown: watch::Sender<bool>,
240}
241
242pub(crate) struct ConnectionPoolGuard<'a> {
243 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
244 _not_send: PhantomData<Rc<()>>,
245}
246
247#[derive(Serialize)]
248pub struct ServerSnapshot<'a> {
249 peer: &'a Peer,
250 #[serde(serialize_with = "serialize_deref")]
251 connection_pool: ConnectionPoolGuard<'a>,
252}
253
254pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
255where
256 S: Serializer,
257 T: Deref<Target = U>,
258 U: Serialize,
259{
260 Serialize::serialize(value.deref(), serializer)
261}
262
263impl Server {
264 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
265 let mut server = Self {
266 id: parking_lot::Mutex::new(id),
267 peer: Peer::new(id.0 as u32),
268 app_state: app_state.clone(),
269 connection_pool: Default::default(),
270 handlers: Default::default(),
271 teardown: watch::channel(false).0,
272 };
273
274 server
275 .add_request_handler(ping)
276 .add_request_handler(create_room)
277 .add_request_handler(join_room)
278 .add_request_handler(rejoin_room)
279 .add_request_handler(leave_room)
280 .add_request_handler(set_room_participant_role)
281 .add_request_handler(call)
282 .add_request_handler(cancel_call)
283 .add_message_handler(decline_call)
284 .add_request_handler(update_participant_location)
285 .add_request_handler(share_project)
286 .add_message_handler(unshare_project)
287 .add_request_handler(join_project)
288 .add_message_handler(leave_project)
289 .add_request_handler(update_project)
290 .add_request_handler(update_worktree)
291 .add_request_handler(update_repository)
292 .add_request_handler(remove_repository)
293 .add_message_handler(start_language_server)
294 .add_message_handler(update_language_server)
295 .add_message_handler(update_diagnostic_summary)
296 .add_message_handler(update_worktree_settings)
297 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
298 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
301 .add_request_handler(forward_find_search_candidates_request)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
315 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
316 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
317 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
318 .add_request_handler(
319 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
320 )
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LanguageServerIdForName>,
323 )
324 .add_request_handler(
325 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
329 .add_request_handler(
330 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
331 )
332 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
337 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
338 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
339 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
340 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
341 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
342 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
343 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
344 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
345 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
346 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
348 .add_request_handler(
349 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
350 )
351 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
352 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
353 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
354 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
355 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
356 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
357 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
358 .add_message_handler(create_buffer_for_peer)
359 .add_request_handler(update_buffer)
360 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
361 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
362 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
363 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
364 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
366 .add_request_handler(get_users)
367 .add_request_handler(fuzzy_search_users)
368 .add_request_handler(request_contact)
369 .add_request_handler(remove_contact)
370 .add_request_handler(respond_to_contact_request)
371 .add_message_handler(subscribe_to_channels)
372 .add_request_handler(create_channel)
373 .add_request_handler(delete_channel)
374 .add_request_handler(invite_channel_member)
375 .add_request_handler(remove_channel_member)
376 .add_request_handler(set_channel_member_role)
377 .add_request_handler(set_channel_visibility)
378 .add_request_handler(rename_channel)
379 .add_request_handler(join_channel_buffer)
380 .add_request_handler(leave_channel_buffer)
381 .add_message_handler(update_channel_buffer)
382 .add_request_handler(rejoin_channel_buffers)
383 .add_request_handler(get_channel_members)
384 .add_request_handler(respond_to_channel_invite)
385 .add_request_handler(join_channel)
386 .add_request_handler(join_channel_chat)
387 .add_message_handler(leave_channel_chat)
388 .add_request_handler(send_channel_message)
389 .add_request_handler(remove_channel_message)
390 .add_request_handler(update_channel_message)
391 .add_request_handler(get_channel_messages)
392 .add_request_handler(get_channel_messages_by_id)
393 .add_request_handler(get_notifications)
394 .add_request_handler(mark_notification_as_read)
395 .add_request_handler(move_channel)
396 .add_request_handler(follow)
397 .add_message_handler(unfollow)
398 .add_message_handler(update_followers)
399 .add_request_handler(get_private_user_info)
400 .add_request_handler(get_llm_api_token)
401 .add_request_handler(accept_terms_of_service)
402 .add_message_handler(acknowledge_channel_message)
403 .add_message_handler(acknowledge_buffer_version)
404 .add_request_handler(get_supermaven_api_key)
405 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
406 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
407 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
408 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
409 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
410 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
411 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
412 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
413 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
414 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
415 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
416 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
417 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
418 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
419 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
420 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
421 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
422 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
424 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
426 .add_message_handler(update_context);
427
428 Arc::new(server)
429 }
430
431 pub async fn start(&self) -> Result<()> {
432 let server_id = *self.id.lock();
433 let app_state = self.app_state.clone();
434 let peer = self.peer.clone();
435 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
436 let pool = self.connection_pool.clone();
437 let livekit_client = self.app_state.livekit_client.clone();
438
439 let span = info_span!("start server");
440 self.app_state.executor.spawn_detached(
441 async move {
442 tracing::info!("waiting for cleanup timeout");
443 timeout.await;
444 tracing::info!("cleanup timeout expired, retrieving stale rooms");
445 if let Some((room_ids, channel_ids)) = app_state
446 .db
447 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
448 .await
449 .trace_err()
450 {
451 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
452 tracing::info!(
453 stale_channel_buffer_count = channel_ids.len(),
454 "retrieved stale channel buffers"
455 );
456
457 for channel_id in channel_ids {
458 if let Some(refreshed_channel_buffer) = app_state
459 .db
460 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
461 .await
462 .trace_err()
463 {
464 for connection_id in refreshed_channel_buffer.connection_ids {
465 peer.send(
466 connection_id,
467 proto::UpdateChannelBufferCollaborators {
468 channel_id: channel_id.to_proto(),
469 collaborators: refreshed_channel_buffer
470 .collaborators
471 .clone(),
472 },
473 )
474 .trace_err();
475 }
476 }
477 }
478
479 for room_id in room_ids {
480 let mut contacts_to_update = HashSet::default();
481 let mut canceled_calls_to_user_ids = Vec::new();
482 let mut livekit_room = String::new();
483 let mut delete_livekit_room = false;
484
485 if let Some(mut refreshed_room) = app_state
486 .db
487 .clear_stale_room_participants(room_id, server_id)
488 .await
489 .trace_err()
490 {
491 tracing::info!(
492 room_id = room_id.0,
493 new_participant_count = refreshed_room.room.participants.len(),
494 "refreshed room"
495 );
496 room_updated(&refreshed_room.room, &peer);
497 if let Some(channel) = refreshed_room.channel.as_ref() {
498 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
499 }
500 contacts_to_update
501 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
502 contacts_to_update
503 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
504 canceled_calls_to_user_ids =
505 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
506 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
507 delete_livekit_room = refreshed_room.room.participants.is_empty();
508 }
509
510 {
511 let pool = pool.lock();
512 for canceled_user_id in canceled_calls_to_user_ids {
513 for connection_id in pool.user_connection_ids(canceled_user_id) {
514 peer.send(
515 connection_id,
516 proto::CallCanceled {
517 room_id: room_id.to_proto(),
518 },
519 )
520 .trace_err();
521 }
522 }
523 }
524
525 for user_id in contacts_to_update {
526 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
527 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
528 if let Some((busy, contacts)) = busy.zip(contacts) {
529 let pool = pool.lock();
530 let updated_contact = contact_for_user(user_id, busy, &pool);
531 for contact in contacts {
532 if let db::Contact::Accepted {
533 user_id: contact_user_id,
534 ..
535 } = contact
536 {
537 for contact_conn_id in
538 pool.user_connection_ids(contact_user_id)
539 {
540 peer.send(
541 contact_conn_id,
542 proto::UpdateContacts {
543 contacts: vec![updated_contact.clone()],
544 remove_contacts: Default::default(),
545 incoming_requests: Default::default(),
546 remove_incoming_requests: Default::default(),
547 outgoing_requests: Default::default(),
548 remove_outgoing_requests: Default::default(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555 }
556 }
557
558 if let Some(live_kit) = livekit_client.as_ref() {
559 if delete_livekit_room {
560 live_kit.delete_room(livekit_room).await.trace_err();
561 }
562 }
563 }
564 }
565
566 app_state
567 .db
568 .delete_stale_servers(&app_state.config.zed_environment, server_id)
569 .await
570 .trace_err();
571 }
572 .instrument(span),
573 );
574 Ok(())
575 }
576
577 pub fn teardown(&self) {
578 self.peer.teardown();
579 self.connection_pool.lock().reset();
580 let _ = self.teardown.send(true);
581 }
582
583 #[cfg(test)]
584 pub fn reset(&self, id: ServerId) {
585 self.teardown();
586 *self.id.lock() = id;
587 self.peer.reset(id.0 as u32);
588 let _ = self.teardown.send(false);
589 }
590
591 #[cfg(test)]
592 pub fn id(&self) -> ServerId {
593 *self.id.lock()
594 }
595
596 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
597 where
598 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
599 Fut: 'static + Send + Future<Output = Result<()>>,
600 M: EnvelopedMessage,
601 {
602 let prev_handler = self.handlers.insert(
603 TypeId::of::<M>(),
604 Box::new(move |envelope, session| {
605 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
606 let received_at = envelope.received_at;
607 tracing::info!("message received");
608 let start_time = Instant::now();
609 let future = (handler)(*envelope, session);
610 async move {
611 let result = future.await;
612 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
613 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
614 let queue_duration_ms = total_duration_ms - processing_duration_ms;
615 let payload_type = M::NAME;
616
617 match result {
618 Err(error) => {
619 tracing::error!(
620 ?error,
621 total_duration_ms,
622 processing_duration_ms,
623 queue_duration_ms,
624 payload_type,
625 "error handling message"
626 )
627 }
628 Ok(()) => tracing::info!(
629 total_duration_ms,
630 processing_duration_ms,
631 queue_duration_ms,
632 "finished handling message"
633 ),
634 }
635 }
636 .boxed()
637 }),
638 );
639 if prev_handler.is_some() {
640 panic!("registered a handler for the same message twice");
641 }
642 self
643 }
644
645 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
646 where
647 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
648 Fut: 'static + Send + Future<Output = Result<()>>,
649 M: EnvelopedMessage,
650 {
651 self.add_handler(move |envelope, session| handler(envelope.payload, session));
652 self
653 }
654
655 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
656 where
657 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
658 Fut: Send + Future<Output = Result<()>>,
659 M: RequestMessage,
660 {
661 let handler = Arc::new(handler);
662 self.add_handler(move |envelope, session| {
663 let receipt = envelope.receipt();
664 let handler = handler.clone();
665 async move {
666 let peer = session.peer.clone();
667 let responded = Arc::new(AtomicBool::default());
668 let response = Response {
669 peer: peer.clone(),
670 responded: responded.clone(),
671 receipt,
672 };
673 match (handler)(envelope.payload, response, session).await {
674 Ok(()) => {
675 if responded.load(std::sync::atomic::Ordering::SeqCst) {
676 Ok(())
677 } else {
678 Err(anyhow!("handler did not send a response"))?
679 }
680 }
681 Err(error) => {
682 let proto_err = match &error {
683 Error::Internal(err) => err.to_proto(),
684 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
685 };
686 peer.respond_with_error(receipt, proto_err)?;
687 Err(error)
688 }
689 }
690 }
691 })
692 }
693
694 pub fn handle_connection(
695 self: &Arc<Self>,
696 connection: Connection,
697 address: String,
698 principal: Principal,
699 zed_version: ZedVersion,
700 geoip_country_code: Option<String>,
701 system_id: Option<String>,
702 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
703 executor: Executor,
704 ) -> impl Future<Output = ()> + use<> {
705 let this = self.clone();
706 let span = info_span!("handle connection", %address,
707 connection_id=field::Empty,
708 user_id=field::Empty,
709 login=field::Empty,
710 impersonator=field::Empty,
711 geoip_country_code=field::Empty
712 );
713 principal.update_span(&span);
714 if let Some(country_code) = geoip_country_code.as_ref() {
715 span.record("geoip_country_code", country_code);
716 }
717
718 let mut teardown = self.teardown.subscribe();
719 async move {
720 if *teardown.borrow() {
721 tracing::error!("server is tearing down");
722 return
723 }
724 let (connection_id, handle_io, mut incoming_rx) = this
725 .peer
726 .add_connection(connection, {
727 let executor = executor.clone();
728 move |duration| executor.sleep(duration)
729 });
730 tracing::Span::current().record("connection_id", format!("{}", connection_id));
731
732 tracing::info!("connection opened");
733
734 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
735 let http_client = match ReqwestClient::user_agent(&user_agent) {
736 Ok(http_client) => Arc::new(http_client),
737 Err(error) => {
738 tracing::error!(?error, "failed to create HTTP client");
739 return;
740 }
741 };
742
743 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
744 supermaven_admin_api_key.to_string(),
745 http_client.clone(),
746 )));
747
748 let session = Session {
749 principal: principal.clone(),
750 connection_id,
751 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
752 peer: this.peer.clone(),
753 connection_pool: this.connection_pool.clone(),
754 app_state: this.app_state.clone(),
755 geoip_country_code,
756 system_id,
757 _executor: executor.clone(),
758 supermaven_client,
759 };
760
761 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
762 tracing::error!(?error, "failed to send initial client update");
763 return;
764 }
765
766 let handle_io = handle_io.fuse();
767 futures::pin_mut!(handle_io);
768
769 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
770 // This prevents deadlocks when e.g., client A performs a request to client B and
771 // client B performs a request to client A. If both clients stop processing further
772 // messages until their respective request completes, they won't have a chance to
773 // respond to the other client's request and cause a deadlock.
774 //
775 // This arrangement ensures we will attempt to process earlier messages first, but fall
776 // back to processing messages arrived later in the spirit of making progress.
777 let mut foreground_message_handlers = FuturesUnordered::new();
778 let concurrent_handlers = Arc::new(Semaphore::new(256));
779 loop {
780 let next_message = async {
781 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
782 let message = incoming_rx.next().await;
783 (permit, message)
784 }.fuse();
785 futures::pin_mut!(next_message);
786 futures::select_biased! {
787 _ = teardown.changed().fuse() => return,
788 result = handle_io => {
789 if let Err(error) = result {
790 tracing::error!(?error, "error handling I/O");
791 }
792 break;
793 }
794 _ = foreground_message_handlers.next() => {}
795 next_message = next_message => {
796 let (permit, message) = next_message;
797 if let Some(message) = message {
798 let type_name = message.payload_type_name();
799 // note: we copy all the fields from the parent span so we can query them in the logs.
800 // (https://github.com/tokio-rs/tracing/issues/2670).
801 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
802 user_id=field::Empty,
803 login=field::Empty,
804 impersonator=field::Empty,
805 );
806 principal.update_span(&span);
807 let span_enter = span.enter();
808 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
809 let is_background = message.is_background();
810 let handle_message = (handler)(message, session.clone());
811 drop(span_enter);
812
813 let handle_message = async move {
814 handle_message.await;
815 drop(permit);
816 }.instrument(span);
817 if is_background {
818 executor.spawn_detached(handle_message);
819 } else {
820 foreground_message_handlers.push(handle_message);
821 }
822 } else {
823 tracing::error!("no message handler");
824 }
825 } else {
826 tracing::info!("connection closed");
827 break;
828 }
829 }
830 }
831 }
832
833 drop(foreground_message_handlers);
834 tracing::info!("signing out");
835 if let Err(error) = connection_lost(session, teardown, executor).await {
836 tracing::error!(?error, "error signing out");
837 }
838
839 }.instrument(span)
840 }
841
842 async fn send_initial_client_update(
843 &self,
844 connection_id: ConnectionId,
845 principal: &Principal,
846 zed_version: ZedVersion,
847 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
848 session: &Session,
849 ) -> Result<()> {
850 self.peer.send(
851 connection_id,
852 proto::Hello {
853 peer_id: Some(connection_id.into()),
854 },
855 )?;
856 tracing::info!("sent hello message");
857 if let Some(send_connection_id) = send_connection_id.take() {
858 let _ = send_connection_id.send(connection_id);
859 }
860
861 match principal {
862 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
863 if !user.connected_once {
864 self.peer.send(connection_id, proto::ShowContacts {})?;
865 self.app_state
866 .db
867 .set_user_connected_once(user.id, true)
868 .await?;
869 }
870
871 update_user_plan(user.id, session).await?;
872
873 let contacts = self.app_state.db.get_contacts(user.id).await?;
874
875 {
876 let mut pool = self.connection_pool.lock();
877 pool.add_connection(connection_id, user.id, user.admin, zed_version);
878 self.peer.send(
879 connection_id,
880 build_initial_contacts_update(contacts, &pool),
881 )?;
882 }
883
884 if should_auto_subscribe_to_channels(zed_version) {
885 subscribe_user_to_channels(user.id, session).await?;
886 }
887
888 if let Some(incoming_call) =
889 self.app_state.db.incoming_call_for_user(user.id).await?
890 {
891 self.peer.send(connection_id, incoming_call)?;
892 }
893
894 update_user_contacts(user.id, session).await?;
895 }
896 }
897
898 Ok(())
899 }
900
901 pub async fn invite_code_redeemed(
902 self: &Arc<Self>,
903 inviter_id: UserId,
904 invitee_id: UserId,
905 ) -> Result<()> {
906 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
907 if let Some(code) = &user.invite_code {
908 let pool = self.connection_pool.lock();
909 let invitee_contact = contact_for_user(invitee_id, false, &pool);
910 for connection_id in pool.user_connection_ids(inviter_id) {
911 self.peer.send(
912 connection_id,
913 proto::UpdateContacts {
914 contacts: vec![invitee_contact.clone()],
915 ..Default::default()
916 },
917 )?;
918 self.peer.send(
919 connection_id,
920 proto::UpdateInviteInfo {
921 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
922 count: user.invite_count as u32,
923 },
924 )?;
925 }
926 }
927 }
928 Ok(())
929 }
930
931 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
932 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
933 if let Some(invite_code) = &user.invite_code {
934 let pool = self.connection_pool.lock();
935 for connection_id in pool.user_connection_ids(user_id) {
936 self.peer.send(
937 connection_id,
938 proto::UpdateInviteInfo {
939 url: format!(
940 "{}{}",
941 self.app_state.config.invite_link_prefix, invite_code
942 ),
943 count: user.invite_count as u32,
944 },
945 )?;
946 }
947 }
948 }
949 Ok(())
950 }
951
952 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
953 let pool = self.connection_pool.lock();
954 for connection_id in pool.user_connection_ids(user_id) {
955 self.peer
956 .send(connection_id, proto::RefreshLlmToken {})
957 .trace_err();
958 }
959 }
960
961 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
962 ServerSnapshot {
963 connection_pool: ConnectionPoolGuard {
964 guard: self.connection_pool.lock(),
965 _not_send: PhantomData,
966 },
967 peer: &self.peer,
968 }
969 }
970}
971
972impl Deref for ConnectionPoolGuard<'_> {
973 type Target = ConnectionPool;
974
975 fn deref(&self) -> &Self::Target {
976 &self.guard
977 }
978}
979
980impl DerefMut for ConnectionPoolGuard<'_> {
981 fn deref_mut(&mut self) -> &mut Self::Target {
982 &mut self.guard
983 }
984}
985
986impl Drop for ConnectionPoolGuard<'_> {
987 fn drop(&mut self) {
988 #[cfg(test)]
989 self.check_invariants();
990 }
991}
992
993fn broadcast<F>(
994 sender_id: Option<ConnectionId>,
995 receiver_ids: impl IntoIterator<Item = ConnectionId>,
996 mut f: F,
997) where
998 F: FnMut(ConnectionId) -> anyhow::Result<()>,
999{
1000 for receiver_id in receiver_ids {
1001 if Some(receiver_id) != sender_id {
1002 if let Err(error) = f(receiver_id) {
1003 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1004 }
1005 }
1006 }
1007}
1008
1009pub struct ProtocolVersion(u32);
1010
1011impl Header for ProtocolVersion {
1012 fn name() -> &'static HeaderName {
1013 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1014 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1015 }
1016
1017 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1018 where
1019 Self: Sized,
1020 I: Iterator<Item = &'i axum::http::HeaderValue>,
1021 {
1022 let version = values
1023 .next()
1024 .ok_or_else(axum::headers::Error::invalid)?
1025 .to_str()
1026 .map_err(|_| axum::headers::Error::invalid())?
1027 .parse()
1028 .map_err(|_| axum::headers::Error::invalid())?;
1029 Ok(Self(version))
1030 }
1031
1032 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1033 values.extend([self.0.to_string().parse().unwrap()]);
1034 }
1035}
1036
1037pub struct AppVersionHeader(SemanticVersion);
1038impl Header for AppVersionHeader {
1039 fn name() -> &'static HeaderName {
1040 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1042 }
1043
1044 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045 where
1046 Self: Sized,
1047 I: Iterator<Item = &'i axum::http::HeaderValue>,
1048 {
1049 let version = values
1050 .next()
1051 .ok_or_else(axum::headers::Error::invalid)?
1052 .to_str()
1053 .map_err(|_| axum::headers::Error::invalid())?
1054 .parse()
1055 .map_err(|_| axum::headers::Error::invalid())?;
1056 Ok(Self(version))
1057 }
1058
1059 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060 values.extend([self.0.to_string().parse().unwrap()]);
1061 }
1062}
1063
1064pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1065 Router::new()
1066 .route("/rpc", get(handle_websocket_request))
1067 .layer(
1068 ServiceBuilder::new()
1069 .layer(Extension(server.app_state.clone()))
1070 .layer(middleware::from_fn(auth::validate_header)),
1071 )
1072 .route("/metrics", get(handle_metrics))
1073 .layer(Extension(server))
1074}
1075
1076pub async fn handle_websocket_request(
1077 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1078 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1079 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1080 Extension(server): Extension<Arc<Server>>,
1081 Extension(principal): Extension<Principal>,
1082 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1083 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1084 ws: WebSocketUpgrade,
1085) -> axum::response::Response {
1086 if protocol_version != rpc::PROTOCOL_VERSION {
1087 return (
1088 StatusCode::UPGRADE_REQUIRED,
1089 "client must be upgraded".to_string(),
1090 )
1091 .into_response();
1092 }
1093
1094 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1095 return (
1096 StatusCode::UPGRADE_REQUIRED,
1097 "no version header found".to_string(),
1098 )
1099 .into_response();
1100 };
1101
1102 if !version.can_collaborate() {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "client must be upgraded".to_string(),
1106 )
1107 .into_response();
1108 }
1109
1110 let socket_address = socket_address.to_string();
1111 ws.on_upgrade(move |socket| {
1112 let socket = socket
1113 .map_ok(to_tungstenite_message)
1114 .err_into()
1115 .with(|message| async move { to_axum_message(message) });
1116 let connection = Connection::new(Box::pin(socket));
1117 async move {
1118 server
1119 .handle_connection(
1120 connection,
1121 socket_address,
1122 principal,
1123 version,
1124 country_code_header.map(|header| header.to_string()),
1125 system_id_header.map(|header| header.to_string()),
1126 None,
1127 Executor::Production,
1128 )
1129 .await;
1130 }
1131 })
1132}
1133
1134pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1135 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136 let connections_metric = CONNECTIONS_METRIC
1137 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1138
1139 let connections = server
1140 .connection_pool
1141 .lock()
1142 .connections()
1143 .filter(|connection| !connection.admin)
1144 .count();
1145 connections_metric.set(connections as _);
1146
1147 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1149 register_int_gauge!(
1150 "shared_projects",
1151 "number of open projects with one or more guests"
1152 )
1153 .unwrap()
1154 });
1155
1156 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1157 shared_projects_metric.set(shared_projects as _);
1158
1159 let encoder = prometheus::TextEncoder::new();
1160 let metric_families = prometheus::gather();
1161 let encoded_metrics = encoder
1162 .encode_to_string(&metric_families)
1163 .map_err(|err| anyhow!("{}", err))?;
1164 Ok(encoded_metrics)
1165}
1166
1167#[instrument(err, skip(executor))]
1168async fn connection_lost(
1169 session: Session,
1170 mut teardown: watch::Receiver<bool>,
1171 executor: Executor,
1172) -> Result<()> {
1173 session.peer.disconnect(session.connection_id);
1174 session
1175 .connection_pool()
1176 .await
1177 .remove_connection(session.connection_id)?;
1178
1179 session
1180 .db()
1181 .await
1182 .connection_lost(session.connection_id)
1183 .await
1184 .trace_err();
1185
1186 futures::select_biased! {
1187 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1188
1189 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1190 leave_room_for_session(&session, session.connection_id).await.trace_err();
1191 leave_channel_buffers_for_session(&session)
1192 .await
1193 .trace_err();
1194
1195 if !session
1196 .connection_pool()
1197 .await
1198 .is_user_online(session.user_id())
1199 {
1200 let db = session.db().await;
1201 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1202 room_updated(&room, &session.peer);
1203 }
1204 }
1205
1206 update_user_contacts(session.user_id(), &session).await?;
1207 },
1208 _ = teardown.changed().fuse() => {}
1209 }
1210
1211 Ok(())
1212}
1213
1214/// Acknowledges a ping from a client, used to keep the connection alive.
1215async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1216 response.send(proto::Ack {})?;
1217 Ok(())
1218}
1219
1220/// Creates a new room for calling (outside of channels)
1221async fn create_room(
1222 _request: proto::CreateRoom,
1223 response: Response<proto::CreateRoom>,
1224 session: Session,
1225) -> Result<()> {
1226 let livekit_room = nanoid::nanoid!(30);
1227
1228 let live_kit_connection_info = util::maybe!(async {
1229 let live_kit = session.app_state.livekit_client.as_ref();
1230 let live_kit = live_kit?;
1231 let user_id = session.user_id().to_string();
1232
1233 let token = live_kit
1234 .room_token(&livekit_room, &user_id.to_string())
1235 .trace_err()?;
1236
1237 Some(proto::LiveKitConnectionInfo {
1238 server_url: live_kit.url().into(),
1239 token,
1240 can_publish: true,
1241 })
1242 })
1243 .await;
1244
1245 let room = session
1246 .db()
1247 .await
1248 .create_room(session.user_id(), session.connection_id, &livekit_room)
1249 .await?;
1250
1251 response.send(proto::CreateRoomResponse {
1252 room: Some(room.clone()),
1253 live_kit_connection_info,
1254 })?;
1255
1256 update_user_contacts(session.user_id(), &session).await?;
1257 Ok(())
1258}
1259
1260/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1261async fn join_room(
1262 request: proto::JoinRoom,
1263 response: Response<proto::JoinRoom>,
1264 session: Session,
1265) -> Result<()> {
1266 let room_id = RoomId::from_proto(request.id);
1267
1268 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1269
1270 if let Some(channel_id) = channel_id {
1271 return join_channel_internal(channel_id, Box::new(response), session).await;
1272 }
1273
1274 let joined_room = {
1275 let room = session
1276 .db()
1277 .await
1278 .join_room(room_id, session.user_id(), session.connection_id)
1279 .await?;
1280 room_updated(&room.room, &session.peer);
1281 room.into_inner()
1282 };
1283
1284 for connection_id in session
1285 .connection_pool()
1286 .await
1287 .user_connection_ids(session.user_id())
1288 {
1289 session
1290 .peer
1291 .send(
1292 connection_id,
1293 proto::CallCanceled {
1294 room_id: room_id.to_proto(),
1295 },
1296 )
1297 .trace_err();
1298 }
1299
1300 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1301 {
1302 live_kit
1303 .room_token(
1304 &joined_room.room.livekit_room,
1305 &session.user_id().to_string(),
1306 )
1307 .trace_err()
1308 .map(|token| proto::LiveKitConnectionInfo {
1309 server_url: live_kit.url().into(),
1310 token,
1311 can_publish: true,
1312 })
1313 } else {
1314 None
1315 };
1316
1317 response.send(proto::JoinRoomResponse {
1318 room: Some(joined_room.room),
1319 channel_id: None,
1320 live_kit_connection_info,
1321 })?;
1322
1323 update_user_contacts(session.user_id(), &session).await?;
1324 Ok(())
1325}
1326
1327/// Rejoin room is used to reconnect to a room after connection errors.
1328async fn rejoin_room(
1329 request: proto::RejoinRoom,
1330 response: Response<proto::RejoinRoom>,
1331 session: Session,
1332) -> Result<()> {
1333 let room;
1334 let channel;
1335 {
1336 let mut rejoined_room = session
1337 .db()
1338 .await
1339 .rejoin_room(request, session.user_id(), session.connection_id)
1340 .await?;
1341
1342 response.send(proto::RejoinRoomResponse {
1343 room: Some(rejoined_room.room.clone()),
1344 reshared_projects: rejoined_room
1345 .reshared_projects
1346 .iter()
1347 .map(|project| proto::ResharedProject {
1348 id: project.id.to_proto(),
1349 collaborators: project
1350 .collaborators
1351 .iter()
1352 .map(|collaborator| collaborator.to_proto())
1353 .collect(),
1354 })
1355 .collect(),
1356 rejoined_projects: rejoined_room
1357 .rejoined_projects
1358 .iter()
1359 .map(|rejoined_project| rejoined_project.to_proto())
1360 .collect(),
1361 })?;
1362 room_updated(&rejoined_room.room, &session.peer);
1363
1364 for project in &rejoined_room.reshared_projects {
1365 for collaborator in &project.collaborators {
1366 session
1367 .peer
1368 .send(
1369 collaborator.connection_id,
1370 proto::UpdateProjectCollaborator {
1371 project_id: project.id.to_proto(),
1372 old_peer_id: Some(project.old_connection_id.into()),
1373 new_peer_id: Some(session.connection_id.into()),
1374 },
1375 )
1376 .trace_err();
1377 }
1378
1379 broadcast(
1380 Some(session.connection_id),
1381 project
1382 .collaborators
1383 .iter()
1384 .map(|collaborator| collaborator.connection_id),
1385 |connection_id| {
1386 session.peer.forward_send(
1387 session.connection_id,
1388 connection_id,
1389 proto::UpdateProject {
1390 project_id: project.id.to_proto(),
1391 worktrees: project.worktrees.clone(),
1392 },
1393 )
1394 },
1395 );
1396 }
1397
1398 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1399
1400 let rejoined_room = rejoined_room.into_inner();
1401
1402 room = rejoined_room.room;
1403 channel = rejoined_room.channel;
1404 }
1405
1406 if let Some(channel) = channel {
1407 channel_updated(
1408 &channel,
1409 &room,
1410 &session.peer,
1411 &*session.connection_pool().await,
1412 );
1413 }
1414
1415 update_user_contacts(session.user_id(), &session).await?;
1416 Ok(())
1417}
1418
1419fn notify_rejoined_projects(
1420 rejoined_projects: &mut Vec<RejoinedProject>,
1421 session: &Session,
1422) -> Result<()> {
1423 for project in rejoined_projects.iter() {
1424 for collaborator in &project.collaborators {
1425 session
1426 .peer
1427 .send(
1428 collaborator.connection_id,
1429 proto::UpdateProjectCollaborator {
1430 project_id: project.id.to_proto(),
1431 old_peer_id: Some(project.old_connection_id.into()),
1432 new_peer_id: Some(session.connection_id.into()),
1433 },
1434 )
1435 .trace_err();
1436 }
1437 }
1438
1439 for project in rejoined_projects {
1440 for worktree in mem::take(&mut project.worktrees) {
1441 // Stream this worktree's entries.
1442 let message = proto::UpdateWorktree {
1443 project_id: project.id.to_proto(),
1444 worktree_id: worktree.id,
1445 abs_path: worktree.abs_path.clone(),
1446 root_name: worktree.root_name,
1447 updated_entries: worktree.updated_entries,
1448 removed_entries: worktree.removed_entries,
1449 scan_id: worktree.scan_id,
1450 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1451 updated_repositories: worktree.updated_repositories,
1452 removed_repositories: worktree.removed_repositories,
1453 };
1454 for update in proto::split_worktree_update(message) {
1455 session.peer.send(session.connection_id, update)?;
1456 }
1457
1458 // Stream this worktree's diagnostics.
1459 for summary in worktree.diagnostic_summaries {
1460 session.peer.send(
1461 session.connection_id,
1462 proto::UpdateDiagnosticSummary {
1463 project_id: project.id.to_proto(),
1464 worktree_id: worktree.id,
1465 summary: Some(summary),
1466 },
1467 )?;
1468 }
1469
1470 for settings_file in worktree.settings_files {
1471 session.peer.send(
1472 session.connection_id,
1473 proto::UpdateWorktreeSettings {
1474 project_id: project.id.to_proto(),
1475 worktree_id: worktree.id,
1476 path: settings_file.path,
1477 content: Some(settings_file.content),
1478 kind: Some(settings_file.kind.to_proto().into()),
1479 },
1480 )?;
1481 }
1482 }
1483
1484 for repository in mem::take(&mut project.updated_repositories) {
1485 for update in split_repository_update(repository) {
1486 session.peer.send(session.connection_id, update)?;
1487 }
1488 }
1489
1490 for id in mem::take(&mut project.removed_repositories) {
1491 session.peer.send(
1492 session.connection_id,
1493 proto::RemoveRepository {
1494 project_id: project.id.to_proto(),
1495 id,
1496 },
1497 )?;
1498 }
1499 }
1500
1501 Ok(())
1502}
1503
1504/// leave room disconnects from the room.
1505async fn leave_room(
1506 _: proto::LeaveRoom,
1507 response: Response<proto::LeaveRoom>,
1508 session: Session,
1509) -> Result<()> {
1510 leave_room_for_session(&session, session.connection_id).await?;
1511 response.send(proto::Ack {})?;
1512 Ok(())
1513}
1514
1515/// Updates the permissions of someone else in the room.
1516async fn set_room_participant_role(
1517 request: proto::SetRoomParticipantRole,
1518 response: Response<proto::SetRoomParticipantRole>,
1519 session: Session,
1520) -> Result<()> {
1521 let user_id = UserId::from_proto(request.user_id);
1522 let role = ChannelRole::from(request.role());
1523
1524 let (livekit_room, can_publish) = {
1525 let room = session
1526 .db()
1527 .await
1528 .set_room_participant_role(
1529 session.user_id(),
1530 RoomId::from_proto(request.room_id),
1531 user_id,
1532 role,
1533 )
1534 .await?;
1535
1536 let livekit_room = room.livekit_room.clone();
1537 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1538 room_updated(&room, &session.peer);
1539 (livekit_room, can_publish)
1540 };
1541
1542 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1543 live_kit
1544 .update_participant(
1545 livekit_room.clone(),
1546 request.user_id.to_string(),
1547 livekit_api::proto::ParticipantPermission {
1548 can_subscribe: true,
1549 can_publish,
1550 can_publish_data: can_publish,
1551 hidden: false,
1552 recorder: false,
1553 },
1554 )
1555 .await
1556 .trace_err();
1557 }
1558
1559 response.send(proto::Ack {})?;
1560 Ok(())
1561}
1562
1563/// Call someone else into the current room
1564async fn call(
1565 request: proto::Call,
1566 response: Response<proto::Call>,
1567 session: Session,
1568) -> Result<()> {
1569 let room_id = RoomId::from_proto(request.room_id);
1570 let calling_user_id = session.user_id();
1571 let calling_connection_id = session.connection_id;
1572 let called_user_id = UserId::from_proto(request.called_user_id);
1573 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1574 if !session
1575 .db()
1576 .await
1577 .has_contact(calling_user_id, called_user_id)
1578 .await?
1579 {
1580 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1581 }
1582
1583 let incoming_call = {
1584 let (room, incoming_call) = &mut *session
1585 .db()
1586 .await
1587 .call(
1588 room_id,
1589 calling_user_id,
1590 calling_connection_id,
1591 called_user_id,
1592 initial_project_id,
1593 )
1594 .await?;
1595 room_updated(room, &session.peer);
1596 mem::take(incoming_call)
1597 };
1598 update_user_contacts(called_user_id, &session).await?;
1599
1600 let mut calls = session
1601 .connection_pool()
1602 .await
1603 .user_connection_ids(called_user_id)
1604 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1605 .collect::<FuturesUnordered<_>>();
1606
1607 while let Some(call_response) = calls.next().await {
1608 match call_response.as_ref() {
1609 Ok(_) => {
1610 response.send(proto::Ack {})?;
1611 return Ok(());
1612 }
1613 Err(_) => {
1614 call_response.trace_err();
1615 }
1616 }
1617 }
1618
1619 {
1620 let room = session
1621 .db()
1622 .await
1623 .call_failed(room_id, called_user_id)
1624 .await?;
1625 room_updated(&room, &session.peer);
1626 }
1627 update_user_contacts(called_user_id, &session).await?;
1628
1629 Err(anyhow!("failed to ring user"))?
1630}
1631
1632/// Cancel an outgoing call.
1633async fn cancel_call(
1634 request: proto::CancelCall,
1635 response: Response<proto::CancelCall>,
1636 session: Session,
1637) -> Result<()> {
1638 let called_user_id = UserId::from_proto(request.called_user_id);
1639 let room_id = RoomId::from_proto(request.room_id);
1640 {
1641 let room = session
1642 .db()
1643 .await
1644 .cancel_call(room_id, session.connection_id, called_user_id)
1645 .await?;
1646 room_updated(&room, &session.peer);
1647 }
1648
1649 for connection_id in session
1650 .connection_pool()
1651 .await
1652 .user_connection_ids(called_user_id)
1653 {
1654 session
1655 .peer
1656 .send(
1657 connection_id,
1658 proto::CallCanceled {
1659 room_id: room_id.to_proto(),
1660 },
1661 )
1662 .trace_err();
1663 }
1664 response.send(proto::Ack {})?;
1665
1666 update_user_contacts(called_user_id, &session).await?;
1667 Ok(())
1668}
1669
1670/// Decline an incoming call.
1671async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1672 let room_id = RoomId::from_proto(message.room_id);
1673 {
1674 let room = session
1675 .db()
1676 .await
1677 .decline_call(Some(room_id), session.user_id())
1678 .await?
1679 .ok_or_else(|| anyhow!("failed to decline call"))?;
1680 room_updated(&room, &session.peer);
1681 }
1682
1683 for connection_id in session
1684 .connection_pool()
1685 .await
1686 .user_connection_ids(session.user_id())
1687 {
1688 session
1689 .peer
1690 .send(
1691 connection_id,
1692 proto::CallCanceled {
1693 room_id: room_id.to_proto(),
1694 },
1695 )
1696 .trace_err();
1697 }
1698 update_user_contacts(session.user_id(), &session).await?;
1699 Ok(())
1700}
1701
1702/// Updates other participants in the room with your current location.
1703async fn update_participant_location(
1704 request: proto::UpdateParticipantLocation,
1705 response: Response<proto::UpdateParticipantLocation>,
1706 session: Session,
1707) -> Result<()> {
1708 let room_id = RoomId::from_proto(request.room_id);
1709 let location = request
1710 .location
1711 .ok_or_else(|| anyhow!("invalid location"))?;
1712
1713 let db = session.db().await;
1714 let room = db
1715 .update_room_participant_location(room_id, session.connection_id, location)
1716 .await?;
1717
1718 room_updated(&room, &session.peer);
1719 response.send(proto::Ack {})?;
1720 Ok(())
1721}
1722
1723/// Share a project into the room.
1724async fn share_project(
1725 request: proto::ShareProject,
1726 response: Response<proto::ShareProject>,
1727 session: Session,
1728) -> Result<()> {
1729 let (project_id, room) = &*session
1730 .db()
1731 .await
1732 .share_project(
1733 RoomId::from_proto(request.room_id),
1734 session.connection_id,
1735 &request.worktrees,
1736 request.is_ssh_project,
1737 )
1738 .await?;
1739 response.send(proto::ShareProjectResponse {
1740 project_id: project_id.to_proto(),
1741 })?;
1742 room_updated(room, &session.peer);
1743
1744 Ok(())
1745}
1746
1747/// Unshare a project from the room.
1748async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1749 let project_id = ProjectId::from_proto(message.project_id);
1750 unshare_project_internal(project_id, session.connection_id, &session).await
1751}
1752
1753async fn unshare_project_internal(
1754 project_id: ProjectId,
1755 connection_id: ConnectionId,
1756 session: &Session,
1757) -> Result<()> {
1758 let delete = {
1759 let room_guard = session
1760 .db()
1761 .await
1762 .unshare_project(project_id, connection_id)
1763 .await?;
1764
1765 let (delete, room, guest_connection_ids) = &*room_guard;
1766
1767 let message = proto::UnshareProject {
1768 project_id: project_id.to_proto(),
1769 };
1770
1771 broadcast(
1772 Some(connection_id),
1773 guest_connection_ids.iter().copied(),
1774 |conn_id| session.peer.send(conn_id, message.clone()),
1775 );
1776 if let Some(room) = room {
1777 room_updated(room, &session.peer);
1778 }
1779
1780 *delete
1781 };
1782
1783 if delete {
1784 let db = session.db().await;
1785 db.delete_project(project_id).await?;
1786 }
1787
1788 Ok(())
1789}
1790
1791/// Join someone elses shared project.
1792async fn join_project(
1793 request: proto::JoinProject,
1794 response: Response<proto::JoinProject>,
1795 session: Session,
1796) -> Result<()> {
1797 let project_id = ProjectId::from_proto(request.project_id);
1798
1799 tracing::info!(%project_id, "join project");
1800
1801 let db = session.db().await;
1802 let (project, replica_id) = &mut *db
1803 .join_project(project_id, session.connection_id, session.user_id())
1804 .await?;
1805 drop(db);
1806 tracing::info!(%project_id, "join remote project");
1807 join_project_internal(response, session, project, replica_id)
1808}
1809
1810trait JoinProjectInternalResponse {
1811 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1814 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815 Response::<proto::JoinProject>::send(self, result)
1816 }
1817}
1818
1819fn join_project_internal(
1820 response: impl JoinProjectInternalResponse,
1821 session: Session,
1822 project: &mut Project,
1823 replica_id: &ReplicaId,
1824) -> Result<()> {
1825 let collaborators = project
1826 .collaborators
1827 .iter()
1828 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829 .map(|collaborator| collaborator.to_proto())
1830 .collect::<Vec<_>>();
1831 let project_id = project.id;
1832 let guest_user_id = session.user_id();
1833
1834 let worktrees = project
1835 .worktrees
1836 .iter()
1837 .map(|(id, worktree)| proto::WorktreeMetadata {
1838 id: *id,
1839 root_name: worktree.root_name.clone(),
1840 visible: worktree.visible,
1841 abs_path: worktree.abs_path.clone(),
1842 })
1843 .collect::<Vec<_>>();
1844
1845 let add_project_collaborator = proto::AddProjectCollaborator {
1846 project_id: project_id.to_proto(),
1847 collaborator: Some(proto::Collaborator {
1848 peer_id: Some(session.connection_id.into()),
1849 replica_id: replica_id.0 as u32,
1850 user_id: guest_user_id.to_proto(),
1851 is_host: false,
1852 }),
1853 };
1854
1855 for collaborator in &collaborators {
1856 session
1857 .peer
1858 .send(
1859 collaborator.peer_id.unwrap().into(),
1860 add_project_collaborator.clone(),
1861 )
1862 .trace_err();
1863 }
1864
1865 // First, we send the metadata associated with each worktree.
1866 response.send(proto::JoinProjectResponse {
1867 project_id: project.id.0 as u64,
1868 worktrees: worktrees.clone(),
1869 replica_id: replica_id.0 as u32,
1870 collaborators: collaborators.clone(),
1871 language_servers: project.language_servers.clone(),
1872 role: project.role.into(),
1873 })?;
1874
1875 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1876 // Stream this worktree's entries.
1877 let message = proto::UpdateWorktree {
1878 project_id: project_id.to_proto(),
1879 worktree_id,
1880 abs_path: worktree.abs_path.clone(),
1881 root_name: worktree.root_name,
1882 updated_entries: worktree.entries,
1883 removed_entries: Default::default(),
1884 scan_id: worktree.scan_id,
1885 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1887 removed_repositories: Default::default(),
1888 };
1889 for update in proto::split_worktree_update(message) {
1890 session.peer.send(session.connection_id, update.clone())?;
1891 }
1892
1893 // Stream this worktree's diagnostics.
1894 for summary in worktree.diagnostic_summaries {
1895 session.peer.send(
1896 session.connection_id,
1897 proto::UpdateDiagnosticSummary {
1898 project_id: project_id.to_proto(),
1899 worktree_id: worktree.id,
1900 summary: Some(summary),
1901 },
1902 )?;
1903 }
1904
1905 for settings_file in worktree.settings_files {
1906 session.peer.send(
1907 session.connection_id,
1908 proto::UpdateWorktreeSettings {
1909 project_id: project_id.to_proto(),
1910 worktree_id: worktree.id,
1911 path: settings_file.path,
1912 content: Some(settings_file.content),
1913 kind: Some(settings_file.kind.to_proto() as i32),
1914 },
1915 )?;
1916 }
1917 }
1918
1919 for repository in mem::take(&mut project.repositories) {
1920 for update in split_repository_update(repository) {
1921 session.peer.send(session.connection_id, update)?;
1922 }
1923 }
1924
1925 for language_server in &project.language_servers {
1926 session.peer.send(
1927 session.connection_id,
1928 proto::UpdateLanguageServer {
1929 project_id: project_id.to_proto(),
1930 language_server_id: language_server.id,
1931 variant: Some(
1932 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1933 proto::LspDiskBasedDiagnosticsUpdated {},
1934 ),
1935 ),
1936 },
1937 )?;
1938 }
1939
1940 Ok(())
1941}
1942
1943/// Leave someone elses shared project.
1944async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1945 let sender_id = session.connection_id;
1946 let project_id = ProjectId::from_proto(request.project_id);
1947 let db = session.db().await;
1948
1949 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950 tracing::info!(
1951 %project_id,
1952 "leave project"
1953 );
1954
1955 project_left(project, &session);
1956 if let Some(room) = room {
1957 room_updated(room, &session.peer);
1958 }
1959
1960 Ok(())
1961}
1962
1963/// Updates other participants with changes to the project
1964async fn update_project(
1965 request: proto::UpdateProject,
1966 response: Response<proto::UpdateProject>,
1967 session: Session,
1968) -> Result<()> {
1969 let project_id = ProjectId::from_proto(request.project_id);
1970 let (room, guest_connection_ids) = &*session
1971 .db()
1972 .await
1973 .update_project(project_id, session.connection_id, &request.worktrees)
1974 .await?;
1975 broadcast(
1976 Some(session.connection_id),
1977 guest_connection_ids.iter().copied(),
1978 |connection_id| {
1979 session
1980 .peer
1981 .forward_send(session.connection_id, connection_id, request.clone())
1982 },
1983 );
1984 if let Some(room) = room {
1985 room_updated(room, &session.peer);
1986 }
1987 response.send(proto::Ack {})?;
1988
1989 Ok(())
1990}
1991
1992/// Updates other participants with changes to the worktree
1993async fn update_worktree(
1994 request: proto::UpdateWorktree,
1995 response: Response<proto::UpdateWorktree>,
1996 session: Session,
1997) -> Result<()> {
1998 let guest_connection_ids = session
1999 .db()
2000 .await
2001 .update_worktree(&request, session.connection_id)
2002 .await?;
2003
2004 broadcast(
2005 Some(session.connection_id),
2006 guest_connection_ids.iter().copied(),
2007 |connection_id| {
2008 session
2009 .peer
2010 .forward_send(session.connection_id, connection_id, request.clone())
2011 },
2012 );
2013 response.send(proto::Ack {})?;
2014 Ok(())
2015}
2016
2017async fn update_repository(
2018 request: proto::UpdateRepository,
2019 response: Response<proto::UpdateRepository>,
2020 session: Session,
2021) -> Result<()> {
2022 let guest_connection_ids = session
2023 .db()
2024 .await
2025 .update_repository(&request, session.connection_id)
2026 .await?;
2027
2028 broadcast(
2029 Some(session.connection_id),
2030 guest_connection_ids.iter().copied(),
2031 |connection_id| {
2032 session
2033 .peer
2034 .forward_send(session.connection_id, connection_id, request.clone())
2035 },
2036 );
2037 response.send(proto::Ack {})?;
2038 Ok(())
2039}
2040
2041async fn remove_repository(
2042 request: proto::RemoveRepository,
2043 response: Response<proto::RemoveRepository>,
2044 session: Session,
2045) -> Result<()> {
2046 let guest_connection_ids = session
2047 .db()
2048 .await
2049 .remove_repository(&request, session.connection_id)
2050 .await?;
2051
2052 broadcast(
2053 Some(session.connection_id),
2054 guest_connection_ids.iter().copied(),
2055 |connection_id| {
2056 session
2057 .peer
2058 .forward_send(session.connection_id, connection_id, request.clone())
2059 },
2060 );
2061 response.send(proto::Ack {})?;
2062 Ok(())
2063}
2064
2065/// Updates other participants with changes to the diagnostics
2066async fn update_diagnostic_summary(
2067 message: proto::UpdateDiagnosticSummary,
2068 session: Session,
2069) -> Result<()> {
2070 let guest_connection_ids = session
2071 .db()
2072 .await
2073 .update_diagnostic_summary(&message, session.connection_id)
2074 .await?;
2075
2076 broadcast(
2077 Some(session.connection_id),
2078 guest_connection_ids.iter().copied(),
2079 |connection_id| {
2080 session
2081 .peer
2082 .forward_send(session.connection_id, connection_id, message.clone())
2083 },
2084 );
2085
2086 Ok(())
2087}
2088
2089/// Updates other participants with changes to the worktree settings
2090async fn update_worktree_settings(
2091 message: proto::UpdateWorktreeSettings,
2092 session: Session,
2093) -> Result<()> {
2094 let guest_connection_ids = session
2095 .db()
2096 .await
2097 .update_worktree_settings(&message, session.connection_id)
2098 .await?;
2099
2100 broadcast(
2101 Some(session.connection_id),
2102 guest_connection_ids.iter().copied(),
2103 |connection_id| {
2104 session
2105 .peer
2106 .forward_send(session.connection_id, connection_id, message.clone())
2107 },
2108 );
2109
2110 Ok(())
2111}
2112
2113/// Notify other participants that a language server has started.
2114async fn start_language_server(
2115 request: proto::StartLanguageServer,
2116 session: Session,
2117) -> Result<()> {
2118 let guest_connection_ids = session
2119 .db()
2120 .await
2121 .start_language_server(&request, session.connection_id)
2122 .await?;
2123
2124 broadcast(
2125 Some(session.connection_id),
2126 guest_connection_ids.iter().copied(),
2127 |connection_id| {
2128 session
2129 .peer
2130 .forward_send(session.connection_id, connection_id, request.clone())
2131 },
2132 );
2133 Ok(())
2134}
2135
2136/// Notify other participants that a language server has changed.
2137async fn update_language_server(
2138 request: proto::UpdateLanguageServer,
2139 session: Session,
2140) -> Result<()> {
2141 let project_id = ProjectId::from_proto(request.project_id);
2142 let project_connection_ids = session
2143 .db()
2144 .await
2145 .project_connection_ids(project_id, session.connection_id, true)
2146 .await?;
2147 broadcast(
2148 Some(session.connection_id),
2149 project_connection_ids.iter().copied(),
2150 |connection_id| {
2151 session
2152 .peer
2153 .forward_send(session.connection_id, connection_id, request.clone())
2154 },
2155 );
2156 Ok(())
2157}
2158
2159/// forward a project request to the host. These requests should be read only
2160/// as guests are allowed to send them.
2161async fn forward_read_only_project_request<T>(
2162 request: T,
2163 response: Response<T>,
2164 session: Session,
2165) -> Result<()>
2166where
2167 T: EntityMessage + RequestMessage,
2168{
2169 let project_id = ProjectId::from_proto(request.remote_entity_id());
2170 let host_connection_id = session
2171 .db()
2172 .await
2173 .host_for_read_only_project_request(project_id, session.connection_id)
2174 .await?;
2175 let payload = session
2176 .peer
2177 .forward_request(session.connection_id, host_connection_id, request)
2178 .await?;
2179 response.send(payload)?;
2180 Ok(())
2181}
2182
2183async fn forward_find_search_candidates_request(
2184 request: proto::FindSearchCandidates,
2185 response: Response<proto::FindSearchCandidates>,
2186 session: Session,
2187) -> Result<()> {
2188 let project_id = ProjectId::from_proto(request.remote_entity_id());
2189 let host_connection_id = session
2190 .db()
2191 .await
2192 .host_for_read_only_project_request(project_id, session.connection_id)
2193 .await?;
2194 let payload = session
2195 .peer
2196 .forward_request(session.connection_id, host_connection_id, request)
2197 .await?;
2198 response.send(payload)?;
2199 Ok(())
2200}
2201
2202/// forward a project request to the host. These requests are disallowed
2203/// for guests.
2204async fn forward_mutating_project_request<T>(
2205 request: T,
2206 response: Response<T>,
2207 session: Session,
2208) -> Result<()>
2209where
2210 T: EntityMessage + RequestMessage,
2211{
2212 let project_id = ProjectId::from_proto(request.remote_entity_id());
2213
2214 let host_connection_id = session
2215 .db()
2216 .await
2217 .host_for_mutating_project_request(project_id, session.connection_id)
2218 .await?;
2219 let payload = session
2220 .peer
2221 .forward_request(session.connection_id, host_connection_id, request)
2222 .await?;
2223 response.send(payload)?;
2224 Ok(())
2225}
2226
2227/// Notify other participants that a new buffer has been created
2228async fn create_buffer_for_peer(
2229 request: proto::CreateBufferForPeer,
2230 session: Session,
2231) -> Result<()> {
2232 session
2233 .db()
2234 .await
2235 .check_user_is_project_host(
2236 ProjectId::from_proto(request.project_id),
2237 session.connection_id,
2238 )
2239 .await?;
2240 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2241 session
2242 .peer
2243 .forward_send(session.connection_id, peer_id.into(), request)?;
2244 Ok(())
2245}
2246
2247/// Notify other participants that a buffer has been updated. This is
2248/// allowed for guests as long as the update is limited to selections.
2249async fn update_buffer(
2250 request: proto::UpdateBuffer,
2251 response: Response<proto::UpdateBuffer>,
2252 session: Session,
2253) -> Result<()> {
2254 let project_id = ProjectId::from_proto(request.project_id);
2255 let mut capability = Capability::ReadOnly;
2256
2257 for op in request.operations.iter() {
2258 match op.variant {
2259 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2260 Some(_) => capability = Capability::ReadWrite,
2261 }
2262 }
2263
2264 let host = {
2265 let guard = session
2266 .db()
2267 .await
2268 .connections_for_buffer_update(project_id, session.connection_id, capability)
2269 .await?;
2270
2271 let (host, guests) = &*guard;
2272
2273 broadcast(
2274 Some(session.connection_id),
2275 guests.clone(),
2276 |connection_id| {
2277 session
2278 .peer
2279 .forward_send(session.connection_id, connection_id, request.clone())
2280 },
2281 );
2282
2283 *host
2284 };
2285
2286 if host != session.connection_id {
2287 session
2288 .peer
2289 .forward_request(session.connection_id, host, request.clone())
2290 .await?;
2291 }
2292
2293 response.send(proto::Ack {})?;
2294 Ok(())
2295}
2296
2297async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2298 let project_id = ProjectId::from_proto(message.project_id);
2299
2300 let operation = message.operation.as_ref().context("invalid operation")?;
2301 let capability = match operation.variant.as_ref() {
2302 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2303 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2304 match buffer_op.variant {
2305 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2306 Capability::ReadOnly
2307 }
2308 _ => Capability::ReadWrite,
2309 }
2310 } else {
2311 Capability::ReadWrite
2312 }
2313 }
2314 Some(_) => Capability::ReadWrite,
2315 None => Capability::ReadOnly,
2316 };
2317
2318 let guard = session
2319 .db()
2320 .await
2321 .connections_for_buffer_update(project_id, session.connection_id, capability)
2322 .await?;
2323
2324 let (host, guests) = &*guard;
2325
2326 broadcast(
2327 Some(session.connection_id),
2328 guests.iter().chain([host]).copied(),
2329 |connection_id| {
2330 session
2331 .peer
2332 .forward_send(session.connection_id, connection_id, message.clone())
2333 },
2334 );
2335
2336 Ok(())
2337}
2338
2339/// Notify other participants that a project has been updated.
2340async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2341 request: T,
2342 session: Session,
2343) -> Result<()> {
2344 let project_id = ProjectId::from_proto(request.remote_entity_id());
2345 let project_connection_ids = session
2346 .db()
2347 .await
2348 .project_connection_ids(project_id, session.connection_id, false)
2349 .await?;
2350
2351 broadcast(
2352 Some(session.connection_id),
2353 project_connection_ids.iter().copied(),
2354 |connection_id| {
2355 session
2356 .peer
2357 .forward_send(session.connection_id, connection_id, request.clone())
2358 },
2359 );
2360 Ok(())
2361}
2362
2363/// Start following another user in a call.
2364async fn follow(
2365 request: proto::Follow,
2366 response: Response<proto::Follow>,
2367 session: Session,
2368) -> Result<()> {
2369 let room_id = RoomId::from_proto(request.room_id);
2370 let project_id = request.project_id.map(ProjectId::from_proto);
2371 let leader_id = request
2372 .leader_id
2373 .ok_or_else(|| anyhow!("invalid leader id"))?
2374 .into();
2375 let follower_id = session.connection_id;
2376
2377 session
2378 .db()
2379 .await
2380 .check_room_participants(room_id, leader_id, session.connection_id)
2381 .await?;
2382
2383 let response_payload = session
2384 .peer
2385 .forward_request(session.connection_id, leader_id, request)
2386 .await?;
2387 response.send(response_payload)?;
2388
2389 if let Some(project_id) = project_id {
2390 let room = session
2391 .db()
2392 .await
2393 .follow(room_id, project_id, leader_id, follower_id)
2394 .await?;
2395 room_updated(&room, &session.peer);
2396 }
2397
2398 Ok(())
2399}
2400
2401/// Stop following another user in a call.
2402async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2403 let room_id = RoomId::from_proto(request.room_id);
2404 let project_id = request.project_id.map(ProjectId::from_proto);
2405 let leader_id = request
2406 .leader_id
2407 .ok_or_else(|| anyhow!("invalid leader id"))?
2408 .into();
2409 let follower_id = session.connection_id;
2410
2411 session
2412 .db()
2413 .await
2414 .check_room_participants(room_id, leader_id, session.connection_id)
2415 .await?;
2416
2417 session
2418 .peer
2419 .forward_send(session.connection_id, leader_id, request)?;
2420
2421 if let Some(project_id) = project_id {
2422 let room = session
2423 .db()
2424 .await
2425 .unfollow(room_id, project_id, leader_id, follower_id)
2426 .await?;
2427 room_updated(&room, &session.peer);
2428 }
2429
2430 Ok(())
2431}
2432
2433/// Notify everyone following you of your current location.
2434async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2435 let room_id = RoomId::from_proto(request.room_id);
2436 let database = session.db.lock().await;
2437
2438 let connection_ids = if let Some(project_id) = request.project_id {
2439 let project_id = ProjectId::from_proto(project_id);
2440 database
2441 .project_connection_ids(project_id, session.connection_id, true)
2442 .await?
2443 } else {
2444 database
2445 .room_connection_ids(room_id, session.connection_id)
2446 .await?
2447 };
2448
2449 // For now, don't send view update messages back to that view's current leader.
2450 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2451 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2452 _ => None,
2453 });
2454
2455 for connection_id in connection_ids.iter().cloned() {
2456 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2457 session
2458 .peer
2459 .forward_send(session.connection_id, connection_id, request.clone())?;
2460 }
2461 }
2462 Ok(())
2463}
2464
2465/// Get public data about users.
2466async fn get_users(
2467 request: proto::GetUsers,
2468 response: Response<proto::GetUsers>,
2469 session: Session,
2470) -> Result<()> {
2471 let user_ids = request
2472 .user_ids
2473 .into_iter()
2474 .map(UserId::from_proto)
2475 .collect();
2476 let users = session
2477 .db()
2478 .await
2479 .get_users_by_ids(user_ids)
2480 .await?
2481 .into_iter()
2482 .map(|user| proto::User {
2483 id: user.id.to_proto(),
2484 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2485 github_login: user.github_login,
2486 email: user.email_address,
2487 name: user.name,
2488 })
2489 .collect();
2490 response.send(proto::UsersResponse { users })?;
2491 Ok(())
2492}
2493
2494/// Search for users (to invite) buy Github login
2495async fn fuzzy_search_users(
2496 request: proto::FuzzySearchUsers,
2497 response: Response<proto::FuzzySearchUsers>,
2498 session: Session,
2499) -> Result<()> {
2500 let query = request.query;
2501 let users = match query.len() {
2502 0 => vec![],
2503 1 | 2 => session
2504 .db()
2505 .await
2506 .get_user_by_github_login(&query)
2507 .await?
2508 .into_iter()
2509 .collect(),
2510 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2511 };
2512 let users = users
2513 .into_iter()
2514 .filter(|user| user.id != session.user_id())
2515 .map(|user| proto::User {
2516 id: user.id.to_proto(),
2517 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2518 github_login: user.github_login,
2519 name: user.name,
2520 email: user.email_address,
2521 })
2522 .collect();
2523 response.send(proto::UsersResponse { users })?;
2524 Ok(())
2525}
2526
2527/// Send a contact request to another user.
2528async fn request_contact(
2529 request: proto::RequestContact,
2530 response: Response<proto::RequestContact>,
2531 session: Session,
2532) -> Result<()> {
2533 let requester_id = session.user_id();
2534 let responder_id = UserId::from_proto(request.responder_id);
2535 if requester_id == responder_id {
2536 return Err(anyhow!("cannot add yourself as a contact"))?;
2537 }
2538
2539 let notifications = session
2540 .db()
2541 .await
2542 .send_contact_request(requester_id, responder_id)
2543 .await?;
2544
2545 // Update outgoing contact requests of requester
2546 let mut update = proto::UpdateContacts::default();
2547 update.outgoing_requests.push(responder_id.to_proto());
2548 for connection_id in session
2549 .connection_pool()
2550 .await
2551 .user_connection_ids(requester_id)
2552 {
2553 session.peer.send(connection_id, update.clone())?;
2554 }
2555
2556 // Update incoming contact requests of responder
2557 let mut update = proto::UpdateContacts::default();
2558 update
2559 .incoming_requests
2560 .push(proto::IncomingContactRequest {
2561 requester_id: requester_id.to_proto(),
2562 });
2563 let connection_pool = session.connection_pool().await;
2564 for connection_id in connection_pool.user_connection_ids(responder_id) {
2565 session.peer.send(connection_id, update.clone())?;
2566 }
2567
2568 send_notifications(&connection_pool, &session.peer, notifications);
2569
2570 response.send(proto::Ack {})?;
2571 Ok(())
2572}
2573
2574/// Accept or decline a contact request
2575async fn respond_to_contact_request(
2576 request: proto::RespondToContactRequest,
2577 response: Response<proto::RespondToContactRequest>,
2578 session: Session,
2579) -> Result<()> {
2580 let responder_id = session.user_id();
2581 let requester_id = UserId::from_proto(request.requester_id);
2582 let db = session.db().await;
2583 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2584 db.dismiss_contact_notification(responder_id, requester_id)
2585 .await?;
2586 } else {
2587 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2588
2589 let notifications = db
2590 .respond_to_contact_request(responder_id, requester_id, accept)
2591 .await?;
2592 let requester_busy = db.is_user_busy(requester_id).await?;
2593 let responder_busy = db.is_user_busy(responder_id).await?;
2594
2595 let pool = session.connection_pool().await;
2596 // Update responder with new contact
2597 let mut update = proto::UpdateContacts::default();
2598 if accept {
2599 update
2600 .contacts
2601 .push(contact_for_user(requester_id, requester_busy, &pool));
2602 }
2603 update
2604 .remove_incoming_requests
2605 .push(requester_id.to_proto());
2606 for connection_id in pool.user_connection_ids(responder_id) {
2607 session.peer.send(connection_id, update.clone())?;
2608 }
2609
2610 // Update requester with new contact
2611 let mut update = proto::UpdateContacts::default();
2612 if accept {
2613 update
2614 .contacts
2615 .push(contact_for_user(responder_id, responder_busy, &pool));
2616 }
2617 update
2618 .remove_outgoing_requests
2619 .push(responder_id.to_proto());
2620
2621 for connection_id in pool.user_connection_ids(requester_id) {
2622 session.peer.send(connection_id, update.clone())?;
2623 }
2624
2625 send_notifications(&pool, &session.peer, notifications);
2626 }
2627
2628 response.send(proto::Ack {})?;
2629 Ok(())
2630}
2631
2632/// Remove a contact.
2633async fn remove_contact(
2634 request: proto::RemoveContact,
2635 response: Response<proto::RemoveContact>,
2636 session: Session,
2637) -> Result<()> {
2638 let requester_id = session.user_id();
2639 let responder_id = UserId::from_proto(request.user_id);
2640 let db = session.db().await;
2641 let (contact_accepted, deleted_notification_id) =
2642 db.remove_contact(requester_id, responder_id).await?;
2643
2644 let pool = session.connection_pool().await;
2645 // Update outgoing contact requests of requester
2646 let mut update = proto::UpdateContacts::default();
2647 if contact_accepted {
2648 update.remove_contacts.push(responder_id.to_proto());
2649 } else {
2650 update
2651 .remove_outgoing_requests
2652 .push(responder_id.to_proto());
2653 }
2654 for connection_id in pool.user_connection_ids(requester_id) {
2655 session.peer.send(connection_id, update.clone())?;
2656 }
2657
2658 // Update incoming contact requests of responder
2659 let mut update = proto::UpdateContacts::default();
2660 if contact_accepted {
2661 update.remove_contacts.push(requester_id.to_proto());
2662 } else {
2663 update
2664 .remove_incoming_requests
2665 .push(requester_id.to_proto());
2666 }
2667 for connection_id in pool.user_connection_ids(responder_id) {
2668 session.peer.send(connection_id, update.clone())?;
2669 if let Some(notification_id) = deleted_notification_id {
2670 session.peer.send(
2671 connection_id,
2672 proto::DeleteNotification {
2673 notification_id: notification_id.to_proto(),
2674 },
2675 )?;
2676 }
2677 }
2678
2679 response.send(proto::Ack {})?;
2680 Ok(())
2681}
2682
2683fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2684 version.0.minor() < 139
2685}
2686
2687async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2688 let plan = session.current_plan(&session.db().await).await?;
2689
2690 session
2691 .peer
2692 .send(
2693 session.connection_id,
2694 proto::UpdateUserPlan { plan: plan.into() },
2695 )
2696 .trace_err();
2697
2698 Ok(())
2699}
2700
2701async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2702 subscribe_user_to_channels(session.user_id(), &session).await?;
2703 Ok(())
2704}
2705
2706async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2707 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2708 let mut pool = session.connection_pool().await;
2709 for membership in &channels_for_user.channel_memberships {
2710 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2711 }
2712 session.peer.send(
2713 session.connection_id,
2714 build_update_user_channels(&channels_for_user),
2715 )?;
2716 session.peer.send(
2717 session.connection_id,
2718 build_channels_update(channels_for_user),
2719 )?;
2720 Ok(())
2721}
2722
2723/// Creates a new channel.
2724async fn create_channel(
2725 request: proto::CreateChannel,
2726 response: Response<proto::CreateChannel>,
2727 session: Session,
2728) -> Result<()> {
2729 let db = session.db().await;
2730
2731 let parent_id = request.parent_id.map(ChannelId::from_proto);
2732 let (channel, membership) = db
2733 .create_channel(&request.name, parent_id, session.user_id())
2734 .await?;
2735
2736 let root_id = channel.root_id();
2737 let channel = Channel::from_model(channel);
2738
2739 response.send(proto::CreateChannelResponse {
2740 channel: Some(channel.to_proto()),
2741 parent_id: request.parent_id,
2742 })?;
2743
2744 let mut connection_pool = session.connection_pool().await;
2745 if let Some(membership) = membership {
2746 connection_pool.subscribe_to_channel(
2747 membership.user_id,
2748 membership.channel_id,
2749 membership.role,
2750 );
2751 let update = proto::UpdateUserChannels {
2752 channel_memberships: vec![proto::ChannelMembership {
2753 channel_id: membership.channel_id.to_proto(),
2754 role: membership.role.into(),
2755 }],
2756 ..Default::default()
2757 };
2758 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2759 session.peer.send(connection_id, update.clone())?;
2760 }
2761 }
2762
2763 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2764 if !role.can_see_channel(channel.visibility) {
2765 continue;
2766 }
2767
2768 let update = proto::UpdateChannels {
2769 channels: vec![channel.to_proto()],
2770 ..Default::default()
2771 };
2772 session.peer.send(connection_id, update.clone())?;
2773 }
2774
2775 Ok(())
2776}
2777
2778/// Delete a channel
2779async fn delete_channel(
2780 request: proto::DeleteChannel,
2781 response: Response<proto::DeleteChannel>,
2782 session: Session,
2783) -> Result<()> {
2784 let db = session.db().await;
2785
2786 let channel_id = request.channel_id;
2787 let (root_channel, removed_channels) = db
2788 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2789 .await?;
2790 response.send(proto::Ack {})?;
2791
2792 // Notify members of removed channels
2793 let mut update = proto::UpdateChannels::default();
2794 update
2795 .delete_channels
2796 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2797
2798 let connection_pool = session.connection_pool().await;
2799 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2800 session.peer.send(connection_id, update.clone())?;
2801 }
2802
2803 Ok(())
2804}
2805
2806/// Invite someone to join a channel.
2807async fn invite_channel_member(
2808 request: proto::InviteChannelMember,
2809 response: Response<proto::InviteChannelMember>,
2810 session: Session,
2811) -> Result<()> {
2812 let db = session.db().await;
2813 let channel_id = ChannelId::from_proto(request.channel_id);
2814 let invitee_id = UserId::from_proto(request.user_id);
2815 let InviteMemberResult {
2816 channel,
2817 notifications,
2818 } = db
2819 .invite_channel_member(
2820 channel_id,
2821 invitee_id,
2822 session.user_id(),
2823 request.role().into(),
2824 )
2825 .await?;
2826
2827 let update = proto::UpdateChannels {
2828 channel_invitations: vec![channel.to_proto()],
2829 ..Default::default()
2830 };
2831
2832 let connection_pool = session.connection_pool().await;
2833 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2834 session.peer.send(connection_id, update.clone())?;
2835 }
2836
2837 send_notifications(&connection_pool, &session.peer, notifications);
2838
2839 response.send(proto::Ack {})?;
2840 Ok(())
2841}
2842
2843/// remove someone from a channel
2844async fn remove_channel_member(
2845 request: proto::RemoveChannelMember,
2846 response: Response<proto::RemoveChannelMember>,
2847 session: Session,
2848) -> Result<()> {
2849 let db = session.db().await;
2850 let channel_id = ChannelId::from_proto(request.channel_id);
2851 let member_id = UserId::from_proto(request.user_id);
2852
2853 let RemoveChannelMemberResult {
2854 membership_update,
2855 notification_id,
2856 } = db
2857 .remove_channel_member(channel_id, member_id, session.user_id())
2858 .await?;
2859
2860 let mut connection_pool = session.connection_pool().await;
2861 notify_membership_updated(
2862 &mut connection_pool,
2863 membership_update,
2864 member_id,
2865 &session.peer,
2866 );
2867 for connection_id in connection_pool.user_connection_ids(member_id) {
2868 if let Some(notification_id) = notification_id {
2869 session
2870 .peer
2871 .send(
2872 connection_id,
2873 proto::DeleteNotification {
2874 notification_id: notification_id.to_proto(),
2875 },
2876 )
2877 .trace_err();
2878 }
2879 }
2880
2881 response.send(proto::Ack {})?;
2882 Ok(())
2883}
2884
2885/// Toggle the channel between public and private.
2886/// Care is taken to maintain the invariant that public channels only descend from public channels,
2887/// (though members-only channels can appear at any point in the hierarchy).
2888async fn set_channel_visibility(
2889 request: proto::SetChannelVisibility,
2890 response: Response<proto::SetChannelVisibility>,
2891 session: Session,
2892) -> Result<()> {
2893 let db = session.db().await;
2894 let channel_id = ChannelId::from_proto(request.channel_id);
2895 let visibility = request.visibility().into();
2896
2897 let channel_model = db
2898 .set_channel_visibility(channel_id, visibility, session.user_id())
2899 .await?;
2900 let root_id = channel_model.root_id();
2901 let channel = Channel::from_model(channel_model);
2902
2903 let mut connection_pool = session.connection_pool().await;
2904 for (user_id, role) in connection_pool
2905 .channel_user_ids(root_id)
2906 .collect::<Vec<_>>()
2907 .into_iter()
2908 {
2909 let update = if role.can_see_channel(channel.visibility) {
2910 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2911 proto::UpdateChannels {
2912 channels: vec![channel.to_proto()],
2913 ..Default::default()
2914 }
2915 } else {
2916 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2917 proto::UpdateChannels {
2918 delete_channels: vec![channel.id.to_proto()],
2919 ..Default::default()
2920 }
2921 };
2922
2923 for connection_id in connection_pool.user_connection_ids(user_id) {
2924 session.peer.send(connection_id, update.clone())?;
2925 }
2926 }
2927
2928 response.send(proto::Ack {})?;
2929 Ok(())
2930}
2931
2932/// Alter the role for a user in the channel.
2933async fn set_channel_member_role(
2934 request: proto::SetChannelMemberRole,
2935 response: Response<proto::SetChannelMemberRole>,
2936 session: Session,
2937) -> Result<()> {
2938 let db = session.db().await;
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940 let member_id = UserId::from_proto(request.user_id);
2941 let result = db
2942 .set_channel_member_role(
2943 channel_id,
2944 session.user_id(),
2945 member_id,
2946 request.role().into(),
2947 )
2948 .await?;
2949
2950 match result {
2951 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2952 let mut connection_pool = session.connection_pool().await;
2953 notify_membership_updated(
2954 &mut connection_pool,
2955 membership_update,
2956 member_id,
2957 &session.peer,
2958 )
2959 }
2960 db::SetMemberRoleResult::InviteUpdated(channel) => {
2961 let update = proto::UpdateChannels {
2962 channel_invitations: vec![channel.to_proto()],
2963 ..Default::default()
2964 };
2965
2966 for connection_id in session
2967 .connection_pool()
2968 .await
2969 .user_connection_ids(member_id)
2970 {
2971 session.peer.send(connection_id, update.clone())?;
2972 }
2973 }
2974 }
2975
2976 response.send(proto::Ack {})?;
2977 Ok(())
2978}
2979
2980/// Change the name of a channel
2981async fn rename_channel(
2982 request: proto::RenameChannel,
2983 response: Response<proto::RenameChannel>,
2984 session: Session,
2985) -> Result<()> {
2986 let db = session.db().await;
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let channel_model = db
2989 .rename_channel(channel_id, session.user_id(), &request.name)
2990 .await?;
2991 let root_id = channel_model.root_id();
2992 let channel = Channel::from_model(channel_model);
2993
2994 response.send(proto::RenameChannelResponse {
2995 channel: Some(channel.to_proto()),
2996 })?;
2997
2998 let connection_pool = session.connection_pool().await;
2999 let update = proto::UpdateChannels {
3000 channels: vec![channel.to_proto()],
3001 ..Default::default()
3002 };
3003 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3004 if role.can_see_channel(channel.visibility) {
3005 session.peer.send(connection_id, update.clone())?;
3006 }
3007 }
3008
3009 Ok(())
3010}
3011
3012/// Move a channel to a new parent.
3013async fn move_channel(
3014 request: proto::MoveChannel,
3015 response: Response<proto::MoveChannel>,
3016 session: Session,
3017) -> Result<()> {
3018 let channel_id = ChannelId::from_proto(request.channel_id);
3019 let to = ChannelId::from_proto(request.to);
3020
3021 let (root_id, channels) = session
3022 .db()
3023 .await
3024 .move_channel(channel_id, to, session.user_id())
3025 .await?;
3026
3027 let connection_pool = session.connection_pool().await;
3028 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3029 let channels = channels
3030 .iter()
3031 .filter_map(|channel| {
3032 if role.can_see_channel(channel.visibility) {
3033 Some(channel.to_proto())
3034 } else {
3035 None
3036 }
3037 })
3038 .collect::<Vec<_>>();
3039 if channels.is_empty() {
3040 continue;
3041 }
3042
3043 let update = proto::UpdateChannels {
3044 channels,
3045 ..Default::default()
3046 };
3047
3048 session.peer.send(connection_id, update.clone())?;
3049 }
3050
3051 response.send(Ack {})?;
3052 Ok(())
3053}
3054
3055/// Get the list of channel members
3056async fn get_channel_members(
3057 request: proto::GetChannelMembers,
3058 response: Response<proto::GetChannelMembers>,
3059 session: Session,
3060) -> Result<()> {
3061 let db = session.db().await;
3062 let channel_id = ChannelId::from_proto(request.channel_id);
3063 let limit = if request.limit == 0 {
3064 u16::MAX as u64
3065 } else {
3066 request.limit
3067 };
3068 let (members, users) = db
3069 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3070 .await?;
3071 response.send(proto::GetChannelMembersResponse { members, users })?;
3072 Ok(())
3073}
3074
3075/// Accept or decline a channel invitation.
3076async fn respond_to_channel_invite(
3077 request: proto::RespondToChannelInvite,
3078 response: Response<proto::RespondToChannelInvite>,
3079 session: Session,
3080) -> Result<()> {
3081 let db = session.db().await;
3082 let channel_id = ChannelId::from_proto(request.channel_id);
3083 let RespondToChannelInvite {
3084 membership_update,
3085 notifications,
3086 } = db
3087 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3088 .await?;
3089
3090 let mut connection_pool = session.connection_pool().await;
3091 if let Some(membership_update) = membership_update {
3092 notify_membership_updated(
3093 &mut connection_pool,
3094 membership_update,
3095 session.user_id(),
3096 &session.peer,
3097 );
3098 } else {
3099 let update = proto::UpdateChannels {
3100 remove_channel_invitations: vec![channel_id.to_proto()],
3101 ..Default::default()
3102 };
3103
3104 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3105 session.peer.send(connection_id, update.clone())?;
3106 }
3107 };
3108
3109 send_notifications(&connection_pool, &session.peer, notifications);
3110
3111 response.send(proto::Ack {})?;
3112
3113 Ok(())
3114}
3115
3116/// Join the channels' room
3117async fn join_channel(
3118 request: proto::JoinChannel,
3119 response: Response<proto::JoinChannel>,
3120 session: Session,
3121) -> Result<()> {
3122 let channel_id = ChannelId::from_proto(request.channel_id);
3123 join_channel_internal(channel_id, Box::new(response), session).await
3124}
3125
3126trait JoinChannelInternalResponse {
3127 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3128}
3129impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3130 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3131 Response::<proto::JoinChannel>::send(self, result)
3132 }
3133}
3134impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3135 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3136 Response::<proto::JoinRoom>::send(self, result)
3137 }
3138}
3139
3140async fn join_channel_internal(
3141 channel_id: ChannelId,
3142 response: Box<impl JoinChannelInternalResponse>,
3143 session: Session,
3144) -> Result<()> {
3145 let joined_room = {
3146 let mut db = session.db().await;
3147 // If zed quits without leaving the room, and the user re-opens zed before the
3148 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3149 // room they were in.
3150 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3151 tracing::info!(
3152 stale_connection_id = %connection,
3153 "cleaning up stale connection",
3154 );
3155 drop(db);
3156 leave_room_for_session(&session, connection).await?;
3157 db = session.db().await;
3158 }
3159
3160 let (joined_room, membership_updated, role) = db
3161 .join_channel(channel_id, session.user_id(), session.connection_id)
3162 .await?;
3163
3164 let live_kit_connection_info =
3165 session
3166 .app_state
3167 .livekit_client
3168 .as_ref()
3169 .and_then(|live_kit| {
3170 let (can_publish, token) = if role == ChannelRole::Guest {
3171 (
3172 false,
3173 live_kit
3174 .guest_token(
3175 &joined_room.room.livekit_room,
3176 &session.user_id().to_string(),
3177 )
3178 .trace_err()?,
3179 )
3180 } else {
3181 (
3182 true,
3183 live_kit
3184 .room_token(
3185 &joined_room.room.livekit_room,
3186 &session.user_id().to_string(),
3187 )
3188 .trace_err()?,
3189 )
3190 };
3191
3192 Some(LiveKitConnectionInfo {
3193 server_url: live_kit.url().into(),
3194 token,
3195 can_publish,
3196 })
3197 });
3198
3199 response.send(proto::JoinRoomResponse {
3200 room: Some(joined_room.room.clone()),
3201 channel_id: joined_room
3202 .channel
3203 .as_ref()
3204 .map(|channel| channel.id.to_proto()),
3205 live_kit_connection_info,
3206 })?;
3207
3208 let mut connection_pool = session.connection_pool().await;
3209 if let Some(membership_updated) = membership_updated {
3210 notify_membership_updated(
3211 &mut connection_pool,
3212 membership_updated,
3213 session.user_id(),
3214 &session.peer,
3215 );
3216 }
3217
3218 room_updated(&joined_room.room, &session.peer);
3219
3220 joined_room
3221 };
3222
3223 channel_updated(
3224 &joined_room
3225 .channel
3226 .ok_or_else(|| anyhow!("channel not returned"))?,
3227 &joined_room.room,
3228 &session.peer,
3229 &*session.connection_pool().await,
3230 );
3231
3232 update_user_contacts(session.user_id(), &session).await?;
3233 Ok(())
3234}
3235
3236/// Start editing the channel notes
3237async fn join_channel_buffer(
3238 request: proto::JoinChannelBuffer,
3239 response: Response<proto::JoinChannelBuffer>,
3240 session: Session,
3241) -> Result<()> {
3242 let db = session.db().await;
3243 let channel_id = ChannelId::from_proto(request.channel_id);
3244
3245 let open_response = db
3246 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3247 .await?;
3248
3249 let collaborators = open_response.collaborators.clone();
3250 response.send(open_response)?;
3251
3252 let update = UpdateChannelBufferCollaborators {
3253 channel_id: channel_id.to_proto(),
3254 collaborators: collaborators.clone(),
3255 };
3256 channel_buffer_updated(
3257 session.connection_id,
3258 collaborators
3259 .iter()
3260 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3261 &update,
3262 &session.peer,
3263 );
3264
3265 Ok(())
3266}
3267
3268/// Edit the channel notes
3269async fn update_channel_buffer(
3270 request: proto::UpdateChannelBuffer,
3271 session: Session,
3272) -> Result<()> {
3273 let db = session.db().await;
3274 let channel_id = ChannelId::from_proto(request.channel_id);
3275
3276 let (collaborators, epoch, version) = db
3277 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3278 .await?;
3279
3280 channel_buffer_updated(
3281 session.connection_id,
3282 collaborators.clone(),
3283 &proto::UpdateChannelBuffer {
3284 channel_id: channel_id.to_proto(),
3285 operations: request.operations,
3286 },
3287 &session.peer,
3288 );
3289
3290 let pool = &*session.connection_pool().await;
3291
3292 let non_collaborators =
3293 pool.channel_connection_ids(channel_id)
3294 .filter_map(|(connection_id, _)| {
3295 if collaborators.contains(&connection_id) {
3296 None
3297 } else {
3298 Some(connection_id)
3299 }
3300 });
3301
3302 broadcast(None, non_collaborators, |peer_id| {
3303 session.peer.send(
3304 peer_id,
3305 proto::UpdateChannels {
3306 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3307 channel_id: channel_id.to_proto(),
3308 epoch: epoch as u64,
3309 version: version.clone(),
3310 }],
3311 ..Default::default()
3312 },
3313 )
3314 });
3315
3316 Ok(())
3317}
3318
3319/// Rejoin the channel notes after a connection blip
3320async fn rejoin_channel_buffers(
3321 request: proto::RejoinChannelBuffers,
3322 response: Response<proto::RejoinChannelBuffers>,
3323 session: Session,
3324) -> Result<()> {
3325 let db = session.db().await;
3326 let buffers = db
3327 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3328 .await?;
3329
3330 for rejoined_buffer in &buffers {
3331 let collaborators_to_notify = rejoined_buffer
3332 .buffer
3333 .collaborators
3334 .iter()
3335 .filter_map(|c| Some(c.peer_id?.into()));
3336 channel_buffer_updated(
3337 session.connection_id,
3338 collaborators_to_notify,
3339 &proto::UpdateChannelBufferCollaborators {
3340 channel_id: rejoined_buffer.buffer.channel_id,
3341 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3342 },
3343 &session.peer,
3344 );
3345 }
3346
3347 response.send(proto::RejoinChannelBuffersResponse {
3348 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3349 })?;
3350
3351 Ok(())
3352}
3353
3354/// Stop editing the channel notes
3355async fn leave_channel_buffer(
3356 request: proto::LeaveChannelBuffer,
3357 response: Response<proto::LeaveChannelBuffer>,
3358 session: Session,
3359) -> Result<()> {
3360 let db = session.db().await;
3361 let channel_id = ChannelId::from_proto(request.channel_id);
3362
3363 let left_buffer = db
3364 .leave_channel_buffer(channel_id, session.connection_id)
3365 .await?;
3366
3367 response.send(Ack {})?;
3368
3369 channel_buffer_updated(
3370 session.connection_id,
3371 left_buffer.connections,
3372 &proto::UpdateChannelBufferCollaborators {
3373 channel_id: channel_id.to_proto(),
3374 collaborators: left_buffer.collaborators,
3375 },
3376 &session.peer,
3377 );
3378
3379 Ok(())
3380}
3381
3382fn channel_buffer_updated<T: EnvelopedMessage>(
3383 sender_id: ConnectionId,
3384 collaborators: impl IntoIterator<Item = ConnectionId>,
3385 message: &T,
3386 peer: &Peer,
3387) {
3388 broadcast(Some(sender_id), collaborators, |peer_id| {
3389 peer.send(peer_id, message.clone())
3390 });
3391}
3392
3393fn send_notifications(
3394 connection_pool: &ConnectionPool,
3395 peer: &Peer,
3396 notifications: db::NotificationBatch,
3397) {
3398 for (user_id, notification) in notifications {
3399 for connection_id in connection_pool.user_connection_ids(user_id) {
3400 if let Err(error) = peer.send(
3401 connection_id,
3402 proto::AddNotification {
3403 notification: Some(notification.clone()),
3404 },
3405 ) {
3406 tracing::error!(
3407 "failed to send notification to {:?} {}",
3408 connection_id,
3409 error
3410 );
3411 }
3412 }
3413 }
3414}
3415
3416/// Send a message to the channel
3417async fn send_channel_message(
3418 request: proto::SendChannelMessage,
3419 response: Response<proto::SendChannelMessage>,
3420 session: Session,
3421) -> Result<()> {
3422 // Validate the message body.
3423 let body = request.body.trim().to_string();
3424 if body.len() > MAX_MESSAGE_LEN {
3425 return Err(anyhow!("message is too long"))?;
3426 }
3427 if body.is_empty() {
3428 return Err(anyhow!("message can't be blank"))?;
3429 }
3430
3431 // TODO: adjust mentions if body is trimmed
3432
3433 let timestamp = OffsetDateTime::now_utc();
3434 let nonce = request
3435 .nonce
3436 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3437
3438 let channel_id = ChannelId::from_proto(request.channel_id);
3439 let CreatedChannelMessage {
3440 message_id,
3441 participant_connection_ids,
3442 notifications,
3443 } = session
3444 .db()
3445 .await
3446 .create_channel_message(
3447 channel_id,
3448 session.user_id(),
3449 &body,
3450 &request.mentions,
3451 timestamp,
3452 nonce.clone().into(),
3453 request.reply_to_message_id.map(MessageId::from_proto),
3454 )
3455 .await?;
3456
3457 let message = proto::ChannelMessage {
3458 sender_id: session.user_id().to_proto(),
3459 id: message_id.to_proto(),
3460 body,
3461 mentions: request.mentions,
3462 timestamp: timestamp.unix_timestamp() as u64,
3463 nonce: Some(nonce),
3464 reply_to_message_id: request.reply_to_message_id,
3465 edited_at: None,
3466 };
3467 broadcast(
3468 Some(session.connection_id),
3469 participant_connection_ids.clone(),
3470 |connection| {
3471 session.peer.send(
3472 connection,
3473 proto::ChannelMessageSent {
3474 channel_id: channel_id.to_proto(),
3475 message: Some(message.clone()),
3476 },
3477 )
3478 },
3479 );
3480 response.send(proto::SendChannelMessageResponse {
3481 message: Some(message),
3482 })?;
3483
3484 let pool = &*session.connection_pool().await;
3485 let non_participants =
3486 pool.channel_connection_ids(channel_id)
3487 .filter_map(|(connection_id, _)| {
3488 if participant_connection_ids.contains(&connection_id) {
3489 None
3490 } else {
3491 Some(connection_id)
3492 }
3493 });
3494 broadcast(None, non_participants, |peer_id| {
3495 session.peer.send(
3496 peer_id,
3497 proto::UpdateChannels {
3498 latest_channel_message_ids: vec![proto::ChannelMessageId {
3499 channel_id: channel_id.to_proto(),
3500 message_id: message_id.to_proto(),
3501 }],
3502 ..Default::default()
3503 },
3504 )
3505 });
3506 send_notifications(pool, &session.peer, notifications);
3507
3508 Ok(())
3509}
3510
3511/// Delete a channel message
3512async fn remove_channel_message(
3513 request: proto::RemoveChannelMessage,
3514 response: Response<proto::RemoveChannelMessage>,
3515 session: Session,
3516) -> Result<()> {
3517 let channel_id = ChannelId::from_proto(request.channel_id);
3518 let message_id = MessageId::from_proto(request.message_id);
3519 let (connection_ids, existing_notification_ids) = session
3520 .db()
3521 .await
3522 .remove_channel_message(channel_id, message_id, session.user_id())
3523 .await?;
3524
3525 broadcast(
3526 Some(session.connection_id),
3527 connection_ids,
3528 move |connection| {
3529 session.peer.send(connection, request.clone())?;
3530
3531 for notification_id in &existing_notification_ids {
3532 session.peer.send(
3533 connection,
3534 proto::DeleteNotification {
3535 notification_id: (*notification_id).to_proto(),
3536 },
3537 )?;
3538 }
3539
3540 Ok(())
3541 },
3542 );
3543 response.send(proto::Ack {})?;
3544 Ok(())
3545}
3546
3547async fn update_channel_message(
3548 request: proto::UpdateChannelMessage,
3549 response: Response<proto::UpdateChannelMessage>,
3550 session: Session,
3551) -> Result<()> {
3552 let channel_id = ChannelId::from_proto(request.channel_id);
3553 let message_id = MessageId::from_proto(request.message_id);
3554 let updated_at = OffsetDateTime::now_utc();
3555 let UpdatedChannelMessage {
3556 message_id,
3557 participant_connection_ids,
3558 notifications,
3559 reply_to_message_id,
3560 timestamp,
3561 deleted_mention_notification_ids,
3562 updated_mention_notifications,
3563 } = session
3564 .db()
3565 .await
3566 .update_channel_message(
3567 channel_id,
3568 message_id,
3569 session.user_id(),
3570 request.body.as_str(),
3571 &request.mentions,
3572 updated_at,
3573 )
3574 .await?;
3575
3576 let nonce = request
3577 .nonce
3578 .clone()
3579 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3580
3581 let message = proto::ChannelMessage {
3582 sender_id: session.user_id().to_proto(),
3583 id: message_id.to_proto(),
3584 body: request.body.clone(),
3585 mentions: request.mentions.clone(),
3586 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3587 nonce: Some(nonce),
3588 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3589 edited_at: Some(updated_at.unix_timestamp() as u64),
3590 };
3591
3592 response.send(proto::Ack {})?;
3593
3594 let pool = &*session.connection_pool().await;
3595 broadcast(
3596 Some(session.connection_id),
3597 participant_connection_ids,
3598 |connection| {
3599 session.peer.send(
3600 connection,
3601 proto::ChannelMessageUpdate {
3602 channel_id: channel_id.to_proto(),
3603 message: Some(message.clone()),
3604 },
3605 )?;
3606
3607 for notification_id in &deleted_mention_notification_ids {
3608 session.peer.send(
3609 connection,
3610 proto::DeleteNotification {
3611 notification_id: (*notification_id).to_proto(),
3612 },
3613 )?;
3614 }
3615
3616 for notification in &updated_mention_notifications {
3617 session.peer.send(
3618 connection,
3619 proto::UpdateNotification {
3620 notification: Some(notification.clone()),
3621 },
3622 )?;
3623 }
3624
3625 Ok(())
3626 },
3627 );
3628
3629 send_notifications(pool, &session.peer, notifications);
3630
3631 Ok(())
3632}
3633
3634/// Mark a channel message as read
3635async fn acknowledge_channel_message(
3636 request: proto::AckChannelMessage,
3637 session: Session,
3638) -> Result<()> {
3639 let channel_id = ChannelId::from_proto(request.channel_id);
3640 let message_id = MessageId::from_proto(request.message_id);
3641 let notifications = session
3642 .db()
3643 .await
3644 .observe_channel_message(channel_id, session.user_id(), message_id)
3645 .await?;
3646 send_notifications(
3647 &*session.connection_pool().await,
3648 &session.peer,
3649 notifications,
3650 );
3651 Ok(())
3652}
3653
3654/// Mark a buffer version as synced
3655async fn acknowledge_buffer_version(
3656 request: proto::AckBufferOperation,
3657 session: Session,
3658) -> Result<()> {
3659 let buffer_id = BufferId::from_proto(request.buffer_id);
3660 session
3661 .db()
3662 .await
3663 .observe_buffer_version(
3664 buffer_id,
3665 session.user_id(),
3666 request.epoch as i32,
3667 &request.version,
3668 )
3669 .await?;
3670 Ok(())
3671}
3672
3673/// Get a Supermaven API key for the user
3674async fn get_supermaven_api_key(
3675 _request: proto::GetSupermavenApiKey,
3676 response: Response<proto::GetSupermavenApiKey>,
3677 session: Session,
3678) -> Result<()> {
3679 let user_id: String = session.user_id().to_string();
3680 if !session.is_staff() {
3681 return Err(anyhow!("supermaven not enabled for this account"))?;
3682 }
3683
3684 let email = session
3685 .email()
3686 .ok_or_else(|| anyhow!("user must have an email"))?;
3687
3688 let supermaven_admin_api = session
3689 .supermaven_client
3690 .as_ref()
3691 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3692
3693 let result = supermaven_admin_api
3694 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3695 .await?;
3696
3697 response.send(proto::GetSupermavenApiKeyResponse {
3698 api_key: result.api_key,
3699 })?;
3700
3701 Ok(())
3702}
3703
3704/// Start receiving chat updates for a channel
3705async fn join_channel_chat(
3706 request: proto::JoinChannelChat,
3707 response: Response<proto::JoinChannelChat>,
3708 session: Session,
3709) -> Result<()> {
3710 let channel_id = ChannelId::from_proto(request.channel_id);
3711
3712 let db = session.db().await;
3713 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3714 .await?;
3715 let messages = db
3716 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3717 .await?;
3718 response.send(proto::JoinChannelChatResponse {
3719 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3720 messages,
3721 })?;
3722 Ok(())
3723}
3724
3725/// Stop receiving chat updates for a channel
3726async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3727 let channel_id = ChannelId::from_proto(request.channel_id);
3728 session
3729 .db()
3730 .await
3731 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3732 .await?;
3733 Ok(())
3734}
3735
3736/// Retrieve the chat history for a channel
3737async fn get_channel_messages(
3738 request: proto::GetChannelMessages,
3739 response: Response<proto::GetChannelMessages>,
3740 session: Session,
3741) -> Result<()> {
3742 let channel_id = ChannelId::from_proto(request.channel_id);
3743 let messages = session
3744 .db()
3745 .await
3746 .get_channel_messages(
3747 channel_id,
3748 session.user_id(),
3749 MESSAGE_COUNT_PER_PAGE,
3750 Some(MessageId::from_proto(request.before_message_id)),
3751 )
3752 .await?;
3753 response.send(proto::GetChannelMessagesResponse {
3754 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3755 messages,
3756 })?;
3757 Ok(())
3758}
3759
3760/// Retrieve specific chat messages
3761async fn get_channel_messages_by_id(
3762 request: proto::GetChannelMessagesById,
3763 response: Response<proto::GetChannelMessagesById>,
3764 session: Session,
3765) -> Result<()> {
3766 let message_ids = request
3767 .message_ids
3768 .iter()
3769 .map(|id| MessageId::from_proto(*id))
3770 .collect::<Vec<_>>();
3771 let messages = session
3772 .db()
3773 .await
3774 .get_channel_messages_by_id(session.user_id(), &message_ids)
3775 .await?;
3776 response.send(proto::GetChannelMessagesResponse {
3777 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3778 messages,
3779 })?;
3780 Ok(())
3781}
3782
3783/// Retrieve the current users notifications
3784async fn get_notifications(
3785 request: proto::GetNotifications,
3786 response: Response<proto::GetNotifications>,
3787 session: Session,
3788) -> Result<()> {
3789 let notifications = session
3790 .db()
3791 .await
3792 .get_notifications(
3793 session.user_id(),
3794 NOTIFICATION_COUNT_PER_PAGE,
3795 request.before_id.map(db::NotificationId::from_proto),
3796 )
3797 .await?;
3798 response.send(proto::GetNotificationsResponse {
3799 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3800 notifications,
3801 })?;
3802 Ok(())
3803}
3804
3805/// Mark notifications as read
3806async fn mark_notification_as_read(
3807 request: proto::MarkNotificationRead,
3808 response: Response<proto::MarkNotificationRead>,
3809 session: Session,
3810) -> Result<()> {
3811 let database = &session.db().await;
3812 let notifications = database
3813 .mark_notification_as_read_by_id(
3814 session.user_id(),
3815 NotificationId::from_proto(request.notification_id),
3816 )
3817 .await?;
3818 send_notifications(
3819 &*session.connection_pool().await,
3820 &session.peer,
3821 notifications,
3822 );
3823 response.send(proto::Ack {})?;
3824 Ok(())
3825}
3826
3827/// Get the current users information
3828async fn get_private_user_info(
3829 _request: proto::GetPrivateUserInfo,
3830 response: Response<proto::GetPrivateUserInfo>,
3831 session: Session,
3832) -> Result<()> {
3833 let db = session.db().await;
3834
3835 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3836 let user = db
3837 .get_user_by_id(session.user_id())
3838 .await?
3839 .ok_or_else(|| anyhow!("user not found"))?;
3840 let flags = db.get_user_flags(session.user_id()).await?;
3841
3842 response.send(proto::GetPrivateUserInfoResponse {
3843 metrics_id,
3844 staff: user.admin,
3845 flags,
3846 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3847 })?;
3848 Ok(())
3849}
3850
3851/// Accept the terms of service (tos) on behalf of the current user
3852async fn accept_terms_of_service(
3853 _request: proto::AcceptTermsOfService,
3854 response: Response<proto::AcceptTermsOfService>,
3855 session: Session,
3856) -> Result<()> {
3857 let db = session.db().await;
3858
3859 let accepted_tos_at = Utc::now();
3860 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3861 .await?;
3862
3863 response.send(proto::AcceptTermsOfServiceResponse {
3864 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3865 })?;
3866 Ok(())
3867}
3868
3869/// The minimum account age an account must have in order to use the LLM service.
3870pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3871
3872async fn get_llm_api_token(
3873 _request: proto::GetLlmToken,
3874 response: Response<proto::GetLlmToken>,
3875 session: Session,
3876) -> Result<()> {
3877 let db = session.db().await;
3878
3879 let flags = db.get_user_flags(session.user_id()).await?;
3880 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3881
3882 if !session.is_staff() && !has_language_models_feature_flag {
3883 Err(anyhow!("permission denied"))?
3884 }
3885
3886 let user_id = session.user_id();
3887 let user = db
3888 .get_user_by_id(user_id)
3889 .await?
3890 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3891
3892 if user.accepted_tos_at.is_none() {
3893 Err(anyhow!("terms of service not accepted"))?
3894 }
3895
3896 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
3897 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
3898 let billing_preferences = db.get_billing_preferences(user.id).await?;
3899
3900 let token = LlmTokenClaims::create(
3901 &user,
3902 session.is_staff(),
3903 billing_preferences,
3904 &flags,
3905 has_legacy_llm_subscription,
3906 billing_subscription,
3907 session.system_id.clone(),
3908 &session.app_state.config,
3909 )?;
3910 response.send(proto::GetLlmTokenResponse { token })?;
3911 Ok(())
3912}
3913
3914fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3915 let message = match message {
3916 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3917 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3918 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3919 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3920 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3921 code: frame.code.into(),
3922 reason: frame.reason.as_str().to_owned().into(),
3923 })),
3924 // We should never receive a frame while reading the message, according
3925 // to the `tungstenite` maintainers:
3926 //
3927 // > It cannot occur when you read messages from the WebSocket, but it
3928 // > can be used when you want to send the raw frames (e.g. you want to
3929 // > send the frames to the WebSocket without composing the full message first).
3930 // >
3931 // > — https://github.com/snapview/tungstenite-rs/issues/268
3932 TungsteniteMessage::Frame(_) => {
3933 bail!("received an unexpected frame while reading the message")
3934 }
3935 };
3936
3937 Ok(message)
3938}
3939
3940fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3941 match message {
3942 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3943 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3944 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3945 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3946 AxumMessage::Close(frame) => {
3947 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3948 code: frame.code.into(),
3949 reason: frame.reason.as_ref().into(),
3950 }))
3951 }
3952 }
3953}
3954
3955fn notify_membership_updated(
3956 connection_pool: &mut ConnectionPool,
3957 result: MembershipUpdated,
3958 user_id: UserId,
3959 peer: &Peer,
3960) {
3961 for membership in &result.new_channels.channel_memberships {
3962 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3963 }
3964 for channel_id in &result.removed_channels {
3965 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3966 }
3967
3968 let user_channels_update = proto::UpdateUserChannels {
3969 channel_memberships: result
3970 .new_channels
3971 .channel_memberships
3972 .iter()
3973 .map(|cm| proto::ChannelMembership {
3974 channel_id: cm.channel_id.to_proto(),
3975 role: cm.role.into(),
3976 })
3977 .collect(),
3978 ..Default::default()
3979 };
3980
3981 let mut update = build_channels_update(result.new_channels);
3982 update.delete_channels = result
3983 .removed_channels
3984 .into_iter()
3985 .map(|id| id.to_proto())
3986 .collect();
3987 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3988
3989 for connection_id in connection_pool.user_connection_ids(user_id) {
3990 peer.send(connection_id, user_channels_update.clone())
3991 .trace_err();
3992 peer.send(connection_id, update.clone()).trace_err();
3993 }
3994}
3995
3996fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3997 proto::UpdateUserChannels {
3998 channel_memberships: channels
3999 .channel_memberships
4000 .iter()
4001 .map(|m| proto::ChannelMembership {
4002 channel_id: m.channel_id.to_proto(),
4003 role: m.role.into(),
4004 })
4005 .collect(),
4006 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4007 observed_channel_message_id: channels.observed_channel_messages.clone(),
4008 }
4009}
4010
4011fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4012 let mut update = proto::UpdateChannels::default();
4013
4014 for channel in channels.channels {
4015 update.channels.push(channel.to_proto());
4016 }
4017
4018 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4019 update.latest_channel_message_ids = channels.latest_channel_messages;
4020
4021 for (channel_id, participants) in channels.channel_participants {
4022 update
4023 .channel_participants
4024 .push(proto::ChannelParticipants {
4025 channel_id: channel_id.to_proto(),
4026 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4027 });
4028 }
4029
4030 for channel in channels.invited_channels {
4031 update.channel_invitations.push(channel.to_proto());
4032 }
4033
4034 update
4035}
4036
4037fn build_initial_contacts_update(
4038 contacts: Vec<db::Contact>,
4039 pool: &ConnectionPool,
4040) -> proto::UpdateContacts {
4041 let mut update = proto::UpdateContacts::default();
4042
4043 for contact in contacts {
4044 match contact {
4045 db::Contact::Accepted { user_id, busy } => {
4046 update.contacts.push(contact_for_user(user_id, busy, pool));
4047 }
4048 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4049 db::Contact::Incoming { user_id } => {
4050 update
4051 .incoming_requests
4052 .push(proto::IncomingContactRequest {
4053 requester_id: user_id.to_proto(),
4054 })
4055 }
4056 }
4057 }
4058
4059 update
4060}
4061
4062fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4063 proto::Contact {
4064 user_id: user_id.to_proto(),
4065 online: pool.is_user_online(user_id),
4066 busy,
4067 }
4068}
4069
4070fn room_updated(room: &proto::Room, peer: &Peer) {
4071 broadcast(
4072 None,
4073 room.participants
4074 .iter()
4075 .filter_map(|participant| Some(participant.peer_id?.into())),
4076 |peer_id| {
4077 peer.send(
4078 peer_id,
4079 proto::RoomUpdated {
4080 room: Some(room.clone()),
4081 },
4082 )
4083 },
4084 );
4085}
4086
4087fn channel_updated(
4088 channel: &db::channel::Model,
4089 room: &proto::Room,
4090 peer: &Peer,
4091 pool: &ConnectionPool,
4092) {
4093 let participants = room
4094 .participants
4095 .iter()
4096 .map(|p| p.user_id)
4097 .collect::<Vec<_>>();
4098
4099 broadcast(
4100 None,
4101 pool.channel_connection_ids(channel.root_id())
4102 .filter_map(|(channel_id, role)| {
4103 role.can_see_channel(channel.visibility)
4104 .then_some(channel_id)
4105 }),
4106 |peer_id| {
4107 peer.send(
4108 peer_id,
4109 proto::UpdateChannels {
4110 channel_participants: vec![proto::ChannelParticipants {
4111 channel_id: channel.id.to_proto(),
4112 participant_user_ids: participants.clone(),
4113 }],
4114 ..Default::default()
4115 },
4116 )
4117 },
4118 );
4119}
4120
4121async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4122 let db = session.db().await;
4123
4124 let contacts = db.get_contacts(user_id).await?;
4125 let busy = db.is_user_busy(user_id).await?;
4126
4127 let pool = session.connection_pool().await;
4128 let updated_contact = contact_for_user(user_id, busy, &pool);
4129 for contact in contacts {
4130 if let db::Contact::Accepted {
4131 user_id: contact_user_id,
4132 ..
4133 } = contact
4134 {
4135 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4136 session
4137 .peer
4138 .send(
4139 contact_conn_id,
4140 proto::UpdateContacts {
4141 contacts: vec![updated_contact.clone()],
4142 remove_contacts: Default::default(),
4143 incoming_requests: Default::default(),
4144 remove_incoming_requests: Default::default(),
4145 outgoing_requests: Default::default(),
4146 remove_outgoing_requests: Default::default(),
4147 },
4148 )
4149 .trace_err();
4150 }
4151 }
4152 }
4153 Ok(())
4154}
4155
4156async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4157 let mut contacts_to_update = HashSet::default();
4158
4159 let room_id;
4160 let canceled_calls_to_user_ids;
4161 let livekit_room;
4162 let delete_livekit_room;
4163 let room;
4164 let channel;
4165
4166 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4167 contacts_to_update.insert(session.user_id());
4168
4169 for project in left_room.left_projects.values() {
4170 project_left(project, session);
4171 }
4172
4173 room_id = RoomId::from_proto(left_room.room.id);
4174 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4175 livekit_room = mem::take(&mut left_room.room.livekit_room);
4176 delete_livekit_room = left_room.deleted;
4177 room = mem::take(&mut left_room.room);
4178 channel = mem::take(&mut left_room.channel);
4179
4180 room_updated(&room, &session.peer);
4181 } else {
4182 return Ok(());
4183 }
4184
4185 if let Some(channel) = channel {
4186 channel_updated(
4187 &channel,
4188 &room,
4189 &session.peer,
4190 &*session.connection_pool().await,
4191 );
4192 }
4193
4194 {
4195 let pool = session.connection_pool().await;
4196 for canceled_user_id in canceled_calls_to_user_ids {
4197 for connection_id in pool.user_connection_ids(canceled_user_id) {
4198 session
4199 .peer
4200 .send(
4201 connection_id,
4202 proto::CallCanceled {
4203 room_id: room_id.to_proto(),
4204 },
4205 )
4206 .trace_err();
4207 }
4208 contacts_to_update.insert(canceled_user_id);
4209 }
4210 }
4211
4212 for contact_user_id in contacts_to_update {
4213 update_user_contacts(contact_user_id, session).await?;
4214 }
4215
4216 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4217 live_kit
4218 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4219 .await
4220 .trace_err();
4221
4222 if delete_livekit_room {
4223 live_kit.delete_room(livekit_room).await.trace_err();
4224 }
4225 }
4226
4227 Ok(())
4228}
4229
4230async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4231 let left_channel_buffers = session
4232 .db()
4233 .await
4234 .leave_channel_buffers(session.connection_id)
4235 .await?;
4236
4237 for left_buffer in left_channel_buffers {
4238 channel_buffer_updated(
4239 session.connection_id,
4240 left_buffer.connections,
4241 &proto::UpdateChannelBufferCollaborators {
4242 channel_id: left_buffer.channel_id.to_proto(),
4243 collaborators: left_buffer.collaborators,
4244 },
4245 &session.peer,
4246 );
4247 }
4248
4249 Ok(())
4250}
4251
4252fn project_left(project: &db::LeftProject, session: &Session) {
4253 for connection_id in &project.connection_ids {
4254 if project.should_unshare {
4255 session
4256 .peer
4257 .send(
4258 *connection_id,
4259 proto::UnshareProject {
4260 project_id: project.id.to_proto(),
4261 },
4262 )
4263 .trace_err();
4264 } else {
4265 session
4266 .peer
4267 .send(
4268 *connection_id,
4269 proto::RemoveProjectCollaborator {
4270 project_id: project.id.to_proto(),
4271 peer_id: Some(session.connection_id.into()),
4272 },
4273 )
4274 .trace_err();
4275 }
4276 }
4277}
4278
4279pub trait ResultExt {
4280 type Ok;
4281
4282 fn trace_err(self) -> Option<Self::Ok>;
4283}
4284
4285impl<T, E> ResultExt for Result<T, E>
4286where
4287 E: std::fmt::Debug,
4288{
4289 type Ok = T;
4290
4291 #[track_caller]
4292 fn trace_err(self) -> Option<T> {
4293 match self {
4294 Ok(value) => Some(value),
4295 Err(error) => {
4296 tracing::error!("{:?}", error);
4297 None
4298 }
4299 }
4300 }
4301}