1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, MutexGuard, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81use self::connection_pool::VersionedMessage;
82
83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
84
85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
87
88const MESSAGE_COUNT_PER_PAGE: usize = 100;
89const MAX_MESSAGE_LEN: usize = 1024;
90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
94
95struct Response<R> {
96 peer: Arc<Peer>,
97 receipt: Receipt<R>,
98 responded: Arc<AtomicBool>,
99}
100
101impl<R: RequestMessage> Response<R> {
102 fn send(self, payload: R::Response) -> Result<()> {
103 self.responded.store(true, SeqCst);
104 self.peer.respond(self.receipt, payload)?;
105 Ok(())
106 }
107}
108
109#[derive(Clone, Debug)]
110pub enum Principal {
111 User(User),
112 Impersonated { user: User, admin: User },
113 DevServer(dev_server::Model),
114}
115
116impl Principal {
117 fn update_span(&self, span: &tracing::Span) {
118 match &self {
119 Principal::User(user) => {
120 span.record("user_id", &user.id.0);
121 span.record("login", &user.github_login);
122 }
123 Principal::Impersonated { user, admin } => {
124 span.record("user_id", &user.id.0);
125 span.record("login", &user.github_login);
126 span.record("impersonator", &admin.github_login);
127 }
128 Principal::DevServer(dev_server) => {
129 span.record("dev_server_id", &dev_server.id.0);
130 }
131 }
132 }
133}
134
135#[derive(Clone)]
136struct Session {
137 principal: Principal,
138 connection_id: ConnectionId,
139 db: Arc<tokio::sync::Mutex<DbHandle>>,
140 peer: Arc<Peer>,
141 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
142 app_state: Arc<AppState>,
143 supermaven_client: Option<Arc<SupermavenAdminApi>>,
144 http_client: Arc<IsahcHttpClient>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn for_dev_server(self) -> Option<DevServerSession> {
176 DevServerSession::new(self)
177 }
178
179 fn user_id(&self) -> Option<UserId> {
180 match &self.principal {
181 Principal::User(user) => Some(user.id),
182 Principal::Impersonated { user, .. } => Some(user.id),
183 Principal::DevServer(_) => None,
184 }
185 }
186
187 fn is_staff(&self) -> bool {
188 match &self.principal {
189 Principal::User(user) => user.admin,
190 Principal::Impersonated { .. } => true,
191 Principal::DevServer(_) => false,
192 }
193 }
194
195 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
196 if self.is_staff() {
197 return Ok(proto::Plan::ZedPro);
198 }
199
200 let Some(user_id) = self.user_id() else {
201 return Ok(proto::Plan::Free);
202 };
203
204 if db.has_active_billing_subscription(user_id).await? {
205 Ok(proto::Plan::ZedPro)
206 } else {
207 Ok(proto::Plan::Free)
208 }
209 }
210
211 fn dev_server_id(&self) -> Option<DevServerId> {
212 match &self.principal {
213 Principal::User(_) | Principal::Impersonated { .. } => None,
214 Principal::DevServer(dev_server) => Some(dev_server.id),
215 }
216 }
217
218 fn principal_id(&self) -> PrincipalId {
219 match &self.principal {
220 Principal::User(user) => PrincipalId::UserId(user.id),
221 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
222 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
223 }
224 }
225}
226
227impl Debug for Session {
228 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
229 let mut result = f.debug_struct("Session");
230 match &self.principal {
231 Principal::User(user) => {
232 result.field("user", &user.github_login);
233 }
234 Principal::Impersonated { user, admin } => {
235 result.field("user", &user.github_login);
236 result.field("impersonator", &admin.github_login);
237 }
238 Principal::DevServer(dev_server) => {
239 result.field("dev_server", &dev_server.id);
240 }
241 }
242 result.field("connection_id", &self.connection_id).finish()
243 }
244}
245
246struct UserSession(Session);
247
248impl UserSession {
249 pub fn new(s: Session) -> Option<Self> {
250 s.user_id().map(|_| UserSession(s))
251 }
252 pub fn user_id(&self) -> UserId {
253 self.0.user_id().unwrap()
254 }
255
256 pub fn email(&self) -> Option<String> {
257 match &self.0.principal {
258 Principal::User(user) => user.email_address.clone(),
259 Principal::Impersonated { user, .. } => user.email_address.clone(),
260 Principal::DevServer(..) => None,
261 }
262 }
263}
264
265impl Deref for UserSession {
266 type Target = Session;
267
268 fn deref(&self) -> &Self::Target {
269 &self.0
270 }
271}
272impl DerefMut for UserSession {
273 fn deref_mut(&mut self) -> &mut Self::Target {
274 &mut self.0
275 }
276}
277
278struct DevServerSession(Session);
279
280impl DevServerSession {
281 pub fn new(s: Session) -> Option<Self> {
282 s.dev_server_id().map(|_| DevServerSession(s))
283 }
284 pub fn dev_server_id(&self) -> DevServerId {
285 self.0.dev_server_id().unwrap()
286 }
287
288 fn dev_server(&self) -> &dev_server::Model {
289 match &self.0.principal {
290 Principal::DevServer(dev_server) => dev_server,
291 _ => unreachable!(),
292 }
293 }
294}
295
296impl Deref for DevServerSession {
297 type Target = Session;
298
299 fn deref(&self) -> &Self::Target {
300 &self.0
301 }
302}
303impl DerefMut for DevServerSession {
304 fn deref_mut(&mut self) -> &mut Self::Target {
305 &mut self.0
306 }
307}
308
309fn user_handler<M: RequestMessage, Fut>(
310 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
311) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
312where
313 Fut: Send + Future<Output = Result<()>>,
314{
315 let handler = Arc::new(handler);
316 move |message, response, session| {
317 let handler = handler.clone();
318 Box::pin(async move {
319 if let Some(user_session) = session.for_user() {
320 Ok(handler(message, response, user_session).await?)
321 } else {
322 Err(Error::Internal(anyhow!(
323 "must be a user to call {}",
324 M::NAME
325 )))
326 }
327 })
328 }
329}
330
331fn dev_server_handler<M: RequestMessage, Fut>(
332 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
333) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
334where
335 Fut: Send + Future<Output = Result<()>>,
336{
337 let handler = Arc::new(handler);
338 move |message, response, session| {
339 let handler = handler.clone();
340 Box::pin(async move {
341 if let Some(dev_server_session) = session.for_dev_server() {
342 Ok(handler(message, response, dev_server_session).await?)
343 } else {
344 Err(Error::Internal(anyhow!(
345 "must be a dev server to call {}",
346 M::NAME
347 )))
348 }
349 })
350 }
351}
352
353fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
354 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
355) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
356where
357 InnertRetFut: Send + Future<Output = Result<()>>,
358{
359 let handler = Arc::new(handler);
360 move |message, session| {
361 let handler = handler.clone();
362 Box::pin(async move {
363 if let Some(user_session) = session.for_user() {
364 Ok(handler(message, user_session).await?)
365 } else {
366 Err(Error::Internal(anyhow!(
367 "must be a user to call {}",
368 M::NAME
369 )))
370 }
371 })
372 }
373}
374
375struct DbHandle(Arc<Database>);
376
377impl Deref for DbHandle {
378 type Target = Database;
379
380 fn deref(&self) -> &Self::Target {
381 self.0.as_ref()
382 }
383}
384
385pub struct Server {
386 id: parking_lot::Mutex<ServerId>,
387 peer: Arc<Peer>,
388 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
389 app_state: Arc<AppState>,
390 handlers: HashMap<TypeId, MessageHandler>,
391 teardown: watch::Sender<bool>,
392}
393
394pub(crate) struct ConnectionPoolGuard<'a> {
395 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
396 _not_send: PhantomData<Rc<()>>,
397}
398
399#[derive(Serialize)]
400pub struct ServerSnapshot<'a> {
401 peer: &'a Peer,
402 #[serde(serialize_with = "serialize_deref")]
403 connection_pool: ConnectionPoolGuard<'a>,
404}
405
406pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
407where
408 S: Serializer,
409 T: Deref<Target = U>,
410 U: Serialize,
411{
412 Serialize::serialize(value.deref(), serializer)
413}
414
415impl Server {
416 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
417 let mut server = Self {
418 id: parking_lot::Mutex::new(id),
419 peer: Peer::new(id.0 as u32),
420 app_state: app_state.clone(),
421 connection_pool: Default::default(),
422 handlers: Default::default(),
423 teardown: watch::channel(false).0,
424 };
425
426 server
427 .add_request_handler(ping)
428 .add_request_handler(user_handler(create_room))
429 .add_request_handler(user_handler(join_room))
430 .add_request_handler(user_handler(rejoin_room))
431 .add_request_handler(user_handler(leave_room))
432 .add_request_handler(user_handler(set_room_participant_role))
433 .add_request_handler(user_handler(call))
434 .add_request_handler(user_handler(cancel_call))
435 .add_message_handler(user_message_handler(decline_call))
436 .add_request_handler(user_handler(update_participant_location))
437 .add_request_handler(user_handler(share_project))
438 .add_message_handler(unshare_project)
439 .add_request_handler(user_handler(join_project))
440 .add_request_handler(user_handler(join_hosted_project))
441 .add_request_handler(user_handler(rejoin_dev_server_projects))
442 .add_request_handler(user_handler(create_dev_server_project))
443 .add_request_handler(user_handler(update_dev_server_project))
444 .add_request_handler(user_handler(delete_dev_server_project))
445 .add_request_handler(user_handler(create_dev_server))
446 .add_request_handler(user_handler(regenerate_dev_server_token))
447 .add_request_handler(user_handler(rename_dev_server))
448 .add_request_handler(user_handler(delete_dev_server))
449 .add_request_handler(user_handler(list_remote_directory))
450 .add_request_handler(dev_server_handler(share_dev_server_project))
451 .add_request_handler(dev_server_handler(shutdown_dev_server))
452 .add_request_handler(dev_server_handler(reconnect_dev_server))
453 .add_message_handler(user_message_handler(leave_project))
454 .add_request_handler(update_project)
455 .add_request_handler(update_worktree)
456 .add_message_handler(start_language_server)
457 .add_message_handler(update_language_server)
458 .add_message_handler(update_diagnostic_summary)
459 .add_message_handler(update_worktree_settings)
460 .add_request_handler(user_handler(
461 forward_project_request_for_owner::<proto::TaskContextForLocation>,
462 ))
463 .add_request_handler(user_handler(
464 forward_project_request_for_owner::<proto::TaskTemplates>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetHover>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::GetDefinition>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetTypeDefinition>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetReferences>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::SearchProject>,
480 ))
481 .add_request_handler(user_handler(forward_find_search_candidates_request))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::GetDocumentHighlights>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::GetProjectSymbols>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
490 ))
491 .add_request_handler(user_handler(
492 forward_read_only_project_request::<proto::OpenBufferById>,
493 ))
494 .add_request_handler(user_handler(
495 forward_read_only_project_request::<proto::SynchronizeBuffers>,
496 ))
497 .add_request_handler(user_handler(
498 forward_read_only_project_request::<proto::InlayHints>,
499 ))
500 .add_request_handler(user_handler(
501 forward_read_only_project_request::<proto::OpenBufferByPath>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCompletions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
508 ))
509 .add_request_handler(user_handler(
510 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::GetCodeActions>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::ApplyCodeAction>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::PrepareRename>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::PerformRename>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::ReloadBuffers>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::FormatBuffers>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::CreateProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::RenameProjectEntry>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::CopyProjectEntry>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::DeleteProjectEntry>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::ExpandProjectEntry>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::OnTypeFormatting>,
550 ))
551 .add_request_handler(user_handler(
552 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
553 ))
554 .add_request_handler(user_handler(
555 forward_mutating_project_request::<proto::BlameBuffer>,
556 ))
557 .add_request_handler(user_handler(
558 forward_mutating_project_request::<proto::MultiLspQuery>,
559 ))
560 .add_request_handler(user_handler(
561 forward_mutating_project_request::<proto::RestartLanguageServers>,
562 ))
563 .add_request_handler(user_handler(
564 forward_mutating_project_request::<proto::LinkedEditingRange>,
565 ))
566 .add_message_handler(create_buffer_for_peer)
567 .add_request_handler(update_buffer)
568 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
573 .add_request_handler(get_users)
574 .add_request_handler(user_handler(fuzzy_search_users))
575 .add_request_handler(user_handler(request_contact))
576 .add_request_handler(user_handler(remove_contact))
577 .add_request_handler(user_handler(respond_to_contact_request))
578 .add_message_handler(subscribe_to_channels)
579 .add_request_handler(user_handler(create_channel))
580 .add_request_handler(user_handler(delete_channel))
581 .add_request_handler(user_handler(invite_channel_member))
582 .add_request_handler(user_handler(remove_channel_member))
583 .add_request_handler(user_handler(set_channel_member_role))
584 .add_request_handler(user_handler(set_channel_visibility))
585 .add_request_handler(user_handler(rename_channel))
586 .add_request_handler(user_handler(join_channel_buffer))
587 .add_request_handler(user_handler(leave_channel_buffer))
588 .add_message_handler(user_message_handler(update_channel_buffer))
589 .add_request_handler(user_handler(rejoin_channel_buffers))
590 .add_request_handler(user_handler(get_channel_members))
591 .add_request_handler(user_handler(respond_to_channel_invite))
592 .add_request_handler(user_handler(join_channel))
593 .add_request_handler(user_handler(join_channel_chat))
594 .add_message_handler(user_message_handler(leave_channel_chat))
595 .add_request_handler(user_handler(send_channel_message))
596 .add_request_handler(user_handler(remove_channel_message))
597 .add_request_handler(user_handler(update_channel_message))
598 .add_request_handler(user_handler(get_channel_messages))
599 .add_request_handler(user_handler(get_channel_messages_by_id))
600 .add_request_handler(user_handler(get_notifications))
601 .add_request_handler(user_handler(mark_notification_as_read))
602 .add_request_handler(user_handler(move_channel))
603 .add_request_handler(user_handler(follow))
604 .add_message_handler(user_message_handler(unfollow))
605 .add_message_handler(user_message_handler(update_followers))
606 .add_request_handler(user_handler(get_private_user_info))
607 .add_request_handler(user_handler(get_llm_api_token))
608 .add_request_handler(user_handler(accept_terms_of_service))
609 .add_message_handler(user_message_handler(acknowledge_channel_message))
610 .add_message_handler(user_message_handler(acknowledge_buffer_version))
611 .add_request_handler(user_handler(get_supermaven_api_key))
612 .add_request_handler(user_handler(
613 forward_mutating_project_request::<proto::OpenContext>,
614 ))
615 .add_request_handler(user_handler(
616 forward_mutating_project_request::<proto::CreateContext>,
617 ))
618 .add_request_handler(user_handler(
619 forward_mutating_project_request::<proto::SynchronizeContexts>,
620 ))
621 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
622 .add_message_handler(update_context)
623 .add_request_handler({
624 let app_state = app_state.clone();
625 move |request, response, session| {
626 let app_state = app_state.clone();
627 async move {
628 count_language_model_tokens(request, response, session, &app_state.config)
629 .await
630 }
631 }
632 })
633 .add_request_handler({
634 user_handler(move |request, response, session| {
635 get_cached_embeddings(request, response, session)
636 })
637 })
638 .add_request_handler({
639 let app_state = app_state.clone();
640 user_handler(move |request, response, session| {
641 compute_embeddings(
642 request,
643 response,
644 session,
645 app_state.config.openai_api_key.clone(),
646 )
647 })
648 });
649
650 Arc::new(server)
651 }
652
653 pub async fn start(&self) -> Result<()> {
654 let server_id = *self.id.lock();
655 let app_state = self.app_state.clone();
656 let peer = self.peer.clone();
657 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
658 let pool = self.connection_pool.clone();
659 let live_kit_client = self.app_state.live_kit_client.clone();
660
661 let span = info_span!("start server");
662 self.app_state.executor.spawn_detached(
663 async move {
664 tracing::info!("waiting for cleanup timeout");
665 timeout.await;
666 tracing::info!("cleanup timeout expired, retrieving stale rooms");
667 if let Some((room_ids, channel_ids)) = app_state
668 .db
669 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
670 .await
671 .trace_err()
672 {
673 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
674 tracing::info!(
675 stale_channel_buffer_count = channel_ids.len(),
676 "retrieved stale channel buffers"
677 );
678
679 for channel_id in channel_ids {
680 if let Some(refreshed_channel_buffer) = app_state
681 .db
682 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
683 .await
684 .trace_err()
685 {
686 for connection_id in refreshed_channel_buffer.connection_ids {
687 peer.send(
688 connection_id,
689 proto::UpdateChannelBufferCollaborators {
690 channel_id: channel_id.to_proto(),
691 collaborators: refreshed_channel_buffer
692 .collaborators
693 .clone(),
694 },
695 )
696 .trace_err();
697 }
698 }
699 }
700
701 for room_id in room_ids {
702 let mut contacts_to_update = HashSet::default();
703 let mut canceled_calls_to_user_ids = Vec::new();
704 let mut live_kit_room = String::new();
705 let mut delete_live_kit_room = false;
706
707 if let Some(mut refreshed_room) = app_state
708 .db
709 .clear_stale_room_participants(room_id, server_id)
710 .await
711 .trace_err()
712 {
713 tracing::info!(
714 room_id = room_id.0,
715 new_participant_count = refreshed_room.room.participants.len(),
716 "refreshed room"
717 );
718 room_updated(&refreshed_room.room, &peer);
719 if let Some(channel) = refreshed_room.channel.as_ref() {
720 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
721 }
722 contacts_to_update
723 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
724 contacts_to_update
725 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
726 canceled_calls_to_user_ids =
727 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
728 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
729 delete_live_kit_room = refreshed_room.room.participants.is_empty();
730 }
731
732 {
733 let pool = pool.lock();
734 for canceled_user_id in canceled_calls_to_user_ids {
735 for connection_id in pool.user_connection_ids(canceled_user_id) {
736 peer.send(
737 connection_id,
738 proto::CallCanceled {
739 room_id: room_id.to_proto(),
740 },
741 )
742 .trace_err();
743 }
744 }
745 }
746
747 for user_id in contacts_to_update {
748 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
749 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
750 if let Some((busy, contacts)) = busy.zip(contacts) {
751 let pool = pool.lock();
752 let updated_contact = contact_for_user(user_id, busy, &pool);
753 for contact in contacts {
754 if let db::Contact::Accepted {
755 user_id: contact_user_id,
756 ..
757 } = contact
758 {
759 for contact_conn_id in
760 pool.user_connection_ids(contact_user_id)
761 {
762 peer.send(
763 contact_conn_id,
764 proto::UpdateContacts {
765 contacts: vec![updated_contact.clone()],
766 remove_contacts: Default::default(),
767 incoming_requests: Default::default(),
768 remove_incoming_requests: Default::default(),
769 outgoing_requests: Default::default(),
770 remove_outgoing_requests: Default::default(),
771 },
772 )
773 .trace_err();
774 }
775 }
776 }
777 }
778 }
779
780 if let Some(live_kit) = live_kit_client.as_ref() {
781 if delete_live_kit_room {
782 live_kit.delete_room(live_kit_room).await.trace_err();
783 }
784 }
785 }
786 }
787
788 app_state
789 .db
790 .delete_stale_servers(&app_state.config.zed_environment, server_id)
791 .await
792 .trace_err();
793 }
794 .instrument(span),
795 );
796 Ok(())
797 }
798
799 pub fn teardown(&self) {
800 self.peer.teardown();
801 self.connection_pool.lock().reset();
802 let _ = self.teardown.send(true);
803 }
804
805 #[cfg(test)]
806 pub fn reset(&self, id: ServerId) {
807 self.teardown();
808 *self.id.lock() = id;
809 self.peer.reset(id.0 as u32);
810 let _ = self.teardown.send(false);
811 }
812
813 #[cfg(test)]
814 pub fn id(&self) -> ServerId {
815 *self.id.lock()
816 }
817
818 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
819 where
820 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
821 Fut: 'static + Send + Future<Output = Result<()>>,
822 M: EnvelopedMessage,
823 {
824 let prev_handler = self.handlers.insert(
825 TypeId::of::<M>(),
826 Box::new(move |envelope, session| {
827 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
828 let received_at = envelope.received_at;
829 tracing::info!("message received");
830 let start_time = Instant::now();
831 let future = (handler)(*envelope, session);
832 async move {
833 let result = future.await;
834 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
835 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
836 let queue_duration_ms = total_duration_ms - processing_duration_ms;
837 let payload_type = M::NAME;
838
839 match result {
840 Err(error) => {
841 tracing::error!(
842 ?error,
843 total_duration_ms,
844 processing_duration_ms,
845 queue_duration_ms,
846 payload_type,
847 "error handling message"
848 )
849 }
850 Ok(()) => tracing::info!(
851 total_duration_ms,
852 processing_duration_ms,
853 queue_duration_ms,
854 "finished handling message"
855 ),
856 }
857 }
858 .boxed()
859 }),
860 );
861 if prev_handler.is_some() {
862 panic!("registered a handler for the same message twice");
863 }
864 self
865 }
866
867 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
868 where
869 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
870 Fut: 'static + Send + Future<Output = Result<()>>,
871 M: EnvelopedMessage,
872 {
873 self.add_handler(move |envelope, session| handler(envelope.payload, session));
874 self
875 }
876
877 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
878 where
879 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
880 Fut: Send + Future<Output = Result<()>>,
881 M: RequestMessage,
882 {
883 let handler = Arc::new(handler);
884 self.add_handler(move |envelope, session| {
885 let receipt = envelope.receipt();
886 let handler = handler.clone();
887 async move {
888 let peer = session.peer.clone();
889 let responded = Arc::new(AtomicBool::default());
890 let response = Response {
891 peer: peer.clone(),
892 responded: responded.clone(),
893 receipt,
894 };
895 match (handler)(envelope.payload, response, session).await {
896 Ok(()) => {
897 if responded.load(std::sync::atomic::Ordering::SeqCst) {
898 Ok(())
899 } else {
900 Err(anyhow!("handler did not send a response"))?
901 }
902 }
903 Err(error) => {
904 let proto_err = match &error {
905 Error::Internal(err) => err.to_proto(),
906 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
907 };
908 peer.respond_with_error(receipt, proto_err)?;
909 Err(error)
910 }
911 }
912 }
913 })
914 }
915
916 #[allow(clippy::too_many_arguments)]
917 pub fn handle_connection(
918 self: &Arc<Self>,
919 connection: Connection,
920 address: String,
921 principal: Principal,
922 zed_version: ZedVersion,
923 geoip_country_code: Option<String>,
924 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
925 executor: Executor,
926 ) -> impl Future<Output = ()> {
927 let this = self.clone();
928 let span = info_span!("handle connection", %address,
929 connection_id=field::Empty,
930 user_id=field::Empty,
931 login=field::Empty,
932 impersonator=field::Empty,
933 dev_server_id=field::Empty,
934 geoip_country_code=field::Empty
935 );
936 principal.update_span(&span);
937 if let Some(country_code) = geoip_country_code.as_ref() {
938 span.record("geoip_country_code", country_code);
939 }
940
941 let mut teardown = self.teardown.subscribe();
942 async move {
943 if *teardown.borrow() {
944 tracing::error!("server is tearing down");
945 return
946 }
947 let (connection_id, handle_io, mut incoming_rx) = this
948 .peer
949 .add_connection(connection, {
950 let executor = executor.clone();
951 move |duration| executor.sleep(duration)
952 });
953 tracing::Span::current().record("connection_id", format!("{}", connection_id));
954
955 tracing::info!("connection opened");
956
957 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
958 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
959 Ok(http_client) => Arc::new(http_client),
960 Err(error) => {
961 tracing::error!(?error, "failed to create HTTP client");
962 return;
963 }
964 };
965
966 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
967 Some(Arc::new(SupermavenAdminApi::new(
968 supermaven_admin_api_key.to_string(),
969 http_client.clone(),
970 )))
971 } else {
972 None
973 };
974
975 let session = Session {
976 principal: principal.clone(),
977 connection_id,
978 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
979 peer: this.peer.clone(),
980 connection_pool: this.connection_pool.clone(),
981 app_state: this.app_state.clone(),
982 http_client,
983 geoip_country_code,
984 _executor: executor.clone(),
985 supermaven_client,
986 };
987
988 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
989 tracing::error!(?error, "failed to send initial client update");
990 return;
991 }
992
993 let handle_io = handle_io.fuse();
994 futures::pin_mut!(handle_io);
995
996 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
997 // This prevents deadlocks when e.g., client A performs a request to client B and
998 // client B performs a request to client A. If both clients stop processing further
999 // messages until their respective request completes, they won't have a chance to
1000 // respond to the other client's request and cause a deadlock.
1001 //
1002 // This arrangement ensures we will attempt to process earlier messages first, but fall
1003 // back to processing messages arrived later in the spirit of making progress.
1004 let mut foreground_message_handlers = FuturesUnordered::new();
1005 let concurrent_handlers = Arc::new(Semaphore::new(256));
1006 loop {
1007 let next_message = async {
1008 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1009 let message = incoming_rx.next().await;
1010 (permit, message)
1011 }.fuse();
1012 futures::pin_mut!(next_message);
1013 futures::select_biased! {
1014 _ = teardown.changed().fuse() => return,
1015 result = handle_io => {
1016 if let Err(error) = result {
1017 tracing::error!(?error, "error handling I/O");
1018 }
1019 break;
1020 }
1021 _ = foreground_message_handlers.next() => {}
1022 next_message = next_message => {
1023 let (permit, message) = next_message;
1024 if let Some(message) = message {
1025 let type_name = message.payload_type_name();
1026 // note: we copy all the fields from the parent span so we can query them in the logs.
1027 // (https://github.com/tokio-rs/tracing/issues/2670).
1028 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1029 user_id=field::Empty,
1030 login=field::Empty,
1031 impersonator=field::Empty,
1032 dev_server_id=field::Empty
1033 );
1034 principal.update_span(&span);
1035 let span_enter = span.enter();
1036 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1037 let is_background = message.is_background();
1038 let handle_message = (handler)(message, session.clone());
1039 drop(span_enter);
1040
1041 let handle_message = async move {
1042 handle_message.await;
1043 drop(permit);
1044 }.instrument(span);
1045 if is_background {
1046 executor.spawn_detached(handle_message);
1047 } else {
1048 foreground_message_handlers.push(handle_message);
1049 }
1050 } else {
1051 tracing::error!("no message handler");
1052 }
1053 } else {
1054 tracing::info!("connection closed");
1055 break;
1056 }
1057 }
1058 }
1059 }
1060
1061 drop(foreground_message_handlers);
1062 tracing::info!("signing out");
1063 if let Err(error) = connection_lost(session, teardown, executor).await {
1064 tracing::error!(?error, "error signing out");
1065 }
1066
1067 }.instrument(span)
1068 }
1069
1070 async fn send_initial_client_update(
1071 &self,
1072 connection_id: ConnectionId,
1073 principal: &Principal,
1074 zed_version: ZedVersion,
1075 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1076 session: &Session,
1077 ) -> Result<()> {
1078 self.peer.send(
1079 connection_id,
1080 proto::Hello {
1081 peer_id: Some(connection_id.into()),
1082 },
1083 )?;
1084 tracing::info!("sent hello message");
1085 if let Some(send_connection_id) = send_connection_id.take() {
1086 let _ = send_connection_id.send(connection_id);
1087 }
1088
1089 match principal {
1090 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1091 if !user.connected_once {
1092 self.peer.send(connection_id, proto::ShowContacts {})?;
1093 self.app_state
1094 .db
1095 .set_user_connected_once(user.id, true)
1096 .await?;
1097 }
1098
1099 update_user_plan(user.id, session).await?;
1100
1101 let (contacts, dev_server_projects) = future::try_join(
1102 self.app_state.db.get_contacts(user.id),
1103 self.app_state.db.dev_server_projects_update(user.id),
1104 )
1105 .await?;
1106
1107 {
1108 let mut pool = self.connection_pool.lock();
1109 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1110 self.peer.send(
1111 connection_id,
1112 build_initial_contacts_update(contacts, &pool),
1113 )?;
1114 }
1115
1116 if should_auto_subscribe_to_channels(zed_version) {
1117 subscribe_user_to_channels(user.id, session).await?;
1118 }
1119
1120 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1121
1122 if let Some(incoming_call) =
1123 self.app_state.db.incoming_call_for_user(user.id).await?
1124 {
1125 self.peer.send(connection_id, incoming_call)?;
1126 }
1127
1128 update_user_contacts(user.id, &session).await?;
1129 }
1130 Principal::DevServer(dev_server) => {
1131 {
1132 let mut pool = self.connection_pool.lock();
1133 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1134 {
1135 self.peer.send(
1136 stale_connection_id,
1137 proto::ShutdownDevServer {
1138 reason: Some(
1139 "another dev server connected with the same token".to_string(),
1140 ),
1141 },
1142 )?;
1143 pool.remove_connection(stale_connection_id)?;
1144 };
1145 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1146 }
1147
1148 let projects = self
1149 .app_state
1150 .db
1151 .get_projects_for_dev_server(dev_server.id)
1152 .await?;
1153 self.peer
1154 .send(connection_id, proto::DevServerInstructions { projects })?;
1155
1156 let status = self
1157 .app_state
1158 .db
1159 .dev_server_projects_update(dev_server.user_id)
1160 .await?;
1161 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1162 }
1163 }
1164
1165 Ok(())
1166 }
1167
1168 pub async fn invite_code_redeemed(
1169 self: &Arc<Self>,
1170 inviter_id: UserId,
1171 invitee_id: UserId,
1172 ) -> Result<()> {
1173 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1174 if let Some(code) = &user.invite_code {
1175 let pool = self.connection_pool.lock();
1176 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1177 for connection_id in pool.user_connection_ids(inviter_id) {
1178 self.peer.send(
1179 connection_id,
1180 proto::UpdateContacts {
1181 contacts: vec![invitee_contact.clone()],
1182 ..Default::default()
1183 },
1184 )?;
1185 self.peer.send(
1186 connection_id,
1187 proto::UpdateInviteInfo {
1188 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1189 count: user.invite_count as u32,
1190 },
1191 )?;
1192 }
1193 }
1194 }
1195 Ok(())
1196 }
1197
1198 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1199 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1200 if let Some(invite_code) = &user.invite_code {
1201 let pool = self.connection_pool.lock();
1202 for connection_id in pool.user_connection_ids(user_id) {
1203 self.peer.send(
1204 connection_id,
1205 proto::UpdateInviteInfo {
1206 url: format!(
1207 "{}{}",
1208 self.app_state.config.invite_link_prefix, invite_code
1209 ),
1210 count: user.invite_count as u32,
1211 },
1212 )?;
1213 }
1214 }
1215 }
1216 Ok(())
1217 }
1218
1219 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1220 ServerSnapshot {
1221 connection_pool: ConnectionPoolGuard {
1222 guard: self.connection_pool.lock(),
1223 _not_send: PhantomData,
1224 },
1225 peer: &self.peer,
1226 }
1227 }
1228}
1229
1230impl<'a> Deref for ConnectionPoolGuard<'a> {
1231 type Target = ConnectionPool;
1232
1233 fn deref(&self) -> &Self::Target {
1234 &self.guard
1235 }
1236}
1237
1238impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1239 fn deref_mut(&mut self) -> &mut Self::Target {
1240 &mut self.guard
1241 }
1242}
1243
1244impl<'a> Drop for ConnectionPoolGuard<'a> {
1245 fn drop(&mut self) {
1246 #[cfg(test)]
1247 self.check_invariants();
1248 }
1249}
1250
1251fn broadcast<F>(
1252 sender_id: Option<ConnectionId>,
1253 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1254 mut f: F,
1255) where
1256 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1257{
1258 for receiver_id in receiver_ids {
1259 if Some(receiver_id) != sender_id {
1260 if let Err(error) = f(receiver_id) {
1261 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1262 }
1263 }
1264 }
1265}
1266
1267pub struct ProtocolVersion(u32);
1268
1269impl Header for ProtocolVersion {
1270 fn name() -> &'static HeaderName {
1271 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1272 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1273 }
1274
1275 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1276 where
1277 Self: Sized,
1278 I: Iterator<Item = &'i axum::http::HeaderValue>,
1279 {
1280 let version = values
1281 .next()
1282 .ok_or_else(axum::headers::Error::invalid)?
1283 .to_str()
1284 .map_err(|_| axum::headers::Error::invalid())?
1285 .parse()
1286 .map_err(|_| axum::headers::Error::invalid())?;
1287 Ok(Self(version))
1288 }
1289
1290 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1291 values.extend([self.0.to_string().parse().unwrap()]);
1292 }
1293}
1294
1295pub struct AppVersionHeader(SemanticVersion);
1296impl Header for AppVersionHeader {
1297 fn name() -> &'static HeaderName {
1298 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1299 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1300 }
1301
1302 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1303 where
1304 Self: Sized,
1305 I: Iterator<Item = &'i axum::http::HeaderValue>,
1306 {
1307 let version = values
1308 .next()
1309 .ok_or_else(axum::headers::Error::invalid)?
1310 .to_str()
1311 .map_err(|_| axum::headers::Error::invalid())?
1312 .parse()
1313 .map_err(|_| axum::headers::Error::invalid())?;
1314 Ok(Self(version))
1315 }
1316
1317 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1318 values.extend([self.0.to_string().parse().unwrap()]);
1319 }
1320}
1321
1322pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1323 Router::new()
1324 .route("/rpc", get(handle_websocket_request))
1325 .layer(
1326 ServiceBuilder::new()
1327 .layer(Extension(server.app_state.clone()))
1328 .layer(middleware::from_fn(auth::validate_header)),
1329 )
1330 .route("/metrics", get(handle_metrics))
1331 .layer(Extension(server))
1332}
1333
1334pub async fn handle_websocket_request(
1335 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1336 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1337 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1338 Extension(server): Extension<Arc<Server>>,
1339 Extension(principal): Extension<Principal>,
1340 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1341 ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343 if protocol_version != rpc::PROTOCOL_VERSION {
1344 return (
1345 StatusCode::UPGRADE_REQUIRED,
1346 "client must be upgraded".to_string(),
1347 )
1348 .into_response();
1349 }
1350
1351 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352 return (
1353 StatusCode::UPGRADE_REQUIRED,
1354 "no version header found".to_string(),
1355 )
1356 .into_response();
1357 };
1358
1359 if !version.can_collaborate() {
1360 return (
1361 StatusCode::UPGRADE_REQUIRED,
1362 "client must be upgraded".to_string(),
1363 )
1364 .into_response();
1365 }
1366
1367 let socket_address = socket_address.to_string();
1368 ws.on_upgrade(move |socket| {
1369 let socket = socket
1370 .map_ok(to_tungstenite_message)
1371 .err_into()
1372 .with(|message| async move { to_axum_message(message) });
1373 let connection = Connection::new(Box::pin(socket));
1374 async move {
1375 server
1376 .handle_connection(
1377 connection,
1378 socket_address,
1379 principal,
1380 version,
1381 country_code_header.map(|header| header.to_string()),
1382 None,
1383 Executor::Production,
1384 )
1385 .await;
1386 }
1387 })
1388}
1389
1390pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1391 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1392 let connections_metric = CONNECTIONS_METRIC
1393 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1394
1395 let connections = server
1396 .connection_pool
1397 .lock()
1398 .connections()
1399 .filter(|connection| !connection.admin)
1400 .count();
1401 connections_metric.set(connections as _);
1402
1403 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1405 register_int_gauge!(
1406 "shared_projects",
1407 "number of open projects with one or more guests"
1408 )
1409 .unwrap()
1410 });
1411
1412 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1413 shared_projects_metric.set(shared_projects as _);
1414
1415 let encoder = prometheus::TextEncoder::new();
1416 let metric_families = prometheus::gather();
1417 let encoded_metrics = encoder
1418 .encode_to_string(&metric_families)
1419 .map_err(|err| anyhow!("{}", err))?;
1420 Ok(encoded_metrics)
1421}
1422
1423#[instrument(err, skip(executor))]
1424async fn connection_lost(
1425 session: Session,
1426 mut teardown: watch::Receiver<bool>,
1427 executor: Executor,
1428) -> Result<()> {
1429 session.peer.disconnect(session.connection_id);
1430 session
1431 .connection_pool()
1432 .await
1433 .remove_connection(session.connection_id)?;
1434
1435 session
1436 .db()
1437 .await
1438 .connection_lost(session.connection_id)
1439 .await
1440 .trace_err();
1441
1442 futures::select_biased! {
1443 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1444 match &session.principal {
1445 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1446 let session = session.for_user().unwrap();
1447
1448 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1449 leave_room_for_session(&session, session.connection_id).await.trace_err();
1450 leave_channel_buffers_for_session(&session)
1451 .await
1452 .trace_err();
1453
1454 if !session
1455 .connection_pool()
1456 .await
1457 .is_user_online(session.user_id())
1458 {
1459 let db = session.db().await;
1460 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1461 room_updated(&room, &session.peer);
1462 }
1463 }
1464
1465 update_user_contacts(session.user_id(), &session).await?;
1466 },
1467 Principal::DevServer(_) => {
1468 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1469 },
1470 }
1471 },
1472 _ = teardown.changed().fuse() => {}
1473 }
1474
1475 Ok(())
1476}
1477
1478/// Acknowledges a ping from a client, used to keep the connection alive.
1479async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1480 response.send(proto::Ack {})?;
1481 Ok(())
1482}
1483
1484/// Creates a new room for calling (outside of channels)
1485async fn create_room(
1486 _request: proto::CreateRoom,
1487 response: Response<proto::CreateRoom>,
1488 session: UserSession,
1489) -> Result<()> {
1490 let live_kit_room = nanoid::nanoid!(30);
1491
1492 let live_kit_connection_info = util::maybe!(async {
1493 let live_kit = session.app_state.live_kit_client.as_ref();
1494 let live_kit = live_kit?;
1495 let user_id = session.user_id().to_string();
1496
1497 let token = live_kit
1498 .room_token(&live_kit_room, &user_id.to_string())
1499 .trace_err()?;
1500
1501 Some(proto::LiveKitConnectionInfo {
1502 server_url: live_kit.url().into(),
1503 token,
1504 can_publish: true,
1505 })
1506 })
1507 .await;
1508
1509 let room = session
1510 .db()
1511 .await
1512 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1513 .await?;
1514
1515 response.send(proto::CreateRoomResponse {
1516 room: Some(room.clone()),
1517 live_kit_connection_info,
1518 })?;
1519
1520 update_user_contacts(session.user_id(), &session).await?;
1521 Ok(())
1522}
1523
1524/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1525async fn join_room(
1526 request: proto::JoinRoom,
1527 response: Response<proto::JoinRoom>,
1528 session: UserSession,
1529) -> Result<()> {
1530 let room_id = RoomId::from_proto(request.id);
1531
1532 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1533
1534 if let Some(channel_id) = channel_id {
1535 return join_channel_internal(channel_id, Box::new(response), session).await;
1536 }
1537
1538 let joined_room = {
1539 let room = session
1540 .db()
1541 .await
1542 .join_room(room_id, session.user_id(), session.connection_id)
1543 .await?;
1544 room_updated(&room.room, &session.peer);
1545 room.into_inner()
1546 };
1547
1548 for connection_id in session
1549 .connection_pool()
1550 .await
1551 .user_connection_ids(session.user_id())
1552 {
1553 session
1554 .peer
1555 .send(
1556 connection_id,
1557 proto::CallCanceled {
1558 room_id: room_id.to_proto(),
1559 },
1560 )
1561 .trace_err();
1562 }
1563
1564 let live_kit_connection_info =
1565 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1566 if let Some(token) = live_kit
1567 .room_token(
1568 &joined_room.room.live_kit_room,
1569 &session.user_id().to_string(),
1570 )
1571 .trace_err()
1572 {
1573 Some(proto::LiveKitConnectionInfo {
1574 server_url: live_kit.url().into(),
1575 token,
1576 can_publish: true,
1577 })
1578 } else {
1579 None
1580 }
1581 } else {
1582 None
1583 };
1584
1585 response.send(proto::JoinRoomResponse {
1586 room: Some(joined_room.room),
1587 channel_id: None,
1588 live_kit_connection_info,
1589 })?;
1590
1591 update_user_contacts(session.user_id(), &session).await?;
1592 Ok(())
1593}
1594
1595/// Rejoin room is used to reconnect to a room after connection errors.
1596async fn rejoin_room(
1597 request: proto::RejoinRoom,
1598 response: Response<proto::RejoinRoom>,
1599 session: UserSession,
1600) -> Result<()> {
1601 let room;
1602 let channel;
1603 {
1604 let mut rejoined_room = session
1605 .db()
1606 .await
1607 .rejoin_room(request, session.user_id(), session.connection_id)
1608 .await?;
1609
1610 response.send(proto::RejoinRoomResponse {
1611 room: Some(rejoined_room.room.clone()),
1612 reshared_projects: rejoined_room
1613 .reshared_projects
1614 .iter()
1615 .map(|project| proto::ResharedProject {
1616 id: project.id.to_proto(),
1617 collaborators: project
1618 .collaborators
1619 .iter()
1620 .map(|collaborator| collaborator.to_proto())
1621 .collect(),
1622 })
1623 .collect(),
1624 rejoined_projects: rejoined_room
1625 .rejoined_projects
1626 .iter()
1627 .map(|rejoined_project| rejoined_project.to_proto())
1628 .collect(),
1629 })?;
1630 room_updated(&rejoined_room.room, &session.peer);
1631
1632 for project in &rejoined_room.reshared_projects {
1633 for collaborator in &project.collaborators {
1634 session
1635 .peer
1636 .send(
1637 collaborator.connection_id,
1638 proto::UpdateProjectCollaborator {
1639 project_id: project.id.to_proto(),
1640 old_peer_id: Some(project.old_connection_id.into()),
1641 new_peer_id: Some(session.connection_id.into()),
1642 },
1643 )
1644 .trace_err();
1645 }
1646
1647 broadcast(
1648 Some(session.connection_id),
1649 project
1650 .collaborators
1651 .iter()
1652 .map(|collaborator| collaborator.connection_id),
1653 |connection_id| {
1654 session.peer.forward_send(
1655 session.connection_id,
1656 connection_id,
1657 proto::UpdateProject {
1658 project_id: project.id.to_proto(),
1659 worktrees: project.worktrees.clone(),
1660 },
1661 )
1662 },
1663 );
1664 }
1665
1666 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1667
1668 let rejoined_room = rejoined_room.into_inner();
1669
1670 room = rejoined_room.room;
1671 channel = rejoined_room.channel;
1672 }
1673
1674 if let Some(channel) = channel {
1675 channel_updated(
1676 &channel,
1677 &room,
1678 &session.peer,
1679 &*session.connection_pool().await,
1680 );
1681 }
1682
1683 update_user_contacts(session.user_id(), &session).await?;
1684 Ok(())
1685}
1686
1687fn notify_rejoined_projects(
1688 rejoined_projects: &mut Vec<RejoinedProject>,
1689 session: &UserSession,
1690) -> Result<()> {
1691 for project in rejoined_projects.iter() {
1692 for collaborator in &project.collaborators {
1693 session
1694 .peer
1695 .send(
1696 collaborator.connection_id,
1697 proto::UpdateProjectCollaborator {
1698 project_id: project.id.to_proto(),
1699 old_peer_id: Some(project.old_connection_id.into()),
1700 new_peer_id: Some(session.connection_id.into()),
1701 },
1702 )
1703 .trace_err();
1704 }
1705 }
1706
1707 for project in rejoined_projects {
1708 for worktree in mem::take(&mut project.worktrees) {
1709 #[cfg(any(test, feature = "test-support"))]
1710 const MAX_CHUNK_SIZE: usize = 2;
1711 #[cfg(not(any(test, feature = "test-support")))]
1712 const MAX_CHUNK_SIZE: usize = 256;
1713
1714 // Stream this worktree's entries.
1715 let message = proto::UpdateWorktree {
1716 project_id: project.id.to_proto(),
1717 worktree_id: worktree.id,
1718 abs_path: worktree.abs_path.clone(),
1719 root_name: worktree.root_name,
1720 updated_entries: worktree.updated_entries,
1721 removed_entries: worktree.removed_entries,
1722 scan_id: worktree.scan_id,
1723 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1724 updated_repositories: worktree.updated_repositories,
1725 removed_repositories: worktree.removed_repositories,
1726 };
1727 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1728 session.peer.send(session.connection_id, update.clone())?;
1729 }
1730
1731 // Stream this worktree's diagnostics.
1732 for summary in worktree.diagnostic_summaries {
1733 session.peer.send(
1734 session.connection_id,
1735 proto::UpdateDiagnosticSummary {
1736 project_id: project.id.to_proto(),
1737 worktree_id: worktree.id,
1738 summary: Some(summary),
1739 },
1740 )?;
1741 }
1742
1743 for settings_file in worktree.settings_files {
1744 session.peer.send(
1745 session.connection_id,
1746 proto::UpdateWorktreeSettings {
1747 project_id: project.id.to_proto(),
1748 worktree_id: worktree.id,
1749 path: settings_file.path,
1750 content: Some(settings_file.content),
1751 },
1752 )?;
1753 }
1754 }
1755
1756 for language_server in &project.language_servers {
1757 session.peer.send(
1758 session.connection_id,
1759 proto::UpdateLanguageServer {
1760 project_id: project.id.to_proto(),
1761 language_server_id: language_server.id,
1762 variant: Some(
1763 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1764 proto::LspDiskBasedDiagnosticsUpdated {},
1765 ),
1766 ),
1767 },
1768 )?;
1769 }
1770 }
1771 Ok(())
1772}
1773
1774/// leave room disconnects from the room.
1775async fn leave_room(
1776 _: proto::LeaveRoom,
1777 response: Response<proto::LeaveRoom>,
1778 session: UserSession,
1779) -> Result<()> {
1780 leave_room_for_session(&session, session.connection_id).await?;
1781 response.send(proto::Ack {})?;
1782 Ok(())
1783}
1784
1785/// Updates the permissions of someone else in the room.
1786async fn set_room_participant_role(
1787 request: proto::SetRoomParticipantRole,
1788 response: Response<proto::SetRoomParticipantRole>,
1789 session: UserSession,
1790) -> Result<()> {
1791 let user_id = UserId::from_proto(request.user_id);
1792 let role = ChannelRole::from(request.role());
1793
1794 let (live_kit_room, can_publish) = {
1795 let room = session
1796 .db()
1797 .await
1798 .set_room_participant_role(
1799 session.user_id(),
1800 RoomId::from_proto(request.room_id),
1801 user_id,
1802 role,
1803 )
1804 .await?;
1805
1806 let live_kit_room = room.live_kit_room.clone();
1807 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1808 room_updated(&room, &session.peer);
1809 (live_kit_room, can_publish)
1810 };
1811
1812 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1813 live_kit
1814 .update_participant(
1815 live_kit_room.clone(),
1816 request.user_id.to_string(),
1817 live_kit_server::proto::ParticipantPermission {
1818 can_subscribe: true,
1819 can_publish,
1820 can_publish_data: can_publish,
1821 hidden: false,
1822 recorder: false,
1823 },
1824 )
1825 .await
1826 .trace_err();
1827 }
1828
1829 response.send(proto::Ack {})?;
1830 Ok(())
1831}
1832
1833/// Call someone else into the current room
1834async fn call(
1835 request: proto::Call,
1836 response: Response<proto::Call>,
1837 session: UserSession,
1838) -> Result<()> {
1839 let room_id = RoomId::from_proto(request.room_id);
1840 let calling_user_id = session.user_id();
1841 let calling_connection_id = session.connection_id;
1842 let called_user_id = UserId::from_proto(request.called_user_id);
1843 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1844 if !session
1845 .db()
1846 .await
1847 .has_contact(calling_user_id, called_user_id)
1848 .await?
1849 {
1850 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1851 }
1852
1853 let incoming_call = {
1854 let (room, incoming_call) = &mut *session
1855 .db()
1856 .await
1857 .call(
1858 room_id,
1859 calling_user_id,
1860 calling_connection_id,
1861 called_user_id,
1862 initial_project_id,
1863 )
1864 .await?;
1865 room_updated(&room, &session.peer);
1866 mem::take(incoming_call)
1867 };
1868 update_user_contacts(called_user_id, &session).await?;
1869
1870 let mut calls = session
1871 .connection_pool()
1872 .await
1873 .user_connection_ids(called_user_id)
1874 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1875 .collect::<FuturesUnordered<_>>();
1876
1877 while let Some(call_response) = calls.next().await {
1878 match call_response.as_ref() {
1879 Ok(_) => {
1880 response.send(proto::Ack {})?;
1881 return Ok(());
1882 }
1883 Err(_) => {
1884 call_response.trace_err();
1885 }
1886 }
1887 }
1888
1889 {
1890 let room = session
1891 .db()
1892 .await
1893 .call_failed(room_id, called_user_id)
1894 .await?;
1895 room_updated(&room, &session.peer);
1896 }
1897 update_user_contacts(called_user_id, &session).await?;
1898
1899 Err(anyhow!("failed to ring user"))?
1900}
1901
1902/// Cancel an outgoing call.
1903async fn cancel_call(
1904 request: proto::CancelCall,
1905 response: Response<proto::CancelCall>,
1906 session: UserSession,
1907) -> Result<()> {
1908 let called_user_id = UserId::from_proto(request.called_user_id);
1909 let room_id = RoomId::from_proto(request.room_id);
1910 {
1911 let room = session
1912 .db()
1913 .await
1914 .cancel_call(room_id, session.connection_id, called_user_id)
1915 .await?;
1916 room_updated(&room, &session.peer);
1917 }
1918
1919 for connection_id in session
1920 .connection_pool()
1921 .await
1922 .user_connection_ids(called_user_id)
1923 {
1924 session
1925 .peer
1926 .send(
1927 connection_id,
1928 proto::CallCanceled {
1929 room_id: room_id.to_proto(),
1930 },
1931 )
1932 .trace_err();
1933 }
1934 response.send(proto::Ack {})?;
1935
1936 update_user_contacts(called_user_id, &session).await?;
1937 Ok(())
1938}
1939
1940/// Decline an incoming call.
1941async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1942 let room_id = RoomId::from_proto(message.room_id);
1943 {
1944 let room = session
1945 .db()
1946 .await
1947 .decline_call(Some(room_id), session.user_id())
1948 .await?
1949 .ok_or_else(|| anyhow!("failed to decline call"))?;
1950 room_updated(&room, &session.peer);
1951 }
1952
1953 for connection_id in session
1954 .connection_pool()
1955 .await
1956 .user_connection_ids(session.user_id())
1957 {
1958 session
1959 .peer
1960 .send(
1961 connection_id,
1962 proto::CallCanceled {
1963 room_id: room_id.to_proto(),
1964 },
1965 )
1966 .trace_err();
1967 }
1968 update_user_contacts(session.user_id(), &session).await?;
1969 Ok(())
1970}
1971
1972/// Updates other participants in the room with your current location.
1973async fn update_participant_location(
1974 request: proto::UpdateParticipantLocation,
1975 response: Response<proto::UpdateParticipantLocation>,
1976 session: UserSession,
1977) -> Result<()> {
1978 let room_id = RoomId::from_proto(request.room_id);
1979 let location = request
1980 .location
1981 .ok_or_else(|| anyhow!("invalid location"))?;
1982
1983 let db = session.db().await;
1984 let room = db
1985 .update_room_participant_location(room_id, session.connection_id, location)
1986 .await?;
1987
1988 room_updated(&room, &session.peer);
1989 response.send(proto::Ack {})?;
1990 Ok(())
1991}
1992
1993/// Share a project into the room.
1994async fn share_project(
1995 request: proto::ShareProject,
1996 response: Response<proto::ShareProject>,
1997 session: UserSession,
1998) -> Result<()> {
1999 let (project_id, room) = &*session
2000 .db()
2001 .await
2002 .share_project(
2003 RoomId::from_proto(request.room_id),
2004 session.connection_id,
2005 &request.worktrees,
2006 request
2007 .dev_server_project_id
2008 .map(|id| DevServerProjectId::from_proto(id)),
2009 )
2010 .await?;
2011 response.send(proto::ShareProjectResponse {
2012 project_id: project_id.to_proto(),
2013 })?;
2014 room_updated(&room, &session.peer);
2015
2016 Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021 let project_id = ProjectId::from_proto(message.project_id);
2022 unshare_project_internal(
2023 project_id,
2024 session.connection_id,
2025 session.user_id(),
2026 &session,
2027 )
2028 .await
2029}
2030
2031async fn unshare_project_internal(
2032 project_id: ProjectId,
2033 connection_id: ConnectionId,
2034 user_id: Option<UserId>,
2035 session: &Session,
2036) -> Result<()> {
2037 let delete = {
2038 let room_guard = session
2039 .db()
2040 .await
2041 .unshare_project(project_id, connection_id, user_id)
2042 .await?;
2043
2044 let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046 let message = proto::UnshareProject {
2047 project_id: project_id.to_proto(),
2048 };
2049
2050 broadcast(
2051 Some(connection_id),
2052 guest_connection_ids.iter().copied(),
2053 |conn_id| session.peer.send(conn_id, message.clone()),
2054 );
2055 if let Some(room) = room {
2056 room_updated(room, &session.peer);
2057 }
2058
2059 *delete
2060 };
2061
2062 if delete {
2063 let db = session.db().await;
2064 db.delete_project(project_id).await?;
2065 }
2066
2067 Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072 request: proto::ShareDevServerProject,
2073 response: Response<proto::ShareDevServerProject>,
2074 session: DevServerSession,
2075) -> Result<()> {
2076 let (dev_server_project, user_id, status) = session
2077 .db()
2078 .await
2079 .share_dev_server_project(
2080 DevServerProjectId::from_proto(request.dev_server_project_id),
2081 session.dev_server_id(),
2082 session.connection_id,
2083 &request.worktrees,
2084 )
2085 .await?;
2086 let Some(project_id) = dev_server_project.project_id else {
2087 return Err(anyhow!("failed to share remote project"))?;
2088 };
2089
2090 send_dev_server_projects_update(user_id, status, &session).await;
2091
2092 response.send(proto::ShareProjectResponse { project_id })?;
2093
2094 Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099 request: proto::JoinProject,
2100 response: Response<proto::JoinProject>,
2101 session: UserSession,
2102) -> Result<()> {
2103 let project_id = ProjectId::from_proto(request.project_id);
2104
2105 tracing::info!(%project_id, "join project");
2106
2107 let db = session.db().await;
2108 let (project, replica_id) = &mut *db
2109 .join_project(project_id, session.connection_id, session.user_id())
2110 .await?;
2111 drop(db);
2112 tracing::info!(%project_id, "join remote project");
2113 join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121 Response::<proto::JoinProject>::send(self, result)
2122 }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126 Response::<proto::JoinHostedProject>::send(self, result)
2127 }
2128}
2129
2130fn join_project_internal(
2131 response: impl JoinProjectInternalResponse,
2132 session: UserSession,
2133 project: &mut Project,
2134 replica_id: &ReplicaId,
2135) -> Result<()> {
2136 let collaborators = project
2137 .collaborators
2138 .iter()
2139 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140 .map(|collaborator| collaborator.to_proto())
2141 .collect::<Vec<_>>();
2142 let project_id = project.id;
2143 let guest_user_id = session.user_id();
2144
2145 let worktrees = project
2146 .worktrees
2147 .iter()
2148 .map(|(id, worktree)| proto::WorktreeMetadata {
2149 id: *id,
2150 root_name: worktree.root_name.clone(),
2151 visible: worktree.visible,
2152 abs_path: worktree.abs_path.clone(),
2153 })
2154 .collect::<Vec<_>>();
2155
2156 let add_project_collaborator = proto::AddProjectCollaborator {
2157 project_id: project_id.to_proto(),
2158 collaborator: Some(proto::Collaborator {
2159 peer_id: Some(session.connection_id.into()),
2160 replica_id: replica_id.0 as u32,
2161 user_id: guest_user_id.to_proto(),
2162 }),
2163 };
2164
2165 for collaborator in &collaborators {
2166 session
2167 .peer
2168 .send(
2169 collaborator.peer_id.unwrap().into(),
2170 add_project_collaborator.clone(),
2171 )
2172 .trace_err();
2173 }
2174
2175 // First, we send the metadata associated with each worktree.
2176 response.send(proto::JoinProjectResponse {
2177 project_id: project.id.0 as u64,
2178 worktrees: worktrees.clone(),
2179 replica_id: replica_id.0 as u32,
2180 collaborators: collaborators.clone(),
2181 language_servers: project.language_servers.clone(),
2182 role: project.role.into(),
2183 dev_server_project_id: project
2184 .dev_server_project_id
2185 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186 })?;
2187
2188 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189 #[cfg(any(test, feature = "test-support"))]
2190 const MAX_CHUNK_SIZE: usize = 2;
2191 #[cfg(not(any(test, feature = "test-support")))]
2192 const MAX_CHUNK_SIZE: usize = 256;
2193
2194 // Stream this worktree's entries.
2195 let message = proto::UpdateWorktree {
2196 project_id: project_id.to_proto(),
2197 worktree_id,
2198 abs_path: worktree.abs_path.clone(),
2199 root_name: worktree.root_name,
2200 updated_entries: worktree.entries,
2201 removed_entries: Default::default(),
2202 scan_id: worktree.scan_id,
2203 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204 updated_repositories: worktree.repository_entries.into_values().collect(),
2205 removed_repositories: Default::default(),
2206 };
2207 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208 session.peer.send(session.connection_id, update.clone())?;
2209 }
2210
2211 // Stream this worktree's diagnostics.
2212 for summary in worktree.diagnostic_summaries {
2213 session.peer.send(
2214 session.connection_id,
2215 proto::UpdateDiagnosticSummary {
2216 project_id: project_id.to_proto(),
2217 worktree_id: worktree.id,
2218 summary: Some(summary),
2219 },
2220 )?;
2221 }
2222
2223 for settings_file in worktree.settings_files {
2224 session.peer.send(
2225 session.connection_id,
2226 proto::UpdateWorktreeSettings {
2227 project_id: project_id.to_proto(),
2228 worktree_id: worktree.id,
2229 path: settings_file.path,
2230 content: Some(settings_file.content),
2231 },
2232 )?;
2233 }
2234 }
2235
2236 for language_server in &project.language_servers {
2237 session.peer.send(
2238 session.connection_id,
2239 proto::UpdateLanguageServer {
2240 project_id: project_id.to_proto(),
2241 language_server_id: language_server.id,
2242 variant: Some(
2243 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2244 proto::LspDiskBasedDiagnosticsUpdated {},
2245 ),
2246 ),
2247 },
2248 )?;
2249 }
2250
2251 Ok(())
2252}
2253
2254/// Leave someone elses shared project.
2255async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2256 let sender_id = session.connection_id;
2257 let project_id = ProjectId::from_proto(request.project_id);
2258 let db = session.db().await;
2259 if db.is_hosted_project(project_id).await? {
2260 let project = db.leave_hosted_project(project_id, sender_id).await?;
2261 project_left(&project, &session);
2262 return Ok(());
2263 }
2264
2265 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2266 tracing::info!(
2267 %project_id,
2268 "leave project"
2269 );
2270
2271 project_left(&project, &session);
2272 if let Some(room) = room {
2273 room_updated(&room, &session.peer);
2274 }
2275
2276 Ok(())
2277}
2278
2279async fn join_hosted_project(
2280 request: proto::JoinHostedProject,
2281 response: Response<proto::JoinHostedProject>,
2282 session: UserSession,
2283) -> Result<()> {
2284 let (mut project, replica_id) = session
2285 .db()
2286 .await
2287 .join_hosted_project(
2288 ProjectId(request.project_id as i32),
2289 session.user_id(),
2290 session.connection_id,
2291 )
2292 .await?;
2293
2294 join_project_internal(response, session, &mut project, &replica_id)
2295}
2296
2297async fn list_remote_directory(
2298 request: proto::ListRemoteDirectory,
2299 response: Response<proto::ListRemoteDirectory>,
2300 session: UserSession,
2301) -> Result<()> {
2302 let dev_server_id = DevServerId(request.dev_server_id as i32);
2303 let dev_server_connection_id = session
2304 .connection_pool()
2305 .await
2306 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2307
2308 session
2309 .db()
2310 .await
2311 .get_dev_server_for_user(dev_server_id, session.user_id())
2312 .await?;
2313
2314 response.send(
2315 session
2316 .peer
2317 .forward_request(session.connection_id, dev_server_connection_id, request)
2318 .await?,
2319 )?;
2320 Ok(())
2321}
2322
2323async fn update_dev_server_project(
2324 request: proto::UpdateDevServerProject,
2325 response: Response<proto::UpdateDevServerProject>,
2326 session: UserSession,
2327) -> Result<()> {
2328 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2329
2330 let (dev_server_project, update) = session
2331 .db()
2332 .await
2333 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2334 .await?;
2335
2336 let projects = session
2337 .db()
2338 .await
2339 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2340 .await?;
2341
2342 let dev_server_connection_id = session
2343 .connection_pool()
2344 .await
2345 .dev_server_connection_id_supporting(
2346 dev_server_project.dev_server_id,
2347 ZedVersion::with_list_directory(),
2348 )?;
2349
2350 session.peer.send(
2351 dev_server_connection_id,
2352 proto::DevServerInstructions { projects },
2353 )?;
2354
2355 send_dev_server_projects_update(session.user_id(), update, &session).await;
2356
2357 response.send(proto::Ack {})
2358}
2359
2360async fn create_dev_server_project(
2361 request: proto::CreateDevServerProject,
2362 response: Response<proto::CreateDevServerProject>,
2363 session: UserSession,
2364) -> Result<()> {
2365 let dev_server_id = DevServerId(request.dev_server_id as i32);
2366 let dev_server_connection_id = session
2367 .connection_pool()
2368 .await
2369 .dev_server_connection_id(dev_server_id);
2370 let Some(dev_server_connection_id) = dev_server_connection_id else {
2371 Err(ErrorCode::DevServerOffline
2372 .message("Cannot create a remote project when the dev server is offline".to_string())
2373 .anyhow())?
2374 };
2375
2376 let path = request.path.clone();
2377 //Check that the path exists on the dev server
2378 session
2379 .peer
2380 .forward_request(
2381 session.connection_id,
2382 dev_server_connection_id,
2383 proto::ValidateDevServerProjectRequest { path: path.clone() },
2384 )
2385 .await?;
2386
2387 let (dev_server_project, update) = session
2388 .db()
2389 .await
2390 .create_dev_server_project(
2391 DevServerId(request.dev_server_id as i32),
2392 &request.path,
2393 session.user_id(),
2394 )
2395 .await?;
2396
2397 let projects = session
2398 .db()
2399 .await
2400 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2401 .await?;
2402
2403 session.peer.send(
2404 dev_server_connection_id,
2405 proto::DevServerInstructions { projects },
2406 )?;
2407
2408 send_dev_server_projects_update(session.user_id(), update, &session).await;
2409
2410 response.send(proto::CreateDevServerProjectResponse {
2411 dev_server_project: Some(dev_server_project.to_proto(None)),
2412 })?;
2413 Ok(())
2414}
2415
2416async fn create_dev_server(
2417 request: proto::CreateDevServer,
2418 response: Response<proto::CreateDevServer>,
2419 session: UserSession,
2420) -> Result<()> {
2421 let access_token = auth::random_token();
2422 let hashed_access_token = auth::hash_access_token(&access_token);
2423
2424 if request.name.is_empty() {
2425 return Err(proto::ErrorCode::Forbidden
2426 .message("Dev server name cannot be empty".to_string())
2427 .anyhow())?;
2428 }
2429
2430 let (dev_server, status) = session
2431 .db()
2432 .await
2433 .create_dev_server(
2434 &request.name,
2435 request.ssh_connection_string.as_deref(),
2436 &hashed_access_token,
2437 session.user_id(),
2438 )
2439 .await?;
2440
2441 send_dev_server_projects_update(session.user_id(), status, &session).await;
2442
2443 response.send(proto::CreateDevServerResponse {
2444 dev_server_id: dev_server.id.0 as u64,
2445 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2446 name: request.name,
2447 })?;
2448 Ok(())
2449}
2450
2451async fn regenerate_dev_server_token(
2452 request: proto::RegenerateDevServerToken,
2453 response: Response<proto::RegenerateDevServerToken>,
2454 session: UserSession,
2455) -> Result<()> {
2456 let dev_server_id = DevServerId(request.dev_server_id as i32);
2457 let access_token = auth::random_token();
2458 let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460 let connection_id = session
2461 .connection_pool()
2462 .await
2463 .dev_server_connection_id(dev_server_id);
2464 if let Some(connection_id) = connection_id {
2465 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2466 session.peer.send(
2467 connection_id,
2468 proto::ShutdownDevServer {
2469 reason: Some("dev server token was regenerated".to_string()),
2470 },
2471 )?;
2472 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2473 }
2474
2475 let status = session
2476 .db()
2477 .await
2478 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2479 .await?;
2480
2481 send_dev_server_projects_update(session.user_id(), status, &session).await;
2482
2483 response.send(proto::RegenerateDevServerTokenResponse {
2484 dev_server_id: dev_server_id.to_proto(),
2485 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2486 })?;
2487 Ok(())
2488}
2489
2490async fn rename_dev_server(
2491 request: proto::RenameDevServer,
2492 response: Response<proto::RenameDevServer>,
2493 session: UserSession,
2494) -> Result<()> {
2495 if request.name.trim().is_empty() {
2496 return Err(proto::ErrorCode::Forbidden
2497 .message("Dev server name cannot be empty".to_string())
2498 .anyhow())?;
2499 }
2500
2501 let dev_server_id = DevServerId(request.dev_server_id as i32);
2502 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2503 if dev_server.user_id != session.user_id() {
2504 return Err(anyhow!(ErrorCode::Forbidden))?;
2505 }
2506
2507 let status = session
2508 .db()
2509 .await
2510 .rename_dev_server(
2511 dev_server_id,
2512 &request.name,
2513 request.ssh_connection_string.as_deref(),
2514 session.user_id(),
2515 )
2516 .await?;
2517
2518 send_dev_server_projects_update(session.user_id(), status, &session).await;
2519
2520 response.send(proto::Ack {})?;
2521 Ok(())
2522}
2523
2524async fn delete_dev_server(
2525 request: proto::DeleteDevServer,
2526 response: Response<proto::DeleteDevServer>,
2527 session: UserSession,
2528) -> Result<()> {
2529 let dev_server_id = DevServerId(request.dev_server_id as i32);
2530 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2531 if dev_server.user_id != session.user_id() {
2532 return Err(anyhow!(ErrorCode::Forbidden))?;
2533 }
2534
2535 let connection_id = session
2536 .connection_pool()
2537 .await
2538 .dev_server_connection_id(dev_server_id);
2539 if let Some(connection_id) = connection_id {
2540 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2541 session.peer.send(
2542 connection_id,
2543 proto::ShutdownDevServer {
2544 reason: Some("dev server was deleted".to_string()),
2545 },
2546 )?;
2547 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2548 }
2549
2550 let status = session
2551 .db()
2552 .await
2553 .delete_dev_server(dev_server_id, session.user_id())
2554 .await?;
2555
2556 send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558 response.send(proto::Ack {})?;
2559 Ok(())
2560}
2561
2562async fn delete_dev_server_project(
2563 request: proto::DeleteDevServerProject,
2564 response: Response<proto::DeleteDevServerProject>,
2565 session: UserSession,
2566) -> Result<()> {
2567 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2568 let dev_server_project = session
2569 .db()
2570 .await
2571 .get_dev_server_project(dev_server_project_id)
2572 .await?;
2573
2574 let dev_server = session
2575 .db()
2576 .await
2577 .get_dev_server(dev_server_project.dev_server_id)
2578 .await?;
2579 if dev_server.user_id != session.user_id() {
2580 return Err(anyhow!(ErrorCode::Forbidden))?;
2581 }
2582
2583 let dev_server_connection_id = session
2584 .connection_pool()
2585 .await
2586 .dev_server_connection_id(dev_server.id);
2587
2588 if let Some(dev_server_connection_id) = dev_server_connection_id {
2589 let project = session
2590 .db()
2591 .await
2592 .find_dev_server_project(dev_server_project_id)
2593 .await;
2594 if let Ok(project) = project {
2595 unshare_project_internal(
2596 project.id,
2597 dev_server_connection_id,
2598 Some(session.user_id()),
2599 &session,
2600 )
2601 .await?;
2602 }
2603 }
2604
2605 let (projects, status) = session
2606 .db()
2607 .await
2608 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2609 .await?;
2610
2611 if let Some(dev_server_connection_id) = dev_server_connection_id {
2612 session.peer.send(
2613 dev_server_connection_id,
2614 proto::DevServerInstructions { projects },
2615 )?;
2616 }
2617
2618 send_dev_server_projects_update(session.user_id(), status, &session).await;
2619
2620 response.send(proto::Ack {})?;
2621 Ok(())
2622}
2623
2624async fn rejoin_dev_server_projects(
2625 request: proto::RejoinRemoteProjects,
2626 response: Response<proto::RejoinRemoteProjects>,
2627 session: UserSession,
2628) -> Result<()> {
2629 let mut rejoined_projects = {
2630 let db = session.db().await;
2631 db.rejoin_dev_server_projects(
2632 &request.rejoined_projects,
2633 session.user_id(),
2634 session.0.connection_id,
2635 )
2636 .await?
2637 };
2638 response.send(proto::RejoinRemoteProjectsResponse {
2639 rejoined_projects: rejoined_projects
2640 .iter()
2641 .map(|project| project.to_proto())
2642 .collect(),
2643 })?;
2644 notify_rejoined_projects(&mut rejoined_projects, &session)
2645}
2646
2647async fn reconnect_dev_server(
2648 request: proto::ReconnectDevServer,
2649 response: Response<proto::ReconnectDevServer>,
2650 session: DevServerSession,
2651) -> Result<()> {
2652 let reshared_projects = {
2653 let db = session.db().await;
2654 db.reshare_dev_server_projects(
2655 &request.reshared_projects,
2656 session.dev_server_id(),
2657 session.0.connection_id,
2658 )
2659 .await?
2660 };
2661
2662 for project in &reshared_projects {
2663 for collaborator in &project.collaborators {
2664 session
2665 .peer
2666 .send(
2667 collaborator.connection_id,
2668 proto::UpdateProjectCollaborator {
2669 project_id: project.id.to_proto(),
2670 old_peer_id: Some(project.old_connection_id.into()),
2671 new_peer_id: Some(session.connection_id.into()),
2672 },
2673 )
2674 .trace_err();
2675 }
2676
2677 broadcast(
2678 Some(session.connection_id),
2679 project
2680 .collaborators
2681 .iter()
2682 .map(|collaborator| collaborator.connection_id),
2683 |connection_id| {
2684 session.peer.forward_send(
2685 session.connection_id,
2686 connection_id,
2687 proto::UpdateProject {
2688 project_id: project.id.to_proto(),
2689 worktrees: project.worktrees.clone(),
2690 },
2691 )
2692 },
2693 );
2694 }
2695
2696 response.send(proto::ReconnectDevServerResponse {
2697 reshared_projects: reshared_projects
2698 .iter()
2699 .map(|project| proto::ResharedProject {
2700 id: project.id.to_proto(),
2701 collaborators: project
2702 .collaborators
2703 .iter()
2704 .map(|collaborator| collaborator.to_proto())
2705 .collect(),
2706 })
2707 .collect(),
2708 })?;
2709
2710 Ok(())
2711}
2712
2713async fn shutdown_dev_server(
2714 _: proto::ShutdownDevServer,
2715 response: Response<proto::ShutdownDevServer>,
2716 session: DevServerSession,
2717) -> Result<()> {
2718 response.send(proto::Ack {})?;
2719 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2720 remove_dev_server_connection(session.dev_server_id(), &session).await
2721}
2722
2723async fn shutdown_dev_server_internal(
2724 dev_server_id: DevServerId,
2725 connection_id: ConnectionId,
2726 session: &Session,
2727) -> Result<()> {
2728 let (dev_server_projects, dev_server) = {
2729 let db = session.db().await;
2730 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2731 let dev_server = db.get_dev_server(dev_server_id).await?;
2732 (dev_server_projects, dev_server)
2733 };
2734
2735 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2736 unshare_project_internal(
2737 ProjectId::from_proto(project_id),
2738 connection_id,
2739 None,
2740 session,
2741 )
2742 .await?;
2743 }
2744
2745 session
2746 .connection_pool()
2747 .await
2748 .set_dev_server_offline(dev_server_id);
2749
2750 let status = session
2751 .db()
2752 .await
2753 .dev_server_projects_update(dev_server.user_id)
2754 .await?;
2755 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2756
2757 Ok(())
2758}
2759
2760async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2761 let dev_server_connection = session
2762 .connection_pool()
2763 .await
2764 .dev_server_connection_id(dev_server_id);
2765
2766 if let Some(dev_server_connection) = dev_server_connection {
2767 session
2768 .connection_pool()
2769 .await
2770 .remove_connection(dev_server_connection)?;
2771 }
2772 Ok(())
2773}
2774
2775/// Updates other participants with changes to the project
2776async fn update_project(
2777 request: proto::UpdateProject,
2778 response: Response<proto::UpdateProject>,
2779 session: Session,
2780) -> Result<()> {
2781 let project_id = ProjectId::from_proto(request.project_id);
2782 let (room, guest_connection_ids) = &*session
2783 .db()
2784 .await
2785 .update_project(project_id, session.connection_id, &request.worktrees)
2786 .await?;
2787 broadcast(
2788 Some(session.connection_id),
2789 guest_connection_ids.iter().copied(),
2790 |connection_id| {
2791 session
2792 .peer
2793 .forward_send(session.connection_id, connection_id, request.clone())
2794 },
2795 );
2796 if let Some(room) = room {
2797 room_updated(&room, &session.peer);
2798 }
2799 response.send(proto::Ack {})?;
2800
2801 Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree
2805async fn update_worktree(
2806 request: proto::UpdateWorktree,
2807 response: Response<proto::UpdateWorktree>,
2808 session: Session,
2809) -> Result<()> {
2810 let guest_connection_ids = session
2811 .db()
2812 .await
2813 .update_worktree(&request, session.connection_id)
2814 .await?;
2815
2816 broadcast(
2817 Some(session.connection_id),
2818 guest_connection_ids.iter().copied(),
2819 |connection_id| {
2820 session
2821 .peer
2822 .forward_send(session.connection_id, connection_id, request.clone())
2823 },
2824 );
2825 response.send(proto::Ack {})?;
2826 Ok(())
2827}
2828
2829/// Updates other participants with changes to the diagnostics
2830async fn update_diagnostic_summary(
2831 message: proto::UpdateDiagnosticSummary,
2832 session: Session,
2833) -> Result<()> {
2834 let guest_connection_ids = session
2835 .db()
2836 .await
2837 .update_diagnostic_summary(&message, session.connection_id)
2838 .await?;
2839
2840 broadcast(
2841 Some(session.connection_id),
2842 guest_connection_ids.iter().copied(),
2843 |connection_id| {
2844 session
2845 .peer
2846 .forward_send(session.connection_id, connection_id, message.clone())
2847 },
2848 );
2849
2850 Ok(())
2851}
2852
2853/// Updates other participants with changes to the worktree settings
2854async fn update_worktree_settings(
2855 message: proto::UpdateWorktreeSettings,
2856 session: Session,
2857) -> Result<()> {
2858 let guest_connection_ids = session
2859 .db()
2860 .await
2861 .update_worktree_settings(&message, session.connection_id)
2862 .await?;
2863
2864 broadcast(
2865 Some(session.connection_id),
2866 guest_connection_ids.iter().copied(),
2867 |connection_id| {
2868 session
2869 .peer
2870 .forward_send(session.connection_id, connection_id, message.clone())
2871 },
2872 );
2873
2874 Ok(())
2875}
2876
2877/// Notify other participants that a language server has started.
2878async fn start_language_server(
2879 request: proto::StartLanguageServer,
2880 session: Session,
2881) -> Result<()> {
2882 let guest_connection_ids = session
2883 .db()
2884 .await
2885 .start_language_server(&request, session.connection_id)
2886 .await?;
2887
2888 broadcast(
2889 Some(session.connection_id),
2890 guest_connection_ids.iter().copied(),
2891 |connection_id| {
2892 session
2893 .peer
2894 .forward_send(session.connection_id, connection_id, request.clone())
2895 },
2896 );
2897 Ok(())
2898}
2899
2900/// Notify other participants that a language server has changed.
2901async fn update_language_server(
2902 request: proto::UpdateLanguageServer,
2903 session: Session,
2904) -> Result<()> {
2905 let project_id = ProjectId::from_proto(request.project_id);
2906 let project_connection_ids = session
2907 .db()
2908 .await
2909 .project_connection_ids(project_id, session.connection_id, true)
2910 .await?;
2911 broadcast(
2912 Some(session.connection_id),
2913 project_connection_ids.iter().copied(),
2914 |connection_id| {
2915 session
2916 .peer
2917 .forward_send(session.connection_id, connection_id, request.clone())
2918 },
2919 );
2920 Ok(())
2921}
2922
2923/// forward a project request to the host. These requests should be read only
2924/// as guests are allowed to send them.
2925async fn forward_read_only_project_request<T>(
2926 request: T,
2927 response: Response<T>,
2928 session: UserSession,
2929) -> Result<()>
2930where
2931 T: EntityMessage + RequestMessage,
2932{
2933 let project_id = ProjectId::from_proto(request.remote_entity_id());
2934 let host_connection_id = session
2935 .db()
2936 .await
2937 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2938 .await?;
2939 let payload = session
2940 .peer
2941 .forward_request(session.connection_id, host_connection_id, request)
2942 .await?;
2943 response.send(payload)?;
2944 Ok(())
2945}
2946
2947async fn forward_find_search_candidates_request(
2948 request: proto::FindSearchCandidates,
2949 response: Response<proto::FindSearchCandidates>,
2950 session: UserSession,
2951) -> Result<()> {
2952 let project_id = ProjectId::from_proto(request.remote_entity_id());
2953 let host_connection_id = session
2954 .db()
2955 .await
2956 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2957 .await?;
2958
2959 let host_version = session
2960 .connection_pool()
2961 .await
2962 .connection(host_connection_id)
2963 .map(|c| c.zed_version);
2964
2965 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2966 {
2967 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2968 let search = proto::SearchProject {
2969 project_id: project_id.to_proto(),
2970 query: query.query,
2971 regex: query.regex,
2972 whole_word: query.whole_word,
2973 case_sensitive: query.case_sensitive,
2974 files_to_include: query.files_to_include,
2975 files_to_exclude: query.files_to_exclude,
2976 include_ignored: query.include_ignored,
2977 };
2978
2979 let payload = session
2980 .peer
2981 .forward_request(session.connection_id, host_connection_id, search)
2982 .await?;
2983 return response.send(proto::FindSearchCandidatesResponse {
2984 buffer_ids: payload
2985 .locations
2986 .into_iter()
2987 .map(|loc| loc.buffer_id)
2988 .collect(),
2989 });
2990 }
2991
2992 let payload = session
2993 .peer
2994 .forward_request(session.connection_id, host_connection_id, request)
2995 .await?;
2996 response.send(payload)?;
2997 Ok(())
2998}
2999
3000/// forward a project request to the dev server. Only allowed
3001/// if it's your dev server.
3002async fn forward_project_request_for_owner<T>(
3003 request: T,
3004 response: Response<T>,
3005 session: UserSession,
3006) -> Result<()>
3007where
3008 T: EntityMessage + RequestMessage,
3009{
3010 let project_id = ProjectId::from_proto(request.remote_entity_id());
3011
3012 let host_connection_id = session
3013 .db()
3014 .await
3015 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3016 .await?;
3017 let payload = session
3018 .peer
3019 .forward_request(session.connection_id, host_connection_id, request)
3020 .await?;
3021 response.send(payload)?;
3022 Ok(())
3023}
3024
3025/// forward a project request to the host. These requests are disallowed
3026/// for guests.
3027async fn forward_mutating_project_request<T>(
3028 request: T,
3029 response: Response<T>,
3030 session: UserSession,
3031) -> Result<()>
3032where
3033 T: EntityMessage + RequestMessage,
3034{
3035 let project_id = ProjectId::from_proto(request.remote_entity_id());
3036
3037 let host_connection_id = session
3038 .db()
3039 .await
3040 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3041 .await?;
3042 let payload = session
3043 .peer
3044 .forward_request(session.connection_id, host_connection_id, request)
3045 .await?;
3046 response.send(payload)?;
3047 Ok(())
3048}
3049
3050/// forward a project request to the host. These requests are disallowed
3051/// for guests.
3052async fn forward_versioned_mutating_project_request<T>(
3053 request: T,
3054 response: Response<T>,
3055 session: UserSession,
3056) -> Result<()>
3057where
3058 T: EntityMessage + RequestMessage + VersionedMessage,
3059{
3060 let project_id = ProjectId::from_proto(request.remote_entity_id());
3061
3062 let host_connection_id = session
3063 .db()
3064 .await
3065 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3066 .await?;
3067 if let Some(host_version) = session
3068 .connection_pool()
3069 .await
3070 .connection(host_connection_id)
3071 .map(|c| c.zed_version)
3072 {
3073 if let Some(min_required_version) = request.required_host_version() {
3074 if min_required_version > host_version {
3075 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3076 .with_tag("required", &min_required_version.to_string())))?;
3077 }
3078 }
3079 }
3080
3081 let payload = session
3082 .peer
3083 .forward_request(session.connection_id, host_connection_id, request)
3084 .await?;
3085 response.send(payload)?;
3086 Ok(())
3087}
3088
3089/// Notify other participants that a new buffer has been created
3090async fn create_buffer_for_peer(
3091 request: proto::CreateBufferForPeer,
3092 session: Session,
3093) -> Result<()> {
3094 session
3095 .db()
3096 .await
3097 .check_user_is_project_host(
3098 ProjectId::from_proto(request.project_id),
3099 session.connection_id,
3100 )
3101 .await?;
3102 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3103 session
3104 .peer
3105 .forward_send(session.connection_id, peer_id.into(), request)?;
3106 Ok(())
3107}
3108
3109/// Notify other participants that a buffer has been updated. This is
3110/// allowed for guests as long as the update is limited to selections.
3111async fn update_buffer(
3112 request: proto::UpdateBuffer,
3113 response: Response<proto::UpdateBuffer>,
3114 session: Session,
3115) -> Result<()> {
3116 let project_id = ProjectId::from_proto(request.project_id);
3117 let mut capability = Capability::ReadOnly;
3118
3119 for op in request.operations.iter() {
3120 match op.variant {
3121 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3122 Some(_) => capability = Capability::ReadWrite,
3123 }
3124 }
3125
3126 let host = {
3127 let guard = session
3128 .db()
3129 .await
3130 .connections_for_buffer_update(
3131 project_id,
3132 session.principal_id(),
3133 session.connection_id,
3134 capability,
3135 )
3136 .await?;
3137
3138 let (host, guests) = &*guard;
3139
3140 broadcast(
3141 Some(session.connection_id),
3142 guests.clone(),
3143 |connection_id| {
3144 session
3145 .peer
3146 .forward_send(session.connection_id, connection_id, request.clone())
3147 },
3148 );
3149
3150 *host
3151 };
3152
3153 if host != session.connection_id {
3154 session
3155 .peer
3156 .forward_request(session.connection_id, host, request.clone())
3157 .await?;
3158 }
3159
3160 response.send(proto::Ack {})?;
3161 Ok(())
3162}
3163
3164async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3165 let project_id = ProjectId::from_proto(message.project_id);
3166
3167 let operation = message.operation.as_ref().context("invalid operation")?;
3168 let capability = match operation.variant.as_ref() {
3169 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3170 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3171 match buffer_op.variant {
3172 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3173 Capability::ReadOnly
3174 }
3175 _ => Capability::ReadWrite,
3176 }
3177 } else {
3178 Capability::ReadWrite
3179 }
3180 }
3181 Some(_) => Capability::ReadWrite,
3182 None => Capability::ReadOnly,
3183 };
3184
3185 let guard = session
3186 .db()
3187 .await
3188 .connections_for_buffer_update(
3189 project_id,
3190 session.principal_id(),
3191 session.connection_id,
3192 capability,
3193 )
3194 .await?;
3195
3196 let (host, guests) = &*guard;
3197
3198 broadcast(
3199 Some(session.connection_id),
3200 guests.iter().chain([host]).copied(),
3201 |connection_id| {
3202 session
3203 .peer
3204 .forward_send(session.connection_id, connection_id, message.clone())
3205 },
3206 );
3207
3208 Ok(())
3209}
3210
3211/// Notify other participants that a project has been updated.
3212async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3213 request: T,
3214 session: Session,
3215) -> Result<()> {
3216 let project_id = ProjectId::from_proto(request.remote_entity_id());
3217 let project_connection_ids = session
3218 .db()
3219 .await
3220 .project_connection_ids(project_id, session.connection_id, false)
3221 .await?;
3222
3223 broadcast(
3224 Some(session.connection_id),
3225 project_connection_ids.iter().copied(),
3226 |connection_id| {
3227 session
3228 .peer
3229 .forward_send(session.connection_id, connection_id, request.clone())
3230 },
3231 );
3232 Ok(())
3233}
3234
3235/// Start following another user in a call.
3236async fn follow(
3237 request: proto::Follow,
3238 response: Response<proto::Follow>,
3239 session: UserSession,
3240) -> Result<()> {
3241 let room_id = RoomId::from_proto(request.room_id);
3242 let project_id = request.project_id.map(ProjectId::from_proto);
3243 let leader_id = request
3244 .leader_id
3245 .ok_or_else(|| anyhow!("invalid leader id"))?
3246 .into();
3247 let follower_id = session.connection_id;
3248
3249 session
3250 .db()
3251 .await
3252 .check_room_participants(room_id, leader_id, session.connection_id)
3253 .await?;
3254
3255 let response_payload = session
3256 .peer
3257 .forward_request(session.connection_id, leader_id, request)
3258 .await?;
3259 response.send(response_payload)?;
3260
3261 if let Some(project_id) = project_id {
3262 let room = session
3263 .db()
3264 .await
3265 .follow(room_id, project_id, leader_id, follower_id)
3266 .await?;
3267 room_updated(&room, &session.peer);
3268 }
3269
3270 Ok(())
3271}
3272
3273/// Stop following another user in a call.
3274async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3275 let room_id = RoomId::from_proto(request.room_id);
3276 let project_id = request.project_id.map(ProjectId::from_proto);
3277 let leader_id = request
3278 .leader_id
3279 .ok_or_else(|| anyhow!("invalid leader id"))?
3280 .into();
3281 let follower_id = session.connection_id;
3282
3283 session
3284 .db()
3285 .await
3286 .check_room_participants(room_id, leader_id, session.connection_id)
3287 .await?;
3288
3289 session
3290 .peer
3291 .forward_send(session.connection_id, leader_id, request)?;
3292
3293 if let Some(project_id) = project_id {
3294 let room = session
3295 .db()
3296 .await
3297 .unfollow(room_id, project_id, leader_id, follower_id)
3298 .await?;
3299 room_updated(&room, &session.peer);
3300 }
3301
3302 Ok(())
3303}
3304
3305/// Notify everyone following you of your current location.
3306async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3307 let room_id = RoomId::from_proto(request.room_id);
3308 let database = session.db.lock().await;
3309
3310 let connection_ids = if let Some(project_id) = request.project_id {
3311 let project_id = ProjectId::from_proto(project_id);
3312 database
3313 .project_connection_ids(project_id, session.connection_id, true)
3314 .await?
3315 } else {
3316 database
3317 .room_connection_ids(room_id, session.connection_id)
3318 .await?
3319 };
3320
3321 // For now, don't send view update messages back to that view's current leader.
3322 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3323 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3324 _ => None,
3325 });
3326
3327 for connection_id in connection_ids.iter().cloned() {
3328 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3329 session
3330 .peer
3331 .forward_send(session.connection_id, connection_id, request.clone())?;
3332 }
3333 }
3334 Ok(())
3335}
3336
3337/// Get public data about users.
3338async fn get_users(
3339 request: proto::GetUsers,
3340 response: Response<proto::GetUsers>,
3341 session: Session,
3342) -> Result<()> {
3343 let user_ids = request
3344 .user_ids
3345 .into_iter()
3346 .map(UserId::from_proto)
3347 .collect();
3348 let users = session
3349 .db()
3350 .await
3351 .get_users_by_ids(user_ids)
3352 .await?
3353 .into_iter()
3354 .map(|user| proto::User {
3355 id: user.id.to_proto(),
3356 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3357 github_login: user.github_login,
3358 })
3359 .collect();
3360 response.send(proto::UsersResponse { users })?;
3361 Ok(())
3362}
3363
3364/// Search for users (to invite) buy Github login
3365async fn fuzzy_search_users(
3366 request: proto::FuzzySearchUsers,
3367 response: Response<proto::FuzzySearchUsers>,
3368 session: UserSession,
3369) -> Result<()> {
3370 let query = request.query;
3371 let users = match query.len() {
3372 0 => vec![],
3373 1 | 2 => session
3374 .db()
3375 .await
3376 .get_user_by_github_login(&query)
3377 .await?
3378 .into_iter()
3379 .collect(),
3380 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3381 };
3382 let users = users
3383 .into_iter()
3384 .filter(|user| user.id != session.user_id())
3385 .map(|user| proto::User {
3386 id: user.id.to_proto(),
3387 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3388 github_login: user.github_login,
3389 })
3390 .collect();
3391 response.send(proto::UsersResponse { users })?;
3392 Ok(())
3393}
3394
3395/// Send a contact request to another user.
3396async fn request_contact(
3397 request: proto::RequestContact,
3398 response: Response<proto::RequestContact>,
3399 session: UserSession,
3400) -> Result<()> {
3401 let requester_id = session.user_id();
3402 let responder_id = UserId::from_proto(request.responder_id);
3403 if requester_id == responder_id {
3404 return Err(anyhow!("cannot add yourself as a contact"))?;
3405 }
3406
3407 let notifications = session
3408 .db()
3409 .await
3410 .send_contact_request(requester_id, responder_id)
3411 .await?;
3412
3413 // Update outgoing contact requests of requester
3414 let mut update = proto::UpdateContacts::default();
3415 update.outgoing_requests.push(responder_id.to_proto());
3416 for connection_id in session
3417 .connection_pool()
3418 .await
3419 .user_connection_ids(requester_id)
3420 {
3421 session.peer.send(connection_id, update.clone())?;
3422 }
3423
3424 // Update incoming contact requests of responder
3425 let mut update = proto::UpdateContacts::default();
3426 update
3427 .incoming_requests
3428 .push(proto::IncomingContactRequest {
3429 requester_id: requester_id.to_proto(),
3430 });
3431 let connection_pool = session.connection_pool().await;
3432 for connection_id in connection_pool.user_connection_ids(responder_id) {
3433 session.peer.send(connection_id, update.clone())?;
3434 }
3435
3436 send_notifications(&connection_pool, &session.peer, notifications);
3437
3438 response.send(proto::Ack {})?;
3439 Ok(())
3440}
3441
3442/// Accept or decline a contact request
3443async fn respond_to_contact_request(
3444 request: proto::RespondToContactRequest,
3445 response: Response<proto::RespondToContactRequest>,
3446 session: UserSession,
3447) -> Result<()> {
3448 let responder_id = session.user_id();
3449 let requester_id = UserId::from_proto(request.requester_id);
3450 let db = session.db().await;
3451 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3452 db.dismiss_contact_notification(responder_id, requester_id)
3453 .await?;
3454 } else {
3455 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3456
3457 let notifications = db
3458 .respond_to_contact_request(responder_id, requester_id, accept)
3459 .await?;
3460 let requester_busy = db.is_user_busy(requester_id).await?;
3461 let responder_busy = db.is_user_busy(responder_id).await?;
3462
3463 let pool = session.connection_pool().await;
3464 // Update responder with new contact
3465 let mut update = proto::UpdateContacts::default();
3466 if accept {
3467 update
3468 .contacts
3469 .push(contact_for_user(requester_id, requester_busy, &pool));
3470 }
3471 update
3472 .remove_incoming_requests
3473 .push(requester_id.to_proto());
3474 for connection_id in pool.user_connection_ids(responder_id) {
3475 session.peer.send(connection_id, update.clone())?;
3476 }
3477
3478 // Update requester with new contact
3479 let mut update = proto::UpdateContacts::default();
3480 if accept {
3481 update
3482 .contacts
3483 .push(contact_for_user(responder_id, responder_busy, &pool));
3484 }
3485 update
3486 .remove_outgoing_requests
3487 .push(responder_id.to_proto());
3488
3489 for connection_id in pool.user_connection_ids(requester_id) {
3490 session.peer.send(connection_id, update.clone())?;
3491 }
3492
3493 send_notifications(&pool, &session.peer, notifications);
3494 }
3495
3496 response.send(proto::Ack {})?;
3497 Ok(())
3498}
3499
3500/// Remove a contact.
3501async fn remove_contact(
3502 request: proto::RemoveContact,
3503 response: Response<proto::RemoveContact>,
3504 session: UserSession,
3505) -> Result<()> {
3506 let requester_id = session.user_id();
3507 let responder_id = UserId::from_proto(request.user_id);
3508 let db = session.db().await;
3509 let (contact_accepted, deleted_notification_id) =
3510 db.remove_contact(requester_id, responder_id).await?;
3511
3512 let pool = session.connection_pool().await;
3513 // Update outgoing contact requests of requester
3514 let mut update = proto::UpdateContacts::default();
3515 if contact_accepted {
3516 update.remove_contacts.push(responder_id.to_proto());
3517 } else {
3518 update
3519 .remove_outgoing_requests
3520 .push(responder_id.to_proto());
3521 }
3522 for connection_id in pool.user_connection_ids(requester_id) {
3523 session.peer.send(connection_id, update.clone())?;
3524 }
3525
3526 // Update incoming contact requests of responder
3527 let mut update = proto::UpdateContacts::default();
3528 if contact_accepted {
3529 update.remove_contacts.push(requester_id.to_proto());
3530 } else {
3531 update
3532 .remove_incoming_requests
3533 .push(requester_id.to_proto());
3534 }
3535 for connection_id in pool.user_connection_ids(responder_id) {
3536 session.peer.send(connection_id, update.clone())?;
3537 if let Some(notification_id) = deleted_notification_id {
3538 session.peer.send(
3539 connection_id,
3540 proto::DeleteNotification {
3541 notification_id: notification_id.to_proto(),
3542 },
3543 )?;
3544 }
3545 }
3546
3547 response.send(proto::Ack {})?;
3548 Ok(())
3549}
3550
3551fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3552 version.0.minor() < 139
3553}
3554
3555async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3556 let plan = session.current_plan(session.db().await).await?;
3557
3558 session
3559 .peer
3560 .send(
3561 session.connection_id,
3562 proto::UpdateUserPlan { plan: plan.into() },
3563 )
3564 .trace_err();
3565
3566 Ok(())
3567}
3568
3569async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3570 subscribe_user_to_channels(
3571 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3572 &session,
3573 )
3574 .await?;
3575 Ok(())
3576}
3577
3578async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3579 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3580 let mut pool = session.connection_pool().await;
3581 for membership in &channels_for_user.channel_memberships {
3582 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3583 }
3584 session.peer.send(
3585 session.connection_id,
3586 build_update_user_channels(&channels_for_user),
3587 )?;
3588 session.peer.send(
3589 session.connection_id,
3590 build_channels_update(channels_for_user),
3591 )?;
3592 Ok(())
3593}
3594
3595/// Creates a new channel.
3596async fn create_channel(
3597 request: proto::CreateChannel,
3598 response: Response<proto::CreateChannel>,
3599 session: UserSession,
3600) -> Result<()> {
3601 let db = session.db().await;
3602
3603 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3604 let (channel, membership) = db
3605 .create_channel(&request.name, parent_id, session.user_id())
3606 .await?;
3607
3608 let root_id = channel.root_id();
3609 let channel = Channel::from_model(channel);
3610
3611 response.send(proto::CreateChannelResponse {
3612 channel: Some(channel.to_proto()),
3613 parent_id: request.parent_id,
3614 })?;
3615
3616 let mut connection_pool = session.connection_pool().await;
3617 if let Some(membership) = membership {
3618 connection_pool.subscribe_to_channel(
3619 membership.user_id,
3620 membership.channel_id,
3621 membership.role,
3622 );
3623 let update = proto::UpdateUserChannels {
3624 channel_memberships: vec![proto::ChannelMembership {
3625 channel_id: membership.channel_id.to_proto(),
3626 role: membership.role.into(),
3627 }],
3628 ..Default::default()
3629 };
3630 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3631 session.peer.send(connection_id, update.clone())?;
3632 }
3633 }
3634
3635 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3636 if !role.can_see_channel(channel.visibility) {
3637 continue;
3638 }
3639
3640 let update = proto::UpdateChannels {
3641 channels: vec![channel.to_proto()],
3642 ..Default::default()
3643 };
3644 session.peer.send(connection_id, update.clone())?;
3645 }
3646
3647 Ok(())
3648}
3649
3650/// Delete a channel
3651async fn delete_channel(
3652 request: proto::DeleteChannel,
3653 response: Response<proto::DeleteChannel>,
3654 session: UserSession,
3655) -> Result<()> {
3656 let db = session.db().await;
3657
3658 let channel_id = request.channel_id;
3659 let (root_channel, removed_channels) = db
3660 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3661 .await?;
3662 response.send(proto::Ack {})?;
3663
3664 // Notify members of removed channels
3665 let mut update = proto::UpdateChannels::default();
3666 update
3667 .delete_channels
3668 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3669
3670 let connection_pool = session.connection_pool().await;
3671 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3672 session.peer.send(connection_id, update.clone())?;
3673 }
3674
3675 Ok(())
3676}
3677
3678/// Invite someone to join a channel.
3679async fn invite_channel_member(
3680 request: proto::InviteChannelMember,
3681 response: Response<proto::InviteChannelMember>,
3682 session: UserSession,
3683) -> Result<()> {
3684 let db = session.db().await;
3685 let channel_id = ChannelId::from_proto(request.channel_id);
3686 let invitee_id = UserId::from_proto(request.user_id);
3687 let InviteMemberResult {
3688 channel,
3689 notifications,
3690 } = db
3691 .invite_channel_member(
3692 channel_id,
3693 invitee_id,
3694 session.user_id(),
3695 request.role().into(),
3696 )
3697 .await?;
3698
3699 let update = proto::UpdateChannels {
3700 channel_invitations: vec![channel.to_proto()],
3701 ..Default::default()
3702 };
3703
3704 let connection_pool = session.connection_pool().await;
3705 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3706 session.peer.send(connection_id, update.clone())?;
3707 }
3708
3709 send_notifications(&connection_pool, &session.peer, notifications);
3710
3711 response.send(proto::Ack {})?;
3712 Ok(())
3713}
3714
3715/// remove someone from a channel
3716async fn remove_channel_member(
3717 request: proto::RemoveChannelMember,
3718 response: Response<proto::RemoveChannelMember>,
3719 session: UserSession,
3720) -> Result<()> {
3721 let db = session.db().await;
3722 let channel_id = ChannelId::from_proto(request.channel_id);
3723 let member_id = UserId::from_proto(request.user_id);
3724
3725 let RemoveChannelMemberResult {
3726 membership_update,
3727 notification_id,
3728 } = db
3729 .remove_channel_member(channel_id, member_id, session.user_id())
3730 .await?;
3731
3732 let mut connection_pool = session.connection_pool().await;
3733 notify_membership_updated(
3734 &mut connection_pool,
3735 membership_update,
3736 member_id,
3737 &session.peer,
3738 );
3739 for connection_id in connection_pool.user_connection_ids(member_id) {
3740 if let Some(notification_id) = notification_id {
3741 session
3742 .peer
3743 .send(
3744 connection_id,
3745 proto::DeleteNotification {
3746 notification_id: notification_id.to_proto(),
3747 },
3748 )
3749 .trace_err();
3750 }
3751 }
3752
3753 response.send(proto::Ack {})?;
3754 Ok(())
3755}
3756
3757/// Toggle the channel between public and private.
3758/// Care is taken to maintain the invariant that public channels only descend from public channels,
3759/// (though members-only channels can appear at any point in the hierarchy).
3760async fn set_channel_visibility(
3761 request: proto::SetChannelVisibility,
3762 response: Response<proto::SetChannelVisibility>,
3763 session: UserSession,
3764) -> Result<()> {
3765 let db = session.db().await;
3766 let channel_id = ChannelId::from_proto(request.channel_id);
3767 let visibility = request.visibility().into();
3768
3769 let channel_model = db
3770 .set_channel_visibility(channel_id, visibility, session.user_id())
3771 .await?;
3772 let root_id = channel_model.root_id();
3773 let channel = Channel::from_model(channel_model);
3774
3775 let mut connection_pool = session.connection_pool().await;
3776 for (user_id, role) in connection_pool
3777 .channel_user_ids(root_id)
3778 .collect::<Vec<_>>()
3779 .into_iter()
3780 {
3781 let update = if role.can_see_channel(channel.visibility) {
3782 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3783 proto::UpdateChannels {
3784 channels: vec![channel.to_proto()],
3785 ..Default::default()
3786 }
3787 } else {
3788 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3789 proto::UpdateChannels {
3790 delete_channels: vec![channel.id.to_proto()],
3791 ..Default::default()
3792 }
3793 };
3794
3795 for connection_id in connection_pool.user_connection_ids(user_id) {
3796 session.peer.send(connection_id, update.clone())?;
3797 }
3798 }
3799
3800 response.send(proto::Ack {})?;
3801 Ok(())
3802}
3803
3804/// Alter the role for a user in the channel.
3805async fn set_channel_member_role(
3806 request: proto::SetChannelMemberRole,
3807 response: Response<proto::SetChannelMemberRole>,
3808 session: UserSession,
3809) -> Result<()> {
3810 let db = session.db().await;
3811 let channel_id = ChannelId::from_proto(request.channel_id);
3812 let member_id = UserId::from_proto(request.user_id);
3813 let result = db
3814 .set_channel_member_role(
3815 channel_id,
3816 session.user_id(),
3817 member_id,
3818 request.role().into(),
3819 )
3820 .await?;
3821
3822 match result {
3823 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3824 let mut connection_pool = session.connection_pool().await;
3825 notify_membership_updated(
3826 &mut connection_pool,
3827 membership_update,
3828 member_id,
3829 &session.peer,
3830 )
3831 }
3832 db::SetMemberRoleResult::InviteUpdated(channel) => {
3833 let update = proto::UpdateChannels {
3834 channel_invitations: vec![channel.to_proto()],
3835 ..Default::default()
3836 };
3837
3838 for connection_id in session
3839 .connection_pool()
3840 .await
3841 .user_connection_ids(member_id)
3842 {
3843 session.peer.send(connection_id, update.clone())?;
3844 }
3845 }
3846 }
3847
3848 response.send(proto::Ack {})?;
3849 Ok(())
3850}
3851
3852/// Change the name of a channel
3853async fn rename_channel(
3854 request: proto::RenameChannel,
3855 response: Response<proto::RenameChannel>,
3856 session: UserSession,
3857) -> Result<()> {
3858 let db = session.db().await;
3859 let channel_id = ChannelId::from_proto(request.channel_id);
3860 let channel_model = db
3861 .rename_channel(channel_id, session.user_id(), &request.name)
3862 .await?;
3863 let root_id = channel_model.root_id();
3864 let channel = Channel::from_model(channel_model);
3865
3866 response.send(proto::RenameChannelResponse {
3867 channel: Some(channel.to_proto()),
3868 })?;
3869
3870 let connection_pool = session.connection_pool().await;
3871 let update = proto::UpdateChannels {
3872 channels: vec![channel.to_proto()],
3873 ..Default::default()
3874 };
3875 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3876 if role.can_see_channel(channel.visibility) {
3877 session.peer.send(connection_id, update.clone())?;
3878 }
3879 }
3880
3881 Ok(())
3882}
3883
3884/// Move a channel to a new parent.
3885async fn move_channel(
3886 request: proto::MoveChannel,
3887 response: Response<proto::MoveChannel>,
3888 session: UserSession,
3889) -> Result<()> {
3890 let channel_id = ChannelId::from_proto(request.channel_id);
3891 let to = ChannelId::from_proto(request.to);
3892
3893 let (root_id, channels) = session
3894 .db()
3895 .await
3896 .move_channel(channel_id, to, session.user_id())
3897 .await?;
3898
3899 let connection_pool = session.connection_pool().await;
3900 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3901 let channels = channels
3902 .iter()
3903 .filter_map(|channel| {
3904 if role.can_see_channel(channel.visibility) {
3905 Some(channel.to_proto())
3906 } else {
3907 None
3908 }
3909 })
3910 .collect::<Vec<_>>();
3911 if channels.is_empty() {
3912 continue;
3913 }
3914
3915 let update = proto::UpdateChannels {
3916 channels,
3917 ..Default::default()
3918 };
3919
3920 session.peer.send(connection_id, update.clone())?;
3921 }
3922
3923 response.send(Ack {})?;
3924 Ok(())
3925}
3926
3927/// Get the list of channel members
3928async fn get_channel_members(
3929 request: proto::GetChannelMembers,
3930 response: Response<proto::GetChannelMembers>,
3931 session: UserSession,
3932) -> Result<()> {
3933 let db = session.db().await;
3934 let channel_id = ChannelId::from_proto(request.channel_id);
3935 let limit = if request.limit == 0 {
3936 u16::MAX as u64
3937 } else {
3938 request.limit
3939 };
3940 let (members, users) = db
3941 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3942 .await?;
3943 response.send(proto::GetChannelMembersResponse { members, users })?;
3944 Ok(())
3945}
3946
3947/// Accept or decline a channel invitation.
3948async fn respond_to_channel_invite(
3949 request: proto::RespondToChannelInvite,
3950 response: Response<proto::RespondToChannelInvite>,
3951 session: UserSession,
3952) -> Result<()> {
3953 let db = session.db().await;
3954 let channel_id = ChannelId::from_proto(request.channel_id);
3955 let RespondToChannelInvite {
3956 membership_update,
3957 notifications,
3958 } = db
3959 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3960 .await?;
3961
3962 let mut connection_pool = session.connection_pool().await;
3963 if let Some(membership_update) = membership_update {
3964 notify_membership_updated(
3965 &mut connection_pool,
3966 membership_update,
3967 session.user_id(),
3968 &session.peer,
3969 );
3970 } else {
3971 let update = proto::UpdateChannels {
3972 remove_channel_invitations: vec![channel_id.to_proto()],
3973 ..Default::default()
3974 };
3975
3976 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3977 session.peer.send(connection_id, update.clone())?;
3978 }
3979 };
3980
3981 send_notifications(&connection_pool, &session.peer, notifications);
3982
3983 response.send(proto::Ack {})?;
3984
3985 Ok(())
3986}
3987
3988/// Join the channels' room
3989async fn join_channel(
3990 request: proto::JoinChannel,
3991 response: Response<proto::JoinChannel>,
3992 session: UserSession,
3993) -> Result<()> {
3994 let channel_id = ChannelId::from_proto(request.channel_id);
3995 join_channel_internal(channel_id, Box::new(response), session).await
3996}
3997
3998trait JoinChannelInternalResponse {
3999 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4000}
4001impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4002 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4003 Response::<proto::JoinChannel>::send(self, result)
4004 }
4005}
4006impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4007 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4008 Response::<proto::JoinRoom>::send(self, result)
4009 }
4010}
4011
4012async fn join_channel_internal(
4013 channel_id: ChannelId,
4014 response: Box<impl JoinChannelInternalResponse>,
4015 session: UserSession,
4016) -> Result<()> {
4017 let joined_room = {
4018 let mut db = session.db().await;
4019 // If zed quits without leaving the room, and the user re-opens zed before the
4020 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4021 // room they were in.
4022 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4023 tracing::info!(
4024 stale_connection_id = %connection,
4025 "cleaning up stale connection",
4026 );
4027 drop(db);
4028 leave_room_for_session(&session, connection).await?;
4029 db = session.db().await;
4030 }
4031
4032 let (joined_room, membership_updated, role) = db
4033 .join_channel(channel_id, session.user_id(), session.connection_id)
4034 .await?;
4035
4036 let live_kit_connection_info =
4037 session
4038 .app_state
4039 .live_kit_client
4040 .as_ref()
4041 .and_then(|live_kit| {
4042 let (can_publish, token) = if role == ChannelRole::Guest {
4043 (
4044 false,
4045 live_kit
4046 .guest_token(
4047 &joined_room.room.live_kit_room,
4048 &session.user_id().to_string(),
4049 )
4050 .trace_err()?,
4051 )
4052 } else {
4053 (
4054 true,
4055 live_kit
4056 .room_token(
4057 &joined_room.room.live_kit_room,
4058 &session.user_id().to_string(),
4059 )
4060 .trace_err()?,
4061 )
4062 };
4063
4064 Some(LiveKitConnectionInfo {
4065 server_url: live_kit.url().into(),
4066 token,
4067 can_publish,
4068 })
4069 });
4070
4071 response.send(proto::JoinRoomResponse {
4072 room: Some(joined_room.room.clone()),
4073 channel_id: joined_room
4074 .channel
4075 .as_ref()
4076 .map(|channel| channel.id.to_proto()),
4077 live_kit_connection_info,
4078 })?;
4079
4080 let mut connection_pool = session.connection_pool().await;
4081 if let Some(membership_updated) = membership_updated {
4082 notify_membership_updated(
4083 &mut connection_pool,
4084 membership_updated,
4085 session.user_id(),
4086 &session.peer,
4087 );
4088 }
4089
4090 room_updated(&joined_room.room, &session.peer);
4091
4092 joined_room
4093 };
4094
4095 channel_updated(
4096 &joined_room
4097 .channel
4098 .ok_or_else(|| anyhow!("channel not returned"))?,
4099 &joined_room.room,
4100 &session.peer,
4101 &*session.connection_pool().await,
4102 );
4103
4104 update_user_contacts(session.user_id(), &session).await?;
4105 Ok(())
4106}
4107
4108/// Start editing the channel notes
4109async fn join_channel_buffer(
4110 request: proto::JoinChannelBuffer,
4111 response: Response<proto::JoinChannelBuffer>,
4112 session: UserSession,
4113) -> Result<()> {
4114 let db = session.db().await;
4115 let channel_id = ChannelId::from_proto(request.channel_id);
4116
4117 let open_response = db
4118 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4119 .await?;
4120
4121 let collaborators = open_response.collaborators.clone();
4122 response.send(open_response)?;
4123
4124 let update = UpdateChannelBufferCollaborators {
4125 channel_id: channel_id.to_proto(),
4126 collaborators: collaborators.clone(),
4127 };
4128 channel_buffer_updated(
4129 session.connection_id,
4130 collaborators
4131 .iter()
4132 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4133 &update,
4134 &session.peer,
4135 );
4136
4137 Ok(())
4138}
4139
4140/// Edit the channel notes
4141async fn update_channel_buffer(
4142 request: proto::UpdateChannelBuffer,
4143 session: UserSession,
4144) -> Result<()> {
4145 let db = session.db().await;
4146 let channel_id = ChannelId::from_proto(request.channel_id);
4147
4148 let (collaborators, epoch, version) = db
4149 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4150 .await?;
4151
4152 channel_buffer_updated(
4153 session.connection_id,
4154 collaborators.clone(),
4155 &proto::UpdateChannelBuffer {
4156 channel_id: channel_id.to_proto(),
4157 operations: request.operations,
4158 },
4159 &session.peer,
4160 );
4161
4162 let pool = &*session.connection_pool().await;
4163
4164 let non_collaborators =
4165 pool.channel_connection_ids(channel_id)
4166 .filter_map(|(connection_id, _)| {
4167 if collaborators.contains(&connection_id) {
4168 None
4169 } else {
4170 Some(connection_id)
4171 }
4172 });
4173
4174 broadcast(None, non_collaborators, |peer_id| {
4175 session.peer.send(
4176 peer_id,
4177 proto::UpdateChannels {
4178 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4179 channel_id: channel_id.to_proto(),
4180 epoch: epoch as u64,
4181 version: version.clone(),
4182 }],
4183 ..Default::default()
4184 },
4185 )
4186 });
4187
4188 Ok(())
4189}
4190
4191/// Rejoin the channel notes after a connection blip
4192async fn rejoin_channel_buffers(
4193 request: proto::RejoinChannelBuffers,
4194 response: Response<proto::RejoinChannelBuffers>,
4195 session: UserSession,
4196) -> Result<()> {
4197 let db = session.db().await;
4198 let buffers = db
4199 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4200 .await?;
4201
4202 for rejoined_buffer in &buffers {
4203 let collaborators_to_notify = rejoined_buffer
4204 .buffer
4205 .collaborators
4206 .iter()
4207 .filter_map(|c| Some(c.peer_id?.into()));
4208 channel_buffer_updated(
4209 session.connection_id,
4210 collaborators_to_notify,
4211 &proto::UpdateChannelBufferCollaborators {
4212 channel_id: rejoined_buffer.buffer.channel_id,
4213 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4214 },
4215 &session.peer,
4216 );
4217 }
4218
4219 response.send(proto::RejoinChannelBuffersResponse {
4220 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4221 })?;
4222
4223 Ok(())
4224}
4225
4226/// Stop editing the channel notes
4227async fn leave_channel_buffer(
4228 request: proto::LeaveChannelBuffer,
4229 response: Response<proto::LeaveChannelBuffer>,
4230 session: UserSession,
4231) -> Result<()> {
4232 let db = session.db().await;
4233 let channel_id = ChannelId::from_proto(request.channel_id);
4234
4235 let left_buffer = db
4236 .leave_channel_buffer(channel_id, session.connection_id)
4237 .await?;
4238
4239 response.send(Ack {})?;
4240
4241 channel_buffer_updated(
4242 session.connection_id,
4243 left_buffer.connections,
4244 &proto::UpdateChannelBufferCollaborators {
4245 channel_id: channel_id.to_proto(),
4246 collaborators: left_buffer.collaborators,
4247 },
4248 &session.peer,
4249 );
4250
4251 Ok(())
4252}
4253
4254fn channel_buffer_updated<T: EnvelopedMessage>(
4255 sender_id: ConnectionId,
4256 collaborators: impl IntoIterator<Item = ConnectionId>,
4257 message: &T,
4258 peer: &Peer,
4259) {
4260 broadcast(Some(sender_id), collaborators, |peer_id| {
4261 peer.send(peer_id, message.clone())
4262 });
4263}
4264
4265fn send_notifications(
4266 connection_pool: &ConnectionPool,
4267 peer: &Peer,
4268 notifications: db::NotificationBatch,
4269) {
4270 for (user_id, notification) in notifications {
4271 for connection_id in connection_pool.user_connection_ids(user_id) {
4272 if let Err(error) = peer.send(
4273 connection_id,
4274 proto::AddNotification {
4275 notification: Some(notification.clone()),
4276 },
4277 ) {
4278 tracing::error!(
4279 "failed to send notification to {:?} {}",
4280 connection_id,
4281 error
4282 );
4283 }
4284 }
4285 }
4286}
4287
4288/// Send a message to the channel
4289async fn send_channel_message(
4290 request: proto::SendChannelMessage,
4291 response: Response<proto::SendChannelMessage>,
4292 session: UserSession,
4293) -> Result<()> {
4294 // Validate the message body.
4295 let body = request.body.trim().to_string();
4296 if body.len() > MAX_MESSAGE_LEN {
4297 return Err(anyhow!("message is too long"))?;
4298 }
4299 if body.is_empty() {
4300 return Err(anyhow!("message can't be blank"))?;
4301 }
4302
4303 // TODO: adjust mentions if body is trimmed
4304
4305 let timestamp = OffsetDateTime::now_utc();
4306 let nonce = request
4307 .nonce
4308 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4309
4310 let channel_id = ChannelId::from_proto(request.channel_id);
4311 let CreatedChannelMessage {
4312 message_id,
4313 participant_connection_ids,
4314 notifications,
4315 } = session
4316 .db()
4317 .await
4318 .create_channel_message(
4319 channel_id,
4320 session.user_id(),
4321 &body,
4322 &request.mentions,
4323 timestamp,
4324 nonce.clone().into(),
4325 match request.reply_to_message_id {
4326 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4327 None => None,
4328 },
4329 )
4330 .await?;
4331
4332 let message = proto::ChannelMessage {
4333 sender_id: session.user_id().to_proto(),
4334 id: message_id.to_proto(),
4335 body,
4336 mentions: request.mentions,
4337 timestamp: timestamp.unix_timestamp() as u64,
4338 nonce: Some(nonce),
4339 reply_to_message_id: request.reply_to_message_id,
4340 edited_at: None,
4341 };
4342 broadcast(
4343 Some(session.connection_id),
4344 participant_connection_ids.clone(),
4345 |connection| {
4346 session.peer.send(
4347 connection,
4348 proto::ChannelMessageSent {
4349 channel_id: channel_id.to_proto(),
4350 message: Some(message.clone()),
4351 },
4352 )
4353 },
4354 );
4355 response.send(proto::SendChannelMessageResponse {
4356 message: Some(message),
4357 })?;
4358
4359 let pool = &*session.connection_pool().await;
4360 let non_participants =
4361 pool.channel_connection_ids(channel_id)
4362 .filter_map(|(connection_id, _)| {
4363 if participant_connection_ids.contains(&connection_id) {
4364 None
4365 } else {
4366 Some(connection_id)
4367 }
4368 });
4369 broadcast(None, non_participants, |peer_id| {
4370 session.peer.send(
4371 peer_id,
4372 proto::UpdateChannels {
4373 latest_channel_message_ids: vec![proto::ChannelMessageId {
4374 channel_id: channel_id.to_proto(),
4375 message_id: message_id.to_proto(),
4376 }],
4377 ..Default::default()
4378 },
4379 )
4380 });
4381 send_notifications(pool, &session.peer, notifications);
4382
4383 Ok(())
4384}
4385
4386/// Delete a channel message
4387async fn remove_channel_message(
4388 request: proto::RemoveChannelMessage,
4389 response: Response<proto::RemoveChannelMessage>,
4390 session: UserSession,
4391) -> Result<()> {
4392 let channel_id = ChannelId::from_proto(request.channel_id);
4393 let message_id = MessageId::from_proto(request.message_id);
4394 let (connection_ids, existing_notification_ids) = session
4395 .db()
4396 .await
4397 .remove_channel_message(channel_id, message_id, session.user_id())
4398 .await?;
4399
4400 broadcast(
4401 Some(session.connection_id),
4402 connection_ids,
4403 move |connection| {
4404 session.peer.send(connection, request.clone())?;
4405
4406 for notification_id in &existing_notification_ids {
4407 session.peer.send(
4408 connection,
4409 proto::DeleteNotification {
4410 notification_id: (*notification_id).to_proto(),
4411 },
4412 )?;
4413 }
4414
4415 Ok(())
4416 },
4417 );
4418 response.send(proto::Ack {})?;
4419 Ok(())
4420}
4421
4422async fn update_channel_message(
4423 request: proto::UpdateChannelMessage,
4424 response: Response<proto::UpdateChannelMessage>,
4425 session: UserSession,
4426) -> Result<()> {
4427 let channel_id = ChannelId::from_proto(request.channel_id);
4428 let message_id = MessageId::from_proto(request.message_id);
4429 let updated_at = OffsetDateTime::now_utc();
4430 let UpdatedChannelMessage {
4431 message_id,
4432 participant_connection_ids,
4433 notifications,
4434 reply_to_message_id,
4435 timestamp,
4436 deleted_mention_notification_ids,
4437 updated_mention_notifications,
4438 } = session
4439 .db()
4440 .await
4441 .update_channel_message(
4442 channel_id,
4443 message_id,
4444 session.user_id(),
4445 request.body.as_str(),
4446 &request.mentions,
4447 updated_at,
4448 )
4449 .await?;
4450
4451 let nonce = request
4452 .nonce
4453 .clone()
4454 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4455
4456 let message = proto::ChannelMessage {
4457 sender_id: session.user_id().to_proto(),
4458 id: message_id.to_proto(),
4459 body: request.body.clone(),
4460 mentions: request.mentions.clone(),
4461 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4462 nonce: Some(nonce),
4463 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4464 edited_at: Some(updated_at.unix_timestamp() as u64),
4465 };
4466
4467 response.send(proto::Ack {})?;
4468
4469 let pool = &*session.connection_pool().await;
4470 broadcast(
4471 Some(session.connection_id),
4472 participant_connection_ids,
4473 |connection| {
4474 session.peer.send(
4475 connection,
4476 proto::ChannelMessageUpdate {
4477 channel_id: channel_id.to_proto(),
4478 message: Some(message.clone()),
4479 },
4480 )?;
4481
4482 for notification_id in &deleted_mention_notification_ids {
4483 session.peer.send(
4484 connection,
4485 proto::DeleteNotification {
4486 notification_id: (*notification_id).to_proto(),
4487 },
4488 )?;
4489 }
4490
4491 for notification in &updated_mention_notifications {
4492 session.peer.send(
4493 connection,
4494 proto::UpdateNotification {
4495 notification: Some(notification.clone()),
4496 },
4497 )?;
4498 }
4499
4500 Ok(())
4501 },
4502 );
4503
4504 send_notifications(pool, &session.peer, notifications);
4505
4506 Ok(())
4507}
4508
4509/// Mark a channel message as read
4510async fn acknowledge_channel_message(
4511 request: proto::AckChannelMessage,
4512 session: UserSession,
4513) -> Result<()> {
4514 let channel_id = ChannelId::from_proto(request.channel_id);
4515 let message_id = MessageId::from_proto(request.message_id);
4516 let notifications = session
4517 .db()
4518 .await
4519 .observe_channel_message(channel_id, session.user_id(), message_id)
4520 .await?;
4521 send_notifications(
4522 &*session.connection_pool().await,
4523 &session.peer,
4524 notifications,
4525 );
4526 Ok(())
4527}
4528
4529/// Mark a buffer version as synced
4530async fn acknowledge_buffer_version(
4531 request: proto::AckBufferOperation,
4532 session: UserSession,
4533) -> Result<()> {
4534 let buffer_id = BufferId::from_proto(request.buffer_id);
4535 session
4536 .db()
4537 .await
4538 .observe_buffer_version(
4539 buffer_id,
4540 session.user_id(),
4541 request.epoch as i32,
4542 &request.version,
4543 )
4544 .await?;
4545 Ok(())
4546}
4547
4548async fn count_language_model_tokens(
4549 request: proto::CountLanguageModelTokens,
4550 response: Response<proto::CountLanguageModelTokens>,
4551 session: Session,
4552 config: &Config,
4553) -> Result<()> {
4554 let Some(session) = session.for_user() else {
4555 return Err(anyhow!("user not found"))?;
4556 };
4557 authorize_access_to_legacy_llm_endpoints(&session).await?;
4558
4559 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4560 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4561 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4562 };
4563
4564 session
4565 .app_state
4566 .rate_limiter
4567 .check(&*rate_limit, session.user_id())
4568 .await?;
4569
4570 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4571 Some(proto::LanguageModelProvider::Google) => {
4572 let api_key = config
4573 .google_ai_api_key
4574 .as_ref()
4575 .context("no Google AI API key configured on the server")?;
4576 google_ai::count_tokens(
4577 session.http_client.as_ref(),
4578 google_ai::API_URL,
4579 api_key,
4580 serde_json::from_str(&request.request)?,
4581 )
4582 .await?
4583 }
4584 _ => return Err(anyhow!("unsupported provider"))?,
4585 };
4586
4587 response.send(proto::CountLanguageModelTokensResponse {
4588 token_count: result.total_tokens as u32,
4589 })?;
4590
4591 Ok(())
4592}
4593
4594struct ZedProCountLanguageModelTokensRateLimit;
4595
4596impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4597 fn capacity(&self) -> usize {
4598 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4599 .ok()
4600 .and_then(|v| v.parse().ok())
4601 .unwrap_or(600) // Picked arbitrarily
4602 }
4603
4604 fn refill_duration(&self) -> chrono::Duration {
4605 chrono::Duration::hours(1)
4606 }
4607
4608 fn db_name(&self) -> &'static str {
4609 "zed-pro:count-language-model-tokens"
4610 }
4611}
4612
4613struct FreeCountLanguageModelTokensRateLimit;
4614
4615impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4616 fn capacity(&self) -> usize {
4617 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4618 .ok()
4619 .and_then(|v| v.parse().ok())
4620 .unwrap_or(600 / 10) // Picked arbitrarily
4621 }
4622
4623 fn refill_duration(&self) -> chrono::Duration {
4624 chrono::Duration::hours(1)
4625 }
4626
4627 fn db_name(&self) -> &'static str {
4628 "free:count-language-model-tokens"
4629 }
4630}
4631
4632struct ZedProComputeEmbeddingsRateLimit;
4633
4634impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4635 fn capacity(&self) -> usize {
4636 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4637 .ok()
4638 .and_then(|v| v.parse().ok())
4639 .unwrap_or(5000) // Picked arbitrarily
4640 }
4641
4642 fn refill_duration(&self) -> chrono::Duration {
4643 chrono::Duration::hours(1)
4644 }
4645
4646 fn db_name(&self) -> &'static str {
4647 "zed-pro:compute-embeddings"
4648 }
4649}
4650
4651struct FreeComputeEmbeddingsRateLimit;
4652
4653impl RateLimit for FreeComputeEmbeddingsRateLimit {
4654 fn capacity(&self) -> usize {
4655 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4656 .ok()
4657 .and_then(|v| v.parse().ok())
4658 .unwrap_or(5000 / 10) // Picked arbitrarily
4659 }
4660
4661 fn refill_duration(&self) -> chrono::Duration {
4662 chrono::Duration::hours(1)
4663 }
4664
4665 fn db_name(&self) -> &'static str {
4666 "free:compute-embeddings"
4667 }
4668}
4669
4670async fn compute_embeddings(
4671 request: proto::ComputeEmbeddings,
4672 response: Response<proto::ComputeEmbeddings>,
4673 session: UserSession,
4674 api_key: Option<Arc<str>>,
4675) -> Result<()> {
4676 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4677 authorize_access_to_legacy_llm_endpoints(&session).await?;
4678
4679 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4680 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4681 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4682 };
4683
4684 session
4685 .app_state
4686 .rate_limiter
4687 .check(&*rate_limit, session.user_id())
4688 .await?;
4689
4690 let embeddings = match request.model.as_str() {
4691 "openai/text-embedding-3-small" => {
4692 open_ai::embed(
4693 session.http_client.as_ref(),
4694 OPEN_AI_API_URL,
4695 &api_key,
4696 OpenAiEmbeddingModel::TextEmbedding3Small,
4697 request.texts.iter().map(|text| text.as_str()),
4698 )
4699 .await?
4700 }
4701 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4702 };
4703
4704 let embeddings = request
4705 .texts
4706 .iter()
4707 .map(|text| {
4708 let mut hasher = sha2::Sha256::new();
4709 hasher.update(text.as_bytes());
4710 let result = hasher.finalize();
4711 result.to_vec()
4712 })
4713 .zip(
4714 embeddings
4715 .data
4716 .into_iter()
4717 .map(|embedding| embedding.embedding),
4718 )
4719 .collect::<HashMap<_, _>>();
4720
4721 let db = session.db().await;
4722 db.save_embeddings(&request.model, &embeddings)
4723 .await
4724 .context("failed to save embeddings")
4725 .trace_err();
4726
4727 response.send(proto::ComputeEmbeddingsResponse {
4728 embeddings: embeddings
4729 .into_iter()
4730 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4731 .collect(),
4732 })?;
4733 Ok(())
4734}
4735
4736async fn get_cached_embeddings(
4737 request: proto::GetCachedEmbeddings,
4738 response: Response<proto::GetCachedEmbeddings>,
4739 session: UserSession,
4740) -> Result<()> {
4741 authorize_access_to_legacy_llm_endpoints(&session).await?;
4742
4743 let db = session.db().await;
4744 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4745
4746 response.send(proto::GetCachedEmbeddingsResponse {
4747 embeddings: embeddings
4748 .into_iter()
4749 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4750 .collect(),
4751 })?;
4752 Ok(())
4753}
4754
4755/// This is leftover from before the LLM service.
4756///
4757/// The endpoints protected by this check will be moved there eventually.
4758async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4759 if session.is_staff() {
4760 Ok(())
4761 } else {
4762 Err(anyhow!("permission denied"))?
4763 }
4764}
4765
4766/// Get a Supermaven API key for the user
4767async fn get_supermaven_api_key(
4768 _request: proto::GetSupermavenApiKey,
4769 response: Response<proto::GetSupermavenApiKey>,
4770 session: UserSession,
4771) -> Result<()> {
4772 let user_id: String = session.user_id().to_string();
4773 if !session.is_staff() {
4774 return Err(anyhow!("supermaven not enabled for this account"))?;
4775 }
4776
4777 let email = session
4778 .email()
4779 .ok_or_else(|| anyhow!("user must have an email"))?;
4780
4781 let supermaven_admin_api = session
4782 .supermaven_client
4783 .as_ref()
4784 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4785
4786 let result = supermaven_admin_api
4787 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4788 .await?;
4789
4790 response.send(proto::GetSupermavenApiKeyResponse {
4791 api_key: result.api_key,
4792 })?;
4793
4794 Ok(())
4795}
4796
4797/// Start receiving chat updates for a channel
4798async fn join_channel_chat(
4799 request: proto::JoinChannelChat,
4800 response: Response<proto::JoinChannelChat>,
4801 session: UserSession,
4802) -> Result<()> {
4803 let channel_id = ChannelId::from_proto(request.channel_id);
4804
4805 let db = session.db().await;
4806 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4807 .await?;
4808 let messages = db
4809 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4810 .await?;
4811 response.send(proto::JoinChannelChatResponse {
4812 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4813 messages,
4814 })?;
4815 Ok(())
4816}
4817
4818/// Stop receiving chat updates for a channel
4819async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4820 let channel_id = ChannelId::from_proto(request.channel_id);
4821 session
4822 .db()
4823 .await
4824 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4825 .await?;
4826 Ok(())
4827}
4828
4829/// Retrieve the chat history for a channel
4830async fn get_channel_messages(
4831 request: proto::GetChannelMessages,
4832 response: Response<proto::GetChannelMessages>,
4833 session: UserSession,
4834) -> Result<()> {
4835 let channel_id = ChannelId::from_proto(request.channel_id);
4836 let messages = session
4837 .db()
4838 .await
4839 .get_channel_messages(
4840 channel_id,
4841 session.user_id(),
4842 MESSAGE_COUNT_PER_PAGE,
4843 Some(MessageId::from_proto(request.before_message_id)),
4844 )
4845 .await?;
4846 response.send(proto::GetChannelMessagesResponse {
4847 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4848 messages,
4849 })?;
4850 Ok(())
4851}
4852
4853/// Retrieve specific chat messages
4854async fn get_channel_messages_by_id(
4855 request: proto::GetChannelMessagesById,
4856 response: Response<proto::GetChannelMessagesById>,
4857 session: UserSession,
4858) -> Result<()> {
4859 let message_ids = request
4860 .message_ids
4861 .iter()
4862 .map(|id| MessageId::from_proto(*id))
4863 .collect::<Vec<_>>();
4864 let messages = session
4865 .db()
4866 .await
4867 .get_channel_messages_by_id(session.user_id(), &message_ids)
4868 .await?;
4869 response.send(proto::GetChannelMessagesResponse {
4870 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4871 messages,
4872 })?;
4873 Ok(())
4874}
4875
4876/// Retrieve the current users notifications
4877async fn get_notifications(
4878 request: proto::GetNotifications,
4879 response: Response<proto::GetNotifications>,
4880 session: UserSession,
4881) -> Result<()> {
4882 let notifications = session
4883 .db()
4884 .await
4885 .get_notifications(
4886 session.user_id(),
4887 NOTIFICATION_COUNT_PER_PAGE,
4888 request
4889 .before_id
4890 .map(|id| db::NotificationId::from_proto(id)),
4891 )
4892 .await?;
4893 response.send(proto::GetNotificationsResponse {
4894 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4895 notifications,
4896 })?;
4897 Ok(())
4898}
4899
4900/// Mark notifications as read
4901async fn mark_notification_as_read(
4902 request: proto::MarkNotificationRead,
4903 response: Response<proto::MarkNotificationRead>,
4904 session: UserSession,
4905) -> Result<()> {
4906 let database = &session.db().await;
4907 let notifications = database
4908 .mark_notification_as_read_by_id(
4909 session.user_id(),
4910 NotificationId::from_proto(request.notification_id),
4911 )
4912 .await?;
4913 send_notifications(
4914 &*session.connection_pool().await,
4915 &session.peer,
4916 notifications,
4917 );
4918 response.send(proto::Ack {})?;
4919 Ok(())
4920}
4921
4922/// Get the current users information
4923async fn get_private_user_info(
4924 _request: proto::GetPrivateUserInfo,
4925 response: Response<proto::GetPrivateUserInfo>,
4926 session: UserSession,
4927) -> Result<()> {
4928 let db = session.db().await;
4929
4930 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4931 let user = db
4932 .get_user_by_id(session.user_id())
4933 .await?
4934 .ok_or_else(|| anyhow!("user not found"))?;
4935 let flags = db.get_user_flags(session.user_id()).await?;
4936
4937 response.send(proto::GetPrivateUserInfoResponse {
4938 metrics_id,
4939 staff: user.admin,
4940 flags,
4941 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4942 })?;
4943 Ok(())
4944}
4945
4946/// Accept the terms of service (tos) on behalf of the current user
4947async fn accept_terms_of_service(
4948 _request: proto::AcceptTermsOfService,
4949 response: Response<proto::AcceptTermsOfService>,
4950 session: UserSession,
4951) -> Result<()> {
4952 let db = session.db().await;
4953
4954 let accepted_tos_at = Utc::now();
4955 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4956 .await?;
4957
4958 response.send(proto::AcceptTermsOfServiceResponse {
4959 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4960 })?;
4961 Ok(())
4962}
4963
4964/// The minimum account age an account must have in order to use the LLM service.
4965const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4966
4967async fn get_llm_api_token(
4968 _request: proto::GetLlmToken,
4969 response: Response<proto::GetLlmToken>,
4970 session: UserSession,
4971) -> Result<()> {
4972 let db = session.db().await;
4973
4974 let flags = db.get_user_flags(session.user_id()).await?;
4975 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4976 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4977
4978 if !session.is_staff() && !has_language_models_feature_flag {
4979 Err(anyhow!("permission denied"))?
4980 }
4981
4982 let user_id = session.user_id();
4983 let user = db
4984 .get_user_by_id(user_id)
4985 .await?
4986 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4987
4988 if user.accepted_tos_at.is_none() {
4989 Err(anyhow!("terms of service not accepted"))?
4990 }
4991
4992 let mut account_created_at = user.created_at;
4993 if let Some(github_created_at) = user.github_user_created_at {
4994 account_created_at = account_created_at.min(github_created_at);
4995 }
4996 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4997 Err(anyhow!("account too young"))?
4998 }
4999 let token = LlmTokenClaims::create(
5000 user.id,
5001 user.github_login.clone(),
5002 session.is_staff(),
5003 has_llm_closed_beta_feature_flag,
5004 session.current_plan(db).await?,
5005 &session.app_state.config,
5006 )?;
5007 response.send(proto::GetLlmTokenResponse { token })?;
5008 Ok(())
5009}
5010
5011fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5012 let message = match message {
5013 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5014 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5015 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5016 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5017 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5018 code: frame.code.into(),
5019 reason: frame.reason,
5020 })),
5021 // We should never receive a frame while reading the message, according
5022 // to the `tungstenite` maintainers:
5023 //
5024 // > It cannot occur when you read messages from the WebSocket, but it
5025 // > can be used when you want to send the raw frames (e.g. you want to
5026 // > send the frames to the WebSocket without composing the full message first).
5027 // >
5028 // > — https://github.com/snapview/tungstenite-rs/issues/268
5029 TungsteniteMessage::Frame(_) => {
5030 bail!("received an unexpected frame while reading the message")
5031 }
5032 };
5033
5034 Ok(message)
5035}
5036
5037fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5038 match message {
5039 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5040 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5041 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5042 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5043 AxumMessage::Close(frame) => {
5044 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5045 code: frame.code.into(),
5046 reason: frame.reason,
5047 }))
5048 }
5049 }
5050}
5051
5052fn notify_membership_updated(
5053 connection_pool: &mut ConnectionPool,
5054 result: MembershipUpdated,
5055 user_id: UserId,
5056 peer: &Peer,
5057) {
5058 for membership in &result.new_channels.channel_memberships {
5059 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5060 }
5061 for channel_id in &result.removed_channels {
5062 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5063 }
5064
5065 let user_channels_update = proto::UpdateUserChannels {
5066 channel_memberships: result
5067 .new_channels
5068 .channel_memberships
5069 .iter()
5070 .map(|cm| proto::ChannelMembership {
5071 channel_id: cm.channel_id.to_proto(),
5072 role: cm.role.into(),
5073 })
5074 .collect(),
5075 ..Default::default()
5076 };
5077
5078 let mut update = build_channels_update(result.new_channels);
5079 update.delete_channels = result
5080 .removed_channels
5081 .into_iter()
5082 .map(|id| id.to_proto())
5083 .collect();
5084 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5085
5086 for connection_id in connection_pool.user_connection_ids(user_id) {
5087 peer.send(connection_id, user_channels_update.clone())
5088 .trace_err();
5089 peer.send(connection_id, update.clone()).trace_err();
5090 }
5091}
5092
5093fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5094 proto::UpdateUserChannels {
5095 channel_memberships: channels
5096 .channel_memberships
5097 .iter()
5098 .map(|m| proto::ChannelMembership {
5099 channel_id: m.channel_id.to_proto(),
5100 role: m.role.into(),
5101 })
5102 .collect(),
5103 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5104 observed_channel_message_id: channels.observed_channel_messages.clone(),
5105 }
5106}
5107
5108fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5109 let mut update = proto::UpdateChannels::default();
5110
5111 for channel in channels.channels {
5112 update.channels.push(channel.to_proto());
5113 }
5114
5115 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5116 update.latest_channel_message_ids = channels.latest_channel_messages;
5117
5118 for (channel_id, participants) in channels.channel_participants {
5119 update
5120 .channel_participants
5121 .push(proto::ChannelParticipants {
5122 channel_id: channel_id.to_proto(),
5123 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5124 });
5125 }
5126
5127 for channel in channels.invited_channels {
5128 update.channel_invitations.push(channel.to_proto());
5129 }
5130
5131 update.hosted_projects = channels.hosted_projects;
5132 update
5133}
5134
5135fn build_initial_contacts_update(
5136 contacts: Vec<db::Contact>,
5137 pool: &ConnectionPool,
5138) -> proto::UpdateContacts {
5139 let mut update = proto::UpdateContacts::default();
5140
5141 for contact in contacts {
5142 match contact {
5143 db::Contact::Accepted { user_id, busy } => {
5144 update.contacts.push(contact_for_user(user_id, busy, &pool));
5145 }
5146 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5147 db::Contact::Incoming { user_id } => {
5148 update
5149 .incoming_requests
5150 .push(proto::IncomingContactRequest {
5151 requester_id: user_id.to_proto(),
5152 })
5153 }
5154 }
5155 }
5156
5157 update
5158}
5159
5160fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5161 proto::Contact {
5162 user_id: user_id.to_proto(),
5163 online: pool.is_user_online(user_id),
5164 busy,
5165 }
5166}
5167
5168fn room_updated(room: &proto::Room, peer: &Peer) {
5169 broadcast(
5170 None,
5171 room.participants
5172 .iter()
5173 .filter_map(|participant| Some(participant.peer_id?.into())),
5174 |peer_id| {
5175 peer.send(
5176 peer_id,
5177 proto::RoomUpdated {
5178 room: Some(room.clone()),
5179 },
5180 )
5181 },
5182 );
5183}
5184
5185fn channel_updated(
5186 channel: &db::channel::Model,
5187 room: &proto::Room,
5188 peer: &Peer,
5189 pool: &ConnectionPool,
5190) {
5191 let participants = room
5192 .participants
5193 .iter()
5194 .map(|p| p.user_id)
5195 .collect::<Vec<_>>();
5196
5197 broadcast(
5198 None,
5199 pool.channel_connection_ids(channel.root_id())
5200 .filter_map(|(channel_id, role)| {
5201 role.can_see_channel(channel.visibility).then(|| channel_id)
5202 }),
5203 |peer_id| {
5204 peer.send(
5205 peer_id,
5206 proto::UpdateChannels {
5207 channel_participants: vec![proto::ChannelParticipants {
5208 channel_id: channel.id.to_proto(),
5209 participant_user_ids: participants.clone(),
5210 }],
5211 ..Default::default()
5212 },
5213 )
5214 },
5215 );
5216}
5217
5218async fn send_dev_server_projects_update(
5219 user_id: UserId,
5220 mut status: proto::DevServerProjectsUpdate,
5221 session: &Session,
5222) {
5223 let pool = session.connection_pool().await;
5224 for dev_server in &mut status.dev_servers {
5225 dev_server.status =
5226 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5227 }
5228 let connections = pool.user_connection_ids(user_id);
5229 for connection_id in connections {
5230 session.peer.send(connection_id, status.clone()).trace_err();
5231 }
5232}
5233
5234async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5235 let db = session.db().await;
5236
5237 let contacts = db.get_contacts(user_id).await?;
5238 let busy = db.is_user_busy(user_id).await?;
5239
5240 let pool = session.connection_pool().await;
5241 let updated_contact = contact_for_user(user_id, busy, &pool);
5242 for contact in contacts {
5243 if let db::Contact::Accepted {
5244 user_id: contact_user_id,
5245 ..
5246 } = contact
5247 {
5248 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5249 session
5250 .peer
5251 .send(
5252 contact_conn_id,
5253 proto::UpdateContacts {
5254 contacts: vec![updated_contact.clone()],
5255 remove_contacts: Default::default(),
5256 incoming_requests: Default::default(),
5257 remove_incoming_requests: Default::default(),
5258 outgoing_requests: Default::default(),
5259 remove_outgoing_requests: Default::default(),
5260 },
5261 )
5262 .trace_err();
5263 }
5264 }
5265 }
5266 Ok(())
5267}
5268
5269async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5270 log::info!("lost dev server connection, unsharing projects");
5271 let project_ids = session
5272 .db()
5273 .await
5274 .get_stale_dev_server_projects(session.connection_id)
5275 .await?;
5276
5277 for project_id in project_ids {
5278 // not unshare re-checks the connection ids match, so we get away with no transaction
5279 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5280 }
5281
5282 let user_id = session.dev_server().user_id;
5283 let update = session
5284 .db()
5285 .await
5286 .dev_server_projects_update(user_id)
5287 .await?;
5288
5289 send_dev_server_projects_update(user_id, update, session).await;
5290
5291 Ok(())
5292}
5293
5294async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5295 let mut contacts_to_update = HashSet::default();
5296
5297 let room_id;
5298 let canceled_calls_to_user_ids;
5299 let live_kit_room;
5300 let delete_live_kit_room;
5301 let room;
5302 let channel;
5303
5304 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5305 contacts_to_update.insert(session.user_id());
5306
5307 for project in left_room.left_projects.values() {
5308 project_left(project, session);
5309 }
5310
5311 room_id = RoomId::from_proto(left_room.room.id);
5312 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5313 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5314 delete_live_kit_room = left_room.deleted;
5315 room = mem::take(&mut left_room.room);
5316 channel = mem::take(&mut left_room.channel);
5317
5318 room_updated(&room, &session.peer);
5319 } else {
5320 return Ok(());
5321 }
5322
5323 if let Some(channel) = channel {
5324 channel_updated(
5325 &channel,
5326 &room,
5327 &session.peer,
5328 &*session.connection_pool().await,
5329 );
5330 }
5331
5332 {
5333 let pool = session.connection_pool().await;
5334 for canceled_user_id in canceled_calls_to_user_ids {
5335 for connection_id in pool.user_connection_ids(canceled_user_id) {
5336 session
5337 .peer
5338 .send(
5339 connection_id,
5340 proto::CallCanceled {
5341 room_id: room_id.to_proto(),
5342 },
5343 )
5344 .trace_err();
5345 }
5346 contacts_to_update.insert(canceled_user_id);
5347 }
5348 }
5349
5350 for contact_user_id in contacts_to_update {
5351 update_user_contacts(contact_user_id, &session).await?;
5352 }
5353
5354 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5355 live_kit
5356 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5357 .await
5358 .trace_err();
5359
5360 if delete_live_kit_room {
5361 live_kit.delete_room(live_kit_room).await.trace_err();
5362 }
5363 }
5364
5365 Ok(())
5366}
5367
5368async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5369 let left_channel_buffers = session
5370 .db()
5371 .await
5372 .leave_channel_buffers(session.connection_id)
5373 .await?;
5374
5375 for left_buffer in left_channel_buffers {
5376 channel_buffer_updated(
5377 session.connection_id,
5378 left_buffer.connections,
5379 &proto::UpdateChannelBufferCollaborators {
5380 channel_id: left_buffer.channel_id.to_proto(),
5381 collaborators: left_buffer.collaborators,
5382 },
5383 &session.peer,
5384 );
5385 }
5386
5387 Ok(())
5388}
5389
5390fn project_left(project: &db::LeftProject, session: &UserSession) {
5391 for connection_id in &project.connection_ids {
5392 if project.should_unshare {
5393 session
5394 .peer
5395 .send(
5396 *connection_id,
5397 proto::UnshareProject {
5398 project_id: project.id.to_proto(),
5399 },
5400 )
5401 .trace_err();
5402 } else {
5403 session
5404 .peer
5405 .send(
5406 *connection_id,
5407 proto::RemoveProjectCollaborator {
5408 project_id: project.id.to_proto(),
5409 peer_id: Some(session.connection_id.into()),
5410 },
5411 )
5412 .trace_err();
5413 }
5414 }
5415}
5416
5417pub trait ResultExt {
5418 type Ok;
5419
5420 fn trace_err(self) -> Option<Self::Ok>;
5421}
5422
5423impl<T, E> ResultExt for Result<T, E>
5424where
5425 E: std::fmt::Debug,
5426{
5427 type Ok = T;
5428
5429 #[track_caller]
5430 fn trace_err(self) -> Option<T> {
5431 match self {
5432 Ok(value) => Some(value),
5433 Err(error) => {
5434 tracing::error!("{:?}", error);
5435 None
5436 }
5437 }
5438 }
5439}