1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use isahc_http_client::IsahcHttpClient;
40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn has_llm_subscription(
195 &self,
196 db: &MutexGuard<'_, DbHandle>,
197 ) -> anyhow::Result<bool> {
198 if self.is_staff() {
199 return Ok(true);
200 }
201
202 let Some(user_id) = self.user_id() else {
203 return Ok(false);
204 };
205
206 Ok(db.has_active_billing_subscription(user_id).await?)
207 }
208
209 pub async fn current_plan(
210 &self,
211 _db: &MutexGuard<'_, DbHandle>,
212 ) -> anyhow::Result<proto::Plan> {
213 if self.is_staff() {
214 Ok(proto::Plan::ZedPro)
215 } else {
216 Ok(proto::Plan::Free)
217 }
218 }
219
220 fn dev_server_id(&self) -> Option<DevServerId> {
221 match &self.principal {
222 Principal::User(_) | Principal::Impersonated { .. } => None,
223 Principal::DevServer(dev_server) => Some(dev_server.id),
224 }
225 }
226
227 fn principal_id(&self) -> PrincipalId {
228 match &self.principal {
229 Principal::User(user) => PrincipalId::UserId(user.id),
230 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
231 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
232 }
233 }
234}
235
236impl Debug for Session {
237 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
238 let mut result = f.debug_struct("Session");
239 match &self.principal {
240 Principal::User(user) => {
241 result.field("user", &user.github_login);
242 }
243 Principal::Impersonated { user, admin } => {
244 result.field("user", &user.github_login);
245 result.field("impersonator", &admin.github_login);
246 }
247 Principal::DevServer(dev_server) => {
248 result.field("dev_server", &dev_server.id);
249 }
250 }
251 result.field("connection_id", &self.connection_id).finish()
252 }
253}
254
255struct UserSession(Session);
256
257impl UserSession {
258 pub fn new(s: Session) -> Option<Self> {
259 s.user_id().map(|_| UserSession(s))
260 }
261 pub fn user_id(&self) -> UserId {
262 self.0.user_id().unwrap()
263 }
264
265 pub fn email(&self) -> Option<String> {
266 match &self.0.principal {
267 Principal::User(user) => user.email_address.clone(),
268 Principal::Impersonated { user, .. } => user.email_address.clone(),
269 Principal::DevServer(..) => None,
270 }
271 }
272}
273
274impl Deref for UserSession {
275 type Target = Session;
276
277 fn deref(&self) -> &Self::Target {
278 &self.0
279 }
280}
281impl DerefMut for UserSession {
282 fn deref_mut(&mut self) -> &mut Self::Target {
283 &mut self.0
284 }
285}
286
287struct DevServerSession(Session);
288
289impl DevServerSession {
290 pub fn new(s: Session) -> Option<Self> {
291 s.dev_server_id().map(|_| DevServerSession(s))
292 }
293 pub fn dev_server_id(&self) -> DevServerId {
294 self.0.dev_server_id().unwrap()
295 }
296
297 fn dev_server(&self) -> &dev_server::Model {
298 match &self.0.principal {
299 Principal::DevServer(dev_server) => dev_server,
300 _ => unreachable!(),
301 }
302 }
303}
304
305impl Deref for DevServerSession {
306 type Target = Session;
307
308 fn deref(&self) -> &Self::Target {
309 &self.0
310 }
311}
312impl DerefMut for DevServerSession {
313 fn deref_mut(&mut self) -> &mut Self::Target {
314 &mut self.0
315 }
316}
317
318fn user_handler<M: RequestMessage, Fut>(
319 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
321where
322 Fut: Send + Future<Output = Result<()>>,
323{
324 let handler = Arc::new(handler);
325 move |message, response, session| {
326 let handler = handler.clone();
327 Box::pin(async move {
328 if let Some(user_session) = session.for_user() {
329 Ok(handler(message, response, user_session).await?)
330 } else {
331 Err(Error::Internal(anyhow!(
332 "must be a user to call {}",
333 M::NAME
334 )))
335 }
336 })
337 }
338}
339
340fn dev_server_handler<M: RequestMessage, Fut>(
341 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
343where
344 Fut: Send + Future<Output = Result<()>>,
345{
346 let handler = Arc::new(handler);
347 move |message, response, session| {
348 let handler = handler.clone();
349 Box::pin(async move {
350 if let Some(dev_server_session) = session.for_dev_server() {
351 Ok(handler(message, response, dev_server_session).await?)
352 } else {
353 Err(Error::Internal(anyhow!(
354 "must be a dev server to call {}",
355 M::NAME
356 )))
357 }
358 })
359 }
360}
361
362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
363 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
365where
366 InnertRetFut: Send + Future<Output = Result<()>>,
367{
368 let handler = Arc::new(handler);
369 move |message, session| {
370 let handler = handler.clone();
371 Box::pin(async move {
372 if let Some(user_session) = session.for_user() {
373 Ok(handler(message, user_session).await?)
374 } else {
375 Err(Error::Internal(anyhow!(
376 "must be a user to call {}",
377 M::NAME
378 )))
379 }
380 })
381 }
382}
383
384struct DbHandle(Arc<Database>);
385
386impl Deref for DbHandle {
387 type Target = Database;
388
389 fn deref(&self) -> &Self::Target {
390 self.0.as_ref()
391 }
392}
393
394pub struct Server {
395 id: parking_lot::Mutex<ServerId>,
396 peer: Arc<Peer>,
397 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
398 app_state: Arc<AppState>,
399 handlers: HashMap<TypeId, MessageHandler>,
400 teardown: watch::Sender<bool>,
401}
402
403pub(crate) struct ConnectionPoolGuard<'a> {
404 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
405 _not_send: PhantomData<Rc<()>>,
406}
407
408#[derive(Serialize)]
409pub struct ServerSnapshot<'a> {
410 peer: &'a Peer,
411 #[serde(serialize_with = "serialize_deref")]
412 connection_pool: ConnectionPoolGuard<'a>,
413}
414
415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
416where
417 S: Serializer,
418 T: Deref<Target = U>,
419 U: Serialize,
420{
421 Serialize::serialize(value.deref(), serializer)
422}
423
424impl Server {
425 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
426 let mut server = Self {
427 id: parking_lot::Mutex::new(id),
428 peer: Peer::new(id.0 as u32),
429 app_state: app_state.clone(),
430 connection_pool: Default::default(),
431 handlers: Default::default(),
432 teardown: watch::channel(false).0,
433 };
434
435 server
436 .add_request_handler(ping)
437 .add_request_handler(user_handler(create_room))
438 .add_request_handler(user_handler(join_room))
439 .add_request_handler(user_handler(rejoin_room))
440 .add_request_handler(user_handler(leave_room))
441 .add_request_handler(user_handler(set_room_participant_role))
442 .add_request_handler(user_handler(call))
443 .add_request_handler(user_handler(cancel_call))
444 .add_message_handler(user_message_handler(decline_call))
445 .add_request_handler(user_handler(update_participant_location))
446 .add_request_handler(user_handler(share_project))
447 .add_message_handler(unshare_project)
448 .add_request_handler(user_handler(join_project))
449 .add_request_handler(user_handler(join_hosted_project))
450 .add_request_handler(user_handler(rejoin_dev_server_projects))
451 .add_request_handler(user_handler(create_dev_server_project))
452 .add_request_handler(user_handler(update_dev_server_project))
453 .add_request_handler(user_handler(delete_dev_server_project))
454 .add_request_handler(user_handler(create_dev_server))
455 .add_request_handler(user_handler(regenerate_dev_server_token))
456 .add_request_handler(user_handler(rename_dev_server))
457 .add_request_handler(user_handler(delete_dev_server))
458 .add_request_handler(user_handler(list_remote_directory))
459 .add_request_handler(dev_server_handler(share_dev_server_project))
460 .add_request_handler(dev_server_handler(shutdown_dev_server))
461 .add_request_handler(dev_server_handler(reconnect_dev_server))
462 .add_message_handler(user_message_handler(leave_project))
463 .add_request_handler(update_project)
464 .add_request_handler(update_worktree)
465 .add_message_handler(start_language_server)
466 .add_message_handler(update_language_server)
467 .add_message_handler(update_diagnostic_summary)
468 .add_message_handler(update_worktree_settings)
469 .add_request_handler(user_handler(
470 forward_project_request_for_owner::<proto::TaskContextForLocation>,
471 ))
472 .add_request_handler(user_handler(
473 forward_project_request_for_owner::<proto::TaskTemplates>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetHover>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetDefinition>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetTypeDefinition>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetReferences>,
486 ))
487 .add_request_handler(user_handler(forward_find_search_candidates_request))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::GetDocumentHighlights>,
490 ))
491 .add_request_handler(user_handler(
492 forward_read_only_project_request::<proto::GetProjectSymbols>,
493 ))
494 .add_request_handler(user_handler(
495 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
496 ))
497 .add_request_handler(user_handler(
498 forward_read_only_project_request::<proto::OpenBufferById>,
499 ))
500 .add_request_handler(user_handler(
501 forward_read_only_project_request::<proto::SynchronizeBuffers>,
502 ))
503 .add_request_handler(user_handler(
504 forward_read_only_project_request::<proto::InlayHints>,
505 ))
506 .add_request_handler(user_handler(
507 forward_read_only_project_request::<proto::ResolveInlayHint>,
508 ))
509 .add_request_handler(user_handler(
510 forward_read_only_project_request::<proto::OpenBufferByPath>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::GetCompletions>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::OpenNewBuffer>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::GetCodeActions>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::ApplyCodeAction>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::PrepareRename>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::PerformRename>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::ReloadBuffers>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::FormatBuffers>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::CreateProjectEntry>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::RenameProjectEntry>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::CopyProjectEntry>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::DeleteProjectEntry>,
553 ))
554 .add_request_handler(user_handler(
555 forward_mutating_project_request::<proto::ExpandProjectEntry>,
556 ))
557 .add_request_handler(user_handler(
558 forward_mutating_project_request::<proto::OnTypeFormatting>,
559 ))
560 .add_request_handler(user_handler(
561 forward_mutating_project_request::<proto::SaveBuffer>,
562 ))
563 .add_request_handler(user_handler(
564 forward_mutating_project_request::<proto::BlameBuffer>,
565 ))
566 .add_request_handler(user_handler(
567 forward_mutating_project_request::<proto::MultiLspQuery>,
568 ))
569 .add_request_handler(user_handler(
570 forward_mutating_project_request::<proto::RestartLanguageServers>,
571 ))
572 .add_request_handler(user_handler(
573 forward_mutating_project_request::<proto::LinkedEditingRange>,
574 ))
575 .add_message_handler(create_buffer_for_peer)
576 .add_request_handler(update_buffer)
577 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
578 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
579 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
580 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
581 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
582 .add_request_handler(get_users)
583 .add_request_handler(user_handler(fuzzy_search_users))
584 .add_request_handler(user_handler(request_contact))
585 .add_request_handler(user_handler(remove_contact))
586 .add_request_handler(user_handler(respond_to_contact_request))
587 .add_message_handler(subscribe_to_channels)
588 .add_request_handler(user_handler(create_channel))
589 .add_request_handler(user_handler(delete_channel))
590 .add_request_handler(user_handler(invite_channel_member))
591 .add_request_handler(user_handler(remove_channel_member))
592 .add_request_handler(user_handler(set_channel_member_role))
593 .add_request_handler(user_handler(set_channel_visibility))
594 .add_request_handler(user_handler(rename_channel))
595 .add_request_handler(user_handler(join_channel_buffer))
596 .add_request_handler(user_handler(leave_channel_buffer))
597 .add_message_handler(user_message_handler(update_channel_buffer))
598 .add_request_handler(user_handler(rejoin_channel_buffers))
599 .add_request_handler(user_handler(get_channel_members))
600 .add_request_handler(user_handler(respond_to_channel_invite))
601 .add_request_handler(user_handler(join_channel))
602 .add_request_handler(user_handler(join_channel_chat))
603 .add_message_handler(user_message_handler(leave_channel_chat))
604 .add_request_handler(user_handler(send_channel_message))
605 .add_request_handler(user_handler(remove_channel_message))
606 .add_request_handler(user_handler(update_channel_message))
607 .add_request_handler(user_handler(get_channel_messages))
608 .add_request_handler(user_handler(get_channel_messages_by_id))
609 .add_request_handler(user_handler(get_notifications))
610 .add_request_handler(user_handler(mark_notification_as_read))
611 .add_request_handler(user_handler(move_channel))
612 .add_request_handler(user_handler(follow))
613 .add_message_handler(user_message_handler(unfollow))
614 .add_message_handler(user_message_handler(update_followers))
615 .add_request_handler(user_handler(get_private_user_info))
616 .add_request_handler(user_handler(get_llm_api_token))
617 .add_request_handler(user_handler(accept_terms_of_service))
618 .add_message_handler(user_message_handler(acknowledge_channel_message))
619 .add_message_handler(user_message_handler(acknowledge_buffer_version))
620 .add_request_handler(user_handler(get_supermaven_api_key))
621 .add_request_handler(user_handler(
622 forward_mutating_project_request::<proto::OpenContext>,
623 ))
624 .add_request_handler(user_handler(
625 forward_mutating_project_request::<proto::CreateContext>,
626 ))
627 .add_request_handler(user_handler(
628 forward_mutating_project_request::<proto::SynchronizeContexts>,
629 ))
630 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
631 .add_message_handler(update_context)
632 .add_request_handler({
633 let app_state = app_state.clone();
634 move |request, response, session| {
635 let app_state = app_state.clone();
636 async move {
637 count_language_model_tokens(request, response, session, &app_state.config)
638 .await
639 }
640 }
641 })
642 .add_request_handler({
643 user_handler(move |request, response, session| {
644 get_cached_embeddings(request, response, session)
645 })
646 })
647 .add_request_handler({
648 let app_state = app_state.clone();
649 user_handler(move |request, response, session| {
650 compute_embeddings(
651 request,
652 response,
653 session,
654 app_state.config.openai_api_key.clone(),
655 )
656 })
657 });
658
659 Arc::new(server)
660 }
661
662 pub async fn start(&self) -> Result<()> {
663 let server_id = *self.id.lock();
664 let app_state = self.app_state.clone();
665 let peer = self.peer.clone();
666 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
667 let pool = self.connection_pool.clone();
668 let live_kit_client = self.app_state.live_kit_client.clone();
669
670 let span = info_span!("start server");
671 self.app_state.executor.spawn_detached(
672 async move {
673 tracing::info!("waiting for cleanup timeout");
674 timeout.await;
675 tracing::info!("cleanup timeout expired, retrieving stale rooms");
676 if let Some((room_ids, channel_ids)) = app_state
677 .db
678 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
679 .await
680 .trace_err()
681 {
682 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
683 tracing::info!(
684 stale_channel_buffer_count = channel_ids.len(),
685 "retrieved stale channel buffers"
686 );
687
688 for channel_id in channel_ids {
689 if let Some(refreshed_channel_buffer) = app_state
690 .db
691 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
692 .await
693 .trace_err()
694 {
695 for connection_id in refreshed_channel_buffer.connection_ids {
696 peer.send(
697 connection_id,
698 proto::UpdateChannelBufferCollaborators {
699 channel_id: channel_id.to_proto(),
700 collaborators: refreshed_channel_buffer
701 .collaborators
702 .clone(),
703 },
704 )
705 .trace_err();
706 }
707 }
708 }
709
710 for room_id in room_ids {
711 let mut contacts_to_update = HashSet::default();
712 let mut canceled_calls_to_user_ids = Vec::new();
713 let mut live_kit_room = String::new();
714 let mut delete_live_kit_room = false;
715
716 if let Some(mut refreshed_room) = app_state
717 .db
718 .clear_stale_room_participants(room_id, server_id)
719 .await
720 .trace_err()
721 {
722 tracing::info!(
723 room_id = room_id.0,
724 new_participant_count = refreshed_room.room.participants.len(),
725 "refreshed room"
726 );
727 room_updated(&refreshed_room.room, &peer);
728 if let Some(channel) = refreshed_room.channel.as_ref() {
729 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
730 }
731 contacts_to_update
732 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
733 contacts_to_update
734 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
735 canceled_calls_to_user_ids =
736 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
737 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
738 delete_live_kit_room = refreshed_room.room.participants.is_empty();
739 }
740
741 {
742 let pool = pool.lock();
743 for canceled_user_id in canceled_calls_to_user_ids {
744 for connection_id in pool.user_connection_ids(canceled_user_id) {
745 peer.send(
746 connection_id,
747 proto::CallCanceled {
748 room_id: room_id.to_proto(),
749 },
750 )
751 .trace_err();
752 }
753 }
754 }
755
756 for user_id in contacts_to_update {
757 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
758 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
759 if let Some((busy, contacts)) = busy.zip(contacts) {
760 let pool = pool.lock();
761 let updated_contact = contact_for_user(user_id, busy, &pool);
762 for contact in contacts {
763 if let db::Contact::Accepted {
764 user_id: contact_user_id,
765 ..
766 } = contact
767 {
768 for contact_conn_id in
769 pool.user_connection_ids(contact_user_id)
770 {
771 peer.send(
772 contact_conn_id,
773 proto::UpdateContacts {
774 contacts: vec![updated_contact.clone()],
775 remove_contacts: Default::default(),
776 incoming_requests: Default::default(),
777 remove_incoming_requests: Default::default(),
778 outgoing_requests: Default::default(),
779 remove_outgoing_requests: Default::default(),
780 },
781 )
782 .trace_err();
783 }
784 }
785 }
786 }
787 }
788
789 if let Some(live_kit) = live_kit_client.as_ref() {
790 if delete_live_kit_room {
791 live_kit.delete_room(live_kit_room).await.trace_err();
792 }
793 }
794 }
795 }
796
797 app_state
798 .db
799 .delete_stale_servers(&app_state.config.zed_environment, server_id)
800 .await
801 .trace_err();
802 }
803 .instrument(span),
804 );
805 Ok(())
806 }
807
808 pub fn teardown(&self) {
809 self.peer.teardown();
810 self.connection_pool.lock().reset();
811 let _ = self.teardown.send(true);
812 }
813
814 #[cfg(test)]
815 pub fn reset(&self, id: ServerId) {
816 self.teardown();
817 *self.id.lock() = id;
818 self.peer.reset(id.0 as u32);
819 let _ = self.teardown.send(false);
820 }
821
822 #[cfg(test)]
823 pub fn id(&self) -> ServerId {
824 *self.id.lock()
825 }
826
827 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
828 where
829 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
830 Fut: 'static + Send + Future<Output = Result<()>>,
831 M: EnvelopedMessage,
832 {
833 let prev_handler = self.handlers.insert(
834 TypeId::of::<M>(),
835 Box::new(move |envelope, session| {
836 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
837 let received_at = envelope.received_at;
838 tracing::info!("message received");
839 let start_time = Instant::now();
840 let future = (handler)(*envelope, session);
841 async move {
842 let result = future.await;
843 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
844 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
845 let queue_duration_ms = total_duration_ms - processing_duration_ms;
846 let payload_type = M::NAME;
847
848 match result {
849 Err(error) => {
850 tracing::error!(
851 ?error,
852 total_duration_ms,
853 processing_duration_ms,
854 queue_duration_ms,
855 payload_type,
856 "error handling message"
857 )
858 }
859 Ok(()) => tracing::info!(
860 total_duration_ms,
861 processing_duration_ms,
862 queue_duration_ms,
863 "finished handling message"
864 ),
865 }
866 }
867 .boxed()
868 }),
869 );
870 if prev_handler.is_some() {
871 panic!("registered a handler for the same message twice");
872 }
873 self
874 }
875
876 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
879 Fut: 'static + Send + Future<Output = Result<()>>,
880 M: EnvelopedMessage,
881 {
882 self.add_handler(move |envelope, session| handler(envelope.payload, session));
883 self
884 }
885
886 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
887 where
888 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
889 Fut: Send + Future<Output = Result<()>>,
890 M: RequestMessage,
891 {
892 let handler = Arc::new(handler);
893 self.add_handler(move |envelope, session| {
894 let receipt = envelope.receipt();
895 let handler = handler.clone();
896 async move {
897 let peer = session.peer.clone();
898 let responded = Arc::new(AtomicBool::default());
899 let response = Response {
900 peer: peer.clone(),
901 responded: responded.clone(),
902 receipt,
903 };
904 match (handler)(envelope.payload, response, session).await {
905 Ok(()) => {
906 if responded.load(std::sync::atomic::Ordering::SeqCst) {
907 Ok(())
908 } else {
909 Err(anyhow!("handler did not send a response"))?
910 }
911 }
912 Err(error) => {
913 let proto_err = match &error {
914 Error::Internal(err) => err.to_proto(),
915 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
916 };
917 peer.respond_with_error(receipt, proto_err)?;
918 Err(error)
919 }
920 }
921 }
922 })
923 }
924
925 #[allow(clippy::too_many_arguments)]
926 pub fn handle_connection(
927 self: &Arc<Self>,
928 connection: Connection,
929 address: String,
930 principal: Principal,
931 zed_version: ZedVersion,
932 geoip_country_code: Option<String>,
933 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 executor: Executor,
935 ) -> impl Future<Output = ()> {
936 let this = self.clone();
937 let span = info_span!("handle connection", %address,
938 connection_id=field::Empty,
939 user_id=field::Empty,
940 login=field::Empty,
941 impersonator=field::Empty,
942 dev_server_id=field::Empty,
943 geoip_country_code=field::Empty
944 );
945 principal.update_span(&span);
946 if let Some(country_code) = geoip_country_code.as_ref() {
947 span.record("geoip_country_code", country_code);
948 }
949
950 let mut teardown = self.teardown.subscribe();
951 async move {
952 if *teardown.borrow() {
953 tracing::error!("server is tearing down");
954 return
955 }
956 let (connection_id, handle_io, mut incoming_rx) = this
957 .peer
958 .add_connection(connection, {
959 let executor = executor.clone();
960 move |duration| executor.sleep(duration)
961 });
962 tracing::Span::current().record("connection_id", format!("{}", connection_id));
963
964 tracing::info!("connection opened");
965
966 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
967 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
968 Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
969 Err(error) => {
970 tracing::error!(?error, "failed to create HTTP client");
971 return;
972 }
973 };
974
975 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
976 supermaven_admin_api_key.to_string(),
977 http_client.clone(),
978 )));
979
980 let session = Session {
981 principal: principal.clone(),
982 connection_id,
983 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
984 peer: this.peer.clone(),
985 connection_pool: this.connection_pool.clone(),
986 app_state: this.app_state.clone(),
987 http_client,
988 geoip_country_code,
989 _executor: executor.clone(),
990 supermaven_client,
991 };
992
993 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
994 tracing::error!(?error, "failed to send initial client update");
995 return;
996 }
997
998 let handle_io = handle_io.fuse();
999 futures::pin_mut!(handle_io);
1000
1001 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1002 // This prevents deadlocks when e.g., client A performs a request to client B and
1003 // client B performs a request to client A. If both clients stop processing further
1004 // messages until their respective request completes, they won't have a chance to
1005 // respond to the other client's request and cause a deadlock.
1006 //
1007 // This arrangement ensures we will attempt to process earlier messages first, but fall
1008 // back to processing messages arrived later in the spirit of making progress.
1009 let mut foreground_message_handlers = FuturesUnordered::new();
1010 let concurrent_handlers = Arc::new(Semaphore::new(256));
1011 loop {
1012 let next_message = async {
1013 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1014 let message = incoming_rx.next().await;
1015 (permit, message)
1016 }.fuse();
1017 futures::pin_mut!(next_message);
1018 futures::select_biased! {
1019 _ = teardown.changed().fuse() => return,
1020 result = handle_io => {
1021 if let Err(error) = result {
1022 tracing::error!(?error, "error handling I/O");
1023 }
1024 break;
1025 }
1026 _ = foreground_message_handlers.next() => {}
1027 next_message = next_message => {
1028 let (permit, message) = next_message;
1029 if let Some(message) = message {
1030 let type_name = message.payload_type_name();
1031 // note: we copy all the fields from the parent span so we can query them in the logs.
1032 // (https://github.com/tokio-rs/tracing/issues/2670).
1033 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1034 user_id=field::Empty,
1035 login=field::Empty,
1036 impersonator=field::Empty,
1037 dev_server_id=field::Empty
1038 );
1039 principal.update_span(&span);
1040 let span_enter = span.enter();
1041 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1042 let is_background = message.is_background();
1043 let handle_message = (handler)(message, session.clone());
1044 drop(span_enter);
1045
1046 let handle_message = async move {
1047 handle_message.await;
1048 drop(permit);
1049 }.instrument(span);
1050 if is_background {
1051 executor.spawn_detached(handle_message);
1052 } else {
1053 foreground_message_handlers.push(handle_message);
1054 }
1055 } else {
1056 tracing::error!("no message handler");
1057 }
1058 } else {
1059 tracing::info!("connection closed");
1060 break;
1061 }
1062 }
1063 }
1064 }
1065
1066 drop(foreground_message_handlers);
1067 tracing::info!("signing out");
1068 if let Err(error) = connection_lost(session, teardown, executor).await {
1069 tracing::error!(?error, "error signing out");
1070 }
1071
1072 }.instrument(span)
1073 }
1074
1075 async fn send_initial_client_update(
1076 &self,
1077 connection_id: ConnectionId,
1078 principal: &Principal,
1079 zed_version: ZedVersion,
1080 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1081 session: &Session,
1082 ) -> Result<()> {
1083 self.peer.send(
1084 connection_id,
1085 proto::Hello {
1086 peer_id: Some(connection_id.into()),
1087 },
1088 )?;
1089 tracing::info!("sent hello message");
1090 if let Some(send_connection_id) = send_connection_id.take() {
1091 let _ = send_connection_id.send(connection_id);
1092 }
1093
1094 match principal {
1095 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1096 if !user.connected_once {
1097 self.peer.send(connection_id, proto::ShowContacts {})?;
1098 self.app_state
1099 .db
1100 .set_user_connected_once(user.id, true)
1101 .await?;
1102 }
1103
1104 update_user_plan(user.id, session).await?;
1105
1106 let (contacts, dev_server_projects) = future::try_join(
1107 self.app_state.db.get_contacts(user.id),
1108 self.app_state.db.dev_server_projects_update(user.id),
1109 )
1110 .await?;
1111
1112 {
1113 let mut pool = self.connection_pool.lock();
1114 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1115 self.peer.send(
1116 connection_id,
1117 build_initial_contacts_update(contacts, &pool),
1118 )?;
1119 }
1120
1121 if should_auto_subscribe_to_channels(zed_version) {
1122 subscribe_user_to_channels(user.id, session).await?;
1123 }
1124
1125 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1126
1127 if let Some(incoming_call) =
1128 self.app_state.db.incoming_call_for_user(user.id).await?
1129 {
1130 self.peer.send(connection_id, incoming_call)?;
1131 }
1132
1133 update_user_contacts(user.id, session).await?;
1134 }
1135 Principal::DevServer(dev_server) => {
1136 {
1137 let mut pool = self.connection_pool.lock();
1138 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1139 {
1140 self.peer.send(
1141 stale_connection_id,
1142 proto::ShutdownDevServer {
1143 reason: Some(
1144 "another dev server connected with the same token".to_string(),
1145 ),
1146 },
1147 )?;
1148 pool.remove_connection(stale_connection_id)?;
1149 };
1150 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1151 }
1152
1153 let projects = self
1154 .app_state
1155 .db
1156 .get_projects_for_dev_server(dev_server.id)
1157 .await?;
1158 self.peer
1159 .send(connection_id, proto::DevServerInstructions { projects })?;
1160
1161 let status = self
1162 .app_state
1163 .db
1164 .dev_server_projects_update(dev_server.user_id)
1165 .await?;
1166 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1167 }
1168 }
1169
1170 Ok(())
1171 }
1172
1173 pub async fn invite_code_redeemed(
1174 self: &Arc<Self>,
1175 inviter_id: UserId,
1176 invitee_id: UserId,
1177 ) -> Result<()> {
1178 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1179 if let Some(code) = &user.invite_code {
1180 let pool = self.connection_pool.lock();
1181 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1182 for connection_id in pool.user_connection_ids(inviter_id) {
1183 self.peer.send(
1184 connection_id,
1185 proto::UpdateContacts {
1186 contacts: vec![invitee_contact.clone()],
1187 ..Default::default()
1188 },
1189 )?;
1190 self.peer.send(
1191 connection_id,
1192 proto::UpdateInviteInfo {
1193 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1194 count: user.invite_count as u32,
1195 },
1196 )?;
1197 }
1198 }
1199 }
1200 Ok(())
1201 }
1202
1203 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1204 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1205 if let Some(invite_code) = &user.invite_code {
1206 let pool = self.connection_pool.lock();
1207 for connection_id in pool.user_connection_ids(user_id) {
1208 self.peer.send(
1209 connection_id,
1210 proto::UpdateInviteInfo {
1211 url: format!(
1212 "{}{}",
1213 self.app_state.config.invite_link_prefix, invite_code
1214 ),
1215 count: user.invite_count as u32,
1216 },
1217 )?;
1218 }
1219 }
1220 }
1221 Ok(())
1222 }
1223
1224 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1225 ServerSnapshot {
1226 connection_pool: ConnectionPoolGuard {
1227 guard: self.connection_pool.lock(),
1228 _not_send: PhantomData,
1229 },
1230 peer: &self.peer,
1231 }
1232 }
1233}
1234
1235impl<'a> Deref for ConnectionPoolGuard<'a> {
1236 type Target = ConnectionPool;
1237
1238 fn deref(&self) -> &Self::Target {
1239 &self.guard
1240 }
1241}
1242
1243impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1244 fn deref_mut(&mut self) -> &mut Self::Target {
1245 &mut self.guard
1246 }
1247}
1248
1249impl<'a> Drop for ConnectionPoolGuard<'a> {
1250 fn drop(&mut self) {
1251 #[cfg(test)]
1252 self.check_invariants();
1253 }
1254}
1255
1256fn broadcast<F>(
1257 sender_id: Option<ConnectionId>,
1258 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1259 mut f: F,
1260) where
1261 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1262{
1263 for receiver_id in receiver_ids {
1264 if Some(receiver_id) != sender_id {
1265 if let Err(error) = f(receiver_id) {
1266 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1267 }
1268 }
1269 }
1270}
1271
1272pub struct ProtocolVersion(u32);
1273
1274impl Header for ProtocolVersion {
1275 fn name() -> &'static HeaderName {
1276 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1277 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1278 }
1279
1280 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1281 where
1282 Self: Sized,
1283 I: Iterator<Item = &'i axum::http::HeaderValue>,
1284 {
1285 let version = values
1286 .next()
1287 .ok_or_else(axum::headers::Error::invalid)?
1288 .to_str()
1289 .map_err(|_| axum::headers::Error::invalid())?
1290 .parse()
1291 .map_err(|_| axum::headers::Error::invalid())?;
1292 Ok(Self(version))
1293 }
1294
1295 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1296 values.extend([self.0.to_string().parse().unwrap()]);
1297 }
1298}
1299
1300pub struct AppVersionHeader(SemanticVersion);
1301impl Header for AppVersionHeader {
1302 fn name() -> &'static HeaderName {
1303 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1304 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1305 }
1306
1307 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1308 where
1309 Self: Sized,
1310 I: Iterator<Item = &'i axum::http::HeaderValue>,
1311 {
1312 let version = values
1313 .next()
1314 .ok_or_else(axum::headers::Error::invalid)?
1315 .to_str()
1316 .map_err(|_| axum::headers::Error::invalid())?
1317 .parse()
1318 .map_err(|_| axum::headers::Error::invalid())?;
1319 Ok(Self(version))
1320 }
1321
1322 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1323 values.extend([self.0.to_string().parse().unwrap()]);
1324 }
1325}
1326
1327pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1328 Router::new()
1329 .route("/rpc", get(handle_websocket_request))
1330 .layer(
1331 ServiceBuilder::new()
1332 .layer(Extension(server.app_state.clone()))
1333 .layer(middleware::from_fn(auth::validate_header)),
1334 )
1335 .route("/metrics", get(handle_metrics))
1336 .layer(Extension(server))
1337}
1338
1339pub async fn handle_websocket_request(
1340 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1341 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1342 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1343 Extension(server): Extension<Arc<Server>>,
1344 Extension(principal): Extension<Principal>,
1345 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1346 ws: WebSocketUpgrade,
1347) -> axum::response::Response {
1348 if protocol_version != rpc::PROTOCOL_VERSION {
1349 return (
1350 StatusCode::UPGRADE_REQUIRED,
1351 "client must be upgraded".to_string(),
1352 )
1353 .into_response();
1354 }
1355
1356 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1357 return (
1358 StatusCode::UPGRADE_REQUIRED,
1359 "no version header found".to_string(),
1360 )
1361 .into_response();
1362 };
1363
1364 if !version.can_collaborate() {
1365 return (
1366 StatusCode::UPGRADE_REQUIRED,
1367 "client must be upgraded".to_string(),
1368 )
1369 .into_response();
1370 }
1371
1372 let socket_address = socket_address.to_string();
1373 ws.on_upgrade(move |socket| {
1374 let socket = socket
1375 .map_ok(to_tungstenite_message)
1376 .err_into()
1377 .with(|message| async move { to_axum_message(message) });
1378 let connection = Connection::new(Box::pin(socket));
1379 async move {
1380 server
1381 .handle_connection(
1382 connection,
1383 socket_address,
1384 principal,
1385 version,
1386 country_code_header.map(|header| header.to_string()),
1387 None,
1388 Executor::Production,
1389 )
1390 .await;
1391 }
1392 })
1393}
1394
1395pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1396 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1397 let connections_metric = CONNECTIONS_METRIC
1398 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1399
1400 let connections = server
1401 .connection_pool
1402 .lock()
1403 .connections()
1404 .filter(|connection| !connection.admin)
1405 .count();
1406 connections_metric.set(connections as _);
1407
1408 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1409 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1410 register_int_gauge!(
1411 "shared_projects",
1412 "number of open projects with one or more guests"
1413 )
1414 .unwrap()
1415 });
1416
1417 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1418 shared_projects_metric.set(shared_projects as _);
1419
1420 let encoder = prometheus::TextEncoder::new();
1421 let metric_families = prometheus::gather();
1422 let encoded_metrics = encoder
1423 .encode_to_string(&metric_families)
1424 .map_err(|err| anyhow!("{}", err))?;
1425 Ok(encoded_metrics)
1426}
1427
1428#[instrument(err, skip(executor))]
1429async fn connection_lost(
1430 session: Session,
1431 mut teardown: watch::Receiver<bool>,
1432 executor: Executor,
1433) -> Result<()> {
1434 session.peer.disconnect(session.connection_id);
1435 session
1436 .connection_pool()
1437 .await
1438 .remove_connection(session.connection_id)?;
1439
1440 session
1441 .db()
1442 .await
1443 .connection_lost(session.connection_id)
1444 .await
1445 .trace_err();
1446
1447 futures::select_biased! {
1448 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1449 match &session.principal {
1450 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1451 let session = session.for_user().unwrap();
1452
1453 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1454 leave_room_for_session(&session, session.connection_id).await.trace_err();
1455 leave_channel_buffers_for_session(&session)
1456 .await
1457 .trace_err();
1458
1459 if !session
1460 .connection_pool()
1461 .await
1462 .is_user_online(session.user_id())
1463 {
1464 let db = session.db().await;
1465 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1466 room_updated(&room, &session.peer);
1467 }
1468 }
1469
1470 update_user_contacts(session.user_id(), &session).await?;
1471 },
1472 Principal::DevServer(_) => {
1473 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1474 },
1475 }
1476 },
1477 _ = teardown.changed().fuse() => {}
1478 }
1479
1480 Ok(())
1481}
1482
1483/// Acknowledges a ping from a client, used to keep the connection alive.
1484async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1485 response.send(proto::Ack {})?;
1486 Ok(())
1487}
1488
1489/// Creates a new room for calling (outside of channels)
1490async fn create_room(
1491 _request: proto::CreateRoom,
1492 response: Response<proto::CreateRoom>,
1493 session: UserSession,
1494) -> Result<()> {
1495 let live_kit_room = nanoid::nanoid!(30);
1496
1497 let live_kit_connection_info = util::maybe!(async {
1498 let live_kit = session.app_state.live_kit_client.as_ref();
1499 let live_kit = live_kit?;
1500 let user_id = session.user_id().to_string();
1501
1502 let token = live_kit
1503 .room_token(&live_kit_room, &user_id.to_string())
1504 .trace_err()?;
1505
1506 Some(proto::LiveKitConnectionInfo {
1507 server_url: live_kit.url().into(),
1508 token,
1509 can_publish: true,
1510 })
1511 })
1512 .await;
1513
1514 let room = session
1515 .db()
1516 .await
1517 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1518 .await?;
1519
1520 response.send(proto::CreateRoomResponse {
1521 room: Some(room.clone()),
1522 live_kit_connection_info,
1523 })?;
1524
1525 update_user_contacts(session.user_id(), &session).await?;
1526 Ok(())
1527}
1528
1529/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1530async fn join_room(
1531 request: proto::JoinRoom,
1532 response: Response<proto::JoinRoom>,
1533 session: UserSession,
1534) -> Result<()> {
1535 let room_id = RoomId::from_proto(request.id);
1536
1537 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1538
1539 if let Some(channel_id) = channel_id {
1540 return join_channel_internal(channel_id, Box::new(response), session).await;
1541 }
1542
1543 let joined_room = {
1544 let room = session
1545 .db()
1546 .await
1547 .join_room(room_id, session.user_id(), session.connection_id)
1548 .await?;
1549 room_updated(&room.room, &session.peer);
1550 room.into_inner()
1551 };
1552
1553 for connection_id in session
1554 .connection_pool()
1555 .await
1556 .user_connection_ids(session.user_id())
1557 {
1558 session
1559 .peer
1560 .send(
1561 connection_id,
1562 proto::CallCanceled {
1563 room_id: room_id.to_proto(),
1564 },
1565 )
1566 .trace_err();
1567 }
1568
1569 let live_kit_connection_info =
1570 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1571 live_kit
1572 .room_token(
1573 &joined_room.room.live_kit_room,
1574 &session.user_id().to_string(),
1575 )
1576 .trace_err()
1577 .map(|token| proto::LiveKitConnectionInfo {
1578 server_url: live_kit.url().into(),
1579 token,
1580 can_publish: true,
1581 })
1582 } else {
1583 None
1584 };
1585
1586 response.send(proto::JoinRoomResponse {
1587 room: Some(joined_room.room),
1588 channel_id: None,
1589 live_kit_connection_info,
1590 })?;
1591
1592 update_user_contacts(session.user_id(), &session).await?;
1593 Ok(())
1594}
1595
1596/// Rejoin room is used to reconnect to a room after connection errors.
1597async fn rejoin_room(
1598 request: proto::RejoinRoom,
1599 response: Response<proto::RejoinRoom>,
1600 session: UserSession,
1601) -> Result<()> {
1602 let room;
1603 let channel;
1604 {
1605 let mut rejoined_room = session
1606 .db()
1607 .await
1608 .rejoin_room(request, session.user_id(), session.connection_id)
1609 .await?;
1610
1611 response.send(proto::RejoinRoomResponse {
1612 room: Some(rejoined_room.room.clone()),
1613 reshared_projects: rejoined_room
1614 .reshared_projects
1615 .iter()
1616 .map(|project| proto::ResharedProject {
1617 id: project.id.to_proto(),
1618 collaborators: project
1619 .collaborators
1620 .iter()
1621 .map(|collaborator| collaborator.to_proto())
1622 .collect(),
1623 })
1624 .collect(),
1625 rejoined_projects: rejoined_room
1626 .rejoined_projects
1627 .iter()
1628 .map(|rejoined_project| rejoined_project.to_proto())
1629 .collect(),
1630 })?;
1631 room_updated(&rejoined_room.room, &session.peer);
1632
1633 for project in &rejoined_room.reshared_projects {
1634 for collaborator in &project.collaborators {
1635 session
1636 .peer
1637 .send(
1638 collaborator.connection_id,
1639 proto::UpdateProjectCollaborator {
1640 project_id: project.id.to_proto(),
1641 old_peer_id: Some(project.old_connection_id.into()),
1642 new_peer_id: Some(session.connection_id.into()),
1643 },
1644 )
1645 .trace_err();
1646 }
1647
1648 broadcast(
1649 Some(session.connection_id),
1650 project
1651 .collaborators
1652 .iter()
1653 .map(|collaborator| collaborator.connection_id),
1654 |connection_id| {
1655 session.peer.forward_send(
1656 session.connection_id,
1657 connection_id,
1658 proto::UpdateProject {
1659 project_id: project.id.to_proto(),
1660 worktrees: project.worktrees.clone(),
1661 },
1662 )
1663 },
1664 );
1665 }
1666
1667 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1668
1669 let rejoined_room = rejoined_room.into_inner();
1670
1671 room = rejoined_room.room;
1672 channel = rejoined_room.channel;
1673 }
1674
1675 if let Some(channel) = channel {
1676 channel_updated(
1677 &channel,
1678 &room,
1679 &session.peer,
1680 &*session.connection_pool().await,
1681 );
1682 }
1683
1684 update_user_contacts(session.user_id(), &session).await?;
1685 Ok(())
1686}
1687
1688fn notify_rejoined_projects(
1689 rejoined_projects: &mut Vec<RejoinedProject>,
1690 session: &UserSession,
1691) -> Result<()> {
1692 for project in rejoined_projects.iter() {
1693 for collaborator in &project.collaborators {
1694 session
1695 .peer
1696 .send(
1697 collaborator.connection_id,
1698 proto::UpdateProjectCollaborator {
1699 project_id: project.id.to_proto(),
1700 old_peer_id: Some(project.old_connection_id.into()),
1701 new_peer_id: Some(session.connection_id.into()),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 }
1707
1708 for project in rejoined_projects {
1709 for worktree in mem::take(&mut project.worktrees) {
1710 #[cfg(any(test, feature = "test-support"))]
1711 const MAX_CHUNK_SIZE: usize = 2;
1712 #[cfg(not(any(test, feature = "test-support")))]
1713 const MAX_CHUNK_SIZE: usize = 256;
1714
1715 // Stream this worktree's entries.
1716 let message = proto::UpdateWorktree {
1717 project_id: project.id.to_proto(),
1718 worktree_id: worktree.id,
1719 abs_path: worktree.abs_path.clone(),
1720 root_name: worktree.root_name,
1721 updated_entries: worktree.updated_entries,
1722 removed_entries: worktree.removed_entries,
1723 scan_id: worktree.scan_id,
1724 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1725 updated_repositories: worktree.updated_repositories,
1726 removed_repositories: worktree.removed_repositories,
1727 };
1728 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1729 session.peer.send(session.connection_id, update.clone())?;
1730 }
1731
1732 // Stream this worktree's diagnostics.
1733 for summary in worktree.diagnostic_summaries {
1734 session.peer.send(
1735 session.connection_id,
1736 proto::UpdateDiagnosticSummary {
1737 project_id: project.id.to_proto(),
1738 worktree_id: worktree.id,
1739 summary: Some(summary),
1740 },
1741 )?;
1742 }
1743
1744 for settings_file in worktree.settings_files {
1745 session.peer.send(
1746 session.connection_id,
1747 proto::UpdateWorktreeSettings {
1748 project_id: project.id.to_proto(),
1749 worktree_id: worktree.id,
1750 path: settings_file.path,
1751 content: Some(settings_file.content),
1752 kind: Some(settings_file.kind.to_proto().into()),
1753 },
1754 )?;
1755 }
1756 }
1757
1758 for language_server in &project.language_servers {
1759 session.peer.send(
1760 session.connection_id,
1761 proto::UpdateLanguageServer {
1762 project_id: project.id.to_proto(),
1763 language_server_id: language_server.id,
1764 variant: Some(
1765 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1766 proto::LspDiskBasedDiagnosticsUpdated {},
1767 ),
1768 ),
1769 },
1770 )?;
1771 }
1772 }
1773 Ok(())
1774}
1775
1776/// leave room disconnects from the room.
1777async fn leave_room(
1778 _: proto::LeaveRoom,
1779 response: Response<proto::LeaveRoom>,
1780 session: UserSession,
1781) -> Result<()> {
1782 leave_room_for_session(&session, session.connection_id).await?;
1783 response.send(proto::Ack {})?;
1784 Ok(())
1785}
1786
1787/// Updates the permissions of someone else in the room.
1788async fn set_room_participant_role(
1789 request: proto::SetRoomParticipantRole,
1790 response: Response<proto::SetRoomParticipantRole>,
1791 session: UserSession,
1792) -> Result<()> {
1793 let user_id = UserId::from_proto(request.user_id);
1794 let role = ChannelRole::from(request.role());
1795
1796 let (live_kit_room, can_publish) = {
1797 let room = session
1798 .db()
1799 .await
1800 .set_room_participant_role(
1801 session.user_id(),
1802 RoomId::from_proto(request.room_id),
1803 user_id,
1804 role,
1805 )
1806 .await?;
1807
1808 let live_kit_room = room.live_kit_room.clone();
1809 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1810 room_updated(&room, &session.peer);
1811 (live_kit_room, can_publish)
1812 };
1813
1814 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1815 live_kit
1816 .update_participant(
1817 live_kit_room.clone(),
1818 request.user_id.to_string(),
1819 live_kit_server::proto::ParticipantPermission {
1820 can_subscribe: true,
1821 can_publish,
1822 can_publish_data: can_publish,
1823 hidden: false,
1824 recorder: false,
1825 },
1826 )
1827 .await
1828 .trace_err();
1829 }
1830
1831 response.send(proto::Ack {})?;
1832 Ok(())
1833}
1834
1835/// Call someone else into the current room
1836async fn call(
1837 request: proto::Call,
1838 response: Response<proto::Call>,
1839 session: UserSession,
1840) -> Result<()> {
1841 let room_id = RoomId::from_proto(request.room_id);
1842 let calling_user_id = session.user_id();
1843 let calling_connection_id = session.connection_id;
1844 let called_user_id = UserId::from_proto(request.called_user_id);
1845 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1846 if !session
1847 .db()
1848 .await
1849 .has_contact(calling_user_id, called_user_id)
1850 .await?
1851 {
1852 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1853 }
1854
1855 let incoming_call = {
1856 let (room, incoming_call) = &mut *session
1857 .db()
1858 .await
1859 .call(
1860 room_id,
1861 calling_user_id,
1862 calling_connection_id,
1863 called_user_id,
1864 initial_project_id,
1865 )
1866 .await?;
1867 room_updated(room, &session.peer);
1868 mem::take(incoming_call)
1869 };
1870 update_user_contacts(called_user_id, &session).await?;
1871
1872 let mut calls = session
1873 .connection_pool()
1874 .await
1875 .user_connection_ids(called_user_id)
1876 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1877 .collect::<FuturesUnordered<_>>();
1878
1879 while let Some(call_response) = calls.next().await {
1880 match call_response.as_ref() {
1881 Ok(_) => {
1882 response.send(proto::Ack {})?;
1883 return Ok(());
1884 }
1885 Err(_) => {
1886 call_response.trace_err();
1887 }
1888 }
1889 }
1890
1891 {
1892 let room = session
1893 .db()
1894 .await
1895 .call_failed(room_id, called_user_id)
1896 .await?;
1897 room_updated(&room, &session.peer);
1898 }
1899 update_user_contacts(called_user_id, &session).await?;
1900
1901 Err(anyhow!("failed to ring user"))?
1902}
1903
1904/// Cancel an outgoing call.
1905async fn cancel_call(
1906 request: proto::CancelCall,
1907 response: Response<proto::CancelCall>,
1908 session: UserSession,
1909) -> Result<()> {
1910 let called_user_id = UserId::from_proto(request.called_user_id);
1911 let room_id = RoomId::from_proto(request.room_id);
1912 {
1913 let room = session
1914 .db()
1915 .await
1916 .cancel_call(room_id, session.connection_id, called_user_id)
1917 .await?;
1918 room_updated(&room, &session.peer);
1919 }
1920
1921 for connection_id in session
1922 .connection_pool()
1923 .await
1924 .user_connection_ids(called_user_id)
1925 {
1926 session
1927 .peer
1928 .send(
1929 connection_id,
1930 proto::CallCanceled {
1931 room_id: room_id.to_proto(),
1932 },
1933 )
1934 .trace_err();
1935 }
1936 response.send(proto::Ack {})?;
1937
1938 update_user_contacts(called_user_id, &session).await?;
1939 Ok(())
1940}
1941
1942/// Decline an incoming call.
1943async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1944 let room_id = RoomId::from_proto(message.room_id);
1945 {
1946 let room = session
1947 .db()
1948 .await
1949 .decline_call(Some(room_id), session.user_id())
1950 .await?
1951 .ok_or_else(|| anyhow!("failed to decline call"))?;
1952 room_updated(&room, &session.peer);
1953 }
1954
1955 for connection_id in session
1956 .connection_pool()
1957 .await
1958 .user_connection_ids(session.user_id())
1959 {
1960 session
1961 .peer
1962 .send(
1963 connection_id,
1964 proto::CallCanceled {
1965 room_id: room_id.to_proto(),
1966 },
1967 )
1968 .trace_err();
1969 }
1970 update_user_contacts(session.user_id(), &session).await?;
1971 Ok(())
1972}
1973
1974/// Updates other participants in the room with your current location.
1975async fn update_participant_location(
1976 request: proto::UpdateParticipantLocation,
1977 response: Response<proto::UpdateParticipantLocation>,
1978 session: UserSession,
1979) -> Result<()> {
1980 let room_id = RoomId::from_proto(request.room_id);
1981 let location = request
1982 .location
1983 .ok_or_else(|| anyhow!("invalid location"))?;
1984
1985 let db = session.db().await;
1986 let room = db
1987 .update_room_participant_location(room_id, session.connection_id, location)
1988 .await?;
1989
1990 room_updated(&room, &session.peer);
1991 response.send(proto::Ack {})?;
1992 Ok(())
1993}
1994
1995/// Share a project into the room.
1996async fn share_project(
1997 request: proto::ShareProject,
1998 response: Response<proto::ShareProject>,
1999 session: UserSession,
2000) -> Result<()> {
2001 let (project_id, room) = &*session
2002 .db()
2003 .await
2004 .share_project(
2005 RoomId::from_proto(request.room_id),
2006 session.connection_id,
2007 &request.worktrees,
2008 request.is_ssh_project,
2009 request
2010 .dev_server_project_id
2011 .map(DevServerProjectId::from_proto),
2012 )
2013 .await?;
2014 response.send(proto::ShareProjectResponse {
2015 project_id: project_id.to_proto(),
2016 })?;
2017 room_updated(room, &session.peer);
2018
2019 Ok(())
2020}
2021
2022/// Unshare a project from the room.
2023async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2024 let project_id = ProjectId::from_proto(message.project_id);
2025 unshare_project_internal(
2026 project_id,
2027 session.connection_id,
2028 session.user_id(),
2029 &session,
2030 )
2031 .await
2032}
2033
2034async fn unshare_project_internal(
2035 project_id: ProjectId,
2036 connection_id: ConnectionId,
2037 user_id: Option<UserId>,
2038 session: &Session,
2039) -> Result<()> {
2040 let delete = {
2041 let room_guard = session
2042 .db()
2043 .await
2044 .unshare_project(project_id, connection_id, user_id)
2045 .await?;
2046
2047 let (delete, room, guest_connection_ids) = &*room_guard;
2048
2049 let message = proto::UnshareProject {
2050 project_id: project_id.to_proto(),
2051 };
2052
2053 broadcast(
2054 Some(connection_id),
2055 guest_connection_ids.iter().copied(),
2056 |conn_id| session.peer.send(conn_id, message.clone()),
2057 );
2058 if let Some(room) = room {
2059 room_updated(room, &session.peer);
2060 }
2061
2062 *delete
2063 };
2064
2065 if delete {
2066 let db = session.db().await;
2067 db.delete_project(project_id).await?;
2068 }
2069
2070 Ok(())
2071}
2072
2073/// DevServer makes a project available online
2074async fn share_dev_server_project(
2075 request: proto::ShareDevServerProject,
2076 response: Response<proto::ShareDevServerProject>,
2077 session: DevServerSession,
2078) -> Result<()> {
2079 let (dev_server_project, user_id, status) = session
2080 .db()
2081 .await
2082 .share_dev_server_project(
2083 DevServerProjectId::from_proto(request.dev_server_project_id),
2084 session.dev_server_id(),
2085 session.connection_id,
2086 &request.worktrees,
2087 )
2088 .await?;
2089 let Some(project_id) = dev_server_project.project_id else {
2090 return Err(anyhow!("failed to share remote project"))?;
2091 };
2092
2093 send_dev_server_projects_update(user_id, status, &session).await;
2094
2095 response.send(proto::ShareProjectResponse { project_id })?;
2096
2097 Ok(())
2098}
2099
2100/// Join someone elses shared project.
2101async fn join_project(
2102 request: proto::JoinProject,
2103 response: Response<proto::JoinProject>,
2104 session: UserSession,
2105) -> Result<()> {
2106 let project_id = ProjectId::from_proto(request.project_id);
2107
2108 tracing::info!(%project_id, "join project");
2109
2110 let db = session.db().await;
2111 let (project, replica_id) = &mut *db
2112 .join_project(project_id, session.connection_id, session.user_id())
2113 .await?;
2114 drop(db);
2115 tracing::info!(%project_id, "join remote project");
2116 join_project_internal(response, session, project, replica_id)
2117}
2118
2119trait JoinProjectInternalResponse {
2120 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2123 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124 Response::<proto::JoinProject>::send(self, result)
2125 }
2126}
2127impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2128 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2129 Response::<proto::JoinHostedProject>::send(self, result)
2130 }
2131}
2132
2133fn join_project_internal(
2134 response: impl JoinProjectInternalResponse,
2135 session: UserSession,
2136 project: &mut Project,
2137 replica_id: &ReplicaId,
2138) -> Result<()> {
2139 let collaborators = project
2140 .collaborators
2141 .iter()
2142 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2143 .map(|collaborator| collaborator.to_proto())
2144 .collect::<Vec<_>>();
2145 let project_id = project.id;
2146 let guest_user_id = session.user_id();
2147
2148 let worktrees = project
2149 .worktrees
2150 .iter()
2151 .map(|(id, worktree)| proto::WorktreeMetadata {
2152 id: *id,
2153 root_name: worktree.root_name.clone(),
2154 visible: worktree.visible,
2155 abs_path: worktree.abs_path.clone(),
2156 })
2157 .collect::<Vec<_>>();
2158
2159 let add_project_collaborator = proto::AddProjectCollaborator {
2160 project_id: project_id.to_proto(),
2161 collaborator: Some(proto::Collaborator {
2162 peer_id: Some(session.connection_id.into()),
2163 replica_id: replica_id.0 as u32,
2164 user_id: guest_user_id.to_proto(),
2165 }),
2166 };
2167
2168 for collaborator in &collaborators {
2169 session
2170 .peer
2171 .send(
2172 collaborator.peer_id.unwrap().into(),
2173 add_project_collaborator.clone(),
2174 )
2175 .trace_err();
2176 }
2177
2178 // First, we send the metadata associated with each worktree.
2179 response.send(proto::JoinProjectResponse {
2180 project_id: project.id.0 as u64,
2181 worktrees: worktrees.clone(),
2182 replica_id: replica_id.0 as u32,
2183 collaborators: collaborators.clone(),
2184 language_servers: project.language_servers.clone(),
2185 role: project.role.into(),
2186 dev_server_project_id: project
2187 .dev_server_project_id
2188 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2189 })?;
2190
2191 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2192 #[cfg(any(test, feature = "test-support"))]
2193 const MAX_CHUNK_SIZE: usize = 2;
2194 #[cfg(not(any(test, feature = "test-support")))]
2195 const MAX_CHUNK_SIZE: usize = 256;
2196
2197 // Stream this worktree's entries.
2198 let message = proto::UpdateWorktree {
2199 project_id: project_id.to_proto(),
2200 worktree_id,
2201 abs_path: worktree.abs_path.clone(),
2202 root_name: worktree.root_name,
2203 updated_entries: worktree.entries,
2204 removed_entries: Default::default(),
2205 scan_id: worktree.scan_id,
2206 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2207 updated_repositories: worktree.repository_entries.into_values().collect(),
2208 removed_repositories: Default::default(),
2209 };
2210 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2211 session.peer.send(session.connection_id, update.clone())?;
2212 }
2213
2214 // Stream this worktree's diagnostics.
2215 for summary in worktree.diagnostic_summaries {
2216 session.peer.send(
2217 session.connection_id,
2218 proto::UpdateDiagnosticSummary {
2219 project_id: project_id.to_proto(),
2220 worktree_id: worktree.id,
2221 summary: Some(summary),
2222 },
2223 )?;
2224 }
2225
2226 for settings_file in worktree.settings_files {
2227 session.peer.send(
2228 session.connection_id,
2229 proto::UpdateWorktreeSettings {
2230 project_id: project_id.to_proto(),
2231 worktree_id: worktree.id,
2232 path: settings_file.path,
2233 content: Some(settings_file.content),
2234 kind: Some(proto::update_user_settings::Kind::Settings.into()),
2235 },
2236 )?;
2237 }
2238 }
2239
2240 for language_server in &project.language_servers {
2241 session.peer.send(
2242 session.connection_id,
2243 proto::UpdateLanguageServer {
2244 project_id: project_id.to_proto(),
2245 language_server_id: language_server.id,
2246 variant: Some(
2247 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2248 proto::LspDiskBasedDiagnosticsUpdated {},
2249 ),
2250 ),
2251 },
2252 )?;
2253 }
2254
2255 Ok(())
2256}
2257
2258/// Leave someone elses shared project.
2259async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2260 let sender_id = session.connection_id;
2261 let project_id = ProjectId::from_proto(request.project_id);
2262 let db = session.db().await;
2263 if db.is_hosted_project(project_id).await? {
2264 let project = db.leave_hosted_project(project_id, sender_id).await?;
2265 project_left(&project, &session);
2266 return Ok(());
2267 }
2268
2269 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2270 tracing::info!(
2271 %project_id,
2272 "leave project"
2273 );
2274
2275 project_left(project, &session);
2276 if let Some(room) = room {
2277 room_updated(room, &session.peer);
2278 }
2279
2280 Ok(())
2281}
2282
2283async fn join_hosted_project(
2284 request: proto::JoinHostedProject,
2285 response: Response<proto::JoinHostedProject>,
2286 session: UserSession,
2287) -> Result<()> {
2288 let (mut project, replica_id) = session
2289 .db()
2290 .await
2291 .join_hosted_project(
2292 ProjectId(request.project_id as i32),
2293 session.user_id(),
2294 session.connection_id,
2295 )
2296 .await?;
2297
2298 join_project_internal(response, session, &mut project, &replica_id)
2299}
2300
2301async fn list_remote_directory(
2302 request: proto::ListRemoteDirectory,
2303 response: Response<proto::ListRemoteDirectory>,
2304 session: UserSession,
2305) -> Result<()> {
2306 let dev_server_id = DevServerId(request.dev_server_id as i32);
2307 let dev_server_connection_id = session
2308 .connection_pool()
2309 .await
2310 .online_dev_server_connection_id(dev_server_id)?;
2311
2312 session
2313 .db()
2314 .await
2315 .get_dev_server_for_user(dev_server_id, session.user_id())
2316 .await?;
2317
2318 response.send(
2319 session
2320 .peer
2321 .forward_request(session.connection_id, dev_server_connection_id, request)
2322 .await?,
2323 )?;
2324 Ok(())
2325}
2326
2327async fn update_dev_server_project(
2328 request: proto::UpdateDevServerProject,
2329 response: Response<proto::UpdateDevServerProject>,
2330 session: UserSession,
2331) -> Result<()> {
2332 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2333
2334 let (dev_server_project, update) = session
2335 .db()
2336 .await
2337 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2338 .await?;
2339
2340 let projects = session
2341 .db()
2342 .await
2343 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2344 .await?;
2345
2346 let dev_server_connection_id = session
2347 .connection_pool()
2348 .await
2349 .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2350
2351 session.peer.send(
2352 dev_server_connection_id,
2353 proto::DevServerInstructions { projects },
2354 )?;
2355
2356 send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358 response.send(proto::Ack {})
2359}
2360
2361async fn create_dev_server_project(
2362 request: proto::CreateDevServerProject,
2363 response: Response<proto::CreateDevServerProject>,
2364 session: UserSession,
2365) -> Result<()> {
2366 let dev_server_id = DevServerId(request.dev_server_id as i32);
2367 let dev_server_connection_id = session
2368 .connection_pool()
2369 .await
2370 .dev_server_connection_id(dev_server_id);
2371 let Some(dev_server_connection_id) = dev_server_connection_id else {
2372 Err(ErrorCode::DevServerOffline
2373 .message("Cannot create a remote project when the dev server is offline".to_string())
2374 .anyhow())?
2375 };
2376
2377 let path = request.path.clone();
2378 //Check that the path exists on the dev server
2379 session
2380 .peer
2381 .forward_request(
2382 session.connection_id,
2383 dev_server_connection_id,
2384 proto::ValidateDevServerProjectRequest { path: path.clone() },
2385 )
2386 .await?;
2387
2388 let (dev_server_project, update) = session
2389 .db()
2390 .await
2391 .create_dev_server_project(
2392 DevServerId(request.dev_server_id as i32),
2393 &request.path,
2394 session.user_id(),
2395 )
2396 .await?;
2397
2398 let projects = session
2399 .db()
2400 .await
2401 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2402 .await?;
2403
2404 session.peer.send(
2405 dev_server_connection_id,
2406 proto::DevServerInstructions { projects },
2407 )?;
2408
2409 send_dev_server_projects_update(session.user_id(), update, &session).await;
2410
2411 response.send(proto::CreateDevServerProjectResponse {
2412 dev_server_project: Some(dev_server_project.to_proto(None)),
2413 })?;
2414 Ok(())
2415}
2416
2417async fn create_dev_server(
2418 request: proto::CreateDevServer,
2419 response: Response<proto::CreateDevServer>,
2420 session: UserSession,
2421) -> Result<()> {
2422 let access_token = auth::random_token();
2423 let hashed_access_token = auth::hash_access_token(&access_token);
2424
2425 if request.name.is_empty() {
2426 return Err(proto::ErrorCode::Forbidden
2427 .message("Dev server name cannot be empty".to_string())
2428 .anyhow())?;
2429 }
2430
2431 let (dev_server, status) = session
2432 .db()
2433 .await
2434 .create_dev_server(
2435 &request.name,
2436 request.ssh_connection_string.as_deref(),
2437 &hashed_access_token,
2438 session.user_id(),
2439 )
2440 .await?;
2441
2442 send_dev_server_projects_update(session.user_id(), status, &session).await;
2443
2444 response.send(proto::CreateDevServerResponse {
2445 dev_server_id: dev_server.id.0 as u64,
2446 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2447 name: request.name,
2448 })?;
2449 Ok(())
2450}
2451
2452async fn regenerate_dev_server_token(
2453 request: proto::RegenerateDevServerToken,
2454 response: Response<proto::RegenerateDevServerToken>,
2455 session: UserSession,
2456) -> Result<()> {
2457 let dev_server_id = DevServerId(request.dev_server_id as i32);
2458 let access_token = auth::random_token();
2459 let hashed_access_token = auth::hash_access_token(&access_token);
2460
2461 let connection_id = session
2462 .connection_pool()
2463 .await
2464 .dev_server_connection_id(dev_server_id);
2465 if let Some(connection_id) = connection_id {
2466 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2467 session.peer.send(
2468 connection_id,
2469 proto::ShutdownDevServer {
2470 reason: Some("dev server token was regenerated".to_string()),
2471 },
2472 )?;
2473 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2474 }
2475
2476 let status = session
2477 .db()
2478 .await
2479 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2480 .await?;
2481
2482 send_dev_server_projects_update(session.user_id(), status, &session).await;
2483
2484 response.send(proto::RegenerateDevServerTokenResponse {
2485 dev_server_id: dev_server_id.to_proto(),
2486 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2487 })?;
2488 Ok(())
2489}
2490
2491async fn rename_dev_server(
2492 request: proto::RenameDevServer,
2493 response: Response<proto::RenameDevServer>,
2494 session: UserSession,
2495) -> Result<()> {
2496 if request.name.trim().is_empty() {
2497 return Err(proto::ErrorCode::Forbidden
2498 .message("Dev server name cannot be empty".to_string())
2499 .anyhow())?;
2500 }
2501
2502 let dev_server_id = DevServerId(request.dev_server_id as i32);
2503 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2504 if dev_server.user_id != session.user_id() {
2505 return Err(anyhow!(ErrorCode::Forbidden))?;
2506 }
2507
2508 let status = session
2509 .db()
2510 .await
2511 .rename_dev_server(
2512 dev_server_id,
2513 &request.name,
2514 request.ssh_connection_string.as_deref(),
2515 session.user_id(),
2516 )
2517 .await?;
2518
2519 send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521 response.send(proto::Ack {})?;
2522 Ok(())
2523}
2524
2525async fn delete_dev_server(
2526 request: proto::DeleteDevServer,
2527 response: Response<proto::DeleteDevServer>,
2528 session: UserSession,
2529) -> Result<()> {
2530 let dev_server_id = DevServerId(request.dev_server_id as i32);
2531 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2532 if dev_server.user_id != session.user_id() {
2533 return Err(anyhow!(ErrorCode::Forbidden))?;
2534 }
2535
2536 let connection_id = session
2537 .connection_pool()
2538 .await
2539 .dev_server_connection_id(dev_server_id);
2540 if let Some(connection_id) = connection_id {
2541 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2542 session.peer.send(
2543 connection_id,
2544 proto::ShutdownDevServer {
2545 reason: Some("dev server was deleted".to_string()),
2546 },
2547 )?;
2548 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2549 }
2550
2551 let status = session
2552 .db()
2553 .await
2554 .delete_dev_server(dev_server_id, session.user_id())
2555 .await?;
2556
2557 send_dev_server_projects_update(session.user_id(), status, &session).await;
2558
2559 response.send(proto::Ack {})?;
2560 Ok(())
2561}
2562
2563async fn delete_dev_server_project(
2564 request: proto::DeleteDevServerProject,
2565 response: Response<proto::DeleteDevServerProject>,
2566 session: UserSession,
2567) -> Result<()> {
2568 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2569 let dev_server_project = session
2570 .db()
2571 .await
2572 .get_dev_server_project(dev_server_project_id)
2573 .await?;
2574
2575 let dev_server = session
2576 .db()
2577 .await
2578 .get_dev_server(dev_server_project.dev_server_id)
2579 .await?;
2580 if dev_server.user_id != session.user_id() {
2581 return Err(anyhow!(ErrorCode::Forbidden))?;
2582 }
2583
2584 let dev_server_connection_id = session
2585 .connection_pool()
2586 .await
2587 .dev_server_connection_id(dev_server.id);
2588
2589 if let Some(dev_server_connection_id) = dev_server_connection_id {
2590 let project = session
2591 .db()
2592 .await
2593 .find_dev_server_project(dev_server_project_id)
2594 .await;
2595 if let Ok(project) = project {
2596 unshare_project_internal(
2597 project.id,
2598 dev_server_connection_id,
2599 Some(session.user_id()),
2600 &session,
2601 )
2602 .await?;
2603 }
2604 }
2605
2606 let (projects, status) = session
2607 .db()
2608 .await
2609 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2610 .await?;
2611
2612 if let Some(dev_server_connection_id) = dev_server_connection_id {
2613 session.peer.send(
2614 dev_server_connection_id,
2615 proto::DevServerInstructions { projects },
2616 )?;
2617 }
2618
2619 send_dev_server_projects_update(session.user_id(), status, &session).await;
2620
2621 response.send(proto::Ack {})?;
2622 Ok(())
2623}
2624
2625async fn rejoin_dev_server_projects(
2626 request: proto::RejoinRemoteProjects,
2627 response: Response<proto::RejoinRemoteProjects>,
2628 session: UserSession,
2629) -> Result<()> {
2630 let mut rejoined_projects = {
2631 let db = session.db().await;
2632 db.rejoin_dev_server_projects(
2633 &request.rejoined_projects,
2634 session.user_id(),
2635 session.0.connection_id,
2636 )
2637 .await?
2638 };
2639 response.send(proto::RejoinRemoteProjectsResponse {
2640 rejoined_projects: rejoined_projects
2641 .iter()
2642 .map(|project| project.to_proto())
2643 .collect(),
2644 })?;
2645 notify_rejoined_projects(&mut rejoined_projects, &session)
2646}
2647
2648async fn reconnect_dev_server(
2649 request: proto::ReconnectDevServer,
2650 response: Response<proto::ReconnectDevServer>,
2651 session: DevServerSession,
2652) -> Result<()> {
2653 let reshared_projects = {
2654 let db = session.db().await;
2655 db.reshare_dev_server_projects(
2656 &request.reshared_projects,
2657 session.dev_server_id(),
2658 session.0.connection_id,
2659 )
2660 .await?
2661 };
2662
2663 for project in &reshared_projects {
2664 for collaborator in &project.collaborators {
2665 session
2666 .peer
2667 .send(
2668 collaborator.connection_id,
2669 proto::UpdateProjectCollaborator {
2670 project_id: project.id.to_proto(),
2671 old_peer_id: Some(project.old_connection_id.into()),
2672 new_peer_id: Some(session.connection_id.into()),
2673 },
2674 )
2675 .trace_err();
2676 }
2677
2678 broadcast(
2679 Some(session.connection_id),
2680 project
2681 .collaborators
2682 .iter()
2683 .map(|collaborator| collaborator.connection_id),
2684 |connection_id| {
2685 session.peer.forward_send(
2686 session.connection_id,
2687 connection_id,
2688 proto::UpdateProject {
2689 project_id: project.id.to_proto(),
2690 worktrees: project.worktrees.clone(),
2691 },
2692 )
2693 },
2694 );
2695 }
2696
2697 response.send(proto::ReconnectDevServerResponse {
2698 reshared_projects: reshared_projects
2699 .iter()
2700 .map(|project| proto::ResharedProject {
2701 id: project.id.to_proto(),
2702 collaborators: project
2703 .collaborators
2704 .iter()
2705 .map(|collaborator| collaborator.to_proto())
2706 .collect(),
2707 })
2708 .collect(),
2709 })?;
2710
2711 Ok(())
2712}
2713
2714async fn shutdown_dev_server(
2715 _: proto::ShutdownDevServer,
2716 response: Response<proto::ShutdownDevServer>,
2717 session: DevServerSession,
2718) -> Result<()> {
2719 response.send(proto::Ack {})?;
2720 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2721 remove_dev_server_connection(session.dev_server_id(), &session).await
2722}
2723
2724async fn shutdown_dev_server_internal(
2725 dev_server_id: DevServerId,
2726 connection_id: ConnectionId,
2727 session: &Session,
2728) -> Result<()> {
2729 let (dev_server_projects, dev_server) = {
2730 let db = session.db().await;
2731 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2732 let dev_server = db.get_dev_server(dev_server_id).await?;
2733 (dev_server_projects, dev_server)
2734 };
2735
2736 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2737 unshare_project_internal(
2738 ProjectId::from_proto(project_id),
2739 connection_id,
2740 None,
2741 session,
2742 )
2743 .await?;
2744 }
2745
2746 session
2747 .connection_pool()
2748 .await
2749 .set_dev_server_offline(dev_server_id);
2750
2751 let status = session
2752 .db()
2753 .await
2754 .dev_server_projects_update(dev_server.user_id)
2755 .await?;
2756 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2757
2758 Ok(())
2759}
2760
2761async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2762 let dev_server_connection = session
2763 .connection_pool()
2764 .await
2765 .dev_server_connection_id(dev_server_id);
2766
2767 if let Some(dev_server_connection) = dev_server_connection {
2768 session
2769 .connection_pool()
2770 .await
2771 .remove_connection(dev_server_connection)?;
2772 }
2773 Ok(())
2774}
2775
2776/// Updates other participants with changes to the project
2777async fn update_project(
2778 request: proto::UpdateProject,
2779 response: Response<proto::UpdateProject>,
2780 session: Session,
2781) -> Result<()> {
2782 let project_id = ProjectId::from_proto(request.project_id);
2783 let (room, guest_connection_ids) = &*session
2784 .db()
2785 .await
2786 .update_project(project_id, session.connection_id, &request.worktrees)
2787 .await?;
2788 broadcast(
2789 Some(session.connection_id),
2790 guest_connection_ids.iter().copied(),
2791 |connection_id| {
2792 session
2793 .peer
2794 .forward_send(session.connection_id, connection_id, request.clone())
2795 },
2796 );
2797 if let Some(room) = room {
2798 room_updated(room, &session.peer);
2799 }
2800 response.send(proto::Ack {})?;
2801
2802 Ok(())
2803}
2804
2805/// Updates other participants with changes to the worktree
2806async fn update_worktree(
2807 request: proto::UpdateWorktree,
2808 response: Response<proto::UpdateWorktree>,
2809 session: Session,
2810) -> Result<()> {
2811 let guest_connection_ids = session
2812 .db()
2813 .await
2814 .update_worktree(&request, session.connection_id)
2815 .await?;
2816
2817 broadcast(
2818 Some(session.connection_id),
2819 guest_connection_ids.iter().copied(),
2820 |connection_id| {
2821 session
2822 .peer
2823 .forward_send(session.connection_id, connection_id, request.clone())
2824 },
2825 );
2826 response.send(proto::Ack {})?;
2827 Ok(())
2828}
2829
2830/// Updates other participants with changes to the diagnostics
2831async fn update_diagnostic_summary(
2832 message: proto::UpdateDiagnosticSummary,
2833 session: Session,
2834) -> Result<()> {
2835 let guest_connection_ids = session
2836 .db()
2837 .await
2838 .update_diagnostic_summary(&message, session.connection_id)
2839 .await?;
2840
2841 broadcast(
2842 Some(session.connection_id),
2843 guest_connection_ids.iter().copied(),
2844 |connection_id| {
2845 session
2846 .peer
2847 .forward_send(session.connection_id, connection_id, message.clone())
2848 },
2849 );
2850
2851 Ok(())
2852}
2853
2854/// Updates other participants with changes to the worktree settings
2855async fn update_worktree_settings(
2856 message: proto::UpdateWorktreeSettings,
2857 session: Session,
2858) -> Result<()> {
2859 let guest_connection_ids = session
2860 .db()
2861 .await
2862 .update_worktree_settings(&message, session.connection_id)
2863 .await?;
2864
2865 broadcast(
2866 Some(session.connection_id),
2867 guest_connection_ids.iter().copied(),
2868 |connection_id| {
2869 session
2870 .peer
2871 .forward_send(session.connection_id, connection_id, message.clone())
2872 },
2873 );
2874
2875 Ok(())
2876}
2877
2878/// Notify other participants that a language server has started.
2879async fn start_language_server(
2880 request: proto::StartLanguageServer,
2881 session: Session,
2882) -> Result<()> {
2883 let guest_connection_ids = session
2884 .db()
2885 .await
2886 .start_language_server(&request, session.connection_id)
2887 .await?;
2888
2889 broadcast(
2890 Some(session.connection_id),
2891 guest_connection_ids.iter().copied(),
2892 |connection_id| {
2893 session
2894 .peer
2895 .forward_send(session.connection_id, connection_id, request.clone())
2896 },
2897 );
2898 Ok(())
2899}
2900
2901/// Notify other participants that a language server has changed.
2902async fn update_language_server(
2903 request: proto::UpdateLanguageServer,
2904 session: Session,
2905) -> Result<()> {
2906 let project_id = ProjectId::from_proto(request.project_id);
2907 let project_connection_ids = session
2908 .db()
2909 .await
2910 .project_connection_ids(project_id, session.connection_id, true)
2911 .await?;
2912 broadcast(
2913 Some(session.connection_id),
2914 project_connection_ids.iter().copied(),
2915 |connection_id| {
2916 session
2917 .peer
2918 .forward_send(session.connection_id, connection_id, request.clone())
2919 },
2920 );
2921 Ok(())
2922}
2923
2924/// forward a project request to the host. These requests should be read only
2925/// as guests are allowed to send them.
2926async fn forward_read_only_project_request<T>(
2927 request: T,
2928 response: Response<T>,
2929 session: UserSession,
2930) -> Result<()>
2931where
2932 T: EntityMessage + RequestMessage,
2933{
2934 let project_id = ProjectId::from_proto(request.remote_entity_id());
2935 let host_connection_id = session
2936 .db()
2937 .await
2938 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2939 .await?;
2940 let payload = session
2941 .peer
2942 .forward_request(session.connection_id, host_connection_id, request)
2943 .await?;
2944 response.send(payload)?;
2945 Ok(())
2946}
2947
2948async fn forward_find_search_candidates_request(
2949 request: proto::FindSearchCandidates,
2950 response: Response<proto::FindSearchCandidates>,
2951 session: UserSession,
2952) -> Result<()> {
2953 let project_id = ProjectId::from_proto(request.remote_entity_id());
2954 let host_connection_id = session
2955 .db()
2956 .await
2957 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2958 .await?;
2959 let payload = session
2960 .peer
2961 .forward_request(session.connection_id, host_connection_id, request)
2962 .await?;
2963 response.send(payload)?;
2964 Ok(())
2965}
2966
2967/// forward a project request to the dev server. Only allowed
2968/// if it's your dev server.
2969async fn forward_project_request_for_owner<T>(
2970 request: T,
2971 response: Response<T>,
2972 session: UserSession,
2973) -> Result<()>
2974where
2975 T: EntityMessage + RequestMessage,
2976{
2977 let project_id = ProjectId::from_proto(request.remote_entity_id());
2978
2979 let host_connection_id = session
2980 .db()
2981 .await
2982 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2983 .await?;
2984 let payload = session
2985 .peer
2986 .forward_request(session.connection_id, host_connection_id, request)
2987 .await?;
2988 response.send(payload)?;
2989 Ok(())
2990}
2991
2992/// forward a project request to the host. These requests are disallowed
2993/// for guests.
2994async fn forward_mutating_project_request<T>(
2995 request: T,
2996 response: Response<T>,
2997 session: UserSession,
2998) -> Result<()>
2999where
3000 T: EntityMessage + RequestMessage,
3001{
3002 let project_id = ProjectId::from_proto(request.remote_entity_id());
3003
3004 let host_connection_id = session
3005 .db()
3006 .await
3007 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3008 .await?;
3009 let payload = session
3010 .peer
3011 .forward_request(session.connection_id, host_connection_id, request)
3012 .await?;
3013 response.send(payload)?;
3014 Ok(())
3015}
3016
3017/// Notify other participants that a new buffer has been created
3018async fn create_buffer_for_peer(
3019 request: proto::CreateBufferForPeer,
3020 session: Session,
3021) -> Result<()> {
3022 session
3023 .db()
3024 .await
3025 .check_user_is_project_host(
3026 ProjectId::from_proto(request.project_id),
3027 session.connection_id,
3028 )
3029 .await?;
3030 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3031 session
3032 .peer
3033 .forward_send(session.connection_id, peer_id.into(), request)?;
3034 Ok(())
3035}
3036
3037/// Notify other participants that a buffer has been updated. This is
3038/// allowed for guests as long as the update is limited to selections.
3039async fn update_buffer(
3040 request: proto::UpdateBuffer,
3041 response: Response<proto::UpdateBuffer>,
3042 session: Session,
3043) -> Result<()> {
3044 let project_id = ProjectId::from_proto(request.project_id);
3045 let mut capability = Capability::ReadOnly;
3046
3047 for op in request.operations.iter() {
3048 match op.variant {
3049 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3050 Some(_) => capability = Capability::ReadWrite,
3051 }
3052 }
3053
3054 let host = {
3055 let guard = session
3056 .db()
3057 .await
3058 .connections_for_buffer_update(
3059 project_id,
3060 session.principal_id(),
3061 session.connection_id,
3062 capability,
3063 )
3064 .await?;
3065
3066 let (host, guests) = &*guard;
3067
3068 broadcast(
3069 Some(session.connection_id),
3070 guests.clone(),
3071 |connection_id| {
3072 session
3073 .peer
3074 .forward_send(session.connection_id, connection_id, request.clone())
3075 },
3076 );
3077
3078 *host
3079 };
3080
3081 if host != session.connection_id {
3082 session
3083 .peer
3084 .forward_request(session.connection_id, host, request.clone())
3085 .await?;
3086 }
3087
3088 response.send(proto::Ack {})?;
3089 Ok(())
3090}
3091
3092async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3093 let project_id = ProjectId::from_proto(message.project_id);
3094
3095 let operation = message.operation.as_ref().context("invalid operation")?;
3096 let capability = match operation.variant.as_ref() {
3097 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3098 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3099 match buffer_op.variant {
3100 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3101 Capability::ReadOnly
3102 }
3103 _ => Capability::ReadWrite,
3104 }
3105 } else {
3106 Capability::ReadWrite
3107 }
3108 }
3109 Some(_) => Capability::ReadWrite,
3110 None => Capability::ReadOnly,
3111 };
3112
3113 let guard = session
3114 .db()
3115 .await
3116 .connections_for_buffer_update(
3117 project_id,
3118 session.principal_id(),
3119 session.connection_id,
3120 capability,
3121 )
3122 .await?;
3123
3124 let (host, guests) = &*guard;
3125
3126 broadcast(
3127 Some(session.connection_id),
3128 guests.iter().chain([host]).copied(),
3129 |connection_id| {
3130 session
3131 .peer
3132 .forward_send(session.connection_id, connection_id, message.clone())
3133 },
3134 );
3135
3136 Ok(())
3137}
3138
3139/// Notify other participants that a project has been updated.
3140async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3141 request: T,
3142 session: Session,
3143) -> Result<()> {
3144 let project_id = ProjectId::from_proto(request.remote_entity_id());
3145 let project_connection_ids = session
3146 .db()
3147 .await
3148 .project_connection_ids(project_id, session.connection_id, false)
3149 .await?;
3150
3151 broadcast(
3152 Some(session.connection_id),
3153 project_connection_ids.iter().copied(),
3154 |connection_id| {
3155 session
3156 .peer
3157 .forward_send(session.connection_id, connection_id, request.clone())
3158 },
3159 );
3160 Ok(())
3161}
3162
3163/// Start following another user in a call.
3164async fn follow(
3165 request: proto::Follow,
3166 response: Response<proto::Follow>,
3167 session: UserSession,
3168) -> Result<()> {
3169 let room_id = RoomId::from_proto(request.room_id);
3170 let project_id = request.project_id.map(ProjectId::from_proto);
3171 let leader_id = request
3172 .leader_id
3173 .ok_or_else(|| anyhow!("invalid leader id"))?
3174 .into();
3175 let follower_id = session.connection_id;
3176
3177 session
3178 .db()
3179 .await
3180 .check_room_participants(room_id, leader_id, session.connection_id)
3181 .await?;
3182
3183 let response_payload = session
3184 .peer
3185 .forward_request(session.connection_id, leader_id, request)
3186 .await?;
3187 response.send(response_payload)?;
3188
3189 if let Some(project_id) = project_id {
3190 let room = session
3191 .db()
3192 .await
3193 .follow(room_id, project_id, leader_id, follower_id)
3194 .await?;
3195 room_updated(&room, &session.peer);
3196 }
3197
3198 Ok(())
3199}
3200
3201/// Stop following another user in a call.
3202async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3203 let room_id = RoomId::from_proto(request.room_id);
3204 let project_id = request.project_id.map(ProjectId::from_proto);
3205 let leader_id = request
3206 .leader_id
3207 .ok_or_else(|| anyhow!("invalid leader id"))?
3208 .into();
3209 let follower_id = session.connection_id;
3210
3211 session
3212 .db()
3213 .await
3214 .check_room_participants(room_id, leader_id, session.connection_id)
3215 .await?;
3216
3217 session
3218 .peer
3219 .forward_send(session.connection_id, leader_id, request)?;
3220
3221 if let Some(project_id) = project_id {
3222 let room = session
3223 .db()
3224 .await
3225 .unfollow(room_id, project_id, leader_id, follower_id)
3226 .await?;
3227 room_updated(&room, &session.peer);
3228 }
3229
3230 Ok(())
3231}
3232
3233/// Notify everyone following you of your current location.
3234async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3235 let room_id = RoomId::from_proto(request.room_id);
3236 let database = session.db.lock().await;
3237
3238 let connection_ids = if let Some(project_id) = request.project_id {
3239 let project_id = ProjectId::from_proto(project_id);
3240 database
3241 .project_connection_ids(project_id, session.connection_id, true)
3242 .await?
3243 } else {
3244 database
3245 .room_connection_ids(room_id, session.connection_id)
3246 .await?
3247 };
3248
3249 // For now, don't send view update messages back to that view's current leader.
3250 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3251 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3252 _ => None,
3253 });
3254
3255 for connection_id in connection_ids.iter().cloned() {
3256 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3257 session
3258 .peer
3259 .forward_send(session.connection_id, connection_id, request.clone())?;
3260 }
3261 }
3262 Ok(())
3263}
3264
3265/// Get public data about users.
3266async fn get_users(
3267 request: proto::GetUsers,
3268 response: Response<proto::GetUsers>,
3269 session: Session,
3270) -> Result<()> {
3271 let user_ids = request
3272 .user_ids
3273 .into_iter()
3274 .map(UserId::from_proto)
3275 .collect();
3276 let users = session
3277 .db()
3278 .await
3279 .get_users_by_ids(user_ids)
3280 .await?
3281 .into_iter()
3282 .map(|user| proto::User {
3283 id: user.id.to_proto(),
3284 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3285 github_login: user.github_login,
3286 })
3287 .collect();
3288 response.send(proto::UsersResponse { users })?;
3289 Ok(())
3290}
3291
3292/// Search for users (to invite) buy Github login
3293async fn fuzzy_search_users(
3294 request: proto::FuzzySearchUsers,
3295 response: Response<proto::FuzzySearchUsers>,
3296 session: UserSession,
3297) -> Result<()> {
3298 let query = request.query;
3299 let users = match query.len() {
3300 0 => vec![],
3301 1 | 2 => session
3302 .db()
3303 .await
3304 .get_user_by_github_login(&query)
3305 .await?
3306 .into_iter()
3307 .collect(),
3308 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3309 };
3310 let users = users
3311 .into_iter()
3312 .filter(|user| user.id != session.user_id())
3313 .map(|user| proto::User {
3314 id: user.id.to_proto(),
3315 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3316 github_login: user.github_login,
3317 })
3318 .collect();
3319 response.send(proto::UsersResponse { users })?;
3320 Ok(())
3321}
3322
3323/// Send a contact request to another user.
3324async fn request_contact(
3325 request: proto::RequestContact,
3326 response: Response<proto::RequestContact>,
3327 session: UserSession,
3328) -> Result<()> {
3329 let requester_id = session.user_id();
3330 let responder_id = UserId::from_proto(request.responder_id);
3331 if requester_id == responder_id {
3332 return Err(anyhow!("cannot add yourself as a contact"))?;
3333 }
3334
3335 let notifications = session
3336 .db()
3337 .await
3338 .send_contact_request(requester_id, responder_id)
3339 .await?;
3340
3341 // Update outgoing contact requests of requester
3342 let mut update = proto::UpdateContacts::default();
3343 update.outgoing_requests.push(responder_id.to_proto());
3344 for connection_id in session
3345 .connection_pool()
3346 .await
3347 .user_connection_ids(requester_id)
3348 {
3349 session.peer.send(connection_id, update.clone())?;
3350 }
3351
3352 // Update incoming contact requests of responder
3353 let mut update = proto::UpdateContacts::default();
3354 update
3355 .incoming_requests
3356 .push(proto::IncomingContactRequest {
3357 requester_id: requester_id.to_proto(),
3358 });
3359 let connection_pool = session.connection_pool().await;
3360 for connection_id in connection_pool.user_connection_ids(responder_id) {
3361 session.peer.send(connection_id, update.clone())?;
3362 }
3363
3364 send_notifications(&connection_pool, &session.peer, notifications);
3365
3366 response.send(proto::Ack {})?;
3367 Ok(())
3368}
3369
3370/// Accept or decline a contact request
3371async fn respond_to_contact_request(
3372 request: proto::RespondToContactRequest,
3373 response: Response<proto::RespondToContactRequest>,
3374 session: UserSession,
3375) -> Result<()> {
3376 let responder_id = session.user_id();
3377 let requester_id = UserId::from_proto(request.requester_id);
3378 let db = session.db().await;
3379 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3380 db.dismiss_contact_notification(responder_id, requester_id)
3381 .await?;
3382 } else {
3383 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3384
3385 let notifications = db
3386 .respond_to_contact_request(responder_id, requester_id, accept)
3387 .await?;
3388 let requester_busy = db.is_user_busy(requester_id).await?;
3389 let responder_busy = db.is_user_busy(responder_id).await?;
3390
3391 let pool = session.connection_pool().await;
3392 // Update responder with new contact
3393 let mut update = proto::UpdateContacts::default();
3394 if accept {
3395 update
3396 .contacts
3397 .push(contact_for_user(requester_id, requester_busy, &pool));
3398 }
3399 update
3400 .remove_incoming_requests
3401 .push(requester_id.to_proto());
3402 for connection_id in pool.user_connection_ids(responder_id) {
3403 session.peer.send(connection_id, update.clone())?;
3404 }
3405
3406 // Update requester with new contact
3407 let mut update = proto::UpdateContacts::default();
3408 if accept {
3409 update
3410 .contacts
3411 .push(contact_for_user(responder_id, responder_busy, &pool));
3412 }
3413 update
3414 .remove_outgoing_requests
3415 .push(responder_id.to_proto());
3416
3417 for connection_id in pool.user_connection_ids(requester_id) {
3418 session.peer.send(connection_id, update.clone())?;
3419 }
3420
3421 send_notifications(&pool, &session.peer, notifications);
3422 }
3423
3424 response.send(proto::Ack {})?;
3425 Ok(())
3426}
3427
3428/// Remove a contact.
3429async fn remove_contact(
3430 request: proto::RemoveContact,
3431 response: Response<proto::RemoveContact>,
3432 session: UserSession,
3433) -> Result<()> {
3434 let requester_id = session.user_id();
3435 let responder_id = UserId::from_proto(request.user_id);
3436 let db = session.db().await;
3437 let (contact_accepted, deleted_notification_id) =
3438 db.remove_contact(requester_id, responder_id).await?;
3439
3440 let pool = session.connection_pool().await;
3441 // Update outgoing contact requests of requester
3442 let mut update = proto::UpdateContacts::default();
3443 if contact_accepted {
3444 update.remove_contacts.push(responder_id.to_proto());
3445 } else {
3446 update
3447 .remove_outgoing_requests
3448 .push(responder_id.to_proto());
3449 }
3450 for connection_id in pool.user_connection_ids(requester_id) {
3451 session.peer.send(connection_id, update.clone())?;
3452 }
3453
3454 // Update incoming contact requests of responder
3455 let mut update = proto::UpdateContacts::default();
3456 if contact_accepted {
3457 update.remove_contacts.push(requester_id.to_proto());
3458 } else {
3459 update
3460 .remove_incoming_requests
3461 .push(requester_id.to_proto());
3462 }
3463 for connection_id in pool.user_connection_ids(responder_id) {
3464 session.peer.send(connection_id, update.clone())?;
3465 if let Some(notification_id) = deleted_notification_id {
3466 session.peer.send(
3467 connection_id,
3468 proto::DeleteNotification {
3469 notification_id: notification_id.to_proto(),
3470 },
3471 )?;
3472 }
3473 }
3474
3475 response.send(proto::Ack {})?;
3476 Ok(())
3477}
3478
3479fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3480 version.0.minor() < 139
3481}
3482
3483async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3484 let plan = session.current_plan(&session.db().await).await?;
3485
3486 session
3487 .peer
3488 .send(
3489 session.connection_id,
3490 proto::UpdateUserPlan { plan: plan.into() },
3491 )
3492 .trace_err();
3493
3494 Ok(())
3495}
3496
3497async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3498 subscribe_user_to_channels(
3499 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3500 &session,
3501 )
3502 .await?;
3503 Ok(())
3504}
3505
3506async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3507 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3508 let mut pool = session.connection_pool().await;
3509 for membership in &channels_for_user.channel_memberships {
3510 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3511 }
3512 session.peer.send(
3513 session.connection_id,
3514 build_update_user_channels(&channels_for_user),
3515 )?;
3516 session.peer.send(
3517 session.connection_id,
3518 build_channels_update(channels_for_user),
3519 )?;
3520 Ok(())
3521}
3522
3523/// Creates a new channel.
3524async fn create_channel(
3525 request: proto::CreateChannel,
3526 response: Response<proto::CreateChannel>,
3527 session: UserSession,
3528) -> Result<()> {
3529 let db = session.db().await;
3530
3531 let parent_id = request.parent_id.map(ChannelId::from_proto);
3532 let (channel, membership) = db
3533 .create_channel(&request.name, parent_id, session.user_id())
3534 .await?;
3535
3536 let root_id = channel.root_id();
3537 let channel = Channel::from_model(channel);
3538
3539 response.send(proto::CreateChannelResponse {
3540 channel: Some(channel.to_proto()),
3541 parent_id: request.parent_id,
3542 })?;
3543
3544 let mut connection_pool = session.connection_pool().await;
3545 if let Some(membership) = membership {
3546 connection_pool.subscribe_to_channel(
3547 membership.user_id,
3548 membership.channel_id,
3549 membership.role,
3550 );
3551 let update = proto::UpdateUserChannels {
3552 channel_memberships: vec![proto::ChannelMembership {
3553 channel_id: membership.channel_id.to_proto(),
3554 role: membership.role.into(),
3555 }],
3556 ..Default::default()
3557 };
3558 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3559 session.peer.send(connection_id, update.clone())?;
3560 }
3561 }
3562
3563 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3564 if !role.can_see_channel(channel.visibility) {
3565 continue;
3566 }
3567
3568 let update = proto::UpdateChannels {
3569 channels: vec![channel.to_proto()],
3570 ..Default::default()
3571 };
3572 session.peer.send(connection_id, update.clone())?;
3573 }
3574
3575 Ok(())
3576}
3577
3578/// Delete a channel
3579async fn delete_channel(
3580 request: proto::DeleteChannel,
3581 response: Response<proto::DeleteChannel>,
3582 session: UserSession,
3583) -> Result<()> {
3584 let db = session.db().await;
3585
3586 let channel_id = request.channel_id;
3587 let (root_channel, removed_channels) = db
3588 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3589 .await?;
3590 response.send(proto::Ack {})?;
3591
3592 // Notify members of removed channels
3593 let mut update = proto::UpdateChannels::default();
3594 update
3595 .delete_channels
3596 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3597
3598 let connection_pool = session.connection_pool().await;
3599 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3600 session.peer.send(connection_id, update.clone())?;
3601 }
3602
3603 Ok(())
3604}
3605
3606/// Invite someone to join a channel.
3607async fn invite_channel_member(
3608 request: proto::InviteChannelMember,
3609 response: Response<proto::InviteChannelMember>,
3610 session: UserSession,
3611) -> Result<()> {
3612 let db = session.db().await;
3613 let channel_id = ChannelId::from_proto(request.channel_id);
3614 let invitee_id = UserId::from_proto(request.user_id);
3615 let InviteMemberResult {
3616 channel,
3617 notifications,
3618 } = db
3619 .invite_channel_member(
3620 channel_id,
3621 invitee_id,
3622 session.user_id(),
3623 request.role().into(),
3624 )
3625 .await?;
3626
3627 let update = proto::UpdateChannels {
3628 channel_invitations: vec![channel.to_proto()],
3629 ..Default::default()
3630 };
3631
3632 let connection_pool = session.connection_pool().await;
3633 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3634 session.peer.send(connection_id, update.clone())?;
3635 }
3636
3637 send_notifications(&connection_pool, &session.peer, notifications);
3638
3639 response.send(proto::Ack {})?;
3640 Ok(())
3641}
3642
3643/// remove someone from a channel
3644async fn remove_channel_member(
3645 request: proto::RemoveChannelMember,
3646 response: Response<proto::RemoveChannelMember>,
3647 session: UserSession,
3648) -> Result<()> {
3649 let db = session.db().await;
3650 let channel_id = ChannelId::from_proto(request.channel_id);
3651 let member_id = UserId::from_proto(request.user_id);
3652
3653 let RemoveChannelMemberResult {
3654 membership_update,
3655 notification_id,
3656 } = db
3657 .remove_channel_member(channel_id, member_id, session.user_id())
3658 .await?;
3659
3660 let mut connection_pool = session.connection_pool().await;
3661 notify_membership_updated(
3662 &mut connection_pool,
3663 membership_update,
3664 member_id,
3665 &session.peer,
3666 );
3667 for connection_id in connection_pool.user_connection_ids(member_id) {
3668 if let Some(notification_id) = notification_id {
3669 session
3670 .peer
3671 .send(
3672 connection_id,
3673 proto::DeleteNotification {
3674 notification_id: notification_id.to_proto(),
3675 },
3676 )
3677 .trace_err();
3678 }
3679 }
3680
3681 response.send(proto::Ack {})?;
3682 Ok(())
3683}
3684
3685/// Toggle the channel between public and private.
3686/// Care is taken to maintain the invariant that public channels only descend from public channels,
3687/// (though members-only channels can appear at any point in the hierarchy).
3688async fn set_channel_visibility(
3689 request: proto::SetChannelVisibility,
3690 response: Response<proto::SetChannelVisibility>,
3691 session: UserSession,
3692) -> Result<()> {
3693 let db = session.db().await;
3694 let channel_id = ChannelId::from_proto(request.channel_id);
3695 let visibility = request.visibility().into();
3696
3697 let channel_model = db
3698 .set_channel_visibility(channel_id, visibility, session.user_id())
3699 .await?;
3700 let root_id = channel_model.root_id();
3701 let channel = Channel::from_model(channel_model);
3702
3703 let mut connection_pool = session.connection_pool().await;
3704 for (user_id, role) in connection_pool
3705 .channel_user_ids(root_id)
3706 .collect::<Vec<_>>()
3707 .into_iter()
3708 {
3709 let update = if role.can_see_channel(channel.visibility) {
3710 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3711 proto::UpdateChannels {
3712 channels: vec![channel.to_proto()],
3713 ..Default::default()
3714 }
3715 } else {
3716 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3717 proto::UpdateChannels {
3718 delete_channels: vec![channel.id.to_proto()],
3719 ..Default::default()
3720 }
3721 };
3722
3723 for connection_id in connection_pool.user_connection_ids(user_id) {
3724 session.peer.send(connection_id, update.clone())?;
3725 }
3726 }
3727
3728 response.send(proto::Ack {})?;
3729 Ok(())
3730}
3731
3732/// Alter the role for a user in the channel.
3733async fn set_channel_member_role(
3734 request: proto::SetChannelMemberRole,
3735 response: Response<proto::SetChannelMemberRole>,
3736 session: UserSession,
3737) -> Result<()> {
3738 let db = session.db().await;
3739 let channel_id = ChannelId::from_proto(request.channel_id);
3740 let member_id = UserId::from_proto(request.user_id);
3741 let result = db
3742 .set_channel_member_role(
3743 channel_id,
3744 session.user_id(),
3745 member_id,
3746 request.role().into(),
3747 )
3748 .await?;
3749
3750 match result {
3751 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3752 let mut connection_pool = session.connection_pool().await;
3753 notify_membership_updated(
3754 &mut connection_pool,
3755 membership_update,
3756 member_id,
3757 &session.peer,
3758 )
3759 }
3760 db::SetMemberRoleResult::InviteUpdated(channel) => {
3761 let update = proto::UpdateChannels {
3762 channel_invitations: vec![channel.to_proto()],
3763 ..Default::default()
3764 };
3765
3766 for connection_id in session
3767 .connection_pool()
3768 .await
3769 .user_connection_ids(member_id)
3770 {
3771 session.peer.send(connection_id, update.clone())?;
3772 }
3773 }
3774 }
3775
3776 response.send(proto::Ack {})?;
3777 Ok(())
3778}
3779
3780/// Change the name of a channel
3781async fn rename_channel(
3782 request: proto::RenameChannel,
3783 response: Response<proto::RenameChannel>,
3784 session: UserSession,
3785) -> Result<()> {
3786 let db = session.db().await;
3787 let channel_id = ChannelId::from_proto(request.channel_id);
3788 let channel_model = db
3789 .rename_channel(channel_id, session.user_id(), &request.name)
3790 .await?;
3791 let root_id = channel_model.root_id();
3792 let channel = Channel::from_model(channel_model);
3793
3794 response.send(proto::RenameChannelResponse {
3795 channel: Some(channel.to_proto()),
3796 })?;
3797
3798 let connection_pool = session.connection_pool().await;
3799 let update = proto::UpdateChannels {
3800 channels: vec![channel.to_proto()],
3801 ..Default::default()
3802 };
3803 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3804 if role.can_see_channel(channel.visibility) {
3805 session.peer.send(connection_id, update.clone())?;
3806 }
3807 }
3808
3809 Ok(())
3810}
3811
3812/// Move a channel to a new parent.
3813async fn move_channel(
3814 request: proto::MoveChannel,
3815 response: Response<proto::MoveChannel>,
3816 session: UserSession,
3817) -> Result<()> {
3818 let channel_id = ChannelId::from_proto(request.channel_id);
3819 let to = ChannelId::from_proto(request.to);
3820
3821 let (root_id, channels) = session
3822 .db()
3823 .await
3824 .move_channel(channel_id, to, session.user_id())
3825 .await?;
3826
3827 let connection_pool = session.connection_pool().await;
3828 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3829 let channels = channels
3830 .iter()
3831 .filter_map(|channel| {
3832 if role.can_see_channel(channel.visibility) {
3833 Some(channel.to_proto())
3834 } else {
3835 None
3836 }
3837 })
3838 .collect::<Vec<_>>();
3839 if channels.is_empty() {
3840 continue;
3841 }
3842
3843 let update = proto::UpdateChannels {
3844 channels,
3845 ..Default::default()
3846 };
3847
3848 session.peer.send(connection_id, update.clone())?;
3849 }
3850
3851 response.send(Ack {})?;
3852 Ok(())
3853}
3854
3855/// Get the list of channel members
3856async fn get_channel_members(
3857 request: proto::GetChannelMembers,
3858 response: Response<proto::GetChannelMembers>,
3859 session: UserSession,
3860) -> Result<()> {
3861 let db = session.db().await;
3862 let channel_id = ChannelId::from_proto(request.channel_id);
3863 let limit = if request.limit == 0 {
3864 u16::MAX as u64
3865 } else {
3866 request.limit
3867 };
3868 let (members, users) = db
3869 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3870 .await?;
3871 response.send(proto::GetChannelMembersResponse { members, users })?;
3872 Ok(())
3873}
3874
3875/// Accept or decline a channel invitation.
3876async fn respond_to_channel_invite(
3877 request: proto::RespondToChannelInvite,
3878 response: Response<proto::RespondToChannelInvite>,
3879 session: UserSession,
3880) -> Result<()> {
3881 let db = session.db().await;
3882 let channel_id = ChannelId::from_proto(request.channel_id);
3883 let RespondToChannelInvite {
3884 membership_update,
3885 notifications,
3886 } = db
3887 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3888 .await?;
3889
3890 let mut connection_pool = session.connection_pool().await;
3891 if let Some(membership_update) = membership_update {
3892 notify_membership_updated(
3893 &mut connection_pool,
3894 membership_update,
3895 session.user_id(),
3896 &session.peer,
3897 );
3898 } else {
3899 let update = proto::UpdateChannels {
3900 remove_channel_invitations: vec![channel_id.to_proto()],
3901 ..Default::default()
3902 };
3903
3904 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3905 session.peer.send(connection_id, update.clone())?;
3906 }
3907 };
3908
3909 send_notifications(&connection_pool, &session.peer, notifications);
3910
3911 response.send(proto::Ack {})?;
3912
3913 Ok(())
3914}
3915
3916/// Join the channels' room
3917async fn join_channel(
3918 request: proto::JoinChannel,
3919 response: Response<proto::JoinChannel>,
3920 session: UserSession,
3921) -> Result<()> {
3922 let channel_id = ChannelId::from_proto(request.channel_id);
3923 join_channel_internal(channel_id, Box::new(response), session).await
3924}
3925
3926trait JoinChannelInternalResponse {
3927 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3928}
3929impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3930 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3931 Response::<proto::JoinChannel>::send(self, result)
3932 }
3933}
3934impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3935 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3936 Response::<proto::JoinRoom>::send(self, result)
3937 }
3938}
3939
3940async fn join_channel_internal(
3941 channel_id: ChannelId,
3942 response: Box<impl JoinChannelInternalResponse>,
3943 session: UserSession,
3944) -> Result<()> {
3945 let joined_room = {
3946 let mut db = session.db().await;
3947 // If zed quits without leaving the room, and the user re-opens zed before the
3948 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3949 // room they were in.
3950 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3951 tracing::info!(
3952 stale_connection_id = %connection,
3953 "cleaning up stale connection",
3954 );
3955 drop(db);
3956 leave_room_for_session(&session, connection).await?;
3957 db = session.db().await;
3958 }
3959
3960 let (joined_room, membership_updated, role) = db
3961 .join_channel(channel_id, session.user_id(), session.connection_id)
3962 .await?;
3963
3964 let live_kit_connection_info =
3965 session
3966 .app_state
3967 .live_kit_client
3968 .as_ref()
3969 .and_then(|live_kit| {
3970 let (can_publish, token) = if role == ChannelRole::Guest {
3971 (
3972 false,
3973 live_kit
3974 .guest_token(
3975 &joined_room.room.live_kit_room,
3976 &session.user_id().to_string(),
3977 )
3978 .trace_err()?,
3979 )
3980 } else {
3981 (
3982 true,
3983 live_kit
3984 .room_token(
3985 &joined_room.room.live_kit_room,
3986 &session.user_id().to_string(),
3987 )
3988 .trace_err()?,
3989 )
3990 };
3991
3992 Some(LiveKitConnectionInfo {
3993 server_url: live_kit.url().into(),
3994 token,
3995 can_publish,
3996 })
3997 });
3998
3999 response.send(proto::JoinRoomResponse {
4000 room: Some(joined_room.room.clone()),
4001 channel_id: joined_room
4002 .channel
4003 .as_ref()
4004 .map(|channel| channel.id.to_proto()),
4005 live_kit_connection_info,
4006 })?;
4007
4008 let mut connection_pool = session.connection_pool().await;
4009 if let Some(membership_updated) = membership_updated {
4010 notify_membership_updated(
4011 &mut connection_pool,
4012 membership_updated,
4013 session.user_id(),
4014 &session.peer,
4015 );
4016 }
4017
4018 room_updated(&joined_room.room, &session.peer);
4019
4020 joined_room
4021 };
4022
4023 channel_updated(
4024 &joined_room
4025 .channel
4026 .ok_or_else(|| anyhow!("channel not returned"))?,
4027 &joined_room.room,
4028 &session.peer,
4029 &*session.connection_pool().await,
4030 );
4031
4032 update_user_contacts(session.user_id(), &session).await?;
4033 Ok(())
4034}
4035
4036/// Start editing the channel notes
4037async fn join_channel_buffer(
4038 request: proto::JoinChannelBuffer,
4039 response: Response<proto::JoinChannelBuffer>,
4040 session: UserSession,
4041) -> Result<()> {
4042 let db = session.db().await;
4043 let channel_id = ChannelId::from_proto(request.channel_id);
4044
4045 let open_response = db
4046 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4047 .await?;
4048
4049 let collaborators = open_response.collaborators.clone();
4050 response.send(open_response)?;
4051
4052 let update = UpdateChannelBufferCollaborators {
4053 channel_id: channel_id.to_proto(),
4054 collaborators: collaborators.clone(),
4055 };
4056 channel_buffer_updated(
4057 session.connection_id,
4058 collaborators
4059 .iter()
4060 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4061 &update,
4062 &session.peer,
4063 );
4064
4065 Ok(())
4066}
4067
4068/// Edit the channel notes
4069async fn update_channel_buffer(
4070 request: proto::UpdateChannelBuffer,
4071 session: UserSession,
4072) -> Result<()> {
4073 let db = session.db().await;
4074 let channel_id = ChannelId::from_proto(request.channel_id);
4075
4076 let (collaborators, epoch, version) = db
4077 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4078 .await?;
4079
4080 channel_buffer_updated(
4081 session.connection_id,
4082 collaborators.clone(),
4083 &proto::UpdateChannelBuffer {
4084 channel_id: channel_id.to_proto(),
4085 operations: request.operations,
4086 },
4087 &session.peer,
4088 );
4089
4090 let pool = &*session.connection_pool().await;
4091
4092 let non_collaborators =
4093 pool.channel_connection_ids(channel_id)
4094 .filter_map(|(connection_id, _)| {
4095 if collaborators.contains(&connection_id) {
4096 None
4097 } else {
4098 Some(connection_id)
4099 }
4100 });
4101
4102 broadcast(None, non_collaborators, |peer_id| {
4103 session.peer.send(
4104 peer_id,
4105 proto::UpdateChannels {
4106 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4107 channel_id: channel_id.to_proto(),
4108 epoch: epoch as u64,
4109 version: version.clone(),
4110 }],
4111 ..Default::default()
4112 },
4113 )
4114 });
4115
4116 Ok(())
4117}
4118
4119/// Rejoin the channel notes after a connection blip
4120async fn rejoin_channel_buffers(
4121 request: proto::RejoinChannelBuffers,
4122 response: Response<proto::RejoinChannelBuffers>,
4123 session: UserSession,
4124) -> Result<()> {
4125 let db = session.db().await;
4126 let buffers = db
4127 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4128 .await?;
4129
4130 for rejoined_buffer in &buffers {
4131 let collaborators_to_notify = rejoined_buffer
4132 .buffer
4133 .collaborators
4134 .iter()
4135 .filter_map(|c| Some(c.peer_id?.into()));
4136 channel_buffer_updated(
4137 session.connection_id,
4138 collaborators_to_notify,
4139 &proto::UpdateChannelBufferCollaborators {
4140 channel_id: rejoined_buffer.buffer.channel_id,
4141 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4142 },
4143 &session.peer,
4144 );
4145 }
4146
4147 response.send(proto::RejoinChannelBuffersResponse {
4148 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4149 })?;
4150
4151 Ok(())
4152}
4153
4154/// Stop editing the channel notes
4155async fn leave_channel_buffer(
4156 request: proto::LeaveChannelBuffer,
4157 response: Response<proto::LeaveChannelBuffer>,
4158 session: UserSession,
4159) -> Result<()> {
4160 let db = session.db().await;
4161 let channel_id = ChannelId::from_proto(request.channel_id);
4162
4163 let left_buffer = db
4164 .leave_channel_buffer(channel_id, session.connection_id)
4165 .await?;
4166
4167 response.send(Ack {})?;
4168
4169 channel_buffer_updated(
4170 session.connection_id,
4171 left_buffer.connections,
4172 &proto::UpdateChannelBufferCollaborators {
4173 channel_id: channel_id.to_proto(),
4174 collaborators: left_buffer.collaborators,
4175 },
4176 &session.peer,
4177 );
4178
4179 Ok(())
4180}
4181
4182fn channel_buffer_updated<T: EnvelopedMessage>(
4183 sender_id: ConnectionId,
4184 collaborators: impl IntoIterator<Item = ConnectionId>,
4185 message: &T,
4186 peer: &Peer,
4187) {
4188 broadcast(Some(sender_id), collaborators, |peer_id| {
4189 peer.send(peer_id, message.clone())
4190 });
4191}
4192
4193fn send_notifications(
4194 connection_pool: &ConnectionPool,
4195 peer: &Peer,
4196 notifications: db::NotificationBatch,
4197) {
4198 for (user_id, notification) in notifications {
4199 for connection_id in connection_pool.user_connection_ids(user_id) {
4200 if let Err(error) = peer.send(
4201 connection_id,
4202 proto::AddNotification {
4203 notification: Some(notification.clone()),
4204 },
4205 ) {
4206 tracing::error!(
4207 "failed to send notification to {:?} {}",
4208 connection_id,
4209 error
4210 );
4211 }
4212 }
4213 }
4214}
4215
4216/// Send a message to the channel
4217async fn send_channel_message(
4218 request: proto::SendChannelMessage,
4219 response: Response<proto::SendChannelMessage>,
4220 session: UserSession,
4221) -> Result<()> {
4222 // Validate the message body.
4223 let body = request.body.trim().to_string();
4224 if body.len() > MAX_MESSAGE_LEN {
4225 return Err(anyhow!("message is too long"))?;
4226 }
4227 if body.is_empty() {
4228 return Err(anyhow!("message can't be blank"))?;
4229 }
4230
4231 // TODO: adjust mentions if body is trimmed
4232
4233 let timestamp = OffsetDateTime::now_utc();
4234 let nonce = request
4235 .nonce
4236 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4237
4238 let channel_id = ChannelId::from_proto(request.channel_id);
4239 let CreatedChannelMessage {
4240 message_id,
4241 participant_connection_ids,
4242 notifications,
4243 } = session
4244 .db()
4245 .await
4246 .create_channel_message(
4247 channel_id,
4248 session.user_id(),
4249 &body,
4250 &request.mentions,
4251 timestamp,
4252 nonce.clone().into(),
4253 request.reply_to_message_id.map(MessageId::from_proto),
4254 )
4255 .await?;
4256
4257 let message = proto::ChannelMessage {
4258 sender_id: session.user_id().to_proto(),
4259 id: message_id.to_proto(),
4260 body,
4261 mentions: request.mentions,
4262 timestamp: timestamp.unix_timestamp() as u64,
4263 nonce: Some(nonce),
4264 reply_to_message_id: request.reply_to_message_id,
4265 edited_at: None,
4266 };
4267 broadcast(
4268 Some(session.connection_id),
4269 participant_connection_ids.clone(),
4270 |connection| {
4271 session.peer.send(
4272 connection,
4273 proto::ChannelMessageSent {
4274 channel_id: channel_id.to_proto(),
4275 message: Some(message.clone()),
4276 },
4277 )
4278 },
4279 );
4280 response.send(proto::SendChannelMessageResponse {
4281 message: Some(message),
4282 })?;
4283
4284 let pool = &*session.connection_pool().await;
4285 let non_participants =
4286 pool.channel_connection_ids(channel_id)
4287 .filter_map(|(connection_id, _)| {
4288 if participant_connection_ids.contains(&connection_id) {
4289 None
4290 } else {
4291 Some(connection_id)
4292 }
4293 });
4294 broadcast(None, non_participants, |peer_id| {
4295 session.peer.send(
4296 peer_id,
4297 proto::UpdateChannels {
4298 latest_channel_message_ids: vec![proto::ChannelMessageId {
4299 channel_id: channel_id.to_proto(),
4300 message_id: message_id.to_proto(),
4301 }],
4302 ..Default::default()
4303 },
4304 )
4305 });
4306 send_notifications(pool, &session.peer, notifications);
4307
4308 Ok(())
4309}
4310
4311/// Delete a channel message
4312async fn remove_channel_message(
4313 request: proto::RemoveChannelMessage,
4314 response: Response<proto::RemoveChannelMessage>,
4315 session: UserSession,
4316) -> Result<()> {
4317 let channel_id = ChannelId::from_proto(request.channel_id);
4318 let message_id = MessageId::from_proto(request.message_id);
4319 let (connection_ids, existing_notification_ids) = session
4320 .db()
4321 .await
4322 .remove_channel_message(channel_id, message_id, session.user_id())
4323 .await?;
4324
4325 broadcast(
4326 Some(session.connection_id),
4327 connection_ids,
4328 move |connection| {
4329 session.peer.send(connection, request.clone())?;
4330
4331 for notification_id in &existing_notification_ids {
4332 session.peer.send(
4333 connection,
4334 proto::DeleteNotification {
4335 notification_id: (*notification_id).to_proto(),
4336 },
4337 )?;
4338 }
4339
4340 Ok(())
4341 },
4342 );
4343 response.send(proto::Ack {})?;
4344 Ok(())
4345}
4346
4347async fn update_channel_message(
4348 request: proto::UpdateChannelMessage,
4349 response: Response<proto::UpdateChannelMessage>,
4350 session: UserSession,
4351) -> Result<()> {
4352 let channel_id = ChannelId::from_proto(request.channel_id);
4353 let message_id = MessageId::from_proto(request.message_id);
4354 let updated_at = OffsetDateTime::now_utc();
4355 let UpdatedChannelMessage {
4356 message_id,
4357 participant_connection_ids,
4358 notifications,
4359 reply_to_message_id,
4360 timestamp,
4361 deleted_mention_notification_ids,
4362 updated_mention_notifications,
4363 } = session
4364 .db()
4365 .await
4366 .update_channel_message(
4367 channel_id,
4368 message_id,
4369 session.user_id(),
4370 request.body.as_str(),
4371 &request.mentions,
4372 updated_at,
4373 )
4374 .await?;
4375
4376 let nonce = request
4377 .nonce
4378 .clone()
4379 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4380
4381 let message = proto::ChannelMessage {
4382 sender_id: session.user_id().to_proto(),
4383 id: message_id.to_proto(),
4384 body: request.body.clone(),
4385 mentions: request.mentions.clone(),
4386 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4387 nonce: Some(nonce),
4388 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4389 edited_at: Some(updated_at.unix_timestamp() as u64),
4390 };
4391
4392 response.send(proto::Ack {})?;
4393
4394 let pool = &*session.connection_pool().await;
4395 broadcast(
4396 Some(session.connection_id),
4397 participant_connection_ids,
4398 |connection| {
4399 session.peer.send(
4400 connection,
4401 proto::ChannelMessageUpdate {
4402 channel_id: channel_id.to_proto(),
4403 message: Some(message.clone()),
4404 },
4405 )?;
4406
4407 for notification_id in &deleted_mention_notification_ids {
4408 session.peer.send(
4409 connection,
4410 proto::DeleteNotification {
4411 notification_id: (*notification_id).to_proto(),
4412 },
4413 )?;
4414 }
4415
4416 for notification in &updated_mention_notifications {
4417 session.peer.send(
4418 connection,
4419 proto::UpdateNotification {
4420 notification: Some(notification.clone()),
4421 },
4422 )?;
4423 }
4424
4425 Ok(())
4426 },
4427 );
4428
4429 send_notifications(pool, &session.peer, notifications);
4430
4431 Ok(())
4432}
4433
4434/// Mark a channel message as read
4435async fn acknowledge_channel_message(
4436 request: proto::AckChannelMessage,
4437 session: UserSession,
4438) -> Result<()> {
4439 let channel_id = ChannelId::from_proto(request.channel_id);
4440 let message_id = MessageId::from_proto(request.message_id);
4441 let notifications = session
4442 .db()
4443 .await
4444 .observe_channel_message(channel_id, session.user_id(), message_id)
4445 .await?;
4446 send_notifications(
4447 &*session.connection_pool().await,
4448 &session.peer,
4449 notifications,
4450 );
4451 Ok(())
4452}
4453
4454/// Mark a buffer version as synced
4455async fn acknowledge_buffer_version(
4456 request: proto::AckBufferOperation,
4457 session: UserSession,
4458) -> Result<()> {
4459 let buffer_id = BufferId::from_proto(request.buffer_id);
4460 session
4461 .db()
4462 .await
4463 .observe_buffer_version(
4464 buffer_id,
4465 session.user_id(),
4466 request.epoch as i32,
4467 &request.version,
4468 )
4469 .await?;
4470 Ok(())
4471}
4472
4473async fn count_language_model_tokens(
4474 request: proto::CountLanguageModelTokens,
4475 response: Response<proto::CountLanguageModelTokens>,
4476 session: Session,
4477 config: &Config,
4478) -> Result<()> {
4479 let Some(session) = session.for_user() else {
4480 return Err(anyhow!("user not found"))?;
4481 };
4482 authorize_access_to_legacy_llm_endpoints(&session).await?;
4483
4484 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4485 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4486 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4487 };
4488
4489 session
4490 .app_state
4491 .rate_limiter
4492 .check(&*rate_limit, session.user_id())
4493 .await?;
4494
4495 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4496 Some(proto::LanguageModelProvider::Google) => {
4497 let api_key = config
4498 .google_ai_api_key
4499 .as_ref()
4500 .context("no Google AI API key configured on the server")?;
4501 google_ai::count_tokens(
4502 session.http_client.as_ref(),
4503 google_ai::API_URL,
4504 api_key,
4505 serde_json::from_str(&request.request)?,
4506 None,
4507 )
4508 .await?
4509 }
4510 _ => return Err(anyhow!("unsupported provider"))?,
4511 };
4512
4513 response.send(proto::CountLanguageModelTokensResponse {
4514 token_count: result.total_tokens as u32,
4515 })?;
4516
4517 Ok(())
4518}
4519
4520struct ZedProCountLanguageModelTokensRateLimit;
4521
4522impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4523 fn capacity(&self) -> usize {
4524 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4525 .ok()
4526 .and_then(|v| v.parse().ok())
4527 .unwrap_or(600) // Picked arbitrarily
4528 }
4529
4530 fn refill_duration(&self) -> chrono::Duration {
4531 chrono::Duration::hours(1)
4532 }
4533
4534 fn db_name(&self) -> &'static str {
4535 "zed-pro:count-language-model-tokens"
4536 }
4537}
4538
4539struct FreeCountLanguageModelTokensRateLimit;
4540
4541impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4542 fn capacity(&self) -> usize {
4543 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4544 .ok()
4545 .and_then(|v| v.parse().ok())
4546 .unwrap_or(600 / 10) // Picked arbitrarily
4547 }
4548
4549 fn refill_duration(&self) -> chrono::Duration {
4550 chrono::Duration::hours(1)
4551 }
4552
4553 fn db_name(&self) -> &'static str {
4554 "free:count-language-model-tokens"
4555 }
4556}
4557
4558struct ZedProComputeEmbeddingsRateLimit;
4559
4560impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4561 fn capacity(&self) -> usize {
4562 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4563 .ok()
4564 .and_then(|v| v.parse().ok())
4565 .unwrap_or(5000) // Picked arbitrarily
4566 }
4567
4568 fn refill_duration(&self) -> chrono::Duration {
4569 chrono::Duration::hours(1)
4570 }
4571
4572 fn db_name(&self) -> &'static str {
4573 "zed-pro:compute-embeddings"
4574 }
4575}
4576
4577struct FreeComputeEmbeddingsRateLimit;
4578
4579impl RateLimit for FreeComputeEmbeddingsRateLimit {
4580 fn capacity(&self) -> usize {
4581 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4582 .ok()
4583 .and_then(|v| v.parse().ok())
4584 .unwrap_or(5000 / 10) // Picked arbitrarily
4585 }
4586
4587 fn refill_duration(&self) -> chrono::Duration {
4588 chrono::Duration::hours(1)
4589 }
4590
4591 fn db_name(&self) -> &'static str {
4592 "free:compute-embeddings"
4593 }
4594}
4595
4596async fn compute_embeddings(
4597 request: proto::ComputeEmbeddings,
4598 response: Response<proto::ComputeEmbeddings>,
4599 session: UserSession,
4600 api_key: Option<Arc<str>>,
4601) -> Result<()> {
4602 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4603 authorize_access_to_legacy_llm_endpoints(&session).await?;
4604
4605 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4606 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4607 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4608 };
4609
4610 session
4611 .app_state
4612 .rate_limiter
4613 .check(&*rate_limit, session.user_id())
4614 .await?;
4615
4616 let embeddings = match request.model.as_str() {
4617 "openai/text-embedding-3-small" => {
4618 open_ai::embed(
4619 session.http_client.as_ref(),
4620 OPEN_AI_API_URL,
4621 &api_key,
4622 OpenAiEmbeddingModel::TextEmbedding3Small,
4623 request.texts.iter().map(|text| text.as_str()),
4624 )
4625 .await?
4626 }
4627 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4628 };
4629
4630 let embeddings = request
4631 .texts
4632 .iter()
4633 .map(|text| {
4634 let mut hasher = sha2::Sha256::new();
4635 hasher.update(text.as_bytes());
4636 let result = hasher.finalize();
4637 result.to_vec()
4638 })
4639 .zip(
4640 embeddings
4641 .data
4642 .into_iter()
4643 .map(|embedding| embedding.embedding),
4644 )
4645 .collect::<HashMap<_, _>>();
4646
4647 let db = session.db().await;
4648 db.save_embeddings(&request.model, &embeddings)
4649 .await
4650 .context("failed to save embeddings")
4651 .trace_err();
4652
4653 response.send(proto::ComputeEmbeddingsResponse {
4654 embeddings: embeddings
4655 .into_iter()
4656 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4657 .collect(),
4658 })?;
4659 Ok(())
4660}
4661
4662async fn get_cached_embeddings(
4663 request: proto::GetCachedEmbeddings,
4664 response: Response<proto::GetCachedEmbeddings>,
4665 session: UserSession,
4666) -> Result<()> {
4667 authorize_access_to_legacy_llm_endpoints(&session).await?;
4668
4669 let db = session.db().await;
4670 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4671
4672 response.send(proto::GetCachedEmbeddingsResponse {
4673 embeddings: embeddings
4674 .into_iter()
4675 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4676 .collect(),
4677 })?;
4678 Ok(())
4679}
4680
4681/// This is leftover from before the LLM service.
4682///
4683/// The endpoints protected by this check will be moved there eventually.
4684async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4685 if session.is_staff() {
4686 Ok(())
4687 } else {
4688 Err(anyhow!("permission denied"))?
4689 }
4690}
4691
4692/// Get a Supermaven API key for the user
4693async fn get_supermaven_api_key(
4694 _request: proto::GetSupermavenApiKey,
4695 response: Response<proto::GetSupermavenApiKey>,
4696 session: UserSession,
4697) -> Result<()> {
4698 let user_id: String = session.user_id().to_string();
4699 if !session.is_staff() {
4700 return Err(anyhow!("supermaven not enabled for this account"))?;
4701 }
4702
4703 let email = session
4704 .email()
4705 .ok_or_else(|| anyhow!("user must have an email"))?;
4706
4707 let supermaven_admin_api = session
4708 .supermaven_client
4709 .as_ref()
4710 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4711
4712 let result = supermaven_admin_api
4713 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4714 .await?;
4715
4716 response.send(proto::GetSupermavenApiKeyResponse {
4717 api_key: result.api_key,
4718 })?;
4719
4720 Ok(())
4721}
4722
4723/// Start receiving chat updates for a channel
4724async fn join_channel_chat(
4725 request: proto::JoinChannelChat,
4726 response: Response<proto::JoinChannelChat>,
4727 session: UserSession,
4728) -> Result<()> {
4729 let channel_id = ChannelId::from_proto(request.channel_id);
4730
4731 let db = session.db().await;
4732 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4733 .await?;
4734 let messages = db
4735 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4736 .await?;
4737 response.send(proto::JoinChannelChatResponse {
4738 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4739 messages,
4740 })?;
4741 Ok(())
4742}
4743
4744/// Stop receiving chat updates for a channel
4745async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4746 let channel_id = ChannelId::from_proto(request.channel_id);
4747 session
4748 .db()
4749 .await
4750 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4751 .await?;
4752 Ok(())
4753}
4754
4755/// Retrieve the chat history for a channel
4756async fn get_channel_messages(
4757 request: proto::GetChannelMessages,
4758 response: Response<proto::GetChannelMessages>,
4759 session: UserSession,
4760) -> Result<()> {
4761 let channel_id = ChannelId::from_proto(request.channel_id);
4762 let messages = session
4763 .db()
4764 .await
4765 .get_channel_messages(
4766 channel_id,
4767 session.user_id(),
4768 MESSAGE_COUNT_PER_PAGE,
4769 Some(MessageId::from_proto(request.before_message_id)),
4770 )
4771 .await?;
4772 response.send(proto::GetChannelMessagesResponse {
4773 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4774 messages,
4775 })?;
4776 Ok(())
4777}
4778
4779/// Retrieve specific chat messages
4780async fn get_channel_messages_by_id(
4781 request: proto::GetChannelMessagesById,
4782 response: Response<proto::GetChannelMessagesById>,
4783 session: UserSession,
4784) -> Result<()> {
4785 let message_ids = request
4786 .message_ids
4787 .iter()
4788 .map(|id| MessageId::from_proto(*id))
4789 .collect::<Vec<_>>();
4790 let messages = session
4791 .db()
4792 .await
4793 .get_channel_messages_by_id(session.user_id(), &message_ids)
4794 .await?;
4795 response.send(proto::GetChannelMessagesResponse {
4796 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4797 messages,
4798 })?;
4799 Ok(())
4800}
4801
4802/// Retrieve the current users notifications
4803async fn get_notifications(
4804 request: proto::GetNotifications,
4805 response: Response<proto::GetNotifications>,
4806 session: UserSession,
4807) -> Result<()> {
4808 let notifications = session
4809 .db()
4810 .await
4811 .get_notifications(
4812 session.user_id(),
4813 NOTIFICATION_COUNT_PER_PAGE,
4814 request.before_id.map(db::NotificationId::from_proto),
4815 )
4816 .await?;
4817 response.send(proto::GetNotificationsResponse {
4818 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4819 notifications,
4820 })?;
4821 Ok(())
4822}
4823
4824/// Mark notifications as read
4825async fn mark_notification_as_read(
4826 request: proto::MarkNotificationRead,
4827 response: Response<proto::MarkNotificationRead>,
4828 session: UserSession,
4829) -> Result<()> {
4830 let database = &session.db().await;
4831 let notifications = database
4832 .mark_notification_as_read_by_id(
4833 session.user_id(),
4834 NotificationId::from_proto(request.notification_id),
4835 )
4836 .await?;
4837 send_notifications(
4838 &*session.connection_pool().await,
4839 &session.peer,
4840 notifications,
4841 );
4842 response.send(proto::Ack {})?;
4843 Ok(())
4844}
4845
4846/// Get the current users information
4847async fn get_private_user_info(
4848 _request: proto::GetPrivateUserInfo,
4849 response: Response<proto::GetPrivateUserInfo>,
4850 session: UserSession,
4851) -> Result<()> {
4852 let db = session.db().await;
4853
4854 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4855 let user = db
4856 .get_user_by_id(session.user_id())
4857 .await?
4858 .ok_or_else(|| anyhow!("user not found"))?;
4859 let flags = db.get_user_flags(session.user_id()).await?;
4860
4861 response.send(proto::GetPrivateUserInfoResponse {
4862 metrics_id,
4863 staff: user.admin,
4864 flags,
4865 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4866 })?;
4867 Ok(())
4868}
4869
4870/// Accept the terms of service (tos) on behalf of the current user
4871async fn accept_terms_of_service(
4872 _request: proto::AcceptTermsOfService,
4873 response: Response<proto::AcceptTermsOfService>,
4874 session: UserSession,
4875) -> Result<()> {
4876 let db = session.db().await;
4877
4878 let accepted_tos_at = Utc::now();
4879 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4880 .await?;
4881
4882 response.send(proto::AcceptTermsOfServiceResponse {
4883 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4884 })?;
4885 Ok(())
4886}
4887
4888/// The minimum account age an account must have in order to use the LLM service.
4889const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4890
4891async fn get_llm_api_token(
4892 _request: proto::GetLlmToken,
4893 response: Response<proto::GetLlmToken>,
4894 session: UserSession,
4895) -> Result<()> {
4896 let db = session.db().await;
4897
4898 let flags = db.get_user_flags(session.user_id()).await?;
4899 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4900 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4901
4902 if !session.is_staff() && !has_language_models_feature_flag {
4903 Err(anyhow!("permission denied"))?
4904 }
4905
4906 let user_id = session.user_id();
4907 let user = db
4908 .get_user_by_id(user_id)
4909 .await?
4910 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4911
4912 if user.accepted_tos_at.is_none() {
4913 Err(anyhow!("terms of service not accepted"))?
4914 }
4915
4916 let mut account_created_at = user.created_at;
4917 if let Some(github_created_at) = user.github_user_created_at {
4918 account_created_at = account_created_at.min(github_created_at);
4919 }
4920 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4921 Err(anyhow!("account too young"))?
4922 }
4923 let token = LlmTokenClaims::create(
4924 user.id,
4925 user.github_login.clone(),
4926 session.is_staff(),
4927 has_llm_closed_beta_feature_flag,
4928 session.has_llm_subscription(&db).await?,
4929 session.current_plan(&db).await?,
4930 &session.app_state.config,
4931 )?;
4932 response.send(proto::GetLlmTokenResponse { token })?;
4933 Ok(())
4934}
4935
4936fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4937 let message = match message {
4938 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4939 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4940 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4941 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4942 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4943 code: frame.code.into(),
4944 reason: frame.reason,
4945 })),
4946 // We should never receive a frame while reading the message, according
4947 // to the `tungstenite` maintainers:
4948 //
4949 // > It cannot occur when you read messages from the WebSocket, but it
4950 // > can be used when you want to send the raw frames (e.g. you want to
4951 // > send the frames to the WebSocket without composing the full message first).
4952 // >
4953 // > — https://github.com/snapview/tungstenite-rs/issues/268
4954 TungsteniteMessage::Frame(_) => {
4955 bail!("received an unexpected frame while reading the message")
4956 }
4957 };
4958
4959 Ok(message)
4960}
4961
4962fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4963 match message {
4964 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4965 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4966 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4967 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4968 AxumMessage::Close(frame) => {
4969 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4970 code: frame.code.into(),
4971 reason: frame.reason,
4972 }))
4973 }
4974 }
4975}
4976
4977fn notify_membership_updated(
4978 connection_pool: &mut ConnectionPool,
4979 result: MembershipUpdated,
4980 user_id: UserId,
4981 peer: &Peer,
4982) {
4983 for membership in &result.new_channels.channel_memberships {
4984 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4985 }
4986 for channel_id in &result.removed_channels {
4987 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4988 }
4989
4990 let user_channels_update = proto::UpdateUserChannels {
4991 channel_memberships: result
4992 .new_channels
4993 .channel_memberships
4994 .iter()
4995 .map(|cm| proto::ChannelMembership {
4996 channel_id: cm.channel_id.to_proto(),
4997 role: cm.role.into(),
4998 })
4999 .collect(),
5000 ..Default::default()
5001 };
5002
5003 let mut update = build_channels_update(result.new_channels);
5004 update.delete_channels = result
5005 .removed_channels
5006 .into_iter()
5007 .map(|id| id.to_proto())
5008 .collect();
5009 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5010
5011 for connection_id in connection_pool.user_connection_ids(user_id) {
5012 peer.send(connection_id, user_channels_update.clone())
5013 .trace_err();
5014 peer.send(connection_id, update.clone()).trace_err();
5015 }
5016}
5017
5018fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5019 proto::UpdateUserChannels {
5020 channel_memberships: channels
5021 .channel_memberships
5022 .iter()
5023 .map(|m| proto::ChannelMembership {
5024 channel_id: m.channel_id.to_proto(),
5025 role: m.role.into(),
5026 })
5027 .collect(),
5028 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5029 observed_channel_message_id: channels.observed_channel_messages.clone(),
5030 }
5031}
5032
5033fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5034 let mut update = proto::UpdateChannels::default();
5035
5036 for channel in channels.channels {
5037 update.channels.push(channel.to_proto());
5038 }
5039
5040 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5041 update.latest_channel_message_ids = channels.latest_channel_messages;
5042
5043 for (channel_id, participants) in channels.channel_participants {
5044 update
5045 .channel_participants
5046 .push(proto::ChannelParticipants {
5047 channel_id: channel_id.to_proto(),
5048 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5049 });
5050 }
5051
5052 for channel in channels.invited_channels {
5053 update.channel_invitations.push(channel.to_proto());
5054 }
5055
5056 update.hosted_projects = channels.hosted_projects;
5057 update
5058}
5059
5060fn build_initial_contacts_update(
5061 contacts: Vec<db::Contact>,
5062 pool: &ConnectionPool,
5063) -> proto::UpdateContacts {
5064 let mut update = proto::UpdateContacts::default();
5065
5066 for contact in contacts {
5067 match contact {
5068 db::Contact::Accepted { user_id, busy } => {
5069 update.contacts.push(contact_for_user(user_id, busy, pool));
5070 }
5071 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5072 db::Contact::Incoming { user_id } => {
5073 update
5074 .incoming_requests
5075 .push(proto::IncomingContactRequest {
5076 requester_id: user_id.to_proto(),
5077 })
5078 }
5079 }
5080 }
5081
5082 update
5083}
5084
5085fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5086 proto::Contact {
5087 user_id: user_id.to_proto(),
5088 online: pool.is_user_online(user_id),
5089 busy,
5090 }
5091}
5092
5093fn room_updated(room: &proto::Room, peer: &Peer) {
5094 broadcast(
5095 None,
5096 room.participants
5097 .iter()
5098 .filter_map(|participant| Some(participant.peer_id?.into())),
5099 |peer_id| {
5100 peer.send(
5101 peer_id,
5102 proto::RoomUpdated {
5103 room: Some(room.clone()),
5104 },
5105 )
5106 },
5107 );
5108}
5109
5110fn channel_updated(
5111 channel: &db::channel::Model,
5112 room: &proto::Room,
5113 peer: &Peer,
5114 pool: &ConnectionPool,
5115) {
5116 let participants = room
5117 .participants
5118 .iter()
5119 .map(|p| p.user_id)
5120 .collect::<Vec<_>>();
5121
5122 broadcast(
5123 None,
5124 pool.channel_connection_ids(channel.root_id())
5125 .filter_map(|(channel_id, role)| {
5126 role.can_see_channel(channel.visibility)
5127 .then_some(channel_id)
5128 }),
5129 |peer_id| {
5130 peer.send(
5131 peer_id,
5132 proto::UpdateChannels {
5133 channel_participants: vec![proto::ChannelParticipants {
5134 channel_id: channel.id.to_proto(),
5135 participant_user_ids: participants.clone(),
5136 }],
5137 ..Default::default()
5138 },
5139 )
5140 },
5141 );
5142}
5143
5144async fn send_dev_server_projects_update(
5145 user_id: UserId,
5146 mut status: proto::DevServerProjectsUpdate,
5147 session: &Session,
5148) {
5149 let pool = session.connection_pool().await;
5150 for dev_server in &mut status.dev_servers {
5151 dev_server.status =
5152 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5153 }
5154 let connections = pool.user_connection_ids(user_id);
5155 for connection_id in connections {
5156 session.peer.send(connection_id, status.clone()).trace_err();
5157 }
5158}
5159
5160async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5161 let db = session.db().await;
5162
5163 let contacts = db.get_contacts(user_id).await?;
5164 let busy = db.is_user_busy(user_id).await?;
5165
5166 let pool = session.connection_pool().await;
5167 let updated_contact = contact_for_user(user_id, busy, &pool);
5168 for contact in contacts {
5169 if let db::Contact::Accepted {
5170 user_id: contact_user_id,
5171 ..
5172 } = contact
5173 {
5174 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5175 session
5176 .peer
5177 .send(
5178 contact_conn_id,
5179 proto::UpdateContacts {
5180 contacts: vec![updated_contact.clone()],
5181 remove_contacts: Default::default(),
5182 incoming_requests: Default::default(),
5183 remove_incoming_requests: Default::default(),
5184 outgoing_requests: Default::default(),
5185 remove_outgoing_requests: Default::default(),
5186 },
5187 )
5188 .trace_err();
5189 }
5190 }
5191 }
5192 Ok(())
5193}
5194
5195async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5196 log::info!("lost dev server connection, unsharing projects");
5197 let project_ids = session
5198 .db()
5199 .await
5200 .get_stale_dev_server_projects(session.connection_id)
5201 .await?;
5202
5203 for project_id in project_ids {
5204 // not unshare re-checks the connection ids match, so we get away with no transaction
5205 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5206 }
5207
5208 let user_id = session.dev_server().user_id;
5209 let update = session
5210 .db()
5211 .await
5212 .dev_server_projects_update(user_id)
5213 .await?;
5214
5215 send_dev_server_projects_update(user_id, update, session).await;
5216
5217 Ok(())
5218}
5219
5220async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5221 let mut contacts_to_update = HashSet::default();
5222
5223 let room_id;
5224 let canceled_calls_to_user_ids;
5225 let live_kit_room;
5226 let delete_live_kit_room;
5227 let room;
5228 let channel;
5229
5230 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5231 contacts_to_update.insert(session.user_id());
5232
5233 for project in left_room.left_projects.values() {
5234 project_left(project, session);
5235 }
5236
5237 room_id = RoomId::from_proto(left_room.room.id);
5238 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5239 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5240 delete_live_kit_room = left_room.deleted;
5241 room = mem::take(&mut left_room.room);
5242 channel = mem::take(&mut left_room.channel);
5243
5244 room_updated(&room, &session.peer);
5245 } else {
5246 return Ok(());
5247 }
5248
5249 if let Some(channel) = channel {
5250 channel_updated(
5251 &channel,
5252 &room,
5253 &session.peer,
5254 &*session.connection_pool().await,
5255 );
5256 }
5257
5258 {
5259 let pool = session.connection_pool().await;
5260 for canceled_user_id in canceled_calls_to_user_ids {
5261 for connection_id in pool.user_connection_ids(canceled_user_id) {
5262 session
5263 .peer
5264 .send(
5265 connection_id,
5266 proto::CallCanceled {
5267 room_id: room_id.to_proto(),
5268 },
5269 )
5270 .trace_err();
5271 }
5272 contacts_to_update.insert(canceled_user_id);
5273 }
5274 }
5275
5276 for contact_user_id in contacts_to_update {
5277 update_user_contacts(contact_user_id, session).await?;
5278 }
5279
5280 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5281 live_kit
5282 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5283 .await
5284 .trace_err();
5285
5286 if delete_live_kit_room {
5287 live_kit.delete_room(live_kit_room).await.trace_err();
5288 }
5289 }
5290
5291 Ok(())
5292}
5293
5294async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5295 let left_channel_buffers = session
5296 .db()
5297 .await
5298 .leave_channel_buffers(session.connection_id)
5299 .await?;
5300
5301 for left_buffer in left_channel_buffers {
5302 channel_buffer_updated(
5303 session.connection_id,
5304 left_buffer.connections,
5305 &proto::UpdateChannelBufferCollaborators {
5306 channel_id: left_buffer.channel_id.to_proto(),
5307 collaborators: left_buffer.collaborators,
5308 },
5309 &session.peer,
5310 );
5311 }
5312
5313 Ok(())
5314}
5315
5316fn project_left(project: &db::LeftProject, session: &UserSession) {
5317 for connection_id in &project.connection_ids {
5318 if project.should_unshare {
5319 session
5320 .peer
5321 .send(
5322 *connection_id,
5323 proto::UnshareProject {
5324 project_id: project.id.to_proto(),
5325 },
5326 )
5327 .trace_err();
5328 } else {
5329 session
5330 .peer
5331 .send(
5332 *connection_id,
5333 proto::RemoveProjectCollaborator {
5334 project_id: project.id.to_proto(),
5335 peer_id: Some(session.connection_id.into()),
5336 },
5337 )
5338 .trace_err();
5339 }
5340 }
5341}
5342
5343pub trait ResultExt {
5344 type Ok;
5345
5346 fn trace_err(self) -> Option<Self::Ok>;
5347}
5348
5349impl<T, E> ResultExt for Result<T, E>
5350where
5351 E: std::fmt::Debug,
5352{
5353 type Ok = T;
5354
5355 #[track_caller]
5356 fn trace_err(self) -> Option<T> {
5357 match self {
5358 Ok(value) => Some(value),
5359 Err(error) => {
5360 tracing::error!("{:?}", error);
5361 None
5362 }
5363 }
5364 }
5365}