1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::LlmTokenClaims;
6use crate::{
7 AppState, Error, Result, auth,
8 db::{
9 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
10 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
11 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
12 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15};
16use anyhow::{Context as _, anyhow, bail};
17use async_tungstenite::tungstenite::{
18 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
19};
20use axum::{
21 Extension, Router, TypedHeader,
22 body::Body,
23 extract::{
24 ConnectInfo, WebSocketUpgrade,
25 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{MutexGuard, Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone, Debug)]
105pub enum Principal {
106 User(User),
107 Impersonated { user: User, admin: User },
108}
109
110impl Principal {
111 fn update_span(&self, span: &tracing::Span) {
112 match &self {
113 Principal::User(user) => {
114 span.record("user_id", user.id.0);
115 span.record("login", &user.github_login);
116 }
117 Principal::Impersonated { user, admin } => {
118 span.record("user_id", user.id.0);
119 span.record("login", &user.github_login);
120 span.record("impersonator", &admin.github_login);
121 }
122 }
123 }
124}
125
126#[derive(Clone)]
127struct Session {
128 principal: Principal,
129 connection_id: ConnectionId,
130 db: Arc<tokio::sync::Mutex<DbHandle>>,
131 peer: Arc<Peer>,
132 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
133 app_state: Arc<AppState>,
134 supermaven_client: Option<Arc<SupermavenAdminApi>>,
135 /// The GeoIP country code for the user.
136 #[allow(unused)]
137 geoip_country_code: Option<String>,
138 system_id: Option<String>,
139 _executor: Executor,
140}
141
142impl Session {
143 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
144 #[cfg(test)]
145 tokio::task::yield_now().await;
146 let guard = self.db.lock().await;
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 guard
150 }
151
152 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.connection_pool.lock();
156 ConnectionPoolGuard {
157 guard,
158 _not_send: PhantomData,
159 }
160 }
161
162 fn is_staff(&self) -> bool {
163 match &self.principal {
164 Principal::User(user) => user.admin,
165 Principal::Impersonated { .. } => true,
166 }
167 }
168
169 pub async fn has_llm_subscription(
170 &self,
171 db: &MutexGuard<'_, DbHandle>,
172 ) -> anyhow::Result<bool> {
173 if self.is_staff() {
174 return Ok(true);
175 }
176
177 let user_id = self.user_id();
178
179 Ok(db.has_active_billing_subscription(user_id).await?)
180 }
181
182 pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
183 let user_id = self.user_id();
184
185 let subscription = db.get_active_billing_subscription(user_id).await?;
186 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
187
188 let plan = if let Some(subscription_kind) = subscription_kind {
189 match subscription_kind {
190 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
191 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
192 SubscriptionKind::ZedFree => proto::Plan::Free,
193 }
194 } else {
195 proto::Plan::Free
196 };
197
198 Ok(plan)
199 }
200
201 fn user_id(&self) -> UserId {
202 match &self.principal {
203 Principal::User(user) => user.id,
204 Principal::Impersonated { user, .. } => user.id,
205 }
206 }
207
208 pub fn email(&self) -> Option<String> {
209 match &self.principal {
210 Principal::User(user) => user.email_address.clone(),
211 Principal::Impersonated { user, .. } => user.email_address.clone(),
212 }
213 }
214}
215
216impl Debug for Session {
217 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
218 let mut result = f.debug_struct("Session");
219 match &self.principal {
220 Principal::User(user) => {
221 result.field("user", &user.github_login);
222 }
223 Principal::Impersonated { user, admin } => {
224 result.field("user", &user.github_login);
225 result.field("impersonator", &admin.github_login);
226 }
227 }
228 result.field("connection_id", &self.connection_id).finish()
229 }
230}
231
232struct DbHandle(Arc<Database>);
233
234impl Deref for DbHandle {
235 type Target = Database;
236
237 fn deref(&self) -> &Self::Target {
238 self.0.as_ref()
239 }
240}
241
242pub struct Server {
243 id: parking_lot::Mutex<ServerId>,
244 peer: Arc<Peer>,
245 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
246 app_state: Arc<AppState>,
247 handlers: HashMap<TypeId, MessageHandler>,
248 teardown: watch::Sender<bool>,
249}
250
251pub(crate) struct ConnectionPoolGuard<'a> {
252 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
253 _not_send: PhantomData<Rc<()>>,
254}
255
256#[derive(Serialize)]
257pub struct ServerSnapshot<'a> {
258 peer: &'a Peer,
259 #[serde(serialize_with = "serialize_deref")]
260 connection_pool: ConnectionPoolGuard<'a>,
261}
262
263pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
264where
265 S: Serializer,
266 T: Deref<Target = U>,
267 U: Serialize,
268{
269 Serialize::serialize(value.deref(), serializer)
270}
271
272impl Server {
273 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
274 let mut server = Self {
275 id: parking_lot::Mutex::new(id),
276 peer: Peer::new(id.0 as u32),
277 app_state: app_state.clone(),
278 connection_pool: Default::default(),
279 handlers: Default::default(),
280 teardown: watch::channel(false).0,
281 };
282
283 server
284 .add_request_handler(ping)
285 .add_request_handler(create_room)
286 .add_request_handler(join_room)
287 .add_request_handler(rejoin_room)
288 .add_request_handler(leave_room)
289 .add_request_handler(set_room_participant_role)
290 .add_request_handler(call)
291 .add_request_handler(cancel_call)
292 .add_message_handler(decline_call)
293 .add_request_handler(update_participant_location)
294 .add_request_handler(share_project)
295 .add_message_handler(unshare_project)
296 .add_request_handler(join_project)
297 .add_message_handler(leave_project)
298 .add_request_handler(update_project)
299 .add_request_handler(update_worktree)
300 .add_request_handler(update_repository)
301 .add_request_handler(remove_repository)
302 .add_message_handler(start_language_server)
303 .add_message_handler(update_language_server)
304 .add_message_handler(update_diagnostic_summary)
305 .add_message_handler(update_worktree_settings)
306 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
309 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
310 .add_request_handler(forward_find_search_candidates_request)
311 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
313 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
316 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
317 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
318 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
320 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
321 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
326 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
327 .add_request_handler(
328 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
329 )
330 .add_request_handler(
331 forward_read_only_project_request::<proto::LanguageServerIdForName>,
332 )
333 .add_request_handler(
334 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
337 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
342 .add_request_handler(
343 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
344 )
345 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
346 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
347 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
348 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
349 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
350 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
351 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
352 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
353 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
354 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
355 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
356 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
357 .add_request_handler(
358 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
359 )
360 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
361 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
362 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
363 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
364 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
365 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
366 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
367 .add_message_handler(create_buffer_for_peer)
368 .add_request_handler(update_buffer)
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
375 .add_request_handler(get_users)
376 .add_request_handler(fuzzy_search_users)
377 .add_request_handler(request_contact)
378 .add_request_handler(remove_contact)
379 .add_request_handler(respond_to_contact_request)
380 .add_message_handler(subscribe_to_channels)
381 .add_request_handler(create_channel)
382 .add_request_handler(delete_channel)
383 .add_request_handler(invite_channel_member)
384 .add_request_handler(remove_channel_member)
385 .add_request_handler(set_channel_member_role)
386 .add_request_handler(set_channel_visibility)
387 .add_request_handler(rename_channel)
388 .add_request_handler(join_channel_buffer)
389 .add_request_handler(leave_channel_buffer)
390 .add_message_handler(update_channel_buffer)
391 .add_request_handler(rejoin_channel_buffers)
392 .add_request_handler(get_channel_members)
393 .add_request_handler(respond_to_channel_invite)
394 .add_request_handler(join_channel)
395 .add_request_handler(join_channel_chat)
396 .add_message_handler(leave_channel_chat)
397 .add_request_handler(send_channel_message)
398 .add_request_handler(remove_channel_message)
399 .add_request_handler(update_channel_message)
400 .add_request_handler(get_channel_messages)
401 .add_request_handler(get_channel_messages_by_id)
402 .add_request_handler(get_notifications)
403 .add_request_handler(mark_notification_as_read)
404 .add_request_handler(move_channel)
405 .add_request_handler(follow)
406 .add_message_handler(unfollow)
407 .add_message_handler(update_followers)
408 .add_request_handler(get_private_user_info)
409 .add_request_handler(get_llm_api_token)
410 .add_request_handler(accept_terms_of_service)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(get_supermaven_api_key)
414 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
415 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
416 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
417 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
418 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
419 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
432 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
433 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
434 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
435 .add_message_handler(update_context);
436
437 Arc::new(server)
438 }
439
440 pub async fn start(&self) -> Result<()> {
441 let server_id = *self.id.lock();
442 let app_state = self.app_state.clone();
443 let peer = self.peer.clone();
444 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
445 let pool = self.connection_pool.clone();
446 let livekit_client = self.app_state.livekit_client.clone();
447
448 let span = info_span!("start server");
449 self.app_state.executor.spawn_detached(
450 async move {
451 tracing::info!("waiting for cleanup timeout");
452 timeout.await;
453 tracing::info!("cleanup timeout expired, retrieving stale rooms");
454 if let Some((room_ids, channel_ids)) = app_state
455 .db
456 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
457 .await
458 .trace_err()
459 {
460 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
461 tracing::info!(
462 stale_channel_buffer_count = channel_ids.len(),
463 "retrieved stale channel buffers"
464 );
465
466 for channel_id in channel_ids {
467 if let Some(refreshed_channel_buffer) = app_state
468 .db
469 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
470 .await
471 .trace_err()
472 {
473 for connection_id in refreshed_channel_buffer.connection_ids {
474 peer.send(
475 connection_id,
476 proto::UpdateChannelBufferCollaborators {
477 channel_id: channel_id.to_proto(),
478 collaborators: refreshed_channel_buffer
479 .collaborators
480 .clone(),
481 },
482 )
483 .trace_err();
484 }
485 }
486 }
487
488 for room_id in room_ids {
489 let mut contacts_to_update = HashSet::default();
490 let mut canceled_calls_to_user_ids = Vec::new();
491 let mut livekit_room = String::new();
492 let mut delete_livekit_room = false;
493
494 if let Some(mut refreshed_room) = app_state
495 .db
496 .clear_stale_room_participants(room_id, server_id)
497 .await
498 .trace_err()
499 {
500 tracing::info!(
501 room_id = room_id.0,
502 new_participant_count = refreshed_room.room.participants.len(),
503 "refreshed room"
504 );
505 room_updated(&refreshed_room.room, &peer);
506 if let Some(channel) = refreshed_room.channel.as_ref() {
507 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
508 }
509 contacts_to_update
510 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
511 contacts_to_update
512 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
513 canceled_calls_to_user_ids =
514 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
515 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
516 delete_livekit_room = refreshed_room.room.participants.is_empty();
517 }
518
519 {
520 let pool = pool.lock();
521 for canceled_user_id in canceled_calls_to_user_ids {
522 for connection_id in pool.user_connection_ids(canceled_user_id) {
523 peer.send(
524 connection_id,
525 proto::CallCanceled {
526 room_id: room_id.to_proto(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for user_id in contacts_to_update {
535 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
536 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
537 if let Some((busy, contacts)) = busy.zip(contacts) {
538 let pool = pool.lock();
539 let updated_contact = contact_for_user(user_id, busy, &pool);
540 for contact in contacts {
541 if let db::Contact::Accepted {
542 user_id: contact_user_id,
543 ..
544 } = contact
545 {
546 for contact_conn_id in
547 pool.user_connection_ids(contact_user_id)
548 {
549 peer.send(
550 contact_conn_id,
551 proto::UpdateContacts {
552 contacts: vec![updated_contact.clone()],
553 remove_contacts: Default::default(),
554 incoming_requests: Default::default(),
555 remove_incoming_requests: Default::default(),
556 outgoing_requests: Default::default(),
557 remove_outgoing_requests: Default::default(),
558 },
559 )
560 .trace_err();
561 }
562 }
563 }
564 }
565 }
566
567 if let Some(live_kit) = livekit_client.as_ref() {
568 if delete_livekit_room {
569 live_kit.delete_room(livekit_room).await.trace_err();
570 }
571 }
572 }
573 }
574
575 app_state
576 .db
577 .delete_stale_servers(&app_state.config.zed_environment, server_id)
578 .await
579 .trace_err();
580 }
581 .instrument(span),
582 );
583 Ok(())
584 }
585
586 pub fn teardown(&self) {
587 self.peer.teardown();
588 self.connection_pool.lock().reset();
589 let _ = self.teardown.send(true);
590 }
591
592 #[cfg(test)]
593 pub fn reset(&self, id: ServerId) {
594 self.teardown();
595 *self.id.lock() = id;
596 self.peer.reset(id.0 as u32);
597 let _ = self.teardown.send(false);
598 }
599
600 #[cfg(test)]
601 pub fn id(&self) -> ServerId {
602 *self.id.lock()
603 }
604
605 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
606 where
607 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
608 Fut: 'static + Send + Future<Output = Result<()>>,
609 M: EnvelopedMessage,
610 {
611 let prev_handler = self.handlers.insert(
612 TypeId::of::<M>(),
613 Box::new(move |envelope, session| {
614 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
615 let received_at = envelope.received_at;
616 tracing::info!("message received");
617 let start_time = Instant::now();
618 let future = (handler)(*envelope, session);
619 async move {
620 let result = future.await;
621 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
622 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
623 let queue_duration_ms = total_duration_ms - processing_duration_ms;
624 let payload_type = M::NAME;
625
626 match result {
627 Err(error) => {
628 tracing::error!(
629 ?error,
630 total_duration_ms,
631 processing_duration_ms,
632 queue_duration_ms,
633 payload_type,
634 "error handling message"
635 )
636 }
637 Ok(()) => tracing::info!(
638 total_duration_ms,
639 processing_duration_ms,
640 queue_duration_ms,
641 "finished handling message"
642 ),
643 }
644 }
645 .boxed()
646 }),
647 );
648 if prev_handler.is_some() {
649 panic!("registered a handler for the same message twice");
650 }
651 self
652 }
653
654 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
655 where
656 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
657 Fut: 'static + Send + Future<Output = Result<()>>,
658 M: EnvelopedMessage,
659 {
660 self.add_handler(move |envelope, session| handler(envelope.payload, session));
661 self
662 }
663
664 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
665 where
666 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
667 Fut: Send + Future<Output = Result<()>>,
668 M: RequestMessage,
669 {
670 let handler = Arc::new(handler);
671 self.add_handler(move |envelope, session| {
672 let receipt = envelope.receipt();
673 let handler = handler.clone();
674 async move {
675 let peer = session.peer.clone();
676 let responded = Arc::new(AtomicBool::default());
677 let response = Response {
678 peer: peer.clone(),
679 responded: responded.clone(),
680 receipt,
681 };
682 match (handler)(envelope.payload, response, session).await {
683 Ok(()) => {
684 if responded.load(std::sync::atomic::Ordering::SeqCst) {
685 Ok(())
686 } else {
687 Err(anyhow!("handler did not send a response"))?
688 }
689 }
690 Err(error) => {
691 let proto_err = match &error {
692 Error::Internal(err) => err.to_proto(),
693 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
694 };
695 peer.respond_with_error(receipt, proto_err)?;
696 Err(error)
697 }
698 }
699 }
700 })
701 }
702
703 pub fn handle_connection(
704 self: &Arc<Self>,
705 connection: Connection,
706 address: String,
707 principal: Principal,
708 zed_version: ZedVersion,
709 geoip_country_code: Option<String>,
710 system_id: Option<String>,
711 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
712 executor: Executor,
713 ) -> impl Future<Output = ()> + use<> {
714 let this = self.clone();
715 let span = info_span!("handle connection", %address,
716 connection_id=field::Empty,
717 user_id=field::Empty,
718 login=field::Empty,
719 impersonator=field::Empty,
720 geoip_country_code=field::Empty
721 );
722 principal.update_span(&span);
723 if let Some(country_code) = geoip_country_code.as_ref() {
724 span.record("geoip_country_code", country_code);
725 }
726
727 let mut teardown = self.teardown.subscribe();
728 async move {
729 if *teardown.borrow() {
730 tracing::error!("server is tearing down");
731 return
732 }
733 let (connection_id, handle_io, mut incoming_rx) = this
734 .peer
735 .add_connection(connection, {
736 let executor = executor.clone();
737 move |duration| executor.sleep(duration)
738 });
739 tracing::Span::current().record("connection_id", format!("{}", connection_id));
740
741 tracing::info!("connection opened");
742
743 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
744 let http_client = match ReqwestClient::user_agent(&user_agent) {
745 Ok(http_client) => Arc::new(http_client),
746 Err(error) => {
747 tracing::error!(?error, "failed to create HTTP client");
748 return;
749 }
750 };
751
752 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
753 supermaven_admin_api_key.to_string(),
754 http_client.clone(),
755 )));
756
757 let session = Session {
758 principal: principal.clone(),
759 connection_id,
760 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
761 peer: this.peer.clone(),
762 connection_pool: this.connection_pool.clone(),
763 app_state: this.app_state.clone(),
764 geoip_country_code,
765 system_id,
766 _executor: executor.clone(),
767 supermaven_client,
768 };
769
770 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
771 tracing::error!(?error, "failed to send initial client update");
772 return;
773 }
774
775 let handle_io = handle_io.fuse();
776 futures::pin_mut!(handle_io);
777
778 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
779 // This prevents deadlocks when e.g., client A performs a request to client B and
780 // client B performs a request to client A. If both clients stop processing further
781 // messages until their respective request completes, they won't have a chance to
782 // respond to the other client's request and cause a deadlock.
783 //
784 // This arrangement ensures we will attempt to process earlier messages first, but fall
785 // back to processing messages arrived later in the spirit of making progress.
786 let mut foreground_message_handlers = FuturesUnordered::new();
787 let concurrent_handlers = Arc::new(Semaphore::new(256));
788 loop {
789 let next_message = async {
790 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
791 let message = incoming_rx.next().await;
792 (permit, message)
793 }.fuse();
794 futures::pin_mut!(next_message);
795 futures::select_biased! {
796 _ = teardown.changed().fuse() => return,
797 result = handle_io => {
798 if let Err(error) = result {
799 tracing::error!(?error, "error handling I/O");
800 }
801 break;
802 }
803 _ = foreground_message_handlers.next() => {}
804 next_message = next_message => {
805 let (permit, message) = next_message;
806 if let Some(message) = message {
807 let type_name = message.payload_type_name();
808 // note: we copy all the fields from the parent span so we can query them in the logs.
809 // (https://github.com/tokio-rs/tracing/issues/2670).
810 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
811 user_id=field::Empty,
812 login=field::Empty,
813 impersonator=field::Empty,
814 );
815 principal.update_span(&span);
816 let span_enter = span.enter();
817 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
818 let is_background = message.is_background();
819 let handle_message = (handler)(message, session.clone());
820 drop(span_enter);
821
822 let handle_message = async move {
823 handle_message.await;
824 drop(permit);
825 }.instrument(span);
826 if is_background {
827 executor.spawn_detached(handle_message);
828 } else {
829 foreground_message_handlers.push(handle_message);
830 }
831 } else {
832 tracing::error!("no message handler");
833 }
834 } else {
835 tracing::info!("connection closed");
836 break;
837 }
838 }
839 }
840 }
841
842 drop(foreground_message_handlers);
843 tracing::info!("signing out");
844 if let Err(error) = connection_lost(session, teardown, executor).await {
845 tracing::error!(?error, "error signing out");
846 }
847
848 }.instrument(span)
849 }
850
851 async fn send_initial_client_update(
852 &self,
853 connection_id: ConnectionId,
854 principal: &Principal,
855 zed_version: ZedVersion,
856 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
857 session: &Session,
858 ) -> Result<()> {
859 self.peer.send(
860 connection_id,
861 proto::Hello {
862 peer_id: Some(connection_id.into()),
863 },
864 )?;
865 tracing::info!("sent hello message");
866 if let Some(send_connection_id) = send_connection_id.take() {
867 let _ = send_connection_id.send(connection_id);
868 }
869
870 match principal {
871 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
872 if !user.connected_once {
873 self.peer.send(connection_id, proto::ShowContacts {})?;
874 self.app_state
875 .db
876 .set_user_connected_once(user.id, true)
877 .await?;
878 }
879
880 update_user_plan(user.id, session).await?;
881
882 let contacts = self.app_state.db.get_contacts(user.id).await?;
883
884 {
885 let mut pool = self.connection_pool.lock();
886 pool.add_connection(connection_id, user.id, user.admin, zed_version);
887 self.peer.send(
888 connection_id,
889 build_initial_contacts_update(contacts, &pool),
890 )?;
891 }
892
893 if should_auto_subscribe_to_channels(zed_version) {
894 subscribe_user_to_channels(user.id, session).await?;
895 }
896
897 if let Some(incoming_call) =
898 self.app_state.db.incoming_call_for_user(user.id).await?
899 {
900 self.peer.send(connection_id, incoming_call)?;
901 }
902
903 update_user_contacts(user.id, session).await?;
904 }
905 }
906
907 Ok(())
908 }
909
910 pub async fn invite_code_redeemed(
911 self: &Arc<Self>,
912 inviter_id: UserId,
913 invitee_id: UserId,
914 ) -> Result<()> {
915 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
916 if let Some(code) = &user.invite_code {
917 let pool = self.connection_pool.lock();
918 let invitee_contact = contact_for_user(invitee_id, false, &pool);
919 for connection_id in pool.user_connection_ids(inviter_id) {
920 self.peer.send(
921 connection_id,
922 proto::UpdateContacts {
923 contacts: vec![invitee_contact.clone()],
924 ..Default::default()
925 },
926 )?;
927 self.peer.send(
928 connection_id,
929 proto::UpdateInviteInfo {
930 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
931 count: user.invite_count as u32,
932 },
933 )?;
934 }
935 }
936 }
937 Ok(())
938 }
939
940 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
941 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
942 if let Some(invite_code) = &user.invite_code {
943 let pool = self.connection_pool.lock();
944 for connection_id in pool.user_connection_ids(user_id) {
945 self.peer.send(
946 connection_id,
947 proto::UpdateInviteInfo {
948 url: format!(
949 "{}{}",
950 self.app_state.config.invite_link_prefix, invite_code
951 ),
952 count: user.invite_count as u32,
953 },
954 )?;
955 }
956 }
957 }
958 Ok(())
959 }
960
961 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
962 let pool = self.connection_pool.lock();
963 for connection_id in pool.user_connection_ids(user_id) {
964 self.peer
965 .send(connection_id, proto::RefreshLlmToken {})
966 .trace_err();
967 }
968 }
969
970 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
971 ServerSnapshot {
972 connection_pool: ConnectionPoolGuard {
973 guard: self.connection_pool.lock(),
974 _not_send: PhantomData,
975 },
976 peer: &self.peer,
977 }
978 }
979}
980
981impl Deref for ConnectionPoolGuard<'_> {
982 type Target = ConnectionPool;
983
984 fn deref(&self) -> &Self::Target {
985 &self.guard
986 }
987}
988
989impl DerefMut for ConnectionPoolGuard<'_> {
990 fn deref_mut(&mut self) -> &mut Self::Target {
991 &mut self.guard
992 }
993}
994
995impl Drop for ConnectionPoolGuard<'_> {
996 fn drop(&mut self) {
997 #[cfg(test)]
998 self.check_invariants();
999 }
1000}
1001
1002fn broadcast<F>(
1003 sender_id: Option<ConnectionId>,
1004 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005 mut f: F,
1006) where
1007 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009 for receiver_id in receiver_ids {
1010 if Some(receiver_id) != sender_id {
1011 if let Err(error) = f(receiver_id) {
1012 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013 }
1014 }
1015 }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021 fn name() -> &'static HeaderName {
1022 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024 }
1025
1026 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027 where
1028 Self: Sized,
1029 I: Iterator<Item = &'i axum::http::HeaderValue>,
1030 {
1031 let version = values
1032 .next()
1033 .ok_or_else(axum::headers::Error::invalid)?
1034 .to_str()
1035 .map_err(|_| axum::headers::Error::invalid())?
1036 .parse()
1037 .map_err(|_| axum::headers::Error::invalid())?;
1038 Ok(Self(version))
1039 }
1040
1041 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042 values.extend([self.0.to_string().parse().unwrap()]);
1043 }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048 fn name() -> &'static HeaderName {
1049 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051 }
1052
1053 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054 where
1055 Self: Sized,
1056 I: Iterator<Item = &'i axum::http::HeaderValue>,
1057 {
1058 let version = values
1059 .next()
1060 .ok_or_else(axum::headers::Error::invalid)?
1061 .to_str()
1062 .map_err(|_| axum::headers::Error::invalid())?
1063 .parse()
1064 .map_err(|_| axum::headers::Error::invalid())?;
1065 Ok(Self(version))
1066 }
1067
1068 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069 values.extend([self.0.to_string().parse().unwrap()]);
1070 }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074 Router::new()
1075 .route("/rpc", get(handle_websocket_request))
1076 .layer(
1077 ServiceBuilder::new()
1078 .layer(Extension(server.app_state.clone()))
1079 .layer(middleware::from_fn(auth::validate_header)),
1080 )
1081 .route("/metrics", get(handle_metrics))
1082 .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089 Extension(server): Extension<Arc<Server>>,
1090 Extension(principal): Extension<Principal>,
1091 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093 ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095 if protocol_version != rpc::PROTOCOL_VERSION {
1096 return (
1097 StatusCode::UPGRADE_REQUIRED,
1098 "client must be upgraded".to_string(),
1099 )
1100 .into_response();
1101 }
1102
1103 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1104 return (
1105 StatusCode::UPGRADE_REQUIRED,
1106 "no version header found".to_string(),
1107 )
1108 .into_response();
1109 };
1110
1111 if !version.can_collaborate() {
1112 return (
1113 StatusCode::UPGRADE_REQUIRED,
1114 "client must be upgraded".to_string(),
1115 )
1116 .into_response();
1117 }
1118
1119 let socket_address = socket_address.to_string();
1120 ws.on_upgrade(move |socket| {
1121 let socket = socket
1122 .map_ok(to_tungstenite_message)
1123 .err_into()
1124 .with(|message| async move { to_axum_message(message) });
1125 let connection = Connection::new(Box::pin(socket));
1126 async move {
1127 server
1128 .handle_connection(
1129 connection,
1130 socket_address,
1131 principal,
1132 version,
1133 country_code_header.map(|header| header.to_string()),
1134 system_id_header.map(|header| header.to_string()),
1135 None,
1136 Executor::Production,
1137 )
1138 .await;
1139 }
1140 })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145 let connections_metric = CONNECTIONS_METRIC
1146 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148 let connections = server
1149 .connection_pool
1150 .lock()
1151 .connections()
1152 .filter(|connection| !connection.admin)
1153 .count();
1154 connections_metric.set(connections as _);
1155
1156 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158 register_int_gauge!(
1159 "shared_projects",
1160 "number of open projects with one or more guests"
1161 )
1162 .unwrap()
1163 });
1164
1165 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166 shared_projects_metric.set(shared_projects as _);
1167
1168 let encoder = prometheus::TextEncoder::new();
1169 let metric_families = prometheus::gather();
1170 let encoded_metrics = encoder
1171 .encode_to_string(&metric_families)
1172 .map_err(|err| anyhow!("{}", err))?;
1173 Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178 session: Session,
1179 mut teardown: watch::Receiver<bool>,
1180 executor: Executor,
1181) -> Result<()> {
1182 session.peer.disconnect(session.connection_id);
1183 session
1184 .connection_pool()
1185 .await
1186 .remove_connection(session.connection_id)?;
1187
1188 session
1189 .db()
1190 .await
1191 .connection_lost(session.connection_id)
1192 .await
1193 .trace_err();
1194
1195 futures::select_biased! {
1196 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199 leave_room_for_session(&session, session.connection_id).await.trace_err();
1200 leave_channel_buffers_for_session(&session)
1201 .await
1202 .trace_err();
1203
1204 if !session
1205 .connection_pool()
1206 .await
1207 .is_user_online(session.user_id())
1208 {
1209 let db = session.db().await;
1210 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211 room_updated(&room, &session.peer);
1212 }
1213 }
1214
1215 update_user_contacts(session.user_id(), &session).await?;
1216 },
1217 _ = teardown.changed().fuse() => {}
1218 }
1219
1220 Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225 response.send(proto::Ack {})?;
1226 Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231 _request: proto::CreateRoom,
1232 response: Response<proto::CreateRoom>,
1233 session: Session,
1234) -> Result<()> {
1235 let livekit_room = nanoid::nanoid!(30);
1236
1237 let live_kit_connection_info = util::maybe!(async {
1238 let live_kit = session.app_state.livekit_client.as_ref();
1239 let live_kit = live_kit?;
1240 let user_id = session.user_id().to_string();
1241
1242 let token = live_kit
1243 .room_token(&livekit_room, &user_id.to_string())
1244 .trace_err()?;
1245
1246 Some(proto::LiveKitConnectionInfo {
1247 server_url: live_kit.url().into(),
1248 token,
1249 can_publish: true,
1250 })
1251 })
1252 .await;
1253
1254 let room = session
1255 .db()
1256 .await
1257 .create_room(session.user_id(), session.connection_id, &livekit_room)
1258 .await?;
1259
1260 response.send(proto::CreateRoomResponse {
1261 room: Some(room.clone()),
1262 live_kit_connection_info,
1263 })?;
1264
1265 update_user_contacts(session.user_id(), &session).await?;
1266 Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271 request: proto::JoinRoom,
1272 response: Response<proto::JoinRoom>,
1273 session: Session,
1274) -> Result<()> {
1275 let room_id = RoomId::from_proto(request.id);
1276
1277 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279 if let Some(channel_id) = channel_id {
1280 return join_channel_internal(channel_id, Box::new(response), session).await;
1281 }
1282
1283 let joined_room = {
1284 let room = session
1285 .db()
1286 .await
1287 .join_room(room_id, session.user_id(), session.connection_id)
1288 .await?;
1289 room_updated(&room.room, &session.peer);
1290 room.into_inner()
1291 };
1292
1293 for connection_id in session
1294 .connection_pool()
1295 .await
1296 .user_connection_ids(session.user_id())
1297 {
1298 session
1299 .peer
1300 .send(
1301 connection_id,
1302 proto::CallCanceled {
1303 room_id: room_id.to_proto(),
1304 },
1305 )
1306 .trace_err();
1307 }
1308
1309 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310 {
1311 live_kit
1312 .room_token(
1313 &joined_room.room.livekit_room,
1314 &session.user_id().to_string(),
1315 )
1316 .trace_err()
1317 .map(|token| proto::LiveKitConnectionInfo {
1318 server_url: live_kit.url().into(),
1319 token,
1320 can_publish: true,
1321 })
1322 } else {
1323 None
1324 };
1325
1326 response.send(proto::JoinRoomResponse {
1327 room: Some(joined_room.room),
1328 channel_id: None,
1329 live_kit_connection_info,
1330 })?;
1331
1332 update_user_contacts(session.user_id(), &session).await?;
1333 Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338 request: proto::RejoinRoom,
1339 response: Response<proto::RejoinRoom>,
1340 session: Session,
1341) -> Result<()> {
1342 let room;
1343 let channel;
1344 {
1345 let mut rejoined_room = session
1346 .db()
1347 .await
1348 .rejoin_room(request, session.user_id(), session.connection_id)
1349 .await?;
1350
1351 response.send(proto::RejoinRoomResponse {
1352 room: Some(rejoined_room.room.clone()),
1353 reshared_projects: rejoined_room
1354 .reshared_projects
1355 .iter()
1356 .map(|project| proto::ResharedProject {
1357 id: project.id.to_proto(),
1358 collaborators: project
1359 .collaborators
1360 .iter()
1361 .map(|collaborator| collaborator.to_proto())
1362 .collect(),
1363 })
1364 .collect(),
1365 rejoined_projects: rejoined_room
1366 .rejoined_projects
1367 .iter()
1368 .map(|rejoined_project| rejoined_project.to_proto())
1369 .collect(),
1370 })?;
1371 room_updated(&rejoined_room.room, &session.peer);
1372
1373 for project in &rejoined_room.reshared_projects {
1374 for collaborator in &project.collaborators {
1375 session
1376 .peer
1377 .send(
1378 collaborator.connection_id,
1379 proto::UpdateProjectCollaborator {
1380 project_id: project.id.to_proto(),
1381 old_peer_id: Some(project.old_connection_id.into()),
1382 new_peer_id: Some(session.connection_id.into()),
1383 },
1384 )
1385 .trace_err();
1386 }
1387
1388 broadcast(
1389 Some(session.connection_id),
1390 project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.connection_id),
1394 |connection_id| {
1395 session.peer.forward_send(
1396 session.connection_id,
1397 connection_id,
1398 proto::UpdateProject {
1399 project_id: project.id.to_proto(),
1400 worktrees: project.worktrees.clone(),
1401 },
1402 )
1403 },
1404 );
1405 }
1406
1407 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409 let rejoined_room = rejoined_room.into_inner();
1410
1411 room = rejoined_room.room;
1412 channel = rejoined_room.channel;
1413 }
1414
1415 if let Some(channel) = channel {
1416 channel_updated(
1417 &channel,
1418 &room,
1419 &session.peer,
1420 &*session.connection_pool().await,
1421 );
1422 }
1423
1424 update_user_contacts(session.user_id(), &session).await?;
1425 Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429 rejoined_projects: &mut Vec<RejoinedProject>,
1430 session: &Session,
1431) -> Result<()> {
1432 for project in rejoined_projects.iter() {
1433 for collaborator in &project.collaborators {
1434 session
1435 .peer
1436 .send(
1437 collaborator.connection_id,
1438 proto::UpdateProjectCollaborator {
1439 project_id: project.id.to_proto(),
1440 old_peer_id: Some(project.old_connection_id.into()),
1441 new_peer_id: Some(session.connection_id.into()),
1442 },
1443 )
1444 .trace_err();
1445 }
1446 }
1447
1448 for project in rejoined_projects {
1449 for worktree in mem::take(&mut project.worktrees) {
1450 // Stream this worktree's entries.
1451 let message = proto::UpdateWorktree {
1452 project_id: project.id.to_proto(),
1453 worktree_id: worktree.id,
1454 abs_path: worktree.abs_path.clone(),
1455 root_name: worktree.root_name,
1456 updated_entries: worktree.updated_entries,
1457 removed_entries: worktree.removed_entries,
1458 scan_id: worktree.scan_id,
1459 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460 updated_repositories: worktree.updated_repositories,
1461 removed_repositories: worktree.removed_repositories,
1462 };
1463 for update in proto::split_worktree_update(message) {
1464 session.peer.send(session.connection_id, update)?;
1465 }
1466
1467 // Stream this worktree's diagnostics.
1468 for summary in worktree.diagnostic_summaries {
1469 session.peer.send(
1470 session.connection_id,
1471 proto::UpdateDiagnosticSummary {
1472 project_id: project.id.to_proto(),
1473 worktree_id: worktree.id,
1474 summary: Some(summary),
1475 },
1476 )?;
1477 }
1478
1479 for settings_file in worktree.settings_files {
1480 session.peer.send(
1481 session.connection_id,
1482 proto::UpdateWorktreeSettings {
1483 project_id: project.id.to_proto(),
1484 worktree_id: worktree.id,
1485 path: settings_file.path,
1486 content: Some(settings_file.content),
1487 kind: Some(settings_file.kind.to_proto().into()),
1488 },
1489 )?;
1490 }
1491 }
1492
1493 for repository in mem::take(&mut project.updated_repositories) {
1494 for update in split_repository_update(repository) {
1495 session.peer.send(session.connection_id, update)?;
1496 }
1497 }
1498
1499 for id in mem::take(&mut project.removed_repositories) {
1500 session.peer.send(
1501 session.connection_id,
1502 proto::RemoveRepository {
1503 project_id: project.id.to_proto(),
1504 id,
1505 },
1506 )?;
1507 }
1508 }
1509
1510 Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515 _: proto::LeaveRoom,
1516 response: Response<proto::LeaveRoom>,
1517 session: Session,
1518) -> Result<()> {
1519 leave_room_for_session(&session, session.connection_id).await?;
1520 response.send(proto::Ack {})?;
1521 Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526 request: proto::SetRoomParticipantRole,
1527 response: Response<proto::SetRoomParticipantRole>,
1528 session: Session,
1529) -> Result<()> {
1530 let user_id = UserId::from_proto(request.user_id);
1531 let role = ChannelRole::from(request.role());
1532
1533 let (livekit_room, can_publish) = {
1534 let room = session
1535 .db()
1536 .await
1537 .set_room_participant_role(
1538 session.user_id(),
1539 RoomId::from_proto(request.room_id),
1540 user_id,
1541 role,
1542 )
1543 .await?;
1544
1545 let livekit_room = room.livekit_room.clone();
1546 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547 room_updated(&room, &session.peer);
1548 (livekit_room, can_publish)
1549 };
1550
1551 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552 live_kit
1553 .update_participant(
1554 livekit_room.clone(),
1555 request.user_id.to_string(),
1556 livekit_api::proto::ParticipantPermission {
1557 can_subscribe: true,
1558 can_publish,
1559 can_publish_data: can_publish,
1560 hidden: false,
1561 recorder: false,
1562 },
1563 )
1564 .await
1565 .trace_err();
1566 }
1567
1568 response.send(proto::Ack {})?;
1569 Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574 request: proto::Call,
1575 response: Response<proto::Call>,
1576 session: Session,
1577) -> Result<()> {
1578 let room_id = RoomId::from_proto(request.room_id);
1579 let calling_user_id = session.user_id();
1580 let calling_connection_id = session.connection_id;
1581 let called_user_id = UserId::from_proto(request.called_user_id);
1582 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583 if !session
1584 .db()
1585 .await
1586 .has_contact(calling_user_id, called_user_id)
1587 .await?
1588 {
1589 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590 }
1591
1592 let incoming_call = {
1593 let (room, incoming_call) = &mut *session
1594 .db()
1595 .await
1596 .call(
1597 room_id,
1598 calling_user_id,
1599 calling_connection_id,
1600 called_user_id,
1601 initial_project_id,
1602 )
1603 .await?;
1604 room_updated(room, &session.peer);
1605 mem::take(incoming_call)
1606 };
1607 update_user_contacts(called_user_id, &session).await?;
1608
1609 let mut calls = session
1610 .connection_pool()
1611 .await
1612 .user_connection_ids(called_user_id)
1613 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614 .collect::<FuturesUnordered<_>>();
1615
1616 while let Some(call_response) = calls.next().await {
1617 match call_response.as_ref() {
1618 Ok(_) => {
1619 response.send(proto::Ack {})?;
1620 return Ok(());
1621 }
1622 Err(_) => {
1623 call_response.trace_err();
1624 }
1625 }
1626 }
1627
1628 {
1629 let room = session
1630 .db()
1631 .await
1632 .call_failed(room_id, called_user_id)
1633 .await?;
1634 room_updated(&room, &session.peer);
1635 }
1636 update_user_contacts(called_user_id, &session).await?;
1637
1638 Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643 request: proto::CancelCall,
1644 response: Response<proto::CancelCall>,
1645 session: Session,
1646) -> Result<()> {
1647 let called_user_id = UserId::from_proto(request.called_user_id);
1648 let room_id = RoomId::from_proto(request.room_id);
1649 {
1650 let room = session
1651 .db()
1652 .await
1653 .cancel_call(room_id, session.connection_id, called_user_id)
1654 .await?;
1655 room_updated(&room, &session.peer);
1656 }
1657
1658 for connection_id in session
1659 .connection_pool()
1660 .await
1661 .user_connection_ids(called_user_id)
1662 {
1663 session
1664 .peer
1665 .send(
1666 connection_id,
1667 proto::CallCanceled {
1668 room_id: room_id.to_proto(),
1669 },
1670 )
1671 .trace_err();
1672 }
1673 response.send(proto::Ack {})?;
1674
1675 update_user_contacts(called_user_id, &session).await?;
1676 Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681 let room_id = RoomId::from_proto(message.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .decline_call(Some(room_id), session.user_id())
1687 .await?
1688 .ok_or_else(|| anyhow!("failed to decline call"))?;
1689 room_updated(&room, &session.peer);
1690 }
1691
1692 for connection_id in session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(session.user_id())
1696 {
1697 session
1698 .peer
1699 .send(
1700 connection_id,
1701 proto::CallCanceled {
1702 room_id: room_id.to_proto(),
1703 },
1704 )
1705 .trace_err();
1706 }
1707 update_user_contacts(session.user_id(), &session).await?;
1708 Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713 request: proto::UpdateParticipantLocation,
1714 response: Response<proto::UpdateParticipantLocation>,
1715 session: Session,
1716) -> Result<()> {
1717 let room_id = RoomId::from_proto(request.room_id);
1718 let location = request
1719 .location
1720 .ok_or_else(|| anyhow!("invalid location"))?;
1721
1722 let db = session.db().await;
1723 let room = db
1724 .update_room_participant_location(room_id, session.connection_id, location)
1725 .await?;
1726
1727 room_updated(&room, &session.peer);
1728 response.send(proto::Ack {})?;
1729 Ok(())
1730}
1731
1732/// Share a project into the room.
1733async fn share_project(
1734 request: proto::ShareProject,
1735 response: Response<proto::ShareProject>,
1736 session: Session,
1737) -> Result<()> {
1738 let (project_id, room) = &*session
1739 .db()
1740 .await
1741 .share_project(
1742 RoomId::from_proto(request.room_id),
1743 session.connection_id,
1744 &request.worktrees,
1745 request.is_ssh_project,
1746 )
1747 .await?;
1748 response.send(proto::ShareProjectResponse {
1749 project_id: project_id.to_proto(),
1750 })?;
1751 room_updated(room, &session.peer);
1752
1753 Ok(())
1754}
1755
1756/// Unshare a project from the room.
1757async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1758 let project_id = ProjectId::from_proto(message.project_id);
1759 unshare_project_internal(project_id, session.connection_id, &session).await
1760}
1761
1762async fn unshare_project_internal(
1763 project_id: ProjectId,
1764 connection_id: ConnectionId,
1765 session: &Session,
1766) -> Result<()> {
1767 let delete = {
1768 let room_guard = session
1769 .db()
1770 .await
1771 .unshare_project(project_id, connection_id)
1772 .await?;
1773
1774 let (delete, room, guest_connection_ids) = &*room_guard;
1775
1776 let message = proto::UnshareProject {
1777 project_id: project_id.to_proto(),
1778 };
1779
1780 broadcast(
1781 Some(connection_id),
1782 guest_connection_ids.iter().copied(),
1783 |conn_id| session.peer.send(conn_id, message.clone()),
1784 );
1785 if let Some(room) = room {
1786 room_updated(room, &session.peer);
1787 }
1788
1789 *delete
1790 };
1791
1792 if delete {
1793 let db = session.db().await;
1794 db.delete_project(project_id).await?;
1795 }
1796
1797 Ok(())
1798}
1799
1800/// Join someone elses shared project.
1801async fn join_project(
1802 request: proto::JoinProject,
1803 response: Response<proto::JoinProject>,
1804 session: Session,
1805) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807
1808 tracing::info!(%project_id, "join project");
1809
1810 let db = session.db().await;
1811 let (project, replica_id) = &mut *db
1812 .join_project(project_id, session.connection_id, session.user_id())
1813 .await?;
1814 drop(db);
1815 tracing::info!(%project_id, "join remote project");
1816 join_project_internal(response, session, project, replica_id)
1817}
1818
1819trait JoinProjectInternalResponse {
1820 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1821}
1822impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1823 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1824 Response::<proto::JoinProject>::send(self, result)
1825 }
1826}
1827
1828fn join_project_internal(
1829 response: impl JoinProjectInternalResponse,
1830 session: Session,
1831 project: &mut Project,
1832 replica_id: &ReplicaId,
1833) -> Result<()> {
1834 let collaborators = project
1835 .collaborators
1836 .iter()
1837 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1838 .map(|collaborator| collaborator.to_proto())
1839 .collect::<Vec<_>>();
1840 let project_id = project.id;
1841 let guest_user_id = session.user_id();
1842
1843 let worktrees = project
1844 .worktrees
1845 .iter()
1846 .map(|(id, worktree)| proto::WorktreeMetadata {
1847 id: *id,
1848 root_name: worktree.root_name.clone(),
1849 visible: worktree.visible,
1850 abs_path: worktree.abs_path.clone(),
1851 })
1852 .collect::<Vec<_>>();
1853
1854 let add_project_collaborator = proto::AddProjectCollaborator {
1855 project_id: project_id.to_proto(),
1856 collaborator: Some(proto::Collaborator {
1857 peer_id: Some(session.connection_id.into()),
1858 replica_id: replica_id.0 as u32,
1859 user_id: guest_user_id.to_proto(),
1860 is_host: false,
1861 }),
1862 };
1863
1864 for collaborator in &collaborators {
1865 session
1866 .peer
1867 .send(
1868 collaborator.peer_id.unwrap().into(),
1869 add_project_collaborator.clone(),
1870 )
1871 .trace_err();
1872 }
1873
1874 // First, we send the metadata associated with each worktree.
1875 response.send(proto::JoinProjectResponse {
1876 project_id: project.id.0 as u64,
1877 worktrees: worktrees.clone(),
1878 replica_id: replica_id.0 as u32,
1879 collaborators: collaborators.clone(),
1880 language_servers: project.language_servers.clone(),
1881 role: project.role.into(),
1882 })?;
1883
1884 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885 // Stream this worktree's entries.
1886 let message = proto::UpdateWorktree {
1887 project_id: project_id.to_proto(),
1888 worktree_id,
1889 abs_path: worktree.abs_path.clone(),
1890 root_name: worktree.root_name,
1891 updated_entries: worktree.entries,
1892 removed_entries: Default::default(),
1893 scan_id: worktree.scan_id,
1894 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1895 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1896 removed_repositories: Default::default(),
1897 };
1898 for update in proto::split_worktree_update(message) {
1899 session.peer.send(session.connection_id, update.clone())?;
1900 }
1901
1902 // Stream this worktree's diagnostics.
1903 for summary in worktree.diagnostic_summaries {
1904 session.peer.send(
1905 session.connection_id,
1906 proto::UpdateDiagnosticSummary {
1907 project_id: project_id.to_proto(),
1908 worktree_id: worktree.id,
1909 summary: Some(summary),
1910 },
1911 )?;
1912 }
1913
1914 for settings_file in worktree.settings_files {
1915 session.peer.send(
1916 session.connection_id,
1917 proto::UpdateWorktreeSettings {
1918 project_id: project_id.to_proto(),
1919 worktree_id: worktree.id,
1920 path: settings_file.path,
1921 content: Some(settings_file.content),
1922 kind: Some(settings_file.kind.to_proto() as i32),
1923 },
1924 )?;
1925 }
1926 }
1927
1928 for repository in mem::take(&mut project.repositories) {
1929 for update in split_repository_update(repository) {
1930 session.peer.send(session.connection_id, update)?;
1931 }
1932 }
1933
1934 for language_server in &project.language_servers {
1935 session.peer.send(
1936 session.connection_id,
1937 proto::UpdateLanguageServer {
1938 project_id: project_id.to_proto(),
1939 language_server_id: language_server.id,
1940 variant: Some(
1941 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1942 proto::LspDiskBasedDiagnosticsUpdated {},
1943 ),
1944 ),
1945 },
1946 )?;
1947 }
1948
1949 Ok(())
1950}
1951
1952/// Leave someone elses shared project.
1953async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1954 let sender_id = session.connection_id;
1955 let project_id = ProjectId::from_proto(request.project_id);
1956 let db = session.db().await;
1957
1958 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1959 tracing::info!(
1960 %project_id,
1961 "leave project"
1962 );
1963
1964 project_left(project, &session);
1965 if let Some(room) = room {
1966 room_updated(room, &session.peer);
1967 }
1968
1969 Ok(())
1970}
1971
1972/// Updates other participants with changes to the project
1973async fn update_project(
1974 request: proto::UpdateProject,
1975 response: Response<proto::UpdateProject>,
1976 session: Session,
1977) -> Result<()> {
1978 let project_id = ProjectId::from_proto(request.project_id);
1979 let (room, guest_connection_ids) = &*session
1980 .db()
1981 .await
1982 .update_project(project_id, session.connection_id, &request.worktrees)
1983 .await?;
1984 broadcast(
1985 Some(session.connection_id),
1986 guest_connection_ids.iter().copied(),
1987 |connection_id| {
1988 session
1989 .peer
1990 .forward_send(session.connection_id, connection_id, request.clone())
1991 },
1992 );
1993 if let Some(room) = room {
1994 room_updated(room, &session.peer);
1995 }
1996 response.send(proto::Ack {})?;
1997
1998 Ok(())
1999}
2000
2001/// Updates other participants with changes to the worktree
2002async fn update_worktree(
2003 request: proto::UpdateWorktree,
2004 response: Response<proto::UpdateWorktree>,
2005 session: Session,
2006) -> Result<()> {
2007 let guest_connection_ids = session
2008 .db()
2009 .await
2010 .update_worktree(&request, session.connection_id)
2011 .await?;
2012
2013 broadcast(
2014 Some(session.connection_id),
2015 guest_connection_ids.iter().copied(),
2016 |connection_id| {
2017 session
2018 .peer
2019 .forward_send(session.connection_id, connection_id, request.clone())
2020 },
2021 );
2022 response.send(proto::Ack {})?;
2023 Ok(())
2024}
2025
2026async fn update_repository(
2027 request: proto::UpdateRepository,
2028 response: Response<proto::UpdateRepository>,
2029 session: Session,
2030) -> Result<()> {
2031 let guest_connection_ids = session
2032 .db()
2033 .await
2034 .update_repository(&request, session.connection_id)
2035 .await?;
2036
2037 broadcast(
2038 Some(session.connection_id),
2039 guest_connection_ids.iter().copied(),
2040 |connection_id| {
2041 session
2042 .peer
2043 .forward_send(session.connection_id, connection_id, request.clone())
2044 },
2045 );
2046 response.send(proto::Ack {})?;
2047 Ok(())
2048}
2049
2050async fn remove_repository(
2051 request: proto::RemoveRepository,
2052 response: Response<proto::RemoveRepository>,
2053 session: Session,
2054) -> Result<()> {
2055 let guest_connection_ids = session
2056 .db()
2057 .await
2058 .remove_repository(&request, session.connection_id)
2059 .await?;
2060
2061 broadcast(
2062 Some(session.connection_id),
2063 guest_connection_ids.iter().copied(),
2064 |connection_id| {
2065 session
2066 .peer
2067 .forward_send(session.connection_id, connection_id, request.clone())
2068 },
2069 );
2070 response.send(proto::Ack {})?;
2071 Ok(())
2072}
2073
2074/// Updates other participants with changes to the diagnostics
2075async fn update_diagnostic_summary(
2076 message: proto::UpdateDiagnosticSummary,
2077 session: Session,
2078) -> Result<()> {
2079 let guest_connection_ids = session
2080 .db()
2081 .await
2082 .update_diagnostic_summary(&message, session.connection_id)
2083 .await?;
2084
2085 broadcast(
2086 Some(session.connection_id),
2087 guest_connection_ids.iter().copied(),
2088 |connection_id| {
2089 session
2090 .peer
2091 .forward_send(session.connection_id, connection_id, message.clone())
2092 },
2093 );
2094
2095 Ok(())
2096}
2097
2098/// Updates other participants with changes to the worktree settings
2099async fn update_worktree_settings(
2100 message: proto::UpdateWorktreeSettings,
2101 session: Session,
2102) -> Result<()> {
2103 let guest_connection_ids = session
2104 .db()
2105 .await
2106 .update_worktree_settings(&message, session.connection_id)
2107 .await?;
2108
2109 broadcast(
2110 Some(session.connection_id),
2111 guest_connection_ids.iter().copied(),
2112 |connection_id| {
2113 session
2114 .peer
2115 .forward_send(session.connection_id, connection_id, message.clone())
2116 },
2117 );
2118
2119 Ok(())
2120}
2121
2122/// Notify other participants that a language server has started.
2123async fn start_language_server(
2124 request: proto::StartLanguageServer,
2125 session: Session,
2126) -> Result<()> {
2127 let guest_connection_ids = session
2128 .db()
2129 .await
2130 .start_language_server(&request, session.connection_id)
2131 .await?;
2132
2133 broadcast(
2134 Some(session.connection_id),
2135 guest_connection_ids.iter().copied(),
2136 |connection_id| {
2137 session
2138 .peer
2139 .forward_send(session.connection_id, connection_id, request.clone())
2140 },
2141 );
2142 Ok(())
2143}
2144
2145/// Notify other participants that a language server has changed.
2146async fn update_language_server(
2147 request: proto::UpdateLanguageServer,
2148 session: Session,
2149) -> Result<()> {
2150 let project_id = ProjectId::from_proto(request.project_id);
2151 let project_connection_ids = session
2152 .db()
2153 .await
2154 .project_connection_ids(project_id, session.connection_id, true)
2155 .await?;
2156 broadcast(
2157 Some(session.connection_id),
2158 project_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, request.clone())
2163 },
2164 );
2165 Ok(())
2166}
2167
2168/// forward a project request to the host. These requests should be read only
2169/// as guests are allowed to send them.
2170async fn forward_read_only_project_request<T>(
2171 request: T,
2172 response: Response<T>,
2173 session: Session,
2174) -> Result<()>
2175where
2176 T: EntityMessage + RequestMessage,
2177{
2178 let project_id = ProjectId::from_proto(request.remote_entity_id());
2179 let host_connection_id = session
2180 .db()
2181 .await
2182 .host_for_read_only_project_request(project_id, session.connection_id)
2183 .await?;
2184 let payload = session
2185 .peer
2186 .forward_request(session.connection_id, host_connection_id, request)
2187 .await?;
2188 response.send(payload)?;
2189 Ok(())
2190}
2191
2192async fn forward_find_search_candidates_request(
2193 request: proto::FindSearchCandidates,
2194 response: Response<proto::FindSearchCandidates>,
2195 session: Session,
2196) -> Result<()> {
2197 let project_id = ProjectId::from_proto(request.remote_entity_id());
2198 let host_connection_id = session
2199 .db()
2200 .await
2201 .host_for_read_only_project_request(project_id, session.connection_id)
2202 .await?;
2203 let payload = session
2204 .peer
2205 .forward_request(session.connection_id, host_connection_id, request)
2206 .await?;
2207 response.send(payload)?;
2208 Ok(())
2209}
2210
2211/// forward a project request to the host. These requests are disallowed
2212/// for guests.
2213async fn forward_mutating_project_request<T>(
2214 request: T,
2215 response: Response<T>,
2216 session: Session,
2217) -> Result<()>
2218where
2219 T: EntityMessage + RequestMessage,
2220{
2221 let project_id = ProjectId::from_proto(request.remote_entity_id());
2222
2223 let host_connection_id = session
2224 .db()
2225 .await
2226 .host_for_mutating_project_request(project_id, session.connection_id)
2227 .await?;
2228 let payload = session
2229 .peer
2230 .forward_request(session.connection_id, host_connection_id, request)
2231 .await?;
2232 response.send(payload)?;
2233 Ok(())
2234}
2235
2236/// Notify other participants that a new buffer has been created
2237async fn create_buffer_for_peer(
2238 request: proto::CreateBufferForPeer,
2239 session: Session,
2240) -> Result<()> {
2241 session
2242 .db()
2243 .await
2244 .check_user_is_project_host(
2245 ProjectId::from_proto(request.project_id),
2246 session.connection_id,
2247 )
2248 .await?;
2249 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2250 session
2251 .peer
2252 .forward_send(session.connection_id, peer_id.into(), request)?;
2253 Ok(())
2254}
2255
2256/// Notify other participants that a buffer has been updated. This is
2257/// allowed for guests as long as the update is limited to selections.
2258async fn update_buffer(
2259 request: proto::UpdateBuffer,
2260 response: Response<proto::UpdateBuffer>,
2261 session: Session,
2262) -> Result<()> {
2263 let project_id = ProjectId::from_proto(request.project_id);
2264 let mut capability = Capability::ReadOnly;
2265
2266 for op in request.operations.iter() {
2267 match op.variant {
2268 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2269 Some(_) => capability = Capability::ReadWrite,
2270 }
2271 }
2272
2273 let host = {
2274 let guard = session
2275 .db()
2276 .await
2277 .connections_for_buffer_update(project_id, session.connection_id, capability)
2278 .await?;
2279
2280 let (host, guests) = &*guard;
2281
2282 broadcast(
2283 Some(session.connection_id),
2284 guests.clone(),
2285 |connection_id| {
2286 session
2287 .peer
2288 .forward_send(session.connection_id, connection_id, request.clone())
2289 },
2290 );
2291
2292 *host
2293 };
2294
2295 if host != session.connection_id {
2296 session
2297 .peer
2298 .forward_request(session.connection_id, host, request.clone())
2299 .await?;
2300 }
2301
2302 response.send(proto::Ack {})?;
2303 Ok(())
2304}
2305
2306async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2307 let project_id = ProjectId::from_proto(message.project_id);
2308
2309 let operation = message.operation.as_ref().context("invalid operation")?;
2310 let capability = match operation.variant.as_ref() {
2311 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2312 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2313 match buffer_op.variant {
2314 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2315 Capability::ReadOnly
2316 }
2317 _ => Capability::ReadWrite,
2318 }
2319 } else {
2320 Capability::ReadWrite
2321 }
2322 }
2323 Some(_) => Capability::ReadWrite,
2324 None => Capability::ReadOnly,
2325 };
2326
2327 let guard = session
2328 .db()
2329 .await
2330 .connections_for_buffer_update(project_id, session.connection_id, capability)
2331 .await?;
2332
2333 let (host, guests) = &*guard;
2334
2335 broadcast(
2336 Some(session.connection_id),
2337 guests.iter().chain([host]).copied(),
2338 |connection_id| {
2339 session
2340 .peer
2341 .forward_send(session.connection_id, connection_id, message.clone())
2342 },
2343 );
2344
2345 Ok(())
2346}
2347
2348/// Notify other participants that a project has been updated.
2349async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2350 request: T,
2351 session: Session,
2352) -> Result<()> {
2353 let project_id = ProjectId::from_proto(request.remote_entity_id());
2354 let project_connection_ids = session
2355 .db()
2356 .await
2357 .project_connection_ids(project_id, session.connection_id, false)
2358 .await?;
2359
2360 broadcast(
2361 Some(session.connection_id),
2362 project_connection_ids.iter().copied(),
2363 |connection_id| {
2364 session
2365 .peer
2366 .forward_send(session.connection_id, connection_id, request.clone())
2367 },
2368 );
2369 Ok(())
2370}
2371
2372/// Start following another user in a call.
2373async fn follow(
2374 request: proto::Follow,
2375 response: Response<proto::Follow>,
2376 session: Session,
2377) -> Result<()> {
2378 let room_id = RoomId::from_proto(request.room_id);
2379 let project_id = request.project_id.map(ProjectId::from_proto);
2380 let leader_id = request
2381 .leader_id
2382 .ok_or_else(|| anyhow!("invalid leader id"))?
2383 .into();
2384 let follower_id = session.connection_id;
2385
2386 session
2387 .db()
2388 .await
2389 .check_room_participants(room_id, leader_id, session.connection_id)
2390 .await?;
2391
2392 let response_payload = session
2393 .peer
2394 .forward_request(session.connection_id, leader_id, request)
2395 .await?;
2396 response.send(response_payload)?;
2397
2398 if let Some(project_id) = project_id {
2399 let room = session
2400 .db()
2401 .await
2402 .follow(room_id, project_id, leader_id, follower_id)
2403 .await?;
2404 room_updated(&room, &session.peer);
2405 }
2406
2407 Ok(())
2408}
2409
2410/// Stop following another user in a call.
2411async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2412 let room_id = RoomId::from_proto(request.room_id);
2413 let project_id = request.project_id.map(ProjectId::from_proto);
2414 let leader_id = request
2415 .leader_id
2416 .ok_or_else(|| anyhow!("invalid leader id"))?
2417 .into();
2418 let follower_id = session.connection_id;
2419
2420 session
2421 .db()
2422 .await
2423 .check_room_participants(room_id, leader_id, session.connection_id)
2424 .await?;
2425
2426 session
2427 .peer
2428 .forward_send(session.connection_id, leader_id, request)?;
2429
2430 if let Some(project_id) = project_id {
2431 let room = session
2432 .db()
2433 .await
2434 .unfollow(room_id, project_id, leader_id, follower_id)
2435 .await?;
2436 room_updated(&room, &session.peer);
2437 }
2438
2439 Ok(())
2440}
2441
2442/// Notify everyone following you of your current location.
2443async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2444 let room_id = RoomId::from_proto(request.room_id);
2445 let database = session.db.lock().await;
2446
2447 let connection_ids = if let Some(project_id) = request.project_id {
2448 let project_id = ProjectId::from_proto(project_id);
2449 database
2450 .project_connection_ids(project_id, session.connection_id, true)
2451 .await?
2452 } else {
2453 database
2454 .room_connection_ids(room_id, session.connection_id)
2455 .await?
2456 };
2457
2458 // For now, don't send view update messages back to that view's current leader.
2459 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2460 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2461 _ => None,
2462 });
2463
2464 for connection_id in connection_ids.iter().cloned() {
2465 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2466 session
2467 .peer
2468 .forward_send(session.connection_id, connection_id, request.clone())?;
2469 }
2470 }
2471 Ok(())
2472}
2473
2474/// Get public data about users.
2475async fn get_users(
2476 request: proto::GetUsers,
2477 response: Response<proto::GetUsers>,
2478 session: Session,
2479) -> Result<()> {
2480 let user_ids = request
2481 .user_ids
2482 .into_iter()
2483 .map(UserId::from_proto)
2484 .collect();
2485 let users = session
2486 .db()
2487 .await
2488 .get_users_by_ids(user_ids)
2489 .await?
2490 .into_iter()
2491 .map(|user| proto::User {
2492 id: user.id.to_proto(),
2493 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2494 github_login: user.github_login,
2495 email: user.email_address,
2496 name: user.name,
2497 })
2498 .collect();
2499 response.send(proto::UsersResponse { users })?;
2500 Ok(())
2501}
2502
2503/// Search for users (to invite) buy Github login
2504async fn fuzzy_search_users(
2505 request: proto::FuzzySearchUsers,
2506 response: Response<proto::FuzzySearchUsers>,
2507 session: Session,
2508) -> Result<()> {
2509 let query = request.query;
2510 let users = match query.len() {
2511 0 => vec![],
2512 1 | 2 => session
2513 .db()
2514 .await
2515 .get_user_by_github_login(&query)
2516 .await?
2517 .into_iter()
2518 .collect(),
2519 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2520 };
2521 let users = users
2522 .into_iter()
2523 .filter(|user| user.id != session.user_id())
2524 .map(|user| proto::User {
2525 id: user.id.to_proto(),
2526 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2527 github_login: user.github_login,
2528 name: user.name,
2529 email: user.email_address,
2530 })
2531 .collect();
2532 response.send(proto::UsersResponse { users })?;
2533 Ok(())
2534}
2535
2536/// Send a contact request to another user.
2537async fn request_contact(
2538 request: proto::RequestContact,
2539 response: Response<proto::RequestContact>,
2540 session: Session,
2541) -> Result<()> {
2542 let requester_id = session.user_id();
2543 let responder_id = UserId::from_proto(request.responder_id);
2544 if requester_id == responder_id {
2545 return Err(anyhow!("cannot add yourself as a contact"))?;
2546 }
2547
2548 let notifications = session
2549 .db()
2550 .await
2551 .send_contact_request(requester_id, responder_id)
2552 .await?;
2553
2554 // Update outgoing contact requests of requester
2555 let mut update = proto::UpdateContacts::default();
2556 update.outgoing_requests.push(responder_id.to_proto());
2557 for connection_id in session
2558 .connection_pool()
2559 .await
2560 .user_connection_ids(requester_id)
2561 {
2562 session.peer.send(connection_id, update.clone())?;
2563 }
2564
2565 // Update incoming contact requests of responder
2566 let mut update = proto::UpdateContacts::default();
2567 update
2568 .incoming_requests
2569 .push(proto::IncomingContactRequest {
2570 requester_id: requester_id.to_proto(),
2571 });
2572 let connection_pool = session.connection_pool().await;
2573 for connection_id in connection_pool.user_connection_ids(responder_id) {
2574 session.peer.send(connection_id, update.clone())?;
2575 }
2576
2577 send_notifications(&connection_pool, &session.peer, notifications);
2578
2579 response.send(proto::Ack {})?;
2580 Ok(())
2581}
2582
2583/// Accept or decline a contact request
2584async fn respond_to_contact_request(
2585 request: proto::RespondToContactRequest,
2586 response: Response<proto::RespondToContactRequest>,
2587 session: Session,
2588) -> Result<()> {
2589 let responder_id = session.user_id();
2590 let requester_id = UserId::from_proto(request.requester_id);
2591 let db = session.db().await;
2592 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2593 db.dismiss_contact_notification(responder_id, requester_id)
2594 .await?;
2595 } else {
2596 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2597
2598 let notifications = db
2599 .respond_to_contact_request(responder_id, requester_id, accept)
2600 .await?;
2601 let requester_busy = db.is_user_busy(requester_id).await?;
2602 let responder_busy = db.is_user_busy(responder_id).await?;
2603
2604 let pool = session.connection_pool().await;
2605 // Update responder with new contact
2606 let mut update = proto::UpdateContacts::default();
2607 if accept {
2608 update
2609 .contacts
2610 .push(contact_for_user(requester_id, requester_busy, &pool));
2611 }
2612 update
2613 .remove_incoming_requests
2614 .push(requester_id.to_proto());
2615 for connection_id in pool.user_connection_ids(responder_id) {
2616 session.peer.send(connection_id, update.clone())?;
2617 }
2618
2619 // Update requester with new contact
2620 let mut update = proto::UpdateContacts::default();
2621 if accept {
2622 update
2623 .contacts
2624 .push(contact_for_user(responder_id, responder_busy, &pool));
2625 }
2626 update
2627 .remove_outgoing_requests
2628 .push(responder_id.to_proto());
2629
2630 for connection_id in pool.user_connection_ids(requester_id) {
2631 session.peer.send(connection_id, update.clone())?;
2632 }
2633
2634 send_notifications(&pool, &session.peer, notifications);
2635 }
2636
2637 response.send(proto::Ack {})?;
2638 Ok(())
2639}
2640
2641/// Remove a contact.
2642async fn remove_contact(
2643 request: proto::RemoveContact,
2644 response: Response<proto::RemoveContact>,
2645 session: Session,
2646) -> Result<()> {
2647 let requester_id = session.user_id();
2648 let responder_id = UserId::from_proto(request.user_id);
2649 let db = session.db().await;
2650 let (contact_accepted, deleted_notification_id) =
2651 db.remove_contact(requester_id, responder_id).await?;
2652
2653 let pool = session.connection_pool().await;
2654 // Update outgoing contact requests of requester
2655 let mut update = proto::UpdateContacts::default();
2656 if contact_accepted {
2657 update.remove_contacts.push(responder_id.to_proto());
2658 } else {
2659 update
2660 .remove_outgoing_requests
2661 .push(responder_id.to_proto());
2662 }
2663 for connection_id in pool.user_connection_ids(requester_id) {
2664 session.peer.send(connection_id, update.clone())?;
2665 }
2666
2667 // Update incoming contact requests of responder
2668 let mut update = proto::UpdateContacts::default();
2669 if contact_accepted {
2670 update.remove_contacts.push(requester_id.to_proto());
2671 } else {
2672 update
2673 .remove_incoming_requests
2674 .push(requester_id.to_proto());
2675 }
2676 for connection_id in pool.user_connection_ids(responder_id) {
2677 session.peer.send(connection_id, update.clone())?;
2678 if let Some(notification_id) = deleted_notification_id {
2679 session.peer.send(
2680 connection_id,
2681 proto::DeleteNotification {
2682 notification_id: notification_id.to_proto(),
2683 },
2684 )?;
2685 }
2686 }
2687
2688 response.send(proto::Ack {})?;
2689 Ok(())
2690}
2691
2692fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2693 version.0.minor() < 139
2694}
2695
2696async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2697 let plan = session.current_plan(&session.db().await).await?;
2698
2699 session
2700 .peer
2701 .send(
2702 session.connection_id,
2703 proto::UpdateUserPlan { plan: plan.into() },
2704 )
2705 .trace_err();
2706
2707 Ok(())
2708}
2709
2710async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2711 subscribe_user_to_channels(session.user_id(), &session).await?;
2712 Ok(())
2713}
2714
2715async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2716 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2717 let mut pool = session.connection_pool().await;
2718 for membership in &channels_for_user.channel_memberships {
2719 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2720 }
2721 session.peer.send(
2722 session.connection_id,
2723 build_update_user_channels(&channels_for_user),
2724 )?;
2725 session.peer.send(
2726 session.connection_id,
2727 build_channels_update(channels_for_user),
2728 )?;
2729 Ok(())
2730}
2731
2732/// Creates a new channel.
2733async fn create_channel(
2734 request: proto::CreateChannel,
2735 response: Response<proto::CreateChannel>,
2736 session: Session,
2737) -> Result<()> {
2738 let db = session.db().await;
2739
2740 let parent_id = request.parent_id.map(ChannelId::from_proto);
2741 let (channel, membership) = db
2742 .create_channel(&request.name, parent_id, session.user_id())
2743 .await?;
2744
2745 let root_id = channel.root_id();
2746 let channel = Channel::from_model(channel);
2747
2748 response.send(proto::CreateChannelResponse {
2749 channel: Some(channel.to_proto()),
2750 parent_id: request.parent_id,
2751 })?;
2752
2753 let mut connection_pool = session.connection_pool().await;
2754 if let Some(membership) = membership {
2755 connection_pool.subscribe_to_channel(
2756 membership.user_id,
2757 membership.channel_id,
2758 membership.role,
2759 );
2760 let update = proto::UpdateUserChannels {
2761 channel_memberships: vec![proto::ChannelMembership {
2762 channel_id: membership.channel_id.to_proto(),
2763 role: membership.role.into(),
2764 }],
2765 ..Default::default()
2766 };
2767 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2768 session.peer.send(connection_id, update.clone())?;
2769 }
2770 }
2771
2772 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2773 if !role.can_see_channel(channel.visibility) {
2774 continue;
2775 }
2776
2777 let update = proto::UpdateChannels {
2778 channels: vec![channel.to_proto()],
2779 ..Default::default()
2780 };
2781 session.peer.send(connection_id, update.clone())?;
2782 }
2783
2784 Ok(())
2785}
2786
2787/// Delete a channel
2788async fn delete_channel(
2789 request: proto::DeleteChannel,
2790 response: Response<proto::DeleteChannel>,
2791 session: Session,
2792) -> Result<()> {
2793 let db = session.db().await;
2794
2795 let channel_id = request.channel_id;
2796 let (root_channel, removed_channels) = db
2797 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2798 .await?;
2799 response.send(proto::Ack {})?;
2800
2801 // Notify members of removed channels
2802 let mut update = proto::UpdateChannels::default();
2803 update
2804 .delete_channels
2805 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2806
2807 let connection_pool = session.connection_pool().await;
2808 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2809 session.peer.send(connection_id, update.clone())?;
2810 }
2811
2812 Ok(())
2813}
2814
2815/// Invite someone to join a channel.
2816async fn invite_channel_member(
2817 request: proto::InviteChannelMember,
2818 response: Response<proto::InviteChannelMember>,
2819 session: Session,
2820) -> Result<()> {
2821 let db = session.db().await;
2822 let channel_id = ChannelId::from_proto(request.channel_id);
2823 let invitee_id = UserId::from_proto(request.user_id);
2824 let InviteMemberResult {
2825 channel,
2826 notifications,
2827 } = db
2828 .invite_channel_member(
2829 channel_id,
2830 invitee_id,
2831 session.user_id(),
2832 request.role().into(),
2833 )
2834 .await?;
2835
2836 let update = proto::UpdateChannels {
2837 channel_invitations: vec![channel.to_proto()],
2838 ..Default::default()
2839 };
2840
2841 let connection_pool = session.connection_pool().await;
2842 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2843 session.peer.send(connection_id, update.clone())?;
2844 }
2845
2846 send_notifications(&connection_pool, &session.peer, notifications);
2847
2848 response.send(proto::Ack {})?;
2849 Ok(())
2850}
2851
2852/// remove someone from a channel
2853async fn remove_channel_member(
2854 request: proto::RemoveChannelMember,
2855 response: Response<proto::RemoveChannelMember>,
2856 session: Session,
2857) -> Result<()> {
2858 let db = session.db().await;
2859 let channel_id = ChannelId::from_proto(request.channel_id);
2860 let member_id = UserId::from_proto(request.user_id);
2861
2862 let RemoveChannelMemberResult {
2863 membership_update,
2864 notification_id,
2865 } = db
2866 .remove_channel_member(channel_id, member_id, session.user_id())
2867 .await?;
2868
2869 let mut connection_pool = session.connection_pool().await;
2870 notify_membership_updated(
2871 &mut connection_pool,
2872 membership_update,
2873 member_id,
2874 &session.peer,
2875 );
2876 for connection_id in connection_pool.user_connection_ids(member_id) {
2877 if let Some(notification_id) = notification_id {
2878 session
2879 .peer
2880 .send(
2881 connection_id,
2882 proto::DeleteNotification {
2883 notification_id: notification_id.to_proto(),
2884 },
2885 )
2886 .trace_err();
2887 }
2888 }
2889
2890 response.send(proto::Ack {})?;
2891 Ok(())
2892}
2893
2894/// Toggle the channel between public and private.
2895/// Care is taken to maintain the invariant that public channels only descend from public channels,
2896/// (though members-only channels can appear at any point in the hierarchy).
2897async fn set_channel_visibility(
2898 request: proto::SetChannelVisibility,
2899 response: Response<proto::SetChannelVisibility>,
2900 session: Session,
2901) -> Result<()> {
2902 let db = session.db().await;
2903 let channel_id = ChannelId::from_proto(request.channel_id);
2904 let visibility = request.visibility().into();
2905
2906 let channel_model = db
2907 .set_channel_visibility(channel_id, visibility, session.user_id())
2908 .await?;
2909 let root_id = channel_model.root_id();
2910 let channel = Channel::from_model(channel_model);
2911
2912 let mut connection_pool = session.connection_pool().await;
2913 for (user_id, role) in connection_pool
2914 .channel_user_ids(root_id)
2915 .collect::<Vec<_>>()
2916 .into_iter()
2917 {
2918 let update = if role.can_see_channel(channel.visibility) {
2919 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2920 proto::UpdateChannels {
2921 channels: vec![channel.to_proto()],
2922 ..Default::default()
2923 }
2924 } else {
2925 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2926 proto::UpdateChannels {
2927 delete_channels: vec![channel.id.to_proto()],
2928 ..Default::default()
2929 }
2930 };
2931
2932 for connection_id in connection_pool.user_connection_ids(user_id) {
2933 session.peer.send(connection_id, update.clone())?;
2934 }
2935 }
2936
2937 response.send(proto::Ack {})?;
2938 Ok(())
2939}
2940
2941/// Alter the role for a user in the channel.
2942async fn set_channel_member_role(
2943 request: proto::SetChannelMemberRole,
2944 response: Response<proto::SetChannelMemberRole>,
2945 session: Session,
2946) -> Result<()> {
2947 let db = session.db().await;
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let member_id = UserId::from_proto(request.user_id);
2950 let result = db
2951 .set_channel_member_role(
2952 channel_id,
2953 session.user_id(),
2954 member_id,
2955 request.role().into(),
2956 )
2957 .await?;
2958
2959 match result {
2960 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2961 let mut connection_pool = session.connection_pool().await;
2962 notify_membership_updated(
2963 &mut connection_pool,
2964 membership_update,
2965 member_id,
2966 &session.peer,
2967 )
2968 }
2969 db::SetMemberRoleResult::InviteUpdated(channel) => {
2970 let update = proto::UpdateChannels {
2971 channel_invitations: vec![channel.to_proto()],
2972 ..Default::default()
2973 };
2974
2975 for connection_id in session
2976 .connection_pool()
2977 .await
2978 .user_connection_ids(member_id)
2979 {
2980 session.peer.send(connection_id, update.clone())?;
2981 }
2982 }
2983 }
2984
2985 response.send(proto::Ack {})?;
2986 Ok(())
2987}
2988
2989/// Change the name of a channel
2990async fn rename_channel(
2991 request: proto::RenameChannel,
2992 response: Response<proto::RenameChannel>,
2993 session: Session,
2994) -> Result<()> {
2995 let db = session.db().await;
2996 let channel_id = ChannelId::from_proto(request.channel_id);
2997 let channel_model = db
2998 .rename_channel(channel_id, session.user_id(), &request.name)
2999 .await?;
3000 let root_id = channel_model.root_id();
3001 let channel = Channel::from_model(channel_model);
3002
3003 response.send(proto::RenameChannelResponse {
3004 channel: Some(channel.to_proto()),
3005 })?;
3006
3007 let connection_pool = session.connection_pool().await;
3008 let update = proto::UpdateChannels {
3009 channels: vec![channel.to_proto()],
3010 ..Default::default()
3011 };
3012 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3013 if role.can_see_channel(channel.visibility) {
3014 session.peer.send(connection_id, update.clone())?;
3015 }
3016 }
3017
3018 Ok(())
3019}
3020
3021/// Move a channel to a new parent.
3022async fn move_channel(
3023 request: proto::MoveChannel,
3024 response: Response<proto::MoveChannel>,
3025 session: Session,
3026) -> Result<()> {
3027 let channel_id = ChannelId::from_proto(request.channel_id);
3028 let to = ChannelId::from_proto(request.to);
3029
3030 let (root_id, channels) = session
3031 .db()
3032 .await
3033 .move_channel(channel_id, to, session.user_id())
3034 .await?;
3035
3036 let connection_pool = session.connection_pool().await;
3037 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3038 let channels = channels
3039 .iter()
3040 .filter_map(|channel| {
3041 if role.can_see_channel(channel.visibility) {
3042 Some(channel.to_proto())
3043 } else {
3044 None
3045 }
3046 })
3047 .collect::<Vec<_>>();
3048 if channels.is_empty() {
3049 continue;
3050 }
3051
3052 let update = proto::UpdateChannels {
3053 channels,
3054 ..Default::default()
3055 };
3056
3057 session.peer.send(connection_id, update.clone())?;
3058 }
3059
3060 response.send(Ack {})?;
3061 Ok(())
3062}
3063
3064/// Get the list of channel members
3065async fn get_channel_members(
3066 request: proto::GetChannelMembers,
3067 response: Response<proto::GetChannelMembers>,
3068 session: Session,
3069) -> Result<()> {
3070 let db = session.db().await;
3071 let channel_id = ChannelId::from_proto(request.channel_id);
3072 let limit = if request.limit == 0 {
3073 u16::MAX as u64
3074 } else {
3075 request.limit
3076 };
3077 let (members, users) = db
3078 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3079 .await?;
3080 response.send(proto::GetChannelMembersResponse { members, users })?;
3081 Ok(())
3082}
3083
3084/// Accept or decline a channel invitation.
3085async fn respond_to_channel_invite(
3086 request: proto::RespondToChannelInvite,
3087 response: Response<proto::RespondToChannelInvite>,
3088 session: Session,
3089) -> Result<()> {
3090 let db = session.db().await;
3091 let channel_id = ChannelId::from_proto(request.channel_id);
3092 let RespondToChannelInvite {
3093 membership_update,
3094 notifications,
3095 } = db
3096 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3097 .await?;
3098
3099 let mut connection_pool = session.connection_pool().await;
3100 if let Some(membership_update) = membership_update {
3101 notify_membership_updated(
3102 &mut connection_pool,
3103 membership_update,
3104 session.user_id(),
3105 &session.peer,
3106 );
3107 } else {
3108 let update = proto::UpdateChannels {
3109 remove_channel_invitations: vec![channel_id.to_proto()],
3110 ..Default::default()
3111 };
3112
3113 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3114 session.peer.send(connection_id, update.clone())?;
3115 }
3116 };
3117
3118 send_notifications(&connection_pool, &session.peer, notifications);
3119
3120 response.send(proto::Ack {})?;
3121
3122 Ok(())
3123}
3124
3125/// Join the channels' room
3126async fn join_channel(
3127 request: proto::JoinChannel,
3128 response: Response<proto::JoinChannel>,
3129 session: Session,
3130) -> Result<()> {
3131 let channel_id = ChannelId::from_proto(request.channel_id);
3132 join_channel_internal(channel_id, Box::new(response), session).await
3133}
3134
3135trait JoinChannelInternalResponse {
3136 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3137}
3138impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3139 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3140 Response::<proto::JoinChannel>::send(self, result)
3141 }
3142}
3143impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3144 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3145 Response::<proto::JoinRoom>::send(self, result)
3146 }
3147}
3148
3149async fn join_channel_internal(
3150 channel_id: ChannelId,
3151 response: Box<impl JoinChannelInternalResponse>,
3152 session: Session,
3153) -> Result<()> {
3154 let joined_room = {
3155 let mut db = session.db().await;
3156 // If zed quits without leaving the room, and the user re-opens zed before the
3157 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3158 // room they were in.
3159 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3160 tracing::info!(
3161 stale_connection_id = %connection,
3162 "cleaning up stale connection",
3163 );
3164 drop(db);
3165 leave_room_for_session(&session, connection).await?;
3166 db = session.db().await;
3167 }
3168
3169 let (joined_room, membership_updated, role) = db
3170 .join_channel(channel_id, session.user_id(), session.connection_id)
3171 .await?;
3172
3173 let live_kit_connection_info =
3174 session
3175 .app_state
3176 .livekit_client
3177 .as_ref()
3178 .and_then(|live_kit| {
3179 let (can_publish, token) = if role == ChannelRole::Guest {
3180 (
3181 false,
3182 live_kit
3183 .guest_token(
3184 &joined_room.room.livekit_room,
3185 &session.user_id().to_string(),
3186 )
3187 .trace_err()?,
3188 )
3189 } else {
3190 (
3191 true,
3192 live_kit
3193 .room_token(
3194 &joined_room.room.livekit_room,
3195 &session.user_id().to_string(),
3196 )
3197 .trace_err()?,
3198 )
3199 };
3200
3201 Some(LiveKitConnectionInfo {
3202 server_url: live_kit.url().into(),
3203 token,
3204 can_publish,
3205 })
3206 });
3207
3208 response.send(proto::JoinRoomResponse {
3209 room: Some(joined_room.room.clone()),
3210 channel_id: joined_room
3211 .channel
3212 .as_ref()
3213 .map(|channel| channel.id.to_proto()),
3214 live_kit_connection_info,
3215 })?;
3216
3217 let mut connection_pool = session.connection_pool().await;
3218 if let Some(membership_updated) = membership_updated {
3219 notify_membership_updated(
3220 &mut connection_pool,
3221 membership_updated,
3222 session.user_id(),
3223 &session.peer,
3224 );
3225 }
3226
3227 room_updated(&joined_room.room, &session.peer);
3228
3229 joined_room
3230 };
3231
3232 channel_updated(
3233 &joined_room
3234 .channel
3235 .ok_or_else(|| anyhow!("channel not returned"))?,
3236 &joined_room.room,
3237 &session.peer,
3238 &*session.connection_pool().await,
3239 );
3240
3241 update_user_contacts(session.user_id(), &session).await?;
3242 Ok(())
3243}
3244
3245/// Start editing the channel notes
3246async fn join_channel_buffer(
3247 request: proto::JoinChannelBuffer,
3248 response: Response<proto::JoinChannelBuffer>,
3249 session: Session,
3250) -> Result<()> {
3251 let db = session.db().await;
3252 let channel_id = ChannelId::from_proto(request.channel_id);
3253
3254 let open_response = db
3255 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3256 .await?;
3257
3258 let collaborators = open_response.collaborators.clone();
3259 response.send(open_response)?;
3260
3261 let update = UpdateChannelBufferCollaborators {
3262 channel_id: channel_id.to_proto(),
3263 collaborators: collaborators.clone(),
3264 };
3265 channel_buffer_updated(
3266 session.connection_id,
3267 collaborators
3268 .iter()
3269 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3270 &update,
3271 &session.peer,
3272 );
3273
3274 Ok(())
3275}
3276
3277/// Edit the channel notes
3278async fn update_channel_buffer(
3279 request: proto::UpdateChannelBuffer,
3280 session: Session,
3281) -> Result<()> {
3282 let db = session.db().await;
3283 let channel_id = ChannelId::from_proto(request.channel_id);
3284
3285 let (collaborators, epoch, version) = db
3286 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3287 .await?;
3288
3289 channel_buffer_updated(
3290 session.connection_id,
3291 collaborators.clone(),
3292 &proto::UpdateChannelBuffer {
3293 channel_id: channel_id.to_proto(),
3294 operations: request.operations,
3295 },
3296 &session.peer,
3297 );
3298
3299 let pool = &*session.connection_pool().await;
3300
3301 let non_collaborators =
3302 pool.channel_connection_ids(channel_id)
3303 .filter_map(|(connection_id, _)| {
3304 if collaborators.contains(&connection_id) {
3305 None
3306 } else {
3307 Some(connection_id)
3308 }
3309 });
3310
3311 broadcast(None, non_collaborators, |peer_id| {
3312 session.peer.send(
3313 peer_id,
3314 proto::UpdateChannels {
3315 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3316 channel_id: channel_id.to_proto(),
3317 epoch: epoch as u64,
3318 version: version.clone(),
3319 }],
3320 ..Default::default()
3321 },
3322 )
3323 });
3324
3325 Ok(())
3326}
3327
3328/// Rejoin the channel notes after a connection blip
3329async fn rejoin_channel_buffers(
3330 request: proto::RejoinChannelBuffers,
3331 response: Response<proto::RejoinChannelBuffers>,
3332 session: Session,
3333) -> Result<()> {
3334 let db = session.db().await;
3335 let buffers = db
3336 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3337 .await?;
3338
3339 for rejoined_buffer in &buffers {
3340 let collaborators_to_notify = rejoined_buffer
3341 .buffer
3342 .collaborators
3343 .iter()
3344 .filter_map(|c| Some(c.peer_id?.into()));
3345 channel_buffer_updated(
3346 session.connection_id,
3347 collaborators_to_notify,
3348 &proto::UpdateChannelBufferCollaborators {
3349 channel_id: rejoined_buffer.buffer.channel_id,
3350 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3351 },
3352 &session.peer,
3353 );
3354 }
3355
3356 response.send(proto::RejoinChannelBuffersResponse {
3357 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3358 })?;
3359
3360 Ok(())
3361}
3362
3363/// Stop editing the channel notes
3364async fn leave_channel_buffer(
3365 request: proto::LeaveChannelBuffer,
3366 response: Response<proto::LeaveChannelBuffer>,
3367 session: Session,
3368) -> Result<()> {
3369 let db = session.db().await;
3370 let channel_id = ChannelId::from_proto(request.channel_id);
3371
3372 let left_buffer = db
3373 .leave_channel_buffer(channel_id, session.connection_id)
3374 .await?;
3375
3376 response.send(Ack {})?;
3377
3378 channel_buffer_updated(
3379 session.connection_id,
3380 left_buffer.connections,
3381 &proto::UpdateChannelBufferCollaborators {
3382 channel_id: channel_id.to_proto(),
3383 collaborators: left_buffer.collaborators,
3384 },
3385 &session.peer,
3386 );
3387
3388 Ok(())
3389}
3390
3391fn channel_buffer_updated<T: EnvelopedMessage>(
3392 sender_id: ConnectionId,
3393 collaborators: impl IntoIterator<Item = ConnectionId>,
3394 message: &T,
3395 peer: &Peer,
3396) {
3397 broadcast(Some(sender_id), collaborators, |peer_id| {
3398 peer.send(peer_id, message.clone())
3399 });
3400}
3401
3402fn send_notifications(
3403 connection_pool: &ConnectionPool,
3404 peer: &Peer,
3405 notifications: db::NotificationBatch,
3406) {
3407 for (user_id, notification) in notifications {
3408 for connection_id in connection_pool.user_connection_ids(user_id) {
3409 if let Err(error) = peer.send(
3410 connection_id,
3411 proto::AddNotification {
3412 notification: Some(notification.clone()),
3413 },
3414 ) {
3415 tracing::error!(
3416 "failed to send notification to {:?} {}",
3417 connection_id,
3418 error
3419 );
3420 }
3421 }
3422 }
3423}
3424
3425/// Send a message to the channel
3426async fn send_channel_message(
3427 request: proto::SendChannelMessage,
3428 response: Response<proto::SendChannelMessage>,
3429 session: Session,
3430) -> Result<()> {
3431 // Validate the message body.
3432 let body = request.body.trim().to_string();
3433 if body.len() > MAX_MESSAGE_LEN {
3434 return Err(anyhow!("message is too long"))?;
3435 }
3436 if body.is_empty() {
3437 return Err(anyhow!("message can't be blank"))?;
3438 }
3439
3440 // TODO: adjust mentions if body is trimmed
3441
3442 let timestamp = OffsetDateTime::now_utc();
3443 let nonce = request
3444 .nonce
3445 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3446
3447 let channel_id = ChannelId::from_proto(request.channel_id);
3448 let CreatedChannelMessage {
3449 message_id,
3450 participant_connection_ids,
3451 notifications,
3452 } = session
3453 .db()
3454 .await
3455 .create_channel_message(
3456 channel_id,
3457 session.user_id(),
3458 &body,
3459 &request.mentions,
3460 timestamp,
3461 nonce.clone().into(),
3462 request.reply_to_message_id.map(MessageId::from_proto),
3463 )
3464 .await?;
3465
3466 let message = proto::ChannelMessage {
3467 sender_id: session.user_id().to_proto(),
3468 id: message_id.to_proto(),
3469 body,
3470 mentions: request.mentions,
3471 timestamp: timestamp.unix_timestamp() as u64,
3472 nonce: Some(nonce),
3473 reply_to_message_id: request.reply_to_message_id,
3474 edited_at: None,
3475 };
3476 broadcast(
3477 Some(session.connection_id),
3478 participant_connection_ids.clone(),
3479 |connection| {
3480 session.peer.send(
3481 connection,
3482 proto::ChannelMessageSent {
3483 channel_id: channel_id.to_proto(),
3484 message: Some(message.clone()),
3485 },
3486 )
3487 },
3488 );
3489 response.send(proto::SendChannelMessageResponse {
3490 message: Some(message),
3491 })?;
3492
3493 let pool = &*session.connection_pool().await;
3494 let non_participants =
3495 pool.channel_connection_ids(channel_id)
3496 .filter_map(|(connection_id, _)| {
3497 if participant_connection_ids.contains(&connection_id) {
3498 None
3499 } else {
3500 Some(connection_id)
3501 }
3502 });
3503 broadcast(None, non_participants, |peer_id| {
3504 session.peer.send(
3505 peer_id,
3506 proto::UpdateChannels {
3507 latest_channel_message_ids: vec![proto::ChannelMessageId {
3508 channel_id: channel_id.to_proto(),
3509 message_id: message_id.to_proto(),
3510 }],
3511 ..Default::default()
3512 },
3513 )
3514 });
3515 send_notifications(pool, &session.peer, notifications);
3516
3517 Ok(())
3518}
3519
3520/// Delete a channel message
3521async fn remove_channel_message(
3522 request: proto::RemoveChannelMessage,
3523 response: Response<proto::RemoveChannelMessage>,
3524 session: Session,
3525) -> Result<()> {
3526 let channel_id = ChannelId::from_proto(request.channel_id);
3527 let message_id = MessageId::from_proto(request.message_id);
3528 let (connection_ids, existing_notification_ids) = session
3529 .db()
3530 .await
3531 .remove_channel_message(channel_id, message_id, session.user_id())
3532 .await?;
3533
3534 broadcast(
3535 Some(session.connection_id),
3536 connection_ids,
3537 move |connection| {
3538 session.peer.send(connection, request.clone())?;
3539
3540 for notification_id in &existing_notification_ids {
3541 session.peer.send(
3542 connection,
3543 proto::DeleteNotification {
3544 notification_id: (*notification_id).to_proto(),
3545 },
3546 )?;
3547 }
3548
3549 Ok(())
3550 },
3551 );
3552 response.send(proto::Ack {})?;
3553 Ok(())
3554}
3555
3556async fn update_channel_message(
3557 request: proto::UpdateChannelMessage,
3558 response: Response<proto::UpdateChannelMessage>,
3559 session: Session,
3560) -> Result<()> {
3561 let channel_id = ChannelId::from_proto(request.channel_id);
3562 let message_id = MessageId::from_proto(request.message_id);
3563 let updated_at = OffsetDateTime::now_utc();
3564 let UpdatedChannelMessage {
3565 message_id,
3566 participant_connection_ids,
3567 notifications,
3568 reply_to_message_id,
3569 timestamp,
3570 deleted_mention_notification_ids,
3571 updated_mention_notifications,
3572 } = session
3573 .db()
3574 .await
3575 .update_channel_message(
3576 channel_id,
3577 message_id,
3578 session.user_id(),
3579 request.body.as_str(),
3580 &request.mentions,
3581 updated_at,
3582 )
3583 .await?;
3584
3585 let nonce = request
3586 .nonce
3587 .clone()
3588 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3589
3590 let message = proto::ChannelMessage {
3591 sender_id: session.user_id().to_proto(),
3592 id: message_id.to_proto(),
3593 body: request.body.clone(),
3594 mentions: request.mentions.clone(),
3595 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3596 nonce: Some(nonce),
3597 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3598 edited_at: Some(updated_at.unix_timestamp() as u64),
3599 };
3600
3601 response.send(proto::Ack {})?;
3602
3603 let pool = &*session.connection_pool().await;
3604 broadcast(
3605 Some(session.connection_id),
3606 participant_connection_ids,
3607 |connection| {
3608 session.peer.send(
3609 connection,
3610 proto::ChannelMessageUpdate {
3611 channel_id: channel_id.to_proto(),
3612 message: Some(message.clone()),
3613 },
3614 )?;
3615
3616 for notification_id in &deleted_mention_notification_ids {
3617 session.peer.send(
3618 connection,
3619 proto::DeleteNotification {
3620 notification_id: (*notification_id).to_proto(),
3621 },
3622 )?;
3623 }
3624
3625 for notification in &updated_mention_notifications {
3626 session.peer.send(
3627 connection,
3628 proto::UpdateNotification {
3629 notification: Some(notification.clone()),
3630 },
3631 )?;
3632 }
3633
3634 Ok(())
3635 },
3636 );
3637
3638 send_notifications(pool, &session.peer, notifications);
3639
3640 Ok(())
3641}
3642
3643/// Mark a channel message as read
3644async fn acknowledge_channel_message(
3645 request: proto::AckChannelMessage,
3646 session: Session,
3647) -> Result<()> {
3648 let channel_id = ChannelId::from_proto(request.channel_id);
3649 let message_id = MessageId::from_proto(request.message_id);
3650 let notifications = session
3651 .db()
3652 .await
3653 .observe_channel_message(channel_id, session.user_id(), message_id)
3654 .await?;
3655 send_notifications(
3656 &*session.connection_pool().await,
3657 &session.peer,
3658 notifications,
3659 );
3660 Ok(())
3661}
3662
3663/// Mark a buffer version as synced
3664async fn acknowledge_buffer_version(
3665 request: proto::AckBufferOperation,
3666 session: Session,
3667) -> Result<()> {
3668 let buffer_id = BufferId::from_proto(request.buffer_id);
3669 session
3670 .db()
3671 .await
3672 .observe_buffer_version(
3673 buffer_id,
3674 session.user_id(),
3675 request.epoch as i32,
3676 &request.version,
3677 )
3678 .await?;
3679 Ok(())
3680}
3681
3682/// Get a Supermaven API key for the user
3683async fn get_supermaven_api_key(
3684 _request: proto::GetSupermavenApiKey,
3685 response: Response<proto::GetSupermavenApiKey>,
3686 session: Session,
3687) -> Result<()> {
3688 let user_id: String = session.user_id().to_string();
3689 if !session.is_staff() {
3690 return Err(anyhow!("supermaven not enabled for this account"))?;
3691 }
3692
3693 let email = session
3694 .email()
3695 .ok_or_else(|| anyhow!("user must have an email"))?;
3696
3697 let supermaven_admin_api = session
3698 .supermaven_client
3699 .as_ref()
3700 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3701
3702 let result = supermaven_admin_api
3703 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3704 .await?;
3705
3706 response.send(proto::GetSupermavenApiKeyResponse {
3707 api_key: result.api_key,
3708 })?;
3709
3710 Ok(())
3711}
3712
3713/// Start receiving chat updates for a channel
3714async fn join_channel_chat(
3715 request: proto::JoinChannelChat,
3716 response: Response<proto::JoinChannelChat>,
3717 session: Session,
3718) -> Result<()> {
3719 let channel_id = ChannelId::from_proto(request.channel_id);
3720
3721 let db = session.db().await;
3722 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3723 .await?;
3724 let messages = db
3725 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3726 .await?;
3727 response.send(proto::JoinChannelChatResponse {
3728 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3729 messages,
3730 })?;
3731 Ok(())
3732}
3733
3734/// Stop receiving chat updates for a channel
3735async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3736 let channel_id = ChannelId::from_proto(request.channel_id);
3737 session
3738 .db()
3739 .await
3740 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3741 .await?;
3742 Ok(())
3743}
3744
3745/// Retrieve the chat history for a channel
3746async fn get_channel_messages(
3747 request: proto::GetChannelMessages,
3748 response: Response<proto::GetChannelMessages>,
3749 session: Session,
3750) -> Result<()> {
3751 let channel_id = ChannelId::from_proto(request.channel_id);
3752 let messages = session
3753 .db()
3754 .await
3755 .get_channel_messages(
3756 channel_id,
3757 session.user_id(),
3758 MESSAGE_COUNT_PER_PAGE,
3759 Some(MessageId::from_proto(request.before_message_id)),
3760 )
3761 .await?;
3762 response.send(proto::GetChannelMessagesResponse {
3763 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3764 messages,
3765 })?;
3766 Ok(())
3767}
3768
3769/// Retrieve specific chat messages
3770async fn get_channel_messages_by_id(
3771 request: proto::GetChannelMessagesById,
3772 response: Response<proto::GetChannelMessagesById>,
3773 session: Session,
3774) -> Result<()> {
3775 let message_ids = request
3776 .message_ids
3777 .iter()
3778 .map(|id| MessageId::from_proto(*id))
3779 .collect::<Vec<_>>();
3780 let messages = session
3781 .db()
3782 .await
3783 .get_channel_messages_by_id(session.user_id(), &message_ids)
3784 .await?;
3785 response.send(proto::GetChannelMessagesResponse {
3786 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3787 messages,
3788 })?;
3789 Ok(())
3790}
3791
3792/// Retrieve the current users notifications
3793async fn get_notifications(
3794 request: proto::GetNotifications,
3795 response: Response<proto::GetNotifications>,
3796 session: Session,
3797) -> Result<()> {
3798 let notifications = session
3799 .db()
3800 .await
3801 .get_notifications(
3802 session.user_id(),
3803 NOTIFICATION_COUNT_PER_PAGE,
3804 request.before_id.map(db::NotificationId::from_proto),
3805 )
3806 .await?;
3807 response.send(proto::GetNotificationsResponse {
3808 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3809 notifications,
3810 })?;
3811 Ok(())
3812}
3813
3814/// Mark notifications as read
3815async fn mark_notification_as_read(
3816 request: proto::MarkNotificationRead,
3817 response: Response<proto::MarkNotificationRead>,
3818 session: Session,
3819) -> Result<()> {
3820 let database = &session.db().await;
3821 let notifications = database
3822 .mark_notification_as_read_by_id(
3823 session.user_id(),
3824 NotificationId::from_proto(request.notification_id),
3825 )
3826 .await?;
3827 send_notifications(
3828 &*session.connection_pool().await,
3829 &session.peer,
3830 notifications,
3831 );
3832 response.send(proto::Ack {})?;
3833 Ok(())
3834}
3835
3836/// Get the current users information
3837async fn get_private_user_info(
3838 _request: proto::GetPrivateUserInfo,
3839 response: Response<proto::GetPrivateUserInfo>,
3840 session: Session,
3841) -> Result<()> {
3842 let db = session.db().await;
3843
3844 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3845 let user = db
3846 .get_user_by_id(session.user_id())
3847 .await?
3848 .ok_or_else(|| anyhow!("user not found"))?;
3849 let flags = db.get_user_flags(session.user_id()).await?;
3850
3851 response.send(proto::GetPrivateUserInfoResponse {
3852 metrics_id,
3853 staff: user.admin,
3854 flags,
3855 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3856 })?;
3857 Ok(())
3858}
3859
3860/// Accept the terms of service (tos) on behalf of the current user
3861async fn accept_terms_of_service(
3862 _request: proto::AcceptTermsOfService,
3863 response: Response<proto::AcceptTermsOfService>,
3864 session: Session,
3865) -> Result<()> {
3866 let db = session.db().await;
3867
3868 let accepted_tos_at = Utc::now();
3869 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3870 .await?;
3871
3872 response.send(proto::AcceptTermsOfServiceResponse {
3873 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3874 })?;
3875 Ok(())
3876}
3877
3878/// The minimum account age an account must have in order to use the LLM service.
3879pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3880
3881async fn get_llm_api_token(
3882 _request: proto::GetLlmToken,
3883 response: Response<proto::GetLlmToken>,
3884 session: Session,
3885) -> Result<()> {
3886 let db = session.db().await;
3887
3888 let flags = db.get_user_flags(session.user_id()).await?;
3889 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3890
3891 if !session.is_staff() && !has_language_models_feature_flag {
3892 Err(anyhow!("permission denied"))?
3893 }
3894
3895 let user_id = session.user_id();
3896 let user = db
3897 .get_user_by_id(user_id)
3898 .await?
3899 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3900
3901 if user.accepted_tos_at.is_none() {
3902 Err(anyhow!("terms of service not accepted"))?
3903 }
3904
3905 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
3906 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
3907 let billing_preferences = db.get_billing_preferences(user.id).await?;
3908
3909 let token = LlmTokenClaims::create(
3910 &user,
3911 session.is_staff(),
3912 billing_preferences,
3913 &flags,
3914 has_legacy_llm_subscription,
3915 billing_subscription,
3916 session.system_id.clone(),
3917 &session.app_state.config,
3918 )?;
3919 response.send(proto::GetLlmTokenResponse { token })?;
3920 Ok(())
3921}
3922
3923fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3924 let message = match message {
3925 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3926 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3927 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3928 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3929 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3930 code: frame.code.into(),
3931 reason: frame.reason.as_str().to_owned().into(),
3932 })),
3933 // We should never receive a frame while reading the message, according
3934 // to the `tungstenite` maintainers:
3935 //
3936 // > It cannot occur when you read messages from the WebSocket, but it
3937 // > can be used when you want to send the raw frames (e.g. you want to
3938 // > send the frames to the WebSocket without composing the full message first).
3939 // >
3940 // > — https://github.com/snapview/tungstenite-rs/issues/268
3941 TungsteniteMessage::Frame(_) => {
3942 bail!("received an unexpected frame while reading the message")
3943 }
3944 };
3945
3946 Ok(message)
3947}
3948
3949fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3950 match message {
3951 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3952 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3953 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3954 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3955 AxumMessage::Close(frame) => {
3956 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3957 code: frame.code.into(),
3958 reason: frame.reason.as_ref().into(),
3959 }))
3960 }
3961 }
3962}
3963
3964fn notify_membership_updated(
3965 connection_pool: &mut ConnectionPool,
3966 result: MembershipUpdated,
3967 user_id: UserId,
3968 peer: &Peer,
3969) {
3970 for membership in &result.new_channels.channel_memberships {
3971 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3972 }
3973 for channel_id in &result.removed_channels {
3974 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3975 }
3976
3977 let user_channels_update = proto::UpdateUserChannels {
3978 channel_memberships: result
3979 .new_channels
3980 .channel_memberships
3981 .iter()
3982 .map(|cm| proto::ChannelMembership {
3983 channel_id: cm.channel_id.to_proto(),
3984 role: cm.role.into(),
3985 })
3986 .collect(),
3987 ..Default::default()
3988 };
3989
3990 let mut update = build_channels_update(result.new_channels);
3991 update.delete_channels = result
3992 .removed_channels
3993 .into_iter()
3994 .map(|id| id.to_proto())
3995 .collect();
3996 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3997
3998 for connection_id in connection_pool.user_connection_ids(user_id) {
3999 peer.send(connection_id, user_channels_update.clone())
4000 .trace_err();
4001 peer.send(connection_id, update.clone()).trace_err();
4002 }
4003}
4004
4005fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4006 proto::UpdateUserChannels {
4007 channel_memberships: channels
4008 .channel_memberships
4009 .iter()
4010 .map(|m| proto::ChannelMembership {
4011 channel_id: m.channel_id.to_proto(),
4012 role: m.role.into(),
4013 })
4014 .collect(),
4015 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4016 observed_channel_message_id: channels.observed_channel_messages.clone(),
4017 }
4018}
4019
4020fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4021 let mut update = proto::UpdateChannels::default();
4022
4023 for channel in channels.channels {
4024 update.channels.push(channel.to_proto());
4025 }
4026
4027 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4028 update.latest_channel_message_ids = channels.latest_channel_messages;
4029
4030 for (channel_id, participants) in channels.channel_participants {
4031 update
4032 .channel_participants
4033 .push(proto::ChannelParticipants {
4034 channel_id: channel_id.to_proto(),
4035 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4036 });
4037 }
4038
4039 for channel in channels.invited_channels {
4040 update.channel_invitations.push(channel.to_proto());
4041 }
4042
4043 update
4044}
4045
4046fn build_initial_contacts_update(
4047 contacts: Vec<db::Contact>,
4048 pool: &ConnectionPool,
4049) -> proto::UpdateContacts {
4050 let mut update = proto::UpdateContacts::default();
4051
4052 for contact in contacts {
4053 match contact {
4054 db::Contact::Accepted { user_id, busy } => {
4055 update.contacts.push(contact_for_user(user_id, busy, pool));
4056 }
4057 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4058 db::Contact::Incoming { user_id } => {
4059 update
4060 .incoming_requests
4061 .push(proto::IncomingContactRequest {
4062 requester_id: user_id.to_proto(),
4063 })
4064 }
4065 }
4066 }
4067
4068 update
4069}
4070
4071fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4072 proto::Contact {
4073 user_id: user_id.to_proto(),
4074 online: pool.is_user_online(user_id),
4075 busy,
4076 }
4077}
4078
4079fn room_updated(room: &proto::Room, peer: &Peer) {
4080 broadcast(
4081 None,
4082 room.participants
4083 .iter()
4084 .filter_map(|participant| Some(participant.peer_id?.into())),
4085 |peer_id| {
4086 peer.send(
4087 peer_id,
4088 proto::RoomUpdated {
4089 room: Some(room.clone()),
4090 },
4091 )
4092 },
4093 );
4094}
4095
4096fn channel_updated(
4097 channel: &db::channel::Model,
4098 room: &proto::Room,
4099 peer: &Peer,
4100 pool: &ConnectionPool,
4101) {
4102 let participants = room
4103 .participants
4104 .iter()
4105 .map(|p| p.user_id)
4106 .collect::<Vec<_>>();
4107
4108 broadcast(
4109 None,
4110 pool.channel_connection_ids(channel.root_id())
4111 .filter_map(|(channel_id, role)| {
4112 role.can_see_channel(channel.visibility)
4113 .then_some(channel_id)
4114 }),
4115 |peer_id| {
4116 peer.send(
4117 peer_id,
4118 proto::UpdateChannels {
4119 channel_participants: vec![proto::ChannelParticipants {
4120 channel_id: channel.id.to_proto(),
4121 participant_user_ids: participants.clone(),
4122 }],
4123 ..Default::default()
4124 },
4125 )
4126 },
4127 );
4128}
4129
4130async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4131 let db = session.db().await;
4132
4133 let contacts = db.get_contacts(user_id).await?;
4134 let busy = db.is_user_busy(user_id).await?;
4135
4136 let pool = session.connection_pool().await;
4137 let updated_contact = contact_for_user(user_id, busy, &pool);
4138 for contact in contacts {
4139 if let db::Contact::Accepted {
4140 user_id: contact_user_id,
4141 ..
4142 } = contact
4143 {
4144 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4145 session
4146 .peer
4147 .send(
4148 contact_conn_id,
4149 proto::UpdateContacts {
4150 contacts: vec![updated_contact.clone()],
4151 remove_contacts: Default::default(),
4152 incoming_requests: Default::default(),
4153 remove_incoming_requests: Default::default(),
4154 outgoing_requests: Default::default(),
4155 remove_outgoing_requests: Default::default(),
4156 },
4157 )
4158 .trace_err();
4159 }
4160 }
4161 }
4162 Ok(())
4163}
4164
4165async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4166 let mut contacts_to_update = HashSet::default();
4167
4168 let room_id;
4169 let canceled_calls_to_user_ids;
4170 let livekit_room;
4171 let delete_livekit_room;
4172 let room;
4173 let channel;
4174
4175 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4176 contacts_to_update.insert(session.user_id());
4177
4178 for project in left_room.left_projects.values() {
4179 project_left(project, session);
4180 }
4181
4182 room_id = RoomId::from_proto(left_room.room.id);
4183 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4184 livekit_room = mem::take(&mut left_room.room.livekit_room);
4185 delete_livekit_room = left_room.deleted;
4186 room = mem::take(&mut left_room.room);
4187 channel = mem::take(&mut left_room.channel);
4188
4189 room_updated(&room, &session.peer);
4190 } else {
4191 return Ok(());
4192 }
4193
4194 if let Some(channel) = channel {
4195 channel_updated(
4196 &channel,
4197 &room,
4198 &session.peer,
4199 &*session.connection_pool().await,
4200 );
4201 }
4202
4203 {
4204 let pool = session.connection_pool().await;
4205 for canceled_user_id in canceled_calls_to_user_ids {
4206 for connection_id in pool.user_connection_ids(canceled_user_id) {
4207 session
4208 .peer
4209 .send(
4210 connection_id,
4211 proto::CallCanceled {
4212 room_id: room_id.to_proto(),
4213 },
4214 )
4215 .trace_err();
4216 }
4217 contacts_to_update.insert(canceled_user_id);
4218 }
4219 }
4220
4221 for contact_user_id in contacts_to_update {
4222 update_user_contacts(contact_user_id, session).await?;
4223 }
4224
4225 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4226 live_kit
4227 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4228 .await
4229 .trace_err();
4230
4231 if delete_livekit_room {
4232 live_kit.delete_room(livekit_room).await.trace_err();
4233 }
4234 }
4235
4236 Ok(())
4237}
4238
4239async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4240 let left_channel_buffers = session
4241 .db()
4242 .await
4243 .leave_channel_buffers(session.connection_id)
4244 .await?;
4245
4246 for left_buffer in left_channel_buffers {
4247 channel_buffer_updated(
4248 session.connection_id,
4249 left_buffer.connections,
4250 &proto::UpdateChannelBufferCollaborators {
4251 channel_id: left_buffer.channel_id.to_proto(),
4252 collaborators: left_buffer.collaborators,
4253 },
4254 &session.peer,
4255 );
4256 }
4257
4258 Ok(())
4259}
4260
4261fn project_left(project: &db::LeftProject, session: &Session) {
4262 for connection_id in &project.connection_ids {
4263 if project.should_unshare {
4264 session
4265 .peer
4266 .send(
4267 *connection_id,
4268 proto::UnshareProject {
4269 project_id: project.id.to_proto(),
4270 },
4271 )
4272 .trace_err();
4273 } else {
4274 session
4275 .peer
4276 .send(
4277 *connection_id,
4278 proto::RemoveProjectCollaborator {
4279 project_id: project.id.to_proto(),
4280 peer_id: Some(session.connection_id.into()),
4281 },
4282 )
4283 .trace_err();
4284 }
4285 }
4286}
4287
4288pub trait ResultExt {
4289 type Ok;
4290
4291 fn trace_err(self) -> Option<Self::Ok>;
4292}
4293
4294impl<T, E> ResultExt for Result<T, E>
4295where
4296 E: std::fmt::Debug,
4297{
4298 type Ok = T;
4299
4300 #[track_caller]
4301 fn trace_err(self) -> Option<T> {
4302 match self {
4303 Ok(value) => Some(value),
4304 Err(error) => {
4305 tracing::error!("{:?}", error);
4306 None
4307 }
4308 }
4309 }
4310}