1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
40use reqwest_client::ReqwestClient;
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn has_llm_subscription(
195 &self,
196 db: &MutexGuard<'_, DbHandle>,
197 ) -> anyhow::Result<bool> {
198 if self.is_staff() {
199 return Ok(true);
200 }
201
202 let Some(user_id) = self.user_id() else {
203 return Ok(false);
204 };
205
206 Ok(db.has_active_billing_subscription(user_id).await?)
207 }
208
209 pub async fn current_plan(
210 &self,
211 _db: &MutexGuard<'_, DbHandle>,
212 ) -> anyhow::Result<proto::Plan> {
213 if self.is_staff() {
214 Ok(proto::Plan::ZedPro)
215 } else {
216 Ok(proto::Plan::Free)
217 }
218 }
219
220 fn dev_server_id(&self) -> Option<DevServerId> {
221 match &self.principal {
222 Principal::User(_) | Principal::Impersonated { .. } => None,
223 Principal::DevServer(dev_server) => Some(dev_server.id),
224 }
225 }
226
227 fn principal_id(&self) -> PrincipalId {
228 match &self.principal {
229 Principal::User(user) => PrincipalId::UserId(user.id),
230 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
231 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
232 }
233 }
234}
235
236impl Debug for Session {
237 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
238 let mut result = f.debug_struct("Session");
239 match &self.principal {
240 Principal::User(user) => {
241 result.field("user", &user.github_login);
242 }
243 Principal::Impersonated { user, admin } => {
244 result.field("user", &user.github_login);
245 result.field("impersonator", &admin.github_login);
246 }
247 Principal::DevServer(dev_server) => {
248 result.field("dev_server", &dev_server.id);
249 }
250 }
251 result.field("connection_id", &self.connection_id).finish()
252 }
253}
254
255struct UserSession(Session);
256
257impl UserSession {
258 pub fn new(s: Session) -> Option<Self> {
259 s.user_id().map(|_| UserSession(s))
260 }
261 pub fn user_id(&self) -> UserId {
262 self.0.user_id().unwrap()
263 }
264
265 pub fn email(&self) -> Option<String> {
266 match &self.0.principal {
267 Principal::User(user) => user.email_address.clone(),
268 Principal::Impersonated { user, .. } => user.email_address.clone(),
269 Principal::DevServer(..) => None,
270 }
271 }
272}
273
274impl Deref for UserSession {
275 type Target = Session;
276
277 fn deref(&self) -> &Self::Target {
278 &self.0
279 }
280}
281impl DerefMut for UserSession {
282 fn deref_mut(&mut self) -> &mut Self::Target {
283 &mut self.0
284 }
285}
286
287struct DevServerSession(Session);
288
289impl DevServerSession {
290 pub fn new(s: Session) -> Option<Self> {
291 s.dev_server_id().map(|_| DevServerSession(s))
292 }
293 pub fn dev_server_id(&self) -> DevServerId {
294 self.0.dev_server_id().unwrap()
295 }
296
297 fn dev_server(&self) -> &dev_server::Model {
298 match &self.0.principal {
299 Principal::DevServer(dev_server) => dev_server,
300 _ => unreachable!(),
301 }
302 }
303}
304
305impl Deref for DevServerSession {
306 type Target = Session;
307
308 fn deref(&self) -> &Self::Target {
309 &self.0
310 }
311}
312impl DerefMut for DevServerSession {
313 fn deref_mut(&mut self) -> &mut Self::Target {
314 &mut self.0
315 }
316}
317
318fn user_handler<M: RequestMessage, Fut>(
319 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
321where
322 Fut: Send + Future<Output = Result<()>>,
323{
324 let handler = Arc::new(handler);
325 move |message, response, session| {
326 let handler = handler.clone();
327 Box::pin(async move {
328 if let Some(user_session) = session.for_user() {
329 Ok(handler(message, response, user_session).await?)
330 } else {
331 Err(Error::Internal(anyhow!(
332 "must be a user to call {}",
333 M::NAME
334 )))
335 }
336 })
337 }
338}
339
340fn dev_server_handler<M: RequestMessage, Fut>(
341 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
343where
344 Fut: Send + Future<Output = Result<()>>,
345{
346 let handler = Arc::new(handler);
347 move |message, response, session| {
348 let handler = handler.clone();
349 Box::pin(async move {
350 if let Some(dev_server_session) = session.for_dev_server() {
351 Ok(handler(message, response, dev_server_session).await?)
352 } else {
353 Err(Error::Internal(anyhow!(
354 "must be a dev server to call {}",
355 M::NAME
356 )))
357 }
358 })
359 }
360}
361
362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
363 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
365where
366 InnertRetFut: Send + Future<Output = Result<()>>,
367{
368 let handler = Arc::new(handler);
369 move |message, session| {
370 let handler = handler.clone();
371 Box::pin(async move {
372 if let Some(user_session) = session.for_user() {
373 Ok(handler(message, user_session).await?)
374 } else {
375 Err(Error::Internal(anyhow!(
376 "must be a user to call {}",
377 M::NAME
378 )))
379 }
380 })
381 }
382}
383
384struct DbHandle(Arc<Database>);
385
386impl Deref for DbHandle {
387 type Target = Database;
388
389 fn deref(&self) -> &Self::Target {
390 self.0.as_ref()
391 }
392}
393
394pub struct Server {
395 id: parking_lot::Mutex<ServerId>,
396 peer: Arc<Peer>,
397 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
398 app_state: Arc<AppState>,
399 handlers: HashMap<TypeId, MessageHandler>,
400 teardown: watch::Sender<bool>,
401}
402
403pub(crate) struct ConnectionPoolGuard<'a> {
404 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
405 _not_send: PhantomData<Rc<()>>,
406}
407
408#[derive(Serialize)]
409pub struct ServerSnapshot<'a> {
410 peer: &'a Peer,
411 #[serde(serialize_with = "serialize_deref")]
412 connection_pool: ConnectionPoolGuard<'a>,
413}
414
415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
416where
417 S: Serializer,
418 T: Deref<Target = U>,
419 U: Serialize,
420{
421 Serialize::serialize(value.deref(), serializer)
422}
423
424impl Server {
425 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
426 let mut server = Self {
427 id: parking_lot::Mutex::new(id),
428 peer: Peer::new(id.0 as u32),
429 app_state: app_state.clone(),
430 connection_pool: Default::default(),
431 handlers: Default::default(),
432 teardown: watch::channel(false).0,
433 };
434
435 server
436 .add_request_handler(ping)
437 .add_request_handler(user_handler(create_room))
438 .add_request_handler(user_handler(join_room))
439 .add_request_handler(user_handler(rejoin_room))
440 .add_request_handler(user_handler(leave_room))
441 .add_request_handler(user_handler(set_room_participant_role))
442 .add_request_handler(user_handler(call))
443 .add_request_handler(user_handler(cancel_call))
444 .add_message_handler(user_message_handler(decline_call))
445 .add_request_handler(user_handler(update_participant_location))
446 .add_request_handler(user_handler(share_project))
447 .add_message_handler(unshare_project)
448 .add_request_handler(user_handler(join_project))
449 .add_request_handler(user_handler(join_hosted_project))
450 .add_request_handler(user_handler(rejoin_dev_server_projects))
451 .add_request_handler(user_handler(create_dev_server_project))
452 .add_request_handler(user_handler(update_dev_server_project))
453 .add_request_handler(user_handler(delete_dev_server_project))
454 .add_request_handler(user_handler(create_dev_server))
455 .add_request_handler(user_handler(regenerate_dev_server_token))
456 .add_request_handler(user_handler(rename_dev_server))
457 .add_request_handler(user_handler(delete_dev_server))
458 .add_request_handler(user_handler(list_remote_directory))
459 .add_request_handler(dev_server_handler(share_dev_server_project))
460 .add_request_handler(dev_server_handler(shutdown_dev_server))
461 .add_request_handler(dev_server_handler(reconnect_dev_server))
462 .add_message_handler(user_message_handler(leave_project))
463 .add_request_handler(update_project)
464 .add_request_handler(update_worktree)
465 .add_message_handler(start_language_server)
466 .add_message_handler(update_language_server)
467 .add_message_handler(update_diagnostic_summary)
468 .add_message_handler(update_worktree_settings)
469 .add_request_handler(user_handler(
470 forward_project_request_for_owner::<proto::TaskContextForLocation>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetHover>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetDefinition>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetTypeDefinition>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetReferences>,
483 ))
484 .add_request_handler(user_handler(forward_find_search_candidates_request))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::GetDocumentHighlights>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::GetProjectSymbols>,
490 ))
491 .add_request_handler(user_handler(
492 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
493 ))
494 .add_request_handler(user_handler(
495 forward_read_only_project_request::<proto::OpenBufferById>,
496 ))
497 .add_request_handler(user_handler(
498 forward_read_only_project_request::<proto::SynchronizeBuffers>,
499 ))
500 .add_request_handler(user_handler(
501 forward_read_only_project_request::<proto::InlayHints>,
502 ))
503 .add_request_handler(user_handler(
504 forward_read_only_project_request::<proto::ResolveInlayHint>,
505 ))
506 .add_request_handler(user_handler(
507 forward_read_only_project_request::<proto::OpenBufferByPath>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::GetCompletions>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::OpenNewBuffer>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::GetCodeActions>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::ApplyCodeAction>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::PrepareRename>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::PerformRename>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ReloadBuffers>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::FormatBuffers>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::CreateProjectEntry>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::RenameProjectEntry>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::CopyProjectEntry>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::DeleteProjectEntry>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::ExpandProjectEntry>,
553 ))
554 .add_request_handler(user_handler(
555 forward_mutating_project_request::<proto::OnTypeFormatting>,
556 ))
557 .add_request_handler(user_handler(
558 forward_mutating_project_request::<proto::SaveBuffer>,
559 ))
560 .add_request_handler(user_handler(
561 forward_mutating_project_request::<proto::BlameBuffer>,
562 ))
563 .add_request_handler(user_handler(
564 forward_mutating_project_request::<proto::MultiLspQuery>,
565 ))
566 .add_request_handler(user_handler(
567 forward_mutating_project_request::<proto::RestartLanguageServers>,
568 ))
569 .add_request_handler(user_handler(
570 forward_mutating_project_request::<proto::LinkedEditingRange>,
571 ))
572 .add_message_handler(create_buffer_for_peer)
573 .add_request_handler(update_buffer)
574 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
575 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
576 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
577 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
578 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
579 .add_request_handler(get_users)
580 .add_request_handler(user_handler(fuzzy_search_users))
581 .add_request_handler(user_handler(request_contact))
582 .add_request_handler(user_handler(remove_contact))
583 .add_request_handler(user_handler(respond_to_contact_request))
584 .add_message_handler(subscribe_to_channels)
585 .add_request_handler(user_handler(create_channel))
586 .add_request_handler(user_handler(delete_channel))
587 .add_request_handler(user_handler(invite_channel_member))
588 .add_request_handler(user_handler(remove_channel_member))
589 .add_request_handler(user_handler(set_channel_member_role))
590 .add_request_handler(user_handler(set_channel_visibility))
591 .add_request_handler(user_handler(rename_channel))
592 .add_request_handler(user_handler(join_channel_buffer))
593 .add_request_handler(user_handler(leave_channel_buffer))
594 .add_message_handler(user_message_handler(update_channel_buffer))
595 .add_request_handler(user_handler(rejoin_channel_buffers))
596 .add_request_handler(user_handler(get_channel_members))
597 .add_request_handler(user_handler(respond_to_channel_invite))
598 .add_request_handler(user_handler(join_channel))
599 .add_request_handler(user_handler(join_channel_chat))
600 .add_message_handler(user_message_handler(leave_channel_chat))
601 .add_request_handler(user_handler(send_channel_message))
602 .add_request_handler(user_handler(remove_channel_message))
603 .add_request_handler(user_handler(update_channel_message))
604 .add_request_handler(user_handler(get_channel_messages))
605 .add_request_handler(user_handler(get_channel_messages_by_id))
606 .add_request_handler(user_handler(get_notifications))
607 .add_request_handler(user_handler(mark_notification_as_read))
608 .add_request_handler(user_handler(move_channel))
609 .add_request_handler(user_handler(follow))
610 .add_message_handler(user_message_handler(unfollow))
611 .add_message_handler(user_message_handler(update_followers))
612 .add_request_handler(user_handler(get_private_user_info))
613 .add_request_handler(user_handler(get_llm_api_token))
614 .add_request_handler(user_handler(accept_terms_of_service))
615 .add_message_handler(user_message_handler(acknowledge_channel_message))
616 .add_message_handler(user_message_handler(acknowledge_buffer_version))
617 .add_request_handler(user_handler(get_supermaven_api_key))
618 .add_request_handler(user_handler(
619 forward_mutating_project_request::<proto::OpenContext>,
620 ))
621 .add_request_handler(user_handler(
622 forward_mutating_project_request::<proto::CreateContext>,
623 ))
624 .add_request_handler(user_handler(
625 forward_mutating_project_request::<proto::SynchronizeContexts>,
626 ))
627 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
628 .add_message_handler(update_context)
629 .add_request_handler({
630 let app_state = app_state.clone();
631 move |request, response, session| {
632 let app_state = app_state.clone();
633 async move {
634 count_language_model_tokens(request, response, session, &app_state.config)
635 .await
636 }
637 }
638 })
639 .add_request_handler({
640 user_handler(move |request, response, session| {
641 get_cached_embeddings(request, response, session)
642 })
643 })
644 .add_request_handler({
645 let app_state = app_state.clone();
646 user_handler(move |request, response, session| {
647 compute_embeddings(
648 request,
649 response,
650 session,
651 app_state.config.openai_api_key.clone(),
652 )
653 })
654 });
655
656 Arc::new(server)
657 }
658
659 pub async fn start(&self) -> Result<()> {
660 let server_id = *self.id.lock();
661 let app_state = self.app_state.clone();
662 let peer = self.peer.clone();
663 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
664 let pool = self.connection_pool.clone();
665 let live_kit_client = self.app_state.live_kit_client.clone();
666
667 let span = info_span!("start server");
668 self.app_state.executor.spawn_detached(
669 async move {
670 tracing::info!("waiting for cleanup timeout");
671 timeout.await;
672 tracing::info!("cleanup timeout expired, retrieving stale rooms");
673 if let Some((room_ids, channel_ids)) = app_state
674 .db
675 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
676 .await
677 .trace_err()
678 {
679 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
680 tracing::info!(
681 stale_channel_buffer_count = channel_ids.len(),
682 "retrieved stale channel buffers"
683 );
684
685 for channel_id in channel_ids {
686 if let Some(refreshed_channel_buffer) = app_state
687 .db
688 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
689 .await
690 .trace_err()
691 {
692 for connection_id in refreshed_channel_buffer.connection_ids {
693 peer.send(
694 connection_id,
695 proto::UpdateChannelBufferCollaborators {
696 channel_id: channel_id.to_proto(),
697 collaborators: refreshed_channel_buffer
698 .collaborators
699 .clone(),
700 },
701 )
702 .trace_err();
703 }
704 }
705 }
706
707 for room_id in room_ids {
708 let mut contacts_to_update = HashSet::default();
709 let mut canceled_calls_to_user_ids = Vec::new();
710 let mut live_kit_room = String::new();
711 let mut delete_live_kit_room = false;
712
713 if let Some(mut refreshed_room) = app_state
714 .db
715 .clear_stale_room_participants(room_id, server_id)
716 .await
717 .trace_err()
718 {
719 tracing::info!(
720 room_id = room_id.0,
721 new_participant_count = refreshed_room.room.participants.len(),
722 "refreshed room"
723 );
724 room_updated(&refreshed_room.room, &peer);
725 if let Some(channel) = refreshed_room.channel.as_ref() {
726 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
727 }
728 contacts_to_update
729 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
730 contacts_to_update
731 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
732 canceled_calls_to_user_ids =
733 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
734 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
735 delete_live_kit_room = refreshed_room.room.participants.is_empty();
736 }
737
738 {
739 let pool = pool.lock();
740 for canceled_user_id in canceled_calls_to_user_ids {
741 for connection_id in pool.user_connection_ids(canceled_user_id) {
742 peer.send(
743 connection_id,
744 proto::CallCanceled {
745 room_id: room_id.to_proto(),
746 },
747 )
748 .trace_err();
749 }
750 }
751 }
752
753 for user_id in contacts_to_update {
754 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
755 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
756 if let Some((busy, contacts)) = busy.zip(contacts) {
757 let pool = pool.lock();
758 let updated_contact = contact_for_user(user_id, busy, &pool);
759 for contact in contacts {
760 if let db::Contact::Accepted {
761 user_id: contact_user_id,
762 ..
763 } = contact
764 {
765 for contact_conn_id in
766 pool.user_connection_ids(contact_user_id)
767 {
768 peer.send(
769 contact_conn_id,
770 proto::UpdateContacts {
771 contacts: vec![updated_contact.clone()],
772 remove_contacts: Default::default(),
773 incoming_requests: Default::default(),
774 remove_incoming_requests: Default::default(),
775 outgoing_requests: Default::default(),
776 remove_outgoing_requests: Default::default(),
777 },
778 )
779 .trace_err();
780 }
781 }
782 }
783 }
784 }
785
786 if let Some(live_kit) = live_kit_client.as_ref() {
787 if delete_live_kit_room {
788 live_kit.delete_room(live_kit_room).await.trace_err();
789 }
790 }
791 }
792 }
793
794 app_state
795 .db
796 .delete_stale_servers(&app_state.config.zed_environment, server_id)
797 .await
798 .trace_err();
799 }
800 .instrument(span),
801 );
802 Ok(())
803 }
804
805 pub fn teardown(&self) {
806 self.peer.teardown();
807 self.connection_pool.lock().reset();
808 let _ = self.teardown.send(true);
809 }
810
811 #[cfg(test)]
812 pub fn reset(&self, id: ServerId) {
813 self.teardown();
814 *self.id.lock() = id;
815 self.peer.reset(id.0 as u32);
816 let _ = self.teardown.send(false);
817 }
818
819 #[cfg(test)]
820 pub fn id(&self) -> ServerId {
821 *self.id.lock()
822 }
823
824 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
825 where
826 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
827 Fut: 'static + Send + Future<Output = Result<()>>,
828 M: EnvelopedMessage,
829 {
830 let prev_handler = self.handlers.insert(
831 TypeId::of::<M>(),
832 Box::new(move |envelope, session| {
833 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
834 let received_at = envelope.received_at;
835 tracing::info!("message received");
836 let start_time = Instant::now();
837 let future = (handler)(*envelope, session);
838 async move {
839 let result = future.await;
840 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
841 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
842 let queue_duration_ms = total_duration_ms - processing_duration_ms;
843 let payload_type = M::NAME;
844
845 match result {
846 Err(error) => {
847 tracing::error!(
848 ?error,
849 total_duration_ms,
850 processing_duration_ms,
851 queue_duration_ms,
852 payload_type,
853 "error handling message"
854 )
855 }
856 Ok(()) => tracing::info!(
857 total_duration_ms,
858 processing_duration_ms,
859 queue_duration_ms,
860 "finished handling message"
861 ),
862 }
863 }
864 .boxed()
865 }),
866 );
867 if prev_handler.is_some() {
868 panic!("registered a handler for the same message twice");
869 }
870 self
871 }
872
873 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
874 where
875 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
876 Fut: 'static + Send + Future<Output = Result<()>>,
877 M: EnvelopedMessage,
878 {
879 self.add_handler(move |envelope, session| handler(envelope.payload, session));
880 self
881 }
882
883 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
884 where
885 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
886 Fut: Send + Future<Output = Result<()>>,
887 M: RequestMessage,
888 {
889 let handler = Arc::new(handler);
890 self.add_handler(move |envelope, session| {
891 let receipt = envelope.receipt();
892 let handler = handler.clone();
893 async move {
894 let peer = session.peer.clone();
895 let responded = Arc::new(AtomicBool::default());
896 let response = Response {
897 peer: peer.clone(),
898 responded: responded.clone(),
899 receipt,
900 };
901 match (handler)(envelope.payload, response, session).await {
902 Ok(()) => {
903 if responded.load(std::sync::atomic::Ordering::SeqCst) {
904 Ok(())
905 } else {
906 Err(anyhow!("handler did not send a response"))?
907 }
908 }
909 Err(error) => {
910 let proto_err = match &error {
911 Error::Internal(err) => err.to_proto(),
912 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
913 };
914 peer.respond_with_error(receipt, proto_err)?;
915 Err(error)
916 }
917 }
918 }
919 })
920 }
921
922 #[allow(clippy::too_many_arguments)]
923 pub fn handle_connection(
924 self: &Arc<Self>,
925 connection: Connection,
926 address: String,
927 principal: Principal,
928 zed_version: ZedVersion,
929 geoip_country_code: Option<String>,
930 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
931 executor: Executor,
932 ) -> impl Future<Output = ()> {
933 let this = self.clone();
934 let span = info_span!("handle connection", %address,
935 connection_id=field::Empty,
936 user_id=field::Empty,
937 login=field::Empty,
938 impersonator=field::Empty,
939 dev_server_id=field::Empty,
940 geoip_country_code=field::Empty
941 );
942 principal.update_span(&span);
943 if let Some(country_code) = geoip_country_code.as_ref() {
944 span.record("geoip_country_code", country_code);
945 }
946
947 let mut teardown = self.teardown.subscribe();
948 async move {
949 if *teardown.borrow() {
950 tracing::error!("server is tearing down");
951 return
952 }
953 let (connection_id, handle_io, mut incoming_rx) = this
954 .peer
955 .add_connection(connection, {
956 let executor = executor.clone();
957 move |duration| executor.sleep(duration)
958 });
959 tracing::Span::current().record("connection_id", format!("{}", connection_id));
960
961 tracing::info!("connection opened");
962
963 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
964 let http_client = match ReqwestClient::user_agent(&user_agent) {
965 Ok(http_client) => Arc::new(http_client),
966 Err(error) => {
967 tracing::error!(?error, "failed to create HTTP client");
968 return;
969 }
970 };
971
972 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
973 supermaven_admin_api_key.to_string(),
974 http_client.clone(),
975 )));
976
977 let session = Session {
978 principal: principal.clone(),
979 connection_id,
980 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
981 peer: this.peer.clone(),
982 connection_pool: this.connection_pool.clone(),
983 app_state: this.app_state.clone(),
984 http_client,
985 geoip_country_code,
986 _executor: executor.clone(),
987 supermaven_client,
988 };
989
990 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
991 tracing::error!(?error, "failed to send initial client update");
992 return;
993 }
994
995 let handle_io = handle_io.fuse();
996 futures::pin_mut!(handle_io);
997
998 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
999 // This prevents deadlocks when e.g., client A performs a request to client B and
1000 // client B performs a request to client A. If both clients stop processing further
1001 // messages until their respective request completes, they won't have a chance to
1002 // respond to the other client's request and cause a deadlock.
1003 //
1004 // This arrangement ensures we will attempt to process earlier messages first, but fall
1005 // back to processing messages arrived later in the spirit of making progress.
1006 let mut foreground_message_handlers = FuturesUnordered::new();
1007 let concurrent_handlers = Arc::new(Semaphore::new(256));
1008 loop {
1009 let next_message = async {
1010 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1011 let message = incoming_rx.next().await;
1012 (permit, message)
1013 }.fuse();
1014 futures::pin_mut!(next_message);
1015 futures::select_biased! {
1016 _ = teardown.changed().fuse() => return,
1017 result = handle_io => {
1018 if let Err(error) = result {
1019 tracing::error!(?error, "error handling I/O");
1020 }
1021 break;
1022 }
1023 _ = foreground_message_handlers.next() => {}
1024 next_message = next_message => {
1025 let (permit, message) = next_message;
1026 if let Some(message) = message {
1027 let type_name = message.payload_type_name();
1028 // note: we copy all the fields from the parent span so we can query them in the logs.
1029 // (https://github.com/tokio-rs/tracing/issues/2670).
1030 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1031 user_id=field::Empty,
1032 login=field::Empty,
1033 impersonator=field::Empty,
1034 dev_server_id=field::Empty
1035 );
1036 principal.update_span(&span);
1037 let span_enter = span.enter();
1038 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1039 let is_background = message.is_background();
1040 let handle_message = (handler)(message, session.clone());
1041 drop(span_enter);
1042
1043 let handle_message = async move {
1044 handle_message.await;
1045 drop(permit);
1046 }.instrument(span);
1047 if is_background {
1048 executor.spawn_detached(handle_message);
1049 } else {
1050 foreground_message_handlers.push(handle_message);
1051 }
1052 } else {
1053 tracing::error!("no message handler");
1054 }
1055 } else {
1056 tracing::info!("connection closed");
1057 break;
1058 }
1059 }
1060 }
1061 }
1062
1063 drop(foreground_message_handlers);
1064 tracing::info!("signing out");
1065 if let Err(error) = connection_lost(session, teardown, executor).await {
1066 tracing::error!(?error, "error signing out");
1067 }
1068
1069 }.instrument(span)
1070 }
1071
1072 async fn send_initial_client_update(
1073 &self,
1074 connection_id: ConnectionId,
1075 principal: &Principal,
1076 zed_version: ZedVersion,
1077 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1078 session: &Session,
1079 ) -> Result<()> {
1080 self.peer.send(
1081 connection_id,
1082 proto::Hello {
1083 peer_id: Some(connection_id.into()),
1084 },
1085 )?;
1086 tracing::info!("sent hello message");
1087 if let Some(send_connection_id) = send_connection_id.take() {
1088 let _ = send_connection_id.send(connection_id);
1089 }
1090
1091 match principal {
1092 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1093 if !user.connected_once {
1094 self.peer.send(connection_id, proto::ShowContacts {})?;
1095 self.app_state
1096 .db
1097 .set_user_connected_once(user.id, true)
1098 .await?;
1099 }
1100
1101 update_user_plan(user.id, session).await?;
1102
1103 let (contacts, dev_server_projects) = future::try_join(
1104 self.app_state.db.get_contacts(user.id),
1105 self.app_state.db.dev_server_projects_update(user.id),
1106 )
1107 .await?;
1108
1109 {
1110 let mut pool = self.connection_pool.lock();
1111 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112 self.peer.send(
1113 connection_id,
1114 build_initial_contacts_update(contacts, &pool),
1115 )?;
1116 }
1117
1118 if should_auto_subscribe_to_channels(zed_version) {
1119 subscribe_user_to_channels(user.id, session).await?;
1120 }
1121
1122 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1123
1124 if let Some(incoming_call) =
1125 self.app_state.db.incoming_call_for_user(user.id).await?
1126 {
1127 self.peer.send(connection_id, incoming_call)?;
1128 }
1129
1130 update_user_contacts(user.id, session).await?;
1131 }
1132 Principal::DevServer(dev_server) => {
1133 {
1134 let mut pool = self.connection_pool.lock();
1135 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1136 {
1137 self.peer.send(
1138 stale_connection_id,
1139 proto::ShutdownDevServer {
1140 reason: Some(
1141 "another dev server connected with the same token".to_string(),
1142 ),
1143 },
1144 )?;
1145 pool.remove_connection(stale_connection_id)?;
1146 };
1147 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1148 }
1149
1150 let projects = self
1151 .app_state
1152 .db
1153 .get_projects_for_dev_server(dev_server.id)
1154 .await?;
1155 self.peer
1156 .send(connection_id, proto::DevServerInstructions { projects })?;
1157
1158 let status = self
1159 .app_state
1160 .db
1161 .dev_server_projects_update(dev_server.user_id)
1162 .await?;
1163 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1164 }
1165 }
1166
1167 Ok(())
1168 }
1169
1170 pub async fn invite_code_redeemed(
1171 self: &Arc<Self>,
1172 inviter_id: UserId,
1173 invitee_id: UserId,
1174 ) -> Result<()> {
1175 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1176 if let Some(code) = &user.invite_code {
1177 let pool = self.connection_pool.lock();
1178 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1179 for connection_id in pool.user_connection_ids(inviter_id) {
1180 self.peer.send(
1181 connection_id,
1182 proto::UpdateContacts {
1183 contacts: vec![invitee_contact.clone()],
1184 ..Default::default()
1185 },
1186 )?;
1187 self.peer.send(
1188 connection_id,
1189 proto::UpdateInviteInfo {
1190 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1191 count: user.invite_count as u32,
1192 },
1193 )?;
1194 }
1195 }
1196 }
1197 Ok(())
1198 }
1199
1200 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1201 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1202 if let Some(invite_code) = &user.invite_code {
1203 let pool = self.connection_pool.lock();
1204 for connection_id in pool.user_connection_ids(user_id) {
1205 self.peer.send(
1206 connection_id,
1207 proto::UpdateInviteInfo {
1208 url: format!(
1209 "{}{}",
1210 self.app_state.config.invite_link_prefix, invite_code
1211 ),
1212 count: user.invite_count as u32,
1213 },
1214 )?;
1215 }
1216 }
1217 }
1218 Ok(())
1219 }
1220
1221 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1222 let pool = self.connection_pool.lock();
1223 for connection_id in pool.user_connection_ids(user_id) {
1224 self.peer
1225 .send(connection_id, proto::RefreshLlmToken {})
1226 .trace_err();
1227 }
1228 }
1229
1230 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1231 ServerSnapshot {
1232 connection_pool: ConnectionPoolGuard {
1233 guard: self.connection_pool.lock(),
1234 _not_send: PhantomData,
1235 },
1236 peer: &self.peer,
1237 }
1238 }
1239}
1240
1241impl<'a> Deref for ConnectionPoolGuard<'a> {
1242 type Target = ConnectionPool;
1243
1244 fn deref(&self) -> &Self::Target {
1245 &self.guard
1246 }
1247}
1248
1249impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1250 fn deref_mut(&mut self) -> &mut Self::Target {
1251 &mut self.guard
1252 }
1253}
1254
1255impl<'a> Drop for ConnectionPoolGuard<'a> {
1256 fn drop(&mut self) {
1257 #[cfg(test)]
1258 self.check_invariants();
1259 }
1260}
1261
1262fn broadcast<F>(
1263 sender_id: Option<ConnectionId>,
1264 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1265 mut f: F,
1266) where
1267 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1268{
1269 for receiver_id in receiver_ids {
1270 if Some(receiver_id) != sender_id {
1271 if let Err(error) = f(receiver_id) {
1272 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1273 }
1274 }
1275 }
1276}
1277
1278pub struct ProtocolVersion(u32);
1279
1280impl Header for ProtocolVersion {
1281 fn name() -> &'static HeaderName {
1282 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1283 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1284 }
1285
1286 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1287 where
1288 Self: Sized,
1289 I: Iterator<Item = &'i axum::http::HeaderValue>,
1290 {
1291 let version = values
1292 .next()
1293 .ok_or_else(axum::headers::Error::invalid)?
1294 .to_str()
1295 .map_err(|_| axum::headers::Error::invalid())?
1296 .parse()
1297 .map_err(|_| axum::headers::Error::invalid())?;
1298 Ok(Self(version))
1299 }
1300
1301 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1302 values.extend([self.0.to_string().parse().unwrap()]);
1303 }
1304}
1305
1306pub struct AppVersionHeader(SemanticVersion);
1307impl Header for AppVersionHeader {
1308 fn name() -> &'static HeaderName {
1309 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1311 }
1312
1313 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314 where
1315 Self: Sized,
1316 I: Iterator<Item = &'i axum::http::HeaderValue>,
1317 {
1318 let version = values
1319 .next()
1320 .ok_or_else(axum::headers::Error::invalid)?
1321 .to_str()
1322 .map_err(|_| axum::headers::Error::invalid())?
1323 .parse()
1324 .map_err(|_| axum::headers::Error::invalid())?;
1325 Ok(Self(version))
1326 }
1327
1328 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329 values.extend([self.0.to_string().parse().unwrap()]);
1330 }
1331}
1332
1333pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1334 Router::new()
1335 .route("/rpc", get(handle_websocket_request))
1336 .layer(
1337 ServiceBuilder::new()
1338 .layer(Extension(server.app_state.clone()))
1339 .layer(middleware::from_fn(auth::validate_header)),
1340 )
1341 .route("/metrics", get(handle_metrics))
1342 .layer(Extension(server))
1343}
1344
1345pub async fn handle_websocket_request(
1346 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1347 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1348 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1349 Extension(server): Extension<Arc<Server>>,
1350 Extension(principal): Extension<Principal>,
1351 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1352 ws: WebSocketUpgrade,
1353) -> axum::response::Response {
1354 if protocol_version != rpc::PROTOCOL_VERSION {
1355 return (
1356 StatusCode::UPGRADE_REQUIRED,
1357 "client must be upgraded".to_string(),
1358 )
1359 .into_response();
1360 }
1361
1362 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1363 return (
1364 StatusCode::UPGRADE_REQUIRED,
1365 "no version header found".to_string(),
1366 )
1367 .into_response();
1368 };
1369
1370 if !version.can_collaborate() {
1371 return (
1372 StatusCode::UPGRADE_REQUIRED,
1373 "client must be upgraded".to_string(),
1374 )
1375 .into_response();
1376 }
1377
1378 let socket_address = socket_address.to_string();
1379 ws.on_upgrade(move |socket| {
1380 let socket = socket
1381 .map_ok(to_tungstenite_message)
1382 .err_into()
1383 .with(|message| async move { to_axum_message(message) });
1384 let connection = Connection::new(Box::pin(socket));
1385 async move {
1386 server
1387 .handle_connection(
1388 connection,
1389 socket_address,
1390 principal,
1391 version,
1392 country_code_header.map(|header| header.to_string()),
1393 None,
1394 Executor::Production,
1395 )
1396 .await;
1397 }
1398 })
1399}
1400
1401pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1402 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403 let connections_metric = CONNECTIONS_METRIC
1404 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1405
1406 let connections = server
1407 .connection_pool
1408 .lock()
1409 .connections()
1410 .filter(|connection| !connection.admin)
1411 .count();
1412 connections_metric.set(connections as _);
1413
1414 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1416 register_int_gauge!(
1417 "shared_projects",
1418 "number of open projects with one or more guests"
1419 )
1420 .unwrap()
1421 });
1422
1423 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1424 shared_projects_metric.set(shared_projects as _);
1425
1426 let encoder = prometheus::TextEncoder::new();
1427 let metric_families = prometheus::gather();
1428 let encoded_metrics = encoder
1429 .encode_to_string(&metric_families)
1430 .map_err(|err| anyhow!("{}", err))?;
1431 Ok(encoded_metrics)
1432}
1433
1434#[instrument(err, skip(executor))]
1435async fn connection_lost(
1436 session: Session,
1437 mut teardown: watch::Receiver<bool>,
1438 executor: Executor,
1439) -> Result<()> {
1440 session.peer.disconnect(session.connection_id);
1441 session
1442 .connection_pool()
1443 .await
1444 .remove_connection(session.connection_id)?;
1445
1446 session
1447 .db()
1448 .await
1449 .connection_lost(session.connection_id)
1450 .await
1451 .trace_err();
1452
1453 futures::select_biased! {
1454 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1455 match &session.principal {
1456 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1457 let session = session.for_user().unwrap();
1458
1459 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1460 leave_room_for_session(&session, session.connection_id).await.trace_err();
1461 leave_channel_buffers_for_session(&session)
1462 .await
1463 .trace_err();
1464
1465 if !session
1466 .connection_pool()
1467 .await
1468 .is_user_online(session.user_id())
1469 {
1470 let db = session.db().await;
1471 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1472 room_updated(&room, &session.peer);
1473 }
1474 }
1475
1476 update_user_contacts(session.user_id(), &session).await?;
1477 },
1478 Principal::DevServer(_) => {
1479 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1480 },
1481 }
1482 },
1483 _ = teardown.changed().fuse() => {}
1484 }
1485
1486 Ok(())
1487}
1488
1489/// Acknowledges a ping from a client, used to keep the connection alive.
1490async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1491 response.send(proto::Ack {})?;
1492 Ok(())
1493}
1494
1495/// Creates a new room for calling (outside of channels)
1496async fn create_room(
1497 _request: proto::CreateRoom,
1498 response: Response<proto::CreateRoom>,
1499 session: UserSession,
1500) -> Result<()> {
1501 let live_kit_room = nanoid::nanoid!(30);
1502
1503 let live_kit_connection_info = util::maybe!(async {
1504 let live_kit = session.app_state.live_kit_client.as_ref();
1505 let live_kit = live_kit?;
1506 let user_id = session.user_id().to_string();
1507
1508 let token = live_kit
1509 .room_token(&live_kit_room, &user_id.to_string())
1510 .trace_err()?;
1511
1512 Some(proto::LiveKitConnectionInfo {
1513 server_url: live_kit.url().into(),
1514 token,
1515 can_publish: true,
1516 })
1517 })
1518 .await;
1519
1520 let room = session
1521 .db()
1522 .await
1523 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1524 .await?;
1525
1526 response.send(proto::CreateRoomResponse {
1527 room: Some(room.clone()),
1528 live_kit_connection_info,
1529 })?;
1530
1531 update_user_contacts(session.user_id(), &session).await?;
1532 Ok(())
1533}
1534
1535/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1536async fn join_room(
1537 request: proto::JoinRoom,
1538 response: Response<proto::JoinRoom>,
1539 session: UserSession,
1540) -> Result<()> {
1541 let room_id = RoomId::from_proto(request.id);
1542
1543 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1544
1545 if let Some(channel_id) = channel_id {
1546 return join_channel_internal(channel_id, Box::new(response), session).await;
1547 }
1548
1549 let joined_room = {
1550 let room = session
1551 .db()
1552 .await
1553 .join_room(room_id, session.user_id(), session.connection_id)
1554 .await?;
1555 room_updated(&room.room, &session.peer);
1556 room.into_inner()
1557 };
1558
1559 for connection_id in session
1560 .connection_pool()
1561 .await
1562 .user_connection_ids(session.user_id())
1563 {
1564 session
1565 .peer
1566 .send(
1567 connection_id,
1568 proto::CallCanceled {
1569 room_id: room_id.to_proto(),
1570 },
1571 )
1572 .trace_err();
1573 }
1574
1575 let live_kit_connection_info =
1576 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1577 live_kit
1578 .room_token(
1579 &joined_room.room.live_kit_room,
1580 &session.user_id().to_string(),
1581 )
1582 .trace_err()
1583 .map(|token| proto::LiveKitConnectionInfo {
1584 server_url: live_kit.url().into(),
1585 token,
1586 can_publish: true,
1587 })
1588 } else {
1589 None
1590 };
1591
1592 response.send(proto::JoinRoomResponse {
1593 room: Some(joined_room.room),
1594 channel_id: None,
1595 live_kit_connection_info,
1596 })?;
1597
1598 update_user_contacts(session.user_id(), &session).await?;
1599 Ok(())
1600}
1601
1602/// Rejoin room is used to reconnect to a room after connection errors.
1603async fn rejoin_room(
1604 request: proto::RejoinRoom,
1605 response: Response<proto::RejoinRoom>,
1606 session: UserSession,
1607) -> Result<()> {
1608 let room;
1609 let channel;
1610 {
1611 let mut rejoined_room = session
1612 .db()
1613 .await
1614 .rejoin_room(request, session.user_id(), session.connection_id)
1615 .await?;
1616
1617 response.send(proto::RejoinRoomResponse {
1618 room: Some(rejoined_room.room.clone()),
1619 reshared_projects: rejoined_room
1620 .reshared_projects
1621 .iter()
1622 .map(|project| proto::ResharedProject {
1623 id: project.id.to_proto(),
1624 collaborators: project
1625 .collaborators
1626 .iter()
1627 .map(|collaborator| collaborator.to_proto())
1628 .collect(),
1629 })
1630 .collect(),
1631 rejoined_projects: rejoined_room
1632 .rejoined_projects
1633 .iter()
1634 .map(|rejoined_project| rejoined_project.to_proto())
1635 .collect(),
1636 })?;
1637 room_updated(&rejoined_room.room, &session.peer);
1638
1639 for project in &rejoined_room.reshared_projects {
1640 for collaborator in &project.collaborators {
1641 session
1642 .peer
1643 .send(
1644 collaborator.connection_id,
1645 proto::UpdateProjectCollaborator {
1646 project_id: project.id.to_proto(),
1647 old_peer_id: Some(project.old_connection_id.into()),
1648 new_peer_id: Some(session.connection_id.into()),
1649 },
1650 )
1651 .trace_err();
1652 }
1653
1654 broadcast(
1655 Some(session.connection_id),
1656 project
1657 .collaborators
1658 .iter()
1659 .map(|collaborator| collaborator.connection_id),
1660 |connection_id| {
1661 session.peer.forward_send(
1662 session.connection_id,
1663 connection_id,
1664 proto::UpdateProject {
1665 project_id: project.id.to_proto(),
1666 worktrees: project.worktrees.clone(),
1667 },
1668 )
1669 },
1670 );
1671 }
1672
1673 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1674
1675 let rejoined_room = rejoined_room.into_inner();
1676
1677 room = rejoined_room.room;
1678 channel = rejoined_room.channel;
1679 }
1680
1681 if let Some(channel) = channel {
1682 channel_updated(
1683 &channel,
1684 &room,
1685 &session.peer,
1686 &*session.connection_pool().await,
1687 );
1688 }
1689
1690 update_user_contacts(session.user_id(), &session).await?;
1691 Ok(())
1692}
1693
1694fn notify_rejoined_projects(
1695 rejoined_projects: &mut Vec<RejoinedProject>,
1696 session: &UserSession,
1697) -> Result<()> {
1698 for project in rejoined_projects.iter() {
1699 for collaborator in &project.collaborators {
1700 session
1701 .peer
1702 .send(
1703 collaborator.connection_id,
1704 proto::UpdateProjectCollaborator {
1705 project_id: project.id.to_proto(),
1706 old_peer_id: Some(project.old_connection_id.into()),
1707 new_peer_id: Some(session.connection_id.into()),
1708 },
1709 )
1710 .trace_err();
1711 }
1712 }
1713
1714 for project in rejoined_projects {
1715 for worktree in mem::take(&mut project.worktrees) {
1716 #[cfg(any(test, feature = "test-support"))]
1717 const MAX_CHUNK_SIZE: usize = 2;
1718 #[cfg(not(any(test, feature = "test-support")))]
1719 const MAX_CHUNK_SIZE: usize = 256;
1720
1721 // Stream this worktree's entries.
1722 let message = proto::UpdateWorktree {
1723 project_id: project.id.to_proto(),
1724 worktree_id: worktree.id,
1725 abs_path: worktree.abs_path.clone(),
1726 root_name: worktree.root_name,
1727 updated_entries: worktree.updated_entries,
1728 removed_entries: worktree.removed_entries,
1729 scan_id: worktree.scan_id,
1730 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1731 updated_repositories: worktree.updated_repositories,
1732 removed_repositories: worktree.removed_repositories,
1733 };
1734 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1735 session.peer.send(session.connection_id, update.clone())?;
1736 }
1737
1738 // Stream this worktree's diagnostics.
1739 for summary in worktree.diagnostic_summaries {
1740 session.peer.send(
1741 session.connection_id,
1742 proto::UpdateDiagnosticSummary {
1743 project_id: project.id.to_proto(),
1744 worktree_id: worktree.id,
1745 summary: Some(summary),
1746 },
1747 )?;
1748 }
1749
1750 for settings_file in worktree.settings_files {
1751 session.peer.send(
1752 session.connection_id,
1753 proto::UpdateWorktreeSettings {
1754 project_id: project.id.to_proto(),
1755 worktree_id: worktree.id,
1756 path: settings_file.path,
1757 content: Some(settings_file.content),
1758 kind: Some(settings_file.kind.to_proto().into()),
1759 },
1760 )?;
1761 }
1762 }
1763
1764 for language_server in &project.language_servers {
1765 session.peer.send(
1766 session.connection_id,
1767 proto::UpdateLanguageServer {
1768 project_id: project.id.to_proto(),
1769 language_server_id: language_server.id,
1770 variant: Some(
1771 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1772 proto::LspDiskBasedDiagnosticsUpdated {},
1773 ),
1774 ),
1775 },
1776 )?;
1777 }
1778 }
1779 Ok(())
1780}
1781
1782/// leave room disconnects from the room.
1783async fn leave_room(
1784 _: proto::LeaveRoom,
1785 response: Response<proto::LeaveRoom>,
1786 session: UserSession,
1787) -> Result<()> {
1788 leave_room_for_session(&session, session.connection_id).await?;
1789 response.send(proto::Ack {})?;
1790 Ok(())
1791}
1792
1793/// Updates the permissions of someone else in the room.
1794async fn set_room_participant_role(
1795 request: proto::SetRoomParticipantRole,
1796 response: Response<proto::SetRoomParticipantRole>,
1797 session: UserSession,
1798) -> Result<()> {
1799 let user_id = UserId::from_proto(request.user_id);
1800 let role = ChannelRole::from(request.role());
1801
1802 let (live_kit_room, can_publish) = {
1803 let room = session
1804 .db()
1805 .await
1806 .set_room_participant_role(
1807 session.user_id(),
1808 RoomId::from_proto(request.room_id),
1809 user_id,
1810 role,
1811 )
1812 .await?;
1813
1814 let live_kit_room = room.live_kit_room.clone();
1815 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1816 room_updated(&room, &session.peer);
1817 (live_kit_room, can_publish)
1818 };
1819
1820 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1821 live_kit
1822 .update_participant(
1823 live_kit_room.clone(),
1824 request.user_id.to_string(),
1825 live_kit_server::proto::ParticipantPermission {
1826 can_subscribe: true,
1827 can_publish,
1828 can_publish_data: can_publish,
1829 hidden: false,
1830 recorder: false,
1831 },
1832 )
1833 .await
1834 .trace_err();
1835 }
1836
1837 response.send(proto::Ack {})?;
1838 Ok(())
1839}
1840
1841/// Call someone else into the current room
1842async fn call(
1843 request: proto::Call,
1844 response: Response<proto::Call>,
1845 session: UserSession,
1846) -> Result<()> {
1847 let room_id = RoomId::from_proto(request.room_id);
1848 let calling_user_id = session.user_id();
1849 let calling_connection_id = session.connection_id;
1850 let called_user_id = UserId::from_proto(request.called_user_id);
1851 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1852 if !session
1853 .db()
1854 .await
1855 .has_contact(calling_user_id, called_user_id)
1856 .await?
1857 {
1858 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1859 }
1860
1861 let incoming_call = {
1862 let (room, incoming_call) = &mut *session
1863 .db()
1864 .await
1865 .call(
1866 room_id,
1867 calling_user_id,
1868 calling_connection_id,
1869 called_user_id,
1870 initial_project_id,
1871 )
1872 .await?;
1873 room_updated(room, &session.peer);
1874 mem::take(incoming_call)
1875 };
1876 update_user_contacts(called_user_id, &session).await?;
1877
1878 let mut calls = session
1879 .connection_pool()
1880 .await
1881 .user_connection_ids(called_user_id)
1882 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1883 .collect::<FuturesUnordered<_>>();
1884
1885 while let Some(call_response) = calls.next().await {
1886 match call_response.as_ref() {
1887 Ok(_) => {
1888 response.send(proto::Ack {})?;
1889 return Ok(());
1890 }
1891 Err(_) => {
1892 call_response.trace_err();
1893 }
1894 }
1895 }
1896
1897 {
1898 let room = session
1899 .db()
1900 .await
1901 .call_failed(room_id, called_user_id)
1902 .await?;
1903 room_updated(&room, &session.peer);
1904 }
1905 update_user_contacts(called_user_id, &session).await?;
1906
1907 Err(anyhow!("failed to ring user"))?
1908}
1909
1910/// Cancel an outgoing call.
1911async fn cancel_call(
1912 request: proto::CancelCall,
1913 response: Response<proto::CancelCall>,
1914 session: UserSession,
1915) -> Result<()> {
1916 let called_user_id = UserId::from_proto(request.called_user_id);
1917 let room_id = RoomId::from_proto(request.room_id);
1918 {
1919 let room = session
1920 .db()
1921 .await
1922 .cancel_call(room_id, session.connection_id, called_user_id)
1923 .await?;
1924 room_updated(&room, &session.peer);
1925 }
1926
1927 for connection_id in session
1928 .connection_pool()
1929 .await
1930 .user_connection_ids(called_user_id)
1931 {
1932 session
1933 .peer
1934 .send(
1935 connection_id,
1936 proto::CallCanceled {
1937 room_id: room_id.to_proto(),
1938 },
1939 )
1940 .trace_err();
1941 }
1942 response.send(proto::Ack {})?;
1943
1944 update_user_contacts(called_user_id, &session).await?;
1945 Ok(())
1946}
1947
1948/// Decline an incoming call.
1949async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1950 let room_id = RoomId::from_proto(message.room_id);
1951 {
1952 let room = session
1953 .db()
1954 .await
1955 .decline_call(Some(room_id), session.user_id())
1956 .await?
1957 .ok_or_else(|| anyhow!("failed to decline call"))?;
1958 room_updated(&room, &session.peer);
1959 }
1960
1961 for connection_id in session
1962 .connection_pool()
1963 .await
1964 .user_connection_ids(session.user_id())
1965 {
1966 session
1967 .peer
1968 .send(
1969 connection_id,
1970 proto::CallCanceled {
1971 room_id: room_id.to_proto(),
1972 },
1973 )
1974 .trace_err();
1975 }
1976 update_user_contacts(session.user_id(), &session).await?;
1977 Ok(())
1978}
1979
1980/// Updates other participants in the room with your current location.
1981async fn update_participant_location(
1982 request: proto::UpdateParticipantLocation,
1983 response: Response<proto::UpdateParticipantLocation>,
1984 session: UserSession,
1985) -> Result<()> {
1986 let room_id = RoomId::from_proto(request.room_id);
1987 let location = request
1988 .location
1989 .ok_or_else(|| anyhow!("invalid location"))?;
1990
1991 let db = session.db().await;
1992 let room = db
1993 .update_room_participant_location(room_id, session.connection_id, location)
1994 .await?;
1995
1996 room_updated(&room, &session.peer);
1997 response.send(proto::Ack {})?;
1998 Ok(())
1999}
2000
2001/// Share a project into the room.
2002async fn share_project(
2003 request: proto::ShareProject,
2004 response: Response<proto::ShareProject>,
2005 session: UserSession,
2006) -> Result<()> {
2007 let (project_id, room) = &*session
2008 .db()
2009 .await
2010 .share_project(
2011 RoomId::from_proto(request.room_id),
2012 session.connection_id,
2013 &request.worktrees,
2014 request.is_ssh_project,
2015 request
2016 .dev_server_project_id
2017 .map(DevServerProjectId::from_proto),
2018 )
2019 .await?;
2020 response.send(proto::ShareProjectResponse {
2021 project_id: project_id.to_proto(),
2022 })?;
2023 room_updated(room, &session.peer);
2024
2025 Ok(())
2026}
2027
2028/// Unshare a project from the room.
2029async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2030 let project_id = ProjectId::from_proto(message.project_id);
2031 unshare_project_internal(
2032 project_id,
2033 session.connection_id,
2034 session.user_id(),
2035 &session,
2036 )
2037 .await
2038}
2039
2040async fn unshare_project_internal(
2041 project_id: ProjectId,
2042 connection_id: ConnectionId,
2043 user_id: Option<UserId>,
2044 session: &Session,
2045) -> Result<()> {
2046 let delete = {
2047 let room_guard = session
2048 .db()
2049 .await
2050 .unshare_project(project_id, connection_id, user_id)
2051 .await?;
2052
2053 let (delete, room, guest_connection_ids) = &*room_guard;
2054
2055 let message = proto::UnshareProject {
2056 project_id: project_id.to_proto(),
2057 };
2058
2059 broadcast(
2060 Some(connection_id),
2061 guest_connection_ids.iter().copied(),
2062 |conn_id| session.peer.send(conn_id, message.clone()),
2063 );
2064 if let Some(room) = room {
2065 room_updated(room, &session.peer);
2066 }
2067
2068 *delete
2069 };
2070
2071 if delete {
2072 let db = session.db().await;
2073 db.delete_project(project_id).await?;
2074 }
2075
2076 Ok(())
2077}
2078
2079/// DevServer makes a project available online
2080async fn share_dev_server_project(
2081 request: proto::ShareDevServerProject,
2082 response: Response<proto::ShareDevServerProject>,
2083 session: DevServerSession,
2084) -> Result<()> {
2085 let (dev_server_project, user_id, status) = session
2086 .db()
2087 .await
2088 .share_dev_server_project(
2089 DevServerProjectId::from_proto(request.dev_server_project_id),
2090 session.dev_server_id(),
2091 session.connection_id,
2092 &request.worktrees,
2093 )
2094 .await?;
2095 let Some(project_id) = dev_server_project.project_id else {
2096 return Err(anyhow!("failed to share remote project"))?;
2097 };
2098
2099 send_dev_server_projects_update(user_id, status, &session).await;
2100
2101 response.send(proto::ShareProjectResponse { project_id })?;
2102
2103 Ok(())
2104}
2105
2106/// Join someone elses shared project.
2107async fn join_project(
2108 request: proto::JoinProject,
2109 response: Response<proto::JoinProject>,
2110 session: UserSession,
2111) -> Result<()> {
2112 let project_id = ProjectId::from_proto(request.project_id);
2113
2114 tracing::info!(%project_id, "join project");
2115
2116 let db = session.db().await;
2117 let (project, replica_id) = &mut *db
2118 .join_project(project_id, session.connection_id, session.user_id())
2119 .await?;
2120 drop(db);
2121 tracing::info!(%project_id, "join remote project");
2122 join_project_internal(response, session, project, replica_id)
2123}
2124
2125trait JoinProjectInternalResponse {
2126 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2127}
2128impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2129 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2130 Response::<proto::JoinProject>::send(self, result)
2131 }
2132}
2133impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2134 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2135 Response::<proto::JoinHostedProject>::send(self, result)
2136 }
2137}
2138
2139fn join_project_internal(
2140 response: impl JoinProjectInternalResponse,
2141 session: UserSession,
2142 project: &mut Project,
2143 replica_id: &ReplicaId,
2144) -> Result<()> {
2145 let collaborators = project
2146 .collaborators
2147 .iter()
2148 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2149 .map(|collaborator| collaborator.to_proto())
2150 .collect::<Vec<_>>();
2151 let project_id = project.id;
2152 let guest_user_id = session.user_id();
2153
2154 let worktrees = project
2155 .worktrees
2156 .iter()
2157 .map(|(id, worktree)| proto::WorktreeMetadata {
2158 id: *id,
2159 root_name: worktree.root_name.clone(),
2160 visible: worktree.visible,
2161 abs_path: worktree.abs_path.clone(),
2162 })
2163 .collect::<Vec<_>>();
2164
2165 let add_project_collaborator = proto::AddProjectCollaborator {
2166 project_id: project_id.to_proto(),
2167 collaborator: Some(proto::Collaborator {
2168 peer_id: Some(session.connection_id.into()),
2169 replica_id: replica_id.0 as u32,
2170 user_id: guest_user_id.to_proto(),
2171 }),
2172 };
2173
2174 for collaborator in &collaborators {
2175 session
2176 .peer
2177 .send(
2178 collaborator.peer_id.unwrap().into(),
2179 add_project_collaborator.clone(),
2180 )
2181 .trace_err();
2182 }
2183
2184 // First, we send the metadata associated with each worktree.
2185 response.send(proto::JoinProjectResponse {
2186 project_id: project.id.0 as u64,
2187 worktrees: worktrees.clone(),
2188 replica_id: replica_id.0 as u32,
2189 collaborators: collaborators.clone(),
2190 language_servers: project.language_servers.clone(),
2191 role: project.role.into(),
2192 dev_server_project_id: project
2193 .dev_server_project_id
2194 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2195 })?;
2196
2197 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2198 #[cfg(any(test, feature = "test-support"))]
2199 const MAX_CHUNK_SIZE: usize = 2;
2200 #[cfg(not(any(test, feature = "test-support")))]
2201 const MAX_CHUNK_SIZE: usize = 256;
2202
2203 // Stream this worktree's entries.
2204 let message = proto::UpdateWorktree {
2205 project_id: project_id.to_proto(),
2206 worktree_id,
2207 abs_path: worktree.abs_path.clone(),
2208 root_name: worktree.root_name,
2209 updated_entries: worktree.entries,
2210 removed_entries: Default::default(),
2211 scan_id: worktree.scan_id,
2212 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2213 updated_repositories: worktree.repository_entries.into_values().collect(),
2214 removed_repositories: Default::default(),
2215 };
2216 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2217 session.peer.send(session.connection_id, update.clone())?;
2218 }
2219
2220 // Stream this worktree's diagnostics.
2221 for summary in worktree.diagnostic_summaries {
2222 session.peer.send(
2223 session.connection_id,
2224 proto::UpdateDiagnosticSummary {
2225 project_id: project_id.to_proto(),
2226 worktree_id: worktree.id,
2227 summary: Some(summary),
2228 },
2229 )?;
2230 }
2231
2232 for settings_file in worktree.settings_files {
2233 session.peer.send(
2234 session.connection_id,
2235 proto::UpdateWorktreeSettings {
2236 project_id: project_id.to_proto(),
2237 worktree_id: worktree.id,
2238 path: settings_file.path,
2239 content: Some(settings_file.content),
2240 kind: Some(proto::update_user_settings::Kind::Settings.into()),
2241 },
2242 )?;
2243 }
2244 }
2245
2246 for language_server in &project.language_servers {
2247 session.peer.send(
2248 session.connection_id,
2249 proto::UpdateLanguageServer {
2250 project_id: project_id.to_proto(),
2251 language_server_id: language_server.id,
2252 variant: Some(
2253 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2254 proto::LspDiskBasedDiagnosticsUpdated {},
2255 ),
2256 ),
2257 },
2258 )?;
2259 }
2260
2261 Ok(())
2262}
2263
2264/// Leave someone elses shared project.
2265async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2266 let sender_id = session.connection_id;
2267 let project_id = ProjectId::from_proto(request.project_id);
2268 let db = session.db().await;
2269 if db.is_hosted_project(project_id).await? {
2270 let project = db.leave_hosted_project(project_id, sender_id).await?;
2271 project_left(&project, &session);
2272 return Ok(());
2273 }
2274
2275 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2276 tracing::info!(
2277 %project_id,
2278 "leave project"
2279 );
2280
2281 project_left(project, &session);
2282 if let Some(room) = room {
2283 room_updated(room, &session.peer);
2284 }
2285
2286 Ok(())
2287}
2288
2289async fn join_hosted_project(
2290 request: proto::JoinHostedProject,
2291 response: Response<proto::JoinHostedProject>,
2292 session: UserSession,
2293) -> Result<()> {
2294 let (mut project, replica_id) = session
2295 .db()
2296 .await
2297 .join_hosted_project(
2298 ProjectId(request.project_id as i32),
2299 session.user_id(),
2300 session.connection_id,
2301 )
2302 .await?;
2303
2304 join_project_internal(response, session, &mut project, &replica_id)
2305}
2306
2307async fn list_remote_directory(
2308 request: proto::ListRemoteDirectory,
2309 response: Response<proto::ListRemoteDirectory>,
2310 session: UserSession,
2311) -> Result<()> {
2312 let dev_server_id = DevServerId(request.dev_server_id as i32);
2313 let dev_server_connection_id = session
2314 .connection_pool()
2315 .await
2316 .online_dev_server_connection_id(dev_server_id)?;
2317
2318 session
2319 .db()
2320 .await
2321 .get_dev_server_for_user(dev_server_id, session.user_id())
2322 .await?;
2323
2324 response.send(
2325 session
2326 .peer
2327 .forward_request(session.connection_id, dev_server_connection_id, request)
2328 .await?,
2329 )?;
2330 Ok(())
2331}
2332
2333async fn update_dev_server_project(
2334 request: proto::UpdateDevServerProject,
2335 response: Response<proto::UpdateDevServerProject>,
2336 session: UserSession,
2337) -> Result<()> {
2338 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2339
2340 let (dev_server_project, update) = session
2341 .db()
2342 .await
2343 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2344 .await?;
2345
2346 let projects = session
2347 .db()
2348 .await
2349 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2350 .await?;
2351
2352 let dev_server_connection_id = session
2353 .connection_pool()
2354 .await
2355 .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2356
2357 session.peer.send(
2358 dev_server_connection_id,
2359 proto::DevServerInstructions { projects },
2360 )?;
2361
2362 send_dev_server_projects_update(session.user_id(), update, &session).await;
2363
2364 response.send(proto::Ack {})
2365}
2366
2367async fn create_dev_server_project(
2368 request: proto::CreateDevServerProject,
2369 response: Response<proto::CreateDevServerProject>,
2370 session: UserSession,
2371) -> Result<()> {
2372 let dev_server_id = DevServerId(request.dev_server_id as i32);
2373 let dev_server_connection_id = session
2374 .connection_pool()
2375 .await
2376 .dev_server_connection_id(dev_server_id);
2377 let Some(dev_server_connection_id) = dev_server_connection_id else {
2378 Err(ErrorCode::DevServerOffline
2379 .message("Cannot create a remote project when the dev server is offline".to_string())
2380 .anyhow())?
2381 };
2382
2383 let path = request.path.clone();
2384 //Check that the path exists on the dev server
2385 session
2386 .peer
2387 .forward_request(
2388 session.connection_id,
2389 dev_server_connection_id,
2390 proto::ValidateDevServerProjectRequest { path: path.clone() },
2391 )
2392 .await?;
2393
2394 let (dev_server_project, update) = session
2395 .db()
2396 .await
2397 .create_dev_server_project(
2398 DevServerId(request.dev_server_id as i32),
2399 &request.path,
2400 session.user_id(),
2401 )
2402 .await?;
2403
2404 let projects = session
2405 .db()
2406 .await
2407 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2408 .await?;
2409
2410 session.peer.send(
2411 dev_server_connection_id,
2412 proto::DevServerInstructions { projects },
2413 )?;
2414
2415 send_dev_server_projects_update(session.user_id(), update, &session).await;
2416
2417 response.send(proto::CreateDevServerProjectResponse {
2418 dev_server_project: Some(dev_server_project.to_proto(None)),
2419 })?;
2420 Ok(())
2421}
2422
2423async fn create_dev_server(
2424 request: proto::CreateDevServer,
2425 response: Response<proto::CreateDevServer>,
2426 session: UserSession,
2427) -> Result<()> {
2428 let access_token = auth::random_token();
2429 let hashed_access_token = auth::hash_access_token(&access_token);
2430
2431 if request.name.is_empty() {
2432 return Err(proto::ErrorCode::Forbidden
2433 .message("Dev server name cannot be empty".to_string())
2434 .anyhow())?;
2435 }
2436
2437 let (dev_server, status) = session
2438 .db()
2439 .await
2440 .create_dev_server(
2441 &request.name,
2442 request.ssh_connection_string.as_deref(),
2443 &hashed_access_token,
2444 session.user_id(),
2445 )
2446 .await?;
2447
2448 send_dev_server_projects_update(session.user_id(), status, &session).await;
2449
2450 response.send(proto::CreateDevServerResponse {
2451 dev_server_id: dev_server.id.0 as u64,
2452 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2453 name: request.name,
2454 })?;
2455 Ok(())
2456}
2457
2458async fn regenerate_dev_server_token(
2459 request: proto::RegenerateDevServerToken,
2460 response: Response<proto::RegenerateDevServerToken>,
2461 session: UserSession,
2462) -> Result<()> {
2463 let dev_server_id = DevServerId(request.dev_server_id as i32);
2464 let access_token = auth::random_token();
2465 let hashed_access_token = auth::hash_access_token(&access_token);
2466
2467 let connection_id = session
2468 .connection_pool()
2469 .await
2470 .dev_server_connection_id(dev_server_id);
2471 if let Some(connection_id) = connection_id {
2472 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2473 session.peer.send(
2474 connection_id,
2475 proto::ShutdownDevServer {
2476 reason: Some("dev server token was regenerated".to_string()),
2477 },
2478 )?;
2479 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2480 }
2481
2482 let status = session
2483 .db()
2484 .await
2485 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2486 .await?;
2487
2488 send_dev_server_projects_update(session.user_id(), status, &session).await;
2489
2490 response.send(proto::RegenerateDevServerTokenResponse {
2491 dev_server_id: dev_server_id.to_proto(),
2492 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2493 })?;
2494 Ok(())
2495}
2496
2497async fn rename_dev_server(
2498 request: proto::RenameDevServer,
2499 response: Response<proto::RenameDevServer>,
2500 session: UserSession,
2501) -> Result<()> {
2502 if request.name.trim().is_empty() {
2503 return Err(proto::ErrorCode::Forbidden
2504 .message("Dev server name cannot be empty".to_string())
2505 .anyhow())?;
2506 }
2507
2508 let dev_server_id = DevServerId(request.dev_server_id as i32);
2509 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2510 if dev_server.user_id != session.user_id() {
2511 return Err(anyhow!(ErrorCode::Forbidden))?;
2512 }
2513
2514 let status = session
2515 .db()
2516 .await
2517 .rename_dev_server(
2518 dev_server_id,
2519 &request.name,
2520 request.ssh_connection_string.as_deref(),
2521 session.user_id(),
2522 )
2523 .await?;
2524
2525 send_dev_server_projects_update(session.user_id(), status, &session).await;
2526
2527 response.send(proto::Ack {})?;
2528 Ok(())
2529}
2530
2531async fn delete_dev_server(
2532 request: proto::DeleteDevServer,
2533 response: Response<proto::DeleteDevServer>,
2534 session: UserSession,
2535) -> Result<()> {
2536 let dev_server_id = DevServerId(request.dev_server_id as i32);
2537 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2538 if dev_server.user_id != session.user_id() {
2539 return Err(anyhow!(ErrorCode::Forbidden))?;
2540 }
2541
2542 let connection_id = session
2543 .connection_pool()
2544 .await
2545 .dev_server_connection_id(dev_server_id);
2546 if let Some(connection_id) = connection_id {
2547 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2548 session.peer.send(
2549 connection_id,
2550 proto::ShutdownDevServer {
2551 reason: Some("dev server was deleted".to_string()),
2552 },
2553 )?;
2554 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2555 }
2556
2557 let status = session
2558 .db()
2559 .await
2560 .delete_dev_server(dev_server_id, session.user_id())
2561 .await?;
2562
2563 send_dev_server_projects_update(session.user_id(), status, &session).await;
2564
2565 response.send(proto::Ack {})?;
2566 Ok(())
2567}
2568
2569async fn delete_dev_server_project(
2570 request: proto::DeleteDevServerProject,
2571 response: Response<proto::DeleteDevServerProject>,
2572 session: UserSession,
2573) -> Result<()> {
2574 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2575 let dev_server_project = session
2576 .db()
2577 .await
2578 .get_dev_server_project(dev_server_project_id)
2579 .await?;
2580
2581 let dev_server = session
2582 .db()
2583 .await
2584 .get_dev_server(dev_server_project.dev_server_id)
2585 .await?;
2586 if dev_server.user_id != session.user_id() {
2587 return Err(anyhow!(ErrorCode::Forbidden))?;
2588 }
2589
2590 let dev_server_connection_id = session
2591 .connection_pool()
2592 .await
2593 .dev_server_connection_id(dev_server.id);
2594
2595 if let Some(dev_server_connection_id) = dev_server_connection_id {
2596 let project = session
2597 .db()
2598 .await
2599 .find_dev_server_project(dev_server_project_id)
2600 .await;
2601 if let Ok(project) = project {
2602 unshare_project_internal(
2603 project.id,
2604 dev_server_connection_id,
2605 Some(session.user_id()),
2606 &session,
2607 )
2608 .await?;
2609 }
2610 }
2611
2612 let (projects, status) = session
2613 .db()
2614 .await
2615 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2616 .await?;
2617
2618 if let Some(dev_server_connection_id) = dev_server_connection_id {
2619 session.peer.send(
2620 dev_server_connection_id,
2621 proto::DevServerInstructions { projects },
2622 )?;
2623 }
2624
2625 send_dev_server_projects_update(session.user_id(), status, &session).await;
2626
2627 response.send(proto::Ack {})?;
2628 Ok(())
2629}
2630
2631async fn rejoin_dev_server_projects(
2632 request: proto::RejoinRemoteProjects,
2633 response: Response<proto::RejoinRemoteProjects>,
2634 session: UserSession,
2635) -> Result<()> {
2636 let mut rejoined_projects = {
2637 let db = session.db().await;
2638 db.rejoin_dev_server_projects(
2639 &request.rejoined_projects,
2640 session.user_id(),
2641 session.0.connection_id,
2642 )
2643 .await?
2644 };
2645 response.send(proto::RejoinRemoteProjectsResponse {
2646 rejoined_projects: rejoined_projects
2647 .iter()
2648 .map(|project| project.to_proto())
2649 .collect(),
2650 })?;
2651 notify_rejoined_projects(&mut rejoined_projects, &session)
2652}
2653
2654async fn reconnect_dev_server(
2655 request: proto::ReconnectDevServer,
2656 response: Response<proto::ReconnectDevServer>,
2657 session: DevServerSession,
2658) -> Result<()> {
2659 let reshared_projects = {
2660 let db = session.db().await;
2661 db.reshare_dev_server_projects(
2662 &request.reshared_projects,
2663 session.dev_server_id(),
2664 session.0.connection_id,
2665 )
2666 .await?
2667 };
2668
2669 for project in &reshared_projects {
2670 for collaborator in &project.collaborators {
2671 session
2672 .peer
2673 .send(
2674 collaborator.connection_id,
2675 proto::UpdateProjectCollaborator {
2676 project_id: project.id.to_proto(),
2677 old_peer_id: Some(project.old_connection_id.into()),
2678 new_peer_id: Some(session.connection_id.into()),
2679 },
2680 )
2681 .trace_err();
2682 }
2683
2684 broadcast(
2685 Some(session.connection_id),
2686 project
2687 .collaborators
2688 .iter()
2689 .map(|collaborator| collaborator.connection_id),
2690 |connection_id| {
2691 session.peer.forward_send(
2692 session.connection_id,
2693 connection_id,
2694 proto::UpdateProject {
2695 project_id: project.id.to_proto(),
2696 worktrees: project.worktrees.clone(),
2697 },
2698 )
2699 },
2700 );
2701 }
2702
2703 response.send(proto::ReconnectDevServerResponse {
2704 reshared_projects: reshared_projects
2705 .iter()
2706 .map(|project| proto::ResharedProject {
2707 id: project.id.to_proto(),
2708 collaborators: project
2709 .collaborators
2710 .iter()
2711 .map(|collaborator| collaborator.to_proto())
2712 .collect(),
2713 })
2714 .collect(),
2715 })?;
2716
2717 Ok(())
2718}
2719
2720async fn shutdown_dev_server(
2721 _: proto::ShutdownDevServer,
2722 response: Response<proto::ShutdownDevServer>,
2723 session: DevServerSession,
2724) -> Result<()> {
2725 response.send(proto::Ack {})?;
2726 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2727 remove_dev_server_connection(session.dev_server_id(), &session).await
2728}
2729
2730async fn shutdown_dev_server_internal(
2731 dev_server_id: DevServerId,
2732 connection_id: ConnectionId,
2733 session: &Session,
2734) -> Result<()> {
2735 let (dev_server_projects, dev_server) = {
2736 let db = session.db().await;
2737 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2738 let dev_server = db.get_dev_server(dev_server_id).await?;
2739 (dev_server_projects, dev_server)
2740 };
2741
2742 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2743 unshare_project_internal(
2744 ProjectId::from_proto(project_id),
2745 connection_id,
2746 None,
2747 session,
2748 )
2749 .await?;
2750 }
2751
2752 session
2753 .connection_pool()
2754 .await
2755 .set_dev_server_offline(dev_server_id);
2756
2757 let status = session
2758 .db()
2759 .await
2760 .dev_server_projects_update(dev_server.user_id)
2761 .await?;
2762 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2763
2764 Ok(())
2765}
2766
2767async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2768 let dev_server_connection = session
2769 .connection_pool()
2770 .await
2771 .dev_server_connection_id(dev_server_id);
2772
2773 if let Some(dev_server_connection) = dev_server_connection {
2774 session
2775 .connection_pool()
2776 .await
2777 .remove_connection(dev_server_connection)?;
2778 }
2779 Ok(())
2780}
2781
2782/// Updates other participants with changes to the project
2783async fn update_project(
2784 request: proto::UpdateProject,
2785 response: Response<proto::UpdateProject>,
2786 session: Session,
2787) -> Result<()> {
2788 let project_id = ProjectId::from_proto(request.project_id);
2789 let (room, guest_connection_ids) = &*session
2790 .db()
2791 .await
2792 .update_project(project_id, session.connection_id, &request.worktrees)
2793 .await?;
2794 broadcast(
2795 Some(session.connection_id),
2796 guest_connection_ids.iter().copied(),
2797 |connection_id| {
2798 session
2799 .peer
2800 .forward_send(session.connection_id, connection_id, request.clone())
2801 },
2802 );
2803 if let Some(room) = room {
2804 room_updated(room, &session.peer);
2805 }
2806 response.send(proto::Ack {})?;
2807
2808 Ok(())
2809}
2810
2811/// Updates other participants with changes to the worktree
2812async fn update_worktree(
2813 request: proto::UpdateWorktree,
2814 response: Response<proto::UpdateWorktree>,
2815 session: Session,
2816) -> Result<()> {
2817 let guest_connection_ids = session
2818 .db()
2819 .await
2820 .update_worktree(&request, session.connection_id)
2821 .await?;
2822
2823 broadcast(
2824 Some(session.connection_id),
2825 guest_connection_ids.iter().copied(),
2826 |connection_id| {
2827 session
2828 .peer
2829 .forward_send(session.connection_id, connection_id, request.clone())
2830 },
2831 );
2832 response.send(proto::Ack {})?;
2833 Ok(())
2834}
2835
2836/// Updates other participants with changes to the diagnostics
2837async fn update_diagnostic_summary(
2838 message: proto::UpdateDiagnosticSummary,
2839 session: Session,
2840) -> Result<()> {
2841 let guest_connection_ids = session
2842 .db()
2843 .await
2844 .update_diagnostic_summary(&message, session.connection_id)
2845 .await?;
2846
2847 broadcast(
2848 Some(session.connection_id),
2849 guest_connection_ids.iter().copied(),
2850 |connection_id| {
2851 session
2852 .peer
2853 .forward_send(session.connection_id, connection_id, message.clone())
2854 },
2855 );
2856
2857 Ok(())
2858}
2859
2860/// Updates other participants with changes to the worktree settings
2861async fn update_worktree_settings(
2862 message: proto::UpdateWorktreeSettings,
2863 session: Session,
2864) -> Result<()> {
2865 let guest_connection_ids = session
2866 .db()
2867 .await
2868 .update_worktree_settings(&message, session.connection_id)
2869 .await?;
2870
2871 broadcast(
2872 Some(session.connection_id),
2873 guest_connection_ids.iter().copied(),
2874 |connection_id| {
2875 session
2876 .peer
2877 .forward_send(session.connection_id, connection_id, message.clone())
2878 },
2879 );
2880
2881 Ok(())
2882}
2883
2884/// Notify other participants that a language server has started.
2885async fn start_language_server(
2886 request: proto::StartLanguageServer,
2887 session: Session,
2888) -> Result<()> {
2889 let guest_connection_ids = session
2890 .db()
2891 .await
2892 .start_language_server(&request, session.connection_id)
2893 .await?;
2894
2895 broadcast(
2896 Some(session.connection_id),
2897 guest_connection_ids.iter().copied(),
2898 |connection_id| {
2899 session
2900 .peer
2901 .forward_send(session.connection_id, connection_id, request.clone())
2902 },
2903 );
2904 Ok(())
2905}
2906
2907/// Notify other participants that a language server has changed.
2908async fn update_language_server(
2909 request: proto::UpdateLanguageServer,
2910 session: Session,
2911) -> Result<()> {
2912 let project_id = ProjectId::from_proto(request.project_id);
2913 let project_connection_ids = session
2914 .db()
2915 .await
2916 .project_connection_ids(project_id, session.connection_id, true)
2917 .await?;
2918 broadcast(
2919 Some(session.connection_id),
2920 project_connection_ids.iter().copied(),
2921 |connection_id| {
2922 session
2923 .peer
2924 .forward_send(session.connection_id, connection_id, request.clone())
2925 },
2926 );
2927 Ok(())
2928}
2929
2930/// forward a project request to the host. These requests should be read only
2931/// as guests are allowed to send them.
2932async fn forward_read_only_project_request<T>(
2933 request: T,
2934 response: Response<T>,
2935 session: UserSession,
2936) -> Result<()>
2937where
2938 T: EntityMessage + RequestMessage,
2939{
2940 let project_id = ProjectId::from_proto(request.remote_entity_id());
2941 let host_connection_id = session
2942 .db()
2943 .await
2944 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2945 .await?;
2946 let payload = session
2947 .peer
2948 .forward_request(session.connection_id, host_connection_id, request)
2949 .await?;
2950 response.send(payload)?;
2951 Ok(())
2952}
2953
2954async fn forward_find_search_candidates_request(
2955 request: proto::FindSearchCandidates,
2956 response: Response<proto::FindSearchCandidates>,
2957 session: UserSession,
2958) -> Result<()> {
2959 let project_id = ProjectId::from_proto(request.remote_entity_id());
2960 let host_connection_id = session
2961 .db()
2962 .await
2963 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2964 .await?;
2965 let payload = session
2966 .peer
2967 .forward_request(session.connection_id, host_connection_id, request)
2968 .await?;
2969 response.send(payload)?;
2970 Ok(())
2971}
2972
2973/// forward a project request to the dev server. Only allowed
2974/// if it's your dev server.
2975async fn forward_project_request_for_owner<T>(
2976 request: T,
2977 response: Response<T>,
2978 session: UserSession,
2979) -> Result<()>
2980where
2981 T: EntityMessage + RequestMessage,
2982{
2983 let project_id = ProjectId::from_proto(request.remote_entity_id());
2984
2985 let host_connection_id = session
2986 .db()
2987 .await
2988 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2989 .await?;
2990 let payload = session
2991 .peer
2992 .forward_request(session.connection_id, host_connection_id, request)
2993 .await?;
2994 response.send(payload)?;
2995 Ok(())
2996}
2997
2998/// forward a project request to the host. These requests are disallowed
2999/// for guests.
3000async fn forward_mutating_project_request<T>(
3001 request: T,
3002 response: Response<T>,
3003 session: UserSession,
3004) -> Result<()>
3005where
3006 T: EntityMessage + RequestMessage,
3007{
3008 let project_id = ProjectId::from_proto(request.remote_entity_id());
3009
3010 let host_connection_id = session
3011 .db()
3012 .await
3013 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3014 .await?;
3015 let payload = session
3016 .peer
3017 .forward_request(session.connection_id, host_connection_id, request)
3018 .await?;
3019 response.send(payload)?;
3020 Ok(())
3021}
3022
3023/// Notify other participants that a new buffer has been created
3024async fn create_buffer_for_peer(
3025 request: proto::CreateBufferForPeer,
3026 session: Session,
3027) -> Result<()> {
3028 session
3029 .db()
3030 .await
3031 .check_user_is_project_host(
3032 ProjectId::from_proto(request.project_id),
3033 session.connection_id,
3034 )
3035 .await?;
3036 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3037 session
3038 .peer
3039 .forward_send(session.connection_id, peer_id.into(), request)?;
3040 Ok(())
3041}
3042
3043/// Notify other participants that a buffer has been updated. This is
3044/// allowed for guests as long as the update is limited to selections.
3045async fn update_buffer(
3046 request: proto::UpdateBuffer,
3047 response: Response<proto::UpdateBuffer>,
3048 session: Session,
3049) -> Result<()> {
3050 let project_id = ProjectId::from_proto(request.project_id);
3051 let mut capability = Capability::ReadOnly;
3052
3053 for op in request.operations.iter() {
3054 match op.variant {
3055 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3056 Some(_) => capability = Capability::ReadWrite,
3057 }
3058 }
3059
3060 let host = {
3061 let guard = session
3062 .db()
3063 .await
3064 .connections_for_buffer_update(
3065 project_id,
3066 session.principal_id(),
3067 session.connection_id,
3068 capability,
3069 )
3070 .await?;
3071
3072 let (host, guests) = &*guard;
3073
3074 broadcast(
3075 Some(session.connection_id),
3076 guests.clone(),
3077 |connection_id| {
3078 session
3079 .peer
3080 .forward_send(session.connection_id, connection_id, request.clone())
3081 },
3082 );
3083
3084 *host
3085 };
3086
3087 if host != session.connection_id {
3088 session
3089 .peer
3090 .forward_request(session.connection_id, host, request.clone())
3091 .await?;
3092 }
3093
3094 response.send(proto::Ack {})?;
3095 Ok(())
3096}
3097
3098async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3099 let project_id = ProjectId::from_proto(message.project_id);
3100
3101 let operation = message.operation.as_ref().context("invalid operation")?;
3102 let capability = match operation.variant.as_ref() {
3103 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3104 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3105 match buffer_op.variant {
3106 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3107 Capability::ReadOnly
3108 }
3109 _ => Capability::ReadWrite,
3110 }
3111 } else {
3112 Capability::ReadWrite
3113 }
3114 }
3115 Some(_) => Capability::ReadWrite,
3116 None => Capability::ReadOnly,
3117 };
3118
3119 let guard = session
3120 .db()
3121 .await
3122 .connections_for_buffer_update(
3123 project_id,
3124 session.principal_id(),
3125 session.connection_id,
3126 capability,
3127 )
3128 .await?;
3129
3130 let (host, guests) = &*guard;
3131
3132 broadcast(
3133 Some(session.connection_id),
3134 guests.iter().chain([host]).copied(),
3135 |connection_id| {
3136 session
3137 .peer
3138 .forward_send(session.connection_id, connection_id, message.clone())
3139 },
3140 );
3141
3142 Ok(())
3143}
3144
3145/// Notify other participants that a project has been updated.
3146async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3147 request: T,
3148 session: Session,
3149) -> Result<()> {
3150 let project_id = ProjectId::from_proto(request.remote_entity_id());
3151 let project_connection_ids = session
3152 .db()
3153 .await
3154 .project_connection_ids(project_id, session.connection_id, false)
3155 .await?;
3156
3157 broadcast(
3158 Some(session.connection_id),
3159 project_connection_ids.iter().copied(),
3160 |connection_id| {
3161 session
3162 .peer
3163 .forward_send(session.connection_id, connection_id, request.clone())
3164 },
3165 );
3166 Ok(())
3167}
3168
3169/// Start following another user in a call.
3170async fn follow(
3171 request: proto::Follow,
3172 response: Response<proto::Follow>,
3173 session: UserSession,
3174) -> Result<()> {
3175 let room_id = RoomId::from_proto(request.room_id);
3176 let project_id = request.project_id.map(ProjectId::from_proto);
3177 let leader_id = request
3178 .leader_id
3179 .ok_or_else(|| anyhow!("invalid leader id"))?
3180 .into();
3181 let follower_id = session.connection_id;
3182
3183 session
3184 .db()
3185 .await
3186 .check_room_participants(room_id, leader_id, session.connection_id)
3187 .await?;
3188
3189 let response_payload = session
3190 .peer
3191 .forward_request(session.connection_id, leader_id, request)
3192 .await?;
3193 response.send(response_payload)?;
3194
3195 if let Some(project_id) = project_id {
3196 let room = session
3197 .db()
3198 .await
3199 .follow(room_id, project_id, leader_id, follower_id)
3200 .await?;
3201 room_updated(&room, &session.peer);
3202 }
3203
3204 Ok(())
3205}
3206
3207/// Stop following another user in a call.
3208async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3209 let room_id = RoomId::from_proto(request.room_id);
3210 let project_id = request.project_id.map(ProjectId::from_proto);
3211 let leader_id = request
3212 .leader_id
3213 .ok_or_else(|| anyhow!("invalid leader id"))?
3214 .into();
3215 let follower_id = session.connection_id;
3216
3217 session
3218 .db()
3219 .await
3220 .check_room_participants(room_id, leader_id, session.connection_id)
3221 .await?;
3222
3223 session
3224 .peer
3225 .forward_send(session.connection_id, leader_id, request)?;
3226
3227 if let Some(project_id) = project_id {
3228 let room = session
3229 .db()
3230 .await
3231 .unfollow(room_id, project_id, leader_id, follower_id)
3232 .await?;
3233 room_updated(&room, &session.peer);
3234 }
3235
3236 Ok(())
3237}
3238
3239/// Notify everyone following you of your current location.
3240async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3241 let room_id = RoomId::from_proto(request.room_id);
3242 let database = session.db.lock().await;
3243
3244 let connection_ids = if let Some(project_id) = request.project_id {
3245 let project_id = ProjectId::from_proto(project_id);
3246 database
3247 .project_connection_ids(project_id, session.connection_id, true)
3248 .await?
3249 } else {
3250 database
3251 .room_connection_ids(room_id, session.connection_id)
3252 .await?
3253 };
3254
3255 // For now, don't send view update messages back to that view's current leader.
3256 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3257 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3258 _ => None,
3259 });
3260
3261 for connection_id in connection_ids.iter().cloned() {
3262 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3263 session
3264 .peer
3265 .forward_send(session.connection_id, connection_id, request.clone())?;
3266 }
3267 }
3268 Ok(())
3269}
3270
3271/// Get public data about users.
3272async fn get_users(
3273 request: proto::GetUsers,
3274 response: Response<proto::GetUsers>,
3275 session: Session,
3276) -> Result<()> {
3277 let user_ids = request
3278 .user_ids
3279 .into_iter()
3280 .map(UserId::from_proto)
3281 .collect();
3282 let users = session
3283 .db()
3284 .await
3285 .get_users_by_ids(user_ids)
3286 .await?
3287 .into_iter()
3288 .map(|user| proto::User {
3289 id: user.id.to_proto(),
3290 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3291 github_login: user.github_login,
3292 })
3293 .collect();
3294 response.send(proto::UsersResponse { users })?;
3295 Ok(())
3296}
3297
3298/// Search for users (to invite) buy Github login
3299async fn fuzzy_search_users(
3300 request: proto::FuzzySearchUsers,
3301 response: Response<proto::FuzzySearchUsers>,
3302 session: UserSession,
3303) -> Result<()> {
3304 let query = request.query;
3305 let users = match query.len() {
3306 0 => vec![],
3307 1 | 2 => session
3308 .db()
3309 .await
3310 .get_user_by_github_login(&query)
3311 .await?
3312 .into_iter()
3313 .collect(),
3314 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3315 };
3316 let users = users
3317 .into_iter()
3318 .filter(|user| user.id != session.user_id())
3319 .map(|user| proto::User {
3320 id: user.id.to_proto(),
3321 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3322 github_login: user.github_login,
3323 })
3324 .collect();
3325 response.send(proto::UsersResponse { users })?;
3326 Ok(())
3327}
3328
3329/// Send a contact request to another user.
3330async fn request_contact(
3331 request: proto::RequestContact,
3332 response: Response<proto::RequestContact>,
3333 session: UserSession,
3334) -> Result<()> {
3335 let requester_id = session.user_id();
3336 let responder_id = UserId::from_proto(request.responder_id);
3337 if requester_id == responder_id {
3338 return Err(anyhow!("cannot add yourself as a contact"))?;
3339 }
3340
3341 let notifications = session
3342 .db()
3343 .await
3344 .send_contact_request(requester_id, responder_id)
3345 .await?;
3346
3347 // Update outgoing contact requests of requester
3348 let mut update = proto::UpdateContacts::default();
3349 update.outgoing_requests.push(responder_id.to_proto());
3350 for connection_id in session
3351 .connection_pool()
3352 .await
3353 .user_connection_ids(requester_id)
3354 {
3355 session.peer.send(connection_id, update.clone())?;
3356 }
3357
3358 // Update incoming contact requests of responder
3359 let mut update = proto::UpdateContacts::default();
3360 update
3361 .incoming_requests
3362 .push(proto::IncomingContactRequest {
3363 requester_id: requester_id.to_proto(),
3364 });
3365 let connection_pool = session.connection_pool().await;
3366 for connection_id in connection_pool.user_connection_ids(responder_id) {
3367 session.peer.send(connection_id, update.clone())?;
3368 }
3369
3370 send_notifications(&connection_pool, &session.peer, notifications);
3371
3372 response.send(proto::Ack {})?;
3373 Ok(())
3374}
3375
3376/// Accept or decline a contact request
3377async fn respond_to_contact_request(
3378 request: proto::RespondToContactRequest,
3379 response: Response<proto::RespondToContactRequest>,
3380 session: UserSession,
3381) -> Result<()> {
3382 let responder_id = session.user_id();
3383 let requester_id = UserId::from_proto(request.requester_id);
3384 let db = session.db().await;
3385 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3386 db.dismiss_contact_notification(responder_id, requester_id)
3387 .await?;
3388 } else {
3389 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3390
3391 let notifications = db
3392 .respond_to_contact_request(responder_id, requester_id, accept)
3393 .await?;
3394 let requester_busy = db.is_user_busy(requester_id).await?;
3395 let responder_busy = db.is_user_busy(responder_id).await?;
3396
3397 let pool = session.connection_pool().await;
3398 // Update responder with new contact
3399 let mut update = proto::UpdateContacts::default();
3400 if accept {
3401 update
3402 .contacts
3403 .push(contact_for_user(requester_id, requester_busy, &pool));
3404 }
3405 update
3406 .remove_incoming_requests
3407 .push(requester_id.to_proto());
3408 for connection_id in pool.user_connection_ids(responder_id) {
3409 session.peer.send(connection_id, update.clone())?;
3410 }
3411
3412 // Update requester with new contact
3413 let mut update = proto::UpdateContacts::default();
3414 if accept {
3415 update
3416 .contacts
3417 .push(contact_for_user(responder_id, responder_busy, &pool));
3418 }
3419 update
3420 .remove_outgoing_requests
3421 .push(responder_id.to_proto());
3422
3423 for connection_id in pool.user_connection_ids(requester_id) {
3424 session.peer.send(connection_id, update.clone())?;
3425 }
3426
3427 send_notifications(&pool, &session.peer, notifications);
3428 }
3429
3430 response.send(proto::Ack {})?;
3431 Ok(())
3432}
3433
3434/// Remove a contact.
3435async fn remove_contact(
3436 request: proto::RemoveContact,
3437 response: Response<proto::RemoveContact>,
3438 session: UserSession,
3439) -> Result<()> {
3440 let requester_id = session.user_id();
3441 let responder_id = UserId::from_proto(request.user_id);
3442 let db = session.db().await;
3443 let (contact_accepted, deleted_notification_id) =
3444 db.remove_contact(requester_id, responder_id).await?;
3445
3446 let pool = session.connection_pool().await;
3447 // Update outgoing contact requests of requester
3448 let mut update = proto::UpdateContacts::default();
3449 if contact_accepted {
3450 update.remove_contacts.push(responder_id.to_proto());
3451 } else {
3452 update
3453 .remove_outgoing_requests
3454 .push(responder_id.to_proto());
3455 }
3456 for connection_id in pool.user_connection_ids(requester_id) {
3457 session.peer.send(connection_id, update.clone())?;
3458 }
3459
3460 // Update incoming contact requests of responder
3461 let mut update = proto::UpdateContacts::default();
3462 if contact_accepted {
3463 update.remove_contacts.push(requester_id.to_proto());
3464 } else {
3465 update
3466 .remove_incoming_requests
3467 .push(requester_id.to_proto());
3468 }
3469 for connection_id in pool.user_connection_ids(responder_id) {
3470 session.peer.send(connection_id, update.clone())?;
3471 if let Some(notification_id) = deleted_notification_id {
3472 session.peer.send(
3473 connection_id,
3474 proto::DeleteNotification {
3475 notification_id: notification_id.to_proto(),
3476 },
3477 )?;
3478 }
3479 }
3480
3481 response.send(proto::Ack {})?;
3482 Ok(())
3483}
3484
3485fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3486 version.0.minor() < 139
3487}
3488
3489async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3490 let plan = session.current_plan(&session.db().await).await?;
3491
3492 session
3493 .peer
3494 .send(
3495 session.connection_id,
3496 proto::UpdateUserPlan { plan: plan.into() },
3497 )
3498 .trace_err();
3499
3500 Ok(())
3501}
3502
3503async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3504 subscribe_user_to_channels(
3505 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3506 &session,
3507 )
3508 .await?;
3509 Ok(())
3510}
3511
3512async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3513 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3514 let mut pool = session.connection_pool().await;
3515 for membership in &channels_for_user.channel_memberships {
3516 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3517 }
3518 session.peer.send(
3519 session.connection_id,
3520 build_update_user_channels(&channels_for_user),
3521 )?;
3522 session.peer.send(
3523 session.connection_id,
3524 build_channels_update(channels_for_user),
3525 )?;
3526 Ok(())
3527}
3528
3529/// Creates a new channel.
3530async fn create_channel(
3531 request: proto::CreateChannel,
3532 response: Response<proto::CreateChannel>,
3533 session: UserSession,
3534) -> Result<()> {
3535 let db = session.db().await;
3536
3537 let parent_id = request.parent_id.map(ChannelId::from_proto);
3538 let (channel, membership) = db
3539 .create_channel(&request.name, parent_id, session.user_id())
3540 .await?;
3541
3542 let root_id = channel.root_id();
3543 let channel = Channel::from_model(channel);
3544
3545 response.send(proto::CreateChannelResponse {
3546 channel: Some(channel.to_proto()),
3547 parent_id: request.parent_id,
3548 })?;
3549
3550 let mut connection_pool = session.connection_pool().await;
3551 if let Some(membership) = membership {
3552 connection_pool.subscribe_to_channel(
3553 membership.user_id,
3554 membership.channel_id,
3555 membership.role,
3556 );
3557 let update = proto::UpdateUserChannels {
3558 channel_memberships: vec![proto::ChannelMembership {
3559 channel_id: membership.channel_id.to_proto(),
3560 role: membership.role.into(),
3561 }],
3562 ..Default::default()
3563 };
3564 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3565 session.peer.send(connection_id, update.clone())?;
3566 }
3567 }
3568
3569 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3570 if !role.can_see_channel(channel.visibility) {
3571 continue;
3572 }
3573
3574 let update = proto::UpdateChannels {
3575 channels: vec![channel.to_proto()],
3576 ..Default::default()
3577 };
3578 session.peer.send(connection_id, update.clone())?;
3579 }
3580
3581 Ok(())
3582}
3583
3584/// Delete a channel
3585async fn delete_channel(
3586 request: proto::DeleteChannel,
3587 response: Response<proto::DeleteChannel>,
3588 session: UserSession,
3589) -> Result<()> {
3590 let db = session.db().await;
3591
3592 let channel_id = request.channel_id;
3593 let (root_channel, removed_channels) = db
3594 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3595 .await?;
3596 response.send(proto::Ack {})?;
3597
3598 // Notify members of removed channels
3599 let mut update = proto::UpdateChannels::default();
3600 update
3601 .delete_channels
3602 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3603
3604 let connection_pool = session.connection_pool().await;
3605 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3606 session.peer.send(connection_id, update.clone())?;
3607 }
3608
3609 Ok(())
3610}
3611
3612/// Invite someone to join a channel.
3613async fn invite_channel_member(
3614 request: proto::InviteChannelMember,
3615 response: Response<proto::InviteChannelMember>,
3616 session: UserSession,
3617) -> Result<()> {
3618 let db = session.db().await;
3619 let channel_id = ChannelId::from_proto(request.channel_id);
3620 let invitee_id = UserId::from_proto(request.user_id);
3621 let InviteMemberResult {
3622 channel,
3623 notifications,
3624 } = db
3625 .invite_channel_member(
3626 channel_id,
3627 invitee_id,
3628 session.user_id(),
3629 request.role().into(),
3630 )
3631 .await?;
3632
3633 let update = proto::UpdateChannels {
3634 channel_invitations: vec![channel.to_proto()],
3635 ..Default::default()
3636 };
3637
3638 let connection_pool = session.connection_pool().await;
3639 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3640 session.peer.send(connection_id, update.clone())?;
3641 }
3642
3643 send_notifications(&connection_pool, &session.peer, notifications);
3644
3645 response.send(proto::Ack {})?;
3646 Ok(())
3647}
3648
3649/// remove someone from a channel
3650async fn remove_channel_member(
3651 request: proto::RemoveChannelMember,
3652 response: Response<proto::RemoveChannelMember>,
3653 session: UserSession,
3654) -> Result<()> {
3655 let db = session.db().await;
3656 let channel_id = ChannelId::from_proto(request.channel_id);
3657 let member_id = UserId::from_proto(request.user_id);
3658
3659 let RemoveChannelMemberResult {
3660 membership_update,
3661 notification_id,
3662 } = db
3663 .remove_channel_member(channel_id, member_id, session.user_id())
3664 .await?;
3665
3666 let mut connection_pool = session.connection_pool().await;
3667 notify_membership_updated(
3668 &mut connection_pool,
3669 membership_update,
3670 member_id,
3671 &session.peer,
3672 );
3673 for connection_id in connection_pool.user_connection_ids(member_id) {
3674 if let Some(notification_id) = notification_id {
3675 session
3676 .peer
3677 .send(
3678 connection_id,
3679 proto::DeleteNotification {
3680 notification_id: notification_id.to_proto(),
3681 },
3682 )
3683 .trace_err();
3684 }
3685 }
3686
3687 response.send(proto::Ack {})?;
3688 Ok(())
3689}
3690
3691/// Toggle the channel between public and private.
3692/// Care is taken to maintain the invariant that public channels only descend from public channels,
3693/// (though members-only channels can appear at any point in the hierarchy).
3694async fn set_channel_visibility(
3695 request: proto::SetChannelVisibility,
3696 response: Response<proto::SetChannelVisibility>,
3697 session: UserSession,
3698) -> Result<()> {
3699 let db = session.db().await;
3700 let channel_id = ChannelId::from_proto(request.channel_id);
3701 let visibility = request.visibility().into();
3702
3703 let channel_model = db
3704 .set_channel_visibility(channel_id, visibility, session.user_id())
3705 .await?;
3706 let root_id = channel_model.root_id();
3707 let channel = Channel::from_model(channel_model);
3708
3709 let mut connection_pool = session.connection_pool().await;
3710 for (user_id, role) in connection_pool
3711 .channel_user_ids(root_id)
3712 .collect::<Vec<_>>()
3713 .into_iter()
3714 {
3715 let update = if role.can_see_channel(channel.visibility) {
3716 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3717 proto::UpdateChannels {
3718 channels: vec![channel.to_proto()],
3719 ..Default::default()
3720 }
3721 } else {
3722 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3723 proto::UpdateChannels {
3724 delete_channels: vec![channel.id.to_proto()],
3725 ..Default::default()
3726 }
3727 };
3728
3729 for connection_id in connection_pool.user_connection_ids(user_id) {
3730 session.peer.send(connection_id, update.clone())?;
3731 }
3732 }
3733
3734 response.send(proto::Ack {})?;
3735 Ok(())
3736}
3737
3738/// Alter the role for a user in the channel.
3739async fn set_channel_member_role(
3740 request: proto::SetChannelMemberRole,
3741 response: Response<proto::SetChannelMemberRole>,
3742 session: UserSession,
3743) -> Result<()> {
3744 let db = session.db().await;
3745 let channel_id = ChannelId::from_proto(request.channel_id);
3746 let member_id = UserId::from_proto(request.user_id);
3747 let result = db
3748 .set_channel_member_role(
3749 channel_id,
3750 session.user_id(),
3751 member_id,
3752 request.role().into(),
3753 )
3754 .await?;
3755
3756 match result {
3757 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3758 let mut connection_pool = session.connection_pool().await;
3759 notify_membership_updated(
3760 &mut connection_pool,
3761 membership_update,
3762 member_id,
3763 &session.peer,
3764 )
3765 }
3766 db::SetMemberRoleResult::InviteUpdated(channel) => {
3767 let update = proto::UpdateChannels {
3768 channel_invitations: vec![channel.to_proto()],
3769 ..Default::default()
3770 };
3771
3772 for connection_id in session
3773 .connection_pool()
3774 .await
3775 .user_connection_ids(member_id)
3776 {
3777 session.peer.send(connection_id, update.clone())?;
3778 }
3779 }
3780 }
3781
3782 response.send(proto::Ack {})?;
3783 Ok(())
3784}
3785
3786/// Change the name of a channel
3787async fn rename_channel(
3788 request: proto::RenameChannel,
3789 response: Response<proto::RenameChannel>,
3790 session: UserSession,
3791) -> Result<()> {
3792 let db = session.db().await;
3793 let channel_id = ChannelId::from_proto(request.channel_id);
3794 let channel_model = db
3795 .rename_channel(channel_id, session.user_id(), &request.name)
3796 .await?;
3797 let root_id = channel_model.root_id();
3798 let channel = Channel::from_model(channel_model);
3799
3800 response.send(proto::RenameChannelResponse {
3801 channel: Some(channel.to_proto()),
3802 })?;
3803
3804 let connection_pool = session.connection_pool().await;
3805 let update = proto::UpdateChannels {
3806 channels: vec![channel.to_proto()],
3807 ..Default::default()
3808 };
3809 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3810 if role.can_see_channel(channel.visibility) {
3811 session.peer.send(connection_id, update.clone())?;
3812 }
3813 }
3814
3815 Ok(())
3816}
3817
3818/// Move a channel to a new parent.
3819async fn move_channel(
3820 request: proto::MoveChannel,
3821 response: Response<proto::MoveChannel>,
3822 session: UserSession,
3823) -> Result<()> {
3824 let channel_id = ChannelId::from_proto(request.channel_id);
3825 let to = ChannelId::from_proto(request.to);
3826
3827 let (root_id, channels) = session
3828 .db()
3829 .await
3830 .move_channel(channel_id, to, session.user_id())
3831 .await?;
3832
3833 let connection_pool = session.connection_pool().await;
3834 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3835 let channels = channels
3836 .iter()
3837 .filter_map(|channel| {
3838 if role.can_see_channel(channel.visibility) {
3839 Some(channel.to_proto())
3840 } else {
3841 None
3842 }
3843 })
3844 .collect::<Vec<_>>();
3845 if channels.is_empty() {
3846 continue;
3847 }
3848
3849 let update = proto::UpdateChannels {
3850 channels,
3851 ..Default::default()
3852 };
3853
3854 session.peer.send(connection_id, update.clone())?;
3855 }
3856
3857 response.send(Ack {})?;
3858 Ok(())
3859}
3860
3861/// Get the list of channel members
3862async fn get_channel_members(
3863 request: proto::GetChannelMembers,
3864 response: Response<proto::GetChannelMembers>,
3865 session: UserSession,
3866) -> Result<()> {
3867 let db = session.db().await;
3868 let channel_id = ChannelId::from_proto(request.channel_id);
3869 let limit = if request.limit == 0 {
3870 u16::MAX as u64
3871 } else {
3872 request.limit
3873 };
3874 let (members, users) = db
3875 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3876 .await?;
3877 response.send(proto::GetChannelMembersResponse { members, users })?;
3878 Ok(())
3879}
3880
3881/// Accept or decline a channel invitation.
3882async fn respond_to_channel_invite(
3883 request: proto::RespondToChannelInvite,
3884 response: Response<proto::RespondToChannelInvite>,
3885 session: UserSession,
3886) -> Result<()> {
3887 let db = session.db().await;
3888 let channel_id = ChannelId::from_proto(request.channel_id);
3889 let RespondToChannelInvite {
3890 membership_update,
3891 notifications,
3892 } = db
3893 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3894 .await?;
3895
3896 let mut connection_pool = session.connection_pool().await;
3897 if let Some(membership_update) = membership_update {
3898 notify_membership_updated(
3899 &mut connection_pool,
3900 membership_update,
3901 session.user_id(),
3902 &session.peer,
3903 );
3904 } else {
3905 let update = proto::UpdateChannels {
3906 remove_channel_invitations: vec![channel_id.to_proto()],
3907 ..Default::default()
3908 };
3909
3910 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3911 session.peer.send(connection_id, update.clone())?;
3912 }
3913 };
3914
3915 send_notifications(&connection_pool, &session.peer, notifications);
3916
3917 response.send(proto::Ack {})?;
3918
3919 Ok(())
3920}
3921
3922/// Join the channels' room
3923async fn join_channel(
3924 request: proto::JoinChannel,
3925 response: Response<proto::JoinChannel>,
3926 session: UserSession,
3927) -> Result<()> {
3928 let channel_id = ChannelId::from_proto(request.channel_id);
3929 join_channel_internal(channel_id, Box::new(response), session).await
3930}
3931
3932trait JoinChannelInternalResponse {
3933 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3934}
3935impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3936 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3937 Response::<proto::JoinChannel>::send(self, result)
3938 }
3939}
3940impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3941 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3942 Response::<proto::JoinRoom>::send(self, result)
3943 }
3944}
3945
3946async fn join_channel_internal(
3947 channel_id: ChannelId,
3948 response: Box<impl JoinChannelInternalResponse>,
3949 session: UserSession,
3950) -> Result<()> {
3951 let joined_room = {
3952 let mut db = session.db().await;
3953 // If zed quits without leaving the room, and the user re-opens zed before the
3954 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3955 // room they were in.
3956 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3957 tracing::info!(
3958 stale_connection_id = %connection,
3959 "cleaning up stale connection",
3960 );
3961 drop(db);
3962 leave_room_for_session(&session, connection).await?;
3963 db = session.db().await;
3964 }
3965
3966 let (joined_room, membership_updated, role) = db
3967 .join_channel(channel_id, session.user_id(), session.connection_id)
3968 .await?;
3969
3970 let live_kit_connection_info =
3971 session
3972 .app_state
3973 .live_kit_client
3974 .as_ref()
3975 .and_then(|live_kit| {
3976 let (can_publish, token) = if role == ChannelRole::Guest {
3977 (
3978 false,
3979 live_kit
3980 .guest_token(
3981 &joined_room.room.live_kit_room,
3982 &session.user_id().to_string(),
3983 )
3984 .trace_err()?,
3985 )
3986 } else {
3987 (
3988 true,
3989 live_kit
3990 .room_token(
3991 &joined_room.room.live_kit_room,
3992 &session.user_id().to_string(),
3993 )
3994 .trace_err()?,
3995 )
3996 };
3997
3998 Some(LiveKitConnectionInfo {
3999 server_url: live_kit.url().into(),
4000 token,
4001 can_publish,
4002 })
4003 });
4004
4005 response.send(proto::JoinRoomResponse {
4006 room: Some(joined_room.room.clone()),
4007 channel_id: joined_room
4008 .channel
4009 .as_ref()
4010 .map(|channel| channel.id.to_proto()),
4011 live_kit_connection_info,
4012 })?;
4013
4014 let mut connection_pool = session.connection_pool().await;
4015 if let Some(membership_updated) = membership_updated {
4016 notify_membership_updated(
4017 &mut connection_pool,
4018 membership_updated,
4019 session.user_id(),
4020 &session.peer,
4021 );
4022 }
4023
4024 room_updated(&joined_room.room, &session.peer);
4025
4026 joined_room
4027 };
4028
4029 channel_updated(
4030 &joined_room
4031 .channel
4032 .ok_or_else(|| anyhow!("channel not returned"))?,
4033 &joined_room.room,
4034 &session.peer,
4035 &*session.connection_pool().await,
4036 );
4037
4038 update_user_contacts(session.user_id(), &session).await?;
4039 Ok(())
4040}
4041
4042/// Start editing the channel notes
4043async fn join_channel_buffer(
4044 request: proto::JoinChannelBuffer,
4045 response: Response<proto::JoinChannelBuffer>,
4046 session: UserSession,
4047) -> Result<()> {
4048 let db = session.db().await;
4049 let channel_id = ChannelId::from_proto(request.channel_id);
4050
4051 let open_response = db
4052 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4053 .await?;
4054
4055 let collaborators = open_response.collaborators.clone();
4056 response.send(open_response)?;
4057
4058 let update = UpdateChannelBufferCollaborators {
4059 channel_id: channel_id.to_proto(),
4060 collaborators: collaborators.clone(),
4061 };
4062 channel_buffer_updated(
4063 session.connection_id,
4064 collaborators
4065 .iter()
4066 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4067 &update,
4068 &session.peer,
4069 );
4070
4071 Ok(())
4072}
4073
4074/// Edit the channel notes
4075async fn update_channel_buffer(
4076 request: proto::UpdateChannelBuffer,
4077 session: UserSession,
4078) -> Result<()> {
4079 let db = session.db().await;
4080 let channel_id = ChannelId::from_proto(request.channel_id);
4081
4082 let (collaborators, epoch, version) = db
4083 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4084 .await?;
4085
4086 channel_buffer_updated(
4087 session.connection_id,
4088 collaborators.clone(),
4089 &proto::UpdateChannelBuffer {
4090 channel_id: channel_id.to_proto(),
4091 operations: request.operations,
4092 },
4093 &session.peer,
4094 );
4095
4096 let pool = &*session.connection_pool().await;
4097
4098 let non_collaborators =
4099 pool.channel_connection_ids(channel_id)
4100 .filter_map(|(connection_id, _)| {
4101 if collaborators.contains(&connection_id) {
4102 None
4103 } else {
4104 Some(connection_id)
4105 }
4106 });
4107
4108 broadcast(None, non_collaborators, |peer_id| {
4109 session.peer.send(
4110 peer_id,
4111 proto::UpdateChannels {
4112 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4113 channel_id: channel_id.to_proto(),
4114 epoch: epoch as u64,
4115 version: version.clone(),
4116 }],
4117 ..Default::default()
4118 },
4119 )
4120 });
4121
4122 Ok(())
4123}
4124
4125/// Rejoin the channel notes after a connection blip
4126async fn rejoin_channel_buffers(
4127 request: proto::RejoinChannelBuffers,
4128 response: Response<proto::RejoinChannelBuffers>,
4129 session: UserSession,
4130) -> Result<()> {
4131 let db = session.db().await;
4132 let buffers = db
4133 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4134 .await?;
4135
4136 for rejoined_buffer in &buffers {
4137 let collaborators_to_notify = rejoined_buffer
4138 .buffer
4139 .collaborators
4140 .iter()
4141 .filter_map(|c| Some(c.peer_id?.into()));
4142 channel_buffer_updated(
4143 session.connection_id,
4144 collaborators_to_notify,
4145 &proto::UpdateChannelBufferCollaborators {
4146 channel_id: rejoined_buffer.buffer.channel_id,
4147 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4148 },
4149 &session.peer,
4150 );
4151 }
4152
4153 response.send(proto::RejoinChannelBuffersResponse {
4154 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4155 })?;
4156
4157 Ok(())
4158}
4159
4160/// Stop editing the channel notes
4161async fn leave_channel_buffer(
4162 request: proto::LeaveChannelBuffer,
4163 response: Response<proto::LeaveChannelBuffer>,
4164 session: UserSession,
4165) -> Result<()> {
4166 let db = session.db().await;
4167 let channel_id = ChannelId::from_proto(request.channel_id);
4168
4169 let left_buffer = db
4170 .leave_channel_buffer(channel_id, session.connection_id)
4171 .await?;
4172
4173 response.send(Ack {})?;
4174
4175 channel_buffer_updated(
4176 session.connection_id,
4177 left_buffer.connections,
4178 &proto::UpdateChannelBufferCollaborators {
4179 channel_id: channel_id.to_proto(),
4180 collaborators: left_buffer.collaborators,
4181 },
4182 &session.peer,
4183 );
4184
4185 Ok(())
4186}
4187
4188fn channel_buffer_updated<T: EnvelopedMessage>(
4189 sender_id: ConnectionId,
4190 collaborators: impl IntoIterator<Item = ConnectionId>,
4191 message: &T,
4192 peer: &Peer,
4193) {
4194 broadcast(Some(sender_id), collaborators, |peer_id| {
4195 peer.send(peer_id, message.clone())
4196 });
4197}
4198
4199fn send_notifications(
4200 connection_pool: &ConnectionPool,
4201 peer: &Peer,
4202 notifications: db::NotificationBatch,
4203) {
4204 for (user_id, notification) in notifications {
4205 for connection_id in connection_pool.user_connection_ids(user_id) {
4206 if let Err(error) = peer.send(
4207 connection_id,
4208 proto::AddNotification {
4209 notification: Some(notification.clone()),
4210 },
4211 ) {
4212 tracing::error!(
4213 "failed to send notification to {:?} {}",
4214 connection_id,
4215 error
4216 );
4217 }
4218 }
4219 }
4220}
4221
4222/// Send a message to the channel
4223async fn send_channel_message(
4224 request: proto::SendChannelMessage,
4225 response: Response<proto::SendChannelMessage>,
4226 session: UserSession,
4227) -> Result<()> {
4228 // Validate the message body.
4229 let body = request.body.trim().to_string();
4230 if body.len() > MAX_MESSAGE_LEN {
4231 return Err(anyhow!("message is too long"))?;
4232 }
4233 if body.is_empty() {
4234 return Err(anyhow!("message can't be blank"))?;
4235 }
4236
4237 // TODO: adjust mentions if body is trimmed
4238
4239 let timestamp = OffsetDateTime::now_utc();
4240 let nonce = request
4241 .nonce
4242 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4243
4244 let channel_id = ChannelId::from_proto(request.channel_id);
4245 let CreatedChannelMessage {
4246 message_id,
4247 participant_connection_ids,
4248 notifications,
4249 } = session
4250 .db()
4251 .await
4252 .create_channel_message(
4253 channel_id,
4254 session.user_id(),
4255 &body,
4256 &request.mentions,
4257 timestamp,
4258 nonce.clone().into(),
4259 request.reply_to_message_id.map(MessageId::from_proto),
4260 )
4261 .await?;
4262
4263 let message = proto::ChannelMessage {
4264 sender_id: session.user_id().to_proto(),
4265 id: message_id.to_proto(),
4266 body,
4267 mentions: request.mentions,
4268 timestamp: timestamp.unix_timestamp() as u64,
4269 nonce: Some(nonce),
4270 reply_to_message_id: request.reply_to_message_id,
4271 edited_at: None,
4272 };
4273 broadcast(
4274 Some(session.connection_id),
4275 participant_connection_ids.clone(),
4276 |connection| {
4277 session.peer.send(
4278 connection,
4279 proto::ChannelMessageSent {
4280 channel_id: channel_id.to_proto(),
4281 message: Some(message.clone()),
4282 },
4283 )
4284 },
4285 );
4286 response.send(proto::SendChannelMessageResponse {
4287 message: Some(message),
4288 })?;
4289
4290 let pool = &*session.connection_pool().await;
4291 let non_participants =
4292 pool.channel_connection_ids(channel_id)
4293 .filter_map(|(connection_id, _)| {
4294 if participant_connection_ids.contains(&connection_id) {
4295 None
4296 } else {
4297 Some(connection_id)
4298 }
4299 });
4300 broadcast(None, non_participants, |peer_id| {
4301 session.peer.send(
4302 peer_id,
4303 proto::UpdateChannels {
4304 latest_channel_message_ids: vec![proto::ChannelMessageId {
4305 channel_id: channel_id.to_proto(),
4306 message_id: message_id.to_proto(),
4307 }],
4308 ..Default::default()
4309 },
4310 )
4311 });
4312 send_notifications(pool, &session.peer, notifications);
4313
4314 Ok(())
4315}
4316
4317/// Delete a channel message
4318async fn remove_channel_message(
4319 request: proto::RemoveChannelMessage,
4320 response: Response<proto::RemoveChannelMessage>,
4321 session: UserSession,
4322) -> Result<()> {
4323 let channel_id = ChannelId::from_proto(request.channel_id);
4324 let message_id = MessageId::from_proto(request.message_id);
4325 let (connection_ids, existing_notification_ids) = session
4326 .db()
4327 .await
4328 .remove_channel_message(channel_id, message_id, session.user_id())
4329 .await?;
4330
4331 broadcast(
4332 Some(session.connection_id),
4333 connection_ids,
4334 move |connection| {
4335 session.peer.send(connection, request.clone())?;
4336
4337 for notification_id in &existing_notification_ids {
4338 session.peer.send(
4339 connection,
4340 proto::DeleteNotification {
4341 notification_id: (*notification_id).to_proto(),
4342 },
4343 )?;
4344 }
4345
4346 Ok(())
4347 },
4348 );
4349 response.send(proto::Ack {})?;
4350 Ok(())
4351}
4352
4353async fn update_channel_message(
4354 request: proto::UpdateChannelMessage,
4355 response: Response<proto::UpdateChannelMessage>,
4356 session: UserSession,
4357) -> Result<()> {
4358 let channel_id = ChannelId::from_proto(request.channel_id);
4359 let message_id = MessageId::from_proto(request.message_id);
4360 let updated_at = OffsetDateTime::now_utc();
4361 let UpdatedChannelMessage {
4362 message_id,
4363 participant_connection_ids,
4364 notifications,
4365 reply_to_message_id,
4366 timestamp,
4367 deleted_mention_notification_ids,
4368 updated_mention_notifications,
4369 } = session
4370 .db()
4371 .await
4372 .update_channel_message(
4373 channel_id,
4374 message_id,
4375 session.user_id(),
4376 request.body.as_str(),
4377 &request.mentions,
4378 updated_at,
4379 )
4380 .await?;
4381
4382 let nonce = request
4383 .nonce
4384 .clone()
4385 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4386
4387 let message = proto::ChannelMessage {
4388 sender_id: session.user_id().to_proto(),
4389 id: message_id.to_proto(),
4390 body: request.body.clone(),
4391 mentions: request.mentions.clone(),
4392 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4393 nonce: Some(nonce),
4394 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4395 edited_at: Some(updated_at.unix_timestamp() as u64),
4396 };
4397
4398 response.send(proto::Ack {})?;
4399
4400 let pool = &*session.connection_pool().await;
4401 broadcast(
4402 Some(session.connection_id),
4403 participant_connection_ids,
4404 |connection| {
4405 session.peer.send(
4406 connection,
4407 proto::ChannelMessageUpdate {
4408 channel_id: channel_id.to_proto(),
4409 message: Some(message.clone()),
4410 },
4411 )?;
4412
4413 for notification_id in &deleted_mention_notification_ids {
4414 session.peer.send(
4415 connection,
4416 proto::DeleteNotification {
4417 notification_id: (*notification_id).to_proto(),
4418 },
4419 )?;
4420 }
4421
4422 for notification in &updated_mention_notifications {
4423 session.peer.send(
4424 connection,
4425 proto::UpdateNotification {
4426 notification: Some(notification.clone()),
4427 },
4428 )?;
4429 }
4430
4431 Ok(())
4432 },
4433 );
4434
4435 send_notifications(pool, &session.peer, notifications);
4436
4437 Ok(())
4438}
4439
4440/// Mark a channel message as read
4441async fn acknowledge_channel_message(
4442 request: proto::AckChannelMessage,
4443 session: UserSession,
4444) -> Result<()> {
4445 let channel_id = ChannelId::from_proto(request.channel_id);
4446 let message_id = MessageId::from_proto(request.message_id);
4447 let notifications = session
4448 .db()
4449 .await
4450 .observe_channel_message(channel_id, session.user_id(), message_id)
4451 .await?;
4452 send_notifications(
4453 &*session.connection_pool().await,
4454 &session.peer,
4455 notifications,
4456 );
4457 Ok(())
4458}
4459
4460/// Mark a buffer version as synced
4461async fn acknowledge_buffer_version(
4462 request: proto::AckBufferOperation,
4463 session: UserSession,
4464) -> Result<()> {
4465 let buffer_id = BufferId::from_proto(request.buffer_id);
4466 session
4467 .db()
4468 .await
4469 .observe_buffer_version(
4470 buffer_id,
4471 session.user_id(),
4472 request.epoch as i32,
4473 &request.version,
4474 )
4475 .await?;
4476 Ok(())
4477}
4478
4479async fn count_language_model_tokens(
4480 request: proto::CountLanguageModelTokens,
4481 response: Response<proto::CountLanguageModelTokens>,
4482 session: Session,
4483 config: &Config,
4484) -> Result<()> {
4485 let Some(session) = session.for_user() else {
4486 return Err(anyhow!("user not found"))?;
4487 };
4488 authorize_access_to_legacy_llm_endpoints(&session).await?;
4489
4490 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4491 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4492 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4493 };
4494
4495 session
4496 .app_state
4497 .rate_limiter
4498 .check(&*rate_limit, session.user_id())
4499 .await?;
4500
4501 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4502 Some(proto::LanguageModelProvider::Google) => {
4503 let api_key = config
4504 .google_ai_api_key
4505 .as_ref()
4506 .context("no Google AI API key configured on the server")?;
4507 google_ai::count_tokens(
4508 session.http_client.as_ref(),
4509 google_ai::API_URL,
4510 api_key,
4511 serde_json::from_str(&request.request)?,
4512 None,
4513 )
4514 .await?
4515 }
4516 _ => return Err(anyhow!("unsupported provider"))?,
4517 };
4518
4519 response.send(proto::CountLanguageModelTokensResponse {
4520 token_count: result.total_tokens as u32,
4521 })?;
4522
4523 Ok(())
4524}
4525
4526struct ZedProCountLanguageModelTokensRateLimit;
4527
4528impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4529 fn capacity(&self) -> usize {
4530 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4531 .ok()
4532 .and_then(|v| v.parse().ok())
4533 .unwrap_or(600) // Picked arbitrarily
4534 }
4535
4536 fn refill_duration(&self) -> chrono::Duration {
4537 chrono::Duration::hours(1)
4538 }
4539
4540 fn db_name(&self) -> &'static str {
4541 "zed-pro:count-language-model-tokens"
4542 }
4543}
4544
4545struct FreeCountLanguageModelTokensRateLimit;
4546
4547impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4548 fn capacity(&self) -> usize {
4549 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4550 .ok()
4551 .and_then(|v| v.parse().ok())
4552 .unwrap_or(600 / 10) // Picked arbitrarily
4553 }
4554
4555 fn refill_duration(&self) -> chrono::Duration {
4556 chrono::Duration::hours(1)
4557 }
4558
4559 fn db_name(&self) -> &'static str {
4560 "free:count-language-model-tokens"
4561 }
4562}
4563
4564struct ZedProComputeEmbeddingsRateLimit;
4565
4566impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4567 fn capacity(&self) -> usize {
4568 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4569 .ok()
4570 .and_then(|v| v.parse().ok())
4571 .unwrap_or(5000) // Picked arbitrarily
4572 }
4573
4574 fn refill_duration(&self) -> chrono::Duration {
4575 chrono::Duration::hours(1)
4576 }
4577
4578 fn db_name(&self) -> &'static str {
4579 "zed-pro:compute-embeddings"
4580 }
4581}
4582
4583struct FreeComputeEmbeddingsRateLimit;
4584
4585impl RateLimit for FreeComputeEmbeddingsRateLimit {
4586 fn capacity(&self) -> usize {
4587 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4588 .ok()
4589 .and_then(|v| v.parse().ok())
4590 .unwrap_or(5000 / 10) // Picked arbitrarily
4591 }
4592
4593 fn refill_duration(&self) -> chrono::Duration {
4594 chrono::Duration::hours(1)
4595 }
4596
4597 fn db_name(&self) -> &'static str {
4598 "free:compute-embeddings"
4599 }
4600}
4601
4602async fn compute_embeddings(
4603 request: proto::ComputeEmbeddings,
4604 response: Response<proto::ComputeEmbeddings>,
4605 session: UserSession,
4606 api_key: Option<Arc<str>>,
4607) -> Result<()> {
4608 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4609 authorize_access_to_legacy_llm_endpoints(&session).await?;
4610
4611 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4612 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4613 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4614 };
4615
4616 session
4617 .app_state
4618 .rate_limiter
4619 .check(&*rate_limit, session.user_id())
4620 .await?;
4621
4622 let embeddings = match request.model.as_str() {
4623 "openai/text-embedding-3-small" => {
4624 open_ai::embed(
4625 session.http_client.as_ref(),
4626 OPEN_AI_API_URL,
4627 &api_key,
4628 OpenAiEmbeddingModel::TextEmbedding3Small,
4629 request.texts.iter().map(|text| text.as_str()),
4630 )
4631 .await?
4632 }
4633 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4634 };
4635
4636 let embeddings = request
4637 .texts
4638 .iter()
4639 .map(|text| {
4640 let mut hasher = sha2::Sha256::new();
4641 hasher.update(text.as_bytes());
4642 let result = hasher.finalize();
4643 result.to_vec()
4644 })
4645 .zip(
4646 embeddings
4647 .data
4648 .into_iter()
4649 .map(|embedding| embedding.embedding),
4650 )
4651 .collect::<HashMap<_, _>>();
4652
4653 let db = session.db().await;
4654 db.save_embeddings(&request.model, &embeddings)
4655 .await
4656 .context("failed to save embeddings")
4657 .trace_err();
4658
4659 response.send(proto::ComputeEmbeddingsResponse {
4660 embeddings: embeddings
4661 .into_iter()
4662 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4663 .collect(),
4664 })?;
4665 Ok(())
4666}
4667
4668async fn get_cached_embeddings(
4669 request: proto::GetCachedEmbeddings,
4670 response: Response<proto::GetCachedEmbeddings>,
4671 session: UserSession,
4672) -> Result<()> {
4673 authorize_access_to_legacy_llm_endpoints(&session).await?;
4674
4675 let db = session.db().await;
4676 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4677
4678 response.send(proto::GetCachedEmbeddingsResponse {
4679 embeddings: embeddings
4680 .into_iter()
4681 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4682 .collect(),
4683 })?;
4684 Ok(())
4685}
4686
4687/// This is leftover from before the LLM service.
4688///
4689/// The endpoints protected by this check will be moved there eventually.
4690async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4691 if session.is_staff() {
4692 Ok(())
4693 } else {
4694 Err(anyhow!("permission denied"))?
4695 }
4696}
4697
4698/// Get a Supermaven API key for the user
4699async fn get_supermaven_api_key(
4700 _request: proto::GetSupermavenApiKey,
4701 response: Response<proto::GetSupermavenApiKey>,
4702 session: UserSession,
4703) -> Result<()> {
4704 let user_id: String = session.user_id().to_string();
4705 if !session.is_staff() {
4706 return Err(anyhow!("supermaven not enabled for this account"))?;
4707 }
4708
4709 let email = session
4710 .email()
4711 .ok_or_else(|| anyhow!("user must have an email"))?;
4712
4713 let supermaven_admin_api = session
4714 .supermaven_client
4715 .as_ref()
4716 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4717
4718 let result = supermaven_admin_api
4719 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4720 .await?;
4721
4722 response.send(proto::GetSupermavenApiKeyResponse {
4723 api_key: result.api_key,
4724 })?;
4725
4726 Ok(())
4727}
4728
4729/// Start receiving chat updates for a channel
4730async fn join_channel_chat(
4731 request: proto::JoinChannelChat,
4732 response: Response<proto::JoinChannelChat>,
4733 session: UserSession,
4734) -> Result<()> {
4735 let channel_id = ChannelId::from_proto(request.channel_id);
4736
4737 let db = session.db().await;
4738 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4739 .await?;
4740 let messages = db
4741 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4742 .await?;
4743 response.send(proto::JoinChannelChatResponse {
4744 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4745 messages,
4746 })?;
4747 Ok(())
4748}
4749
4750/// Stop receiving chat updates for a channel
4751async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4752 let channel_id = ChannelId::from_proto(request.channel_id);
4753 session
4754 .db()
4755 .await
4756 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4757 .await?;
4758 Ok(())
4759}
4760
4761/// Retrieve the chat history for a channel
4762async fn get_channel_messages(
4763 request: proto::GetChannelMessages,
4764 response: Response<proto::GetChannelMessages>,
4765 session: UserSession,
4766) -> Result<()> {
4767 let channel_id = ChannelId::from_proto(request.channel_id);
4768 let messages = session
4769 .db()
4770 .await
4771 .get_channel_messages(
4772 channel_id,
4773 session.user_id(),
4774 MESSAGE_COUNT_PER_PAGE,
4775 Some(MessageId::from_proto(request.before_message_id)),
4776 )
4777 .await?;
4778 response.send(proto::GetChannelMessagesResponse {
4779 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4780 messages,
4781 })?;
4782 Ok(())
4783}
4784
4785/// Retrieve specific chat messages
4786async fn get_channel_messages_by_id(
4787 request: proto::GetChannelMessagesById,
4788 response: Response<proto::GetChannelMessagesById>,
4789 session: UserSession,
4790) -> Result<()> {
4791 let message_ids = request
4792 .message_ids
4793 .iter()
4794 .map(|id| MessageId::from_proto(*id))
4795 .collect::<Vec<_>>();
4796 let messages = session
4797 .db()
4798 .await
4799 .get_channel_messages_by_id(session.user_id(), &message_ids)
4800 .await?;
4801 response.send(proto::GetChannelMessagesResponse {
4802 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4803 messages,
4804 })?;
4805 Ok(())
4806}
4807
4808/// Retrieve the current users notifications
4809async fn get_notifications(
4810 request: proto::GetNotifications,
4811 response: Response<proto::GetNotifications>,
4812 session: UserSession,
4813) -> Result<()> {
4814 let notifications = session
4815 .db()
4816 .await
4817 .get_notifications(
4818 session.user_id(),
4819 NOTIFICATION_COUNT_PER_PAGE,
4820 request.before_id.map(db::NotificationId::from_proto),
4821 )
4822 .await?;
4823 response.send(proto::GetNotificationsResponse {
4824 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4825 notifications,
4826 })?;
4827 Ok(())
4828}
4829
4830/// Mark notifications as read
4831async fn mark_notification_as_read(
4832 request: proto::MarkNotificationRead,
4833 response: Response<proto::MarkNotificationRead>,
4834 session: UserSession,
4835) -> Result<()> {
4836 let database = &session.db().await;
4837 let notifications = database
4838 .mark_notification_as_read_by_id(
4839 session.user_id(),
4840 NotificationId::from_proto(request.notification_id),
4841 )
4842 .await?;
4843 send_notifications(
4844 &*session.connection_pool().await,
4845 &session.peer,
4846 notifications,
4847 );
4848 response.send(proto::Ack {})?;
4849 Ok(())
4850}
4851
4852/// Get the current users information
4853async fn get_private_user_info(
4854 _request: proto::GetPrivateUserInfo,
4855 response: Response<proto::GetPrivateUserInfo>,
4856 session: UserSession,
4857) -> Result<()> {
4858 let db = session.db().await;
4859
4860 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4861 let user = db
4862 .get_user_by_id(session.user_id())
4863 .await?
4864 .ok_or_else(|| anyhow!("user not found"))?;
4865 let flags = db.get_user_flags(session.user_id()).await?;
4866
4867 response.send(proto::GetPrivateUserInfoResponse {
4868 metrics_id,
4869 staff: user.admin,
4870 flags,
4871 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4872 })?;
4873 Ok(())
4874}
4875
4876/// Accept the terms of service (tos) on behalf of the current user
4877async fn accept_terms_of_service(
4878 _request: proto::AcceptTermsOfService,
4879 response: Response<proto::AcceptTermsOfService>,
4880 session: UserSession,
4881) -> Result<()> {
4882 let db = session.db().await;
4883
4884 let accepted_tos_at = Utc::now();
4885 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4886 .await?;
4887
4888 response.send(proto::AcceptTermsOfServiceResponse {
4889 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4890 })?;
4891 Ok(())
4892}
4893
4894/// The minimum account age an account must have in order to use the LLM service.
4895const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4896
4897async fn get_llm_api_token(
4898 _request: proto::GetLlmToken,
4899 response: Response<proto::GetLlmToken>,
4900 session: UserSession,
4901) -> Result<()> {
4902 let db = session.db().await;
4903
4904 let flags = db.get_user_flags(session.user_id()).await?;
4905 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4906 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4907
4908 if !session.is_staff() && !has_language_models_feature_flag {
4909 Err(anyhow!("permission denied"))?
4910 }
4911
4912 let user_id = session.user_id();
4913 let user = db
4914 .get_user_by_id(user_id)
4915 .await?
4916 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4917
4918 if user.accepted_tos_at.is_none() {
4919 Err(anyhow!("terms of service not accepted"))?
4920 }
4921
4922 let mut account_created_at = user.created_at;
4923 if let Some(github_created_at) = user.github_user_created_at {
4924 account_created_at = account_created_at.min(github_created_at);
4925 }
4926 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4927 Err(anyhow!("account too young"))?
4928 }
4929
4930 let billing_preferences = db.get_billing_preferences(user.id).await?;
4931
4932 let token = LlmTokenClaims::create(
4933 user.id,
4934 user.github_login.clone(),
4935 session.is_staff(),
4936 billing_preferences,
4937 has_llm_closed_beta_feature_flag,
4938 session.has_llm_subscription(&db).await?,
4939 session.current_plan(&db).await?,
4940 &session.app_state.config,
4941 )?;
4942 response.send(proto::GetLlmTokenResponse { token })?;
4943 Ok(())
4944}
4945
4946fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4947 let message = match message {
4948 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4949 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4950 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4951 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4952 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4953 code: frame.code.into(),
4954 reason: frame.reason,
4955 })),
4956 // We should never receive a frame while reading the message, according
4957 // to the `tungstenite` maintainers:
4958 //
4959 // > It cannot occur when you read messages from the WebSocket, but it
4960 // > can be used when you want to send the raw frames (e.g. you want to
4961 // > send the frames to the WebSocket without composing the full message first).
4962 // >
4963 // > — https://github.com/snapview/tungstenite-rs/issues/268
4964 TungsteniteMessage::Frame(_) => {
4965 bail!("received an unexpected frame while reading the message")
4966 }
4967 };
4968
4969 Ok(message)
4970}
4971
4972fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4973 match message {
4974 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4975 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4976 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4977 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4978 AxumMessage::Close(frame) => {
4979 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4980 code: frame.code.into(),
4981 reason: frame.reason,
4982 }))
4983 }
4984 }
4985}
4986
4987fn notify_membership_updated(
4988 connection_pool: &mut ConnectionPool,
4989 result: MembershipUpdated,
4990 user_id: UserId,
4991 peer: &Peer,
4992) {
4993 for membership in &result.new_channels.channel_memberships {
4994 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4995 }
4996 for channel_id in &result.removed_channels {
4997 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4998 }
4999
5000 let user_channels_update = proto::UpdateUserChannels {
5001 channel_memberships: result
5002 .new_channels
5003 .channel_memberships
5004 .iter()
5005 .map(|cm| proto::ChannelMembership {
5006 channel_id: cm.channel_id.to_proto(),
5007 role: cm.role.into(),
5008 })
5009 .collect(),
5010 ..Default::default()
5011 };
5012
5013 let mut update = build_channels_update(result.new_channels);
5014 update.delete_channels = result
5015 .removed_channels
5016 .into_iter()
5017 .map(|id| id.to_proto())
5018 .collect();
5019 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5020
5021 for connection_id in connection_pool.user_connection_ids(user_id) {
5022 peer.send(connection_id, user_channels_update.clone())
5023 .trace_err();
5024 peer.send(connection_id, update.clone()).trace_err();
5025 }
5026}
5027
5028fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5029 proto::UpdateUserChannels {
5030 channel_memberships: channels
5031 .channel_memberships
5032 .iter()
5033 .map(|m| proto::ChannelMembership {
5034 channel_id: m.channel_id.to_proto(),
5035 role: m.role.into(),
5036 })
5037 .collect(),
5038 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5039 observed_channel_message_id: channels.observed_channel_messages.clone(),
5040 }
5041}
5042
5043fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5044 let mut update = proto::UpdateChannels::default();
5045
5046 for channel in channels.channels {
5047 update.channels.push(channel.to_proto());
5048 }
5049
5050 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5051 update.latest_channel_message_ids = channels.latest_channel_messages;
5052
5053 for (channel_id, participants) in channels.channel_participants {
5054 update
5055 .channel_participants
5056 .push(proto::ChannelParticipants {
5057 channel_id: channel_id.to_proto(),
5058 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5059 });
5060 }
5061
5062 for channel in channels.invited_channels {
5063 update.channel_invitations.push(channel.to_proto());
5064 }
5065
5066 update.hosted_projects = channels.hosted_projects;
5067 update
5068}
5069
5070fn build_initial_contacts_update(
5071 contacts: Vec<db::Contact>,
5072 pool: &ConnectionPool,
5073) -> proto::UpdateContacts {
5074 let mut update = proto::UpdateContacts::default();
5075
5076 for contact in contacts {
5077 match contact {
5078 db::Contact::Accepted { user_id, busy } => {
5079 update.contacts.push(contact_for_user(user_id, busy, pool));
5080 }
5081 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5082 db::Contact::Incoming { user_id } => {
5083 update
5084 .incoming_requests
5085 .push(proto::IncomingContactRequest {
5086 requester_id: user_id.to_proto(),
5087 })
5088 }
5089 }
5090 }
5091
5092 update
5093}
5094
5095fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5096 proto::Contact {
5097 user_id: user_id.to_proto(),
5098 online: pool.is_user_online(user_id),
5099 busy,
5100 }
5101}
5102
5103fn room_updated(room: &proto::Room, peer: &Peer) {
5104 broadcast(
5105 None,
5106 room.participants
5107 .iter()
5108 .filter_map(|participant| Some(participant.peer_id?.into())),
5109 |peer_id| {
5110 peer.send(
5111 peer_id,
5112 proto::RoomUpdated {
5113 room: Some(room.clone()),
5114 },
5115 )
5116 },
5117 );
5118}
5119
5120fn channel_updated(
5121 channel: &db::channel::Model,
5122 room: &proto::Room,
5123 peer: &Peer,
5124 pool: &ConnectionPool,
5125) {
5126 let participants = room
5127 .participants
5128 .iter()
5129 .map(|p| p.user_id)
5130 .collect::<Vec<_>>();
5131
5132 broadcast(
5133 None,
5134 pool.channel_connection_ids(channel.root_id())
5135 .filter_map(|(channel_id, role)| {
5136 role.can_see_channel(channel.visibility)
5137 .then_some(channel_id)
5138 }),
5139 |peer_id| {
5140 peer.send(
5141 peer_id,
5142 proto::UpdateChannels {
5143 channel_participants: vec![proto::ChannelParticipants {
5144 channel_id: channel.id.to_proto(),
5145 participant_user_ids: participants.clone(),
5146 }],
5147 ..Default::default()
5148 },
5149 )
5150 },
5151 );
5152}
5153
5154async fn send_dev_server_projects_update(
5155 user_id: UserId,
5156 mut status: proto::DevServerProjectsUpdate,
5157 session: &Session,
5158) {
5159 let pool = session.connection_pool().await;
5160 for dev_server in &mut status.dev_servers {
5161 dev_server.status =
5162 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5163 }
5164 let connections = pool.user_connection_ids(user_id);
5165 for connection_id in connections {
5166 session.peer.send(connection_id, status.clone()).trace_err();
5167 }
5168}
5169
5170async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5171 let db = session.db().await;
5172
5173 let contacts = db.get_contacts(user_id).await?;
5174 let busy = db.is_user_busy(user_id).await?;
5175
5176 let pool = session.connection_pool().await;
5177 let updated_contact = contact_for_user(user_id, busy, &pool);
5178 for contact in contacts {
5179 if let db::Contact::Accepted {
5180 user_id: contact_user_id,
5181 ..
5182 } = contact
5183 {
5184 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5185 session
5186 .peer
5187 .send(
5188 contact_conn_id,
5189 proto::UpdateContacts {
5190 contacts: vec![updated_contact.clone()],
5191 remove_contacts: Default::default(),
5192 incoming_requests: Default::default(),
5193 remove_incoming_requests: Default::default(),
5194 outgoing_requests: Default::default(),
5195 remove_outgoing_requests: Default::default(),
5196 },
5197 )
5198 .trace_err();
5199 }
5200 }
5201 }
5202 Ok(())
5203}
5204
5205async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5206 log::info!("lost dev server connection, unsharing projects");
5207 let project_ids = session
5208 .db()
5209 .await
5210 .get_stale_dev_server_projects(session.connection_id)
5211 .await?;
5212
5213 for project_id in project_ids {
5214 // not unshare re-checks the connection ids match, so we get away with no transaction
5215 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5216 }
5217
5218 let user_id = session.dev_server().user_id;
5219 let update = session
5220 .db()
5221 .await
5222 .dev_server_projects_update(user_id)
5223 .await?;
5224
5225 send_dev_server_projects_update(user_id, update, session).await;
5226
5227 Ok(())
5228}
5229
5230async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5231 let mut contacts_to_update = HashSet::default();
5232
5233 let room_id;
5234 let canceled_calls_to_user_ids;
5235 let live_kit_room;
5236 let delete_live_kit_room;
5237 let room;
5238 let channel;
5239
5240 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5241 contacts_to_update.insert(session.user_id());
5242
5243 for project in left_room.left_projects.values() {
5244 project_left(project, session);
5245 }
5246
5247 room_id = RoomId::from_proto(left_room.room.id);
5248 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5249 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5250 delete_live_kit_room = left_room.deleted;
5251 room = mem::take(&mut left_room.room);
5252 channel = mem::take(&mut left_room.channel);
5253
5254 room_updated(&room, &session.peer);
5255 } else {
5256 return Ok(());
5257 }
5258
5259 if let Some(channel) = channel {
5260 channel_updated(
5261 &channel,
5262 &room,
5263 &session.peer,
5264 &*session.connection_pool().await,
5265 );
5266 }
5267
5268 {
5269 let pool = session.connection_pool().await;
5270 for canceled_user_id in canceled_calls_to_user_ids {
5271 for connection_id in pool.user_connection_ids(canceled_user_id) {
5272 session
5273 .peer
5274 .send(
5275 connection_id,
5276 proto::CallCanceled {
5277 room_id: room_id.to_proto(),
5278 },
5279 )
5280 .trace_err();
5281 }
5282 contacts_to_update.insert(canceled_user_id);
5283 }
5284 }
5285
5286 for contact_user_id in contacts_to_update {
5287 update_user_contacts(contact_user_id, session).await?;
5288 }
5289
5290 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5291 live_kit
5292 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5293 .await
5294 .trace_err();
5295
5296 if delete_live_kit_room {
5297 live_kit.delete_room(live_kit_room).await.trace_err();
5298 }
5299 }
5300
5301 Ok(())
5302}
5303
5304async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5305 let left_channel_buffers = session
5306 .db()
5307 .await
5308 .leave_channel_buffers(session.connection_id)
5309 .await?;
5310
5311 for left_buffer in left_channel_buffers {
5312 channel_buffer_updated(
5313 session.connection_id,
5314 left_buffer.connections,
5315 &proto::UpdateChannelBufferCollaborators {
5316 channel_id: left_buffer.channel_id.to_proto(),
5317 collaborators: left_buffer.collaborators,
5318 },
5319 &session.peer,
5320 );
5321 }
5322
5323 Ok(())
5324}
5325
5326fn project_left(project: &db::LeftProject, session: &UserSession) {
5327 for connection_id in &project.connection_ids {
5328 if project.should_unshare {
5329 session
5330 .peer
5331 .send(
5332 *connection_id,
5333 proto::UnshareProject {
5334 project_id: project.id.to_proto(),
5335 },
5336 )
5337 .trace_err();
5338 } else {
5339 session
5340 .peer
5341 .send(
5342 *connection_id,
5343 proto::RemoveProjectCollaborator {
5344 project_id: project.id.to_proto(),
5345 peer_id: Some(session.connection_id.into()),
5346 },
5347 )
5348 .trace_err();
5349 }
5350 }
5351}
5352
5353pub trait ResultExt {
5354 type Ok;
5355
5356 fn trace_err(self) -> Option<Self::Ok>;
5357}
5358
5359impl<T, E> ResultExt for Result<T, E>
5360where
5361 E: std::fmt::Debug,
5362{
5363 type Ok = T;
5364
5365 #[track_caller]
5366 fn trace_err(self) -> Option<T> {
5367 match self {
5368 Ok(value) => Some(value),
5369 Err(error) => {
5370 tracing::error!("{:?}", error);
5371 None
5372 }
5373 }
5374 }
5375}