1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, MutexGuard, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111 DevServer(dev_server::Model),
112}
113
114impl Principal {
115 fn update_span(&self, span: &tracing::Span) {
116 match &self {
117 Principal::User(user) => {
118 span.record("user_id", &user.id.0);
119 span.record("login", &user.github_login);
120 }
121 Principal::Impersonated { user, admin } => {
122 span.record("user_id", &user.id.0);
123 span.record("login", &user.github_login);
124 span.record("impersonator", &admin.github_login);
125 }
126 Principal::DevServer(dev_server) => {
127 span.record("dev_server_id", &dev_server.id.0);
128 }
129 }
130 }
131}
132
133#[derive(Clone)]
134struct Session {
135 principal: Principal,
136 connection_id: ConnectionId,
137 db: Arc<tokio::sync::Mutex<DbHandle>>,
138 peer: Arc<Peer>,
139 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
140 app_state: Arc<AppState>,
141 supermaven_client: Option<Arc<SupermavenAdminApi>>,
142 http_client: Arc<IsahcHttpClient>,
143 /// The GeoIP country code for the user.
144 #[allow(unused)]
145 geoip_country_code: Option<String>,
146 _executor: Executor,
147}
148
149impl Session {
150 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
151 #[cfg(test)]
152 tokio::task::yield_now().await;
153 let guard = self.db.lock().await;
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 guard
157 }
158
159 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.connection_pool.lock();
163 ConnectionPoolGuard {
164 guard,
165 _not_send: PhantomData,
166 }
167 }
168
169 fn for_user(self) -> Option<UserSession> {
170 UserSession::new(self)
171 }
172
173 fn for_dev_server(self) -> Option<DevServerSession> {
174 DevServerSession::new(self)
175 }
176
177 fn user_id(&self) -> Option<UserId> {
178 match &self.principal {
179 Principal::User(user) => Some(user.id),
180 Principal::Impersonated { user, .. } => Some(user.id),
181 Principal::DevServer(_) => None,
182 }
183 }
184
185 fn is_staff(&self) -> bool {
186 match &self.principal {
187 Principal::User(user) => user.admin,
188 Principal::Impersonated { .. } => true,
189 Principal::DevServer(_) => false,
190 }
191 }
192
193 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
194 if self.is_staff() {
195 return Ok(proto::Plan::ZedPro);
196 }
197
198 let Some(user_id) = self.user_id() else {
199 return Ok(proto::Plan::Free);
200 };
201
202 if db.has_active_billing_subscription(user_id).await? {
203 Ok(proto::Plan::ZedPro)
204 } else {
205 Ok(proto::Plan::Free)
206 }
207 }
208
209 fn dev_server_id(&self) -> Option<DevServerId> {
210 match &self.principal {
211 Principal::User(_) | Principal::Impersonated { .. } => None,
212 Principal::DevServer(dev_server) => Some(dev_server.id),
213 }
214 }
215
216 fn principal_id(&self) -> PrincipalId {
217 match &self.principal {
218 Principal::User(user) => PrincipalId::UserId(user.id),
219 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
220 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
221 }
222 }
223}
224
225impl Debug for Session {
226 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
227 let mut result = f.debug_struct("Session");
228 match &self.principal {
229 Principal::User(user) => {
230 result.field("user", &user.github_login);
231 }
232 Principal::Impersonated { user, admin } => {
233 result.field("user", &user.github_login);
234 result.field("impersonator", &admin.github_login);
235 }
236 Principal::DevServer(dev_server) => {
237 result.field("dev_server", &dev_server.id);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct UserSession(Session);
245
246impl UserSession {
247 pub fn new(s: Session) -> Option<Self> {
248 s.user_id().map(|_| UserSession(s))
249 }
250 pub fn user_id(&self) -> UserId {
251 self.0.user_id().unwrap()
252 }
253
254 pub fn email(&self) -> Option<String> {
255 match &self.0.principal {
256 Principal::User(user) => user.email_address.clone(),
257 Principal::Impersonated { user, .. } => user.email_address.clone(),
258 Principal::DevServer(..) => None,
259 }
260 }
261}
262
263impl Deref for UserSession {
264 type Target = Session;
265
266 fn deref(&self) -> &Self::Target {
267 &self.0
268 }
269}
270impl DerefMut for UserSession {
271 fn deref_mut(&mut self) -> &mut Self::Target {
272 &mut self.0
273 }
274}
275
276struct DevServerSession(Session);
277
278impl DevServerSession {
279 pub fn new(s: Session) -> Option<Self> {
280 s.dev_server_id().map(|_| DevServerSession(s))
281 }
282 pub fn dev_server_id(&self) -> DevServerId {
283 self.0.dev_server_id().unwrap()
284 }
285
286 fn dev_server(&self) -> &dev_server::Model {
287 match &self.0.principal {
288 Principal::DevServer(dev_server) => dev_server,
289 _ => unreachable!(),
290 }
291 }
292}
293
294impl Deref for DevServerSession {
295 type Target = Session;
296
297 fn deref(&self) -> &Self::Target {
298 &self.0
299 }
300}
301impl DerefMut for DevServerSession {
302 fn deref_mut(&mut self) -> &mut Self::Target {
303 &mut self.0
304 }
305}
306
307fn user_handler<M: RequestMessage, Fut>(
308 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
310where
311 Fut: Send + Future<Output = Result<()>>,
312{
313 let handler = Arc::new(handler);
314 move |message, response, session| {
315 let handler = handler.clone();
316 Box::pin(async move {
317 if let Some(user_session) = session.for_user() {
318 Ok(handler(message, response, user_session).await?)
319 } else {
320 Err(Error::Internal(anyhow!(
321 "must be a user to call {}",
322 M::NAME
323 )))
324 }
325 })
326 }
327}
328
329fn dev_server_handler<M: RequestMessage, Fut>(
330 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
332where
333 Fut: Send + Future<Output = Result<()>>,
334{
335 let handler = Arc::new(handler);
336 move |message, response, session| {
337 let handler = handler.clone();
338 Box::pin(async move {
339 if let Some(dev_server_session) = session.for_dev_server() {
340 Ok(handler(message, response, dev_server_session).await?)
341 } else {
342 Err(Error::Internal(anyhow!(
343 "must be a dev server to call {}",
344 M::NAME
345 )))
346 }
347 })
348 }
349}
350
351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
352 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
354where
355 InnertRetFut: Send + Future<Output = Result<()>>,
356{
357 let handler = Arc::new(handler);
358 move |message, session| {
359 let handler = handler.clone();
360 Box::pin(async move {
361 if let Some(user_session) = session.for_user() {
362 Ok(handler(message, user_session).await?)
363 } else {
364 Err(Error::Internal(anyhow!(
365 "must be a user to call {}",
366 M::NAME
367 )))
368 }
369 })
370 }
371}
372
373struct DbHandle(Arc<Database>);
374
375impl Deref for DbHandle {
376 type Target = Database;
377
378 fn deref(&self) -> &Self::Target {
379 self.0.as_ref()
380 }
381}
382
383pub struct Server {
384 id: parking_lot::Mutex<ServerId>,
385 peer: Arc<Peer>,
386 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
387 app_state: Arc<AppState>,
388 handlers: HashMap<TypeId, MessageHandler>,
389 teardown: watch::Sender<bool>,
390}
391
392pub(crate) struct ConnectionPoolGuard<'a> {
393 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
394 _not_send: PhantomData<Rc<()>>,
395}
396
397#[derive(Serialize)]
398pub struct ServerSnapshot<'a> {
399 peer: &'a Peer,
400 #[serde(serialize_with = "serialize_deref")]
401 connection_pool: ConnectionPoolGuard<'a>,
402}
403
404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
405where
406 S: Serializer,
407 T: Deref<Target = U>,
408 U: Serialize,
409{
410 Serialize::serialize(value.deref(), serializer)
411}
412
413impl Server {
414 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
415 let mut server = Self {
416 id: parking_lot::Mutex::new(id),
417 peer: Peer::new(id.0 as u32),
418 app_state: app_state.clone(),
419 connection_pool: Default::default(),
420 handlers: Default::default(),
421 teardown: watch::channel(false).0,
422 };
423
424 server
425 .add_request_handler(ping)
426 .add_request_handler(user_handler(create_room))
427 .add_request_handler(user_handler(join_room))
428 .add_request_handler(user_handler(rejoin_room))
429 .add_request_handler(user_handler(leave_room))
430 .add_request_handler(user_handler(set_room_participant_role))
431 .add_request_handler(user_handler(call))
432 .add_request_handler(user_handler(cancel_call))
433 .add_message_handler(user_message_handler(decline_call))
434 .add_request_handler(user_handler(update_participant_location))
435 .add_request_handler(user_handler(share_project))
436 .add_message_handler(unshare_project)
437 .add_request_handler(user_handler(join_project))
438 .add_request_handler(user_handler(join_hosted_project))
439 .add_request_handler(user_handler(rejoin_dev_server_projects))
440 .add_request_handler(user_handler(create_dev_server_project))
441 .add_request_handler(user_handler(update_dev_server_project))
442 .add_request_handler(user_handler(delete_dev_server_project))
443 .add_request_handler(user_handler(create_dev_server))
444 .add_request_handler(user_handler(regenerate_dev_server_token))
445 .add_request_handler(user_handler(rename_dev_server))
446 .add_request_handler(user_handler(delete_dev_server))
447 .add_request_handler(user_handler(list_remote_directory))
448 .add_request_handler(dev_server_handler(share_dev_server_project))
449 .add_request_handler(dev_server_handler(shutdown_dev_server))
450 .add_request_handler(dev_server_handler(reconnect_dev_server))
451 .add_message_handler(user_message_handler(leave_project))
452 .add_request_handler(update_project)
453 .add_request_handler(update_worktree)
454 .add_message_handler(start_language_server)
455 .add_message_handler(update_language_server)
456 .add_message_handler(update_diagnostic_summary)
457 .add_message_handler(update_worktree_settings)
458 .add_request_handler(user_handler(
459 forward_project_request_for_owner::<proto::TaskContextForLocation>,
460 ))
461 .add_request_handler(user_handler(
462 forward_project_request_for_owner::<proto::TaskTemplates>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetHover>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetDefinition>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetTypeDefinition>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetReferences>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::SearchProject>,
478 ))
479 .add_request_handler(user_handler(forward_find_search_candidates_request))
480 .add_request_handler(user_handler(
481 forward_read_only_project_request::<proto::GetDocumentHighlights>,
482 ))
483 .add_request_handler(user_handler(
484 forward_read_only_project_request::<proto::GetProjectSymbols>,
485 ))
486 .add_request_handler(user_handler(
487 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
488 ))
489 .add_request_handler(user_handler(
490 forward_read_only_project_request::<proto::OpenBufferById>,
491 ))
492 .add_request_handler(user_handler(
493 forward_read_only_project_request::<proto::SynchronizeBuffers>,
494 ))
495 .add_request_handler(user_handler(
496 forward_read_only_project_request::<proto::InlayHints>,
497 ))
498 .add_request_handler(user_handler(
499 forward_read_only_project_request::<proto::OpenBufferByPath>,
500 ))
501 .add_request_handler(user_handler(
502 forward_mutating_project_request::<proto::GetCompletions>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::OpenNewBuffer>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::GetCodeActions>,
515 ))
516 .add_request_handler(user_handler(
517 forward_mutating_project_request::<proto::ApplyCodeAction>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::PrepareRename>,
521 ))
522 .add_request_handler(user_handler(
523 forward_mutating_project_request::<proto::PerformRename>,
524 ))
525 .add_request_handler(user_handler(
526 forward_mutating_project_request::<proto::ReloadBuffers>,
527 ))
528 .add_request_handler(user_handler(
529 forward_mutating_project_request::<proto::FormatBuffers>,
530 ))
531 .add_request_handler(user_handler(
532 forward_mutating_project_request::<proto::CreateProjectEntry>,
533 ))
534 .add_request_handler(user_handler(
535 forward_mutating_project_request::<proto::RenameProjectEntry>,
536 ))
537 .add_request_handler(user_handler(
538 forward_mutating_project_request::<proto::CopyProjectEntry>,
539 ))
540 .add_request_handler(user_handler(
541 forward_mutating_project_request::<proto::DeleteProjectEntry>,
542 ))
543 .add_request_handler(user_handler(
544 forward_mutating_project_request::<proto::ExpandProjectEntry>,
545 ))
546 .add_request_handler(user_handler(
547 forward_mutating_project_request::<proto::OnTypeFormatting>,
548 ))
549 .add_request_handler(user_handler(
550 forward_mutating_project_request::<proto::SaveBuffer>,
551 ))
552 .add_request_handler(user_handler(
553 forward_mutating_project_request::<proto::BlameBuffer>,
554 ))
555 .add_request_handler(user_handler(
556 forward_mutating_project_request::<proto::MultiLspQuery>,
557 ))
558 .add_request_handler(user_handler(
559 forward_mutating_project_request::<proto::RestartLanguageServers>,
560 ))
561 .add_request_handler(user_handler(
562 forward_mutating_project_request::<proto::LinkedEditingRange>,
563 ))
564 .add_message_handler(create_buffer_for_peer)
565 .add_request_handler(update_buffer)
566 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
567 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
568 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
571 .add_request_handler(get_users)
572 .add_request_handler(user_handler(fuzzy_search_users))
573 .add_request_handler(user_handler(request_contact))
574 .add_request_handler(user_handler(remove_contact))
575 .add_request_handler(user_handler(respond_to_contact_request))
576 .add_message_handler(subscribe_to_channels)
577 .add_request_handler(user_handler(create_channel))
578 .add_request_handler(user_handler(delete_channel))
579 .add_request_handler(user_handler(invite_channel_member))
580 .add_request_handler(user_handler(remove_channel_member))
581 .add_request_handler(user_handler(set_channel_member_role))
582 .add_request_handler(user_handler(set_channel_visibility))
583 .add_request_handler(user_handler(rename_channel))
584 .add_request_handler(user_handler(join_channel_buffer))
585 .add_request_handler(user_handler(leave_channel_buffer))
586 .add_message_handler(user_message_handler(update_channel_buffer))
587 .add_request_handler(user_handler(rejoin_channel_buffers))
588 .add_request_handler(user_handler(get_channel_members))
589 .add_request_handler(user_handler(respond_to_channel_invite))
590 .add_request_handler(user_handler(join_channel))
591 .add_request_handler(user_handler(join_channel_chat))
592 .add_message_handler(user_message_handler(leave_channel_chat))
593 .add_request_handler(user_handler(send_channel_message))
594 .add_request_handler(user_handler(remove_channel_message))
595 .add_request_handler(user_handler(update_channel_message))
596 .add_request_handler(user_handler(get_channel_messages))
597 .add_request_handler(user_handler(get_channel_messages_by_id))
598 .add_request_handler(user_handler(get_notifications))
599 .add_request_handler(user_handler(mark_notification_as_read))
600 .add_request_handler(user_handler(move_channel))
601 .add_request_handler(user_handler(follow))
602 .add_message_handler(user_message_handler(unfollow))
603 .add_message_handler(user_message_handler(update_followers))
604 .add_request_handler(user_handler(get_private_user_info))
605 .add_request_handler(user_handler(get_llm_api_token))
606 .add_request_handler(user_handler(accept_terms_of_service))
607 .add_message_handler(user_message_handler(acknowledge_channel_message))
608 .add_message_handler(user_message_handler(acknowledge_buffer_version))
609 .add_request_handler(user_handler(get_supermaven_api_key))
610 .add_request_handler(user_handler(
611 forward_mutating_project_request::<proto::OpenContext>,
612 ))
613 .add_request_handler(user_handler(
614 forward_mutating_project_request::<proto::CreateContext>,
615 ))
616 .add_request_handler(user_handler(
617 forward_mutating_project_request::<proto::SynchronizeContexts>,
618 ))
619 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
620 .add_message_handler(update_context)
621 .add_request_handler({
622 let app_state = app_state.clone();
623 move |request, response, session| {
624 let app_state = app_state.clone();
625 async move {
626 count_language_model_tokens(request, response, session, &app_state.config)
627 .await
628 }
629 }
630 })
631 .add_request_handler({
632 user_handler(move |request, response, session| {
633 get_cached_embeddings(request, response, session)
634 })
635 })
636 .add_request_handler({
637 let app_state = app_state.clone();
638 user_handler(move |request, response, session| {
639 compute_embeddings(
640 request,
641 response,
642 session,
643 app_state.config.openai_api_key.clone(),
644 )
645 })
646 });
647
648 Arc::new(server)
649 }
650
651 pub async fn start(&self) -> Result<()> {
652 let server_id = *self.id.lock();
653 let app_state = self.app_state.clone();
654 let peer = self.peer.clone();
655 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
656 let pool = self.connection_pool.clone();
657 let live_kit_client = self.app_state.live_kit_client.clone();
658
659 let span = info_span!("start server");
660 self.app_state.executor.spawn_detached(
661 async move {
662 tracing::info!("waiting for cleanup timeout");
663 timeout.await;
664 tracing::info!("cleanup timeout expired, retrieving stale rooms");
665 if let Some((room_ids, channel_ids)) = app_state
666 .db
667 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
668 .await
669 .trace_err()
670 {
671 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
672 tracing::info!(
673 stale_channel_buffer_count = channel_ids.len(),
674 "retrieved stale channel buffers"
675 );
676
677 for channel_id in channel_ids {
678 if let Some(refreshed_channel_buffer) = app_state
679 .db
680 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
681 .await
682 .trace_err()
683 {
684 for connection_id in refreshed_channel_buffer.connection_ids {
685 peer.send(
686 connection_id,
687 proto::UpdateChannelBufferCollaborators {
688 channel_id: channel_id.to_proto(),
689 collaborators: refreshed_channel_buffer
690 .collaborators
691 .clone(),
692 },
693 )
694 .trace_err();
695 }
696 }
697 }
698
699 for room_id in room_ids {
700 let mut contacts_to_update = HashSet::default();
701 let mut canceled_calls_to_user_ids = Vec::new();
702 let mut live_kit_room = String::new();
703 let mut delete_live_kit_room = false;
704
705 if let Some(mut refreshed_room) = app_state
706 .db
707 .clear_stale_room_participants(room_id, server_id)
708 .await
709 .trace_err()
710 {
711 tracing::info!(
712 room_id = room_id.0,
713 new_participant_count = refreshed_room.room.participants.len(),
714 "refreshed room"
715 );
716 room_updated(&refreshed_room.room, &peer);
717 if let Some(channel) = refreshed_room.channel.as_ref() {
718 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
719 }
720 contacts_to_update
721 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
722 contacts_to_update
723 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
724 canceled_calls_to_user_ids =
725 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
726 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
727 delete_live_kit_room = refreshed_room.room.participants.is_empty();
728 }
729
730 {
731 let pool = pool.lock();
732 for canceled_user_id in canceled_calls_to_user_ids {
733 for connection_id in pool.user_connection_ids(canceled_user_id) {
734 peer.send(
735 connection_id,
736 proto::CallCanceled {
737 room_id: room_id.to_proto(),
738 },
739 )
740 .trace_err();
741 }
742 }
743 }
744
745 for user_id in contacts_to_update {
746 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
747 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
748 if let Some((busy, contacts)) = busy.zip(contacts) {
749 let pool = pool.lock();
750 let updated_contact = contact_for_user(user_id, busy, &pool);
751 for contact in contacts {
752 if let db::Contact::Accepted {
753 user_id: contact_user_id,
754 ..
755 } = contact
756 {
757 for contact_conn_id in
758 pool.user_connection_ids(contact_user_id)
759 {
760 peer.send(
761 contact_conn_id,
762 proto::UpdateContacts {
763 contacts: vec![updated_contact.clone()],
764 remove_contacts: Default::default(),
765 incoming_requests: Default::default(),
766 remove_incoming_requests: Default::default(),
767 outgoing_requests: Default::default(),
768 remove_outgoing_requests: Default::default(),
769 },
770 )
771 .trace_err();
772 }
773 }
774 }
775 }
776 }
777
778 if let Some(live_kit) = live_kit_client.as_ref() {
779 if delete_live_kit_room {
780 live_kit.delete_room(live_kit_room).await.trace_err();
781 }
782 }
783 }
784 }
785
786 app_state
787 .db
788 .delete_stale_servers(&app_state.config.zed_environment, server_id)
789 .await
790 .trace_err();
791 }
792 .instrument(span),
793 );
794 Ok(())
795 }
796
797 pub fn teardown(&self) {
798 self.peer.teardown();
799 self.connection_pool.lock().reset();
800 let _ = self.teardown.send(true);
801 }
802
803 #[cfg(test)]
804 pub fn reset(&self, id: ServerId) {
805 self.teardown();
806 *self.id.lock() = id;
807 self.peer.reset(id.0 as u32);
808 let _ = self.teardown.send(false);
809 }
810
811 #[cfg(test)]
812 pub fn id(&self) -> ServerId {
813 *self.id.lock()
814 }
815
816 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
817 where
818 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
819 Fut: 'static + Send + Future<Output = Result<()>>,
820 M: EnvelopedMessage,
821 {
822 let prev_handler = self.handlers.insert(
823 TypeId::of::<M>(),
824 Box::new(move |envelope, session| {
825 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
826 let received_at = envelope.received_at;
827 tracing::info!("message received");
828 let start_time = Instant::now();
829 let future = (handler)(*envelope, session);
830 async move {
831 let result = future.await;
832 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
833 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
834 let queue_duration_ms = total_duration_ms - processing_duration_ms;
835 let payload_type = M::NAME;
836
837 match result {
838 Err(error) => {
839 tracing::error!(
840 ?error,
841 total_duration_ms,
842 processing_duration_ms,
843 queue_duration_ms,
844 payload_type,
845 "error handling message"
846 )
847 }
848 Ok(()) => tracing::info!(
849 total_duration_ms,
850 processing_duration_ms,
851 queue_duration_ms,
852 "finished handling message"
853 ),
854 }
855 }
856 .boxed()
857 }),
858 );
859 if prev_handler.is_some() {
860 panic!("registered a handler for the same message twice");
861 }
862 self
863 }
864
865 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
866 where
867 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
868 Fut: 'static + Send + Future<Output = Result<()>>,
869 M: EnvelopedMessage,
870 {
871 self.add_handler(move |envelope, session| handler(envelope.payload, session));
872 self
873 }
874
875 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
876 where
877 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
878 Fut: Send + Future<Output = Result<()>>,
879 M: RequestMessage,
880 {
881 let handler = Arc::new(handler);
882 self.add_handler(move |envelope, session| {
883 let receipt = envelope.receipt();
884 let handler = handler.clone();
885 async move {
886 let peer = session.peer.clone();
887 let responded = Arc::new(AtomicBool::default());
888 let response = Response {
889 peer: peer.clone(),
890 responded: responded.clone(),
891 receipt,
892 };
893 match (handler)(envelope.payload, response, session).await {
894 Ok(()) => {
895 if responded.load(std::sync::atomic::Ordering::SeqCst) {
896 Ok(())
897 } else {
898 Err(anyhow!("handler did not send a response"))?
899 }
900 }
901 Err(error) => {
902 let proto_err = match &error {
903 Error::Internal(err) => err.to_proto(),
904 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
905 };
906 peer.respond_with_error(receipt, proto_err)?;
907 Err(error)
908 }
909 }
910 }
911 })
912 }
913
914 #[allow(clippy::too_many_arguments)]
915 pub fn handle_connection(
916 self: &Arc<Self>,
917 connection: Connection,
918 address: String,
919 principal: Principal,
920 zed_version: ZedVersion,
921 geoip_country_code: Option<String>,
922 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
923 executor: Executor,
924 ) -> impl Future<Output = ()> {
925 let this = self.clone();
926 let span = info_span!("handle connection", %address,
927 connection_id=field::Empty,
928 user_id=field::Empty,
929 login=field::Empty,
930 impersonator=field::Empty,
931 dev_server_id=field::Empty,
932 geoip_country_code=field::Empty
933 );
934 principal.update_span(&span);
935 if let Some(country_code) = geoip_country_code.as_ref() {
936 span.record("geoip_country_code", country_code);
937 }
938
939 let mut teardown = self.teardown.subscribe();
940 async move {
941 if *teardown.borrow() {
942 tracing::error!("server is tearing down");
943 return
944 }
945 let (connection_id, handle_io, mut incoming_rx) = this
946 .peer
947 .add_connection(connection, {
948 let executor = executor.clone();
949 move |duration| executor.sleep(duration)
950 });
951 tracing::Span::current().record("connection_id", format!("{}", connection_id));
952
953 tracing::info!("connection opened");
954
955 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
956 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
957 Ok(http_client) => Arc::new(http_client),
958 Err(error) => {
959 tracing::error!(?error, "failed to create HTTP client");
960 return;
961 }
962 };
963
964 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
965 Some(Arc::new(SupermavenAdminApi::new(
966 supermaven_admin_api_key.to_string(),
967 http_client.clone(),
968 )))
969 } else {
970 None
971 };
972
973 let session = Session {
974 principal: principal.clone(),
975 connection_id,
976 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
977 peer: this.peer.clone(),
978 connection_pool: this.connection_pool.clone(),
979 app_state: this.app_state.clone(),
980 http_client,
981 geoip_country_code,
982 _executor: executor.clone(),
983 supermaven_client,
984 };
985
986 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
987 tracing::error!(?error, "failed to send initial client update");
988 return;
989 }
990
991 let handle_io = handle_io.fuse();
992 futures::pin_mut!(handle_io);
993
994 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
995 // This prevents deadlocks when e.g., client A performs a request to client B and
996 // client B performs a request to client A. If both clients stop processing further
997 // messages until their respective request completes, they won't have a chance to
998 // respond to the other client's request and cause a deadlock.
999 //
1000 // This arrangement ensures we will attempt to process earlier messages first, but fall
1001 // back to processing messages arrived later in the spirit of making progress.
1002 let mut foreground_message_handlers = FuturesUnordered::new();
1003 let concurrent_handlers = Arc::new(Semaphore::new(256));
1004 loop {
1005 let next_message = async {
1006 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1007 let message = incoming_rx.next().await;
1008 (permit, message)
1009 }.fuse();
1010 futures::pin_mut!(next_message);
1011 futures::select_biased! {
1012 _ = teardown.changed().fuse() => return,
1013 result = handle_io => {
1014 if let Err(error) = result {
1015 tracing::error!(?error, "error handling I/O");
1016 }
1017 break;
1018 }
1019 _ = foreground_message_handlers.next() => {}
1020 next_message = next_message => {
1021 let (permit, message) = next_message;
1022 if let Some(message) = message {
1023 let type_name = message.payload_type_name();
1024 // note: we copy all the fields from the parent span so we can query them in the logs.
1025 // (https://github.com/tokio-rs/tracing/issues/2670).
1026 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1027 user_id=field::Empty,
1028 login=field::Empty,
1029 impersonator=field::Empty,
1030 dev_server_id=field::Empty
1031 );
1032 principal.update_span(&span);
1033 let span_enter = span.enter();
1034 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1035 let is_background = message.is_background();
1036 let handle_message = (handler)(message, session.clone());
1037 drop(span_enter);
1038
1039 let handle_message = async move {
1040 handle_message.await;
1041 drop(permit);
1042 }.instrument(span);
1043 if is_background {
1044 executor.spawn_detached(handle_message);
1045 } else {
1046 foreground_message_handlers.push(handle_message);
1047 }
1048 } else {
1049 tracing::error!("no message handler");
1050 }
1051 } else {
1052 tracing::info!("connection closed");
1053 break;
1054 }
1055 }
1056 }
1057 }
1058
1059 drop(foreground_message_handlers);
1060 tracing::info!("signing out");
1061 if let Err(error) = connection_lost(session, teardown, executor).await {
1062 tracing::error!(?error, "error signing out");
1063 }
1064
1065 }.instrument(span)
1066 }
1067
1068 async fn send_initial_client_update(
1069 &self,
1070 connection_id: ConnectionId,
1071 principal: &Principal,
1072 zed_version: ZedVersion,
1073 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1074 session: &Session,
1075 ) -> Result<()> {
1076 self.peer.send(
1077 connection_id,
1078 proto::Hello {
1079 peer_id: Some(connection_id.into()),
1080 },
1081 )?;
1082 tracing::info!("sent hello message");
1083 if let Some(send_connection_id) = send_connection_id.take() {
1084 let _ = send_connection_id.send(connection_id);
1085 }
1086
1087 match principal {
1088 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1089 if !user.connected_once {
1090 self.peer.send(connection_id, proto::ShowContacts {})?;
1091 self.app_state
1092 .db
1093 .set_user_connected_once(user.id, true)
1094 .await?;
1095 }
1096
1097 update_user_plan(user.id, session).await?;
1098
1099 let (contacts, dev_server_projects) = future::try_join(
1100 self.app_state.db.get_contacts(user.id),
1101 self.app_state.db.dev_server_projects_update(user.id),
1102 )
1103 .await?;
1104
1105 {
1106 let mut pool = self.connection_pool.lock();
1107 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1108 self.peer.send(
1109 connection_id,
1110 build_initial_contacts_update(contacts, &pool),
1111 )?;
1112 }
1113
1114 if should_auto_subscribe_to_channels(zed_version) {
1115 subscribe_user_to_channels(user.id, session).await?;
1116 }
1117
1118 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1119
1120 if let Some(incoming_call) =
1121 self.app_state.db.incoming_call_for_user(user.id).await?
1122 {
1123 self.peer.send(connection_id, incoming_call)?;
1124 }
1125
1126 update_user_contacts(user.id, &session).await?;
1127 }
1128 Principal::DevServer(dev_server) => {
1129 {
1130 let mut pool = self.connection_pool.lock();
1131 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1132 {
1133 self.peer.send(
1134 stale_connection_id,
1135 proto::ShutdownDevServer {
1136 reason: Some(
1137 "another dev server connected with the same token".to_string(),
1138 ),
1139 },
1140 )?;
1141 pool.remove_connection(stale_connection_id)?;
1142 };
1143 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1144 }
1145
1146 let projects = self
1147 .app_state
1148 .db
1149 .get_projects_for_dev_server(dev_server.id)
1150 .await?;
1151 self.peer
1152 .send(connection_id, proto::DevServerInstructions { projects })?;
1153
1154 let status = self
1155 .app_state
1156 .db
1157 .dev_server_projects_update(dev_server.user_id)
1158 .await?;
1159 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1160 }
1161 }
1162
1163 Ok(())
1164 }
1165
1166 pub async fn invite_code_redeemed(
1167 self: &Arc<Self>,
1168 inviter_id: UserId,
1169 invitee_id: UserId,
1170 ) -> Result<()> {
1171 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1172 if let Some(code) = &user.invite_code {
1173 let pool = self.connection_pool.lock();
1174 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1175 for connection_id in pool.user_connection_ids(inviter_id) {
1176 self.peer.send(
1177 connection_id,
1178 proto::UpdateContacts {
1179 contacts: vec![invitee_contact.clone()],
1180 ..Default::default()
1181 },
1182 )?;
1183 self.peer.send(
1184 connection_id,
1185 proto::UpdateInviteInfo {
1186 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1187 count: user.invite_count as u32,
1188 },
1189 )?;
1190 }
1191 }
1192 }
1193 Ok(())
1194 }
1195
1196 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1197 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1198 if let Some(invite_code) = &user.invite_code {
1199 let pool = self.connection_pool.lock();
1200 for connection_id in pool.user_connection_ids(user_id) {
1201 self.peer.send(
1202 connection_id,
1203 proto::UpdateInviteInfo {
1204 url: format!(
1205 "{}{}",
1206 self.app_state.config.invite_link_prefix, invite_code
1207 ),
1208 count: user.invite_count as u32,
1209 },
1210 )?;
1211 }
1212 }
1213 }
1214 Ok(())
1215 }
1216
1217 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1218 ServerSnapshot {
1219 connection_pool: ConnectionPoolGuard {
1220 guard: self.connection_pool.lock(),
1221 _not_send: PhantomData,
1222 },
1223 peer: &self.peer,
1224 }
1225 }
1226}
1227
1228impl<'a> Deref for ConnectionPoolGuard<'a> {
1229 type Target = ConnectionPool;
1230
1231 fn deref(&self) -> &Self::Target {
1232 &self.guard
1233 }
1234}
1235
1236impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1237 fn deref_mut(&mut self) -> &mut Self::Target {
1238 &mut self.guard
1239 }
1240}
1241
1242impl<'a> Drop for ConnectionPoolGuard<'a> {
1243 fn drop(&mut self) {
1244 #[cfg(test)]
1245 self.check_invariants();
1246 }
1247}
1248
1249fn broadcast<F>(
1250 sender_id: Option<ConnectionId>,
1251 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1252 mut f: F,
1253) where
1254 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1255{
1256 for receiver_id in receiver_ids {
1257 if Some(receiver_id) != sender_id {
1258 if let Err(error) = f(receiver_id) {
1259 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1260 }
1261 }
1262 }
1263}
1264
1265pub struct ProtocolVersion(u32);
1266
1267impl Header for ProtocolVersion {
1268 fn name() -> &'static HeaderName {
1269 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1270 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1271 }
1272
1273 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1274 where
1275 Self: Sized,
1276 I: Iterator<Item = &'i axum::http::HeaderValue>,
1277 {
1278 let version = values
1279 .next()
1280 .ok_or_else(axum::headers::Error::invalid)?
1281 .to_str()
1282 .map_err(|_| axum::headers::Error::invalid())?
1283 .parse()
1284 .map_err(|_| axum::headers::Error::invalid())?;
1285 Ok(Self(version))
1286 }
1287
1288 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1289 values.extend([self.0.to_string().parse().unwrap()]);
1290 }
1291}
1292
1293pub struct AppVersionHeader(SemanticVersion);
1294impl Header for AppVersionHeader {
1295 fn name() -> &'static HeaderName {
1296 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1298 }
1299
1300 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301 where
1302 Self: Sized,
1303 I: Iterator<Item = &'i axum::http::HeaderValue>,
1304 {
1305 let version = values
1306 .next()
1307 .ok_or_else(axum::headers::Error::invalid)?
1308 .to_str()
1309 .map_err(|_| axum::headers::Error::invalid())?
1310 .parse()
1311 .map_err(|_| axum::headers::Error::invalid())?;
1312 Ok(Self(version))
1313 }
1314
1315 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316 values.extend([self.0.to_string().parse().unwrap()]);
1317 }
1318}
1319
1320pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1321 Router::new()
1322 .route("/rpc", get(handle_websocket_request))
1323 .layer(
1324 ServiceBuilder::new()
1325 .layer(Extension(server.app_state.clone()))
1326 .layer(middleware::from_fn(auth::validate_header)),
1327 )
1328 .route("/metrics", get(handle_metrics))
1329 .layer(Extension(server))
1330}
1331
1332pub async fn handle_websocket_request(
1333 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1334 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1335 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1336 Extension(server): Extension<Arc<Server>>,
1337 Extension(principal): Extension<Principal>,
1338 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1339 ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341 if protocol_version != rpc::PROTOCOL_VERSION {
1342 return (
1343 StatusCode::UPGRADE_REQUIRED,
1344 "client must be upgraded".to_string(),
1345 )
1346 .into_response();
1347 }
1348
1349 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350 return (
1351 StatusCode::UPGRADE_REQUIRED,
1352 "no version header found".to_string(),
1353 )
1354 .into_response();
1355 };
1356
1357 if !version.can_collaborate() {
1358 return (
1359 StatusCode::UPGRADE_REQUIRED,
1360 "client must be upgraded".to_string(),
1361 )
1362 .into_response();
1363 }
1364
1365 let socket_address = socket_address.to_string();
1366 ws.on_upgrade(move |socket| {
1367 let socket = socket
1368 .map_ok(to_tungstenite_message)
1369 .err_into()
1370 .with(|message| async move { to_axum_message(message) });
1371 let connection = Connection::new(Box::pin(socket));
1372 async move {
1373 server
1374 .handle_connection(
1375 connection,
1376 socket_address,
1377 principal,
1378 version,
1379 country_code_header.map(|header| header.to_string()),
1380 None,
1381 Executor::Production,
1382 )
1383 .await;
1384 }
1385 })
1386}
1387
1388pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1389 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1390 let connections_metric = CONNECTIONS_METRIC
1391 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1392
1393 let connections = server
1394 .connection_pool
1395 .lock()
1396 .connections()
1397 .filter(|connection| !connection.admin)
1398 .count();
1399 connections_metric.set(connections as _);
1400
1401 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1402 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1403 register_int_gauge!(
1404 "shared_projects",
1405 "number of open projects with one or more guests"
1406 )
1407 .unwrap()
1408 });
1409
1410 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1411 shared_projects_metric.set(shared_projects as _);
1412
1413 let encoder = prometheus::TextEncoder::new();
1414 let metric_families = prometheus::gather();
1415 let encoded_metrics = encoder
1416 .encode_to_string(&metric_families)
1417 .map_err(|err| anyhow!("{}", err))?;
1418 Ok(encoded_metrics)
1419}
1420
1421#[instrument(err, skip(executor))]
1422async fn connection_lost(
1423 session: Session,
1424 mut teardown: watch::Receiver<bool>,
1425 executor: Executor,
1426) -> Result<()> {
1427 session.peer.disconnect(session.connection_id);
1428 session
1429 .connection_pool()
1430 .await
1431 .remove_connection(session.connection_id)?;
1432
1433 session
1434 .db()
1435 .await
1436 .connection_lost(session.connection_id)
1437 .await
1438 .trace_err();
1439
1440 futures::select_biased! {
1441 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1442 match &session.principal {
1443 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1444 let session = session.for_user().unwrap();
1445
1446 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1447 leave_room_for_session(&session, session.connection_id).await.trace_err();
1448 leave_channel_buffers_for_session(&session)
1449 .await
1450 .trace_err();
1451
1452 if !session
1453 .connection_pool()
1454 .await
1455 .is_user_online(session.user_id())
1456 {
1457 let db = session.db().await;
1458 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1459 room_updated(&room, &session.peer);
1460 }
1461 }
1462
1463 update_user_contacts(session.user_id(), &session).await?;
1464 },
1465 Principal::DevServer(_) => {
1466 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1467 },
1468 }
1469 },
1470 _ = teardown.changed().fuse() => {}
1471 }
1472
1473 Ok(())
1474}
1475
1476/// Acknowledges a ping from a client, used to keep the connection alive.
1477async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1478 response.send(proto::Ack {})?;
1479 Ok(())
1480}
1481
1482/// Creates a new room for calling (outside of channels)
1483async fn create_room(
1484 _request: proto::CreateRoom,
1485 response: Response<proto::CreateRoom>,
1486 session: UserSession,
1487) -> Result<()> {
1488 let live_kit_room = nanoid::nanoid!(30);
1489
1490 let live_kit_connection_info = util::maybe!(async {
1491 let live_kit = session.app_state.live_kit_client.as_ref();
1492 let live_kit = live_kit?;
1493 let user_id = session.user_id().to_string();
1494
1495 let token = live_kit
1496 .room_token(&live_kit_room, &user_id.to_string())
1497 .trace_err()?;
1498
1499 Some(proto::LiveKitConnectionInfo {
1500 server_url: live_kit.url().into(),
1501 token,
1502 can_publish: true,
1503 })
1504 })
1505 .await;
1506
1507 let room = session
1508 .db()
1509 .await
1510 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1511 .await?;
1512
1513 response.send(proto::CreateRoomResponse {
1514 room: Some(room.clone()),
1515 live_kit_connection_info,
1516 })?;
1517
1518 update_user_contacts(session.user_id(), &session).await?;
1519 Ok(())
1520}
1521
1522/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1523async fn join_room(
1524 request: proto::JoinRoom,
1525 response: Response<proto::JoinRoom>,
1526 session: UserSession,
1527) -> Result<()> {
1528 let room_id = RoomId::from_proto(request.id);
1529
1530 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1531
1532 if let Some(channel_id) = channel_id {
1533 return join_channel_internal(channel_id, Box::new(response), session).await;
1534 }
1535
1536 let joined_room = {
1537 let room = session
1538 .db()
1539 .await
1540 .join_room(room_id, session.user_id(), session.connection_id)
1541 .await?;
1542 room_updated(&room.room, &session.peer);
1543 room.into_inner()
1544 };
1545
1546 for connection_id in session
1547 .connection_pool()
1548 .await
1549 .user_connection_ids(session.user_id())
1550 {
1551 session
1552 .peer
1553 .send(
1554 connection_id,
1555 proto::CallCanceled {
1556 room_id: room_id.to_proto(),
1557 },
1558 )
1559 .trace_err();
1560 }
1561
1562 let live_kit_connection_info =
1563 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1564 if let Some(token) = live_kit
1565 .room_token(
1566 &joined_room.room.live_kit_room,
1567 &session.user_id().to_string(),
1568 )
1569 .trace_err()
1570 {
1571 Some(proto::LiveKitConnectionInfo {
1572 server_url: live_kit.url().into(),
1573 token,
1574 can_publish: true,
1575 })
1576 } else {
1577 None
1578 }
1579 } else {
1580 None
1581 };
1582
1583 response.send(proto::JoinRoomResponse {
1584 room: Some(joined_room.room),
1585 channel_id: None,
1586 live_kit_connection_info,
1587 })?;
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595 request: proto::RejoinRoom,
1596 response: Response<proto::RejoinRoom>,
1597 session: UserSession,
1598) -> Result<()> {
1599 let room;
1600 let channel;
1601 {
1602 let mut rejoined_room = session
1603 .db()
1604 .await
1605 .rejoin_room(request, session.user_id(), session.connection_id)
1606 .await?;
1607
1608 response.send(proto::RejoinRoomResponse {
1609 room: Some(rejoined_room.room.clone()),
1610 reshared_projects: rejoined_room
1611 .reshared_projects
1612 .iter()
1613 .map(|project| proto::ResharedProject {
1614 id: project.id.to_proto(),
1615 collaborators: project
1616 .collaborators
1617 .iter()
1618 .map(|collaborator| collaborator.to_proto())
1619 .collect(),
1620 })
1621 .collect(),
1622 rejoined_projects: rejoined_room
1623 .rejoined_projects
1624 .iter()
1625 .map(|rejoined_project| rejoined_project.to_proto())
1626 .collect(),
1627 })?;
1628 room_updated(&rejoined_room.room, &session.peer);
1629
1630 for project in &rejoined_room.reshared_projects {
1631 for collaborator in &project.collaborators {
1632 session
1633 .peer
1634 .send(
1635 collaborator.connection_id,
1636 proto::UpdateProjectCollaborator {
1637 project_id: project.id.to_proto(),
1638 old_peer_id: Some(project.old_connection_id.into()),
1639 new_peer_id: Some(session.connection_id.into()),
1640 },
1641 )
1642 .trace_err();
1643 }
1644
1645 broadcast(
1646 Some(session.connection_id),
1647 project
1648 .collaborators
1649 .iter()
1650 .map(|collaborator| collaborator.connection_id),
1651 |connection_id| {
1652 session.peer.forward_send(
1653 session.connection_id,
1654 connection_id,
1655 proto::UpdateProject {
1656 project_id: project.id.to_proto(),
1657 worktrees: project.worktrees.clone(),
1658 },
1659 )
1660 },
1661 );
1662 }
1663
1664 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666 let rejoined_room = rejoined_room.into_inner();
1667
1668 room = rejoined_room.room;
1669 channel = rejoined_room.channel;
1670 }
1671
1672 if let Some(channel) = channel {
1673 channel_updated(
1674 &channel,
1675 &room,
1676 &session.peer,
1677 &*session.connection_pool().await,
1678 );
1679 }
1680
1681 update_user_contacts(session.user_id(), &session).await?;
1682 Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686 rejoined_projects: &mut Vec<RejoinedProject>,
1687 session: &UserSession,
1688) -> Result<()> {
1689 for project in rejoined_projects.iter() {
1690 for collaborator in &project.collaborators {
1691 session
1692 .peer
1693 .send(
1694 collaborator.connection_id,
1695 proto::UpdateProjectCollaborator {
1696 project_id: project.id.to_proto(),
1697 old_peer_id: Some(project.old_connection_id.into()),
1698 new_peer_id: Some(session.connection_id.into()),
1699 },
1700 )
1701 .trace_err();
1702 }
1703 }
1704
1705 for project in rejoined_projects {
1706 for worktree in mem::take(&mut project.worktrees) {
1707 #[cfg(any(test, feature = "test-support"))]
1708 const MAX_CHUNK_SIZE: usize = 2;
1709 #[cfg(not(any(test, feature = "test-support")))]
1710 const MAX_CHUNK_SIZE: usize = 256;
1711
1712 // Stream this worktree's entries.
1713 let message = proto::UpdateWorktree {
1714 project_id: project.id.to_proto(),
1715 worktree_id: worktree.id,
1716 abs_path: worktree.abs_path.clone(),
1717 root_name: worktree.root_name,
1718 updated_entries: worktree.updated_entries,
1719 removed_entries: worktree.removed_entries,
1720 scan_id: worktree.scan_id,
1721 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722 updated_repositories: worktree.updated_repositories,
1723 removed_repositories: worktree.removed_repositories,
1724 };
1725 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726 session.peer.send(session.connection_id, update.clone())?;
1727 }
1728
1729 // Stream this worktree's diagnostics.
1730 for summary in worktree.diagnostic_summaries {
1731 session.peer.send(
1732 session.connection_id,
1733 proto::UpdateDiagnosticSummary {
1734 project_id: project.id.to_proto(),
1735 worktree_id: worktree.id,
1736 summary: Some(summary),
1737 },
1738 )?;
1739 }
1740
1741 for settings_file in worktree.settings_files {
1742 session.peer.send(
1743 session.connection_id,
1744 proto::UpdateWorktreeSettings {
1745 project_id: project.id.to_proto(),
1746 worktree_id: worktree.id,
1747 path: settings_file.path,
1748 content: Some(settings_file.content),
1749 },
1750 )?;
1751 }
1752 }
1753
1754 for language_server in &project.language_servers {
1755 session.peer.send(
1756 session.connection_id,
1757 proto::UpdateLanguageServer {
1758 project_id: project.id.to_proto(),
1759 language_server_id: language_server.id,
1760 variant: Some(
1761 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1762 proto::LspDiskBasedDiagnosticsUpdated {},
1763 ),
1764 ),
1765 },
1766 )?;
1767 }
1768 }
1769 Ok(())
1770}
1771
1772/// leave room disconnects from the room.
1773async fn leave_room(
1774 _: proto::LeaveRoom,
1775 response: Response<proto::LeaveRoom>,
1776 session: UserSession,
1777) -> Result<()> {
1778 leave_room_for_session(&session, session.connection_id).await?;
1779 response.send(proto::Ack {})?;
1780 Ok(())
1781}
1782
1783/// Updates the permissions of someone else in the room.
1784async fn set_room_participant_role(
1785 request: proto::SetRoomParticipantRole,
1786 response: Response<proto::SetRoomParticipantRole>,
1787 session: UserSession,
1788) -> Result<()> {
1789 let user_id = UserId::from_proto(request.user_id);
1790 let role = ChannelRole::from(request.role());
1791
1792 let (live_kit_room, can_publish) = {
1793 let room = session
1794 .db()
1795 .await
1796 .set_room_participant_role(
1797 session.user_id(),
1798 RoomId::from_proto(request.room_id),
1799 user_id,
1800 role,
1801 )
1802 .await?;
1803
1804 let live_kit_room = room.live_kit_room.clone();
1805 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1806 room_updated(&room, &session.peer);
1807 (live_kit_room, can_publish)
1808 };
1809
1810 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1811 live_kit
1812 .update_participant(
1813 live_kit_room.clone(),
1814 request.user_id.to_string(),
1815 live_kit_server::proto::ParticipantPermission {
1816 can_subscribe: true,
1817 can_publish,
1818 can_publish_data: can_publish,
1819 hidden: false,
1820 recorder: false,
1821 },
1822 )
1823 .await
1824 .trace_err();
1825 }
1826
1827 response.send(proto::Ack {})?;
1828 Ok(())
1829}
1830
1831/// Call someone else into the current room
1832async fn call(
1833 request: proto::Call,
1834 response: Response<proto::Call>,
1835 session: UserSession,
1836) -> Result<()> {
1837 let room_id = RoomId::from_proto(request.room_id);
1838 let calling_user_id = session.user_id();
1839 let calling_connection_id = session.connection_id;
1840 let called_user_id = UserId::from_proto(request.called_user_id);
1841 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1842 if !session
1843 .db()
1844 .await
1845 .has_contact(calling_user_id, called_user_id)
1846 .await?
1847 {
1848 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1849 }
1850
1851 let incoming_call = {
1852 let (room, incoming_call) = &mut *session
1853 .db()
1854 .await
1855 .call(
1856 room_id,
1857 calling_user_id,
1858 calling_connection_id,
1859 called_user_id,
1860 initial_project_id,
1861 )
1862 .await?;
1863 room_updated(&room, &session.peer);
1864 mem::take(incoming_call)
1865 };
1866 update_user_contacts(called_user_id, &session).await?;
1867
1868 let mut calls = session
1869 .connection_pool()
1870 .await
1871 .user_connection_ids(called_user_id)
1872 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1873 .collect::<FuturesUnordered<_>>();
1874
1875 while let Some(call_response) = calls.next().await {
1876 match call_response.as_ref() {
1877 Ok(_) => {
1878 response.send(proto::Ack {})?;
1879 return Ok(());
1880 }
1881 Err(_) => {
1882 call_response.trace_err();
1883 }
1884 }
1885 }
1886
1887 {
1888 let room = session
1889 .db()
1890 .await
1891 .call_failed(room_id, called_user_id)
1892 .await?;
1893 room_updated(&room, &session.peer);
1894 }
1895 update_user_contacts(called_user_id, &session).await?;
1896
1897 Err(anyhow!("failed to ring user"))?
1898}
1899
1900/// Cancel an outgoing call.
1901async fn cancel_call(
1902 request: proto::CancelCall,
1903 response: Response<proto::CancelCall>,
1904 session: UserSession,
1905) -> Result<()> {
1906 let called_user_id = UserId::from_proto(request.called_user_id);
1907 let room_id = RoomId::from_proto(request.room_id);
1908 {
1909 let room = session
1910 .db()
1911 .await
1912 .cancel_call(room_id, session.connection_id, called_user_id)
1913 .await?;
1914 room_updated(&room, &session.peer);
1915 }
1916
1917 for connection_id in session
1918 .connection_pool()
1919 .await
1920 .user_connection_ids(called_user_id)
1921 {
1922 session
1923 .peer
1924 .send(
1925 connection_id,
1926 proto::CallCanceled {
1927 room_id: room_id.to_proto(),
1928 },
1929 )
1930 .trace_err();
1931 }
1932 response.send(proto::Ack {})?;
1933
1934 update_user_contacts(called_user_id, &session).await?;
1935 Ok(())
1936}
1937
1938/// Decline an incoming call.
1939async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1940 let room_id = RoomId::from_proto(message.room_id);
1941 {
1942 let room = session
1943 .db()
1944 .await
1945 .decline_call(Some(room_id), session.user_id())
1946 .await?
1947 .ok_or_else(|| anyhow!("failed to decline call"))?;
1948 room_updated(&room, &session.peer);
1949 }
1950
1951 for connection_id in session
1952 .connection_pool()
1953 .await
1954 .user_connection_ids(session.user_id())
1955 {
1956 session
1957 .peer
1958 .send(
1959 connection_id,
1960 proto::CallCanceled {
1961 room_id: room_id.to_proto(),
1962 },
1963 )
1964 .trace_err();
1965 }
1966 update_user_contacts(session.user_id(), &session).await?;
1967 Ok(())
1968}
1969
1970/// Updates other participants in the room with your current location.
1971async fn update_participant_location(
1972 request: proto::UpdateParticipantLocation,
1973 response: Response<proto::UpdateParticipantLocation>,
1974 session: UserSession,
1975) -> Result<()> {
1976 let room_id = RoomId::from_proto(request.room_id);
1977 let location = request
1978 .location
1979 .ok_or_else(|| anyhow!("invalid location"))?;
1980
1981 let db = session.db().await;
1982 let room = db
1983 .update_room_participant_location(room_id, session.connection_id, location)
1984 .await?;
1985
1986 room_updated(&room, &session.peer);
1987 response.send(proto::Ack {})?;
1988 Ok(())
1989}
1990
1991/// Share a project into the room.
1992async fn share_project(
1993 request: proto::ShareProject,
1994 response: Response<proto::ShareProject>,
1995 session: UserSession,
1996) -> Result<()> {
1997 let (project_id, room) = &*session
1998 .db()
1999 .await
2000 .share_project(
2001 RoomId::from_proto(request.room_id),
2002 session.connection_id,
2003 &request.worktrees,
2004 request
2005 .dev_server_project_id
2006 .map(|id| DevServerProjectId::from_proto(id)),
2007 )
2008 .await?;
2009 response.send(proto::ShareProjectResponse {
2010 project_id: project_id.to_proto(),
2011 })?;
2012 room_updated(&room, &session.peer);
2013
2014 Ok(())
2015}
2016
2017/// Unshare a project from the room.
2018async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2019 let project_id = ProjectId::from_proto(message.project_id);
2020 unshare_project_internal(
2021 project_id,
2022 session.connection_id,
2023 session.user_id(),
2024 &session,
2025 )
2026 .await
2027}
2028
2029async fn unshare_project_internal(
2030 project_id: ProjectId,
2031 connection_id: ConnectionId,
2032 user_id: Option<UserId>,
2033 session: &Session,
2034) -> Result<()> {
2035 let delete = {
2036 let room_guard = session
2037 .db()
2038 .await
2039 .unshare_project(project_id, connection_id, user_id)
2040 .await?;
2041
2042 let (delete, room, guest_connection_ids) = &*room_guard;
2043
2044 let message = proto::UnshareProject {
2045 project_id: project_id.to_proto(),
2046 };
2047
2048 broadcast(
2049 Some(connection_id),
2050 guest_connection_ids.iter().copied(),
2051 |conn_id| session.peer.send(conn_id, message.clone()),
2052 );
2053 if let Some(room) = room {
2054 room_updated(room, &session.peer);
2055 }
2056
2057 *delete
2058 };
2059
2060 if delete {
2061 let db = session.db().await;
2062 db.delete_project(project_id).await?;
2063 }
2064
2065 Ok(())
2066}
2067
2068/// DevServer makes a project available online
2069async fn share_dev_server_project(
2070 request: proto::ShareDevServerProject,
2071 response: Response<proto::ShareDevServerProject>,
2072 session: DevServerSession,
2073) -> Result<()> {
2074 let (dev_server_project, user_id, status) = session
2075 .db()
2076 .await
2077 .share_dev_server_project(
2078 DevServerProjectId::from_proto(request.dev_server_project_id),
2079 session.dev_server_id(),
2080 session.connection_id,
2081 &request.worktrees,
2082 )
2083 .await?;
2084 let Some(project_id) = dev_server_project.project_id else {
2085 return Err(anyhow!("failed to share remote project"))?;
2086 };
2087
2088 send_dev_server_projects_update(user_id, status, &session).await;
2089
2090 response.send(proto::ShareProjectResponse { project_id })?;
2091
2092 Ok(())
2093}
2094
2095/// Join someone elses shared project.
2096async fn join_project(
2097 request: proto::JoinProject,
2098 response: Response<proto::JoinProject>,
2099 session: UserSession,
2100) -> Result<()> {
2101 let project_id = ProjectId::from_proto(request.project_id);
2102
2103 tracing::info!(%project_id, "join project");
2104
2105 let db = session.db().await;
2106 let (project, replica_id) = &mut *db
2107 .join_project(project_id, session.connection_id, session.user_id())
2108 .await?;
2109 drop(db);
2110 tracing::info!(%project_id, "join remote project");
2111 join_project_internal(response, session, project, replica_id)
2112}
2113
2114trait JoinProjectInternalResponse {
2115 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2118 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119 Response::<proto::JoinProject>::send(self, result)
2120 }
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2123 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124 Response::<proto::JoinHostedProject>::send(self, result)
2125 }
2126}
2127
2128fn join_project_internal(
2129 response: impl JoinProjectInternalResponse,
2130 session: UserSession,
2131 project: &mut Project,
2132 replica_id: &ReplicaId,
2133) -> Result<()> {
2134 let collaborators = project
2135 .collaborators
2136 .iter()
2137 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2138 .map(|collaborator| collaborator.to_proto())
2139 .collect::<Vec<_>>();
2140 let project_id = project.id;
2141 let guest_user_id = session.user_id();
2142
2143 let worktrees = project
2144 .worktrees
2145 .iter()
2146 .map(|(id, worktree)| proto::WorktreeMetadata {
2147 id: *id,
2148 root_name: worktree.root_name.clone(),
2149 visible: worktree.visible,
2150 abs_path: worktree.abs_path.clone(),
2151 })
2152 .collect::<Vec<_>>();
2153
2154 let add_project_collaborator = proto::AddProjectCollaborator {
2155 project_id: project_id.to_proto(),
2156 collaborator: Some(proto::Collaborator {
2157 peer_id: Some(session.connection_id.into()),
2158 replica_id: replica_id.0 as u32,
2159 user_id: guest_user_id.to_proto(),
2160 }),
2161 };
2162
2163 for collaborator in &collaborators {
2164 session
2165 .peer
2166 .send(
2167 collaborator.peer_id.unwrap().into(),
2168 add_project_collaborator.clone(),
2169 )
2170 .trace_err();
2171 }
2172
2173 // First, we send the metadata associated with each worktree.
2174 response.send(proto::JoinProjectResponse {
2175 project_id: project.id.0 as u64,
2176 worktrees: worktrees.clone(),
2177 replica_id: replica_id.0 as u32,
2178 collaborators: collaborators.clone(),
2179 language_servers: project.language_servers.clone(),
2180 role: project.role.into(),
2181 dev_server_project_id: project
2182 .dev_server_project_id
2183 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2184 })?;
2185
2186 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2187 #[cfg(any(test, feature = "test-support"))]
2188 const MAX_CHUNK_SIZE: usize = 2;
2189 #[cfg(not(any(test, feature = "test-support")))]
2190 const MAX_CHUNK_SIZE: usize = 256;
2191
2192 // Stream this worktree's entries.
2193 let message = proto::UpdateWorktree {
2194 project_id: project_id.to_proto(),
2195 worktree_id,
2196 abs_path: worktree.abs_path.clone(),
2197 root_name: worktree.root_name,
2198 updated_entries: worktree.entries,
2199 removed_entries: Default::default(),
2200 scan_id: worktree.scan_id,
2201 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2202 updated_repositories: worktree.repository_entries.into_values().collect(),
2203 removed_repositories: Default::default(),
2204 };
2205 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2206 session.peer.send(session.connection_id, update.clone())?;
2207 }
2208
2209 // Stream this worktree's diagnostics.
2210 for summary in worktree.diagnostic_summaries {
2211 session.peer.send(
2212 session.connection_id,
2213 proto::UpdateDiagnosticSummary {
2214 project_id: project_id.to_proto(),
2215 worktree_id: worktree.id,
2216 summary: Some(summary),
2217 },
2218 )?;
2219 }
2220
2221 for settings_file in worktree.settings_files {
2222 session.peer.send(
2223 session.connection_id,
2224 proto::UpdateWorktreeSettings {
2225 project_id: project_id.to_proto(),
2226 worktree_id: worktree.id,
2227 path: settings_file.path,
2228 content: Some(settings_file.content),
2229 },
2230 )?;
2231 }
2232 }
2233
2234 for language_server in &project.language_servers {
2235 session.peer.send(
2236 session.connection_id,
2237 proto::UpdateLanguageServer {
2238 project_id: project_id.to_proto(),
2239 language_server_id: language_server.id,
2240 variant: Some(
2241 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2242 proto::LspDiskBasedDiagnosticsUpdated {},
2243 ),
2244 ),
2245 },
2246 )?;
2247 }
2248
2249 Ok(())
2250}
2251
2252/// Leave someone elses shared project.
2253async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2254 let sender_id = session.connection_id;
2255 let project_id = ProjectId::from_proto(request.project_id);
2256 let db = session.db().await;
2257 if db.is_hosted_project(project_id).await? {
2258 let project = db.leave_hosted_project(project_id, sender_id).await?;
2259 project_left(&project, &session);
2260 return Ok(());
2261 }
2262
2263 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2264 tracing::info!(
2265 %project_id,
2266 "leave project"
2267 );
2268
2269 project_left(&project, &session);
2270 if let Some(room) = room {
2271 room_updated(&room, &session.peer);
2272 }
2273
2274 Ok(())
2275}
2276
2277async fn join_hosted_project(
2278 request: proto::JoinHostedProject,
2279 response: Response<proto::JoinHostedProject>,
2280 session: UserSession,
2281) -> Result<()> {
2282 let (mut project, replica_id) = session
2283 .db()
2284 .await
2285 .join_hosted_project(
2286 ProjectId(request.project_id as i32),
2287 session.user_id(),
2288 session.connection_id,
2289 )
2290 .await?;
2291
2292 join_project_internal(response, session, &mut project, &replica_id)
2293}
2294
2295async fn list_remote_directory(
2296 request: proto::ListRemoteDirectory,
2297 response: Response<proto::ListRemoteDirectory>,
2298 session: UserSession,
2299) -> Result<()> {
2300 let dev_server_id = DevServerId(request.dev_server_id as i32);
2301 let dev_server_connection_id = session
2302 .connection_pool()
2303 .await
2304 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2305
2306 session
2307 .db()
2308 .await
2309 .get_dev_server_for_user(dev_server_id, session.user_id())
2310 .await?;
2311
2312 response.send(
2313 session
2314 .peer
2315 .forward_request(session.connection_id, dev_server_connection_id, request)
2316 .await?,
2317 )?;
2318 Ok(())
2319}
2320
2321async fn update_dev_server_project(
2322 request: proto::UpdateDevServerProject,
2323 response: Response<proto::UpdateDevServerProject>,
2324 session: UserSession,
2325) -> Result<()> {
2326 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2327
2328 let (dev_server_project, update) = session
2329 .db()
2330 .await
2331 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2332 .await?;
2333
2334 let projects = session
2335 .db()
2336 .await
2337 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2338 .await?;
2339
2340 let dev_server_connection_id = session
2341 .connection_pool()
2342 .await
2343 .dev_server_connection_id_supporting(
2344 dev_server_project.dev_server_id,
2345 ZedVersion::with_list_directory(),
2346 )?;
2347
2348 session.peer.send(
2349 dev_server_connection_id,
2350 proto::DevServerInstructions { projects },
2351 )?;
2352
2353 send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355 response.send(proto::Ack {})
2356}
2357
2358async fn create_dev_server_project(
2359 request: proto::CreateDevServerProject,
2360 response: Response<proto::CreateDevServerProject>,
2361 session: UserSession,
2362) -> Result<()> {
2363 let dev_server_id = DevServerId(request.dev_server_id as i32);
2364 let dev_server_connection_id = session
2365 .connection_pool()
2366 .await
2367 .dev_server_connection_id(dev_server_id);
2368 let Some(dev_server_connection_id) = dev_server_connection_id else {
2369 Err(ErrorCode::DevServerOffline
2370 .message("Cannot create a remote project when the dev server is offline".to_string())
2371 .anyhow())?
2372 };
2373
2374 let path = request.path.clone();
2375 //Check that the path exists on the dev server
2376 session
2377 .peer
2378 .forward_request(
2379 session.connection_id,
2380 dev_server_connection_id,
2381 proto::ValidateDevServerProjectRequest { path: path.clone() },
2382 )
2383 .await?;
2384
2385 let (dev_server_project, update) = session
2386 .db()
2387 .await
2388 .create_dev_server_project(
2389 DevServerId(request.dev_server_id as i32),
2390 &request.path,
2391 session.user_id(),
2392 )
2393 .await?;
2394
2395 let projects = session
2396 .db()
2397 .await
2398 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2399 .await?;
2400
2401 session.peer.send(
2402 dev_server_connection_id,
2403 proto::DevServerInstructions { projects },
2404 )?;
2405
2406 send_dev_server_projects_update(session.user_id(), update, &session).await;
2407
2408 response.send(proto::CreateDevServerProjectResponse {
2409 dev_server_project: Some(dev_server_project.to_proto(None)),
2410 })?;
2411 Ok(())
2412}
2413
2414async fn create_dev_server(
2415 request: proto::CreateDevServer,
2416 response: Response<proto::CreateDevServer>,
2417 session: UserSession,
2418) -> Result<()> {
2419 let access_token = auth::random_token();
2420 let hashed_access_token = auth::hash_access_token(&access_token);
2421
2422 if request.name.is_empty() {
2423 return Err(proto::ErrorCode::Forbidden
2424 .message("Dev server name cannot be empty".to_string())
2425 .anyhow())?;
2426 }
2427
2428 let (dev_server, status) = session
2429 .db()
2430 .await
2431 .create_dev_server(
2432 &request.name,
2433 request.ssh_connection_string.as_deref(),
2434 &hashed_access_token,
2435 session.user_id(),
2436 )
2437 .await?;
2438
2439 send_dev_server_projects_update(session.user_id(), status, &session).await;
2440
2441 response.send(proto::CreateDevServerResponse {
2442 dev_server_id: dev_server.id.0 as u64,
2443 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2444 name: request.name,
2445 })?;
2446 Ok(())
2447}
2448
2449async fn regenerate_dev_server_token(
2450 request: proto::RegenerateDevServerToken,
2451 response: Response<proto::RegenerateDevServerToken>,
2452 session: UserSession,
2453) -> Result<()> {
2454 let dev_server_id = DevServerId(request.dev_server_id as i32);
2455 let access_token = auth::random_token();
2456 let hashed_access_token = auth::hash_access_token(&access_token);
2457
2458 let connection_id = session
2459 .connection_pool()
2460 .await
2461 .dev_server_connection_id(dev_server_id);
2462 if let Some(connection_id) = connection_id {
2463 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2464 session.peer.send(
2465 connection_id,
2466 proto::ShutdownDevServer {
2467 reason: Some("dev server token was regenerated".to_string()),
2468 },
2469 )?;
2470 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2471 }
2472
2473 let status = session
2474 .db()
2475 .await
2476 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2477 .await?;
2478
2479 send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481 response.send(proto::RegenerateDevServerTokenResponse {
2482 dev_server_id: dev_server_id.to_proto(),
2483 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2484 })?;
2485 Ok(())
2486}
2487
2488async fn rename_dev_server(
2489 request: proto::RenameDevServer,
2490 response: Response<proto::RenameDevServer>,
2491 session: UserSession,
2492) -> Result<()> {
2493 if request.name.trim().is_empty() {
2494 return Err(proto::ErrorCode::Forbidden
2495 .message("Dev server name cannot be empty".to_string())
2496 .anyhow())?;
2497 }
2498
2499 let dev_server_id = DevServerId(request.dev_server_id as i32);
2500 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2501 if dev_server.user_id != session.user_id() {
2502 return Err(anyhow!(ErrorCode::Forbidden))?;
2503 }
2504
2505 let status = session
2506 .db()
2507 .await
2508 .rename_dev_server(
2509 dev_server_id,
2510 &request.name,
2511 request.ssh_connection_string.as_deref(),
2512 session.user_id(),
2513 )
2514 .await?;
2515
2516 send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518 response.send(proto::Ack {})?;
2519 Ok(())
2520}
2521
2522async fn delete_dev_server(
2523 request: proto::DeleteDevServer,
2524 response: Response<proto::DeleteDevServer>,
2525 session: UserSession,
2526) -> Result<()> {
2527 let dev_server_id = DevServerId(request.dev_server_id as i32);
2528 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2529 if dev_server.user_id != session.user_id() {
2530 return Err(anyhow!(ErrorCode::Forbidden))?;
2531 }
2532
2533 let connection_id = session
2534 .connection_pool()
2535 .await
2536 .dev_server_connection_id(dev_server_id);
2537 if let Some(connection_id) = connection_id {
2538 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2539 session.peer.send(
2540 connection_id,
2541 proto::ShutdownDevServer {
2542 reason: Some("dev server was deleted".to_string()),
2543 },
2544 )?;
2545 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2546 }
2547
2548 let status = session
2549 .db()
2550 .await
2551 .delete_dev_server(dev_server_id, session.user_id())
2552 .await?;
2553
2554 send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556 response.send(proto::Ack {})?;
2557 Ok(())
2558}
2559
2560async fn delete_dev_server_project(
2561 request: proto::DeleteDevServerProject,
2562 response: Response<proto::DeleteDevServerProject>,
2563 session: UserSession,
2564) -> Result<()> {
2565 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2566 let dev_server_project = session
2567 .db()
2568 .await
2569 .get_dev_server_project(dev_server_project_id)
2570 .await?;
2571
2572 let dev_server = session
2573 .db()
2574 .await
2575 .get_dev_server(dev_server_project.dev_server_id)
2576 .await?;
2577 if dev_server.user_id != session.user_id() {
2578 return Err(anyhow!(ErrorCode::Forbidden))?;
2579 }
2580
2581 let dev_server_connection_id = session
2582 .connection_pool()
2583 .await
2584 .dev_server_connection_id(dev_server.id);
2585
2586 if let Some(dev_server_connection_id) = dev_server_connection_id {
2587 let project = session
2588 .db()
2589 .await
2590 .find_dev_server_project(dev_server_project_id)
2591 .await;
2592 if let Ok(project) = project {
2593 unshare_project_internal(
2594 project.id,
2595 dev_server_connection_id,
2596 Some(session.user_id()),
2597 &session,
2598 )
2599 .await?;
2600 }
2601 }
2602
2603 let (projects, status) = session
2604 .db()
2605 .await
2606 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2607 .await?;
2608
2609 if let Some(dev_server_connection_id) = dev_server_connection_id {
2610 session.peer.send(
2611 dev_server_connection_id,
2612 proto::DevServerInstructions { projects },
2613 )?;
2614 }
2615
2616 send_dev_server_projects_update(session.user_id(), status, &session).await;
2617
2618 response.send(proto::Ack {})?;
2619 Ok(())
2620}
2621
2622async fn rejoin_dev_server_projects(
2623 request: proto::RejoinRemoteProjects,
2624 response: Response<proto::RejoinRemoteProjects>,
2625 session: UserSession,
2626) -> Result<()> {
2627 let mut rejoined_projects = {
2628 let db = session.db().await;
2629 db.rejoin_dev_server_projects(
2630 &request.rejoined_projects,
2631 session.user_id(),
2632 session.0.connection_id,
2633 )
2634 .await?
2635 };
2636 response.send(proto::RejoinRemoteProjectsResponse {
2637 rejoined_projects: rejoined_projects
2638 .iter()
2639 .map(|project| project.to_proto())
2640 .collect(),
2641 })?;
2642 notify_rejoined_projects(&mut rejoined_projects, &session)
2643}
2644
2645async fn reconnect_dev_server(
2646 request: proto::ReconnectDevServer,
2647 response: Response<proto::ReconnectDevServer>,
2648 session: DevServerSession,
2649) -> Result<()> {
2650 let reshared_projects = {
2651 let db = session.db().await;
2652 db.reshare_dev_server_projects(
2653 &request.reshared_projects,
2654 session.dev_server_id(),
2655 session.0.connection_id,
2656 )
2657 .await?
2658 };
2659
2660 for project in &reshared_projects {
2661 for collaborator in &project.collaborators {
2662 session
2663 .peer
2664 .send(
2665 collaborator.connection_id,
2666 proto::UpdateProjectCollaborator {
2667 project_id: project.id.to_proto(),
2668 old_peer_id: Some(project.old_connection_id.into()),
2669 new_peer_id: Some(session.connection_id.into()),
2670 },
2671 )
2672 .trace_err();
2673 }
2674
2675 broadcast(
2676 Some(session.connection_id),
2677 project
2678 .collaborators
2679 .iter()
2680 .map(|collaborator| collaborator.connection_id),
2681 |connection_id| {
2682 session.peer.forward_send(
2683 session.connection_id,
2684 connection_id,
2685 proto::UpdateProject {
2686 project_id: project.id.to_proto(),
2687 worktrees: project.worktrees.clone(),
2688 },
2689 )
2690 },
2691 );
2692 }
2693
2694 response.send(proto::ReconnectDevServerResponse {
2695 reshared_projects: reshared_projects
2696 .iter()
2697 .map(|project| proto::ResharedProject {
2698 id: project.id.to_proto(),
2699 collaborators: project
2700 .collaborators
2701 .iter()
2702 .map(|collaborator| collaborator.to_proto())
2703 .collect(),
2704 })
2705 .collect(),
2706 })?;
2707
2708 Ok(())
2709}
2710
2711async fn shutdown_dev_server(
2712 _: proto::ShutdownDevServer,
2713 response: Response<proto::ShutdownDevServer>,
2714 session: DevServerSession,
2715) -> Result<()> {
2716 response.send(proto::Ack {})?;
2717 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2718 remove_dev_server_connection(session.dev_server_id(), &session).await
2719}
2720
2721async fn shutdown_dev_server_internal(
2722 dev_server_id: DevServerId,
2723 connection_id: ConnectionId,
2724 session: &Session,
2725) -> Result<()> {
2726 let (dev_server_projects, dev_server) = {
2727 let db = session.db().await;
2728 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2729 let dev_server = db.get_dev_server(dev_server_id).await?;
2730 (dev_server_projects, dev_server)
2731 };
2732
2733 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2734 unshare_project_internal(
2735 ProjectId::from_proto(project_id),
2736 connection_id,
2737 None,
2738 session,
2739 )
2740 .await?;
2741 }
2742
2743 session
2744 .connection_pool()
2745 .await
2746 .set_dev_server_offline(dev_server_id);
2747
2748 let status = session
2749 .db()
2750 .await
2751 .dev_server_projects_update(dev_server.user_id)
2752 .await?;
2753 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2754
2755 Ok(())
2756}
2757
2758async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2759 let dev_server_connection = session
2760 .connection_pool()
2761 .await
2762 .dev_server_connection_id(dev_server_id);
2763
2764 if let Some(dev_server_connection) = dev_server_connection {
2765 session
2766 .connection_pool()
2767 .await
2768 .remove_connection(dev_server_connection)?;
2769 }
2770 Ok(())
2771}
2772
2773/// Updates other participants with changes to the project
2774async fn update_project(
2775 request: proto::UpdateProject,
2776 response: Response<proto::UpdateProject>,
2777 session: Session,
2778) -> Result<()> {
2779 let project_id = ProjectId::from_proto(request.project_id);
2780 let (room, guest_connection_ids) = &*session
2781 .db()
2782 .await
2783 .update_project(project_id, session.connection_id, &request.worktrees)
2784 .await?;
2785 broadcast(
2786 Some(session.connection_id),
2787 guest_connection_ids.iter().copied(),
2788 |connection_id| {
2789 session
2790 .peer
2791 .forward_send(session.connection_id, connection_id, request.clone())
2792 },
2793 );
2794 if let Some(room) = room {
2795 room_updated(&room, &session.peer);
2796 }
2797 response.send(proto::Ack {})?;
2798
2799 Ok(())
2800}
2801
2802/// Updates other participants with changes to the worktree
2803async fn update_worktree(
2804 request: proto::UpdateWorktree,
2805 response: Response<proto::UpdateWorktree>,
2806 session: Session,
2807) -> Result<()> {
2808 let guest_connection_ids = session
2809 .db()
2810 .await
2811 .update_worktree(&request, session.connection_id)
2812 .await?;
2813
2814 broadcast(
2815 Some(session.connection_id),
2816 guest_connection_ids.iter().copied(),
2817 |connection_id| {
2818 session
2819 .peer
2820 .forward_send(session.connection_id, connection_id, request.clone())
2821 },
2822 );
2823 response.send(proto::Ack {})?;
2824 Ok(())
2825}
2826
2827/// Updates other participants with changes to the diagnostics
2828async fn update_diagnostic_summary(
2829 message: proto::UpdateDiagnosticSummary,
2830 session: Session,
2831) -> Result<()> {
2832 let guest_connection_ids = session
2833 .db()
2834 .await
2835 .update_diagnostic_summary(&message, session.connection_id)
2836 .await?;
2837
2838 broadcast(
2839 Some(session.connection_id),
2840 guest_connection_ids.iter().copied(),
2841 |connection_id| {
2842 session
2843 .peer
2844 .forward_send(session.connection_id, connection_id, message.clone())
2845 },
2846 );
2847
2848 Ok(())
2849}
2850
2851/// Updates other participants with changes to the worktree settings
2852async fn update_worktree_settings(
2853 message: proto::UpdateWorktreeSettings,
2854 session: Session,
2855) -> Result<()> {
2856 let guest_connection_ids = session
2857 .db()
2858 .await
2859 .update_worktree_settings(&message, session.connection_id)
2860 .await?;
2861
2862 broadcast(
2863 Some(session.connection_id),
2864 guest_connection_ids.iter().copied(),
2865 |connection_id| {
2866 session
2867 .peer
2868 .forward_send(session.connection_id, connection_id, message.clone())
2869 },
2870 );
2871
2872 Ok(())
2873}
2874
2875/// Notify other participants that a language server has started.
2876async fn start_language_server(
2877 request: proto::StartLanguageServer,
2878 session: Session,
2879) -> Result<()> {
2880 let guest_connection_ids = session
2881 .db()
2882 .await
2883 .start_language_server(&request, session.connection_id)
2884 .await?;
2885
2886 broadcast(
2887 Some(session.connection_id),
2888 guest_connection_ids.iter().copied(),
2889 |connection_id| {
2890 session
2891 .peer
2892 .forward_send(session.connection_id, connection_id, request.clone())
2893 },
2894 );
2895 Ok(())
2896}
2897
2898/// Notify other participants that a language server has changed.
2899async fn update_language_server(
2900 request: proto::UpdateLanguageServer,
2901 session: Session,
2902) -> Result<()> {
2903 let project_id = ProjectId::from_proto(request.project_id);
2904 let project_connection_ids = session
2905 .db()
2906 .await
2907 .project_connection_ids(project_id, session.connection_id, true)
2908 .await?;
2909 broadcast(
2910 Some(session.connection_id),
2911 project_connection_ids.iter().copied(),
2912 |connection_id| {
2913 session
2914 .peer
2915 .forward_send(session.connection_id, connection_id, request.clone())
2916 },
2917 );
2918 Ok(())
2919}
2920
2921/// forward a project request to the host. These requests should be read only
2922/// as guests are allowed to send them.
2923async fn forward_read_only_project_request<T>(
2924 request: T,
2925 response: Response<T>,
2926 session: UserSession,
2927) -> Result<()>
2928where
2929 T: EntityMessage + RequestMessage,
2930{
2931 let project_id = ProjectId::from_proto(request.remote_entity_id());
2932 let host_connection_id = session
2933 .db()
2934 .await
2935 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2936 .await?;
2937 let payload = session
2938 .peer
2939 .forward_request(session.connection_id, host_connection_id, request)
2940 .await?;
2941 response.send(payload)?;
2942 Ok(())
2943}
2944
2945async fn forward_find_search_candidates_request(
2946 request: proto::FindSearchCandidates,
2947 response: Response<proto::FindSearchCandidates>,
2948 session: UserSession,
2949) -> Result<()> {
2950 let project_id = ProjectId::from_proto(request.remote_entity_id());
2951 let host_connection_id = session
2952 .db()
2953 .await
2954 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2955 .await?;
2956
2957 let host_version = session
2958 .connection_pool()
2959 .await
2960 .connection(host_connection_id)
2961 .map(|c| c.zed_version);
2962
2963 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2964 {
2965 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2966 let search = proto::SearchProject {
2967 project_id: project_id.to_proto(),
2968 query: query.query,
2969 regex: query.regex,
2970 whole_word: query.whole_word,
2971 case_sensitive: query.case_sensitive,
2972 files_to_include: query.files_to_include,
2973 files_to_exclude: query.files_to_exclude,
2974 include_ignored: query.include_ignored,
2975 };
2976
2977 let payload = session
2978 .peer
2979 .forward_request(session.connection_id, host_connection_id, search)
2980 .await?;
2981 return response.send(proto::FindSearchCandidatesResponse {
2982 buffer_ids: payload
2983 .locations
2984 .into_iter()
2985 .map(|loc| loc.buffer_id)
2986 .collect(),
2987 });
2988 }
2989
2990 let payload = session
2991 .peer
2992 .forward_request(session.connection_id, host_connection_id, request)
2993 .await?;
2994 response.send(payload)?;
2995 Ok(())
2996}
2997
2998/// forward a project request to the dev server. Only allowed
2999/// if it's your dev server.
3000async fn forward_project_request_for_owner<T>(
3001 request: T,
3002 response: Response<T>,
3003 session: UserSession,
3004) -> Result<()>
3005where
3006 T: EntityMessage + RequestMessage,
3007{
3008 let project_id = ProjectId::from_proto(request.remote_entity_id());
3009
3010 let host_connection_id = session
3011 .db()
3012 .await
3013 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3014 .await?;
3015 let payload = session
3016 .peer
3017 .forward_request(session.connection_id, host_connection_id, request)
3018 .await?;
3019 response.send(payload)?;
3020 Ok(())
3021}
3022
3023/// forward a project request to the host. These requests are disallowed
3024/// for guests.
3025async fn forward_mutating_project_request<T>(
3026 request: T,
3027 response: Response<T>,
3028 session: UserSession,
3029) -> Result<()>
3030where
3031 T: EntityMessage + RequestMessage,
3032{
3033 let project_id = ProjectId::from_proto(request.remote_entity_id());
3034
3035 let host_connection_id = session
3036 .db()
3037 .await
3038 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3039 .await?;
3040 let payload = session
3041 .peer
3042 .forward_request(session.connection_id, host_connection_id, request)
3043 .await?;
3044 response.send(payload)?;
3045 Ok(())
3046}
3047
3048/// Notify other participants that a new buffer has been created
3049async fn create_buffer_for_peer(
3050 request: proto::CreateBufferForPeer,
3051 session: Session,
3052) -> Result<()> {
3053 session
3054 .db()
3055 .await
3056 .check_user_is_project_host(
3057 ProjectId::from_proto(request.project_id),
3058 session.connection_id,
3059 )
3060 .await?;
3061 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3062 session
3063 .peer
3064 .forward_send(session.connection_id, peer_id.into(), request)?;
3065 Ok(())
3066}
3067
3068/// Notify other participants that a buffer has been updated. This is
3069/// allowed for guests as long as the update is limited to selections.
3070async fn update_buffer(
3071 request: proto::UpdateBuffer,
3072 response: Response<proto::UpdateBuffer>,
3073 session: Session,
3074) -> Result<()> {
3075 let project_id = ProjectId::from_proto(request.project_id);
3076 let mut capability = Capability::ReadOnly;
3077
3078 for op in request.operations.iter() {
3079 match op.variant {
3080 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3081 Some(_) => capability = Capability::ReadWrite,
3082 }
3083 }
3084
3085 let host = {
3086 let guard = session
3087 .db()
3088 .await
3089 .connections_for_buffer_update(
3090 project_id,
3091 session.principal_id(),
3092 session.connection_id,
3093 capability,
3094 )
3095 .await?;
3096
3097 let (host, guests) = &*guard;
3098
3099 broadcast(
3100 Some(session.connection_id),
3101 guests.clone(),
3102 |connection_id| {
3103 session
3104 .peer
3105 .forward_send(session.connection_id, connection_id, request.clone())
3106 },
3107 );
3108
3109 *host
3110 };
3111
3112 if host != session.connection_id {
3113 session
3114 .peer
3115 .forward_request(session.connection_id, host, request.clone())
3116 .await?;
3117 }
3118
3119 response.send(proto::Ack {})?;
3120 Ok(())
3121}
3122
3123async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3124 let project_id = ProjectId::from_proto(message.project_id);
3125
3126 let operation = message.operation.as_ref().context("invalid operation")?;
3127 let capability = match operation.variant.as_ref() {
3128 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3129 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3130 match buffer_op.variant {
3131 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3132 Capability::ReadOnly
3133 }
3134 _ => Capability::ReadWrite,
3135 }
3136 } else {
3137 Capability::ReadWrite
3138 }
3139 }
3140 Some(_) => Capability::ReadWrite,
3141 None => Capability::ReadOnly,
3142 };
3143
3144 let guard = session
3145 .db()
3146 .await
3147 .connections_for_buffer_update(
3148 project_id,
3149 session.principal_id(),
3150 session.connection_id,
3151 capability,
3152 )
3153 .await?;
3154
3155 let (host, guests) = &*guard;
3156
3157 broadcast(
3158 Some(session.connection_id),
3159 guests.iter().chain([host]).copied(),
3160 |connection_id| {
3161 session
3162 .peer
3163 .forward_send(session.connection_id, connection_id, message.clone())
3164 },
3165 );
3166
3167 Ok(())
3168}
3169
3170/// Notify other participants that a project has been updated.
3171async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3172 request: T,
3173 session: Session,
3174) -> Result<()> {
3175 let project_id = ProjectId::from_proto(request.remote_entity_id());
3176 let project_connection_ids = session
3177 .db()
3178 .await
3179 .project_connection_ids(project_id, session.connection_id, false)
3180 .await?;
3181
3182 broadcast(
3183 Some(session.connection_id),
3184 project_connection_ids.iter().copied(),
3185 |connection_id| {
3186 session
3187 .peer
3188 .forward_send(session.connection_id, connection_id, request.clone())
3189 },
3190 );
3191 Ok(())
3192}
3193
3194/// Start following another user in a call.
3195async fn follow(
3196 request: proto::Follow,
3197 response: Response<proto::Follow>,
3198 session: UserSession,
3199) -> Result<()> {
3200 let room_id = RoomId::from_proto(request.room_id);
3201 let project_id = request.project_id.map(ProjectId::from_proto);
3202 let leader_id = request
3203 .leader_id
3204 .ok_or_else(|| anyhow!("invalid leader id"))?
3205 .into();
3206 let follower_id = session.connection_id;
3207
3208 session
3209 .db()
3210 .await
3211 .check_room_participants(room_id, leader_id, session.connection_id)
3212 .await?;
3213
3214 let response_payload = session
3215 .peer
3216 .forward_request(session.connection_id, leader_id, request)
3217 .await?;
3218 response.send(response_payload)?;
3219
3220 if let Some(project_id) = project_id {
3221 let room = session
3222 .db()
3223 .await
3224 .follow(room_id, project_id, leader_id, follower_id)
3225 .await?;
3226 room_updated(&room, &session.peer);
3227 }
3228
3229 Ok(())
3230}
3231
3232/// Stop following another user in a call.
3233async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3234 let room_id = RoomId::from_proto(request.room_id);
3235 let project_id = request.project_id.map(ProjectId::from_proto);
3236 let leader_id = request
3237 .leader_id
3238 .ok_or_else(|| anyhow!("invalid leader id"))?
3239 .into();
3240 let follower_id = session.connection_id;
3241
3242 session
3243 .db()
3244 .await
3245 .check_room_participants(room_id, leader_id, session.connection_id)
3246 .await?;
3247
3248 session
3249 .peer
3250 .forward_send(session.connection_id, leader_id, request)?;
3251
3252 if let Some(project_id) = project_id {
3253 let room = session
3254 .db()
3255 .await
3256 .unfollow(room_id, project_id, leader_id, follower_id)
3257 .await?;
3258 room_updated(&room, &session.peer);
3259 }
3260
3261 Ok(())
3262}
3263
3264/// Notify everyone following you of your current location.
3265async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3266 let room_id = RoomId::from_proto(request.room_id);
3267 let database = session.db.lock().await;
3268
3269 let connection_ids = if let Some(project_id) = request.project_id {
3270 let project_id = ProjectId::from_proto(project_id);
3271 database
3272 .project_connection_ids(project_id, session.connection_id, true)
3273 .await?
3274 } else {
3275 database
3276 .room_connection_ids(room_id, session.connection_id)
3277 .await?
3278 };
3279
3280 // For now, don't send view update messages back to that view's current leader.
3281 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3282 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3283 _ => None,
3284 });
3285
3286 for connection_id in connection_ids.iter().cloned() {
3287 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3288 session
3289 .peer
3290 .forward_send(session.connection_id, connection_id, request.clone())?;
3291 }
3292 }
3293 Ok(())
3294}
3295
3296/// Get public data about users.
3297async fn get_users(
3298 request: proto::GetUsers,
3299 response: Response<proto::GetUsers>,
3300 session: Session,
3301) -> Result<()> {
3302 let user_ids = request
3303 .user_ids
3304 .into_iter()
3305 .map(UserId::from_proto)
3306 .collect();
3307 let users = session
3308 .db()
3309 .await
3310 .get_users_by_ids(user_ids)
3311 .await?
3312 .into_iter()
3313 .map(|user| proto::User {
3314 id: user.id.to_proto(),
3315 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3316 github_login: user.github_login,
3317 })
3318 .collect();
3319 response.send(proto::UsersResponse { users })?;
3320 Ok(())
3321}
3322
3323/// Search for users (to invite) buy Github login
3324async fn fuzzy_search_users(
3325 request: proto::FuzzySearchUsers,
3326 response: Response<proto::FuzzySearchUsers>,
3327 session: UserSession,
3328) -> Result<()> {
3329 let query = request.query;
3330 let users = match query.len() {
3331 0 => vec![],
3332 1 | 2 => session
3333 .db()
3334 .await
3335 .get_user_by_github_login(&query)
3336 .await?
3337 .into_iter()
3338 .collect(),
3339 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3340 };
3341 let users = users
3342 .into_iter()
3343 .filter(|user| user.id != session.user_id())
3344 .map(|user| proto::User {
3345 id: user.id.to_proto(),
3346 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3347 github_login: user.github_login,
3348 })
3349 .collect();
3350 response.send(proto::UsersResponse { users })?;
3351 Ok(())
3352}
3353
3354/// Send a contact request to another user.
3355async fn request_contact(
3356 request: proto::RequestContact,
3357 response: Response<proto::RequestContact>,
3358 session: UserSession,
3359) -> Result<()> {
3360 let requester_id = session.user_id();
3361 let responder_id = UserId::from_proto(request.responder_id);
3362 if requester_id == responder_id {
3363 return Err(anyhow!("cannot add yourself as a contact"))?;
3364 }
3365
3366 let notifications = session
3367 .db()
3368 .await
3369 .send_contact_request(requester_id, responder_id)
3370 .await?;
3371
3372 // Update outgoing contact requests of requester
3373 let mut update = proto::UpdateContacts::default();
3374 update.outgoing_requests.push(responder_id.to_proto());
3375 for connection_id in session
3376 .connection_pool()
3377 .await
3378 .user_connection_ids(requester_id)
3379 {
3380 session.peer.send(connection_id, update.clone())?;
3381 }
3382
3383 // Update incoming contact requests of responder
3384 let mut update = proto::UpdateContacts::default();
3385 update
3386 .incoming_requests
3387 .push(proto::IncomingContactRequest {
3388 requester_id: requester_id.to_proto(),
3389 });
3390 let connection_pool = session.connection_pool().await;
3391 for connection_id in connection_pool.user_connection_ids(responder_id) {
3392 session.peer.send(connection_id, update.clone())?;
3393 }
3394
3395 send_notifications(&connection_pool, &session.peer, notifications);
3396
3397 response.send(proto::Ack {})?;
3398 Ok(())
3399}
3400
3401/// Accept or decline a contact request
3402async fn respond_to_contact_request(
3403 request: proto::RespondToContactRequest,
3404 response: Response<proto::RespondToContactRequest>,
3405 session: UserSession,
3406) -> Result<()> {
3407 let responder_id = session.user_id();
3408 let requester_id = UserId::from_proto(request.requester_id);
3409 let db = session.db().await;
3410 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3411 db.dismiss_contact_notification(responder_id, requester_id)
3412 .await?;
3413 } else {
3414 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3415
3416 let notifications = db
3417 .respond_to_contact_request(responder_id, requester_id, accept)
3418 .await?;
3419 let requester_busy = db.is_user_busy(requester_id).await?;
3420 let responder_busy = db.is_user_busy(responder_id).await?;
3421
3422 let pool = session.connection_pool().await;
3423 // Update responder with new contact
3424 let mut update = proto::UpdateContacts::default();
3425 if accept {
3426 update
3427 .contacts
3428 .push(contact_for_user(requester_id, requester_busy, &pool));
3429 }
3430 update
3431 .remove_incoming_requests
3432 .push(requester_id.to_proto());
3433 for connection_id in pool.user_connection_ids(responder_id) {
3434 session.peer.send(connection_id, update.clone())?;
3435 }
3436
3437 // Update requester with new contact
3438 let mut update = proto::UpdateContacts::default();
3439 if accept {
3440 update
3441 .contacts
3442 .push(contact_for_user(responder_id, responder_busy, &pool));
3443 }
3444 update
3445 .remove_outgoing_requests
3446 .push(responder_id.to_proto());
3447
3448 for connection_id in pool.user_connection_ids(requester_id) {
3449 session.peer.send(connection_id, update.clone())?;
3450 }
3451
3452 send_notifications(&pool, &session.peer, notifications);
3453 }
3454
3455 response.send(proto::Ack {})?;
3456 Ok(())
3457}
3458
3459/// Remove a contact.
3460async fn remove_contact(
3461 request: proto::RemoveContact,
3462 response: Response<proto::RemoveContact>,
3463 session: UserSession,
3464) -> Result<()> {
3465 let requester_id = session.user_id();
3466 let responder_id = UserId::from_proto(request.user_id);
3467 let db = session.db().await;
3468 let (contact_accepted, deleted_notification_id) =
3469 db.remove_contact(requester_id, responder_id).await?;
3470
3471 let pool = session.connection_pool().await;
3472 // Update outgoing contact requests of requester
3473 let mut update = proto::UpdateContacts::default();
3474 if contact_accepted {
3475 update.remove_contacts.push(responder_id.to_proto());
3476 } else {
3477 update
3478 .remove_outgoing_requests
3479 .push(responder_id.to_proto());
3480 }
3481 for connection_id in pool.user_connection_ids(requester_id) {
3482 session.peer.send(connection_id, update.clone())?;
3483 }
3484
3485 // Update incoming contact requests of responder
3486 let mut update = proto::UpdateContacts::default();
3487 if contact_accepted {
3488 update.remove_contacts.push(requester_id.to_proto());
3489 } else {
3490 update
3491 .remove_incoming_requests
3492 .push(requester_id.to_proto());
3493 }
3494 for connection_id in pool.user_connection_ids(responder_id) {
3495 session.peer.send(connection_id, update.clone())?;
3496 if let Some(notification_id) = deleted_notification_id {
3497 session.peer.send(
3498 connection_id,
3499 proto::DeleteNotification {
3500 notification_id: notification_id.to_proto(),
3501 },
3502 )?;
3503 }
3504 }
3505
3506 response.send(proto::Ack {})?;
3507 Ok(())
3508}
3509
3510fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3511 version.0.minor() < 139
3512}
3513
3514async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3515 let plan = session.current_plan(session.db().await).await?;
3516
3517 session
3518 .peer
3519 .send(
3520 session.connection_id,
3521 proto::UpdateUserPlan { plan: plan.into() },
3522 )
3523 .trace_err();
3524
3525 Ok(())
3526}
3527
3528async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3529 subscribe_user_to_channels(
3530 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3531 &session,
3532 )
3533 .await?;
3534 Ok(())
3535}
3536
3537async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3538 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3539 let mut pool = session.connection_pool().await;
3540 for membership in &channels_for_user.channel_memberships {
3541 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3542 }
3543 session.peer.send(
3544 session.connection_id,
3545 build_update_user_channels(&channels_for_user),
3546 )?;
3547 session.peer.send(
3548 session.connection_id,
3549 build_channels_update(channels_for_user),
3550 )?;
3551 Ok(())
3552}
3553
3554/// Creates a new channel.
3555async fn create_channel(
3556 request: proto::CreateChannel,
3557 response: Response<proto::CreateChannel>,
3558 session: UserSession,
3559) -> Result<()> {
3560 let db = session.db().await;
3561
3562 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3563 let (channel, membership) = db
3564 .create_channel(&request.name, parent_id, session.user_id())
3565 .await?;
3566
3567 let root_id = channel.root_id();
3568 let channel = Channel::from_model(channel);
3569
3570 response.send(proto::CreateChannelResponse {
3571 channel: Some(channel.to_proto()),
3572 parent_id: request.parent_id,
3573 })?;
3574
3575 let mut connection_pool = session.connection_pool().await;
3576 if let Some(membership) = membership {
3577 connection_pool.subscribe_to_channel(
3578 membership.user_id,
3579 membership.channel_id,
3580 membership.role,
3581 );
3582 let update = proto::UpdateUserChannels {
3583 channel_memberships: vec![proto::ChannelMembership {
3584 channel_id: membership.channel_id.to_proto(),
3585 role: membership.role.into(),
3586 }],
3587 ..Default::default()
3588 };
3589 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3590 session.peer.send(connection_id, update.clone())?;
3591 }
3592 }
3593
3594 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3595 if !role.can_see_channel(channel.visibility) {
3596 continue;
3597 }
3598
3599 let update = proto::UpdateChannels {
3600 channels: vec![channel.to_proto()],
3601 ..Default::default()
3602 };
3603 session.peer.send(connection_id, update.clone())?;
3604 }
3605
3606 Ok(())
3607}
3608
3609/// Delete a channel
3610async fn delete_channel(
3611 request: proto::DeleteChannel,
3612 response: Response<proto::DeleteChannel>,
3613 session: UserSession,
3614) -> Result<()> {
3615 let db = session.db().await;
3616
3617 let channel_id = request.channel_id;
3618 let (root_channel, removed_channels) = db
3619 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3620 .await?;
3621 response.send(proto::Ack {})?;
3622
3623 // Notify members of removed channels
3624 let mut update = proto::UpdateChannels::default();
3625 update
3626 .delete_channels
3627 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3628
3629 let connection_pool = session.connection_pool().await;
3630 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3631 session.peer.send(connection_id, update.clone())?;
3632 }
3633
3634 Ok(())
3635}
3636
3637/// Invite someone to join a channel.
3638async fn invite_channel_member(
3639 request: proto::InviteChannelMember,
3640 response: Response<proto::InviteChannelMember>,
3641 session: UserSession,
3642) -> Result<()> {
3643 let db = session.db().await;
3644 let channel_id = ChannelId::from_proto(request.channel_id);
3645 let invitee_id = UserId::from_proto(request.user_id);
3646 let InviteMemberResult {
3647 channel,
3648 notifications,
3649 } = db
3650 .invite_channel_member(
3651 channel_id,
3652 invitee_id,
3653 session.user_id(),
3654 request.role().into(),
3655 )
3656 .await?;
3657
3658 let update = proto::UpdateChannels {
3659 channel_invitations: vec![channel.to_proto()],
3660 ..Default::default()
3661 };
3662
3663 let connection_pool = session.connection_pool().await;
3664 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3665 session.peer.send(connection_id, update.clone())?;
3666 }
3667
3668 send_notifications(&connection_pool, &session.peer, notifications);
3669
3670 response.send(proto::Ack {})?;
3671 Ok(())
3672}
3673
3674/// remove someone from a channel
3675async fn remove_channel_member(
3676 request: proto::RemoveChannelMember,
3677 response: Response<proto::RemoveChannelMember>,
3678 session: UserSession,
3679) -> Result<()> {
3680 let db = session.db().await;
3681 let channel_id = ChannelId::from_proto(request.channel_id);
3682 let member_id = UserId::from_proto(request.user_id);
3683
3684 let RemoveChannelMemberResult {
3685 membership_update,
3686 notification_id,
3687 } = db
3688 .remove_channel_member(channel_id, member_id, session.user_id())
3689 .await?;
3690
3691 let mut connection_pool = session.connection_pool().await;
3692 notify_membership_updated(
3693 &mut connection_pool,
3694 membership_update,
3695 member_id,
3696 &session.peer,
3697 );
3698 for connection_id in connection_pool.user_connection_ids(member_id) {
3699 if let Some(notification_id) = notification_id {
3700 session
3701 .peer
3702 .send(
3703 connection_id,
3704 proto::DeleteNotification {
3705 notification_id: notification_id.to_proto(),
3706 },
3707 )
3708 .trace_err();
3709 }
3710 }
3711
3712 response.send(proto::Ack {})?;
3713 Ok(())
3714}
3715
3716/// Toggle the channel between public and private.
3717/// Care is taken to maintain the invariant that public channels only descend from public channels,
3718/// (though members-only channels can appear at any point in the hierarchy).
3719async fn set_channel_visibility(
3720 request: proto::SetChannelVisibility,
3721 response: Response<proto::SetChannelVisibility>,
3722 session: UserSession,
3723) -> Result<()> {
3724 let db = session.db().await;
3725 let channel_id = ChannelId::from_proto(request.channel_id);
3726 let visibility = request.visibility().into();
3727
3728 let channel_model = db
3729 .set_channel_visibility(channel_id, visibility, session.user_id())
3730 .await?;
3731 let root_id = channel_model.root_id();
3732 let channel = Channel::from_model(channel_model);
3733
3734 let mut connection_pool = session.connection_pool().await;
3735 for (user_id, role) in connection_pool
3736 .channel_user_ids(root_id)
3737 .collect::<Vec<_>>()
3738 .into_iter()
3739 {
3740 let update = if role.can_see_channel(channel.visibility) {
3741 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3742 proto::UpdateChannels {
3743 channels: vec![channel.to_proto()],
3744 ..Default::default()
3745 }
3746 } else {
3747 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3748 proto::UpdateChannels {
3749 delete_channels: vec![channel.id.to_proto()],
3750 ..Default::default()
3751 }
3752 };
3753
3754 for connection_id in connection_pool.user_connection_ids(user_id) {
3755 session.peer.send(connection_id, update.clone())?;
3756 }
3757 }
3758
3759 response.send(proto::Ack {})?;
3760 Ok(())
3761}
3762
3763/// Alter the role for a user in the channel.
3764async fn set_channel_member_role(
3765 request: proto::SetChannelMemberRole,
3766 response: Response<proto::SetChannelMemberRole>,
3767 session: UserSession,
3768) -> Result<()> {
3769 let db = session.db().await;
3770 let channel_id = ChannelId::from_proto(request.channel_id);
3771 let member_id = UserId::from_proto(request.user_id);
3772 let result = db
3773 .set_channel_member_role(
3774 channel_id,
3775 session.user_id(),
3776 member_id,
3777 request.role().into(),
3778 )
3779 .await?;
3780
3781 match result {
3782 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3783 let mut connection_pool = session.connection_pool().await;
3784 notify_membership_updated(
3785 &mut connection_pool,
3786 membership_update,
3787 member_id,
3788 &session.peer,
3789 )
3790 }
3791 db::SetMemberRoleResult::InviteUpdated(channel) => {
3792 let update = proto::UpdateChannels {
3793 channel_invitations: vec![channel.to_proto()],
3794 ..Default::default()
3795 };
3796
3797 for connection_id in session
3798 .connection_pool()
3799 .await
3800 .user_connection_ids(member_id)
3801 {
3802 session.peer.send(connection_id, update.clone())?;
3803 }
3804 }
3805 }
3806
3807 response.send(proto::Ack {})?;
3808 Ok(())
3809}
3810
3811/// Change the name of a channel
3812async fn rename_channel(
3813 request: proto::RenameChannel,
3814 response: Response<proto::RenameChannel>,
3815 session: UserSession,
3816) -> Result<()> {
3817 let db = session.db().await;
3818 let channel_id = ChannelId::from_proto(request.channel_id);
3819 let channel_model = db
3820 .rename_channel(channel_id, session.user_id(), &request.name)
3821 .await?;
3822 let root_id = channel_model.root_id();
3823 let channel = Channel::from_model(channel_model);
3824
3825 response.send(proto::RenameChannelResponse {
3826 channel: Some(channel.to_proto()),
3827 })?;
3828
3829 let connection_pool = session.connection_pool().await;
3830 let update = proto::UpdateChannels {
3831 channels: vec![channel.to_proto()],
3832 ..Default::default()
3833 };
3834 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3835 if role.can_see_channel(channel.visibility) {
3836 session.peer.send(connection_id, update.clone())?;
3837 }
3838 }
3839
3840 Ok(())
3841}
3842
3843/// Move a channel to a new parent.
3844async fn move_channel(
3845 request: proto::MoveChannel,
3846 response: Response<proto::MoveChannel>,
3847 session: UserSession,
3848) -> Result<()> {
3849 let channel_id = ChannelId::from_proto(request.channel_id);
3850 let to = ChannelId::from_proto(request.to);
3851
3852 let (root_id, channels) = session
3853 .db()
3854 .await
3855 .move_channel(channel_id, to, session.user_id())
3856 .await?;
3857
3858 let connection_pool = session.connection_pool().await;
3859 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3860 let channels = channels
3861 .iter()
3862 .filter_map(|channel| {
3863 if role.can_see_channel(channel.visibility) {
3864 Some(channel.to_proto())
3865 } else {
3866 None
3867 }
3868 })
3869 .collect::<Vec<_>>();
3870 if channels.is_empty() {
3871 continue;
3872 }
3873
3874 let update = proto::UpdateChannels {
3875 channels,
3876 ..Default::default()
3877 };
3878
3879 session.peer.send(connection_id, update.clone())?;
3880 }
3881
3882 response.send(Ack {})?;
3883 Ok(())
3884}
3885
3886/// Get the list of channel members
3887async fn get_channel_members(
3888 request: proto::GetChannelMembers,
3889 response: Response<proto::GetChannelMembers>,
3890 session: UserSession,
3891) -> Result<()> {
3892 let db = session.db().await;
3893 let channel_id = ChannelId::from_proto(request.channel_id);
3894 let limit = if request.limit == 0 {
3895 u16::MAX as u64
3896 } else {
3897 request.limit
3898 };
3899 let (members, users) = db
3900 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3901 .await?;
3902 response.send(proto::GetChannelMembersResponse { members, users })?;
3903 Ok(())
3904}
3905
3906/// Accept or decline a channel invitation.
3907async fn respond_to_channel_invite(
3908 request: proto::RespondToChannelInvite,
3909 response: Response<proto::RespondToChannelInvite>,
3910 session: UserSession,
3911) -> Result<()> {
3912 let db = session.db().await;
3913 let channel_id = ChannelId::from_proto(request.channel_id);
3914 let RespondToChannelInvite {
3915 membership_update,
3916 notifications,
3917 } = db
3918 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3919 .await?;
3920
3921 let mut connection_pool = session.connection_pool().await;
3922 if let Some(membership_update) = membership_update {
3923 notify_membership_updated(
3924 &mut connection_pool,
3925 membership_update,
3926 session.user_id(),
3927 &session.peer,
3928 );
3929 } else {
3930 let update = proto::UpdateChannels {
3931 remove_channel_invitations: vec![channel_id.to_proto()],
3932 ..Default::default()
3933 };
3934
3935 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3936 session.peer.send(connection_id, update.clone())?;
3937 }
3938 };
3939
3940 send_notifications(&connection_pool, &session.peer, notifications);
3941
3942 response.send(proto::Ack {})?;
3943
3944 Ok(())
3945}
3946
3947/// Join the channels' room
3948async fn join_channel(
3949 request: proto::JoinChannel,
3950 response: Response<proto::JoinChannel>,
3951 session: UserSession,
3952) -> Result<()> {
3953 let channel_id = ChannelId::from_proto(request.channel_id);
3954 join_channel_internal(channel_id, Box::new(response), session).await
3955}
3956
3957trait JoinChannelInternalResponse {
3958 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3959}
3960impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3961 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3962 Response::<proto::JoinChannel>::send(self, result)
3963 }
3964}
3965impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3966 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3967 Response::<proto::JoinRoom>::send(self, result)
3968 }
3969}
3970
3971async fn join_channel_internal(
3972 channel_id: ChannelId,
3973 response: Box<impl JoinChannelInternalResponse>,
3974 session: UserSession,
3975) -> Result<()> {
3976 let joined_room = {
3977 let mut db = session.db().await;
3978 // If zed quits without leaving the room, and the user re-opens zed before the
3979 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3980 // room they were in.
3981 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3982 tracing::info!(
3983 stale_connection_id = %connection,
3984 "cleaning up stale connection",
3985 );
3986 drop(db);
3987 leave_room_for_session(&session, connection).await?;
3988 db = session.db().await;
3989 }
3990
3991 let (joined_room, membership_updated, role) = db
3992 .join_channel(channel_id, session.user_id(), session.connection_id)
3993 .await?;
3994
3995 let live_kit_connection_info =
3996 session
3997 .app_state
3998 .live_kit_client
3999 .as_ref()
4000 .and_then(|live_kit| {
4001 let (can_publish, token) = if role == ChannelRole::Guest {
4002 (
4003 false,
4004 live_kit
4005 .guest_token(
4006 &joined_room.room.live_kit_room,
4007 &session.user_id().to_string(),
4008 )
4009 .trace_err()?,
4010 )
4011 } else {
4012 (
4013 true,
4014 live_kit
4015 .room_token(
4016 &joined_room.room.live_kit_room,
4017 &session.user_id().to_string(),
4018 )
4019 .trace_err()?,
4020 )
4021 };
4022
4023 Some(LiveKitConnectionInfo {
4024 server_url: live_kit.url().into(),
4025 token,
4026 can_publish,
4027 })
4028 });
4029
4030 response.send(proto::JoinRoomResponse {
4031 room: Some(joined_room.room.clone()),
4032 channel_id: joined_room
4033 .channel
4034 .as_ref()
4035 .map(|channel| channel.id.to_proto()),
4036 live_kit_connection_info,
4037 })?;
4038
4039 let mut connection_pool = session.connection_pool().await;
4040 if let Some(membership_updated) = membership_updated {
4041 notify_membership_updated(
4042 &mut connection_pool,
4043 membership_updated,
4044 session.user_id(),
4045 &session.peer,
4046 );
4047 }
4048
4049 room_updated(&joined_room.room, &session.peer);
4050
4051 joined_room
4052 };
4053
4054 channel_updated(
4055 &joined_room
4056 .channel
4057 .ok_or_else(|| anyhow!("channel not returned"))?,
4058 &joined_room.room,
4059 &session.peer,
4060 &*session.connection_pool().await,
4061 );
4062
4063 update_user_contacts(session.user_id(), &session).await?;
4064 Ok(())
4065}
4066
4067/// Start editing the channel notes
4068async fn join_channel_buffer(
4069 request: proto::JoinChannelBuffer,
4070 response: Response<proto::JoinChannelBuffer>,
4071 session: UserSession,
4072) -> Result<()> {
4073 let db = session.db().await;
4074 let channel_id = ChannelId::from_proto(request.channel_id);
4075
4076 let open_response = db
4077 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4078 .await?;
4079
4080 let collaborators = open_response.collaborators.clone();
4081 response.send(open_response)?;
4082
4083 let update = UpdateChannelBufferCollaborators {
4084 channel_id: channel_id.to_proto(),
4085 collaborators: collaborators.clone(),
4086 };
4087 channel_buffer_updated(
4088 session.connection_id,
4089 collaborators
4090 .iter()
4091 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4092 &update,
4093 &session.peer,
4094 );
4095
4096 Ok(())
4097}
4098
4099/// Edit the channel notes
4100async fn update_channel_buffer(
4101 request: proto::UpdateChannelBuffer,
4102 session: UserSession,
4103) -> Result<()> {
4104 let db = session.db().await;
4105 let channel_id = ChannelId::from_proto(request.channel_id);
4106
4107 let (collaborators, epoch, version) = db
4108 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4109 .await?;
4110
4111 channel_buffer_updated(
4112 session.connection_id,
4113 collaborators.clone(),
4114 &proto::UpdateChannelBuffer {
4115 channel_id: channel_id.to_proto(),
4116 operations: request.operations,
4117 },
4118 &session.peer,
4119 );
4120
4121 let pool = &*session.connection_pool().await;
4122
4123 let non_collaborators =
4124 pool.channel_connection_ids(channel_id)
4125 .filter_map(|(connection_id, _)| {
4126 if collaborators.contains(&connection_id) {
4127 None
4128 } else {
4129 Some(connection_id)
4130 }
4131 });
4132
4133 broadcast(None, non_collaborators, |peer_id| {
4134 session.peer.send(
4135 peer_id,
4136 proto::UpdateChannels {
4137 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4138 channel_id: channel_id.to_proto(),
4139 epoch: epoch as u64,
4140 version: version.clone(),
4141 }],
4142 ..Default::default()
4143 },
4144 )
4145 });
4146
4147 Ok(())
4148}
4149
4150/// Rejoin the channel notes after a connection blip
4151async fn rejoin_channel_buffers(
4152 request: proto::RejoinChannelBuffers,
4153 response: Response<proto::RejoinChannelBuffers>,
4154 session: UserSession,
4155) -> Result<()> {
4156 let db = session.db().await;
4157 let buffers = db
4158 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4159 .await?;
4160
4161 for rejoined_buffer in &buffers {
4162 let collaborators_to_notify = rejoined_buffer
4163 .buffer
4164 .collaborators
4165 .iter()
4166 .filter_map(|c| Some(c.peer_id?.into()));
4167 channel_buffer_updated(
4168 session.connection_id,
4169 collaborators_to_notify,
4170 &proto::UpdateChannelBufferCollaborators {
4171 channel_id: rejoined_buffer.buffer.channel_id,
4172 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4173 },
4174 &session.peer,
4175 );
4176 }
4177
4178 response.send(proto::RejoinChannelBuffersResponse {
4179 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4180 })?;
4181
4182 Ok(())
4183}
4184
4185/// Stop editing the channel notes
4186async fn leave_channel_buffer(
4187 request: proto::LeaveChannelBuffer,
4188 response: Response<proto::LeaveChannelBuffer>,
4189 session: UserSession,
4190) -> Result<()> {
4191 let db = session.db().await;
4192 let channel_id = ChannelId::from_proto(request.channel_id);
4193
4194 let left_buffer = db
4195 .leave_channel_buffer(channel_id, session.connection_id)
4196 .await?;
4197
4198 response.send(Ack {})?;
4199
4200 channel_buffer_updated(
4201 session.connection_id,
4202 left_buffer.connections,
4203 &proto::UpdateChannelBufferCollaborators {
4204 channel_id: channel_id.to_proto(),
4205 collaborators: left_buffer.collaborators,
4206 },
4207 &session.peer,
4208 );
4209
4210 Ok(())
4211}
4212
4213fn channel_buffer_updated<T: EnvelopedMessage>(
4214 sender_id: ConnectionId,
4215 collaborators: impl IntoIterator<Item = ConnectionId>,
4216 message: &T,
4217 peer: &Peer,
4218) {
4219 broadcast(Some(sender_id), collaborators, |peer_id| {
4220 peer.send(peer_id, message.clone())
4221 });
4222}
4223
4224fn send_notifications(
4225 connection_pool: &ConnectionPool,
4226 peer: &Peer,
4227 notifications: db::NotificationBatch,
4228) {
4229 for (user_id, notification) in notifications {
4230 for connection_id in connection_pool.user_connection_ids(user_id) {
4231 if let Err(error) = peer.send(
4232 connection_id,
4233 proto::AddNotification {
4234 notification: Some(notification.clone()),
4235 },
4236 ) {
4237 tracing::error!(
4238 "failed to send notification to {:?} {}",
4239 connection_id,
4240 error
4241 );
4242 }
4243 }
4244 }
4245}
4246
4247/// Send a message to the channel
4248async fn send_channel_message(
4249 request: proto::SendChannelMessage,
4250 response: Response<proto::SendChannelMessage>,
4251 session: UserSession,
4252) -> Result<()> {
4253 // Validate the message body.
4254 let body = request.body.trim().to_string();
4255 if body.len() > MAX_MESSAGE_LEN {
4256 return Err(anyhow!("message is too long"))?;
4257 }
4258 if body.is_empty() {
4259 return Err(anyhow!("message can't be blank"))?;
4260 }
4261
4262 // TODO: adjust mentions if body is trimmed
4263
4264 let timestamp = OffsetDateTime::now_utc();
4265 let nonce = request
4266 .nonce
4267 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4268
4269 let channel_id = ChannelId::from_proto(request.channel_id);
4270 let CreatedChannelMessage {
4271 message_id,
4272 participant_connection_ids,
4273 notifications,
4274 } = session
4275 .db()
4276 .await
4277 .create_channel_message(
4278 channel_id,
4279 session.user_id(),
4280 &body,
4281 &request.mentions,
4282 timestamp,
4283 nonce.clone().into(),
4284 match request.reply_to_message_id {
4285 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4286 None => None,
4287 },
4288 )
4289 .await?;
4290
4291 let message = proto::ChannelMessage {
4292 sender_id: session.user_id().to_proto(),
4293 id: message_id.to_proto(),
4294 body,
4295 mentions: request.mentions,
4296 timestamp: timestamp.unix_timestamp() as u64,
4297 nonce: Some(nonce),
4298 reply_to_message_id: request.reply_to_message_id,
4299 edited_at: None,
4300 };
4301 broadcast(
4302 Some(session.connection_id),
4303 participant_connection_ids.clone(),
4304 |connection| {
4305 session.peer.send(
4306 connection,
4307 proto::ChannelMessageSent {
4308 channel_id: channel_id.to_proto(),
4309 message: Some(message.clone()),
4310 },
4311 )
4312 },
4313 );
4314 response.send(proto::SendChannelMessageResponse {
4315 message: Some(message),
4316 })?;
4317
4318 let pool = &*session.connection_pool().await;
4319 let non_participants =
4320 pool.channel_connection_ids(channel_id)
4321 .filter_map(|(connection_id, _)| {
4322 if participant_connection_ids.contains(&connection_id) {
4323 None
4324 } else {
4325 Some(connection_id)
4326 }
4327 });
4328 broadcast(None, non_participants, |peer_id| {
4329 session.peer.send(
4330 peer_id,
4331 proto::UpdateChannels {
4332 latest_channel_message_ids: vec![proto::ChannelMessageId {
4333 channel_id: channel_id.to_proto(),
4334 message_id: message_id.to_proto(),
4335 }],
4336 ..Default::default()
4337 },
4338 )
4339 });
4340 send_notifications(pool, &session.peer, notifications);
4341
4342 Ok(())
4343}
4344
4345/// Delete a channel message
4346async fn remove_channel_message(
4347 request: proto::RemoveChannelMessage,
4348 response: Response<proto::RemoveChannelMessage>,
4349 session: UserSession,
4350) -> Result<()> {
4351 let channel_id = ChannelId::from_proto(request.channel_id);
4352 let message_id = MessageId::from_proto(request.message_id);
4353 let (connection_ids, existing_notification_ids) = session
4354 .db()
4355 .await
4356 .remove_channel_message(channel_id, message_id, session.user_id())
4357 .await?;
4358
4359 broadcast(
4360 Some(session.connection_id),
4361 connection_ids,
4362 move |connection| {
4363 session.peer.send(connection, request.clone())?;
4364
4365 for notification_id in &existing_notification_ids {
4366 session.peer.send(
4367 connection,
4368 proto::DeleteNotification {
4369 notification_id: (*notification_id).to_proto(),
4370 },
4371 )?;
4372 }
4373
4374 Ok(())
4375 },
4376 );
4377 response.send(proto::Ack {})?;
4378 Ok(())
4379}
4380
4381async fn update_channel_message(
4382 request: proto::UpdateChannelMessage,
4383 response: Response<proto::UpdateChannelMessage>,
4384 session: UserSession,
4385) -> Result<()> {
4386 let channel_id = ChannelId::from_proto(request.channel_id);
4387 let message_id = MessageId::from_proto(request.message_id);
4388 let updated_at = OffsetDateTime::now_utc();
4389 let UpdatedChannelMessage {
4390 message_id,
4391 participant_connection_ids,
4392 notifications,
4393 reply_to_message_id,
4394 timestamp,
4395 deleted_mention_notification_ids,
4396 updated_mention_notifications,
4397 } = session
4398 .db()
4399 .await
4400 .update_channel_message(
4401 channel_id,
4402 message_id,
4403 session.user_id(),
4404 request.body.as_str(),
4405 &request.mentions,
4406 updated_at,
4407 )
4408 .await?;
4409
4410 let nonce = request
4411 .nonce
4412 .clone()
4413 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4414
4415 let message = proto::ChannelMessage {
4416 sender_id: session.user_id().to_proto(),
4417 id: message_id.to_proto(),
4418 body: request.body.clone(),
4419 mentions: request.mentions.clone(),
4420 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4421 nonce: Some(nonce),
4422 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4423 edited_at: Some(updated_at.unix_timestamp() as u64),
4424 };
4425
4426 response.send(proto::Ack {})?;
4427
4428 let pool = &*session.connection_pool().await;
4429 broadcast(
4430 Some(session.connection_id),
4431 participant_connection_ids,
4432 |connection| {
4433 session.peer.send(
4434 connection,
4435 proto::ChannelMessageUpdate {
4436 channel_id: channel_id.to_proto(),
4437 message: Some(message.clone()),
4438 },
4439 )?;
4440
4441 for notification_id in &deleted_mention_notification_ids {
4442 session.peer.send(
4443 connection,
4444 proto::DeleteNotification {
4445 notification_id: (*notification_id).to_proto(),
4446 },
4447 )?;
4448 }
4449
4450 for notification in &updated_mention_notifications {
4451 session.peer.send(
4452 connection,
4453 proto::UpdateNotification {
4454 notification: Some(notification.clone()),
4455 },
4456 )?;
4457 }
4458
4459 Ok(())
4460 },
4461 );
4462
4463 send_notifications(pool, &session.peer, notifications);
4464
4465 Ok(())
4466}
4467
4468/// Mark a channel message as read
4469async fn acknowledge_channel_message(
4470 request: proto::AckChannelMessage,
4471 session: UserSession,
4472) -> Result<()> {
4473 let channel_id = ChannelId::from_proto(request.channel_id);
4474 let message_id = MessageId::from_proto(request.message_id);
4475 let notifications = session
4476 .db()
4477 .await
4478 .observe_channel_message(channel_id, session.user_id(), message_id)
4479 .await?;
4480 send_notifications(
4481 &*session.connection_pool().await,
4482 &session.peer,
4483 notifications,
4484 );
4485 Ok(())
4486}
4487
4488/// Mark a buffer version as synced
4489async fn acknowledge_buffer_version(
4490 request: proto::AckBufferOperation,
4491 session: UserSession,
4492) -> Result<()> {
4493 let buffer_id = BufferId::from_proto(request.buffer_id);
4494 session
4495 .db()
4496 .await
4497 .observe_buffer_version(
4498 buffer_id,
4499 session.user_id(),
4500 request.epoch as i32,
4501 &request.version,
4502 )
4503 .await?;
4504 Ok(())
4505}
4506
4507async fn count_language_model_tokens(
4508 request: proto::CountLanguageModelTokens,
4509 response: Response<proto::CountLanguageModelTokens>,
4510 session: Session,
4511 config: &Config,
4512) -> Result<()> {
4513 let Some(session) = session.for_user() else {
4514 return Err(anyhow!("user not found"))?;
4515 };
4516 authorize_access_to_legacy_llm_endpoints(&session).await?;
4517
4518 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4519 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4520 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4521 };
4522
4523 session
4524 .app_state
4525 .rate_limiter
4526 .check(&*rate_limit, session.user_id())
4527 .await?;
4528
4529 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4530 Some(proto::LanguageModelProvider::Google) => {
4531 let api_key = config
4532 .google_ai_api_key
4533 .as_ref()
4534 .context("no Google AI API key configured on the server")?;
4535 google_ai::count_tokens(
4536 session.http_client.as_ref(),
4537 google_ai::API_URL,
4538 api_key,
4539 serde_json::from_str(&request.request)?,
4540 )
4541 .await?
4542 }
4543 _ => return Err(anyhow!("unsupported provider"))?,
4544 };
4545
4546 response.send(proto::CountLanguageModelTokensResponse {
4547 token_count: result.total_tokens as u32,
4548 })?;
4549
4550 Ok(())
4551}
4552
4553struct ZedProCountLanguageModelTokensRateLimit;
4554
4555impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4556 fn capacity(&self) -> usize {
4557 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4558 .ok()
4559 .and_then(|v| v.parse().ok())
4560 .unwrap_or(600) // Picked arbitrarily
4561 }
4562
4563 fn refill_duration(&self) -> chrono::Duration {
4564 chrono::Duration::hours(1)
4565 }
4566
4567 fn db_name(&self) -> &'static str {
4568 "zed-pro:count-language-model-tokens"
4569 }
4570}
4571
4572struct FreeCountLanguageModelTokensRateLimit;
4573
4574impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4575 fn capacity(&self) -> usize {
4576 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4577 .ok()
4578 .and_then(|v| v.parse().ok())
4579 .unwrap_or(600 / 10) // Picked arbitrarily
4580 }
4581
4582 fn refill_duration(&self) -> chrono::Duration {
4583 chrono::Duration::hours(1)
4584 }
4585
4586 fn db_name(&self) -> &'static str {
4587 "free:count-language-model-tokens"
4588 }
4589}
4590
4591struct ZedProComputeEmbeddingsRateLimit;
4592
4593impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4594 fn capacity(&self) -> usize {
4595 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4596 .ok()
4597 .and_then(|v| v.parse().ok())
4598 .unwrap_or(5000) // Picked arbitrarily
4599 }
4600
4601 fn refill_duration(&self) -> chrono::Duration {
4602 chrono::Duration::hours(1)
4603 }
4604
4605 fn db_name(&self) -> &'static str {
4606 "zed-pro:compute-embeddings"
4607 }
4608}
4609
4610struct FreeComputeEmbeddingsRateLimit;
4611
4612impl RateLimit for FreeComputeEmbeddingsRateLimit {
4613 fn capacity(&self) -> usize {
4614 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4615 .ok()
4616 .and_then(|v| v.parse().ok())
4617 .unwrap_or(5000 / 10) // Picked arbitrarily
4618 }
4619
4620 fn refill_duration(&self) -> chrono::Duration {
4621 chrono::Duration::hours(1)
4622 }
4623
4624 fn db_name(&self) -> &'static str {
4625 "free:compute-embeddings"
4626 }
4627}
4628
4629async fn compute_embeddings(
4630 request: proto::ComputeEmbeddings,
4631 response: Response<proto::ComputeEmbeddings>,
4632 session: UserSession,
4633 api_key: Option<Arc<str>>,
4634) -> Result<()> {
4635 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4636 authorize_access_to_legacy_llm_endpoints(&session).await?;
4637
4638 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4639 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4640 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4641 };
4642
4643 session
4644 .app_state
4645 .rate_limiter
4646 .check(&*rate_limit, session.user_id())
4647 .await?;
4648
4649 let embeddings = match request.model.as_str() {
4650 "openai/text-embedding-3-small" => {
4651 open_ai::embed(
4652 session.http_client.as_ref(),
4653 OPEN_AI_API_URL,
4654 &api_key,
4655 OpenAiEmbeddingModel::TextEmbedding3Small,
4656 request.texts.iter().map(|text| text.as_str()),
4657 )
4658 .await?
4659 }
4660 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4661 };
4662
4663 let embeddings = request
4664 .texts
4665 .iter()
4666 .map(|text| {
4667 let mut hasher = sha2::Sha256::new();
4668 hasher.update(text.as_bytes());
4669 let result = hasher.finalize();
4670 result.to_vec()
4671 })
4672 .zip(
4673 embeddings
4674 .data
4675 .into_iter()
4676 .map(|embedding| embedding.embedding),
4677 )
4678 .collect::<HashMap<_, _>>();
4679
4680 let db = session.db().await;
4681 db.save_embeddings(&request.model, &embeddings)
4682 .await
4683 .context("failed to save embeddings")
4684 .trace_err();
4685
4686 response.send(proto::ComputeEmbeddingsResponse {
4687 embeddings: embeddings
4688 .into_iter()
4689 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4690 .collect(),
4691 })?;
4692 Ok(())
4693}
4694
4695async fn get_cached_embeddings(
4696 request: proto::GetCachedEmbeddings,
4697 response: Response<proto::GetCachedEmbeddings>,
4698 session: UserSession,
4699) -> Result<()> {
4700 authorize_access_to_legacy_llm_endpoints(&session).await?;
4701
4702 let db = session.db().await;
4703 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4704
4705 response.send(proto::GetCachedEmbeddingsResponse {
4706 embeddings: embeddings
4707 .into_iter()
4708 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4709 .collect(),
4710 })?;
4711 Ok(())
4712}
4713
4714/// This is leftover from before the LLM service.
4715///
4716/// The endpoints protected by this check will be moved there eventually.
4717async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4718 if session.is_staff() {
4719 Ok(())
4720 } else {
4721 Err(anyhow!("permission denied"))?
4722 }
4723}
4724
4725/// Get a Supermaven API key for the user
4726async fn get_supermaven_api_key(
4727 _request: proto::GetSupermavenApiKey,
4728 response: Response<proto::GetSupermavenApiKey>,
4729 session: UserSession,
4730) -> Result<()> {
4731 let user_id: String = session.user_id().to_string();
4732 if !session.is_staff() {
4733 return Err(anyhow!("supermaven not enabled for this account"))?;
4734 }
4735
4736 let email = session
4737 .email()
4738 .ok_or_else(|| anyhow!("user must have an email"))?;
4739
4740 let supermaven_admin_api = session
4741 .supermaven_client
4742 .as_ref()
4743 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4744
4745 let result = supermaven_admin_api
4746 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4747 .await?;
4748
4749 response.send(proto::GetSupermavenApiKeyResponse {
4750 api_key: result.api_key,
4751 })?;
4752
4753 Ok(())
4754}
4755
4756/// Start receiving chat updates for a channel
4757async fn join_channel_chat(
4758 request: proto::JoinChannelChat,
4759 response: Response<proto::JoinChannelChat>,
4760 session: UserSession,
4761) -> Result<()> {
4762 let channel_id = ChannelId::from_proto(request.channel_id);
4763
4764 let db = session.db().await;
4765 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4766 .await?;
4767 let messages = db
4768 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4769 .await?;
4770 response.send(proto::JoinChannelChatResponse {
4771 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4772 messages,
4773 })?;
4774 Ok(())
4775}
4776
4777/// Stop receiving chat updates for a channel
4778async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4779 let channel_id = ChannelId::from_proto(request.channel_id);
4780 session
4781 .db()
4782 .await
4783 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4784 .await?;
4785 Ok(())
4786}
4787
4788/// Retrieve the chat history for a channel
4789async fn get_channel_messages(
4790 request: proto::GetChannelMessages,
4791 response: Response<proto::GetChannelMessages>,
4792 session: UserSession,
4793) -> Result<()> {
4794 let channel_id = ChannelId::from_proto(request.channel_id);
4795 let messages = session
4796 .db()
4797 .await
4798 .get_channel_messages(
4799 channel_id,
4800 session.user_id(),
4801 MESSAGE_COUNT_PER_PAGE,
4802 Some(MessageId::from_proto(request.before_message_id)),
4803 )
4804 .await?;
4805 response.send(proto::GetChannelMessagesResponse {
4806 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4807 messages,
4808 })?;
4809 Ok(())
4810}
4811
4812/// Retrieve specific chat messages
4813async fn get_channel_messages_by_id(
4814 request: proto::GetChannelMessagesById,
4815 response: Response<proto::GetChannelMessagesById>,
4816 session: UserSession,
4817) -> Result<()> {
4818 let message_ids = request
4819 .message_ids
4820 .iter()
4821 .map(|id| MessageId::from_proto(*id))
4822 .collect::<Vec<_>>();
4823 let messages = session
4824 .db()
4825 .await
4826 .get_channel_messages_by_id(session.user_id(), &message_ids)
4827 .await?;
4828 response.send(proto::GetChannelMessagesResponse {
4829 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4830 messages,
4831 })?;
4832 Ok(())
4833}
4834
4835/// Retrieve the current users notifications
4836async fn get_notifications(
4837 request: proto::GetNotifications,
4838 response: Response<proto::GetNotifications>,
4839 session: UserSession,
4840) -> Result<()> {
4841 let notifications = session
4842 .db()
4843 .await
4844 .get_notifications(
4845 session.user_id(),
4846 NOTIFICATION_COUNT_PER_PAGE,
4847 request
4848 .before_id
4849 .map(|id| db::NotificationId::from_proto(id)),
4850 )
4851 .await?;
4852 response.send(proto::GetNotificationsResponse {
4853 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4854 notifications,
4855 })?;
4856 Ok(())
4857}
4858
4859/// Mark notifications as read
4860async fn mark_notification_as_read(
4861 request: proto::MarkNotificationRead,
4862 response: Response<proto::MarkNotificationRead>,
4863 session: UserSession,
4864) -> Result<()> {
4865 let database = &session.db().await;
4866 let notifications = database
4867 .mark_notification_as_read_by_id(
4868 session.user_id(),
4869 NotificationId::from_proto(request.notification_id),
4870 )
4871 .await?;
4872 send_notifications(
4873 &*session.connection_pool().await,
4874 &session.peer,
4875 notifications,
4876 );
4877 response.send(proto::Ack {})?;
4878 Ok(())
4879}
4880
4881/// Get the current users information
4882async fn get_private_user_info(
4883 _request: proto::GetPrivateUserInfo,
4884 response: Response<proto::GetPrivateUserInfo>,
4885 session: UserSession,
4886) -> Result<()> {
4887 let db = session.db().await;
4888
4889 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4890 let user = db
4891 .get_user_by_id(session.user_id())
4892 .await?
4893 .ok_or_else(|| anyhow!("user not found"))?;
4894 let flags = db.get_user_flags(session.user_id()).await?;
4895
4896 response.send(proto::GetPrivateUserInfoResponse {
4897 metrics_id,
4898 staff: user.admin,
4899 flags,
4900 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4901 })?;
4902 Ok(())
4903}
4904
4905/// Accept the terms of service (tos) on behalf of the current user
4906async fn accept_terms_of_service(
4907 _request: proto::AcceptTermsOfService,
4908 response: Response<proto::AcceptTermsOfService>,
4909 session: UserSession,
4910) -> Result<()> {
4911 let db = session.db().await;
4912
4913 let accepted_tos_at = Utc::now();
4914 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4915 .await?;
4916
4917 response.send(proto::AcceptTermsOfServiceResponse {
4918 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4919 })?;
4920 Ok(())
4921}
4922
4923/// The minimum account age an account must have in order to use the LLM service.
4924const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4925
4926async fn get_llm_api_token(
4927 _request: proto::GetLlmToken,
4928 response: Response<proto::GetLlmToken>,
4929 session: UserSession,
4930) -> Result<()> {
4931 let db = session.db().await;
4932
4933 let flags = db.get_user_flags(session.user_id()).await?;
4934 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4935 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4936
4937 if !session.is_staff() && !has_language_models_feature_flag {
4938 Err(anyhow!("permission denied"))?
4939 }
4940
4941 let user_id = session.user_id();
4942 let user = db
4943 .get_user_by_id(user_id)
4944 .await?
4945 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4946
4947 if user.accepted_tos_at.is_none() {
4948 Err(anyhow!("terms of service not accepted"))?
4949 }
4950
4951 let mut account_created_at = user.created_at;
4952 if let Some(github_created_at) = user.github_user_created_at {
4953 account_created_at = account_created_at.min(github_created_at);
4954 }
4955 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4956 Err(anyhow!("account too young"))?
4957 }
4958 let token = LlmTokenClaims::create(
4959 user.id,
4960 user.github_login.clone(),
4961 session.is_staff(),
4962 has_llm_closed_beta_feature_flag,
4963 session.current_plan(db).await?,
4964 &session.app_state.config,
4965 )?;
4966 response.send(proto::GetLlmTokenResponse { token })?;
4967 Ok(())
4968}
4969
4970fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4971 let message = match message {
4972 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4973 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4974 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4975 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4976 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4977 code: frame.code.into(),
4978 reason: frame.reason,
4979 })),
4980 // We should never receive a frame while reading the message, according
4981 // to the `tungstenite` maintainers:
4982 //
4983 // > It cannot occur when you read messages from the WebSocket, but it
4984 // > can be used when you want to send the raw frames (e.g. you want to
4985 // > send the frames to the WebSocket without composing the full message first).
4986 // >
4987 // > — https://github.com/snapview/tungstenite-rs/issues/268
4988 TungsteniteMessage::Frame(_) => {
4989 bail!("received an unexpected frame while reading the message")
4990 }
4991 };
4992
4993 Ok(message)
4994}
4995
4996fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4997 match message {
4998 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4999 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5000 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5001 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5002 AxumMessage::Close(frame) => {
5003 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5004 code: frame.code.into(),
5005 reason: frame.reason,
5006 }))
5007 }
5008 }
5009}
5010
5011fn notify_membership_updated(
5012 connection_pool: &mut ConnectionPool,
5013 result: MembershipUpdated,
5014 user_id: UserId,
5015 peer: &Peer,
5016) {
5017 for membership in &result.new_channels.channel_memberships {
5018 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5019 }
5020 for channel_id in &result.removed_channels {
5021 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5022 }
5023
5024 let user_channels_update = proto::UpdateUserChannels {
5025 channel_memberships: result
5026 .new_channels
5027 .channel_memberships
5028 .iter()
5029 .map(|cm| proto::ChannelMembership {
5030 channel_id: cm.channel_id.to_proto(),
5031 role: cm.role.into(),
5032 })
5033 .collect(),
5034 ..Default::default()
5035 };
5036
5037 let mut update = build_channels_update(result.new_channels);
5038 update.delete_channels = result
5039 .removed_channels
5040 .into_iter()
5041 .map(|id| id.to_proto())
5042 .collect();
5043 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5044
5045 for connection_id in connection_pool.user_connection_ids(user_id) {
5046 peer.send(connection_id, user_channels_update.clone())
5047 .trace_err();
5048 peer.send(connection_id, update.clone()).trace_err();
5049 }
5050}
5051
5052fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5053 proto::UpdateUserChannels {
5054 channel_memberships: channels
5055 .channel_memberships
5056 .iter()
5057 .map(|m| proto::ChannelMembership {
5058 channel_id: m.channel_id.to_proto(),
5059 role: m.role.into(),
5060 })
5061 .collect(),
5062 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5063 observed_channel_message_id: channels.observed_channel_messages.clone(),
5064 }
5065}
5066
5067fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5068 let mut update = proto::UpdateChannels::default();
5069
5070 for channel in channels.channels {
5071 update.channels.push(channel.to_proto());
5072 }
5073
5074 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5075 update.latest_channel_message_ids = channels.latest_channel_messages;
5076
5077 for (channel_id, participants) in channels.channel_participants {
5078 update
5079 .channel_participants
5080 .push(proto::ChannelParticipants {
5081 channel_id: channel_id.to_proto(),
5082 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5083 });
5084 }
5085
5086 for channel in channels.invited_channels {
5087 update.channel_invitations.push(channel.to_proto());
5088 }
5089
5090 update.hosted_projects = channels.hosted_projects;
5091 update
5092}
5093
5094fn build_initial_contacts_update(
5095 contacts: Vec<db::Contact>,
5096 pool: &ConnectionPool,
5097) -> proto::UpdateContacts {
5098 let mut update = proto::UpdateContacts::default();
5099
5100 for contact in contacts {
5101 match contact {
5102 db::Contact::Accepted { user_id, busy } => {
5103 update.contacts.push(contact_for_user(user_id, busy, &pool));
5104 }
5105 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5106 db::Contact::Incoming { user_id } => {
5107 update
5108 .incoming_requests
5109 .push(proto::IncomingContactRequest {
5110 requester_id: user_id.to_proto(),
5111 })
5112 }
5113 }
5114 }
5115
5116 update
5117}
5118
5119fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5120 proto::Contact {
5121 user_id: user_id.to_proto(),
5122 online: pool.is_user_online(user_id),
5123 busy,
5124 }
5125}
5126
5127fn room_updated(room: &proto::Room, peer: &Peer) {
5128 broadcast(
5129 None,
5130 room.participants
5131 .iter()
5132 .filter_map(|participant| Some(participant.peer_id?.into())),
5133 |peer_id| {
5134 peer.send(
5135 peer_id,
5136 proto::RoomUpdated {
5137 room: Some(room.clone()),
5138 },
5139 )
5140 },
5141 );
5142}
5143
5144fn channel_updated(
5145 channel: &db::channel::Model,
5146 room: &proto::Room,
5147 peer: &Peer,
5148 pool: &ConnectionPool,
5149) {
5150 let participants = room
5151 .participants
5152 .iter()
5153 .map(|p| p.user_id)
5154 .collect::<Vec<_>>();
5155
5156 broadcast(
5157 None,
5158 pool.channel_connection_ids(channel.root_id())
5159 .filter_map(|(channel_id, role)| {
5160 role.can_see_channel(channel.visibility).then(|| channel_id)
5161 }),
5162 |peer_id| {
5163 peer.send(
5164 peer_id,
5165 proto::UpdateChannels {
5166 channel_participants: vec![proto::ChannelParticipants {
5167 channel_id: channel.id.to_proto(),
5168 participant_user_ids: participants.clone(),
5169 }],
5170 ..Default::default()
5171 },
5172 )
5173 },
5174 );
5175}
5176
5177async fn send_dev_server_projects_update(
5178 user_id: UserId,
5179 mut status: proto::DevServerProjectsUpdate,
5180 session: &Session,
5181) {
5182 let pool = session.connection_pool().await;
5183 for dev_server in &mut status.dev_servers {
5184 dev_server.status =
5185 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5186 }
5187 let connections = pool.user_connection_ids(user_id);
5188 for connection_id in connections {
5189 session.peer.send(connection_id, status.clone()).trace_err();
5190 }
5191}
5192
5193async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5194 let db = session.db().await;
5195
5196 let contacts = db.get_contacts(user_id).await?;
5197 let busy = db.is_user_busy(user_id).await?;
5198
5199 let pool = session.connection_pool().await;
5200 let updated_contact = contact_for_user(user_id, busy, &pool);
5201 for contact in contacts {
5202 if let db::Contact::Accepted {
5203 user_id: contact_user_id,
5204 ..
5205 } = contact
5206 {
5207 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5208 session
5209 .peer
5210 .send(
5211 contact_conn_id,
5212 proto::UpdateContacts {
5213 contacts: vec![updated_contact.clone()],
5214 remove_contacts: Default::default(),
5215 incoming_requests: Default::default(),
5216 remove_incoming_requests: Default::default(),
5217 outgoing_requests: Default::default(),
5218 remove_outgoing_requests: Default::default(),
5219 },
5220 )
5221 .trace_err();
5222 }
5223 }
5224 }
5225 Ok(())
5226}
5227
5228async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5229 log::info!("lost dev server connection, unsharing projects");
5230 let project_ids = session
5231 .db()
5232 .await
5233 .get_stale_dev_server_projects(session.connection_id)
5234 .await?;
5235
5236 for project_id in project_ids {
5237 // not unshare re-checks the connection ids match, so we get away with no transaction
5238 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5239 }
5240
5241 let user_id = session.dev_server().user_id;
5242 let update = session
5243 .db()
5244 .await
5245 .dev_server_projects_update(user_id)
5246 .await?;
5247
5248 send_dev_server_projects_update(user_id, update, session).await;
5249
5250 Ok(())
5251}
5252
5253async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5254 let mut contacts_to_update = HashSet::default();
5255
5256 let room_id;
5257 let canceled_calls_to_user_ids;
5258 let live_kit_room;
5259 let delete_live_kit_room;
5260 let room;
5261 let channel;
5262
5263 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5264 contacts_to_update.insert(session.user_id());
5265
5266 for project in left_room.left_projects.values() {
5267 project_left(project, session);
5268 }
5269
5270 room_id = RoomId::from_proto(left_room.room.id);
5271 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5272 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5273 delete_live_kit_room = left_room.deleted;
5274 room = mem::take(&mut left_room.room);
5275 channel = mem::take(&mut left_room.channel);
5276
5277 room_updated(&room, &session.peer);
5278 } else {
5279 return Ok(());
5280 }
5281
5282 if let Some(channel) = channel {
5283 channel_updated(
5284 &channel,
5285 &room,
5286 &session.peer,
5287 &*session.connection_pool().await,
5288 );
5289 }
5290
5291 {
5292 let pool = session.connection_pool().await;
5293 for canceled_user_id in canceled_calls_to_user_ids {
5294 for connection_id in pool.user_connection_ids(canceled_user_id) {
5295 session
5296 .peer
5297 .send(
5298 connection_id,
5299 proto::CallCanceled {
5300 room_id: room_id.to_proto(),
5301 },
5302 )
5303 .trace_err();
5304 }
5305 contacts_to_update.insert(canceled_user_id);
5306 }
5307 }
5308
5309 for contact_user_id in contacts_to_update {
5310 update_user_contacts(contact_user_id, &session).await?;
5311 }
5312
5313 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5314 live_kit
5315 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5316 .await
5317 .trace_err();
5318
5319 if delete_live_kit_room {
5320 live_kit.delete_room(live_kit_room).await.trace_err();
5321 }
5322 }
5323
5324 Ok(())
5325}
5326
5327async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5328 let left_channel_buffers = session
5329 .db()
5330 .await
5331 .leave_channel_buffers(session.connection_id)
5332 .await?;
5333
5334 for left_buffer in left_channel_buffers {
5335 channel_buffer_updated(
5336 session.connection_id,
5337 left_buffer.connections,
5338 &proto::UpdateChannelBufferCollaborators {
5339 channel_id: left_buffer.channel_id.to_proto(),
5340 collaborators: left_buffer.collaborators,
5341 },
5342 &session.peer,
5343 );
5344 }
5345
5346 Ok(())
5347}
5348
5349fn project_left(project: &db::LeftProject, session: &UserSession) {
5350 for connection_id in &project.connection_ids {
5351 if project.should_unshare {
5352 session
5353 .peer
5354 .send(
5355 *connection_id,
5356 proto::UnshareProject {
5357 project_id: project.id.to_proto(),
5358 },
5359 )
5360 .trace_err();
5361 } else {
5362 session
5363 .peer
5364 .send(
5365 *connection_id,
5366 proto::RemoveProjectCollaborator {
5367 project_id: project.id.to_proto(),
5368 peer_id: Some(session.connection_id.into()),
5369 },
5370 )
5371 .trace_err();
5372 }
5373 }
5374}
5375
5376pub trait ResultExt {
5377 type Ok;
5378
5379 fn trace_err(self) -> Option<Self::Ok>;
5380}
5381
5382impl<T, E> ResultExt for Result<T, E>
5383where
5384 E: std::fmt::Debug,
5385{
5386 type Ok = T;
5387
5388 #[track_caller]
5389 fn trace_err(self) -> Option<T> {
5390 match self {
5391 Ok(value) => Some(value),
5392 Err(error) => {
5393 tracing::error!("{:?}", error);
5394 None
5395 }
5396 }
5397 }
5398}