1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, MutexGuard, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111 DevServer(dev_server::Model),
112}
113
114impl Principal {
115 fn update_span(&self, span: &tracing::Span) {
116 match &self {
117 Principal::User(user) => {
118 span.record("user_id", user.id.0);
119 span.record("login", &user.github_login);
120 }
121 Principal::Impersonated { user, admin } => {
122 span.record("user_id", user.id.0);
123 span.record("login", &user.github_login);
124 span.record("impersonator", &admin.github_login);
125 }
126 Principal::DevServer(dev_server) => {
127 span.record("dev_server_id", dev_server.id.0);
128 }
129 }
130 }
131}
132
133#[derive(Clone)]
134struct Session {
135 principal: Principal,
136 connection_id: ConnectionId,
137 db: Arc<tokio::sync::Mutex<DbHandle>>,
138 peer: Arc<Peer>,
139 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
140 app_state: Arc<AppState>,
141 supermaven_client: Option<Arc<SupermavenAdminApi>>,
142 http_client: Arc<IsahcHttpClient>,
143 /// The GeoIP country code for the user.
144 #[allow(unused)]
145 geoip_country_code: Option<String>,
146 _executor: Executor,
147}
148
149impl Session {
150 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
151 #[cfg(test)]
152 tokio::task::yield_now().await;
153 let guard = self.db.lock().await;
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 guard
157 }
158
159 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.connection_pool.lock();
163 ConnectionPoolGuard {
164 guard,
165 _not_send: PhantomData,
166 }
167 }
168
169 fn for_user(self) -> Option<UserSession> {
170 UserSession::new(self)
171 }
172
173 fn for_dev_server(self) -> Option<DevServerSession> {
174 DevServerSession::new(self)
175 }
176
177 fn user_id(&self) -> Option<UserId> {
178 match &self.principal {
179 Principal::User(user) => Some(user.id),
180 Principal::Impersonated { user, .. } => Some(user.id),
181 Principal::DevServer(_) => None,
182 }
183 }
184
185 fn is_staff(&self) -> bool {
186 match &self.principal {
187 Principal::User(user) => user.admin,
188 Principal::Impersonated { .. } => true,
189 Principal::DevServer(_) => false,
190 }
191 }
192
193 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
194 if self.is_staff() {
195 return Ok(proto::Plan::ZedPro);
196 }
197
198 let Some(user_id) = self.user_id() else {
199 return Ok(proto::Plan::Free);
200 };
201
202 if db.has_active_billing_subscription(user_id).await? {
203 Ok(proto::Plan::ZedPro)
204 } else {
205 Ok(proto::Plan::Free)
206 }
207 }
208
209 fn dev_server_id(&self) -> Option<DevServerId> {
210 match &self.principal {
211 Principal::User(_) | Principal::Impersonated { .. } => None,
212 Principal::DevServer(dev_server) => Some(dev_server.id),
213 }
214 }
215
216 fn principal_id(&self) -> PrincipalId {
217 match &self.principal {
218 Principal::User(user) => PrincipalId::UserId(user.id),
219 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
220 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
221 }
222 }
223}
224
225impl Debug for Session {
226 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
227 let mut result = f.debug_struct("Session");
228 match &self.principal {
229 Principal::User(user) => {
230 result.field("user", &user.github_login);
231 }
232 Principal::Impersonated { user, admin } => {
233 result.field("user", &user.github_login);
234 result.field("impersonator", &admin.github_login);
235 }
236 Principal::DevServer(dev_server) => {
237 result.field("dev_server", &dev_server.id);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct UserSession(Session);
245
246impl UserSession {
247 pub fn new(s: Session) -> Option<Self> {
248 s.user_id().map(|_| UserSession(s))
249 }
250 pub fn user_id(&self) -> UserId {
251 self.0.user_id().unwrap()
252 }
253
254 pub fn email(&self) -> Option<String> {
255 match &self.0.principal {
256 Principal::User(user) => user.email_address.clone(),
257 Principal::Impersonated { user, .. } => user.email_address.clone(),
258 Principal::DevServer(..) => None,
259 }
260 }
261}
262
263impl Deref for UserSession {
264 type Target = Session;
265
266 fn deref(&self) -> &Self::Target {
267 &self.0
268 }
269}
270impl DerefMut for UserSession {
271 fn deref_mut(&mut self) -> &mut Self::Target {
272 &mut self.0
273 }
274}
275
276struct DevServerSession(Session);
277
278impl DevServerSession {
279 pub fn new(s: Session) -> Option<Self> {
280 s.dev_server_id().map(|_| DevServerSession(s))
281 }
282 pub fn dev_server_id(&self) -> DevServerId {
283 self.0.dev_server_id().unwrap()
284 }
285
286 fn dev_server(&self) -> &dev_server::Model {
287 match &self.0.principal {
288 Principal::DevServer(dev_server) => dev_server,
289 _ => unreachable!(),
290 }
291 }
292}
293
294impl Deref for DevServerSession {
295 type Target = Session;
296
297 fn deref(&self) -> &Self::Target {
298 &self.0
299 }
300}
301impl DerefMut for DevServerSession {
302 fn deref_mut(&mut self) -> &mut Self::Target {
303 &mut self.0
304 }
305}
306
307fn user_handler<M: RequestMessage, Fut>(
308 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
310where
311 Fut: Send + Future<Output = Result<()>>,
312{
313 let handler = Arc::new(handler);
314 move |message, response, session| {
315 let handler = handler.clone();
316 Box::pin(async move {
317 if let Some(user_session) = session.for_user() {
318 Ok(handler(message, response, user_session).await?)
319 } else {
320 Err(Error::Internal(anyhow!(
321 "must be a user to call {}",
322 M::NAME
323 )))
324 }
325 })
326 }
327}
328
329fn dev_server_handler<M: RequestMessage, Fut>(
330 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
332where
333 Fut: Send + Future<Output = Result<()>>,
334{
335 let handler = Arc::new(handler);
336 move |message, response, session| {
337 let handler = handler.clone();
338 Box::pin(async move {
339 if let Some(dev_server_session) = session.for_dev_server() {
340 Ok(handler(message, response, dev_server_session).await?)
341 } else {
342 Err(Error::Internal(anyhow!(
343 "must be a dev server to call {}",
344 M::NAME
345 )))
346 }
347 })
348 }
349}
350
351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
352 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
354where
355 InnertRetFut: Send + Future<Output = Result<()>>,
356{
357 let handler = Arc::new(handler);
358 move |message, session| {
359 let handler = handler.clone();
360 Box::pin(async move {
361 if let Some(user_session) = session.for_user() {
362 Ok(handler(message, user_session).await?)
363 } else {
364 Err(Error::Internal(anyhow!(
365 "must be a user to call {}",
366 M::NAME
367 )))
368 }
369 })
370 }
371}
372
373struct DbHandle(Arc<Database>);
374
375impl Deref for DbHandle {
376 type Target = Database;
377
378 fn deref(&self) -> &Self::Target {
379 self.0.as_ref()
380 }
381}
382
383pub struct Server {
384 id: parking_lot::Mutex<ServerId>,
385 peer: Arc<Peer>,
386 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
387 app_state: Arc<AppState>,
388 handlers: HashMap<TypeId, MessageHandler>,
389 teardown: watch::Sender<bool>,
390}
391
392pub(crate) struct ConnectionPoolGuard<'a> {
393 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
394 _not_send: PhantomData<Rc<()>>,
395}
396
397#[derive(Serialize)]
398pub struct ServerSnapshot<'a> {
399 peer: &'a Peer,
400 #[serde(serialize_with = "serialize_deref")]
401 connection_pool: ConnectionPoolGuard<'a>,
402}
403
404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
405where
406 S: Serializer,
407 T: Deref<Target = U>,
408 U: Serialize,
409{
410 Serialize::serialize(value.deref(), serializer)
411}
412
413impl Server {
414 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
415 let mut server = Self {
416 id: parking_lot::Mutex::new(id),
417 peer: Peer::new(id.0 as u32),
418 app_state: app_state.clone(),
419 connection_pool: Default::default(),
420 handlers: Default::default(),
421 teardown: watch::channel(false).0,
422 };
423
424 server
425 .add_request_handler(ping)
426 .add_request_handler(user_handler(create_room))
427 .add_request_handler(user_handler(join_room))
428 .add_request_handler(user_handler(rejoin_room))
429 .add_request_handler(user_handler(leave_room))
430 .add_request_handler(user_handler(set_room_participant_role))
431 .add_request_handler(user_handler(call))
432 .add_request_handler(user_handler(cancel_call))
433 .add_message_handler(user_message_handler(decline_call))
434 .add_request_handler(user_handler(update_participant_location))
435 .add_request_handler(user_handler(share_project))
436 .add_message_handler(unshare_project)
437 .add_request_handler(user_handler(join_project))
438 .add_request_handler(user_handler(join_hosted_project))
439 .add_request_handler(user_handler(rejoin_dev_server_projects))
440 .add_request_handler(user_handler(create_dev_server_project))
441 .add_request_handler(user_handler(update_dev_server_project))
442 .add_request_handler(user_handler(delete_dev_server_project))
443 .add_request_handler(user_handler(create_dev_server))
444 .add_request_handler(user_handler(regenerate_dev_server_token))
445 .add_request_handler(user_handler(rename_dev_server))
446 .add_request_handler(user_handler(delete_dev_server))
447 .add_request_handler(user_handler(list_remote_directory))
448 .add_request_handler(dev_server_handler(share_dev_server_project))
449 .add_request_handler(dev_server_handler(shutdown_dev_server))
450 .add_request_handler(dev_server_handler(reconnect_dev_server))
451 .add_message_handler(user_message_handler(leave_project))
452 .add_request_handler(update_project)
453 .add_request_handler(update_worktree)
454 .add_message_handler(start_language_server)
455 .add_message_handler(update_language_server)
456 .add_message_handler(update_diagnostic_summary)
457 .add_message_handler(update_worktree_settings)
458 .add_request_handler(user_handler(
459 forward_project_request_for_owner::<proto::TaskContextForLocation>,
460 ))
461 .add_request_handler(user_handler(
462 forward_project_request_for_owner::<proto::TaskTemplates>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetHover>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetDefinition>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetTypeDefinition>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetReferences>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::SearchProject>,
478 ))
479 .add_request_handler(user_handler(forward_find_search_candidates_request))
480 .add_request_handler(user_handler(
481 forward_read_only_project_request::<proto::GetDocumentHighlights>,
482 ))
483 .add_request_handler(user_handler(
484 forward_read_only_project_request::<proto::GetProjectSymbols>,
485 ))
486 .add_request_handler(user_handler(
487 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
488 ))
489 .add_request_handler(user_handler(
490 forward_read_only_project_request::<proto::OpenBufferById>,
491 ))
492 .add_request_handler(user_handler(
493 forward_read_only_project_request::<proto::SynchronizeBuffers>,
494 ))
495 .add_request_handler(user_handler(
496 forward_read_only_project_request::<proto::InlayHints>,
497 ))
498 .add_request_handler(user_handler(
499 forward_read_only_project_request::<proto::ResolveInlayHint>,
500 ))
501 .add_request_handler(user_handler(
502 forward_read_only_project_request::<proto::OpenBufferByPath>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::GetCompletions>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::OpenNewBuffer>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
515 ))
516 .add_request_handler(user_handler(
517 forward_mutating_project_request::<proto::GetCodeActions>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::ApplyCodeAction>,
521 ))
522 .add_request_handler(user_handler(
523 forward_mutating_project_request::<proto::PrepareRename>,
524 ))
525 .add_request_handler(user_handler(
526 forward_mutating_project_request::<proto::PerformRename>,
527 ))
528 .add_request_handler(user_handler(
529 forward_mutating_project_request::<proto::ReloadBuffers>,
530 ))
531 .add_request_handler(user_handler(
532 forward_mutating_project_request::<proto::FormatBuffers>,
533 ))
534 .add_request_handler(user_handler(
535 forward_mutating_project_request::<proto::CreateProjectEntry>,
536 ))
537 .add_request_handler(user_handler(
538 forward_mutating_project_request::<proto::RenameProjectEntry>,
539 ))
540 .add_request_handler(user_handler(
541 forward_mutating_project_request::<proto::CopyProjectEntry>,
542 ))
543 .add_request_handler(user_handler(
544 forward_mutating_project_request::<proto::DeleteProjectEntry>,
545 ))
546 .add_request_handler(user_handler(
547 forward_mutating_project_request::<proto::ExpandProjectEntry>,
548 ))
549 .add_request_handler(user_handler(
550 forward_mutating_project_request::<proto::OnTypeFormatting>,
551 ))
552 .add_request_handler(user_handler(
553 forward_mutating_project_request::<proto::SaveBuffer>,
554 ))
555 .add_request_handler(user_handler(
556 forward_mutating_project_request::<proto::BlameBuffer>,
557 ))
558 .add_request_handler(user_handler(
559 forward_mutating_project_request::<proto::MultiLspQuery>,
560 ))
561 .add_request_handler(user_handler(
562 forward_mutating_project_request::<proto::RestartLanguageServers>,
563 ))
564 .add_request_handler(user_handler(
565 forward_mutating_project_request::<proto::LinkedEditingRange>,
566 ))
567 .add_message_handler(create_buffer_for_peer)
568 .add_request_handler(update_buffer)
569 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
573 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
574 .add_request_handler(get_users)
575 .add_request_handler(user_handler(fuzzy_search_users))
576 .add_request_handler(user_handler(request_contact))
577 .add_request_handler(user_handler(remove_contact))
578 .add_request_handler(user_handler(respond_to_contact_request))
579 .add_message_handler(subscribe_to_channels)
580 .add_request_handler(user_handler(create_channel))
581 .add_request_handler(user_handler(delete_channel))
582 .add_request_handler(user_handler(invite_channel_member))
583 .add_request_handler(user_handler(remove_channel_member))
584 .add_request_handler(user_handler(set_channel_member_role))
585 .add_request_handler(user_handler(set_channel_visibility))
586 .add_request_handler(user_handler(rename_channel))
587 .add_request_handler(user_handler(join_channel_buffer))
588 .add_request_handler(user_handler(leave_channel_buffer))
589 .add_message_handler(user_message_handler(update_channel_buffer))
590 .add_request_handler(user_handler(rejoin_channel_buffers))
591 .add_request_handler(user_handler(get_channel_members))
592 .add_request_handler(user_handler(respond_to_channel_invite))
593 .add_request_handler(user_handler(join_channel))
594 .add_request_handler(user_handler(join_channel_chat))
595 .add_message_handler(user_message_handler(leave_channel_chat))
596 .add_request_handler(user_handler(send_channel_message))
597 .add_request_handler(user_handler(remove_channel_message))
598 .add_request_handler(user_handler(update_channel_message))
599 .add_request_handler(user_handler(get_channel_messages))
600 .add_request_handler(user_handler(get_channel_messages_by_id))
601 .add_request_handler(user_handler(get_notifications))
602 .add_request_handler(user_handler(mark_notification_as_read))
603 .add_request_handler(user_handler(move_channel))
604 .add_request_handler(user_handler(follow))
605 .add_message_handler(user_message_handler(unfollow))
606 .add_message_handler(user_message_handler(update_followers))
607 .add_request_handler(user_handler(get_private_user_info))
608 .add_request_handler(user_handler(get_llm_api_token))
609 .add_request_handler(user_handler(accept_terms_of_service))
610 .add_message_handler(user_message_handler(acknowledge_channel_message))
611 .add_message_handler(user_message_handler(acknowledge_buffer_version))
612 .add_request_handler(user_handler(get_supermaven_api_key))
613 .add_request_handler(user_handler(
614 forward_mutating_project_request::<proto::OpenContext>,
615 ))
616 .add_request_handler(user_handler(
617 forward_mutating_project_request::<proto::CreateContext>,
618 ))
619 .add_request_handler(user_handler(
620 forward_mutating_project_request::<proto::SynchronizeContexts>,
621 ))
622 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
623 .add_message_handler(update_context)
624 .add_request_handler({
625 let app_state = app_state.clone();
626 move |request, response, session| {
627 let app_state = app_state.clone();
628 async move {
629 count_language_model_tokens(request, response, session, &app_state.config)
630 .await
631 }
632 }
633 })
634 .add_request_handler({
635 user_handler(move |request, response, session| {
636 get_cached_embeddings(request, response, session)
637 })
638 })
639 .add_request_handler({
640 let app_state = app_state.clone();
641 user_handler(move |request, response, session| {
642 compute_embeddings(
643 request,
644 response,
645 session,
646 app_state.config.openai_api_key.clone(),
647 )
648 })
649 });
650
651 Arc::new(server)
652 }
653
654 pub async fn start(&self) -> Result<()> {
655 let server_id = *self.id.lock();
656 let app_state = self.app_state.clone();
657 let peer = self.peer.clone();
658 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
659 let pool = self.connection_pool.clone();
660 let live_kit_client = self.app_state.live_kit_client.clone();
661
662 let span = info_span!("start server");
663 self.app_state.executor.spawn_detached(
664 async move {
665 tracing::info!("waiting for cleanup timeout");
666 timeout.await;
667 tracing::info!("cleanup timeout expired, retrieving stale rooms");
668 if let Some((room_ids, channel_ids)) = app_state
669 .db
670 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
671 .await
672 .trace_err()
673 {
674 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
675 tracing::info!(
676 stale_channel_buffer_count = channel_ids.len(),
677 "retrieved stale channel buffers"
678 );
679
680 for channel_id in channel_ids {
681 if let Some(refreshed_channel_buffer) = app_state
682 .db
683 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
684 .await
685 .trace_err()
686 {
687 for connection_id in refreshed_channel_buffer.connection_ids {
688 peer.send(
689 connection_id,
690 proto::UpdateChannelBufferCollaborators {
691 channel_id: channel_id.to_proto(),
692 collaborators: refreshed_channel_buffer
693 .collaborators
694 .clone(),
695 },
696 )
697 .trace_err();
698 }
699 }
700 }
701
702 for room_id in room_ids {
703 let mut contacts_to_update = HashSet::default();
704 let mut canceled_calls_to_user_ids = Vec::new();
705 let mut live_kit_room = String::new();
706 let mut delete_live_kit_room = false;
707
708 if let Some(mut refreshed_room) = app_state
709 .db
710 .clear_stale_room_participants(room_id, server_id)
711 .await
712 .trace_err()
713 {
714 tracing::info!(
715 room_id = room_id.0,
716 new_participant_count = refreshed_room.room.participants.len(),
717 "refreshed room"
718 );
719 room_updated(&refreshed_room.room, &peer);
720 if let Some(channel) = refreshed_room.channel.as_ref() {
721 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
722 }
723 contacts_to_update
724 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
725 contacts_to_update
726 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
727 canceled_calls_to_user_ids =
728 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
729 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
730 delete_live_kit_room = refreshed_room.room.participants.is_empty();
731 }
732
733 {
734 let pool = pool.lock();
735 for canceled_user_id in canceled_calls_to_user_ids {
736 for connection_id in pool.user_connection_ids(canceled_user_id) {
737 peer.send(
738 connection_id,
739 proto::CallCanceled {
740 room_id: room_id.to_proto(),
741 },
742 )
743 .trace_err();
744 }
745 }
746 }
747
748 for user_id in contacts_to_update {
749 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
750 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
751 if let Some((busy, contacts)) = busy.zip(contacts) {
752 let pool = pool.lock();
753 let updated_contact = contact_for_user(user_id, busy, &pool);
754 for contact in contacts {
755 if let db::Contact::Accepted {
756 user_id: contact_user_id,
757 ..
758 } = contact
759 {
760 for contact_conn_id in
761 pool.user_connection_ids(contact_user_id)
762 {
763 peer.send(
764 contact_conn_id,
765 proto::UpdateContacts {
766 contacts: vec![updated_contact.clone()],
767 remove_contacts: Default::default(),
768 incoming_requests: Default::default(),
769 remove_incoming_requests: Default::default(),
770 outgoing_requests: Default::default(),
771 remove_outgoing_requests: Default::default(),
772 },
773 )
774 .trace_err();
775 }
776 }
777 }
778 }
779 }
780
781 if let Some(live_kit) = live_kit_client.as_ref() {
782 if delete_live_kit_room {
783 live_kit.delete_room(live_kit_room).await.trace_err();
784 }
785 }
786 }
787 }
788
789 app_state
790 .db
791 .delete_stale_servers(&app_state.config.zed_environment, server_id)
792 .await
793 .trace_err();
794 }
795 .instrument(span),
796 );
797 Ok(())
798 }
799
800 pub fn teardown(&self) {
801 self.peer.teardown();
802 self.connection_pool.lock().reset();
803 let _ = self.teardown.send(true);
804 }
805
806 #[cfg(test)]
807 pub fn reset(&self, id: ServerId) {
808 self.teardown();
809 *self.id.lock() = id;
810 self.peer.reset(id.0 as u32);
811 let _ = self.teardown.send(false);
812 }
813
814 #[cfg(test)]
815 pub fn id(&self) -> ServerId {
816 *self.id.lock()
817 }
818
819 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
820 where
821 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
822 Fut: 'static + Send + Future<Output = Result<()>>,
823 M: EnvelopedMessage,
824 {
825 let prev_handler = self.handlers.insert(
826 TypeId::of::<M>(),
827 Box::new(move |envelope, session| {
828 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
829 let received_at = envelope.received_at;
830 tracing::info!("message received");
831 let start_time = Instant::now();
832 let future = (handler)(*envelope, session);
833 async move {
834 let result = future.await;
835 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
836 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
837 let queue_duration_ms = total_duration_ms - processing_duration_ms;
838 let payload_type = M::NAME;
839
840 match result {
841 Err(error) => {
842 tracing::error!(
843 ?error,
844 total_duration_ms,
845 processing_duration_ms,
846 queue_duration_ms,
847 payload_type,
848 "error handling message"
849 )
850 }
851 Ok(()) => tracing::info!(
852 total_duration_ms,
853 processing_duration_ms,
854 queue_duration_ms,
855 "finished handling message"
856 ),
857 }
858 }
859 .boxed()
860 }),
861 );
862 if prev_handler.is_some() {
863 panic!("registered a handler for the same message twice");
864 }
865 self
866 }
867
868 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
869 where
870 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
871 Fut: 'static + Send + Future<Output = Result<()>>,
872 M: EnvelopedMessage,
873 {
874 self.add_handler(move |envelope, session| handler(envelope.payload, session));
875 self
876 }
877
878 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
879 where
880 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
881 Fut: Send + Future<Output = Result<()>>,
882 M: RequestMessage,
883 {
884 let handler = Arc::new(handler);
885 self.add_handler(move |envelope, session| {
886 let receipt = envelope.receipt();
887 let handler = handler.clone();
888 async move {
889 let peer = session.peer.clone();
890 let responded = Arc::new(AtomicBool::default());
891 let response = Response {
892 peer: peer.clone(),
893 responded: responded.clone(),
894 receipt,
895 };
896 match (handler)(envelope.payload, response, session).await {
897 Ok(()) => {
898 if responded.load(std::sync::atomic::Ordering::SeqCst) {
899 Ok(())
900 } else {
901 Err(anyhow!("handler did not send a response"))?
902 }
903 }
904 Err(error) => {
905 let proto_err = match &error {
906 Error::Internal(err) => err.to_proto(),
907 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
908 };
909 peer.respond_with_error(receipt, proto_err)?;
910 Err(error)
911 }
912 }
913 }
914 })
915 }
916
917 #[allow(clippy::too_many_arguments)]
918 pub fn handle_connection(
919 self: &Arc<Self>,
920 connection: Connection,
921 address: String,
922 principal: Principal,
923 zed_version: ZedVersion,
924 geoip_country_code: Option<String>,
925 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
926 executor: Executor,
927 ) -> impl Future<Output = ()> {
928 let this = self.clone();
929 let span = info_span!("handle connection", %address,
930 connection_id=field::Empty,
931 user_id=field::Empty,
932 login=field::Empty,
933 impersonator=field::Empty,
934 dev_server_id=field::Empty,
935 geoip_country_code=field::Empty
936 );
937 principal.update_span(&span);
938 if let Some(country_code) = geoip_country_code.as_ref() {
939 span.record("geoip_country_code", country_code);
940 }
941
942 let mut teardown = self.teardown.subscribe();
943 async move {
944 if *teardown.borrow() {
945 tracing::error!("server is tearing down");
946 return
947 }
948 let (connection_id, handle_io, mut incoming_rx) = this
949 .peer
950 .add_connection(connection, {
951 let executor = executor.clone();
952 move |duration| executor.sleep(duration)
953 });
954 tracing::Span::current().record("connection_id", format!("{}", connection_id));
955
956 tracing::info!("connection opened");
957
958 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
959 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
960 Ok(http_client) => Arc::new(http_client),
961 Err(error) => {
962 tracing::error!(?error, "failed to create HTTP client");
963 return;
964 }
965 };
966
967 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
968 supermaven_admin_api_key.to_string(),
969 http_client.clone(),
970 )));
971
972 let session = Session {
973 principal: principal.clone(),
974 connection_id,
975 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
976 peer: this.peer.clone(),
977 connection_pool: this.connection_pool.clone(),
978 app_state: this.app_state.clone(),
979 http_client,
980 geoip_country_code,
981 _executor: executor.clone(),
982 supermaven_client,
983 };
984
985 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
986 tracing::error!(?error, "failed to send initial client update");
987 return;
988 }
989
990 let handle_io = handle_io.fuse();
991 futures::pin_mut!(handle_io);
992
993 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
994 // This prevents deadlocks when e.g., client A performs a request to client B and
995 // client B performs a request to client A. If both clients stop processing further
996 // messages until their respective request completes, they won't have a chance to
997 // respond to the other client's request and cause a deadlock.
998 //
999 // This arrangement ensures we will attempt to process earlier messages first, but fall
1000 // back to processing messages arrived later in the spirit of making progress.
1001 let mut foreground_message_handlers = FuturesUnordered::new();
1002 let concurrent_handlers = Arc::new(Semaphore::new(256));
1003 loop {
1004 let next_message = async {
1005 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1006 let message = incoming_rx.next().await;
1007 (permit, message)
1008 }.fuse();
1009 futures::pin_mut!(next_message);
1010 futures::select_biased! {
1011 _ = teardown.changed().fuse() => return,
1012 result = handle_io => {
1013 if let Err(error) = result {
1014 tracing::error!(?error, "error handling I/O");
1015 }
1016 break;
1017 }
1018 _ = foreground_message_handlers.next() => {}
1019 next_message = next_message => {
1020 let (permit, message) = next_message;
1021 if let Some(message) = message {
1022 let type_name = message.payload_type_name();
1023 // note: we copy all the fields from the parent span so we can query them in the logs.
1024 // (https://github.com/tokio-rs/tracing/issues/2670).
1025 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1026 user_id=field::Empty,
1027 login=field::Empty,
1028 impersonator=field::Empty,
1029 dev_server_id=field::Empty
1030 );
1031 principal.update_span(&span);
1032 let span_enter = span.enter();
1033 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1034 let is_background = message.is_background();
1035 let handle_message = (handler)(message, session.clone());
1036 drop(span_enter);
1037
1038 let handle_message = async move {
1039 handle_message.await;
1040 drop(permit);
1041 }.instrument(span);
1042 if is_background {
1043 executor.spawn_detached(handle_message);
1044 } else {
1045 foreground_message_handlers.push(handle_message);
1046 }
1047 } else {
1048 tracing::error!("no message handler");
1049 }
1050 } else {
1051 tracing::info!("connection closed");
1052 break;
1053 }
1054 }
1055 }
1056 }
1057
1058 drop(foreground_message_handlers);
1059 tracing::info!("signing out");
1060 if let Err(error) = connection_lost(session, teardown, executor).await {
1061 tracing::error!(?error, "error signing out");
1062 }
1063
1064 }.instrument(span)
1065 }
1066
1067 async fn send_initial_client_update(
1068 &self,
1069 connection_id: ConnectionId,
1070 principal: &Principal,
1071 zed_version: ZedVersion,
1072 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1073 session: &Session,
1074 ) -> Result<()> {
1075 self.peer.send(
1076 connection_id,
1077 proto::Hello {
1078 peer_id: Some(connection_id.into()),
1079 },
1080 )?;
1081 tracing::info!("sent hello message");
1082 if let Some(send_connection_id) = send_connection_id.take() {
1083 let _ = send_connection_id.send(connection_id);
1084 }
1085
1086 match principal {
1087 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1088 if !user.connected_once {
1089 self.peer.send(connection_id, proto::ShowContacts {})?;
1090 self.app_state
1091 .db
1092 .set_user_connected_once(user.id, true)
1093 .await?;
1094 }
1095
1096 update_user_plan(user.id, session).await?;
1097
1098 let (contacts, dev_server_projects) = future::try_join(
1099 self.app_state.db.get_contacts(user.id),
1100 self.app_state.db.dev_server_projects_update(user.id),
1101 )
1102 .await?;
1103
1104 {
1105 let mut pool = self.connection_pool.lock();
1106 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1107 self.peer.send(
1108 connection_id,
1109 build_initial_contacts_update(contacts, &pool),
1110 )?;
1111 }
1112
1113 if should_auto_subscribe_to_channels(zed_version) {
1114 subscribe_user_to_channels(user.id, session).await?;
1115 }
1116
1117 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1118
1119 if let Some(incoming_call) =
1120 self.app_state.db.incoming_call_for_user(user.id).await?
1121 {
1122 self.peer.send(connection_id, incoming_call)?;
1123 }
1124
1125 update_user_contacts(user.id, session).await?;
1126 }
1127 Principal::DevServer(dev_server) => {
1128 {
1129 let mut pool = self.connection_pool.lock();
1130 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1131 {
1132 self.peer.send(
1133 stale_connection_id,
1134 proto::ShutdownDevServer {
1135 reason: Some(
1136 "another dev server connected with the same token".to_string(),
1137 ),
1138 },
1139 )?;
1140 pool.remove_connection(stale_connection_id)?;
1141 };
1142 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1143 }
1144
1145 let projects = self
1146 .app_state
1147 .db
1148 .get_projects_for_dev_server(dev_server.id)
1149 .await?;
1150 self.peer
1151 .send(connection_id, proto::DevServerInstructions { projects })?;
1152
1153 let status = self
1154 .app_state
1155 .db
1156 .dev_server_projects_update(dev_server.user_id)
1157 .await?;
1158 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1159 }
1160 }
1161
1162 Ok(())
1163 }
1164
1165 pub async fn invite_code_redeemed(
1166 self: &Arc<Self>,
1167 inviter_id: UserId,
1168 invitee_id: UserId,
1169 ) -> Result<()> {
1170 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1171 if let Some(code) = &user.invite_code {
1172 let pool = self.connection_pool.lock();
1173 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1174 for connection_id in pool.user_connection_ids(inviter_id) {
1175 self.peer.send(
1176 connection_id,
1177 proto::UpdateContacts {
1178 contacts: vec![invitee_contact.clone()],
1179 ..Default::default()
1180 },
1181 )?;
1182 self.peer.send(
1183 connection_id,
1184 proto::UpdateInviteInfo {
1185 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1186 count: user.invite_count as u32,
1187 },
1188 )?;
1189 }
1190 }
1191 }
1192 Ok(())
1193 }
1194
1195 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1196 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1197 if let Some(invite_code) = &user.invite_code {
1198 let pool = self.connection_pool.lock();
1199 for connection_id in pool.user_connection_ids(user_id) {
1200 self.peer.send(
1201 connection_id,
1202 proto::UpdateInviteInfo {
1203 url: format!(
1204 "{}{}",
1205 self.app_state.config.invite_link_prefix, invite_code
1206 ),
1207 count: user.invite_count as u32,
1208 },
1209 )?;
1210 }
1211 }
1212 }
1213 Ok(())
1214 }
1215
1216 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1217 ServerSnapshot {
1218 connection_pool: ConnectionPoolGuard {
1219 guard: self.connection_pool.lock(),
1220 _not_send: PhantomData,
1221 },
1222 peer: &self.peer,
1223 }
1224 }
1225}
1226
1227impl<'a> Deref for ConnectionPoolGuard<'a> {
1228 type Target = ConnectionPool;
1229
1230 fn deref(&self) -> &Self::Target {
1231 &self.guard
1232 }
1233}
1234
1235impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1236 fn deref_mut(&mut self) -> &mut Self::Target {
1237 &mut self.guard
1238 }
1239}
1240
1241impl<'a> Drop for ConnectionPoolGuard<'a> {
1242 fn drop(&mut self) {
1243 #[cfg(test)]
1244 self.check_invariants();
1245 }
1246}
1247
1248fn broadcast<F>(
1249 sender_id: Option<ConnectionId>,
1250 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1251 mut f: F,
1252) where
1253 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1254{
1255 for receiver_id in receiver_ids {
1256 if Some(receiver_id) != sender_id {
1257 if let Err(error) = f(receiver_id) {
1258 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1259 }
1260 }
1261 }
1262}
1263
1264pub struct ProtocolVersion(u32);
1265
1266impl Header for ProtocolVersion {
1267 fn name() -> &'static HeaderName {
1268 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1269 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1270 }
1271
1272 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1273 where
1274 Self: Sized,
1275 I: Iterator<Item = &'i axum::http::HeaderValue>,
1276 {
1277 let version = values
1278 .next()
1279 .ok_or_else(axum::headers::Error::invalid)?
1280 .to_str()
1281 .map_err(|_| axum::headers::Error::invalid())?
1282 .parse()
1283 .map_err(|_| axum::headers::Error::invalid())?;
1284 Ok(Self(version))
1285 }
1286
1287 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1288 values.extend([self.0.to_string().parse().unwrap()]);
1289 }
1290}
1291
1292pub struct AppVersionHeader(SemanticVersion);
1293impl Header for AppVersionHeader {
1294 fn name() -> &'static HeaderName {
1295 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1297 }
1298
1299 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300 where
1301 Self: Sized,
1302 I: Iterator<Item = &'i axum::http::HeaderValue>,
1303 {
1304 let version = values
1305 .next()
1306 .ok_or_else(axum::headers::Error::invalid)?
1307 .to_str()
1308 .map_err(|_| axum::headers::Error::invalid())?
1309 .parse()
1310 .map_err(|_| axum::headers::Error::invalid())?;
1311 Ok(Self(version))
1312 }
1313
1314 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315 values.extend([self.0.to_string().parse().unwrap()]);
1316 }
1317}
1318
1319pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1320 Router::new()
1321 .route("/rpc", get(handle_websocket_request))
1322 .layer(
1323 ServiceBuilder::new()
1324 .layer(Extension(server.app_state.clone()))
1325 .layer(middleware::from_fn(auth::validate_header)),
1326 )
1327 .route("/metrics", get(handle_metrics))
1328 .layer(Extension(server))
1329}
1330
1331pub async fn handle_websocket_request(
1332 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1333 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1334 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1335 Extension(server): Extension<Arc<Server>>,
1336 Extension(principal): Extension<Principal>,
1337 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1338 ws: WebSocketUpgrade,
1339) -> axum::response::Response {
1340 if protocol_version != rpc::PROTOCOL_VERSION {
1341 return (
1342 StatusCode::UPGRADE_REQUIRED,
1343 "client must be upgraded".to_string(),
1344 )
1345 .into_response();
1346 }
1347
1348 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1349 return (
1350 StatusCode::UPGRADE_REQUIRED,
1351 "no version header found".to_string(),
1352 )
1353 .into_response();
1354 };
1355
1356 if !version.can_collaborate() {
1357 return (
1358 StatusCode::UPGRADE_REQUIRED,
1359 "client must be upgraded".to_string(),
1360 )
1361 .into_response();
1362 }
1363
1364 let socket_address = socket_address.to_string();
1365 ws.on_upgrade(move |socket| {
1366 let socket = socket
1367 .map_ok(to_tungstenite_message)
1368 .err_into()
1369 .with(|message| async move { to_axum_message(message) });
1370 let connection = Connection::new(Box::pin(socket));
1371 async move {
1372 server
1373 .handle_connection(
1374 connection,
1375 socket_address,
1376 principal,
1377 version,
1378 country_code_header.map(|header| header.to_string()),
1379 None,
1380 Executor::Production,
1381 )
1382 .await;
1383 }
1384 })
1385}
1386
1387pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1388 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1389 let connections_metric = CONNECTIONS_METRIC
1390 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1391
1392 let connections = server
1393 .connection_pool
1394 .lock()
1395 .connections()
1396 .filter(|connection| !connection.admin)
1397 .count();
1398 connections_metric.set(connections as _);
1399
1400 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1402 register_int_gauge!(
1403 "shared_projects",
1404 "number of open projects with one or more guests"
1405 )
1406 .unwrap()
1407 });
1408
1409 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1410 shared_projects_metric.set(shared_projects as _);
1411
1412 let encoder = prometheus::TextEncoder::new();
1413 let metric_families = prometheus::gather();
1414 let encoded_metrics = encoder
1415 .encode_to_string(&metric_families)
1416 .map_err(|err| anyhow!("{}", err))?;
1417 Ok(encoded_metrics)
1418}
1419
1420#[instrument(err, skip(executor))]
1421async fn connection_lost(
1422 session: Session,
1423 mut teardown: watch::Receiver<bool>,
1424 executor: Executor,
1425) -> Result<()> {
1426 session.peer.disconnect(session.connection_id);
1427 session
1428 .connection_pool()
1429 .await
1430 .remove_connection(session.connection_id)?;
1431
1432 session
1433 .db()
1434 .await
1435 .connection_lost(session.connection_id)
1436 .await
1437 .trace_err();
1438
1439 futures::select_biased! {
1440 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1441 match &session.principal {
1442 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1443 let session = session.for_user().unwrap();
1444
1445 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1446 leave_room_for_session(&session, session.connection_id).await.trace_err();
1447 leave_channel_buffers_for_session(&session)
1448 .await
1449 .trace_err();
1450
1451 if !session
1452 .connection_pool()
1453 .await
1454 .is_user_online(session.user_id())
1455 {
1456 let db = session.db().await;
1457 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1458 room_updated(&room, &session.peer);
1459 }
1460 }
1461
1462 update_user_contacts(session.user_id(), &session).await?;
1463 },
1464 Principal::DevServer(_) => {
1465 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1466 },
1467 }
1468 },
1469 _ = teardown.changed().fuse() => {}
1470 }
1471
1472 Ok(())
1473}
1474
1475/// Acknowledges a ping from a client, used to keep the connection alive.
1476async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1477 response.send(proto::Ack {})?;
1478 Ok(())
1479}
1480
1481/// Creates a new room for calling (outside of channels)
1482async fn create_room(
1483 _request: proto::CreateRoom,
1484 response: Response<proto::CreateRoom>,
1485 session: UserSession,
1486) -> Result<()> {
1487 let live_kit_room = nanoid::nanoid!(30);
1488
1489 let live_kit_connection_info = util::maybe!(async {
1490 let live_kit = session.app_state.live_kit_client.as_ref();
1491 let live_kit = live_kit?;
1492 let user_id = session.user_id().to_string();
1493
1494 let token = live_kit
1495 .room_token(&live_kit_room, &user_id.to_string())
1496 .trace_err()?;
1497
1498 Some(proto::LiveKitConnectionInfo {
1499 server_url: live_kit.url().into(),
1500 token,
1501 can_publish: true,
1502 })
1503 })
1504 .await;
1505
1506 let room = session
1507 .db()
1508 .await
1509 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1510 .await?;
1511
1512 response.send(proto::CreateRoomResponse {
1513 room: Some(room.clone()),
1514 live_kit_connection_info,
1515 })?;
1516
1517 update_user_contacts(session.user_id(), &session).await?;
1518 Ok(())
1519}
1520
1521/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1522async fn join_room(
1523 request: proto::JoinRoom,
1524 response: Response<proto::JoinRoom>,
1525 session: UserSession,
1526) -> Result<()> {
1527 let room_id = RoomId::from_proto(request.id);
1528
1529 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1530
1531 if let Some(channel_id) = channel_id {
1532 return join_channel_internal(channel_id, Box::new(response), session).await;
1533 }
1534
1535 let joined_room = {
1536 let room = session
1537 .db()
1538 .await
1539 .join_room(room_id, session.user_id(), session.connection_id)
1540 .await?;
1541 room_updated(&room.room, &session.peer);
1542 room.into_inner()
1543 };
1544
1545 for connection_id in session
1546 .connection_pool()
1547 .await
1548 .user_connection_ids(session.user_id())
1549 {
1550 session
1551 .peer
1552 .send(
1553 connection_id,
1554 proto::CallCanceled {
1555 room_id: room_id.to_proto(),
1556 },
1557 )
1558 .trace_err();
1559 }
1560
1561 let live_kit_connection_info =
1562 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1563 live_kit
1564 .room_token(
1565 &joined_room.room.live_kit_room,
1566 &session.user_id().to_string(),
1567 )
1568 .trace_err()
1569 .map(|token| proto::LiveKitConnectionInfo {
1570 server_url: live_kit.url().into(),
1571 token,
1572 can_publish: true,
1573 })
1574 } else {
1575 None
1576 };
1577
1578 response.send(proto::JoinRoomResponse {
1579 room: Some(joined_room.room),
1580 channel_id: None,
1581 live_kit_connection_info,
1582 })?;
1583
1584 update_user_contacts(session.user_id(), &session).await?;
1585 Ok(())
1586}
1587
1588/// Rejoin room is used to reconnect to a room after connection errors.
1589async fn rejoin_room(
1590 request: proto::RejoinRoom,
1591 response: Response<proto::RejoinRoom>,
1592 session: UserSession,
1593) -> Result<()> {
1594 let room;
1595 let channel;
1596 {
1597 let mut rejoined_room = session
1598 .db()
1599 .await
1600 .rejoin_room(request, session.user_id(), session.connection_id)
1601 .await?;
1602
1603 response.send(proto::RejoinRoomResponse {
1604 room: Some(rejoined_room.room.clone()),
1605 reshared_projects: rejoined_room
1606 .reshared_projects
1607 .iter()
1608 .map(|project| proto::ResharedProject {
1609 id: project.id.to_proto(),
1610 collaborators: project
1611 .collaborators
1612 .iter()
1613 .map(|collaborator| collaborator.to_proto())
1614 .collect(),
1615 })
1616 .collect(),
1617 rejoined_projects: rejoined_room
1618 .rejoined_projects
1619 .iter()
1620 .map(|rejoined_project| rejoined_project.to_proto())
1621 .collect(),
1622 })?;
1623 room_updated(&rejoined_room.room, &session.peer);
1624
1625 for project in &rejoined_room.reshared_projects {
1626 for collaborator in &project.collaborators {
1627 session
1628 .peer
1629 .send(
1630 collaborator.connection_id,
1631 proto::UpdateProjectCollaborator {
1632 project_id: project.id.to_proto(),
1633 old_peer_id: Some(project.old_connection_id.into()),
1634 new_peer_id: Some(session.connection_id.into()),
1635 },
1636 )
1637 .trace_err();
1638 }
1639
1640 broadcast(
1641 Some(session.connection_id),
1642 project
1643 .collaborators
1644 .iter()
1645 .map(|collaborator| collaborator.connection_id),
1646 |connection_id| {
1647 session.peer.forward_send(
1648 session.connection_id,
1649 connection_id,
1650 proto::UpdateProject {
1651 project_id: project.id.to_proto(),
1652 worktrees: project.worktrees.clone(),
1653 },
1654 )
1655 },
1656 );
1657 }
1658
1659 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1660
1661 let rejoined_room = rejoined_room.into_inner();
1662
1663 room = rejoined_room.room;
1664 channel = rejoined_room.channel;
1665 }
1666
1667 if let Some(channel) = channel {
1668 channel_updated(
1669 &channel,
1670 &room,
1671 &session.peer,
1672 &*session.connection_pool().await,
1673 );
1674 }
1675
1676 update_user_contacts(session.user_id(), &session).await?;
1677 Ok(())
1678}
1679
1680fn notify_rejoined_projects(
1681 rejoined_projects: &mut Vec<RejoinedProject>,
1682 session: &UserSession,
1683) -> Result<()> {
1684 for project in rejoined_projects.iter() {
1685 for collaborator in &project.collaborators {
1686 session
1687 .peer
1688 .send(
1689 collaborator.connection_id,
1690 proto::UpdateProjectCollaborator {
1691 project_id: project.id.to_proto(),
1692 old_peer_id: Some(project.old_connection_id.into()),
1693 new_peer_id: Some(session.connection_id.into()),
1694 },
1695 )
1696 .trace_err();
1697 }
1698 }
1699
1700 for project in rejoined_projects {
1701 for worktree in mem::take(&mut project.worktrees) {
1702 #[cfg(any(test, feature = "test-support"))]
1703 const MAX_CHUNK_SIZE: usize = 2;
1704 #[cfg(not(any(test, feature = "test-support")))]
1705 const MAX_CHUNK_SIZE: usize = 256;
1706
1707 // Stream this worktree's entries.
1708 let message = proto::UpdateWorktree {
1709 project_id: project.id.to_proto(),
1710 worktree_id: worktree.id,
1711 abs_path: worktree.abs_path.clone(),
1712 root_name: worktree.root_name,
1713 updated_entries: worktree.updated_entries,
1714 removed_entries: worktree.removed_entries,
1715 scan_id: worktree.scan_id,
1716 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1717 updated_repositories: worktree.updated_repositories,
1718 removed_repositories: worktree.removed_repositories,
1719 };
1720 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1721 session.peer.send(session.connection_id, update.clone())?;
1722 }
1723
1724 // Stream this worktree's diagnostics.
1725 for summary in worktree.diagnostic_summaries {
1726 session.peer.send(
1727 session.connection_id,
1728 proto::UpdateDiagnosticSummary {
1729 project_id: project.id.to_proto(),
1730 worktree_id: worktree.id,
1731 summary: Some(summary),
1732 },
1733 )?;
1734 }
1735
1736 for settings_file in worktree.settings_files {
1737 session.peer.send(
1738 session.connection_id,
1739 proto::UpdateWorktreeSettings {
1740 project_id: project.id.to_proto(),
1741 worktree_id: worktree.id,
1742 path: settings_file.path,
1743 content: Some(settings_file.content),
1744 },
1745 )?;
1746 }
1747 }
1748
1749 for language_server in &project.language_servers {
1750 session.peer.send(
1751 session.connection_id,
1752 proto::UpdateLanguageServer {
1753 project_id: project.id.to_proto(),
1754 language_server_id: language_server.id,
1755 variant: Some(
1756 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1757 proto::LspDiskBasedDiagnosticsUpdated {},
1758 ),
1759 ),
1760 },
1761 )?;
1762 }
1763 }
1764 Ok(())
1765}
1766
1767/// leave room disconnects from the room.
1768async fn leave_room(
1769 _: proto::LeaveRoom,
1770 response: Response<proto::LeaveRoom>,
1771 session: UserSession,
1772) -> Result<()> {
1773 leave_room_for_session(&session, session.connection_id).await?;
1774 response.send(proto::Ack {})?;
1775 Ok(())
1776}
1777
1778/// Updates the permissions of someone else in the room.
1779async fn set_room_participant_role(
1780 request: proto::SetRoomParticipantRole,
1781 response: Response<proto::SetRoomParticipantRole>,
1782 session: UserSession,
1783) -> Result<()> {
1784 let user_id = UserId::from_proto(request.user_id);
1785 let role = ChannelRole::from(request.role());
1786
1787 let (live_kit_room, can_publish) = {
1788 let room = session
1789 .db()
1790 .await
1791 .set_room_participant_role(
1792 session.user_id(),
1793 RoomId::from_proto(request.room_id),
1794 user_id,
1795 role,
1796 )
1797 .await?;
1798
1799 let live_kit_room = room.live_kit_room.clone();
1800 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1801 room_updated(&room, &session.peer);
1802 (live_kit_room, can_publish)
1803 };
1804
1805 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1806 live_kit
1807 .update_participant(
1808 live_kit_room.clone(),
1809 request.user_id.to_string(),
1810 live_kit_server::proto::ParticipantPermission {
1811 can_subscribe: true,
1812 can_publish,
1813 can_publish_data: can_publish,
1814 hidden: false,
1815 recorder: false,
1816 },
1817 )
1818 .await
1819 .trace_err();
1820 }
1821
1822 response.send(proto::Ack {})?;
1823 Ok(())
1824}
1825
1826/// Call someone else into the current room
1827async fn call(
1828 request: proto::Call,
1829 response: Response<proto::Call>,
1830 session: UserSession,
1831) -> Result<()> {
1832 let room_id = RoomId::from_proto(request.room_id);
1833 let calling_user_id = session.user_id();
1834 let calling_connection_id = session.connection_id;
1835 let called_user_id = UserId::from_proto(request.called_user_id);
1836 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1837 if !session
1838 .db()
1839 .await
1840 .has_contact(calling_user_id, called_user_id)
1841 .await?
1842 {
1843 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1844 }
1845
1846 let incoming_call = {
1847 let (room, incoming_call) = &mut *session
1848 .db()
1849 .await
1850 .call(
1851 room_id,
1852 calling_user_id,
1853 calling_connection_id,
1854 called_user_id,
1855 initial_project_id,
1856 )
1857 .await?;
1858 room_updated(room, &session.peer);
1859 mem::take(incoming_call)
1860 };
1861 update_user_contacts(called_user_id, &session).await?;
1862
1863 let mut calls = session
1864 .connection_pool()
1865 .await
1866 .user_connection_ids(called_user_id)
1867 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1868 .collect::<FuturesUnordered<_>>();
1869
1870 while let Some(call_response) = calls.next().await {
1871 match call_response.as_ref() {
1872 Ok(_) => {
1873 response.send(proto::Ack {})?;
1874 return Ok(());
1875 }
1876 Err(_) => {
1877 call_response.trace_err();
1878 }
1879 }
1880 }
1881
1882 {
1883 let room = session
1884 .db()
1885 .await
1886 .call_failed(room_id, called_user_id)
1887 .await?;
1888 room_updated(&room, &session.peer);
1889 }
1890 update_user_contacts(called_user_id, &session).await?;
1891
1892 Err(anyhow!("failed to ring user"))?
1893}
1894
1895/// Cancel an outgoing call.
1896async fn cancel_call(
1897 request: proto::CancelCall,
1898 response: Response<proto::CancelCall>,
1899 session: UserSession,
1900) -> Result<()> {
1901 let called_user_id = UserId::from_proto(request.called_user_id);
1902 let room_id = RoomId::from_proto(request.room_id);
1903 {
1904 let room = session
1905 .db()
1906 .await
1907 .cancel_call(room_id, session.connection_id, called_user_id)
1908 .await?;
1909 room_updated(&room, &session.peer);
1910 }
1911
1912 for connection_id in session
1913 .connection_pool()
1914 .await
1915 .user_connection_ids(called_user_id)
1916 {
1917 session
1918 .peer
1919 .send(
1920 connection_id,
1921 proto::CallCanceled {
1922 room_id: room_id.to_proto(),
1923 },
1924 )
1925 .trace_err();
1926 }
1927 response.send(proto::Ack {})?;
1928
1929 update_user_contacts(called_user_id, &session).await?;
1930 Ok(())
1931}
1932
1933/// Decline an incoming call.
1934async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1935 let room_id = RoomId::from_proto(message.room_id);
1936 {
1937 let room = session
1938 .db()
1939 .await
1940 .decline_call(Some(room_id), session.user_id())
1941 .await?
1942 .ok_or_else(|| anyhow!("failed to decline call"))?;
1943 room_updated(&room, &session.peer);
1944 }
1945
1946 for connection_id in session
1947 .connection_pool()
1948 .await
1949 .user_connection_ids(session.user_id())
1950 {
1951 session
1952 .peer
1953 .send(
1954 connection_id,
1955 proto::CallCanceled {
1956 room_id: room_id.to_proto(),
1957 },
1958 )
1959 .trace_err();
1960 }
1961 update_user_contacts(session.user_id(), &session).await?;
1962 Ok(())
1963}
1964
1965/// Updates other participants in the room with your current location.
1966async fn update_participant_location(
1967 request: proto::UpdateParticipantLocation,
1968 response: Response<proto::UpdateParticipantLocation>,
1969 session: UserSession,
1970) -> Result<()> {
1971 let room_id = RoomId::from_proto(request.room_id);
1972 let location = request
1973 .location
1974 .ok_or_else(|| anyhow!("invalid location"))?;
1975
1976 let db = session.db().await;
1977 let room = db
1978 .update_room_participant_location(room_id, session.connection_id, location)
1979 .await?;
1980
1981 room_updated(&room, &session.peer);
1982 response.send(proto::Ack {})?;
1983 Ok(())
1984}
1985
1986/// Share a project into the room.
1987async fn share_project(
1988 request: proto::ShareProject,
1989 response: Response<proto::ShareProject>,
1990 session: UserSession,
1991) -> Result<()> {
1992 let (project_id, room) = &*session
1993 .db()
1994 .await
1995 .share_project(
1996 RoomId::from_proto(request.room_id),
1997 session.connection_id,
1998 &request.worktrees,
1999 request.is_ssh_project,
2000 request
2001 .dev_server_project_id
2002 .map(DevServerProjectId::from_proto),
2003 )
2004 .await?;
2005 response.send(proto::ShareProjectResponse {
2006 project_id: project_id.to_proto(),
2007 })?;
2008 room_updated(room, &session.peer);
2009
2010 Ok(())
2011}
2012
2013/// Unshare a project from the room.
2014async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2015 let project_id = ProjectId::from_proto(message.project_id);
2016 unshare_project_internal(
2017 project_id,
2018 session.connection_id,
2019 session.user_id(),
2020 &session,
2021 )
2022 .await
2023}
2024
2025async fn unshare_project_internal(
2026 project_id: ProjectId,
2027 connection_id: ConnectionId,
2028 user_id: Option<UserId>,
2029 session: &Session,
2030) -> Result<()> {
2031 let delete = {
2032 let room_guard = session
2033 .db()
2034 .await
2035 .unshare_project(project_id, connection_id, user_id)
2036 .await?;
2037
2038 let (delete, room, guest_connection_ids) = &*room_guard;
2039
2040 let message = proto::UnshareProject {
2041 project_id: project_id.to_proto(),
2042 };
2043
2044 broadcast(
2045 Some(connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |conn_id| session.peer.send(conn_id, message.clone()),
2048 );
2049 if let Some(room) = room {
2050 room_updated(room, &session.peer);
2051 }
2052
2053 *delete
2054 };
2055
2056 if delete {
2057 let db = session.db().await;
2058 db.delete_project(project_id).await?;
2059 }
2060
2061 Ok(())
2062}
2063
2064/// DevServer makes a project available online
2065async fn share_dev_server_project(
2066 request: proto::ShareDevServerProject,
2067 response: Response<proto::ShareDevServerProject>,
2068 session: DevServerSession,
2069) -> Result<()> {
2070 let (dev_server_project, user_id, status) = session
2071 .db()
2072 .await
2073 .share_dev_server_project(
2074 DevServerProjectId::from_proto(request.dev_server_project_id),
2075 session.dev_server_id(),
2076 session.connection_id,
2077 &request.worktrees,
2078 )
2079 .await?;
2080 let Some(project_id) = dev_server_project.project_id else {
2081 return Err(anyhow!("failed to share remote project"))?;
2082 };
2083
2084 send_dev_server_projects_update(user_id, status, &session).await;
2085
2086 response.send(proto::ShareProjectResponse { project_id })?;
2087
2088 Ok(())
2089}
2090
2091/// Join someone elses shared project.
2092async fn join_project(
2093 request: proto::JoinProject,
2094 response: Response<proto::JoinProject>,
2095 session: UserSession,
2096) -> Result<()> {
2097 let project_id = ProjectId::from_proto(request.project_id);
2098
2099 tracing::info!(%project_id, "join project");
2100
2101 let db = session.db().await;
2102 let (project, replica_id) = &mut *db
2103 .join_project(project_id, session.connection_id, session.user_id())
2104 .await?;
2105 drop(db);
2106 tracing::info!(%project_id, "join remote project");
2107 join_project_internal(response, session, project, replica_id)
2108}
2109
2110trait JoinProjectInternalResponse {
2111 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2112}
2113impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2114 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2115 Response::<proto::JoinProject>::send(self, result)
2116 }
2117}
2118impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2119 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2120 Response::<proto::JoinHostedProject>::send(self, result)
2121 }
2122}
2123
2124fn join_project_internal(
2125 response: impl JoinProjectInternalResponse,
2126 session: UserSession,
2127 project: &mut Project,
2128 replica_id: &ReplicaId,
2129) -> Result<()> {
2130 let collaborators = project
2131 .collaborators
2132 .iter()
2133 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2134 .map(|collaborator| collaborator.to_proto())
2135 .collect::<Vec<_>>();
2136 let project_id = project.id;
2137 let guest_user_id = session.user_id();
2138
2139 let worktrees = project
2140 .worktrees
2141 .iter()
2142 .map(|(id, worktree)| proto::WorktreeMetadata {
2143 id: *id,
2144 root_name: worktree.root_name.clone(),
2145 visible: worktree.visible,
2146 abs_path: worktree.abs_path.clone(),
2147 })
2148 .collect::<Vec<_>>();
2149
2150 let add_project_collaborator = proto::AddProjectCollaborator {
2151 project_id: project_id.to_proto(),
2152 collaborator: Some(proto::Collaborator {
2153 peer_id: Some(session.connection_id.into()),
2154 replica_id: replica_id.0 as u32,
2155 user_id: guest_user_id.to_proto(),
2156 }),
2157 };
2158
2159 for collaborator in &collaborators {
2160 session
2161 .peer
2162 .send(
2163 collaborator.peer_id.unwrap().into(),
2164 add_project_collaborator.clone(),
2165 )
2166 .trace_err();
2167 }
2168
2169 // First, we send the metadata associated with each worktree.
2170 response.send(proto::JoinProjectResponse {
2171 project_id: project.id.0 as u64,
2172 worktrees: worktrees.clone(),
2173 replica_id: replica_id.0 as u32,
2174 collaborators: collaborators.clone(),
2175 language_servers: project.language_servers.clone(),
2176 role: project.role.into(),
2177 dev_server_project_id: project
2178 .dev_server_project_id
2179 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2180 })?;
2181
2182 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2183 #[cfg(any(test, feature = "test-support"))]
2184 const MAX_CHUNK_SIZE: usize = 2;
2185 #[cfg(not(any(test, feature = "test-support")))]
2186 const MAX_CHUNK_SIZE: usize = 256;
2187
2188 // Stream this worktree's entries.
2189 let message = proto::UpdateWorktree {
2190 project_id: project_id.to_proto(),
2191 worktree_id,
2192 abs_path: worktree.abs_path.clone(),
2193 root_name: worktree.root_name,
2194 updated_entries: worktree.entries,
2195 removed_entries: Default::default(),
2196 scan_id: worktree.scan_id,
2197 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2198 updated_repositories: worktree.repository_entries.into_values().collect(),
2199 removed_repositories: Default::default(),
2200 };
2201 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2202 session.peer.send(session.connection_id, update.clone())?;
2203 }
2204
2205 // Stream this worktree's diagnostics.
2206 for summary in worktree.diagnostic_summaries {
2207 session.peer.send(
2208 session.connection_id,
2209 proto::UpdateDiagnosticSummary {
2210 project_id: project_id.to_proto(),
2211 worktree_id: worktree.id,
2212 summary: Some(summary),
2213 },
2214 )?;
2215 }
2216
2217 for settings_file in worktree.settings_files {
2218 session.peer.send(
2219 session.connection_id,
2220 proto::UpdateWorktreeSettings {
2221 project_id: project_id.to_proto(),
2222 worktree_id: worktree.id,
2223 path: settings_file.path,
2224 content: Some(settings_file.content),
2225 },
2226 )?;
2227 }
2228 }
2229
2230 for language_server in &project.language_servers {
2231 session.peer.send(
2232 session.connection_id,
2233 proto::UpdateLanguageServer {
2234 project_id: project_id.to_proto(),
2235 language_server_id: language_server.id,
2236 variant: Some(
2237 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2238 proto::LspDiskBasedDiagnosticsUpdated {},
2239 ),
2240 ),
2241 },
2242 )?;
2243 }
2244
2245 Ok(())
2246}
2247
2248/// Leave someone elses shared project.
2249async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2250 let sender_id = session.connection_id;
2251 let project_id = ProjectId::from_proto(request.project_id);
2252 let db = session.db().await;
2253 if db.is_hosted_project(project_id).await? {
2254 let project = db.leave_hosted_project(project_id, sender_id).await?;
2255 project_left(&project, &session);
2256 return Ok(());
2257 }
2258
2259 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2260 tracing::info!(
2261 %project_id,
2262 "leave project"
2263 );
2264
2265 project_left(project, &session);
2266 if let Some(room) = room {
2267 room_updated(room, &session.peer);
2268 }
2269
2270 Ok(())
2271}
2272
2273async fn join_hosted_project(
2274 request: proto::JoinHostedProject,
2275 response: Response<proto::JoinHostedProject>,
2276 session: UserSession,
2277) -> Result<()> {
2278 let (mut project, replica_id) = session
2279 .db()
2280 .await
2281 .join_hosted_project(
2282 ProjectId(request.project_id as i32),
2283 session.user_id(),
2284 session.connection_id,
2285 )
2286 .await?;
2287
2288 join_project_internal(response, session, &mut project, &replica_id)
2289}
2290
2291async fn list_remote_directory(
2292 request: proto::ListRemoteDirectory,
2293 response: Response<proto::ListRemoteDirectory>,
2294 session: UserSession,
2295) -> Result<()> {
2296 let dev_server_id = DevServerId(request.dev_server_id as i32);
2297 let dev_server_connection_id = session
2298 .connection_pool()
2299 .await
2300 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2301
2302 session
2303 .db()
2304 .await
2305 .get_dev_server_for_user(dev_server_id, session.user_id())
2306 .await?;
2307
2308 response.send(
2309 session
2310 .peer
2311 .forward_request(session.connection_id, dev_server_connection_id, request)
2312 .await?,
2313 )?;
2314 Ok(())
2315}
2316
2317async fn update_dev_server_project(
2318 request: proto::UpdateDevServerProject,
2319 response: Response<proto::UpdateDevServerProject>,
2320 session: UserSession,
2321) -> Result<()> {
2322 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2323
2324 let (dev_server_project, update) = session
2325 .db()
2326 .await
2327 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2328 .await?;
2329
2330 let projects = session
2331 .db()
2332 .await
2333 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2334 .await?;
2335
2336 let dev_server_connection_id = session
2337 .connection_pool()
2338 .await
2339 .dev_server_connection_id_supporting(
2340 dev_server_project.dev_server_id,
2341 ZedVersion::with_list_directory(),
2342 )?;
2343
2344 session.peer.send(
2345 dev_server_connection_id,
2346 proto::DevServerInstructions { projects },
2347 )?;
2348
2349 send_dev_server_projects_update(session.user_id(), update, &session).await;
2350
2351 response.send(proto::Ack {})
2352}
2353
2354async fn create_dev_server_project(
2355 request: proto::CreateDevServerProject,
2356 response: Response<proto::CreateDevServerProject>,
2357 session: UserSession,
2358) -> Result<()> {
2359 let dev_server_id = DevServerId(request.dev_server_id as i32);
2360 let dev_server_connection_id = session
2361 .connection_pool()
2362 .await
2363 .dev_server_connection_id(dev_server_id);
2364 let Some(dev_server_connection_id) = dev_server_connection_id else {
2365 Err(ErrorCode::DevServerOffline
2366 .message("Cannot create a remote project when the dev server is offline".to_string())
2367 .anyhow())?
2368 };
2369
2370 let path = request.path.clone();
2371 //Check that the path exists on the dev server
2372 session
2373 .peer
2374 .forward_request(
2375 session.connection_id,
2376 dev_server_connection_id,
2377 proto::ValidateDevServerProjectRequest { path: path.clone() },
2378 )
2379 .await?;
2380
2381 let (dev_server_project, update) = session
2382 .db()
2383 .await
2384 .create_dev_server_project(
2385 DevServerId(request.dev_server_id as i32),
2386 &request.path,
2387 session.user_id(),
2388 )
2389 .await?;
2390
2391 let projects = session
2392 .db()
2393 .await
2394 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2395 .await?;
2396
2397 session.peer.send(
2398 dev_server_connection_id,
2399 proto::DevServerInstructions { projects },
2400 )?;
2401
2402 send_dev_server_projects_update(session.user_id(), update, &session).await;
2403
2404 response.send(proto::CreateDevServerProjectResponse {
2405 dev_server_project: Some(dev_server_project.to_proto(None)),
2406 })?;
2407 Ok(())
2408}
2409
2410async fn create_dev_server(
2411 request: proto::CreateDevServer,
2412 response: Response<proto::CreateDevServer>,
2413 session: UserSession,
2414) -> Result<()> {
2415 let access_token = auth::random_token();
2416 let hashed_access_token = auth::hash_access_token(&access_token);
2417
2418 if request.name.is_empty() {
2419 return Err(proto::ErrorCode::Forbidden
2420 .message("Dev server name cannot be empty".to_string())
2421 .anyhow())?;
2422 }
2423
2424 let (dev_server, status) = session
2425 .db()
2426 .await
2427 .create_dev_server(
2428 &request.name,
2429 request.ssh_connection_string.as_deref(),
2430 &hashed_access_token,
2431 session.user_id(),
2432 )
2433 .await?;
2434
2435 send_dev_server_projects_update(session.user_id(), status, &session).await;
2436
2437 response.send(proto::CreateDevServerResponse {
2438 dev_server_id: dev_server.id.0 as u64,
2439 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2440 name: request.name,
2441 })?;
2442 Ok(())
2443}
2444
2445async fn regenerate_dev_server_token(
2446 request: proto::RegenerateDevServerToken,
2447 response: Response<proto::RegenerateDevServerToken>,
2448 session: UserSession,
2449) -> Result<()> {
2450 let dev_server_id = DevServerId(request.dev_server_id as i32);
2451 let access_token = auth::random_token();
2452 let hashed_access_token = auth::hash_access_token(&access_token);
2453
2454 let connection_id = session
2455 .connection_pool()
2456 .await
2457 .dev_server_connection_id(dev_server_id);
2458 if let Some(connection_id) = connection_id {
2459 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2460 session.peer.send(
2461 connection_id,
2462 proto::ShutdownDevServer {
2463 reason: Some("dev server token was regenerated".to_string()),
2464 },
2465 )?;
2466 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2467 }
2468
2469 let status = session
2470 .db()
2471 .await
2472 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2473 .await?;
2474
2475 send_dev_server_projects_update(session.user_id(), status, &session).await;
2476
2477 response.send(proto::RegenerateDevServerTokenResponse {
2478 dev_server_id: dev_server_id.to_proto(),
2479 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2480 })?;
2481 Ok(())
2482}
2483
2484async fn rename_dev_server(
2485 request: proto::RenameDevServer,
2486 response: Response<proto::RenameDevServer>,
2487 session: UserSession,
2488) -> Result<()> {
2489 if request.name.trim().is_empty() {
2490 return Err(proto::ErrorCode::Forbidden
2491 .message("Dev server name cannot be empty".to_string())
2492 .anyhow())?;
2493 }
2494
2495 let dev_server_id = DevServerId(request.dev_server_id as i32);
2496 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2497 if dev_server.user_id != session.user_id() {
2498 return Err(anyhow!(ErrorCode::Forbidden))?;
2499 }
2500
2501 let status = session
2502 .db()
2503 .await
2504 .rename_dev_server(
2505 dev_server_id,
2506 &request.name,
2507 request.ssh_connection_string.as_deref(),
2508 session.user_id(),
2509 )
2510 .await?;
2511
2512 send_dev_server_projects_update(session.user_id(), status, &session).await;
2513
2514 response.send(proto::Ack {})?;
2515 Ok(())
2516}
2517
2518async fn delete_dev_server(
2519 request: proto::DeleteDevServer,
2520 response: Response<proto::DeleteDevServer>,
2521 session: UserSession,
2522) -> Result<()> {
2523 let dev_server_id = DevServerId(request.dev_server_id as i32);
2524 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2525 if dev_server.user_id != session.user_id() {
2526 return Err(anyhow!(ErrorCode::Forbidden))?;
2527 }
2528
2529 let connection_id = session
2530 .connection_pool()
2531 .await
2532 .dev_server_connection_id(dev_server_id);
2533 if let Some(connection_id) = connection_id {
2534 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2535 session.peer.send(
2536 connection_id,
2537 proto::ShutdownDevServer {
2538 reason: Some("dev server was deleted".to_string()),
2539 },
2540 )?;
2541 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2542 }
2543
2544 let status = session
2545 .db()
2546 .await
2547 .delete_dev_server(dev_server_id, session.user_id())
2548 .await?;
2549
2550 send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552 response.send(proto::Ack {})?;
2553 Ok(())
2554}
2555
2556async fn delete_dev_server_project(
2557 request: proto::DeleteDevServerProject,
2558 response: Response<proto::DeleteDevServerProject>,
2559 session: UserSession,
2560) -> Result<()> {
2561 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2562 let dev_server_project = session
2563 .db()
2564 .await
2565 .get_dev_server_project(dev_server_project_id)
2566 .await?;
2567
2568 let dev_server = session
2569 .db()
2570 .await
2571 .get_dev_server(dev_server_project.dev_server_id)
2572 .await?;
2573 if dev_server.user_id != session.user_id() {
2574 return Err(anyhow!(ErrorCode::Forbidden))?;
2575 }
2576
2577 let dev_server_connection_id = session
2578 .connection_pool()
2579 .await
2580 .dev_server_connection_id(dev_server.id);
2581
2582 if let Some(dev_server_connection_id) = dev_server_connection_id {
2583 let project = session
2584 .db()
2585 .await
2586 .find_dev_server_project(dev_server_project_id)
2587 .await;
2588 if let Ok(project) = project {
2589 unshare_project_internal(
2590 project.id,
2591 dev_server_connection_id,
2592 Some(session.user_id()),
2593 &session,
2594 )
2595 .await?;
2596 }
2597 }
2598
2599 let (projects, status) = session
2600 .db()
2601 .await
2602 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2603 .await?;
2604
2605 if let Some(dev_server_connection_id) = dev_server_connection_id {
2606 session.peer.send(
2607 dev_server_connection_id,
2608 proto::DevServerInstructions { projects },
2609 )?;
2610 }
2611
2612 send_dev_server_projects_update(session.user_id(), status, &session).await;
2613
2614 response.send(proto::Ack {})?;
2615 Ok(())
2616}
2617
2618async fn rejoin_dev_server_projects(
2619 request: proto::RejoinRemoteProjects,
2620 response: Response<proto::RejoinRemoteProjects>,
2621 session: UserSession,
2622) -> Result<()> {
2623 let mut rejoined_projects = {
2624 let db = session.db().await;
2625 db.rejoin_dev_server_projects(
2626 &request.rejoined_projects,
2627 session.user_id(),
2628 session.0.connection_id,
2629 )
2630 .await?
2631 };
2632 response.send(proto::RejoinRemoteProjectsResponse {
2633 rejoined_projects: rejoined_projects
2634 .iter()
2635 .map(|project| project.to_proto())
2636 .collect(),
2637 })?;
2638 notify_rejoined_projects(&mut rejoined_projects, &session)
2639}
2640
2641async fn reconnect_dev_server(
2642 request: proto::ReconnectDevServer,
2643 response: Response<proto::ReconnectDevServer>,
2644 session: DevServerSession,
2645) -> Result<()> {
2646 let reshared_projects = {
2647 let db = session.db().await;
2648 db.reshare_dev_server_projects(
2649 &request.reshared_projects,
2650 session.dev_server_id(),
2651 session.0.connection_id,
2652 )
2653 .await?
2654 };
2655
2656 for project in &reshared_projects {
2657 for collaborator in &project.collaborators {
2658 session
2659 .peer
2660 .send(
2661 collaborator.connection_id,
2662 proto::UpdateProjectCollaborator {
2663 project_id: project.id.to_proto(),
2664 old_peer_id: Some(project.old_connection_id.into()),
2665 new_peer_id: Some(session.connection_id.into()),
2666 },
2667 )
2668 .trace_err();
2669 }
2670
2671 broadcast(
2672 Some(session.connection_id),
2673 project
2674 .collaborators
2675 .iter()
2676 .map(|collaborator| collaborator.connection_id),
2677 |connection_id| {
2678 session.peer.forward_send(
2679 session.connection_id,
2680 connection_id,
2681 proto::UpdateProject {
2682 project_id: project.id.to_proto(),
2683 worktrees: project.worktrees.clone(),
2684 },
2685 )
2686 },
2687 );
2688 }
2689
2690 response.send(proto::ReconnectDevServerResponse {
2691 reshared_projects: reshared_projects
2692 .iter()
2693 .map(|project| proto::ResharedProject {
2694 id: project.id.to_proto(),
2695 collaborators: project
2696 .collaborators
2697 .iter()
2698 .map(|collaborator| collaborator.to_proto())
2699 .collect(),
2700 })
2701 .collect(),
2702 })?;
2703
2704 Ok(())
2705}
2706
2707async fn shutdown_dev_server(
2708 _: proto::ShutdownDevServer,
2709 response: Response<proto::ShutdownDevServer>,
2710 session: DevServerSession,
2711) -> Result<()> {
2712 response.send(proto::Ack {})?;
2713 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2714 remove_dev_server_connection(session.dev_server_id(), &session).await
2715}
2716
2717async fn shutdown_dev_server_internal(
2718 dev_server_id: DevServerId,
2719 connection_id: ConnectionId,
2720 session: &Session,
2721) -> Result<()> {
2722 let (dev_server_projects, dev_server) = {
2723 let db = session.db().await;
2724 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2725 let dev_server = db.get_dev_server(dev_server_id).await?;
2726 (dev_server_projects, dev_server)
2727 };
2728
2729 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2730 unshare_project_internal(
2731 ProjectId::from_proto(project_id),
2732 connection_id,
2733 None,
2734 session,
2735 )
2736 .await?;
2737 }
2738
2739 session
2740 .connection_pool()
2741 .await
2742 .set_dev_server_offline(dev_server_id);
2743
2744 let status = session
2745 .db()
2746 .await
2747 .dev_server_projects_update(dev_server.user_id)
2748 .await?;
2749 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2750
2751 Ok(())
2752}
2753
2754async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2755 let dev_server_connection = session
2756 .connection_pool()
2757 .await
2758 .dev_server_connection_id(dev_server_id);
2759
2760 if let Some(dev_server_connection) = dev_server_connection {
2761 session
2762 .connection_pool()
2763 .await
2764 .remove_connection(dev_server_connection)?;
2765 }
2766 Ok(())
2767}
2768
2769/// Updates other participants with changes to the project
2770async fn update_project(
2771 request: proto::UpdateProject,
2772 response: Response<proto::UpdateProject>,
2773 session: Session,
2774) -> Result<()> {
2775 let project_id = ProjectId::from_proto(request.project_id);
2776 let (room, guest_connection_ids) = &*session
2777 .db()
2778 .await
2779 .update_project(project_id, session.connection_id, &request.worktrees)
2780 .await?;
2781 broadcast(
2782 Some(session.connection_id),
2783 guest_connection_ids.iter().copied(),
2784 |connection_id| {
2785 session
2786 .peer
2787 .forward_send(session.connection_id, connection_id, request.clone())
2788 },
2789 );
2790 if let Some(room) = room {
2791 room_updated(room, &session.peer);
2792 }
2793 response.send(proto::Ack {})?;
2794
2795 Ok(())
2796}
2797
2798/// Updates other participants with changes to the worktree
2799async fn update_worktree(
2800 request: proto::UpdateWorktree,
2801 response: Response<proto::UpdateWorktree>,
2802 session: Session,
2803) -> Result<()> {
2804 let guest_connection_ids = session
2805 .db()
2806 .await
2807 .update_worktree(&request, session.connection_id)
2808 .await?;
2809
2810 broadcast(
2811 Some(session.connection_id),
2812 guest_connection_ids.iter().copied(),
2813 |connection_id| {
2814 session
2815 .peer
2816 .forward_send(session.connection_id, connection_id, request.clone())
2817 },
2818 );
2819 response.send(proto::Ack {})?;
2820 Ok(())
2821}
2822
2823/// Updates other participants with changes to the diagnostics
2824async fn update_diagnostic_summary(
2825 message: proto::UpdateDiagnosticSummary,
2826 session: Session,
2827) -> Result<()> {
2828 let guest_connection_ids = session
2829 .db()
2830 .await
2831 .update_diagnostic_summary(&message, session.connection_id)
2832 .await?;
2833
2834 broadcast(
2835 Some(session.connection_id),
2836 guest_connection_ids.iter().copied(),
2837 |connection_id| {
2838 session
2839 .peer
2840 .forward_send(session.connection_id, connection_id, message.clone())
2841 },
2842 );
2843
2844 Ok(())
2845}
2846
2847/// Updates other participants with changes to the worktree settings
2848async fn update_worktree_settings(
2849 message: proto::UpdateWorktreeSettings,
2850 session: Session,
2851) -> Result<()> {
2852 let guest_connection_ids = session
2853 .db()
2854 .await
2855 .update_worktree_settings(&message, session.connection_id)
2856 .await?;
2857
2858 broadcast(
2859 Some(session.connection_id),
2860 guest_connection_ids.iter().copied(),
2861 |connection_id| {
2862 session
2863 .peer
2864 .forward_send(session.connection_id, connection_id, message.clone())
2865 },
2866 );
2867
2868 Ok(())
2869}
2870
2871/// Notify other participants that a language server has started.
2872async fn start_language_server(
2873 request: proto::StartLanguageServer,
2874 session: Session,
2875) -> Result<()> {
2876 let guest_connection_ids = session
2877 .db()
2878 .await
2879 .start_language_server(&request, session.connection_id)
2880 .await?;
2881
2882 broadcast(
2883 Some(session.connection_id),
2884 guest_connection_ids.iter().copied(),
2885 |connection_id| {
2886 session
2887 .peer
2888 .forward_send(session.connection_id, connection_id, request.clone())
2889 },
2890 );
2891 Ok(())
2892}
2893
2894/// Notify other participants that a language server has changed.
2895async fn update_language_server(
2896 request: proto::UpdateLanguageServer,
2897 session: Session,
2898) -> Result<()> {
2899 let project_id = ProjectId::from_proto(request.project_id);
2900 let project_connection_ids = session
2901 .db()
2902 .await
2903 .project_connection_ids(project_id, session.connection_id, true)
2904 .await?;
2905 broadcast(
2906 Some(session.connection_id),
2907 project_connection_ids.iter().copied(),
2908 |connection_id| {
2909 session
2910 .peer
2911 .forward_send(session.connection_id, connection_id, request.clone())
2912 },
2913 );
2914 Ok(())
2915}
2916
2917/// forward a project request to the host. These requests should be read only
2918/// as guests are allowed to send them.
2919async fn forward_read_only_project_request<T>(
2920 request: T,
2921 response: Response<T>,
2922 session: UserSession,
2923) -> Result<()>
2924where
2925 T: EntityMessage + RequestMessage,
2926{
2927 let project_id = ProjectId::from_proto(request.remote_entity_id());
2928 let host_connection_id = session
2929 .db()
2930 .await
2931 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2932 .await?;
2933 let payload = session
2934 .peer
2935 .forward_request(session.connection_id, host_connection_id, request)
2936 .await?;
2937 response.send(payload)?;
2938 Ok(())
2939}
2940
2941async fn forward_find_search_candidates_request(
2942 request: proto::FindSearchCandidates,
2943 response: Response<proto::FindSearchCandidates>,
2944 session: UserSession,
2945) -> Result<()> {
2946 let project_id = ProjectId::from_proto(request.remote_entity_id());
2947 let host_connection_id = session
2948 .db()
2949 .await
2950 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2951 .await?;
2952
2953 let host_version = session
2954 .connection_pool()
2955 .await
2956 .connection(host_connection_id)
2957 .map(|c| c.zed_version);
2958
2959 if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2960 {
2961 let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2962 let search = proto::SearchProject {
2963 project_id: project_id.to_proto(),
2964 query: query.query,
2965 regex: query.regex,
2966 whole_word: query.whole_word,
2967 case_sensitive: query.case_sensitive,
2968 files_to_include: query.files_to_include,
2969 files_to_exclude: query.files_to_exclude,
2970 include_ignored: query.include_ignored,
2971 };
2972
2973 let payload = session
2974 .peer
2975 .forward_request(session.connection_id, host_connection_id, search)
2976 .await?;
2977 return response.send(proto::FindSearchCandidatesResponse {
2978 buffer_ids: payload
2979 .locations
2980 .into_iter()
2981 .map(|loc| loc.buffer_id)
2982 .collect(),
2983 });
2984 }
2985
2986 let payload = session
2987 .peer
2988 .forward_request(session.connection_id, host_connection_id, request)
2989 .await?;
2990 response.send(payload)?;
2991 Ok(())
2992}
2993
2994/// forward a project request to the dev server. Only allowed
2995/// if it's your dev server.
2996async fn forward_project_request_for_owner<T>(
2997 request: T,
2998 response: Response<T>,
2999 session: UserSession,
3000) -> Result<()>
3001where
3002 T: EntityMessage + RequestMessage,
3003{
3004 let project_id = ProjectId::from_proto(request.remote_entity_id());
3005
3006 let host_connection_id = session
3007 .db()
3008 .await
3009 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3010 .await?;
3011 let payload = session
3012 .peer
3013 .forward_request(session.connection_id, host_connection_id, request)
3014 .await?;
3015 response.send(payload)?;
3016 Ok(())
3017}
3018
3019/// forward a project request to the host. These requests are disallowed
3020/// for guests.
3021async fn forward_mutating_project_request<T>(
3022 request: T,
3023 response: Response<T>,
3024 session: UserSession,
3025) -> Result<()>
3026where
3027 T: EntityMessage + RequestMessage,
3028{
3029 let project_id = ProjectId::from_proto(request.remote_entity_id());
3030
3031 let host_connection_id = session
3032 .db()
3033 .await
3034 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3035 .await?;
3036 let payload = session
3037 .peer
3038 .forward_request(session.connection_id, host_connection_id, request)
3039 .await?;
3040 response.send(payload)?;
3041 Ok(())
3042}
3043
3044/// Notify other participants that a new buffer has been created
3045async fn create_buffer_for_peer(
3046 request: proto::CreateBufferForPeer,
3047 session: Session,
3048) -> Result<()> {
3049 session
3050 .db()
3051 .await
3052 .check_user_is_project_host(
3053 ProjectId::from_proto(request.project_id),
3054 session.connection_id,
3055 )
3056 .await?;
3057 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3058 session
3059 .peer
3060 .forward_send(session.connection_id, peer_id.into(), request)?;
3061 Ok(())
3062}
3063
3064/// Notify other participants that a buffer has been updated. This is
3065/// allowed for guests as long as the update is limited to selections.
3066async fn update_buffer(
3067 request: proto::UpdateBuffer,
3068 response: Response<proto::UpdateBuffer>,
3069 session: Session,
3070) -> Result<()> {
3071 let project_id = ProjectId::from_proto(request.project_id);
3072 let mut capability = Capability::ReadOnly;
3073
3074 for op in request.operations.iter() {
3075 match op.variant {
3076 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3077 Some(_) => capability = Capability::ReadWrite,
3078 }
3079 }
3080
3081 let host = {
3082 let guard = session
3083 .db()
3084 .await
3085 .connections_for_buffer_update(
3086 project_id,
3087 session.principal_id(),
3088 session.connection_id,
3089 capability,
3090 )
3091 .await?;
3092
3093 let (host, guests) = &*guard;
3094
3095 broadcast(
3096 Some(session.connection_id),
3097 guests.clone(),
3098 |connection_id| {
3099 session
3100 .peer
3101 .forward_send(session.connection_id, connection_id, request.clone())
3102 },
3103 );
3104
3105 *host
3106 };
3107
3108 if host != session.connection_id {
3109 session
3110 .peer
3111 .forward_request(session.connection_id, host, request.clone())
3112 .await?;
3113 }
3114
3115 response.send(proto::Ack {})?;
3116 Ok(())
3117}
3118
3119async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3120 let project_id = ProjectId::from_proto(message.project_id);
3121
3122 let operation = message.operation.as_ref().context("invalid operation")?;
3123 let capability = match operation.variant.as_ref() {
3124 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3125 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3126 match buffer_op.variant {
3127 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3128 Capability::ReadOnly
3129 }
3130 _ => Capability::ReadWrite,
3131 }
3132 } else {
3133 Capability::ReadWrite
3134 }
3135 }
3136 Some(_) => Capability::ReadWrite,
3137 None => Capability::ReadOnly,
3138 };
3139
3140 let guard = session
3141 .db()
3142 .await
3143 .connections_for_buffer_update(
3144 project_id,
3145 session.principal_id(),
3146 session.connection_id,
3147 capability,
3148 )
3149 .await?;
3150
3151 let (host, guests) = &*guard;
3152
3153 broadcast(
3154 Some(session.connection_id),
3155 guests.iter().chain([host]).copied(),
3156 |connection_id| {
3157 session
3158 .peer
3159 .forward_send(session.connection_id, connection_id, message.clone())
3160 },
3161 );
3162
3163 Ok(())
3164}
3165
3166/// Notify other participants that a project has been updated.
3167async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3168 request: T,
3169 session: Session,
3170) -> Result<()> {
3171 let project_id = ProjectId::from_proto(request.remote_entity_id());
3172 let project_connection_ids = session
3173 .db()
3174 .await
3175 .project_connection_ids(project_id, session.connection_id, false)
3176 .await?;
3177
3178 broadcast(
3179 Some(session.connection_id),
3180 project_connection_ids.iter().copied(),
3181 |connection_id| {
3182 session
3183 .peer
3184 .forward_send(session.connection_id, connection_id, request.clone())
3185 },
3186 );
3187 Ok(())
3188}
3189
3190/// Start following another user in a call.
3191async fn follow(
3192 request: proto::Follow,
3193 response: Response<proto::Follow>,
3194 session: UserSession,
3195) -> Result<()> {
3196 let room_id = RoomId::from_proto(request.room_id);
3197 let project_id = request.project_id.map(ProjectId::from_proto);
3198 let leader_id = request
3199 .leader_id
3200 .ok_or_else(|| anyhow!("invalid leader id"))?
3201 .into();
3202 let follower_id = session.connection_id;
3203
3204 session
3205 .db()
3206 .await
3207 .check_room_participants(room_id, leader_id, session.connection_id)
3208 .await?;
3209
3210 let response_payload = session
3211 .peer
3212 .forward_request(session.connection_id, leader_id, request)
3213 .await?;
3214 response.send(response_payload)?;
3215
3216 if let Some(project_id) = project_id {
3217 let room = session
3218 .db()
3219 .await
3220 .follow(room_id, project_id, leader_id, follower_id)
3221 .await?;
3222 room_updated(&room, &session.peer);
3223 }
3224
3225 Ok(())
3226}
3227
3228/// Stop following another user in a call.
3229async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3230 let room_id = RoomId::from_proto(request.room_id);
3231 let project_id = request.project_id.map(ProjectId::from_proto);
3232 let leader_id = request
3233 .leader_id
3234 .ok_or_else(|| anyhow!("invalid leader id"))?
3235 .into();
3236 let follower_id = session.connection_id;
3237
3238 session
3239 .db()
3240 .await
3241 .check_room_participants(room_id, leader_id, session.connection_id)
3242 .await?;
3243
3244 session
3245 .peer
3246 .forward_send(session.connection_id, leader_id, request)?;
3247
3248 if let Some(project_id) = project_id {
3249 let room = session
3250 .db()
3251 .await
3252 .unfollow(room_id, project_id, leader_id, follower_id)
3253 .await?;
3254 room_updated(&room, &session.peer);
3255 }
3256
3257 Ok(())
3258}
3259
3260/// Notify everyone following you of your current location.
3261async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3262 let room_id = RoomId::from_proto(request.room_id);
3263 let database = session.db.lock().await;
3264
3265 let connection_ids = if let Some(project_id) = request.project_id {
3266 let project_id = ProjectId::from_proto(project_id);
3267 database
3268 .project_connection_ids(project_id, session.connection_id, true)
3269 .await?
3270 } else {
3271 database
3272 .room_connection_ids(room_id, session.connection_id)
3273 .await?
3274 };
3275
3276 // For now, don't send view update messages back to that view's current leader.
3277 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3278 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3279 _ => None,
3280 });
3281
3282 for connection_id in connection_ids.iter().cloned() {
3283 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3284 session
3285 .peer
3286 .forward_send(session.connection_id, connection_id, request.clone())?;
3287 }
3288 }
3289 Ok(())
3290}
3291
3292/// Get public data about users.
3293async fn get_users(
3294 request: proto::GetUsers,
3295 response: Response<proto::GetUsers>,
3296 session: Session,
3297) -> Result<()> {
3298 let user_ids = request
3299 .user_ids
3300 .into_iter()
3301 .map(UserId::from_proto)
3302 .collect();
3303 let users = session
3304 .db()
3305 .await
3306 .get_users_by_ids(user_ids)
3307 .await?
3308 .into_iter()
3309 .map(|user| proto::User {
3310 id: user.id.to_proto(),
3311 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3312 github_login: user.github_login,
3313 })
3314 .collect();
3315 response.send(proto::UsersResponse { users })?;
3316 Ok(())
3317}
3318
3319/// Search for users (to invite) buy Github login
3320async fn fuzzy_search_users(
3321 request: proto::FuzzySearchUsers,
3322 response: Response<proto::FuzzySearchUsers>,
3323 session: UserSession,
3324) -> Result<()> {
3325 let query = request.query;
3326 let users = match query.len() {
3327 0 => vec![],
3328 1 | 2 => session
3329 .db()
3330 .await
3331 .get_user_by_github_login(&query)
3332 .await?
3333 .into_iter()
3334 .collect(),
3335 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3336 };
3337 let users = users
3338 .into_iter()
3339 .filter(|user| user.id != session.user_id())
3340 .map(|user| proto::User {
3341 id: user.id.to_proto(),
3342 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3343 github_login: user.github_login,
3344 })
3345 .collect();
3346 response.send(proto::UsersResponse { users })?;
3347 Ok(())
3348}
3349
3350/// Send a contact request to another user.
3351async fn request_contact(
3352 request: proto::RequestContact,
3353 response: Response<proto::RequestContact>,
3354 session: UserSession,
3355) -> Result<()> {
3356 let requester_id = session.user_id();
3357 let responder_id = UserId::from_proto(request.responder_id);
3358 if requester_id == responder_id {
3359 return Err(anyhow!("cannot add yourself as a contact"))?;
3360 }
3361
3362 let notifications = session
3363 .db()
3364 .await
3365 .send_contact_request(requester_id, responder_id)
3366 .await?;
3367
3368 // Update outgoing contact requests of requester
3369 let mut update = proto::UpdateContacts::default();
3370 update.outgoing_requests.push(responder_id.to_proto());
3371 for connection_id in session
3372 .connection_pool()
3373 .await
3374 .user_connection_ids(requester_id)
3375 {
3376 session.peer.send(connection_id, update.clone())?;
3377 }
3378
3379 // Update incoming contact requests of responder
3380 let mut update = proto::UpdateContacts::default();
3381 update
3382 .incoming_requests
3383 .push(proto::IncomingContactRequest {
3384 requester_id: requester_id.to_proto(),
3385 });
3386 let connection_pool = session.connection_pool().await;
3387 for connection_id in connection_pool.user_connection_ids(responder_id) {
3388 session.peer.send(connection_id, update.clone())?;
3389 }
3390
3391 send_notifications(&connection_pool, &session.peer, notifications);
3392
3393 response.send(proto::Ack {})?;
3394 Ok(())
3395}
3396
3397/// Accept or decline a contact request
3398async fn respond_to_contact_request(
3399 request: proto::RespondToContactRequest,
3400 response: Response<proto::RespondToContactRequest>,
3401 session: UserSession,
3402) -> Result<()> {
3403 let responder_id = session.user_id();
3404 let requester_id = UserId::from_proto(request.requester_id);
3405 let db = session.db().await;
3406 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3407 db.dismiss_contact_notification(responder_id, requester_id)
3408 .await?;
3409 } else {
3410 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3411
3412 let notifications = db
3413 .respond_to_contact_request(responder_id, requester_id, accept)
3414 .await?;
3415 let requester_busy = db.is_user_busy(requester_id).await?;
3416 let responder_busy = db.is_user_busy(responder_id).await?;
3417
3418 let pool = session.connection_pool().await;
3419 // Update responder with new contact
3420 let mut update = proto::UpdateContacts::default();
3421 if accept {
3422 update
3423 .contacts
3424 .push(contact_for_user(requester_id, requester_busy, &pool));
3425 }
3426 update
3427 .remove_incoming_requests
3428 .push(requester_id.to_proto());
3429 for connection_id in pool.user_connection_ids(responder_id) {
3430 session.peer.send(connection_id, update.clone())?;
3431 }
3432
3433 // Update requester with new contact
3434 let mut update = proto::UpdateContacts::default();
3435 if accept {
3436 update
3437 .contacts
3438 .push(contact_for_user(responder_id, responder_busy, &pool));
3439 }
3440 update
3441 .remove_outgoing_requests
3442 .push(responder_id.to_proto());
3443
3444 for connection_id in pool.user_connection_ids(requester_id) {
3445 session.peer.send(connection_id, update.clone())?;
3446 }
3447
3448 send_notifications(&pool, &session.peer, notifications);
3449 }
3450
3451 response.send(proto::Ack {})?;
3452 Ok(())
3453}
3454
3455/// Remove a contact.
3456async fn remove_contact(
3457 request: proto::RemoveContact,
3458 response: Response<proto::RemoveContact>,
3459 session: UserSession,
3460) -> Result<()> {
3461 let requester_id = session.user_id();
3462 let responder_id = UserId::from_proto(request.user_id);
3463 let db = session.db().await;
3464 let (contact_accepted, deleted_notification_id) =
3465 db.remove_contact(requester_id, responder_id).await?;
3466
3467 let pool = session.connection_pool().await;
3468 // Update outgoing contact requests of requester
3469 let mut update = proto::UpdateContacts::default();
3470 if contact_accepted {
3471 update.remove_contacts.push(responder_id.to_proto());
3472 } else {
3473 update
3474 .remove_outgoing_requests
3475 .push(responder_id.to_proto());
3476 }
3477 for connection_id in pool.user_connection_ids(requester_id) {
3478 session.peer.send(connection_id, update.clone())?;
3479 }
3480
3481 // Update incoming contact requests of responder
3482 let mut update = proto::UpdateContacts::default();
3483 if contact_accepted {
3484 update.remove_contacts.push(requester_id.to_proto());
3485 } else {
3486 update
3487 .remove_incoming_requests
3488 .push(requester_id.to_proto());
3489 }
3490 for connection_id in pool.user_connection_ids(responder_id) {
3491 session.peer.send(connection_id, update.clone())?;
3492 if let Some(notification_id) = deleted_notification_id {
3493 session.peer.send(
3494 connection_id,
3495 proto::DeleteNotification {
3496 notification_id: notification_id.to_proto(),
3497 },
3498 )?;
3499 }
3500 }
3501
3502 response.send(proto::Ack {})?;
3503 Ok(())
3504}
3505
3506fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3507 version.0.minor() < 139
3508}
3509
3510async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3511 let plan = session.current_plan(session.db().await).await?;
3512
3513 session
3514 .peer
3515 .send(
3516 session.connection_id,
3517 proto::UpdateUserPlan { plan: plan.into() },
3518 )
3519 .trace_err();
3520
3521 Ok(())
3522}
3523
3524async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3525 subscribe_user_to_channels(
3526 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3527 &session,
3528 )
3529 .await?;
3530 Ok(())
3531}
3532
3533async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3534 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3535 let mut pool = session.connection_pool().await;
3536 for membership in &channels_for_user.channel_memberships {
3537 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3538 }
3539 session.peer.send(
3540 session.connection_id,
3541 build_update_user_channels(&channels_for_user),
3542 )?;
3543 session.peer.send(
3544 session.connection_id,
3545 build_channels_update(channels_for_user),
3546 )?;
3547 Ok(())
3548}
3549
3550/// Creates a new channel.
3551async fn create_channel(
3552 request: proto::CreateChannel,
3553 response: Response<proto::CreateChannel>,
3554 session: UserSession,
3555) -> Result<()> {
3556 let db = session.db().await;
3557
3558 let parent_id = request.parent_id.map(ChannelId::from_proto);
3559 let (channel, membership) = db
3560 .create_channel(&request.name, parent_id, session.user_id())
3561 .await?;
3562
3563 let root_id = channel.root_id();
3564 let channel = Channel::from_model(channel);
3565
3566 response.send(proto::CreateChannelResponse {
3567 channel: Some(channel.to_proto()),
3568 parent_id: request.parent_id,
3569 })?;
3570
3571 let mut connection_pool = session.connection_pool().await;
3572 if let Some(membership) = membership {
3573 connection_pool.subscribe_to_channel(
3574 membership.user_id,
3575 membership.channel_id,
3576 membership.role,
3577 );
3578 let update = proto::UpdateUserChannels {
3579 channel_memberships: vec![proto::ChannelMembership {
3580 channel_id: membership.channel_id.to_proto(),
3581 role: membership.role.into(),
3582 }],
3583 ..Default::default()
3584 };
3585 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3586 session.peer.send(connection_id, update.clone())?;
3587 }
3588 }
3589
3590 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3591 if !role.can_see_channel(channel.visibility) {
3592 continue;
3593 }
3594
3595 let update = proto::UpdateChannels {
3596 channels: vec![channel.to_proto()],
3597 ..Default::default()
3598 };
3599 session.peer.send(connection_id, update.clone())?;
3600 }
3601
3602 Ok(())
3603}
3604
3605/// Delete a channel
3606async fn delete_channel(
3607 request: proto::DeleteChannel,
3608 response: Response<proto::DeleteChannel>,
3609 session: UserSession,
3610) -> Result<()> {
3611 let db = session.db().await;
3612
3613 let channel_id = request.channel_id;
3614 let (root_channel, removed_channels) = db
3615 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3616 .await?;
3617 response.send(proto::Ack {})?;
3618
3619 // Notify members of removed channels
3620 let mut update = proto::UpdateChannels::default();
3621 update
3622 .delete_channels
3623 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3624
3625 let connection_pool = session.connection_pool().await;
3626 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3627 session.peer.send(connection_id, update.clone())?;
3628 }
3629
3630 Ok(())
3631}
3632
3633/// Invite someone to join a channel.
3634async fn invite_channel_member(
3635 request: proto::InviteChannelMember,
3636 response: Response<proto::InviteChannelMember>,
3637 session: UserSession,
3638) -> Result<()> {
3639 let db = session.db().await;
3640 let channel_id = ChannelId::from_proto(request.channel_id);
3641 let invitee_id = UserId::from_proto(request.user_id);
3642 let InviteMemberResult {
3643 channel,
3644 notifications,
3645 } = db
3646 .invite_channel_member(
3647 channel_id,
3648 invitee_id,
3649 session.user_id(),
3650 request.role().into(),
3651 )
3652 .await?;
3653
3654 let update = proto::UpdateChannels {
3655 channel_invitations: vec![channel.to_proto()],
3656 ..Default::default()
3657 };
3658
3659 let connection_pool = session.connection_pool().await;
3660 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3661 session.peer.send(connection_id, update.clone())?;
3662 }
3663
3664 send_notifications(&connection_pool, &session.peer, notifications);
3665
3666 response.send(proto::Ack {})?;
3667 Ok(())
3668}
3669
3670/// remove someone from a channel
3671async fn remove_channel_member(
3672 request: proto::RemoveChannelMember,
3673 response: Response<proto::RemoveChannelMember>,
3674 session: UserSession,
3675) -> Result<()> {
3676 let db = session.db().await;
3677 let channel_id = ChannelId::from_proto(request.channel_id);
3678 let member_id = UserId::from_proto(request.user_id);
3679
3680 let RemoveChannelMemberResult {
3681 membership_update,
3682 notification_id,
3683 } = db
3684 .remove_channel_member(channel_id, member_id, session.user_id())
3685 .await?;
3686
3687 let mut connection_pool = session.connection_pool().await;
3688 notify_membership_updated(
3689 &mut connection_pool,
3690 membership_update,
3691 member_id,
3692 &session.peer,
3693 );
3694 for connection_id in connection_pool.user_connection_ids(member_id) {
3695 if let Some(notification_id) = notification_id {
3696 session
3697 .peer
3698 .send(
3699 connection_id,
3700 proto::DeleteNotification {
3701 notification_id: notification_id.to_proto(),
3702 },
3703 )
3704 .trace_err();
3705 }
3706 }
3707
3708 response.send(proto::Ack {})?;
3709 Ok(())
3710}
3711
3712/// Toggle the channel between public and private.
3713/// Care is taken to maintain the invariant that public channels only descend from public channels,
3714/// (though members-only channels can appear at any point in the hierarchy).
3715async fn set_channel_visibility(
3716 request: proto::SetChannelVisibility,
3717 response: Response<proto::SetChannelVisibility>,
3718 session: UserSession,
3719) -> Result<()> {
3720 let db = session.db().await;
3721 let channel_id = ChannelId::from_proto(request.channel_id);
3722 let visibility = request.visibility().into();
3723
3724 let channel_model = db
3725 .set_channel_visibility(channel_id, visibility, session.user_id())
3726 .await?;
3727 let root_id = channel_model.root_id();
3728 let channel = Channel::from_model(channel_model);
3729
3730 let mut connection_pool = session.connection_pool().await;
3731 for (user_id, role) in connection_pool
3732 .channel_user_ids(root_id)
3733 .collect::<Vec<_>>()
3734 .into_iter()
3735 {
3736 let update = if role.can_see_channel(channel.visibility) {
3737 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3738 proto::UpdateChannels {
3739 channels: vec![channel.to_proto()],
3740 ..Default::default()
3741 }
3742 } else {
3743 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3744 proto::UpdateChannels {
3745 delete_channels: vec![channel.id.to_proto()],
3746 ..Default::default()
3747 }
3748 };
3749
3750 for connection_id in connection_pool.user_connection_ids(user_id) {
3751 session.peer.send(connection_id, update.clone())?;
3752 }
3753 }
3754
3755 response.send(proto::Ack {})?;
3756 Ok(())
3757}
3758
3759/// Alter the role for a user in the channel.
3760async fn set_channel_member_role(
3761 request: proto::SetChannelMemberRole,
3762 response: Response<proto::SetChannelMemberRole>,
3763 session: UserSession,
3764) -> Result<()> {
3765 let db = session.db().await;
3766 let channel_id = ChannelId::from_proto(request.channel_id);
3767 let member_id = UserId::from_proto(request.user_id);
3768 let result = db
3769 .set_channel_member_role(
3770 channel_id,
3771 session.user_id(),
3772 member_id,
3773 request.role().into(),
3774 )
3775 .await?;
3776
3777 match result {
3778 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3779 let mut connection_pool = session.connection_pool().await;
3780 notify_membership_updated(
3781 &mut connection_pool,
3782 membership_update,
3783 member_id,
3784 &session.peer,
3785 )
3786 }
3787 db::SetMemberRoleResult::InviteUpdated(channel) => {
3788 let update = proto::UpdateChannels {
3789 channel_invitations: vec![channel.to_proto()],
3790 ..Default::default()
3791 };
3792
3793 for connection_id in session
3794 .connection_pool()
3795 .await
3796 .user_connection_ids(member_id)
3797 {
3798 session.peer.send(connection_id, update.clone())?;
3799 }
3800 }
3801 }
3802
3803 response.send(proto::Ack {})?;
3804 Ok(())
3805}
3806
3807/// Change the name of a channel
3808async fn rename_channel(
3809 request: proto::RenameChannel,
3810 response: Response<proto::RenameChannel>,
3811 session: UserSession,
3812) -> Result<()> {
3813 let db = session.db().await;
3814 let channel_id = ChannelId::from_proto(request.channel_id);
3815 let channel_model = db
3816 .rename_channel(channel_id, session.user_id(), &request.name)
3817 .await?;
3818 let root_id = channel_model.root_id();
3819 let channel = Channel::from_model(channel_model);
3820
3821 response.send(proto::RenameChannelResponse {
3822 channel: Some(channel.to_proto()),
3823 })?;
3824
3825 let connection_pool = session.connection_pool().await;
3826 let update = proto::UpdateChannels {
3827 channels: vec![channel.to_proto()],
3828 ..Default::default()
3829 };
3830 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3831 if role.can_see_channel(channel.visibility) {
3832 session.peer.send(connection_id, update.clone())?;
3833 }
3834 }
3835
3836 Ok(())
3837}
3838
3839/// Move a channel to a new parent.
3840async fn move_channel(
3841 request: proto::MoveChannel,
3842 response: Response<proto::MoveChannel>,
3843 session: UserSession,
3844) -> Result<()> {
3845 let channel_id = ChannelId::from_proto(request.channel_id);
3846 let to = ChannelId::from_proto(request.to);
3847
3848 let (root_id, channels) = session
3849 .db()
3850 .await
3851 .move_channel(channel_id, to, session.user_id())
3852 .await?;
3853
3854 let connection_pool = session.connection_pool().await;
3855 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3856 let channels = channels
3857 .iter()
3858 .filter_map(|channel| {
3859 if role.can_see_channel(channel.visibility) {
3860 Some(channel.to_proto())
3861 } else {
3862 None
3863 }
3864 })
3865 .collect::<Vec<_>>();
3866 if channels.is_empty() {
3867 continue;
3868 }
3869
3870 let update = proto::UpdateChannels {
3871 channels,
3872 ..Default::default()
3873 };
3874
3875 session.peer.send(connection_id, update.clone())?;
3876 }
3877
3878 response.send(Ack {})?;
3879 Ok(())
3880}
3881
3882/// Get the list of channel members
3883async fn get_channel_members(
3884 request: proto::GetChannelMembers,
3885 response: Response<proto::GetChannelMembers>,
3886 session: UserSession,
3887) -> Result<()> {
3888 let db = session.db().await;
3889 let channel_id = ChannelId::from_proto(request.channel_id);
3890 let limit = if request.limit == 0 {
3891 u16::MAX as u64
3892 } else {
3893 request.limit
3894 };
3895 let (members, users) = db
3896 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3897 .await?;
3898 response.send(proto::GetChannelMembersResponse { members, users })?;
3899 Ok(())
3900}
3901
3902/// Accept or decline a channel invitation.
3903async fn respond_to_channel_invite(
3904 request: proto::RespondToChannelInvite,
3905 response: Response<proto::RespondToChannelInvite>,
3906 session: UserSession,
3907) -> Result<()> {
3908 let db = session.db().await;
3909 let channel_id = ChannelId::from_proto(request.channel_id);
3910 let RespondToChannelInvite {
3911 membership_update,
3912 notifications,
3913 } = db
3914 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3915 .await?;
3916
3917 let mut connection_pool = session.connection_pool().await;
3918 if let Some(membership_update) = membership_update {
3919 notify_membership_updated(
3920 &mut connection_pool,
3921 membership_update,
3922 session.user_id(),
3923 &session.peer,
3924 );
3925 } else {
3926 let update = proto::UpdateChannels {
3927 remove_channel_invitations: vec![channel_id.to_proto()],
3928 ..Default::default()
3929 };
3930
3931 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3932 session.peer.send(connection_id, update.clone())?;
3933 }
3934 };
3935
3936 send_notifications(&connection_pool, &session.peer, notifications);
3937
3938 response.send(proto::Ack {})?;
3939
3940 Ok(())
3941}
3942
3943/// Join the channels' room
3944async fn join_channel(
3945 request: proto::JoinChannel,
3946 response: Response<proto::JoinChannel>,
3947 session: UserSession,
3948) -> Result<()> {
3949 let channel_id = ChannelId::from_proto(request.channel_id);
3950 join_channel_internal(channel_id, Box::new(response), session).await
3951}
3952
3953trait JoinChannelInternalResponse {
3954 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3955}
3956impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3957 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3958 Response::<proto::JoinChannel>::send(self, result)
3959 }
3960}
3961impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3962 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3963 Response::<proto::JoinRoom>::send(self, result)
3964 }
3965}
3966
3967async fn join_channel_internal(
3968 channel_id: ChannelId,
3969 response: Box<impl JoinChannelInternalResponse>,
3970 session: UserSession,
3971) -> Result<()> {
3972 let joined_room = {
3973 let mut db = session.db().await;
3974 // If zed quits without leaving the room, and the user re-opens zed before the
3975 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3976 // room they were in.
3977 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3978 tracing::info!(
3979 stale_connection_id = %connection,
3980 "cleaning up stale connection",
3981 );
3982 drop(db);
3983 leave_room_for_session(&session, connection).await?;
3984 db = session.db().await;
3985 }
3986
3987 let (joined_room, membership_updated, role) = db
3988 .join_channel(channel_id, session.user_id(), session.connection_id)
3989 .await?;
3990
3991 let live_kit_connection_info =
3992 session
3993 .app_state
3994 .live_kit_client
3995 .as_ref()
3996 .and_then(|live_kit| {
3997 let (can_publish, token) = if role == ChannelRole::Guest {
3998 (
3999 false,
4000 live_kit
4001 .guest_token(
4002 &joined_room.room.live_kit_room,
4003 &session.user_id().to_string(),
4004 )
4005 .trace_err()?,
4006 )
4007 } else {
4008 (
4009 true,
4010 live_kit
4011 .room_token(
4012 &joined_room.room.live_kit_room,
4013 &session.user_id().to_string(),
4014 )
4015 .trace_err()?,
4016 )
4017 };
4018
4019 Some(LiveKitConnectionInfo {
4020 server_url: live_kit.url().into(),
4021 token,
4022 can_publish,
4023 })
4024 });
4025
4026 response.send(proto::JoinRoomResponse {
4027 room: Some(joined_room.room.clone()),
4028 channel_id: joined_room
4029 .channel
4030 .as_ref()
4031 .map(|channel| channel.id.to_proto()),
4032 live_kit_connection_info,
4033 })?;
4034
4035 let mut connection_pool = session.connection_pool().await;
4036 if let Some(membership_updated) = membership_updated {
4037 notify_membership_updated(
4038 &mut connection_pool,
4039 membership_updated,
4040 session.user_id(),
4041 &session.peer,
4042 );
4043 }
4044
4045 room_updated(&joined_room.room, &session.peer);
4046
4047 joined_room
4048 };
4049
4050 channel_updated(
4051 &joined_room
4052 .channel
4053 .ok_or_else(|| anyhow!("channel not returned"))?,
4054 &joined_room.room,
4055 &session.peer,
4056 &*session.connection_pool().await,
4057 );
4058
4059 update_user_contacts(session.user_id(), &session).await?;
4060 Ok(())
4061}
4062
4063/// Start editing the channel notes
4064async fn join_channel_buffer(
4065 request: proto::JoinChannelBuffer,
4066 response: Response<proto::JoinChannelBuffer>,
4067 session: UserSession,
4068) -> Result<()> {
4069 let db = session.db().await;
4070 let channel_id = ChannelId::from_proto(request.channel_id);
4071
4072 let open_response = db
4073 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4074 .await?;
4075
4076 let collaborators = open_response.collaborators.clone();
4077 response.send(open_response)?;
4078
4079 let update = UpdateChannelBufferCollaborators {
4080 channel_id: channel_id.to_proto(),
4081 collaborators: collaborators.clone(),
4082 };
4083 channel_buffer_updated(
4084 session.connection_id,
4085 collaborators
4086 .iter()
4087 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4088 &update,
4089 &session.peer,
4090 );
4091
4092 Ok(())
4093}
4094
4095/// Edit the channel notes
4096async fn update_channel_buffer(
4097 request: proto::UpdateChannelBuffer,
4098 session: UserSession,
4099) -> Result<()> {
4100 let db = session.db().await;
4101 let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103 let (collaborators, epoch, version) = db
4104 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4105 .await?;
4106
4107 channel_buffer_updated(
4108 session.connection_id,
4109 collaborators.clone(),
4110 &proto::UpdateChannelBuffer {
4111 channel_id: channel_id.to_proto(),
4112 operations: request.operations,
4113 },
4114 &session.peer,
4115 );
4116
4117 let pool = &*session.connection_pool().await;
4118
4119 let non_collaborators =
4120 pool.channel_connection_ids(channel_id)
4121 .filter_map(|(connection_id, _)| {
4122 if collaborators.contains(&connection_id) {
4123 None
4124 } else {
4125 Some(connection_id)
4126 }
4127 });
4128
4129 broadcast(None, non_collaborators, |peer_id| {
4130 session.peer.send(
4131 peer_id,
4132 proto::UpdateChannels {
4133 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4134 channel_id: channel_id.to_proto(),
4135 epoch: epoch as u64,
4136 version: version.clone(),
4137 }],
4138 ..Default::default()
4139 },
4140 )
4141 });
4142
4143 Ok(())
4144}
4145
4146/// Rejoin the channel notes after a connection blip
4147async fn rejoin_channel_buffers(
4148 request: proto::RejoinChannelBuffers,
4149 response: Response<proto::RejoinChannelBuffers>,
4150 session: UserSession,
4151) -> Result<()> {
4152 let db = session.db().await;
4153 let buffers = db
4154 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4155 .await?;
4156
4157 for rejoined_buffer in &buffers {
4158 let collaborators_to_notify = rejoined_buffer
4159 .buffer
4160 .collaborators
4161 .iter()
4162 .filter_map(|c| Some(c.peer_id?.into()));
4163 channel_buffer_updated(
4164 session.connection_id,
4165 collaborators_to_notify,
4166 &proto::UpdateChannelBufferCollaborators {
4167 channel_id: rejoined_buffer.buffer.channel_id,
4168 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4169 },
4170 &session.peer,
4171 );
4172 }
4173
4174 response.send(proto::RejoinChannelBuffersResponse {
4175 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4176 })?;
4177
4178 Ok(())
4179}
4180
4181/// Stop editing the channel notes
4182async fn leave_channel_buffer(
4183 request: proto::LeaveChannelBuffer,
4184 response: Response<proto::LeaveChannelBuffer>,
4185 session: UserSession,
4186) -> Result<()> {
4187 let db = session.db().await;
4188 let channel_id = ChannelId::from_proto(request.channel_id);
4189
4190 let left_buffer = db
4191 .leave_channel_buffer(channel_id, session.connection_id)
4192 .await?;
4193
4194 response.send(Ack {})?;
4195
4196 channel_buffer_updated(
4197 session.connection_id,
4198 left_buffer.connections,
4199 &proto::UpdateChannelBufferCollaborators {
4200 channel_id: channel_id.to_proto(),
4201 collaborators: left_buffer.collaborators,
4202 },
4203 &session.peer,
4204 );
4205
4206 Ok(())
4207}
4208
4209fn channel_buffer_updated<T: EnvelopedMessage>(
4210 sender_id: ConnectionId,
4211 collaborators: impl IntoIterator<Item = ConnectionId>,
4212 message: &T,
4213 peer: &Peer,
4214) {
4215 broadcast(Some(sender_id), collaborators, |peer_id| {
4216 peer.send(peer_id, message.clone())
4217 });
4218}
4219
4220fn send_notifications(
4221 connection_pool: &ConnectionPool,
4222 peer: &Peer,
4223 notifications: db::NotificationBatch,
4224) {
4225 for (user_id, notification) in notifications {
4226 for connection_id in connection_pool.user_connection_ids(user_id) {
4227 if let Err(error) = peer.send(
4228 connection_id,
4229 proto::AddNotification {
4230 notification: Some(notification.clone()),
4231 },
4232 ) {
4233 tracing::error!(
4234 "failed to send notification to {:?} {}",
4235 connection_id,
4236 error
4237 );
4238 }
4239 }
4240 }
4241}
4242
4243/// Send a message to the channel
4244async fn send_channel_message(
4245 request: proto::SendChannelMessage,
4246 response: Response<proto::SendChannelMessage>,
4247 session: UserSession,
4248) -> Result<()> {
4249 // Validate the message body.
4250 let body = request.body.trim().to_string();
4251 if body.len() > MAX_MESSAGE_LEN {
4252 return Err(anyhow!("message is too long"))?;
4253 }
4254 if body.is_empty() {
4255 return Err(anyhow!("message can't be blank"))?;
4256 }
4257
4258 // TODO: adjust mentions if body is trimmed
4259
4260 let timestamp = OffsetDateTime::now_utc();
4261 let nonce = request
4262 .nonce
4263 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4264
4265 let channel_id = ChannelId::from_proto(request.channel_id);
4266 let CreatedChannelMessage {
4267 message_id,
4268 participant_connection_ids,
4269 notifications,
4270 } = session
4271 .db()
4272 .await
4273 .create_channel_message(
4274 channel_id,
4275 session.user_id(),
4276 &body,
4277 &request.mentions,
4278 timestamp,
4279 nonce.clone().into(),
4280 request.reply_to_message_id.map(MessageId::from_proto),
4281 )
4282 .await?;
4283
4284 let message = proto::ChannelMessage {
4285 sender_id: session.user_id().to_proto(),
4286 id: message_id.to_proto(),
4287 body,
4288 mentions: request.mentions,
4289 timestamp: timestamp.unix_timestamp() as u64,
4290 nonce: Some(nonce),
4291 reply_to_message_id: request.reply_to_message_id,
4292 edited_at: None,
4293 };
4294 broadcast(
4295 Some(session.connection_id),
4296 participant_connection_ids.clone(),
4297 |connection| {
4298 session.peer.send(
4299 connection,
4300 proto::ChannelMessageSent {
4301 channel_id: channel_id.to_proto(),
4302 message: Some(message.clone()),
4303 },
4304 )
4305 },
4306 );
4307 response.send(proto::SendChannelMessageResponse {
4308 message: Some(message),
4309 })?;
4310
4311 let pool = &*session.connection_pool().await;
4312 let non_participants =
4313 pool.channel_connection_ids(channel_id)
4314 .filter_map(|(connection_id, _)| {
4315 if participant_connection_ids.contains(&connection_id) {
4316 None
4317 } else {
4318 Some(connection_id)
4319 }
4320 });
4321 broadcast(None, non_participants, |peer_id| {
4322 session.peer.send(
4323 peer_id,
4324 proto::UpdateChannels {
4325 latest_channel_message_ids: vec![proto::ChannelMessageId {
4326 channel_id: channel_id.to_proto(),
4327 message_id: message_id.to_proto(),
4328 }],
4329 ..Default::default()
4330 },
4331 )
4332 });
4333 send_notifications(pool, &session.peer, notifications);
4334
4335 Ok(())
4336}
4337
4338/// Delete a channel message
4339async fn remove_channel_message(
4340 request: proto::RemoveChannelMessage,
4341 response: Response<proto::RemoveChannelMessage>,
4342 session: UserSession,
4343) -> Result<()> {
4344 let channel_id = ChannelId::from_proto(request.channel_id);
4345 let message_id = MessageId::from_proto(request.message_id);
4346 let (connection_ids, existing_notification_ids) = session
4347 .db()
4348 .await
4349 .remove_channel_message(channel_id, message_id, session.user_id())
4350 .await?;
4351
4352 broadcast(
4353 Some(session.connection_id),
4354 connection_ids,
4355 move |connection| {
4356 session.peer.send(connection, request.clone())?;
4357
4358 for notification_id in &existing_notification_ids {
4359 session.peer.send(
4360 connection,
4361 proto::DeleteNotification {
4362 notification_id: (*notification_id).to_proto(),
4363 },
4364 )?;
4365 }
4366
4367 Ok(())
4368 },
4369 );
4370 response.send(proto::Ack {})?;
4371 Ok(())
4372}
4373
4374async fn update_channel_message(
4375 request: proto::UpdateChannelMessage,
4376 response: Response<proto::UpdateChannelMessage>,
4377 session: UserSession,
4378) -> Result<()> {
4379 let channel_id = ChannelId::from_proto(request.channel_id);
4380 let message_id = MessageId::from_proto(request.message_id);
4381 let updated_at = OffsetDateTime::now_utc();
4382 let UpdatedChannelMessage {
4383 message_id,
4384 participant_connection_ids,
4385 notifications,
4386 reply_to_message_id,
4387 timestamp,
4388 deleted_mention_notification_ids,
4389 updated_mention_notifications,
4390 } = session
4391 .db()
4392 .await
4393 .update_channel_message(
4394 channel_id,
4395 message_id,
4396 session.user_id(),
4397 request.body.as_str(),
4398 &request.mentions,
4399 updated_at,
4400 )
4401 .await?;
4402
4403 let nonce = request
4404 .nonce
4405 .clone()
4406 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4407
4408 let message = proto::ChannelMessage {
4409 sender_id: session.user_id().to_proto(),
4410 id: message_id.to_proto(),
4411 body: request.body.clone(),
4412 mentions: request.mentions.clone(),
4413 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4414 nonce: Some(nonce),
4415 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4416 edited_at: Some(updated_at.unix_timestamp() as u64),
4417 };
4418
4419 response.send(proto::Ack {})?;
4420
4421 let pool = &*session.connection_pool().await;
4422 broadcast(
4423 Some(session.connection_id),
4424 participant_connection_ids,
4425 |connection| {
4426 session.peer.send(
4427 connection,
4428 proto::ChannelMessageUpdate {
4429 channel_id: channel_id.to_proto(),
4430 message: Some(message.clone()),
4431 },
4432 )?;
4433
4434 for notification_id in &deleted_mention_notification_ids {
4435 session.peer.send(
4436 connection,
4437 proto::DeleteNotification {
4438 notification_id: (*notification_id).to_proto(),
4439 },
4440 )?;
4441 }
4442
4443 for notification in &updated_mention_notifications {
4444 session.peer.send(
4445 connection,
4446 proto::UpdateNotification {
4447 notification: Some(notification.clone()),
4448 },
4449 )?;
4450 }
4451
4452 Ok(())
4453 },
4454 );
4455
4456 send_notifications(pool, &session.peer, notifications);
4457
4458 Ok(())
4459}
4460
4461/// Mark a channel message as read
4462async fn acknowledge_channel_message(
4463 request: proto::AckChannelMessage,
4464 session: UserSession,
4465) -> Result<()> {
4466 let channel_id = ChannelId::from_proto(request.channel_id);
4467 let message_id = MessageId::from_proto(request.message_id);
4468 let notifications = session
4469 .db()
4470 .await
4471 .observe_channel_message(channel_id, session.user_id(), message_id)
4472 .await?;
4473 send_notifications(
4474 &*session.connection_pool().await,
4475 &session.peer,
4476 notifications,
4477 );
4478 Ok(())
4479}
4480
4481/// Mark a buffer version as synced
4482async fn acknowledge_buffer_version(
4483 request: proto::AckBufferOperation,
4484 session: UserSession,
4485) -> Result<()> {
4486 let buffer_id = BufferId::from_proto(request.buffer_id);
4487 session
4488 .db()
4489 .await
4490 .observe_buffer_version(
4491 buffer_id,
4492 session.user_id(),
4493 request.epoch as i32,
4494 &request.version,
4495 )
4496 .await?;
4497 Ok(())
4498}
4499
4500async fn count_language_model_tokens(
4501 request: proto::CountLanguageModelTokens,
4502 response: Response<proto::CountLanguageModelTokens>,
4503 session: Session,
4504 config: &Config,
4505) -> Result<()> {
4506 let Some(session) = session.for_user() else {
4507 return Err(anyhow!("user not found"))?;
4508 };
4509 authorize_access_to_legacy_llm_endpoints(&session).await?;
4510
4511 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4512 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4513 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4514 };
4515
4516 session
4517 .app_state
4518 .rate_limiter
4519 .check(&*rate_limit, session.user_id())
4520 .await?;
4521
4522 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4523 Some(proto::LanguageModelProvider::Google) => {
4524 let api_key = config
4525 .google_ai_api_key
4526 .as_ref()
4527 .context("no Google AI API key configured on the server")?;
4528 google_ai::count_tokens(
4529 session.http_client.as_ref(),
4530 google_ai::API_URL,
4531 api_key,
4532 serde_json::from_str(&request.request)?,
4533 None,
4534 )
4535 .await?
4536 }
4537 _ => return Err(anyhow!("unsupported provider"))?,
4538 };
4539
4540 response.send(proto::CountLanguageModelTokensResponse {
4541 token_count: result.total_tokens as u32,
4542 })?;
4543
4544 Ok(())
4545}
4546
4547struct ZedProCountLanguageModelTokensRateLimit;
4548
4549impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4550 fn capacity(&self) -> usize {
4551 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4552 .ok()
4553 .and_then(|v| v.parse().ok())
4554 .unwrap_or(600) // Picked arbitrarily
4555 }
4556
4557 fn refill_duration(&self) -> chrono::Duration {
4558 chrono::Duration::hours(1)
4559 }
4560
4561 fn db_name(&self) -> &'static str {
4562 "zed-pro:count-language-model-tokens"
4563 }
4564}
4565
4566struct FreeCountLanguageModelTokensRateLimit;
4567
4568impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4569 fn capacity(&self) -> usize {
4570 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4571 .ok()
4572 .and_then(|v| v.parse().ok())
4573 .unwrap_or(600 / 10) // Picked arbitrarily
4574 }
4575
4576 fn refill_duration(&self) -> chrono::Duration {
4577 chrono::Duration::hours(1)
4578 }
4579
4580 fn db_name(&self) -> &'static str {
4581 "free:count-language-model-tokens"
4582 }
4583}
4584
4585struct ZedProComputeEmbeddingsRateLimit;
4586
4587impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4588 fn capacity(&self) -> usize {
4589 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4590 .ok()
4591 .and_then(|v| v.parse().ok())
4592 .unwrap_or(5000) // Picked arbitrarily
4593 }
4594
4595 fn refill_duration(&self) -> chrono::Duration {
4596 chrono::Duration::hours(1)
4597 }
4598
4599 fn db_name(&self) -> &'static str {
4600 "zed-pro:compute-embeddings"
4601 }
4602}
4603
4604struct FreeComputeEmbeddingsRateLimit;
4605
4606impl RateLimit for FreeComputeEmbeddingsRateLimit {
4607 fn capacity(&self) -> usize {
4608 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4609 .ok()
4610 .and_then(|v| v.parse().ok())
4611 .unwrap_or(5000 / 10) // Picked arbitrarily
4612 }
4613
4614 fn refill_duration(&self) -> chrono::Duration {
4615 chrono::Duration::hours(1)
4616 }
4617
4618 fn db_name(&self) -> &'static str {
4619 "free:compute-embeddings"
4620 }
4621}
4622
4623async fn compute_embeddings(
4624 request: proto::ComputeEmbeddings,
4625 response: Response<proto::ComputeEmbeddings>,
4626 session: UserSession,
4627 api_key: Option<Arc<str>>,
4628) -> Result<()> {
4629 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4630 authorize_access_to_legacy_llm_endpoints(&session).await?;
4631
4632 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4633 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4634 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4635 };
4636
4637 session
4638 .app_state
4639 .rate_limiter
4640 .check(&*rate_limit, session.user_id())
4641 .await?;
4642
4643 let embeddings = match request.model.as_str() {
4644 "openai/text-embedding-3-small" => {
4645 open_ai::embed(
4646 session.http_client.as_ref(),
4647 OPEN_AI_API_URL,
4648 &api_key,
4649 OpenAiEmbeddingModel::TextEmbedding3Small,
4650 request.texts.iter().map(|text| text.as_str()),
4651 )
4652 .await?
4653 }
4654 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4655 };
4656
4657 let embeddings = request
4658 .texts
4659 .iter()
4660 .map(|text| {
4661 let mut hasher = sha2::Sha256::new();
4662 hasher.update(text.as_bytes());
4663 let result = hasher.finalize();
4664 result.to_vec()
4665 })
4666 .zip(
4667 embeddings
4668 .data
4669 .into_iter()
4670 .map(|embedding| embedding.embedding),
4671 )
4672 .collect::<HashMap<_, _>>();
4673
4674 let db = session.db().await;
4675 db.save_embeddings(&request.model, &embeddings)
4676 .await
4677 .context("failed to save embeddings")
4678 .trace_err();
4679
4680 response.send(proto::ComputeEmbeddingsResponse {
4681 embeddings: embeddings
4682 .into_iter()
4683 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4684 .collect(),
4685 })?;
4686 Ok(())
4687}
4688
4689async fn get_cached_embeddings(
4690 request: proto::GetCachedEmbeddings,
4691 response: Response<proto::GetCachedEmbeddings>,
4692 session: UserSession,
4693) -> Result<()> {
4694 authorize_access_to_legacy_llm_endpoints(&session).await?;
4695
4696 let db = session.db().await;
4697 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4698
4699 response.send(proto::GetCachedEmbeddingsResponse {
4700 embeddings: embeddings
4701 .into_iter()
4702 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4703 .collect(),
4704 })?;
4705 Ok(())
4706}
4707
4708/// This is leftover from before the LLM service.
4709///
4710/// The endpoints protected by this check will be moved there eventually.
4711async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4712 if session.is_staff() {
4713 Ok(())
4714 } else {
4715 Err(anyhow!("permission denied"))?
4716 }
4717}
4718
4719/// Get a Supermaven API key for the user
4720async fn get_supermaven_api_key(
4721 _request: proto::GetSupermavenApiKey,
4722 response: Response<proto::GetSupermavenApiKey>,
4723 session: UserSession,
4724) -> Result<()> {
4725 let user_id: String = session.user_id().to_string();
4726 if !session.is_staff() {
4727 return Err(anyhow!("supermaven not enabled for this account"))?;
4728 }
4729
4730 let email = session
4731 .email()
4732 .ok_or_else(|| anyhow!("user must have an email"))?;
4733
4734 let supermaven_admin_api = session
4735 .supermaven_client
4736 .as_ref()
4737 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4738
4739 let result = supermaven_admin_api
4740 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4741 .await?;
4742
4743 response.send(proto::GetSupermavenApiKeyResponse {
4744 api_key: result.api_key,
4745 })?;
4746
4747 Ok(())
4748}
4749
4750/// Start receiving chat updates for a channel
4751async fn join_channel_chat(
4752 request: proto::JoinChannelChat,
4753 response: Response<proto::JoinChannelChat>,
4754 session: UserSession,
4755) -> Result<()> {
4756 let channel_id = ChannelId::from_proto(request.channel_id);
4757
4758 let db = session.db().await;
4759 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4760 .await?;
4761 let messages = db
4762 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4763 .await?;
4764 response.send(proto::JoinChannelChatResponse {
4765 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4766 messages,
4767 })?;
4768 Ok(())
4769}
4770
4771/// Stop receiving chat updates for a channel
4772async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4773 let channel_id = ChannelId::from_proto(request.channel_id);
4774 session
4775 .db()
4776 .await
4777 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4778 .await?;
4779 Ok(())
4780}
4781
4782/// Retrieve the chat history for a channel
4783async fn get_channel_messages(
4784 request: proto::GetChannelMessages,
4785 response: Response<proto::GetChannelMessages>,
4786 session: UserSession,
4787) -> Result<()> {
4788 let channel_id = ChannelId::from_proto(request.channel_id);
4789 let messages = session
4790 .db()
4791 .await
4792 .get_channel_messages(
4793 channel_id,
4794 session.user_id(),
4795 MESSAGE_COUNT_PER_PAGE,
4796 Some(MessageId::from_proto(request.before_message_id)),
4797 )
4798 .await?;
4799 response.send(proto::GetChannelMessagesResponse {
4800 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4801 messages,
4802 })?;
4803 Ok(())
4804}
4805
4806/// Retrieve specific chat messages
4807async fn get_channel_messages_by_id(
4808 request: proto::GetChannelMessagesById,
4809 response: Response<proto::GetChannelMessagesById>,
4810 session: UserSession,
4811) -> Result<()> {
4812 let message_ids = request
4813 .message_ids
4814 .iter()
4815 .map(|id| MessageId::from_proto(*id))
4816 .collect::<Vec<_>>();
4817 let messages = session
4818 .db()
4819 .await
4820 .get_channel_messages_by_id(session.user_id(), &message_ids)
4821 .await?;
4822 response.send(proto::GetChannelMessagesResponse {
4823 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4824 messages,
4825 })?;
4826 Ok(())
4827}
4828
4829/// Retrieve the current users notifications
4830async fn get_notifications(
4831 request: proto::GetNotifications,
4832 response: Response<proto::GetNotifications>,
4833 session: UserSession,
4834) -> Result<()> {
4835 let notifications = session
4836 .db()
4837 .await
4838 .get_notifications(
4839 session.user_id(),
4840 NOTIFICATION_COUNT_PER_PAGE,
4841 request.before_id.map(db::NotificationId::from_proto),
4842 )
4843 .await?;
4844 response.send(proto::GetNotificationsResponse {
4845 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4846 notifications,
4847 })?;
4848 Ok(())
4849}
4850
4851/// Mark notifications as read
4852async fn mark_notification_as_read(
4853 request: proto::MarkNotificationRead,
4854 response: Response<proto::MarkNotificationRead>,
4855 session: UserSession,
4856) -> Result<()> {
4857 let database = &session.db().await;
4858 let notifications = database
4859 .mark_notification_as_read_by_id(
4860 session.user_id(),
4861 NotificationId::from_proto(request.notification_id),
4862 )
4863 .await?;
4864 send_notifications(
4865 &*session.connection_pool().await,
4866 &session.peer,
4867 notifications,
4868 );
4869 response.send(proto::Ack {})?;
4870 Ok(())
4871}
4872
4873/// Get the current users information
4874async fn get_private_user_info(
4875 _request: proto::GetPrivateUserInfo,
4876 response: Response<proto::GetPrivateUserInfo>,
4877 session: UserSession,
4878) -> Result<()> {
4879 let db = session.db().await;
4880
4881 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4882 let user = db
4883 .get_user_by_id(session.user_id())
4884 .await?
4885 .ok_or_else(|| anyhow!("user not found"))?;
4886 let flags = db.get_user_flags(session.user_id()).await?;
4887
4888 response.send(proto::GetPrivateUserInfoResponse {
4889 metrics_id,
4890 staff: user.admin,
4891 flags,
4892 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4893 })?;
4894 Ok(())
4895}
4896
4897/// Accept the terms of service (tos) on behalf of the current user
4898async fn accept_terms_of_service(
4899 _request: proto::AcceptTermsOfService,
4900 response: Response<proto::AcceptTermsOfService>,
4901 session: UserSession,
4902) -> Result<()> {
4903 let db = session.db().await;
4904
4905 let accepted_tos_at = Utc::now();
4906 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4907 .await?;
4908
4909 response.send(proto::AcceptTermsOfServiceResponse {
4910 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4911 })?;
4912 Ok(())
4913}
4914
4915/// The minimum account age an account must have in order to use the LLM service.
4916const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4917
4918async fn get_llm_api_token(
4919 _request: proto::GetLlmToken,
4920 response: Response<proto::GetLlmToken>,
4921 session: UserSession,
4922) -> Result<()> {
4923 let db = session.db().await;
4924
4925 let flags = db.get_user_flags(session.user_id()).await?;
4926 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4927 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4928
4929 if !session.is_staff() && !has_language_models_feature_flag {
4930 Err(anyhow!("permission denied"))?
4931 }
4932
4933 let user_id = session.user_id();
4934 let user = db
4935 .get_user_by_id(user_id)
4936 .await?
4937 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4938
4939 if user.accepted_tos_at.is_none() {
4940 Err(anyhow!("terms of service not accepted"))?
4941 }
4942
4943 let mut account_created_at = user.created_at;
4944 if let Some(github_created_at) = user.github_user_created_at {
4945 account_created_at = account_created_at.min(github_created_at);
4946 }
4947 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4948 Err(anyhow!("account too young"))?
4949 }
4950 let token = LlmTokenClaims::create(
4951 user.id,
4952 user.github_login.clone(),
4953 session.is_staff(),
4954 has_llm_closed_beta_feature_flag,
4955 session.current_plan(db).await?,
4956 &session.app_state.config,
4957 )?;
4958 response.send(proto::GetLlmTokenResponse { token })?;
4959 Ok(())
4960}
4961
4962fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4963 let message = match message {
4964 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4965 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4966 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4967 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4968 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4969 code: frame.code.into(),
4970 reason: frame.reason,
4971 })),
4972 // We should never receive a frame while reading the message, according
4973 // to the `tungstenite` maintainers:
4974 //
4975 // > It cannot occur when you read messages from the WebSocket, but it
4976 // > can be used when you want to send the raw frames (e.g. you want to
4977 // > send the frames to the WebSocket without composing the full message first).
4978 // >
4979 // > — https://github.com/snapview/tungstenite-rs/issues/268
4980 TungsteniteMessage::Frame(_) => {
4981 bail!("received an unexpected frame while reading the message")
4982 }
4983 };
4984
4985 Ok(message)
4986}
4987
4988fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4989 match message {
4990 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4991 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4992 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4993 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4994 AxumMessage::Close(frame) => {
4995 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4996 code: frame.code.into(),
4997 reason: frame.reason,
4998 }))
4999 }
5000 }
5001}
5002
5003fn notify_membership_updated(
5004 connection_pool: &mut ConnectionPool,
5005 result: MembershipUpdated,
5006 user_id: UserId,
5007 peer: &Peer,
5008) {
5009 for membership in &result.new_channels.channel_memberships {
5010 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5011 }
5012 for channel_id in &result.removed_channels {
5013 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5014 }
5015
5016 let user_channels_update = proto::UpdateUserChannels {
5017 channel_memberships: result
5018 .new_channels
5019 .channel_memberships
5020 .iter()
5021 .map(|cm| proto::ChannelMembership {
5022 channel_id: cm.channel_id.to_proto(),
5023 role: cm.role.into(),
5024 })
5025 .collect(),
5026 ..Default::default()
5027 };
5028
5029 let mut update = build_channels_update(result.new_channels);
5030 update.delete_channels = result
5031 .removed_channels
5032 .into_iter()
5033 .map(|id| id.to_proto())
5034 .collect();
5035 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5036
5037 for connection_id in connection_pool.user_connection_ids(user_id) {
5038 peer.send(connection_id, user_channels_update.clone())
5039 .trace_err();
5040 peer.send(connection_id, update.clone()).trace_err();
5041 }
5042}
5043
5044fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5045 proto::UpdateUserChannels {
5046 channel_memberships: channels
5047 .channel_memberships
5048 .iter()
5049 .map(|m| proto::ChannelMembership {
5050 channel_id: m.channel_id.to_proto(),
5051 role: m.role.into(),
5052 })
5053 .collect(),
5054 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5055 observed_channel_message_id: channels.observed_channel_messages.clone(),
5056 }
5057}
5058
5059fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5060 let mut update = proto::UpdateChannels::default();
5061
5062 for channel in channels.channels {
5063 update.channels.push(channel.to_proto());
5064 }
5065
5066 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5067 update.latest_channel_message_ids = channels.latest_channel_messages;
5068
5069 for (channel_id, participants) in channels.channel_participants {
5070 update
5071 .channel_participants
5072 .push(proto::ChannelParticipants {
5073 channel_id: channel_id.to_proto(),
5074 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5075 });
5076 }
5077
5078 for channel in channels.invited_channels {
5079 update.channel_invitations.push(channel.to_proto());
5080 }
5081
5082 update.hosted_projects = channels.hosted_projects;
5083 update
5084}
5085
5086fn build_initial_contacts_update(
5087 contacts: Vec<db::Contact>,
5088 pool: &ConnectionPool,
5089) -> proto::UpdateContacts {
5090 let mut update = proto::UpdateContacts::default();
5091
5092 for contact in contacts {
5093 match contact {
5094 db::Contact::Accepted { user_id, busy } => {
5095 update.contacts.push(contact_for_user(user_id, busy, pool));
5096 }
5097 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5098 db::Contact::Incoming { user_id } => {
5099 update
5100 .incoming_requests
5101 .push(proto::IncomingContactRequest {
5102 requester_id: user_id.to_proto(),
5103 })
5104 }
5105 }
5106 }
5107
5108 update
5109}
5110
5111fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5112 proto::Contact {
5113 user_id: user_id.to_proto(),
5114 online: pool.is_user_online(user_id),
5115 busy,
5116 }
5117}
5118
5119fn room_updated(room: &proto::Room, peer: &Peer) {
5120 broadcast(
5121 None,
5122 room.participants
5123 .iter()
5124 .filter_map(|participant| Some(participant.peer_id?.into())),
5125 |peer_id| {
5126 peer.send(
5127 peer_id,
5128 proto::RoomUpdated {
5129 room: Some(room.clone()),
5130 },
5131 )
5132 },
5133 );
5134}
5135
5136fn channel_updated(
5137 channel: &db::channel::Model,
5138 room: &proto::Room,
5139 peer: &Peer,
5140 pool: &ConnectionPool,
5141) {
5142 let participants = room
5143 .participants
5144 .iter()
5145 .map(|p| p.user_id)
5146 .collect::<Vec<_>>();
5147
5148 broadcast(
5149 None,
5150 pool.channel_connection_ids(channel.root_id())
5151 .filter_map(|(channel_id, role)| {
5152 role.can_see_channel(channel.visibility)
5153 .then_some(channel_id)
5154 }),
5155 |peer_id| {
5156 peer.send(
5157 peer_id,
5158 proto::UpdateChannels {
5159 channel_participants: vec![proto::ChannelParticipants {
5160 channel_id: channel.id.to_proto(),
5161 participant_user_ids: participants.clone(),
5162 }],
5163 ..Default::default()
5164 },
5165 )
5166 },
5167 );
5168}
5169
5170async fn send_dev_server_projects_update(
5171 user_id: UserId,
5172 mut status: proto::DevServerProjectsUpdate,
5173 session: &Session,
5174) {
5175 let pool = session.connection_pool().await;
5176 for dev_server in &mut status.dev_servers {
5177 dev_server.status =
5178 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5179 }
5180 let connections = pool.user_connection_ids(user_id);
5181 for connection_id in connections {
5182 session.peer.send(connection_id, status.clone()).trace_err();
5183 }
5184}
5185
5186async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5187 let db = session.db().await;
5188
5189 let contacts = db.get_contacts(user_id).await?;
5190 let busy = db.is_user_busy(user_id).await?;
5191
5192 let pool = session.connection_pool().await;
5193 let updated_contact = contact_for_user(user_id, busy, &pool);
5194 for contact in contacts {
5195 if let db::Contact::Accepted {
5196 user_id: contact_user_id,
5197 ..
5198 } = contact
5199 {
5200 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5201 session
5202 .peer
5203 .send(
5204 contact_conn_id,
5205 proto::UpdateContacts {
5206 contacts: vec![updated_contact.clone()],
5207 remove_contacts: Default::default(),
5208 incoming_requests: Default::default(),
5209 remove_incoming_requests: Default::default(),
5210 outgoing_requests: Default::default(),
5211 remove_outgoing_requests: Default::default(),
5212 },
5213 )
5214 .trace_err();
5215 }
5216 }
5217 }
5218 Ok(())
5219}
5220
5221async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5222 log::info!("lost dev server connection, unsharing projects");
5223 let project_ids = session
5224 .db()
5225 .await
5226 .get_stale_dev_server_projects(session.connection_id)
5227 .await?;
5228
5229 for project_id in project_ids {
5230 // not unshare re-checks the connection ids match, so we get away with no transaction
5231 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5232 }
5233
5234 let user_id = session.dev_server().user_id;
5235 let update = session
5236 .db()
5237 .await
5238 .dev_server_projects_update(user_id)
5239 .await?;
5240
5241 send_dev_server_projects_update(user_id, update, session).await;
5242
5243 Ok(())
5244}
5245
5246async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5247 let mut contacts_to_update = HashSet::default();
5248
5249 let room_id;
5250 let canceled_calls_to_user_ids;
5251 let live_kit_room;
5252 let delete_live_kit_room;
5253 let room;
5254 let channel;
5255
5256 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5257 contacts_to_update.insert(session.user_id());
5258
5259 for project in left_room.left_projects.values() {
5260 project_left(project, session);
5261 }
5262
5263 room_id = RoomId::from_proto(left_room.room.id);
5264 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5265 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5266 delete_live_kit_room = left_room.deleted;
5267 room = mem::take(&mut left_room.room);
5268 channel = mem::take(&mut left_room.channel);
5269
5270 room_updated(&room, &session.peer);
5271 } else {
5272 return Ok(());
5273 }
5274
5275 if let Some(channel) = channel {
5276 channel_updated(
5277 &channel,
5278 &room,
5279 &session.peer,
5280 &*session.connection_pool().await,
5281 );
5282 }
5283
5284 {
5285 let pool = session.connection_pool().await;
5286 for canceled_user_id in canceled_calls_to_user_ids {
5287 for connection_id in pool.user_connection_ids(canceled_user_id) {
5288 session
5289 .peer
5290 .send(
5291 connection_id,
5292 proto::CallCanceled {
5293 room_id: room_id.to_proto(),
5294 },
5295 )
5296 .trace_err();
5297 }
5298 contacts_to_update.insert(canceled_user_id);
5299 }
5300 }
5301
5302 for contact_user_id in contacts_to_update {
5303 update_user_contacts(contact_user_id, session).await?;
5304 }
5305
5306 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5307 live_kit
5308 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5309 .await
5310 .trace_err();
5311
5312 if delete_live_kit_room {
5313 live_kit.delete_room(live_kit_room).await.trace_err();
5314 }
5315 }
5316
5317 Ok(())
5318}
5319
5320async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5321 let left_channel_buffers = session
5322 .db()
5323 .await
5324 .leave_channel_buffers(session.connection_id)
5325 .await?;
5326
5327 for left_buffer in left_channel_buffers {
5328 channel_buffer_updated(
5329 session.connection_id,
5330 left_buffer.connections,
5331 &proto::UpdateChannelBufferCollaborators {
5332 channel_id: left_buffer.channel_id.to_proto(),
5333 collaborators: left_buffer.collaborators,
5334 },
5335 &session.peer,
5336 );
5337 }
5338
5339 Ok(())
5340}
5341
5342fn project_left(project: &db::LeftProject, session: &UserSession) {
5343 for connection_id in &project.connection_ids {
5344 if project.should_unshare {
5345 session
5346 .peer
5347 .send(
5348 *connection_id,
5349 proto::UnshareProject {
5350 project_id: project.id.to_proto(),
5351 },
5352 )
5353 .trace_err();
5354 } else {
5355 session
5356 .peer
5357 .send(
5358 *connection_id,
5359 proto::RemoveProjectCollaborator {
5360 project_id: project.id.to_proto(),
5361 peer_id: Some(session.connection_id.into()),
5362 },
5363 )
5364 .trace_err();
5365 }
5366 }
5367}
5368
5369pub trait ResultExt {
5370 type Ok;
5371
5372 fn trace_err(self) -> Option<Self::Ok>;
5373}
5374
5375impl<T, E> ResultExt for Result<T, E>
5376where
5377 E: std::fmt::Debug,
5378{
5379 type Ok = T;
5380
5381 #[track_caller]
5382 fn trace_err(self) -> Option<T> {
5383 match self {
5384 Ok(value) => Some(value),
5385 Err(error) => {
5386 tracing::error!("{:?}", error);
5387 None
5388 }
5389 }
5390 }
5391}