1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
93
94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn user(&self) -> &User {
144 match self {
145 Principal::User(user) => user,
146 Principal::Impersonated { user, .. } => user,
147 }
148 }
149
150 fn update_span(&self, span: &tracing::Span) {
151 match &self {
152 Principal::User(user) => {
153 span.record("user_id", user.id.0);
154 span.record("login", &user.github_login);
155 }
156 Principal::Impersonated { user, admin } => {
157 span.record("user_id", user.id.0);
158 span.record("login", &user.github_login);
159 span.record("impersonator", &admin.github_login);
160 }
161 }
162 }
163}
164
165#[derive(Clone)]
166struct Session {
167 principal: Principal,
168 connection_id: ConnectionId,
169 db: Arc<tokio::sync::Mutex<DbHandle>>,
170 peer: Arc<Peer>,
171 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
172 app_state: Arc<AppState>,
173 supermaven_client: Option<Arc<SupermavenAdminApi>>,
174 /// The GeoIP country code for the user.
175 #[allow(unused)]
176 geoip_country_code: Option<String>,
177 system_id: Option<String>,
178 _executor: Executor,
179}
180
181impl Session {
182 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
183 #[cfg(test)]
184 tokio::task::yield_now().await;
185 let guard = self.db.lock().await;
186 #[cfg(test)]
187 tokio::task::yield_now().await;
188 guard
189 }
190
191 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
192 #[cfg(test)]
193 tokio::task::yield_now().await;
194 let guard = self.connection_pool.lock();
195 ConnectionPoolGuard {
196 guard,
197 _not_send: PhantomData,
198 }
199 }
200
201 fn is_staff(&self) -> bool {
202 match &self.principal {
203 Principal::User(user) => user.admin,
204 Principal::Impersonated { .. } => true,
205 }
206 }
207
208 fn user_id(&self) -> UserId {
209 match &self.principal {
210 Principal::User(user) => user.id,
211 Principal::Impersonated { user, .. } => user.id,
212 }
213 }
214
215 pub fn email(&self) -> Option<String> {
216 match &self.principal {
217 Principal::User(user) => user.email_address.clone(),
218 Principal::Impersonated { user, .. } => user.email_address.clone(),
219 }
220 }
221}
222
223impl Debug for Session {
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 let mut result = f.debug_struct("Session");
226 match &self.principal {
227 Principal::User(user) => {
228 result.field("user", &user.github_login);
229 }
230 Principal::Impersonated { user, admin } => {
231 result.field("user", &user.github_login);
232 result.field("impersonator", &admin.github_login);
233 }
234 }
235 result.field("connection_id", &self.connection_id).finish()
236 }
237}
238
239struct DbHandle(Arc<Database>);
240
241impl Deref for DbHandle {
242 type Target = Database;
243
244 fn deref(&self) -> &Self::Target {
245 self.0.as_ref()
246 }
247}
248
249pub struct Server {
250 id: parking_lot::Mutex<ServerId>,
251 peer: Arc<Peer>,
252 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
253 app_state: Arc<AppState>,
254 handlers: HashMap<TypeId, MessageHandler>,
255 teardown: watch::Sender<bool>,
256}
257
258pub(crate) struct ConnectionPoolGuard<'a> {
259 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
260 _not_send: PhantomData<Rc<()>>,
261}
262
263#[derive(Serialize)]
264pub struct ServerSnapshot<'a> {
265 peer: &'a Peer,
266 #[serde(serialize_with = "serialize_deref")]
267 connection_pool: ConnectionPoolGuard<'a>,
268}
269
270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
271where
272 S: Serializer,
273 T: Deref<Target = U>,
274 U: Serialize,
275{
276 Serialize::serialize(value.deref(), serializer)
277}
278
279impl Server {
280 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
281 let mut server = Self {
282 id: parking_lot::Mutex::new(id),
283 peer: Peer::new(id.0 as u32),
284 app_state: app_state.clone(),
285 connection_pool: Default::default(),
286 handlers: Default::default(),
287 teardown: watch::channel(false).0,
288 };
289
290 server
291 .add_request_handler(ping)
292 .add_request_handler(create_room)
293 .add_request_handler(join_room)
294 .add_request_handler(rejoin_room)
295 .add_request_handler(leave_room)
296 .add_request_handler(set_room_participant_role)
297 .add_request_handler(call)
298 .add_request_handler(cancel_call)
299 .add_message_handler(decline_call)
300 .add_request_handler(update_participant_location)
301 .add_request_handler(share_project)
302 .add_message_handler(unshare_project)
303 .add_request_handler(join_project)
304 .add_message_handler(leave_project)
305 .add_request_handler(update_project)
306 .add_request_handler(update_worktree)
307 .add_request_handler(update_repository)
308 .add_request_handler(remove_repository)
309 .add_message_handler(start_language_server)
310 .add_message_handler(update_language_server)
311 .add_message_handler(update_diagnostic_summary)
312 .add_message_handler(update_worktree_settings)
313 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
317 .add_request_handler(forward_find_search_candidates_request)
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
323 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
324 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
325 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
327 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
328 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
330 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
331 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
332 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
333 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
334 .add_request_handler(
335 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
336 )
337 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
341 .add_request_handler(
342 forward_read_only_project_request::<proto::LanguageServerIdForName>,
343 )
344 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
345 .add_request_handler(
346 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
347 )
348 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
349 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
352 )
353 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
363 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
364 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
369 .add_request_handler(
370 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
371 )
372 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
373 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
374 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
375 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
376 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
377 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
378 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
379 .add_message_handler(create_buffer_for_peer)
380 .add_request_handler(update_buffer)
381 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
387 .add_message_handler(
388 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
389 )
390 .add_request_handler(get_users)
391 .add_request_handler(fuzzy_search_users)
392 .add_request_handler(request_contact)
393 .add_request_handler(remove_contact)
394 .add_request_handler(respond_to_contact_request)
395 .add_message_handler(subscribe_to_channels)
396 .add_request_handler(create_channel)
397 .add_request_handler(delete_channel)
398 .add_request_handler(invite_channel_member)
399 .add_request_handler(remove_channel_member)
400 .add_request_handler(set_channel_member_role)
401 .add_request_handler(set_channel_visibility)
402 .add_request_handler(rename_channel)
403 .add_request_handler(join_channel_buffer)
404 .add_request_handler(leave_channel_buffer)
405 .add_message_handler(update_channel_buffer)
406 .add_request_handler(rejoin_channel_buffers)
407 .add_request_handler(get_channel_members)
408 .add_request_handler(respond_to_channel_invite)
409 .add_request_handler(join_channel)
410 .add_request_handler(join_channel_chat)
411 .add_message_handler(leave_channel_chat)
412 .add_request_handler(send_channel_message)
413 .add_request_handler(remove_channel_message)
414 .add_request_handler(update_channel_message)
415 .add_request_handler(get_channel_messages)
416 .add_request_handler(get_channel_messages_by_id)
417 .add_request_handler(get_notifications)
418 .add_request_handler(mark_notification_as_read)
419 .add_request_handler(move_channel)
420 .add_request_handler(reorder_channel)
421 .add_request_handler(follow)
422 .add_message_handler(unfollow)
423 .add_message_handler(update_followers)
424 .add_request_handler(get_private_user_info)
425 .add_request_handler(get_llm_api_token)
426 .add_request_handler(accept_terms_of_service)
427 .add_message_handler(acknowledge_channel_message)
428 .add_message_handler(acknowledge_buffer_version)
429 .add_request_handler(get_supermaven_api_key)
430 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
431 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
432 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
433 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
434 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
435 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
437 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
439 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
440 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
442 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
443 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
444 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
445 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
446 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
450 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
451 .add_message_handler(update_context);
452
453 Arc::new(server)
454 }
455
456 pub async fn start(&self) -> Result<()> {
457 let server_id = *self.id.lock();
458 let app_state = self.app_state.clone();
459 let peer = self.peer.clone();
460 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
461 let pool = self.connection_pool.clone();
462 let livekit_client = self.app_state.livekit_client.clone();
463
464 let span = info_span!("start server");
465 self.app_state.executor.spawn_detached(
466 async move {
467 tracing::info!("waiting for cleanup timeout");
468 timeout.await;
469 tracing::info!("cleanup timeout expired, retrieving stale rooms");
470
471 app_state
472 .db
473 .delete_stale_channel_chat_participants(
474 &app_state.config.zed_environment,
475 server_id,
476 )
477 .await
478 .trace_err();
479
480 if let Some((room_ids, channel_ids)) = app_state
481 .db
482 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
483 .await
484 .trace_err()
485 {
486 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
487 tracing::info!(
488 stale_channel_buffer_count = channel_ids.len(),
489 "retrieved stale channel buffers"
490 );
491
492 for channel_id in channel_ids {
493 if let Some(refreshed_channel_buffer) = app_state
494 .db
495 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
496 .await
497 .trace_err()
498 {
499 for connection_id in refreshed_channel_buffer.connection_ids {
500 peer.send(
501 connection_id,
502 proto::UpdateChannelBufferCollaborators {
503 channel_id: channel_id.to_proto(),
504 collaborators: refreshed_channel_buffer
505 .collaborators
506 .clone(),
507 },
508 )
509 .trace_err();
510 }
511 }
512 }
513
514 for room_id in room_ids {
515 let mut contacts_to_update = HashSet::default();
516 let mut canceled_calls_to_user_ids = Vec::new();
517 let mut livekit_room = String::new();
518 let mut delete_livekit_room = false;
519
520 if let Some(mut refreshed_room) = app_state
521 .db
522 .clear_stale_room_participants(room_id, server_id)
523 .await
524 .trace_err()
525 {
526 tracing::info!(
527 room_id = room_id.0,
528 new_participant_count = refreshed_room.room.participants.len(),
529 "refreshed room"
530 );
531 room_updated(&refreshed_room.room, &peer);
532 if let Some(channel) = refreshed_room.channel.as_ref() {
533 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
534 }
535 contacts_to_update
536 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
537 contacts_to_update
538 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
539 canceled_calls_to_user_ids =
540 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
541 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
542 delete_livekit_room = refreshed_room.room.participants.is_empty();
543 }
544
545 {
546 let pool = pool.lock();
547 for canceled_user_id in canceled_calls_to_user_ids {
548 for connection_id in pool.user_connection_ids(canceled_user_id) {
549 peer.send(
550 connection_id,
551 proto::CallCanceled {
552 room_id: room_id.to_proto(),
553 },
554 )
555 .trace_err();
556 }
557 }
558 }
559
560 for user_id in contacts_to_update {
561 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
562 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
563 if let Some((busy, contacts)) = busy.zip(contacts) {
564 let pool = pool.lock();
565 let updated_contact = contact_for_user(user_id, busy, &pool);
566 for contact in contacts {
567 if let db::Contact::Accepted {
568 user_id: contact_user_id,
569 ..
570 } = contact
571 {
572 for contact_conn_id in
573 pool.user_connection_ids(contact_user_id)
574 {
575 peer.send(
576 contact_conn_id,
577 proto::UpdateContacts {
578 contacts: vec![updated_contact.clone()],
579 remove_contacts: Default::default(),
580 incoming_requests: Default::default(),
581 remove_incoming_requests: Default::default(),
582 outgoing_requests: Default::default(),
583 remove_outgoing_requests: Default::default(),
584 },
585 )
586 .trace_err();
587 }
588 }
589 }
590 }
591 }
592
593 if let Some(live_kit) = livekit_client.as_ref() {
594 if delete_livekit_room {
595 live_kit.delete_room(livekit_room).await.trace_err();
596 }
597 }
598 }
599 }
600
601 app_state
602 .db
603 .delete_stale_channel_chat_participants(
604 &app_state.config.zed_environment,
605 server_id,
606 )
607 .await
608 .trace_err();
609
610 app_state
611 .db
612 .clear_old_worktree_entries(server_id)
613 .await
614 .trace_err();
615
616 app_state
617 .db
618 .delete_stale_servers(&app_state.config.zed_environment, server_id)
619 .await
620 .trace_err();
621 }
622 .instrument(span),
623 );
624 Ok(())
625 }
626
627 pub fn teardown(&self) {
628 self.peer.teardown();
629 self.connection_pool.lock().reset();
630 let _ = self.teardown.send(true);
631 }
632
633 #[cfg(test)]
634 pub fn reset(&self, id: ServerId) {
635 self.teardown();
636 *self.id.lock() = id;
637 self.peer.reset(id.0 as u32);
638 let _ = self.teardown.send(false);
639 }
640
641 #[cfg(test)]
642 pub fn id(&self) -> ServerId {
643 *self.id.lock()
644 }
645
646 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
649 Fut: 'static + Send + Future<Output = Result<()>>,
650 M: EnvelopedMessage,
651 {
652 let prev_handler = self.handlers.insert(
653 TypeId::of::<M>(),
654 Box::new(move |envelope, session| {
655 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
656 let received_at = envelope.received_at;
657 tracing::info!("message received");
658 let start_time = Instant::now();
659 let future = (handler)(*envelope, session);
660 async move {
661 let result = future.await;
662 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
663 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
664 let queue_duration_ms = total_duration_ms - processing_duration_ms;
665 let payload_type = M::NAME;
666
667 match result {
668 Err(error) => {
669 tracing::error!(
670 ?error,
671 total_duration_ms,
672 processing_duration_ms,
673 queue_duration_ms,
674 payload_type,
675 "error handling message"
676 )
677 }
678 Ok(()) => tracing::info!(
679 total_duration_ms,
680 processing_duration_ms,
681 queue_duration_ms,
682 "finished handling message"
683 ),
684 }
685 }
686 .boxed()
687 }),
688 );
689 if prev_handler.is_some() {
690 panic!("registered a handler for the same message twice");
691 }
692 self
693 }
694
695 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
696 where
697 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
698 Fut: 'static + Send + Future<Output = Result<()>>,
699 M: EnvelopedMessage,
700 {
701 self.add_handler(move |envelope, session| handler(envelope.payload, session));
702 self
703 }
704
705 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
706 where
707 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
708 Fut: Send + Future<Output = Result<()>>,
709 M: RequestMessage,
710 {
711 let handler = Arc::new(handler);
712 self.add_handler(move |envelope, session| {
713 let receipt = envelope.receipt();
714 let handler = handler.clone();
715 async move {
716 let peer = session.peer.clone();
717 let responded = Arc::new(AtomicBool::default());
718 let response = Response {
719 peer: peer.clone(),
720 responded: responded.clone(),
721 receipt,
722 };
723 match (handler)(envelope.payload, response, session).await {
724 Ok(()) => {
725 if responded.load(std::sync::atomic::Ordering::SeqCst) {
726 Ok(())
727 } else {
728 Err(anyhow!("handler did not send a response"))?
729 }
730 }
731 Err(error) => {
732 let proto_err = match &error {
733 Error::Internal(err) => err.to_proto(),
734 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
735 };
736 peer.respond_with_error(receipt, proto_err)?;
737 Err(error)
738 }
739 }
740 }
741 })
742 }
743
744 pub fn handle_connection(
745 self: &Arc<Self>,
746 connection: Connection,
747 address: String,
748 principal: Principal,
749 zed_version: ZedVersion,
750 geoip_country_code: Option<String>,
751 system_id: Option<String>,
752 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
753 executor: Executor,
754 connection_guard: Option<ConnectionGuard>,
755 ) -> impl Future<Output = ()> + use<> {
756 let this = self.clone();
757 let span = info_span!("handle connection", %address,
758 connection_id=field::Empty,
759 user_id=field::Empty,
760 login=field::Empty,
761 impersonator=field::Empty,
762 geoip_country_code=field::Empty
763 );
764 principal.update_span(&span);
765 if let Some(country_code) = geoip_country_code.as_ref() {
766 span.record("geoip_country_code", country_code);
767 }
768
769 let mut teardown = self.teardown.subscribe();
770 async move {
771 if *teardown.borrow() {
772 tracing::error!("server is tearing down");
773 return
774 }
775
776 let (connection_id, handle_io, mut incoming_rx) = this
777 .peer
778 .add_connection(connection, {
779 let executor = executor.clone();
780 move |duration| executor.sleep(duration)
781 });
782 tracing::Span::current().record("connection_id", format!("{}", connection_id));
783
784 tracing::info!("connection opened");
785
786 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
787 let http_client = match ReqwestClient::user_agent(&user_agent) {
788 Ok(http_client) => Arc::new(http_client),
789 Err(error) => {
790 tracing::error!(?error, "failed to create HTTP client");
791 return;
792 }
793 };
794
795 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
796 supermaven_admin_api_key.to_string(),
797 http_client.clone(),
798 )));
799
800 let session = Session {
801 principal: principal.clone(),
802 connection_id,
803 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
804 peer: this.peer.clone(),
805 connection_pool: this.connection_pool.clone(),
806 app_state: this.app_state.clone(),
807 geoip_country_code,
808 system_id,
809 _executor: executor.clone(),
810 supermaven_client,
811 };
812
813 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
814 tracing::error!(?error, "failed to send initial client update");
815 return;
816 }
817 drop(connection_guard);
818
819 let handle_io = handle_io.fuse();
820 futures::pin_mut!(handle_io);
821
822 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
823 // This prevents deadlocks when e.g., client A performs a request to client B and
824 // client B performs a request to client A. If both clients stop processing further
825 // messages until their respective request completes, they won't have a chance to
826 // respond to the other client's request and cause a deadlock.
827 //
828 // This arrangement ensures we will attempt to process earlier messages first, but fall
829 // back to processing messages arrived later in the spirit of making progress.
830 let mut foreground_message_handlers = FuturesUnordered::new();
831 let concurrent_handlers = Arc::new(Semaphore::new(256));
832 loop {
833 let next_message = async {
834 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
835 let message = incoming_rx.next().await;
836 (permit, message)
837 }.fuse();
838 futures::pin_mut!(next_message);
839 futures::select_biased! {
840 _ = teardown.changed().fuse() => return,
841 result = handle_io => {
842 if let Err(error) = result {
843 tracing::error!(?error, "error handling I/O");
844 }
845 break;
846 }
847 _ = foreground_message_handlers.next() => {}
848 next_message = next_message => {
849 let (permit, message) = next_message;
850 if let Some(message) = message {
851 let type_name = message.payload_type_name();
852 // note: we copy all the fields from the parent span so we can query them in the logs.
853 // (https://github.com/tokio-rs/tracing/issues/2670).
854 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
855 user_id=field::Empty,
856 login=field::Empty,
857 impersonator=field::Empty,
858 );
859 principal.update_span(&span);
860 let span_enter = span.enter();
861 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
862 let is_background = message.is_background();
863 let handle_message = (handler)(message, session.clone());
864 drop(span_enter);
865
866 let handle_message = async move {
867 handle_message.await;
868 drop(permit);
869 }.instrument(span);
870 if is_background {
871 executor.spawn_detached(handle_message);
872 } else {
873 foreground_message_handlers.push(handle_message);
874 }
875 } else {
876 tracing::error!("no message handler");
877 }
878 } else {
879 tracing::info!("connection closed");
880 break;
881 }
882 }
883 }
884 }
885
886 drop(foreground_message_handlers);
887 tracing::info!("signing out");
888 if let Err(error) = connection_lost(session, teardown, executor).await {
889 tracing::error!(?error, "error signing out");
890 }
891
892 }.instrument(span)
893 }
894
895 async fn send_initial_client_update(
896 &self,
897 connection_id: ConnectionId,
898 zed_version: ZedVersion,
899 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
900 session: &Session,
901 ) -> Result<()> {
902 self.peer.send(
903 connection_id,
904 proto::Hello {
905 peer_id: Some(connection_id.into()),
906 },
907 )?;
908 tracing::info!("sent hello message");
909 if let Some(send_connection_id) = send_connection_id.take() {
910 let _ = send_connection_id.send(connection_id);
911 }
912
913 match &session.principal {
914 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
915 if !user.connected_once {
916 self.peer.send(connection_id, proto::ShowContacts {})?;
917 self.app_state
918 .db
919 .set_user_connected_once(user.id, true)
920 .await?;
921 }
922
923 update_user_plan(session).await?;
924
925 let contacts = self.app_state.db.get_contacts(user.id).await?;
926
927 {
928 let mut pool = self.connection_pool.lock();
929 pool.add_connection(connection_id, user.id, user.admin, zed_version);
930 self.peer.send(
931 connection_id,
932 build_initial_contacts_update(contacts, &pool),
933 )?;
934 }
935
936 if should_auto_subscribe_to_channels(zed_version) {
937 subscribe_user_to_channels(user.id, session).await?;
938 }
939
940 if let Some(incoming_call) =
941 self.app_state.db.incoming_call_for_user(user.id).await?
942 {
943 self.peer.send(connection_id, incoming_call)?;
944 }
945
946 update_user_contacts(user.id, session).await?;
947 }
948 }
949
950 Ok(())
951 }
952
953 pub async fn invite_code_redeemed(
954 self: &Arc<Self>,
955 inviter_id: UserId,
956 invitee_id: UserId,
957 ) -> Result<()> {
958 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
959 if let Some(code) = &user.invite_code {
960 let pool = self.connection_pool.lock();
961 let invitee_contact = contact_for_user(invitee_id, false, &pool);
962 for connection_id in pool.user_connection_ids(inviter_id) {
963 self.peer.send(
964 connection_id,
965 proto::UpdateContacts {
966 contacts: vec![invitee_contact.clone()],
967 ..Default::default()
968 },
969 )?;
970 self.peer.send(
971 connection_id,
972 proto::UpdateInviteInfo {
973 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
974 count: user.invite_count as u32,
975 },
976 )?;
977 }
978 }
979 }
980 Ok(())
981 }
982
983 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
984 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
985 if let Some(invite_code) = &user.invite_code {
986 let pool = self.connection_pool.lock();
987 for connection_id in pool.user_connection_ids(user_id) {
988 self.peer.send(
989 connection_id,
990 proto::UpdateInviteInfo {
991 url: format!(
992 "{}{}",
993 self.app_state.config.invite_link_prefix, invite_code
994 ),
995 count: user.invite_count as u32,
996 },
997 )?;
998 }
999 }
1000 }
1001 Ok(())
1002 }
1003
1004 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1005 let user = self
1006 .app_state
1007 .db
1008 .get_user_by_id(user_id)
1009 .await?
1010 .context("user not found")?;
1011
1012 let update_user_plan = make_update_user_plan_message(
1013 &user,
1014 user.admin,
1015 &self.app_state.db,
1016 self.app_state.llm_db.clone(),
1017 )
1018 .await?;
1019
1020 let pool = self.connection_pool.lock();
1021 for connection_id in pool.user_connection_ids(user_id) {
1022 self.peer
1023 .send(connection_id, update_user_plan.clone())
1024 .trace_err();
1025 }
1026
1027 Ok(())
1028 }
1029
1030 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1031 let pool = self.connection_pool.lock();
1032 for connection_id in pool.user_connection_ids(user_id) {
1033 self.peer
1034 .send(connection_id, proto::RefreshLlmToken {})
1035 .trace_err();
1036 }
1037 }
1038
1039 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1040 ServerSnapshot {
1041 connection_pool: ConnectionPoolGuard {
1042 guard: self.connection_pool.lock(),
1043 _not_send: PhantomData,
1044 },
1045 peer: &self.peer,
1046 }
1047 }
1048}
1049
1050impl Deref for ConnectionPoolGuard<'_> {
1051 type Target = ConnectionPool;
1052
1053 fn deref(&self) -> &Self::Target {
1054 &self.guard
1055 }
1056}
1057
1058impl DerefMut for ConnectionPoolGuard<'_> {
1059 fn deref_mut(&mut self) -> &mut Self::Target {
1060 &mut self.guard
1061 }
1062}
1063
1064impl Drop for ConnectionPoolGuard<'_> {
1065 fn drop(&mut self) {
1066 #[cfg(test)]
1067 self.check_invariants();
1068 }
1069}
1070
1071fn broadcast<F>(
1072 sender_id: Option<ConnectionId>,
1073 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1074 mut f: F,
1075) where
1076 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1077{
1078 for receiver_id in receiver_ids {
1079 if Some(receiver_id) != sender_id {
1080 if let Err(error) = f(receiver_id) {
1081 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1082 }
1083 }
1084 }
1085}
1086
1087pub struct ProtocolVersion(u32);
1088
1089impl Header for ProtocolVersion {
1090 fn name() -> &'static HeaderName {
1091 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1092 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1093 }
1094
1095 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1096 where
1097 Self: Sized,
1098 I: Iterator<Item = &'i axum::http::HeaderValue>,
1099 {
1100 let version = values
1101 .next()
1102 .ok_or_else(axum::headers::Error::invalid)?
1103 .to_str()
1104 .map_err(|_| axum::headers::Error::invalid())?
1105 .parse()
1106 .map_err(|_| axum::headers::Error::invalid())?;
1107 Ok(Self(version))
1108 }
1109
1110 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1111 values.extend([self.0.to_string().parse().unwrap()]);
1112 }
1113}
1114
1115pub struct AppVersionHeader(SemanticVersion);
1116impl Header for AppVersionHeader {
1117 fn name() -> &'static HeaderName {
1118 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1120 }
1121
1122 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123 where
1124 Self: Sized,
1125 I: Iterator<Item = &'i axum::http::HeaderValue>,
1126 {
1127 let version = values
1128 .next()
1129 .ok_or_else(axum::headers::Error::invalid)?
1130 .to_str()
1131 .map_err(|_| axum::headers::Error::invalid())?
1132 .parse()
1133 .map_err(|_| axum::headers::Error::invalid())?;
1134 Ok(Self(version))
1135 }
1136
1137 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138 values.extend([self.0.to_string().parse().unwrap()]);
1139 }
1140}
1141
1142pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1143 Router::new()
1144 .route("/rpc", get(handle_websocket_request))
1145 .layer(
1146 ServiceBuilder::new()
1147 .layer(Extension(server.app_state.clone()))
1148 .layer(middleware::from_fn(auth::validate_header)),
1149 )
1150 .route("/metrics", get(handle_metrics))
1151 .layer(Extension(server))
1152}
1153
1154pub async fn handle_websocket_request(
1155 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1156 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1157 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1158 Extension(server): Extension<Arc<Server>>,
1159 Extension(principal): Extension<Principal>,
1160 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1161 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1162 ws: WebSocketUpgrade,
1163) -> axum::response::Response {
1164 if protocol_version != rpc::PROTOCOL_VERSION {
1165 return (
1166 StatusCode::UPGRADE_REQUIRED,
1167 "client must be upgraded".to_string(),
1168 )
1169 .into_response();
1170 }
1171
1172 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1173 return (
1174 StatusCode::UPGRADE_REQUIRED,
1175 "no version header found".to_string(),
1176 )
1177 .into_response();
1178 };
1179
1180 if !version.can_collaborate() {
1181 return (
1182 StatusCode::UPGRADE_REQUIRED,
1183 "client must be upgraded".to_string(),
1184 )
1185 .into_response();
1186 }
1187
1188 let socket_address = socket_address.to_string();
1189
1190 // Acquire connection guard before WebSocket upgrade
1191 let connection_guard = match ConnectionGuard::try_acquire() {
1192 Ok(guard) => guard,
1193 Err(()) => {
1194 return (
1195 StatusCode::SERVICE_UNAVAILABLE,
1196 "Too many concurrent connections",
1197 )
1198 .into_response();
1199 }
1200 };
1201
1202 ws.on_upgrade(move |socket| {
1203 let socket = socket
1204 .map_ok(to_tungstenite_message)
1205 .err_into()
1206 .with(|message| async move { to_axum_message(message) });
1207 let connection = Connection::new(Box::pin(socket));
1208 async move {
1209 server
1210 .handle_connection(
1211 connection,
1212 socket_address,
1213 principal,
1214 version,
1215 country_code_header.map(|header| header.to_string()),
1216 system_id_header.map(|header| header.to_string()),
1217 None,
1218 Executor::Production,
1219 Some(connection_guard),
1220 )
1221 .await;
1222 }
1223 })
1224}
1225
1226pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1227 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1228 let connections_metric = CONNECTIONS_METRIC
1229 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1230
1231 let connections = server
1232 .connection_pool
1233 .lock()
1234 .connections()
1235 .filter(|connection| !connection.admin)
1236 .count();
1237 connections_metric.set(connections as _);
1238
1239 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1240 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1241 register_int_gauge!(
1242 "shared_projects",
1243 "number of open projects with one or more guests"
1244 )
1245 .unwrap()
1246 });
1247
1248 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1249 shared_projects_metric.set(shared_projects as _);
1250
1251 let encoder = prometheus::TextEncoder::new();
1252 let metric_families = prometheus::gather();
1253 let encoded_metrics = encoder
1254 .encode_to_string(&metric_families)
1255 .map_err(|err| anyhow!("{err}"))?;
1256 Ok(encoded_metrics)
1257}
1258
1259#[instrument(err, skip(executor))]
1260async fn connection_lost(
1261 session: Session,
1262 mut teardown: watch::Receiver<bool>,
1263 executor: Executor,
1264) -> Result<()> {
1265 session.peer.disconnect(session.connection_id);
1266 session
1267 .connection_pool()
1268 .await
1269 .remove_connection(session.connection_id)?;
1270
1271 session
1272 .db()
1273 .await
1274 .connection_lost(session.connection_id)
1275 .await
1276 .trace_err();
1277
1278 futures::select_biased! {
1279 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1280
1281 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1282 leave_room_for_session(&session, session.connection_id).await.trace_err();
1283 leave_channel_buffers_for_session(&session)
1284 .await
1285 .trace_err();
1286
1287 if !session
1288 .connection_pool()
1289 .await
1290 .is_user_online(session.user_id())
1291 {
1292 let db = session.db().await;
1293 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1294 room_updated(&room, &session.peer);
1295 }
1296 }
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 },
1300 _ = teardown.changed().fuse() => {}
1301 }
1302
1303 Ok(())
1304}
1305
1306/// Acknowledges a ping from a client, used to keep the connection alive.
1307async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1308 response.send(proto::Ack {})?;
1309 Ok(())
1310}
1311
1312/// Creates a new room for calling (outside of channels)
1313async fn create_room(
1314 _request: proto::CreateRoom,
1315 response: Response<proto::CreateRoom>,
1316 session: Session,
1317) -> Result<()> {
1318 let livekit_room = nanoid::nanoid!(30);
1319
1320 let live_kit_connection_info = util::maybe!(async {
1321 let live_kit = session.app_state.livekit_client.as_ref();
1322 let live_kit = live_kit?;
1323 let user_id = session.user_id().to_string();
1324
1325 let token = live_kit
1326 .room_token(&livekit_room, &user_id.to_string())
1327 .trace_err()?;
1328
1329 Some(proto::LiveKitConnectionInfo {
1330 server_url: live_kit.url().into(),
1331 token,
1332 can_publish: true,
1333 })
1334 })
1335 .await;
1336
1337 let room = session
1338 .db()
1339 .await
1340 .create_room(session.user_id(), session.connection_id, &livekit_room)
1341 .await?;
1342
1343 response.send(proto::CreateRoomResponse {
1344 room: Some(room.clone()),
1345 live_kit_connection_info,
1346 })?;
1347
1348 update_user_contacts(session.user_id(), &session).await?;
1349 Ok(())
1350}
1351
1352/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1353async fn join_room(
1354 request: proto::JoinRoom,
1355 response: Response<proto::JoinRoom>,
1356 session: Session,
1357) -> Result<()> {
1358 let room_id = RoomId::from_proto(request.id);
1359
1360 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1361
1362 if let Some(channel_id) = channel_id {
1363 return join_channel_internal(channel_id, Box::new(response), session).await;
1364 }
1365
1366 let joined_room = {
1367 let room = session
1368 .db()
1369 .await
1370 .join_room(room_id, session.user_id(), session.connection_id)
1371 .await?;
1372 room_updated(&room.room, &session.peer);
1373 room.into_inner()
1374 };
1375
1376 for connection_id in session
1377 .connection_pool()
1378 .await
1379 .user_connection_ids(session.user_id())
1380 {
1381 session
1382 .peer
1383 .send(
1384 connection_id,
1385 proto::CallCanceled {
1386 room_id: room_id.to_proto(),
1387 },
1388 )
1389 .trace_err();
1390 }
1391
1392 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1393 {
1394 live_kit
1395 .room_token(
1396 &joined_room.room.livekit_room,
1397 &session.user_id().to_string(),
1398 )
1399 .trace_err()
1400 .map(|token| proto::LiveKitConnectionInfo {
1401 server_url: live_kit.url().into(),
1402 token,
1403 can_publish: true,
1404 })
1405 } else {
1406 None
1407 };
1408
1409 response.send(proto::JoinRoomResponse {
1410 room: Some(joined_room.room),
1411 channel_id: None,
1412 live_kit_connection_info,
1413 })?;
1414
1415 update_user_contacts(session.user_id(), &session).await?;
1416 Ok(())
1417}
1418
1419/// Rejoin room is used to reconnect to a room after connection errors.
1420async fn rejoin_room(
1421 request: proto::RejoinRoom,
1422 response: Response<proto::RejoinRoom>,
1423 session: Session,
1424) -> Result<()> {
1425 let room;
1426 let channel;
1427 {
1428 let mut rejoined_room = session
1429 .db()
1430 .await
1431 .rejoin_room(request, session.user_id(), session.connection_id)
1432 .await?;
1433
1434 response.send(proto::RejoinRoomResponse {
1435 room: Some(rejoined_room.room.clone()),
1436 reshared_projects: rejoined_room
1437 .reshared_projects
1438 .iter()
1439 .map(|project| proto::ResharedProject {
1440 id: project.id.to_proto(),
1441 collaborators: project
1442 .collaborators
1443 .iter()
1444 .map(|collaborator| collaborator.to_proto())
1445 .collect(),
1446 })
1447 .collect(),
1448 rejoined_projects: rejoined_room
1449 .rejoined_projects
1450 .iter()
1451 .map(|rejoined_project| rejoined_project.to_proto())
1452 .collect(),
1453 })?;
1454 room_updated(&rejoined_room.room, &session.peer);
1455
1456 for project in &rejoined_room.reshared_projects {
1457 for collaborator in &project.collaborators {
1458 session
1459 .peer
1460 .send(
1461 collaborator.connection_id,
1462 proto::UpdateProjectCollaborator {
1463 project_id: project.id.to_proto(),
1464 old_peer_id: Some(project.old_connection_id.into()),
1465 new_peer_id: Some(session.connection_id.into()),
1466 },
1467 )
1468 .trace_err();
1469 }
1470
1471 broadcast(
1472 Some(session.connection_id),
1473 project
1474 .collaborators
1475 .iter()
1476 .map(|collaborator| collaborator.connection_id),
1477 |connection_id| {
1478 session.peer.forward_send(
1479 session.connection_id,
1480 connection_id,
1481 proto::UpdateProject {
1482 project_id: project.id.to_proto(),
1483 worktrees: project.worktrees.clone(),
1484 },
1485 )
1486 },
1487 );
1488 }
1489
1490 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1491
1492 let rejoined_room = rejoined_room.into_inner();
1493
1494 room = rejoined_room.room;
1495 channel = rejoined_room.channel;
1496 }
1497
1498 if let Some(channel) = channel {
1499 channel_updated(
1500 &channel,
1501 &room,
1502 &session.peer,
1503 &*session.connection_pool().await,
1504 );
1505 }
1506
1507 update_user_contacts(session.user_id(), &session).await?;
1508 Ok(())
1509}
1510
1511fn notify_rejoined_projects(
1512 rejoined_projects: &mut Vec<RejoinedProject>,
1513 session: &Session,
1514) -> Result<()> {
1515 for project in rejoined_projects.iter() {
1516 for collaborator in &project.collaborators {
1517 session
1518 .peer
1519 .send(
1520 collaborator.connection_id,
1521 proto::UpdateProjectCollaborator {
1522 project_id: project.id.to_proto(),
1523 old_peer_id: Some(project.old_connection_id.into()),
1524 new_peer_id: Some(session.connection_id.into()),
1525 },
1526 )
1527 .trace_err();
1528 }
1529 }
1530
1531 for project in rejoined_projects {
1532 for worktree in mem::take(&mut project.worktrees) {
1533 // Stream this worktree's entries.
1534 let message = proto::UpdateWorktree {
1535 project_id: project.id.to_proto(),
1536 worktree_id: worktree.id,
1537 abs_path: worktree.abs_path.clone(),
1538 root_name: worktree.root_name,
1539 updated_entries: worktree.updated_entries,
1540 removed_entries: worktree.removed_entries,
1541 scan_id: worktree.scan_id,
1542 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1543 updated_repositories: worktree.updated_repositories,
1544 removed_repositories: worktree.removed_repositories,
1545 };
1546 for update in proto::split_worktree_update(message) {
1547 session.peer.send(session.connection_id, update)?;
1548 }
1549
1550 // Stream this worktree's diagnostics.
1551 for summary in worktree.diagnostic_summaries {
1552 session.peer.send(
1553 session.connection_id,
1554 proto::UpdateDiagnosticSummary {
1555 project_id: project.id.to_proto(),
1556 worktree_id: worktree.id,
1557 summary: Some(summary),
1558 },
1559 )?;
1560 }
1561
1562 for settings_file in worktree.settings_files {
1563 session.peer.send(
1564 session.connection_id,
1565 proto::UpdateWorktreeSettings {
1566 project_id: project.id.to_proto(),
1567 worktree_id: worktree.id,
1568 path: settings_file.path,
1569 content: Some(settings_file.content),
1570 kind: Some(settings_file.kind.to_proto().into()),
1571 },
1572 )?;
1573 }
1574 }
1575
1576 for repository in mem::take(&mut project.updated_repositories) {
1577 for update in split_repository_update(repository) {
1578 session.peer.send(session.connection_id, update)?;
1579 }
1580 }
1581
1582 for id in mem::take(&mut project.removed_repositories) {
1583 session.peer.send(
1584 session.connection_id,
1585 proto::RemoveRepository {
1586 project_id: project.id.to_proto(),
1587 id,
1588 },
1589 )?;
1590 }
1591 }
1592
1593 Ok(())
1594}
1595
1596/// leave room disconnects from the room.
1597async fn leave_room(
1598 _: proto::LeaveRoom,
1599 response: Response<proto::LeaveRoom>,
1600 session: Session,
1601) -> Result<()> {
1602 leave_room_for_session(&session, session.connection_id).await?;
1603 response.send(proto::Ack {})?;
1604 Ok(())
1605}
1606
1607/// Updates the permissions of someone else in the room.
1608async fn set_room_participant_role(
1609 request: proto::SetRoomParticipantRole,
1610 response: Response<proto::SetRoomParticipantRole>,
1611 session: Session,
1612) -> Result<()> {
1613 let user_id = UserId::from_proto(request.user_id);
1614 let role = ChannelRole::from(request.role());
1615
1616 let (livekit_room, can_publish) = {
1617 let room = session
1618 .db()
1619 .await
1620 .set_room_participant_role(
1621 session.user_id(),
1622 RoomId::from_proto(request.room_id),
1623 user_id,
1624 role,
1625 )
1626 .await?;
1627
1628 let livekit_room = room.livekit_room.clone();
1629 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1630 room_updated(&room, &session.peer);
1631 (livekit_room, can_publish)
1632 };
1633
1634 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1635 live_kit
1636 .update_participant(
1637 livekit_room.clone(),
1638 request.user_id.to_string(),
1639 livekit_api::proto::ParticipantPermission {
1640 can_subscribe: true,
1641 can_publish,
1642 can_publish_data: can_publish,
1643 hidden: false,
1644 recorder: false,
1645 },
1646 )
1647 .await
1648 .trace_err();
1649 }
1650
1651 response.send(proto::Ack {})?;
1652 Ok(())
1653}
1654
1655/// Call someone else into the current room
1656async fn call(
1657 request: proto::Call,
1658 response: Response<proto::Call>,
1659 session: Session,
1660) -> Result<()> {
1661 let room_id = RoomId::from_proto(request.room_id);
1662 let calling_user_id = session.user_id();
1663 let calling_connection_id = session.connection_id;
1664 let called_user_id = UserId::from_proto(request.called_user_id);
1665 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1666 if !session
1667 .db()
1668 .await
1669 .has_contact(calling_user_id, called_user_id)
1670 .await?
1671 {
1672 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1673 }
1674
1675 let incoming_call = {
1676 let (room, incoming_call) = &mut *session
1677 .db()
1678 .await
1679 .call(
1680 room_id,
1681 calling_user_id,
1682 calling_connection_id,
1683 called_user_id,
1684 initial_project_id,
1685 )
1686 .await?;
1687 room_updated(room, &session.peer);
1688 mem::take(incoming_call)
1689 };
1690 update_user_contacts(called_user_id, &session).await?;
1691
1692 let mut calls = session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(called_user_id)
1696 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1697 .collect::<FuturesUnordered<_>>();
1698
1699 while let Some(call_response) = calls.next().await {
1700 match call_response.as_ref() {
1701 Ok(_) => {
1702 response.send(proto::Ack {})?;
1703 return Ok(());
1704 }
1705 Err(_) => {
1706 call_response.trace_err();
1707 }
1708 }
1709 }
1710
1711 {
1712 let room = session
1713 .db()
1714 .await
1715 .call_failed(room_id, called_user_id)
1716 .await?;
1717 room_updated(&room, &session.peer);
1718 }
1719 update_user_contacts(called_user_id, &session).await?;
1720
1721 Err(anyhow!("failed to ring user"))?
1722}
1723
1724/// Cancel an outgoing call.
1725async fn cancel_call(
1726 request: proto::CancelCall,
1727 response: Response<proto::CancelCall>,
1728 session: Session,
1729) -> Result<()> {
1730 let called_user_id = UserId::from_proto(request.called_user_id);
1731 let room_id = RoomId::from_proto(request.room_id);
1732 {
1733 let room = session
1734 .db()
1735 .await
1736 .cancel_call(room_id, session.connection_id, called_user_id)
1737 .await?;
1738 room_updated(&room, &session.peer);
1739 }
1740
1741 for connection_id in session
1742 .connection_pool()
1743 .await
1744 .user_connection_ids(called_user_id)
1745 {
1746 session
1747 .peer
1748 .send(
1749 connection_id,
1750 proto::CallCanceled {
1751 room_id: room_id.to_proto(),
1752 },
1753 )
1754 .trace_err();
1755 }
1756 response.send(proto::Ack {})?;
1757
1758 update_user_contacts(called_user_id, &session).await?;
1759 Ok(())
1760}
1761
1762/// Decline an incoming call.
1763async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1764 let room_id = RoomId::from_proto(message.room_id);
1765 {
1766 let room = session
1767 .db()
1768 .await
1769 .decline_call(Some(room_id), session.user_id())
1770 .await?
1771 .context("declining call")?;
1772 room_updated(&room, &session.peer);
1773 }
1774
1775 for connection_id in session
1776 .connection_pool()
1777 .await
1778 .user_connection_ids(session.user_id())
1779 {
1780 session
1781 .peer
1782 .send(
1783 connection_id,
1784 proto::CallCanceled {
1785 room_id: room_id.to_proto(),
1786 },
1787 )
1788 .trace_err();
1789 }
1790 update_user_contacts(session.user_id(), &session).await?;
1791 Ok(())
1792}
1793
1794/// Updates other participants in the room with your current location.
1795async fn update_participant_location(
1796 request: proto::UpdateParticipantLocation,
1797 response: Response<proto::UpdateParticipantLocation>,
1798 session: Session,
1799) -> Result<()> {
1800 let room_id = RoomId::from_proto(request.room_id);
1801 let location = request.location.context("invalid location")?;
1802
1803 let db = session.db().await;
1804 let room = db
1805 .update_room_participant_location(room_id, session.connection_id, location)
1806 .await?;
1807
1808 room_updated(&room, &session.peer);
1809 response.send(proto::Ack {})?;
1810 Ok(())
1811}
1812
1813/// Share a project into the room.
1814async fn share_project(
1815 request: proto::ShareProject,
1816 response: Response<proto::ShareProject>,
1817 session: Session,
1818) -> Result<()> {
1819 let (project_id, room) = &*session
1820 .db()
1821 .await
1822 .share_project(
1823 RoomId::from_proto(request.room_id),
1824 session.connection_id,
1825 &request.worktrees,
1826 request.is_ssh_project,
1827 )
1828 .await?;
1829 response.send(proto::ShareProjectResponse {
1830 project_id: project_id.to_proto(),
1831 })?;
1832 room_updated(room, &session.peer);
1833
1834 Ok(())
1835}
1836
1837/// Unshare a project from the room.
1838async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1839 let project_id = ProjectId::from_proto(message.project_id);
1840 unshare_project_internal(project_id, session.connection_id, &session).await
1841}
1842
1843async fn unshare_project_internal(
1844 project_id: ProjectId,
1845 connection_id: ConnectionId,
1846 session: &Session,
1847) -> Result<()> {
1848 let delete = {
1849 let room_guard = session
1850 .db()
1851 .await
1852 .unshare_project(project_id, connection_id)
1853 .await?;
1854
1855 let (delete, room, guest_connection_ids) = &*room_guard;
1856
1857 let message = proto::UnshareProject {
1858 project_id: project_id.to_proto(),
1859 };
1860
1861 broadcast(
1862 Some(connection_id),
1863 guest_connection_ids.iter().copied(),
1864 |conn_id| session.peer.send(conn_id, message.clone()),
1865 );
1866 if let Some(room) = room {
1867 room_updated(room, &session.peer);
1868 }
1869
1870 *delete
1871 };
1872
1873 if delete {
1874 let db = session.db().await;
1875 db.delete_project(project_id).await?;
1876 }
1877
1878 Ok(())
1879}
1880
1881/// Join someone elses shared project.
1882async fn join_project(
1883 request: proto::JoinProject,
1884 response: Response<proto::JoinProject>,
1885 session: Session,
1886) -> Result<()> {
1887 let project_id = ProjectId::from_proto(request.project_id);
1888
1889 tracing::info!(%project_id, "join project");
1890
1891 let db = session.db().await;
1892 let (project, replica_id) = &mut *db
1893 .join_project(
1894 project_id,
1895 session.connection_id,
1896 session.user_id(),
1897 request.committer_name.clone(),
1898 request.committer_email.clone(),
1899 )
1900 .await?;
1901 drop(db);
1902 tracing::info!(%project_id, "join remote project");
1903 let collaborators = project
1904 .collaborators
1905 .iter()
1906 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1907 .map(|collaborator| collaborator.to_proto())
1908 .collect::<Vec<_>>();
1909 let project_id = project.id;
1910 let guest_user_id = session.user_id();
1911
1912 let worktrees = project
1913 .worktrees
1914 .iter()
1915 .map(|(id, worktree)| proto::WorktreeMetadata {
1916 id: *id,
1917 root_name: worktree.root_name.clone(),
1918 visible: worktree.visible,
1919 abs_path: worktree.abs_path.clone(),
1920 })
1921 .collect::<Vec<_>>();
1922
1923 let add_project_collaborator = proto::AddProjectCollaborator {
1924 project_id: project_id.to_proto(),
1925 collaborator: Some(proto::Collaborator {
1926 peer_id: Some(session.connection_id.into()),
1927 replica_id: replica_id.0 as u32,
1928 user_id: guest_user_id.to_proto(),
1929 is_host: false,
1930 committer_name: request.committer_name.clone(),
1931 committer_email: request.committer_email.clone(),
1932 }),
1933 };
1934
1935 for collaborator in &collaborators {
1936 session
1937 .peer
1938 .send(
1939 collaborator.peer_id.unwrap().into(),
1940 add_project_collaborator.clone(),
1941 )
1942 .trace_err();
1943 }
1944
1945 // First, we send the metadata associated with each worktree.
1946 response.send(proto::JoinProjectResponse {
1947 project_id: project.id.0 as u64,
1948 worktrees: worktrees.clone(),
1949 replica_id: replica_id.0 as u32,
1950 collaborators: collaborators.clone(),
1951 language_servers: project.language_servers.clone(),
1952 role: project.role.into(),
1953 })?;
1954
1955 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1956 // Stream this worktree's entries.
1957 let message = proto::UpdateWorktree {
1958 project_id: project_id.to_proto(),
1959 worktree_id,
1960 abs_path: worktree.abs_path.clone(),
1961 root_name: worktree.root_name,
1962 updated_entries: worktree.entries,
1963 removed_entries: Default::default(),
1964 scan_id: worktree.scan_id,
1965 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1966 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1967 removed_repositories: Default::default(),
1968 };
1969 for update in proto::split_worktree_update(message) {
1970 session.peer.send(session.connection_id, update.clone())?;
1971 }
1972
1973 // Stream this worktree's diagnostics.
1974 for summary in worktree.diagnostic_summaries {
1975 session.peer.send(
1976 session.connection_id,
1977 proto::UpdateDiagnosticSummary {
1978 project_id: project_id.to_proto(),
1979 worktree_id: worktree.id,
1980 summary: Some(summary),
1981 },
1982 )?;
1983 }
1984
1985 for settings_file in worktree.settings_files {
1986 session.peer.send(
1987 session.connection_id,
1988 proto::UpdateWorktreeSettings {
1989 project_id: project_id.to_proto(),
1990 worktree_id: worktree.id,
1991 path: settings_file.path,
1992 content: Some(settings_file.content),
1993 kind: Some(settings_file.kind.to_proto() as i32),
1994 },
1995 )?;
1996 }
1997 }
1998
1999 for repository in mem::take(&mut project.repositories) {
2000 for update in split_repository_update(repository) {
2001 session.peer.send(session.connection_id, update)?;
2002 }
2003 }
2004
2005 for language_server in &project.language_servers {
2006 session.peer.send(
2007 session.connection_id,
2008 proto::UpdateLanguageServer {
2009 project_id: project_id.to_proto(),
2010 language_server_id: language_server.id,
2011 variant: Some(
2012 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2013 proto::LspDiskBasedDiagnosticsUpdated {},
2014 ),
2015 ),
2016 },
2017 )?;
2018 }
2019
2020 Ok(())
2021}
2022
2023/// Leave someone elses shared project.
2024async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2025 let sender_id = session.connection_id;
2026 let project_id = ProjectId::from_proto(request.project_id);
2027 let db = session.db().await;
2028
2029 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2030 tracing::info!(
2031 %project_id,
2032 "leave project"
2033 );
2034
2035 project_left(project, &session);
2036 if let Some(room) = room {
2037 room_updated(room, &session.peer);
2038 }
2039
2040 Ok(())
2041}
2042
2043/// Updates other participants with changes to the project
2044async fn update_project(
2045 request: proto::UpdateProject,
2046 response: Response<proto::UpdateProject>,
2047 session: Session,
2048) -> Result<()> {
2049 let project_id = ProjectId::from_proto(request.project_id);
2050 let (room, guest_connection_ids) = &*session
2051 .db()
2052 .await
2053 .update_project(project_id, session.connection_id, &request.worktrees)
2054 .await?;
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, request.clone())
2062 },
2063 );
2064 if let Some(room) = room {
2065 room_updated(room, &session.peer);
2066 }
2067 response.send(proto::Ack {})?;
2068
2069 Ok(())
2070}
2071
2072/// Updates other participants with changes to the worktree
2073async fn update_worktree(
2074 request: proto::UpdateWorktree,
2075 response: Response<proto::UpdateWorktree>,
2076 session: Session,
2077) -> Result<()> {
2078 let guest_connection_ids = session
2079 .db()
2080 .await
2081 .update_worktree(&request, session.connection_id)
2082 .await?;
2083
2084 broadcast(
2085 Some(session.connection_id),
2086 guest_connection_ids.iter().copied(),
2087 |connection_id| {
2088 session
2089 .peer
2090 .forward_send(session.connection_id, connection_id, request.clone())
2091 },
2092 );
2093 response.send(proto::Ack {})?;
2094 Ok(())
2095}
2096
2097async fn update_repository(
2098 request: proto::UpdateRepository,
2099 response: Response<proto::UpdateRepository>,
2100 session: Session,
2101) -> Result<()> {
2102 let guest_connection_ids = session
2103 .db()
2104 .await
2105 .update_repository(&request, session.connection_id)
2106 .await?;
2107
2108 broadcast(
2109 Some(session.connection_id),
2110 guest_connection_ids.iter().copied(),
2111 |connection_id| {
2112 session
2113 .peer
2114 .forward_send(session.connection_id, connection_id, request.clone())
2115 },
2116 );
2117 response.send(proto::Ack {})?;
2118 Ok(())
2119}
2120
2121async fn remove_repository(
2122 request: proto::RemoveRepository,
2123 response: Response<proto::RemoveRepository>,
2124 session: Session,
2125) -> Result<()> {
2126 let guest_connection_ids = session
2127 .db()
2128 .await
2129 .remove_repository(&request, session.connection_id)
2130 .await?;
2131
2132 broadcast(
2133 Some(session.connection_id),
2134 guest_connection_ids.iter().copied(),
2135 |connection_id| {
2136 session
2137 .peer
2138 .forward_send(session.connection_id, connection_id, request.clone())
2139 },
2140 );
2141 response.send(proto::Ack {})?;
2142 Ok(())
2143}
2144
2145/// Updates other participants with changes to the diagnostics
2146async fn update_diagnostic_summary(
2147 message: proto::UpdateDiagnosticSummary,
2148 session: Session,
2149) -> Result<()> {
2150 let guest_connection_ids = session
2151 .db()
2152 .await
2153 .update_diagnostic_summary(&message, session.connection_id)
2154 .await?;
2155
2156 broadcast(
2157 Some(session.connection_id),
2158 guest_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, message.clone())
2163 },
2164 );
2165
2166 Ok(())
2167}
2168
2169/// Updates other participants with changes to the worktree settings
2170async fn update_worktree_settings(
2171 message: proto::UpdateWorktreeSettings,
2172 session: Session,
2173) -> Result<()> {
2174 let guest_connection_ids = session
2175 .db()
2176 .await
2177 .update_worktree_settings(&message, session.connection_id)
2178 .await?;
2179
2180 broadcast(
2181 Some(session.connection_id),
2182 guest_connection_ids.iter().copied(),
2183 |connection_id| {
2184 session
2185 .peer
2186 .forward_send(session.connection_id, connection_id, message.clone())
2187 },
2188 );
2189
2190 Ok(())
2191}
2192
2193/// Notify other participants that a language server has started.
2194async fn start_language_server(
2195 request: proto::StartLanguageServer,
2196 session: Session,
2197) -> Result<()> {
2198 let guest_connection_ids = session
2199 .db()
2200 .await
2201 .start_language_server(&request, session.connection_id)
2202 .await?;
2203
2204 broadcast(
2205 Some(session.connection_id),
2206 guest_connection_ids.iter().copied(),
2207 |connection_id| {
2208 session
2209 .peer
2210 .forward_send(session.connection_id, connection_id, request.clone())
2211 },
2212 );
2213 Ok(())
2214}
2215
2216/// Notify other participants that a language server has changed.
2217async fn update_language_server(
2218 request: proto::UpdateLanguageServer,
2219 session: Session,
2220) -> Result<()> {
2221 let project_id = ProjectId::from_proto(request.project_id);
2222 let project_connection_ids = session
2223 .db()
2224 .await
2225 .project_connection_ids(project_id, session.connection_id, true)
2226 .await?;
2227 broadcast(
2228 Some(session.connection_id),
2229 project_connection_ids.iter().copied(),
2230 |connection_id| {
2231 session
2232 .peer
2233 .forward_send(session.connection_id, connection_id, request.clone())
2234 },
2235 );
2236 Ok(())
2237}
2238
2239/// forward a project request to the host. These requests should be read only
2240/// as guests are allowed to send them.
2241async fn forward_read_only_project_request<T>(
2242 request: T,
2243 response: Response<T>,
2244 session: Session,
2245) -> Result<()>
2246where
2247 T: EntityMessage + RequestMessage,
2248{
2249 let project_id = ProjectId::from_proto(request.remote_entity_id());
2250 let host_connection_id = session
2251 .db()
2252 .await
2253 .host_for_read_only_project_request(project_id, session.connection_id)
2254 .await?;
2255 let payload = session
2256 .peer
2257 .forward_request(session.connection_id, host_connection_id, request)
2258 .await?;
2259 response.send(payload)?;
2260 Ok(())
2261}
2262
2263async fn forward_find_search_candidates_request(
2264 request: proto::FindSearchCandidates,
2265 response: Response<proto::FindSearchCandidates>,
2266 session: Session,
2267) -> Result<()> {
2268 let project_id = ProjectId::from_proto(request.remote_entity_id());
2269 let host_connection_id = session
2270 .db()
2271 .await
2272 .host_for_read_only_project_request(project_id, session.connection_id)
2273 .await?;
2274 let payload = session
2275 .peer
2276 .forward_request(session.connection_id, host_connection_id, request)
2277 .await?;
2278 response.send(payload)?;
2279 Ok(())
2280}
2281
2282/// forward a project request to the host. These requests are disallowed
2283/// for guests.
2284async fn forward_mutating_project_request<T>(
2285 request: T,
2286 response: Response<T>,
2287 session: Session,
2288) -> Result<()>
2289where
2290 T: EntityMessage + RequestMessage,
2291{
2292 let project_id = ProjectId::from_proto(request.remote_entity_id());
2293
2294 let host_connection_id = session
2295 .db()
2296 .await
2297 .host_for_mutating_project_request(project_id, session.connection_id)
2298 .await?;
2299 let payload = session
2300 .peer
2301 .forward_request(session.connection_id, host_connection_id, request)
2302 .await?;
2303 response.send(payload)?;
2304 Ok(())
2305}
2306
2307/// Notify other participants that a new buffer has been created
2308async fn create_buffer_for_peer(
2309 request: proto::CreateBufferForPeer,
2310 session: Session,
2311) -> Result<()> {
2312 session
2313 .db()
2314 .await
2315 .check_user_is_project_host(
2316 ProjectId::from_proto(request.project_id),
2317 session.connection_id,
2318 )
2319 .await?;
2320 let peer_id = request.peer_id.context("invalid peer id")?;
2321 session
2322 .peer
2323 .forward_send(session.connection_id, peer_id.into(), request)?;
2324 Ok(())
2325}
2326
2327/// Notify other participants that a buffer has been updated. This is
2328/// allowed for guests as long as the update is limited to selections.
2329async fn update_buffer(
2330 request: proto::UpdateBuffer,
2331 response: Response<proto::UpdateBuffer>,
2332 session: Session,
2333) -> Result<()> {
2334 let project_id = ProjectId::from_proto(request.project_id);
2335 let mut capability = Capability::ReadOnly;
2336
2337 for op in request.operations.iter() {
2338 match op.variant {
2339 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2340 Some(_) => capability = Capability::ReadWrite,
2341 }
2342 }
2343
2344 let host = {
2345 let guard = session
2346 .db()
2347 .await
2348 .connections_for_buffer_update(project_id, session.connection_id, capability)
2349 .await?;
2350
2351 let (host, guests) = &*guard;
2352
2353 broadcast(
2354 Some(session.connection_id),
2355 guests.clone(),
2356 |connection_id| {
2357 session
2358 .peer
2359 .forward_send(session.connection_id, connection_id, request.clone())
2360 },
2361 );
2362
2363 *host
2364 };
2365
2366 if host != session.connection_id {
2367 session
2368 .peer
2369 .forward_request(session.connection_id, host, request.clone())
2370 .await?;
2371 }
2372
2373 response.send(proto::Ack {})?;
2374 Ok(())
2375}
2376
2377async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2378 let project_id = ProjectId::from_proto(message.project_id);
2379
2380 let operation = message.operation.as_ref().context("invalid operation")?;
2381 let capability = match operation.variant.as_ref() {
2382 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2383 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2384 match buffer_op.variant {
2385 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2386 Capability::ReadOnly
2387 }
2388 _ => Capability::ReadWrite,
2389 }
2390 } else {
2391 Capability::ReadWrite
2392 }
2393 }
2394 Some(_) => Capability::ReadWrite,
2395 None => Capability::ReadOnly,
2396 };
2397
2398 let guard = session
2399 .db()
2400 .await
2401 .connections_for_buffer_update(project_id, session.connection_id, capability)
2402 .await?;
2403
2404 let (host, guests) = &*guard;
2405
2406 broadcast(
2407 Some(session.connection_id),
2408 guests.iter().chain([host]).copied(),
2409 |connection_id| {
2410 session
2411 .peer
2412 .forward_send(session.connection_id, connection_id, message.clone())
2413 },
2414 );
2415
2416 Ok(())
2417}
2418
2419/// Notify other participants that a project has been updated.
2420async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2421 request: T,
2422 session: Session,
2423) -> Result<()> {
2424 let project_id = ProjectId::from_proto(request.remote_entity_id());
2425 let project_connection_ids = session
2426 .db()
2427 .await
2428 .project_connection_ids(project_id, session.connection_id, false)
2429 .await?;
2430
2431 broadcast(
2432 Some(session.connection_id),
2433 project_connection_ids.iter().copied(),
2434 |connection_id| {
2435 session
2436 .peer
2437 .forward_send(session.connection_id, connection_id, request.clone())
2438 },
2439 );
2440 Ok(())
2441}
2442
2443/// Start following another user in a call.
2444async fn follow(
2445 request: proto::Follow,
2446 response: Response<proto::Follow>,
2447 session: Session,
2448) -> Result<()> {
2449 let room_id = RoomId::from_proto(request.room_id);
2450 let project_id = request.project_id.map(ProjectId::from_proto);
2451 let leader_id = request.leader_id.context("invalid leader id")?.into();
2452 let follower_id = session.connection_id;
2453
2454 session
2455 .db()
2456 .await
2457 .check_room_participants(room_id, leader_id, session.connection_id)
2458 .await?;
2459
2460 let response_payload = session
2461 .peer
2462 .forward_request(session.connection_id, leader_id, request)
2463 .await?;
2464 response.send(response_payload)?;
2465
2466 if let Some(project_id) = project_id {
2467 let room = session
2468 .db()
2469 .await
2470 .follow(room_id, project_id, leader_id, follower_id)
2471 .await?;
2472 room_updated(&room, &session.peer);
2473 }
2474
2475 Ok(())
2476}
2477
2478/// Stop following another user in a call.
2479async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2480 let room_id = RoomId::from_proto(request.room_id);
2481 let project_id = request.project_id.map(ProjectId::from_proto);
2482 let leader_id = request.leader_id.context("invalid leader id")?.into();
2483 let follower_id = session.connection_id;
2484
2485 session
2486 .db()
2487 .await
2488 .check_room_participants(room_id, leader_id, session.connection_id)
2489 .await?;
2490
2491 session
2492 .peer
2493 .forward_send(session.connection_id, leader_id, request)?;
2494
2495 if let Some(project_id) = project_id {
2496 let room = session
2497 .db()
2498 .await
2499 .unfollow(room_id, project_id, leader_id, follower_id)
2500 .await?;
2501 room_updated(&room, &session.peer);
2502 }
2503
2504 Ok(())
2505}
2506
2507/// Notify everyone following you of your current location.
2508async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2509 let room_id = RoomId::from_proto(request.room_id);
2510 let database = session.db.lock().await;
2511
2512 let connection_ids = if let Some(project_id) = request.project_id {
2513 let project_id = ProjectId::from_proto(project_id);
2514 database
2515 .project_connection_ids(project_id, session.connection_id, true)
2516 .await?
2517 } else {
2518 database
2519 .room_connection_ids(room_id, session.connection_id)
2520 .await?
2521 };
2522
2523 // For now, don't send view update messages back to that view's current leader.
2524 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2525 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2526 _ => None,
2527 });
2528
2529 for connection_id in connection_ids.iter().cloned() {
2530 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2531 session
2532 .peer
2533 .forward_send(session.connection_id, connection_id, request.clone())?;
2534 }
2535 }
2536 Ok(())
2537}
2538
2539/// Get public data about users.
2540async fn get_users(
2541 request: proto::GetUsers,
2542 response: Response<proto::GetUsers>,
2543 session: Session,
2544) -> Result<()> {
2545 let user_ids = request
2546 .user_ids
2547 .into_iter()
2548 .map(UserId::from_proto)
2549 .collect();
2550 let users = session
2551 .db()
2552 .await
2553 .get_users_by_ids(user_ids)
2554 .await?
2555 .into_iter()
2556 .map(|user| proto::User {
2557 id: user.id.to_proto(),
2558 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2559 github_login: user.github_login,
2560 name: user.name,
2561 })
2562 .collect();
2563 response.send(proto::UsersResponse { users })?;
2564 Ok(())
2565}
2566
2567/// Search for users (to invite) buy Github login
2568async fn fuzzy_search_users(
2569 request: proto::FuzzySearchUsers,
2570 response: Response<proto::FuzzySearchUsers>,
2571 session: Session,
2572) -> Result<()> {
2573 let query = request.query;
2574 let users = match query.len() {
2575 0 => vec![],
2576 1 | 2 => session
2577 .db()
2578 .await
2579 .get_user_by_github_login(&query)
2580 .await?
2581 .into_iter()
2582 .collect(),
2583 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2584 };
2585 let users = users
2586 .into_iter()
2587 .filter(|user| user.id != session.user_id())
2588 .map(|user| proto::User {
2589 id: user.id.to_proto(),
2590 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2591 github_login: user.github_login,
2592 name: user.name,
2593 })
2594 .collect();
2595 response.send(proto::UsersResponse { users })?;
2596 Ok(())
2597}
2598
2599/// Send a contact request to another user.
2600async fn request_contact(
2601 request: proto::RequestContact,
2602 response: Response<proto::RequestContact>,
2603 session: Session,
2604) -> Result<()> {
2605 let requester_id = session.user_id();
2606 let responder_id = UserId::from_proto(request.responder_id);
2607 if requester_id == responder_id {
2608 return Err(anyhow!("cannot add yourself as a contact"))?;
2609 }
2610
2611 let notifications = session
2612 .db()
2613 .await
2614 .send_contact_request(requester_id, responder_id)
2615 .await?;
2616
2617 // Update outgoing contact requests of requester
2618 let mut update = proto::UpdateContacts::default();
2619 update.outgoing_requests.push(responder_id.to_proto());
2620 for connection_id in session
2621 .connection_pool()
2622 .await
2623 .user_connection_ids(requester_id)
2624 {
2625 session.peer.send(connection_id, update.clone())?;
2626 }
2627
2628 // Update incoming contact requests of responder
2629 let mut update = proto::UpdateContacts::default();
2630 update
2631 .incoming_requests
2632 .push(proto::IncomingContactRequest {
2633 requester_id: requester_id.to_proto(),
2634 });
2635 let connection_pool = session.connection_pool().await;
2636 for connection_id in connection_pool.user_connection_ids(responder_id) {
2637 session.peer.send(connection_id, update.clone())?;
2638 }
2639
2640 send_notifications(&connection_pool, &session.peer, notifications);
2641
2642 response.send(proto::Ack {})?;
2643 Ok(())
2644}
2645
2646/// Accept or decline a contact request
2647async fn respond_to_contact_request(
2648 request: proto::RespondToContactRequest,
2649 response: Response<proto::RespondToContactRequest>,
2650 session: Session,
2651) -> Result<()> {
2652 let responder_id = session.user_id();
2653 let requester_id = UserId::from_proto(request.requester_id);
2654 let db = session.db().await;
2655 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2656 db.dismiss_contact_notification(responder_id, requester_id)
2657 .await?;
2658 } else {
2659 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2660
2661 let notifications = db
2662 .respond_to_contact_request(responder_id, requester_id, accept)
2663 .await?;
2664 let requester_busy = db.is_user_busy(requester_id).await?;
2665 let responder_busy = db.is_user_busy(responder_id).await?;
2666
2667 let pool = session.connection_pool().await;
2668 // Update responder with new contact
2669 let mut update = proto::UpdateContacts::default();
2670 if accept {
2671 update
2672 .contacts
2673 .push(contact_for_user(requester_id, requester_busy, &pool));
2674 }
2675 update
2676 .remove_incoming_requests
2677 .push(requester_id.to_proto());
2678 for connection_id in pool.user_connection_ids(responder_id) {
2679 session.peer.send(connection_id, update.clone())?;
2680 }
2681
2682 // Update requester with new contact
2683 let mut update = proto::UpdateContacts::default();
2684 if accept {
2685 update
2686 .contacts
2687 .push(contact_for_user(responder_id, responder_busy, &pool));
2688 }
2689 update
2690 .remove_outgoing_requests
2691 .push(responder_id.to_proto());
2692
2693 for connection_id in pool.user_connection_ids(requester_id) {
2694 session.peer.send(connection_id, update.clone())?;
2695 }
2696
2697 send_notifications(&pool, &session.peer, notifications);
2698 }
2699
2700 response.send(proto::Ack {})?;
2701 Ok(())
2702}
2703
2704/// Remove a contact.
2705async fn remove_contact(
2706 request: proto::RemoveContact,
2707 response: Response<proto::RemoveContact>,
2708 session: Session,
2709) -> Result<()> {
2710 let requester_id = session.user_id();
2711 let responder_id = UserId::from_proto(request.user_id);
2712 let db = session.db().await;
2713 let (contact_accepted, deleted_notification_id) =
2714 db.remove_contact(requester_id, responder_id).await?;
2715
2716 let pool = session.connection_pool().await;
2717 // Update outgoing contact requests of requester
2718 let mut update = proto::UpdateContacts::default();
2719 if contact_accepted {
2720 update.remove_contacts.push(responder_id.to_proto());
2721 } else {
2722 update
2723 .remove_outgoing_requests
2724 .push(responder_id.to_proto());
2725 }
2726 for connection_id in pool.user_connection_ids(requester_id) {
2727 session.peer.send(connection_id, update.clone())?;
2728 }
2729
2730 // Update incoming contact requests of responder
2731 let mut update = proto::UpdateContacts::default();
2732 if contact_accepted {
2733 update.remove_contacts.push(requester_id.to_proto());
2734 } else {
2735 update
2736 .remove_incoming_requests
2737 .push(requester_id.to_proto());
2738 }
2739 for connection_id in pool.user_connection_ids(responder_id) {
2740 session.peer.send(connection_id, update.clone())?;
2741 if let Some(notification_id) = deleted_notification_id {
2742 session.peer.send(
2743 connection_id,
2744 proto::DeleteNotification {
2745 notification_id: notification_id.to_proto(),
2746 },
2747 )?;
2748 }
2749 }
2750
2751 response.send(proto::Ack {})?;
2752 Ok(())
2753}
2754
2755fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2756 version.0.minor() < 139
2757}
2758
2759async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2760 if is_staff {
2761 return Ok(proto::Plan::ZedPro);
2762 }
2763
2764 let subscription = db.get_active_billing_subscription(user_id).await?;
2765 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2766
2767 let plan = if let Some(subscription_kind) = subscription_kind {
2768 match subscription_kind {
2769 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2770 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2771 SubscriptionKind::ZedFree => proto::Plan::Free,
2772 }
2773 } else {
2774 proto::Plan::Free
2775 };
2776
2777 Ok(plan)
2778}
2779
2780async fn make_update_user_plan_message(
2781 user: &User,
2782 is_staff: bool,
2783 db: &Arc<Database>,
2784 llm_db: Option<Arc<LlmDatabase>>,
2785) -> Result<proto::UpdateUserPlan> {
2786 let feature_flags = db.get_user_flags(user.id).await?;
2787 let plan = current_plan(db, user.id, is_staff).await?;
2788 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2789 let billing_preferences = db.get_billing_preferences(user.id).await?;
2790
2791 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2792 let subscription = db.get_active_billing_subscription(user.id).await?;
2793
2794 let subscription_period =
2795 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2796
2797 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2798 llm_db
2799 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2800 .await?
2801 } else {
2802 None
2803 };
2804
2805 (subscription_period, usage)
2806 } else {
2807 (None, None)
2808 };
2809
2810 let bypass_account_age_check = feature_flags
2811 .iter()
2812 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2813 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2814 && !bypass_account_age_check
2815 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2816
2817 Ok(proto::UpdateUserPlan {
2818 plan: plan.into(),
2819 trial_started_at: billing_customer
2820 .as_ref()
2821 .and_then(|billing_customer| billing_customer.trial_started_at)
2822 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2823 is_usage_based_billing_enabled: if is_staff {
2824 Some(true)
2825 } else {
2826 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2827 },
2828 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2829 proto::SubscriptionPeriod {
2830 started_at: started_at.timestamp() as u64,
2831 ended_at: ended_at.timestamp() as u64,
2832 }
2833 }),
2834 account_too_young: Some(account_too_young),
2835 has_overdue_invoices: billing_customer
2836 .map(|billing_customer| billing_customer.has_overdue_invoices),
2837 usage: usage.map(|usage| {
2838 let plan = match plan {
2839 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2840 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2841 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2842 };
2843
2844 let model_requests_limit = match plan.model_requests_limit() {
2845 zed_llm_client::UsageLimit::Limited(limit) => {
2846 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2847 && feature_flags
2848 .iter()
2849 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2850 {
2851 1_000
2852 } else {
2853 limit
2854 };
2855
2856 zed_llm_client::UsageLimit::Limited(limit)
2857 }
2858 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2859 };
2860
2861 proto::SubscriptionUsage {
2862 model_requests_usage_amount: usage.model_requests as u32,
2863 model_requests_usage_limit: Some(proto::UsageLimit {
2864 variant: Some(match model_requests_limit {
2865 zed_llm_client::UsageLimit::Limited(limit) => {
2866 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2867 limit: limit as u32,
2868 })
2869 }
2870 zed_llm_client::UsageLimit::Unlimited => {
2871 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2872 }
2873 }),
2874 }),
2875 edit_predictions_usage_amount: usage.edit_predictions as u32,
2876 edit_predictions_usage_limit: Some(proto::UsageLimit {
2877 variant: Some(match plan.edit_predictions_limit() {
2878 zed_llm_client::UsageLimit::Limited(limit) => {
2879 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2880 limit: limit as u32,
2881 })
2882 }
2883 zed_llm_client::UsageLimit::Unlimited => {
2884 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2885 }
2886 }),
2887 }),
2888 }
2889 }),
2890 })
2891}
2892
2893async fn update_user_plan(session: &Session) -> Result<()> {
2894 let db = session.db().await;
2895
2896 let update_user_plan = make_update_user_plan_message(
2897 session.principal.user(),
2898 session.is_staff(),
2899 &db.0,
2900 session.app_state.llm_db.clone(),
2901 )
2902 .await?;
2903
2904 session
2905 .peer
2906 .send(session.connection_id, update_user_plan)
2907 .trace_err();
2908
2909 Ok(())
2910}
2911
2912async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2913 subscribe_user_to_channels(session.user_id(), &session).await?;
2914 Ok(())
2915}
2916
2917async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2918 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2919 let mut pool = session.connection_pool().await;
2920 for membership in &channels_for_user.channel_memberships {
2921 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2922 }
2923 session.peer.send(
2924 session.connection_id,
2925 build_update_user_channels(&channels_for_user),
2926 )?;
2927 session.peer.send(
2928 session.connection_id,
2929 build_channels_update(channels_for_user),
2930 )?;
2931 Ok(())
2932}
2933
2934/// Creates a new channel.
2935async fn create_channel(
2936 request: proto::CreateChannel,
2937 response: Response<proto::CreateChannel>,
2938 session: Session,
2939) -> Result<()> {
2940 let db = session.db().await;
2941
2942 let parent_id = request.parent_id.map(ChannelId::from_proto);
2943 let (channel, membership) = db
2944 .create_channel(&request.name, parent_id, session.user_id())
2945 .await?;
2946
2947 let root_id = channel.root_id();
2948 let channel = Channel::from_model(channel);
2949
2950 response.send(proto::CreateChannelResponse {
2951 channel: Some(channel.to_proto()),
2952 parent_id: request.parent_id,
2953 })?;
2954
2955 let mut connection_pool = session.connection_pool().await;
2956 if let Some(membership) = membership {
2957 connection_pool.subscribe_to_channel(
2958 membership.user_id,
2959 membership.channel_id,
2960 membership.role,
2961 );
2962 let update = proto::UpdateUserChannels {
2963 channel_memberships: vec![proto::ChannelMembership {
2964 channel_id: membership.channel_id.to_proto(),
2965 role: membership.role.into(),
2966 }],
2967 ..Default::default()
2968 };
2969 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2970 session.peer.send(connection_id, update.clone())?;
2971 }
2972 }
2973
2974 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2975 if !role.can_see_channel(channel.visibility) {
2976 continue;
2977 }
2978
2979 let update = proto::UpdateChannels {
2980 channels: vec![channel.to_proto()],
2981 ..Default::default()
2982 };
2983 session.peer.send(connection_id, update.clone())?;
2984 }
2985
2986 Ok(())
2987}
2988
2989/// Delete a channel
2990async fn delete_channel(
2991 request: proto::DeleteChannel,
2992 response: Response<proto::DeleteChannel>,
2993 session: Session,
2994) -> Result<()> {
2995 let db = session.db().await;
2996
2997 let channel_id = request.channel_id;
2998 let (root_channel, removed_channels) = db
2999 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3000 .await?;
3001 response.send(proto::Ack {})?;
3002
3003 // Notify members of removed channels
3004 let mut update = proto::UpdateChannels::default();
3005 update
3006 .delete_channels
3007 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3008
3009 let connection_pool = session.connection_pool().await;
3010 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3011 session.peer.send(connection_id, update.clone())?;
3012 }
3013
3014 Ok(())
3015}
3016
3017/// Invite someone to join a channel.
3018async fn invite_channel_member(
3019 request: proto::InviteChannelMember,
3020 response: Response<proto::InviteChannelMember>,
3021 session: Session,
3022) -> Result<()> {
3023 let db = session.db().await;
3024 let channel_id = ChannelId::from_proto(request.channel_id);
3025 let invitee_id = UserId::from_proto(request.user_id);
3026 let InviteMemberResult {
3027 channel,
3028 notifications,
3029 } = db
3030 .invite_channel_member(
3031 channel_id,
3032 invitee_id,
3033 session.user_id(),
3034 request.role().into(),
3035 )
3036 .await?;
3037
3038 let update = proto::UpdateChannels {
3039 channel_invitations: vec![channel.to_proto()],
3040 ..Default::default()
3041 };
3042
3043 let connection_pool = session.connection_pool().await;
3044 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3045 session.peer.send(connection_id, update.clone())?;
3046 }
3047
3048 send_notifications(&connection_pool, &session.peer, notifications);
3049
3050 response.send(proto::Ack {})?;
3051 Ok(())
3052}
3053
3054/// remove someone from a channel
3055async fn remove_channel_member(
3056 request: proto::RemoveChannelMember,
3057 response: Response<proto::RemoveChannelMember>,
3058 session: Session,
3059) -> Result<()> {
3060 let db = session.db().await;
3061 let channel_id = ChannelId::from_proto(request.channel_id);
3062 let member_id = UserId::from_proto(request.user_id);
3063
3064 let RemoveChannelMemberResult {
3065 membership_update,
3066 notification_id,
3067 } = db
3068 .remove_channel_member(channel_id, member_id, session.user_id())
3069 .await?;
3070
3071 let mut connection_pool = session.connection_pool().await;
3072 notify_membership_updated(
3073 &mut connection_pool,
3074 membership_update,
3075 member_id,
3076 &session.peer,
3077 );
3078 for connection_id in connection_pool.user_connection_ids(member_id) {
3079 if let Some(notification_id) = notification_id {
3080 session
3081 .peer
3082 .send(
3083 connection_id,
3084 proto::DeleteNotification {
3085 notification_id: notification_id.to_proto(),
3086 },
3087 )
3088 .trace_err();
3089 }
3090 }
3091
3092 response.send(proto::Ack {})?;
3093 Ok(())
3094}
3095
3096/// Toggle the channel between public and private.
3097/// Care is taken to maintain the invariant that public channels only descend from public channels,
3098/// (though members-only channels can appear at any point in the hierarchy).
3099async fn set_channel_visibility(
3100 request: proto::SetChannelVisibility,
3101 response: Response<proto::SetChannelVisibility>,
3102 session: Session,
3103) -> Result<()> {
3104 let db = session.db().await;
3105 let channel_id = ChannelId::from_proto(request.channel_id);
3106 let visibility = request.visibility().into();
3107
3108 let channel_model = db
3109 .set_channel_visibility(channel_id, visibility, session.user_id())
3110 .await?;
3111 let root_id = channel_model.root_id();
3112 let channel = Channel::from_model(channel_model);
3113
3114 let mut connection_pool = session.connection_pool().await;
3115 for (user_id, role) in connection_pool
3116 .channel_user_ids(root_id)
3117 .collect::<Vec<_>>()
3118 .into_iter()
3119 {
3120 let update = if role.can_see_channel(channel.visibility) {
3121 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3122 proto::UpdateChannels {
3123 channels: vec![channel.to_proto()],
3124 ..Default::default()
3125 }
3126 } else {
3127 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3128 proto::UpdateChannels {
3129 delete_channels: vec![channel.id.to_proto()],
3130 ..Default::default()
3131 }
3132 };
3133
3134 for connection_id in connection_pool.user_connection_ids(user_id) {
3135 session.peer.send(connection_id, update.clone())?;
3136 }
3137 }
3138
3139 response.send(proto::Ack {})?;
3140 Ok(())
3141}
3142
3143/// Alter the role for a user in the channel.
3144async fn set_channel_member_role(
3145 request: proto::SetChannelMemberRole,
3146 response: Response<proto::SetChannelMemberRole>,
3147 session: Session,
3148) -> Result<()> {
3149 let db = session.db().await;
3150 let channel_id = ChannelId::from_proto(request.channel_id);
3151 let member_id = UserId::from_proto(request.user_id);
3152 let result = db
3153 .set_channel_member_role(
3154 channel_id,
3155 session.user_id(),
3156 member_id,
3157 request.role().into(),
3158 )
3159 .await?;
3160
3161 match result {
3162 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3163 let mut connection_pool = session.connection_pool().await;
3164 notify_membership_updated(
3165 &mut connection_pool,
3166 membership_update,
3167 member_id,
3168 &session.peer,
3169 )
3170 }
3171 db::SetMemberRoleResult::InviteUpdated(channel) => {
3172 let update = proto::UpdateChannels {
3173 channel_invitations: vec![channel.to_proto()],
3174 ..Default::default()
3175 };
3176
3177 for connection_id in session
3178 .connection_pool()
3179 .await
3180 .user_connection_ids(member_id)
3181 {
3182 session.peer.send(connection_id, update.clone())?;
3183 }
3184 }
3185 }
3186
3187 response.send(proto::Ack {})?;
3188 Ok(())
3189}
3190
3191/// Change the name of a channel
3192async fn rename_channel(
3193 request: proto::RenameChannel,
3194 response: Response<proto::RenameChannel>,
3195 session: Session,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199 let channel_model = db
3200 .rename_channel(channel_id, session.user_id(), &request.name)
3201 .await?;
3202 let root_id = channel_model.root_id();
3203 let channel = Channel::from_model(channel_model);
3204
3205 response.send(proto::RenameChannelResponse {
3206 channel: Some(channel.to_proto()),
3207 })?;
3208
3209 let connection_pool = session.connection_pool().await;
3210 let update = proto::UpdateChannels {
3211 channels: vec![channel.to_proto()],
3212 ..Default::default()
3213 };
3214 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3215 if role.can_see_channel(channel.visibility) {
3216 session.peer.send(connection_id, update.clone())?;
3217 }
3218 }
3219
3220 Ok(())
3221}
3222
3223/// Move a channel to a new parent.
3224async fn move_channel(
3225 request: proto::MoveChannel,
3226 response: Response<proto::MoveChannel>,
3227 session: Session,
3228) -> Result<()> {
3229 let channel_id = ChannelId::from_proto(request.channel_id);
3230 let to = ChannelId::from_proto(request.to);
3231
3232 let (root_id, channels) = session
3233 .db()
3234 .await
3235 .move_channel(channel_id, to, session.user_id())
3236 .await?;
3237
3238 let connection_pool = session.connection_pool().await;
3239 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3240 let channels = channels
3241 .iter()
3242 .filter_map(|channel| {
3243 if role.can_see_channel(channel.visibility) {
3244 Some(channel.to_proto())
3245 } else {
3246 None
3247 }
3248 })
3249 .collect::<Vec<_>>();
3250 if channels.is_empty() {
3251 continue;
3252 }
3253
3254 let update = proto::UpdateChannels {
3255 channels,
3256 ..Default::default()
3257 };
3258
3259 session.peer.send(connection_id, update.clone())?;
3260 }
3261
3262 response.send(Ack {})?;
3263 Ok(())
3264}
3265
3266async fn reorder_channel(
3267 request: proto::ReorderChannel,
3268 response: Response<proto::ReorderChannel>,
3269 session: Session,
3270) -> Result<()> {
3271 let channel_id = ChannelId::from_proto(request.channel_id);
3272 let direction = request.direction();
3273
3274 let updated_channels = session
3275 .db()
3276 .await
3277 .reorder_channel(channel_id, direction, session.user_id())
3278 .await?;
3279
3280 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3281 let connection_pool = session.connection_pool().await;
3282 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3283 let channels = updated_channels
3284 .iter()
3285 .filter_map(|channel| {
3286 if role.can_see_channel(channel.visibility) {
3287 Some(channel.to_proto())
3288 } else {
3289 None
3290 }
3291 })
3292 .collect::<Vec<_>>();
3293
3294 if channels.is_empty() {
3295 continue;
3296 }
3297
3298 let update = proto::UpdateChannels {
3299 channels,
3300 ..Default::default()
3301 };
3302
3303 session.peer.send(connection_id, update.clone())?;
3304 }
3305 }
3306
3307 response.send(Ack {})?;
3308 Ok(())
3309}
3310
3311/// Get the list of channel members
3312async fn get_channel_members(
3313 request: proto::GetChannelMembers,
3314 response: Response<proto::GetChannelMembers>,
3315 session: Session,
3316) -> Result<()> {
3317 let db = session.db().await;
3318 let channel_id = ChannelId::from_proto(request.channel_id);
3319 let limit = if request.limit == 0 {
3320 u16::MAX as u64
3321 } else {
3322 request.limit
3323 };
3324 let (members, users) = db
3325 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3326 .await?;
3327 response.send(proto::GetChannelMembersResponse { members, users })?;
3328 Ok(())
3329}
3330
3331/// Accept or decline a channel invitation.
3332async fn respond_to_channel_invite(
3333 request: proto::RespondToChannelInvite,
3334 response: Response<proto::RespondToChannelInvite>,
3335 session: Session,
3336) -> Result<()> {
3337 let db = session.db().await;
3338 let channel_id = ChannelId::from_proto(request.channel_id);
3339 let RespondToChannelInvite {
3340 membership_update,
3341 notifications,
3342 } = db
3343 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3344 .await?;
3345
3346 let mut connection_pool = session.connection_pool().await;
3347 if let Some(membership_update) = membership_update {
3348 notify_membership_updated(
3349 &mut connection_pool,
3350 membership_update,
3351 session.user_id(),
3352 &session.peer,
3353 );
3354 } else {
3355 let update = proto::UpdateChannels {
3356 remove_channel_invitations: vec![channel_id.to_proto()],
3357 ..Default::default()
3358 };
3359
3360 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3361 session.peer.send(connection_id, update.clone())?;
3362 }
3363 };
3364
3365 send_notifications(&connection_pool, &session.peer, notifications);
3366
3367 response.send(proto::Ack {})?;
3368
3369 Ok(())
3370}
3371
3372/// Join the channels' room
3373async fn join_channel(
3374 request: proto::JoinChannel,
3375 response: Response<proto::JoinChannel>,
3376 session: Session,
3377) -> Result<()> {
3378 let channel_id = ChannelId::from_proto(request.channel_id);
3379 join_channel_internal(channel_id, Box::new(response), session).await
3380}
3381
3382trait JoinChannelInternalResponse {
3383 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3384}
3385impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3386 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3387 Response::<proto::JoinChannel>::send(self, result)
3388 }
3389}
3390impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3391 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3392 Response::<proto::JoinRoom>::send(self, result)
3393 }
3394}
3395
3396async fn join_channel_internal(
3397 channel_id: ChannelId,
3398 response: Box<impl JoinChannelInternalResponse>,
3399 session: Session,
3400) -> Result<()> {
3401 let joined_room = {
3402 let mut db = session.db().await;
3403 // If zed quits without leaving the room, and the user re-opens zed before the
3404 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3405 // room they were in.
3406 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3407 tracing::info!(
3408 stale_connection_id = %connection,
3409 "cleaning up stale connection",
3410 );
3411 drop(db);
3412 leave_room_for_session(&session, connection).await?;
3413 db = session.db().await;
3414 }
3415
3416 let (joined_room, membership_updated, role) = db
3417 .join_channel(channel_id, session.user_id(), session.connection_id)
3418 .await?;
3419
3420 let live_kit_connection_info =
3421 session
3422 .app_state
3423 .livekit_client
3424 .as_ref()
3425 .and_then(|live_kit| {
3426 let (can_publish, token) = if role == ChannelRole::Guest {
3427 (
3428 false,
3429 live_kit
3430 .guest_token(
3431 &joined_room.room.livekit_room,
3432 &session.user_id().to_string(),
3433 )
3434 .trace_err()?,
3435 )
3436 } else {
3437 (
3438 true,
3439 live_kit
3440 .room_token(
3441 &joined_room.room.livekit_room,
3442 &session.user_id().to_string(),
3443 )
3444 .trace_err()?,
3445 )
3446 };
3447
3448 Some(LiveKitConnectionInfo {
3449 server_url: live_kit.url().into(),
3450 token,
3451 can_publish,
3452 })
3453 });
3454
3455 response.send(proto::JoinRoomResponse {
3456 room: Some(joined_room.room.clone()),
3457 channel_id: joined_room
3458 .channel
3459 .as_ref()
3460 .map(|channel| channel.id.to_proto()),
3461 live_kit_connection_info,
3462 })?;
3463
3464 let mut connection_pool = session.connection_pool().await;
3465 if let Some(membership_updated) = membership_updated {
3466 notify_membership_updated(
3467 &mut connection_pool,
3468 membership_updated,
3469 session.user_id(),
3470 &session.peer,
3471 );
3472 }
3473
3474 room_updated(&joined_room.room, &session.peer);
3475
3476 joined_room
3477 };
3478
3479 channel_updated(
3480 &joined_room.channel.context("channel not returned")?,
3481 &joined_room.room,
3482 &session.peer,
3483 &*session.connection_pool().await,
3484 );
3485
3486 update_user_contacts(session.user_id(), &session).await?;
3487 Ok(())
3488}
3489
3490/// Start editing the channel notes
3491async fn join_channel_buffer(
3492 request: proto::JoinChannelBuffer,
3493 response: Response<proto::JoinChannelBuffer>,
3494 session: Session,
3495) -> Result<()> {
3496 let db = session.db().await;
3497 let channel_id = ChannelId::from_proto(request.channel_id);
3498
3499 let open_response = db
3500 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3501 .await?;
3502
3503 let collaborators = open_response.collaborators.clone();
3504 response.send(open_response)?;
3505
3506 let update = UpdateChannelBufferCollaborators {
3507 channel_id: channel_id.to_proto(),
3508 collaborators: collaborators.clone(),
3509 };
3510 channel_buffer_updated(
3511 session.connection_id,
3512 collaborators
3513 .iter()
3514 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3515 &update,
3516 &session.peer,
3517 );
3518
3519 Ok(())
3520}
3521
3522/// Edit the channel notes
3523async fn update_channel_buffer(
3524 request: proto::UpdateChannelBuffer,
3525 session: Session,
3526) -> Result<()> {
3527 let db = session.db().await;
3528 let channel_id = ChannelId::from_proto(request.channel_id);
3529
3530 let (collaborators, epoch, version) = db
3531 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3532 .await?;
3533
3534 channel_buffer_updated(
3535 session.connection_id,
3536 collaborators.clone(),
3537 &proto::UpdateChannelBuffer {
3538 channel_id: channel_id.to_proto(),
3539 operations: request.operations,
3540 },
3541 &session.peer,
3542 );
3543
3544 let pool = &*session.connection_pool().await;
3545
3546 let non_collaborators =
3547 pool.channel_connection_ids(channel_id)
3548 .filter_map(|(connection_id, _)| {
3549 if collaborators.contains(&connection_id) {
3550 None
3551 } else {
3552 Some(connection_id)
3553 }
3554 });
3555
3556 broadcast(None, non_collaborators, |peer_id| {
3557 session.peer.send(
3558 peer_id,
3559 proto::UpdateChannels {
3560 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3561 channel_id: channel_id.to_proto(),
3562 epoch: epoch as u64,
3563 version: version.clone(),
3564 }],
3565 ..Default::default()
3566 },
3567 )
3568 });
3569
3570 Ok(())
3571}
3572
3573/// Rejoin the channel notes after a connection blip
3574async fn rejoin_channel_buffers(
3575 request: proto::RejoinChannelBuffers,
3576 response: Response<proto::RejoinChannelBuffers>,
3577 session: Session,
3578) -> Result<()> {
3579 let db = session.db().await;
3580 let buffers = db
3581 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3582 .await?;
3583
3584 for rejoined_buffer in &buffers {
3585 let collaborators_to_notify = rejoined_buffer
3586 .buffer
3587 .collaborators
3588 .iter()
3589 .filter_map(|c| Some(c.peer_id?.into()));
3590 channel_buffer_updated(
3591 session.connection_id,
3592 collaborators_to_notify,
3593 &proto::UpdateChannelBufferCollaborators {
3594 channel_id: rejoined_buffer.buffer.channel_id,
3595 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3596 },
3597 &session.peer,
3598 );
3599 }
3600
3601 response.send(proto::RejoinChannelBuffersResponse {
3602 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3603 })?;
3604
3605 Ok(())
3606}
3607
3608/// Stop editing the channel notes
3609async fn leave_channel_buffer(
3610 request: proto::LeaveChannelBuffer,
3611 response: Response<proto::LeaveChannelBuffer>,
3612 session: Session,
3613) -> Result<()> {
3614 let db = session.db().await;
3615 let channel_id = ChannelId::from_proto(request.channel_id);
3616
3617 let left_buffer = db
3618 .leave_channel_buffer(channel_id, session.connection_id)
3619 .await?;
3620
3621 response.send(Ack {})?;
3622
3623 channel_buffer_updated(
3624 session.connection_id,
3625 left_buffer.connections,
3626 &proto::UpdateChannelBufferCollaborators {
3627 channel_id: channel_id.to_proto(),
3628 collaborators: left_buffer.collaborators,
3629 },
3630 &session.peer,
3631 );
3632
3633 Ok(())
3634}
3635
3636fn channel_buffer_updated<T: EnvelopedMessage>(
3637 sender_id: ConnectionId,
3638 collaborators: impl IntoIterator<Item = ConnectionId>,
3639 message: &T,
3640 peer: &Peer,
3641) {
3642 broadcast(Some(sender_id), collaborators, |peer_id| {
3643 peer.send(peer_id, message.clone())
3644 });
3645}
3646
3647fn send_notifications(
3648 connection_pool: &ConnectionPool,
3649 peer: &Peer,
3650 notifications: db::NotificationBatch,
3651) {
3652 for (user_id, notification) in notifications {
3653 for connection_id in connection_pool.user_connection_ids(user_id) {
3654 if let Err(error) = peer.send(
3655 connection_id,
3656 proto::AddNotification {
3657 notification: Some(notification.clone()),
3658 },
3659 ) {
3660 tracing::error!(
3661 "failed to send notification to {:?} {}",
3662 connection_id,
3663 error
3664 );
3665 }
3666 }
3667 }
3668}
3669
3670/// Send a message to the channel
3671async fn send_channel_message(
3672 request: proto::SendChannelMessage,
3673 response: Response<proto::SendChannelMessage>,
3674 session: Session,
3675) -> Result<()> {
3676 // Validate the message body.
3677 let body = request.body.trim().to_string();
3678 if body.len() > MAX_MESSAGE_LEN {
3679 return Err(anyhow!("message is too long"))?;
3680 }
3681 if body.is_empty() {
3682 return Err(anyhow!("message can't be blank"))?;
3683 }
3684
3685 // TODO: adjust mentions if body is trimmed
3686
3687 let timestamp = OffsetDateTime::now_utc();
3688 let nonce = request.nonce.context("nonce can't be blank")?;
3689
3690 let channel_id = ChannelId::from_proto(request.channel_id);
3691 let CreatedChannelMessage {
3692 message_id,
3693 participant_connection_ids,
3694 notifications,
3695 } = session
3696 .db()
3697 .await
3698 .create_channel_message(
3699 channel_id,
3700 session.user_id(),
3701 &body,
3702 &request.mentions,
3703 timestamp,
3704 nonce.clone().into(),
3705 request.reply_to_message_id.map(MessageId::from_proto),
3706 )
3707 .await?;
3708
3709 let message = proto::ChannelMessage {
3710 sender_id: session.user_id().to_proto(),
3711 id: message_id.to_proto(),
3712 body,
3713 mentions: request.mentions,
3714 timestamp: timestamp.unix_timestamp() as u64,
3715 nonce: Some(nonce),
3716 reply_to_message_id: request.reply_to_message_id,
3717 edited_at: None,
3718 };
3719 broadcast(
3720 Some(session.connection_id),
3721 participant_connection_ids.clone(),
3722 |connection| {
3723 session.peer.send(
3724 connection,
3725 proto::ChannelMessageSent {
3726 channel_id: channel_id.to_proto(),
3727 message: Some(message.clone()),
3728 },
3729 )
3730 },
3731 );
3732 response.send(proto::SendChannelMessageResponse {
3733 message: Some(message),
3734 })?;
3735
3736 let pool = &*session.connection_pool().await;
3737 let non_participants =
3738 pool.channel_connection_ids(channel_id)
3739 .filter_map(|(connection_id, _)| {
3740 if participant_connection_ids.contains(&connection_id) {
3741 None
3742 } else {
3743 Some(connection_id)
3744 }
3745 });
3746 broadcast(None, non_participants, |peer_id| {
3747 session.peer.send(
3748 peer_id,
3749 proto::UpdateChannels {
3750 latest_channel_message_ids: vec![proto::ChannelMessageId {
3751 channel_id: channel_id.to_proto(),
3752 message_id: message_id.to_proto(),
3753 }],
3754 ..Default::default()
3755 },
3756 )
3757 });
3758 send_notifications(pool, &session.peer, notifications);
3759
3760 Ok(())
3761}
3762
3763/// Delete a channel message
3764async fn remove_channel_message(
3765 request: proto::RemoveChannelMessage,
3766 response: Response<proto::RemoveChannelMessage>,
3767 session: Session,
3768) -> Result<()> {
3769 let channel_id = ChannelId::from_proto(request.channel_id);
3770 let message_id = MessageId::from_proto(request.message_id);
3771 let (connection_ids, existing_notification_ids) = session
3772 .db()
3773 .await
3774 .remove_channel_message(channel_id, message_id, session.user_id())
3775 .await?;
3776
3777 broadcast(
3778 Some(session.connection_id),
3779 connection_ids,
3780 move |connection| {
3781 session.peer.send(connection, request.clone())?;
3782
3783 for notification_id in &existing_notification_ids {
3784 session.peer.send(
3785 connection,
3786 proto::DeleteNotification {
3787 notification_id: (*notification_id).to_proto(),
3788 },
3789 )?;
3790 }
3791
3792 Ok(())
3793 },
3794 );
3795 response.send(proto::Ack {})?;
3796 Ok(())
3797}
3798
3799async fn update_channel_message(
3800 request: proto::UpdateChannelMessage,
3801 response: Response<proto::UpdateChannelMessage>,
3802 session: Session,
3803) -> Result<()> {
3804 let channel_id = ChannelId::from_proto(request.channel_id);
3805 let message_id = MessageId::from_proto(request.message_id);
3806 let updated_at = OffsetDateTime::now_utc();
3807 let UpdatedChannelMessage {
3808 message_id,
3809 participant_connection_ids,
3810 notifications,
3811 reply_to_message_id,
3812 timestamp,
3813 deleted_mention_notification_ids,
3814 updated_mention_notifications,
3815 } = session
3816 .db()
3817 .await
3818 .update_channel_message(
3819 channel_id,
3820 message_id,
3821 session.user_id(),
3822 request.body.as_str(),
3823 &request.mentions,
3824 updated_at,
3825 )
3826 .await?;
3827
3828 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3829
3830 let message = proto::ChannelMessage {
3831 sender_id: session.user_id().to_proto(),
3832 id: message_id.to_proto(),
3833 body: request.body.clone(),
3834 mentions: request.mentions.clone(),
3835 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3836 nonce: Some(nonce),
3837 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3838 edited_at: Some(updated_at.unix_timestamp() as u64),
3839 };
3840
3841 response.send(proto::Ack {})?;
3842
3843 let pool = &*session.connection_pool().await;
3844 broadcast(
3845 Some(session.connection_id),
3846 participant_connection_ids,
3847 |connection| {
3848 session.peer.send(
3849 connection,
3850 proto::ChannelMessageUpdate {
3851 channel_id: channel_id.to_proto(),
3852 message: Some(message.clone()),
3853 },
3854 )?;
3855
3856 for notification_id in &deleted_mention_notification_ids {
3857 session.peer.send(
3858 connection,
3859 proto::DeleteNotification {
3860 notification_id: (*notification_id).to_proto(),
3861 },
3862 )?;
3863 }
3864
3865 for notification in &updated_mention_notifications {
3866 session.peer.send(
3867 connection,
3868 proto::UpdateNotification {
3869 notification: Some(notification.clone()),
3870 },
3871 )?;
3872 }
3873
3874 Ok(())
3875 },
3876 );
3877
3878 send_notifications(pool, &session.peer, notifications);
3879
3880 Ok(())
3881}
3882
3883/// Mark a channel message as read
3884async fn acknowledge_channel_message(
3885 request: proto::AckChannelMessage,
3886 session: Session,
3887) -> Result<()> {
3888 let channel_id = ChannelId::from_proto(request.channel_id);
3889 let message_id = MessageId::from_proto(request.message_id);
3890 let notifications = session
3891 .db()
3892 .await
3893 .observe_channel_message(channel_id, session.user_id(), message_id)
3894 .await?;
3895 send_notifications(
3896 &*session.connection_pool().await,
3897 &session.peer,
3898 notifications,
3899 );
3900 Ok(())
3901}
3902
3903/// Mark a buffer version as synced
3904async fn acknowledge_buffer_version(
3905 request: proto::AckBufferOperation,
3906 session: Session,
3907) -> Result<()> {
3908 let buffer_id = BufferId::from_proto(request.buffer_id);
3909 session
3910 .db()
3911 .await
3912 .observe_buffer_version(
3913 buffer_id,
3914 session.user_id(),
3915 request.epoch as i32,
3916 &request.version,
3917 )
3918 .await?;
3919 Ok(())
3920}
3921
3922/// Get a Supermaven API key for the user
3923async fn get_supermaven_api_key(
3924 _request: proto::GetSupermavenApiKey,
3925 response: Response<proto::GetSupermavenApiKey>,
3926 session: Session,
3927) -> Result<()> {
3928 let user_id: String = session.user_id().to_string();
3929 if !session.is_staff() {
3930 return Err(anyhow!("supermaven not enabled for this account"))?;
3931 }
3932
3933 let email = session.email().context("user must have an email")?;
3934
3935 let supermaven_admin_api = session
3936 .supermaven_client
3937 .as_ref()
3938 .context("supermaven not configured")?;
3939
3940 let result = supermaven_admin_api
3941 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3942 .await?;
3943
3944 response.send(proto::GetSupermavenApiKeyResponse {
3945 api_key: result.api_key,
3946 })?;
3947
3948 Ok(())
3949}
3950
3951/// Start receiving chat updates for a channel
3952async fn join_channel_chat(
3953 request: proto::JoinChannelChat,
3954 response: Response<proto::JoinChannelChat>,
3955 session: Session,
3956) -> Result<()> {
3957 let channel_id = ChannelId::from_proto(request.channel_id);
3958
3959 let db = session.db().await;
3960 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3961 .await?;
3962 let messages = db
3963 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3964 .await?;
3965 response.send(proto::JoinChannelChatResponse {
3966 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3967 messages,
3968 })?;
3969 Ok(())
3970}
3971
3972/// Stop receiving chat updates for a channel
3973async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3974 let channel_id = ChannelId::from_proto(request.channel_id);
3975 session
3976 .db()
3977 .await
3978 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3979 .await?;
3980 Ok(())
3981}
3982
3983/// Retrieve the chat history for a channel
3984async fn get_channel_messages(
3985 request: proto::GetChannelMessages,
3986 response: Response<proto::GetChannelMessages>,
3987 session: Session,
3988) -> Result<()> {
3989 let channel_id = ChannelId::from_proto(request.channel_id);
3990 let messages = session
3991 .db()
3992 .await
3993 .get_channel_messages(
3994 channel_id,
3995 session.user_id(),
3996 MESSAGE_COUNT_PER_PAGE,
3997 Some(MessageId::from_proto(request.before_message_id)),
3998 )
3999 .await?;
4000 response.send(proto::GetChannelMessagesResponse {
4001 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4002 messages,
4003 })?;
4004 Ok(())
4005}
4006
4007/// Retrieve specific chat messages
4008async fn get_channel_messages_by_id(
4009 request: proto::GetChannelMessagesById,
4010 response: Response<proto::GetChannelMessagesById>,
4011 session: Session,
4012) -> Result<()> {
4013 let message_ids = request
4014 .message_ids
4015 .iter()
4016 .map(|id| MessageId::from_proto(*id))
4017 .collect::<Vec<_>>();
4018 let messages = session
4019 .db()
4020 .await
4021 .get_channel_messages_by_id(session.user_id(), &message_ids)
4022 .await?;
4023 response.send(proto::GetChannelMessagesResponse {
4024 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4025 messages,
4026 })?;
4027 Ok(())
4028}
4029
4030/// Retrieve the current users notifications
4031async fn get_notifications(
4032 request: proto::GetNotifications,
4033 response: Response<proto::GetNotifications>,
4034 session: Session,
4035) -> Result<()> {
4036 let notifications = session
4037 .db()
4038 .await
4039 .get_notifications(
4040 session.user_id(),
4041 NOTIFICATION_COUNT_PER_PAGE,
4042 request.before_id.map(db::NotificationId::from_proto),
4043 )
4044 .await?;
4045 response.send(proto::GetNotificationsResponse {
4046 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4047 notifications,
4048 })?;
4049 Ok(())
4050}
4051
4052/// Mark notifications as read
4053async fn mark_notification_as_read(
4054 request: proto::MarkNotificationRead,
4055 response: Response<proto::MarkNotificationRead>,
4056 session: Session,
4057) -> Result<()> {
4058 let database = &session.db().await;
4059 let notifications = database
4060 .mark_notification_as_read_by_id(
4061 session.user_id(),
4062 NotificationId::from_proto(request.notification_id),
4063 )
4064 .await?;
4065 send_notifications(
4066 &*session.connection_pool().await,
4067 &session.peer,
4068 notifications,
4069 );
4070 response.send(proto::Ack {})?;
4071 Ok(())
4072}
4073
4074/// Get the current users information
4075async fn get_private_user_info(
4076 _request: proto::GetPrivateUserInfo,
4077 response: Response<proto::GetPrivateUserInfo>,
4078 session: Session,
4079) -> Result<()> {
4080 let db = session.db().await;
4081
4082 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4083 let user = db
4084 .get_user_by_id(session.user_id())
4085 .await?
4086 .context("user not found")?;
4087 let flags = db.get_user_flags(session.user_id()).await?;
4088
4089 response.send(proto::GetPrivateUserInfoResponse {
4090 metrics_id,
4091 staff: user.admin,
4092 flags,
4093 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4094 })?;
4095 Ok(())
4096}
4097
4098/// Accept the terms of service (tos) on behalf of the current user
4099async fn accept_terms_of_service(
4100 _request: proto::AcceptTermsOfService,
4101 response: Response<proto::AcceptTermsOfService>,
4102 session: Session,
4103) -> Result<()> {
4104 let db = session.db().await;
4105
4106 let accepted_tos_at = Utc::now();
4107 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4108 .await?;
4109
4110 response.send(proto::AcceptTermsOfServiceResponse {
4111 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4112 })?;
4113 Ok(())
4114}
4115
4116async fn get_llm_api_token(
4117 _request: proto::GetLlmToken,
4118 response: Response<proto::GetLlmToken>,
4119 session: Session,
4120) -> Result<()> {
4121 let db = session.db().await;
4122
4123 let flags = db.get_user_flags(session.user_id()).await?;
4124
4125 let user_id = session.user_id();
4126 let user = db
4127 .get_user_by_id(user_id)
4128 .await?
4129 .with_context(|| format!("user {user_id} not found"))?;
4130
4131 if user.accepted_tos_at.is_none() {
4132 Err(anyhow!("terms of service not accepted"))?
4133 }
4134
4135 let stripe_client = session
4136 .app_state
4137 .stripe_client
4138 .as_ref()
4139 .context("failed to retrieve Stripe client")?;
4140
4141 let stripe_billing = session
4142 .app_state
4143 .stripe_billing
4144 .as_ref()
4145 .context("failed to retrieve Stripe billing object")?;
4146
4147 let billing_customer = if let Some(billing_customer) =
4148 db.get_billing_customer_by_user_id(user.id).await?
4149 {
4150 billing_customer
4151 } else {
4152 let customer_id = stripe_billing
4153 .find_or_create_customer_by_email(user.email_address.as_deref())
4154 .await?;
4155
4156 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4157 .await?
4158 .context("billing customer not found")?
4159 };
4160
4161 let billing_subscription =
4162 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4163 billing_subscription
4164 } else {
4165 let stripe_customer_id =
4166 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4167
4168 let stripe_subscription = stripe_billing
4169 .subscribe_to_zed_free(stripe_customer_id)
4170 .await?;
4171
4172 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4173 billing_customer_id: billing_customer.id,
4174 kind: Some(SubscriptionKind::ZedFree),
4175 stripe_subscription_id: stripe_subscription.id.to_string(),
4176 stripe_subscription_status: stripe_subscription.status.into(),
4177 stripe_cancellation_reason: None,
4178 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4179 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4180 })
4181 .await?
4182 };
4183
4184 let billing_preferences = db.get_billing_preferences(user.id).await?;
4185
4186 let token = LlmTokenClaims::create(
4187 &user,
4188 session.is_staff(),
4189 billing_customer,
4190 billing_preferences,
4191 &flags,
4192 billing_subscription,
4193 session.system_id.clone(),
4194 &session.app_state.config,
4195 )?;
4196 response.send(proto::GetLlmTokenResponse { token })?;
4197 Ok(())
4198}
4199
4200fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4201 let message = match message {
4202 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4203 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4204 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4205 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4206 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4207 code: frame.code.into(),
4208 reason: frame.reason.as_str().to_owned().into(),
4209 })),
4210 // We should never receive a frame while reading the message, according
4211 // to the `tungstenite` maintainers:
4212 //
4213 // > It cannot occur when you read messages from the WebSocket, but it
4214 // > can be used when you want to send the raw frames (e.g. you want to
4215 // > send the frames to the WebSocket without composing the full message first).
4216 // >
4217 // > — https://github.com/snapview/tungstenite-rs/issues/268
4218 TungsteniteMessage::Frame(_) => {
4219 bail!("received an unexpected frame while reading the message")
4220 }
4221 };
4222
4223 Ok(message)
4224}
4225
4226fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4227 match message {
4228 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4229 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4230 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4231 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4232 AxumMessage::Close(frame) => {
4233 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4234 code: frame.code.into(),
4235 reason: frame.reason.as_ref().into(),
4236 }))
4237 }
4238 }
4239}
4240
4241fn notify_membership_updated(
4242 connection_pool: &mut ConnectionPool,
4243 result: MembershipUpdated,
4244 user_id: UserId,
4245 peer: &Peer,
4246) {
4247 for membership in &result.new_channels.channel_memberships {
4248 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4249 }
4250 for channel_id in &result.removed_channels {
4251 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4252 }
4253
4254 let user_channels_update = proto::UpdateUserChannels {
4255 channel_memberships: result
4256 .new_channels
4257 .channel_memberships
4258 .iter()
4259 .map(|cm| proto::ChannelMembership {
4260 channel_id: cm.channel_id.to_proto(),
4261 role: cm.role.into(),
4262 })
4263 .collect(),
4264 ..Default::default()
4265 };
4266
4267 let mut update = build_channels_update(result.new_channels);
4268 update.delete_channels = result
4269 .removed_channels
4270 .into_iter()
4271 .map(|id| id.to_proto())
4272 .collect();
4273 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4274
4275 for connection_id in connection_pool.user_connection_ids(user_id) {
4276 peer.send(connection_id, user_channels_update.clone())
4277 .trace_err();
4278 peer.send(connection_id, update.clone()).trace_err();
4279 }
4280}
4281
4282fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4283 proto::UpdateUserChannels {
4284 channel_memberships: channels
4285 .channel_memberships
4286 .iter()
4287 .map(|m| proto::ChannelMembership {
4288 channel_id: m.channel_id.to_proto(),
4289 role: m.role.into(),
4290 })
4291 .collect(),
4292 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4293 observed_channel_message_id: channels.observed_channel_messages.clone(),
4294 }
4295}
4296
4297fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4298 let mut update = proto::UpdateChannels::default();
4299
4300 for channel in channels.channels {
4301 update.channels.push(channel.to_proto());
4302 }
4303
4304 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4305 update.latest_channel_message_ids = channels.latest_channel_messages;
4306
4307 for (channel_id, participants) in channels.channel_participants {
4308 update
4309 .channel_participants
4310 .push(proto::ChannelParticipants {
4311 channel_id: channel_id.to_proto(),
4312 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4313 });
4314 }
4315
4316 for channel in channels.invited_channels {
4317 update.channel_invitations.push(channel.to_proto());
4318 }
4319
4320 update
4321}
4322
4323fn build_initial_contacts_update(
4324 contacts: Vec<db::Contact>,
4325 pool: &ConnectionPool,
4326) -> proto::UpdateContacts {
4327 let mut update = proto::UpdateContacts::default();
4328
4329 for contact in contacts {
4330 match contact {
4331 db::Contact::Accepted { user_id, busy } => {
4332 update.contacts.push(contact_for_user(user_id, busy, pool));
4333 }
4334 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4335 db::Contact::Incoming { user_id } => {
4336 update
4337 .incoming_requests
4338 .push(proto::IncomingContactRequest {
4339 requester_id: user_id.to_proto(),
4340 })
4341 }
4342 }
4343 }
4344
4345 update
4346}
4347
4348fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4349 proto::Contact {
4350 user_id: user_id.to_proto(),
4351 online: pool.is_user_online(user_id),
4352 busy,
4353 }
4354}
4355
4356fn room_updated(room: &proto::Room, peer: &Peer) {
4357 broadcast(
4358 None,
4359 room.participants
4360 .iter()
4361 .filter_map(|participant| Some(participant.peer_id?.into())),
4362 |peer_id| {
4363 peer.send(
4364 peer_id,
4365 proto::RoomUpdated {
4366 room: Some(room.clone()),
4367 },
4368 )
4369 },
4370 );
4371}
4372
4373fn channel_updated(
4374 channel: &db::channel::Model,
4375 room: &proto::Room,
4376 peer: &Peer,
4377 pool: &ConnectionPool,
4378) {
4379 let participants = room
4380 .participants
4381 .iter()
4382 .map(|p| p.user_id)
4383 .collect::<Vec<_>>();
4384
4385 broadcast(
4386 None,
4387 pool.channel_connection_ids(channel.root_id())
4388 .filter_map(|(channel_id, role)| {
4389 role.can_see_channel(channel.visibility)
4390 .then_some(channel_id)
4391 }),
4392 |peer_id| {
4393 peer.send(
4394 peer_id,
4395 proto::UpdateChannels {
4396 channel_participants: vec![proto::ChannelParticipants {
4397 channel_id: channel.id.to_proto(),
4398 participant_user_ids: participants.clone(),
4399 }],
4400 ..Default::default()
4401 },
4402 )
4403 },
4404 );
4405}
4406
4407async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4408 let db = session.db().await;
4409
4410 let contacts = db.get_contacts(user_id).await?;
4411 let busy = db.is_user_busy(user_id).await?;
4412
4413 let pool = session.connection_pool().await;
4414 let updated_contact = contact_for_user(user_id, busy, &pool);
4415 for contact in contacts {
4416 if let db::Contact::Accepted {
4417 user_id: contact_user_id,
4418 ..
4419 } = contact
4420 {
4421 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4422 session
4423 .peer
4424 .send(
4425 contact_conn_id,
4426 proto::UpdateContacts {
4427 contacts: vec![updated_contact.clone()],
4428 remove_contacts: Default::default(),
4429 incoming_requests: Default::default(),
4430 remove_incoming_requests: Default::default(),
4431 outgoing_requests: Default::default(),
4432 remove_outgoing_requests: Default::default(),
4433 },
4434 )
4435 .trace_err();
4436 }
4437 }
4438 }
4439 Ok(())
4440}
4441
4442async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4443 let mut contacts_to_update = HashSet::default();
4444
4445 let room_id;
4446 let canceled_calls_to_user_ids;
4447 let livekit_room;
4448 let delete_livekit_room;
4449 let room;
4450 let channel;
4451
4452 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4453 contacts_to_update.insert(session.user_id());
4454
4455 for project in left_room.left_projects.values() {
4456 project_left(project, session);
4457 }
4458
4459 room_id = RoomId::from_proto(left_room.room.id);
4460 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4461 livekit_room = mem::take(&mut left_room.room.livekit_room);
4462 delete_livekit_room = left_room.deleted;
4463 room = mem::take(&mut left_room.room);
4464 channel = mem::take(&mut left_room.channel);
4465
4466 room_updated(&room, &session.peer);
4467 } else {
4468 return Ok(());
4469 }
4470
4471 if let Some(channel) = channel {
4472 channel_updated(
4473 &channel,
4474 &room,
4475 &session.peer,
4476 &*session.connection_pool().await,
4477 );
4478 }
4479
4480 {
4481 let pool = session.connection_pool().await;
4482 for canceled_user_id in canceled_calls_to_user_ids {
4483 for connection_id in pool.user_connection_ids(canceled_user_id) {
4484 session
4485 .peer
4486 .send(
4487 connection_id,
4488 proto::CallCanceled {
4489 room_id: room_id.to_proto(),
4490 },
4491 )
4492 .trace_err();
4493 }
4494 contacts_to_update.insert(canceled_user_id);
4495 }
4496 }
4497
4498 for contact_user_id in contacts_to_update {
4499 update_user_contacts(contact_user_id, session).await?;
4500 }
4501
4502 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4503 live_kit
4504 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4505 .await
4506 .trace_err();
4507
4508 if delete_livekit_room {
4509 live_kit.delete_room(livekit_room).await.trace_err();
4510 }
4511 }
4512
4513 Ok(())
4514}
4515
4516async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4517 let left_channel_buffers = session
4518 .db()
4519 .await
4520 .leave_channel_buffers(session.connection_id)
4521 .await?;
4522
4523 for left_buffer in left_channel_buffers {
4524 channel_buffer_updated(
4525 session.connection_id,
4526 left_buffer.connections,
4527 &proto::UpdateChannelBufferCollaborators {
4528 channel_id: left_buffer.channel_id.to_proto(),
4529 collaborators: left_buffer.collaborators,
4530 },
4531 &session.peer,
4532 );
4533 }
4534
4535 Ok(())
4536}
4537
4538fn project_left(project: &db::LeftProject, session: &Session) {
4539 for connection_id in &project.connection_ids {
4540 if project.should_unshare {
4541 session
4542 .peer
4543 .send(
4544 *connection_id,
4545 proto::UnshareProject {
4546 project_id: project.id.to_proto(),
4547 },
4548 )
4549 .trace_err();
4550 } else {
4551 session
4552 .peer
4553 .send(
4554 *connection_id,
4555 proto::RemoveProjectCollaborator {
4556 project_id: project.id.to_proto(),
4557 peer_id: Some(session.connection_id.into()),
4558 },
4559 )
4560 .trace_err();
4561 }
4562 }
4563}
4564
4565pub trait ResultExt {
4566 type Ok;
4567
4568 fn trace_err(self) -> Option<Self::Ok>;
4569}
4570
4571impl<T, E> ResultExt for Result<T, E>
4572where
4573 E: std::fmt::Debug,
4574{
4575 type Ok = T;
4576
4577 #[track_caller]
4578 fn trace_err(self) -> Option<T> {
4579 match self {
4580 Ok(value) => Some(value),
4581 Err(error) => {
4582 tracing::error!("{:?}", error);
4583 None
4584 }
4585 }
4586 }
4587}