1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, MutexGuard, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111 DevServer(dev_server::Model),
112}
113
114impl Principal {
115 fn update_span(&self, span: &tracing::Span) {
116 match &self {
117 Principal::User(user) => {
118 span.record("user_id", &user.id.0);
119 span.record("login", &user.github_login);
120 }
121 Principal::Impersonated { user, admin } => {
122 span.record("user_id", &user.id.0);
123 span.record("login", &user.github_login);
124 span.record("impersonator", &admin.github_login);
125 }
126 Principal::DevServer(dev_server) => {
127 span.record("dev_server_id", &dev_server.id.0);
128 }
129 }
130 }
131}
132
133#[derive(Clone)]
134struct Session {
135 principal: Principal,
136 connection_id: ConnectionId,
137 db: Arc<tokio::sync::Mutex<DbHandle>>,
138 peer: Arc<Peer>,
139 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
140 app_state: Arc<AppState>,
141 supermaven_client: Option<Arc<SupermavenAdminApi>>,
142 http_client: Arc<IsahcHttpClient>,
143 /// The GeoIP country code for the user.
144 #[allow(unused)]
145 geoip_country_code: Option<String>,
146 _executor: Executor,
147}
148
149impl Session {
150 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
151 #[cfg(test)]
152 tokio::task::yield_now().await;
153 let guard = self.db.lock().await;
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 guard
157 }
158
159 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.connection_pool.lock();
163 ConnectionPoolGuard {
164 guard,
165 _not_send: PhantomData,
166 }
167 }
168
169 fn for_user(self) -> Option<UserSession> {
170 UserSession::new(self)
171 }
172
173 fn for_dev_server(self) -> Option<DevServerSession> {
174 DevServerSession::new(self)
175 }
176
177 fn user_id(&self) -> Option<UserId> {
178 match &self.principal {
179 Principal::User(user) => Some(user.id),
180 Principal::Impersonated { user, .. } => Some(user.id),
181 Principal::DevServer(_) => None,
182 }
183 }
184
185 fn is_staff(&self) -> bool {
186 match &self.principal {
187 Principal::User(user) => user.admin,
188 Principal::Impersonated { .. } => true,
189 Principal::DevServer(_) => false,
190 }
191 }
192
193 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
194 if self.is_staff() {
195 return Ok(proto::Plan::ZedPro);
196 }
197
198 let Some(user_id) = self.user_id() else {
199 return Ok(proto::Plan::Free);
200 };
201
202 if db.has_active_billing_subscription(user_id).await? {
203 Ok(proto::Plan::ZedPro)
204 } else {
205 Ok(proto::Plan::Free)
206 }
207 }
208
209 fn dev_server_id(&self) -> Option<DevServerId> {
210 match &self.principal {
211 Principal::User(_) | Principal::Impersonated { .. } => None,
212 Principal::DevServer(dev_server) => Some(dev_server.id),
213 }
214 }
215
216 fn principal_id(&self) -> PrincipalId {
217 match &self.principal {
218 Principal::User(user) => PrincipalId::UserId(user.id),
219 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
220 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
221 }
222 }
223}
224
225impl Debug for Session {
226 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
227 let mut result = f.debug_struct("Session");
228 match &self.principal {
229 Principal::User(user) => {
230 result.field("user", &user.github_login);
231 }
232 Principal::Impersonated { user, admin } => {
233 result.field("user", &user.github_login);
234 result.field("impersonator", &admin.github_login);
235 }
236 Principal::DevServer(dev_server) => {
237 result.field("dev_server", &dev_server.id);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct UserSession(Session);
245
246impl UserSession {
247 pub fn new(s: Session) -> Option<Self> {
248 s.user_id().map(|_| UserSession(s))
249 }
250 pub fn user_id(&self) -> UserId {
251 self.0.user_id().unwrap()
252 }
253
254 pub fn email(&self) -> Option<String> {
255 match &self.0.principal {
256 Principal::User(user) => user.email_address.clone(),
257 Principal::Impersonated { user, .. } => user.email_address.clone(),
258 Principal::DevServer(..) => None,
259 }
260 }
261}
262
263impl Deref for UserSession {
264 type Target = Session;
265
266 fn deref(&self) -> &Self::Target {
267 &self.0
268 }
269}
270impl DerefMut for UserSession {
271 fn deref_mut(&mut self) -> &mut Self::Target {
272 &mut self.0
273 }
274}
275
276struct DevServerSession(Session);
277
278impl DevServerSession {
279 pub fn new(s: Session) -> Option<Self> {
280 s.dev_server_id().map(|_| DevServerSession(s))
281 }
282 pub fn dev_server_id(&self) -> DevServerId {
283 self.0.dev_server_id().unwrap()
284 }
285
286 fn dev_server(&self) -> &dev_server::Model {
287 match &self.0.principal {
288 Principal::DevServer(dev_server) => dev_server,
289 _ => unreachable!(),
290 }
291 }
292}
293
294impl Deref for DevServerSession {
295 type Target = Session;
296
297 fn deref(&self) -> &Self::Target {
298 &self.0
299 }
300}
301impl DerefMut for DevServerSession {
302 fn deref_mut(&mut self) -> &mut Self::Target {
303 &mut self.0
304 }
305}
306
307fn user_handler<M: RequestMessage, Fut>(
308 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
310where
311 Fut: Send + Future<Output = Result<()>>,
312{
313 let handler = Arc::new(handler);
314 move |message, response, session| {
315 let handler = handler.clone();
316 Box::pin(async move {
317 if let Some(user_session) = session.for_user() {
318 Ok(handler(message, response, user_session).await?)
319 } else {
320 Err(Error::Internal(anyhow!(
321 "must be a user to call {}",
322 M::NAME
323 )))
324 }
325 })
326 }
327}
328
329fn dev_server_handler<M: RequestMessage, Fut>(
330 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
332where
333 Fut: Send + Future<Output = Result<()>>,
334{
335 let handler = Arc::new(handler);
336 move |message, response, session| {
337 let handler = handler.clone();
338 Box::pin(async move {
339 if let Some(dev_server_session) = session.for_dev_server() {
340 Ok(handler(message, response, dev_server_session).await?)
341 } else {
342 Err(Error::Internal(anyhow!(
343 "must be a dev server to call {}",
344 M::NAME
345 )))
346 }
347 })
348 }
349}
350
351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
352 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
354where
355 InnertRetFut: Send + Future<Output = Result<()>>,
356{
357 let handler = Arc::new(handler);
358 move |message, session| {
359 let handler = handler.clone();
360 Box::pin(async move {
361 if let Some(user_session) = session.for_user() {
362 Ok(handler(message, user_session).await?)
363 } else {
364 Err(Error::Internal(anyhow!(
365 "must be a user to call {}",
366 M::NAME
367 )))
368 }
369 })
370 }
371}
372
373struct DbHandle(Arc<Database>);
374
375impl Deref for DbHandle {
376 type Target = Database;
377
378 fn deref(&self) -> &Self::Target {
379 self.0.as_ref()
380 }
381}
382
383pub struct Server {
384 id: parking_lot::Mutex<ServerId>,
385 peer: Arc<Peer>,
386 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
387 app_state: Arc<AppState>,
388 handlers: HashMap<TypeId, MessageHandler>,
389 teardown: watch::Sender<bool>,
390}
391
392pub(crate) struct ConnectionPoolGuard<'a> {
393 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
394 _not_send: PhantomData<Rc<()>>,
395}
396
397#[derive(Serialize)]
398pub struct ServerSnapshot<'a> {
399 peer: &'a Peer,
400 #[serde(serialize_with = "serialize_deref")]
401 connection_pool: ConnectionPoolGuard<'a>,
402}
403
404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
405where
406 S: Serializer,
407 T: Deref<Target = U>,
408 U: Serialize,
409{
410 Serialize::serialize(value.deref(), serializer)
411}
412
413impl Server {
414 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
415 let mut server = Self {
416 id: parking_lot::Mutex::new(id),
417 peer: Peer::new(id.0 as u32),
418 app_state: app_state.clone(),
419 connection_pool: Default::default(),
420 handlers: Default::default(),
421 teardown: watch::channel(false).0,
422 };
423
424 server
425 .add_request_handler(ping)
426 .add_request_handler(user_handler(create_room))
427 .add_request_handler(user_handler(join_room))
428 .add_request_handler(user_handler(rejoin_room))
429 .add_request_handler(user_handler(leave_room))
430 .add_request_handler(user_handler(set_room_participant_role))
431 .add_request_handler(user_handler(call))
432 .add_request_handler(user_handler(cancel_call))
433 .add_message_handler(user_message_handler(decline_call))
434 .add_request_handler(user_handler(update_participant_location))
435 .add_request_handler(user_handler(share_project))
436 .add_message_handler(unshare_project)
437 .add_request_handler(user_handler(join_project))
438 .add_request_handler(user_handler(join_hosted_project))
439 .add_request_handler(user_handler(rejoin_dev_server_projects))
440 .add_request_handler(user_handler(create_dev_server_project))
441 .add_request_handler(user_handler(update_dev_server_project))
442 .add_request_handler(user_handler(delete_dev_server_project))
443 .add_request_handler(user_handler(create_dev_server))
444 .add_request_handler(user_handler(regenerate_dev_server_token))
445 .add_request_handler(user_handler(rename_dev_server))
446 .add_request_handler(user_handler(delete_dev_server))
447 .add_request_handler(user_handler(list_remote_directory))
448 .add_request_handler(dev_server_handler(share_dev_server_project))
449 .add_request_handler(dev_server_handler(shutdown_dev_server))
450 .add_request_handler(dev_server_handler(reconnect_dev_server))
451 .add_message_handler(user_message_handler(leave_project))
452 .add_request_handler(update_project)
453 .add_request_handler(update_worktree)
454 .add_message_handler(start_language_server)
455 .add_message_handler(update_language_server)
456 .add_message_handler(update_diagnostic_summary)
457 .add_message_handler(update_worktree_settings)
458 .add_request_handler(user_handler(
459 forward_project_request_for_owner::<proto::TaskContextForLocation>,
460 ))
461 .add_request_handler(user_handler(
462 forward_project_request_for_owner::<proto::TaskTemplates>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetHover>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetDefinition>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetTypeDefinition>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetReferences>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::SearchProject>,
478 ))
479 .add_request_handler(user_handler(forward_find_search_candidates_request))
480 .add_request_handler(user_handler(
481 forward_read_only_project_request::<proto::GetDocumentHighlights>,
482 ))
483 .add_request_handler(user_handler(
484 forward_read_only_project_request::<proto::GetProjectSymbols>,
485 ))
486 .add_request_handler(user_handler(
487 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
488 ))
489 .add_request_handler(user_handler(
490 forward_read_only_project_request::<proto::OpenBufferById>,
491 ))
492 .add_request_handler(user_handler(
493 forward_read_only_project_request::<proto::SynchronizeBuffers>,
494 ))
495 .add_request_handler(user_handler(
496 forward_read_only_project_request::<proto::InlayHints>,
497 ))
498 .add_request_handler(user_handler(
499 forward_read_only_project_request::<proto::ResolveInlayHint>,
500 ))
501 .add_request_handler(user_handler(
502 forward_read_only_project_request::<proto::OpenBufferByPath>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::GetCompletions>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::OpenNewBuffer>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
515 ))
516 .add_request_handler(user_handler(
517 forward_mutating_project_request::<proto::GetCodeActions>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::ApplyCodeAction>,
521 ))
522 .add_request_handler(user_handler(
523 forward_mutating_project_request::<proto::PrepareRename>,
524 ))
525 .add_request_handler(user_handler(
526 forward_mutating_project_request::<proto::PerformRename>,
527 ))
528 .add_request_handler(user_handler(
529 forward_mutating_project_request::<proto::ReloadBuffers>,
530 ))
531 .add_request_handler(user_handler(
532 forward_mutating_project_request::<proto::FormatBuffers>,
533 ))
534 .add_request_handler(user_handler(
535 forward_mutating_project_request::<proto::CreateProjectEntry>,
536 ))
537 .add_request_handler(user_handler(
538 forward_mutating_project_request::<proto::RenameProjectEntry>,
539 ))
540 .add_request_handler(user_handler(
541 forward_mutating_project_request::<proto::CopyProjectEntry>,
542 ))
543 .add_request_handler(user_handler(
544 forward_mutating_project_request::<proto::DeleteProjectEntry>,
545 ))
546 .add_request_handler(user_handler(
547 forward_mutating_project_request::<proto::ExpandProjectEntry>,
548 ))
549 .add_request_handler(user_handler(
550 forward_mutating_project_request::<proto::OnTypeFormatting>,
551 ))
552 .add_request_handler(user_handler(
553 forward_mutating_project_request::<proto::SaveBuffer>,
554 ))
555 .add_request_handler(user_handler(
556 forward_mutating_project_request::<proto::BlameBuffer>,
557 ))
558 .add_request_handler(user_handler(
559 forward_mutating_project_request::<proto::MultiLspQuery>,
560 ))
561 .add_request_handler(user_handler(
562 forward_mutating_project_request::<proto::RestartLanguageServers>,
563 ))
564 .add_request_handler(user_handler(
565 forward_mutating_project_request::<proto::LinkedEditingRange>,
566 ))
567 .add_message_handler(create_buffer_for_peer)
568 .add_request_handler(update_buffer)
569 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
573 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
574 .add_request_handler(get_users)
575 .add_request_handler(user_handler(fuzzy_search_users))
576 .add_request_handler(user_handler(request_contact))
577 .add_request_handler(user_handler(remove_contact))
578 .add_request_handler(user_handler(respond_to_contact_request))
579 .add_message_handler(subscribe_to_channels)
580 .add_request_handler(user_handler(create_channel))
581 .add_request_handler(user_handler(delete_channel))
582 .add_request_handler(user_handler(invite_channel_member))
583 .add_request_handler(user_handler(remove_channel_member))
584 .add_request_handler(user_handler(set_channel_member_role))
585 .add_request_handler(user_handler(set_channel_visibility))
586 .add_request_handler(user_handler(rename_channel))
587 .add_request_handler(user_handler(join_channel_buffer))
588 .add_request_handler(user_handler(leave_channel_buffer))
589 .add_message_handler(user_message_handler(update_channel_buffer))
590 .add_request_handler(user_handler(rejoin_channel_buffers))
591 .add_request_handler(user_handler(get_channel_members))
592 .add_request_handler(user_handler(respond_to_channel_invite))
593 .add_request_handler(user_handler(join_channel))
594 .add_request_handler(user_handler(join_channel_chat))
595 .add_message_handler(user_message_handler(leave_channel_chat))
596 .add_request_handler(user_handler(send_channel_message))
597 .add_request_handler(user_handler(remove_channel_message))
598 .add_request_handler(user_handler(update_channel_message))
599 .add_request_handler(user_handler(get_channel_messages))
600 .add_request_handler(user_handler(get_channel_messages_by_id))
601 .add_request_handler(user_handler(get_notifications))
602 .add_request_handler(user_handler(mark_notification_as_read))
603 .add_request_handler(user_handler(move_channel))
604 .add_request_handler(user_handler(follow))
605 .add_message_handler(user_message_handler(unfollow))
606 .add_message_handler(user_message_handler(update_followers))
607 .add_request_handler(user_handler(get_private_user_info))
608 .add_request_handler(user_handler(get_llm_api_token))
609 .add_request_handler(user_handler(accept_terms_of_service))
610 .add_message_handler(user_message_handler(acknowledge_channel_message))
611 .add_message_handler(user_message_handler(acknowledge_buffer_version))
612 .add_request_handler(user_handler(get_supermaven_api_key))
613 .add_request_handler(user_handler(
614 forward_mutating_project_request::<proto::OpenContext>,
615 ))
616 .add_request_handler(user_handler(
617 forward_mutating_project_request::<proto::CreateContext>,
618 ))
619 .add_request_handler(user_handler(
620 forward_mutating_project_request::<proto::SynchronizeContexts>,
621 ))
622 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
623 .add_message_handler(update_context)
624 .add_request_handler({
625 let app_state = app_state.clone();
626 move |request, response, session| {
627 let app_state = app_state.clone();
628 async move {
629 count_language_model_tokens(request, response, session, &app_state.config)
630 .await
631 }
632 }
633 })
634 .add_request_handler({
635 user_handler(move |request, response, session| {
636 get_cached_embeddings(request, response, session)
637 })
638 })
639 .add_request_handler({
640 let app_state = app_state.clone();
641 user_handler(move |request, response, session| {
642 compute_embeddings(
643 request,
644 response,
645 session,
646 app_state.config.openai_api_key.clone(),
647 )
648 })
649 });
650
651 Arc::new(server)
652 }
653
654 pub async fn start(&self) -> Result<()> {
655 let server_id = *self.id.lock();
656 let app_state = self.app_state.clone();
657 let peer = self.peer.clone();
658 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
659 let pool = self.connection_pool.clone();
660 let live_kit_client = self.app_state.live_kit_client.clone();
661
662 let span = info_span!("start server");
663 self.app_state.executor.spawn_detached(
664 async move {
665 tracing::info!("waiting for cleanup timeout");
666 timeout.await;
667 tracing::info!("cleanup timeout expired, retrieving stale rooms");
668 if let Some((room_ids, channel_ids)) = app_state
669 .db
670 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
671 .await
672 .trace_err()
673 {
674 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
675 tracing::info!(
676 stale_channel_buffer_count = channel_ids.len(),
677 "retrieved stale channel buffers"
678 );
679
680 for channel_id in channel_ids {
681 if let Some(refreshed_channel_buffer) = app_state
682 .db
683 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
684 .await
685 .trace_err()
686 {
687 for connection_id in refreshed_channel_buffer.connection_ids {
688 peer.send(
689 connection_id,
690 proto::UpdateChannelBufferCollaborators {
691 channel_id: channel_id.to_proto(),
692 collaborators: refreshed_channel_buffer
693 .collaborators
694 .clone(),
695 },
696 )
697 .trace_err();
698 }
699 }
700 }
701
702 for room_id in room_ids {
703 let mut contacts_to_update = HashSet::default();
704 let mut canceled_calls_to_user_ids = Vec::new();
705 let mut live_kit_room = String::new();
706 let mut delete_live_kit_room = false;
707
708 if let Some(mut refreshed_room) = app_state
709 .db
710 .clear_stale_room_participants(room_id, server_id)
711 .await
712 .trace_err()
713 {
714 tracing::info!(
715 room_id = room_id.0,
716 new_participant_count = refreshed_room.room.participants.len(),
717 "refreshed room"
718 );
719 room_updated(&refreshed_room.room, &peer);
720 if let Some(channel) = refreshed_room.channel.as_ref() {
721 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
722 }
723 contacts_to_update
724 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
725 contacts_to_update
726 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
727 canceled_calls_to_user_ids =
728 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
729 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
730 delete_live_kit_room = refreshed_room.room.participants.is_empty();
731 }
732
733 {
734 let pool = pool.lock();
735 for canceled_user_id in canceled_calls_to_user_ids {
736 for connection_id in pool.user_connection_ids(canceled_user_id) {
737 peer.send(
738 connection_id,
739 proto::CallCanceled {
740 room_id: room_id.to_proto(),
741 },
742 )
743 .trace_err();
744 }
745 }
746 }
747
748 for user_id in contacts_to_update {
749 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
750 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
751 if let Some((busy, contacts)) = busy.zip(contacts) {
752 let pool = pool.lock();
753 let updated_contact = contact_for_user(user_id, busy, &pool);
754 for contact in contacts {
755 if let db::Contact::Accepted {
756 user_id: contact_user_id,
757 ..
758 } = contact
759 {
760 for contact_conn_id in
761 pool.user_connection_ids(contact_user_id)
762 {
763 peer.send(
764 contact_conn_id,
765 proto::UpdateContacts {
766 contacts: vec![updated_contact.clone()],
767 remove_contacts: Default::default(),
768 incoming_requests: Default::default(),
769 remove_incoming_requests: Default::default(),
770 outgoing_requests: Default::default(),
771 remove_outgoing_requests: Default::default(),
772 },
773 )
774 .trace_err();
775 }
776 }
777 }
778 }
779 }
780
781 if let Some(live_kit) = live_kit_client.as_ref() {
782 if delete_live_kit_room {
783 live_kit.delete_room(live_kit_room).await.trace_err();
784 }
785 }
786 }
787 }
788
789 app_state
790 .db
791 .delete_stale_servers(&app_state.config.zed_environment, server_id)
792 .await
793 .trace_err();
794 }
795 .instrument(span),
796 );
797 Ok(())
798 }
799
800 pub fn teardown(&self) {
801 self.peer.teardown();
802 self.connection_pool.lock().reset();
803 let _ = self.teardown.send(true);
804 }
805
806 #[cfg(test)]
807 pub fn reset(&self, id: ServerId) {
808 self.teardown();
809 *self.id.lock() = id;
810 self.peer.reset(id.0 as u32);
811 let _ = self.teardown.send(false);
812 }
813
814 #[cfg(test)]
815 pub fn id(&self) -> ServerId {
816 *self.id.lock()
817 }
818
819 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
820 where
821 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
822 Fut: 'static + Send + Future<Output = Result<()>>,
823 M: EnvelopedMessage,
824 {
825 let prev_handler = self.handlers.insert(
826 TypeId::of::<M>(),
827 Box::new(move |envelope, session| {
828 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
829 let received_at = envelope.received_at;
830 tracing::info!("message received");
831 let start_time = Instant::now();
832 let future = (handler)(*envelope, session);
833 async move {
834 let result = future.await;
835 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
836 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
837 let queue_duration_ms = total_duration_ms - processing_duration_ms;
838 let payload_type = M::NAME;
839
840 match result {
841 Err(error) => {
842 tracing::error!(
843 ?error,
844 total_duration_ms,
845 processing_duration_ms,
846 queue_duration_ms,
847 payload_type,
848 "error handling message"
849 )
850 }
851 Ok(()) => tracing::info!(
852 total_duration_ms,
853 processing_duration_ms,
854 queue_duration_ms,
855 "finished handling message"
856 ),
857 }
858 }
859 .boxed()
860 }),
861 );
862 if prev_handler.is_some() {
863 panic!("registered a handler for the same message twice");
864 }
865 self
866 }
867
868 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
869 where
870 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
871 Fut: 'static + Send + Future<Output = Result<()>>,
872 M: EnvelopedMessage,
873 {
874 self.add_handler(move |envelope, session| handler(envelope.payload, session));
875 self
876 }
877
878 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
879 where
880 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
881 Fut: Send + Future<Output = Result<()>>,
882 M: RequestMessage,
883 {
884 let handler = Arc::new(handler);
885 self.add_handler(move |envelope, session| {
886 let receipt = envelope.receipt();
887 let handler = handler.clone();
888 async move {
889 let peer = session.peer.clone();
890 let responded = Arc::new(AtomicBool::default());
891 let response = Response {
892 peer: peer.clone(),
893 responded: responded.clone(),
894 receipt,
895 };
896 match (handler)(envelope.payload, response, session).await {
897 Ok(()) => {
898 if responded.load(std::sync::atomic::Ordering::SeqCst) {
899 Ok(())
900 } else {
901 Err(anyhow!("handler did not send a response"))?
902 }
903 }
904 Err(error) => {
905 let proto_err = match &error {
906 Error::Internal(err) => err.to_proto(),
907 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
908 };
909 peer.respond_with_error(receipt, proto_err)?;
910 Err(error)
911 }
912 }
913 }
914 })
915 }
916
917 #[allow(clippy::too_many_arguments)]
918 pub fn handle_connection(
919 self: &Arc<Self>,
920 connection: Connection,
921 address: String,
922 principal: Principal,
923 zed_version: ZedVersion,
924 geoip_country_code: Option<String>,
925 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
926 executor: Executor,
927 ) -> impl Future<Output = ()> {
928 let this = self.clone();
929 let span = info_span!("handle connection", %address,
930 connection_id=field::Empty,
931 user_id=field::Empty,
932 login=field::Empty,
933 impersonator=field::Empty,
934 dev_server_id=field::Empty,
935 geoip_country_code=field::Empty
936 );
937 principal.update_span(&span);
938 if let Some(country_code) = geoip_country_code.as_ref() {
939 span.record("geoip_country_code", country_code);
940 }
941
942 let mut teardown = self.teardown.subscribe();
943 async move {
944 if *teardown.borrow() {
945 tracing::error!("server is tearing down");
946 return
947 }
948 let (connection_id, handle_io, mut incoming_rx) = this
949 .peer
950 .add_connection(connection, {
951 let executor = executor.clone();
952 move |duration| executor.sleep(duration)
953 });
954 tracing::Span::current().record("connection_id", format!("{}", connection_id));
955
956 tracing::info!("connection opened");
957
958 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
959 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
960 Ok(http_client) => Arc::new(http_client),
961 Err(error) => {
962 tracing::error!(?error, "failed to create HTTP client");
963 return;
964 }
965 };
966
967 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
968 Some(Arc::new(SupermavenAdminApi::new(
969 supermaven_admin_api_key.to_string(),
970 http_client.clone(),
971 )))
972 } else {
973 None
974 };
975
976 let session = Session {
977 principal: principal.clone(),
978 connection_id,
979 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
980 peer: this.peer.clone(),
981 connection_pool: this.connection_pool.clone(),
982 app_state: this.app_state.clone(),
983 http_client,
984 geoip_country_code,
985 _executor: executor.clone(),
986 supermaven_client,
987 };
988
989 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
990 tracing::error!(?error, "failed to send initial client update");
991 return;
992 }
993
994 let handle_io = handle_io.fuse();
995 futures::pin_mut!(handle_io);
996
997 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
998 // This prevents deadlocks when e.g., client A performs a request to client B and
999 // client B performs a request to client A. If both clients stop processing further
1000 // messages until their respective request completes, they won't have a chance to
1001 // respond to the other client's request and cause a deadlock.
1002 //
1003 // This arrangement ensures we will attempt to process earlier messages first, but fall
1004 // back to processing messages arrived later in the spirit of making progress.
1005 let mut foreground_message_handlers = FuturesUnordered::new();
1006 let concurrent_handlers = Arc::new(Semaphore::new(256));
1007 loop {
1008 let next_message = async {
1009 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1010 let message = incoming_rx.next().await;
1011 (permit, message)
1012 }.fuse();
1013 futures::pin_mut!(next_message);
1014 futures::select_biased! {
1015 _ = teardown.changed().fuse() => return,
1016 result = handle_io => {
1017 if let Err(error) = result {
1018 tracing::error!(?error, "error handling I/O");
1019 }
1020 break;
1021 }
1022 _ = foreground_message_handlers.next() => {}
1023 next_message = next_message => {
1024 let (permit, message) = next_message;
1025 if let Some(message) = message {
1026 let type_name = message.payload_type_name();
1027 // note: we copy all the fields from the parent span so we can query them in the logs.
1028 // (https://github.com/tokio-rs/tracing/issues/2670).
1029 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1030 user_id=field::Empty,
1031 login=field::Empty,
1032 impersonator=field::Empty,
1033 dev_server_id=field::Empty
1034 );
1035 principal.update_span(&span);
1036 let span_enter = span.enter();
1037 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1038 let is_background = message.is_background();
1039 let handle_message = (handler)(message, session.clone());
1040 drop(span_enter);
1041
1042 let handle_message = async move {
1043 handle_message.await;
1044 drop(permit);
1045 }.instrument(span);
1046 if is_background {
1047 executor.spawn_detached(handle_message);
1048 } else {
1049 foreground_message_handlers.push(handle_message);
1050 }
1051 } else {
1052 tracing::error!("no message handler");
1053 }
1054 } else {
1055 tracing::info!("connection closed");
1056 break;
1057 }
1058 }
1059 }
1060 }
1061
1062 drop(foreground_message_handlers);
1063 tracing::info!("signing out");
1064 if let Err(error) = connection_lost(session, teardown, executor).await {
1065 tracing::error!(?error, "error signing out");
1066 }
1067
1068 }.instrument(span)
1069 }
1070
1071 async fn send_initial_client_update(
1072 &self,
1073 connection_id: ConnectionId,
1074 principal: &Principal,
1075 zed_version: ZedVersion,
1076 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1077 session: &Session,
1078 ) -> Result<()> {
1079 self.peer.send(
1080 connection_id,
1081 proto::Hello {
1082 peer_id: Some(connection_id.into()),
1083 },
1084 )?;
1085 tracing::info!("sent hello message");
1086 if let Some(send_connection_id) = send_connection_id.take() {
1087 let _ = send_connection_id.send(connection_id);
1088 }
1089
1090 match principal {
1091 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1092 if !user.connected_once {
1093 self.peer.send(connection_id, proto::ShowContacts {})?;
1094 self.app_state
1095 .db
1096 .set_user_connected_once(user.id, true)
1097 .await?;
1098 }
1099
1100 update_user_plan(user.id, session).await?;
1101
1102 let (contacts, dev_server_projects) = future::try_join(
1103 self.app_state.db.get_contacts(user.id),
1104 self.app_state.db.dev_server_projects_update(user.id),
1105 )
1106 .await?;
1107
1108 {
1109 let mut pool = self.connection_pool.lock();
1110 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1111 self.peer.send(
1112 connection_id,
1113 build_initial_contacts_update(contacts, &pool),
1114 )?;
1115 }
1116
1117 if should_auto_subscribe_to_channels(zed_version) {
1118 subscribe_user_to_channels(user.id, session).await?;
1119 }
1120
1121 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1122
1123 if let Some(incoming_call) =
1124 self.app_state.db.incoming_call_for_user(user.id).await?
1125 {
1126 self.peer.send(connection_id, incoming_call)?;
1127 }
1128
1129 update_user_contacts(user.id, &session).await?;
1130 }
1131 Principal::DevServer(dev_server) => {
1132 {
1133 let mut pool = self.connection_pool.lock();
1134 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1135 {
1136 self.peer.send(
1137 stale_connection_id,
1138 proto::ShutdownDevServer {
1139 reason: Some(
1140 "another dev server connected with the same token".to_string(),
1141 ),
1142 },
1143 )?;
1144 pool.remove_connection(stale_connection_id)?;
1145 };
1146 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1147 }
1148
1149 let projects = self
1150 .app_state
1151 .db
1152 .get_projects_for_dev_server(dev_server.id)
1153 .await?;
1154 self.peer
1155 .send(connection_id, proto::DevServerInstructions { projects })?;
1156
1157 let status = self
1158 .app_state
1159 .db
1160 .dev_server_projects_update(dev_server.user_id)
1161 .await?;
1162 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1163 }
1164 }
1165
1166 Ok(())
1167 }
1168
1169 pub async fn invite_code_redeemed(
1170 self: &Arc<Self>,
1171 inviter_id: UserId,
1172 invitee_id: UserId,
1173 ) -> Result<()> {
1174 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1175 if let Some(code) = &user.invite_code {
1176 let pool = self.connection_pool.lock();
1177 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1178 for connection_id in pool.user_connection_ids(inviter_id) {
1179 self.peer.send(
1180 connection_id,
1181 proto::UpdateContacts {
1182 contacts: vec![invitee_contact.clone()],
1183 ..Default::default()
1184 },
1185 )?;
1186 self.peer.send(
1187 connection_id,
1188 proto::UpdateInviteInfo {
1189 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1190 count: user.invite_count as u32,
1191 },
1192 )?;
1193 }
1194 }
1195 }
1196 Ok(())
1197 }
1198
1199 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1200 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1201 if let Some(invite_code) = &user.invite_code {
1202 let pool = self.connection_pool.lock();
1203 for connection_id in pool.user_connection_ids(user_id) {
1204 self.peer.send(
1205 connection_id,
1206 proto::UpdateInviteInfo {
1207 url: format!(
1208 "{}{}",
1209 self.app_state.config.invite_link_prefix, invite_code
1210 ),
1211 count: user.invite_count as u32,
1212 },
1213 )?;
1214 }
1215 }
1216 }
1217 Ok(())
1218 }
1219
1220 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1221 ServerSnapshot {
1222 connection_pool: ConnectionPoolGuard {
1223 guard: self.connection_pool.lock(),
1224 _not_send: PhantomData,
1225 },
1226 peer: &self.peer,
1227 }
1228 }
1229}
1230
1231impl<'a> Deref for ConnectionPoolGuard<'a> {
1232 type Target = ConnectionPool;
1233
1234 fn deref(&self) -> &Self::Target {
1235 &self.guard
1236 }
1237}
1238
1239impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1240 fn deref_mut(&mut self) -> &mut Self::Target {
1241 &mut self.guard
1242 }
1243}
1244
1245impl<'a> Drop for ConnectionPoolGuard<'a> {
1246 fn drop(&mut self) {
1247 #[cfg(test)]
1248 self.check_invariants();
1249 }
1250}
1251
1252fn broadcast<F>(
1253 sender_id: Option<ConnectionId>,
1254 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1255 mut f: F,
1256) where
1257 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1258{
1259 for receiver_id in receiver_ids {
1260 if Some(receiver_id) != sender_id {
1261 if let Err(error) = f(receiver_id) {
1262 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1263 }
1264 }
1265 }
1266}
1267
1268pub struct ProtocolVersion(u32);
1269
1270impl Header for ProtocolVersion {
1271 fn name() -> &'static HeaderName {
1272 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1273 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1274 }
1275
1276 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1277 where
1278 Self: Sized,
1279 I: Iterator<Item = &'i axum::http::HeaderValue>,
1280 {
1281 let version = values
1282 .next()
1283 .ok_or_else(axum::headers::Error::invalid)?
1284 .to_str()
1285 .map_err(|_| axum::headers::Error::invalid())?
1286 .parse()
1287 .map_err(|_| axum::headers::Error::invalid())?;
1288 Ok(Self(version))
1289 }
1290
1291 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1292 values.extend([self.0.to_string().parse().unwrap()]);
1293 }
1294}
1295
1296pub struct AppVersionHeader(SemanticVersion);
1297impl Header for AppVersionHeader {
1298 fn name() -> &'static HeaderName {
1299 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1300 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1301 }
1302
1303 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1304 where
1305 Self: Sized,
1306 I: Iterator<Item = &'i axum::http::HeaderValue>,
1307 {
1308 let version = values
1309 .next()
1310 .ok_or_else(axum::headers::Error::invalid)?
1311 .to_str()
1312 .map_err(|_| axum::headers::Error::invalid())?
1313 .parse()
1314 .map_err(|_| axum::headers::Error::invalid())?;
1315 Ok(Self(version))
1316 }
1317
1318 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1319 values.extend([self.0.to_string().parse().unwrap()]);
1320 }
1321}
1322
1323pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1324 Router::new()
1325 .route("/rpc", get(handle_websocket_request))
1326 .layer(
1327 ServiceBuilder::new()
1328 .layer(Extension(server.app_state.clone()))
1329 .layer(middleware::from_fn(auth::validate_header)),
1330 )
1331 .route("/metrics", get(handle_metrics))
1332 .layer(Extension(server))
1333}
1334
1335pub async fn handle_websocket_request(
1336 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1337 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1338 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1339 Extension(server): Extension<Arc<Server>>,
1340 Extension(principal): Extension<Principal>,
1341 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1342 ws: WebSocketUpgrade,
1343) -> axum::response::Response {
1344 if protocol_version != rpc::PROTOCOL_VERSION {
1345 return (
1346 StatusCode::UPGRADE_REQUIRED,
1347 "client must be upgraded".to_string(),
1348 )
1349 .into_response();
1350 }
1351
1352 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1353 return (
1354 StatusCode::UPGRADE_REQUIRED,
1355 "no version header found".to_string(),
1356 )
1357 .into_response();
1358 };
1359
1360 if !version.can_collaborate() {
1361 return (
1362 StatusCode::UPGRADE_REQUIRED,
1363 "client must be upgraded".to_string(),
1364 )
1365 .into_response();
1366 }
1367
1368 let socket_address = socket_address.to_string();
1369 ws.on_upgrade(move |socket| {
1370 let socket = socket
1371 .map_ok(to_tungstenite_message)
1372 .err_into()
1373 .with(|message| async move { to_axum_message(message) });
1374 let connection = Connection::new(Box::pin(socket));
1375 async move {
1376 server
1377 .handle_connection(
1378 connection,
1379 socket_address,
1380 principal,
1381 version,
1382 country_code_header.map(|header| header.to_string()),
1383 None,
1384 Executor::Production,
1385 )
1386 .await;
1387 }
1388 })
1389}
1390
1391pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1392 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1393 let connections_metric = CONNECTIONS_METRIC
1394 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1395
1396 let connections = server
1397 .connection_pool
1398 .lock()
1399 .connections()
1400 .filter(|connection| !connection.admin)
1401 .count();
1402 connections_metric.set(connections as _);
1403
1404 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1405 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1406 register_int_gauge!(
1407 "shared_projects",
1408 "number of open projects with one or more guests"
1409 )
1410 .unwrap()
1411 });
1412
1413 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1414 shared_projects_metric.set(shared_projects as _);
1415
1416 let encoder = prometheus::TextEncoder::new();
1417 let metric_families = prometheus::gather();
1418 let encoded_metrics = encoder
1419 .encode_to_string(&metric_families)
1420 .map_err(|err| anyhow!("{}", err))?;
1421 Ok(encoded_metrics)
1422}
1423
1424#[instrument(err, skip(executor))]
1425async fn connection_lost(
1426 session: Session,
1427 mut teardown: watch::Receiver<bool>,
1428 executor: Executor,
1429) -> Result<()> {
1430 session.peer.disconnect(session.connection_id);
1431 session
1432 .connection_pool()
1433 .await
1434 .remove_connection(session.connection_id)?;
1435
1436 session
1437 .db()
1438 .await
1439 .connection_lost(session.connection_id)
1440 .await
1441 .trace_err();
1442
1443 futures::select_biased! {
1444 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1445 match &session.principal {
1446 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1447 let session = session.for_user().unwrap();
1448
1449 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1450 leave_room_for_session(&session, session.connection_id).await.trace_err();
1451 leave_channel_buffers_for_session(&session)
1452 .await
1453 .trace_err();
1454
1455 if !session
1456 .connection_pool()
1457 .await
1458 .is_user_online(session.user_id())
1459 {
1460 let db = session.db().await;
1461 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1462 room_updated(&room, &session.peer);
1463 }
1464 }
1465
1466 update_user_contacts(session.user_id(), &session).await?;
1467 },
1468 Principal::DevServer(_) => {
1469 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1470 },
1471 }
1472 },
1473 _ = teardown.changed().fuse() => {}
1474 }
1475
1476 Ok(())
1477}
1478
1479/// Acknowledges a ping from a client, used to keep the connection alive.
1480async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1481 response.send(proto::Ack {})?;
1482 Ok(())
1483}
1484
1485/// Creates a new room for calling (outside of channels)
1486async fn create_room(
1487 _request: proto::CreateRoom,
1488 response: Response<proto::CreateRoom>,
1489 session: UserSession,
1490) -> Result<()> {
1491 let live_kit_room = nanoid::nanoid!(30);
1492
1493 let live_kit_connection_info = util::maybe!(async {
1494 let live_kit = session.app_state.live_kit_client.as_ref();
1495 let live_kit = live_kit?;
1496 let user_id = session.user_id().to_string();
1497
1498 let token = live_kit
1499 .room_token(&live_kit_room, &user_id.to_string())
1500 .trace_err()?;
1501
1502 Some(proto::LiveKitConnectionInfo {
1503 server_url: live_kit.url().into(),
1504 token,
1505 can_publish: true,
1506 })
1507 })
1508 .await;
1509
1510 let room = session
1511 .db()
1512 .await
1513 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1514 .await?;
1515
1516 response.send(proto::CreateRoomResponse {
1517 room: Some(room.clone()),
1518 live_kit_connection_info,
1519 })?;
1520
1521 update_user_contacts(session.user_id(), &session).await?;
1522 Ok(())
1523}
1524
1525/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1526async fn join_room(
1527 request: proto::JoinRoom,
1528 response: Response<proto::JoinRoom>,
1529 session: UserSession,
1530) -> Result<()> {
1531 let room_id = RoomId::from_proto(request.id);
1532
1533 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1534
1535 if let Some(channel_id) = channel_id {
1536 return join_channel_internal(channel_id, Box::new(response), session).await;
1537 }
1538
1539 let joined_room = {
1540 let room = session
1541 .db()
1542 .await
1543 .join_room(room_id, session.user_id(), session.connection_id)
1544 .await?;
1545 room_updated(&room.room, &session.peer);
1546 room.into_inner()
1547 };
1548
1549 for connection_id in session
1550 .connection_pool()
1551 .await
1552 .user_connection_ids(session.user_id())
1553 {
1554 session
1555 .peer
1556 .send(
1557 connection_id,
1558 proto::CallCanceled {
1559 room_id: room_id.to_proto(),
1560 },
1561 )
1562 .trace_err();
1563 }
1564
1565 let live_kit_connection_info =
1566 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1567 if let Some(token) = live_kit
1568 .room_token(
1569 &joined_room.room.live_kit_room,
1570 &session.user_id().to_string(),
1571 )
1572 .trace_err()
1573 {
1574 Some(proto::LiveKitConnectionInfo {
1575 server_url: live_kit.url().into(),
1576 token,
1577 can_publish: true,
1578 })
1579 } else {
1580 None
1581 }
1582 } else {
1583 None
1584 };
1585
1586 response.send(proto::JoinRoomResponse {
1587 room: Some(joined_room.room),
1588 channel_id: None,
1589 live_kit_connection_info,
1590 })?;
1591
1592 update_user_contacts(session.user_id(), &session).await?;
1593 Ok(())
1594}
1595
1596/// Rejoin room is used to reconnect to a room after connection errors.
1597async fn rejoin_room(
1598 request: proto::RejoinRoom,
1599 response: Response<proto::RejoinRoom>,
1600 session: UserSession,
1601) -> Result<()> {
1602 let room;
1603 let channel;
1604 {
1605 let mut rejoined_room = session
1606 .db()
1607 .await
1608 .rejoin_room(request, session.user_id(), session.connection_id)
1609 .await?;
1610
1611 response.send(proto::RejoinRoomResponse {
1612 room: Some(rejoined_room.room.clone()),
1613 reshared_projects: rejoined_room
1614 .reshared_projects
1615 .iter()
1616 .map(|project| proto::ResharedProject {
1617 id: project.id.to_proto(),
1618 collaborators: project
1619 .collaborators
1620 .iter()
1621 .map(|collaborator| collaborator.to_proto())
1622 .collect(),
1623 })
1624 .collect(),
1625 rejoined_projects: rejoined_room
1626 .rejoined_projects
1627 .iter()
1628 .map(|rejoined_project| rejoined_project.to_proto())
1629 .collect(),
1630 })?;
1631 room_updated(&rejoined_room.room, &session.peer);
1632
1633 for project in &rejoined_room.reshared_projects {
1634 for collaborator in &project.collaborators {
1635 session
1636 .peer
1637 .send(
1638 collaborator.connection_id,
1639 proto::UpdateProjectCollaborator {
1640 project_id: project.id.to_proto(),
1641 old_peer_id: Some(project.old_connection_id.into()),
1642 new_peer_id: Some(session.connection_id.into()),
1643 },
1644 )
1645 .trace_err();
1646 }
1647
1648 broadcast(
1649 Some(session.connection_id),
1650 project
1651 .collaborators
1652 .iter()
1653 .map(|collaborator| collaborator.connection_id),
1654 |connection_id| {
1655 session.peer.forward_send(
1656 session.connection_id,
1657 connection_id,
1658 proto::UpdateProject {
1659 project_id: project.id.to_proto(),
1660 worktrees: project.worktrees.clone(),
1661 },
1662 )
1663 },
1664 );
1665 }
1666
1667 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1668
1669 let rejoined_room = rejoined_room.into_inner();
1670
1671 room = rejoined_room.room;
1672 channel = rejoined_room.channel;
1673 }
1674
1675 if let Some(channel) = channel {
1676 channel_updated(
1677 &channel,
1678 &room,
1679 &session.peer,
1680 &*session.connection_pool().await,
1681 );
1682 }
1683
1684 update_user_contacts(session.user_id(), &session).await?;
1685 Ok(())
1686}
1687
1688fn notify_rejoined_projects(
1689 rejoined_projects: &mut Vec<RejoinedProject>,
1690 session: &UserSession,
1691) -> Result<()> {
1692 for project in rejoined_projects.iter() {
1693 for collaborator in &project.collaborators {
1694 session
1695 .peer
1696 .send(
1697 collaborator.connection_id,
1698 proto::UpdateProjectCollaborator {
1699 project_id: project.id.to_proto(),
1700 old_peer_id: Some(project.old_connection_id.into()),
1701 new_peer_id: Some(session.connection_id.into()),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 }
1707
1708 for project in rejoined_projects {
1709 for worktree in mem::take(&mut project.worktrees) {
1710 #[cfg(any(test, feature = "test-support"))]
1711 const MAX_CHUNK_SIZE: usize = 2;
1712 #[cfg(not(any(test, feature = "test-support")))]
1713 const MAX_CHUNK_SIZE: usize = 256;
1714
1715 // Stream this worktree's entries.
1716 let message = proto::UpdateWorktree {
1717 project_id: project.id.to_proto(),
1718 worktree_id: worktree.id,
1719 abs_path: worktree.abs_path.clone(),
1720 root_name: worktree.root_name,
1721 updated_entries: worktree.updated_entries,
1722 removed_entries: worktree.removed_entries,
1723 scan_id: worktree.scan_id,
1724 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1725 updated_repositories: worktree.updated_repositories,
1726 removed_repositories: worktree.removed_repositories,
1727 };
1728 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1729 session.peer.send(session.connection_id, update.clone())?;
1730 }
1731
1732 // Stream this worktree's diagnostics.
1733 for summary in worktree.diagnostic_summaries {
1734 session.peer.send(
1735 session.connection_id,
1736 proto::UpdateDiagnosticSummary {
1737 project_id: project.id.to_proto(),
1738 worktree_id: worktree.id,
1739 summary: Some(summary),
1740 },
1741 )?;
1742 }
1743
1744 for settings_file in worktree.settings_files {
1745 session.peer.send(
1746 session.connection_id,
1747 proto::UpdateWorktreeSettings {
1748 project_id: project.id.to_proto(),
1749 worktree_id: worktree.id,
1750 path: settings_file.path,
1751 content: Some(settings_file.content),
1752 },
1753 )?;
1754 }
1755 }
1756
1757 for language_server in &project.language_servers {
1758 session.peer.send(
1759 session.connection_id,
1760 proto::UpdateLanguageServer {
1761 project_id: project.id.to_proto(),
1762 language_server_id: language_server.id,
1763 variant: Some(
1764 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1765 proto::LspDiskBasedDiagnosticsUpdated {},
1766 ),
1767 ),
1768 },
1769 )?;
1770 }
1771 }
1772 Ok(())
1773}
1774
1775/// leave room disconnects from the room.
1776async fn leave_room(
1777 _: proto::LeaveRoom,
1778 response: Response<proto::LeaveRoom>,
1779 session: UserSession,
1780) -> Result<()> {
1781 leave_room_for_session(&session, session.connection_id).await?;
1782 response.send(proto::Ack {})?;
1783 Ok(())
1784}
1785
1786/// Updates the permissions of someone else in the room.
1787async fn set_room_participant_role(
1788 request: proto::SetRoomParticipantRole,
1789 response: Response<proto::SetRoomParticipantRole>,
1790 session: UserSession,
1791) -> Result<()> {
1792 let user_id = UserId::from_proto(request.user_id);
1793 let role = ChannelRole::from(request.role());
1794
1795 let (live_kit_room, can_publish) = {
1796 let room = session
1797 .db()
1798 .await
1799 .set_room_participant_role(
1800 session.user_id(),
1801 RoomId::from_proto(request.room_id),
1802 user_id,
1803 role,
1804 )
1805 .await?;
1806
1807 let live_kit_room = room.live_kit_room.clone();
1808 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1809 room_updated(&room, &session.peer);
1810 (live_kit_room, can_publish)
1811 };
1812
1813 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1814 live_kit
1815 .update_participant(
1816 live_kit_room.clone(),
1817 request.user_id.to_string(),
1818 live_kit_server::proto::ParticipantPermission {
1819 can_subscribe: true,
1820 can_publish,
1821 can_publish_data: can_publish,
1822 hidden: false,
1823 recorder: false,
1824 },
1825 )
1826 .await
1827 .trace_err();
1828 }
1829
1830 response.send(proto::Ack {})?;
1831 Ok(())
1832}
1833
1834/// Call someone else into the current room
1835async fn call(
1836 request: proto::Call,
1837 response: Response<proto::Call>,
1838 session: UserSession,
1839) -> Result<()> {
1840 let room_id = RoomId::from_proto(request.room_id);
1841 let calling_user_id = session.user_id();
1842 let calling_connection_id = session.connection_id;
1843 let called_user_id = UserId::from_proto(request.called_user_id);
1844 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1845 if !session
1846 .db()
1847 .await
1848 .has_contact(calling_user_id, called_user_id)
1849 .await?
1850 {
1851 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1852 }
1853
1854 let incoming_call = {
1855 let (room, incoming_call) = &mut *session
1856 .db()
1857 .await
1858 .call(
1859 room_id,
1860 calling_user_id,
1861 calling_connection_id,
1862 called_user_id,
1863 initial_project_id,
1864 )
1865 .await?;
1866 room_updated(&room, &session.peer);
1867 mem::take(incoming_call)
1868 };
1869 update_user_contacts(called_user_id, &session).await?;
1870
1871 let mut calls = session
1872 .connection_pool()
1873 .await
1874 .user_connection_ids(called_user_id)
1875 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1876 .collect::<FuturesUnordered<_>>();
1877
1878 while let Some(call_response) = calls.next().await {
1879 match call_response.as_ref() {
1880 Ok(_) => {
1881 response.send(proto::Ack {})?;
1882 return Ok(());
1883 }
1884 Err(_) => {
1885 call_response.trace_err();
1886 }
1887 }
1888 }
1889
1890 {
1891 let room = session
1892 .db()
1893 .await
1894 .call_failed(room_id, called_user_id)
1895 .await?;
1896 room_updated(&room, &session.peer);
1897 }
1898 update_user_contacts(called_user_id, &session).await?;
1899
1900 Err(anyhow!("failed to ring user"))?
1901}
1902
1903/// Cancel an outgoing call.
1904async fn cancel_call(
1905 request: proto::CancelCall,
1906 response: Response<proto::CancelCall>,
1907 session: UserSession,
1908) -> Result<()> {
1909 let called_user_id = UserId::from_proto(request.called_user_id);
1910 let room_id = RoomId::from_proto(request.room_id);
1911 {
1912 let room = session
1913 .db()
1914 .await
1915 .cancel_call(room_id, session.connection_id, called_user_id)
1916 .await?;
1917 room_updated(&room, &session.peer);
1918 }
1919
1920 for connection_id in session
1921 .connection_pool()
1922 .await
1923 .user_connection_ids(called_user_id)
1924 {
1925 session
1926 .peer
1927 .send(
1928 connection_id,
1929 proto::CallCanceled {
1930 room_id: room_id.to_proto(),
1931 },
1932 )
1933 .trace_err();
1934 }
1935 response.send(proto::Ack {})?;
1936
1937 update_user_contacts(called_user_id, &session).await?;
1938 Ok(())
1939}
1940
1941/// Decline an incoming call.
1942async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1943 let room_id = RoomId::from_proto(message.room_id);
1944 {
1945 let room = session
1946 .db()
1947 .await
1948 .decline_call(Some(room_id), session.user_id())
1949 .await?
1950 .ok_or_else(|| anyhow!("failed to decline call"))?;
1951 room_updated(&room, &session.peer);
1952 }
1953
1954 for connection_id in session
1955 .connection_pool()
1956 .await
1957 .user_connection_ids(session.user_id())
1958 {
1959 session
1960 .peer
1961 .send(
1962 connection_id,
1963 proto::CallCanceled {
1964 room_id: room_id.to_proto(),
1965 },
1966 )
1967 .trace_err();
1968 }
1969 update_user_contacts(session.user_id(), &session).await?;
1970 Ok(())
1971}
1972
1973/// Updates other participants in the room with your current location.
1974async fn update_participant_location(
1975 request: proto::UpdateParticipantLocation,
1976 response: Response<proto::UpdateParticipantLocation>,
1977 session: UserSession,
1978) -> Result<()> {
1979 let room_id = RoomId::from_proto(request.room_id);
1980 let location = request
1981 .location
1982 .ok_or_else(|| anyhow!("invalid location"))?;
1983
1984 let db = session.db().await;
1985 let room = db
1986 .update_room_participant_location(room_id, session.connection_id, location)
1987 .await?;
1988
1989 room_updated(&room, &session.peer);
1990 response.send(proto::Ack {})?;
1991 Ok(())
1992}
1993
1994/// Share a project into the room.
1995async fn share_project(
1996 request: proto::ShareProject,
1997 response: Response<proto::ShareProject>,
1998 session: UserSession,
1999) -> Result<()> {
2000 let (project_id, room) = &*session
2001 .db()
2002 .await
2003 .share_project(
2004 RoomId::from_proto(request.room_id),
2005 session.connection_id,
2006 &request.worktrees,
2007 request
2008 .dev_server_project_id
2009 .map(|id| DevServerProjectId::from_proto(id)),
2010 )
2011 .await?;
2012 response.send(proto::ShareProjectResponse {
2013 project_id: project_id.to_proto(),
2014 })?;
2015 room_updated(&room, &session.peer);
2016
2017 Ok(())
2018}
2019
2020/// Unshare a project from the room.
2021async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2022 let project_id = ProjectId::from_proto(message.project_id);
2023 unshare_project_internal(
2024 project_id,
2025 session.connection_id,
2026 session.user_id(),
2027 &session,
2028 )
2029 .await
2030}
2031
2032async fn unshare_project_internal(
2033 project_id: ProjectId,
2034 connection_id: ConnectionId,
2035 user_id: Option<UserId>,
2036 session: &Session,
2037) -> Result<()> {
2038 let delete = {
2039 let room_guard = session
2040 .db()
2041 .await
2042 .unshare_project(project_id, connection_id, user_id)
2043 .await?;
2044
2045 let (delete, room, guest_connection_ids) = &*room_guard;
2046
2047 let message = proto::UnshareProject {
2048 project_id: project_id.to_proto(),
2049 };
2050
2051 broadcast(
2052 Some(connection_id),
2053 guest_connection_ids.iter().copied(),
2054 |conn_id| session.peer.send(conn_id, message.clone()),
2055 );
2056 if let Some(room) = room {
2057 room_updated(room, &session.peer);
2058 }
2059
2060 *delete
2061 };
2062
2063 if delete {
2064 let db = session.db().await;
2065 db.delete_project(project_id).await?;
2066 }
2067
2068 Ok(())
2069}
2070
2071/// DevServer makes a project available online
2072async fn share_dev_server_project(
2073 request: proto::ShareDevServerProject,
2074 response: Response<proto::ShareDevServerProject>,
2075 session: DevServerSession,
2076) -> Result<()> {
2077 let (dev_server_project, user_id, status) = session
2078 .db()
2079 .await
2080 .share_dev_server_project(
2081 DevServerProjectId::from_proto(request.dev_server_project_id),
2082 session.dev_server_id(),
2083 session.connection_id,
2084 &request.worktrees,
2085 )
2086 .await?;
2087 let Some(project_id) = dev_server_project.project_id else {
2088 return Err(anyhow!("failed to share remote project"))?;
2089 };
2090
2091 send_dev_server_projects_update(user_id, status, &session).await;
2092
2093 response.send(proto::ShareProjectResponse { project_id })?;
2094
2095 Ok(())
2096}
2097
2098/// Join someone elses shared project.
2099async fn join_project(
2100 request: proto::JoinProject,
2101 response: Response<proto::JoinProject>,
2102 session: UserSession,
2103) -> Result<()> {
2104 let project_id = ProjectId::from_proto(request.project_id);
2105
2106 tracing::info!(%project_id, "join project");
2107
2108 let db = session.db().await;
2109 let (project, replica_id) = &mut *db
2110 .join_project(project_id, session.connection_id, session.user_id())
2111 .await?;
2112 drop(db);
2113 tracing::info!(%project_id, "join remote project");
2114 join_project_internal(response, session, project, replica_id)
2115}
2116
2117trait JoinProjectInternalResponse {
2118 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2119}
2120impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2121 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2122 Response::<proto::JoinProject>::send(self, result)
2123 }
2124}
2125impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2126 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2127 Response::<proto::JoinHostedProject>::send(self, result)
2128 }
2129}
2130
2131fn join_project_internal(
2132 response: impl JoinProjectInternalResponse,
2133 session: UserSession,
2134 project: &mut Project,
2135 replica_id: &ReplicaId,
2136) -> Result<()> {
2137 let collaborators = project
2138 .collaborators
2139 .iter()
2140 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2141 .map(|collaborator| collaborator.to_proto())
2142 .collect::<Vec<_>>();
2143 let project_id = project.id;
2144 let guest_user_id = session.user_id();
2145
2146 let worktrees = project
2147 .worktrees
2148 .iter()
2149 .map(|(id, worktree)| proto::WorktreeMetadata {
2150 id: *id,
2151 root_name: worktree.root_name.clone(),
2152 visible: worktree.visible,
2153 abs_path: worktree.abs_path.clone(),
2154 })
2155 .collect::<Vec<_>>();
2156
2157 let add_project_collaborator = proto::AddProjectCollaborator {
2158 project_id: project_id.to_proto(),
2159 collaborator: Some(proto::Collaborator {
2160 peer_id: Some(session.connection_id.into()),
2161 replica_id: replica_id.0 as u32,
2162 user_id: guest_user_id.to_proto(),
2163 }),
2164 };
2165
2166 for collaborator in &collaborators {
2167 session
2168 .peer
2169 .send(
2170 collaborator.peer_id.unwrap().into(),
2171 add_project_collaborator.clone(),
2172 )
2173 .trace_err();
2174 }
2175
2176 // First, we send the metadata associated with each worktree.
2177 response.send(proto::JoinProjectResponse {
2178 project_id: project.id.0 as u64,
2179 worktrees: worktrees.clone(),
2180 replica_id: replica_id.0 as u32,
2181 collaborators: collaborators.clone(),
2182 language_servers: project.language_servers.clone(),
2183 role: project.role.into(),
2184 dev_server_project_id: project
2185 .dev_server_project_id
2186 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2187 })?;
2188
2189 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2190 #[cfg(any(test, feature = "test-support"))]
2191 const MAX_CHUNK_SIZE: usize = 2;
2192 #[cfg(not(any(test, feature = "test-support")))]
2193 const MAX_CHUNK_SIZE: usize = 256;
2194
2195 // Stream this worktree's entries.
2196 let message = proto::UpdateWorktree {
2197 project_id: project_id.to_proto(),
2198 worktree_id,
2199 abs_path: worktree.abs_path.clone(),
2200 root_name: worktree.root_name,
2201 updated_entries: worktree.entries,
2202 removed_entries: Default::default(),
2203 scan_id: worktree.scan_id,
2204 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2205 updated_repositories: worktree.repository_entries.into_values().collect(),
2206 removed_repositories: Default::default(),
2207 };
2208 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2209 session.peer.send(session.connection_id, update.clone())?;
2210 }
2211
2212 // Stream this worktree's diagnostics.
2213 for summary in worktree.diagnostic_summaries {
2214 session.peer.send(
2215 session.connection_id,
2216 proto::UpdateDiagnosticSummary {
2217 project_id: project_id.to_proto(),
2218 worktree_id: worktree.id,
2219 summary: Some(summary),
2220 },
2221 )?;
2222 }
2223
2224 for settings_file in worktree.settings_files {
2225 session.peer.send(
2226 session.connection_id,
2227 proto::UpdateWorktreeSettings {
2228 project_id: project_id.to_proto(),
2229 worktree_id: worktree.id,
2230 path: settings_file.path,
2231 content: Some(settings_file.content),
2232 },
2233 )?;
2234 }
2235 }
2236
2237 for language_server in &project.language_servers {
2238 session.peer.send(
2239 session.connection_id,
2240 proto::UpdateLanguageServer {
2241 project_id: project_id.to_proto(),
2242 language_server_id: language_server.id,
2243 variant: Some(
2244 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2245 proto::LspDiskBasedDiagnosticsUpdated {},
2246 ),
2247 ),
2248 },
2249 )?;
2250 }
2251
2252 Ok(())
2253}
2254
2255/// Leave someone elses shared project.
2256async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2257 let sender_id = session.connection_id;
2258 let project_id = ProjectId::from_proto(request.project_id);
2259 let db = session.db().await;
2260 if db.is_hosted_project(project_id).await? {
2261 let project = db.leave_hosted_project(project_id, sender_id).await?;
2262 project_left(&project, &session);
2263 return Ok(());
2264 }
2265
2266 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2267 tracing::info!(
2268 %project_id,
2269 "leave project"
2270 );
2271
2272 project_left(&project, &session);
2273 if let Some(room) = room {
2274 room_updated(&room, &session.peer);
2275 }
2276
2277 Ok(())
2278}
2279
2280async fn join_hosted_project(
2281 request: proto::JoinHostedProject,
2282 response: Response<proto::JoinHostedProject>,
2283 session: UserSession,
2284) -> Result<()> {
2285 let (mut project, replica_id) = session
2286 .db()
2287 .await
2288 .join_hosted_project(
2289 ProjectId(request.project_id as i32),
2290 session.user_id(),
2291 session.connection_id,
2292 )
2293 .await?;
2294
2295 join_project_internal(response, session, &mut project, &replica_id)
2296}
2297
2298async fn list_remote_directory(
2299 request: proto::ListRemoteDirectory,
2300 response: Response<proto::ListRemoteDirectory>,
2301 session: UserSession,
2302) -> Result<()> {
2303 let dev_server_id = DevServerId(request.dev_server_id as i32);
2304 let dev_server_connection_id = session
2305 .connection_pool()
2306 .await
2307 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2308
2309 session
2310 .db()
2311 .await
2312 .get_dev_server_for_user(dev_server_id, session.user_id())
2313 .await?;
2314
2315 response.send(
2316 session
2317 .peer
2318 .forward_request(session.connection_id, dev_server_connection_id, request)
2319 .await?,
2320 )?;
2321 Ok(())
2322}
2323
2324async fn update_dev_server_project(
2325 request: proto::UpdateDevServerProject,
2326 response: Response<proto::UpdateDevServerProject>,
2327 session: UserSession,
2328) -> Result<()> {
2329 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2330
2331 let (dev_server_project, update) = session
2332 .db()
2333 .await
2334 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2335 .await?;
2336
2337 let projects = session
2338 .db()
2339 .await
2340 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2341 .await?;
2342
2343 let dev_server_connection_id = session
2344 .connection_pool()
2345 .await
2346 .dev_server_connection_id_supporting(
2347 dev_server_project.dev_server_id,
2348 ZedVersion::with_list_directory(),
2349 )?;
2350
2351 session.peer.send(
2352 dev_server_connection_id,
2353 proto::DevServerInstructions { projects },
2354 )?;
2355
2356 send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358 response.send(proto::Ack {})
2359}
2360
2361async fn create_dev_server_project(
2362 request: proto::CreateDevServerProject,
2363 response: Response<proto::CreateDevServerProject>,
2364 session: UserSession,
2365) -> Result<()> {
2366 let dev_server_id = DevServerId(request.dev_server_id as i32);
2367 let dev_server_connection_id = session
2368 .connection_pool()
2369 .await
2370 .dev_server_connection_id(dev_server_id);
2371 let Some(dev_server_connection_id) = dev_server_connection_id else {
2372 Err(ErrorCode::DevServerOffline
2373 .message("Cannot create a remote project when the dev server is offline".to_string())
2374 .anyhow())?
2375 };
2376
2377 let path = request.path.clone();
2378 //Check that the path exists on the dev server
2379 session
2380 .peer
2381 .forward_request(
2382 session.connection_id,
2383 dev_server_connection_id,
2384 proto::ValidateDevServerProjectRequest { path: path.clone() },
2385 )
2386 .await?;
2387
2388 let (dev_server_project, update) = session
2389 .db()
2390 .await
2391 .create_dev_server_project(
2392 DevServerId(request.dev_server_id as i32),
2393 &request.path,
2394 session.user_id(),
2395 )
2396 .await?;
2397
2398 let projects = session
2399 .db()
2400 .await
2401 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2402 .await?;
2403
2404 session.peer.send(
2405 dev_server_connection_id,
2406 proto::DevServerInstructions { projects },
2407 )?;
2408
2409 send_dev_server_projects_update(session.user_id(), update, &session).await;
2410
2411 response.send(proto::CreateDevServerProjectResponse {
2412 dev_server_project: Some(dev_server_project.to_proto(None)),
2413 })?;
2414 Ok(())
2415}
2416
2417async fn create_dev_server(
2418 request: proto::CreateDevServer,
2419 response: Response<proto::CreateDevServer>,
2420 session: UserSession,
2421) -> Result<()> {
2422 let access_token = auth::random_token();
2423 let hashed_access_token = auth::hash_access_token(&access_token);
2424
2425 if request.name.is_empty() {
2426 return Err(proto::ErrorCode::Forbidden
2427 .message("Dev server name cannot be empty".to_string())
2428 .anyhow())?;
2429 }
2430
2431 let (dev_server, status) = session
2432 .db()
2433 .await
2434 .create_dev_server(
2435 &request.name,
2436 request.ssh_connection_string.as_deref(),
2437 &hashed_access_token,
2438 session.user_id(),
2439 )
2440 .await?;
2441
2442 send_dev_server_projects_update(session.user_id(), status, &session).await;
2443
2444 response.send(proto::CreateDevServerResponse {
2445 dev_server_id: dev_server.id.0 as u64,
2446 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2447 name: request.name,
2448 })?;
2449 Ok(())
2450}
2451
2452async fn regenerate_dev_server_token(
2453 request: proto::RegenerateDevServerToken,
2454 response: Response<proto::RegenerateDevServerToken>,
2455 session: UserSession,
2456) -> Result<()> {
2457 let dev_server_id = DevServerId(request.dev_server_id as i32);
2458 let access_token = auth::random_token();
2459 let hashed_access_token = auth::hash_access_token(&access_token);
2460
2461 let connection_id = session
2462 .connection_pool()
2463 .await
2464 .dev_server_connection_id(dev_server_id);
2465 if let Some(connection_id) = connection_id {
2466 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2467 session.peer.send(
2468 connection_id,
2469 proto::ShutdownDevServer {
2470 reason: Some("dev server token was regenerated".to_string()),
2471 },
2472 )?;
2473 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2474 }
2475
2476 let status = session
2477 .db()
2478 .await
2479 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2480 .await?;
2481
2482 send_dev_server_projects_update(session.user_id(), status, &session).await;
2483
2484 response.send(proto::RegenerateDevServerTokenResponse {
2485 dev_server_id: dev_server_id.to_proto(),
2486 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2487 })?;
2488 Ok(())
2489}
2490
2491async fn rename_dev_server(
2492 request: proto::RenameDevServer,
2493 response: Response<proto::RenameDevServer>,
2494 session: UserSession,
2495) -> Result<()> {
2496 if request.name.trim().is_empty() {
2497 return Err(proto::ErrorCode::Forbidden
2498 .message("Dev server name cannot be empty".to_string())
2499 .anyhow())?;
2500 }
2501
2502 let dev_server_id = DevServerId(request.dev_server_id as i32);
2503 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2504 if dev_server.user_id != session.user_id() {
2505 return Err(anyhow!(ErrorCode::Forbidden))?;
2506 }
2507
2508 let status = session
2509 .db()
2510 .await
2511 .rename_dev_server(
2512 dev_server_id,
2513 &request.name,
2514 request.ssh_connection_string.as_deref(),
2515 session.user_id(),
2516 )
2517 .await?;
2518
2519 send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521 response.send(proto::Ack {})?;
2522 Ok(())
2523}
2524
2525async fn delete_dev_server(
2526 request: proto::DeleteDevServer,
2527 response: Response<proto::DeleteDevServer>,
2528 session: UserSession,
2529) -> Result<()> {
2530 let dev_server_id = DevServerId(request.dev_server_id as i32);
2531 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2532 if dev_server.user_id != session.user_id() {
2533 return Err(anyhow!(ErrorCode::Forbidden))?;
2534 }
2535
2536 let connection_id = session
2537 .connection_pool()
2538 .await
2539 .dev_server_connection_id(dev_server_id);
2540 if let Some(connection_id) = connection_id {
2541 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2542 session.peer.send(
2543 connection_id,
2544 proto::ShutdownDevServer {
2545 reason: Some("dev server was deleted".to_string()),
2546 },
2547 )?;
2548 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2549 }
2550
2551 let status = session
2552 .db()
2553 .await
2554 .delete_dev_server(dev_server_id, session.user_id())
2555 .await?;
2556
2557 send_dev_server_projects_update(session.user_id(), status, &session).await;
2558
2559 response.send(proto::Ack {})?;
2560 Ok(())
2561}
2562
2563async fn delete_dev_server_project(
2564 request: proto::DeleteDevServerProject,
2565 response: Response<proto::DeleteDevServerProject>,
2566 session: UserSession,
2567) -> Result<()> {
2568 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2569 let dev_server_project = session
2570 .db()
2571 .await
2572 .get_dev_server_project(dev_server_project_id)
2573 .await?;
2574
2575 let dev_server = session
2576 .db()
2577 .await
2578 .get_dev_server(dev_server_project.dev_server_id)
2579 .await?;
2580 if dev_server.user_id != session.user_id() {
2581 return Err(anyhow!(ErrorCode::Forbidden))?;
2582 }
2583
2584 let dev_server_connection_id = session
2585 .connection_pool()
2586 .await
2587 .dev_server_connection_id(dev_server.id);
2588
2589 if let Some(dev_server_connection_id) = dev_server_connection_id {
2590 let project = session
2591 .db()
2592 .await
2593 .find_dev_server_project(dev_server_project_id)
2594 .await;
2595 if let Ok(project) = project {
2596 unshare_project_internal(
2597 project.id,
2598 dev_server_connection_id,
2599 Some(session.user_id()),
2600 &session,
2601 )
2602 .await?;
2603 }
2604 }
2605
2606 let (projects, status) = session
2607 .db()
2608 .await
2609 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2610 .await?;
2611
2612 if let Some(dev_server_connection_id) = dev_server_connection_id {
2613 session.peer.send(
2614 dev_server_connection_id,
2615 proto::DevServerInstructions { projects },
2616 )?;
2617 }
2618
2619 send_dev_server_projects_update(session.user_id(), status, &session).await;
2620
2621 response.send(proto::Ack {})?;
2622 Ok(())
2623}
2624
2625async fn rejoin_dev_server_projects(
2626 request: proto::RejoinRemoteProjects,
2627 response: Response<proto::RejoinRemoteProjects>,
2628 session: UserSession,
2629) -> Result<()> {
2630 let mut rejoined_projects = {
2631 let db = session.db().await;
2632 db.rejoin_dev_server_projects(
2633 &request.rejoined_projects,
2634 session.user_id(),
2635 session.0.connection_id,
2636 )
2637 .await?
2638 };
2639 response.send(proto::RejoinRemoteProjectsResponse {
2640 rejoined_projects: rejoined_projects
2641 .iter()
2642 .map(|project| project.to_proto())
2643 .collect(),
2644 })?;
2645 notify_rejoined_projects(&mut rejoined_projects, &session)
2646}
2647
2648async fn reconnect_dev_server(
2649 request: proto::ReconnectDevServer,
2650 response: Response<proto::ReconnectDevServer>,
2651 session: DevServerSession,
2652) -> Result<()> {
2653 let reshared_projects = {
2654 let db = session.db().await;
2655 db.reshare_dev_server_projects(
2656 &request.reshared_projects,
2657 session.dev_server_id(),
2658 session.0.connection_id,
2659 )
2660 .await?
2661 };
2662
2663 for project in &reshared_projects {
2664 for collaborator in &project.collaborators {
2665 session
2666 .peer
2667 .send(
2668 collaborator.connection_id,
2669 proto::UpdateProjectCollaborator {
2670 project_id: project.id.to_proto(),
2671 old_peer_id: Some(project.old_connection_id.into()),
2672 new_peer_id: Some(session.connection_id.into()),
2673 },
2674 )
2675 .trace_err();
2676 }
2677
2678 broadcast(
2679 Some(session.connection_id),
2680 project
2681 .collaborators
2682 .iter()
2683 .map(|collaborator| collaborator.connection_id),
2684 |connection_id| {
2685 session.peer.forward_send(
2686 session.connection_id,
2687 connection_id,
2688 proto::UpdateProject {
2689 project_id: project.id.to_proto(),
2690 worktrees: project.worktrees.clone(),
2691 },
2692 )
2693 },
2694 );
2695 }
2696
2697 response.send(proto::ReconnectDevServerResponse {
2698 reshared_projects: reshared_projects
2699 .iter()
2700 .map(|project| proto::ResharedProject {
2701 id: project.id.to_proto(),
2702 collaborators: project
2703 .collaborators
2704 .iter()
2705 .map(|collaborator| collaborator.to_proto())
2706 .collect(),
2707 })
2708 .collect(),
2709 })?;
2710
2711 Ok(())
2712}
2713
2714async fn shutdown_dev_server(
2715 _: proto::ShutdownDevServer,
2716 response: Response<proto::ShutdownDevServer>,
2717 session: DevServerSession,
2718) -> Result<()> {
2719 response.send(proto::Ack {})?;
2720 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2721 remove_dev_server_connection(session.dev_server_id(), &session).await
2722}
2723
2724async fn shutdown_dev_server_internal(
2725 dev_server_id: DevServerId,
2726 connection_id: ConnectionId,
2727 session: &Session,
2728) -> Result<()> {
2729 let (dev_server_projects, dev_server) = {
2730 let db = session.db().await;
2731 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2732 let dev_server = db.get_dev_server(dev_server_id).await?;
2733 (dev_server_projects, dev_server)
2734 };
2735
2736 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2737 unshare_project_internal(
2738 ProjectId::from_proto(project_id),
2739 connection_id,
2740 None,
2741 session,
2742 )
2743 .await?;
2744 }
2745
2746 session
2747 .connection_pool()
2748 .await
2749 .set_dev_server_offline(dev_server_id);
2750
2751 let status = session
2752 .db()
2753 .await
2754 .dev_server_projects_update(dev_server.user_id)
2755 .await?;
2756 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2757
2758 Ok(())
2759}
2760
2761async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2762 let dev_server_connection = session
2763 .connection_pool()
2764 .await
2765 .dev_server_connection_id(dev_server_id);
2766
2767 if let Some(dev_server_connection) = dev_server_connection {
2768 session
2769 .connection_pool()
2770 .await
2771 .remove_connection(dev_server_connection)?;
2772 }
2773 Ok(())
2774}
2775
2776/// Updates other participants with changes to the project
2777async fn update_project(
2778 request: proto::UpdateProject,
2779 response: Response<proto::UpdateProject>,
2780 session: Session,
2781) -> Result<()> {
2782 let project_id = ProjectId::from_proto(request.project_id);
2783 let (room, guest_connection_ids) = &*session
2784 .db()
2785 .await
2786 .update_project(project_id, session.connection_id, &request.worktrees)
2787 .await?;
2788 broadcast(
2789 Some(session.connection_id),
2790 guest_connection_ids.iter().copied(),
2791 |connection_id| {
2792 session
2793 .peer
2794 .forward_send(session.connection_id, connection_id, request.clone())
2795 },
2796 );
2797 if let Some(room) = room {
2798 room_updated(&room, &session.peer);
2799 }
2800 response.send(proto::Ack {})?;
2801
2802 Ok(())
2803}
2804
2805/// Updates other participants with changes to the worktree
2806async fn update_worktree(
2807 request: proto::UpdateWorktree,
2808 response: Response<proto::UpdateWorktree>,
2809 session: Session,
2810) -> Result<()> {
2811 let guest_connection_ids = session
2812 .db()
2813 .await
2814 .update_worktree(&request, session.connection_id)
2815 .await?;
2816
2817 broadcast(
2818 Some(session.connection_id),
2819 guest_connection_ids.iter().copied(),
2820 |connection_id| {
2821 session
2822 .peer
2823 .forward_send(session.connection_id, connection_id, request.clone())
2824 },
2825 );
2826 response.send(proto::Ack {})?;
2827 Ok(())
2828}
2829
2830/// Updates other participants with changes to the diagnostics
2831async fn update_diagnostic_summary(
2832 message: proto::UpdateDiagnosticSummary,
2833 session: Session,
2834) -> Result<()> {
2835 let guest_connection_ids = session
2836 .db()
2837 .await
2838 .update_diagnostic_summary(&message, session.connection_id)
2839 .await?;
2840
2841 broadcast(
2842 Some(session.connection_id),
2843 guest_connection_ids.iter().copied(),
2844 |connection_id| {
2845 session
2846 .peer
2847 .forward_send(session.connection_id, connection_id, message.clone())
2848 },
2849 );
2850
2851 Ok(())
2852}
2853
2854/// Updates other participants with changes to the worktree settings
2855async fn update_worktree_settings(
2856 message: proto::UpdateWorktreeSettings,
2857 session: Session,
2858) -> Result<()> {
2859 let guest_connection_ids = session
2860 .db()
2861 .await
2862 .update_worktree_settings(&message, session.connection_id)
2863 .await?;
2864
2865 broadcast(
2866 Some(session.connection_id),
2867 guest_connection_ids.iter().copied(),
2868 |connection_id| {
2869 session
2870 .peer
2871 .forward_send(session.connection_id, connection_id, message.clone())
2872 },
2873 );
2874
2875 Ok(())
2876}
2877
2878/// Notify other participants that a language server has started.
2879async fn start_language_server(
2880 request: proto::StartLanguageServer,
2881 session: Session,
2882) -> Result<()> {
2883 let guest_connection_ids = session
2884 .db()
2885 .await
2886 .start_language_server(&request, session.connection_id)
2887 .await?;
2888
2889 broadcast(
2890 Some(session.connection_id),
2891 guest_connection_ids.iter().copied(),
2892 |connection_id| {
2893 session
2894 .peer
2895 .forward_send(session.connection_id, connection_id, request.clone())
2896 },
2897 );
2898 Ok(())
2899}
2900
2901/// Notify other participants that a language server has changed.
2902async fn update_language_server(
2903 request: proto::UpdateLanguageServer,
2904 session: Session,
2905) -> Result<()> {
2906 let project_id = ProjectId::from_proto(request.project_id);
2907 let project_connection_ids = session
2908 .db()
2909 .await
2910 .project_connection_ids(project_id, session.connection_id, true)
2911 .await?;
2912 broadcast(
2913 Some(session.connection_id),
2914 project_connection_ids.iter().copied(),
2915 |connection_id| {
2916 session
2917 .peer
2918 .forward_send(session.connection_id, connection_id, request.clone())
2919 },
2920 );
2921 Ok(())
2922}
2923
2924/// forward a project request to the host. These requests should be read only
2925/// as guests are allowed to send them.
2926async fn forward_read_only_project_request<T>(
2927 request: T,
2928 response: Response<T>,
2929 session: UserSession,
2930) -> Result<()>
2931where
2932 T: EntityMessage + RequestMessage,
2933{
2934 let project_id = ProjectId::from_proto(request.remote_entity_id());
2935 let host_connection_id = session
2936 .db()
2937 .await
2938 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2939 .await?;
2940 let payload = session
2941 .peer
2942 .forward_request(session.connection_id, host_connection_id, request)
2943 .await?;
2944 response.send(payload)?;
2945 Ok(())
2946}
2947
2948async fn forward_find_search_candidates_request(
2949 request: proto::FindSearchCandidates,
2950 response: Response<proto::FindSearchCandidates>,
2951 session: UserSession,
2952) -> Result<()> {
2953 let project_id = ProjectId::from_proto(request.remote_entity_id());
2954 let host_connection_id = session
2955 .db()
2956 .await
2957 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2958 .await?;
2959
2960 let host_version = session
2961 .connection_pool()
2962 .await
2963 .connection(host_connection_id)
2964 .map(|c| c.zed_version);
2965
2966 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2967 {
2968 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2969 let search = proto::SearchProject {
2970 project_id: project_id.to_proto(),
2971 query: query.query,
2972 regex: query.regex,
2973 whole_word: query.whole_word,
2974 case_sensitive: query.case_sensitive,
2975 files_to_include: query.files_to_include,
2976 files_to_exclude: query.files_to_exclude,
2977 include_ignored: query.include_ignored,
2978 };
2979
2980 let payload = session
2981 .peer
2982 .forward_request(session.connection_id, host_connection_id, search)
2983 .await?;
2984 return response.send(proto::FindSearchCandidatesResponse {
2985 buffer_ids: payload
2986 .locations
2987 .into_iter()
2988 .map(|loc| loc.buffer_id)
2989 .collect(),
2990 });
2991 }
2992
2993 let payload = session
2994 .peer
2995 .forward_request(session.connection_id, host_connection_id, request)
2996 .await?;
2997 response.send(payload)?;
2998 Ok(())
2999}
3000
3001/// forward a project request to the dev server. Only allowed
3002/// if it's your dev server.
3003async fn forward_project_request_for_owner<T>(
3004 request: T,
3005 response: Response<T>,
3006 session: UserSession,
3007) -> Result<()>
3008where
3009 T: EntityMessage + RequestMessage,
3010{
3011 let project_id = ProjectId::from_proto(request.remote_entity_id());
3012
3013 let host_connection_id = session
3014 .db()
3015 .await
3016 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3017 .await?;
3018 let payload = session
3019 .peer
3020 .forward_request(session.connection_id, host_connection_id, request)
3021 .await?;
3022 response.send(payload)?;
3023 Ok(())
3024}
3025
3026/// forward a project request to the host. These requests are disallowed
3027/// for guests.
3028async fn forward_mutating_project_request<T>(
3029 request: T,
3030 response: Response<T>,
3031 session: UserSession,
3032) -> Result<()>
3033where
3034 T: EntityMessage + RequestMessage,
3035{
3036 let project_id = ProjectId::from_proto(request.remote_entity_id());
3037
3038 let host_connection_id = session
3039 .db()
3040 .await
3041 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3042 .await?;
3043 let payload = session
3044 .peer
3045 .forward_request(session.connection_id, host_connection_id, request)
3046 .await?;
3047 response.send(payload)?;
3048 Ok(())
3049}
3050
3051/// Notify other participants that a new buffer has been created
3052async fn create_buffer_for_peer(
3053 request: proto::CreateBufferForPeer,
3054 session: Session,
3055) -> Result<()> {
3056 session
3057 .db()
3058 .await
3059 .check_user_is_project_host(
3060 ProjectId::from_proto(request.project_id),
3061 session.connection_id,
3062 )
3063 .await?;
3064 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3065 session
3066 .peer
3067 .forward_send(session.connection_id, peer_id.into(), request)?;
3068 Ok(())
3069}
3070
3071/// Notify other participants that a buffer has been updated. This is
3072/// allowed for guests as long as the update is limited to selections.
3073async fn update_buffer(
3074 request: proto::UpdateBuffer,
3075 response: Response<proto::UpdateBuffer>,
3076 session: Session,
3077) -> Result<()> {
3078 let project_id = ProjectId::from_proto(request.project_id);
3079 let mut capability = Capability::ReadOnly;
3080
3081 for op in request.operations.iter() {
3082 match op.variant {
3083 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3084 Some(_) => capability = Capability::ReadWrite,
3085 }
3086 }
3087
3088 let host = {
3089 let guard = session
3090 .db()
3091 .await
3092 .connections_for_buffer_update(
3093 project_id,
3094 session.principal_id(),
3095 session.connection_id,
3096 capability,
3097 )
3098 .await?;
3099
3100 let (host, guests) = &*guard;
3101
3102 broadcast(
3103 Some(session.connection_id),
3104 guests.clone(),
3105 |connection_id| {
3106 session
3107 .peer
3108 .forward_send(session.connection_id, connection_id, request.clone())
3109 },
3110 );
3111
3112 *host
3113 };
3114
3115 if host != session.connection_id {
3116 session
3117 .peer
3118 .forward_request(session.connection_id, host, request.clone())
3119 .await?;
3120 }
3121
3122 response.send(proto::Ack {})?;
3123 Ok(())
3124}
3125
3126async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3127 let project_id = ProjectId::from_proto(message.project_id);
3128
3129 let operation = message.operation.as_ref().context("invalid operation")?;
3130 let capability = match operation.variant.as_ref() {
3131 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3132 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3133 match buffer_op.variant {
3134 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3135 Capability::ReadOnly
3136 }
3137 _ => Capability::ReadWrite,
3138 }
3139 } else {
3140 Capability::ReadWrite
3141 }
3142 }
3143 Some(_) => Capability::ReadWrite,
3144 None => Capability::ReadOnly,
3145 };
3146
3147 let guard = session
3148 .db()
3149 .await
3150 .connections_for_buffer_update(
3151 project_id,
3152 session.principal_id(),
3153 session.connection_id,
3154 capability,
3155 )
3156 .await?;
3157
3158 let (host, guests) = &*guard;
3159
3160 broadcast(
3161 Some(session.connection_id),
3162 guests.iter().chain([host]).copied(),
3163 |connection_id| {
3164 session
3165 .peer
3166 .forward_send(session.connection_id, connection_id, message.clone())
3167 },
3168 );
3169
3170 Ok(())
3171}
3172
3173/// Notify other participants that a project has been updated.
3174async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3175 request: T,
3176 session: Session,
3177) -> Result<()> {
3178 let project_id = ProjectId::from_proto(request.remote_entity_id());
3179 let project_connection_ids = session
3180 .db()
3181 .await
3182 .project_connection_ids(project_id, session.connection_id, false)
3183 .await?;
3184
3185 broadcast(
3186 Some(session.connection_id),
3187 project_connection_ids.iter().copied(),
3188 |connection_id| {
3189 session
3190 .peer
3191 .forward_send(session.connection_id, connection_id, request.clone())
3192 },
3193 );
3194 Ok(())
3195}
3196
3197/// Start following another user in a call.
3198async fn follow(
3199 request: proto::Follow,
3200 response: Response<proto::Follow>,
3201 session: UserSession,
3202) -> Result<()> {
3203 let room_id = RoomId::from_proto(request.room_id);
3204 let project_id = request.project_id.map(ProjectId::from_proto);
3205 let leader_id = request
3206 .leader_id
3207 .ok_or_else(|| anyhow!("invalid leader id"))?
3208 .into();
3209 let follower_id = session.connection_id;
3210
3211 session
3212 .db()
3213 .await
3214 .check_room_participants(room_id, leader_id, session.connection_id)
3215 .await?;
3216
3217 let response_payload = session
3218 .peer
3219 .forward_request(session.connection_id, leader_id, request)
3220 .await?;
3221 response.send(response_payload)?;
3222
3223 if let Some(project_id) = project_id {
3224 let room = session
3225 .db()
3226 .await
3227 .follow(room_id, project_id, leader_id, follower_id)
3228 .await?;
3229 room_updated(&room, &session.peer);
3230 }
3231
3232 Ok(())
3233}
3234
3235/// Stop following another user in a call.
3236async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3237 let room_id = RoomId::from_proto(request.room_id);
3238 let project_id = request.project_id.map(ProjectId::from_proto);
3239 let leader_id = request
3240 .leader_id
3241 .ok_or_else(|| anyhow!("invalid leader id"))?
3242 .into();
3243 let follower_id = session.connection_id;
3244
3245 session
3246 .db()
3247 .await
3248 .check_room_participants(room_id, leader_id, session.connection_id)
3249 .await?;
3250
3251 session
3252 .peer
3253 .forward_send(session.connection_id, leader_id, request)?;
3254
3255 if let Some(project_id) = project_id {
3256 let room = session
3257 .db()
3258 .await
3259 .unfollow(room_id, project_id, leader_id, follower_id)
3260 .await?;
3261 room_updated(&room, &session.peer);
3262 }
3263
3264 Ok(())
3265}
3266
3267/// Notify everyone following you of your current location.
3268async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3269 let room_id = RoomId::from_proto(request.room_id);
3270 let database = session.db.lock().await;
3271
3272 let connection_ids = if let Some(project_id) = request.project_id {
3273 let project_id = ProjectId::from_proto(project_id);
3274 database
3275 .project_connection_ids(project_id, session.connection_id, true)
3276 .await?
3277 } else {
3278 database
3279 .room_connection_ids(room_id, session.connection_id)
3280 .await?
3281 };
3282
3283 // For now, don't send view update messages back to that view's current leader.
3284 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3285 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3286 _ => None,
3287 });
3288
3289 for connection_id in connection_ids.iter().cloned() {
3290 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3291 session
3292 .peer
3293 .forward_send(session.connection_id, connection_id, request.clone())?;
3294 }
3295 }
3296 Ok(())
3297}
3298
3299/// Get public data about users.
3300async fn get_users(
3301 request: proto::GetUsers,
3302 response: Response<proto::GetUsers>,
3303 session: Session,
3304) -> Result<()> {
3305 let user_ids = request
3306 .user_ids
3307 .into_iter()
3308 .map(UserId::from_proto)
3309 .collect();
3310 let users = session
3311 .db()
3312 .await
3313 .get_users_by_ids(user_ids)
3314 .await?
3315 .into_iter()
3316 .map(|user| proto::User {
3317 id: user.id.to_proto(),
3318 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3319 github_login: user.github_login,
3320 })
3321 .collect();
3322 response.send(proto::UsersResponse { users })?;
3323 Ok(())
3324}
3325
3326/// Search for users (to invite) buy Github login
3327async fn fuzzy_search_users(
3328 request: proto::FuzzySearchUsers,
3329 response: Response<proto::FuzzySearchUsers>,
3330 session: UserSession,
3331) -> Result<()> {
3332 let query = request.query;
3333 let users = match query.len() {
3334 0 => vec![],
3335 1 | 2 => session
3336 .db()
3337 .await
3338 .get_user_by_github_login(&query)
3339 .await?
3340 .into_iter()
3341 .collect(),
3342 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3343 };
3344 let users = users
3345 .into_iter()
3346 .filter(|user| user.id != session.user_id())
3347 .map(|user| proto::User {
3348 id: user.id.to_proto(),
3349 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3350 github_login: user.github_login,
3351 })
3352 .collect();
3353 response.send(proto::UsersResponse { users })?;
3354 Ok(())
3355}
3356
3357/// Send a contact request to another user.
3358async fn request_contact(
3359 request: proto::RequestContact,
3360 response: Response<proto::RequestContact>,
3361 session: UserSession,
3362) -> Result<()> {
3363 let requester_id = session.user_id();
3364 let responder_id = UserId::from_proto(request.responder_id);
3365 if requester_id == responder_id {
3366 return Err(anyhow!("cannot add yourself as a contact"))?;
3367 }
3368
3369 let notifications = session
3370 .db()
3371 .await
3372 .send_contact_request(requester_id, responder_id)
3373 .await?;
3374
3375 // Update outgoing contact requests of requester
3376 let mut update = proto::UpdateContacts::default();
3377 update.outgoing_requests.push(responder_id.to_proto());
3378 for connection_id in session
3379 .connection_pool()
3380 .await
3381 .user_connection_ids(requester_id)
3382 {
3383 session.peer.send(connection_id, update.clone())?;
3384 }
3385
3386 // Update incoming contact requests of responder
3387 let mut update = proto::UpdateContacts::default();
3388 update
3389 .incoming_requests
3390 .push(proto::IncomingContactRequest {
3391 requester_id: requester_id.to_proto(),
3392 });
3393 let connection_pool = session.connection_pool().await;
3394 for connection_id in connection_pool.user_connection_ids(responder_id) {
3395 session.peer.send(connection_id, update.clone())?;
3396 }
3397
3398 send_notifications(&connection_pool, &session.peer, notifications);
3399
3400 response.send(proto::Ack {})?;
3401 Ok(())
3402}
3403
3404/// Accept or decline a contact request
3405async fn respond_to_contact_request(
3406 request: proto::RespondToContactRequest,
3407 response: Response<proto::RespondToContactRequest>,
3408 session: UserSession,
3409) -> Result<()> {
3410 let responder_id = session.user_id();
3411 let requester_id = UserId::from_proto(request.requester_id);
3412 let db = session.db().await;
3413 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3414 db.dismiss_contact_notification(responder_id, requester_id)
3415 .await?;
3416 } else {
3417 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3418
3419 let notifications = db
3420 .respond_to_contact_request(responder_id, requester_id, accept)
3421 .await?;
3422 let requester_busy = db.is_user_busy(requester_id).await?;
3423 let responder_busy = db.is_user_busy(responder_id).await?;
3424
3425 let pool = session.connection_pool().await;
3426 // Update responder with new contact
3427 let mut update = proto::UpdateContacts::default();
3428 if accept {
3429 update
3430 .contacts
3431 .push(contact_for_user(requester_id, requester_busy, &pool));
3432 }
3433 update
3434 .remove_incoming_requests
3435 .push(requester_id.to_proto());
3436 for connection_id in pool.user_connection_ids(responder_id) {
3437 session.peer.send(connection_id, update.clone())?;
3438 }
3439
3440 // Update requester with new contact
3441 let mut update = proto::UpdateContacts::default();
3442 if accept {
3443 update
3444 .contacts
3445 .push(contact_for_user(responder_id, responder_busy, &pool));
3446 }
3447 update
3448 .remove_outgoing_requests
3449 .push(responder_id.to_proto());
3450
3451 for connection_id in pool.user_connection_ids(requester_id) {
3452 session.peer.send(connection_id, update.clone())?;
3453 }
3454
3455 send_notifications(&pool, &session.peer, notifications);
3456 }
3457
3458 response.send(proto::Ack {})?;
3459 Ok(())
3460}
3461
3462/// Remove a contact.
3463async fn remove_contact(
3464 request: proto::RemoveContact,
3465 response: Response<proto::RemoveContact>,
3466 session: UserSession,
3467) -> Result<()> {
3468 let requester_id = session.user_id();
3469 let responder_id = UserId::from_proto(request.user_id);
3470 let db = session.db().await;
3471 let (contact_accepted, deleted_notification_id) =
3472 db.remove_contact(requester_id, responder_id).await?;
3473
3474 let pool = session.connection_pool().await;
3475 // Update outgoing contact requests of requester
3476 let mut update = proto::UpdateContacts::default();
3477 if contact_accepted {
3478 update.remove_contacts.push(responder_id.to_proto());
3479 } else {
3480 update
3481 .remove_outgoing_requests
3482 .push(responder_id.to_proto());
3483 }
3484 for connection_id in pool.user_connection_ids(requester_id) {
3485 session.peer.send(connection_id, update.clone())?;
3486 }
3487
3488 // Update incoming contact requests of responder
3489 let mut update = proto::UpdateContacts::default();
3490 if contact_accepted {
3491 update.remove_contacts.push(requester_id.to_proto());
3492 } else {
3493 update
3494 .remove_incoming_requests
3495 .push(requester_id.to_proto());
3496 }
3497 for connection_id in pool.user_connection_ids(responder_id) {
3498 session.peer.send(connection_id, update.clone())?;
3499 if let Some(notification_id) = deleted_notification_id {
3500 session.peer.send(
3501 connection_id,
3502 proto::DeleteNotification {
3503 notification_id: notification_id.to_proto(),
3504 },
3505 )?;
3506 }
3507 }
3508
3509 response.send(proto::Ack {})?;
3510 Ok(())
3511}
3512
3513fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3514 version.0.minor() < 139
3515}
3516
3517async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3518 let plan = session.current_plan(session.db().await).await?;
3519
3520 session
3521 .peer
3522 .send(
3523 session.connection_id,
3524 proto::UpdateUserPlan { plan: plan.into() },
3525 )
3526 .trace_err();
3527
3528 Ok(())
3529}
3530
3531async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3532 subscribe_user_to_channels(
3533 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3534 &session,
3535 )
3536 .await?;
3537 Ok(())
3538}
3539
3540async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3541 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3542 let mut pool = session.connection_pool().await;
3543 for membership in &channels_for_user.channel_memberships {
3544 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3545 }
3546 session.peer.send(
3547 session.connection_id,
3548 build_update_user_channels(&channels_for_user),
3549 )?;
3550 session.peer.send(
3551 session.connection_id,
3552 build_channels_update(channels_for_user),
3553 )?;
3554 Ok(())
3555}
3556
3557/// Creates a new channel.
3558async fn create_channel(
3559 request: proto::CreateChannel,
3560 response: Response<proto::CreateChannel>,
3561 session: UserSession,
3562) -> Result<()> {
3563 let db = session.db().await;
3564
3565 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3566 let (channel, membership) = db
3567 .create_channel(&request.name, parent_id, session.user_id())
3568 .await?;
3569
3570 let root_id = channel.root_id();
3571 let channel = Channel::from_model(channel);
3572
3573 response.send(proto::CreateChannelResponse {
3574 channel: Some(channel.to_proto()),
3575 parent_id: request.parent_id,
3576 })?;
3577
3578 let mut connection_pool = session.connection_pool().await;
3579 if let Some(membership) = membership {
3580 connection_pool.subscribe_to_channel(
3581 membership.user_id,
3582 membership.channel_id,
3583 membership.role,
3584 );
3585 let update = proto::UpdateUserChannels {
3586 channel_memberships: vec![proto::ChannelMembership {
3587 channel_id: membership.channel_id.to_proto(),
3588 role: membership.role.into(),
3589 }],
3590 ..Default::default()
3591 };
3592 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3593 session.peer.send(connection_id, update.clone())?;
3594 }
3595 }
3596
3597 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3598 if !role.can_see_channel(channel.visibility) {
3599 continue;
3600 }
3601
3602 let update = proto::UpdateChannels {
3603 channels: vec![channel.to_proto()],
3604 ..Default::default()
3605 };
3606 session.peer.send(connection_id, update.clone())?;
3607 }
3608
3609 Ok(())
3610}
3611
3612/// Delete a channel
3613async fn delete_channel(
3614 request: proto::DeleteChannel,
3615 response: Response<proto::DeleteChannel>,
3616 session: UserSession,
3617) -> Result<()> {
3618 let db = session.db().await;
3619
3620 let channel_id = request.channel_id;
3621 let (root_channel, removed_channels) = db
3622 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3623 .await?;
3624 response.send(proto::Ack {})?;
3625
3626 // Notify members of removed channels
3627 let mut update = proto::UpdateChannels::default();
3628 update
3629 .delete_channels
3630 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3631
3632 let connection_pool = session.connection_pool().await;
3633 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3634 session.peer.send(connection_id, update.clone())?;
3635 }
3636
3637 Ok(())
3638}
3639
3640/// Invite someone to join a channel.
3641async fn invite_channel_member(
3642 request: proto::InviteChannelMember,
3643 response: Response<proto::InviteChannelMember>,
3644 session: UserSession,
3645) -> Result<()> {
3646 let db = session.db().await;
3647 let channel_id = ChannelId::from_proto(request.channel_id);
3648 let invitee_id = UserId::from_proto(request.user_id);
3649 let InviteMemberResult {
3650 channel,
3651 notifications,
3652 } = db
3653 .invite_channel_member(
3654 channel_id,
3655 invitee_id,
3656 session.user_id(),
3657 request.role().into(),
3658 )
3659 .await?;
3660
3661 let update = proto::UpdateChannels {
3662 channel_invitations: vec![channel.to_proto()],
3663 ..Default::default()
3664 };
3665
3666 let connection_pool = session.connection_pool().await;
3667 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3668 session.peer.send(connection_id, update.clone())?;
3669 }
3670
3671 send_notifications(&connection_pool, &session.peer, notifications);
3672
3673 response.send(proto::Ack {})?;
3674 Ok(())
3675}
3676
3677/// remove someone from a channel
3678async fn remove_channel_member(
3679 request: proto::RemoveChannelMember,
3680 response: Response<proto::RemoveChannelMember>,
3681 session: UserSession,
3682) -> Result<()> {
3683 let db = session.db().await;
3684 let channel_id = ChannelId::from_proto(request.channel_id);
3685 let member_id = UserId::from_proto(request.user_id);
3686
3687 let RemoveChannelMemberResult {
3688 membership_update,
3689 notification_id,
3690 } = db
3691 .remove_channel_member(channel_id, member_id, session.user_id())
3692 .await?;
3693
3694 let mut connection_pool = session.connection_pool().await;
3695 notify_membership_updated(
3696 &mut connection_pool,
3697 membership_update,
3698 member_id,
3699 &session.peer,
3700 );
3701 for connection_id in connection_pool.user_connection_ids(member_id) {
3702 if let Some(notification_id) = notification_id {
3703 session
3704 .peer
3705 .send(
3706 connection_id,
3707 proto::DeleteNotification {
3708 notification_id: notification_id.to_proto(),
3709 },
3710 )
3711 .trace_err();
3712 }
3713 }
3714
3715 response.send(proto::Ack {})?;
3716 Ok(())
3717}
3718
3719/// Toggle the channel between public and private.
3720/// Care is taken to maintain the invariant that public channels only descend from public channels,
3721/// (though members-only channels can appear at any point in the hierarchy).
3722async fn set_channel_visibility(
3723 request: proto::SetChannelVisibility,
3724 response: Response<proto::SetChannelVisibility>,
3725 session: UserSession,
3726) -> Result<()> {
3727 let db = session.db().await;
3728 let channel_id = ChannelId::from_proto(request.channel_id);
3729 let visibility = request.visibility().into();
3730
3731 let channel_model = db
3732 .set_channel_visibility(channel_id, visibility, session.user_id())
3733 .await?;
3734 let root_id = channel_model.root_id();
3735 let channel = Channel::from_model(channel_model);
3736
3737 let mut connection_pool = session.connection_pool().await;
3738 for (user_id, role) in connection_pool
3739 .channel_user_ids(root_id)
3740 .collect::<Vec<_>>()
3741 .into_iter()
3742 {
3743 let update = if role.can_see_channel(channel.visibility) {
3744 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3745 proto::UpdateChannels {
3746 channels: vec![channel.to_proto()],
3747 ..Default::default()
3748 }
3749 } else {
3750 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3751 proto::UpdateChannels {
3752 delete_channels: vec![channel.id.to_proto()],
3753 ..Default::default()
3754 }
3755 };
3756
3757 for connection_id in connection_pool.user_connection_ids(user_id) {
3758 session.peer.send(connection_id, update.clone())?;
3759 }
3760 }
3761
3762 response.send(proto::Ack {})?;
3763 Ok(())
3764}
3765
3766/// Alter the role for a user in the channel.
3767async fn set_channel_member_role(
3768 request: proto::SetChannelMemberRole,
3769 response: Response<proto::SetChannelMemberRole>,
3770 session: UserSession,
3771) -> Result<()> {
3772 let db = session.db().await;
3773 let channel_id = ChannelId::from_proto(request.channel_id);
3774 let member_id = UserId::from_proto(request.user_id);
3775 let result = db
3776 .set_channel_member_role(
3777 channel_id,
3778 session.user_id(),
3779 member_id,
3780 request.role().into(),
3781 )
3782 .await?;
3783
3784 match result {
3785 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3786 let mut connection_pool = session.connection_pool().await;
3787 notify_membership_updated(
3788 &mut connection_pool,
3789 membership_update,
3790 member_id,
3791 &session.peer,
3792 )
3793 }
3794 db::SetMemberRoleResult::InviteUpdated(channel) => {
3795 let update = proto::UpdateChannels {
3796 channel_invitations: vec![channel.to_proto()],
3797 ..Default::default()
3798 };
3799
3800 for connection_id in session
3801 .connection_pool()
3802 .await
3803 .user_connection_ids(member_id)
3804 {
3805 session.peer.send(connection_id, update.clone())?;
3806 }
3807 }
3808 }
3809
3810 response.send(proto::Ack {})?;
3811 Ok(())
3812}
3813
3814/// Change the name of a channel
3815async fn rename_channel(
3816 request: proto::RenameChannel,
3817 response: Response<proto::RenameChannel>,
3818 session: UserSession,
3819) -> Result<()> {
3820 let db = session.db().await;
3821 let channel_id = ChannelId::from_proto(request.channel_id);
3822 let channel_model = db
3823 .rename_channel(channel_id, session.user_id(), &request.name)
3824 .await?;
3825 let root_id = channel_model.root_id();
3826 let channel = Channel::from_model(channel_model);
3827
3828 response.send(proto::RenameChannelResponse {
3829 channel: Some(channel.to_proto()),
3830 })?;
3831
3832 let connection_pool = session.connection_pool().await;
3833 let update = proto::UpdateChannels {
3834 channels: vec![channel.to_proto()],
3835 ..Default::default()
3836 };
3837 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3838 if role.can_see_channel(channel.visibility) {
3839 session.peer.send(connection_id, update.clone())?;
3840 }
3841 }
3842
3843 Ok(())
3844}
3845
3846/// Move a channel to a new parent.
3847async fn move_channel(
3848 request: proto::MoveChannel,
3849 response: Response<proto::MoveChannel>,
3850 session: UserSession,
3851) -> Result<()> {
3852 let channel_id = ChannelId::from_proto(request.channel_id);
3853 let to = ChannelId::from_proto(request.to);
3854
3855 let (root_id, channels) = session
3856 .db()
3857 .await
3858 .move_channel(channel_id, to, session.user_id())
3859 .await?;
3860
3861 let connection_pool = session.connection_pool().await;
3862 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3863 let channels = channels
3864 .iter()
3865 .filter_map(|channel| {
3866 if role.can_see_channel(channel.visibility) {
3867 Some(channel.to_proto())
3868 } else {
3869 None
3870 }
3871 })
3872 .collect::<Vec<_>>();
3873 if channels.is_empty() {
3874 continue;
3875 }
3876
3877 let update = proto::UpdateChannels {
3878 channels,
3879 ..Default::default()
3880 };
3881
3882 session.peer.send(connection_id, update.clone())?;
3883 }
3884
3885 response.send(Ack {})?;
3886 Ok(())
3887}
3888
3889/// Get the list of channel members
3890async fn get_channel_members(
3891 request: proto::GetChannelMembers,
3892 response: Response<proto::GetChannelMembers>,
3893 session: UserSession,
3894) -> Result<()> {
3895 let db = session.db().await;
3896 let channel_id = ChannelId::from_proto(request.channel_id);
3897 let limit = if request.limit == 0 {
3898 u16::MAX as u64
3899 } else {
3900 request.limit
3901 };
3902 let (members, users) = db
3903 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3904 .await?;
3905 response.send(proto::GetChannelMembersResponse { members, users })?;
3906 Ok(())
3907}
3908
3909/// Accept or decline a channel invitation.
3910async fn respond_to_channel_invite(
3911 request: proto::RespondToChannelInvite,
3912 response: Response<proto::RespondToChannelInvite>,
3913 session: UserSession,
3914) -> Result<()> {
3915 let db = session.db().await;
3916 let channel_id = ChannelId::from_proto(request.channel_id);
3917 let RespondToChannelInvite {
3918 membership_update,
3919 notifications,
3920 } = db
3921 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3922 .await?;
3923
3924 let mut connection_pool = session.connection_pool().await;
3925 if let Some(membership_update) = membership_update {
3926 notify_membership_updated(
3927 &mut connection_pool,
3928 membership_update,
3929 session.user_id(),
3930 &session.peer,
3931 );
3932 } else {
3933 let update = proto::UpdateChannels {
3934 remove_channel_invitations: vec![channel_id.to_proto()],
3935 ..Default::default()
3936 };
3937
3938 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3939 session.peer.send(connection_id, update.clone())?;
3940 }
3941 };
3942
3943 send_notifications(&connection_pool, &session.peer, notifications);
3944
3945 response.send(proto::Ack {})?;
3946
3947 Ok(())
3948}
3949
3950/// Join the channels' room
3951async fn join_channel(
3952 request: proto::JoinChannel,
3953 response: Response<proto::JoinChannel>,
3954 session: UserSession,
3955) -> Result<()> {
3956 let channel_id = ChannelId::from_proto(request.channel_id);
3957 join_channel_internal(channel_id, Box::new(response), session).await
3958}
3959
3960trait JoinChannelInternalResponse {
3961 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3962}
3963impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3964 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3965 Response::<proto::JoinChannel>::send(self, result)
3966 }
3967}
3968impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3969 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3970 Response::<proto::JoinRoom>::send(self, result)
3971 }
3972}
3973
3974async fn join_channel_internal(
3975 channel_id: ChannelId,
3976 response: Box<impl JoinChannelInternalResponse>,
3977 session: UserSession,
3978) -> Result<()> {
3979 let joined_room = {
3980 let mut db = session.db().await;
3981 // If zed quits without leaving the room, and the user re-opens zed before the
3982 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3983 // room they were in.
3984 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3985 tracing::info!(
3986 stale_connection_id = %connection,
3987 "cleaning up stale connection",
3988 );
3989 drop(db);
3990 leave_room_for_session(&session, connection).await?;
3991 db = session.db().await;
3992 }
3993
3994 let (joined_room, membership_updated, role) = db
3995 .join_channel(channel_id, session.user_id(), session.connection_id)
3996 .await?;
3997
3998 let live_kit_connection_info =
3999 session
4000 .app_state
4001 .live_kit_client
4002 .as_ref()
4003 .and_then(|live_kit| {
4004 let (can_publish, token) = if role == ChannelRole::Guest {
4005 (
4006 false,
4007 live_kit
4008 .guest_token(
4009 &joined_room.room.live_kit_room,
4010 &session.user_id().to_string(),
4011 )
4012 .trace_err()?,
4013 )
4014 } else {
4015 (
4016 true,
4017 live_kit
4018 .room_token(
4019 &joined_room.room.live_kit_room,
4020 &session.user_id().to_string(),
4021 )
4022 .trace_err()?,
4023 )
4024 };
4025
4026 Some(LiveKitConnectionInfo {
4027 server_url: live_kit.url().into(),
4028 token,
4029 can_publish,
4030 })
4031 });
4032
4033 response.send(proto::JoinRoomResponse {
4034 room: Some(joined_room.room.clone()),
4035 channel_id: joined_room
4036 .channel
4037 .as_ref()
4038 .map(|channel| channel.id.to_proto()),
4039 live_kit_connection_info,
4040 })?;
4041
4042 let mut connection_pool = session.connection_pool().await;
4043 if let Some(membership_updated) = membership_updated {
4044 notify_membership_updated(
4045 &mut connection_pool,
4046 membership_updated,
4047 session.user_id(),
4048 &session.peer,
4049 );
4050 }
4051
4052 room_updated(&joined_room.room, &session.peer);
4053
4054 joined_room
4055 };
4056
4057 channel_updated(
4058 &joined_room
4059 .channel
4060 .ok_or_else(|| anyhow!("channel not returned"))?,
4061 &joined_room.room,
4062 &session.peer,
4063 &*session.connection_pool().await,
4064 );
4065
4066 update_user_contacts(session.user_id(), &session).await?;
4067 Ok(())
4068}
4069
4070/// Start editing the channel notes
4071async fn join_channel_buffer(
4072 request: proto::JoinChannelBuffer,
4073 response: Response<proto::JoinChannelBuffer>,
4074 session: UserSession,
4075) -> Result<()> {
4076 let db = session.db().await;
4077 let channel_id = ChannelId::from_proto(request.channel_id);
4078
4079 let open_response = db
4080 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4081 .await?;
4082
4083 let collaborators = open_response.collaborators.clone();
4084 response.send(open_response)?;
4085
4086 let update = UpdateChannelBufferCollaborators {
4087 channel_id: channel_id.to_proto(),
4088 collaborators: collaborators.clone(),
4089 };
4090 channel_buffer_updated(
4091 session.connection_id,
4092 collaborators
4093 .iter()
4094 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4095 &update,
4096 &session.peer,
4097 );
4098
4099 Ok(())
4100}
4101
4102/// Edit the channel notes
4103async fn update_channel_buffer(
4104 request: proto::UpdateChannelBuffer,
4105 session: UserSession,
4106) -> Result<()> {
4107 let db = session.db().await;
4108 let channel_id = ChannelId::from_proto(request.channel_id);
4109
4110 let (collaborators, epoch, version) = db
4111 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4112 .await?;
4113
4114 channel_buffer_updated(
4115 session.connection_id,
4116 collaborators.clone(),
4117 &proto::UpdateChannelBuffer {
4118 channel_id: channel_id.to_proto(),
4119 operations: request.operations,
4120 },
4121 &session.peer,
4122 );
4123
4124 let pool = &*session.connection_pool().await;
4125
4126 let non_collaborators =
4127 pool.channel_connection_ids(channel_id)
4128 .filter_map(|(connection_id, _)| {
4129 if collaborators.contains(&connection_id) {
4130 None
4131 } else {
4132 Some(connection_id)
4133 }
4134 });
4135
4136 broadcast(None, non_collaborators, |peer_id| {
4137 session.peer.send(
4138 peer_id,
4139 proto::UpdateChannels {
4140 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4141 channel_id: channel_id.to_proto(),
4142 epoch: epoch as u64,
4143 version: version.clone(),
4144 }],
4145 ..Default::default()
4146 },
4147 )
4148 });
4149
4150 Ok(())
4151}
4152
4153/// Rejoin the channel notes after a connection blip
4154async fn rejoin_channel_buffers(
4155 request: proto::RejoinChannelBuffers,
4156 response: Response<proto::RejoinChannelBuffers>,
4157 session: UserSession,
4158) -> Result<()> {
4159 let db = session.db().await;
4160 let buffers = db
4161 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4162 .await?;
4163
4164 for rejoined_buffer in &buffers {
4165 let collaborators_to_notify = rejoined_buffer
4166 .buffer
4167 .collaborators
4168 .iter()
4169 .filter_map(|c| Some(c.peer_id?.into()));
4170 channel_buffer_updated(
4171 session.connection_id,
4172 collaborators_to_notify,
4173 &proto::UpdateChannelBufferCollaborators {
4174 channel_id: rejoined_buffer.buffer.channel_id,
4175 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4176 },
4177 &session.peer,
4178 );
4179 }
4180
4181 response.send(proto::RejoinChannelBuffersResponse {
4182 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4183 })?;
4184
4185 Ok(())
4186}
4187
4188/// Stop editing the channel notes
4189async fn leave_channel_buffer(
4190 request: proto::LeaveChannelBuffer,
4191 response: Response<proto::LeaveChannelBuffer>,
4192 session: UserSession,
4193) -> Result<()> {
4194 let db = session.db().await;
4195 let channel_id = ChannelId::from_proto(request.channel_id);
4196
4197 let left_buffer = db
4198 .leave_channel_buffer(channel_id, session.connection_id)
4199 .await?;
4200
4201 response.send(Ack {})?;
4202
4203 channel_buffer_updated(
4204 session.connection_id,
4205 left_buffer.connections,
4206 &proto::UpdateChannelBufferCollaborators {
4207 channel_id: channel_id.to_proto(),
4208 collaborators: left_buffer.collaborators,
4209 },
4210 &session.peer,
4211 );
4212
4213 Ok(())
4214}
4215
4216fn channel_buffer_updated<T: EnvelopedMessage>(
4217 sender_id: ConnectionId,
4218 collaborators: impl IntoIterator<Item = ConnectionId>,
4219 message: &T,
4220 peer: &Peer,
4221) {
4222 broadcast(Some(sender_id), collaborators, |peer_id| {
4223 peer.send(peer_id, message.clone())
4224 });
4225}
4226
4227fn send_notifications(
4228 connection_pool: &ConnectionPool,
4229 peer: &Peer,
4230 notifications: db::NotificationBatch,
4231) {
4232 for (user_id, notification) in notifications {
4233 for connection_id in connection_pool.user_connection_ids(user_id) {
4234 if let Err(error) = peer.send(
4235 connection_id,
4236 proto::AddNotification {
4237 notification: Some(notification.clone()),
4238 },
4239 ) {
4240 tracing::error!(
4241 "failed to send notification to {:?} {}",
4242 connection_id,
4243 error
4244 );
4245 }
4246 }
4247 }
4248}
4249
4250/// Send a message to the channel
4251async fn send_channel_message(
4252 request: proto::SendChannelMessage,
4253 response: Response<proto::SendChannelMessage>,
4254 session: UserSession,
4255) -> Result<()> {
4256 // Validate the message body.
4257 let body = request.body.trim().to_string();
4258 if body.len() > MAX_MESSAGE_LEN {
4259 return Err(anyhow!("message is too long"))?;
4260 }
4261 if body.is_empty() {
4262 return Err(anyhow!("message can't be blank"))?;
4263 }
4264
4265 // TODO: adjust mentions if body is trimmed
4266
4267 let timestamp = OffsetDateTime::now_utc();
4268 let nonce = request
4269 .nonce
4270 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4271
4272 let channel_id = ChannelId::from_proto(request.channel_id);
4273 let CreatedChannelMessage {
4274 message_id,
4275 participant_connection_ids,
4276 notifications,
4277 } = session
4278 .db()
4279 .await
4280 .create_channel_message(
4281 channel_id,
4282 session.user_id(),
4283 &body,
4284 &request.mentions,
4285 timestamp,
4286 nonce.clone().into(),
4287 match request.reply_to_message_id {
4288 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4289 None => None,
4290 },
4291 )
4292 .await?;
4293
4294 let message = proto::ChannelMessage {
4295 sender_id: session.user_id().to_proto(),
4296 id: message_id.to_proto(),
4297 body,
4298 mentions: request.mentions,
4299 timestamp: timestamp.unix_timestamp() as u64,
4300 nonce: Some(nonce),
4301 reply_to_message_id: request.reply_to_message_id,
4302 edited_at: None,
4303 };
4304 broadcast(
4305 Some(session.connection_id),
4306 participant_connection_ids.clone(),
4307 |connection| {
4308 session.peer.send(
4309 connection,
4310 proto::ChannelMessageSent {
4311 channel_id: channel_id.to_proto(),
4312 message: Some(message.clone()),
4313 },
4314 )
4315 },
4316 );
4317 response.send(proto::SendChannelMessageResponse {
4318 message: Some(message),
4319 })?;
4320
4321 let pool = &*session.connection_pool().await;
4322 let non_participants =
4323 pool.channel_connection_ids(channel_id)
4324 .filter_map(|(connection_id, _)| {
4325 if participant_connection_ids.contains(&connection_id) {
4326 None
4327 } else {
4328 Some(connection_id)
4329 }
4330 });
4331 broadcast(None, non_participants, |peer_id| {
4332 session.peer.send(
4333 peer_id,
4334 proto::UpdateChannels {
4335 latest_channel_message_ids: vec![proto::ChannelMessageId {
4336 channel_id: channel_id.to_proto(),
4337 message_id: message_id.to_proto(),
4338 }],
4339 ..Default::default()
4340 },
4341 )
4342 });
4343 send_notifications(pool, &session.peer, notifications);
4344
4345 Ok(())
4346}
4347
4348/// Delete a channel message
4349async fn remove_channel_message(
4350 request: proto::RemoveChannelMessage,
4351 response: Response<proto::RemoveChannelMessage>,
4352 session: UserSession,
4353) -> Result<()> {
4354 let channel_id = ChannelId::from_proto(request.channel_id);
4355 let message_id = MessageId::from_proto(request.message_id);
4356 let (connection_ids, existing_notification_ids) = session
4357 .db()
4358 .await
4359 .remove_channel_message(channel_id, message_id, session.user_id())
4360 .await?;
4361
4362 broadcast(
4363 Some(session.connection_id),
4364 connection_ids,
4365 move |connection| {
4366 session.peer.send(connection, request.clone())?;
4367
4368 for notification_id in &existing_notification_ids {
4369 session.peer.send(
4370 connection,
4371 proto::DeleteNotification {
4372 notification_id: (*notification_id).to_proto(),
4373 },
4374 )?;
4375 }
4376
4377 Ok(())
4378 },
4379 );
4380 response.send(proto::Ack {})?;
4381 Ok(())
4382}
4383
4384async fn update_channel_message(
4385 request: proto::UpdateChannelMessage,
4386 response: Response<proto::UpdateChannelMessage>,
4387 session: UserSession,
4388) -> Result<()> {
4389 let channel_id = ChannelId::from_proto(request.channel_id);
4390 let message_id = MessageId::from_proto(request.message_id);
4391 let updated_at = OffsetDateTime::now_utc();
4392 let UpdatedChannelMessage {
4393 message_id,
4394 participant_connection_ids,
4395 notifications,
4396 reply_to_message_id,
4397 timestamp,
4398 deleted_mention_notification_ids,
4399 updated_mention_notifications,
4400 } = session
4401 .db()
4402 .await
4403 .update_channel_message(
4404 channel_id,
4405 message_id,
4406 session.user_id(),
4407 request.body.as_str(),
4408 &request.mentions,
4409 updated_at,
4410 )
4411 .await?;
4412
4413 let nonce = request
4414 .nonce
4415 .clone()
4416 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4417
4418 let message = proto::ChannelMessage {
4419 sender_id: session.user_id().to_proto(),
4420 id: message_id.to_proto(),
4421 body: request.body.clone(),
4422 mentions: request.mentions.clone(),
4423 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4424 nonce: Some(nonce),
4425 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4426 edited_at: Some(updated_at.unix_timestamp() as u64),
4427 };
4428
4429 response.send(proto::Ack {})?;
4430
4431 let pool = &*session.connection_pool().await;
4432 broadcast(
4433 Some(session.connection_id),
4434 participant_connection_ids,
4435 |connection| {
4436 session.peer.send(
4437 connection,
4438 proto::ChannelMessageUpdate {
4439 channel_id: channel_id.to_proto(),
4440 message: Some(message.clone()),
4441 },
4442 )?;
4443
4444 for notification_id in &deleted_mention_notification_ids {
4445 session.peer.send(
4446 connection,
4447 proto::DeleteNotification {
4448 notification_id: (*notification_id).to_proto(),
4449 },
4450 )?;
4451 }
4452
4453 for notification in &updated_mention_notifications {
4454 session.peer.send(
4455 connection,
4456 proto::UpdateNotification {
4457 notification: Some(notification.clone()),
4458 },
4459 )?;
4460 }
4461
4462 Ok(())
4463 },
4464 );
4465
4466 send_notifications(pool, &session.peer, notifications);
4467
4468 Ok(())
4469}
4470
4471/// Mark a channel message as read
4472async fn acknowledge_channel_message(
4473 request: proto::AckChannelMessage,
4474 session: UserSession,
4475) -> Result<()> {
4476 let channel_id = ChannelId::from_proto(request.channel_id);
4477 let message_id = MessageId::from_proto(request.message_id);
4478 let notifications = session
4479 .db()
4480 .await
4481 .observe_channel_message(channel_id, session.user_id(), message_id)
4482 .await?;
4483 send_notifications(
4484 &*session.connection_pool().await,
4485 &session.peer,
4486 notifications,
4487 );
4488 Ok(())
4489}
4490
4491/// Mark a buffer version as synced
4492async fn acknowledge_buffer_version(
4493 request: proto::AckBufferOperation,
4494 session: UserSession,
4495) -> Result<()> {
4496 let buffer_id = BufferId::from_proto(request.buffer_id);
4497 session
4498 .db()
4499 .await
4500 .observe_buffer_version(
4501 buffer_id,
4502 session.user_id(),
4503 request.epoch as i32,
4504 &request.version,
4505 )
4506 .await?;
4507 Ok(())
4508}
4509
4510async fn count_language_model_tokens(
4511 request: proto::CountLanguageModelTokens,
4512 response: Response<proto::CountLanguageModelTokens>,
4513 session: Session,
4514 config: &Config,
4515) -> Result<()> {
4516 let Some(session) = session.for_user() else {
4517 return Err(anyhow!("user not found"))?;
4518 };
4519 authorize_access_to_legacy_llm_endpoints(&session).await?;
4520
4521 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4522 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4523 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4524 };
4525
4526 session
4527 .app_state
4528 .rate_limiter
4529 .check(&*rate_limit, session.user_id())
4530 .await?;
4531
4532 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4533 Some(proto::LanguageModelProvider::Google) => {
4534 let api_key = config
4535 .google_ai_api_key
4536 .as_ref()
4537 .context("no Google AI API key configured on the server")?;
4538 google_ai::count_tokens(
4539 session.http_client.as_ref(),
4540 google_ai::API_URL,
4541 api_key,
4542 serde_json::from_str(&request.request)?,
4543 )
4544 .await?
4545 }
4546 _ => return Err(anyhow!("unsupported provider"))?,
4547 };
4548
4549 response.send(proto::CountLanguageModelTokensResponse {
4550 token_count: result.total_tokens as u32,
4551 })?;
4552
4553 Ok(())
4554}
4555
4556struct ZedProCountLanguageModelTokensRateLimit;
4557
4558impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4559 fn capacity(&self) -> usize {
4560 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4561 .ok()
4562 .and_then(|v| v.parse().ok())
4563 .unwrap_or(600) // Picked arbitrarily
4564 }
4565
4566 fn refill_duration(&self) -> chrono::Duration {
4567 chrono::Duration::hours(1)
4568 }
4569
4570 fn db_name(&self) -> &'static str {
4571 "zed-pro:count-language-model-tokens"
4572 }
4573}
4574
4575struct FreeCountLanguageModelTokensRateLimit;
4576
4577impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4578 fn capacity(&self) -> usize {
4579 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4580 .ok()
4581 .and_then(|v| v.parse().ok())
4582 .unwrap_or(600 / 10) // Picked arbitrarily
4583 }
4584
4585 fn refill_duration(&self) -> chrono::Duration {
4586 chrono::Duration::hours(1)
4587 }
4588
4589 fn db_name(&self) -> &'static str {
4590 "free:count-language-model-tokens"
4591 }
4592}
4593
4594struct ZedProComputeEmbeddingsRateLimit;
4595
4596impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4597 fn capacity(&self) -> usize {
4598 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4599 .ok()
4600 .and_then(|v| v.parse().ok())
4601 .unwrap_or(5000) // Picked arbitrarily
4602 }
4603
4604 fn refill_duration(&self) -> chrono::Duration {
4605 chrono::Duration::hours(1)
4606 }
4607
4608 fn db_name(&self) -> &'static str {
4609 "zed-pro:compute-embeddings"
4610 }
4611}
4612
4613struct FreeComputeEmbeddingsRateLimit;
4614
4615impl RateLimit for FreeComputeEmbeddingsRateLimit {
4616 fn capacity(&self) -> usize {
4617 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4618 .ok()
4619 .and_then(|v| v.parse().ok())
4620 .unwrap_or(5000 / 10) // Picked arbitrarily
4621 }
4622
4623 fn refill_duration(&self) -> chrono::Duration {
4624 chrono::Duration::hours(1)
4625 }
4626
4627 fn db_name(&self) -> &'static str {
4628 "free:compute-embeddings"
4629 }
4630}
4631
4632async fn compute_embeddings(
4633 request: proto::ComputeEmbeddings,
4634 response: Response<proto::ComputeEmbeddings>,
4635 session: UserSession,
4636 api_key: Option<Arc<str>>,
4637) -> Result<()> {
4638 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4639 authorize_access_to_legacy_llm_endpoints(&session).await?;
4640
4641 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4642 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4643 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4644 };
4645
4646 session
4647 .app_state
4648 .rate_limiter
4649 .check(&*rate_limit, session.user_id())
4650 .await?;
4651
4652 let embeddings = match request.model.as_str() {
4653 "openai/text-embedding-3-small" => {
4654 open_ai::embed(
4655 session.http_client.as_ref(),
4656 OPEN_AI_API_URL,
4657 &api_key,
4658 OpenAiEmbeddingModel::TextEmbedding3Small,
4659 request.texts.iter().map(|text| text.as_str()),
4660 )
4661 .await?
4662 }
4663 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4664 };
4665
4666 let embeddings = request
4667 .texts
4668 .iter()
4669 .map(|text| {
4670 let mut hasher = sha2::Sha256::new();
4671 hasher.update(text.as_bytes());
4672 let result = hasher.finalize();
4673 result.to_vec()
4674 })
4675 .zip(
4676 embeddings
4677 .data
4678 .into_iter()
4679 .map(|embedding| embedding.embedding),
4680 )
4681 .collect::<HashMap<_, _>>();
4682
4683 let db = session.db().await;
4684 db.save_embeddings(&request.model, &embeddings)
4685 .await
4686 .context("failed to save embeddings")
4687 .trace_err();
4688
4689 response.send(proto::ComputeEmbeddingsResponse {
4690 embeddings: embeddings
4691 .into_iter()
4692 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4693 .collect(),
4694 })?;
4695 Ok(())
4696}
4697
4698async fn get_cached_embeddings(
4699 request: proto::GetCachedEmbeddings,
4700 response: Response<proto::GetCachedEmbeddings>,
4701 session: UserSession,
4702) -> Result<()> {
4703 authorize_access_to_legacy_llm_endpoints(&session).await?;
4704
4705 let db = session.db().await;
4706 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4707
4708 response.send(proto::GetCachedEmbeddingsResponse {
4709 embeddings: embeddings
4710 .into_iter()
4711 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4712 .collect(),
4713 })?;
4714 Ok(())
4715}
4716
4717/// This is leftover from before the LLM service.
4718///
4719/// The endpoints protected by this check will be moved there eventually.
4720async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4721 if session.is_staff() {
4722 Ok(())
4723 } else {
4724 Err(anyhow!("permission denied"))?
4725 }
4726}
4727
4728/// Get a Supermaven API key for the user
4729async fn get_supermaven_api_key(
4730 _request: proto::GetSupermavenApiKey,
4731 response: Response<proto::GetSupermavenApiKey>,
4732 session: UserSession,
4733) -> Result<()> {
4734 let user_id: String = session.user_id().to_string();
4735 if !session.is_staff() {
4736 return Err(anyhow!("supermaven not enabled for this account"))?;
4737 }
4738
4739 let email = session
4740 .email()
4741 .ok_or_else(|| anyhow!("user must have an email"))?;
4742
4743 let supermaven_admin_api = session
4744 .supermaven_client
4745 .as_ref()
4746 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4747
4748 let result = supermaven_admin_api
4749 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4750 .await?;
4751
4752 response.send(proto::GetSupermavenApiKeyResponse {
4753 api_key: result.api_key,
4754 })?;
4755
4756 Ok(())
4757}
4758
4759/// Start receiving chat updates for a channel
4760async fn join_channel_chat(
4761 request: proto::JoinChannelChat,
4762 response: Response<proto::JoinChannelChat>,
4763 session: UserSession,
4764) -> Result<()> {
4765 let channel_id = ChannelId::from_proto(request.channel_id);
4766
4767 let db = session.db().await;
4768 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4769 .await?;
4770 let messages = db
4771 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4772 .await?;
4773 response.send(proto::JoinChannelChatResponse {
4774 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4775 messages,
4776 })?;
4777 Ok(())
4778}
4779
4780/// Stop receiving chat updates for a channel
4781async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4782 let channel_id = ChannelId::from_proto(request.channel_id);
4783 session
4784 .db()
4785 .await
4786 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4787 .await?;
4788 Ok(())
4789}
4790
4791/// Retrieve the chat history for a channel
4792async fn get_channel_messages(
4793 request: proto::GetChannelMessages,
4794 response: Response<proto::GetChannelMessages>,
4795 session: UserSession,
4796) -> Result<()> {
4797 let channel_id = ChannelId::from_proto(request.channel_id);
4798 let messages = session
4799 .db()
4800 .await
4801 .get_channel_messages(
4802 channel_id,
4803 session.user_id(),
4804 MESSAGE_COUNT_PER_PAGE,
4805 Some(MessageId::from_proto(request.before_message_id)),
4806 )
4807 .await?;
4808 response.send(proto::GetChannelMessagesResponse {
4809 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4810 messages,
4811 })?;
4812 Ok(())
4813}
4814
4815/// Retrieve specific chat messages
4816async fn get_channel_messages_by_id(
4817 request: proto::GetChannelMessagesById,
4818 response: Response<proto::GetChannelMessagesById>,
4819 session: UserSession,
4820) -> Result<()> {
4821 let message_ids = request
4822 .message_ids
4823 .iter()
4824 .map(|id| MessageId::from_proto(*id))
4825 .collect::<Vec<_>>();
4826 let messages = session
4827 .db()
4828 .await
4829 .get_channel_messages_by_id(session.user_id(), &message_ids)
4830 .await?;
4831 response.send(proto::GetChannelMessagesResponse {
4832 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4833 messages,
4834 })?;
4835 Ok(())
4836}
4837
4838/// Retrieve the current users notifications
4839async fn get_notifications(
4840 request: proto::GetNotifications,
4841 response: Response<proto::GetNotifications>,
4842 session: UserSession,
4843) -> Result<()> {
4844 let notifications = session
4845 .db()
4846 .await
4847 .get_notifications(
4848 session.user_id(),
4849 NOTIFICATION_COUNT_PER_PAGE,
4850 request
4851 .before_id
4852 .map(|id| db::NotificationId::from_proto(id)),
4853 )
4854 .await?;
4855 response.send(proto::GetNotificationsResponse {
4856 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4857 notifications,
4858 })?;
4859 Ok(())
4860}
4861
4862/// Mark notifications as read
4863async fn mark_notification_as_read(
4864 request: proto::MarkNotificationRead,
4865 response: Response<proto::MarkNotificationRead>,
4866 session: UserSession,
4867) -> Result<()> {
4868 let database = &session.db().await;
4869 let notifications = database
4870 .mark_notification_as_read_by_id(
4871 session.user_id(),
4872 NotificationId::from_proto(request.notification_id),
4873 )
4874 .await?;
4875 send_notifications(
4876 &*session.connection_pool().await,
4877 &session.peer,
4878 notifications,
4879 );
4880 response.send(proto::Ack {})?;
4881 Ok(())
4882}
4883
4884/// Get the current users information
4885async fn get_private_user_info(
4886 _request: proto::GetPrivateUserInfo,
4887 response: Response<proto::GetPrivateUserInfo>,
4888 session: UserSession,
4889) -> Result<()> {
4890 let db = session.db().await;
4891
4892 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4893 let user = db
4894 .get_user_by_id(session.user_id())
4895 .await?
4896 .ok_or_else(|| anyhow!("user not found"))?;
4897 let flags = db.get_user_flags(session.user_id()).await?;
4898
4899 response.send(proto::GetPrivateUserInfoResponse {
4900 metrics_id,
4901 staff: user.admin,
4902 flags,
4903 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4904 })?;
4905 Ok(())
4906}
4907
4908/// Accept the terms of service (tos) on behalf of the current user
4909async fn accept_terms_of_service(
4910 _request: proto::AcceptTermsOfService,
4911 response: Response<proto::AcceptTermsOfService>,
4912 session: UserSession,
4913) -> Result<()> {
4914 let db = session.db().await;
4915
4916 let accepted_tos_at = Utc::now();
4917 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4918 .await?;
4919
4920 response.send(proto::AcceptTermsOfServiceResponse {
4921 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4922 })?;
4923 Ok(())
4924}
4925
4926/// The minimum account age an account must have in order to use the LLM service.
4927const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4928
4929async fn get_llm_api_token(
4930 _request: proto::GetLlmToken,
4931 response: Response<proto::GetLlmToken>,
4932 session: UserSession,
4933) -> Result<()> {
4934 let db = session.db().await;
4935
4936 let flags = db.get_user_flags(session.user_id()).await?;
4937 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4938 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4939
4940 if !session.is_staff() && !has_language_models_feature_flag {
4941 Err(anyhow!("permission denied"))?
4942 }
4943
4944 let user_id = session.user_id();
4945 let user = db
4946 .get_user_by_id(user_id)
4947 .await?
4948 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4949
4950 if user.accepted_tos_at.is_none() {
4951 Err(anyhow!("terms of service not accepted"))?
4952 }
4953
4954 let mut account_created_at = user.created_at;
4955 if let Some(github_created_at) = user.github_user_created_at {
4956 account_created_at = account_created_at.min(github_created_at);
4957 }
4958 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4959 Err(anyhow!("account too young"))?
4960 }
4961 let token = LlmTokenClaims::create(
4962 user.id,
4963 user.github_login.clone(),
4964 session.is_staff(),
4965 has_llm_closed_beta_feature_flag,
4966 session.current_plan(db).await?,
4967 &session.app_state.config,
4968 )?;
4969 response.send(proto::GetLlmTokenResponse { token })?;
4970 Ok(())
4971}
4972
4973fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4974 let message = match message {
4975 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4976 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4977 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4978 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4979 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4980 code: frame.code.into(),
4981 reason: frame.reason,
4982 })),
4983 // We should never receive a frame while reading the message, according
4984 // to the `tungstenite` maintainers:
4985 //
4986 // > It cannot occur when you read messages from the WebSocket, but it
4987 // > can be used when you want to send the raw frames (e.g. you want to
4988 // > send the frames to the WebSocket without composing the full message first).
4989 // >
4990 // > — https://github.com/snapview/tungstenite-rs/issues/268
4991 TungsteniteMessage::Frame(_) => {
4992 bail!("received an unexpected frame while reading the message")
4993 }
4994 };
4995
4996 Ok(message)
4997}
4998
4999fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5000 match message {
5001 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5002 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5003 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5004 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5005 AxumMessage::Close(frame) => {
5006 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5007 code: frame.code.into(),
5008 reason: frame.reason,
5009 }))
5010 }
5011 }
5012}
5013
5014fn notify_membership_updated(
5015 connection_pool: &mut ConnectionPool,
5016 result: MembershipUpdated,
5017 user_id: UserId,
5018 peer: &Peer,
5019) {
5020 for membership in &result.new_channels.channel_memberships {
5021 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5022 }
5023 for channel_id in &result.removed_channels {
5024 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5025 }
5026
5027 let user_channels_update = proto::UpdateUserChannels {
5028 channel_memberships: result
5029 .new_channels
5030 .channel_memberships
5031 .iter()
5032 .map(|cm| proto::ChannelMembership {
5033 channel_id: cm.channel_id.to_proto(),
5034 role: cm.role.into(),
5035 })
5036 .collect(),
5037 ..Default::default()
5038 };
5039
5040 let mut update = build_channels_update(result.new_channels);
5041 update.delete_channels = result
5042 .removed_channels
5043 .into_iter()
5044 .map(|id| id.to_proto())
5045 .collect();
5046 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5047
5048 for connection_id in connection_pool.user_connection_ids(user_id) {
5049 peer.send(connection_id, user_channels_update.clone())
5050 .trace_err();
5051 peer.send(connection_id, update.clone()).trace_err();
5052 }
5053}
5054
5055fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5056 proto::UpdateUserChannels {
5057 channel_memberships: channels
5058 .channel_memberships
5059 .iter()
5060 .map(|m| proto::ChannelMembership {
5061 channel_id: m.channel_id.to_proto(),
5062 role: m.role.into(),
5063 })
5064 .collect(),
5065 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5066 observed_channel_message_id: channels.observed_channel_messages.clone(),
5067 }
5068}
5069
5070fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5071 let mut update = proto::UpdateChannels::default();
5072
5073 for channel in channels.channels {
5074 update.channels.push(channel.to_proto());
5075 }
5076
5077 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5078 update.latest_channel_message_ids = channels.latest_channel_messages;
5079
5080 for (channel_id, participants) in channels.channel_participants {
5081 update
5082 .channel_participants
5083 .push(proto::ChannelParticipants {
5084 channel_id: channel_id.to_proto(),
5085 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5086 });
5087 }
5088
5089 for channel in channels.invited_channels {
5090 update.channel_invitations.push(channel.to_proto());
5091 }
5092
5093 update.hosted_projects = channels.hosted_projects;
5094 update
5095}
5096
5097fn build_initial_contacts_update(
5098 contacts: Vec<db::Contact>,
5099 pool: &ConnectionPool,
5100) -> proto::UpdateContacts {
5101 let mut update = proto::UpdateContacts::default();
5102
5103 for contact in contacts {
5104 match contact {
5105 db::Contact::Accepted { user_id, busy } => {
5106 update.contacts.push(contact_for_user(user_id, busy, &pool));
5107 }
5108 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5109 db::Contact::Incoming { user_id } => {
5110 update
5111 .incoming_requests
5112 .push(proto::IncomingContactRequest {
5113 requester_id: user_id.to_proto(),
5114 })
5115 }
5116 }
5117 }
5118
5119 update
5120}
5121
5122fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5123 proto::Contact {
5124 user_id: user_id.to_proto(),
5125 online: pool.is_user_online(user_id),
5126 busy,
5127 }
5128}
5129
5130fn room_updated(room: &proto::Room, peer: &Peer) {
5131 broadcast(
5132 None,
5133 room.participants
5134 .iter()
5135 .filter_map(|participant| Some(participant.peer_id?.into())),
5136 |peer_id| {
5137 peer.send(
5138 peer_id,
5139 proto::RoomUpdated {
5140 room: Some(room.clone()),
5141 },
5142 )
5143 },
5144 );
5145}
5146
5147fn channel_updated(
5148 channel: &db::channel::Model,
5149 room: &proto::Room,
5150 peer: &Peer,
5151 pool: &ConnectionPool,
5152) {
5153 let participants = room
5154 .participants
5155 .iter()
5156 .map(|p| p.user_id)
5157 .collect::<Vec<_>>();
5158
5159 broadcast(
5160 None,
5161 pool.channel_connection_ids(channel.root_id())
5162 .filter_map(|(channel_id, role)| {
5163 role.can_see_channel(channel.visibility).then(|| channel_id)
5164 }),
5165 |peer_id| {
5166 peer.send(
5167 peer_id,
5168 proto::UpdateChannels {
5169 channel_participants: vec![proto::ChannelParticipants {
5170 channel_id: channel.id.to_proto(),
5171 participant_user_ids: participants.clone(),
5172 }],
5173 ..Default::default()
5174 },
5175 )
5176 },
5177 );
5178}
5179
5180async fn send_dev_server_projects_update(
5181 user_id: UserId,
5182 mut status: proto::DevServerProjectsUpdate,
5183 session: &Session,
5184) {
5185 let pool = session.connection_pool().await;
5186 for dev_server in &mut status.dev_servers {
5187 dev_server.status =
5188 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5189 }
5190 let connections = pool.user_connection_ids(user_id);
5191 for connection_id in connections {
5192 session.peer.send(connection_id, status.clone()).trace_err();
5193 }
5194}
5195
5196async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5197 let db = session.db().await;
5198
5199 let contacts = db.get_contacts(user_id).await?;
5200 let busy = db.is_user_busy(user_id).await?;
5201
5202 let pool = session.connection_pool().await;
5203 let updated_contact = contact_for_user(user_id, busy, &pool);
5204 for contact in contacts {
5205 if let db::Contact::Accepted {
5206 user_id: contact_user_id,
5207 ..
5208 } = contact
5209 {
5210 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5211 session
5212 .peer
5213 .send(
5214 contact_conn_id,
5215 proto::UpdateContacts {
5216 contacts: vec![updated_contact.clone()],
5217 remove_contacts: Default::default(),
5218 incoming_requests: Default::default(),
5219 remove_incoming_requests: Default::default(),
5220 outgoing_requests: Default::default(),
5221 remove_outgoing_requests: Default::default(),
5222 },
5223 )
5224 .trace_err();
5225 }
5226 }
5227 }
5228 Ok(())
5229}
5230
5231async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5232 log::info!("lost dev server connection, unsharing projects");
5233 let project_ids = session
5234 .db()
5235 .await
5236 .get_stale_dev_server_projects(session.connection_id)
5237 .await?;
5238
5239 for project_id in project_ids {
5240 // not unshare re-checks the connection ids match, so we get away with no transaction
5241 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5242 }
5243
5244 let user_id = session.dev_server().user_id;
5245 let update = session
5246 .db()
5247 .await
5248 .dev_server_projects_update(user_id)
5249 .await?;
5250
5251 send_dev_server_projects_update(user_id, update, session).await;
5252
5253 Ok(())
5254}
5255
5256async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5257 let mut contacts_to_update = HashSet::default();
5258
5259 let room_id;
5260 let canceled_calls_to_user_ids;
5261 let live_kit_room;
5262 let delete_live_kit_room;
5263 let room;
5264 let channel;
5265
5266 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5267 contacts_to_update.insert(session.user_id());
5268
5269 for project in left_room.left_projects.values() {
5270 project_left(project, session);
5271 }
5272
5273 room_id = RoomId::from_proto(left_room.room.id);
5274 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5275 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5276 delete_live_kit_room = left_room.deleted;
5277 room = mem::take(&mut left_room.room);
5278 channel = mem::take(&mut left_room.channel);
5279
5280 room_updated(&room, &session.peer);
5281 } else {
5282 return Ok(());
5283 }
5284
5285 if let Some(channel) = channel {
5286 channel_updated(
5287 &channel,
5288 &room,
5289 &session.peer,
5290 &*session.connection_pool().await,
5291 );
5292 }
5293
5294 {
5295 let pool = session.connection_pool().await;
5296 for canceled_user_id in canceled_calls_to_user_ids {
5297 for connection_id in pool.user_connection_ids(canceled_user_id) {
5298 session
5299 .peer
5300 .send(
5301 connection_id,
5302 proto::CallCanceled {
5303 room_id: room_id.to_proto(),
5304 },
5305 )
5306 .trace_err();
5307 }
5308 contacts_to_update.insert(canceled_user_id);
5309 }
5310 }
5311
5312 for contact_user_id in contacts_to_update {
5313 update_user_contacts(contact_user_id, &session).await?;
5314 }
5315
5316 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5317 live_kit
5318 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5319 .await
5320 .trace_err();
5321
5322 if delete_live_kit_room {
5323 live_kit.delete_room(live_kit_room).await.trace_err();
5324 }
5325 }
5326
5327 Ok(())
5328}
5329
5330async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5331 let left_channel_buffers = session
5332 .db()
5333 .await
5334 .leave_channel_buffers(session.connection_id)
5335 .await?;
5336
5337 for left_buffer in left_channel_buffers {
5338 channel_buffer_updated(
5339 session.connection_id,
5340 left_buffer.connections,
5341 &proto::UpdateChannelBufferCollaborators {
5342 channel_id: left_buffer.channel_id.to_proto(),
5343 collaborators: left_buffer.collaborators,
5344 },
5345 &session.peer,
5346 );
5347 }
5348
5349 Ok(())
5350}
5351
5352fn project_left(project: &db::LeftProject, session: &UserSession) {
5353 for connection_id in &project.connection_ids {
5354 if project.should_unshare {
5355 session
5356 .peer
5357 .send(
5358 *connection_id,
5359 proto::UnshareProject {
5360 project_id: project.id.to_proto(),
5361 },
5362 )
5363 .trace_err();
5364 } else {
5365 session
5366 .peer
5367 .send(
5368 *connection_id,
5369 proto::RemoveProjectCollaborator {
5370 project_id: project.id.to_proto(),
5371 peer_id: Some(session.connection_id.into()),
5372 },
5373 )
5374 .trace_err();
5375 }
5376 }
5377}
5378
5379pub trait ResultExt {
5380 type Ok;
5381
5382 fn trace_err(self) -> Option<Self::Ok>;
5383}
5384
5385impl<T, E> ResultExt for Result<T, E>
5386where
5387 E: std::fmt::Debug,
5388{
5389 type Ok = T;
5390
5391 #[track_caller]
5392 fn trace_err(self) -> Option<T> {
5393 match self {
5394 Ok(value) => Some(value),
5395 Err(error) => {
5396 tracing::error!("{:?}", error);
5397 None
5398 }
5399 }
5400 }
5401}