1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
93
94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn user(&self) -> &User {
144 match self {
145 Principal::User(user) => user,
146 Principal::Impersonated { user, .. } => user,
147 }
148 }
149
150 fn update_span(&self, span: &tracing::Span) {
151 match &self {
152 Principal::User(user) => {
153 span.record("user_id", user.id.0);
154 span.record("login", &user.github_login);
155 }
156 Principal::Impersonated { user, admin } => {
157 span.record("user_id", user.id.0);
158 span.record("login", &user.github_login);
159 span.record("impersonator", &admin.github_login);
160 }
161 }
162 }
163}
164
165#[derive(Clone)]
166struct Session {
167 principal: Principal,
168 connection_id: ConnectionId,
169 db: Arc<tokio::sync::Mutex<DbHandle>>,
170 peer: Arc<Peer>,
171 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
172 app_state: Arc<AppState>,
173 supermaven_client: Option<Arc<SupermavenAdminApi>>,
174 /// The GeoIP country code for the user.
175 #[allow(unused)]
176 geoip_country_code: Option<String>,
177 system_id: Option<String>,
178 _executor: Executor,
179}
180
181impl Session {
182 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
183 #[cfg(test)]
184 tokio::task::yield_now().await;
185 let guard = self.db.lock().await;
186 #[cfg(test)]
187 tokio::task::yield_now().await;
188 guard
189 }
190
191 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
192 #[cfg(test)]
193 tokio::task::yield_now().await;
194 let guard = self.connection_pool.lock();
195 ConnectionPoolGuard {
196 guard,
197 _not_send: PhantomData,
198 }
199 }
200
201 fn is_staff(&self) -> bool {
202 match &self.principal {
203 Principal::User(user) => user.admin,
204 Principal::Impersonated { .. } => true,
205 }
206 }
207
208 fn user_id(&self) -> UserId {
209 match &self.principal {
210 Principal::User(user) => user.id,
211 Principal::Impersonated { user, .. } => user.id,
212 }
213 }
214
215 pub fn email(&self) -> Option<String> {
216 match &self.principal {
217 Principal::User(user) => user.email_address.clone(),
218 Principal::Impersonated { user, .. } => user.email_address.clone(),
219 }
220 }
221}
222
223impl Debug for Session {
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 let mut result = f.debug_struct("Session");
226 match &self.principal {
227 Principal::User(user) => {
228 result.field("user", &user.github_login);
229 }
230 Principal::Impersonated { user, admin } => {
231 result.field("user", &user.github_login);
232 result.field("impersonator", &admin.github_login);
233 }
234 }
235 result.field("connection_id", &self.connection_id).finish()
236 }
237}
238
239struct DbHandle(Arc<Database>);
240
241impl Deref for DbHandle {
242 type Target = Database;
243
244 fn deref(&self) -> &Self::Target {
245 self.0.as_ref()
246 }
247}
248
249pub struct Server {
250 id: parking_lot::Mutex<ServerId>,
251 peer: Arc<Peer>,
252 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
253 app_state: Arc<AppState>,
254 handlers: HashMap<TypeId, MessageHandler>,
255 teardown: watch::Sender<bool>,
256}
257
258pub(crate) struct ConnectionPoolGuard<'a> {
259 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
260 _not_send: PhantomData<Rc<()>>,
261}
262
263#[derive(Serialize)]
264pub struct ServerSnapshot<'a> {
265 peer: &'a Peer,
266 #[serde(serialize_with = "serialize_deref")]
267 connection_pool: ConnectionPoolGuard<'a>,
268}
269
270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
271where
272 S: Serializer,
273 T: Deref<Target = U>,
274 U: Serialize,
275{
276 Serialize::serialize(value.deref(), serializer)
277}
278
279impl Server {
280 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
281 let mut server = Self {
282 id: parking_lot::Mutex::new(id),
283 peer: Peer::new(id.0 as u32),
284 app_state: app_state.clone(),
285 connection_pool: Default::default(),
286 handlers: Default::default(),
287 teardown: watch::channel(false).0,
288 };
289
290 server
291 .add_request_handler(ping)
292 .add_request_handler(create_room)
293 .add_request_handler(join_room)
294 .add_request_handler(rejoin_room)
295 .add_request_handler(leave_room)
296 .add_request_handler(set_room_participant_role)
297 .add_request_handler(call)
298 .add_request_handler(cancel_call)
299 .add_message_handler(decline_call)
300 .add_request_handler(update_participant_location)
301 .add_request_handler(share_project)
302 .add_message_handler(unshare_project)
303 .add_request_handler(join_project)
304 .add_message_handler(leave_project)
305 .add_request_handler(update_project)
306 .add_request_handler(update_worktree)
307 .add_request_handler(update_repository)
308 .add_request_handler(remove_repository)
309 .add_message_handler(start_language_server)
310 .add_message_handler(update_language_server)
311 .add_message_handler(update_diagnostic_summary)
312 .add_message_handler(update_worktree_settings)
313 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
317 .add_request_handler(forward_find_search_candidates_request)
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
323 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
324 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
325 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
327 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
328 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
330 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
331 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
332 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
333 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
334 .add_request_handler(
335 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
336 )
337 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
341 .add_request_handler(
342 forward_read_only_project_request::<proto::LanguageServerIdForName>,
343 )
344 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
345 .add_request_handler(
346 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
347 )
348 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
349 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
352 )
353 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
363 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
364 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
369 .add_request_handler(
370 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
371 )
372 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
373 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
374 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
375 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
376 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
377 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
378 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
379 .add_message_handler(create_buffer_for_peer)
380 .add_request_handler(update_buffer)
381 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
387 .add_message_handler(
388 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
389 )
390 .add_request_handler(get_users)
391 .add_request_handler(fuzzy_search_users)
392 .add_request_handler(request_contact)
393 .add_request_handler(remove_contact)
394 .add_request_handler(respond_to_contact_request)
395 .add_message_handler(subscribe_to_channels)
396 .add_request_handler(create_channel)
397 .add_request_handler(delete_channel)
398 .add_request_handler(invite_channel_member)
399 .add_request_handler(remove_channel_member)
400 .add_request_handler(set_channel_member_role)
401 .add_request_handler(set_channel_visibility)
402 .add_request_handler(rename_channel)
403 .add_request_handler(join_channel_buffer)
404 .add_request_handler(leave_channel_buffer)
405 .add_message_handler(update_channel_buffer)
406 .add_request_handler(rejoin_channel_buffers)
407 .add_request_handler(get_channel_members)
408 .add_request_handler(respond_to_channel_invite)
409 .add_request_handler(join_channel)
410 .add_request_handler(join_channel_chat)
411 .add_message_handler(leave_channel_chat)
412 .add_request_handler(send_channel_message)
413 .add_request_handler(remove_channel_message)
414 .add_request_handler(update_channel_message)
415 .add_request_handler(get_channel_messages)
416 .add_request_handler(get_channel_messages_by_id)
417 .add_request_handler(get_notifications)
418 .add_request_handler(mark_notification_as_read)
419 .add_request_handler(move_channel)
420 .add_request_handler(reorder_channel)
421 .add_request_handler(follow)
422 .add_message_handler(unfollow)
423 .add_message_handler(update_followers)
424 .add_request_handler(get_private_user_info)
425 .add_request_handler(get_llm_api_token)
426 .add_request_handler(accept_terms_of_service)
427 .add_message_handler(acknowledge_channel_message)
428 .add_message_handler(acknowledge_buffer_version)
429 .add_request_handler(get_supermaven_api_key)
430 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
431 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
432 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
433 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
434 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
435 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
437 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
439 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
440 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
442 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
443 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
444 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
445 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
446 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
450 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
451 .add_message_handler(update_context);
452
453 Arc::new(server)
454 }
455
456 pub async fn start(&self) -> Result<()> {
457 let server_id = *self.id.lock();
458 let app_state = self.app_state.clone();
459 let peer = self.peer.clone();
460 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
461 let pool = self.connection_pool.clone();
462 let livekit_client = self.app_state.livekit_client.clone();
463
464 let span = info_span!("start server");
465 self.app_state.executor.spawn_detached(
466 async move {
467 tracing::info!("waiting for cleanup timeout");
468 timeout.await;
469 tracing::info!("cleanup timeout expired, retrieving stale rooms");
470
471 app_state
472 .db
473 .delete_stale_channel_chat_participants(
474 &app_state.config.zed_environment,
475 server_id,
476 )
477 .await
478 .trace_err();
479
480 if let Some((room_ids, channel_ids)) = app_state
481 .db
482 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
483 .await
484 .trace_err()
485 {
486 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
487 tracing::info!(
488 stale_channel_buffer_count = channel_ids.len(),
489 "retrieved stale channel buffers"
490 );
491
492 for channel_id in channel_ids {
493 if let Some(refreshed_channel_buffer) = app_state
494 .db
495 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
496 .await
497 .trace_err()
498 {
499 for connection_id in refreshed_channel_buffer.connection_ids {
500 peer.send(
501 connection_id,
502 proto::UpdateChannelBufferCollaborators {
503 channel_id: channel_id.to_proto(),
504 collaborators: refreshed_channel_buffer
505 .collaborators
506 .clone(),
507 },
508 )
509 .trace_err();
510 }
511 }
512 }
513
514 for room_id in room_ids {
515 let mut contacts_to_update = HashSet::default();
516 let mut canceled_calls_to_user_ids = Vec::new();
517 let mut livekit_room = String::new();
518 let mut delete_livekit_room = false;
519
520 if let Some(mut refreshed_room) = app_state
521 .db
522 .clear_stale_room_participants(room_id, server_id)
523 .await
524 .trace_err()
525 {
526 tracing::info!(
527 room_id = room_id.0,
528 new_participant_count = refreshed_room.room.participants.len(),
529 "refreshed room"
530 );
531 room_updated(&refreshed_room.room, &peer);
532 if let Some(channel) = refreshed_room.channel.as_ref() {
533 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
534 }
535 contacts_to_update
536 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
537 contacts_to_update
538 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
539 canceled_calls_to_user_ids =
540 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
541 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
542 delete_livekit_room = refreshed_room.room.participants.is_empty();
543 }
544
545 {
546 let pool = pool.lock();
547 for canceled_user_id in canceled_calls_to_user_ids {
548 for connection_id in pool.user_connection_ids(canceled_user_id) {
549 peer.send(
550 connection_id,
551 proto::CallCanceled {
552 room_id: room_id.to_proto(),
553 },
554 )
555 .trace_err();
556 }
557 }
558 }
559
560 for user_id in contacts_to_update {
561 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
562 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
563 if let Some((busy, contacts)) = busy.zip(contacts) {
564 let pool = pool.lock();
565 let updated_contact = contact_for_user(user_id, busy, &pool);
566 for contact in contacts {
567 if let db::Contact::Accepted {
568 user_id: contact_user_id,
569 ..
570 } = contact
571 {
572 for contact_conn_id in
573 pool.user_connection_ids(contact_user_id)
574 {
575 peer.send(
576 contact_conn_id,
577 proto::UpdateContacts {
578 contacts: vec![updated_contact.clone()],
579 remove_contacts: Default::default(),
580 incoming_requests: Default::default(),
581 remove_incoming_requests: Default::default(),
582 outgoing_requests: Default::default(),
583 remove_outgoing_requests: Default::default(),
584 },
585 )
586 .trace_err();
587 }
588 }
589 }
590 }
591 }
592
593 if let Some(live_kit) = livekit_client.as_ref() {
594 if delete_livekit_room {
595 live_kit.delete_room(livekit_room).await.trace_err();
596 }
597 }
598 }
599 }
600
601 app_state
602 .db
603 .delete_stale_channel_chat_participants(
604 &app_state.config.zed_environment,
605 server_id,
606 )
607 .await
608 .trace_err();
609
610 app_state
611 .db
612 .clear_old_worktree_entries(server_id)
613 .await
614 .trace_err();
615
616 app_state
617 .db
618 .delete_stale_servers(&app_state.config.zed_environment, server_id)
619 .await
620 .trace_err();
621 }
622 .instrument(span),
623 );
624 Ok(())
625 }
626
627 pub fn teardown(&self) {
628 self.peer.teardown();
629 self.connection_pool.lock().reset();
630 let _ = self.teardown.send(true);
631 }
632
633 #[cfg(test)]
634 pub fn reset(&self, id: ServerId) {
635 self.teardown();
636 *self.id.lock() = id;
637 self.peer.reset(id.0 as u32);
638 let _ = self.teardown.send(false);
639 }
640
641 #[cfg(test)]
642 pub fn id(&self) -> ServerId {
643 *self.id.lock()
644 }
645
646 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
649 Fut: 'static + Send + Future<Output = Result<()>>,
650 M: EnvelopedMessage,
651 {
652 let prev_handler = self.handlers.insert(
653 TypeId::of::<M>(),
654 Box::new(move |envelope, session| {
655 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
656 let received_at = envelope.received_at;
657 tracing::info!("message received");
658 let start_time = Instant::now();
659 let future = (handler)(*envelope, session);
660 async move {
661 let result = future.await;
662 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
663 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
664 let queue_duration_ms = total_duration_ms - processing_duration_ms;
665 let payload_type = M::NAME;
666
667 match result {
668 Err(error) => {
669 tracing::error!(
670 ?error,
671 total_duration_ms,
672 processing_duration_ms,
673 queue_duration_ms,
674 payload_type,
675 "error handling message"
676 )
677 }
678 Ok(()) => tracing::info!(
679 total_duration_ms,
680 processing_duration_ms,
681 queue_duration_ms,
682 "finished handling message"
683 ),
684 }
685 }
686 .boxed()
687 }),
688 );
689 if prev_handler.is_some() {
690 panic!("registered a handler for the same message twice");
691 }
692 self
693 }
694
695 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
696 where
697 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
698 Fut: 'static + Send + Future<Output = Result<()>>,
699 M: EnvelopedMessage,
700 {
701 self.add_handler(move |envelope, session| handler(envelope.payload, session));
702 self
703 }
704
705 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
706 where
707 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
708 Fut: Send + Future<Output = Result<()>>,
709 M: RequestMessage,
710 {
711 let handler = Arc::new(handler);
712 self.add_handler(move |envelope, session| {
713 let receipt = envelope.receipt();
714 let handler = handler.clone();
715 async move {
716 let peer = session.peer.clone();
717 let responded = Arc::new(AtomicBool::default());
718 let response = Response {
719 peer: peer.clone(),
720 responded: responded.clone(),
721 receipt,
722 };
723 match (handler)(envelope.payload, response, session).await {
724 Ok(()) => {
725 if responded.load(std::sync::atomic::Ordering::SeqCst) {
726 Ok(())
727 } else {
728 Err(anyhow!("handler did not send a response"))?
729 }
730 }
731 Err(error) => {
732 let proto_err = match &error {
733 Error::Internal(err) => err.to_proto(),
734 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
735 };
736 peer.respond_with_error(receipt, proto_err)?;
737 Err(error)
738 }
739 }
740 }
741 })
742 }
743
744 pub fn handle_connection(
745 self: &Arc<Self>,
746 connection: Connection,
747 address: String,
748 principal: Principal,
749 zed_version: ZedVersion,
750 geoip_country_code: Option<String>,
751 system_id: Option<String>,
752 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
753 executor: Executor,
754 connection_guard: Option<ConnectionGuard>,
755 ) -> impl Future<Output = ()> + use<> {
756 let this = self.clone();
757 let span = info_span!("handle connection", %address,
758 connection_id=field::Empty,
759 user_id=field::Empty,
760 login=field::Empty,
761 impersonator=field::Empty,
762 geoip_country_code=field::Empty
763 );
764 principal.update_span(&span);
765 if let Some(country_code) = geoip_country_code.as_ref() {
766 span.record("geoip_country_code", country_code);
767 }
768
769 let mut teardown = self.teardown.subscribe();
770 async move {
771 if *teardown.borrow() {
772 tracing::error!("server is tearing down");
773 return
774 }
775
776 let (connection_id, handle_io, mut incoming_rx) = this
777 .peer
778 .add_connection(connection, {
779 let executor = executor.clone();
780 move |duration| executor.sleep(duration)
781 });
782 tracing::Span::current().record("connection_id", format!("{}", connection_id));
783
784 tracing::info!("connection opened");
785
786 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
787 let http_client = match ReqwestClient::user_agent(&user_agent) {
788 Ok(http_client) => Arc::new(http_client),
789 Err(error) => {
790 tracing::error!(?error, "failed to create HTTP client");
791 return;
792 }
793 };
794
795 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
796 supermaven_admin_api_key.to_string(),
797 http_client.clone(),
798 )));
799
800 let session = Session {
801 principal: principal.clone(),
802 connection_id,
803 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
804 peer: this.peer.clone(),
805 connection_pool: this.connection_pool.clone(),
806 app_state: this.app_state.clone(),
807 geoip_country_code,
808 system_id,
809 _executor: executor.clone(),
810 supermaven_client,
811 };
812
813 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
814 tracing::error!(?error, "failed to send initial client update");
815 return;
816 }
817 drop(connection_guard);
818
819 let handle_io = handle_io.fuse();
820 futures::pin_mut!(handle_io);
821
822 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
823 // This prevents deadlocks when e.g., client A performs a request to client B and
824 // client B performs a request to client A. If both clients stop processing further
825 // messages until their respective request completes, they won't have a chance to
826 // respond to the other client's request and cause a deadlock.
827 //
828 // This arrangement ensures we will attempt to process earlier messages first, but fall
829 // back to processing messages arrived later in the spirit of making progress.
830 let mut foreground_message_handlers = FuturesUnordered::new();
831 let concurrent_handlers = Arc::new(Semaphore::new(256));
832 loop {
833 let next_message = async {
834 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
835 let message = incoming_rx.next().await;
836 (permit, message)
837 }.fuse();
838 futures::pin_mut!(next_message);
839 futures::select_biased! {
840 _ = teardown.changed().fuse() => return,
841 result = handle_io => {
842 if let Err(error) = result {
843 tracing::error!(?error, "error handling I/O");
844 }
845 break;
846 }
847 _ = foreground_message_handlers.next() => {}
848 next_message = next_message => {
849 let (permit, message) = next_message;
850 if let Some(message) = message {
851 let type_name = message.payload_type_name();
852 // note: we copy all the fields from the parent span so we can query them in the logs.
853 // (https://github.com/tokio-rs/tracing/issues/2670).
854 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
855 user_id=field::Empty,
856 login=field::Empty,
857 impersonator=field::Empty,
858 );
859 principal.update_span(&span);
860 let span_enter = span.enter();
861 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
862 let is_background = message.is_background();
863 let handle_message = (handler)(message, session.clone());
864 drop(span_enter);
865
866 let handle_message = async move {
867 handle_message.await;
868 drop(permit);
869 }.instrument(span);
870 if is_background {
871 executor.spawn_detached(handle_message);
872 } else {
873 foreground_message_handlers.push(handle_message);
874 }
875 } else {
876 tracing::error!("no message handler");
877 }
878 } else {
879 tracing::info!("connection closed");
880 break;
881 }
882 }
883 }
884 }
885
886 drop(foreground_message_handlers);
887 tracing::info!("signing out");
888 if let Err(error) = connection_lost(session, teardown, executor).await {
889 tracing::error!(?error, "error signing out");
890 }
891
892 }.instrument(span)
893 }
894
895 async fn send_initial_client_update(
896 &self,
897 connection_id: ConnectionId,
898 zed_version: ZedVersion,
899 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
900 session: &Session,
901 ) -> Result<()> {
902 self.peer.send(
903 connection_id,
904 proto::Hello {
905 peer_id: Some(connection_id.into()),
906 },
907 )?;
908 tracing::info!("sent hello message");
909 if let Some(send_connection_id) = send_connection_id.take() {
910 let _ = send_connection_id.send(connection_id);
911 }
912
913 match &session.principal {
914 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
915 if !user.connected_once {
916 self.peer.send(connection_id, proto::ShowContacts {})?;
917 self.app_state
918 .db
919 .set_user_connected_once(user.id, true)
920 .await?;
921 }
922
923 update_user_plan(session).await?;
924
925 let contacts = self.app_state.db.get_contacts(user.id).await?;
926
927 {
928 let mut pool = self.connection_pool.lock();
929 pool.add_connection(connection_id, user.id, user.admin, zed_version);
930 self.peer.send(
931 connection_id,
932 build_initial_contacts_update(contacts, &pool),
933 )?;
934 }
935
936 if should_auto_subscribe_to_channels(zed_version) {
937 subscribe_user_to_channels(user.id, session).await?;
938 }
939
940 if let Some(incoming_call) =
941 self.app_state.db.incoming_call_for_user(user.id).await?
942 {
943 self.peer.send(connection_id, incoming_call)?;
944 }
945
946 update_user_contacts(user.id, session).await?;
947 }
948 }
949
950 Ok(())
951 }
952
953 pub async fn invite_code_redeemed(
954 self: &Arc<Self>,
955 inviter_id: UserId,
956 invitee_id: UserId,
957 ) -> Result<()> {
958 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
959 if let Some(code) = &user.invite_code {
960 let pool = self.connection_pool.lock();
961 let invitee_contact = contact_for_user(invitee_id, false, &pool);
962 for connection_id in pool.user_connection_ids(inviter_id) {
963 self.peer.send(
964 connection_id,
965 proto::UpdateContacts {
966 contacts: vec![invitee_contact.clone()],
967 ..Default::default()
968 },
969 )?;
970 self.peer.send(
971 connection_id,
972 proto::UpdateInviteInfo {
973 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
974 count: user.invite_count as u32,
975 },
976 )?;
977 }
978 }
979 }
980 Ok(())
981 }
982
983 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
984 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
985 if let Some(invite_code) = &user.invite_code {
986 let pool = self.connection_pool.lock();
987 for connection_id in pool.user_connection_ids(user_id) {
988 self.peer.send(
989 connection_id,
990 proto::UpdateInviteInfo {
991 url: format!(
992 "{}{}",
993 self.app_state.config.invite_link_prefix, invite_code
994 ),
995 count: user.invite_count as u32,
996 },
997 )?;
998 }
999 }
1000 }
1001 Ok(())
1002 }
1003
1004 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1005 let user = self
1006 .app_state
1007 .db
1008 .get_user_by_id(user_id)
1009 .await?
1010 .context("user not found")?;
1011
1012 let update_user_plan = make_update_user_plan_message(
1013 &user,
1014 user.admin,
1015 &self.app_state.db,
1016 self.app_state.llm_db.clone(),
1017 )
1018 .await?;
1019
1020 let pool = self.connection_pool.lock();
1021 for connection_id in pool.user_connection_ids(user_id) {
1022 self.peer
1023 .send(connection_id, update_user_plan.clone())
1024 .trace_err();
1025 }
1026
1027 Ok(())
1028 }
1029
1030 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1031 let pool = self.connection_pool.lock();
1032 for connection_id in pool.user_connection_ids(user_id) {
1033 self.peer
1034 .send(connection_id, proto::RefreshLlmToken {})
1035 .trace_err();
1036 }
1037 }
1038
1039 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1040 ServerSnapshot {
1041 connection_pool: ConnectionPoolGuard {
1042 guard: self.connection_pool.lock(),
1043 _not_send: PhantomData,
1044 },
1045 peer: &self.peer,
1046 }
1047 }
1048}
1049
1050impl Deref for ConnectionPoolGuard<'_> {
1051 type Target = ConnectionPool;
1052
1053 fn deref(&self) -> &Self::Target {
1054 &self.guard
1055 }
1056}
1057
1058impl DerefMut for ConnectionPoolGuard<'_> {
1059 fn deref_mut(&mut self) -> &mut Self::Target {
1060 &mut self.guard
1061 }
1062}
1063
1064impl Drop for ConnectionPoolGuard<'_> {
1065 fn drop(&mut self) {
1066 #[cfg(test)]
1067 self.check_invariants();
1068 }
1069}
1070
1071fn broadcast<F>(
1072 sender_id: Option<ConnectionId>,
1073 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1074 mut f: F,
1075) where
1076 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1077{
1078 for receiver_id in receiver_ids {
1079 if Some(receiver_id) != sender_id {
1080 if let Err(error) = f(receiver_id) {
1081 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1082 }
1083 }
1084 }
1085}
1086
1087pub struct ProtocolVersion(u32);
1088
1089impl Header for ProtocolVersion {
1090 fn name() -> &'static HeaderName {
1091 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1092 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1093 }
1094
1095 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1096 where
1097 Self: Sized,
1098 I: Iterator<Item = &'i axum::http::HeaderValue>,
1099 {
1100 let version = values
1101 .next()
1102 .ok_or_else(axum::headers::Error::invalid)?
1103 .to_str()
1104 .map_err(|_| axum::headers::Error::invalid())?
1105 .parse()
1106 .map_err(|_| axum::headers::Error::invalid())?;
1107 Ok(Self(version))
1108 }
1109
1110 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1111 values.extend([self.0.to_string().parse().unwrap()]);
1112 }
1113}
1114
1115pub struct AppVersionHeader(SemanticVersion);
1116impl Header for AppVersionHeader {
1117 fn name() -> &'static HeaderName {
1118 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1120 }
1121
1122 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123 where
1124 Self: Sized,
1125 I: Iterator<Item = &'i axum::http::HeaderValue>,
1126 {
1127 let version = values
1128 .next()
1129 .ok_or_else(axum::headers::Error::invalid)?
1130 .to_str()
1131 .map_err(|_| axum::headers::Error::invalid())?
1132 .parse()
1133 .map_err(|_| axum::headers::Error::invalid())?;
1134 Ok(Self(version))
1135 }
1136
1137 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138 values.extend([self.0.to_string().parse().unwrap()]);
1139 }
1140}
1141
1142pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1143 Router::new()
1144 .route("/rpc", get(handle_websocket_request))
1145 .layer(
1146 ServiceBuilder::new()
1147 .layer(Extension(server.app_state.clone()))
1148 .layer(middleware::from_fn(auth::validate_header)),
1149 )
1150 .route("/metrics", get(handle_metrics))
1151 .layer(Extension(server))
1152}
1153
1154pub async fn handle_websocket_request(
1155 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1156 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1157 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1158 Extension(server): Extension<Arc<Server>>,
1159 Extension(principal): Extension<Principal>,
1160 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1161 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1162 ws: WebSocketUpgrade,
1163) -> axum::response::Response {
1164 if protocol_version != rpc::PROTOCOL_VERSION {
1165 return (
1166 StatusCode::UPGRADE_REQUIRED,
1167 "client must be upgraded".to_string(),
1168 )
1169 .into_response();
1170 }
1171
1172 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1173 return (
1174 StatusCode::UPGRADE_REQUIRED,
1175 "no version header found".to_string(),
1176 )
1177 .into_response();
1178 };
1179
1180 if !version.can_collaborate() {
1181 return (
1182 StatusCode::UPGRADE_REQUIRED,
1183 "client must be upgraded".to_string(),
1184 )
1185 .into_response();
1186 }
1187
1188 let socket_address = socket_address.to_string();
1189
1190 // Acquire connection guard before WebSocket upgrade
1191 let connection_guard = match ConnectionGuard::try_acquire() {
1192 Ok(guard) => guard,
1193 Err(()) => {
1194 return (
1195 StatusCode::SERVICE_UNAVAILABLE,
1196 "Too many concurrent connections",
1197 )
1198 .into_response();
1199 }
1200 };
1201
1202 ws.on_upgrade(move |socket| {
1203 let socket = socket
1204 .map_ok(to_tungstenite_message)
1205 .err_into()
1206 .with(|message| async move { to_axum_message(message) });
1207 let connection = Connection::new(Box::pin(socket));
1208 async move {
1209 server
1210 .handle_connection(
1211 connection,
1212 socket_address,
1213 principal,
1214 version,
1215 country_code_header.map(|header| header.to_string()),
1216 system_id_header.map(|header| header.to_string()),
1217 None,
1218 Executor::Production,
1219 Some(connection_guard),
1220 )
1221 .await;
1222 }
1223 })
1224}
1225
1226pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1227 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1228 let connections_metric = CONNECTIONS_METRIC
1229 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1230
1231 let connections = server
1232 .connection_pool
1233 .lock()
1234 .connections()
1235 .filter(|connection| !connection.admin)
1236 .count();
1237 connections_metric.set(connections as _);
1238
1239 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1240 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1241 register_int_gauge!(
1242 "shared_projects",
1243 "number of open projects with one or more guests"
1244 )
1245 .unwrap()
1246 });
1247
1248 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1249 shared_projects_metric.set(shared_projects as _);
1250
1251 let encoder = prometheus::TextEncoder::new();
1252 let metric_families = prometheus::gather();
1253 let encoded_metrics = encoder
1254 .encode_to_string(&metric_families)
1255 .map_err(|err| anyhow!("{err}"))?;
1256 Ok(encoded_metrics)
1257}
1258
1259#[instrument(err, skip(executor))]
1260async fn connection_lost(
1261 session: Session,
1262 mut teardown: watch::Receiver<bool>,
1263 executor: Executor,
1264) -> Result<()> {
1265 session.peer.disconnect(session.connection_id);
1266 session
1267 .connection_pool()
1268 .await
1269 .remove_connection(session.connection_id)?;
1270
1271 session
1272 .db()
1273 .await
1274 .connection_lost(session.connection_id)
1275 .await
1276 .trace_err();
1277
1278 futures::select_biased! {
1279 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1280
1281 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1282 leave_room_for_session(&session, session.connection_id).await.trace_err();
1283 leave_channel_buffers_for_session(&session)
1284 .await
1285 .trace_err();
1286
1287 if !session
1288 .connection_pool()
1289 .await
1290 .is_user_online(session.user_id())
1291 {
1292 let db = session.db().await;
1293 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1294 room_updated(&room, &session.peer);
1295 }
1296 }
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 },
1300 _ = teardown.changed().fuse() => {}
1301 }
1302
1303 Ok(())
1304}
1305
1306/// Acknowledges a ping from a client, used to keep the connection alive.
1307async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1308 response.send(proto::Ack {})?;
1309 Ok(())
1310}
1311
1312/// Creates a new room for calling (outside of channels)
1313async fn create_room(
1314 _request: proto::CreateRoom,
1315 response: Response<proto::CreateRoom>,
1316 session: Session,
1317) -> Result<()> {
1318 let livekit_room = nanoid::nanoid!(30);
1319
1320 let live_kit_connection_info = util::maybe!(async {
1321 let live_kit = session.app_state.livekit_client.as_ref();
1322 let live_kit = live_kit?;
1323 let user_id = session.user_id().to_string();
1324
1325 let token = live_kit
1326 .room_token(&livekit_room, &user_id.to_string())
1327 .trace_err()?;
1328
1329 Some(proto::LiveKitConnectionInfo {
1330 server_url: live_kit.url().into(),
1331 token,
1332 can_publish: true,
1333 })
1334 })
1335 .await;
1336
1337 let room = session
1338 .db()
1339 .await
1340 .create_room(session.user_id(), session.connection_id, &livekit_room)
1341 .await?;
1342
1343 response.send(proto::CreateRoomResponse {
1344 room: Some(room.clone()),
1345 live_kit_connection_info,
1346 })?;
1347
1348 update_user_contacts(session.user_id(), &session).await?;
1349 Ok(())
1350}
1351
1352/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1353async fn join_room(
1354 request: proto::JoinRoom,
1355 response: Response<proto::JoinRoom>,
1356 session: Session,
1357) -> Result<()> {
1358 let room_id = RoomId::from_proto(request.id);
1359
1360 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1361
1362 if let Some(channel_id) = channel_id {
1363 return join_channel_internal(channel_id, Box::new(response), session).await;
1364 }
1365
1366 let joined_room = {
1367 let room = session
1368 .db()
1369 .await
1370 .join_room(room_id, session.user_id(), session.connection_id)
1371 .await?;
1372 room_updated(&room.room, &session.peer);
1373 room.into_inner()
1374 };
1375
1376 for connection_id in session
1377 .connection_pool()
1378 .await
1379 .user_connection_ids(session.user_id())
1380 {
1381 session
1382 .peer
1383 .send(
1384 connection_id,
1385 proto::CallCanceled {
1386 room_id: room_id.to_proto(),
1387 },
1388 )
1389 .trace_err();
1390 }
1391
1392 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1393 {
1394 live_kit
1395 .room_token(
1396 &joined_room.room.livekit_room,
1397 &session.user_id().to_string(),
1398 )
1399 .trace_err()
1400 .map(|token| proto::LiveKitConnectionInfo {
1401 server_url: live_kit.url().into(),
1402 token,
1403 can_publish: true,
1404 })
1405 } else {
1406 None
1407 };
1408
1409 response.send(proto::JoinRoomResponse {
1410 room: Some(joined_room.room),
1411 channel_id: None,
1412 live_kit_connection_info,
1413 })?;
1414
1415 update_user_contacts(session.user_id(), &session).await?;
1416 Ok(())
1417}
1418
1419/// Rejoin room is used to reconnect to a room after connection errors.
1420async fn rejoin_room(
1421 request: proto::RejoinRoom,
1422 response: Response<proto::RejoinRoom>,
1423 session: Session,
1424) -> Result<()> {
1425 let room;
1426 let channel;
1427 {
1428 let mut rejoined_room = session
1429 .db()
1430 .await
1431 .rejoin_room(request, session.user_id(), session.connection_id)
1432 .await?;
1433
1434 response.send(proto::RejoinRoomResponse {
1435 room: Some(rejoined_room.room.clone()),
1436 reshared_projects: rejoined_room
1437 .reshared_projects
1438 .iter()
1439 .map(|project| proto::ResharedProject {
1440 id: project.id.to_proto(),
1441 collaborators: project
1442 .collaborators
1443 .iter()
1444 .map(|collaborator| collaborator.to_proto())
1445 .collect(),
1446 })
1447 .collect(),
1448 rejoined_projects: rejoined_room
1449 .rejoined_projects
1450 .iter()
1451 .map(|rejoined_project| rejoined_project.to_proto())
1452 .collect(),
1453 })?;
1454 room_updated(&rejoined_room.room, &session.peer);
1455
1456 for project in &rejoined_room.reshared_projects {
1457 for collaborator in &project.collaborators {
1458 session
1459 .peer
1460 .send(
1461 collaborator.connection_id,
1462 proto::UpdateProjectCollaborator {
1463 project_id: project.id.to_proto(),
1464 old_peer_id: Some(project.old_connection_id.into()),
1465 new_peer_id: Some(session.connection_id.into()),
1466 },
1467 )
1468 .trace_err();
1469 }
1470
1471 broadcast(
1472 Some(session.connection_id),
1473 project
1474 .collaborators
1475 .iter()
1476 .map(|collaborator| collaborator.connection_id),
1477 |connection_id| {
1478 session.peer.forward_send(
1479 session.connection_id,
1480 connection_id,
1481 proto::UpdateProject {
1482 project_id: project.id.to_proto(),
1483 worktrees: project.worktrees.clone(),
1484 },
1485 )
1486 },
1487 );
1488 }
1489
1490 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1491
1492 let rejoined_room = rejoined_room.into_inner();
1493
1494 room = rejoined_room.room;
1495 channel = rejoined_room.channel;
1496 }
1497
1498 if let Some(channel) = channel {
1499 channel_updated(
1500 &channel,
1501 &room,
1502 &session.peer,
1503 &*session.connection_pool().await,
1504 );
1505 }
1506
1507 update_user_contacts(session.user_id(), &session).await?;
1508 Ok(())
1509}
1510
1511fn notify_rejoined_projects(
1512 rejoined_projects: &mut Vec<RejoinedProject>,
1513 session: &Session,
1514) -> Result<()> {
1515 for project in rejoined_projects.iter() {
1516 for collaborator in &project.collaborators {
1517 session
1518 .peer
1519 .send(
1520 collaborator.connection_id,
1521 proto::UpdateProjectCollaborator {
1522 project_id: project.id.to_proto(),
1523 old_peer_id: Some(project.old_connection_id.into()),
1524 new_peer_id: Some(session.connection_id.into()),
1525 },
1526 )
1527 .trace_err();
1528 }
1529 }
1530
1531 for project in rejoined_projects {
1532 for worktree in mem::take(&mut project.worktrees) {
1533 // Stream this worktree's entries.
1534 let message = proto::UpdateWorktree {
1535 project_id: project.id.to_proto(),
1536 worktree_id: worktree.id,
1537 abs_path: worktree.abs_path.clone(),
1538 root_name: worktree.root_name,
1539 updated_entries: worktree.updated_entries,
1540 removed_entries: worktree.removed_entries,
1541 scan_id: worktree.scan_id,
1542 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1543 updated_repositories: worktree.updated_repositories,
1544 removed_repositories: worktree.removed_repositories,
1545 };
1546 for update in proto::split_worktree_update(message) {
1547 session.peer.send(session.connection_id, update)?;
1548 }
1549
1550 // Stream this worktree's diagnostics.
1551 for summary in worktree.diagnostic_summaries {
1552 session.peer.send(
1553 session.connection_id,
1554 proto::UpdateDiagnosticSummary {
1555 project_id: project.id.to_proto(),
1556 worktree_id: worktree.id,
1557 summary: Some(summary),
1558 },
1559 )?;
1560 }
1561
1562 for settings_file in worktree.settings_files {
1563 session.peer.send(
1564 session.connection_id,
1565 proto::UpdateWorktreeSettings {
1566 project_id: project.id.to_proto(),
1567 worktree_id: worktree.id,
1568 path: settings_file.path,
1569 content: Some(settings_file.content),
1570 kind: Some(settings_file.kind.to_proto().into()),
1571 },
1572 )?;
1573 }
1574 }
1575
1576 for repository in mem::take(&mut project.updated_repositories) {
1577 for update in split_repository_update(repository) {
1578 session.peer.send(session.connection_id, update)?;
1579 }
1580 }
1581
1582 for id in mem::take(&mut project.removed_repositories) {
1583 session.peer.send(
1584 session.connection_id,
1585 proto::RemoveRepository {
1586 project_id: project.id.to_proto(),
1587 id,
1588 },
1589 )?;
1590 }
1591 }
1592
1593 Ok(())
1594}
1595
1596/// leave room disconnects from the room.
1597async fn leave_room(
1598 _: proto::LeaveRoom,
1599 response: Response<proto::LeaveRoom>,
1600 session: Session,
1601) -> Result<()> {
1602 leave_room_for_session(&session, session.connection_id).await?;
1603 response.send(proto::Ack {})?;
1604 Ok(())
1605}
1606
1607/// Updates the permissions of someone else in the room.
1608async fn set_room_participant_role(
1609 request: proto::SetRoomParticipantRole,
1610 response: Response<proto::SetRoomParticipantRole>,
1611 session: Session,
1612) -> Result<()> {
1613 let user_id = UserId::from_proto(request.user_id);
1614 let role = ChannelRole::from(request.role());
1615
1616 let (livekit_room, can_publish) = {
1617 let room = session
1618 .db()
1619 .await
1620 .set_room_participant_role(
1621 session.user_id(),
1622 RoomId::from_proto(request.room_id),
1623 user_id,
1624 role,
1625 )
1626 .await?;
1627
1628 let livekit_room = room.livekit_room.clone();
1629 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1630 room_updated(&room, &session.peer);
1631 (livekit_room, can_publish)
1632 };
1633
1634 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1635 live_kit
1636 .update_participant(
1637 livekit_room.clone(),
1638 request.user_id.to_string(),
1639 livekit_api::proto::ParticipantPermission {
1640 can_subscribe: true,
1641 can_publish,
1642 can_publish_data: can_publish,
1643 hidden: false,
1644 recorder: false,
1645 },
1646 )
1647 .await
1648 .trace_err();
1649 }
1650
1651 response.send(proto::Ack {})?;
1652 Ok(())
1653}
1654
1655/// Call someone else into the current room
1656async fn call(
1657 request: proto::Call,
1658 response: Response<proto::Call>,
1659 session: Session,
1660) -> Result<()> {
1661 let room_id = RoomId::from_proto(request.room_id);
1662 let calling_user_id = session.user_id();
1663 let calling_connection_id = session.connection_id;
1664 let called_user_id = UserId::from_proto(request.called_user_id);
1665 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1666 if !session
1667 .db()
1668 .await
1669 .has_contact(calling_user_id, called_user_id)
1670 .await?
1671 {
1672 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1673 }
1674
1675 let incoming_call = {
1676 let (room, incoming_call) = &mut *session
1677 .db()
1678 .await
1679 .call(
1680 room_id,
1681 calling_user_id,
1682 calling_connection_id,
1683 called_user_id,
1684 initial_project_id,
1685 )
1686 .await?;
1687 room_updated(room, &session.peer);
1688 mem::take(incoming_call)
1689 };
1690 update_user_contacts(called_user_id, &session).await?;
1691
1692 let mut calls = session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(called_user_id)
1696 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1697 .collect::<FuturesUnordered<_>>();
1698
1699 while let Some(call_response) = calls.next().await {
1700 match call_response.as_ref() {
1701 Ok(_) => {
1702 response.send(proto::Ack {})?;
1703 return Ok(());
1704 }
1705 Err(_) => {
1706 call_response.trace_err();
1707 }
1708 }
1709 }
1710
1711 {
1712 let room = session
1713 .db()
1714 .await
1715 .call_failed(room_id, called_user_id)
1716 .await?;
1717 room_updated(&room, &session.peer);
1718 }
1719 update_user_contacts(called_user_id, &session).await?;
1720
1721 Err(anyhow!("failed to ring user"))?
1722}
1723
1724/// Cancel an outgoing call.
1725async fn cancel_call(
1726 request: proto::CancelCall,
1727 response: Response<proto::CancelCall>,
1728 session: Session,
1729) -> Result<()> {
1730 let called_user_id = UserId::from_proto(request.called_user_id);
1731 let room_id = RoomId::from_proto(request.room_id);
1732 {
1733 let room = session
1734 .db()
1735 .await
1736 .cancel_call(room_id, session.connection_id, called_user_id)
1737 .await?;
1738 room_updated(&room, &session.peer);
1739 }
1740
1741 for connection_id in session
1742 .connection_pool()
1743 .await
1744 .user_connection_ids(called_user_id)
1745 {
1746 session
1747 .peer
1748 .send(
1749 connection_id,
1750 proto::CallCanceled {
1751 room_id: room_id.to_proto(),
1752 },
1753 )
1754 .trace_err();
1755 }
1756 response.send(proto::Ack {})?;
1757
1758 update_user_contacts(called_user_id, &session).await?;
1759 Ok(())
1760}
1761
1762/// Decline an incoming call.
1763async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1764 let room_id = RoomId::from_proto(message.room_id);
1765 {
1766 let room = session
1767 .db()
1768 .await
1769 .decline_call(Some(room_id), session.user_id())
1770 .await?
1771 .context("declining call")?;
1772 room_updated(&room, &session.peer);
1773 }
1774
1775 for connection_id in session
1776 .connection_pool()
1777 .await
1778 .user_connection_ids(session.user_id())
1779 {
1780 session
1781 .peer
1782 .send(
1783 connection_id,
1784 proto::CallCanceled {
1785 room_id: room_id.to_proto(),
1786 },
1787 )
1788 .trace_err();
1789 }
1790 update_user_contacts(session.user_id(), &session).await?;
1791 Ok(())
1792}
1793
1794/// Updates other participants in the room with your current location.
1795async fn update_participant_location(
1796 request: proto::UpdateParticipantLocation,
1797 response: Response<proto::UpdateParticipantLocation>,
1798 session: Session,
1799) -> Result<()> {
1800 let room_id = RoomId::from_proto(request.room_id);
1801 let location = request.location.context("invalid location")?;
1802
1803 let db = session.db().await;
1804 let room = db
1805 .update_room_participant_location(room_id, session.connection_id, location)
1806 .await?;
1807
1808 room_updated(&room, &session.peer);
1809 response.send(proto::Ack {})?;
1810 Ok(())
1811}
1812
1813/// Share a project into the room.
1814async fn share_project(
1815 request: proto::ShareProject,
1816 response: Response<proto::ShareProject>,
1817 session: Session,
1818) -> Result<()> {
1819 let (project_id, room) = &*session
1820 .db()
1821 .await
1822 .share_project(
1823 RoomId::from_proto(request.room_id),
1824 session.connection_id,
1825 &request.worktrees,
1826 request.is_ssh_project,
1827 )
1828 .await?;
1829 response.send(proto::ShareProjectResponse {
1830 project_id: project_id.to_proto(),
1831 })?;
1832 room_updated(room, &session.peer);
1833
1834 Ok(())
1835}
1836
1837/// Unshare a project from the room.
1838async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1839 let project_id = ProjectId::from_proto(message.project_id);
1840 unshare_project_internal(project_id, session.connection_id, &session).await
1841}
1842
1843async fn unshare_project_internal(
1844 project_id: ProjectId,
1845 connection_id: ConnectionId,
1846 session: &Session,
1847) -> Result<()> {
1848 let delete = {
1849 let room_guard = session
1850 .db()
1851 .await
1852 .unshare_project(project_id, connection_id)
1853 .await?;
1854
1855 let (delete, room, guest_connection_ids) = &*room_guard;
1856
1857 let message = proto::UnshareProject {
1858 project_id: project_id.to_proto(),
1859 };
1860
1861 broadcast(
1862 Some(connection_id),
1863 guest_connection_ids.iter().copied(),
1864 |conn_id| session.peer.send(conn_id, message.clone()),
1865 );
1866 if let Some(room) = room {
1867 room_updated(room, &session.peer);
1868 }
1869
1870 *delete
1871 };
1872
1873 if delete {
1874 let db = session.db().await;
1875 db.delete_project(project_id).await?;
1876 }
1877
1878 Ok(())
1879}
1880
1881/// Join someone elses shared project.
1882async fn join_project(
1883 request: proto::JoinProject,
1884 response: Response<proto::JoinProject>,
1885 session: Session,
1886) -> Result<()> {
1887 let project_id = ProjectId::from_proto(request.project_id);
1888
1889 tracing::info!(%project_id, "join project");
1890
1891 let db = session.db().await;
1892 let (project, replica_id) = &mut *db
1893 .join_project(project_id, session.connection_id, session.user_id())
1894 .await?;
1895 drop(db);
1896 tracing::info!(%project_id, "join remote project");
1897 join_project_internal(response, session, project, replica_id)
1898}
1899
1900trait JoinProjectInternalResponse {
1901 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1902}
1903impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1904 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1905 Response::<proto::JoinProject>::send(self, result)
1906 }
1907}
1908
1909fn join_project_internal(
1910 response: impl JoinProjectInternalResponse,
1911 session: Session,
1912 project: &mut Project,
1913 replica_id: &ReplicaId,
1914) -> Result<()> {
1915 let collaborators = project
1916 .collaborators
1917 .iter()
1918 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1919 .map(|collaborator| collaborator.to_proto())
1920 .collect::<Vec<_>>();
1921 let project_id = project.id;
1922 let guest_user_id = session.user_id();
1923
1924 let worktrees = project
1925 .worktrees
1926 .iter()
1927 .map(|(id, worktree)| proto::WorktreeMetadata {
1928 id: *id,
1929 root_name: worktree.root_name.clone(),
1930 visible: worktree.visible,
1931 abs_path: worktree.abs_path.clone(),
1932 })
1933 .collect::<Vec<_>>();
1934
1935 let add_project_collaborator = proto::AddProjectCollaborator {
1936 project_id: project_id.to_proto(),
1937 collaborator: Some(proto::Collaborator {
1938 peer_id: Some(session.connection_id.into()),
1939 replica_id: replica_id.0 as u32,
1940 user_id: guest_user_id.to_proto(),
1941 is_host: false,
1942 }),
1943 };
1944
1945 for collaborator in &collaborators {
1946 session
1947 .peer
1948 .send(
1949 collaborator.peer_id.unwrap().into(),
1950 add_project_collaborator.clone(),
1951 )
1952 .trace_err();
1953 }
1954
1955 // First, we send the metadata associated with each worktree.
1956 response.send(proto::JoinProjectResponse {
1957 project_id: project.id.0 as u64,
1958 worktrees: worktrees.clone(),
1959 replica_id: replica_id.0 as u32,
1960 collaborators: collaborators.clone(),
1961 language_servers: project.language_servers.clone(),
1962 role: project.role.into(),
1963 })?;
1964
1965 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1966 // Stream this worktree's entries.
1967 let message = proto::UpdateWorktree {
1968 project_id: project_id.to_proto(),
1969 worktree_id,
1970 abs_path: worktree.abs_path.clone(),
1971 root_name: worktree.root_name,
1972 updated_entries: worktree.entries,
1973 removed_entries: Default::default(),
1974 scan_id: worktree.scan_id,
1975 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1976 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1977 removed_repositories: Default::default(),
1978 };
1979 for update in proto::split_worktree_update(message) {
1980 session.peer.send(session.connection_id, update.clone())?;
1981 }
1982
1983 // Stream this worktree's diagnostics.
1984 for summary in worktree.diagnostic_summaries {
1985 session.peer.send(
1986 session.connection_id,
1987 proto::UpdateDiagnosticSummary {
1988 project_id: project_id.to_proto(),
1989 worktree_id: worktree.id,
1990 summary: Some(summary),
1991 },
1992 )?;
1993 }
1994
1995 for settings_file in worktree.settings_files {
1996 session.peer.send(
1997 session.connection_id,
1998 proto::UpdateWorktreeSettings {
1999 project_id: project_id.to_proto(),
2000 worktree_id: worktree.id,
2001 path: settings_file.path,
2002 content: Some(settings_file.content),
2003 kind: Some(settings_file.kind.to_proto() as i32),
2004 },
2005 )?;
2006 }
2007 }
2008
2009 for repository in mem::take(&mut project.repositories) {
2010 for update in split_repository_update(repository) {
2011 session.peer.send(session.connection_id, update)?;
2012 }
2013 }
2014
2015 for language_server in &project.language_servers {
2016 session.peer.send(
2017 session.connection_id,
2018 proto::UpdateLanguageServer {
2019 project_id: project_id.to_proto(),
2020 language_server_id: language_server.id,
2021 variant: Some(
2022 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2023 proto::LspDiskBasedDiagnosticsUpdated {},
2024 ),
2025 ),
2026 },
2027 )?;
2028 }
2029
2030 Ok(())
2031}
2032
2033/// Leave someone elses shared project.
2034async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2035 let sender_id = session.connection_id;
2036 let project_id = ProjectId::from_proto(request.project_id);
2037 let db = session.db().await;
2038
2039 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2040 tracing::info!(
2041 %project_id,
2042 "leave project"
2043 );
2044
2045 project_left(project, &session);
2046 if let Some(room) = room {
2047 room_updated(room, &session.peer);
2048 }
2049
2050 Ok(())
2051}
2052
2053/// Updates other participants with changes to the project
2054async fn update_project(
2055 request: proto::UpdateProject,
2056 response: Response<proto::UpdateProject>,
2057 session: Session,
2058) -> Result<()> {
2059 let project_id = ProjectId::from_proto(request.project_id);
2060 let (room, guest_connection_ids) = &*session
2061 .db()
2062 .await
2063 .update_project(project_id, session.connection_id, &request.worktrees)
2064 .await?;
2065 broadcast(
2066 Some(session.connection_id),
2067 guest_connection_ids.iter().copied(),
2068 |connection_id| {
2069 session
2070 .peer
2071 .forward_send(session.connection_id, connection_id, request.clone())
2072 },
2073 );
2074 if let Some(room) = room {
2075 room_updated(room, &session.peer);
2076 }
2077 response.send(proto::Ack {})?;
2078
2079 Ok(())
2080}
2081
2082/// Updates other participants with changes to the worktree
2083async fn update_worktree(
2084 request: proto::UpdateWorktree,
2085 response: Response<proto::UpdateWorktree>,
2086 session: Session,
2087) -> Result<()> {
2088 let guest_connection_ids = session
2089 .db()
2090 .await
2091 .update_worktree(&request, session.connection_id)
2092 .await?;
2093
2094 broadcast(
2095 Some(session.connection_id),
2096 guest_connection_ids.iter().copied(),
2097 |connection_id| {
2098 session
2099 .peer
2100 .forward_send(session.connection_id, connection_id, request.clone())
2101 },
2102 );
2103 response.send(proto::Ack {})?;
2104 Ok(())
2105}
2106
2107async fn update_repository(
2108 request: proto::UpdateRepository,
2109 response: Response<proto::UpdateRepository>,
2110 session: Session,
2111) -> Result<()> {
2112 let guest_connection_ids = session
2113 .db()
2114 .await
2115 .update_repository(&request, session.connection_id)
2116 .await?;
2117
2118 broadcast(
2119 Some(session.connection_id),
2120 guest_connection_ids.iter().copied(),
2121 |connection_id| {
2122 session
2123 .peer
2124 .forward_send(session.connection_id, connection_id, request.clone())
2125 },
2126 );
2127 response.send(proto::Ack {})?;
2128 Ok(())
2129}
2130
2131async fn remove_repository(
2132 request: proto::RemoveRepository,
2133 response: Response<proto::RemoveRepository>,
2134 session: Session,
2135) -> Result<()> {
2136 let guest_connection_ids = session
2137 .db()
2138 .await
2139 .remove_repository(&request, session.connection_id)
2140 .await?;
2141
2142 broadcast(
2143 Some(session.connection_id),
2144 guest_connection_ids.iter().copied(),
2145 |connection_id| {
2146 session
2147 .peer
2148 .forward_send(session.connection_id, connection_id, request.clone())
2149 },
2150 );
2151 response.send(proto::Ack {})?;
2152 Ok(())
2153}
2154
2155/// Updates other participants with changes to the diagnostics
2156async fn update_diagnostic_summary(
2157 message: proto::UpdateDiagnosticSummary,
2158 session: Session,
2159) -> Result<()> {
2160 let guest_connection_ids = session
2161 .db()
2162 .await
2163 .update_diagnostic_summary(&message, session.connection_id)
2164 .await?;
2165
2166 broadcast(
2167 Some(session.connection_id),
2168 guest_connection_ids.iter().copied(),
2169 |connection_id| {
2170 session
2171 .peer
2172 .forward_send(session.connection_id, connection_id, message.clone())
2173 },
2174 );
2175
2176 Ok(())
2177}
2178
2179/// Updates other participants with changes to the worktree settings
2180async fn update_worktree_settings(
2181 message: proto::UpdateWorktreeSettings,
2182 session: Session,
2183) -> Result<()> {
2184 let guest_connection_ids = session
2185 .db()
2186 .await
2187 .update_worktree_settings(&message, session.connection_id)
2188 .await?;
2189
2190 broadcast(
2191 Some(session.connection_id),
2192 guest_connection_ids.iter().copied(),
2193 |connection_id| {
2194 session
2195 .peer
2196 .forward_send(session.connection_id, connection_id, message.clone())
2197 },
2198 );
2199
2200 Ok(())
2201}
2202
2203/// Notify other participants that a language server has started.
2204async fn start_language_server(
2205 request: proto::StartLanguageServer,
2206 session: Session,
2207) -> Result<()> {
2208 let guest_connection_ids = session
2209 .db()
2210 .await
2211 .start_language_server(&request, session.connection_id)
2212 .await?;
2213
2214 broadcast(
2215 Some(session.connection_id),
2216 guest_connection_ids.iter().copied(),
2217 |connection_id| {
2218 session
2219 .peer
2220 .forward_send(session.connection_id, connection_id, request.clone())
2221 },
2222 );
2223 Ok(())
2224}
2225
2226/// Notify other participants that a language server has changed.
2227async fn update_language_server(
2228 request: proto::UpdateLanguageServer,
2229 session: Session,
2230) -> Result<()> {
2231 let project_id = ProjectId::from_proto(request.project_id);
2232 let project_connection_ids = session
2233 .db()
2234 .await
2235 .project_connection_ids(project_id, session.connection_id, true)
2236 .await?;
2237 broadcast(
2238 Some(session.connection_id),
2239 project_connection_ids.iter().copied(),
2240 |connection_id| {
2241 session
2242 .peer
2243 .forward_send(session.connection_id, connection_id, request.clone())
2244 },
2245 );
2246 Ok(())
2247}
2248
2249/// forward a project request to the host. These requests should be read only
2250/// as guests are allowed to send them.
2251async fn forward_read_only_project_request<T>(
2252 request: T,
2253 response: Response<T>,
2254 session: Session,
2255) -> Result<()>
2256where
2257 T: EntityMessage + RequestMessage,
2258{
2259 let project_id = ProjectId::from_proto(request.remote_entity_id());
2260 let host_connection_id = session
2261 .db()
2262 .await
2263 .host_for_read_only_project_request(project_id, session.connection_id)
2264 .await?;
2265 let payload = session
2266 .peer
2267 .forward_request(session.connection_id, host_connection_id, request)
2268 .await?;
2269 response.send(payload)?;
2270 Ok(())
2271}
2272
2273async fn forward_find_search_candidates_request(
2274 request: proto::FindSearchCandidates,
2275 response: Response<proto::FindSearchCandidates>,
2276 session: Session,
2277) -> Result<()> {
2278 let project_id = ProjectId::from_proto(request.remote_entity_id());
2279 let host_connection_id = session
2280 .db()
2281 .await
2282 .host_for_read_only_project_request(project_id, session.connection_id)
2283 .await?;
2284 let payload = session
2285 .peer
2286 .forward_request(session.connection_id, host_connection_id, request)
2287 .await?;
2288 response.send(payload)?;
2289 Ok(())
2290}
2291
2292/// forward a project request to the host. These requests are disallowed
2293/// for guests.
2294async fn forward_mutating_project_request<T>(
2295 request: T,
2296 response: Response<T>,
2297 session: Session,
2298) -> Result<()>
2299where
2300 T: EntityMessage + RequestMessage,
2301{
2302 let project_id = ProjectId::from_proto(request.remote_entity_id());
2303
2304 let host_connection_id = session
2305 .db()
2306 .await
2307 .host_for_mutating_project_request(project_id, session.connection_id)
2308 .await?;
2309 let payload = session
2310 .peer
2311 .forward_request(session.connection_id, host_connection_id, request)
2312 .await?;
2313 response.send(payload)?;
2314 Ok(())
2315}
2316
2317/// Notify other participants that a new buffer has been created
2318async fn create_buffer_for_peer(
2319 request: proto::CreateBufferForPeer,
2320 session: Session,
2321) -> Result<()> {
2322 session
2323 .db()
2324 .await
2325 .check_user_is_project_host(
2326 ProjectId::from_proto(request.project_id),
2327 session.connection_id,
2328 )
2329 .await?;
2330 let peer_id = request.peer_id.context("invalid peer id")?;
2331 session
2332 .peer
2333 .forward_send(session.connection_id, peer_id.into(), request)?;
2334 Ok(())
2335}
2336
2337/// Notify other participants that a buffer has been updated. This is
2338/// allowed for guests as long as the update is limited to selections.
2339async fn update_buffer(
2340 request: proto::UpdateBuffer,
2341 response: Response<proto::UpdateBuffer>,
2342 session: Session,
2343) -> Result<()> {
2344 let project_id = ProjectId::from_proto(request.project_id);
2345 let mut capability = Capability::ReadOnly;
2346
2347 for op in request.operations.iter() {
2348 match op.variant {
2349 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2350 Some(_) => capability = Capability::ReadWrite,
2351 }
2352 }
2353
2354 let host = {
2355 let guard = session
2356 .db()
2357 .await
2358 .connections_for_buffer_update(project_id, session.connection_id, capability)
2359 .await?;
2360
2361 let (host, guests) = &*guard;
2362
2363 broadcast(
2364 Some(session.connection_id),
2365 guests.clone(),
2366 |connection_id| {
2367 session
2368 .peer
2369 .forward_send(session.connection_id, connection_id, request.clone())
2370 },
2371 );
2372
2373 *host
2374 };
2375
2376 if host != session.connection_id {
2377 session
2378 .peer
2379 .forward_request(session.connection_id, host, request.clone())
2380 .await?;
2381 }
2382
2383 response.send(proto::Ack {})?;
2384 Ok(())
2385}
2386
2387async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2388 let project_id = ProjectId::from_proto(message.project_id);
2389
2390 let operation = message.operation.as_ref().context("invalid operation")?;
2391 let capability = match operation.variant.as_ref() {
2392 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2393 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2394 match buffer_op.variant {
2395 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2396 Capability::ReadOnly
2397 }
2398 _ => Capability::ReadWrite,
2399 }
2400 } else {
2401 Capability::ReadWrite
2402 }
2403 }
2404 Some(_) => Capability::ReadWrite,
2405 None => Capability::ReadOnly,
2406 };
2407
2408 let guard = session
2409 .db()
2410 .await
2411 .connections_for_buffer_update(project_id, session.connection_id, capability)
2412 .await?;
2413
2414 let (host, guests) = &*guard;
2415
2416 broadcast(
2417 Some(session.connection_id),
2418 guests.iter().chain([host]).copied(),
2419 |connection_id| {
2420 session
2421 .peer
2422 .forward_send(session.connection_id, connection_id, message.clone())
2423 },
2424 );
2425
2426 Ok(())
2427}
2428
2429/// Notify other participants that a project has been updated.
2430async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2431 request: T,
2432 session: Session,
2433) -> Result<()> {
2434 let project_id = ProjectId::from_proto(request.remote_entity_id());
2435 let project_connection_ids = session
2436 .db()
2437 .await
2438 .project_connection_ids(project_id, session.connection_id, false)
2439 .await?;
2440
2441 broadcast(
2442 Some(session.connection_id),
2443 project_connection_ids.iter().copied(),
2444 |connection_id| {
2445 session
2446 .peer
2447 .forward_send(session.connection_id, connection_id, request.clone())
2448 },
2449 );
2450 Ok(())
2451}
2452
2453/// Start following another user in a call.
2454async fn follow(
2455 request: proto::Follow,
2456 response: Response<proto::Follow>,
2457 session: Session,
2458) -> Result<()> {
2459 let room_id = RoomId::from_proto(request.room_id);
2460 let project_id = request.project_id.map(ProjectId::from_proto);
2461 let leader_id = request.leader_id.context("invalid leader id")?.into();
2462 let follower_id = session.connection_id;
2463
2464 session
2465 .db()
2466 .await
2467 .check_room_participants(room_id, leader_id, session.connection_id)
2468 .await?;
2469
2470 let response_payload = session
2471 .peer
2472 .forward_request(session.connection_id, leader_id, request)
2473 .await?;
2474 response.send(response_payload)?;
2475
2476 if let Some(project_id) = project_id {
2477 let room = session
2478 .db()
2479 .await
2480 .follow(room_id, project_id, leader_id, follower_id)
2481 .await?;
2482 room_updated(&room, &session.peer);
2483 }
2484
2485 Ok(())
2486}
2487
2488/// Stop following another user in a call.
2489async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2490 let room_id = RoomId::from_proto(request.room_id);
2491 let project_id = request.project_id.map(ProjectId::from_proto);
2492 let leader_id = request.leader_id.context("invalid leader id")?.into();
2493 let follower_id = session.connection_id;
2494
2495 session
2496 .db()
2497 .await
2498 .check_room_participants(room_id, leader_id, session.connection_id)
2499 .await?;
2500
2501 session
2502 .peer
2503 .forward_send(session.connection_id, leader_id, request)?;
2504
2505 if let Some(project_id) = project_id {
2506 let room = session
2507 .db()
2508 .await
2509 .unfollow(room_id, project_id, leader_id, follower_id)
2510 .await?;
2511 room_updated(&room, &session.peer);
2512 }
2513
2514 Ok(())
2515}
2516
2517/// Notify everyone following you of your current location.
2518async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2519 let room_id = RoomId::from_proto(request.room_id);
2520 let database = session.db.lock().await;
2521
2522 let connection_ids = if let Some(project_id) = request.project_id {
2523 let project_id = ProjectId::from_proto(project_id);
2524 database
2525 .project_connection_ids(project_id, session.connection_id, true)
2526 .await?
2527 } else {
2528 database
2529 .room_connection_ids(room_id, session.connection_id)
2530 .await?
2531 };
2532
2533 // For now, don't send view update messages back to that view's current leader.
2534 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2535 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2536 _ => None,
2537 });
2538
2539 for connection_id in connection_ids.iter().cloned() {
2540 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2541 session
2542 .peer
2543 .forward_send(session.connection_id, connection_id, request.clone())?;
2544 }
2545 }
2546 Ok(())
2547}
2548
2549/// Get public data about users.
2550async fn get_users(
2551 request: proto::GetUsers,
2552 response: Response<proto::GetUsers>,
2553 session: Session,
2554) -> Result<()> {
2555 let user_ids = request
2556 .user_ids
2557 .into_iter()
2558 .map(UserId::from_proto)
2559 .collect();
2560 let users = session
2561 .db()
2562 .await
2563 .get_users_by_ids(user_ids)
2564 .await?
2565 .into_iter()
2566 .map(|user| proto::User {
2567 id: user.id.to_proto(),
2568 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2569 github_login: user.github_login,
2570 email: user.email_address,
2571 name: user.name,
2572 })
2573 .collect();
2574 response.send(proto::UsersResponse { users })?;
2575 Ok(())
2576}
2577
2578/// Search for users (to invite) buy Github login
2579async fn fuzzy_search_users(
2580 request: proto::FuzzySearchUsers,
2581 response: Response<proto::FuzzySearchUsers>,
2582 session: Session,
2583) -> Result<()> {
2584 let query = request.query;
2585 let users = match query.len() {
2586 0 => vec![],
2587 1 | 2 => session
2588 .db()
2589 .await
2590 .get_user_by_github_login(&query)
2591 .await?
2592 .into_iter()
2593 .collect(),
2594 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2595 };
2596 let users = users
2597 .into_iter()
2598 .filter(|user| user.id != session.user_id())
2599 .map(|user| proto::User {
2600 id: user.id.to_proto(),
2601 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2602 github_login: user.github_login,
2603 name: user.name,
2604 email: user.email_address,
2605 })
2606 .collect();
2607 response.send(proto::UsersResponse { users })?;
2608 Ok(())
2609}
2610
2611/// Send a contact request to another user.
2612async fn request_contact(
2613 request: proto::RequestContact,
2614 response: Response<proto::RequestContact>,
2615 session: Session,
2616) -> Result<()> {
2617 let requester_id = session.user_id();
2618 let responder_id = UserId::from_proto(request.responder_id);
2619 if requester_id == responder_id {
2620 return Err(anyhow!("cannot add yourself as a contact"))?;
2621 }
2622
2623 let notifications = session
2624 .db()
2625 .await
2626 .send_contact_request(requester_id, responder_id)
2627 .await?;
2628
2629 // Update outgoing contact requests of requester
2630 let mut update = proto::UpdateContacts::default();
2631 update.outgoing_requests.push(responder_id.to_proto());
2632 for connection_id in session
2633 .connection_pool()
2634 .await
2635 .user_connection_ids(requester_id)
2636 {
2637 session.peer.send(connection_id, update.clone())?;
2638 }
2639
2640 // Update incoming contact requests of responder
2641 let mut update = proto::UpdateContacts::default();
2642 update
2643 .incoming_requests
2644 .push(proto::IncomingContactRequest {
2645 requester_id: requester_id.to_proto(),
2646 });
2647 let connection_pool = session.connection_pool().await;
2648 for connection_id in connection_pool.user_connection_ids(responder_id) {
2649 session.peer.send(connection_id, update.clone())?;
2650 }
2651
2652 send_notifications(&connection_pool, &session.peer, notifications);
2653
2654 response.send(proto::Ack {})?;
2655 Ok(())
2656}
2657
2658/// Accept or decline a contact request
2659async fn respond_to_contact_request(
2660 request: proto::RespondToContactRequest,
2661 response: Response<proto::RespondToContactRequest>,
2662 session: Session,
2663) -> Result<()> {
2664 let responder_id = session.user_id();
2665 let requester_id = UserId::from_proto(request.requester_id);
2666 let db = session.db().await;
2667 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2668 db.dismiss_contact_notification(responder_id, requester_id)
2669 .await?;
2670 } else {
2671 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2672
2673 let notifications = db
2674 .respond_to_contact_request(responder_id, requester_id, accept)
2675 .await?;
2676 let requester_busy = db.is_user_busy(requester_id).await?;
2677 let responder_busy = db.is_user_busy(responder_id).await?;
2678
2679 let pool = session.connection_pool().await;
2680 // Update responder with new contact
2681 let mut update = proto::UpdateContacts::default();
2682 if accept {
2683 update
2684 .contacts
2685 .push(contact_for_user(requester_id, requester_busy, &pool));
2686 }
2687 update
2688 .remove_incoming_requests
2689 .push(requester_id.to_proto());
2690 for connection_id in pool.user_connection_ids(responder_id) {
2691 session.peer.send(connection_id, update.clone())?;
2692 }
2693
2694 // Update requester with new contact
2695 let mut update = proto::UpdateContacts::default();
2696 if accept {
2697 update
2698 .contacts
2699 .push(contact_for_user(responder_id, responder_busy, &pool));
2700 }
2701 update
2702 .remove_outgoing_requests
2703 .push(responder_id.to_proto());
2704
2705 for connection_id in pool.user_connection_ids(requester_id) {
2706 session.peer.send(connection_id, update.clone())?;
2707 }
2708
2709 send_notifications(&pool, &session.peer, notifications);
2710 }
2711
2712 response.send(proto::Ack {})?;
2713 Ok(())
2714}
2715
2716/// Remove a contact.
2717async fn remove_contact(
2718 request: proto::RemoveContact,
2719 response: Response<proto::RemoveContact>,
2720 session: Session,
2721) -> Result<()> {
2722 let requester_id = session.user_id();
2723 let responder_id = UserId::from_proto(request.user_id);
2724 let db = session.db().await;
2725 let (contact_accepted, deleted_notification_id) =
2726 db.remove_contact(requester_id, responder_id).await?;
2727
2728 let pool = session.connection_pool().await;
2729 // Update outgoing contact requests of requester
2730 let mut update = proto::UpdateContacts::default();
2731 if contact_accepted {
2732 update.remove_contacts.push(responder_id.to_proto());
2733 } else {
2734 update
2735 .remove_outgoing_requests
2736 .push(responder_id.to_proto());
2737 }
2738 for connection_id in pool.user_connection_ids(requester_id) {
2739 session.peer.send(connection_id, update.clone())?;
2740 }
2741
2742 // Update incoming contact requests of responder
2743 let mut update = proto::UpdateContacts::default();
2744 if contact_accepted {
2745 update.remove_contacts.push(requester_id.to_proto());
2746 } else {
2747 update
2748 .remove_incoming_requests
2749 .push(requester_id.to_proto());
2750 }
2751 for connection_id in pool.user_connection_ids(responder_id) {
2752 session.peer.send(connection_id, update.clone())?;
2753 if let Some(notification_id) = deleted_notification_id {
2754 session.peer.send(
2755 connection_id,
2756 proto::DeleteNotification {
2757 notification_id: notification_id.to_proto(),
2758 },
2759 )?;
2760 }
2761 }
2762
2763 response.send(proto::Ack {})?;
2764 Ok(())
2765}
2766
2767fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2768 version.0.minor() < 139
2769}
2770
2771async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2772 if is_staff {
2773 return Ok(proto::Plan::ZedPro);
2774 }
2775
2776 let subscription = db.get_active_billing_subscription(user_id).await?;
2777 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2778
2779 let plan = if let Some(subscription_kind) = subscription_kind {
2780 match subscription_kind {
2781 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2782 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2783 SubscriptionKind::ZedFree => proto::Plan::Free,
2784 }
2785 } else {
2786 proto::Plan::Free
2787 };
2788
2789 Ok(plan)
2790}
2791
2792async fn make_update_user_plan_message(
2793 user: &User,
2794 is_staff: bool,
2795 db: &Arc<Database>,
2796 llm_db: Option<Arc<LlmDatabase>>,
2797) -> Result<proto::UpdateUserPlan> {
2798 let feature_flags = db.get_user_flags(user.id).await?;
2799 let plan = current_plan(db, user.id, is_staff).await?;
2800 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2801 let billing_preferences = db.get_billing_preferences(user.id).await?;
2802
2803 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2804 let subscription = db.get_active_billing_subscription(user.id).await?;
2805
2806 let subscription_period =
2807 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2808
2809 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2810 llm_db
2811 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2812 .await?
2813 } else {
2814 None
2815 };
2816
2817 (subscription_period, usage)
2818 } else {
2819 (None, None)
2820 };
2821
2822 let bypass_account_age_check = feature_flags
2823 .iter()
2824 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2825 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2826 && !bypass_account_age_check
2827 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2828
2829 Ok(proto::UpdateUserPlan {
2830 plan: plan.into(),
2831 trial_started_at: billing_customer
2832 .as_ref()
2833 .and_then(|billing_customer| billing_customer.trial_started_at)
2834 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2835 is_usage_based_billing_enabled: if is_staff {
2836 Some(true)
2837 } else {
2838 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2839 },
2840 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2841 proto::SubscriptionPeriod {
2842 started_at: started_at.timestamp() as u64,
2843 ended_at: ended_at.timestamp() as u64,
2844 }
2845 }),
2846 account_too_young: Some(account_too_young),
2847 has_overdue_invoices: billing_customer
2848 .map(|billing_customer| billing_customer.has_overdue_invoices),
2849 usage: usage.map(|usage| {
2850 let plan = match plan {
2851 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2852 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2853 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2854 };
2855
2856 let model_requests_limit = match plan.model_requests_limit() {
2857 zed_llm_client::UsageLimit::Limited(limit) => {
2858 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2859 && feature_flags
2860 .iter()
2861 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2862 {
2863 1_000
2864 } else {
2865 limit
2866 };
2867
2868 zed_llm_client::UsageLimit::Limited(limit)
2869 }
2870 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2871 };
2872
2873 proto::SubscriptionUsage {
2874 model_requests_usage_amount: usage.model_requests as u32,
2875 model_requests_usage_limit: Some(proto::UsageLimit {
2876 variant: Some(match model_requests_limit {
2877 zed_llm_client::UsageLimit::Limited(limit) => {
2878 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2879 limit: limit as u32,
2880 })
2881 }
2882 zed_llm_client::UsageLimit::Unlimited => {
2883 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2884 }
2885 }),
2886 }),
2887 edit_predictions_usage_amount: usage.edit_predictions as u32,
2888 edit_predictions_usage_limit: Some(proto::UsageLimit {
2889 variant: Some(match plan.edit_predictions_limit() {
2890 zed_llm_client::UsageLimit::Limited(limit) => {
2891 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2892 limit: limit as u32,
2893 })
2894 }
2895 zed_llm_client::UsageLimit::Unlimited => {
2896 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2897 }
2898 }),
2899 }),
2900 }
2901 }),
2902 })
2903}
2904
2905async fn update_user_plan(session: &Session) -> Result<()> {
2906 let db = session.db().await;
2907
2908 let update_user_plan = make_update_user_plan_message(
2909 session.principal.user(),
2910 session.is_staff(),
2911 &db.0,
2912 session.app_state.llm_db.clone(),
2913 )
2914 .await?;
2915
2916 session
2917 .peer
2918 .send(session.connection_id, update_user_plan)
2919 .trace_err();
2920
2921 Ok(())
2922}
2923
2924async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2925 subscribe_user_to_channels(session.user_id(), &session).await?;
2926 Ok(())
2927}
2928
2929async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2930 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2931 let mut pool = session.connection_pool().await;
2932 for membership in &channels_for_user.channel_memberships {
2933 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2934 }
2935 session.peer.send(
2936 session.connection_id,
2937 build_update_user_channels(&channels_for_user),
2938 )?;
2939 session.peer.send(
2940 session.connection_id,
2941 build_channels_update(channels_for_user),
2942 )?;
2943 Ok(())
2944}
2945
2946/// Creates a new channel.
2947async fn create_channel(
2948 request: proto::CreateChannel,
2949 response: Response<proto::CreateChannel>,
2950 session: Session,
2951) -> Result<()> {
2952 let db = session.db().await;
2953
2954 let parent_id = request.parent_id.map(ChannelId::from_proto);
2955 let (channel, membership) = db
2956 .create_channel(&request.name, parent_id, session.user_id())
2957 .await?;
2958
2959 let root_id = channel.root_id();
2960 let channel = Channel::from_model(channel);
2961
2962 response.send(proto::CreateChannelResponse {
2963 channel: Some(channel.to_proto()),
2964 parent_id: request.parent_id,
2965 })?;
2966
2967 let mut connection_pool = session.connection_pool().await;
2968 if let Some(membership) = membership {
2969 connection_pool.subscribe_to_channel(
2970 membership.user_id,
2971 membership.channel_id,
2972 membership.role,
2973 );
2974 let update = proto::UpdateUserChannels {
2975 channel_memberships: vec![proto::ChannelMembership {
2976 channel_id: membership.channel_id.to_proto(),
2977 role: membership.role.into(),
2978 }],
2979 ..Default::default()
2980 };
2981 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2982 session.peer.send(connection_id, update.clone())?;
2983 }
2984 }
2985
2986 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2987 if !role.can_see_channel(channel.visibility) {
2988 continue;
2989 }
2990
2991 let update = proto::UpdateChannels {
2992 channels: vec![channel.to_proto()],
2993 ..Default::default()
2994 };
2995 session.peer.send(connection_id, update.clone())?;
2996 }
2997
2998 Ok(())
2999}
3000
3001/// Delete a channel
3002async fn delete_channel(
3003 request: proto::DeleteChannel,
3004 response: Response<proto::DeleteChannel>,
3005 session: Session,
3006) -> Result<()> {
3007 let db = session.db().await;
3008
3009 let channel_id = request.channel_id;
3010 let (root_channel, removed_channels) = db
3011 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3012 .await?;
3013 response.send(proto::Ack {})?;
3014
3015 // Notify members of removed channels
3016 let mut update = proto::UpdateChannels::default();
3017 update
3018 .delete_channels
3019 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3020
3021 let connection_pool = session.connection_pool().await;
3022 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3023 session.peer.send(connection_id, update.clone())?;
3024 }
3025
3026 Ok(())
3027}
3028
3029/// Invite someone to join a channel.
3030async fn invite_channel_member(
3031 request: proto::InviteChannelMember,
3032 response: Response<proto::InviteChannelMember>,
3033 session: Session,
3034) -> Result<()> {
3035 let db = session.db().await;
3036 let channel_id = ChannelId::from_proto(request.channel_id);
3037 let invitee_id = UserId::from_proto(request.user_id);
3038 let InviteMemberResult {
3039 channel,
3040 notifications,
3041 } = db
3042 .invite_channel_member(
3043 channel_id,
3044 invitee_id,
3045 session.user_id(),
3046 request.role().into(),
3047 )
3048 .await?;
3049
3050 let update = proto::UpdateChannels {
3051 channel_invitations: vec![channel.to_proto()],
3052 ..Default::default()
3053 };
3054
3055 let connection_pool = session.connection_pool().await;
3056 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3057 session.peer.send(connection_id, update.clone())?;
3058 }
3059
3060 send_notifications(&connection_pool, &session.peer, notifications);
3061
3062 response.send(proto::Ack {})?;
3063 Ok(())
3064}
3065
3066/// remove someone from a channel
3067async fn remove_channel_member(
3068 request: proto::RemoveChannelMember,
3069 response: Response<proto::RemoveChannelMember>,
3070 session: Session,
3071) -> Result<()> {
3072 let db = session.db().await;
3073 let channel_id = ChannelId::from_proto(request.channel_id);
3074 let member_id = UserId::from_proto(request.user_id);
3075
3076 let RemoveChannelMemberResult {
3077 membership_update,
3078 notification_id,
3079 } = db
3080 .remove_channel_member(channel_id, member_id, session.user_id())
3081 .await?;
3082
3083 let mut connection_pool = session.connection_pool().await;
3084 notify_membership_updated(
3085 &mut connection_pool,
3086 membership_update,
3087 member_id,
3088 &session.peer,
3089 );
3090 for connection_id in connection_pool.user_connection_ids(member_id) {
3091 if let Some(notification_id) = notification_id {
3092 session
3093 .peer
3094 .send(
3095 connection_id,
3096 proto::DeleteNotification {
3097 notification_id: notification_id.to_proto(),
3098 },
3099 )
3100 .trace_err();
3101 }
3102 }
3103
3104 response.send(proto::Ack {})?;
3105 Ok(())
3106}
3107
3108/// Toggle the channel between public and private.
3109/// Care is taken to maintain the invariant that public channels only descend from public channels,
3110/// (though members-only channels can appear at any point in the hierarchy).
3111async fn set_channel_visibility(
3112 request: proto::SetChannelVisibility,
3113 response: Response<proto::SetChannelVisibility>,
3114 session: Session,
3115) -> Result<()> {
3116 let db = session.db().await;
3117 let channel_id = ChannelId::from_proto(request.channel_id);
3118 let visibility = request.visibility().into();
3119
3120 let channel_model = db
3121 .set_channel_visibility(channel_id, visibility, session.user_id())
3122 .await?;
3123 let root_id = channel_model.root_id();
3124 let channel = Channel::from_model(channel_model);
3125
3126 let mut connection_pool = session.connection_pool().await;
3127 for (user_id, role) in connection_pool
3128 .channel_user_ids(root_id)
3129 .collect::<Vec<_>>()
3130 .into_iter()
3131 {
3132 let update = if role.can_see_channel(channel.visibility) {
3133 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3134 proto::UpdateChannels {
3135 channels: vec![channel.to_proto()],
3136 ..Default::default()
3137 }
3138 } else {
3139 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3140 proto::UpdateChannels {
3141 delete_channels: vec![channel.id.to_proto()],
3142 ..Default::default()
3143 }
3144 };
3145
3146 for connection_id in connection_pool.user_connection_ids(user_id) {
3147 session.peer.send(connection_id, update.clone())?;
3148 }
3149 }
3150
3151 response.send(proto::Ack {})?;
3152 Ok(())
3153}
3154
3155/// Alter the role for a user in the channel.
3156async fn set_channel_member_role(
3157 request: proto::SetChannelMemberRole,
3158 response: Response<proto::SetChannelMemberRole>,
3159 session: Session,
3160) -> Result<()> {
3161 let db = session.db().await;
3162 let channel_id = ChannelId::from_proto(request.channel_id);
3163 let member_id = UserId::from_proto(request.user_id);
3164 let result = db
3165 .set_channel_member_role(
3166 channel_id,
3167 session.user_id(),
3168 member_id,
3169 request.role().into(),
3170 )
3171 .await?;
3172
3173 match result {
3174 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3175 let mut connection_pool = session.connection_pool().await;
3176 notify_membership_updated(
3177 &mut connection_pool,
3178 membership_update,
3179 member_id,
3180 &session.peer,
3181 )
3182 }
3183 db::SetMemberRoleResult::InviteUpdated(channel) => {
3184 let update = proto::UpdateChannels {
3185 channel_invitations: vec![channel.to_proto()],
3186 ..Default::default()
3187 };
3188
3189 for connection_id in session
3190 .connection_pool()
3191 .await
3192 .user_connection_ids(member_id)
3193 {
3194 session.peer.send(connection_id, update.clone())?;
3195 }
3196 }
3197 }
3198
3199 response.send(proto::Ack {})?;
3200 Ok(())
3201}
3202
3203/// Change the name of a channel
3204async fn rename_channel(
3205 request: proto::RenameChannel,
3206 response: Response<proto::RenameChannel>,
3207 session: Session,
3208) -> Result<()> {
3209 let db = session.db().await;
3210 let channel_id = ChannelId::from_proto(request.channel_id);
3211 let channel_model = db
3212 .rename_channel(channel_id, session.user_id(), &request.name)
3213 .await?;
3214 let root_id = channel_model.root_id();
3215 let channel = Channel::from_model(channel_model);
3216
3217 response.send(proto::RenameChannelResponse {
3218 channel: Some(channel.to_proto()),
3219 })?;
3220
3221 let connection_pool = session.connection_pool().await;
3222 let update = proto::UpdateChannels {
3223 channels: vec![channel.to_proto()],
3224 ..Default::default()
3225 };
3226 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3227 if role.can_see_channel(channel.visibility) {
3228 session.peer.send(connection_id, update.clone())?;
3229 }
3230 }
3231
3232 Ok(())
3233}
3234
3235/// Move a channel to a new parent.
3236async fn move_channel(
3237 request: proto::MoveChannel,
3238 response: Response<proto::MoveChannel>,
3239 session: Session,
3240) -> Result<()> {
3241 let channel_id = ChannelId::from_proto(request.channel_id);
3242 let to = ChannelId::from_proto(request.to);
3243
3244 let (root_id, channels) = session
3245 .db()
3246 .await
3247 .move_channel(channel_id, to, session.user_id())
3248 .await?;
3249
3250 let connection_pool = session.connection_pool().await;
3251 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3252 let channels = channels
3253 .iter()
3254 .filter_map(|channel| {
3255 if role.can_see_channel(channel.visibility) {
3256 Some(channel.to_proto())
3257 } else {
3258 None
3259 }
3260 })
3261 .collect::<Vec<_>>();
3262 if channels.is_empty() {
3263 continue;
3264 }
3265
3266 let update = proto::UpdateChannels {
3267 channels,
3268 ..Default::default()
3269 };
3270
3271 session.peer.send(connection_id, update.clone())?;
3272 }
3273
3274 response.send(Ack {})?;
3275 Ok(())
3276}
3277
3278async fn reorder_channel(
3279 request: proto::ReorderChannel,
3280 response: Response<proto::ReorderChannel>,
3281 session: Session,
3282) -> Result<()> {
3283 let channel_id = ChannelId::from_proto(request.channel_id);
3284 let direction = request.direction();
3285
3286 let updated_channels = session
3287 .db()
3288 .await
3289 .reorder_channel(channel_id, direction, session.user_id())
3290 .await?;
3291
3292 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3293 let connection_pool = session.connection_pool().await;
3294 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3295 let channels = updated_channels
3296 .iter()
3297 .filter_map(|channel| {
3298 if role.can_see_channel(channel.visibility) {
3299 Some(channel.to_proto())
3300 } else {
3301 None
3302 }
3303 })
3304 .collect::<Vec<_>>();
3305
3306 if channels.is_empty() {
3307 continue;
3308 }
3309
3310 let update = proto::UpdateChannels {
3311 channels,
3312 ..Default::default()
3313 };
3314
3315 session.peer.send(connection_id, update.clone())?;
3316 }
3317 }
3318
3319 response.send(Ack {})?;
3320 Ok(())
3321}
3322
3323/// Get the list of channel members
3324async fn get_channel_members(
3325 request: proto::GetChannelMembers,
3326 response: Response<proto::GetChannelMembers>,
3327 session: Session,
3328) -> Result<()> {
3329 let db = session.db().await;
3330 let channel_id = ChannelId::from_proto(request.channel_id);
3331 let limit = if request.limit == 0 {
3332 u16::MAX as u64
3333 } else {
3334 request.limit
3335 };
3336 let (members, users) = db
3337 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3338 .await?;
3339 response.send(proto::GetChannelMembersResponse { members, users })?;
3340 Ok(())
3341}
3342
3343/// Accept or decline a channel invitation.
3344async fn respond_to_channel_invite(
3345 request: proto::RespondToChannelInvite,
3346 response: Response<proto::RespondToChannelInvite>,
3347 session: Session,
3348) -> Result<()> {
3349 let db = session.db().await;
3350 let channel_id = ChannelId::from_proto(request.channel_id);
3351 let RespondToChannelInvite {
3352 membership_update,
3353 notifications,
3354 } = db
3355 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3356 .await?;
3357
3358 let mut connection_pool = session.connection_pool().await;
3359 if let Some(membership_update) = membership_update {
3360 notify_membership_updated(
3361 &mut connection_pool,
3362 membership_update,
3363 session.user_id(),
3364 &session.peer,
3365 );
3366 } else {
3367 let update = proto::UpdateChannels {
3368 remove_channel_invitations: vec![channel_id.to_proto()],
3369 ..Default::default()
3370 };
3371
3372 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3373 session.peer.send(connection_id, update.clone())?;
3374 }
3375 };
3376
3377 send_notifications(&connection_pool, &session.peer, notifications);
3378
3379 response.send(proto::Ack {})?;
3380
3381 Ok(())
3382}
3383
3384/// Join the channels' room
3385async fn join_channel(
3386 request: proto::JoinChannel,
3387 response: Response<proto::JoinChannel>,
3388 session: Session,
3389) -> Result<()> {
3390 let channel_id = ChannelId::from_proto(request.channel_id);
3391 join_channel_internal(channel_id, Box::new(response), session).await
3392}
3393
3394trait JoinChannelInternalResponse {
3395 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3396}
3397impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3398 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3399 Response::<proto::JoinChannel>::send(self, result)
3400 }
3401}
3402impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3403 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3404 Response::<proto::JoinRoom>::send(self, result)
3405 }
3406}
3407
3408async fn join_channel_internal(
3409 channel_id: ChannelId,
3410 response: Box<impl JoinChannelInternalResponse>,
3411 session: Session,
3412) -> Result<()> {
3413 let joined_room = {
3414 let mut db = session.db().await;
3415 // If zed quits without leaving the room, and the user re-opens zed before the
3416 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3417 // room they were in.
3418 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3419 tracing::info!(
3420 stale_connection_id = %connection,
3421 "cleaning up stale connection",
3422 );
3423 drop(db);
3424 leave_room_for_session(&session, connection).await?;
3425 db = session.db().await;
3426 }
3427
3428 let (joined_room, membership_updated, role) = db
3429 .join_channel(channel_id, session.user_id(), session.connection_id)
3430 .await?;
3431
3432 let live_kit_connection_info =
3433 session
3434 .app_state
3435 .livekit_client
3436 .as_ref()
3437 .and_then(|live_kit| {
3438 let (can_publish, token) = if role == ChannelRole::Guest {
3439 (
3440 false,
3441 live_kit
3442 .guest_token(
3443 &joined_room.room.livekit_room,
3444 &session.user_id().to_string(),
3445 )
3446 .trace_err()?,
3447 )
3448 } else {
3449 (
3450 true,
3451 live_kit
3452 .room_token(
3453 &joined_room.room.livekit_room,
3454 &session.user_id().to_string(),
3455 )
3456 .trace_err()?,
3457 )
3458 };
3459
3460 Some(LiveKitConnectionInfo {
3461 server_url: live_kit.url().into(),
3462 token,
3463 can_publish,
3464 })
3465 });
3466
3467 response.send(proto::JoinRoomResponse {
3468 room: Some(joined_room.room.clone()),
3469 channel_id: joined_room
3470 .channel
3471 .as_ref()
3472 .map(|channel| channel.id.to_proto()),
3473 live_kit_connection_info,
3474 })?;
3475
3476 let mut connection_pool = session.connection_pool().await;
3477 if let Some(membership_updated) = membership_updated {
3478 notify_membership_updated(
3479 &mut connection_pool,
3480 membership_updated,
3481 session.user_id(),
3482 &session.peer,
3483 );
3484 }
3485
3486 room_updated(&joined_room.room, &session.peer);
3487
3488 joined_room
3489 };
3490
3491 channel_updated(
3492 &joined_room.channel.context("channel not returned")?,
3493 &joined_room.room,
3494 &session.peer,
3495 &*session.connection_pool().await,
3496 );
3497
3498 update_user_contacts(session.user_id(), &session).await?;
3499 Ok(())
3500}
3501
3502/// Start editing the channel notes
3503async fn join_channel_buffer(
3504 request: proto::JoinChannelBuffer,
3505 response: Response<proto::JoinChannelBuffer>,
3506 session: Session,
3507) -> Result<()> {
3508 let db = session.db().await;
3509 let channel_id = ChannelId::from_proto(request.channel_id);
3510
3511 let open_response = db
3512 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3513 .await?;
3514
3515 let collaborators = open_response.collaborators.clone();
3516 response.send(open_response)?;
3517
3518 let update = UpdateChannelBufferCollaborators {
3519 channel_id: channel_id.to_proto(),
3520 collaborators: collaborators.clone(),
3521 };
3522 channel_buffer_updated(
3523 session.connection_id,
3524 collaborators
3525 .iter()
3526 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3527 &update,
3528 &session.peer,
3529 );
3530
3531 Ok(())
3532}
3533
3534/// Edit the channel notes
3535async fn update_channel_buffer(
3536 request: proto::UpdateChannelBuffer,
3537 session: Session,
3538) -> Result<()> {
3539 let db = session.db().await;
3540 let channel_id = ChannelId::from_proto(request.channel_id);
3541
3542 let (collaborators, epoch, version) = db
3543 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3544 .await?;
3545
3546 channel_buffer_updated(
3547 session.connection_id,
3548 collaborators.clone(),
3549 &proto::UpdateChannelBuffer {
3550 channel_id: channel_id.to_proto(),
3551 operations: request.operations,
3552 },
3553 &session.peer,
3554 );
3555
3556 let pool = &*session.connection_pool().await;
3557
3558 let non_collaborators =
3559 pool.channel_connection_ids(channel_id)
3560 .filter_map(|(connection_id, _)| {
3561 if collaborators.contains(&connection_id) {
3562 None
3563 } else {
3564 Some(connection_id)
3565 }
3566 });
3567
3568 broadcast(None, non_collaborators, |peer_id| {
3569 session.peer.send(
3570 peer_id,
3571 proto::UpdateChannels {
3572 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3573 channel_id: channel_id.to_proto(),
3574 epoch: epoch as u64,
3575 version: version.clone(),
3576 }],
3577 ..Default::default()
3578 },
3579 )
3580 });
3581
3582 Ok(())
3583}
3584
3585/// Rejoin the channel notes after a connection blip
3586async fn rejoin_channel_buffers(
3587 request: proto::RejoinChannelBuffers,
3588 response: Response<proto::RejoinChannelBuffers>,
3589 session: Session,
3590) -> Result<()> {
3591 let db = session.db().await;
3592 let buffers = db
3593 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3594 .await?;
3595
3596 for rejoined_buffer in &buffers {
3597 let collaborators_to_notify = rejoined_buffer
3598 .buffer
3599 .collaborators
3600 .iter()
3601 .filter_map(|c| Some(c.peer_id?.into()));
3602 channel_buffer_updated(
3603 session.connection_id,
3604 collaborators_to_notify,
3605 &proto::UpdateChannelBufferCollaborators {
3606 channel_id: rejoined_buffer.buffer.channel_id,
3607 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3608 },
3609 &session.peer,
3610 );
3611 }
3612
3613 response.send(proto::RejoinChannelBuffersResponse {
3614 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3615 })?;
3616
3617 Ok(())
3618}
3619
3620/// Stop editing the channel notes
3621async fn leave_channel_buffer(
3622 request: proto::LeaveChannelBuffer,
3623 response: Response<proto::LeaveChannelBuffer>,
3624 session: Session,
3625) -> Result<()> {
3626 let db = session.db().await;
3627 let channel_id = ChannelId::from_proto(request.channel_id);
3628
3629 let left_buffer = db
3630 .leave_channel_buffer(channel_id, session.connection_id)
3631 .await?;
3632
3633 response.send(Ack {})?;
3634
3635 channel_buffer_updated(
3636 session.connection_id,
3637 left_buffer.connections,
3638 &proto::UpdateChannelBufferCollaborators {
3639 channel_id: channel_id.to_proto(),
3640 collaborators: left_buffer.collaborators,
3641 },
3642 &session.peer,
3643 );
3644
3645 Ok(())
3646}
3647
3648fn channel_buffer_updated<T: EnvelopedMessage>(
3649 sender_id: ConnectionId,
3650 collaborators: impl IntoIterator<Item = ConnectionId>,
3651 message: &T,
3652 peer: &Peer,
3653) {
3654 broadcast(Some(sender_id), collaborators, |peer_id| {
3655 peer.send(peer_id, message.clone())
3656 });
3657}
3658
3659fn send_notifications(
3660 connection_pool: &ConnectionPool,
3661 peer: &Peer,
3662 notifications: db::NotificationBatch,
3663) {
3664 for (user_id, notification) in notifications {
3665 for connection_id in connection_pool.user_connection_ids(user_id) {
3666 if let Err(error) = peer.send(
3667 connection_id,
3668 proto::AddNotification {
3669 notification: Some(notification.clone()),
3670 },
3671 ) {
3672 tracing::error!(
3673 "failed to send notification to {:?} {}",
3674 connection_id,
3675 error
3676 );
3677 }
3678 }
3679 }
3680}
3681
3682/// Send a message to the channel
3683async fn send_channel_message(
3684 request: proto::SendChannelMessage,
3685 response: Response<proto::SendChannelMessage>,
3686 session: Session,
3687) -> Result<()> {
3688 // Validate the message body.
3689 let body = request.body.trim().to_string();
3690 if body.len() > MAX_MESSAGE_LEN {
3691 return Err(anyhow!("message is too long"))?;
3692 }
3693 if body.is_empty() {
3694 return Err(anyhow!("message can't be blank"))?;
3695 }
3696
3697 // TODO: adjust mentions if body is trimmed
3698
3699 let timestamp = OffsetDateTime::now_utc();
3700 let nonce = request.nonce.context("nonce can't be blank")?;
3701
3702 let channel_id = ChannelId::from_proto(request.channel_id);
3703 let CreatedChannelMessage {
3704 message_id,
3705 participant_connection_ids,
3706 notifications,
3707 } = session
3708 .db()
3709 .await
3710 .create_channel_message(
3711 channel_id,
3712 session.user_id(),
3713 &body,
3714 &request.mentions,
3715 timestamp,
3716 nonce.clone().into(),
3717 request.reply_to_message_id.map(MessageId::from_proto),
3718 )
3719 .await?;
3720
3721 let message = proto::ChannelMessage {
3722 sender_id: session.user_id().to_proto(),
3723 id: message_id.to_proto(),
3724 body,
3725 mentions: request.mentions,
3726 timestamp: timestamp.unix_timestamp() as u64,
3727 nonce: Some(nonce),
3728 reply_to_message_id: request.reply_to_message_id,
3729 edited_at: None,
3730 };
3731 broadcast(
3732 Some(session.connection_id),
3733 participant_connection_ids.clone(),
3734 |connection| {
3735 session.peer.send(
3736 connection,
3737 proto::ChannelMessageSent {
3738 channel_id: channel_id.to_proto(),
3739 message: Some(message.clone()),
3740 },
3741 )
3742 },
3743 );
3744 response.send(proto::SendChannelMessageResponse {
3745 message: Some(message),
3746 })?;
3747
3748 let pool = &*session.connection_pool().await;
3749 let non_participants =
3750 pool.channel_connection_ids(channel_id)
3751 .filter_map(|(connection_id, _)| {
3752 if participant_connection_ids.contains(&connection_id) {
3753 None
3754 } else {
3755 Some(connection_id)
3756 }
3757 });
3758 broadcast(None, non_participants, |peer_id| {
3759 session.peer.send(
3760 peer_id,
3761 proto::UpdateChannels {
3762 latest_channel_message_ids: vec![proto::ChannelMessageId {
3763 channel_id: channel_id.to_proto(),
3764 message_id: message_id.to_proto(),
3765 }],
3766 ..Default::default()
3767 },
3768 )
3769 });
3770 send_notifications(pool, &session.peer, notifications);
3771
3772 Ok(())
3773}
3774
3775/// Delete a channel message
3776async fn remove_channel_message(
3777 request: proto::RemoveChannelMessage,
3778 response: Response<proto::RemoveChannelMessage>,
3779 session: Session,
3780) -> Result<()> {
3781 let channel_id = ChannelId::from_proto(request.channel_id);
3782 let message_id = MessageId::from_proto(request.message_id);
3783 let (connection_ids, existing_notification_ids) = session
3784 .db()
3785 .await
3786 .remove_channel_message(channel_id, message_id, session.user_id())
3787 .await?;
3788
3789 broadcast(
3790 Some(session.connection_id),
3791 connection_ids,
3792 move |connection| {
3793 session.peer.send(connection, request.clone())?;
3794
3795 for notification_id in &existing_notification_ids {
3796 session.peer.send(
3797 connection,
3798 proto::DeleteNotification {
3799 notification_id: (*notification_id).to_proto(),
3800 },
3801 )?;
3802 }
3803
3804 Ok(())
3805 },
3806 );
3807 response.send(proto::Ack {})?;
3808 Ok(())
3809}
3810
3811async fn update_channel_message(
3812 request: proto::UpdateChannelMessage,
3813 response: Response<proto::UpdateChannelMessage>,
3814 session: Session,
3815) -> Result<()> {
3816 let channel_id = ChannelId::from_proto(request.channel_id);
3817 let message_id = MessageId::from_proto(request.message_id);
3818 let updated_at = OffsetDateTime::now_utc();
3819 let UpdatedChannelMessage {
3820 message_id,
3821 participant_connection_ids,
3822 notifications,
3823 reply_to_message_id,
3824 timestamp,
3825 deleted_mention_notification_ids,
3826 updated_mention_notifications,
3827 } = session
3828 .db()
3829 .await
3830 .update_channel_message(
3831 channel_id,
3832 message_id,
3833 session.user_id(),
3834 request.body.as_str(),
3835 &request.mentions,
3836 updated_at,
3837 )
3838 .await?;
3839
3840 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3841
3842 let message = proto::ChannelMessage {
3843 sender_id: session.user_id().to_proto(),
3844 id: message_id.to_proto(),
3845 body: request.body.clone(),
3846 mentions: request.mentions.clone(),
3847 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3848 nonce: Some(nonce),
3849 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3850 edited_at: Some(updated_at.unix_timestamp() as u64),
3851 };
3852
3853 response.send(proto::Ack {})?;
3854
3855 let pool = &*session.connection_pool().await;
3856 broadcast(
3857 Some(session.connection_id),
3858 participant_connection_ids,
3859 |connection| {
3860 session.peer.send(
3861 connection,
3862 proto::ChannelMessageUpdate {
3863 channel_id: channel_id.to_proto(),
3864 message: Some(message.clone()),
3865 },
3866 )?;
3867
3868 for notification_id in &deleted_mention_notification_ids {
3869 session.peer.send(
3870 connection,
3871 proto::DeleteNotification {
3872 notification_id: (*notification_id).to_proto(),
3873 },
3874 )?;
3875 }
3876
3877 for notification in &updated_mention_notifications {
3878 session.peer.send(
3879 connection,
3880 proto::UpdateNotification {
3881 notification: Some(notification.clone()),
3882 },
3883 )?;
3884 }
3885
3886 Ok(())
3887 },
3888 );
3889
3890 send_notifications(pool, &session.peer, notifications);
3891
3892 Ok(())
3893}
3894
3895/// Mark a channel message as read
3896async fn acknowledge_channel_message(
3897 request: proto::AckChannelMessage,
3898 session: Session,
3899) -> Result<()> {
3900 let channel_id = ChannelId::from_proto(request.channel_id);
3901 let message_id = MessageId::from_proto(request.message_id);
3902 let notifications = session
3903 .db()
3904 .await
3905 .observe_channel_message(channel_id, session.user_id(), message_id)
3906 .await?;
3907 send_notifications(
3908 &*session.connection_pool().await,
3909 &session.peer,
3910 notifications,
3911 );
3912 Ok(())
3913}
3914
3915/// Mark a buffer version as synced
3916async fn acknowledge_buffer_version(
3917 request: proto::AckBufferOperation,
3918 session: Session,
3919) -> Result<()> {
3920 let buffer_id = BufferId::from_proto(request.buffer_id);
3921 session
3922 .db()
3923 .await
3924 .observe_buffer_version(
3925 buffer_id,
3926 session.user_id(),
3927 request.epoch as i32,
3928 &request.version,
3929 )
3930 .await?;
3931 Ok(())
3932}
3933
3934/// Get a Supermaven API key for the user
3935async fn get_supermaven_api_key(
3936 _request: proto::GetSupermavenApiKey,
3937 response: Response<proto::GetSupermavenApiKey>,
3938 session: Session,
3939) -> Result<()> {
3940 let user_id: String = session.user_id().to_string();
3941 if !session.is_staff() {
3942 return Err(anyhow!("supermaven not enabled for this account"))?;
3943 }
3944
3945 let email = session.email().context("user must have an email")?;
3946
3947 let supermaven_admin_api = session
3948 .supermaven_client
3949 .as_ref()
3950 .context("supermaven not configured")?;
3951
3952 let result = supermaven_admin_api
3953 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3954 .await?;
3955
3956 response.send(proto::GetSupermavenApiKeyResponse {
3957 api_key: result.api_key,
3958 })?;
3959
3960 Ok(())
3961}
3962
3963/// Start receiving chat updates for a channel
3964async fn join_channel_chat(
3965 request: proto::JoinChannelChat,
3966 response: Response<proto::JoinChannelChat>,
3967 session: Session,
3968) -> Result<()> {
3969 let channel_id = ChannelId::from_proto(request.channel_id);
3970
3971 let db = session.db().await;
3972 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3973 .await?;
3974 let messages = db
3975 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3976 .await?;
3977 response.send(proto::JoinChannelChatResponse {
3978 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3979 messages,
3980 })?;
3981 Ok(())
3982}
3983
3984/// Stop receiving chat updates for a channel
3985async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3986 let channel_id = ChannelId::from_proto(request.channel_id);
3987 session
3988 .db()
3989 .await
3990 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3991 .await?;
3992 Ok(())
3993}
3994
3995/// Retrieve the chat history for a channel
3996async fn get_channel_messages(
3997 request: proto::GetChannelMessages,
3998 response: Response<proto::GetChannelMessages>,
3999 session: Session,
4000) -> Result<()> {
4001 let channel_id = ChannelId::from_proto(request.channel_id);
4002 let messages = session
4003 .db()
4004 .await
4005 .get_channel_messages(
4006 channel_id,
4007 session.user_id(),
4008 MESSAGE_COUNT_PER_PAGE,
4009 Some(MessageId::from_proto(request.before_message_id)),
4010 )
4011 .await?;
4012 response.send(proto::GetChannelMessagesResponse {
4013 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4014 messages,
4015 })?;
4016 Ok(())
4017}
4018
4019/// Retrieve specific chat messages
4020async fn get_channel_messages_by_id(
4021 request: proto::GetChannelMessagesById,
4022 response: Response<proto::GetChannelMessagesById>,
4023 session: Session,
4024) -> Result<()> {
4025 let message_ids = request
4026 .message_ids
4027 .iter()
4028 .map(|id| MessageId::from_proto(*id))
4029 .collect::<Vec<_>>();
4030 let messages = session
4031 .db()
4032 .await
4033 .get_channel_messages_by_id(session.user_id(), &message_ids)
4034 .await?;
4035 response.send(proto::GetChannelMessagesResponse {
4036 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4037 messages,
4038 })?;
4039 Ok(())
4040}
4041
4042/// Retrieve the current users notifications
4043async fn get_notifications(
4044 request: proto::GetNotifications,
4045 response: Response<proto::GetNotifications>,
4046 session: Session,
4047) -> Result<()> {
4048 let notifications = session
4049 .db()
4050 .await
4051 .get_notifications(
4052 session.user_id(),
4053 NOTIFICATION_COUNT_PER_PAGE,
4054 request.before_id.map(db::NotificationId::from_proto),
4055 )
4056 .await?;
4057 response.send(proto::GetNotificationsResponse {
4058 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4059 notifications,
4060 })?;
4061 Ok(())
4062}
4063
4064/// Mark notifications as read
4065async fn mark_notification_as_read(
4066 request: proto::MarkNotificationRead,
4067 response: Response<proto::MarkNotificationRead>,
4068 session: Session,
4069) -> Result<()> {
4070 let database = &session.db().await;
4071 let notifications = database
4072 .mark_notification_as_read_by_id(
4073 session.user_id(),
4074 NotificationId::from_proto(request.notification_id),
4075 )
4076 .await?;
4077 send_notifications(
4078 &*session.connection_pool().await,
4079 &session.peer,
4080 notifications,
4081 );
4082 response.send(proto::Ack {})?;
4083 Ok(())
4084}
4085
4086/// Get the current users information
4087async fn get_private_user_info(
4088 _request: proto::GetPrivateUserInfo,
4089 response: Response<proto::GetPrivateUserInfo>,
4090 session: Session,
4091) -> Result<()> {
4092 let db = session.db().await;
4093
4094 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4095 let user = db
4096 .get_user_by_id(session.user_id())
4097 .await?
4098 .context("user not found")?;
4099 let flags = db.get_user_flags(session.user_id()).await?;
4100
4101 response.send(proto::GetPrivateUserInfoResponse {
4102 metrics_id,
4103 staff: user.admin,
4104 flags,
4105 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4106 })?;
4107 Ok(())
4108}
4109
4110/// Accept the terms of service (tos) on behalf of the current user
4111async fn accept_terms_of_service(
4112 _request: proto::AcceptTermsOfService,
4113 response: Response<proto::AcceptTermsOfService>,
4114 session: Session,
4115) -> Result<()> {
4116 let db = session.db().await;
4117
4118 let accepted_tos_at = Utc::now();
4119 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4120 .await?;
4121
4122 response.send(proto::AcceptTermsOfServiceResponse {
4123 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4124 })?;
4125 Ok(())
4126}
4127
4128async fn get_llm_api_token(
4129 _request: proto::GetLlmToken,
4130 response: Response<proto::GetLlmToken>,
4131 session: Session,
4132) -> Result<()> {
4133 let db = session.db().await;
4134
4135 let flags = db.get_user_flags(session.user_id()).await?;
4136
4137 let user_id = session.user_id();
4138 let user = db
4139 .get_user_by_id(user_id)
4140 .await?
4141 .with_context(|| format!("user {user_id} not found"))?;
4142
4143 if user.accepted_tos_at.is_none() {
4144 Err(anyhow!("terms of service not accepted"))?
4145 }
4146
4147 let stripe_client = session
4148 .app_state
4149 .stripe_client
4150 .as_ref()
4151 .context("failed to retrieve Stripe client")?;
4152
4153 let stripe_billing = session
4154 .app_state
4155 .stripe_billing
4156 .as_ref()
4157 .context("failed to retrieve Stripe billing object")?;
4158
4159 let billing_customer = if let Some(billing_customer) =
4160 db.get_billing_customer_by_user_id(user.id).await?
4161 {
4162 billing_customer
4163 } else {
4164 let customer_id = stripe_billing
4165 .find_or_create_customer_by_email(user.email_address.as_deref())
4166 .await?;
4167
4168 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4169 .await?
4170 .context("billing customer not found")?
4171 };
4172
4173 let billing_subscription =
4174 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4175 billing_subscription
4176 } else {
4177 let stripe_customer_id =
4178 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4179
4180 let stripe_subscription = stripe_billing
4181 .subscribe_to_zed_free(stripe_customer_id)
4182 .await?;
4183
4184 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4185 billing_customer_id: billing_customer.id,
4186 kind: Some(SubscriptionKind::ZedFree),
4187 stripe_subscription_id: stripe_subscription.id.to_string(),
4188 stripe_subscription_status: stripe_subscription.status.into(),
4189 stripe_cancellation_reason: None,
4190 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4191 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4192 })
4193 .await?
4194 };
4195
4196 let billing_preferences = db.get_billing_preferences(user.id).await?;
4197
4198 let token = LlmTokenClaims::create(
4199 &user,
4200 session.is_staff(),
4201 billing_customer,
4202 billing_preferences,
4203 &flags,
4204 billing_subscription,
4205 session.system_id.clone(),
4206 &session.app_state.config,
4207 )?;
4208 response.send(proto::GetLlmTokenResponse { token })?;
4209 Ok(())
4210}
4211
4212fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4213 let message = match message {
4214 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4215 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4216 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4217 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4218 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4219 code: frame.code.into(),
4220 reason: frame.reason.as_str().to_owned().into(),
4221 })),
4222 // We should never receive a frame while reading the message, according
4223 // to the `tungstenite` maintainers:
4224 //
4225 // > It cannot occur when you read messages from the WebSocket, but it
4226 // > can be used when you want to send the raw frames (e.g. you want to
4227 // > send the frames to the WebSocket without composing the full message first).
4228 // >
4229 // > — https://github.com/snapview/tungstenite-rs/issues/268
4230 TungsteniteMessage::Frame(_) => {
4231 bail!("received an unexpected frame while reading the message")
4232 }
4233 };
4234
4235 Ok(message)
4236}
4237
4238fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4239 match message {
4240 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4241 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4242 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4243 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4244 AxumMessage::Close(frame) => {
4245 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4246 code: frame.code.into(),
4247 reason: frame.reason.as_ref().into(),
4248 }))
4249 }
4250 }
4251}
4252
4253fn notify_membership_updated(
4254 connection_pool: &mut ConnectionPool,
4255 result: MembershipUpdated,
4256 user_id: UserId,
4257 peer: &Peer,
4258) {
4259 for membership in &result.new_channels.channel_memberships {
4260 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4261 }
4262 for channel_id in &result.removed_channels {
4263 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4264 }
4265
4266 let user_channels_update = proto::UpdateUserChannels {
4267 channel_memberships: result
4268 .new_channels
4269 .channel_memberships
4270 .iter()
4271 .map(|cm| proto::ChannelMembership {
4272 channel_id: cm.channel_id.to_proto(),
4273 role: cm.role.into(),
4274 })
4275 .collect(),
4276 ..Default::default()
4277 };
4278
4279 let mut update = build_channels_update(result.new_channels);
4280 update.delete_channels = result
4281 .removed_channels
4282 .into_iter()
4283 .map(|id| id.to_proto())
4284 .collect();
4285 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4286
4287 for connection_id in connection_pool.user_connection_ids(user_id) {
4288 peer.send(connection_id, user_channels_update.clone())
4289 .trace_err();
4290 peer.send(connection_id, update.clone()).trace_err();
4291 }
4292}
4293
4294fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4295 proto::UpdateUserChannels {
4296 channel_memberships: channels
4297 .channel_memberships
4298 .iter()
4299 .map(|m| proto::ChannelMembership {
4300 channel_id: m.channel_id.to_proto(),
4301 role: m.role.into(),
4302 })
4303 .collect(),
4304 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4305 observed_channel_message_id: channels.observed_channel_messages.clone(),
4306 }
4307}
4308
4309fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4310 let mut update = proto::UpdateChannels::default();
4311
4312 for channel in channels.channels {
4313 update.channels.push(channel.to_proto());
4314 }
4315
4316 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4317 update.latest_channel_message_ids = channels.latest_channel_messages;
4318
4319 for (channel_id, participants) in channels.channel_participants {
4320 update
4321 .channel_participants
4322 .push(proto::ChannelParticipants {
4323 channel_id: channel_id.to_proto(),
4324 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4325 });
4326 }
4327
4328 for channel in channels.invited_channels {
4329 update.channel_invitations.push(channel.to_proto());
4330 }
4331
4332 update
4333}
4334
4335fn build_initial_contacts_update(
4336 contacts: Vec<db::Contact>,
4337 pool: &ConnectionPool,
4338) -> proto::UpdateContacts {
4339 let mut update = proto::UpdateContacts::default();
4340
4341 for contact in contacts {
4342 match contact {
4343 db::Contact::Accepted { user_id, busy } => {
4344 update.contacts.push(contact_for_user(user_id, busy, pool));
4345 }
4346 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4347 db::Contact::Incoming { user_id } => {
4348 update
4349 .incoming_requests
4350 .push(proto::IncomingContactRequest {
4351 requester_id: user_id.to_proto(),
4352 })
4353 }
4354 }
4355 }
4356
4357 update
4358}
4359
4360fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4361 proto::Contact {
4362 user_id: user_id.to_proto(),
4363 online: pool.is_user_online(user_id),
4364 busy,
4365 }
4366}
4367
4368fn room_updated(room: &proto::Room, peer: &Peer) {
4369 broadcast(
4370 None,
4371 room.participants
4372 .iter()
4373 .filter_map(|participant| Some(participant.peer_id?.into())),
4374 |peer_id| {
4375 peer.send(
4376 peer_id,
4377 proto::RoomUpdated {
4378 room: Some(room.clone()),
4379 },
4380 )
4381 },
4382 );
4383}
4384
4385fn channel_updated(
4386 channel: &db::channel::Model,
4387 room: &proto::Room,
4388 peer: &Peer,
4389 pool: &ConnectionPool,
4390) {
4391 let participants = room
4392 .participants
4393 .iter()
4394 .map(|p| p.user_id)
4395 .collect::<Vec<_>>();
4396
4397 broadcast(
4398 None,
4399 pool.channel_connection_ids(channel.root_id())
4400 .filter_map(|(channel_id, role)| {
4401 role.can_see_channel(channel.visibility)
4402 .then_some(channel_id)
4403 }),
4404 |peer_id| {
4405 peer.send(
4406 peer_id,
4407 proto::UpdateChannels {
4408 channel_participants: vec![proto::ChannelParticipants {
4409 channel_id: channel.id.to_proto(),
4410 participant_user_ids: participants.clone(),
4411 }],
4412 ..Default::default()
4413 },
4414 )
4415 },
4416 );
4417}
4418
4419async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4420 let db = session.db().await;
4421
4422 let contacts = db.get_contacts(user_id).await?;
4423 let busy = db.is_user_busy(user_id).await?;
4424
4425 let pool = session.connection_pool().await;
4426 let updated_contact = contact_for_user(user_id, busy, &pool);
4427 for contact in contacts {
4428 if let db::Contact::Accepted {
4429 user_id: contact_user_id,
4430 ..
4431 } = contact
4432 {
4433 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4434 session
4435 .peer
4436 .send(
4437 contact_conn_id,
4438 proto::UpdateContacts {
4439 contacts: vec![updated_contact.clone()],
4440 remove_contacts: Default::default(),
4441 incoming_requests: Default::default(),
4442 remove_incoming_requests: Default::default(),
4443 outgoing_requests: Default::default(),
4444 remove_outgoing_requests: Default::default(),
4445 },
4446 )
4447 .trace_err();
4448 }
4449 }
4450 }
4451 Ok(())
4452}
4453
4454async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4455 let mut contacts_to_update = HashSet::default();
4456
4457 let room_id;
4458 let canceled_calls_to_user_ids;
4459 let livekit_room;
4460 let delete_livekit_room;
4461 let room;
4462 let channel;
4463
4464 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4465 contacts_to_update.insert(session.user_id());
4466
4467 for project in left_room.left_projects.values() {
4468 project_left(project, session);
4469 }
4470
4471 room_id = RoomId::from_proto(left_room.room.id);
4472 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4473 livekit_room = mem::take(&mut left_room.room.livekit_room);
4474 delete_livekit_room = left_room.deleted;
4475 room = mem::take(&mut left_room.room);
4476 channel = mem::take(&mut left_room.channel);
4477
4478 room_updated(&room, &session.peer);
4479 } else {
4480 return Ok(());
4481 }
4482
4483 if let Some(channel) = channel {
4484 channel_updated(
4485 &channel,
4486 &room,
4487 &session.peer,
4488 &*session.connection_pool().await,
4489 );
4490 }
4491
4492 {
4493 let pool = session.connection_pool().await;
4494 for canceled_user_id in canceled_calls_to_user_ids {
4495 for connection_id in pool.user_connection_ids(canceled_user_id) {
4496 session
4497 .peer
4498 .send(
4499 connection_id,
4500 proto::CallCanceled {
4501 room_id: room_id.to_proto(),
4502 },
4503 )
4504 .trace_err();
4505 }
4506 contacts_to_update.insert(canceled_user_id);
4507 }
4508 }
4509
4510 for contact_user_id in contacts_to_update {
4511 update_user_contacts(contact_user_id, session).await?;
4512 }
4513
4514 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4515 live_kit
4516 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4517 .await
4518 .trace_err();
4519
4520 if delete_livekit_room {
4521 live_kit.delete_room(livekit_room).await.trace_err();
4522 }
4523 }
4524
4525 Ok(())
4526}
4527
4528async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4529 let left_channel_buffers = session
4530 .db()
4531 .await
4532 .leave_channel_buffers(session.connection_id)
4533 .await?;
4534
4535 for left_buffer in left_channel_buffers {
4536 channel_buffer_updated(
4537 session.connection_id,
4538 left_buffer.connections,
4539 &proto::UpdateChannelBufferCollaborators {
4540 channel_id: left_buffer.channel_id.to_proto(),
4541 collaborators: left_buffer.collaborators,
4542 },
4543 &session.peer,
4544 );
4545 }
4546
4547 Ok(())
4548}
4549
4550fn project_left(project: &db::LeftProject, session: &Session) {
4551 for connection_id in &project.connection_ids {
4552 if project.should_unshare {
4553 session
4554 .peer
4555 .send(
4556 *connection_id,
4557 proto::UnshareProject {
4558 project_id: project.id.to_proto(),
4559 },
4560 )
4561 .trace_err();
4562 } else {
4563 session
4564 .peer
4565 .send(
4566 *connection_id,
4567 proto::RemoveProjectCollaborator {
4568 project_id: project.id.to_proto(),
4569 peer_id: Some(session.connection_id.into()),
4570 },
4571 )
4572 .trace_err();
4573 }
4574 }
4575}
4576
4577pub trait ResultExt {
4578 type Ok;
4579
4580 fn trace_err(self) -> Option<Self::Ok>;
4581}
4582
4583impl<T, E> ResultExt for Result<T, E>
4584where
4585 E: std::fmt::Debug,
4586{
4587 type Ok = T;
4588
4589 #[track_caller]
4590 fn trace_err(self) -> Option<T> {
4591 match self {
4592 Ok(value) => Some(value),
4593 Err(error) => {
4594 tracing::error!("{:?}", error);
4595 None
4596 }
4597 }
4598 }
4599}