1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use isahc_http_client::IsahcHttpClient;
40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
195 if self.is_staff() {
196 return Ok(proto::Plan::ZedPro);
197 }
198
199 let Some(user_id) = self.user_id() else {
200 return Ok(proto::Plan::Free);
201 };
202
203 if db.has_active_billing_subscription(user_id).await? {
204 Ok(proto::Plan::ZedPro)
205 } else {
206 Ok(proto::Plan::Free)
207 }
208 }
209
210 fn dev_server_id(&self) -> Option<DevServerId> {
211 match &self.principal {
212 Principal::User(_) | Principal::Impersonated { .. } => None,
213 Principal::DevServer(dev_server) => Some(dev_server.id),
214 }
215 }
216
217 fn principal_id(&self) -> PrincipalId {
218 match &self.principal {
219 Principal::User(user) => PrincipalId::UserId(user.id),
220 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
221 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
222 }
223 }
224}
225
226impl Debug for Session {
227 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228 let mut result = f.debug_struct("Session");
229 match &self.principal {
230 Principal::User(user) => {
231 result.field("user", &user.github_login);
232 }
233 Principal::Impersonated { user, admin } => {
234 result.field("user", &user.github_login);
235 result.field("impersonator", &admin.github_login);
236 }
237 Principal::DevServer(dev_server) => {
238 result.field("dev_server", &dev_server.id);
239 }
240 }
241 result.field("connection_id", &self.connection_id).finish()
242 }
243}
244
245struct UserSession(Session);
246
247impl UserSession {
248 pub fn new(s: Session) -> Option<Self> {
249 s.user_id().map(|_| UserSession(s))
250 }
251 pub fn user_id(&self) -> UserId {
252 self.0.user_id().unwrap()
253 }
254
255 pub fn email(&self) -> Option<String> {
256 match &self.0.principal {
257 Principal::User(user) => user.email_address.clone(),
258 Principal::Impersonated { user, .. } => user.email_address.clone(),
259 Principal::DevServer(..) => None,
260 }
261 }
262}
263
264impl Deref for UserSession {
265 type Target = Session;
266
267 fn deref(&self) -> &Self::Target {
268 &self.0
269 }
270}
271impl DerefMut for UserSession {
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self.0
274 }
275}
276
277struct DevServerSession(Session);
278
279impl DevServerSession {
280 pub fn new(s: Session) -> Option<Self> {
281 s.dev_server_id().map(|_| DevServerSession(s))
282 }
283 pub fn dev_server_id(&self) -> DevServerId {
284 self.0.dev_server_id().unwrap()
285 }
286
287 fn dev_server(&self) -> &dev_server::Model {
288 match &self.0.principal {
289 Principal::DevServer(dev_server) => dev_server,
290 _ => unreachable!(),
291 }
292 }
293}
294
295impl Deref for DevServerSession {
296 type Target = Session;
297
298 fn deref(&self) -> &Self::Target {
299 &self.0
300 }
301}
302impl DerefMut for DevServerSession {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut self.0
305 }
306}
307
308fn user_handler<M: RequestMessage, Fut>(
309 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
311where
312 Fut: Send + Future<Output = Result<()>>,
313{
314 let handler = Arc::new(handler);
315 move |message, response, session| {
316 let handler = handler.clone();
317 Box::pin(async move {
318 if let Some(user_session) = session.for_user() {
319 Ok(handler(message, response, user_session).await?)
320 } else {
321 Err(Error::Internal(anyhow!(
322 "must be a user to call {}",
323 M::NAME
324 )))
325 }
326 })
327 }
328}
329
330fn dev_server_handler<M: RequestMessage, Fut>(
331 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
333where
334 Fut: Send + Future<Output = Result<()>>,
335{
336 let handler = Arc::new(handler);
337 move |message, response, session| {
338 let handler = handler.clone();
339 Box::pin(async move {
340 if let Some(dev_server_session) = session.for_dev_server() {
341 Ok(handler(message, response, dev_server_session).await?)
342 } else {
343 Err(Error::Internal(anyhow!(
344 "must be a dev server to call {}",
345 M::NAME
346 )))
347 }
348 })
349 }
350}
351
352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
353 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
355where
356 InnertRetFut: Send + Future<Output = Result<()>>,
357{
358 let handler = Arc::new(handler);
359 move |message, session| {
360 let handler = handler.clone();
361 Box::pin(async move {
362 if let Some(user_session) = session.for_user() {
363 Ok(handler(message, user_session).await?)
364 } else {
365 Err(Error::Internal(anyhow!(
366 "must be a user to call {}",
367 M::NAME
368 )))
369 }
370 })
371 }
372}
373
374struct DbHandle(Arc<Database>);
375
376impl Deref for DbHandle {
377 type Target = Database;
378
379 fn deref(&self) -> &Self::Target {
380 self.0.as_ref()
381 }
382}
383
384pub struct Server {
385 id: parking_lot::Mutex<ServerId>,
386 peer: Arc<Peer>,
387 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
388 app_state: Arc<AppState>,
389 handlers: HashMap<TypeId, MessageHandler>,
390 teardown: watch::Sender<bool>,
391}
392
393pub(crate) struct ConnectionPoolGuard<'a> {
394 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
395 _not_send: PhantomData<Rc<()>>,
396}
397
398#[derive(Serialize)]
399pub struct ServerSnapshot<'a> {
400 peer: &'a Peer,
401 #[serde(serialize_with = "serialize_deref")]
402 connection_pool: ConnectionPoolGuard<'a>,
403}
404
405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
406where
407 S: Serializer,
408 T: Deref<Target = U>,
409 U: Serialize,
410{
411 Serialize::serialize(value.deref(), serializer)
412}
413
414impl Server {
415 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
416 let mut server = Self {
417 id: parking_lot::Mutex::new(id),
418 peer: Peer::new(id.0 as u32),
419 app_state: app_state.clone(),
420 connection_pool: Default::default(),
421 handlers: Default::default(),
422 teardown: watch::channel(false).0,
423 };
424
425 server
426 .add_request_handler(ping)
427 .add_request_handler(user_handler(create_room))
428 .add_request_handler(user_handler(join_room))
429 .add_request_handler(user_handler(rejoin_room))
430 .add_request_handler(user_handler(leave_room))
431 .add_request_handler(user_handler(set_room_participant_role))
432 .add_request_handler(user_handler(call))
433 .add_request_handler(user_handler(cancel_call))
434 .add_message_handler(user_message_handler(decline_call))
435 .add_request_handler(user_handler(update_participant_location))
436 .add_request_handler(user_handler(share_project))
437 .add_message_handler(unshare_project)
438 .add_request_handler(user_handler(join_project))
439 .add_request_handler(user_handler(join_hosted_project))
440 .add_request_handler(user_handler(rejoin_dev_server_projects))
441 .add_request_handler(user_handler(create_dev_server_project))
442 .add_request_handler(user_handler(update_dev_server_project))
443 .add_request_handler(user_handler(delete_dev_server_project))
444 .add_request_handler(user_handler(create_dev_server))
445 .add_request_handler(user_handler(regenerate_dev_server_token))
446 .add_request_handler(user_handler(rename_dev_server))
447 .add_request_handler(user_handler(delete_dev_server))
448 .add_request_handler(user_handler(list_remote_directory))
449 .add_request_handler(dev_server_handler(share_dev_server_project))
450 .add_request_handler(dev_server_handler(shutdown_dev_server))
451 .add_request_handler(dev_server_handler(reconnect_dev_server))
452 .add_message_handler(user_message_handler(leave_project))
453 .add_request_handler(update_project)
454 .add_request_handler(update_worktree)
455 .add_message_handler(start_language_server)
456 .add_message_handler(update_language_server)
457 .add_message_handler(update_diagnostic_summary)
458 .add_message_handler(update_worktree_settings)
459 .add_request_handler(user_handler(
460 forward_project_request_for_owner::<proto::TaskContextForLocation>,
461 ))
462 .add_request_handler(user_handler(
463 forward_project_request_for_owner::<proto::TaskTemplates>,
464 ))
465 .add_request_handler(user_handler(
466 forward_read_only_project_request::<proto::GetHover>,
467 ))
468 .add_request_handler(user_handler(
469 forward_read_only_project_request::<proto::GetDefinition>,
470 ))
471 .add_request_handler(user_handler(
472 forward_read_only_project_request::<proto::GetTypeDefinition>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::GetReferences>,
476 ))
477 .add_request_handler(user_handler(
478 forward_read_only_project_request::<proto::SearchProject>,
479 ))
480 .add_request_handler(user_handler(forward_find_search_candidates_request))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDocumentHighlights>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetProjectSymbols>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferById>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::SynchronizeBuffers>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::InlayHints>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::ResolveInlayHint>,
501 ))
502 .add_request_handler(user_handler(
503 forward_read_only_project_request::<proto::OpenBufferByPath>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCompletions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::OpenNewBuffer>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::GetCodeActions>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::ApplyCodeAction>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::PrepareRename>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::PerformRename>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::ReloadBuffers>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::FormatBuffers>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::CreateProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::RenameProjectEntry>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::CopyProjectEntry>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::DeleteProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::ExpandProjectEntry>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::OnTypeFormatting>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::SaveBuffer>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::BlameBuffer>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::MultiLspQuery>,
561 ))
562 .add_request_handler(user_handler(
563 forward_mutating_project_request::<proto::RestartLanguageServers>,
564 ))
565 .add_request_handler(user_handler(
566 forward_mutating_project_request::<proto::LinkedEditingRange>,
567 ))
568 .add_message_handler(create_buffer_for_peer)
569 .add_request_handler(update_buffer)
570 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
573 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
574 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
575 .add_request_handler(get_users)
576 .add_request_handler(user_handler(fuzzy_search_users))
577 .add_request_handler(user_handler(request_contact))
578 .add_request_handler(user_handler(remove_contact))
579 .add_request_handler(user_handler(respond_to_contact_request))
580 .add_message_handler(subscribe_to_channels)
581 .add_request_handler(user_handler(create_channel))
582 .add_request_handler(user_handler(delete_channel))
583 .add_request_handler(user_handler(invite_channel_member))
584 .add_request_handler(user_handler(remove_channel_member))
585 .add_request_handler(user_handler(set_channel_member_role))
586 .add_request_handler(user_handler(set_channel_visibility))
587 .add_request_handler(user_handler(rename_channel))
588 .add_request_handler(user_handler(join_channel_buffer))
589 .add_request_handler(user_handler(leave_channel_buffer))
590 .add_message_handler(user_message_handler(update_channel_buffer))
591 .add_request_handler(user_handler(rejoin_channel_buffers))
592 .add_request_handler(user_handler(get_channel_members))
593 .add_request_handler(user_handler(respond_to_channel_invite))
594 .add_request_handler(user_handler(join_channel))
595 .add_request_handler(user_handler(join_channel_chat))
596 .add_message_handler(user_message_handler(leave_channel_chat))
597 .add_request_handler(user_handler(send_channel_message))
598 .add_request_handler(user_handler(remove_channel_message))
599 .add_request_handler(user_handler(update_channel_message))
600 .add_request_handler(user_handler(get_channel_messages))
601 .add_request_handler(user_handler(get_channel_messages_by_id))
602 .add_request_handler(user_handler(get_notifications))
603 .add_request_handler(user_handler(mark_notification_as_read))
604 .add_request_handler(user_handler(move_channel))
605 .add_request_handler(user_handler(follow))
606 .add_message_handler(user_message_handler(unfollow))
607 .add_message_handler(user_message_handler(update_followers))
608 .add_request_handler(user_handler(get_private_user_info))
609 .add_request_handler(user_handler(get_llm_api_token))
610 .add_request_handler(user_handler(accept_terms_of_service))
611 .add_message_handler(user_message_handler(acknowledge_channel_message))
612 .add_message_handler(user_message_handler(acknowledge_buffer_version))
613 .add_request_handler(user_handler(get_supermaven_api_key))
614 .add_request_handler(user_handler(
615 forward_mutating_project_request::<proto::OpenContext>,
616 ))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::CreateContext>,
619 ))
620 .add_request_handler(user_handler(
621 forward_mutating_project_request::<proto::SynchronizeContexts>,
622 ))
623 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
624 .add_message_handler(update_context)
625 .add_request_handler({
626 let app_state = app_state.clone();
627 move |request, response, session| {
628 let app_state = app_state.clone();
629 async move {
630 count_language_model_tokens(request, response, session, &app_state.config)
631 .await
632 }
633 }
634 })
635 .add_request_handler({
636 user_handler(move |request, response, session| {
637 get_cached_embeddings(request, response, session)
638 })
639 })
640 .add_request_handler({
641 let app_state = app_state.clone();
642 user_handler(move |request, response, session| {
643 compute_embeddings(
644 request,
645 response,
646 session,
647 app_state.config.openai_api_key.clone(),
648 )
649 })
650 });
651
652 Arc::new(server)
653 }
654
655 pub async fn start(&self) -> Result<()> {
656 let server_id = *self.id.lock();
657 let app_state = self.app_state.clone();
658 let peer = self.peer.clone();
659 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
660 let pool = self.connection_pool.clone();
661 let live_kit_client = self.app_state.live_kit_client.clone();
662
663 let span = info_span!("start server");
664 self.app_state.executor.spawn_detached(
665 async move {
666 tracing::info!("waiting for cleanup timeout");
667 timeout.await;
668 tracing::info!("cleanup timeout expired, retrieving stale rooms");
669 if let Some((room_ids, channel_ids)) = app_state
670 .db
671 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
672 .await
673 .trace_err()
674 {
675 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
676 tracing::info!(
677 stale_channel_buffer_count = channel_ids.len(),
678 "retrieved stale channel buffers"
679 );
680
681 for channel_id in channel_ids {
682 if let Some(refreshed_channel_buffer) = app_state
683 .db
684 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
685 .await
686 .trace_err()
687 {
688 for connection_id in refreshed_channel_buffer.connection_ids {
689 peer.send(
690 connection_id,
691 proto::UpdateChannelBufferCollaborators {
692 channel_id: channel_id.to_proto(),
693 collaborators: refreshed_channel_buffer
694 .collaborators
695 .clone(),
696 },
697 )
698 .trace_err();
699 }
700 }
701 }
702
703 for room_id in room_ids {
704 let mut contacts_to_update = HashSet::default();
705 let mut canceled_calls_to_user_ids = Vec::new();
706 let mut live_kit_room = String::new();
707 let mut delete_live_kit_room = false;
708
709 if let Some(mut refreshed_room) = app_state
710 .db
711 .clear_stale_room_participants(room_id, server_id)
712 .await
713 .trace_err()
714 {
715 tracing::info!(
716 room_id = room_id.0,
717 new_participant_count = refreshed_room.room.participants.len(),
718 "refreshed room"
719 );
720 room_updated(&refreshed_room.room, &peer);
721 if let Some(channel) = refreshed_room.channel.as_ref() {
722 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
723 }
724 contacts_to_update
725 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
726 contacts_to_update
727 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
728 canceled_calls_to_user_ids =
729 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
730 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
731 delete_live_kit_room = refreshed_room.room.participants.is_empty();
732 }
733
734 {
735 let pool = pool.lock();
736 for canceled_user_id in canceled_calls_to_user_ids {
737 for connection_id in pool.user_connection_ids(canceled_user_id) {
738 peer.send(
739 connection_id,
740 proto::CallCanceled {
741 room_id: room_id.to_proto(),
742 },
743 )
744 .trace_err();
745 }
746 }
747 }
748
749 for user_id in contacts_to_update {
750 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
751 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
752 if let Some((busy, contacts)) = busy.zip(contacts) {
753 let pool = pool.lock();
754 let updated_contact = contact_for_user(user_id, busy, &pool);
755 for contact in contacts {
756 if let db::Contact::Accepted {
757 user_id: contact_user_id,
758 ..
759 } = contact
760 {
761 for contact_conn_id in
762 pool.user_connection_ids(contact_user_id)
763 {
764 peer.send(
765 contact_conn_id,
766 proto::UpdateContacts {
767 contacts: vec![updated_contact.clone()],
768 remove_contacts: Default::default(),
769 incoming_requests: Default::default(),
770 remove_incoming_requests: Default::default(),
771 outgoing_requests: Default::default(),
772 remove_outgoing_requests: Default::default(),
773 },
774 )
775 .trace_err();
776 }
777 }
778 }
779 }
780 }
781
782 if let Some(live_kit) = live_kit_client.as_ref() {
783 if delete_live_kit_room {
784 live_kit.delete_room(live_kit_room).await.trace_err();
785 }
786 }
787 }
788 }
789
790 app_state
791 .db
792 .delete_stale_servers(&app_state.config.zed_environment, server_id)
793 .await
794 .trace_err();
795 }
796 .instrument(span),
797 );
798 Ok(())
799 }
800
801 pub fn teardown(&self) {
802 self.peer.teardown();
803 self.connection_pool.lock().reset();
804 let _ = self.teardown.send(true);
805 }
806
807 #[cfg(test)]
808 pub fn reset(&self, id: ServerId) {
809 self.teardown();
810 *self.id.lock() = id;
811 self.peer.reset(id.0 as u32);
812 let _ = self.teardown.send(false);
813 }
814
815 #[cfg(test)]
816 pub fn id(&self) -> ServerId {
817 *self.id.lock()
818 }
819
820 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
821 where
822 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
823 Fut: 'static + Send + Future<Output = Result<()>>,
824 M: EnvelopedMessage,
825 {
826 let prev_handler = self.handlers.insert(
827 TypeId::of::<M>(),
828 Box::new(move |envelope, session| {
829 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
830 let received_at = envelope.received_at;
831 tracing::info!("message received");
832 let start_time = Instant::now();
833 let future = (handler)(*envelope, session);
834 async move {
835 let result = future.await;
836 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
837 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
838 let queue_duration_ms = total_duration_ms - processing_duration_ms;
839 let payload_type = M::NAME;
840
841 match result {
842 Err(error) => {
843 tracing::error!(
844 ?error,
845 total_duration_ms,
846 processing_duration_ms,
847 queue_duration_ms,
848 payload_type,
849 "error handling message"
850 )
851 }
852 Ok(()) => tracing::info!(
853 total_duration_ms,
854 processing_duration_ms,
855 queue_duration_ms,
856 "finished handling message"
857 ),
858 }
859 }
860 .boxed()
861 }),
862 );
863 if prev_handler.is_some() {
864 panic!("registered a handler for the same message twice");
865 }
866 self
867 }
868
869 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
870 where
871 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
872 Fut: 'static + Send + Future<Output = Result<()>>,
873 M: EnvelopedMessage,
874 {
875 self.add_handler(move |envelope, session| handler(envelope.payload, session));
876 self
877 }
878
879 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
880 where
881 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
882 Fut: Send + Future<Output = Result<()>>,
883 M: RequestMessage,
884 {
885 let handler = Arc::new(handler);
886 self.add_handler(move |envelope, session| {
887 let receipt = envelope.receipt();
888 let handler = handler.clone();
889 async move {
890 let peer = session.peer.clone();
891 let responded = Arc::new(AtomicBool::default());
892 let response = Response {
893 peer: peer.clone(),
894 responded: responded.clone(),
895 receipt,
896 };
897 match (handler)(envelope.payload, response, session).await {
898 Ok(()) => {
899 if responded.load(std::sync::atomic::Ordering::SeqCst) {
900 Ok(())
901 } else {
902 Err(anyhow!("handler did not send a response"))?
903 }
904 }
905 Err(error) => {
906 let proto_err = match &error {
907 Error::Internal(err) => err.to_proto(),
908 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
909 };
910 peer.respond_with_error(receipt, proto_err)?;
911 Err(error)
912 }
913 }
914 }
915 })
916 }
917
918 #[allow(clippy::too_many_arguments)]
919 pub fn handle_connection(
920 self: &Arc<Self>,
921 connection: Connection,
922 address: String,
923 principal: Principal,
924 zed_version: ZedVersion,
925 geoip_country_code: Option<String>,
926 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
927 executor: Executor,
928 ) -> impl Future<Output = ()> {
929 let this = self.clone();
930 let span = info_span!("handle connection", %address,
931 connection_id=field::Empty,
932 user_id=field::Empty,
933 login=field::Empty,
934 impersonator=field::Empty,
935 dev_server_id=field::Empty,
936 geoip_country_code=field::Empty
937 );
938 principal.update_span(&span);
939 if let Some(country_code) = geoip_country_code.as_ref() {
940 span.record("geoip_country_code", country_code);
941 }
942
943 let mut teardown = self.teardown.subscribe();
944 async move {
945 if *teardown.borrow() {
946 tracing::error!("server is tearing down");
947 return
948 }
949 let (connection_id, handle_io, mut incoming_rx) = this
950 .peer
951 .add_connection(connection, {
952 let executor = executor.clone();
953 move |duration| executor.sleep(duration)
954 });
955 tracing::Span::current().record("connection_id", format!("{}", connection_id));
956
957 tracing::info!("connection opened");
958
959 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
960 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
961 Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
962 Err(error) => {
963 tracing::error!(?error, "failed to create HTTP client");
964 return;
965 }
966 };
967
968 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
969 supermaven_admin_api_key.to_string(),
970 http_client.clone(),
971 )));
972
973 let session = Session {
974 principal: principal.clone(),
975 connection_id,
976 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
977 peer: this.peer.clone(),
978 connection_pool: this.connection_pool.clone(),
979 app_state: this.app_state.clone(),
980 http_client,
981 geoip_country_code,
982 _executor: executor.clone(),
983 supermaven_client,
984 };
985
986 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
987 tracing::error!(?error, "failed to send initial client update");
988 return;
989 }
990
991 let handle_io = handle_io.fuse();
992 futures::pin_mut!(handle_io);
993
994 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
995 // This prevents deadlocks when e.g., client A performs a request to client B and
996 // client B performs a request to client A. If both clients stop processing further
997 // messages until their respective request completes, they won't have a chance to
998 // respond to the other client's request and cause a deadlock.
999 //
1000 // This arrangement ensures we will attempt to process earlier messages first, but fall
1001 // back to processing messages arrived later in the spirit of making progress.
1002 let mut foreground_message_handlers = FuturesUnordered::new();
1003 let concurrent_handlers = Arc::new(Semaphore::new(256));
1004 loop {
1005 let next_message = async {
1006 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1007 let message = incoming_rx.next().await;
1008 (permit, message)
1009 }.fuse();
1010 futures::pin_mut!(next_message);
1011 futures::select_biased! {
1012 _ = teardown.changed().fuse() => return,
1013 result = handle_io => {
1014 if let Err(error) = result {
1015 tracing::error!(?error, "error handling I/O");
1016 }
1017 break;
1018 }
1019 _ = foreground_message_handlers.next() => {}
1020 next_message = next_message => {
1021 let (permit, message) = next_message;
1022 if let Some(message) = message {
1023 let type_name = message.payload_type_name();
1024 // note: we copy all the fields from the parent span so we can query them in the logs.
1025 // (https://github.com/tokio-rs/tracing/issues/2670).
1026 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1027 user_id=field::Empty,
1028 login=field::Empty,
1029 impersonator=field::Empty,
1030 dev_server_id=field::Empty
1031 );
1032 principal.update_span(&span);
1033 let span_enter = span.enter();
1034 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1035 let is_background = message.is_background();
1036 let handle_message = (handler)(message, session.clone());
1037 drop(span_enter);
1038
1039 let handle_message = async move {
1040 handle_message.await;
1041 drop(permit);
1042 }.instrument(span);
1043 if is_background {
1044 executor.spawn_detached(handle_message);
1045 } else {
1046 foreground_message_handlers.push(handle_message);
1047 }
1048 } else {
1049 tracing::error!("no message handler");
1050 }
1051 } else {
1052 tracing::info!("connection closed");
1053 break;
1054 }
1055 }
1056 }
1057 }
1058
1059 drop(foreground_message_handlers);
1060 tracing::info!("signing out");
1061 if let Err(error) = connection_lost(session, teardown, executor).await {
1062 tracing::error!(?error, "error signing out");
1063 }
1064
1065 }.instrument(span)
1066 }
1067
1068 async fn send_initial_client_update(
1069 &self,
1070 connection_id: ConnectionId,
1071 principal: &Principal,
1072 zed_version: ZedVersion,
1073 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1074 session: &Session,
1075 ) -> Result<()> {
1076 self.peer.send(
1077 connection_id,
1078 proto::Hello {
1079 peer_id: Some(connection_id.into()),
1080 },
1081 )?;
1082 tracing::info!("sent hello message");
1083 if let Some(send_connection_id) = send_connection_id.take() {
1084 let _ = send_connection_id.send(connection_id);
1085 }
1086
1087 match principal {
1088 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1089 if !user.connected_once {
1090 self.peer.send(connection_id, proto::ShowContacts {})?;
1091 self.app_state
1092 .db
1093 .set_user_connected_once(user.id, true)
1094 .await?;
1095 }
1096
1097 update_user_plan(user.id, session).await?;
1098
1099 let (contacts, dev_server_projects) = future::try_join(
1100 self.app_state.db.get_contacts(user.id),
1101 self.app_state.db.dev_server_projects_update(user.id),
1102 )
1103 .await?;
1104
1105 {
1106 let mut pool = self.connection_pool.lock();
1107 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1108 self.peer.send(
1109 connection_id,
1110 build_initial_contacts_update(contacts, &pool),
1111 )?;
1112 }
1113
1114 if should_auto_subscribe_to_channels(zed_version) {
1115 subscribe_user_to_channels(user.id, session).await?;
1116 }
1117
1118 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1119
1120 if let Some(incoming_call) =
1121 self.app_state.db.incoming_call_for_user(user.id).await?
1122 {
1123 self.peer.send(connection_id, incoming_call)?;
1124 }
1125
1126 update_user_contacts(user.id, session).await?;
1127 }
1128 Principal::DevServer(dev_server) => {
1129 {
1130 let mut pool = self.connection_pool.lock();
1131 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1132 {
1133 self.peer.send(
1134 stale_connection_id,
1135 proto::ShutdownDevServer {
1136 reason: Some(
1137 "another dev server connected with the same token".to_string(),
1138 ),
1139 },
1140 )?;
1141 pool.remove_connection(stale_connection_id)?;
1142 };
1143 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1144 }
1145
1146 let projects = self
1147 .app_state
1148 .db
1149 .get_projects_for_dev_server(dev_server.id)
1150 .await?;
1151 self.peer
1152 .send(connection_id, proto::DevServerInstructions { projects })?;
1153
1154 let status = self
1155 .app_state
1156 .db
1157 .dev_server_projects_update(dev_server.user_id)
1158 .await?;
1159 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1160 }
1161 }
1162
1163 Ok(())
1164 }
1165
1166 pub async fn invite_code_redeemed(
1167 self: &Arc<Self>,
1168 inviter_id: UserId,
1169 invitee_id: UserId,
1170 ) -> Result<()> {
1171 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1172 if let Some(code) = &user.invite_code {
1173 let pool = self.connection_pool.lock();
1174 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1175 for connection_id in pool.user_connection_ids(inviter_id) {
1176 self.peer.send(
1177 connection_id,
1178 proto::UpdateContacts {
1179 contacts: vec![invitee_contact.clone()],
1180 ..Default::default()
1181 },
1182 )?;
1183 self.peer.send(
1184 connection_id,
1185 proto::UpdateInviteInfo {
1186 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1187 count: user.invite_count as u32,
1188 },
1189 )?;
1190 }
1191 }
1192 }
1193 Ok(())
1194 }
1195
1196 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1197 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1198 if let Some(invite_code) = &user.invite_code {
1199 let pool = self.connection_pool.lock();
1200 for connection_id in pool.user_connection_ids(user_id) {
1201 self.peer.send(
1202 connection_id,
1203 proto::UpdateInviteInfo {
1204 url: format!(
1205 "{}{}",
1206 self.app_state.config.invite_link_prefix, invite_code
1207 ),
1208 count: user.invite_count as u32,
1209 },
1210 )?;
1211 }
1212 }
1213 }
1214 Ok(())
1215 }
1216
1217 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1218 ServerSnapshot {
1219 connection_pool: ConnectionPoolGuard {
1220 guard: self.connection_pool.lock(),
1221 _not_send: PhantomData,
1222 },
1223 peer: &self.peer,
1224 }
1225 }
1226}
1227
1228impl<'a> Deref for ConnectionPoolGuard<'a> {
1229 type Target = ConnectionPool;
1230
1231 fn deref(&self) -> &Self::Target {
1232 &self.guard
1233 }
1234}
1235
1236impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1237 fn deref_mut(&mut self) -> &mut Self::Target {
1238 &mut self.guard
1239 }
1240}
1241
1242impl<'a> Drop for ConnectionPoolGuard<'a> {
1243 fn drop(&mut self) {
1244 #[cfg(test)]
1245 self.check_invariants();
1246 }
1247}
1248
1249fn broadcast<F>(
1250 sender_id: Option<ConnectionId>,
1251 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1252 mut f: F,
1253) where
1254 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1255{
1256 for receiver_id in receiver_ids {
1257 if Some(receiver_id) != sender_id {
1258 if let Err(error) = f(receiver_id) {
1259 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1260 }
1261 }
1262 }
1263}
1264
1265pub struct ProtocolVersion(u32);
1266
1267impl Header for ProtocolVersion {
1268 fn name() -> &'static HeaderName {
1269 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1270 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1271 }
1272
1273 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1274 where
1275 Self: Sized,
1276 I: Iterator<Item = &'i axum::http::HeaderValue>,
1277 {
1278 let version = values
1279 .next()
1280 .ok_or_else(axum::headers::Error::invalid)?
1281 .to_str()
1282 .map_err(|_| axum::headers::Error::invalid())?
1283 .parse()
1284 .map_err(|_| axum::headers::Error::invalid())?;
1285 Ok(Self(version))
1286 }
1287
1288 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1289 values.extend([self.0.to_string().parse().unwrap()]);
1290 }
1291}
1292
1293pub struct AppVersionHeader(SemanticVersion);
1294impl Header for AppVersionHeader {
1295 fn name() -> &'static HeaderName {
1296 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1298 }
1299
1300 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301 where
1302 Self: Sized,
1303 I: Iterator<Item = &'i axum::http::HeaderValue>,
1304 {
1305 let version = values
1306 .next()
1307 .ok_or_else(axum::headers::Error::invalid)?
1308 .to_str()
1309 .map_err(|_| axum::headers::Error::invalid())?
1310 .parse()
1311 .map_err(|_| axum::headers::Error::invalid())?;
1312 Ok(Self(version))
1313 }
1314
1315 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316 values.extend([self.0.to_string().parse().unwrap()]);
1317 }
1318}
1319
1320pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1321 Router::new()
1322 .route("/rpc", get(handle_websocket_request))
1323 .layer(
1324 ServiceBuilder::new()
1325 .layer(Extension(server.app_state.clone()))
1326 .layer(middleware::from_fn(auth::validate_header)),
1327 )
1328 .route("/metrics", get(handle_metrics))
1329 .layer(Extension(server))
1330}
1331
1332pub async fn handle_websocket_request(
1333 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1334 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1335 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1336 Extension(server): Extension<Arc<Server>>,
1337 Extension(principal): Extension<Principal>,
1338 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1339 ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341 if protocol_version != rpc::PROTOCOL_VERSION {
1342 return (
1343 StatusCode::UPGRADE_REQUIRED,
1344 "client must be upgraded".to_string(),
1345 )
1346 .into_response();
1347 }
1348
1349 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350 return (
1351 StatusCode::UPGRADE_REQUIRED,
1352 "no version header found".to_string(),
1353 )
1354 .into_response();
1355 };
1356
1357 if !version.can_collaborate() {
1358 return (
1359 StatusCode::UPGRADE_REQUIRED,
1360 "client must be upgraded".to_string(),
1361 )
1362 .into_response();
1363 }
1364
1365 let socket_address = socket_address.to_string();
1366 ws.on_upgrade(move |socket| {
1367 let socket = socket
1368 .map_ok(to_tungstenite_message)
1369 .err_into()
1370 .with(|message| async move { to_axum_message(message) });
1371 let connection = Connection::new(Box::pin(socket));
1372 async move {
1373 server
1374 .handle_connection(
1375 connection,
1376 socket_address,
1377 principal,
1378 version,
1379 country_code_header.map(|header| header.to_string()),
1380 None,
1381 Executor::Production,
1382 )
1383 .await;
1384 }
1385 })
1386}
1387
1388pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1389 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1390 let connections_metric = CONNECTIONS_METRIC
1391 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1392
1393 let connections = server
1394 .connection_pool
1395 .lock()
1396 .connections()
1397 .filter(|connection| !connection.admin)
1398 .count();
1399 connections_metric.set(connections as _);
1400
1401 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1402 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1403 register_int_gauge!(
1404 "shared_projects",
1405 "number of open projects with one or more guests"
1406 )
1407 .unwrap()
1408 });
1409
1410 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1411 shared_projects_metric.set(shared_projects as _);
1412
1413 let encoder = prometheus::TextEncoder::new();
1414 let metric_families = prometheus::gather();
1415 let encoded_metrics = encoder
1416 .encode_to_string(&metric_families)
1417 .map_err(|err| anyhow!("{}", err))?;
1418 Ok(encoded_metrics)
1419}
1420
1421#[instrument(err, skip(executor))]
1422async fn connection_lost(
1423 session: Session,
1424 mut teardown: watch::Receiver<bool>,
1425 executor: Executor,
1426) -> Result<()> {
1427 session.peer.disconnect(session.connection_id);
1428 session
1429 .connection_pool()
1430 .await
1431 .remove_connection(session.connection_id)?;
1432
1433 session
1434 .db()
1435 .await
1436 .connection_lost(session.connection_id)
1437 .await
1438 .trace_err();
1439
1440 futures::select_biased! {
1441 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1442 match &session.principal {
1443 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1444 let session = session.for_user().unwrap();
1445
1446 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1447 leave_room_for_session(&session, session.connection_id).await.trace_err();
1448 leave_channel_buffers_for_session(&session)
1449 .await
1450 .trace_err();
1451
1452 if !session
1453 .connection_pool()
1454 .await
1455 .is_user_online(session.user_id())
1456 {
1457 let db = session.db().await;
1458 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1459 room_updated(&room, &session.peer);
1460 }
1461 }
1462
1463 update_user_contacts(session.user_id(), &session).await?;
1464 },
1465 Principal::DevServer(_) => {
1466 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1467 },
1468 }
1469 },
1470 _ = teardown.changed().fuse() => {}
1471 }
1472
1473 Ok(())
1474}
1475
1476/// Acknowledges a ping from a client, used to keep the connection alive.
1477async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1478 response.send(proto::Ack {})?;
1479 Ok(())
1480}
1481
1482/// Creates a new room for calling (outside of channels)
1483async fn create_room(
1484 _request: proto::CreateRoom,
1485 response: Response<proto::CreateRoom>,
1486 session: UserSession,
1487) -> Result<()> {
1488 let live_kit_room = nanoid::nanoid!(30);
1489
1490 let live_kit_connection_info = util::maybe!(async {
1491 let live_kit = session.app_state.live_kit_client.as_ref();
1492 let live_kit = live_kit?;
1493 let user_id = session.user_id().to_string();
1494
1495 let token = live_kit
1496 .room_token(&live_kit_room, &user_id.to_string())
1497 .trace_err()?;
1498
1499 Some(proto::LiveKitConnectionInfo {
1500 server_url: live_kit.url().into(),
1501 token,
1502 can_publish: true,
1503 })
1504 })
1505 .await;
1506
1507 let room = session
1508 .db()
1509 .await
1510 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1511 .await?;
1512
1513 response.send(proto::CreateRoomResponse {
1514 room: Some(room.clone()),
1515 live_kit_connection_info,
1516 })?;
1517
1518 update_user_contacts(session.user_id(), &session).await?;
1519 Ok(())
1520}
1521
1522/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1523async fn join_room(
1524 request: proto::JoinRoom,
1525 response: Response<proto::JoinRoom>,
1526 session: UserSession,
1527) -> Result<()> {
1528 let room_id = RoomId::from_proto(request.id);
1529
1530 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1531
1532 if let Some(channel_id) = channel_id {
1533 return join_channel_internal(channel_id, Box::new(response), session).await;
1534 }
1535
1536 let joined_room = {
1537 let room = session
1538 .db()
1539 .await
1540 .join_room(room_id, session.user_id(), session.connection_id)
1541 .await?;
1542 room_updated(&room.room, &session.peer);
1543 room.into_inner()
1544 };
1545
1546 for connection_id in session
1547 .connection_pool()
1548 .await
1549 .user_connection_ids(session.user_id())
1550 {
1551 session
1552 .peer
1553 .send(
1554 connection_id,
1555 proto::CallCanceled {
1556 room_id: room_id.to_proto(),
1557 },
1558 )
1559 .trace_err();
1560 }
1561
1562 let live_kit_connection_info =
1563 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1564 live_kit
1565 .room_token(
1566 &joined_room.room.live_kit_room,
1567 &session.user_id().to_string(),
1568 )
1569 .trace_err()
1570 .map(|token| proto::LiveKitConnectionInfo {
1571 server_url: live_kit.url().into(),
1572 token,
1573 can_publish: true,
1574 })
1575 } else {
1576 None
1577 };
1578
1579 response.send(proto::JoinRoomResponse {
1580 room: Some(joined_room.room),
1581 channel_id: None,
1582 live_kit_connection_info,
1583 })?;
1584
1585 update_user_contacts(session.user_id(), &session).await?;
1586 Ok(())
1587}
1588
1589/// Rejoin room is used to reconnect to a room after connection errors.
1590async fn rejoin_room(
1591 request: proto::RejoinRoom,
1592 response: Response<proto::RejoinRoom>,
1593 session: UserSession,
1594) -> Result<()> {
1595 let room;
1596 let channel;
1597 {
1598 let mut rejoined_room = session
1599 .db()
1600 .await
1601 .rejoin_room(request, session.user_id(), session.connection_id)
1602 .await?;
1603
1604 response.send(proto::RejoinRoomResponse {
1605 room: Some(rejoined_room.room.clone()),
1606 reshared_projects: rejoined_room
1607 .reshared_projects
1608 .iter()
1609 .map(|project| proto::ResharedProject {
1610 id: project.id.to_proto(),
1611 collaborators: project
1612 .collaborators
1613 .iter()
1614 .map(|collaborator| collaborator.to_proto())
1615 .collect(),
1616 })
1617 .collect(),
1618 rejoined_projects: rejoined_room
1619 .rejoined_projects
1620 .iter()
1621 .map(|rejoined_project| rejoined_project.to_proto())
1622 .collect(),
1623 })?;
1624 room_updated(&rejoined_room.room, &session.peer);
1625
1626 for project in &rejoined_room.reshared_projects {
1627 for collaborator in &project.collaborators {
1628 session
1629 .peer
1630 .send(
1631 collaborator.connection_id,
1632 proto::UpdateProjectCollaborator {
1633 project_id: project.id.to_proto(),
1634 old_peer_id: Some(project.old_connection_id.into()),
1635 new_peer_id: Some(session.connection_id.into()),
1636 },
1637 )
1638 .trace_err();
1639 }
1640
1641 broadcast(
1642 Some(session.connection_id),
1643 project
1644 .collaborators
1645 .iter()
1646 .map(|collaborator| collaborator.connection_id),
1647 |connection_id| {
1648 session.peer.forward_send(
1649 session.connection_id,
1650 connection_id,
1651 proto::UpdateProject {
1652 project_id: project.id.to_proto(),
1653 worktrees: project.worktrees.clone(),
1654 },
1655 )
1656 },
1657 );
1658 }
1659
1660 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1661
1662 let rejoined_room = rejoined_room.into_inner();
1663
1664 room = rejoined_room.room;
1665 channel = rejoined_room.channel;
1666 }
1667
1668 if let Some(channel) = channel {
1669 channel_updated(
1670 &channel,
1671 &room,
1672 &session.peer,
1673 &*session.connection_pool().await,
1674 );
1675 }
1676
1677 update_user_contacts(session.user_id(), &session).await?;
1678 Ok(())
1679}
1680
1681fn notify_rejoined_projects(
1682 rejoined_projects: &mut Vec<RejoinedProject>,
1683 session: &UserSession,
1684) -> Result<()> {
1685 for project in rejoined_projects.iter() {
1686 for collaborator in &project.collaborators {
1687 session
1688 .peer
1689 .send(
1690 collaborator.connection_id,
1691 proto::UpdateProjectCollaborator {
1692 project_id: project.id.to_proto(),
1693 old_peer_id: Some(project.old_connection_id.into()),
1694 new_peer_id: Some(session.connection_id.into()),
1695 },
1696 )
1697 .trace_err();
1698 }
1699 }
1700
1701 for project in rejoined_projects {
1702 for worktree in mem::take(&mut project.worktrees) {
1703 #[cfg(any(test, feature = "test-support"))]
1704 const MAX_CHUNK_SIZE: usize = 2;
1705 #[cfg(not(any(test, feature = "test-support")))]
1706 const MAX_CHUNK_SIZE: usize = 256;
1707
1708 // Stream this worktree's entries.
1709 let message = proto::UpdateWorktree {
1710 project_id: project.id.to_proto(),
1711 worktree_id: worktree.id,
1712 abs_path: worktree.abs_path.clone(),
1713 root_name: worktree.root_name,
1714 updated_entries: worktree.updated_entries,
1715 removed_entries: worktree.removed_entries,
1716 scan_id: worktree.scan_id,
1717 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1718 updated_repositories: worktree.updated_repositories,
1719 removed_repositories: worktree.removed_repositories,
1720 };
1721 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1722 session.peer.send(session.connection_id, update.clone())?;
1723 }
1724
1725 // Stream this worktree's diagnostics.
1726 for summary in worktree.diagnostic_summaries {
1727 session.peer.send(
1728 session.connection_id,
1729 proto::UpdateDiagnosticSummary {
1730 project_id: project.id.to_proto(),
1731 worktree_id: worktree.id,
1732 summary: Some(summary),
1733 },
1734 )?;
1735 }
1736
1737 for settings_file in worktree.settings_files {
1738 session.peer.send(
1739 session.connection_id,
1740 proto::UpdateWorktreeSettings {
1741 project_id: project.id.to_proto(),
1742 worktree_id: worktree.id,
1743 path: settings_file.path,
1744 content: Some(settings_file.content),
1745 },
1746 )?;
1747 }
1748 }
1749
1750 for language_server in &project.language_servers {
1751 session.peer.send(
1752 session.connection_id,
1753 proto::UpdateLanguageServer {
1754 project_id: project.id.to_proto(),
1755 language_server_id: language_server.id,
1756 variant: Some(
1757 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1758 proto::LspDiskBasedDiagnosticsUpdated {},
1759 ),
1760 ),
1761 },
1762 )?;
1763 }
1764 }
1765 Ok(())
1766}
1767
1768/// leave room disconnects from the room.
1769async fn leave_room(
1770 _: proto::LeaveRoom,
1771 response: Response<proto::LeaveRoom>,
1772 session: UserSession,
1773) -> Result<()> {
1774 leave_room_for_session(&session, session.connection_id).await?;
1775 response.send(proto::Ack {})?;
1776 Ok(())
1777}
1778
1779/// Updates the permissions of someone else in the room.
1780async fn set_room_participant_role(
1781 request: proto::SetRoomParticipantRole,
1782 response: Response<proto::SetRoomParticipantRole>,
1783 session: UserSession,
1784) -> Result<()> {
1785 let user_id = UserId::from_proto(request.user_id);
1786 let role = ChannelRole::from(request.role());
1787
1788 let (live_kit_room, can_publish) = {
1789 let room = session
1790 .db()
1791 .await
1792 .set_room_participant_role(
1793 session.user_id(),
1794 RoomId::from_proto(request.room_id),
1795 user_id,
1796 role,
1797 )
1798 .await?;
1799
1800 let live_kit_room = room.live_kit_room.clone();
1801 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1802 room_updated(&room, &session.peer);
1803 (live_kit_room, can_publish)
1804 };
1805
1806 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1807 live_kit
1808 .update_participant(
1809 live_kit_room.clone(),
1810 request.user_id.to_string(),
1811 live_kit_server::proto::ParticipantPermission {
1812 can_subscribe: true,
1813 can_publish,
1814 can_publish_data: can_publish,
1815 hidden: false,
1816 recorder: false,
1817 },
1818 )
1819 .await
1820 .trace_err();
1821 }
1822
1823 response.send(proto::Ack {})?;
1824 Ok(())
1825}
1826
1827/// Call someone else into the current room
1828async fn call(
1829 request: proto::Call,
1830 response: Response<proto::Call>,
1831 session: UserSession,
1832) -> Result<()> {
1833 let room_id = RoomId::from_proto(request.room_id);
1834 let calling_user_id = session.user_id();
1835 let calling_connection_id = session.connection_id;
1836 let called_user_id = UserId::from_proto(request.called_user_id);
1837 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1838 if !session
1839 .db()
1840 .await
1841 .has_contact(calling_user_id, called_user_id)
1842 .await?
1843 {
1844 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1845 }
1846
1847 let incoming_call = {
1848 let (room, incoming_call) = &mut *session
1849 .db()
1850 .await
1851 .call(
1852 room_id,
1853 calling_user_id,
1854 calling_connection_id,
1855 called_user_id,
1856 initial_project_id,
1857 )
1858 .await?;
1859 room_updated(room, &session.peer);
1860 mem::take(incoming_call)
1861 };
1862 update_user_contacts(called_user_id, &session).await?;
1863
1864 let mut calls = session
1865 .connection_pool()
1866 .await
1867 .user_connection_ids(called_user_id)
1868 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1869 .collect::<FuturesUnordered<_>>();
1870
1871 while let Some(call_response) = calls.next().await {
1872 match call_response.as_ref() {
1873 Ok(_) => {
1874 response.send(proto::Ack {})?;
1875 return Ok(());
1876 }
1877 Err(_) => {
1878 call_response.trace_err();
1879 }
1880 }
1881 }
1882
1883 {
1884 let room = session
1885 .db()
1886 .await
1887 .call_failed(room_id, called_user_id)
1888 .await?;
1889 room_updated(&room, &session.peer);
1890 }
1891 update_user_contacts(called_user_id, &session).await?;
1892
1893 Err(anyhow!("failed to ring user"))?
1894}
1895
1896/// Cancel an outgoing call.
1897async fn cancel_call(
1898 request: proto::CancelCall,
1899 response: Response<proto::CancelCall>,
1900 session: UserSession,
1901) -> Result<()> {
1902 let called_user_id = UserId::from_proto(request.called_user_id);
1903 let room_id = RoomId::from_proto(request.room_id);
1904 {
1905 let room = session
1906 .db()
1907 .await
1908 .cancel_call(room_id, session.connection_id, called_user_id)
1909 .await?;
1910 room_updated(&room, &session.peer);
1911 }
1912
1913 for connection_id in session
1914 .connection_pool()
1915 .await
1916 .user_connection_ids(called_user_id)
1917 {
1918 session
1919 .peer
1920 .send(
1921 connection_id,
1922 proto::CallCanceled {
1923 room_id: room_id.to_proto(),
1924 },
1925 )
1926 .trace_err();
1927 }
1928 response.send(proto::Ack {})?;
1929
1930 update_user_contacts(called_user_id, &session).await?;
1931 Ok(())
1932}
1933
1934/// Decline an incoming call.
1935async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1936 let room_id = RoomId::from_proto(message.room_id);
1937 {
1938 let room = session
1939 .db()
1940 .await
1941 .decline_call(Some(room_id), session.user_id())
1942 .await?
1943 .ok_or_else(|| anyhow!("failed to decline call"))?;
1944 room_updated(&room, &session.peer);
1945 }
1946
1947 for connection_id in session
1948 .connection_pool()
1949 .await
1950 .user_connection_ids(session.user_id())
1951 {
1952 session
1953 .peer
1954 .send(
1955 connection_id,
1956 proto::CallCanceled {
1957 room_id: room_id.to_proto(),
1958 },
1959 )
1960 .trace_err();
1961 }
1962 update_user_contacts(session.user_id(), &session).await?;
1963 Ok(())
1964}
1965
1966/// Updates other participants in the room with your current location.
1967async fn update_participant_location(
1968 request: proto::UpdateParticipantLocation,
1969 response: Response<proto::UpdateParticipantLocation>,
1970 session: UserSession,
1971) -> Result<()> {
1972 let room_id = RoomId::from_proto(request.room_id);
1973 let location = request
1974 .location
1975 .ok_or_else(|| anyhow!("invalid location"))?;
1976
1977 let db = session.db().await;
1978 let room = db
1979 .update_room_participant_location(room_id, session.connection_id, location)
1980 .await?;
1981
1982 room_updated(&room, &session.peer);
1983 response.send(proto::Ack {})?;
1984 Ok(())
1985}
1986
1987/// Share a project into the room.
1988async fn share_project(
1989 request: proto::ShareProject,
1990 response: Response<proto::ShareProject>,
1991 session: UserSession,
1992) -> Result<()> {
1993 let (project_id, room) = &*session
1994 .db()
1995 .await
1996 .share_project(
1997 RoomId::from_proto(request.room_id),
1998 session.connection_id,
1999 &request.worktrees,
2000 request.is_ssh_project,
2001 request
2002 .dev_server_project_id
2003 .map(DevServerProjectId::from_proto),
2004 )
2005 .await?;
2006 response.send(proto::ShareProjectResponse {
2007 project_id: project_id.to_proto(),
2008 })?;
2009 room_updated(room, &session.peer);
2010
2011 Ok(())
2012}
2013
2014/// Unshare a project from the room.
2015async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2016 let project_id = ProjectId::from_proto(message.project_id);
2017 unshare_project_internal(
2018 project_id,
2019 session.connection_id,
2020 session.user_id(),
2021 &session,
2022 )
2023 .await
2024}
2025
2026async fn unshare_project_internal(
2027 project_id: ProjectId,
2028 connection_id: ConnectionId,
2029 user_id: Option<UserId>,
2030 session: &Session,
2031) -> Result<()> {
2032 let delete = {
2033 let room_guard = session
2034 .db()
2035 .await
2036 .unshare_project(project_id, connection_id, user_id)
2037 .await?;
2038
2039 let (delete, room, guest_connection_ids) = &*room_guard;
2040
2041 let message = proto::UnshareProject {
2042 project_id: project_id.to_proto(),
2043 };
2044
2045 broadcast(
2046 Some(connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |conn_id| session.peer.send(conn_id, message.clone()),
2049 );
2050 if let Some(room) = room {
2051 room_updated(room, &session.peer);
2052 }
2053
2054 *delete
2055 };
2056
2057 if delete {
2058 let db = session.db().await;
2059 db.delete_project(project_id).await?;
2060 }
2061
2062 Ok(())
2063}
2064
2065/// DevServer makes a project available online
2066async fn share_dev_server_project(
2067 request: proto::ShareDevServerProject,
2068 response: Response<proto::ShareDevServerProject>,
2069 session: DevServerSession,
2070) -> Result<()> {
2071 let (dev_server_project, user_id, status) = session
2072 .db()
2073 .await
2074 .share_dev_server_project(
2075 DevServerProjectId::from_proto(request.dev_server_project_id),
2076 session.dev_server_id(),
2077 session.connection_id,
2078 &request.worktrees,
2079 )
2080 .await?;
2081 let Some(project_id) = dev_server_project.project_id else {
2082 return Err(anyhow!("failed to share remote project"))?;
2083 };
2084
2085 send_dev_server_projects_update(user_id, status, &session).await;
2086
2087 response.send(proto::ShareProjectResponse { project_id })?;
2088
2089 Ok(())
2090}
2091
2092/// Join someone elses shared project.
2093async fn join_project(
2094 request: proto::JoinProject,
2095 response: Response<proto::JoinProject>,
2096 session: UserSession,
2097) -> Result<()> {
2098 let project_id = ProjectId::from_proto(request.project_id);
2099
2100 tracing::info!(%project_id, "join project");
2101
2102 let db = session.db().await;
2103 let (project, replica_id) = &mut *db
2104 .join_project(project_id, session.connection_id, session.user_id())
2105 .await?;
2106 drop(db);
2107 tracing::info!(%project_id, "join remote project");
2108 join_project_internal(response, session, project, replica_id)
2109}
2110
2111trait JoinProjectInternalResponse {
2112 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2113}
2114impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2115 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2116 Response::<proto::JoinProject>::send(self, result)
2117 }
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2120 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121 Response::<proto::JoinHostedProject>::send(self, result)
2122 }
2123}
2124
2125fn join_project_internal(
2126 response: impl JoinProjectInternalResponse,
2127 session: UserSession,
2128 project: &mut Project,
2129 replica_id: &ReplicaId,
2130) -> Result<()> {
2131 let collaborators = project
2132 .collaborators
2133 .iter()
2134 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2135 .map(|collaborator| collaborator.to_proto())
2136 .collect::<Vec<_>>();
2137 let project_id = project.id;
2138 let guest_user_id = session.user_id();
2139
2140 let worktrees = project
2141 .worktrees
2142 .iter()
2143 .map(|(id, worktree)| proto::WorktreeMetadata {
2144 id: *id,
2145 root_name: worktree.root_name.clone(),
2146 visible: worktree.visible,
2147 abs_path: worktree.abs_path.clone(),
2148 })
2149 .collect::<Vec<_>>();
2150
2151 let add_project_collaborator = proto::AddProjectCollaborator {
2152 project_id: project_id.to_proto(),
2153 collaborator: Some(proto::Collaborator {
2154 peer_id: Some(session.connection_id.into()),
2155 replica_id: replica_id.0 as u32,
2156 user_id: guest_user_id.to_proto(),
2157 }),
2158 };
2159
2160 for collaborator in &collaborators {
2161 session
2162 .peer
2163 .send(
2164 collaborator.peer_id.unwrap().into(),
2165 add_project_collaborator.clone(),
2166 )
2167 .trace_err();
2168 }
2169
2170 // First, we send the metadata associated with each worktree.
2171 response.send(proto::JoinProjectResponse {
2172 project_id: project.id.0 as u64,
2173 worktrees: worktrees.clone(),
2174 replica_id: replica_id.0 as u32,
2175 collaborators: collaborators.clone(),
2176 language_servers: project.language_servers.clone(),
2177 role: project.role.into(),
2178 dev_server_project_id: project
2179 .dev_server_project_id
2180 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2181 })?;
2182
2183 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2184 #[cfg(any(test, feature = "test-support"))]
2185 const MAX_CHUNK_SIZE: usize = 2;
2186 #[cfg(not(any(test, feature = "test-support")))]
2187 const MAX_CHUNK_SIZE: usize = 256;
2188
2189 // Stream this worktree's entries.
2190 let message = proto::UpdateWorktree {
2191 project_id: project_id.to_proto(),
2192 worktree_id,
2193 abs_path: worktree.abs_path.clone(),
2194 root_name: worktree.root_name,
2195 updated_entries: worktree.entries,
2196 removed_entries: Default::default(),
2197 scan_id: worktree.scan_id,
2198 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2199 updated_repositories: worktree.repository_entries.into_values().collect(),
2200 removed_repositories: Default::default(),
2201 };
2202 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2203 session.peer.send(session.connection_id, update.clone())?;
2204 }
2205
2206 // Stream this worktree's diagnostics.
2207 for summary in worktree.diagnostic_summaries {
2208 session.peer.send(
2209 session.connection_id,
2210 proto::UpdateDiagnosticSummary {
2211 project_id: project_id.to_proto(),
2212 worktree_id: worktree.id,
2213 summary: Some(summary),
2214 },
2215 )?;
2216 }
2217
2218 for settings_file in worktree.settings_files {
2219 session.peer.send(
2220 session.connection_id,
2221 proto::UpdateWorktreeSettings {
2222 project_id: project_id.to_proto(),
2223 worktree_id: worktree.id,
2224 path: settings_file.path,
2225 content: Some(settings_file.content),
2226 },
2227 )?;
2228 }
2229 }
2230
2231 for language_server in &project.language_servers {
2232 session.peer.send(
2233 session.connection_id,
2234 proto::UpdateLanguageServer {
2235 project_id: project_id.to_proto(),
2236 language_server_id: language_server.id,
2237 variant: Some(
2238 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2239 proto::LspDiskBasedDiagnosticsUpdated {},
2240 ),
2241 ),
2242 },
2243 )?;
2244 }
2245
2246 Ok(())
2247}
2248
2249/// Leave someone elses shared project.
2250async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2251 let sender_id = session.connection_id;
2252 let project_id = ProjectId::from_proto(request.project_id);
2253 let db = session.db().await;
2254 if db.is_hosted_project(project_id).await? {
2255 let project = db.leave_hosted_project(project_id, sender_id).await?;
2256 project_left(&project, &session);
2257 return Ok(());
2258 }
2259
2260 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2261 tracing::info!(
2262 %project_id,
2263 "leave project"
2264 );
2265
2266 project_left(project, &session);
2267 if let Some(room) = room {
2268 room_updated(room, &session.peer);
2269 }
2270
2271 Ok(())
2272}
2273
2274async fn join_hosted_project(
2275 request: proto::JoinHostedProject,
2276 response: Response<proto::JoinHostedProject>,
2277 session: UserSession,
2278) -> Result<()> {
2279 let (mut project, replica_id) = session
2280 .db()
2281 .await
2282 .join_hosted_project(
2283 ProjectId(request.project_id as i32),
2284 session.user_id(),
2285 session.connection_id,
2286 )
2287 .await?;
2288
2289 join_project_internal(response, session, &mut project, &replica_id)
2290}
2291
2292async fn list_remote_directory(
2293 request: proto::ListRemoteDirectory,
2294 response: Response<proto::ListRemoteDirectory>,
2295 session: UserSession,
2296) -> Result<()> {
2297 let dev_server_id = DevServerId(request.dev_server_id as i32);
2298 let dev_server_connection_id = session
2299 .connection_pool()
2300 .await
2301 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2302
2303 session
2304 .db()
2305 .await
2306 .get_dev_server_for_user(dev_server_id, session.user_id())
2307 .await?;
2308
2309 response.send(
2310 session
2311 .peer
2312 .forward_request(session.connection_id, dev_server_connection_id, request)
2313 .await?,
2314 )?;
2315 Ok(())
2316}
2317
2318async fn update_dev_server_project(
2319 request: proto::UpdateDevServerProject,
2320 response: Response<proto::UpdateDevServerProject>,
2321 session: UserSession,
2322) -> Result<()> {
2323 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2324
2325 let (dev_server_project, update) = session
2326 .db()
2327 .await
2328 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2329 .await?;
2330
2331 let projects = session
2332 .db()
2333 .await
2334 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2335 .await?;
2336
2337 let dev_server_connection_id = session
2338 .connection_pool()
2339 .await
2340 .dev_server_connection_id_supporting(
2341 dev_server_project.dev_server_id,
2342 ZedVersion::with_list_directory(),
2343 )?;
2344
2345 session.peer.send(
2346 dev_server_connection_id,
2347 proto::DevServerInstructions { projects },
2348 )?;
2349
2350 send_dev_server_projects_update(session.user_id(), update, &session).await;
2351
2352 response.send(proto::Ack {})
2353}
2354
2355async fn create_dev_server_project(
2356 request: proto::CreateDevServerProject,
2357 response: Response<proto::CreateDevServerProject>,
2358 session: UserSession,
2359) -> Result<()> {
2360 let dev_server_id = DevServerId(request.dev_server_id as i32);
2361 let dev_server_connection_id = session
2362 .connection_pool()
2363 .await
2364 .dev_server_connection_id(dev_server_id);
2365 let Some(dev_server_connection_id) = dev_server_connection_id else {
2366 Err(ErrorCode::DevServerOffline
2367 .message("Cannot create a remote project when the dev server is offline".to_string())
2368 .anyhow())?
2369 };
2370
2371 let path = request.path.clone();
2372 //Check that the path exists on the dev server
2373 session
2374 .peer
2375 .forward_request(
2376 session.connection_id,
2377 dev_server_connection_id,
2378 proto::ValidateDevServerProjectRequest { path: path.clone() },
2379 )
2380 .await?;
2381
2382 let (dev_server_project, update) = session
2383 .db()
2384 .await
2385 .create_dev_server_project(
2386 DevServerId(request.dev_server_id as i32),
2387 &request.path,
2388 session.user_id(),
2389 )
2390 .await?;
2391
2392 let projects = session
2393 .db()
2394 .await
2395 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2396 .await?;
2397
2398 session.peer.send(
2399 dev_server_connection_id,
2400 proto::DevServerInstructions { projects },
2401 )?;
2402
2403 send_dev_server_projects_update(session.user_id(), update, &session).await;
2404
2405 response.send(proto::CreateDevServerProjectResponse {
2406 dev_server_project: Some(dev_server_project.to_proto(None)),
2407 })?;
2408 Ok(())
2409}
2410
2411async fn create_dev_server(
2412 request: proto::CreateDevServer,
2413 response: Response<proto::CreateDevServer>,
2414 session: UserSession,
2415) -> Result<()> {
2416 let access_token = auth::random_token();
2417 let hashed_access_token = auth::hash_access_token(&access_token);
2418
2419 if request.name.is_empty() {
2420 return Err(proto::ErrorCode::Forbidden
2421 .message("Dev server name cannot be empty".to_string())
2422 .anyhow())?;
2423 }
2424
2425 let (dev_server, status) = session
2426 .db()
2427 .await
2428 .create_dev_server(
2429 &request.name,
2430 request.ssh_connection_string.as_deref(),
2431 &hashed_access_token,
2432 session.user_id(),
2433 )
2434 .await?;
2435
2436 send_dev_server_projects_update(session.user_id(), status, &session).await;
2437
2438 response.send(proto::CreateDevServerResponse {
2439 dev_server_id: dev_server.id.0 as u64,
2440 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2441 name: request.name,
2442 })?;
2443 Ok(())
2444}
2445
2446async fn regenerate_dev_server_token(
2447 request: proto::RegenerateDevServerToken,
2448 response: Response<proto::RegenerateDevServerToken>,
2449 session: UserSession,
2450) -> Result<()> {
2451 let dev_server_id = DevServerId(request.dev_server_id as i32);
2452 let access_token = auth::random_token();
2453 let hashed_access_token = auth::hash_access_token(&access_token);
2454
2455 let connection_id = session
2456 .connection_pool()
2457 .await
2458 .dev_server_connection_id(dev_server_id);
2459 if let Some(connection_id) = connection_id {
2460 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2461 session.peer.send(
2462 connection_id,
2463 proto::ShutdownDevServer {
2464 reason: Some("dev server token was regenerated".to_string()),
2465 },
2466 )?;
2467 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2468 }
2469
2470 let status = session
2471 .db()
2472 .await
2473 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2474 .await?;
2475
2476 send_dev_server_projects_update(session.user_id(), status, &session).await;
2477
2478 response.send(proto::RegenerateDevServerTokenResponse {
2479 dev_server_id: dev_server_id.to_proto(),
2480 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2481 })?;
2482 Ok(())
2483}
2484
2485async fn rename_dev_server(
2486 request: proto::RenameDevServer,
2487 response: Response<proto::RenameDevServer>,
2488 session: UserSession,
2489) -> Result<()> {
2490 if request.name.trim().is_empty() {
2491 return Err(proto::ErrorCode::Forbidden
2492 .message("Dev server name cannot be empty".to_string())
2493 .anyhow())?;
2494 }
2495
2496 let dev_server_id = DevServerId(request.dev_server_id as i32);
2497 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2498 if dev_server.user_id != session.user_id() {
2499 return Err(anyhow!(ErrorCode::Forbidden))?;
2500 }
2501
2502 let status = session
2503 .db()
2504 .await
2505 .rename_dev_server(
2506 dev_server_id,
2507 &request.name,
2508 request.ssh_connection_string.as_deref(),
2509 session.user_id(),
2510 )
2511 .await?;
2512
2513 send_dev_server_projects_update(session.user_id(), status, &session).await;
2514
2515 response.send(proto::Ack {})?;
2516 Ok(())
2517}
2518
2519async fn delete_dev_server(
2520 request: proto::DeleteDevServer,
2521 response: Response<proto::DeleteDevServer>,
2522 session: UserSession,
2523) -> Result<()> {
2524 let dev_server_id = DevServerId(request.dev_server_id as i32);
2525 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2526 if dev_server.user_id != session.user_id() {
2527 return Err(anyhow!(ErrorCode::Forbidden))?;
2528 }
2529
2530 let connection_id = session
2531 .connection_pool()
2532 .await
2533 .dev_server_connection_id(dev_server_id);
2534 if let Some(connection_id) = connection_id {
2535 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2536 session.peer.send(
2537 connection_id,
2538 proto::ShutdownDevServer {
2539 reason: Some("dev server was deleted".to_string()),
2540 },
2541 )?;
2542 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2543 }
2544
2545 let status = session
2546 .db()
2547 .await
2548 .delete_dev_server(dev_server_id, session.user_id())
2549 .await?;
2550
2551 send_dev_server_projects_update(session.user_id(), status, &session).await;
2552
2553 response.send(proto::Ack {})?;
2554 Ok(())
2555}
2556
2557async fn delete_dev_server_project(
2558 request: proto::DeleteDevServerProject,
2559 response: Response<proto::DeleteDevServerProject>,
2560 session: UserSession,
2561) -> Result<()> {
2562 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2563 let dev_server_project = session
2564 .db()
2565 .await
2566 .get_dev_server_project(dev_server_project_id)
2567 .await?;
2568
2569 let dev_server = session
2570 .db()
2571 .await
2572 .get_dev_server(dev_server_project.dev_server_id)
2573 .await?;
2574 if dev_server.user_id != session.user_id() {
2575 return Err(anyhow!(ErrorCode::Forbidden))?;
2576 }
2577
2578 let dev_server_connection_id = session
2579 .connection_pool()
2580 .await
2581 .dev_server_connection_id(dev_server.id);
2582
2583 if let Some(dev_server_connection_id) = dev_server_connection_id {
2584 let project = session
2585 .db()
2586 .await
2587 .find_dev_server_project(dev_server_project_id)
2588 .await;
2589 if let Ok(project) = project {
2590 unshare_project_internal(
2591 project.id,
2592 dev_server_connection_id,
2593 Some(session.user_id()),
2594 &session,
2595 )
2596 .await?;
2597 }
2598 }
2599
2600 let (projects, status) = session
2601 .db()
2602 .await
2603 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2604 .await?;
2605
2606 if let Some(dev_server_connection_id) = dev_server_connection_id {
2607 session.peer.send(
2608 dev_server_connection_id,
2609 proto::DevServerInstructions { projects },
2610 )?;
2611 }
2612
2613 send_dev_server_projects_update(session.user_id(), status, &session).await;
2614
2615 response.send(proto::Ack {})?;
2616 Ok(())
2617}
2618
2619async fn rejoin_dev_server_projects(
2620 request: proto::RejoinRemoteProjects,
2621 response: Response<proto::RejoinRemoteProjects>,
2622 session: UserSession,
2623) -> Result<()> {
2624 let mut rejoined_projects = {
2625 let db = session.db().await;
2626 db.rejoin_dev_server_projects(
2627 &request.rejoined_projects,
2628 session.user_id(),
2629 session.0.connection_id,
2630 )
2631 .await?
2632 };
2633 response.send(proto::RejoinRemoteProjectsResponse {
2634 rejoined_projects: rejoined_projects
2635 .iter()
2636 .map(|project| project.to_proto())
2637 .collect(),
2638 })?;
2639 notify_rejoined_projects(&mut rejoined_projects, &session)
2640}
2641
2642async fn reconnect_dev_server(
2643 request: proto::ReconnectDevServer,
2644 response: Response<proto::ReconnectDevServer>,
2645 session: DevServerSession,
2646) -> Result<()> {
2647 let reshared_projects = {
2648 let db = session.db().await;
2649 db.reshare_dev_server_projects(
2650 &request.reshared_projects,
2651 session.dev_server_id(),
2652 session.0.connection_id,
2653 )
2654 .await?
2655 };
2656
2657 for project in &reshared_projects {
2658 for collaborator in &project.collaborators {
2659 session
2660 .peer
2661 .send(
2662 collaborator.connection_id,
2663 proto::UpdateProjectCollaborator {
2664 project_id: project.id.to_proto(),
2665 old_peer_id: Some(project.old_connection_id.into()),
2666 new_peer_id: Some(session.connection_id.into()),
2667 },
2668 )
2669 .trace_err();
2670 }
2671
2672 broadcast(
2673 Some(session.connection_id),
2674 project
2675 .collaborators
2676 .iter()
2677 .map(|collaborator| collaborator.connection_id),
2678 |connection_id| {
2679 session.peer.forward_send(
2680 session.connection_id,
2681 connection_id,
2682 proto::UpdateProject {
2683 project_id: project.id.to_proto(),
2684 worktrees: project.worktrees.clone(),
2685 },
2686 )
2687 },
2688 );
2689 }
2690
2691 response.send(proto::ReconnectDevServerResponse {
2692 reshared_projects: reshared_projects
2693 .iter()
2694 .map(|project| proto::ResharedProject {
2695 id: project.id.to_proto(),
2696 collaborators: project
2697 .collaborators
2698 .iter()
2699 .map(|collaborator| collaborator.to_proto())
2700 .collect(),
2701 })
2702 .collect(),
2703 })?;
2704
2705 Ok(())
2706}
2707
2708async fn shutdown_dev_server(
2709 _: proto::ShutdownDevServer,
2710 response: Response<proto::ShutdownDevServer>,
2711 session: DevServerSession,
2712) -> Result<()> {
2713 response.send(proto::Ack {})?;
2714 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2715 remove_dev_server_connection(session.dev_server_id(), &session).await
2716}
2717
2718async fn shutdown_dev_server_internal(
2719 dev_server_id: DevServerId,
2720 connection_id: ConnectionId,
2721 session: &Session,
2722) -> Result<()> {
2723 let (dev_server_projects, dev_server) = {
2724 let db = session.db().await;
2725 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2726 let dev_server = db.get_dev_server(dev_server_id).await?;
2727 (dev_server_projects, dev_server)
2728 };
2729
2730 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2731 unshare_project_internal(
2732 ProjectId::from_proto(project_id),
2733 connection_id,
2734 None,
2735 session,
2736 )
2737 .await?;
2738 }
2739
2740 session
2741 .connection_pool()
2742 .await
2743 .set_dev_server_offline(dev_server_id);
2744
2745 let status = session
2746 .db()
2747 .await
2748 .dev_server_projects_update(dev_server.user_id)
2749 .await?;
2750 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2751
2752 Ok(())
2753}
2754
2755async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2756 let dev_server_connection = session
2757 .connection_pool()
2758 .await
2759 .dev_server_connection_id(dev_server_id);
2760
2761 if let Some(dev_server_connection) = dev_server_connection {
2762 session
2763 .connection_pool()
2764 .await
2765 .remove_connection(dev_server_connection)?;
2766 }
2767 Ok(())
2768}
2769
2770/// Updates other participants with changes to the project
2771async fn update_project(
2772 request: proto::UpdateProject,
2773 response: Response<proto::UpdateProject>,
2774 session: Session,
2775) -> Result<()> {
2776 let project_id = ProjectId::from_proto(request.project_id);
2777 let (room, guest_connection_ids) = &*session
2778 .db()
2779 .await
2780 .update_project(project_id, session.connection_id, &request.worktrees)
2781 .await?;
2782 broadcast(
2783 Some(session.connection_id),
2784 guest_connection_ids.iter().copied(),
2785 |connection_id| {
2786 session
2787 .peer
2788 .forward_send(session.connection_id, connection_id, request.clone())
2789 },
2790 );
2791 if let Some(room) = room {
2792 room_updated(room, &session.peer);
2793 }
2794 response.send(proto::Ack {})?;
2795
2796 Ok(())
2797}
2798
2799/// Updates other participants with changes to the worktree
2800async fn update_worktree(
2801 request: proto::UpdateWorktree,
2802 response: Response<proto::UpdateWorktree>,
2803 session: Session,
2804) -> Result<()> {
2805 let guest_connection_ids = session
2806 .db()
2807 .await
2808 .update_worktree(&request, session.connection_id)
2809 .await?;
2810
2811 broadcast(
2812 Some(session.connection_id),
2813 guest_connection_ids.iter().copied(),
2814 |connection_id| {
2815 session
2816 .peer
2817 .forward_send(session.connection_id, connection_id, request.clone())
2818 },
2819 );
2820 response.send(proto::Ack {})?;
2821 Ok(())
2822}
2823
2824/// Updates other participants with changes to the diagnostics
2825async fn update_diagnostic_summary(
2826 message: proto::UpdateDiagnosticSummary,
2827 session: Session,
2828) -> Result<()> {
2829 let guest_connection_ids = session
2830 .db()
2831 .await
2832 .update_diagnostic_summary(&message, session.connection_id)
2833 .await?;
2834
2835 broadcast(
2836 Some(session.connection_id),
2837 guest_connection_ids.iter().copied(),
2838 |connection_id| {
2839 session
2840 .peer
2841 .forward_send(session.connection_id, connection_id, message.clone())
2842 },
2843 );
2844
2845 Ok(())
2846}
2847
2848/// Updates other participants with changes to the worktree settings
2849async fn update_worktree_settings(
2850 message: proto::UpdateWorktreeSettings,
2851 session: Session,
2852) -> Result<()> {
2853 let guest_connection_ids = session
2854 .db()
2855 .await
2856 .update_worktree_settings(&message, session.connection_id)
2857 .await?;
2858
2859 broadcast(
2860 Some(session.connection_id),
2861 guest_connection_ids.iter().copied(),
2862 |connection_id| {
2863 session
2864 .peer
2865 .forward_send(session.connection_id, connection_id, message.clone())
2866 },
2867 );
2868
2869 Ok(())
2870}
2871
2872/// Notify other participants that a language server has started.
2873async fn start_language_server(
2874 request: proto::StartLanguageServer,
2875 session: Session,
2876) -> Result<()> {
2877 let guest_connection_ids = session
2878 .db()
2879 .await
2880 .start_language_server(&request, session.connection_id)
2881 .await?;
2882
2883 broadcast(
2884 Some(session.connection_id),
2885 guest_connection_ids.iter().copied(),
2886 |connection_id| {
2887 session
2888 .peer
2889 .forward_send(session.connection_id, connection_id, request.clone())
2890 },
2891 );
2892 Ok(())
2893}
2894
2895/// Notify other participants that a language server has changed.
2896async fn update_language_server(
2897 request: proto::UpdateLanguageServer,
2898 session: Session,
2899) -> Result<()> {
2900 let project_id = ProjectId::from_proto(request.project_id);
2901 let project_connection_ids = session
2902 .db()
2903 .await
2904 .project_connection_ids(project_id, session.connection_id, true)
2905 .await?;
2906 broadcast(
2907 Some(session.connection_id),
2908 project_connection_ids.iter().copied(),
2909 |connection_id| {
2910 session
2911 .peer
2912 .forward_send(session.connection_id, connection_id, request.clone())
2913 },
2914 );
2915 Ok(())
2916}
2917
2918/// forward a project request to the host. These requests should be read only
2919/// as guests are allowed to send them.
2920async fn forward_read_only_project_request<T>(
2921 request: T,
2922 response: Response<T>,
2923 session: UserSession,
2924) -> Result<()>
2925where
2926 T: EntityMessage + RequestMessage,
2927{
2928 let project_id = ProjectId::from_proto(request.remote_entity_id());
2929 let host_connection_id = session
2930 .db()
2931 .await
2932 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2933 .await?;
2934 let payload = session
2935 .peer
2936 .forward_request(session.connection_id, host_connection_id, request)
2937 .await?;
2938 response.send(payload)?;
2939 Ok(())
2940}
2941
2942async fn forward_find_search_candidates_request(
2943 request: proto::FindSearchCandidates,
2944 response: Response<proto::FindSearchCandidates>,
2945 session: UserSession,
2946) -> Result<()> {
2947 let project_id = ProjectId::from_proto(request.remote_entity_id());
2948 let host_connection_id = session
2949 .db()
2950 .await
2951 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2952 .await?;
2953
2954 let host_version = session
2955 .connection_pool()
2956 .await
2957 .connection(host_connection_id)
2958 .map(|c| c.zed_version);
2959
2960 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2961 {
2962 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2963 let search = proto::SearchProject {
2964 project_id: project_id.to_proto(),
2965 query: query.query,
2966 regex: query.regex,
2967 whole_word: query.whole_word,
2968 case_sensitive: query.case_sensitive,
2969 files_to_include: query.files_to_include,
2970 files_to_exclude: query.files_to_exclude,
2971 include_ignored: query.include_ignored,
2972 };
2973
2974 let payload = session
2975 .peer
2976 .forward_request(session.connection_id, host_connection_id, search)
2977 .await?;
2978 return response.send(proto::FindSearchCandidatesResponse {
2979 buffer_ids: payload
2980 .locations
2981 .into_iter()
2982 .map(|loc| loc.buffer_id)
2983 .collect(),
2984 });
2985 }
2986
2987 let payload = session
2988 .peer
2989 .forward_request(session.connection_id, host_connection_id, request)
2990 .await?;
2991 response.send(payload)?;
2992 Ok(())
2993}
2994
2995/// forward a project request to the dev server. Only allowed
2996/// if it's your dev server.
2997async fn forward_project_request_for_owner<T>(
2998 request: T,
2999 response: Response<T>,
3000 session: UserSession,
3001) -> Result<()>
3002where
3003 T: EntityMessage + RequestMessage,
3004{
3005 let project_id = ProjectId::from_proto(request.remote_entity_id());
3006
3007 let host_connection_id = session
3008 .db()
3009 .await
3010 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3011 .await?;
3012 let payload = session
3013 .peer
3014 .forward_request(session.connection_id, host_connection_id, request)
3015 .await?;
3016 response.send(payload)?;
3017 Ok(())
3018}
3019
3020/// forward a project request to the host. These requests are disallowed
3021/// for guests.
3022async fn forward_mutating_project_request<T>(
3023 request: T,
3024 response: Response<T>,
3025 session: UserSession,
3026) -> Result<()>
3027where
3028 T: EntityMessage + RequestMessage,
3029{
3030 let project_id = ProjectId::from_proto(request.remote_entity_id());
3031
3032 let host_connection_id = session
3033 .db()
3034 .await
3035 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3036 .await?;
3037 let payload = session
3038 .peer
3039 .forward_request(session.connection_id, host_connection_id, request)
3040 .await?;
3041 response.send(payload)?;
3042 Ok(())
3043}
3044
3045/// Notify other participants that a new buffer has been created
3046async fn create_buffer_for_peer(
3047 request: proto::CreateBufferForPeer,
3048 session: Session,
3049) -> Result<()> {
3050 session
3051 .db()
3052 .await
3053 .check_user_is_project_host(
3054 ProjectId::from_proto(request.project_id),
3055 session.connection_id,
3056 )
3057 .await?;
3058 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3059 session
3060 .peer
3061 .forward_send(session.connection_id, peer_id.into(), request)?;
3062 Ok(())
3063}
3064
3065/// Notify other participants that a buffer has been updated. This is
3066/// allowed for guests as long as the update is limited to selections.
3067async fn update_buffer(
3068 request: proto::UpdateBuffer,
3069 response: Response<proto::UpdateBuffer>,
3070 session: Session,
3071) -> Result<()> {
3072 let project_id = ProjectId::from_proto(request.project_id);
3073 let mut capability = Capability::ReadOnly;
3074
3075 for op in request.operations.iter() {
3076 match op.variant {
3077 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3078 Some(_) => capability = Capability::ReadWrite,
3079 }
3080 }
3081
3082 let host = {
3083 let guard = session
3084 .db()
3085 .await
3086 .connections_for_buffer_update(
3087 project_id,
3088 session.principal_id(),
3089 session.connection_id,
3090 capability,
3091 )
3092 .await?;
3093
3094 let (host, guests) = &*guard;
3095
3096 broadcast(
3097 Some(session.connection_id),
3098 guests.clone(),
3099 |connection_id| {
3100 session
3101 .peer
3102 .forward_send(session.connection_id, connection_id, request.clone())
3103 },
3104 );
3105
3106 *host
3107 };
3108
3109 if host != session.connection_id {
3110 session
3111 .peer
3112 .forward_request(session.connection_id, host, request.clone())
3113 .await?;
3114 }
3115
3116 response.send(proto::Ack {})?;
3117 Ok(())
3118}
3119
3120async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3121 let project_id = ProjectId::from_proto(message.project_id);
3122
3123 let operation = message.operation.as_ref().context("invalid operation")?;
3124 let capability = match operation.variant.as_ref() {
3125 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3126 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3127 match buffer_op.variant {
3128 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3129 Capability::ReadOnly
3130 }
3131 _ => Capability::ReadWrite,
3132 }
3133 } else {
3134 Capability::ReadWrite
3135 }
3136 }
3137 Some(_) => Capability::ReadWrite,
3138 None => Capability::ReadOnly,
3139 };
3140
3141 let guard = session
3142 .db()
3143 .await
3144 .connections_for_buffer_update(
3145 project_id,
3146 session.principal_id(),
3147 session.connection_id,
3148 capability,
3149 )
3150 .await?;
3151
3152 let (host, guests) = &*guard;
3153
3154 broadcast(
3155 Some(session.connection_id),
3156 guests.iter().chain([host]).copied(),
3157 |connection_id| {
3158 session
3159 .peer
3160 .forward_send(session.connection_id, connection_id, message.clone())
3161 },
3162 );
3163
3164 Ok(())
3165}
3166
3167/// Notify other participants that a project has been updated.
3168async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3169 request: T,
3170 session: Session,
3171) -> Result<()> {
3172 let project_id = ProjectId::from_proto(request.remote_entity_id());
3173 let project_connection_ids = session
3174 .db()
3175 .await
3176 .project_connection_ids(project_id, session.connection_id, false)
3177 .await?;
3178
3179 broadcast(
3180 Some(session.connection_id),
3181 project_connection_ids.iter().copied(),
3182 |connection_id| {
3183 session
3184 .peer
3185 .forward_send(session.connection_id, connection_id, request.clone())
3186 },
3187 );
3188 Ok(())
3189}
3190
3191/// Start following another user in a call.
3192async fn follow(
3193 request: proto::Follow,
3194 response: Response<proto::Follow>,
3195 session: UserSession,
3196) -> Result<()> {
3197 let room_id = RoomId::from_proto(request.room_id);
3198 let project_id = request.project_id.map(ProjectId::from_proto);
3199 let leader_id = request
3200 .leader_id
3201 .ok_or_else(|| anyhow!("invalid leader id"))?
3202 .into();
3203 let follower_id = session.connection_id;
3204
3205 session
3206 .db()
3207 .await
3208 .check_room_participants(room_id, leader_id, session.connection_id)
3209 .await?;
3210
3211 let response_payload = session
3212 .peer
3213 .forward_request(session.connection_id, leader_id, request)
3214 .await?;
3215 response.send(response_payload)?;
3216
3217 if let Some(project_id) = project_id {
3218 let room = session
3219 .db()
3220 .await
3221 .follow(room_id, project_id, leader_id, follower_id)
3222 .await?;
3223 room_updated(&room, &session.peer);
3224 }
3225
3226 Ok(())
3227}
3228
3229/// Stop following another user in a call.
3230async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3231 let room_id = RoomId::from_proto(request.room_id);
3232 let project_id = request.project_id.map(ProjectId::from_proto);
3233 let leader_id = request
3234 .leader_id
3235 .ok_or_else(|| anyhow!("invalid leader id"))?
3236 .into();
3237 let follower_id = session.connection_id;
3238
3239 session
3240 .db()
3241 .await
3242 .check_room_participants(room_id, leader_id, session.connection_id)
3243 .await?;
3244
3245 session
3246 .peer
3247 .forward_send(session.connection_id, leader_id, request)?;
3248
3249 if let Some(project_id) = project_id {
3250 let room = session
3251 .db()
3252 .await
3253 .unfollow(room_id, project_id, leader_id, follower_id)
3254 .await?;
3255 room_updated(&room, &session.peer);
3256 }
3257
3258 Ok(())
3259}
3260
3261/// Notify everyone following you of your current location.
3262async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3263 let room_id = RoomId::from_proto(request.room_id);
3264 let database = session.db.lock().await;
3265
3266 let connection_ids = if let Some(project_id) = request.project_id {
3267 let project_id = ProjectId::from_proto(project_id);
3268 database
3269 .project_connection_ids(project_id, session.connection_id, true)
3270 .await?
3271 } else {
3272 database
3273 .room_connection_ids(room_id, session.connection_id)
3274 .await?
3275 };
3276
3277 // For now, don't send view update messages back to that view's current leader.
3278 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3279 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3280 _ => None,
3281 });
3282
3283 for connection_id in connection_ids.iter().cloned() {
3284 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3285 session
3286 .peer
3287 .forward_send(session.connection_id, connection_id, request.clone())?;
3288 }
3289 }
3290 Ok(())
3291}
3292
3293/// Get public data about users.
3294async fn get_users(
3295 request: proto::GetUsers,
3296 response: Response<proto::GetUsers>,
3297 session: Session,
3298) -> Result<()> {
3299 let user_ids = request
3300 .user_ids
3301 .into_iter()
3302 .map(UserId::from_proto)
3303 .collect();
3304 let users = session
3305 .db()
3306 .await
3307 .get_users_by_ids(user_ids)
3308 .await?
3309 .into_iter()
3310 .map(|user| proto::User {
3311 id: user.id.to_proto(),
3312 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3313 github_login: user.github_login,
3314 })
3315 .collect();
3316 response.send(proto::UsersResponse { users })?;
3317 Ok(())
3318}
3319
3320/// Search for users (to invite) buy Github login
3321async fn fuzzy_search_users(
3322 request: proto::FuzzySearchUsers,
3323 response: Response<proto::FuzzySearchUsers>,
3324 session: UserSession,
3325) -> Result<()> {
3326 let query = request.query;
3327 let users = match query.len() {
3328 0 => vec![],
3329 1 | 2 => session
3330 .db()
3331 .await
3332 .get_user_by_github_login(&query)
3333 .await?
3334 .into_iter()
3335 .collect(),
3336 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3337 };
3338 let users = users
3339 .into_iter()
3340 .filter(|user| user.id != session.user_id())
3341 .map(|user| proto::User {
3342 id: user.id.to_proto(),
3343 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3344 github_login: user.github_login,
3345 })
3346 .collect();
3347 response.send(proto::UsersResponse { users })?;
3348 Ok(())
3349}
3350
3351/// Send a contact request to another user.
3352async fn request_contact(
3353 request: proto::RequestContact,
3354 response: Response<proto::RequestContact>,
3355 session: UserSession,
3356) -> Result<()> {
3357 let requester_id = session.user_id();
3358 let responder_id = UserId::from_proto(request.responder_id);
3359 if requester_id == responder_id {
3360 return Err(anyhow!("cannot add yourself as a contact"))?;
3361 }
3362
3363 let notifications = session
3364 .db()
3365 .await
3366 .send_contact_request(requester_id, responder_id)
3367 .await?;
3368
3369 // Update outgoing contact requests of requester
3370 let mut update = proto::UpdateContacts::default();
3371 update.outgoing_requests.push(responder_id.to_proto());
3372 for connection_id in session
3373 .connection_pool()
3374 .await
3375 .user_connection_ids(requester_id)
3376 {
3377 session.peer.send(connection_id, update.clone())?;
3378 }
3379
3380 // Update incoming contact requests of responder
3381 let mut update = proto::UpdateContacts::default();
3382 update
3383 .incoming_requests
3384 .push(proto::IncomingContactRequest {
3385 requester_id: requester_id.to_proto(),
3386 });
3387 let connection_pool = session.connection_pool().await;
3388 for connection_id in connection_pool.user_connection_ids(responder_id) {
3389 session.peer.send(connection_id, update.clone())?;
3390 }
3391
3392 send_notifications(&connection_pool, &session.peer, notifications);
3393
3394 response.send(proto::Ack {})?;
3395 Ok(())
3396}
3397
3398/// Accept or decline a contact request
3399async fn respond_to_contact_request(
3400 request: proto::RespondToContactRequest,
3401 response: Response<proto::RespondToContactRequest>,
3402 session: UserSession,
3403) -> Result<()> {
3404 let responder_id = session.user_id();
3405 let requester_id = UserId::from_proto(request.requester_id);
3406 let db = session.db().await;
3407 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3408 db.dismiss_contact_notification(responder_id, requester_id)
3409 .await?;
3410 } else {
3411 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3412
3413 let notifications = db
3414 .respond_to_contact_request(responder_id, requester_id, accept)
3415 .await?;
3416 let requester_busy = db.is_user_busy(requester_id).await?;
3417 let responder_busy = db.is_user_busy(responder_id).await?;
3418
3419 let pool = session.connection_pool().await;
3420 // Update responder with new contact
3421 let mut update = proto::UpdateContacts::default();
3422 if accept {
3423 update
3424 .contacts
3425 .push(contact_for_user(requester_id, requester_busy, &pool));
3426 }
3427 update
3428 .remove_incoming_requests
3429 .push(requester_id.to_proto());
3430 for connection_id in pool.user_connection_ids(responder_id) {
3431 session.peer.send(connection_id, update.clone())?;
3432 }
3433
3434 // Update requester with new contact
3435 let mut update = proto::UpdateContacts::default();
3436 if accept {
3437 update
3438 .contacts
3439 .push(contact_for_user(responder_id, responder_busy, &pool));
3440 }
3441 update
3442 .remove_outgoing_requests
3443 .push(responder_id.to_proto());
3444
3445 for connection_id in pool.user_connection_ids(requester_id) {
3446 session.peer.send(connection_id, update.clone())?;
3447 }
3448
3449 send_notifications(&pool, &session.peer, notifications);
3450 }
3451
3452 response.send(proto::Ack {})?;
3453 Ok(())
3454}
3455
3456/// Remove a contact.
3457async fn remove_contact(
3458 request: proto::RemoveContact,
3459 response: Response<proto::RemoveContact>,
3460 session: UserSession,
3461) -> Result<()> {
3462 let requester_id = session.user_id();
3463 let responder_id = UserId::from_proto(request.user_id);
3464 let db = session.db().await;
3465 let (contact_accepted, deleted_notification_id) =
3466 db.remove_contact(requester_id, responder_id).await?;
3467
3468 let pool = session.connection_pool().await;
3469 // Update outgoing contact requests of requester
3470 let mut update = proto::UpdateContacts::default();
3471 if contact_accepted {
3472 update.remove_contacts.push(responder_id.to_proto());
3473 } else {
3474 update
3475 .remove_outgoing_requests
3476 .push(responder_id.to_proto());
3477 }
3478 for connection_id in pool.user_connection_ids(requester_id) {
3479 session.peer.send(connection_id, update.clone())?;
3480 }
3481
3482 // Update incoming contact requests of responder
3483 let mut update = proto::UpdateContacts::default();
3484 if contact_accepted {
3485 update.remove_contacts.push(requester_id.to_proto());
3486 } else {
3487 update
3488 .remove_incoming_requests
3489 .push(requester_id.to_proto());
3490 }
3491 for connection_id in pool.user_connection_ids(responder_id) {
3492 session.peer.send(connection_id, update.clone())?;
3493 if let Some(notification_id) = deleted_notification_id {
3494 session.peer.send(
3495 connection_id,
3496 proto::DeleteNotification {
3497 notification_id: notification_id.to_proto(),
3498 },
3499 )?;
3500 }
3501 }
3502
3503 response.send(proto::Ack {})?;
3504 Ok(())
3505}
3506
3507fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3508 version.0.minor() < 139
3509}
3510
3511async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3512 let plan = session.current_plan(session.db().await).await?;
3513
3514 session
3515 .peer
3516 .send(
3517 session.connection_id,
3518 proto::UpdateUserPlan { plan: plan.into() },
3519 )
3520 .trace_err();
3521
3522 Ok(())
3523}
3524
3525async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3526 subscribe_user_to_channels(
3527 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3528 &session,
3529 )
3530 .await?;
3531 Ok(())
3532}
3533
3534async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3535 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3536 let mut pool = session.connection_pool().await;
3537 for membership in &channels_for_user.channel_memberships {
3538 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3539 }
3540 session.peer.send(
3541 session.connection_id,
3542 build_update_user_channels(&channels_for_user),
3543 )?;
3544 session.peer.send(
3545 session.connection_id,
3546 build_channels_update(channels_for_user),
3547 )?;
3548 Ok(())
3549}
3550
3551/// Creates a new channel.
3552async fn create_channel(
3553 request: proto::CreateChannel,
3554 response: Response<proto::CreateChannel>,
3555 session: UserSession,
3556) -> Result<()> {
3557 let db = session.db().await;
3558
3559 let parent_id = request.parent_id.map(ChannelId::from_proto);
3560 let (channel, membership) = db
3561 .create_channel(&request.name, parent_id, session.user_id())
3562 .await?;
3563
3564 let root_id = channel.root_id();
3565 let channel = Channel::from_model(channel);
3566
3567 response.send(proto::CreateChannelResponse {
3568 channel: Some(channel.to_proto()),
3569 parent_id: request.parent_id,
3570 })?;
3571
3572 let mut connection_pool = session.connection_pool().await;
3573 if let Some(membership) = membership {
3574 connection_pool.subscribe_to_channel(
3575 membership.user_id,
3576 membership.channel_id,
3577 membership.role,
3578 );
3579 let update = proto::UpdateUserChannels {
3580 channel_memberships: vec![proto::ChannelMembership {
3581 channel_id: membership.channel_id.to_proto(),
3582 role: membership.role.into(),
3583 }],
3584 ..Default::default()
3585 };
3586 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3587 session.peer.send(connection_id, update.clone())?;
3588 }
3589 }
3590
3591 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3592 if !role.can_see_channel(channel.visibility) {
3593 continue;
3594 }
3595
3596 let update = proto::UpdateChannels {
3597 channels: vec![channel.to_proto()],
3598 ..Default::default()
3599 };
3600 session.peer.send(connection_id, update.clone())?;
3601 }
3602
3603 Ok(())
3604}
3605
3606/// Delete a channel
3607async fn delete_channel(
3608 request: proto::DeleteChannel,
3609 response: Response<proto::DeleteChannel>,
3610 session: UserSession,
3611) -> Result<()> {
3612 let db = session.db().await;
3613
3614 let channel_id = request.channel_id;
3615 let (root_channel, removed_channels) = db
3616 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3617 .await?;
3618 response.send(proto::Ack {})?;
3619
3620 // Notify members of removed channels
3621 let mut update = proto::UpdateChannels::default();
3622 update
3623 .delete_channels
3624 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3625
3626 let connection_pool = session.connection_pool().await;
3627 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3628 session.peer.send(connection_id, update.clone())?;
3629 }
3630
3631 Ok(())
3632}
3633
3634/// Invite someone to join a channel.
3635async fn invite_channel_member(
3636 request: proto::InviteChannelMember,
3637 response: Response<proto::InviteChannelMember>,
3638 session: UserSession,
3639) -> Result<()> {
3640 let db = session.db().await;
3641 let channel_id = ChannelId::from_proto(request.channel_id);
3642 let invitee_id = UserId::from_proto(request.user_id);
3643 let InviteMemberResult {
3644 channel,
3645 notifications,
3646 } = db
3647 .invite_channel_member(
3648 channel_id,
3649 invitee_id,
3650 session.user_id(),
3651 request.role().into(),
3652 )
3653 .await?;
3654
3655 let update = proto::UpdateChannels {
3656 channel_invitations: vec![channel.to_proto()],
3657 ..Default::default()
3658 };
3659
3660 let connection_pool = session.connection_pool().await;
3661 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3662 session.peer.send(connection_id, update.clone())?;
3663 }
3664
3665 send_notifications(&connection_pool, &session.peer, notifications);
3666
3667 response.send(proto::Ack {})?;
3668 Ok(())
3669}
3670
3671/// remove someone from a channel
3672async fn remove_channel_member(
3673 request: proto::RemoveChannelMember,
3674 response: Response<proto::RemoveChannelMember>,
3675 session: UserSession,
3676) -> Result<()> {
3677 let db = session.db().await;
3678 let channel_id = ChannelId::from_proto(request.channel_id);
3679 let member_id = UserId::from_proto(request.user_id);
3680
3681 let RemoveChannelMemberResult {
3682 membership_update,
3683 notification_id,
3684 } = db
3685 .remove_channel_member(channel_id, member_id, session.user_id())
3686 .await?;
3687
3688 let mut connection_pool = session.connection_pool().await;
3689 notify_membership_updated(
3690 &mut connection_pool,
3691 membership_update,
3692 member_id,
3693 &session.peer,
3694 );
3695 for connection_id in connection_pool.user_connection_ids(member_id) {
3696 if let Some(notification_id) = notification_id {
3697 session
3698 .peer
3699 .send(
3700 connection_id,
3701 proto::DeleteNotification {
3702 notification_id: notification_id.to_proto(),
3703 },
3704 )
3705 .trace_err();
3706 }
3707 }
3708
3709 response.send(proto::Ack {})?;
3710 Ok(())
3711}
3712
3713/// Toggle the channel between public and private.
3714/// Care is taken to maintain the invariant that public channels only descend from public channels,
3715/// (though members-only channels can appear at any point in the hierarchy).
3716async fn set_channel_visibility(
3717 request: proto::SetChannelVisibility,
3718 response: Response<proto::SetChannelVisibility>,
3719 session: UserSession,
3720) -> Result<()> {
3721 let db = session.db().await;
3722 let channel_id = ChannelId::from_proto(request.channel_id);
3723 let visibility = request.visibility().into();
3724
3725 let channel_model = db
3726 .set_channel_visibility(channel_id, visibility, session.user_id())
3727 .await?;
3728 let root_id = channel_model.root_id();
3729 let channel = Channel::from_model(channel_model);
3730
3731 let mut connection_pool = session.connection_pool().await;
3732 for (user_id, role) in connection_pool
3733 .channel_user_ids(root_id)
3734 .collect::<Vec<_>>()
3735 .into_iter()
3736 {
3737 let update = if role.can_see_channel(channel.visibility) {
3738 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3739 proto::UpdateChannels {
3740 channels: vec![channel.to_proto()],
3741 ..Default::default()
3742 }
3743 } else {
3744 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3745 proto::UpdateChannels {
3746 delete_channels: vec![channel.id.to_proto()],
3747 ..Default::default()
3748 }
3749 };
3750
3751 for connection_id in connection_pool.user_connection_ids(user_id) {
3752 session.peer.send(connection_id, update.clone())?;
3753 }
3754 }
3755
3756 response.send(proto::Ack {})?;
3757 Ok(())
3758}
3759
3760/// Alter the role for a user in the channel.
3761async fn set_channel_member_role(
3762 request: proto::SetChannelMemberRole,
3763 response: Response<proto::SetChannelMemberRole>,
3764 session: UserSession,
3765) -> Result<()> {
3766 let db = session.db().await;
3767 let channel_id = ChannelId::from_proto(request.channel_id);
3768 let member_id = UserId::from_proto(request.user_id);
3769 let result = db
3770 .set_channel_member_role(
3771 channel_id,
3772 session.user_id(),
3773 member_id,
3774 request.role().into(),
3775 )
3776 .await?;
3777
3778 match result {
3779 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3780 let mut connection_pool = session.connection_pool().await;
3781 notify_membership_updated(
3782 &mut connection_pool,
3783 membership_update,
3784 member_id,
3785 &session.peer,
3786 )
3787 }
3788 db::SetMemberRoleResult::InviteUpdated(channel) => {
3789 let update = proto::UpdateChannels {
3790 channel_invitations: vec![channel.to_proto()],
3791 ..Default::default()
3792 };
3793
3794 for connection_id in session
3795 .connection_pool()
3796 .await
3797 .user_connection_ids(member_id)
3798 {
3799 session.peer.send(connection_id, update.clone())?;
3800 }
3801 }
3802 }
3803
3804 response.send(proto::Ack {})?;
3805 Ok(())
3806}
3807
3808/// Change the name of a channel
3809async fn rename_channel(
3810 request: proto::RenameChannel,
3811 response: Response<proto::RenameChannel>,
3812 session: UserSession,
3813) -> Result<()> {
3814 let db = session.db().await;
3815 let channel_id = ChannelId::from_proto(request.channel_id);
3816 let channel_model = db
3817 .rename_channel(channel_id, session.user_id(), &request.name)
3818 .await?;
3819 let root_id = channel_model.root_id();
3820 let channel = Channel::from_model(channel_model);
3821
3822 response.send(proto::RenameChannelResponse {
3823 channel: Some(channel.to_proto()),
3824 })?;
3825
3826 let connection_pool = session.connection_pool().await;
3827 let update = proto::UpdateChannels {
3828 channels: vec![channel.to_proto()],
3829 ..Default::default()
3830 };
3831 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3832 if role.can_see_channel(channel.visibility) {
3833 session.peer.send(connection_id, update.clone())?;
3834 }
3835 }
3836
3837 Ok(())
3838}
3839
3840/// Move a channel to a new parent.
3841async fn move_channel(
3842 request: proto::MoveChannel,
3843 response: Response<proto::MoveChannel>,
3844 session: UserSession,
3845) -> Result<()> {
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847 let to = ChannelId::from_proto(request.to);
3848
3849 let (root_id, channels) = session
3850 .db()
3851 .await
3852 .move_channel(channel_id, to, session.user_id())
3853 .await?;
3854
3855 let connection_pool = session.connection_pool().await;
3856 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3857 let channels = channels
3858 .iter()
3859 .filter_map(|channel| {
3860 if role.can_see_channel(channel.visibility) {
3861 Some(channel.to_proto())
3862 } else {
3863 None
3864 }
3865 })
3866 .collect::<Vec<_>>();
3867 if channels.is_empty() {
3868 continue;
3869 }
3870
3871 let update = proto::UpdateChannels {
3872 channels,
3873 ..Default::default()
3874 };
3875
3876 session.peer.send(connection_id, update.clone())?;
3877 }
3878
3879 response.send(Ack {})?;
3880 Ok(())
3881}
3882
3883/// Get the list of channel members
3884async fn get_channel_members(
3885 request: proto::GetChannelMembers,
3886 response: Response<proto::GetChannelMembers>,
3887 session: UserSession,
3888) -> Result<()> {
3889 let db = session.db().await;
3890 let channel_id = ChannelId::from_proto(request.channel_id);
3891 let limit = if request.limit == 0 {
3892 u16::MAX as u64
3893 } else {
3894 request.limit
3895 };
3896 let (members, users) = db
3897 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3898 .await?;
3899 response.send(proto::GetChannelMembersResponse { members, users })?;
3900 Ok(())
3901}
3902
3903/// Accept or decline a channel invitation.
3904async fn respond_to_channel_invite(
3905 request: proto::RespondToChannelInvite,
3906 response: Response<proto::RespondToChannelInvite>,
3907 session: UserSession,
3908) -> Result<()> {
3909 let db = session.db().await;
3910 let channel_id = ChannelId::from_proto(request.channel_id);
3911 let RespondToChannelInvite {
3912 membership_update,
3913 notifications,
3914 } = db
3915 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3916 .await?;
3917
3918 let mut connection_pool = session.connection_pool().await;
3919 if let Some(membership_update) = membership_update {
3920 notify_membership_updated(
3921 &mut connection_pool,
3922 membership_update,
3923 session.user_id(),
3924 &session.peer,
3925 );
3926 } else {
3927 let update = proto::UpdateChannels {
3928 remove_channel_invitations: vec![channel_id.to_proto()],
3929 ..Default::default()
3930 };
3931
3932 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3933 session.peer.send(connection_id, update.clone())?;
3934 }
3935 };
3936
3937 send_notifications(&connection_pool, &session.peer, notifications);
3938
3939 response.send(proto::Ack {})?;
3940
3941 Ok(())
3942}
3943
3944/// Join the channels' room
3945async fn join_channel(
3946 request: proto::JoinChannel,
3947 response: Response<proto::JoinChannel>,
3948 session: UserSession,
3949) -> Result<()> {
3950 let channel_id = ChannelId::from_proto(request.channel_id);
3951 join_channel_internal(channel_id, Box::new(response), session).await
3952}
3953
3954trait JoinChannelInternalResponse {
3955 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3956}
3957impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3958 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3959 Response::<proto::JoinChannel>::send(self, result)
3960 }
3961}
3962impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3963 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3964 Response::<proto::JoinRoom>::send(self, result)
3965 }
3966}
3967
3968async fn join_channel_internal(
3969 channel_id: ChannelId,
3970 response: Box<impl JoinChannelInternalResponse>,
3971 session: UserSession,
3972) -> Result<()> {
3973 let joined_room = {
3974 let mut db = session.db().await;
3975 // If zed quits without leaving the room, and the user re-opens zed before the
3976 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3977 // room they were in.
3978 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3979 tracing::info!(
3980 stale_connection_id = %connection,
3981 "cleaning up stale connection",
3982 );
3983 drop(db);
3984 leave_room_for_session(&session, connection).await?;
3985 db = session.db().await;
3986 }
3987
3988 let (joined_room, membership_updated, role) = db
3989 .join_channel(channel_id, session.user_id(), session.connection_id)
3990 .await?;
3991
3992 let live_kit_connection_info =
3993 session
3994 .app_state
3995 .live_kit_client
3996 .as_ref()
3997 .and_then(|live_kit| {
3998 let (can_publish, token) = if role == ChannelRole::Guest {
3999 (
4000 false,
4001 live_kit
4002 .guest_token(
4003 &joined_room.room.live_kit_room,
4004 &session.user_id().to_string(),
4005 )
4006 .trace_err()?,
4007 )
4008 } else {
4009 (
4010 true,
4011 live_kit
4012 .room_token(
4013 &joined_room.room.live_kit_room,
4014 &session.user_id().to_string(),
4015 )
4016 .trace_err()?,
4017 )
4018 };
4019
4020 Some(LiveKitConnectionInfo {
4021 server_url: live_kit.url().into(),
4022 token,
4023 can_publish,
4024 })
4025 });
4026
4027 response.send(proto::JoinRoomResponse {
4028 room: Some(joined_room.room.clone()),
4029 channel_id: joined_room
4030 .channel
4031 .as_ref()
4032 .map(|channel| channel.id.to_proto()),
4033 live_kit_connection_info,
4034 })?;
4035
4036 let mut connection_pool = session.connection_pool().await;
4037 if let Some(membership_updated) = membership_updated {
4038 notify_membership_updated(
4039 &mut connection_pool,
4040 membership_updated,
4041 session.user_id(),
4042 &session.peer,
4043 );
4044 }
4045
4046 room_updated(&joined_room.room, &session.peer);
4047
4048 joined_room
4049 };
4050
4051 channel_updated(
4052 &joined_room
4053 .channel
4054 .ok_or_else(|| anyhow!("channel not returned"))?,
4055 &joined_room.room,
4056 &session.peer,
4057 &*session.connection_pool().await,
4058 );
4059
4060 update_user_contacts(session.user_id(), &session).await?;
4061 Ok(())
4062}
4063
4064/// Start editing the channel notes
4065async fn join_channel_buffer(
4066 request: proto::JoinChannelBuffer,
4067 response: Response<proto::JoinChannelBuffer>,
4068 session: UserSession,
4069) -> Result<()> {
4070 let db = session.db().await;
4071 let channel_id = ChannelId::from_proto(request.channel_id);
4072
4073 let open_response = db
4074 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4075 .await?;
4076
4077 let collaborators = open_response.collaborators.clone();
4078 response.send(open_response)?;
4079
4080 let update = UpdateChannelBufferCollaborators {
4081 channel_id: channel_id.to_proto(),
4082 collaborators: collaborators.clone(),
4083 };
4084 channel_buffer_updated(
4085 session.connection_id,
4086 collaborators
4087 .iter()
4088 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4089 &update,
4090 &session.peer,
4091 );
4092
4093 Ok(())
4094}
4095
4096/// Edit the channel notes
4097async fn update_channel_buffer(
4098 request: proto::UpdateChannelBuffer,
4099 session: UserSession,
4100) -> Result<()> {
4101 let db = session.db().await;
4102 let channel_id = ChannelId::from_proto(request.channel_id);
4103
4104 let (collaborators, epoch, version) = db
4105 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4106 .await?;
4107
4108 channel_buffer_updated(
4109 session.connection_id,
4110 collaborators.clone(),
4111 &proto::UpdateChannelBuffer {
4112 channel_id: channel_id.to_proto(),
4113 operations: request.operations,
4114 },
4115 &session.peer,
4116 );
4117
4118 let pool = &*session.connection_pool().await;
4119
4120 let non_collaborators =
4121 pool.channel_connection_ids(channel_id)
4122 .filter_map(|(connection_id, _)| {
4123 if collaborators.contains(&connection_id) {
4124 None
4125 } else {
4126 Some(connection_id)
4127 }
4128 });
4129
4130 broadcast(None, non_collaborators, |peer_id| {
4131 session.peer.send(
4132 peer_id,
4133 proto::UpdateChannels {
4134 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4135 channel_id: channel_id.to_proto(),
4136 epoch: epoch as u64,
4137 version: version.clone(),
4138 }],
4139 ..Default::default()
4140 },
4141 )
4142 });
4143
4144 Ok(())
4145}
4146
4147/// Rejoin the channel notes after a connection blip
4148async fn rejoin_channel_buffers(
4149 request: proto::RejoinChannelBuffers,
4150 response: Response<proto::RejoinChannelBuffers>,
4151 session: UserSession,
4152) -> Result<()> {
4153 let db = session.db().await;
4154 let buffers = db
4155 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4156 .await?;
4157
4158 for rejoined_buffer in &buffers {
4159 let collaborators_to_notify = rejoined_buffer
4160 .buffer
4161 .collaborators
4162 .iter()
4163 .filter_map(|c| Some(c.peer_id?.into()));
4164 channel_buffer_updated(
4165 session.connection_id,
4166 collaborators_to_notify,
4167 &proto::UpdateChannelBufferCollaborators {
4168 channel_id: rejoined_buffer.buffer.channel_id,
4169 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4170 },
4171 &session.peer,
4172 );
4173 }
4174
4175 response.send(proto::RejoinChannelBuffersResponse {
4176 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4177 })?;
4178
4179 Ok(())
4180}
4181
4182/// Stop editing the channel notes
4183async fn leave_channel_buffer(
4184 request: proto::LeaveChannelBuffer,
4185 response: Response<proto::LeaveChannelBuffer>,
4186 session: UserSession,
4187) -> Result<()> {
4188 let db = session.db().await;
4189 let channel_id = ChannelId::from_proto(request.channel_id);
4190
4191 let left_buffer = db
4192 .leave_channel_buffer(channel_id, session.connection_id)
4193 .await?;
4194
4195 response.send(Ack {})?;
4196
4197 channel_buffer_updated(
4198 session.connection_id,
4199 left_buffer.connections,
4200 &proto::UpdateChannelBufferCollaborators {
4201 channel_id: channel_id.to_proto(),
4202 collaborators: left_buffer.collaborators,
4203 },
4204 &session.peer,
4205 );
4206
4207 Ok(())
4208}
4209
4210fn channel_buffer_updated<T: EnvelopedMessage>(
4211 sender_id: ConnectionId,
4212 collaborators: impl IntoIterator<Item = ConnectionId>,
4213 message: &T,
4214 peer: &Peer,
4215) {
4216 broadcast(Some(sender_id), collaborators, |peer_id| {
4217 peer.send(peer_id, message.clone())
4218 });
4219}
4220
4221fn send_notifications(
4222 connection_pool: &ConnectionPool,
4223 peer: &Peer,
4224 notifications: db::NotificationBatch,
4225) {
4226 for (user_id, notification) in notifications {
4227 for connection_id in connection_pool.user_connection_ids(user_id) {
4228 if let Err(error) = peer.send(
4229 connection_id,
4230 proto::AddNotification {
4231 notification: Some(notification.clone()),
4232 },
4233 ) {
4234 tracing::error!(
4235 "failed to send notification to {:?} {}",
4236 connection_id,
4237 error
4238 );
4239 }
4240 }
4241 }
4242}
4243
4244/// Send a message to the channel
4245async fn send_channel_message(
4246 request: proto::SendChannelMessage,
4247 response: Response<proto::SendChannelMessage>,
4248 session: UserSession,
4249) -> Result<()> {
4250 // Validate the message body.
4251 let body = request.body.trim().to_string();
4252 if body.len() > MAX_MESSAGE_LEN {
4253 return Err(anyhow!("message is too long"))?;
4254 }
4255 if body.is_empty() {
4256 return Err(anyhow!("message can't be blank"))?;
4257 }
4258
4259 // TODO: adjust mentions if body is trimmed
4260
4261 let timestamp = OffsetDateTime::now_utc();
4262 let nonce = request
4263 .nonce
4264 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4265
4266 let channel_id = ChannelId::from_proto(request.channel_id);
4267 let CreatedChannelMessage {
4268 message_id,
4269 participant_connection_ids,
4270 notifications,
4271 } = session
4272 .db()
4273 .await
4274 .create_channel_message(
4275 channel_id,
4276 session.user_id(),
4277 &body,
4278 &request.mentions,
4279 timestamp,
4280 nonce.clone().into(),
4281 request.reply_to_message_id.map(MessageId::from_proto),
4282 )
4283 .await?;
4284
4285 let message = proto::ChannelMessage {
4286 sender_id: session.user_id().to_proto(),
4287 id: message_id.to_proto(),
4288 body,
4289 mentions: request.mentions,
4290 timestamp: timestamp.unix_timestamp() as u64,
4291 nonce: Some(nonce),
4292 reply_to_message_id: request.reply_to_message_id,
4293 edited_at: None,
4294 };
4295 broadcast(
4296 Some(session.connection_id),
4297 participant_connection_ids.clone(),
4298 |connection| {
4299 session.peer.send(
4300 connection,
4301 proto::ChannelMessageSent {
4302 channel_id: channel_id.to_proto(),
4303 message: Some(message.clone()),
4304 },
4305 )
4306 },
4307 );
4308 response.send(proto::SendChannelMessageResponse {
4309 message: Some(message),
4310 })?;
4311
4312 let pool = &*session.connection_pool().await;
4313 let non_participants =
4314 pool.channel_connection_ids(channel_id)
4315 .filter_map(|(connection_id, _)| {
4316 if participant_connection_ids.contains(&connection_id) {
4317 None
4318 } else {
4319 Some(connection_id)
4320 }
4321 });
4322 broadcast(None, non_participants, |peer_id| {
4323 session.peer.send(
4324 peer_id,
4325 proto::UpdateChannels {
4326 latest_channel_message_ids: vec![proto::ChannelMessageId {
4327 channel_id: channel_id.to_proto(),
4328 message_id: message_id.to_proto(),
4329 }],
4330 ..Default::default()
4331 },
4332 )
4333 });
4334 send_notifications(pool, &session.peer, notifications);
4335
4336 Ok(())
4337}
4338
4339/// Delete a channel message
4340async fn remove_channel_message(
4341 request: proto::RemoveChannelMessage,
4342 response: Response<proto::RemoveChannelMessage>,
4343 session: UserSession,
4344) -> Result<()> {
4345 let channel_id = ChannelId::from_proto(request.channel_id);
4346 let message_id = MessageId::from_proto(request.message_id);
4347 let (connection_ids, existing_notification_ids) = session
4348 .db()
4349 .await
4350 .remove_channel_message(channel_id, message_id, session.user_id())
4351 .await?;
4352
4353 broadcast(
4354 Some(session.connection_id),
4355 connection_ids,
4356 move |connection| {
4357 session.peer.send(connection, request.clone())?;
4358
4359 for notification_id in &existing_notification_ids {
4360 session.peer.send(
4361 connection,
4362 proto::DeleteNotification {
4363 notification_id: (*notification_id).to_proto(),
4364 },
4365 )?;
4366 }
4367
4368 Ok(())
4369 },
4370 );
4371 response.send(proto::Ack {})?;
4372 Ok(())
4373}
4374
4375async fn update_channel_message(
4376 request: proto::UpdateChannelMessage,
4377 response: Response<proto::UpdateChannelMessage>,
4378 session: UserSession,
4379) -> Result<()> {
4380 let channel_id = ChannelId::from_proto(request.channel_id);
4381 let message_id = MessageId::from_proto(request.message_id);
4382 let updated_at = OffsetDateTime::now_utc();
4383 let UpdatedChannelMessage {
4384 message_id,
4385 participant_connection_ids,
4386 notifications,
4387 reply_to_message_id,
4388 timestamp,
4389 deleted_mention_notification_ids,
4390 updated_mention_notifications,
4391 } = session
4392 .db()
4393 .await
4394 .update_channel_message(
4395 channel_id,
4396 message_id,
4397 session.user_id(),
4398 request.body.as_str(),
4399 &request.mentions,
4400 updated_at,
4401 )
4402 .await?;
4403
4404 let nonce = request
4405 .nonce
4406 .clone()
4407 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4408
4409 let message = proto::ChannelMessage {
4410 sender_id: session.user_id().to_proto(),
4411 id: message_id.to_proto(),
4412 body: request.body.clone(),
4413 mentions: request.mentions.clone(),
4414 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4415 nonce: Some(nonce),
4416 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4417 edited_at: Some(updated_at.unix_timestamp() as u64),
4418 };
4419
4420 response.send(proto::Ack {})?;
4421
4422 let pool = &*session.connection_pool().await;
4423 broadcast(
4424 Some(session.connection_id),
4425 participant_connection_ids,
4426 |connection| {
4427 session.peer.send(
4428 connection,
4429 proto::ChannelMessageUpdate {
4430 channel_id: channel_id.to_proto(),
4431 message: Some(message.clone()),
4432 },
4433 )?;
4434
4435 for notification_id in &deleted_mention_notification_ids {
4436 session.peer.send(
4437 connection,
4438 proto::DeleteNotification {
4439 notification_id: (*notification_id).to_proto(),
4440 },
4441 )?;
4442 }
4443
4444 for notification in &updated_mention_notifications {
4445 session.peer.send(
4446 connection,
4447 proto::UpdateNotification {
4448 notification: Some(notification.clone()),
4449 },
4450 )?;
4451 }
4452
4453 Ok(())
4454 },
4455 );
4456
4457 send_notifications(pool, &session.peer, notifications);
4458
4459 Ok(())
4460}
4461
4462/// Mark a channel message as read
4463async fn acknowledge_channel_message(
4464 request: proto::AckChannelMessage,
4465 session: UserSession,
4466) -> Result<()> {
4467 let channel_id = ChannelId::from_proto(request.channel_id);
4468 let message_id = MessageId::from_proto(request.message_id);
4469 let notifications = session
4470 .db()
4471 .await
4472 .observe_channel_message(channel_id, session.user_id(), message_id)
4473 .await?;
4474 send_notifications(
4475 &*session.connection_pool().await,
4476 &session.peer,
4477 notifications,
4478 );
4479 Ok(())
4480}
4481
4482/// Mark a buffer version as synced
4483async fn acknowledge_buffer_version(
4484 request: proto::AckBufferOperation,
4485 session: UserSession,
4486) -> Result<()> {
4487 let buffer_id = BufferId::from_proto(request.buffer_id);
4488 session
4489 .db()
4490 .await
4491 .observe_buffer_version(
4492 buffer_id,
4493 session.user_id(),
4494 request.epoch as i32,
4495 &request.version,
4496 )
4497 .await?;
4498 Ok(())
4499}
4500
4501async fn count_language_model_tokens(
4502 request: proto::CountLanguageModelTokens,
4503 response: Response<proto::CountLanguageModelTokens>,
4504 session: Session,
4505 config: &Config,
4506) -> Result<()> {
4507 let Some(session) = session.for_user() else {
4508 return Err(anyhow!("user not found"))?;
4509 };
4510 authorize_access_to_legacy_llm_endpoints(&session).await?;
4511
4512 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4513 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4514 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4515 };
4516
4517 session
4518 .app_state
4519 .rate_limiter
4520 .check(&*rate_limit, session.user_id())
4521 .await?;
4522
4523 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4524 Some(proto::LanguageModelProvider::Google) => {
4525 let api_key = config
4526 .google_ai_api_key
4527 .as_ref()
4528 .context("no Google AI API key configured on the server")?;
4529 google_ai::count_tokens(
4530 session.http_client.as_ref(),
4531 google_ai::API_URL,
4532 api_key,
4533 serde_json::from_str(&request.request)?,
4534 None,
4535 )
4536 .await?
4537 }
4538 _ => return Err(anyhow!("unsupported provider"))?,
4539 };
4540
4541 response.send(proto::CountLanguageModelTokensResponse {
4542 token_count: result.total_tokens as u32,
4543 })?;
4544
4545 Ok(())
4546}
4547
4548struct ZedProCountLanguageModelTokensRateLimit;
4549
4550impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4551 fn capacity(&self) -> usize {
4552 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4553 .ok()
4554 .and_then(|v| v.parse().ok())
4555 .unwrap_or(600) // Picked arbitrarily
4556 }
4557
4558 fn refill_duration(&self) -> chrono::Duration {
4559 chrono::Duration::hours(1)
4560 }
4561
4562 fn db_name(&self) -> &'static str {
4563 "zed-pro:count-language-model-tokens"
4564 }
4565}
4566
4567struct FreeCountLanguageModelTokensRateLimit;
4568
4569impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4570 fn capacity(&self) -> usize {
4571 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4572 .ok()
4573 .and_then(|v| v.parse().ok())
4574 .unwrap_or(600 / 10) // Picked arbitrarily
4575 }
4576
4577 fn refill_duration(&self) -> chrono::Duration {
4578 chrono::Duration::hours(1)
4579 }
4580
4581 fn db_name(&self) -> &'static str {
4582 "free:count-language-model-tokens"
4583 }
4584}
4585
4586struct ZedProComputeEmbeddingsRateLimit;
4587
4588impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4589 fn capacity(&self) -> usize {
4590 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4591 .ok()
4592 .and_then(|v| v.parse().ok())
4593 .unwrap_or(5000) // Picked arbitrarily
4594 }
4595
4596 fn refill_duration(&self) -> chrono::Duration {
4597 chrono::Duration::hours(1)
4598 }
4599
4600 fn db_name(&self) -> &'static str {
4601 "zed-pro:compute-embeddings"
4602 }
4603}
4604
4605struct FreeComputeEmbeddingsRateLimit;
4606
4607impl RateLimit for FreeComputeEmbeddingsRateLimit {
4608 fn capacity(&self) -> usize {
4609 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4610 .ok()
4611 .and_then(|v| v.parse().ok())
4612 .unwrap_or(5000 / 10) // Picked arbitrarily
4613 }
4614
4615 fn refill_duration(&self) -> chrono::Duration {
4616 chrono::Duration::hours(1)
4617 }
4618
4619 fn db_name(&self) -> &'static str {
4620 "free:compute-embeddings"
4621 }
4622}
4623
4624async fn compute_embeddings(
4625 request: proto::ComputeEmbeddings,
4626 response: Response<proto::ComputeEmbeddings>,
4627 session: UserSession,
4628 api_key: Option<Arc<str>>,
4629) -> Result<()> {
4630 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4631 authorize_access_to_legacy_llm_endpoints(&session).await?;
4632
4633 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4634 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4635 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4636 };
4637
4638 session
4639 .app_state
4640 .rate_limiter
4641 .check(&*rate_limit, session.user_id())
4642 .await?;
4643
4644 let embeddings = match request.model.as_str() {
4645 "openai/text-embedding-3-small" => {
4646 open_ai::embed(
4647 session.http_client.as_ref(),
4648 OPEN_AI_API_URL,
4649 &api_key,
4650 OpenAiEmbeddingModel::TextEmbedding3Small,
4651 request.texts.iter().map(|text| text.as_str()),
4652 )
4653 .await?
4654 }
4655 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4656 };
4657
4658 let embeddings = request
4659 .texts
4660 .iter()
4661 .map(|text| {
4662 let mut hasher = sha2::Sha256::new();
4663 hasher.update(text.as_bytes());
4664 let result = hasher.finalize();
4665 result.to_vec()
4666 })
4667 .zip(
4668 embeddings
4669 .data
4670 .into_iter()
4671 .map(|embedding| embedding.embedding),
4672 )
4673 .collect::<HashMap<_, _>>();
4674
4675 let db = session.db().await;
4676 db.save_embeddings(&request.model, &embeddings)
4677 .await
4678 .context("failed to save embeddings")
4679 .trace_err();
4680
4681 response.send(proto::ComputeEmbeddingsResponse {
4682 embeddings: embeddings
4683 .into_iter()
4684 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4685 .collect(),
4686 })?;
4687 Ok(())
4688}
4689
4690async fn get_cached_embeddings(
4691 request: proto::GetCachedEmbeddings,
4692 response: Response<proto::GetCachedEmbeddings>,
4693 session: UserSession,
4694) -> Result<()> {
4695 authorize_access_to_legacy_llm_endpoints(&session).await?;
4696
4697 let db = session.db().await;
4698 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4699
4700 response.send(proto::GetCachedEmbeddingsResponse {
4701 embeddings: embeddings
4702 .into_iter()
4703 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4704 .collect(),
4705 })?;
4706 Ok(())
4707}
4708
4709/// This is leftover from before the LLM service.
4710///
4711/// The endpoints protected by this check will be moved there eventually.
4712async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4713 if session.is_staff() {
4714 Ok(())
4715 } else {
4716 Err(anyhow!("permission denied"))?
4717 }
4718}
4719
4720/// Get a Supermaven API key for the user
4721async fn get_supermaven_api_key(
4722 _request: proto::GetSupermavenApiKey,
4723 response: Response<proto::GetSupermavenApiKey>,
4724 session: UserSession,
4725) -> Result<()> {
4726 let user_id: String = session.user_id().to_string();
4727 if !session.is_staff() {
4728 return Err(anyhow!("supermaven not enabled for this account"))?;
4729 }
4730
4731 let email = session
4732 .email()
4733 .ok_or_else(|| anyhow!("user must have an email"))?;
4734
4735 let supermaven_admin_api = session
4736 .supermaven_client
4737 .as_ref()
4738 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4739
4740 let result = supermaven_admin_api
4741 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4742 .await?;
4743
4744 response.send(proto::GetSupermavenApiKeyResponse {
4745 api_key: result.api_key,
4746 })?;
4747
4748 Ok(())
4749}
4750
4751/// Start receiving chat updates for a channel
4752async fn join_channel_chat(
4753 request: proto::JoinChannelChat,
4754 response: Response<proto::JoinChannelChat>,
4755 session: UserSession,
4756) -> Result<()> {
4757 let channel_id = ChannelId::from_proto(request.channel_id);
4758
4759 let db = session.db().await;
4760 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4761 .await?;
4762 let messages = db
4763 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4764 .await?;
4765 response.send(proto::JoinChannelChatResponse {
4766 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4767 messages,
4768 })?;
4769 Ok(())
4770}
4771
4772/// Stop receiving chat updates for a channel
4773async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4774 let channel_id = ChannelId::from_proto(request.channel_id);
4775 session
4776 .db()
4777 .await
4778 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4779 .await?;
4780 Ok(())
4781}
4782
4783/// Retrieve the chat history for a channel
4784async fn get_channel_messages(
4785 request: proto::GetChannelMessages,
4786 response: Response<proto::GetChannelMessages>,
4787 session: UserSession,
4788) -> Result<()> {
4789 let channel_id = ChannelId::from_proto(request.channel_id);
4790 let messages = session
4791 .db()
4792 .await
4793 .get_channel_messages(
4794 channel_id,
4795 session.user_id(),
4796 MESSAGE_COUNT_PER_PAGE,
4797 Some(MessageId::from_proto(request.before_message_id)),
4798 )
4799 .await?;
4800 response.send(proto::GetChannelMessagesResponse {
4801 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4802 messages,
4803 })?;
4804 Ok(())
4805}
4806
4807/// Retrieve specific chat messages
4808async fn get_channel_messages_by_id(
4809 request: proto::GetChannelMessagesById,
4810 response: Response<proto::GetChannelMessagesById>,
4811 session: UserSession,
4812) -> Result<()> {
4813 let message_ids = request
4814 .message_ids
4815 .iter()
4816 .map(|id| MessageId::from_proto(*id))
4817 .collect::<Vec<_>>();
4818 let messages = session
4819 .db()
4820 .await
4821 .get_channel_messages_by_id(session.user_id(), &message_ids)
4822 .await?;
4823 response.send(proto::GetChannelMessagesResponse {
4824 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4825 messages,
4826 })?;
4827 Ok(())
4828}
4829
4830/// Retrieve the current users notifications
4831async fn get_notifications(
4832 request: proto::GetNotifications,
4833 response: Response<proto::GetNotifications>,
4834 session: UserSession,
4835) -> Result<()> {
4836 let notifications = session
4837 .db()
4838 .await
4839 .get_notifications(
4840 session.user_id(),
4841 NOTIFICATION_COUNT_PER_PAGE,
4842 request.before_id.map(db::NotificationId::from_proto),
4843 )
4844 .await?;
4845 response.send(proto::GetNotificationsResponse {
4846 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4847 notifications,
4848 })?;
4849 Ok(())
4850}
4851
4852/// Mark notifications as read
4853async fn mark_notification_as_read(
4854 request: proto::MarkNotificationRead,
4855 response: Response<proto::MarkNotificationRead>,
4856 session: UserSession,
4857) -> Result<()> {
4858 let database = &session.db().await;
4859 let notifications = database
4860 .mark_notification_as_read_by_id(
4861 session.user_id(),
4862 NotificationId::from_proto(request.notification_id),
4863 )
4864 .await?;
4865 send_notifications(
4866 &*session.connection_pool().await,
4867 &session.peer,
4868 notifications,
4869 );
4870 response.send(proto::Ack {})?;
4871 Ok(())
4872}
4873
4874/// Get the current users information
4875async fn get_private_user_info(
4876 _request: proto::GetPrivateUserInfo,
4877 response: Response<proto::GetPrivateUserInfo>,
4878 session: UserSession,
4879) -> Result<()> {
4880 let db = session.db().await;
4881
4882 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4883 let user = db
4884 .get_user_by_id(session.user_id())
4885 .await?
4886 .ok_or_else(|| anyhow!("user not found"))?;
4887 let flags = db.get_user_flags(session.user_id()).await?;
4888
4889 response.send(proto::GetPrivateUserInfoResponse {
4890 metrics_id,
4891 staff: user.admin,
4892 flags,
4893 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4894 })?;
4895 Ok(())
4896}
4897
4898/// Accept the terms of service (tos) on behalf of the current user
4899async fn accept_terms_of_service(
4900 _request: proto::AcceptTermsOfService,
4901 response: Response<proto::AcceptTermsOfService>,
4902 session: UserSession,
4903) -> Result<()> {
4904 let db = session.db().await;
4905
4906 let accepted_tos_at = Utc::now();
4907 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4908 .await?;
4909
4910 response.send(proto::AcceptTermsOfServiceResponse {
4911 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4912 })?;
4913 Ok(())
4914}
4915
4916/// The minimum account age an account must have in order to use the LLM service.
4917const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4918
4919async fn get_llm_api_token(
4920 _request: proto::GetLlmToken,
4921 response: Response<proto::GetLlmToken>,
4922 session: UserSession,
4923) -> Result<()> {
4924 let db = session.db().await;
4925
4926 let flags = db.get_user_flags(session.user_id()).await?;
4927 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4928 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4929
4930 if !session.is_staff() && !has_language_models_feature_flag {
4931 Err(anyhow!("permission denied"))?
4932 }
4933
4934 let user_id = session.user_id();
4935 let user = db
4936 .get_user_by_id(user_id)
4937 .await?
4938 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4939
4940 if user.accepted_tos_at.is_none() {
4941 Err(anyhow!("terms of service not accepted"))?
4942 }
4943
4944 let mut account_created_at = user.created_at;
4945 if let Some(github_created_at) = user.github_user_created_at {
4946 account_created_at = account_created_at.min(github_created_at);
4947 }
4948 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4949 Err(anyhow!("account too young"))?
4950 }
4951 let token = LlmTokenClaims::create(
4952 user.id,
4953 user.github_login.clone(),
4954 session.is_staff(),
4955 has_llm_closed_beta_feature_flag,
4956 session.current_plan(db).await?,
4957 &session.app_state.config,
4958 )?;
4959 response.send(proto::GetLlmTokenResponse { token })?;
4960 Ok(())
4961}
4962
4963fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4964 let message = match message {
4965 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4966 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4967 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4968 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4969 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4970 code: frame.code.into(),
4971 reason: frame.reason,
4972 })),
4973 // We should never receive a frame while reading the message, according
4974 // to the `tungstenite` maintainers:
4975 //
4976 // > It cannot occur when you read messages from the WebSocket, but it
4977 // > can be used when you want to send the raw frames (e.g. you want to
4978 // > send the frames to the WebSocket without composing the full message first).
4979 // >
4980 // > — https://github.com/snapview/tungstenite-rs/issues/268
4981 TungsteniteMessage::Frame(_) => {
4982 bail!("received an unexpected frame while reading the message")
4983 }
4984 };
4985
4986 Ok(message)
4987}
4988
4989fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4990 match message {
4991 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4992 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4993 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4994 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4995 AxumMessage::Close(frame) => {
4996 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4997 code: frame.code.into(),
4998 reason: frame.reason,
4999 }))
5000 }
5001 }
5002}
5003
5004fn notify_membership_updated(
5005 connection_pool: &mut ConnectionPool,
5006 result: MembershipUpdated,
5007 user_id: UserId,
5008 peer: &Peer,
5009) {
5010 for membership in &result.new_channels.channel_memberships {
5011 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5012 }
5013 for channel_id in &result.removed_channels {
5014 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5015 }
5016
5017 let user_channels_update = proto::UpdateUserChannels {
5018 channel_memberships: result
5019 .new_channels
5020 .channel_memberships
5021 .iter()
5022 .map(|cm| proto::ChannelMembership {
5023 channel_id: cm.channel_id.to_proto(),
5024 role: cm.role.into(),
5025 })
5026 .collect(),
5027 ..Default::default()
5028 };
5029
5030 let mut update = build_channels_update(result.new_channels);
5031 update.delete_channels = result
5032 .removed_channels
5033 .into_iter()
5034 .map(|id| id.to_proto())
5035 .collect();
5036 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5037
5038 for connection_id in connection_pool.user_connection_ids(user_id) {
5039 peer.send(connection_id, user_channels_update.clone())
5040 .trace_err();
5041 peer.send(connection_id, update.clone()).trace_err();
5042 }
5043}
5044
5045fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5046 proto::UpdateUserChannels {
5047 channel_memberships: channels
5048 .channel_memberships
5049 .iter()
5050 .map(|m| proto::ChannelMembership {
5051 channel_id: m.channel_id.to_proto(),
5052 role: m.role.into(),
5053 })
5054 .collect(),
5055 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5056 observed_channel_message_id: channels.observed_channel_messages.clone(),
5057 }
5058}
5059
5060fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5061 let mut update = proto::UpdateChannels::default();
5062
5063 for channel in channels.channels {
5064 update.channels.push(channel.to_proto());
5065 }
5066
5067 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5068 update.latest_channel_message_ids = channels.latest_channel_messages;
5069
5070 for (channel_id, participants) in channels.channel_participants {
5071 update
5072 .channel_participants
5073 .push(proto::ChannelParticipants {
5074 channel_id: channel_id.to_proto(),
5075 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5076 });
5077 }
5078
5079 for channel in channels.invited_channels {
5080 update.channel_invitations.push(channel.to_proto());
5081 }
5082
5083 update.hosted_projects = channels.hosted_projects;
5084 update
5085}
5086
5087fn build_initial_contacts_update(
5088 contacts: Vec<db::Contact>,
5089 pool: &ConnectionPool,
5090) -> proto::UpdateContacts {
5091 let mut update = proto::UpdateContacts::default();
5092
5093 for contact in contacts {
5094 match contact {
5095 db::Contact::Accepted { user_id, busy } => {
5096 update.contacts.push(contact_for_user(user_id, busy, pool));
5097 }
5098 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5099 db::Contact::Incoming { user_id } => {
5100 update
5101 .incoming_requests
5102 .push(proto::IncomingContactRequest {
5103 requester_id: user_id.to_proto(),
5104 })
5105 }
5106 }
5107 }
5108
5109 update
5110}
5111
5112fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5113 proto::Contact {
5114 user_id: user_id.to_proto(),
5115 online: pool.is_user_online(user_id),
5116 busy,
5117 }
5118}
5119
5120fn room_updated(room: &proto::Room, peer: &Peer) {
5121 broadcast(
5122 None,
5123 room.participants
5124 .iter()
5125 .filter_map(|participant| Some(participant.peer_id?.into())),
5126 |peer_id| {
5127 peer.send(
5128 peer_id,
5129 proto::RoomUpdated {
5130 room: Some(room.clone()),
5131 },
5132 )
5133 },
5134 );
5135}
5136
5137fn channel_updated(
5138 channel: &db::channel::Model,
5139 room: &proto::Room,
5140 peer: &Peer,
5141 pool: &ConnectionPool,
5142) {
5143 let participants = room
5144 .participants
5145 .iter()
5146 .map(|p| p.user_id)
5147 .collect::<Vec<_>>();
5148
5149 broadcast(
5150 None,
5151 pool.channel_connection_ids(channel.root_id())
5152 .filter_map(|(channel_id, role)| {
5153 role.can_see_channel(channel.visibility)
5154 .then_some(channel_id)
5155 }),
5156 |peer_id| {
5157 peer.send(
5158 peer_id,
5159 proto::UpdateChannels {
5160 channel_participants: vec![proto::ChannelParticipants {
5161 channel_id: channel.id.to_proto(),
5162 participant_user_ids: participants.clone(),
5163 }],
5164 ..Default::default()
5165 },
5166 )
5167 },
5168 );
5169}
5170
5171async fn send_dev_server_projects_update(
5172 user_id: UserId,
5173 mut status: proto::DevServerProjectsUpdate,
5174 session: &Session,
5175) {
5176 let pool = session.connection_pool().await;
5177 for dev_server in &mut status.dev_servers {
5178 dev_server.status =
5179 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5180 }
5181 let connections = pool.user_connection_ids(user_id);
5182 for connection_id in connections {
5183 session.peer.send(connection_id, status.clone()).trace_err();
5184 }
5185}
5186
5187async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5188 let db = session.db().await;
5189
5190 let contacts = db.get_contacts(user_id).await?;
5191 let busy = db.is_user_busy(user_id).await?;
5192
5193 let pool = session.connection_pool().await;
5194 let updated_contact = contact_for_user(user_id, busy, &pool);
5195 for contact in contacts {
5196 if let db::Contact::Accepted {
5197 user_id: contact_user_id,
5198 ..
5199 } = contact
5200 {
5201 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5202 session
5203 .peer
5204 .send(
5205 contact_conn_id,
5206 proto::UpdateContacts {
5207 contacts: vec![updated_contact.clone()],
5208 remove_contacts: Default::default(),
5209 incoming_requests: Default::default(),
5210 remove_incoming_requests: Default::default(),
5211 outgoing_requests: Default::default(),
5212 remove_outgoing_requests: Default::default(),
5213 },
5214 )
5215 .trace_err();
5216 }
5217 }
5218 }
5219 Ok(())
5220}
5221
5222async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5223 log::info!("lost dev server connection, unsharing projects");
5224 let project_ids = session
5225 .db()
5226 .await
5227 .get_stale_dev_server_projects(session.connection_id)
5228 .await?;
5229
5230 for project_id in project_ids {
5231 // not unshare re-checks the connection ids match, so we get away with no transaction
5232 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5233 }
5234
5235 let user_id = session.dev_server().user_id;
5236 let update = session
5237 .db()
5238 .await
5239 .dev_server_projects_update(user_id)
5240 .await?;
5241
5242 send_dev_server_projects_update(user_id, update, session).await;
5243
5244 Ok(())
5245}
5246
5247async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5248 let mut contacts_to_update = HashSet::default();
5249
5250 let room_id;
5251 let canceled_calls_to_user_ids;
5252 let live_kit_room;
5253 let delete_live_kit_room;
5254 let room;
5255 let channel;
5256
5257 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5258 contacts_to_update.insert(session.user_id());
5259
5260 for project in left_room.left_projects.values() {
5261 project_left(project, session);
5262 }
5263
5264 room_id = RoomId::from_proto(left_room.room.id);
5265 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5266 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5267 delete_live_kit_room = left_room.deleted;
5268 room = mem::take(&mut left_room.room);
5269 channel = mem::take(&mut left_room.channel);
5270
5271 room_updated(&room, &session.peer);
5272 } else {
5273 return Ok(());
5274 }
5275
5276 if let Some(channel) = channel {
5277 channel_updated(
5278 &channel,
5279 &room,
5280 &session.peer,
5281 &*session.connection_pool().await,
5282 );
5283 }
5284
5285 {
5286 let pool = session.connection_pool().await;
5287 for canceled_user_id in canceled_calls_to_user_ids {
5288 for connection_id in pool.user_connection_ids(canceled_user_id) {
5289 session
5290 .peer
5291 .send(
5292 connection_id,
5293 proto::CallCanceled {
5294 room_id: room_id.to_proto(),
5295 },
5296 )
5297 .trace_err();
5298 }
5299 contacts_to_update.insert(canceled_user_id);
5300 }
5301 }
5302
5303 for contact_user_id in contacts_to_update {
5304 update_user_contacts(contact_user_id, session).await?;
5305 }
5306
5307 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5308 live_kit
5309 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5310 .await
5311 .trace_err();
5312
5313 if delete_live_kit_room {
5314 live_kit.delete_room(live_kit_room).await.trace_err();
5315 }
5316 }
5317
5318 Ok(())
5319}
5320
5321async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5322 let left_channel_buffers = session
5323 .db()
5324 .await
5325 .leave_channel_buffers(session.connection_id)
5326 .await?;
5327
5328 for left_buffer in left_channel_buffers {
5329 channel_buffer_updated(
5330 session.connection_id,
5331 left_buffer.connections,
5332 &proto::UpdateChannelBufferCollaborators {
5333 channel_id: left_buffer.channel_id.to_proto(),
5334 collaborators: left_buffer.collaborators,
5335 },
5336 &session.peer,
5337 );
5338 }
5339
5340 Ok(())
5341}
5342
5343fn project_left(project: &db::LeftProject, session: &UserSession) {
5344 for connection_id in &project.connection_ids {
5345 if project.should_unshare {
5346 session
5347 .peer
5348 .send(
5349 *connection_id,
5350 proto::UnshareProject {
5351 project_id: project.id.to_proto(),
5352 },
5353 )
5354 .trace_err();
5355 } else {
5356 session
5357 .peer
5358 .send(
5359 *connection_id,
5360 proto::RemoveProjectCollaborator {
5361 project_id: project.id.to_proto(),
5362 peer_id: Some(session.connection_id.into()),
5363 },
5364 )
5365 .trace_err();
5366 }
5367 }
5368}
5369
5370pub trait ResultExt {
5371 type Ok;
5372
5373 fn trace_err(self) -> Option<Self::Ok>;
5374}
5375
5376impl<T, E> ResultExt for Result<T, E>
5377where
5378 E: std::fmt::Debug,
5379{
5380 type Ok = T;
5381
5382 #[track_caller]
5383 fn trace_err(self) -> Option<T> {
5384 match self {
5385 Ok(value) => Some(value),
5386 Err(error) => {
5387 tracing::error!("{:?}", error);
5388 None
5389 }
5390 }
5391 }
5392}