1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use isahc_http_client::IsahcHttpClient;
40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
195 if self.is_staff() {
196 return Ok(proto::Plan::ZedPro);
197 }
198
199 let Some(user_id) = self.user_id() else {
200 return Ok(proto::Plan::Free);
201 };
202
203 if db.has_active_billing_subscription(user_id).await? {
204 Ok(proto::Plan::ZedPro)
205 } else {
206 Ok(proto::Plan::Free)
207 }
208 }
209
210 fn dev_server_id(&self) -> Option<DevServerId> {
211 match &self.principal {
212 Principal::User(_) | Principal::Impersonated { .. } => None,
213 Principal::DevServer(dev_server) => Some(dev_server.id),
214 }
215 }
216
217 fn principal_id(&self) -> PrincipalId {
218 match &self.principal {
219 Principal::User(user) => PrincipalId::UserId(user.id),
220 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
221 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
222 }
223 }
224}
225
226impl Debug for Session {
227 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228 let mut result = f.debug_struct("Session");
229 match &self.principal {
230 Principal::User(user) => {
231 result.field("user", &user.github_login);
232 }
233 Principal::Impersonated { user, admin } => {
234 result.field("user", &user.github_login);
235 result.field("impersonator", &admin.github_login);
236 }
237 Principal::DevServer(dev_server) => {
238 result.field("dev_server", &dev_server.id);
239 }
240 }
241 result.field("connection_id", &self.connection_id).finish()
242 }
243}
244
245struct UserSession(Session);
246
247impl UserSession {
248 pub fn new(s: Session) -> Option<Self> {
249 s.user_id().map(|_| UserSession(s))
250 }
251 pub fn user_id(&self) -> UserId {
252 self.0.user_id().unwrap()
253 }
254
255 pub fn email(&self) -> Option<String> {
256 match &self.0.principal {
257 Principal::User(user) => user.email_address.clone(),
258 Principal::Impersonated { user, .. } => user.email_address.clone(),
259 Principal::DevServer(..) => None,
260 }
261 }
262}
263
264impl Deref for UserSession {
265 type Target = Session;
266
267 fn deref(&self) -> &Self::Target {
268 &self.0
269 }
270}
271impl DerefMut for UserSession {
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self.0
274 }
275}
276
277struct DevServerSession(Session);
278
279impl DevServerSession {
280 pub fn new(s: Session) -> Option<Self> {
281 s.dev_server_id().map(|_| DevServerSession(s))
282 }
283 pub fn dev_server_id(&self) -> DevServerId {
284 self.0.dev_server_id().unwrap()
285 }
286
287 fn dev_server(&self) -> &dev_server::Model {
288 match &self.0.principal {
289 Principal::DevServer(dev_server) => dev_server,
290 _ => unreachable!(),
291 }
292 }
293}
294
295impl Deref for DevServerSession {
296 type Target = Session;
297
298 fn deref(&self) -> &Self::Target {
299 &self.0
300 }
301}
302impl DerefMut for DevServerSession {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut self.0
305 }
306}
307
308fn user_handler<M: RequestMessage, Fut>(
309 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
311where
312 Fut: Send + Future<Output = Result<()>>,
313{
314 let handler = Arc::new(handler);
315 move |message, response, session| {
316 let handler = handler.clone();
317 Box::pin(async move {
318 if let Some(user_session) = session.for_user() {
319 Ok(handler(message, response, user_session).await?)
320 } else {
321 Err(Error::Internal(anyhow!(
322 "must be a user to call {}",
323 M::NAME
324 )))
325 }
326 })
327 }
328}
329
330fn dev_server_handler<M: RequestMessage, Fut>(
331 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
333where
334 Fut: Send + Future<Output = Result<()>>,
335{
336 let handler = Arc::new(handler);
337 move |message, response, session| {
338 let handler = handler.clone();
339 Box::pin(async move {
340 if let Some(dev_server_session) = session.for_dev_server() {
341 Ok(handler(message, response, dev_server_session).await?)
342 } else {
343 Err(Error::Internal(anyhow!(
344 "must be a dev server to call {}",
345 M::NAME
346 )))
347 }
348 })
349 }
350}
351
352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
353 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
355where
356 InnertRetFut: Send + Future<Output = Result<()>>,
357{
358 let handler = Arc::new(handler);
359 move |message, session| {
360 let handler = handler.clone();
361 Box::pin(async move {
362 if let Some(user_session) = session.for_user() {
363 Ok(handler(message, user_session).await?)
364 } else {
365 Err(Error::Internal(anyhow!(
366 "must be a user to call {}",
367 M::NAME
368 )))
369 }
370 })
371 }
372}
373
374struct DbHandle(Arc<Database>);
375
376impl Deref for DbHandle {
377 type Target = Database;
378
379 fn deref(&self) -> &Self::Target {
380 self.0.as_ref()
381 }
382}
383
384pub struct Server {
385 id: parking_lot::Mutex<ServerId>,
386 peer: Arc<Peer>,
387 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
388 app_state: Arc<AppState>,
389 handlers: HashMap<TypeId, MessageHandler>,
390 teardown: watch::Sender<bool>,
391}
392
393pub(crate) struct ConnectionPoolGuard<'a> {
394 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
395 _not_send: PhantomData<Rc<()>>,
396}
397
398#[derive(Serialize)]
399pub struct ServerSnapshot<'a> {
400 peer: &'a Peer,
401 #[serde(serialize_with = "serialize_deref")]
402 connection_pool: ConnectionPoolGuard<'a>,
403}
404
405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
406where
407 S: Serializer,
408 T: Deref<Target = U>,
409 U: Serialize,
410{
411 Serialize::serialize(value.deref(), serializer)
412}
413
414impl Server {
415 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
416 let mut server = Self {
417 id: parking_lot::Mutex::new(id),
418 peer: Peer::new(id.0 as u32),
419 app_state: app_state.clone(),
420 connection_pool: Default::default(),
421 handlers: Default::default(),
422 teardown: watch::channel(false).0,
423 };
424
425 server
426 .add_request_handler(ping)
427 .add_request_handler(user_handler(create_room))
428 .add_request_handler(user_handler(join_room))
429 .add_request_handler(user_handler(rejoin_room))
430 .add_request_handler(user_handler(leave_room))
431 .add_request_handler(user_handler(set_room_participant_role))
432 .add_request_handler(user_handler(call))
433 .add_request_handler(user_handler(cancel_call))
434 .add_message_handler(user_message_handler(decline_call))
435 .add_request_handler(user_handler(update_participant_location))
436 .add_request_handler(user_handler(share_project))
437 .add_message_handler(unshare_project)
438 .add_request_handler(user_handler(join_project))
439 .add_request_handler(user_handler(join_hosted_project))
440 .add_request_handler(user_handler(rejoin_dev_server_projects))
441 .add_request_handler(user_handler(create_dev_server_project))
442 .add_request_handler(user_handler(update_dev_server_project))
443 .add_request_handler(user_handler(delete_dev_server_project))
444 .add_request_handler(user_handler(create_dev_server))
445 .add_request_handler(user_handler(regenerate_dev_server_token))
446 .add_request_handler(user_handler(rename_dev_server))
447 .add_request_handler(user_handler(delete_dev_server))
448 .add_request_handler(user_handler(list_remote_directory))
449 .add_request_handler(dev_server_handler(share_dev_server_project))
450 .add_request_handler(dev_server_handler(shutdown_dev_server))
451 .add_request_handler(dev_server_handler(reconnect_dev_server))
452 .add_message_handler(user_message_handler(leave_project))
453 .add_request_handler(update_project)
454 .add_request_handler(update_worktree)
455 .add_message_handler(start_language_server)
456 .add_message_handler(update_language_server)
457 .add_message_handler(update_diagnostic_summary)
458 .add_message_handler(update_worktree_settings)
459 .add_request_handler(user_handler(
460 forward_project_request_for_owner::<proto::TaskContextForLocation>,
461 ))
462 .add_request_handler(user_handler(
463 forward_project_request_for_owner::<proto::TaskTemplates>,
464 ))
465 .add_request_handler(user_handler(
466 forward_read_only_project_request::<proto::GetHover>,
467 ))
468 .add_request_handler(user_handler(
469 forward_read_only_project_request::<proto::GetDefinition>,
470 ))
471 .add_request_handler(user_handler(
472 forward_read_only_project_request::<proto::GetTypeDefinition>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::GetReferences>,
476 ))
477 .add_request_handler(user_handler(
478 forward_read_only_project_request::<proto::SearchProject>,
479 ))
480 .add_request_handler(user_handler(forward_find_search_candidates_request))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDocumentHighlights>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetProjectSymbols>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferById>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::SynchronizeBuffers>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::InlayHints>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::ResolveInlayHint>,
501 ))
502 .add_request_handler(user_handler(
503 forward_read_only_project_request::<proto::OpenBufferByPath>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCompletions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::OpenNewBuffer>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::GetCodeActions>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::ApplyCodeAction>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::PrepareRename>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::PerformRename>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::ReloadBuffers>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::FormatBuffers>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::CreateProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::RenameProjectEntry>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::CopyProjectEntry>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::DeleteProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::ExpandProjectEntry>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::OnTypeFormatting>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::SaveBuffer>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::BlameBuffer>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::MultiLspQuery>,
561 ))
562 .add_request_handler(user_handler(
563 forward_mutating_project_request::<proto::RestartLanguageServers>,
564 ))
565 .add_request_handler(user_handler(
566 forward_mutating_project_request::<proto::LinkedEditingRange>,
567 ))
568 .add_message_handler(create_buffer_for_peer)
569 .add_request_handler(update_buffer)
570 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
573 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
574 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
575 .add_request_handler(get_users)
576 .add_request_handler(user_handler(fuzzy_search_users))
577 .add_request_handler(user_handler(request_contact))
578 .add_request_handler(user_handler(remove_contact))
579 .add_request_handler(user_handler(respond_to_contact_request))
580 .add_message_handler(subscribe_to_channels)
581 .add_request_handler(user_handler(create_channel))
582 .add_request_handler(user_handler(delete_channel))
583 .add_request_handler(user_handler(invite_channel_member))
584 .add_request_handler(user_handler(remove_channel_member))
585 .add_request_handler(user_handler(set_channel_member_role))
586 .add_request_handler(user_handler(set_channel_visibility))
587 .add_request_handler(user_handler(rename_channel))
588 .add_request_handler(user_handler(join_channel_buffer))
589 .add_request_handler(user_handler(leave_channel_buffer))
590 .add_message_handler(user_message_handler(update_channel_buffer))
591 .add_request_handler(user_handler(rejoin_channel_buffers))
592 .add_request_handler(user_handler(get_channel_members))
593 .add_request_handler(user_handler(respond_to_channel_invite))
594 .add_request_handler(user_handler(join_channel))
595 .add_request_handler(user_handler(join_channel_chat))
596 .add_message_handler(user_message_handler(leave_channel_chat))
597 .add_request_handler(user_handler(send_channel_message))
598 .add_request_handler(user_handler(remove_channel_message))
599 .add_request_handler(user_handler(update_channel_message))
600 .add_request_handler(user_handler(get_channel_messages))
601 .add_request_handler(user_handler(get_channel_messages_by_id))
602 .add_request_handler(user_handler(get_notifications))
603 .add_request_handler(user_handler(mark_notification_as_read))
604 .add_request_handler(user_handler(move_channel))
605 .add_request_handler(user_handler(follow))
606 .add_message_handler(user_message_handler(unfollow))
607 .add_message_handler(user_message_handler(update_followers))
608 .add_request_handler(user_handler(get_private_user_info))
609 .add_request_handler(user_handler(get_llm_api_token))
610 .add_request_handler(user_handler(accept_terms_of_service))
611 .add_message_handler(user_message_handler(acknowledge_channel_message))
612 .add_message_handler(user_message_handler(acknowledge_buffer_version))
613 .add_request_handler(user_handler(get_supermaven_api_key))
614 .add_request_handler(user_handler(
615 forward_mutating_project_request::<proto::OpenContext>,
616 ))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::CreateContext>,
619 ))
620 .add_request_handler(user_handler(
621 forward_mutating_project_request::<proto::SynchronizeContexts>,
622 ))
623 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
624 .add_message_handler(update_context)
625 .add_request_handler({
626 let app_state = app_state.clone();
627 move |request, response, session| {
628 let app_state = app_state.clone();
629 async move {
630 count_language_model_tokens(request, response, session, &app_state.config)
631 .await
632 }
633 }
634 })
635 .add_request_handler({
636 user_handler(move |request, response, session| {
637 get_cached_embeddings(request, response, session)
638 })
639 })
640 .add_request_handler({
641 let app_state = app_state.clone();
642 user_handler(move |request, response, session| {
643 compute_embeddings(
644 request,
645 response,
646 session,
647 app_state.config.openai_api_key.clone(),
648 )
649 })
650 });
651
652 Arc::new(server)
653 }
654
655 pub async fn start(&self) -> Result<()> {
656 let server_id = *self.id.lock();
657 let app_state = self.app_state.clone();
658 let peer = self.peer.clone();
659 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
660 let pool = self.connection_pool.clone();
661 let live_kit_client = self.app_state.live_kit_client.clone();
662
663 let span = info_span!("start server");
664 self.app_state.executor.spawn_detached(
665 async move {
666 tracing::info!("waiting for cleanup timeout");
667 timeout.await;
668 tracing::info!("cleanup timeout expired, retrieving stale rooms");
669 if let Some((room_ids, channel_ids)) = app_state
670 .db
671 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
672 .await
673 .trace_err()
674 {
675 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
676 tracing::info!(
677 stale_channel_buffer_count = channel_ids.len(),
678 "retrieved stale channel buffers"
679 );
680
681 for channel_id in channel_ids {
682 if let Some(refreshed_channel_buffer) = app_state
683 .db
684 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
685 .await
686 .trace_err()
687 {
688 for connection_id in refreshed_channel_buffer.connection_ids {
689 peer.send(
690 connection_id,
691 proto::UpdateChannelBufferCollaborators {
692 channel_id: channel_id.to_proto(),
693 collaborators: refreshed_channel_buffer
694 .collaborators
695 .clone(),
696 },
697 )
698 .trace_err();
699 }
700 }
701 }
702
703 for room_id in room_ids {
704 let mut contacts_to_update = HashSet::default();
705 let mut canceled_calls_to_user_ids = Vec::new();
706 let mut live_kit_room = String::new();
707 let mut delete_live_kit_room = false;
708
709 if let Some(mut refreshed_room) = app_state
710 .db
711 .clear_stale_room_participants(room_id, server_id)
712 .await
713 .trace_err()
714 {
715 tracing::info!(
716 room_id = room_id.0,
717 new_participant_count = refreshed_room.room.participants.len(),
718 "refreshed room"
719 );
720 room_updated(&refreshed_room.room, &peer);
721 if let Some(channel) = refreshed_room.channel.as_ref() {
722 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
723 }
724 contacts_to_update
725 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
726 contacts_to_update
727 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
728 canceled_calls_to_user_ids =
729 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
730 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
731 delete_live_kit_room = refreshed_room.room.participants.is_empty();
732 }
733
734 {
735 let pool = pool.lock();
736 for canceled_user_id in canceled_calls_to_user_ids {
737 for connection_id in pool.user_connection_ids(canceled_user_id) {
738 peer.send(
739 connection_id,
740 proto::CallCanceled {
741 room_id: room_id.to_proto(),
742 },
743 )
744 .trace_err();
745 }
746 }
747 }
748
749 for user_id in contacts_to_update {
750 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
751 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
752 if let Some((busy, contacts)) = busy.zip(contacts) {
753 let pool = pool.lock();
754 let updated_contact = contact_for_user(user_id, busy, &pool);
755 for contact in contacts {
756 if let db::Contact::Accepted {
757 user_id: contact_user_id,
758 ..
759 } = contact
760 {
761 for contact_conn_id in
762 pool.user_connection_ids(contact_user_id)
763 {
764 peer.send(
765 contact_conn_id,
766 proto::UpdateContacts {
767 contacts: vec![updated_contact.clone()],
768 remove_contacts: Default::default(),
769 incoming_requests: Default::default(),
770 remove_incoming_requests: Default::default(),
771 outgoing_requests: Default::default(),
772 remove_outgoing_requests: Default::default(),
773 },
774 )
775 .trace_err();
776 }
777 }
778 }
779 }
780 }
781
782 if let Some(live_kit) = live_kit_client.as_ref() {
783 if delete_live_kit_room {
784 live_kit.delete_room(live_kit_room).await.trace_err();
785 }
786 }
787 }
788 }
789
790 app_state
791 .db
792 .delete_stale_servers(&app_state.config.zed_environment, server_id)
793 .await
794 .trace_err();
795 }
796 .instrument(span),
797 );
798 Ok(())
799 }
800
801 pub fn teardown(&self) {
802 self.peer.teardown();
803 self.connection_pool.lock().reset();
804 let _ = self.teardown.send(true);
805 }
806
807 #[cfg(test)]
808 pub fn reset(&self, id: ServerId) {
809 self.teardown();
810 *self.id.lock() = id;
811 self.peer.reset(id.0 as u32);
812 let _ = self.teardown.send(false);
813 }
814
815 #[cfg(test)]
816 pub fn id(&self) -> ServerId {
817 *self.id.lock()
818 }
819
820 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
821 where
822 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
823 Fut: 'static + Send + Future<Output = Result<()>>,
824 M: EnvelopedMessage,
825 {
826 let prev_handler = self.handlers.insert(
827 TypeId::of::<M>(),
828 Box::new(move |envelope, session| {
829 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
830 let received_at = envelope.received_at;
831 tracing::info!("message received");
832 let start_time = Instant::now();
833 let future = (handler)(*envelope, session);
834 async move {
835 let result = future.await;
836 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
837 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
838 let queue_duration_ms = total_duration_ms - processing_duration_ms;
839 let payload_type = M::NAME;
840
841 match result {
842 Err(error) => {
843 tracing::error!(
844 ?error,
845 total_duration_ms,
846 processing_duration_ms,
847 queue_duration_ms,
848 payload_type,
849 "error handling message"
850 )
851 }
852 Ok(()) => tracing::info!(
853 total_duration_ms,
854 processing_duration_ms,
855 queue_duration_ms,
856 "finished handling message"
857 ),
858 }
859 }
860 .boxed()
861 }),
862 );
863 if prev_handler.is_some() {
864 panic!("registered a handler for the same message twice");
865 }
866 self
867 }
868
869 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
870 where
871 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
872 Fut: 'static + Send + Future<Output = Result<()>>,
873 M: EnvelopedMessage,
874 {
875 self.add_handler(move |envelope, session| handler(envelope.payload, session));
876 self
877 }
878
879 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
880 where
881 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
882 Fut: Send + Future<Output = Result<()>>,
883 M: RequestMessage,
884 {
885 let handler = Arc::new(handler);
886 self.add_handler(move |envelope, session| {
887 let receipt = envelope.receipt();
888 let handler = handler.clone();
889 async move {
890 let peer = session.peer.clone();
891 let responded = Arc::new(AtomicBool::default());
892 let response = Response {
893 peer: peer.clone(),
894 responded: responded.clone(),
895 receipt,
896 };
897 match (handler)(envelope.payload, response, session).await {
898 Ok(()) => {
899 if responded.load(std::sync::atomic::Ordering::SeqCst) {
900 Ok(())
901 } else {
902 Err(anyhow!("handler did not send a response"))?
903 }
904 }
905 Err(error) => {
906 let proto_err = match &error {
907 Error::Internal(err) => err.to_proto(),
908 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
909 };
910 peer.respond_with_error(receipt, proto_err)?;
911 Err(error)
912 }
913 }
914 }
915 })
916 }
917
918 #[allow(clippy::too_many_arguments)]
919 pub fn handle_connection(
920 self: &Arc<Self>,
921 connection: Connection,
922 address: String,
923 principal: Principal,
924 zed_version: ZedVersion,
925 geoip_country_code: Option<String>,
926 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
927 executor: Executor,
928 ) -> impl Future<Output = ()> {
929 let this = self.clone();
930 let span = info_span!("handle connection", %address,
931 connection_id=field::Empty,
932 user_id=field::Empty,
933 login=field::Empty,
934 impersonator=field::Empty,
935 dev_server_id=field::Empty,
936 geoip_country_code=field::Empty
937 );
938 principal.update_span(&span);
939 if let Some(country_code) = geoip_country_code.as_ref() {
940 span.record("geoip_country_code", country_code);
941 }
942
943 let mut teardown = self.teardown.subscribe();
944 async move {
945 if *teardown.borrow() {
946 tracing::error!("server is tearing down");
947 return
948 }
949 let (connection_id, handle_io, mut incoming_rx) = this
950 .peer
951 .add_connection(connection, {
952 let executor = executor.clone();
953 move |duration| executor.sleep(duration)
954 });
955 tracing::Span::current().record("connection_id", format!("{}", connection_id));
956
957 tracing::info!("connection opened");
958
959
960 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
961 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
962 Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
963 Err(error) => {
964 tracing::error!(?error, "failed to create HTTP client");
965 return;
966 }
967 };
968
969 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
970 supermaven_admin_api_key.to_string(),
971 http_client.clone(),
972 )));
973
974 let session = Session {
975 principal: principal.clone(),
976 connection_id,
977 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
978 peer: this.peer.clone(),
979 connection_pool: this.connection_pool.clone(),
980 app_state: this.app_state.clone(),
981 http_client,
982 geoip_country_code,
983 _executor: executor.clone(),
984 supermaven_client,
985 };
986
987 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
988 tracing::error!(?error, "failed to send initial client update");
989 return;
990 }
991
992 let handle_io = handle_io.fuse();
993 futures::pin_mut!(handle_io);
994
995 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
996 // This prevents deadlocks when e.g., client A performs a request to client B and
997 // client B performs a request to client A. If both clients stop processing further
998 // messages until their respective request completes, they won't have a chance to
999 // respond to the other client's request and cause a deadlock.
1000 //
1001 // This arrangement ensures we will attempt to process earlier messages first, but fall
1002 // back to processing messages arrived later in the spirit of making progress.
1003 let mut foreground_message_handlers = FuturesUnordered::new();
1004 let concurrent_handlers = Arc::new(Semaphore::new(256));
1005 loop {
1006 let next_message = async {
1007 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1008 let message = incoming_rx.next().await;
1009 (permit, message)
1010 }.fuse();
1011 futures::pin_mut!(next_message);
1012 futures::select_biased! {
1013 _ = teardown.changed().fuse() => return,
1014 result = handle_io => {
1015 if let Err(error) = result {
1016 tracing::error!(?error, "error handling I/O");
1017 }
1018 break;
1019 }
1020 _ = foreground_message_handlers.next() => {}
1021 next_message = next_message => {
1022 let (permit, message) = next_message;
1023 if let Some(message) = message {
1024 let type_name = message.payload_type_name();
1025 // note: we copy all the fields from the parent span so we can query them in the logs.
1026 // (https://github.com/tokio-rs/tracing/issues/2670).
1027 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1028 user_id=field::Empty,
1029 login=field::Empty,
1030 impersonator=field::Empty,
1031 dev_server_id=field::Empty
1032 );
1033 principal.update_span(&span);
1034 let span_enter = span.enter();
1035 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1036 let is_background = message.is_background();
1037 let handle_message = (handler)(message, session.clone());
1038 drop(span_enter);
1039
1040 let handle_message = async move {
1041 handle_message.await;
1042 drop(permit);
1043 }.instrument(span);
1044 if is_background {
1045 executor.spawn_detached(handle_message);
1046 } else {
1047 foreground_message_handlers.push(handle_message);
1048 }
1049 } else {
1050 tracing::error!("no message handler");
1051 }
1052 } else {
1053 tracing::info!("connection closed");
1054 break;
1055 }
1056 }
1057 }
1058 }
1059
1060 drop(foreground_message_handlers);
1061 tracing::info!("signing out");
1062 if let Err(error) = connection_lost(session, teardown, executor).await {
1063 tracing::error!(?error, "error signing out");
1064 }
1065
1066 }.instrument(span)
1067 }
1068
1069 async fn send_initial_client_update(
1070 &self,
1071 connection_id: ConnectionId,
1072 principal: &Principal,
1073 zed_version: ZedVersion,
1074 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1075 session: &Session,
1076 ) -> Result<()> {
1077 self.peer.send(
1078 connection_id,
1079 proto::Hello {
1080 peer_id: Some(connection_id.into()),
1081 },
1082 )?;
1083 tracing::info!("sent hello message");
1084 if let Some(send_connection_id) = send_connection_id.take() {
1085 let _ = send_connection_id.send(connection_id);
1086 }
1087
1088 match principal {
1089 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1090 if !user.connected_once {
1091 self.peer.send(connection_id, proto::ShowContacts {})?;
1092 self.app_state
1093 .db
1094 .set_user_connected_once(user.id, true)
1095 .await?;
1096 }
1097
1098 update_user_plan(user.id, session).await?;
1099
1100 let (contacts, dev_server_projects) = future::try_join(
1101 self.app_state.db.get_contacts(user.id),
1102 self.app_state.db.dev_server_projects_update(user.id),
1103 )
1104 .await?;
1105
1106 {
1107 let mut pool = self.connection_pool.lock();
1108 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1109 self.peer.send(
1110 connection_id,
1111 build_initial_contacts_update(contacts, &pool),
1112 )?;
1113 }
1114
1115 if should_auto_subscribe_to_channels(zed_version) {
1116 subscribe_user_to_channels(user.id, session).await?;
1117 }
1118
1119 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1120
1121 if let Some(incoming_call) =
1122 self.app_state.db.incoming_call_for_user(user.id).await?
1123 {
1124 self.peer.send(connection_id, incoming_call)?;
1125 }
1126
1127 update_user_contacts(user.id, session).await?;
1128 }
1129 Principal::DevServer(dev_server) => {
1130 {
1131 let mut pool = self.connection_pool.lock();
1132 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1133 {
1134 self.peer.send(
1135 stale_connection_id,
1136 proto::ShutdownDevServer {
1137 reason: Some(
1138 "another dev server connected with the same token".to_string(),
1139 ),
1140 },
1141 )?;
1142 pool.remove_connection(stale_connection_id)?;
1143 };
1144 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145 }
1146
1147 let projects = self
1148 .app_state
1149 .db
1150 .get_projects_for_dev_server(dev_server.id)
1151 .await?;
1152 self.peer
1153 .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155 let status = self
1156 .app_state
1157 .db
1158 .dev_server_projects_update(dev_server.user_id)
1159 .await?;
1160 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1161 }
1162 }
1163
1164 Ok(())
1165 }
1166
1167 pub async fn invite_code_redeemed(
1168 self: &Arc<Self>,
1169 inviter_id: UserId,
1170 invitee_id: UserId,
1171 ) -> Result<()> {
1172 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173 if let Some(code) = &user.invite_code {
1174 let pool = self.connection_pool.lock();
1175 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176 for connection_id in pool.user_connection_ids(inviter_id) {
1177 self.peer.send(
1178 connection_id,
1179 proto::UpdateContacts {
1180 contacts: vec![invitee_contact.clone()],
1181 ..Default::default()
1182 },
1183 )?;
1184 self.peer.send(
1185 connection_id,
1186 proto::UpdateInviteInfo {
1187 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188 count: user.invite_count as u32,
1189 },
1190 )?;
1191 }
1192 }
1193 }
1194 Ok(())
1195 }
1196
1197 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199 if let Some(invite_code) = &user.invite_code {
1200 let pool = self.connection_pool.lock();
1201 for connection_id in pool.user_connection_ids(user_id) {
1202 self.peer.send(
1203 connection_id,
1204 proto::UpdateInviteInfo {
1205 url: format!(
1206 "{}{}",
1207 self.app_state.config.invite_link_prefix, invite_code
1208 ),
1209 count: user.invite_count as u32,
1210 },
1211 )?;
1212 }
1213 }
1214 }
1215 Ok(())
1216 }
1217
1218 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219 ServerSnapshot {
1220 connection_pool: ConnectionPoolGuard {
1221 guard: self.connection_pool.lock(),
1222 _not_send: PhantomData,
1223 },
1224 peer: &self.peer,
1225 }
1226 }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230 type Target = ConnectionPool;
1231
1232 fn deref(&self) -> &Self::Target {
1233 &self.guard
1234 }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238 fn deref_mut(&mut self) -> &mut Self::Target {
1239 &mut self.guard
1240 }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244 fn drop(&mut self) {
1245 #[cfg(test)]
1246 self.check_invariants();
1247 }
1248}
1249
1250fn broadcast<F>(
1251 sender_id: Option<ConnectionId>,
1252 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253 mut f: F,
1254) where
1255 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257 for receiver_id in receiver_ids {
1258 if Some(receiver_id) != sender_id {
1259 if let Err(error) = f(receiver_id) {
1260 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261 }
1262 }
1263 }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269 fn name() -> &'static HeaderName {
1270 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272 }
1273
1274 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275 where
1276 Self: Sized,
1277 I: Iterator<Item = &'i axum::http::HeaderValue>,
1278 {
1279 let version = values
1280 .next()
1281 .ok_or_else(axum::headers::Error::invalid)?
1282 .to_str()
1283 .map_err(|_| axum::headers::Error::invalid())?
1284 .parse()
1285 .map_err(|_| axum::headers::Error::invalid())?;
1286 Ok(Self(version))
1287 }
1288
1289 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290 values.extend([self.0.to_string().parse().unwrap()]);
1291 }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296 fn name() -> &'static HeaderName {
1297 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299 }
1300
1301 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302 where
1303 Self: Sized,
1304 I: Iterator<Item = &'i axum::http::HeaderValue>,
1305 {
1306 let version = values
1307 .next()
1308 .ok_or_else(axum::headers::Error::invalid)?
1309 .to_str()
1310 .map_err(|_| axum::headers::Error::invalid())?
1311 .parse()
1312 .map_err(|_| axum::headers::Error::invalid())?;
1313 Ok(Self(version))
1314 }
1315
1316 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317 values.extend([self.0.to_string().parse().unwrap()]);
1318 }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322 Router::new()
1323 .route("/rpc", get(handle_websocket_request))
1324 .layer(
1325 ServiceBuilder::new()
1326 .layer(Extension(server.app_state.clone()))
1327 .layer(middleware::from_fn(auth::validate_header)),
1328 )
1329 .route("/metrics", get(handle_metrics))
1330 .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337 Extension(server): Extension<Arc<Server>>,
1338 Extension(principal): Extension<Principal>,
1339 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1340 ws: WebSocketUpgrade,
1341) -> axum::response::Response {
1342 if protocol_version != rpc::PROTOCOL_VERSION {
1343 return (
1344 StatusCode::UPGRADE_REQUIRED,
1345 "client must be upgraded".to_string(),
1346 )
1347 .into_response();
1348 }
1349
1350 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1351 return (
1352 StatusCode::UPGRADE_REQUIRED,
1353 "no version header found".to_string(),
1354 )
1355 .into_response();
1356 };
1357
1358 if !version.can_collaborate() {
1359 return (
1360 StatusCode::UPGRADE_REQUIRED,
1361 "client must be upgraded".to_string(),
1362 )
1363 .into_response();
1364 }
1365
1366 let socket_address = socket_address.to_string();
1367 ws.on_upgrade(move |socket| {
1368 let socket = socket
1369 .map_ok(to_tungstenite_message)
1370 .err_into()
1371 .with(|message| async move { to_axum_message(message) });
1372 let connection = Connection::new(Box::pin(socket));
1373 async move {
1374 server
1375 .handle_connection(
1376 connection,
1377 socket_address,
1378 principal,
1379 version,
1380 country_code_header.map(|header| header.to_string()),
1381 None,
1382 Executor::Production,
1383 )
1384 .await;
1385 }
1386 })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391 let connections_metric = CONNECTIONS_METRIC
1392 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394 let connections = server
1395 .connection_pool
1396 .lock()
1397 .connections()
1398 .filter(|connection| !connection.admin)
1399 .count();
1400 connections_metric.set(connections as _);
1401
1402 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404 register_int_gauge!(
1405 "shared_projects",
1406 "number of open projects with one or more guests"
1407 )
1408 .unwrap()
1409 });
1410
1411 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412 shared_projects_metric.set(shared_projects as _);
1413
1414 let encoder = prometheus::TextEncoder::new();
1415 let metric_families = prometheus::gather();
1416 let encoded_metrics = encoder
1417 .encode_to_string(&metric_families)
1418 .map_err(|err| anyhow!("{}", err))?;
1419 Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424 session: Session,
1425 mut teardown: watch::Receiver<bool>,
1426 executor: Executor,
1427) -> Result<()> {
1428 session.peer.disconnect(session.connection_id);
1429 session
1430 .connection_pool()
1431 .await
1432 .remove_connection(session.connection_id)?;
1433
1434 session
1435 .db()
1436 .await
1437 .connection_lost(session.connection_id)
1438 .await
1439 .trace_err();
1440
1441 futures::select_biased! {
1442 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443 match &session.principal {
1444 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445 let session = session.for_user().unwrap();
1446
1447 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448 leave_room_for_session(&session, session.connection_id).await.trace_err();
1449 leave_channel_buffers_for_session(&session)
1450 .await
1451 .trace_err();
1452
1453 if !session
1454 .connection_pool()
1455 .await
1456 .is_user_online(session.user_id())
1457 {
1458 let db = session.db().await;
1459 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460 room_updated(&room, &session.peer);
1461 }
1462 }
1463
1464 update_user_contacts(session.user_id(), &session).await?;
1465 },
1466 Principal::DevServer(_) => {
1467 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468 },
1469 }
1470 },
1471 _ = teardown.changed().fuse() => {}
1472 }
1473
1474 Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479 response.send(proto::Ack {})?;
1480 Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485 _request: proto::CreateRoom,
1486 response: Response<proto::CreateRoom>,
1487 session: UserSession,
1488) -> Result<()> {
1489 let live_kit_room = nanoid::nanoid!(30);
1490
1491 let live_kit_connection_info = util::maybe!(async {
1492 let live_kit = session.app_state.live_kit_client.as_ref();
1493 let live_kit = live_kit?;
1494 let user_id = session.user_id().to_string();
1495
1496 let token = live_kit
1497 .room_token(&live_kit_room, &user_id.to_string())
1498 .trace_err()?;
1499
1500 Some(proto::LiveKitConnectionInfo {
1501 server_url: live_kit.url().into(),
1502 token,
1503 can_publish: true,
1504 })
1505 })
1506 .await;
1507
1508 let room = session
1509 .db()
1510 .await
1511 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512 .await?;
1513
1514 response.send(proto::CreateRoomResponse {
1515 room: Some(room.clone()),
1516 live_kit_connection_info,
1517 })?;
1518
1519 update_user_contacts(session.user_id(), &session).await?;
1520 Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525 request: proto::JoinRoom,
1526 response: Response<proto::JoinRoom>,
1527 session: UserSession,
1528) -> Result<()> {
1529 let room_id = RoomId::from_proto(request.id);
1530
1531 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533 if let Some(channel_id) = channel_id {
1534 return join_channel_internal(channel_id, Box::new(response), session).await;
1535 }
1536
1537 let joined_room = {
1538 let room = session
1539 .db()
1540 .await
1541 .join_room(room_id, session.user_id(), session.connection_id)
1542 .await?;
1543 room_updated(&room.room, &session.peer);
1544 room.into_inner()
1545 };
1546
1547 for connection_id in session
1548 .connection_pool()
1549 .await
1550 .user_connection_ids(session.user_id())
1551 {
1552 session
1553 .peer
1554 .send(
1555 connection_id,
1556 proto::CallCanceled {
1557 room_id: room_id.to_proto(),
1558 },
1559 )
1560 .trace_err();
1561 }
1562
1563 let live_kit_connection_info =
1564 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1565 live_kit
1566 .room_token(
1567 &joined_room.room.live_kit_room,
1568 &session.user_id().to_string(),
1569 )
1570 .trace_err()
1571 .map(|token| proto::LiveKitConnectionInfo {
1572 server_url: live_kit.url().into(),
1573 token,
1574 can_publish: true,
1575 })
1576 } else {
1577 None
1578 };
1579
1580 response.send(proto::JoinRoomResponse {
1581 room: Some(joined_room.room),
1582 channel_id: None,
1583 live_kit_connection_info,
1584 })?;
1585
1586 update_user_contacts(session.user_id(), &session).await?;
1587 Ok(())
1588}
1589
1590/// Rejoin room is used to reconnect to a room after connection errors.
1591async fn rejoin_room(
1592 request: proto::RejoinRoom,
1593 response: Response<proto::RejoinRoom>,
1594 session: UserSession,
1595) -> Result<()> {
1596 let room;
1597 let channel;
1598 {
1599 let mut rejoined_room = session
1600 .db()
1601 .await
1602 .rejoin_room(request, session.user_id(), session.connection_id)
1603 .await?;
1604
1605 response.send(proto::RejoinRoomResponse {
1606 room: Some(rejoined_room.room.clone()),
1607 reshared_projects: rejoined_room
1608 .reshared_projects
1609 .iter()
1610 .map(|project| proto::ResharedProject {
1611 id: project.id.to_proto(),
1612 collaborators: project
1613 .collaborators
1614 .iter()
1615 .map(|collaborator| collaborator.to_proto())
1616 .collect(),
1617 })
1618 .collect(),
1619 rejoined_projects: rejoined_room
1620 .rejoined_projects
1621 .iter()
1622 .map(|rejoined_project| rejoined_project.to_proto())
1623 .collect(),
1624 })?;
1625 room_updated(&rejoined_room.room, &session.peer);
1626
1627 for project in &rejoined_room.reshared_projects {
1628 for collaborator in &project.collaborators {
1629 session
1630 .peer
1631 .send(
1632 collaborator.connection_id,
1633 proto::UpdateProjectCollaborator {
1634 project_id: project.id.to_proto(),
1635 old_peer_id: Some(project.old_connection_id.into()),
1636 new_peer_id: Some(session.connection_id.into()),
1637 },
1638 )
1639 .trace_err();
1640 }
1641
1642 broadcast(
1643 Some(session.connection_id),
1644 project
1645 .collaborators
1646 .iter()
1647 .map(|collaborator| collaborator.connection_id),
1648 |connection_id| {
1649 session.peer.forward_send(
1650 session.connection_id,
1651 connection_id,
1652 proto::UpdateProject {
1653 project_id: project.id.to_proto(),
1654 worktrees: project.worktrees.clone(),
1655 },
1656 )
1657 },
1658 );
1659 }
1660
1661 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1662
1663 let rejoined_room = rejoined_room.into_inner();
1664
1665 room = rejoined_room.room;
1666 channel = rejoined_room.channel;
1667 }
1668
1669 if let Some(channel) = channel {
1670 channel_updated(
1671 &channel,
1672 &room,
1673 &session.peer,
1674 &*session.connection_pool().await,
1675 );
1676 }
1677
1678 update_user_contacts(session.user_id(), &session).await?;
1679 Ok(())
1680}
1681
1682fn notify_rejoined_projects(
1683 rejoined_projects: &mut Vec<RejoinedProject>,
1684 session: &UserSession,
1685) -> Result<()> {
1686 for project in rejoined_projects.iter() {
1687 for collaborator in &project.collaborators {
1688 session
1689 .peer
1690 .send(
1691 collaborator.connection_id,
1692 proto::UpdateProjectCollaborator {
1693 project_id: project.id.to_proto(),
1694 old_peer_id: Some(project.old_connection_id.into()),
1695 new_peer_id: Some(session.connection_id.into()),
1696 },
1697 )
1698 .trace_err();
1699 }
1700 }
1701
1702 for project in rejoined_projects {
1703 for worktree in mem::take(&mut project.worktrees) {
1704 #[cfg(any(test, feature = "test-support"))]
1705 const MAX_CHUNK_SIZE: usize = 2;
1706 #[cfg(not(any(test, feature = "test-support")))]
1707 const MAX_CHUNK_SIZE: usize = 256;
1708
1709 // Stream this worktree's entries.
1710 let message = proto::UpdateWorktree {
1711 project_id: project.id.to_proto(),
1712 worktree_id: worktree.id,
1713 abs_path: worktree.abs_path.clone(),
1714 root_name: worktree.root_name,
1715 updated_entries: worktree.updated_entries,
1716 removed_entries: worktree.removed_entries,
1717 scan_id: worktree.scan_id,
1718 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1719 updated_repositories: worktree.updated_repositories,
1720 removed_repositories: worktree.removed_repositories,
1721 };
1722 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1723 session.peer.send(session.connection_id, update.clone())?;
1724 }
1725
1726 // Stream this worktree's diagnostics.
1727 for summary in worktree.diagnostic_summaries {
1728 session.peer.send(
1729 session.connection_id,
1730 proto::UpdateDiagnosticSummary {
1731 project_id: project.id.to_proto(),
1732 worktree_id: worktree.id,
1733 summary: Some(summary),
1734 },
1735 )?;
1736 }
1737
1738 for settings_file in worktree.settings_files {
1739 session.peer.send(
1740 session.connection_id,
1741 proto::UpdateWorktreeSettings {
1742 project_id: project.id.to_proto(),
1743 worktree_id: worktree.id,
1744 path: settings_file.path,
1745 content: Some(settings_file.content),
1746 },
1747 )?;
1748 }
1749 }
1750
1751 for language_server in &project.language_servers {
1752 session.peer.send(
1753 session.connection_id,
1754 proto::UpdateLanguageServer {
1755 project_id: project.id.to_proto(),
1756 language_server_id: language_server.id,
1757 variant: Some(
1758 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1759 proto::LspDiskBasedDiagnosticsUpdated {},
1760 ),
1761 ),
1762 },
1763 )?;
1764 }
1765 }
1766 Ok(())
1767}
1768
1769/// leave room disconnects from the room.
1770async fn leave_room(
1771 _: proto::LeaveRoom,
1772 response: Response<proto::LeaveRoom>,
1773 session: UserSession,
1774) -> Result<()> {
1775 leave_room_for_session(&session, session.connection_id).await?;
1776 response.send(proto::Ack {})?;
1777 Ok(())
1778}
1779
1780/// Updates the permissions of someone else in the room.
1781async fn set_room_participant_role(
1782 request: proto::SetRoomParticipantRole,
1783 response: Response<proto::SetRoomParticipantRole>,
1784 session: UserSession,
1785) -> Result<()> {
1786 let user_id = UserId::from_proto(request.user_id);
1787 let role = ChannelRole::from(request.role());
1788
1789 let (live_kit_room, can_publish) = {
1790 let room = session
1791 .db()
1792 .await
1793 .set_room_participant_role(
1794 session.user_id(),
1795 RoomId::from_proto(request.room_id),
1796 user_id,
1797 role,
1798 )
1799 .await?;
1800
1801 let live_kit_room = room.live_kit_room.clone();
1802 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1803 room_updated(&room, &session.peer);
1804 (live_kit_room, can_publish)
1805 };
1806
1807 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1808 live_kit
1809 .update_participant(
1810 live_kit_room.clone(),
1811 request.user_id.to_string(),
1812 live_kit_server::proto::ParticipantPermission {
1813 can_subscribe: true,
1814 can_publish,
1815 can_publish_data: can_publish,
1816 hidden: false,
1817 recorder: false,
1818 },
1819 )
1820 .await
1821 .trace_err();
1822 }
1823
1824 response.send(proto::Ack {})?;
1825 Ok(())
1826}
1827
1828/// Call someone else into the current room
1829async fn call(
1830 request: proto::Call,
1831 response: Response<proto::Call>,
1832 session: UserSession,
1833) -> Result<()> {
1834 let room_id = RoomId::from_proto(request.room_id);
1835 let calling_user_id = session.user_id();
1836 let calling_connection_id = session.connection_id;
1837 let called_user_id = UserId::from_proto(request.called_user_id);
1838 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1839 if !session
1840 .db()
1841 .await
1842 .has_contact(calling_user_id, called_user_id)
1843 .await?
1844 {
1845 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1846 }
1847
1848 let incoming_call = {
1849 let (room, incoming_call) = &mut *session
1850 .db()
1851 .await
1852 .call(
1853 room_id,
1854 calling_user_id,
1855 calling_connection_id,
1856 called_user_id,
1857 initial_project_id,
1858 )
1859 .await?;
1860 room_updated(room, &session.peer);
1861 mem::take(incoming_call)
1862 };
1863 update_user_contacts(called_user_id, &session).await?;
1864
1865 let mut calls = session
1866 .connection_pool()
1867 .await
1868 .user_connection_ids(called_user_id)
1869 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1870 .collect::<FuturesUnordered<_>>();
1871
1872 while let Some(call_response) = calls.next().await {
1873 match call_response.as_ref() {
1874 Ok(_) => {
1875 response.send(proto::Ack {})?;
1876 return Ok(());
1877 }
1878 Err(_) => {
1879 call_response.trace_err();
1880 }
1881 }
1882 }
1883
1884 {
1885 let room = session
1886 .db()
1887 .await
1888 .call_failed(room_id, called_user_id)
1889 .await?;
1890 room_updated(&room, &session.peer);
1891 }
1892 update_user_contacts(called_user_id, &session).await?;
1893
1894 Err(anyhow!("failed to ring user"))?
1895}
1896
1897/// Cancel an outgoing call.
1898async fn cancel_call(
1899 request: proto::CancelCall,
1900 response: Response<proto::CancelCall>,
1901 session: UserSession,
1902) -> Result<()> {
1903 let called_user_id = UserId::from_proto(request.called_user_id);
1904 let room_id = RoomId::from_proto(request.room_id);
1905 {
1906 let room = session
1907 .db()
1908 .await
1909 .cancel_call(room_id, session.connection_id, called_user_id)
1910 .await?;
1911 room_updated(&room, &session.peer);
1912 }
1913
1914 for connection_id in session
1915 .connection_pool()
1916 .await
1917 .user_connection_ids(called_user_id)
1918 {
1919 session
1920 .peer
1921 .send(
1922 connection_id,
1923 proto::CallCanceled {
1924 room_id: room_id.to_proto(),
1925 },
1926 )
1927 .trace_err();
1928 }
1929 response.send(proto::Ack {})?;
1930
1931 update_user_contacts(called_user_id, &session).await?;
1932 Ok(())
1933}
1934
1935/// Decline an incoming call.
1936async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1937 let room_id = RoomId::from_proto(message.room_id);
1938 {
1939 let room = session
1940 .db()
1941 .await
1942 .decline_call(Some(room_id), session.user_id())
1943 .await?
1944 .ok_or_else(|| anyhow!("failed to decline call"))?;
1945 room_updated(&room, &session.peer);
1946 }
1947
1948 for connection_id in session
1949 .connection_pool()
1950 .await
1951 .user_connection_ids(session.user_id())
1952 {
1953 session
1954 .peer
1955 .send(
1956 connection_id,
1957 proto::CallCanceled {
1958 room_id: room_id.to_proto(),
1959 },
1960 )
1961 .trace_err();
1962 }
1963 update_user_contacts(session.user_id(), &session).await?;
1964 Ok(())
1965}
1966
1967/// Updates other participants in the room with your current location.
1968async fn update_participant_location(
1969 request: proto::UpdateParticipantLocation,
1970 response: Response<proto::UpdateParticipantLocation>,
1971 session: UserSession,
1972) -> Result<()> {
1973 let room_id = RoomId::from_proto(request.room_id);
1974 let location = request
1975 .location
1976 .ok_or_else(|| anyhow!("invalid location"))?;
1977
1978 let db = session.db().await;
1979 let room = db
1980 .update_room_participant_location(room_id, session.connection_id, location)
1981 .await?;
1982
1983 room_updated(&room, &session.peer);
1984 response.send(proto::Ack {})?;
1985 Ok(())
1986}
1987
1988/// Share a project into the room.
1989async fn share_project(
1990 request: proto::ShareProject,
1991 response: Response<proto::ShareProject>,
1992 session: UserSession,
1993) -> Result<()> {
1994 let (project_id, room) = &*session
1995 .db()
1996 .await
1997 .share_project(
1998 RoomId::from_proto(request.room_id),
1999 session.connection_id,
2000 &request.worktrees,
2001 request.is_ssh_project,
2002 request
2003 .dev_server_project_id
2004 .map(DevServerProjectId::from_proto),
2005 )
2006 .await?;
2007 response.send(proto::ShareProjectResponse {
2008 project_id: project_id.to_proto(),
2009 })?;
2010 room_updated(room, &session.peer);
2011
2012 Ok(())
2013}
2014
2015/// Unshare a project from the room.
2016async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2017 let project_id = ProjectId::from_proto(message.project_id);
2018 unshare_project_internal(
2019 project_id,
2020 session.connection_id,
2021 session.user_id(),
2022 &session,
2023 )
2024 .await
2025}
2026
2027async fn unshare_project_internal(
2028 project_id: ProjectId,
2029 connection_id: ConnectionId,
2030 user_id: Option<UserId>,
2031 session: &Session,
2032) -> Result<()> {
2033 let delete = {
2034 let room_guard = session
2035 .db()
2036 .await
2037 .unshare_project(project_id, connection_id, user_id)
2038 .await?;
2039
2040 let (delete, room, guest_connection_ids) = &*room_guard;
2041
2042 let message = proto::UnshareProject {
2043 project_id: project_id.to_proto(),
2044 };
2045
2046 broadcast(
2047 Some(connection_id),
2048 guest_connection_ids.iter().copied(),
2049 |conn_id| session.peer.send(conn_id, message.clone()),
2050 );
2051 if let Some(room) = room {
2052 room_updated(room, &session.peer);
2053 }
2054
2055 *delete
2056 };
2057
2058 if delete {
2059 let db = session.db().await;
2060 db.delete_project(project_id).await?;
2061 }
2062
2063 Ok(())
2064}
2065
2066/// DevServer makes a project available online
2067async fn share_dev_server_project(
2068 request: proto::ShareDevServerProject,
2069 response: Response<proto::ShareDevServerProject>,
2070 session: DevServerSession,
2071) -> Result<()> {
2072 let (dev_server_project, user_id, status) = session
2073 .db()
2074 .await
2075 .share_dev_server_project(
2076 DevServerProjectId::from_proto(request.dev_server_project_id),
2077 session.dev_server_id(),
2078 session.connection_id,
2079 &request.worktrees,
2080 )
2081 .await?;
2082 let Some(project_id) = dev_server_project.project_id else {
2083 return Err(anyhow!("failed to share remote project"))?;
2084 };
2085
2086 send_dev_server_projects_update(user_id, status, &session).await;
2087
2088 response.send(proto::ShareProjectResponse { project_id })?;
2089
2090 Ok(())
2091}
2092
2093/// Join someone elses shared project.
2094async fn join_project(
2095 request: proto::JoinProject,
2096 response: Response<proto::JoinProject>,
2097 session: UserSession,
2098) -> Result<()> {
2099 let project_id = ProjectId::from_proto(request.project_id);
2100
2101 tracing::info!(%project_id, "join project");
2102
2103 let db = session.db().await;
2104 let (project, replica_id) = &mut *db
2105 .join_project(project_id, session.connection_id, session.user_id())
2106 .await?;
2107 drop(db);
2108 tracing::info!(%project_id, "join remote project");
2109 join_project_internal(response, session, project, replica_id)
2110}
2111
2112trait JoinProjectInternalResponse {
2113 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2114}
2115impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2116 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2117 Response::<proto::JoinProject>::send(self, result)
2118 }
2119}
2120impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2121 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2122 Response::<proto::JoinHostedProject>::send(self, result)
2123 }
2124}
2125
2126fn join_project_internal(
2127 response: impl JoinProjectInternalResponse,
2128 session: UserSession,
2129 project: &mut Project,
2130 replica_id: &ReplicaId,
2131) -> Result<()> {
2132 let collaborators = project
2133 .collaborators
2134 .iter()
2135 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2136 .map(|collaborator| collaborator.to_proto())
2137 .collect::<Vec<_>>();
2138 let project_id = project.id;
2139 let guest_user_id = session.user_id();
2140
2141 let worktrees = project
2142 .worktrees
2143 .iter()
2144 .map(|(id, worktree)| proto::WorktreeMetadata {
2145 id: *id,
2146 root_name: worktree.root_name.clone(),
2147 visible: worktree.visible,
2148 abs_path: worktree.abs_path.clone(),
2149 })
2150 .collect::<Vec<_>>();
2151
2152 let add_project_collaborator = proto::AddProjectCollaborator {
2153 project_id: project_id.to_proto(),
2154 collaborator: Some(proto::Collaborator {
2155 peer_id: Some(session.connection_id.into()),
2156 replica_id: replica_id.0 as u32,
2157 user_id: guest_user_id.to_proto(),
2158 }),
2159 };
2160
2161 for collaborator in &collaborators {
2162 session
2163 .peer
2164 .send(
2165 collaborator.peer_id.unwrap().into(),
2166 add_project_collaborator.clone(),
2167 )
2168 .trace_err();
2169 }
2170
2171 // First, we send the metadata associated with each worktree.
2172 response.send(proto::JoinProjectResponse {
2173 project_id: project.id.0 as u64,
2174 worktrees: worktrees.clone(),
2175 replica_id: replica_id.0 as u32,
2176 collaborators: collaborators.clone(),
2177 language_servers: project.language_servers.clone(),
2178 role: project.role.into(),
2179 dev_server_project_id: project
2180 .dev_server_project_id
2181 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2182 })?;
2183
2184 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2185 #[cfg(any(test, feature = "test-support"))]
2186 const MAX_CHUNK_SIZE: usize = 2;
2187 #[cfg(not(any(test, feature = "test-support")))]
2188 const MAX_CHUNK_SIZE: usize = 256;
2189
2190 // Stream this worktree's entries.
2191 let message = proto::UpdateWorktree {
2192 project_id: project_id.to_proto(),
2193 worktree_id,
2194 abs_path: worktree.abs_path.clone(),
2195 root_name: worktree.root_name,
2196 updated_entries: worktree.entries,
2197 removed_entries: Default::default(),
2198 scan_id: worktree.scan_id,
2199 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2200 updated_repositories: worktree.repository_entries.into_values().collect(),
2201 removed_repositories: Default::default(),
2202 };
2203 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2204 session.peer.send(session.connection_id, update.clone())?;
2205 }
2206
2207 // Stream this worktree's diagnostics.
2208 for summary in worktree.diagnostic_summaries {
2209 session.peer.send(
2210 session.connection_id,
2211 proto::UpdateDiagnosticSummary {
2212 project_id: project_id.to_proto(),
2213 worktree_id: worktree.id,
2214 summary: Some(summary),
2215 },
2216 )?;
2217 }
2218
2219 for settings_file in worktree.settings_files {
2220 session.peer.send(
2221 session.connection_id,
2222 proto::UpdateWorktreeSettings {
2223 project_id: project_id.to_proto(),
2224 worktree_id: worktree.id,
2225 path: settings_file.path,
2226 content: Some(settings_file.content),
2227 },
2228 )?;
2229 }
2230 }
2231
2232 for language_server in &project.language_servers {
2233 session.peer.send(
2234 session.connection_id,
2235 proto::UpdateLanguageServer {
2236 project_id: project_id.to_proto(),
2237 language_server_id: language_server.id,
2238 variant: Some(
2239 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2240 proto::LspDiskBasedDiagnosticsUpdated {},
2241 ),
2242 ),
2243 },
2244 )?;
2245 }
2246
2247 Ok(())
2248}
2249
2250/// Leave someone elses shared project.
2251async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2252 let sender_id = session.connection_id;
2253 let project_id = ProjectId::from_proto(request.project_id);
2254 let db = session.db().await;
2255 if db.is_hosted_project(project_id).await? {
2256 let project = db.leave_hosted_project(project_id, sender_id).await?;
2257 project_left(&project, &session);
2258 return Ok(());
2259 }
2260
2261 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2262 tracing::info!(
2263 %project_id,
2264 "leave project"
2265 );
2266
2267 project_left(project, &session);
2268 if let Some(room) = room {
2269 room_updated(room, &session.peer);
2270 }
2271
2272 Ok(())
2273}
2274
2275async fn join_hosted_project(
2276 request: proto::JoinHostedProject,
2277 response: Response<proto::JoinHostedProject>,
2278 session: UserSession,
2279) -> Result<()> {
2280 let (mut project, replica_id) = session
2281 .db()
2282 .await
2283 .join_hosted_project(
2284 ProjectId(request.project_id as i32),
2285 session.user_id(),
2286 session.connection_id,
2287 )
2288 .await?;
2289
2290 join_project_internal(response, session, &mut project, &replica_id)
2291}
2292
2293async fn list_remote_directory(
2294 request: proto::ListRemoteDirectory,
2295 response: Response<proto::ListRemoteDirectory>,
2296 session: UserSession,
2297) -> Result<()> {
2298 let dev_server_id = DevServerId(request.dev_server_id as i32);
2299 let dev_server_connection_id = session
2300 .connection_pool()
2301 .await
2302 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2303
2304 session
2305 .db()
2306 .await
2307 .get_dev_server_for_user(dev_server_id, session.user_id())
2308 .await?;
2309
2310 response.send(
2311 session
2312 .peer
2313 .forward_request(session.connection_id, dev_server_connection_id, request)
2314 .await?,
2315 )?;
2316 Ok(())
2317}
2318
2319async fn update_dev_server_project(
2320 request: proto::UpdateDevServerProject,
2321 response: Response<proto::UpdateDevServerProject>,
2322 session: UserSession,
2323) -> Result<()> {
2324 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2325
2326 let (dev_server_project, update) = session
2327 .db()
2328 .await
2329 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2330 .await?;
2331
2332 let projects = session
2333 .db()
2334 .await
2335 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2336 .await?;
2337
2338 let dev_server_connection_id = session
2339 .connection_pool()
2340 .await
2341 .dev_server_connection_id_supporting(
2342 dev_server_project.dev_server_id,
2343 ZedVersion::with_list_directory(),
2344 )?;
2345
2346 session.peer.send(
2347 dev_server_connection_id,
2348 proto::DevServerInstructions { projects },
2349 )?;
2350
2351 send_dev_server_projects_update(session.user_id(), update, &session).await;
2352
2353 response.send(proto::Ack {})
2354}
2355
2356async fn create_dev_server_project(
2357 request: proto::CreateDevServerProject,
2358 response: Response<proto::CreateDevServerProject>,
2359 session: UserSession,
2360) -> Result<()> {
2361 let dev_server_id = DevServerId(request.dev_server_id as i32);
2362 let dev_server_connection_id = session
2363 .connection_pool()
2364 .await
2365 .dev_server_connection_id(dev_server_id);
2366 let Some(dev_server_connection_id) = dev_server_connection_id else {
2367 Err(ErrorCode::DevServerOffline
2368 .message("Cannot create a remote project when the dev server is offline".to_string())
2369 .anyhow())?
2370 };
2371
2372 let path = request.path.clone();
2373 //Check that the path exists on the dev server
2374 session
2375 .peer
2376 .forward_request(
2377 session.connection_id,
2378 dev_server_connection_id,
2379 proto::ValidateDevServerProjectRequest { path: path.clone() },
2380 )
2381 .await?;
2382
2383 let (dev_server_project, update) = session
2384 .db()
2385 .await
2386 .create_dev_server_project(
2387 DevServerId(request.dev_server_id as i32),
2388 &request.path,
2389 session.user_id(),
2390 )
2391 .await?;
2392
2393 let projects = session
2394 .db()
2395 .await
2396 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2397 .await?;
2398
2399 session.peer.send(
2400 dev_server_connection_id,
2401 proto::DevServerInstructions { projects },
2402 )?;
2403
2404 send_dev_server_projects_update(session.user_id(), update, &session).await;
2405
2406 response.send(proto::CreateDevServerProjectResponse {
2407 dev_server_project: Some(dev_server_project.to_proto(None)),
2408 })?;
2409 Ok(())
2410}
2411
2412async fn create_dev_server(
2413 request: proto::CreateDevServer,
2414 response: Response<proto::CreateDevServer>,
2415 session: UserSession,
2416) -> Result<()> {
2417 let access_token = auth::random_token();
2418 let hashed_access_token = auth::hash_access_token(&access_token);
2419
2420 if request.name.is_empty() {
2421 return Err(proto::ErrorCode::Forbidden
2422 .message("Dev server name cannot be empty".to_string())
2423 .anyhow())?;
2424 }
2425
2426 let (dev_server, status) = session
2427 .db()
2428 .await
2429 .create_dev_server(
2430 &request.name,
2431 request.ssh_connection_string.as_deref(),
2432 &hashed_access_token,
2433 session.user_id(),
2434 )
2435 .await?;
2436
2437 send_dev_server_projects_update(session.user_id(), status, &session).await;
2438
2439 response.send(proto::CreateDevServerResponse {
2440 dev_server_id: dev_server.id.0 as u64,
2441 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2442 name: request.name,
2443 })?;
2444 Ok(())
2445}
2446
2447async fn regenerate_dev_server_token(
2448 request: proto::RegenerateDevServerToken,
2449 response: Response<proto::RegenerateDevServerToken>,
2450 session: UserSession,
2451) -> Result<()> {
2452 let dev_server_id = DevServerId(request.dev_server_id as i32);
2453 let access_token = auth::random_token();
2454 let hashed_access_token = auth::hash_access_token(&access_token);
2455
2456 let connection_id = session
2457 .connection_pool()
2458 .await
2459 .dev_server_connection_id(dev_server_id);
2460 if let Some(connection_id) = connection_id {
2461 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2462 session.peer.send(
2463 connection_id,
2464 proto::ShutdownDevServer {
2465 reason: Some("dev server token was regenerated".to_string()),
2466 },
2467 )?;
2468 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2469 }
2470
2471 let status = session
2472 .db()
2473 .await
2474 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2475 .await?;
2476
2477 send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479 response.send(proto::RegenerateDevServerTokenResponse {
2480 dev_server_id: dev_server_id.to_proto(),
2481 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2482 })?;
2483 Ok(())
2484}
2485
2486async fn rename_dev_server(
2487 request: proto::RenameDevServer,
2488 response: Response<proto::RenameDevServer>,
2489 session: UserSession,
2490) -> Result<()> {
2491 if request.name.trim().is_empty() {
2492 return Err(proto::ErrorCode::Forbidden
2493 .message("Dev server name cannot be empty".to_string())
2494 .anyhow())?;
2495 }
2496
2497 let dev_server_id = DevServerId(request.dev_server_id as i32);
2498 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2499 if dev_server.user_id != session.user_id() {
2500 return Err(anyhow!(ErrorCode::Forbidden))?;
2501 }
2502
2503 let status = session
2504 .db()
2505 .await
2506 .rename_dev_server(
2507 dev_server_id,
2508 &request.name,
2509 request.ssh_connection_string.as_deref(),
2510 session.user_id(),
2511 )
2512 .await?;
2513
2514 send_dev_server_projects_update(session.user_id(), status, &session).await;
2515
2516 response.send(proto::Ack {})?;
2517 Ok(())
2518}
2519
2520async fn delete_dev_server(
2521 request: proto::DeleteDevServer,
2522 response: Response<proto::DeleteDevServer>,
2523 session: UserSession,
2524) -> Result<()> {
2525 let dev_server_id = DevServerId(request.dev_server_id as i32);
2526 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2527 if dev_server.user_id != session.user_id() {
2528 return Err(anyhow!(ErrorCode::Forbidden))?;
2529 }
2530
2531 let connection_id = session
2532 .connection_pool()
2533 .await
2534 .dev_server_connection_id(dev_server_id);
2535 if let Some(connection_id) = connection_id {
2536 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2537 session.peer.send(
2538 connection_id,
2539 proto::ShutdownDevServer {
2540 reason: Some("dev server was deleted".to_string()),
2541 },
2542 )?;
2543 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2544 }
2545
2546 let status = session
2547 .db()
2548 .await
2549 .delete_dev_server(dev_server_id, session.user_id())
2550 .await?;
2551
2552 send_dev_server_projects_update(session.user_id(), status, &session).await;
2553
2554 response.send(proto::Ack {})?;
2555 Ok(())
2556}
2557
2558async fn delete_dev_server_project(
2559 request: proto::DeleteDevServerProject,
2560 response: Response<proto::DeleteDevServerProject>,
2561 session: UserSession,
2562) -> Result<()> {
2563 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2564 let dev_server_project = session
2565 .db()
2566 .await
2567 .get_dev_server_project(dev_server_project_id)
2568 .await?;
2569
2570 let dev_server = session
2571 .db()
2572 .await
2573 .get_dev_server(dev_server_project.dev_server_id)
2574 .await?;
2575 if dev_server.user_id != session.user_id() {
2576 return Err(anyhow!(ErrorCode::Forbidden))?;
2577 }
2578
2579 let dev_server_connection_id = session
2580 .connection_pool()
2581 .await
2582 .dev_server_connection_id(dev_server.id);
2583
2584 if let Some(dev_server_connection_id) = dev_server_connection_id {
2585 let project = session
2586 .db()
2587 .await
2588 .find_dev_server_project(dev_server_project_id)
2589 .await;
2590 if let Ok(project) = project {
2591 unshare_project_internal(
2592 project.id,
2593 dev_server_connection_id,
2594 Some(session.user_id()),
2595 &session,
2596 )
2597 .await?;
2598 }
2599 }
2600
2601 let (projects, status) = session
2602 .db()
2603 .await
2604 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2605 .await?;
2606
2607 if let Some(dev_server_connection_id) = dev_server_connection_id {
2608 session.peer.send(
2609 dev_server_connection_id,
2610 proto::DevServerInstructions { projects },
2611 )?;
2612 }
2613
2614 send_dev_server_projects_update(session.user_id(), status, &session).await;
2615
2616 response.send(proto::Ack {})?;
2617 Ok(())
2618}
2619
2620async fn rejoin_dev_server_projects(
2621 request: proto::RejoinRemoteProjects,
2622 response: Response<proto::RejoinRemoteProjects>,
2623 session: UserSession,
2624) -> Result<()> {
2625 let mut rejoined_projects = {
2626 let db = session.db().await;
2627 db.rejoin_dev_server_projects(
2628 &request.rejoined_projects,
2629 session.user_id(),
2630 session.0.connection_id,
2631 )
2632 .await?
2633 };
2634 response.send(proto::RejoinRemoteProjectsResponse {
2635 rejoined_projects: rejoined_projects
2636 .iter()
2637 .map(|project| project.to_proto())
2638 .collect(),
2639 })?;
2640 notify_rejoined_projects(&mut rejoined_projects, &session)
2641}
2642
2643async fn reconnect_dev_server(
2644 request: proto::ReconnectDevServer,
2645 response: Response<proto::ReconnectDevServer>,
2646 session: DevServerSession,
2647) -> Result<()> {
2648 let reshared_projects = {
2649 let db = session.db().await;
2650 db.reshare_dev_server_projects(
2651 &request.reshared_projects,
2652 session.dev_server_id(),
2653 session.0.connection_id,
2654 )
2655 .await?
2656 };
2657
2658 for project in &reshared_projects {
2659 for collaborator in &project.collaborators {
2660 session
2661 .peer
2662 .send(
2663 collaborator.connection_id,
2664 proto::UpdateProjectCollaborator {
2665 project_id: project.id.to_proto(),
2666 old_peer_id: Some(project.old_connection_id.into()),
2667 new_peer_id: Some(session.connection_id.into()),
2668 },
2669 )
2670 .trace_err();
2671 }
2672
2673 broadcast(
2674 Some(session.connection_id),
2675 project
2676 .collaborators
2677 .iter()
2678 .map(|collaborator| collaborator.connection_id),
2679 |connection_id| {
2680 session.peer.forward_send(
2681 session.connection_id,
2682 connection_id,
2683 proto::UpdateProject {
2684 project_id: project.id.to_proto(),
2685 worktrees: project.worktrees.clone(),
2686 },
2687 )
2688 },
2689 );
2690 }
2691
2692 response.send(proto::ReconnectDevServerResponse {
2693 reshared_projects: reshared_projects
2694 .iter()
2695 .map(|project| proto::ResharedProject {
2696 id: project.id.to_proto(),
2697 collaborators: project
2698 .collaborators
2699 .iter()
2700 .map(|collaborator| collaborator.to_proto())
2701 .collect(),
2702 })
2703 .collect(),
2704 })?;
2705
2706 Ok(())
2707}
2708
2709async fn shutdown_dev_server(
2710 _: proto::ShutdownDevServer,
2711 response: Response<proto::ShutdownDevServer>,
2712 session: DevServerSession,
2713) -> Result<()> {
2714 response.send(proto::Ack {})?;
2715 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2716 remove_dev_server_connection(session.dev_server_id(), &session).await
2717}
2718
2719async fn shutdown_dev_server_internal(
2720 dev_server_id: DevServerId,
2721 connection_id: ConnectionId,
2722 session: &Session,
2723) -> Result<()> {
2724 let (dev_server_projects, dev_server) = {
2725 let db = session.db().await;
2726 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2727 let dev_server = db.get_dev_server(dev_server_id).await?;
2728 (dev_server_projects, dev_server)
2729 };
2730
2731 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2732 unshare_project_internal(
2733 ProjectId::from_proto(project_id),
2734 connection_id,
2735 None,
2736 session,
2737 )
2738 .await?;
2739 }
2740
2741 session
2742 .connection_pool()
2743 .await
2744 .set_dev_server_offline(dev_server_id);
2745
2746 let status = session
2747 .db()
2748 .await
2749 .dev_server_projects_update(dev_server.user_id)
2750 .await?;
2751 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2752
2753 Ok(())
2754}
2755
2756async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2757 let dev_server_connection = session
2758 .connection_pool()
2759 .await
2760 .dev_server_connection_id(dev_server_id);
2761
2762 if let Some(dev_server_connection) = dev_server_connection {
2763 session
2764 .connection_pool()
2765 .await
2766 .remove_connection(dev_server_connection)?;
2767 }
2768 Ok(())
2769}
2770
2771/// Updates other participants with changes to the project
2772async fn update_project(
2773 request: proto::UpdateProject,
2774 response: Response<proto::UpdateProject>,
2775 session: Session,
2776) -> Result<()> {
2777 let project_id = ProjectId::from_proto(request.project_id);
2778 let (room, guest_connection_ids) = &*session
2779 .db()
2780 .await
2781 .update_project(project_id, session.connection_id, &request.worktrees)
2782 .await?;
2783 broadcast(
2784 Some(session.connection_id),
2785 guest_connection_ids.iter().copied(),
2786 |connection_id| {
2787 session
2788 .peer
2789 .forward_send(session.connection_id, connection_id, request.clone())
2790 },
2791 );
2792 if let Some(room) = room {
2793 room_updated(room, &session.peer);
2794 }
2795 response.send(proto::Ack {})?;
2796
2797 Ok(())
2798}
2799
2800/// Updates other participants with changes to the worktree
2801async fn update_worktree(
2802 request: proto::UpdateWorktree,
2803 response: Response<proto::UpdateWorktree>,
2804 session: Session,
2805) -> Result<()> {
2806 let guest_connection_ids = session
2807 .db()
2808 .await
2809 .update_worktree(&request, session.connection_id)
2810 .await?;
2811
2812 broadcast(
2813 Some(session.connection_id),
2814 guest_connection_ids.iter().copied(),
2815 |connection_id| {
2816 session
2817 .peer
2818 .forward_send(session.connection_id, connection_id, request.clone())
2819 },
2820 );
2821 response.send(proto::Ack {})?;
2822 Ok(())
2823}
2824
2825/// Updates other participants with changes to the diagnostics
2826async fn update_diagnostic_summary(
2827 message: proto::UpdateDiagnosticSummary,
2828 session: Session,
2829) -> Result<()> {
2830 let guest_connection_ids = session
2831 .db()
2832 .await
2833 .update_diagnostic_summary(&message, session.connection_id)
2834 .await?;
2835
2836 broadcast(
2837 Some(session.connection_id),
2838 guest_connection_ids.iter().copied(),
2839 |connection_id| {
2840 session
2841 .peer
2842 .forward_send(session.connection_id, connection_id, message.clone())
2843 },
2844 );
2845
2846 Ok(())
2847}
2848
2849/// Updates other participants with changes to the worktree settings
2850async fn update_worktree_settings(
2851 message: proto::UpdateWorktreeSettings,
2852 session: Session,
2853) -> Result<()> {
2854 let guest_connection_ids = session
2855 .db()
2856 .await
2857 .update_worktree_settings(&message, session.connection_id)
2858 .await?;
2859
2860 broadcast(
2861 Some(session.connection_id),
2862 guest_connection_ids.iter().copied(),
2863 |connection_id| {
2864 session
2865 .peer
2866 .forward_send(session.connection_id, connection_id, message.clone())
2867 },
2868 );
2869
2870 Ok(())
2871}
2872
2873/// Notify other participants that a language server has started.
2874async fn start_language_server(
2875 request: proto::StartLanguageServer,
2876 session: Session,
2877) -> Result<()> {
2878 let guest_connection_ids = session
2879 .db()
2880 .await
2881 .start_language_server(&request, session.connection_id)
2882 .await?;
2883
2884 broadcast(
2885 Some(session.connection_id),
2886 guest_connection_ids.iter().copied(),
2887 |connection_id| {
2888 session
2889 .peer
2890 .forward_send(session.connection_id, connection_id, request.clone())
2891 },
2892 );
2893 Ok(())
2894}
2895
2896/// Notify other participants that a language server has changed.
2897async fn update_language_server(
2898 request: proto::UpdateLanguageServer,
2899 session: Session,
2900) -> Result<()> {
2901 let project_id = ProjectId::from_proto(request.project_id);
2902 let project_connection_ids = session
2903 .db()
2904 .await
2905 .project_connection_ids(project_id, session.connection_id, true)
2906 .await?;
2907 broadcast(
2908 Some(session.connection_id),
2909 project_connection_ids.iter().copied(),
2910 |connection_id| {
2911 session
2912 .peer
2913 .forward_send(session.connection_id, connection_id, request.clone())
2914 },
2915 );
2916 Ok(())
2917}
2918
2919/// forward a project request to the host. These requests should be read only
2920/// as guests are allowed to send them.
2921async fn forward_read_only_project_request<T>(
2922 request: T,
2923 response: Response<T>,
2924 session: UserSession,
2925) -> Result<()>
2926where
2927 T: EntityMessage + RequestMessage,
2928{
2929 let project_id = ProjectId::from_proto(request.remote_entity_id());
2930 let host_connection_id = session
2931 .db()
2932 .await
2933 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2934 .await?;
2935 let payload = session
2936 .peer
2937 .forward_request(session.connection_id, host_connection_id, request)
2938 .await?;
2939 response.send(payload)?;
2940 Ok(())
2941}
2942
2943async fn forward_find_search_candidates_request(
2944 request: proto::FindSearchCandidates,
2945 response: Response<proto::FindSearchCandidates>,
2946 session: UserSession,
2947) -> Result<()> {
2948 let project_id = ProjectId::from_proto(request.remote_entity_id());
2949 let host_connection_id = session
2950 .db()
2951 .await
2952 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2953 .await?;
2954
2955 let host_version = session
2956 .connection_pool()
2957 .await
2958 .connection(host_connection_id)
2959 .map(|c| c.zed_version);
2960
2961 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2962 {
2963 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2964 let search = proto::SearchProject {
2965 project_id: project_id.to_proto(),
2966 query: query.query,
2967 regex: query.regex,
2968 whole_word: query.whole_word,
2969 case_sensitive: query.case_sensitive,
2970 files_to_include: query.files_to_include,
2971 files_to_exclude: query.files_to_exclude,
2972 include_ignored: query.include_ignored,
2973 };
2974
2975 let payload = session
2976 .peer
2977 .forward_request(session.connection_id, host_connection_id, search)
2978 .await?;
2979 return response.send(proto::FindSearchCandidatesResponse {
2980 buffer_ids: payload
2981 .locations
2982 .into_iter()
2983 .map(|loc| loc.buffer_id)
2984 .collect(),
2985 });
2986 }
2987
2988 let payload = session
2989 .peer
2990 .forward_request(session.connection_id, host_connection_id, request)
2991 .await?;
2992 response.send(payload)?;
2993 Ok(())
2994}
2995
2996/// forward a project request to the dev server. Only allowed
2997/// if it's your dev server.
2998async fn forward_project_request_for_owner<T>(
2999 request: T,
3000 response: Response<T>,
3001 session: UserSession,
3002) -> Result<()>
3003where
3004 T: EntityMessage + RequestMessage,
3005{
3006 let project_id = ProjectId::from_proto(request.remote_entity_id());
3007
3008 let host_connection_id = session
3009 .db()
3010 .await
3011 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3012 .await?;
3013 let payload = session
3014 .peer
3015 .forward_request(session.connection_id, host_connection_id, request)
3016 .await?;
3017 response.send(payload)?;
3018 Ok(())
3019}
3020
3021/// forward a project request to the host. These requests are disallowed
3022/// for guests.
3023async fn forward_mutating_project_request<T>(
3024 request: T,
3025 response: Response<T>,
3026 session: UserSession,
3027) -> Result<()>
3028where
3029 T: EntityMessage + RequestMessage,
3030{
3031 let project_id = ProjectId::from_proto(request.remote_entity_id());
3032
3033 let host_connection_id = session
3034 .db()
3035 .await
3036 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3037 .await?;
3038 let payload = session
3039 .peer
3040 .forward_request(session.connection_id, host_connection_id, request)
3041 .await?;
3042 response.send(payload)?;
3043 Ok(())
3044}
3045
3046/// Notify other participants that a new buffer has been created
3047async fn create_buffer_for_peer(
3048 request: proto::CreateBufferForPeer,
3049 session: Session,
3050) -> Result<()> {
3051 session
3052 .db()
3053 .await
3054 .check_user_is_project_host(
3055 ProjectId::from_proto(request.project_id),
3056 session.connection_id,
3057 )
3058 .await?;
3059 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3060 session
3061 .peer
3062 .forward_send(session.connection_id, peer_id.into(), request)?;
3063 Ok(())
3064}
3065
3066/// Notify other participants that a buffer has been updated. This is
3067/// allowed for guests as long as the update is limited to selections.
3068async fn update_buffer(
3069 request: proto::UpdateBuffer,
3070 response: Response<proto::UpdateBuffer>,
3071 session: Session,
3072) -> Result<()> {
3073 let project_id = ProjectId::from_proto(request.project_id);
3074 let mut capability = Capability::ReadOnly;
3075
3076 for op in request.operations.iter() {
3077 match op.variant {
3078 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3079 Some(_) => capability = Capability::ReadWrite,
3080 }
3081 }
3082
3083 let host = {
3084 let guard = session
3085 .db()
3086 .await
3087 .connections_for_buffer_update(
3088 project_id,
3089 session.principal_id(),
3090 session.connection_id,
3091 capability,
3092 )
3093 .await?;
3094
3095 let (host, guests) = &*guard;
3096
3097 broadcast(
3098 Some(session.connection_id),
3099 guests.clone(),
3100 |connection_id| {
3101 session
3102 .peer
3103 .forward_send(session.connection_id, connection_id, request.clone())
3104 },
3105 );
3106
3107 *host
3108 };
3109
3110 if host != session.connection_id {
3111 session
3112 .peer
3113 .forward_request(session.connection_id, host, request.clone())
3114 .await?;
3115 }
3116
3117 response.send(proto::Ack {})?;
3118 Ok(())
3119}
3120
3121async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3122 let project_id = ProjectId::from_proto(message.project_id);
3123
3124 let operation = message.operation.as_ref().context("invalid operation")?;
3125 let capability = match operation.variant.as_ref() {
3126 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3127 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3128 match buffer_op.variant {
3129 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3130 Capability::ReadOnly
3131 }
3132 _ => Capability::ReadWrite,
3133 }
3134 } else {
3135 Capability::ReadWrite
3136 }
3137 }
3138 Some(_) => Capability::ReadWrite,
3139 None => Capability::ReadOnly,
3140 };
3141
3142 let guard = session
3143 .db()
3144 .await
3145 .connections_for_buffer_update(
3146 project_id,
3147 session.principal_id(),
3148 session.connection_id,
3149 capability,
3150 )
3151 .await?;
3152
3153 let (host, guests) = &*guard;
3154
3155 broadcast(
3156 Some(session.connection_id),
3157 guests.iter().chain([host]).copied(),
3158 |connection_id| {
3159 session
3160 .peer
3161 .forward_send(session.connection_id, connection_id, message.clone())
3162 },
3163 );
3164
3165 Ok(())
3166}
3167
3168/// Notify other participants that a project has been updated.
3169async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3170 request: T,
3171 session: Session,
3172) -> Result<()> {
3173 let project_id = ProjectId::from_proto(request.remote_entity_id());
3174 let project_connection_ids = session
3175 .db()
3176 .await
3177 .project_connection_ids(project_id, session.connection_id, false)
3178 .await?;
3179
3180 broadcast(
3181 Some(session.connection_id),
3182 project_connection_ids.iter().copied(),
3183 |connection_id| {
3184 session
3185 .peer
3186 .forward_send(session.connection_id, connection_id, request.clone())
3187 },
3188 );
3189 Ok(())
3190}
3191
3192/// Start following another user in a call.
3193async fn follow(
3194 request: proto::Follow,
3195 response: Response<proto::Follow>,
3196 session: UserSession,
3197) -> Result<()> {
3198 let room_id = RoomId::from_proto(request.room_id);
3199 let project_id = request.project_id.map(ProjectId::from_proto);
3200 let leader_id = request
3201 .leader_id
3202 .ok_or_else(|| anyhow!("invalid leader id"))?
3203 .into();
3204 let follower_id = session.connection_id;
3205
3206 session
3207 .db()
3208 .await
3209 .check_room_participants(room_id, leader_id, session.connection_id)
3210 .await?;
3211
3212 let response_payload = session
3213 .peer
3214 .forward_request(session.connection_id, leader_id, request)
3215 .await?;
3216 response.send(response_payload)?;
3217
3218 if let Some(project_id) = project_id {
3219 let room = session
3220 .db()
3221 .await
3222 .follow(room_id, project_id, leader_id, follower_id)
3223 .await?;
3224 room_updated(&room, &session.peer);
3225 }
3226
3227 Ok(())
3228}
3229
3230/// Stop following another user in a call.
3231async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3232 let room_id = RoomId::from_proto(request.room_id);
3233 let project_id = request.project_id.map(ProjectId::from_proto);
3234 let leader_id = request
3235 .leader_id
3236 .ok_or_else(|| anyhow!("invalid leader id"))?
3237 .into();
3238 let follower_id = session.connection_id;
3239
3240 session
3241 .db()
3242 .await
3243 .check_room_participants(room_id, leader_id, session.connection_id)
3244 .await?;
3245
3246 session
3247 .peer
3248 .forward_send(session.connection_id, leader_id, request)?;
3249
3250 if let Some(project_id) = project_id {
3251 let room = session
3252 .db()
3253 .await
3254 .unfollow(room_id, project_id, leader_id, follower_id)
3255 .await?;
3256 room_updated(&room, &session.peer);
3257 }
3258
3259 Ok(())
3260}
3261
3262/// Notify everyone following you of your current location.
3263async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3264 let room_id = RoomId::from_proto(request.room_id);
3265 let database = session.db.lock().await;
3266
3267 let connection_ids = if let Some(project_id) = request.project_id {
3268 let project_id = ProjectId::from_proto(project_id);
3269 database
3270 .project_connection_ids(project_id, session.connection_id, true)
3271 .await?
3272 } else {
3273 database
3274 .room_connection_ids(room_id, session.connection_id)
3275 .await?
3276 };
3277
3278 // For now, don't send view update messages back to that view's current leader.
3279 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3280 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3281 _ => None,
3282 });
3283
3284 for connection_id in connection_ids.iter().cloned() {
3285 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3286 session
3287 .peer
3288 .forward_send(session.connection_id, connection_id, request.clone())?;
3289 }
3290 }
3291 Ok(())
3292}
3293
3294/// Get public data about users.
3295async fn get_users(
3296 request: proto::GetUsers,
3297 response: Response<proto::GetUsers>,
3298 session: Session,
3299) -> Result<()> {
3300 let user_ids = request
3301 .user_ids
3302 .into_iter()
3303 .map(UserId::from_proto)
3304 .collect();
3305 let users = session
3306 .db()
3307 .await
3308 .get_users_by_ids(user_ids)
3309 .await?
3310 .into_iter()
3311 .map(|user| proto::User {
3312 id: user.id.to_proto(),
3313 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3314 github_login: user.github_login,
3315 })
3316 .collect();
3317 response.send(proto::UsersResponse { users })?;
3318 Ok(())
3319}
3320
3321/// Search for users (to invite) buy Github login
3322async fn fuzzy_search_users(
3323 request: proto::FuzzySearchUsers,
3324 response: Response<proto::FuzzySearchUsers>,
3325 session: UserSession,
3326) -> Result<()> {
3327 let query = request.query;
3328 let users = match query.len() {
3329 0 => vec![],
3330 1 | 2 => session
3331 .db()
3332 .await
3333 .get_user_by_github_login(&query)
3334 .await?
3335 .into_iter()
3336 .collect(),
3337 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3338 };
3339 let users = users
3340 .into_iter()
3341 .filter(|user| user.id != session.user_id())
3342 .map(|user| proto::User {
3343 id: user.id.to_proto(),
3344 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3345 github_login: user.github_login,
3346 })
3347 .collect();
3348 response.send(proto::UsersResponse { users })?;
3349 Ok(())
3350}
3351
3352/// Send a contact request to another user.
3353async fn request_contact(
3354 request: proto::RequestContact,
3355 response: Response<proto::RequestContact>,
3356 session: UserSession,
3357) -> Result<()> {
3358 let requester_id = session.user_id();
3359 let responder_id = UserId::from_proto(request.responder_id);
3360 if requester_id == responder_id {
3361 return Err(anyhow!("cannot add yourself as a contact"))?;
3362 }
3363
3364 let notifications = session
3365 .db()
3366 .await
3367 .send_contact_request(requester_id, responder_id)
3368 .await?;
3369
3370 // Update outgoing contact requests of requester
3371 let mut update = proto::UpdateContacts::default();
3372 update.outgoing_requests.push(responder_id.to_proto());
3373 for connection_id in session
3374 .connection_pool()
3375 .await
3376 .user_connection_ids(requester_id)
3377 {
3378 session.peer.send(connection_id, update.clone())?;
3379 }
3380
3381 // Update incoming contact requests of responder
3382 let mut update = proto::UpdateContacts::default();
3383 update
3384 .incoming_requests
3385 .push(proto::IncomingContactRequest {
3386 requester_id: requester_id.to_proto(),
3387 });
3388 let connection_pool = session.connection_pool().await;
3389 for connection_id in connection_pool.user_connection_ids(responder_id) {
3390 session.peer.send(connection_id, update.clone())?;
3391 }
3392
3393 send_notifications(&connection_pool, &session.peer, notifications);
3394
3395 response.send(proto::Ack {})?;
3396 Ok(())
3397}
3398
3399/// Accept or decline a contact request
3400async fn respond_to_contact_request(
3401 request: proto::RespondToContactRequest,
3402 response: Response<proto::RespondToContactRequest>,
3403 session: UserSession,
3404) -> Result<()> {
3405 let responder_id = session.user_id();
3406 let requester_id = UserId::from_proto(request.requester_id);
3407 let db = session.db().await;
3408 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3409 db.dismiss_contact_notification(responder_id, requester_id)
3410 .await?;
3411 } else {
3412 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3413
3414 let notifications = db
3415 .respond_to_contact_request(responder_id, requester_id, accept)
3416 .await?;
3417 let requester_busy = db.is_user_busy(requester_id).await?;
3418 let responder_busy = db.is_user_busy(responder_id).await?;
3419
3420 let pool = session.connection_pool().await;
3421 // Update responder with new contact
3422 let mut update = proto::UpdateContacts::default();
3423 if accept {
3424 update
3425 .contacts
3426 .push(contact_for_user(requester_id, requester_busy, &pool));
3427 }
3428 update
3429 .remove_incoming_requests
3430 .push(requester_id.to_proto());
3431 for connection_id in pool.user_connection_ids(responder_id) {
3432 session.peer.send(connection_id, update.clone())?;
3433 }
3434
3435 // Update requester with new contact
3436 let mut update = proto::UpdateContacts::default();
3437 if accept {
3438 update
3439 .contacts
3440 .push(contact_for_user(responder_id, responder_busy, &pool));
3441 }
3442 update
3443 .remove_outgoing_requests
3444 .push(responder_id.to_proto());
3445
3446 for connection_id in pool.user_connection_ids(requester_id) {
3447 session.peer.send(connection_id, update.clone())?;
3448 }
3449
3450 send_notifications(&pool, &session.peer, notifications);
3451 }
3452
3453 response.send(proto::Ack {})?;
3454 Ok(())
3455}
3456
3457/// Remove a contact.
3458async fn remove_contact(
3459 request: proto::RemoveContact,
3460 response: Response<proto::RemoveContact>,
3461 session: UserSession,
3462) -> Result<()> {
3463 let requester_id = session.user_id();
3464 let responder_id = UserId::from_proto(request.user_id);
3465 let db = session.db().await;
3466 let (contact_accepted, deleted_notification_id) =
3467 db.remove_contact(requester_id, responder_id).await?;
3468
3469 let pool = session.connection_pool().await;
3470 // Update outgoing contact requests of requester
3471 let mut update = proto::UpdateContacts::default();
3472 if contact_accepted {
3473 update.remove_contacts.push(responder_id.to_proto());
3474 } else {
3475 update
3476 .remove_outgoing_requests
3477 .push(responder_id.to_proto());
3478 }
3479 for connection_id in pool.user_connection_ids(requester_id) {
3480 session.peer.send(connection_id, update.clone())?;
3481 }
3482
3483 // Update incoming contact requests of responder
3484 let mut update = proto::UpdateContacts::default();
3485 if contact_accepted {
3486 update.remove_contacts.push(requester_id.to_proto());
3487 } else {
3488 update
3489 .remove_incoming_requests
3490 .push(requester_id.to_proto());
3491 }
3492 for connection_id in pool.user_connection_ids(responder_id) {
3493 session.peer.send(connection_id, update.clone())?;
3494 if let Some(notification_id) = deleted_notification_id {
3495 session.peer.send(
3496 connection_id,
3497 proto::DeleteNotification {
3498 notification_id: notification_id.to_proto(),
3499 },
3500 )?;
3501 }
3502 }
3503
3504 response.send(proto::Ack {})?;
3505 Ok(())
3506}
3507
3508fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3509 version.0.minor() < 139
3510}
3511
3512async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3513 let plan = session.current_plan(session.db().await).await?;
3514
3515 session
3516 .peer
3517 .send(
3518 session.connection_id,
3519 proto::UpdateUserPlan { plan: plan.into() },
3520 )
3521 .trace_err();
3522
3523 Ok(())
3524}
3525
3526async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3527 subscribe_user_to_channels(
3528 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3529 &session,
3530 )
3531 .await?;
3532 Ok(())
3533}
3534
3535async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3536 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3537 let mut pool = session.connection_pool().await;
3538 for membership in &channels_for_user.channel_memberships {
3539 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3540 }
3541 session.peer.send(
3542 session.connection_id,
3543 build_update_user_channels(&channels_for_user),
3544 )?;
3545 session.peer.send(
3546 session.connection_id,
3547 build_channels_update(channels_for_user),
3548 )?;
3549 Ok(())
3550}
3551
3552/// Creates a new channel.
3553async fn create_channel(
3554 request: proto::CreateChannel,
3555 response: Response<proto::CreateChannel>,
3556 session: UserSession,
3557) -> Result<()> {
3558 let db = session.db().await;
3559
3560 let parent_id = request.parent_id.map(ChannelId::from_proto);
3561 let (channel, membership) = db
3562 .create_channel(&request.name, parent_id, session.user_id())
3563 .await?;
3564
3565 let root_id = channel.root_id();
3566 let channel = Channel::from_model(channel);
3567
3568 response.send(proto::CreateChannelResponse {
3569 channel: Some(channel.to_proto()),
3570 parent_id: request.parent_id,
3571 })?;
3572
3573 let mut connection_pool = session.connection_pool().await;
3574 if let Some(membership) = membership {
3575 connection_pool.subscribe_to_channel(
3576 membership.user_id,
3577 membership.channel_id,
3578 membership.role,
3579 );
3580 let update = proto::UpdateUserChannels {
3581 channel_memberships: vec![proto::ChannelMembership {
3582 channel_id: membership.channel_id.to_proto(),
3583 role: membership.role.into(),
3584 }],
3585 ..Default::default()
3586 };
3587 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3588 session.peer.send(connection_id, update.clone())?;
3589 }
3590 }
3591
3592 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3593 if !role.can_see_channel(channel.visibility) {
3594 continue;
3595 }
3596
3597 let update = proto::UpdateChannels {
3598 channels: vec![channel.to_proto()],
3599 ..Default::default()
3600 };
3601 session.peer.send(connection_id, update.clone())?;
3602 }
3603
3604 Ok(())
3605}
3606
3607/// Delete a channel
3608async fn delete_channel(
3609 request: proto::DeleteChannel,
3610 response: Response<proto::DeleteChannel>,
3611 session: UserSession,
3612) -> Result<()> {
3613 let db = session.db().await;
3614
3615 let channel_id = request.channel_id;
3616 let (root_channel, removed_channels) = db
3617 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3618 .await?;
3619 response.send(proto::Ack {})?;
3620
3621 // Notify members of removed channels
3622 let mut update = proto::UpdateChannels::default();
3623 update
3624 .delete_channels
3625 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3626
3627 let connection_pool = session.connection_pool().await;
3628 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3629 session.peer.send(connection_id, update.clone())?;
3630 }
3631
3632 Ok(())
3633}
3634
3635/// Invite someone to join a channel.
3636async fn invite_channel_member(
3637 request: proto::InviteChannelMember,
3638 response: Response<proto::InviteChannelMember>,
3639 session: UserSession,
3640) -> Result<()> {
3641 let db = session.db().await;
3642 let channel_id = ChannelId::from_proto(request.channel_id);
3643 let invitee_id = UserId::from_proto(request.user_id);
3644 let InviteMemberResult {
3645 channel,
3646 notifications,
3647 } = db
3648 .invite_channel_member(
3649 channel_id,
3650 invitee_id,
3651 session.user_id(),
3652 request.role().into(),
3653 )
3654 .await?;
3655
3656 let update = proto::UpdateChannels {
3657 channel_invitations: vec![channel.to_proto()],
3658 ..Default::default()
3659 };
3660
3661 let connection_pool = session.connection_pool().await;
3662 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3663 session.peer.send(connection_id, update.clone())?;
3664 }
3665
3666 send_notifications(&connection_pool, &session.peer, notifications);
3667
3668 response.send(proto::Ack {})?;
3669 Ok(())
3670}
3671
3672/// remove someone from a channel
3673async fn remove_channel_member(
3674 request: proto::RemoveChannelMember,
3675 response: Response<proto::RemoveChannelMember>,
3676 session: UserSession,
3677) -> Result<()> {
3678 let db = session.db().await;
3679 let channel_id = ChannelId::from_proto(request.channel_id);
3680 let member_id = UserId::from_proto(request.user_id);
3681
3682 let RemoveChannelMemberResult {
3683 membership_update,
3684 notification_id,
3685 } = db
3686 .remove_channel_member(channel_id, member_id, session.user_id())
3687 .await?;
3688
3689 let mut connection_pool = session.connection_pool().await;
3690 notify_membership_updated(
3691 &mut connection_pool,
3692 membership_update,
3693 member_id,
3694 &session.peer,
3695 );
3696 for connection_id in connection_pool.user_connection_ids(member_id) {
3697 if let Some(notification_id) = notification_id {
3698 session
3699 .peer
3700 .send(
3701 connection_id,
3702 proto::DeleteNotification {
3703 notification_id: notification_id.to_proto(),
3704 },
3705 )
3706 .trace_err();
3707 }
3708 }
3709
3710 response.send(proto::Ack {})?;
3711 Ok(())
3712}
3713
3714/// Toggle the channel between public and private.
3715/// Care is taken to maintain the invariant that public channels only descend from public channels,
3716/// (though members-only channels can appear at any point in the hierarchy).
3717async fn set_channel_visibility(
3718 request: proto::SetChannelVisibility,
3719 response: Response<proto::SetChannelVisibility>,
3720 session: UserSession,
3721) -> Result<()> {
3722 let db = session.db().await;
3723 let channel_id = ChannelId::from_proto(request.channel_id);
3724 let visibility = request.visibility().into();
3725
3726 let channel_model = db
3727 .set_channel_visibility(channel_id, visibility, session.user_id())
3728 .await?;
3729 let root_id = channel_model.root_id();
3730 let channel = Channel::from_model(channel_model);
3731
3732 let mut connection_pool = session.connection_pool().await;
3733 for (user_id, role) in connection_pool
3734 .channel_user_ids(root_id)
3735 .collect::<Vec<_>>()
3736 .into_iter()
3737 {
3738 let update = if role.can_see_channel(channel.visibility) {
3739 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3740 proto::UpdateChannels {
3741 channels: vec![channel.to_proto()],
3742 ..Default::default()
3743 }
3744 } else {
3745 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3746 proto::UpdateChannels {
3747 delete_channels: vec![channel.id.to_proto()],
3748 ..Default::default()
3749 }
3750 };
3751
3752 for connection_id in connection_pool.user_connection_ids(user_id) {
3753 session.peer.send(connection_id, update.clone())?;
3754 }
3755 }
3756
3757 response.send(proto::Ack {})?;
3758 Ok(())
3759}
3760
3761/// Alter the role for a user in the channel.
3762async fn set_channel_member_role(
3763 request: proto::SetChannelMemberRole,
3764 response: Response<proto::SetChannelMemberRole>,
3765 session: UserSession,
3766) -> Result<()> {
3767 let db = session.db().await;
3768 let channel_id = ChannelId::from_proto(request.channel_id);
3769 let member_id = UserId::from_proto(request.user_id);
3770 let result = db
3771 .set_channel_member_role(
3772 channel_id,
3773 session.user_id(),
3774 member_id,
3775 request.role().into(),
3776 )
3777 .await?;
3778
3779 match result {
3780 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3781 let mut connection_pool = session.connection_pool().await;
3782 notify_membership_updated(
3783 &mut connection_pool,
3784 membership_update,
3785 member_id,
3786 &session.peer,
3787 )
3788 }
3789 db::SetMemberRoleResult::InviteUpdated(channel) => {
3790 let update = proto::UpdateChannels {
3791 channel_invitations: vec![channel.to_proto()],
3792 ..Default::default()
3793 };
3794
3795 for connection_id in session
3796 .connection_pool()
3797 .await
3798 .user_connection_ids(member_id)
3799 {
3800 session.peer.send(connection_id, update.clone())?;
3801 }
3802 }
3803 }
3804
3805 response.send(proto::Ack {})?;
3806 Ok(())
3807}
3808
3809/// Change the name of a channel
3810async fn rename_channel(
3811 request: proto::RenameChannel,
3812 response: Response<proto::RenameChannel>,
3813 session: UserSession,
3814) -> Result<()> {
3815 let db = session.db().await;
3816 let channel_id = ChannelId::from_proto(request.channel_id);
3817 let channel_model = db
3818 .rename_channel(channel_id, session.user_id(), &request.name)
3819 .await?;
3820 let root_id = channel_model.root_id();
3821 let channel = Channel::from_model(channel_model);
3822
3823 response.send(proto::RenameChannelResponse {
3824 channel: Some(channel.to_proto()),
3825 })?;
3826
3827 let connection_pool = session.connection_pool().await;
3828 let update = proto::UpdateChannels {
3829 channels: vec![channel.to_proto()],
3830 ..Default::default()
3831 };
3832 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3833 if role.can_see_channel(channel.visibility) {
3834 session.peer.send(connection_id, update.clone())?;
3835 }
3836 }
3837
3838 Ok(())
3839}
3840
3841/// Move a channel to a new parent.
3842async fn move_channel(
3843 request: proto::MoveChannel,
3844 response: Response<proto::MoveChannel>,
3845 session: UserSession,
3846) -> Result<()> {
3847 let channel_id = ChannelId::from_proto(request.channel_id);
3848 let to = ChannelId::from_proto(request.to);
3849
3850 let (root_id, channels) = session
3851 .db()
3852 .await
3853 .move_channel(channel_id, to, session.user_id())
3854 .await?;
3855
3856 let connection_pool = session.connection_pool().await;
3857 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3858 let channels = channels
3859 .iter()
3860 .filter_map(|channel| {
3861 if role.can_see_channel(channel.visibility) {
3862 Some(channel.to_proto())
3863 } else {
3864 None
3865 }
3866 })
3867 .collect::<Vec<_>>();
3868 if channels.is_empty() {
3869 continue;
3870 }
3871
3872 let update = proto::UpdateChannels {
3873 channels,
3874 ..Default::default()
3875 };
3876
3877 session.peer.send(connection_id, update.clone())?;
3878 }
3879
3880 response.send(Ack {})?;
3881 Ok(())
3882}
3883
3884/// Get the list of channel members
3885async fn get_channel_members(
3886 request: proto::GetChannelMembers,
3887 response: Response<proto::GetChannelMembers>,
3888 session: UserSession,
3889) -> Result<()> {
3890 let db = session.db().await;
3891 let channel_id = ChannelId::from_proto(request.channel_id);
3892 let limit = if request.limit == 0 {
3893 u16::MAX as u64
3894 } else {
3895 request.limit
3896 };
3897 let (members, users) = db
3898 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3899 .await?;
3900 response.send(proto::GetChannelMembersResponse { members, users })?;
3901 Ok(())
3902}
3903
3904/// Accept or decline a channel invitation.
3905async fn respond_to_channel_invite(
3906 request: proto::RespondToChannelInvite,
3907 response: Response<proto::RespondToChannelInvite>,
3908 session: UserSession,
3909) -> Result<()> {
3910 let db = session.db().await;
3911 let channel_id = ChannelId::from_proto(request.channel_id);
3912 let RespondToChannelInvite {
3913 membership_update,
3914 notifications,
3915 } = db
3916 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3917 .await?;
3918
3919 let mut connection_pool = session.connection_pool().await;
3920 if let Some(membership_update) = membership_update {
3921 notify_membership_updated(
3922 &mut connection_pool,
3923 membership_update,
3924 session.user_id(),
3925 &session.peer,
3926 );
3927 } else {
3928 let update = proto::UpdateChannels {
3929 remove_channel_invitations: vec![channel_id.to_proto()],
3930 ..Default::default()
3931 };
3932
3933 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3934 session.peer.send(connection_id, update.clone())?;
3935 }
3936 };
3937
3938 send_notifications(&connection_pool, &session.peer, notifications);
3939
3940 response.send(proto::Ack {})?;
3941
3942 Ok(())
3943}
3944
3945/// Join the channels' room
3946async fn join_channel(
3947 request: proto::JoinChannel,
3948 response: Response<proto::JoinChannel>,
3949 session: UserSession,
3950) -> Result<()> {
3951 let channel_id = ChannelId::from_proto(request.channel_id);
3952 join_channel_internal(channel_id, Box::new(response), session).await
3953}
3954
3955trait JoinChannelInternalResponse {
3956 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3957}
3958impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3959 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3960 Response::<proto::JoinChannel>::send(self, result)
3961 }
3962}
3963impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3964 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3965 Response::<proto::JoinRoom>::send(self, result)
3966 }
3967}
3968
3969async fn join_channel_internal(
3970 channel_id: ChannelId,
3971 response: Box<impl JoinChannelInternalResponse>,
3972 session: UserSession,
3973) -> Result<()> {
3974 let joined_room = {
3975 let mut db = session.db().await;
3976 // If zed quits without leaving the room, and the user re-opens zed before the
3977 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3978 // room they were in.
3979 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3980 tracing::info!(
3981 stale_connection_id = %connection,
3982 "cleaning up stale connection",
3983 );
3984 drop(db);
3985 leave_room_for_session(&session, connection).await?;
3986 db = session.db().await;
3987 }
3988
3989 let (joined_room, membership_updated, role) = db
3990 .join_channel(channel_id, session.user_id(), session.connection_id)
3991 .await?;
3992
3993 let live_kit_connection_info =
3994 session
3995 .app_state
3996 .live_kit_client
3997 .as_ref()
3998 .and_then(|live_kit| {
3999 let (can_publish, token) = if role == ChannelRole::Guest {
4000 (
4001 false,
4002 live_kit
4003 .guest_token(
4004 &joined_room.room.live_kit_room,
4005 &session.user_id().to_string(),
4006 )
4007 .trace_err()?,
4008 )
4009 } else {
4010 (
4011 true,
4012 live_kit
4013 .room_token(
4014 &joined_room.room.live_kit_room,
4015 &session.user_id().to_string(),
4016 )
4017 .trace_err()?,
4018 )
4019 };
4020
4021 Some(LiveKitConnectionInfo {
4022 server_url: live_kit.url().into(),
4023 token,
4024 can_publish,
4025 })
4026 });
4027
4028 response.send(proto::JoinRoomResponse {
4029 room: Some(joined_room.room.clone()),
4030 channel_id: joined_room
4031 .channel
4032 .as_ref()
4033 .map(|channel| channel.id.to_proto()),
4034 live_kit_connection_info,
4035 })?;
4036
4037 let mut connection_pool = session.connection_pool().await;
4038 if let Some(membership_updated) = membership_updated {
4039 notify_membership_updated(
4040 &mut connection_pool,
4041 membership_updated,
4042 session.user_id(),
4043 &session.peer,
4044 );
4045 }
4046
4047 room_updated(&joined_room.room, &session.peer);
4048
4049 joined_room
4050 };
4051
4052 channel_updated(
4053 &joined_room
4054 .channel
4055 .ok_or_else(|| anyhow!("channel not returned"))?,
4056 &joined_room.room,
4057 &session.peer,
4058 &*session.connection_pool().await,
4059 );
4060
4061 update_user_contacts(session.user_id(), &session).await?;
4062 Ok(())
4063}
4064
4065/// Start editing the channel notes
4066async fn join_channel_buffer(
4067 request: proto::JoinChannelBuffer,
4068 response: Response<proto::JoinChannelBuffer>,
4069 session: UserSession,
4070) -> Result<()> {
4071 let db = session.db().await;
4072 let channel_id = ChannelId::from_proto(request.channel_id);
4073
4074 let open_response = db
4075 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4076 .await?;
4077
4078 let collaborators = open_response.collaborators.clone();
4079 response.send(open_response)?;
4080
4081 let update = UpdateChannelBufferCollaborators {
4082 channel_id: channel_id.to_proto(),
4083 collaborators: collaborators.clone(),
4084 };
4085 channel_buffer_updated(
4086 session.connection_id,
4087 collaborators
4088 .iter()
4089 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4090 &update,
4091 &session.peer,
4092 );
4093
4094 Ok(())
4095}
4096
4097/// Edit the channel notes
4098async fn update_channel_buffer(
4099 request: proto::UpdateChannelBuffer,
4100 session: UserSession,
4101) -> Result<()> {
4102 let db = session.db().await;
4103 let channel_id = ChannelId::from_proto(request.channel_id);
4104
4105 let (collaborators, epoch, version) = db
4106 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4107 .await?;
4108
4109 channel_buffer_updated(
4110 session.connection_id,
4111 collaborators.clone(),
4112 &proto::UpdateChannelBuffer {
4113 channel_id: channel_id.to_proto(),
4114 operations: request.operations,
4115 },
4116 &session.peer,
4117 );
4118
4119 let pool = &*session.connection_pool().await;
4120
4121 let non_collaborators =
4122 pool.channel_connection_ids(channel_id)
4123 .filter_map(|(connection_id, _)| {
4124 if collaborators.contains(&connection_id) {
4125 None
4126 } else {
4127 Some(connection_id)
4128 }
4129 });
4130
4131 broadcast(None, non_collaborators, |peer_id| {
4132 session.peer.send(
4133 peer_id,
4134 proto::UpdateChannels {
4135 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4136 channel_id: channel_id.to_proto(),
4137 epoch: epoch as u64,
4138 version: version.clone(),
4139 }],
4140 ..Default::default()
4141 },
4142 )
4143 });
4144
4145 Ok(())
4146}
4147
4148/// Rejoin the channel notes after a connection blip
4149async fn rejoin_channel_buffers(
4150 request: proto::RejoinChannelBuffers,
4151 response: Response<proto::RejoinChannelBuffers>,
4152 session: UserSession,
4153) -> Result<()> {
4154 let db = session.db().await;
4155 let buffers = db
4156 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4157 .await?;
4158
4159 for rejoined_buffer in &buffers {
4160 let collaborators_to_notify = rejoined_buffer
4161 .buffer
4162 .collaborators
4163 .iter()
4164 .filter_map(|c| Some(c.peer_id?.into()));
4165 channel_buffer_updated(
4166 session.connection_id,
4167 collaborators_to_notify,
4168 &proto::UpdateChannelBufferCollaborators {
4169 channel_id: rejoined_buffer.buffer.channel_id,
4170 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4171 },
4172 &session.peer,
4173 );
4174 }
4175
4176 response.send(proto::RejoinChannelBuffersResponse {
4177 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4178 })?;
4179
4180 Ok(())
4181}
4182
4183/// Stop editing the channel notes
4184async fn leave_channel_buffer(
4185 request: proto::LeaveChannelBuffer,
4186 response: Response<proto::LeaveChannelBuffer>,
4187 session: UserSession,
4188) -> Result<()> {
4189 let db = session.db().await;
4190 let channel_id = ChannelId::from_proto(request.channel_id);
4191
4192 let left_buffer = db
4193 .leave_channel_buffer(channel_id, session.connection_id)
4194 .await?;
4195
4196 response.send(Ack {})?;
4197
4198 channel_buffer_updated(
4199 session.connection_id,
4200 left_buffer.connections,
4201 &proto::UpdateChannelBufferCollaborators {
4202 channel_id: channel_id.to_proto(),
4203 collaborators: left_buffer.collaborators,
4204 },
4205 &session.peer,
4206 );
4207
4208 Ok(())
4209}
4210
4211fn channel_buffer_updated<T: EnvelopedMessage>(
4212 sender_id: ConnectionId,
4213 collaborators: impl IntoIterator<Item = ConnectionId>,
4214 message: &T,
4215 peer: &Peer,
4216) {
4217 broadcast(Some(sender_id), collaborators, |peer_id| {
4218 peer.send(peer_id, message.clone())
4219 });
4220}
4221
4222fn send_notifications(
4223 connection_pool: &ConnectionPool,
4224 peer: &Peer,
4225 notifications: db::NotificationBatch,
4226) {
4227 for (user_id, notification) in notifications {
4228 for connection_id in connection_pool.user_connection_ids(user_id) {
4229 if let Err(error) = peer.send(
4230 connection_id,
4231 proto::AddNotification {
4232 notification: Some(notification.clone()),
4233 },
4234 ) {
4235 tracing::error!(
4236 "failed to send notification to {:?} {}",
4237 connection_id,
4238 error
4239 );
4240 }
4241 }
4242 }
4243}
4244
4245/// Send a message to the channel
4246async fn send_channel_message(
4247 request: proto::SendChannelMessage,
4248 response: Response<proto::SendChannelMessage>,
4249 session: UserSession,
4250) -> Result<()> {
4251 // Validate the message body.
4252 let body = request.body.trim().to_string();
4253 if body.len() > MAX_MESSAGE_LEN {
4254 return Err(anyhow!("message is too long"))?;
4255 }
4256 if body.is_empty() {
4257 return Err(anyhow!("message can't be blank"))?;
4258 }
4259
4260 // TODO: adjust mentions if body is trimmed
4261
4262 let timestamp = OffsetDateTime::now_utc();
4263 let nonce = request
4264 .nonce
4265 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4266
4267 let channel_id = ChannelId::from_proto(request.channel_id);
4268 let CreatedChannelMessage {
4269 message_id,
4270 participant_connection_ids,
4271 notifications,
4272 } = session
4273 .db()
4274 .await
4275 .create_channel_message(
4276 channel_id,
4277 session.user_id(),
4278 &body,
4279 &request.mentions,
4280 timestamp,
4281 nonce.clone().into(),
4282 request.reply_to_message_id.map(MessageId::from_proto),
4283 )
4284 .await?;
4285
4286 let message = proto::ChannelMessage {
4287 sender_id: session.user_id().to_proto(),
4288 id: message_id.to_proto(),
4289 body,
4290 mentions: request.mentions,
4291 timestamp: timestamp.unix_timestamp() as u64,
4292 nonce: Some(nonce),
4293 reply_to_message_id: request.reply_to_message_id,
4294 edited_at: None,
4295 };
4296 broadcast(
4297 Some(session.connection_id),
4298 participant_connection_ids.clone(),
4299 |connection| {
4300 session.peer.send(
4301 connection,
4302 proto::ChannelMessageSent {
4303 channel_id: channel_id.to_proto(),
4304 message: Some(message.clone()),
4305 },
4306 )
4307 },
4308 );
4309 response.send(proto::SendChannelMessageResponse {
4310 message: Some(message),
4311 })?;
4312
4313 let pool = &*session.connection_pool().await;
4314 let non_participants =
4315 pool.channel_connection_ids(channel_id)
4316 .filter_map(|(connection_id, _)| {
4317 if participant_connection_ids.contains(&connection_id) {
4318 None
4319 } else {
4320 Some(connection_id)
4321 }
4322 });
4323 broadcast(None, non_participants, |peer_id| {
4324 session.peer.send(
4325 peer_id,
4326 proto::UpdateChannels {
4327 latest_channel_message_ids: vec![proto::ChannelMessageId {
4328 channel_id: channel_id.to_proto(),
4329 message_id: message_id.to_proto(),
4330 }],
4331 ..Default::default()
4332 },
4333 )
4334 });
4335 send_notifications(pool, &session.peer, notifications);
4336
4337 Ok(())
4338}
4339
4340/// Delete a channel message
4341async fn remove_channel_message(
4342 request: proto::RemoveChannelMessage,
4343 response: Response<proto::RemoveChannelMessage>,
4344 session: UserSession,
4345) -> Result<()> {
4346 let channel_id = ChannelId::from_proto(request.channel_id);
4347 let message_id = MessageId::from_proto(request.message_id);
4348 let (connection_ids, existing_notification_ids) = session
4349 .db()
4350 .await
4351 .remove_channel_message(channel_id, message_id, session.user_id())
4352 .await?;
4353
4354 broadcast(
4355 Some(session.connection_id),
4356 connection_ids,
4357 move |connection| {
4358 session.peer.send(connection, request.clone())?;
4359
4360 for notification_id in &existing_notification_ids {
4361 session.peer.send(
4362 connection,
4363 proto::DeleteNotification {
4364 notification_id: (*notification_id).to_proto(),
4365 },
4366 )?;
4367 }
4368
4369 Ok(())
4370 },
4371 );
4372 response.send(proto::Ack {})?;
4373 Ok(())
4374}
4375
4376async fn update_channel_message(
4377 request: proto::UpdateChannelMessage,
4378 response: Response<proto::UpdateChannelMessage>,
4379 session: UserSession,
4380) -> Result<()> {
4381 let channel_id = ChannelId::from_proto(request.channel_id);
4382 let message_id = MessageId::from_proto(request.message_id);
4383 let updated_at = OffsetDateTime::now_utc();
4384 let UpdatedChannelMessage {
4385 message_id,
4386 participant_connection_ids,
4387 notifications,
4388 reply_to_message_id,
4389 timestamp,
4390 deleted_mention_notification_ids,
4391 updated_mention_notifications,
4392 } = session
4393 .db()
4394 .await
4395 .update_channel_message(
4396 channel_id,
4397 message_id,
4398 session.user_id(),
4399 request.body.as_str(),
4400 &request.mentions,
4401 updated_at,
4402 )
4403 .await?;
4404
4405 let nonce = request
4406 .nonce
4407 .clone()
4408 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4409
4410 let message = proto::ChannelMessage {
4411 sender_id: session.user_id().to_proto(),
4412 id: message_id.to_proto(),
4413 body: request.body.clone(),
4414 mentions: request.mentions.clone(),
4415 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4416 nonce: Some(nonce),
4417 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4418 edited_at: Some(updated_at.unix_timestamp() as u64),
4419 };
4420
4421 response.send(proto::Ack {})?;
4422
4423 let pool = &*session.connection_pool().await;
4424 broadcast(
4425 Some(session.connection_id),
4426 participant_connection_ids,
4427 |connection| {
4428 session.peer.send(
4429 connection,
4430 proto::ChannelMessageUpdate {
4431 channel_id: channel_id.to_proto(),
4432 message: Some(message.clone()),
4433 },
4434 )?;
4435
4436 for notification_id in &deleted_mention_notification_ids {
4437 session.peer.send(
4438 connection,
4439 proto::DeleteNotification {
4440 notification_id: (*notification_id).to_proto(),
4441 },
4442 )?;
4443 }
4444
4445 for notification in &updated_mention_notifications {
4446 session.peer.send(
4447 connection,
4448 proto::UpdateNotification {
4449 notification: Some(notification.clone()),
4450 },
4451 )?;
4452 }
4453
4454 Ok(())
4455 },
4456 );
4457
4458 send_notifications(pool, &session.peer, notifications);
4459
4460 Ok(())
4461}
4462
4463/// Mark a channel message as read
4464async fn acknowledge_channel_message(
4465 request: proto::AckChannelMessage,
4466 session: UserSession,
4467) -> Result<()> {
4468 let channel_id = ChannelId::from_proto(request.channel_id);
4469 let message_id = MessageId::from_proto(request.message_id);
4470 let notifications = session
4471 .db()
4472 .await
4473 .observe_channel_message(channel_id, session.user_id(), message_id)
4474 .await?;
4475 send_notifications(
4476 &*session.connection_pool().await,
4477 &session.peer,
4478 notifications,
4479 );
4480 Ok(())
4481}
4482
4483/// Mark a buffer version as synced
4484async fn acknowledge_buffer_version(
4485 request: proto::AckBufferOperation,
4486 session: UserSession,
4487) -> Result<()> {
4488 let buffer_id = BufferId::from_proto(request.buffer_id);
4489 session
4490 .db()
4491 .await
4492 .observe_buffer_version(
4493 buffer_id,
4494 session.user_id(),
4495 request.epoch as i32,
4496 &request.version,
4497 )
4498 .await?;
4499 Ok(())
4500}
4501
4502async fn count_language_model_tokens(
4503 request: proto::CountLanguageModelTokens,
4504 response: Response<proto::CountLanguageModelTokens>,
4505 session: Session,
4506 config: &Config,
4507) -> Result<()> {
4508 let Some(session) = session.for_user() else {
4509 return Err(anyhow!("user not found"))?;
4510 };
4511 authorize_access_to_legacy_llm_endpoints(&session).await?;
4512
4513 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4514 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4515 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4516 };
4517
4518 session
4519 .app_state
4520 .rate_limiter
4521 .check(&*rate_limit, session.user_id())
4522 .await?;
4523
4524 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4525 Some(proto::LanguageModelProvider::Google) => {
4526 let api_key = config
4527 .google_ai_api_key
4528 .as_ref()
4529 .context("no Google AI API key configured on the server")?;
4530 google_ai::count_tokens(
4531 session.http_client.as_ref(),
4532 google_ai::API_URL,
4533 api_key,
4534 serde_json::from_str(&request.request)?,
4535 None,
4536 )
4537 .await?
4538 }
4539 _ => return Err(anyhow!("unsupported provider"))?,
4540 };
4541
4542 response.send(proto::CountLanguageModelTokensResponse {
4543 token_count: result.total_tokens as u32,
4544 })?;
4545
4546 Ok(())
4547}
4548
4549struct ZedProCountLanguageModelTokensRateLimit;
4550
4551impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4552 fn capacity(&self) -> usize {
4553 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4554 .ok()
4555 .and_then(|v| v.parse().ok())
4556 .unwrap_or(600) // Picked arbitrarily
4557 }
4558
4559 fn refill_duration(&self) -> chrono::Duration {
4560 chrono::Duration::hours(1)
4561 }
4562
4563 fn db_name(&self) -> &'static str {
4564 "zed-pro:count-language-model-tokens"
4565 }
4566}
4567
4568struct FreeCountLanguageModelTokensRateLimit;
4569
4570impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4571 fn capacity(&self) -> usize {
4572 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4573 .ok()
4574 .and_then(|v| v.parse().ok())
4575 .unwrap_or(600 / 10) // Picked arbitrarily
4576 }
4577
4578 fn refill_duration(&self) -> chrono::Duration {
4579 chrono::Duration::hours(1)
4580 }
4581
4582 fn db_name(&self) -> &'static str {
4583 "free:count-language-model-tokens"
4584 }
4585}
4586
4587struct ZedProComputeEmbeddingsRateLimit;
4588
4589impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4590 fn capacity(&self) -> usize {
4591 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4592 .ok()
4593 .and_then(|v| v.parse().ok())
4594 .unwrap_or(5000) // Picked arbitrarily
4595 }
4596
4597 fn refill_duration(&self) -> chrono::Duration {
4598 chrono::Duration::hours(1)
4599 }
4600
4601 fn db_name(&self) -> &'static str {
4602 "zed-pro:compute-embeddings"
4603 }
4604}
4605
4606struct FreeComputeEmbeddingsRateLimit;
4607
4608impl RateLimit for FreeComputeEmbeddingsRateLimit {
4609 fn capacity(&self) -> usize {
4610 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4611 .ok()
4612 .and_then(|v| v.parse().ok())
4613 .unwrap_or(5000 / 10) // Picked arbitrarily
4614 }
4615
4616 fn refill_duration(&self) -> chrono::Duration {
4617 chrono::Duration::hours(1)
4618 }
4619
4620 fn db_name(&self) -> &'static str {
4621 "free:compute-embeddings"
4622 }
4623}
4624
4625async fn compute_embeddings(
4626 request: proto::ComputeEmbeddings,
4627 response: Response<proto::ComputeEmbeddings>,
4628 session: UserSession,
4629 api_key: Option<Arc<str>>,
4630) -> Result<()> {
4631 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4632 authorize_access_to_legacy_llm_endpoints(&session).await?;
4633
4634 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4635 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4636 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4637 };
4638
4639 session
4640 .app_state
4641 .rate_limiter
4642 .check(&*rate_limit, session.user_id())
4643 .await?;
4644
4645 let embeddings = match request.model.as_str() {
4646 "openai/text-embedding-3-small" => {
4647 open_ai::embed(
4648 session.http_client.as_ref(),
4649 OPEN_AI_API_URL,
4650 &api_key,
4651 OpenAiEmbeddingModel::TextEmbedding3Small,
4652 request.texts.iter().map(|text| text.as_str()),
4653 )
4654 .await?
4655 }
4656 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4657 };
4658
4659 let embeddings = request
4660 .texts
4661 .iter()
4662 .map(|text| {
4663 let mut hasher = sha2::Sha256::new();
4664 hasher.update(text.as_bytes());
4665 let result = hasher.finalize();
4666 result.to_vec()
4667 })
4668 .zip(
4669 embeddings
4670 .data
4671 .into_iter()
4672 .map(|embedding| embedding.embedding),
4673 )
4674 .collect::<HashMap<_, _>>();
4675
4676 let db = session.db().await;
4677 db.save_embeddings(&request.model, &embeddings)
4678 .await
4679 .context("failed to save embeddings")
4680 .trace_err();
4681
4682 response.send(proto::ComputeEmbeddingsResponse {
4683 embeddings: embeddings
4684 .into_iter()
4685 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4686 .collect(),
4687 })?;
4688 Ok(())
4689}
4690
4691async fn get_cached_embeddings(
4692 request: proto::GetCachedEmbeddings,
4693 response: Response<proto::GetCachedEmbeddings>,
4694 session: UserSession,
4695) -> Result<()> {
4696 authorize_access_to_legacy_llm_endpoints(&session).await?;
4697
4698 let db = session.db().await;
4699 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4700
4701 response.send(proto::GetCachedEmbeddingsResponse {
4702 embeddings: embeddings
4703 .into_iter()
4704 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4705 .collect(),
4706 })?;
4707 Ok(())
4708}
4709
4710/// This is leftover from before the LLM service.
4711///
4712/// The endpoints protected by this check will be moved there eventually.
4713async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4714 if session.is_staff() {
4715 Ok(())
4716 } else {
4717 Err(anyhow!("permission denied"))?
4718 }
4719}
4720
4721/// Get a Supermaven API key for the user
4722async fn get_supermaven_api_key(
4723 _request: proto::GetSupermavenApiKey,
4724 response: Response<proto::GetSupermavenApiKey>,
4725 session: UserSession,
4726) -> Result<()> {
4727 let user_id: String = session.user_id().to_string();
4728 if !session.is_staff() {
4729 return Err(anyhow!("supermaven not enabled for this account"))?;
4730 }
4731
4732 let email = session
4733 .email()
4734 .ok_or_else(|| anyhow!("user must have an email"))?;
4735
4736 let supermaven_admin_api = session
4737 .supermaven_client
4738 .as_ref()
4739 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4740
4741 let result = supermaven_admin_api
4742 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4743 .await?;
4744
4745 response.send(proto::GetSupermavenApiKeyResponse {
4746 api_key: result.api_key,
4747 })?;
4748
4749 Ok(())
4750}
4751
4752/// Start receiving chat updates for a channel
4753async fn join_channel_chat(
4754 request: proto::JoinChannelChat,
4755 response: Response<proto::JoinChannelChat>,
4756 session: UserSession,
4757) -> Result<()> {
4758 let channel_id = ChannelId::from_proto(request.channel_id);
4759
4760 let db = session.db().await;
4761 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4762 .await?;
4763 let messages = db
4764 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4765 .await?;
4766 response.send(proto::JoinChannelChatResponse {
4767 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4768 messages,
4769 })?;
4770 Ok(())
4771}
4772
4773/// Stop receiving chat updates for a channel
4774async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4775 let channel_id = ChannelId::from_proto(request.channel_id);
4776 session
4777 .db()
4778 .await
4779 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4780 .await?;
4781 Ok(())
4782}
4783
4784/// Retrieve the chat history for a channel
4785async fn get_channel_messages(
4786 request: proto::GetChannelMessages,
4787 response: Response<proto::GetChannelMessages>,
4788 session: UserSession,
4789) -> Result<()> {
4790 let channel_id = ChannelId::from_proto(request.channel_id);
4791 let messages = session
4792 .db()
4793 .await
4794 .get_channel_messages(
4795 channel_id,
4796 session.user_id(),
4797 MESSAGE_COUNT_PER_PAGE,
4798 Some(MessageId::from_proto(request.before_message_id)),
4799 )
4800 .await?;
4801 response.send(proto::GetChannelMessagesResponse {
4802 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4803 messages,
4804 })?;
4805 Ok(())
4806}
4807
4808/// Retrieve specific chat messages
4809async fn get_channel_messages_by_id(
4810 request: proto::GetChannelMessagesById,
4811 response: Response<proto::GetChannelMessagesById>,
4812 session: UserSession,
4813) -> Result<()> {
4814 let message_ids = request
4815 .message_ids
4816 .iter()
4817 .map(|id| MessageId::from_proto(*id))
4818 .collect::<Vec<_>>();
4819 let messages = session
4820 .db()
4821 .await
4822 .get_channel_messages_by_id(session.user_id(), &message_ids)
4823 .await?;
4824 response.send(proto::GetChannelMessagesResponse {
4825 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4826 messages,
4827 })?;
4828 Ok(())
4829}
4830
4831/// Retrieve the current users notifications
4832async fn get_notifications(
4833 request: proto::GetNotifications,
4834 response: Response<proto::GetNotifications>,
4835 session: UserSession,
4836) -> Result<()> {
4837 let notifications = session
4838 .db()
4839 .await
4840 .get_notifications(
4841 session.user_id(),
4842 NOTIFICATION_COUNT_PER_PAGE,
4843 request.before_id.map(db::NotificationId::from_proto),
4844 )
4845 .await?;
4846 response.send(proto::GetNotificationsResponse {
4847 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4848 notifications,
4849 })?;
4850 Ok(())
4851}
4852
4853/// Mark notifications as read
4854async fn mark_notification_as_read(
4855 request: proto::MarkNotificationRead,
4856 response: Response<proto::MarkNotificationRead>,
4857 session: UserSession,
4858) -> Result<()> {
4859 let database = &session.db().await;
4860 let notifications = database
4861 .mark_notification_as_read_by_id(
4862 session.user_id(),
4863 NotificationId::from_proto(request.notification_id),
4864 )
4865 .await?;
4866 send_notifications(
4867 &*session.connection_pool().await,
4868 &session.peer,
4869 notifications,
4870 );
4871 response.send(proto::Ack {})?;
4872 Ok(())
4873}
4874
4875/// Get the current users information
4876async fn get_private_user_info(
4877 _request: proto::GetPrivateUserInfo,
4878 response: Response<proto::GetPrivateUserInfo>,
4879 session: UserSession,
4880) -> Result<()> {
4881 let db = session.db().await;
4882
4883 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4884 let user = db
4885 .get_user_by_id(session.user_id())
4886 .await?
4887 .ok_or_else(|| anyhow!("user not found"))?;
4888 let flags = db.get_user_flags(session.user_id()).await?;
4889
4890 response.send(proto::GetPrivateUserInfoResponse {
4891 metrics_id,
4892 staff: user.admin,
4893 flags,
4894 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4895 })?;
4896 Ok(())
4897}
4898
4899/// Accept the terms of service (tos) on behalf of the current user
4900async fn accept_terms_of_service(
4901 _request: proto::AcceptTermsOfService,
4902 response: Response<proto::AcceptTermsOfService>,
4903 session: UserSession,
4904) -> Result<()> {
4905 let db = session.db().await;
4906
4907 let accepted_tos_at = Utc::now();
4908 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4909 .await?;
4910
4911 response.send(proto::AcceptTermsOfServiceResponse {
4912 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4913 })?;
4914 Ok(())
4915}
4916
4917/// The minimum account age an account must have in order to use the LLM service.
4918const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4919
4920async fn get_llm_api_token(
4921 _request: proto::GetLlmToken,
4922 response: Response<proto::GetLlmToken>,
4923 session: UserSession,
4924) -> Result<()> {
4925 let db = session.db().await;
4926
4927 let flags = db.get_user_flags(session.user_id()).await?;
4928 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4929 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4930
4931 if !session.is_staff() && !has_language_models_feature_flag {
4932 Err(anyhow!("permission denied"))?
4933 }
4934
4935 let user_id = session.user_id();
4936 let user = db
4937 .get_user_by_id(user_id)
4938 .await?
4939 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4940
4941 if user.accepted_tos_at.is_none() {
4942 Err(anyhow!("terms of service not accepted"))?
4943 }
4944
4945 let mut account_created_at = user.created_at;
4946 if let Some(github_created_at) = user.github_user_created_at {
4947 account_created_at = account_created_at.min(github_created_at);
4948 }
4949 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4950 Err(anyhow!("account too young"))?
4951 }
4952 let token = LlmTokenClaims::create(
4953 user.id,
4954 user.github_login.clone(),
4955 session.is_staff(),
4956 has_llm_closed_beta_feature_flag,
4957 session.current_plan(db).await?,
4958 &session.app_state.config,
4959 )?;
4960 response.send(proto::GetLlmTokenResponse { token })?;
4961 Ok(())
4962}
4963
4964fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4965 let message = match message {
4966 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4967 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4968 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4969 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4970 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4971 code: frame.code.into(),
4972 reason: frame.reason,
4973 })),
4974 // We should never receive a frame while reading the message, according
4975 // to the `tungstenite` maintainers:
4976 //
4977 // > It cannot occur when you read messages from the WebSocket, but it
4978 // > can be used when you want to send the raw frames (e.g. you want to
4979 // > send the frames to the WebSocket without composing the full message first).
4980 // >
4981 // > — https://github.com/snapview/tungstenite-rs/issues/268
4982 TungsteniteMessage::Frame(_) => {
4983 bail!("received an unexpected frame while reading the message")
4984 }
4985 };
4986
4987 Ok(message)
4988}
4989
4990fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4991 match message {
4992 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4993 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4994 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4995 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4996 AxumMessage::Close(frame) => {
4997 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4998 code: frame.code.into(),
4999 reason: frame.reason,
5000 }))
5001 }
5002 }
5003}
5004
5005fn notify_membership_updated(
5006 connection_pool: &mut ConnectionPool,
5007 result: MembershipUpdated,
5008 user_id: UserId,
5009 peer: &Peer,
5010) {
5011 for membership in &result.new_channels.channel_memberships {
5012 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5013 }
5014 for channel_id in &result.removed_channels {
5015 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5016 }
5017
5018 let user_channels_update = proto::UpdateUserChannels {
5019 channel_memberships: result
5020 .new_channels
5021 .channel_memberships
5022 .iter()
5023 .map(|cm| proto::ChannelMembership {
5024 channel_id: cm.channel_id.to_proto(),
5025 role: cm.role.into(),
5026 })
5027 .collect(),
5028 ..Default::default()
5029 };
5030
5031 let mut update = build_channels_update(result.new_channels);
5032 update.delete_channels = result
5033 .removed_channels
5034 .into_iter()
5035 .map(|id| id.to_proto())
5036 .collect();
5037 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5038
5039 for connection_id in connection_pool.user_connection_ids(user_id) {
5040 peer.send(connection_id, user_channels_update.clone())
5041 .trace_err();
5042 peer.send(connection_id, update.clone()).trace_err();
5043 }
5044}
5045
5046fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5047 proto::UpdateUserChannels {
5048 channel_memberships: channels
5049 .channel_memberships
5050 .iter()
5051 .map(|m| proto::ChannelMembership {
5052 channel_id: m.channel_id.to_proto(),
5053 role: m.role.into(),
5054 })
5055 .collect(),
5056 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5057 observed_channel_message_id: channels.observed_channel_messages.clone(),
5058 }
5059}
5060
5061fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5062 let mut update = proto::UpdateChannels::default();
5063
5064 for channel in channels.channels {
5065 update.channels.push(channel.to_proto());
5066 }
5067
5068 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5069 update.latest_channel_message_ids = channels.latest_channel_messages;
5070
5071 for (channel_id, participants) in channels.channel_participants {
5072 update
5073 .channel_participants
5074 .push(proto::ChannelParticipants {
5075 channel_id: channel_id.to_proto(),
5076 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5077 });
5078 }
5079
5080 for channel in channels.invited_channels {
5081 update.channel_invitations.push(channel.to_proto());
5082 }
5083
5084 update.hosted_projects = channels.hosted_projects;
5085 update
5086}
5087
5088fn build_initial_contacts_update(
5089 contacts: Vec<db::Contact>,
5090 pool: &ConnectionPool,
5091) -> proto::UpdateContacts {
5092 let mut update = proto::UpdateContacts::default();
5093
5094 for contact in contacts {
5095 match contact {
5096 db::Contact::Accepted { user_id, busy } => {
5097 update.contacts.push(contact_for_user(user_id, busy, pool));
5098 }
5099 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5100 db::Contact::Incoming { user_id } => {
5101 update
5102 .incoming_requests
5103 .push(proto::IncomingContactRequest {
5104 requester_id: user_id.to_proto(),
5105 })
5106 }
5107 }
5108 }
5109
5110 update
5111}
5112
5113fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5114 proto::Contact {
5115 user_id: user_id.to_proto(),
5116 online: pool.is_user_online(user_id),
5117 busy,
5118 }
5119}
5120
5121fn room_updated(room: &proto::Room, peer: &Peer) {
5122 broadcast(
5123 None,
5124 room.participants
5125 .iter()
5126 .filter_map(|participant| Some(participant.peer_id?.into())),
5127 |peer_id| {
5128 peer.send(
5129 peer_id,
5130 proto::RoomUpdated {
5131 room: Some(room.clone()),
5132 },
5133 )
5134 },
5135 );
5136}
5137
5138fn channel_updated(
5139 channel: &db::channel::Model,
5140 room: &proto::Room,
5141 peer: &Peer,
5142 pool: &ConnectionPool,
5143) {
5144 let participants = room
5145 .participants
5146 .iter()
5147 .map(|p| p.user_id)
5148 .collect::<Vec<_>>();
5149
5150 broadcast(
5151 None,
5152 pool.channel_connection_ids(channel.root_id())
5153 .filter_map(|(channel_id, role)| {
5154 role.can_see_channel(channel.visibility)
5155 .then_some(channel_id)
5156 }),
5157 |peer_id| {
5158 peer.send(
5159 peer_id,
5160 proto::UpdateChannels {
5161 channel_participants: vec![proto::ChannelParticipants {
5162 channel_id: channel.id.to_proto(),
5163 participant_user_ids: participants.clone(),
5164 }],
5165 ..Default::default()
5166 },
5167 )
5168 },
5169 );
5170}
5171
5172async fn send_dev_server_projects_update(
5173 user_id: UserId,
5174 mut status: proto::DevServerProjectsUpdate,
5175 session: &Session,
5176) {
5177 let pool = session.connection_pool().await;
5178 for dev_server in &mut status.dev_servers {
5179 dev_server.status =
5180 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5181 }
5182 let connections = pool.user_connection_ids(user_id);
5183 for connection_id in connections {
5184 session.peer.send(connection_id, status.clone()).trace_err();
5185 }
5186}
5187
5188async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5189 let db = session.db().await;
5190
5191 let contacts = db.get_contacts(user_id).await?;
5192 let busy = db.is_user_busy(user_id).await?;
5193
5194 let pool = session.connection_pool().await;
5195 let updated_contact = contact_for_user(user_id, busy, &pool);
5196 for contact in contacts {
5197 if let db::Contact::Accepted {
5198 user_id: contact_user_id,
5199 ..
5200 } = contact
5201 {
5202 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5203 session
5204 .peer
5205 .send(
5206 contact_conn_id,
5207 proto::UpdateContacts {
5208 contacts: vec![updated_contact.clone()],
5209 remove_contacts: Default::default(),
5210 incoming_requests: Default::default(),
5211 remove_incoming_requests: Default::default(),
5212 outgoing_requests: Default::default(),
5213 remove_outgoing_requests: Default::default(),
5214 },
5215 )
5216 .trace_err();
5217 }
5218 }
5219 }
5220 Ok(())
5221}
5222
5223async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5224 log::info!("lost dev server connection, unsharing projects");
5225 let project_ids = session
5226 .db()
5227 .await
5228 .get_stale_dev_server_projects(session.connection_id)
5229 .await?;
5230
5231 for project_id in project_ids {
5232 // not unshare re-checks the connection ids match, so we get away with no transaction
5233 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5234 }
5235
5236 let user_id = session.dev_server().user_id;
5237 let update = session
5238 .db()
5239 .await
5240 .dev_server_projects_update(user_id)
5241 .await?;
5242
5243 send_dev_server_projects_update(user_id, update, session).await;
5244
5245 Ok(())
5246}
5247
5248async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5249 let mut contacts_to_update = HashSet::default();
5250
5251 let room_id;
5252 let canceled_calls_to_user_ids;
5253 let live_kit_room;
5254 let delete_live_kit_room;
5255 let room;
5256 let channel;
5257
5258 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5259 contacts_to_update.insert(session.user_id());
5260
5261 for project in left_room.left_projects.values() {
5262 project_left(project, session);
5263 }
5264
5265 room_id = RoomId::from_proto(left_room.room.id);
5266 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5267 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5268 delete_live_kit_room = left_room.deleted;
5269 room = mem::take(&mut left_room.room);
5270 channel = mem::take(&mut left_room.channel);
5271
5272 room_updated(&room, &session.peer);
5273 } else {
5274 return Ok(());
5275 }
5276
5277 if let Some(channel) = channel {
5278 channel_updated(
5279 &channel,
5280 &room,
5281 &session.peer,
5282 &*session.connection_pool().await,
5283 );
5284 }
5285
5286 {
5287 let pool = session.connection_pool().await;
5288 for canceled_user_id in canceled_calls_to_user_ids {
5289 for connection_id in pool.user_connection_ids(canceled_user_id) {
5290 session
5291 .peer
5292 .send(
5293 connection_id,
5294 proto::CallCanceled {
5295 room_id: room_id.to_proto(),
5296 },
5297 )
5298 .trace_err();
5299 }
5300 contacts_to_update.insert(canceled_user_id);
5301 }
5302 }
5303
5304 for contact_user_id in contacts_to_update {
5305 update_user_contacts(contact_user_id, session).await?;
5306 }
5307
5308 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5309 live_kit
5310 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5311 .await
5312 .trace_err();
5313
5314 if delete_live_kit_room {
5315 live_kit.delete_room(live_kit_room).await.trace_err();
5316 }
5317 }
5318
5319 Ok(())
5320}
5321
5322async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5323 let left_channel_buffers = session
5324 .db()
5325 .await
5326 .leave_channel_buffers(session.connection_id)
5327 .await?;
5328
5329 for left_buffer in left_channel_buffers {
5330 channel_buffer_updated(
5331 session.connection_id,
5332 left_buffer.connections,
5333 &proto::UpdateChannelBufferCollaborators {
5334 channel_id: left_buffer.channel_id.to_proto(),
5335 collaborators: left_buffer.collaborators,
5336 },
5337 &session.peer,
5338 );
5339 }
5340
5341 Ok(())
5342}
5343
5344fn project_left(project: &db::LeftProject, session: &UserSession) {
5345 for connection_id in &project.connection_ids {
5346 if project.should_unshare {
5347 session
5348 .peer
5349 .send(
5350 *connection_id,
5351 proto::UnshareProject {
5352 project_id: project.id.to_proto(),
5353 },
5354 )
5355 .trace_err();
5356 } else {
5357 session
5358 .peer
5359 .send(
5360 *connection_id,
5361 proto::RemoveProjectCollaborator {
5362 project_id: project.id.to_proto(),
5363 peer_id: Some(session.connection_id.into()),
5364 },
5365 )
5366 .trace_err();
5367 }
5368 }
5369}
5370
5371pub trait ResultExt {
5372 type Ok;
5373
5374 fn trace_err(self) -> Option<Self::Ok>;
5375}
5376
5377impl<T, E> ResultExt for Result<T, E>
5378where
5379 E: std::fmt::Debug,
5380{
5381 type Ok = T;
5382
5383 #[track_caller]
5384 fn trace_err(self) -> Option<T> {
5385 match self {
5386 Ok(value) => Some(value),
5387 Err(error) => {
5388 tracing::error!("{:?}", error);
5389 None
5390 }
5391 }
5392 }
5393}