1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use isahc_http_client::IsahcHttpClient;
40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn has_llm_subscription(
195 &self,
196 db: &MutexGuard<'_, DbHandle>,
197 ) -> anyhow::Result<bool> {
198 if self.is_staff() {
199 return Ok(true);
200 }
201
202 let Some(user_id) = self.user_id() else {
203 return Ok(false);
204 };
205
206 Ok(db.has_active_billing_subscription(user_id).await?)
207 }
208
209 pub async fn current_plan(
210 &self,
211 _db: &MutexGuard<'_, DbHandle>,
212 ) -> anyhow::Result<proto::Plan> {
213 if self.is_staff() {
214 Ok(proto::Plan::ZedPro)
215 } else {
216 Ok(proto::Plan::Free)
217 }
218 }
219
220 fn dev_server_id(&self) -> Option<DevServerId> {
221 match &self.principal {
222 Principal::User(_) | Principal::Impersonated { .. } => None,
223 Principal::DevServer(dev_server) => Some(dev_server.id),
224 }
225 }
226
227 fn principal_id(&self) -> PrincipalId {
228 match &self.principal {
229 Principal::User(user) => PrincipalId::UserId(user.id),
230 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
231 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
232 }
233 }
234}
235
236impl Debug for Session {
237 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
238 let mut result = f.debug_struct("Session");
239 match &self.principal {
240 Principal::User(user) => {
241 result.field("user", &user.github_login);
242 }
243 Principal::Impersonated { user, admin } => {
244 result.field("user", &user.github_login);
245 result.field("impersonator", &admin.github_login);
246 }
247 Principal::DevServer(dev_server) => {
248 result.field("dev_server", &dev_server.id);
249 }
250 }
251 result.field("connection_id", &self.connection_id).finish()
252 }
253}
254
255struct UserSession(Session);
256
257impl UserSession {
258 pub fn new(s: Session) -> Option<Self> {
259 s.user_id().map(|_| UserSession(s))
260 }
261 pub fn user_id(&self) -> UserId {
262 self.0.user_id().unwrap()
263 }
264
265 pub fn email(&self) -> Option<String> {
266 match &self.0.principal {
267 Principal::User(user) => user.email_address.clone(),
268 Principal::Impersonated { user, .. } => user.email_address.clone(),
269 Principal::DevServer(..) => None,
270 }
271 }
272}
273
274impl Deref for UserSession {
275 type Target = Session;
276
277 fn deref(&self) -> &Self::Target {
278 &self.0
279 }
280}
281impl DerefMut for UserSession {
282 fn deref_mut(&mut self) -> &mut Self::Target {
283 &mut self.0
284 }
285}
286
287struct DevServerSession(Session);
288
289impl DevServerSession {
290 pub fn new(s: Session) -> Option<Self> {
291 s.dev_server_id().map(|_| DevServerSession(s))
292 }
293 pub fn dev_server_id(&self) -> DevServerId {
294 self.0.dev_server_id().unwrap()
295 }
296
297 fn dev_server(&self) -> &dev_server::Model {
298 match &self.0.principal {
299 Principal::DevServer(dev_server) => dev_server,
300 _ => unreachable!(),
301 }
302 }
303}
304
305impl Deref for DevServerSession {
306 type Target = Session;
307
308 fn deref(&self) -> &Self::Target {
309 &self.0
310 }
311}
312impl DerefMut for DevServerSession {
313 fn deref_mut(&mut self) -> &mut Self::Target {
314 &mut self.0
315 }
316}
317
318fn user_handler<M: RequestMessage, Fut>(
319 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
321where
322 Fut: Send + Future<Output = Result<()>>,
323{
324 let handler = Arc::new(handler);
325 move |message, response, session| {
326 let handler = handler.clone();
327 Box::pin(async move {
328 if let Some(user_session) = session.for_user() {
329 Ok(handler(message, response, user_session).await?)
330 } else {
331 Err(Error::Internal(anyhow!(
332 "must be a user to call {}",
333 M::NAME
334 )))
335 }
336 })
337 }
338}
339
340fn dev_server_handler<M: RequestMessage, Fut>(
341 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
343where
344 Fut: Send + Future<Output = Result<()>>,
345{
346 let handler = Arc::new(handler);
347 move |message, response, session| {
348 let handler = handler.clone();
349 Box::pin(async move {
350 if let Some(dev_server_session) = session.for_dev_server() {
351 Ok(handler(message, response, dev_server_session).await?)
352 } else {
353 Err(Error::Internal(anyhow!(
354 "must be a dev server to call {}",
355 M::NAME
356 )))
357 }
358 })
359 }
360}
361
362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
363 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
365where
366 InnertRetFut: Send + Future<Output = Result<()>>,
367{
368 let handler = Arc::new(handler);
369 move |message, session| {
370 let handler = handler.clone();
371 Box::pin(async move {
372 if let Some(user_session) = session.for_user() {
373 Ok(handler(message, user_session).await?)
374 } else {
375 Err(Error::Internal(anyhow!(
376 "must be a user to call {}",
377 M::NAME
378 )))
379 }
380 })
381 }
382}
383
384struct DbHandle(Arc<Database>);
385
386impl Deref for DbHandle {
387 type Target = Database;
388
389 fn deref(&self) -> &Self::Target {
390 self.0.as_ref()
391 }
392}
393
394pub struct Server {
395 id: parking_lot::Mutex<ServerId>,
396 peer: Arc<Peer>,
397 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
398 app_state: Arc<AppState>,
399 handlers: HashMap<TypeId, MessageHandler>,
400 teardown: watch::Sender<bool>,
401}
402
403pub(crate) struct ConnectionPoolGuard<'a> {
404 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
405 _not_send: PhantomData<Rc<()>>,
406}
407
408#[derive(Serialize)]
409pub struct ServerSnapshot<'a> {
410 peer: &'a Peer,
411 #[serde(serialize_with = "serialize_deref")]
412 connection_pool: ConnectionPoolGuard<'a>,
413}
414
415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
416where
417 S: Serializer,
418 T: Deref<Target = U>,
419 U: Serialize,
420{
421 Serialize::serialize(value.deref(), serializer)
422}
423
424impl Server {
425 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
426 let mut server = Self {
427 id: parking_lot::Mutex::new(id),
428 peer: Peer::new(id.0 as u32),
429 app_state: app_state.clone(),
430 connection_pool: Default::default(),
431 handlers: Default::default(),
432 teardown: watch::channel(false).0,
433 };
434
435 server
436 .add_request_handler(ping)
437 .add_request_handler(user_handler(create_room))
438 .add_request_handler(user_handler(join_room))
439 .add_request_handler(user_handler(rejoin_room))
440 .add_request_handler(user_handler(leave_room))
441 .add_request_handler(user_handler(set_room_participant_role))
442 .add_request_handler(user_handler(call))
443 .add_request_handler(user_handler(cancel_call))
444 .add_message_handler(user_message_handler(decline_call))
445 .add_request_handler(user_handler(update_participant_location))
446 .add_request_handler(user_handler(share_project))
447 .add_message_handler(unshare_project)
448 .add_request_handler(user_handler(join_project))
449 .add_request_handler(user_handler(join_hosted_project))
450 .add_request_handler(user_handler(rejoin_dev_server_projects))
451 .add_request_handler(user_handler(create_dev_server_project))
452 .add_request_handler(user_handler(update_dev_server_project))
453 .add_request_handler(user_handler(delete_dev_server_project))
454 .add_request_handler(user_handler(create_dev_server))
455 .add_request_handler(user_handler(regenerate_dev_server_token))
456 .add_request_handler(user_handler(rename_dev_server))
457 .add_request_handler(user_handler(delete_dev_server))
458 .add_request_handler(user_handler(list_remote_directory))
459 .add_request_handler(dev_server_handler(share_dev_server_project))
460 .add_request_handler(dev_server_handler(shutdown_dev_server))
461 .add_request_handler(dev_server_handler(reconnect_dev_server))
462 .add_message_handler(user_message_handler(leave_project))
463 .add_request_handler(update_project)
464 .add_request_handler(update_worktree)
465 .add_message_handler(start_language_server)
466 .add_message_handler(update_language_server)
467 .add_message_handler(update_diagnostic_summary)
468 .add_message_handler(update_worktree_settings)
469 .add_request_handler(user_handler(
470 forward_project_request_for_owner::<proto::TaskContextForLocation>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetHover>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetDefinition>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetTypeDefinition>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetReferences>,
483 ))
484 .add_request_handler(user_handler(forward_find_search_candidates_request))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::GetDocumentHighlights>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::GetProjectSymbols>,
490 ))
491 .add_request_handler(user_handler(
492 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
493 ))
494 .add_request_handler(user_handler(
495 forward_read_only_project_request::<proto::OpenBufferById>,
496 ))
497 .add_request_handler(user_handler(
498 forward_read_only_project_request::<proto::SynchronizeBuffers>,
499 ))
500 .add_request_handler(user_handler(
501 forward_read_only_project_request::<proto::InlayHints>,
502 ))
503 .add_request_handler(user_handler(
504 forward_read_only_project_request::<proto::ResolveInlayHint>,
505 ))
506 .add_request_handler(user_handler(
507 forward_read_only_project_request::<proto::OpenBufferByPath>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::GetCompletions>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::OpenNewBuffer>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::GetCodeActions>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::ApplyCodeAction>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::PrepareRename>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::PerformRename>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ReloadBuffers>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::FormatBuffers>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::CreateProjectEntry>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::RenameProjectEntry>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::CopyProjectEntry>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::DeleteProjectEntry>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::ExpandProjectEntry>,
553 ))
554 .add_request_handler(user_handler(
555 forward_mutating_project_request::<proto::OnTypeFormatting>,
556 ))
557 .add_request_handler(user_handler(
558 forward_mutating_project_request::<proto::SaveBuffer>,
559 ))
560 .add_request_handler(user_handler(
561 forward_mutating_project_request::<proto::BlameBuffer>,
562 ))
563 .add_request_handler(user_handler(
564 forward_mutating_project_request::<proto::MultiLspQuery>,
565 ))
566 .add_request_handler(user_handler(
567 forward_mutating_project_request::<proto::RestartLanguageServers>,
568 ))
569 .add_request_handler(user_handler(
570 forward_mutating_project_request::<proto::LinkedEditingRange>,
571 ))
572 .add_message_handler(create_buffer_for_peer)
573 .add_request_handler(update_buffer)
574 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
575 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
576 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
577 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
578 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
579 .add_request_handler(get_users)
580 .add_request_handler(user_handler(fuzzy_search_users))
581 .add_request_handler(user_handler(request_contact))
582 .add_request_handler(user_handler(remove_contact))
583 .add_request_handler(user_handler(respond_to_contact_request))
584 .add_message_handler(subscribe_to_channels)
585 .add_request_handler(user_handler(create_channel))
586 .add_request_handler(user_handler(delete_channel))
587 .add_request_handler(user_handler(invite_channel_member))
588 .add_request_handler(user_handler(remove_channel_member))
589 .add_request_handler(user_handler(set_channel_member_role))
590 .add_request_handler(user_handler(set_channel_visibility))
591 .add_request_handler(user_handler(rename_channel))
592 .add_request_handler(user_handler(join_channel_buffer))
593 .add_request_handler(user_handler(leave_channel_buffer))
594 .add_message_handler(user_message_handler(update_channel_buffer))
595 .add_request_handler(user_handler(rejoin_channel_buffers))
596 .add_request_handler(user_handler(get_channel_members))
597 .add_request_handler(user_handler(respond_to_channel_invite))
598 .add_request_handler(user_handler(join_channel))
599 .add_request_handler(user_handler(join_channel_chat))
600 .add_message_handler(user_message_handler(leave_channel_chat))
601 .add_request_handler(user_handler(send_channel_message))
602 .add_request_handler(user_handler(remove_channel_message))
603 .add_request_handler(user_handler(update_channel_message))
604 .add_request_handler(user_handler(get_channel_messages))
605 .add_request_handler(user_handler(get_channel_messages_by_id))
606 .add_request_handler(user_handler(get_notifications))
607 .add_request_handler(user_handler(mark_notification_as_read))
608 .add_request_handler(user_handler(move_channel))
609 .add_request_handler(user_handler(follow))
610 .add_message_handler(user_message_handler(unfollow))
611 .add_message_handler(user_message_handler(update_followers))
612 .add_request_handler(user_handler(get_private_user_info))
613 .add_request_handler(user_handler(get_llm_api_token))
614 .add_request_handler(user_handler(accept_terms_of_service))
615 .add_message_handler(user_message_handler(acknowledge_channel_message))
616 .add_message_handler(user_message_handler(acknowledge_buffer_version))
617 .add_request_handler(user_handler(get_supermaven_api_key))
618 .add_request_handler(user_handler(
619 forward_mutating_project_request::<proto::OpenContext>,
620 ))
621 .add_request_handler(user_handler(
622 forward_mutating_project_request::<proto::CreateContext>,
623 ))
624 .add_request_handler(user_handler(
625 forward_mutating_project_request::<proto::SynchronizeContexts>,
626 ))
627 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
628 .add_message_handler(update_context)
629 .add_request_handler({
630 let app_state = app_state.clone();
631 move |request, response, session| {
632 let app_state = app_state.clone();
633 async move {
634 count_language_model_tokens(request, response, session, &app_state.config)
635 .await
636 }
637 }
638 })
639 .add_request_handler({
640 user_handler(move |request, response, session| {
641 get_cached_embeddings(request, response, session)
642 })
643 })
644 .add_request_handler({
645 let app_state = app_state.clone();
646 user_handler(move |request, response, session| {
647 compute_embeddings(
648 request,
649 response,
650 session,
651 app_state.config.openai_api_key.clone(),
652 )
653 })
654 });
655
656 Arc::new(server)
657 }
658
659 pub async fn start(&self) -> Result<()> {
660 let server_id = *self.id.lock();
661 let app_state = self.app_state.clone();
662 let peer = self.peer.clone();
663 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
664 let pool = self.connection_pool.clone();
665 let live_kit_client = self.app_state.live_kit_client.clone();
666
667 let span = info_span!("start server");
668 self.app_state.executor.spawn_detached(
669 async move {
670 tracing::info!("waiting for cleanup timeout");
671 timeout.await;
672 tracing::info!("cleanup timeout expired, retrieving stale rooms");
673 if let Some((room_ids, channel_ids)) = app_state
674 .db
675 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
676 .await
677 .trace_err()
678 {
679 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
680 tracing::info!(
681 stale_channel_buffer_count = channel_ids.len(),
682 "retrieved stale channel buffers"
683 );
684
685 for channel_id in channel_ids {
686 if let Some(refreshed_channel_buffer) = app_state
687 .db
688 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
689 .await
690 .trace_err()
691 {
692 for connection_id in refreshed_channel_buffer.connection_ids {
693 peer.send(
694 connection_id,
695 proto::UpdateChannelBufferCollaborators {
696 channel_id: channel_id.to_proto(),
697 collaborators: refreshed_channel_buffer
698 .collaborators
699 .clone(),
700 },
701 )
702 .trace_err();
703 }
704 }
705 }
706
707 for room_id in room_ids {
708 let mut contacts_to_update = HashSet::default();
709 let mut canceled_calls_to_user_ids = Vec::new();
710 let mut live_kit_room = String::new();
711 let mut delete_live_kit_room = false;
712
713 if let Some(mut refreshed_room) = app_state
714 .db
715 .clear_stale_room_participants(room_id, server_id)
716 .await
717 .trace_err()
718 {
719 tracing::info!(
720 room_id = room_id.0,
721 new_participant_count = refreshed_room.room.participants.len(),
722 "refreshed room"
723 );
724 room_updated(&refreshed_room.room, &peer);
725 if let Some(channel) = refreshed_room.channel.as_ref() {
726 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
727 }
728 contacts_to_update
729 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
730 contacts_to_update
731 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
732 canceled_calls_to_user_ids =
733 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
734 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
735 delete_live_kit_room = refreshed_room.room.participants.is_empty();
736 }
737
738 {
739 let pool = pool.lock();
740 for canceled_user_id in canceled_calls_to_user_ids {
741 for connection_id in pool.user_connection_ids(canceled_user_id) {
742 peer.send(
743 connection_id,
744 proto::CallCanceled {
745 room_id: room_id.to_proto(),
746 },
747 )
748 .trace_err();
749 }
750 }
751 }
752
753 for user_id in contacts_to_update {
754 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
755 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
756 if let Some((busy, contacts)) = busy.zip(contacts) {
757 let pool = pool.lock();
758 let updated_contact = contact_for_user(user_id, busy, &pool);
759 for contact in contacts {
760 if let db::Contact::Accepted {
761 user_id: contact_user_id,
762 ..
763 } = contact
764 {
765 for contact_conn_id in
766 pool.user_connection_ids(contact_user_id)
767 {
768 peer.send(
769 contact_conn_id,
770 proto::UpdateContacts {
771 contacts: vec![updated_contact.clone()],
772 remove_contacts: Default::default(),
773 incoming_requests: Default::default(),
774 remove_incoming_requests: Default::default(),
775 outgoing_requests: Default::default(),
776 remove_outgoing_requests: Default::default(),
777 },
778 )
779 .trace_err();
780 }
781 }
782 }
783 }
784 }
785
786 if let Some(live_kit) = live_kit_client.as_ref() {
787 if delete_live_kit_room {
788 live_kit.delete_room(live_kit_room).await.trace_err();
789 }
790 }
791 }
792 }
793
794 app_state
795 .db
796 .delete_stale_servers(&app_state.config.zed_environment, server_id)
797 .await
798 .trace_err();
799 }
800 .instrument(span),
801 );
802 Ok(())
803 }
804
805 pub fn teardown(&self) {
806 self.peer.teardown();
807 self.connection_pool.lock().reset();
808 let _ = self.teardown.send(true);
809 }
810
811 #[cfg(test)]
812 pub fn reset(&self, id: ServerId) {
813 self.teardown();
814 *self.id.lock() = id;
815 self.peer.reset(id.0 as u32);
816 let _ = self.teardown.send(false);
817 }
818
819 #[cfg(test)]
820 pub fn id(&self) -> ServerId {
821 *self.id.lock()
822 }
823
824 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
825 where
826 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
827 Fut: 'static + Send + Future<Output = Result<()>>,
828 M: EnvelopedMessage,
829 {
830 let prev_handler = self.handlers.insert(
831 TypeId::of::<M>(),
832 Box::new(move |envelope, session| {
833 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
834 let received_at = envelope.received_at;
835 tracing::info!("message received");
836 let start_time = Instant::now();
837 let future = (handler)(*envelope, session);
838 async move {
839 let result = future.await;
840 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
841 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
842 let queue_duration_ms = total_duration_ms - processing_duration_ms;
843 let payload_type = M::NAME;
844
845 match result {
846 Err(error) => {
847 tracing::error!(
848 ?error,
849 total_duration_ms,
850 processing_duration_ms,
851 queue_duration_ms,
852 payload_type,
853 "error handling message"
854 )
855 }
856 Ok(()) => tracing::info!(
857 total_duration_ms,
858 processing_duration_ms,
859 queue_duration_ms,
860 "finished handling message"
861 ),
862 }
863 }
864 .boxed()
865 }),
866 );
867 if prev_handler.is_some() {
868 panic!("registered a handler for the same message twice");
869 }
870 self
871 }
872
873 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
874 where
875 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
876 Fut: 'static + Send + Future<Output = Result<()>>,
877 M: EnvelopedMessage,
878 {
879 self.add_handler(move |envelope, session| handler(envelope.payload, session));
880 self
881 }
882
883 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
884 where
885 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
886 Fut: Send + Future<Output = Result<()>>,
887 M: RequestMessage,
888 {
889 let handler = Arc::new(handler);
890 self.add_handler(move |envelope, session| {
891 let receipt = envelope.receipt();
892 let handler = handler.clone();
893 async move {
894 let peer = session.peer.clone();
895 let responded = Arc::new(AtomicBool::default());
896 let response = Response {
897 peer: peer.clone(),
898 responded: responded.clone(),
899 receipt,
900 };
901 match (handler)(envelope.payload, response, session).await {
902 Ok(()) => {
903 if responded.load(std::sync::atomic::Ordering::SeqCst) {
904 Ok(())
905 } else {
906 Err(anyhow!("handler did not send a response"))?
907 }
908 }
909 Err(error) => {
910 let proto_err = match &error {
911 Error::Internal(err) => err.to_proto(),
912 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
913 };
914 peer.respond_with_error(receipt, proto_err)?;
915 Err(error)
916 }
917 }
918 }
919 })
920 }
921
922 #[allow(clippy::too_many_arguments)]
923 pub fn handle_connection(
924 self: &Arc<Self>,
925 connection: Connection,
926 address: String,
927 principal: Principal,
928 zed_version: ZedVersion,
929 geoip_country_code: Option<String>,
930 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
931 executor: Executor,
932 ) -> impl Future<Output = ()> {
933 let this = self.clone();
934 let span = info_span!("handle connection", %address,
935 connection_id=field::Empty,
936 user_id=field::Empty,
937 login=field::Empty,
938 impersonator=field::Empty,
939 dev_server_id=field::Empty,
940 geoip_country_code=field::Empty
941 );
942 principal.update_span(&span);
943 if let Some(country_code) = geoip_country_code.as_ref() {
944 span.record("geoip_country_code", country_code);
945 }
946
947 let mut teardown = self.teardown.subscribe();
948 async move {
949 if *teardown.borrow() {
950 tracing::error!("server is tearing down");
951 return
952 }
953 let (connection_id, handle_io, mut incoming_rx) = this
954 .peer
955 .add_connection(connection, {
956 let executor = executor.clone();
957 move |duration| executor.sleep(duration)
958 });
959 tracing::Span::current().record("connection_id", format!("{}", connection_id));
960
961 tracing::info!("connection opened");
962
963 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
964 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
965 Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
966 Err(error) => {
967 tracing::error!(?error, "failed to create HTTP client");
968 return;
969 }
970 };
971
972 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
973 supermaven_admin_api_key.to_string(),
974 http_client.clone(),
975 )));
976
977 let session = Session {
978 principal: principal.clone(),
979 connection_id,
980 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
981 peer: this.peer.clone(),
982 connection_pool: this.connection_pool.clone(),
983 app_state: this.app_state.clone(),
984 http_client,
985 geoip_country_code,
986 _executor: executor.clone(),
987 supermaven_client,
988 };
989
990 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
991 tracing::error!(?error, "failed to send initial client update");
992 return;
993 }
994
995 let handle_io = handle_io.fuse();
996 futures::pin_mut!(handle_io);
997
998 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
999 // This prevents deadlocks when e.g., client A performs a request to client B and
1000 // client B performs a request to client A. If both clients stop processing further
1001 // messages until their respective request completes, they won't have a chance to
1002 // respond to the other client's request and cause a deadlock.
1003 //
1004 // This arrangement ensures we will attempt to process earlier messages first, but fall
1005 // back to processing messages arrived later in the spirit of making progress.
1006 let mut foreground_message_handlers = FuturesUnordered::new();
1007 let concurrent_handlers = Arc::new(Semaphore::new(256));
1008 loop {
1009 let next_message = async {
1010 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1011 let message = incoming_rx.next().await;
1012 (permit, message)
1013 }.fuse();
1014 futures::pin_mut!(next_message);
1015 futures::select_biased! {
1016 _ = teardown.changed().fuse() => return,
1017 result = handle_io => {
1018 if let Err(error) = result {
1019 tracing::error!(?error, "error handling I/O");
1020 }
1021 break;
1022 }
1023 _ = foreground_message_handlers.next() => {}
1024 next_message = next_message => {
1025 let (permit, message) = next_message;
1026 if let Some(message) = message {
1027 let type_name = message.payload_type_name();
1028 // note: we copy all the fields from the parent span so we can query them in the logs.
1029 // (https://github.com/tokio-rs/tracing/issues/2670).
1030 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1031 user_id=field::Empty,
1032 login=field::Empty,
1033 impersonator=field::Empty,
1034 dev_server_id=field::Empty
1035 );
1036 principal.update_span(&span);
1037 let span_enter = span.enter();
1038 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1039 let is_background = message.is_background();
1040 let handle_message = (handler)(message, session.clone());
1041 drop(span_enter);
1042
1043 let handle_message = async move {
1044 handle_message.await;
1045 drop(permit);
1046 }.instrument(span);
1047 if is_background {
1048 executor.spawn_detached(handle_message);
1049 } else {
1050 foreground_message_handlers.push(handle_message);
1051 }
1052 } else {
1053 tracing::error!("no message handler");
1054 }
1055 } else {
1056 tracing::info!("connection closed");
1057 break;
1058 }
1059 }
1060 }
1061 }
1062
1063 drop(foreground_message_handlers);
1064 tracing::info!("signing out");
1065 if let Err(error) = connection_lost(session, teardown, executor).await {
1066 tracing::error!(?error, "error signing out");
1067 }
1068
1069 }.instrument(span)
1070 }
1071
1072 async fn send_initial_client_update(
1073 &self,
1074 connection_id: ConnectionId,
1075 principal: &Principal,
1076 zed_version: ZedVersion,
1077 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1078 session: &Session,
1079 ) -> Result<()> {
1080 self.peer.send(
1081 connection_id,
1082 proto::Hello {
1083 peer_id: Some(connection_id.into()),
1084 },
1085 )?;
1086 tracing::info!("sent hello message");
1087 if let Some(send_connection_id) = send_connection_id.take() {
1088 let _ = send_connection_id.send(connection_id);
1089 }
1090
1091 match principal {
1092 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1093 if !user.connected_once {
1094 self.peer.send(connection_id, proto::ShowContacts {})?;
1095 self.app_state
1096 .db
1097 .set_user_connected_once(user.id, true)
1098 .await?;
1099 }
1100
1101 update_user_plan(user.id, session).await?;
1102
1103 let (contacts, dev_server_projects) = future::try_join(
1104 self.app_state.db.get_contacts(user.id),
1105 self.app_state.db.dev_server_projects_update(user.id),
1106 )
1107 .await?;
1108
1109 {
1110 let mut pool = self.connection_pool.lock();
1111 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112 self.peer.send(
1113 connection_id,
1114 build_initial_contacts_update(contacts, &pool),
1115 )?;
1116 }
1117
1118 if should_auto_subscribe_to_channels(zed_version) {
1119 subscribe_user_to_channels(user.id, session).await?;
1120 }
1121
1122 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1123
1124 if let Some(incoming_call) =
1125 self.app_state.db.incoming_call_for_user(user.id).await?
1126 {
1127 self.peer.send(connection_id, incoming_call)?;
1128 }
1129
1130 update_user_contacts(user.id, session).await?;
1131 }
1132 Principal::DevServer(dev_server) => {
1133 {
1134 let mut pool = self.connection_pool.lock();
1135 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1136 {
1137 self.peer.send(
1138 stale_connection_id,
1139 proto::ShutdownDevServer {
1140 reason: Some(
1141 "another dev server connected with the same token".to_string(),
1142 ),
1143 },
1144 )?;
1145 pool.remove_connection(stale_connection_id)?;
1146 };
1147 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1148 }
1149
1150 let projects = self
1151 .app_state
1152 .db
1153 .get_projects_for_dev_server(dev_server.id)
1154 .await?;
1155 self.peer
1156 .send(connection_id, proto::DevServerInstructions { projects })?;
1157
1158 let status = self
1159 .app_state
1160 .db
1161 .dev_server_projects_update(dev_server.user_id)
1162 .await?;
1163 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1164 }
1165 }
1166
1167 Ok(())
1168 }
1169
1170 pub async fn invite_code_redeemed(
1171 self: &Arc<Self>,
1172 inviter_id: UserId,
1173 invitee_id: UserId,
1174 ) -> Result<()> {
1175 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1176 if let Some(code) = &user.invite_code {
1177 let pool = self.connection_pool.lock();
1178 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1179 for connection_id in pool.user_connection_ids(inviter_id) {
1180 self.peer.send(
1181 connection_id,
1182 proto::UpdateContacts {
1183 contacts: vec![invitee_contact.clone()],
1184 ..Default::default()
1185 },
1186 )?;
1187 self.peer.send(
1188 connection_id,
1189 proto::UpdateInviteInfo {
1190 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1191 count: user.invite_count as u32,
1192 },
1193 )?;
1194 }
1195 }
1196 }
1197 Ok(())
1198 }
1199
1200 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1201 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1202 if let Some(invite_code) = &user.invite_code {
1203 let pool = self.connection_pool.lock();
1204 for connection_id in pool.user_connection_ids(user_id) {
1205 self.peer.send(
1206 connection_id,
1207 proto::UpdateInviteInfo {
1208 url: format!(
1209 "{}{}",
1210 self.app_state.config.invite_link_prefix, invite_code
1211 ),
1212 count: user.invite_count as u32,
1213 },
1214 )?;
1215 }
1216 }
1217 }
1218 Ok(())
1219 }
1220
1221 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1222 ServerSnapshot {
1223 connection_pool: ConnectionPoolGuard {
1224 guard: self.connection_pool.lock(),
1225 _not_send: PhantomData,
1226 },
1227 peer: &self.peer,
1228 }
1229 }
1230}
1231
1232impl<'a> Deref for ConnectionPoolGuard<'a> {
1233 type Target = ConnectionPool;
1234
1235 fn deref(&self) -> &Self::Target {
1236 &self.guard
1237 }
1238}
1239
1240impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1241 fn deref_mut(&mut self) -> &mut Self::Target {
1242 &mut self.guard
1243 }
1244}
1245
1246impl<'a> Drop for ConnectionPoolGuard<'a> {
1247 fn drop(&mut self) {
1248 #[cfg(test)]
1249 self.check_invariants();
1250 }
1251}
1252
1253fn broadcast<F>(
1254 sender_id: Option<ConnectionId>,
1255 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1256 mut f: F,
1257) where
1258 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1259{
1260 for receiver_id in receiver_ids {
1261 if Some(receiver_id) != sender_id {
1262 if let Err(error) = f(receiver_id) {
1263 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1264 }
1265 }
1266 }
1267}
1268
1269pub struct ProtocolVersion(u32);
1270
1271impl Header for ProtocolVersion {
1272 fn name() -> &'static HeaderName {
1273 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1274 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1275 }
1276
1277 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1278 where
1279 Self: Sized,
1280 I: Iterator<Item = &'i axum::http::HeaderValue>,
1281 {
1282 let version = values
1283 .next()
1284 .ok_or_else(axum::headers::Error::invalid)?
1285 .to_str()
1286 .map_err(|_| axum::headers::Error::invalid())?
1287 .parse()
1288 .map_err(|_| axum::headers::Error::invalid())?;
1289 Ok(Self(version))
1290 }
1291
1292 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1293 values.extend([self.0.to_string().parse().unwrap()]);
1294 }
1295}
1296
1297pub struct AppVersionHeader(SemanticVersion);
1298impl Header for AppVersionHeader {
1299 fn name() -> &'static HeaderName {
1300 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1301 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1302 }
1303
1304 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1305 where
1306 Self: Sized,
1307 I: Iterator<Item = &'i axum::http::HeaderValue>,
1308 {
1309 let version = values
1310 .next()
1311 .ok_or_else(axum::headers::Error::invalid)?
1312 .to_str()
1313 .map_err(|_| axum::headers::Error::invalid())?
1314 .parse()
1315 .map_err(|_| axum::headers::Error::invalid())?;
1316 Ok(Self(version))
1317 }
1318
1319 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1320 values.extend([self.0.to_string().parse().unwrap()]);
1321 }
1322}
1323
1324pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1325 Router::new()
1326 .route("/rpc", get(handle_websocket_request))
1327 .layer(
1328 ServiceBuilder::new()
1329 .layer(Extension(server.app_state.clone()))
1330 .layer(middleware::from_fn(auth::validate_header)),
1331 )
1332 .route("/metrics", get(handle_metrics))
1333 .layer(Extension(server))
1334}
1335
1336pub async fn handle_websocket_request(
1337 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1338 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1339 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1340 Extension(server): Extension<Arc<Server>>,
1341 Extension(principal): Extension<Principal>,
1342 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1343 ws: WebSocketUpgrade,
1344) -> axum::response::Response {
1345 if protocol_version != rpc::PROTOCOL_VERSION {
1346 return (
1347 StatusCode::UPGRADE_REQUIRED,
1348 "client must be upgraded".to_string(),
1349 )
1350 .into_response();
1351 }
1352
1353 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1354 return (
1355 StatusCode::UPGRADE_REQUIRED,
1356 "no version header found".to_string(),
1357 )
1358 .into_response();
1359 };
1360
1361 if !version.can_collaborate() {
1362 return (
1363 StatusCode::UPGRADE_REQUIRED,
1364 "client must be upgraded".to_string(),
1365 )
1366 .into_response();
1367 }
1368
1369 let socket_address = socket_address.to_string();
1370 ws.on_upgrade(move |socket| {
1371 let socket = socket
1372 .map_ok(to_tungstenite_message)
1373 .err_into()
1374 .with(|message| async move { to_axum_message(message) });
1375 let connection = Connection::new(Box::pin(socket));
1376 async move {
1377 server
1378 .handle_connection(
1379 connection,
1380 socket_address,
1381 principal,
1382 version,
1383 country_code_header.map(|header| header.to_string()),
1384 None,
1385 Executor::Production,
1386 )
1387 .await;
1388 }
1389 })
1390}
1391
1392pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1393 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1394 let connections_metric = CONNECTIONS_METRIC
1395 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1396
1397 let connections = server
1398 .connection_pool
1399 .lock()
1400 .connections()
1401 .filter(|connection| !connection.admin)
1402 .count();
1403 connections_metric.set(connections as _);
1404
1405 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1406 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1407 register_int_gauge!(
1408 "shared_projects",
1409 "number of open projects with one or more guests"
1410 )
1411 .unwrap()
1412 });
1413
1414 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1415 shared_projects_metric.set(shared_projects as _);
1416
1417 let encoder = prometheus::TextEncoder::new();
1418 let metric_families = prometheus::gather();
1419 let encoded_metrics = encoder
1420 .encode_to_string(&metric_families)
1421 .map_err(|err| anyhow!("{}", err))?;
1422 Ok(encoded_metrics)
1423}
1424
1425#[instrument(err, skip(executor))]
1426async fn connection_lost(
1427 session: Session,
1428 mut teardown: watch::Receiver<bool>,
1429 executor: Executor,
1430) -> Result<()> {
1431 session.peer.disconnect(session.connection_id);
1432 session
1433 .connection_pool()
1434 .await
1435 .remove_connection(session.connection_id)?;
1436
1437 session
1438 .db()
1439 .await
1440 .connection_lost(session.connection_id)
1441 .await
1442 .trace_err();
1443
1444 futures::select_biased! {
1445 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1446 match &session.principal {
1447 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1448 let session = session.for_user().unwrap();
1449
1450 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1451 leave_room_for_session(&session, session.connection_id).await.trace_err();
1452 leave_channel_buffers_for_session(&session)
1453 .await
1454 .trace_err();
1455
1456 if !session
1457 .connection_pool()
1458 .await
1459 .is_user_online(session.user_id())
1460 {
1461 let db = session.db().await;
1462 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1463 room_updated(&room, &session.peer);
1464 }
1465 }
1466
1467 update_user_contacts(session.user_id(), &session).await?;
1468 },
1469 Principal::DevServer(_) => {
1470 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1471 },
1472 }
1473 },
1474 _ = teardown.changed().fuse() => {}
1475 }
1476
1477 Ok(())
1478}
1479
1480/// Acknowledges a ping from a client, used to keep the connection alive.
1481async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1482 response.send(proto::Ack {})?;
1483 Ok(())
1484}
1485
1486/// Creates a new room for calling (outside of channels)
1487async fn create_room(
1488 _request: proto::CreateRoom,
1489 response: Response<proto::CreateRoom>,
1490 session: UserSession,
1491) -> Result<()> {
1492 let live_kit_room = nanoid::nanoid!(30);
1493
1494 let live_kit_connection_info = util::maybe!(async {
1495 let live_kit = session.app_state.live_kit_client.as_ref();
1496 let live_kit = live_kit?;
1497 let user_id = session.user_id().to_string();
1498
1499 let token = live_kit
1500 .room_token(&live_kit_room, &user_id.to_string())
1501 .trace_err()?;
1502
1503 Some(proto::LiveKitConnectionInfo {
1504 server_url: live_kit.url().into(),
1505 token,
1506 can_publish: true,
1507 })
1508 })
1509 .await;
1510
1511 let room = session
1512 .db()
1513 .await
1514 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1515 .await?;
1516
1517 response.send(proto::CreateRoomResponse {
1518 room: Some(room.clone()),
1519 live_kit_connection_info,
1520 })?;
1521
1522 update_user_contacts(session.user_id(), &session).await?;
1523 Ok(())
1524}
1525
1526/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1527async fn join_room(
1528 request: proto::JoinRoom,
1529 response: Response<proto::JoinRoom>,
1530 session: UserSession,
1531) -> Result<()> {
1532 let room_id = RoomId::from_proto(request.id);
1533
1534 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1535
1536 if let Some(channel_id) = channel_id {
1537 return join_channel_internal(channel_id, Box::new(response), session).await;
1538 }
1539
1540 let joined_room = {
1541 let room = session
1542 .db()
1543 .await
1544 .join_room(room_id, session.user_id(), session.connection_id)
1545 .await?;
1546 room_updated(&room.room, &session.peer);
1547 room.into_inner()
1548 };
1549
1550 for connection_id in session
1551 .connection_pool()
1552 .await
1553 .user_connection_ids(session.user_id())
1554 {
1555 session
1556 .peer
1557 .send(
1558 connection_id,
1559 proto::CallCanceled {
1560 room_id: room_id.to_proto(),
1561 },
1562 )
1563 .trace_err();
1564 }
1565
1566 let live_kit_connection_info =
1567 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1568 live_kit
1569 .room_token(
1570 &joined_room.room.live_kit_room,
1571 &session.user_id().to_string(),
1572 )
1573 .trace_err()
1574 .map(|token| proto::LiveKitConnectionInfo {
1575 server_url: live_kit.url().into(),
1576 token,
1577 can_publish: true,
1578 })
1579 } else {
1580 None
1581 };
1582
1583 response.send(proto::JoinRoomResponse {
1584 room: Some(joined_room.room),
1585 channel_id: None,
1586 live_kit_connection_info,
1587 })?;
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595 request: proto::RejoinRoom,
1596 response: Response<proto::RejoinRoom>,
1597 session: UserSession,
1598) -> Result<()> {
1599 let room;
1600 let channel;
1601 {
1602 let mut rejoined_room = session
1603 .db()
1604 .await
1605 .rejoin_room(request, session.user_id(), session.connection_id)
1606 .await?;
1607
1608 response.send(proto::RejoinRoomResponse {
1609 room: Some(rejoined_room.room.clone()),
1610 reshared_projects: rejoined_room
1611 .reshared_projects
1612 .iter()
1613 .map(|project| proto::ResharedProject {
1614 id: project.id.to_proto(),
1615 collaborators: project
1616 .collaborators
1617 .iter()
1618 .map(|collaborator| collaborator.to_proto())
1619 .collect(),
1620 })
1621 .collect(),
1622 rejoined_projects: rejoined_room
1623 .rejoined_projects
1624 .iter()
1625 .map(|rejoined_project| rejoined_project.to_proto())
1626 .collect(),
1627 })?;
1628 room_updated(&rejoined_room.room, &session.peer);
1629
1630 for project in &rejoined_room.reshared_projects {
1631 for collaborator in &project.collaborators {
1632 session
1633 .peer
1634 .send(
1635 collaborator.connection_id,
1636 proto::UpdateProjectCollaborator {
1637 project_id: project.id.to_proto(),
1638 old_peer_id: Some(project.old_connection_id.into()),
1639 new_peer_id: Some(session.connection_id.into()),
1640 },
1641 )
1642 .trace_err();
1643 }
1644
1645 broadcast(
1646 Some(session.connection_id),
1647 project
1648 .collaborators
1649 .iter()
1650 .map(|collaborator| collaborator.connection_id),
1651 |connection_id| {
1652 session.peer.forward_send(
1653 session.connection_id,
1654 connection_id,
1655 proto::UpdateProject {
1656 project_id: project.id.to_proto(),
1657 worktrees: project.worktrees.clone(),
1658 },
1659 )
1660 },
1661 );
1662 }
1663
1664 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666 let rejoined_room = rejoined_room.into_inner();
1667
1668 room = rejoined_room.room;
1669 channel = rejoined_room.channel;
1670 }
1671
1672 if let Some(channel) = channel {
1673 channel_updated(
1674 &channel,
1675 &room,
1676 &session.peer,
1677 &*session.connection_pool().await,
1678 );
1679 }
1680
1681 update_user_contacts(session.user_id(), &session).await?;
1682 Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686 rejoined_projects: &mut Vec<RejoinedProject>,
1687 session: &UserSession,
1688) -> Result<()> {
1689 for project in rejoined_projects.iter() {
1690 for collaborator in &project.collaborators {
1691 session
1692 .peer
1693 .send(
1694 collaborator.connection_id,
1695 proto::UpdateProjectCollaborator {
1696 project_id: project.id.to_proto(),
1697 old_peer_id: Some(project.old_connection_id.into()),
1698 new_peer_id: Some(session.connection_id.into()),
1699 },
1700 )
1701 .trace_err();
1702 }
1703 }
1704
1705 for project in rejoined_projects {
1706 for worktree in mem::take(&mut project.worktrees) {
1707 #[cfg(any(test, feature = "test-support"))]
1708 const MAX_CHUNK_SIZE: usize = 2;
1709 #[cfg(not(any(test, feature = "test-support")))]
1710 const MAX_CHUNK_SIZE: usize = 256;
1711
1712 // Stream this worktree's entries.
1713 let message = proto::UpdateWorktree {
1714 project_id: project.id.to_proto(),
1715 worktree_id: worktree.id,
1716 abs_path: worktree.abs_path.clone(),
1717 root_name: worktree.root_name,
1718 updated_entries: worktree.updated_entries,
1719 removed_entries: worktree.removed_entries,
1720 scan_id: worktree.scan_id,
1721 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722 updated_repositories: worktree.updated_repositories,
1723 removed_repositories: worktree.removed_repositories,
1724 };
1725 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726 session.peer.send(session.connection_id, update.clone())?;
1727 }
1728
1729 // Stream this worktree's diagnostics.
1730 for summary in worktree.diagnostic_summaries {
1731 session.peer.send(
1732 session.connection_id,
1733 proto::UpdateDiagnosticSummary {
1734 project_id: project.id.to_proto(),
1735 worktree_id: worktree.id,
1736 summary: Some(summary),
1737 },
1738 )?;
1739 }
1740
1741 for settings_file in worktree.settings_files {
1742 session.peer.send(
1743 session.connection_id,
1744 proto::UpdateWorktreeSettings {
1745 project_id: project.id.to_proto(),
1746 worktree_id: worktree.id,
1747 path: settings_file.path,
1748 content: Some(settings_file.content),
1749 kind: Some(settings_file.kind.to_proto().into()),
1750 },
1751 )?;
1752 }
1753 }
1754
1755 for language_server in &project.language_servers {
1756 session.peer.send(
1757 session.connection_id,
1758 proto::UpdateLanguageServer {
1759 project_id: project.id.to_proto(),
1760 language_server_id: language_server.id,
1761 variant: Some(
1762 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1763 proto::LspDiskBasedDiagnosticsUpdated {},
1764 ),
1765 ),
1766 },
1767 )?;
1768 }
1769 }
1770 Ok(())
1771}
1772
1773/// leave room disconnects from the room.
1774async fn leave_room(
1775 _: proto::LeaveRoom,
1776 response: Response<proto::LeaveRoom>,
1777 session: UserSession,
1778) -> Result<()> {
1779 leave_room_for_session(&session, session.connection_id).await?;
1780 response.send(proto::Ack {})?;
1781 Ok(())
1782}
1783
1784/// Updates the permissions of someone else in the room.
1785async fn set_room_participant_role(
1786 request: proto::SetRoomParticipantRole,
1787 response: Response<proto::SetRoomParticipantRole>,
1788 session: UserSession,
1789) -> Result<()> {
1790 let user_id = UserId::from_proto(request.user_id);
1791 let role = ChannelRole::from(request.role());
1792
1793 let (live_kit_room, can_publish) = {
1794 let room = session
1795 .db()
1796 .await
1797 .set_room_participant_role(
1798 session.user_id(),
1799 RoomId::from_proto(request.room_id),
1800 user_id,
1801 role,
1802 )
1803 .await?;
1804
1805 let live_kit_room = room.live_kit_room.clone();
1806 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1807 room_updated(&room, &session.peer);
1808 (live_kit_room, can_publish)
1809 };
1810
1811 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1812 live_kit
1813 .update_participant(
1814 live_kit_room.clone(),
1815 request.user_id.to_string(),
1816 live_kit_server::proto::ParticipantPermission {
1817 can_subscribe: true,
1818 can_publish,
1819 can_publish_data: can_publish,
1820 hidden: false,
1821 recorder: false,
1822 },
1823 )
1824 .await
1825 .trace_err();
1826 }
1827
1828 response.send(proto::Ack {})?;
1829 Ok(())
1830}
1831
1832/// Call someone else into the current room
1833async fn call(
1834 request: proto::Call,
1835 response: Response<proto::Call>,
1836 session: UserSession,
1837) -> Result<()> {
1838 let room_id = RoomId::from_proto(request.room_id);
1839 let calling_user_id = session.user_id();
1840 let calling_connection_id = session.connection_id;
1841 let called_user_id = UserId::from_proto(request.called_user_id);
1842 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1843 if !session
1844 .db()
1845 .await
1846 .has_contact(calling_user_id, called_user_id)
1847 .await?
1848 {
1849 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1850 }
1851
1852 let incoming_call = {
1853 let (room, incoming_call) = &mut *session
1854 .db()
1855 .await
1856 .call(
1857 room_id,
1858 calling_user_id,
1859 calling_connection_id,
1860 called_user_id,
1861 initial_project_id,
1862 )
1863 .await?;
1864 room_updated(room, &session.peer);
1865 mem::take(incoming_call)
1866 };
1867 update_user_contacts(called_user_id, &session).await?;
1868
1869 let mut calls = session
1870 .connection_pool()
1871 .await
1872 .user_connection_ids(called_user_id)
1873 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1874 .collect::<FuturesUnordered<_>>();
1875
1876 while let Some(call_response) = calls.next().await {
1877 match call_response.as_ref() {
1878 Ok(_) => {
1879 response.send(proto::Ack {})?;
1880 return Ok(());
1881 }
1882 Err(_) => {
1883 call_response.trace_err();
1884 }
1885 }
1886 }
1887
1888 {
1889 let room = session
1890 .db()
1891 .await
1892 .call_failed(room_id, called_user_id)
1893 .await?;
1894 room_updated(&room, &session.peer);
1895 }
1896 update_user_contacts(called_user_id, &session).await?;
1897
1898 Err(anyhow!("failed to ring user"))?
1899}
1900
1901/// Cancel an outgoing call.
1902async fn cancel_call(
1903 request: proto::CancelCall,
1904 response: Response<proto::CancelCall>,
1905 session: UserSession,
1906) -> Result<()> {
1907 let called_user_id = UserId::from_proto(request.called_user_id);
1908 let room_id = RoomId::from_proto(request.room_id);
1909 {
1910 let room = session
1911 .db()
1912 .await
1913 .cancel_call(room_id, session.connection_id, called_user_id)
1914 .await?;
1915 room_updated(&room, &session.peer);
1916 }
1917
1918 for connection_id in session
1919 .connection_pool()
1920 .await
1921 .user_connection_ids(called_user_id)
1922 {
1923 session
1924 .peer
1925 .send(
1926 connection_id,
1927 proto::CallCanceled {
1928 room_id: room_id.to_proto(),
1929 },
1930 )
1931 .trace_err();
1932 }
1933 response.send(proto::Ack {})?;
1934
1935 update_user_contacts(called_user_id, &session).await?;
1936 Ok(())
1937}
1938
1939/// Decline an incoming call.
1940async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1941 let room_id = RoomId::from_proto(message.room_id);
1942 {
1943 let room = session
1944 .db()
1945 .await
1946 .decline_call(Some(room_id), session.user_id())
1947 .await?
1948 .ok_or_else(|| anyhow!("failed to decline call"))?;
1949 room_updated(&room, &session.peer);
1950 }
1951
1952 for connection_id in session
1953 .connection_pool()
1954 .await
1955 .user_connection_ids(session.user_id())
1956 {
1957 session
1958 .peer
1959 .send(
1960 connection_id,
1961 proto::CallCanceled {
1962 room_id: room_id.to_proto(),
1963 },
1964 )
1965 .trace_err();
1966 }
1967 update_user_contacts(session.user_id(), &session).await?;
1968 Ok(())
1969}
1970
1971/// Updates other participants in the room with your current location.
1972async fn update_participant_location(
1973 request: proto::UpdateParticipantLocation,
1974 response: Response<proto::UpdateParticipantLocation>,
1975 session: UserSession,
1976) -> Result<()> {
1977 let room_id = RoomId::from_proto(request.room_id);
1978 let location = request
1979 .location
1980 .ok_or_else(|| anyhow!("invalid location"))?;
1981
1982 let db = session.db().await;
1983 let room = db
1984 .update_room_participant_location(room_id, session.connection_id, location)
1985 .await?;
1986
1987 room_updated(&room, &session.peer);
1988 response.send(proto::Ack {})?;
1989 Ok(())
1990}
1991
1992/// Share a project into the room.
1993async fn share_project(
1994 request: proto::ShareProject,
1995 response: Response<proto::ShareProject>,
1996 session: UserSession,
1997) -> Result<()> {
1998 let (project_id, room) = &*session
1999 .db()
2000 .await
2001 .share_project(
2002 RoomId::from_proto(request.room_id),
2003 session.connection_id,
2004 &request.worktrees,
2005 request.is_ssh_project,
2006 request
2007 .dev_server_project_id
2008 .map(DevServerProjectId::from_proto),
2009 )
2010 .await?;
2011 response.send(proto::ShareProjectResponse {
2012 project_id: project_id.to_proto(),
2013 })?;
2014 room_updated(room, &session.peer);
2015
2016 Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021 let project_id = ProjectId::from_proto(message.project_id);
2022 unshare_project_internal(
2023 project_id,
2024 session.connection_id,
2025 session.user_id(),
2026 &session,
2027 )
2028 .await
2029}
2030
2031async fn unshare_project_internal(
2032 project_id: ProjectId,
2033 connection_id: ConnectionId,
2034 user_id: Option<UserId>,
2035 session: &Session,
2036) -> Result<()> {
2037 let delete = {
2038 let room_guard = session
2039 .db()
2040 .await
2041 .unshare_project(project_id, connection_id, user_id)
2042 .await?;
2043
2044 let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046 let message = proto::UnshareProject {
2047 project_id: project_id.to_proto(),
2048 };
2049
2050 broadcast(
2051 Some(connection_id),
2052 guest_connection_ids.iter().copied(),
2053 |conn_id| session.peer.send(conn_id, message.clone()),
2054 );
2055 if let Some(room) = room {
2056 room_updated(room, &session.peer);
2057 }
2058
2059 *delete
2060 };
2061
2062 if delete {
2063 let db = session.db().await;
2064 db.delete_project(project_id).await?;
2065 }
2066
2067 Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072 request: proto::ShareDevServerProject,
2073 response: Response<proto::ShareDevServerProject>,
2074 session: DevServerSession,
2075) -> Result<()> {
2076 let (dev_server_project, user_id, status) = session
2077 .db()
2078 .await
2079 .share_dev_server_project(
2080 DevServerProjectId::from_proto(request.dev_server_project_id),
2081 session.dev_server_id(),
2082 session.connection_id,
2083 &request.worktrees,
2084 )
2085 .await?;
2086 let Some(project_id) = dev_server_project.project_id else {
2087 return Err(anyhow!("failed to share remote project"))?;
2088 };
2089
2090 send_dev_server_projects_update(user_id, status, &session).await;
2091
2092 response.send(proto::ShareProjectResponse { project_id })?;
2093
2094 Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099 request: proto::JoinProject,
2100 response: Response<proto::JoinProject>,
2101 session: UserSession,
2102) -> Result<()> {
2103 let project_id = ProjectId::from_proto(request.project_id);
2104
2105 tracing::info!(%project_id, "join project");
2106
2107 let db = session.db().await;
2108 let (project, replica_id) = &mut *db
2109 .join_project(project_id, session.connection_id, session.user_id())
2110 .await?;
2111 drop(db);
2112 tracing::info!(%project_id, "join remote project");
2113 join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121 Response::<proto::JoinProject>::send(self, result)
2122 }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126 Response::<proto::JoinHostedProject>::send(self, result)
2127 }
2128}
2129
2130fn join_project_internal(
2131 response: impl JoinProjectInternalResponse,
2132 session: UserSession,
2133 project: &mut Project,
2134 replica_id: &ReplicaId,
2135) -> Result<()> {
2136 let collaborators = project
2137 .collaborators
2138 .iter()
2139 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140 .map(|collaborator| collaborator.to_proto())
2141 .collect::<Vec<_>>();
2142 let project_id = project.id;
2143 let guest_user_id = session.user_id();
2144
2145 let worktrees = project
2146 .worktrees
2147 .iter()
2148 .map(|(id, worktree)| proto::WorktreeMetadata {
2149 id: *id,
2150 root_name: worktree.root_name.clone(),
2151 visible: worktree.visible,
2152 abs_path: worktree.abs_path.clone(),
2153 })
2154 .collect::<Vec<_>>();
2155
2156 let add_project_collaborator = proto::AddProjectCollaborator {
2157 project_id: project_id.to_proto(),
2158 collaborator: Some(proto::Collaborator {
2159 peer_id: Some(session.connection_id.into()),
2160 replica_id: replica_id.0 as u32,
2161 user_id: guest_user_id.to_proto(),
2162 }),
2163 };
2164
2165 for collaborator in &collaborators {
2166 session
2167 .peer
2168 .send(
2169 collaborator.peer_id.unwrap().into(),
2170 add_project_collaborator.clone(),
2171 )
2172 .trace_err();
2173 }
2174
2175 // First, we send the metadata associated with each worktree.
2176 response.send(proto::JoinProjectResponse {
2177 project_id: project.id.0 as u64,
2178 worktrees: worktrees.clone(),
2179 replica_id: replica_id.0 as u32,
2180 collaborators: collaborators.clone(),
2181 language_servers: project.language_servers.clone(),
2182 role: project.role.into(),
2183 dev_server_project_id: project
2184 .dev_server_project_id
2185 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186 })?;
2187
2188 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189 #[cfg(any(test, feature = "test-support"))]
2190 const MAX_CHUNK_SIZE: usize = 2;
2191 #[cfg(not(any(test, feature = "test-support")))]
2192 const MAX_CHUNK_SIZE: usize = 256;
2193
2194 // Stream this worktree's entries.
2195 let message = proto::UpdateWorktree {
2196 project_id: project_id.to_proto(),
2197 worktree_id,
2198 abs_path: worktree.abs_path.clone(),
2199 root_name: worktree.root_name,
2200 updated_entries: worktree.entries,
2201 removed_entries: Default::default(),
2202 scan_id: worktree.scan_id,
2203 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204 updated_repositories: worktree.repository_entries.into_values().collect(),
2205 removed_repositories: Default::default(),
2206 };
2207 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208 session.peer.send(session.connection_id, update.clone())?;
2209 }
2210
2211 // Stream this worktree's diagnostics.
2212 for summary in worktree.diagnostic_summaries {
2213 session.peer.send(
2214 session.connection_id,
2215 proto::UpdateDiagnosticSummary {
2216 project_id: project_id.to_proto(),
2217 worktree_id: worktree.id,
2218 summary: Some(summary),
2219 },
2220 )?;
2221 }
2222
2223 for settings_file in worktree.settings_files {
2224 session.peer.send(
2225 session.connection_id,
2226 proto::UpdateWorktreeSettings {
2227 project_id: project_id.to_proto(),
2228 worktree_id: worktree.id,
2229 path: settings_file.path,
2230 content: Some(settings_file.content),
2231 kind: Some(proto::update_user_settings::Kind::Settings.into()),
2232 },
2233 )?;
2234 }
2235 }
2236
2237 for language_server in &project.language_servers {
2238 session.peer.send(
2239 session.connection_id,
2240 proto::UpdateLanguageServer {
2241 project_id: project_id.to_proto(),
2242 language_server_id: language_server.id,
2243 variant: Some(
2244 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2245 proto::LspDiskBasedDiagnosticsUpdated {},
2246 ),
2247 ),
2248 },
2249 )?;
2250 }
2251
2252 Ok(())
2253}
2254
2255/// Leave someone elses shared project.
2256async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2257 let sender_id = session.connection_id;
2258 let project_id = ProjectId::from_proto(request.project_id);
2259 let db = session.db().await;
2260 if db.is_hosted_project(project_id).await? {
2261 let project = db.leave_hosted_project(project_id, sender_id).await?;
2262 project_left(&project, &session);
2263 return Ok(());
2264 }
2265
2266 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2267 tracing::info!(
2268 %project_id,
2269 "leave project"
2270 );
2271
2272 project_left(project, &session);
2273 if let Some(room) = room {
2274 room_updated(room, &session.peer);
2275 }
2276
2277 Ok(())
2278}
2279
2280async fn join_hosted_project(
2281 request: proto::JoinHostedProject,
2282 response: Response<proto::JoinHostedProject>,
2283 session: UserSession,
2284) -> Result<()> {
2285 let (mut project, replica_id) = session
2286 .db()
2287 .await
2288 .join_hosted_project(
2289 ProjectId(request.project_id as i32),
2290 session.user_id(),
2291 session.connection_id,
2292 )
2293 .await?;
2294
2295 join_project_internal(response, session, &mut project, &replica_id)
2296}
2297
2298async fn list_remote_directory(
2299 request: proto::ListRemoteDirectory,
2300 response: Response<proto::ListRemoteDirectory>,
2301 session: UserSession,
2302) -> Result<()> {
2303 let dev_server_id = DevServerId(request.dev_server_id as i32);
2304 let dev_server_connection_id = session
2305 .connection_pool()
2306 .await
2307 .online_dev_server_connection_id(dev_server_id)?;
2308
2309 session
2310 .db()
2311 .await
2312 .get_dev_server_for_user(dev_server_id, session.user_id())
2313 .await?;
2314
2315 response.send(
2316 session
2317 .peer
2318 .forward_request(session.connection_id, dev_server_connection_id, request)
2319 .await?,
2320 )?;
2321 Ok(())
2322}
2323
2324async fn update_dev_server_project(
2325 request: proto::UpdateDevServerProject,
2326 response: Response<proto::UpdateDevServerProject>,
2327 session: UserSession,
2328) -> Result<()> {
2329 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2330
2331 let (dev_server_project, update) = session
2332 .db()
2333 .await
2334 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2335 .await?;
2336
2337 let projects = session
2338 .db()
2339 .await
2340 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2341 .await?;
2342
2343 let dev_server_connection_id = session
2344 .connection_pool()
2345 .await
2346 .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2347
2348 session.peer.send(
2349 dev_server_connection_id,
2350 proto::DevServerInstructions { projects },
2351 )?;
2352
2353 send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355 response.send(proto::Ack {})
2356}
2357
2358async fn create_dev_server_project(
2359 request: proto::CreateDevServerProject,
2360 response: Response<proto::CreateDevServerProject>,
2361 session: UserSession,
2362) -> Result<()> {
2363 let dev_server_id = DevServerId(request.dev_server_id as i32);
2364 let dev_server_connection_id = session
2365 .connection_pool()
2366 .await
2367 .dev_server_connection_id(dev_server_id);
2368 let Some(dev_server_connection_id) = dev_server_connection_id else {
2369 Err(ErrorCode::DevServerOffline
2370 .message("Cannot create a remote project when the dev server is offline".to_string())
2371 .anyhow())?
2372 };
2373
2374 let path = request.path.clone();
2375 //Check that the path exists on the dev server
2376 session
2377 .peer
2378 .forward_request(
2379 session.connection_id,
2380 dev_server_connection_id,
2381 proto::ValidateDevServerProjectRequest { path: path.clone() },
2382 )
2383 .await?;
2384
2385 let (dev_server_project, update) = session
2386 .db()
2387 .await
2388 .create_dev_server_project(
2389 DevServerId(request.dev_server_id as i32),
2390 &request.path,
2391 session.user_id(),
2392 )
2393 .await?;
2394
2395 let projects = session
2396 .db()
2397 .await
2398 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2399 .await?;
2400
2401 session.peer.send(
2402 dev_server_connection_id,
2403 proto::DevServerInstructions { projects },
2404 )?;
2405
2406 send_dev_server_projects_update(session.user_id(), update, &session).await;
2407
2408 response.send(proto::CreateDevServerProjectResponse {
2409 dev_server_project: Some(dev_server_project.to_proto(None)),
2410 })?;
2411 Ok(())
2412}
2413
2414async fn create_dev_server(
2415 request: proto::CreateDevServer,
2416 response: Response<proto::CreateDevServer>,
2417 session: UserSession,
2418) -> Result<()> {
2419 let access_token = auth::random_token();
2420 let hashed_access_token = auth::hash_access_token(&access_token);
2421
2422 if request.name.is_empty() {
2423 return Err(proto::ErrorCode::Forbidden
2424 .message("Dev server name cannot be empty".to_string())
2425 .anyhow())?;
2426 }
2427
2428 let (dev_server, status) = session
2429 .db()
2430 .await
2431 .create_dev_server(
2432 &request.name,
2433 request.ssh_connection_string.as_deref(),
2434 &hashed_access_token,
2435 session.user_id(),
2436 )
2437 .await?;
2438
2439 send_dev_server_projects_update(session.user_id(), status, &session).await;
2440
2441 response.send(proto::CreateDevServerResponse {
2442 dev_server_id: dev_server.id.0 as u64,
2443 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2444 name: request.name,
2445 })?;
2446 Ok(())
2447}
2448
2449async fn regenerate_dev_server_token(
2450 request: proto::RegenerateDevServerToken,
2451 response: Response<proto::RegenerateDevServerToken>,
2452 session: UserSession,
2453) -> Result<()> {
2454 let dev_server_id = DevServerId(request.dev_server_id as i32);
2455 let access_token = auth::random_token();
2456 let hashed_access_token = auth::hash_access_token(&access_token);
2457
2458 let connection_id = session
2459 .connection_pool()
2460 .await
2461 .dev_server_connection_id(dev_server_id);
2462 if let Some(connection_id) = connection_id {
2463 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2464 session.peer.send(
2465 connection_id,
2466 proto::ShutdownDevServer {
2467 reason: Some("dev server token was regenerated".to_string()),
2468 },
2469 )?;
2470 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2471 }
2472
2473 let status = session
2474 .db()
2475 .await
2476 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2477 .await?;
2478
2479 send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481 response.send(proto::RegenerateDevServerTokenResponse {
2482 dev_server_id: dev_server_id.to_proto(),
2483 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2484 })?;
2485 Ok(())
2486}
2487
2488async fn rename_dev_server(
2489 request: proto::RenameDevServer,
2490 response: Response<proto::RenameDevServer>,
2491 session: UserSession,
2492) -> Result<()> {
2493 if request.name.trim().is_empty() {
2494 return Err(proto::ErrorCode::Forbidden
2495 .message("Dev server name cannot be empty".to_string())
2496 .anyhow())?;
2497 }
2498
2499 let dev_server_id = DevServerId(request.dev_server_id as i32);
2500 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2501 if dev_server.user_id != session.user_id() {
2502 return Err(anyhow!(ErrorCode::Forbidden))?;
2503 }
2504
2505 let status = session
2506 .db()
2507 .await
2508 .rename_dev_server(
2509 dev_server_id,
2510 &request.name,
2511 request.ssh_connection_string.as_deref(),
2512 session.user_id(),
2513 )
2514 .await?;
2515
2516 send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518 response.send(proto::Ack {})?;
2519 Ok(())
2520}
2521
2522async fn delete_dev_server(
2523 request: proto::DeleteDevServer,
2524 response: Response<proto::DeleteDevServer>,
2525 session: UserSession,
2526) -> Result<()> {
2527 let dev_server_id = DevServerId(request.dev_server_id as i32);
2528 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2529 if dev_server.user_id != session.user_id() {
2530 return Err(anyhow!(ErrorCode::Forbidden))?;
2531 }
2532
2533 let connection_id = session
2534 .connection_pool()
2535 .await
2536 .dev_server_connection_id(dev_server_id);
2537 if let Some(connection_id) = connection_id {
2538 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2539 session.peer.send(
2540 connection_id,
2541 proto::ShutdownDevServer {
2542 reason: Some("dev server was deleted".to_string()),
2543 },
2544 )?;
2545 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2546 }
2547
2548 let status = session
2549 .db()
2550 .await
2551 .delete_dev_server(dev_server_id, session.user_id())
2552 .await?;
2553
2554 send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556 response.send(proto::Ack {})?;
2557 Ok(())
2558}
2559
2560async fn delete_dev_server_project(
2561 request: proto::DeleteDevServerProject,
2562 response: Response<proto::DeleteDevServerProject>,
2563 session: UserSession,
2564) -> Result<()> {
2565 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2566 let dev_server_project = session
2567 .db()
2568 .await
2569 .get_dev_server_project(dev_server_project_id)
2570 .await?;
2571
2572 let dev_server = session
2573 .db()
2574 .await
2575 .get_dev_server(dev_server_project.dev_server_id)
2576 .await?;
2577 if dev_server.user_id != session.user_id() {
2578 return Err(anyhow!(ErrorCode::Forbidden))?;
2579 }
2580
2581 let dev_server_connection_id = session
2582 .connection_pool()
2583 .await
2584 .dev_server_connection_id(dev_server.id);
2585
2586 if let Some(dev_server_connection_id) = dev_server_connection_id {
2587 let project = session
2588 .db()
2589 .await
2590 .find_dev_server_project(dev_server_project_id)
2591 .await;
2592 if let Ok(project) = project {
2593 unshare_project_internal(
2594 project.id,
2595 dev_server_connection_id,
2596 Some(session.user_id()),
2597 &session,
2598 )
2599 .await?;
2600 }
2601 }
2602
2603 let (projects, status) = session
2604 .db()
2605 .await
2606 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2607 .await?;
2608
2609 if let Some(dev_server_connection_id) = dev_server_connection_id {
2610 session.peer.send(
2611 dev_server_connection_id,
2612 proto::DevServerInstructions { projects },
2613 )?;
2614 }
2615
2616 send_dev_server_projects_update(session.user_id(), status, &session).await;
2617
2618 response.send(proto::Ack {})?;
2619 Ok(())
2620}
2621
2622async fn rejoin_dev_server_projects(
2623 request: proto::RejoinRemoteProjects,
2624 response: Response<proto::RejoinRemoteProjects>,
2625 session: UserSession,
2626) -> Result<()> {
2627 let mut rejoined_projects = {
2628 let db = session.db().await;
2629 db.rejoin_dev_server_projects(
2630 &request.rejoined_projects,
2631 session.user_id(),
2632 session.0.connection_id,
2633 )
2634 .await?
2635 };
2636 response.send(proto::RejoinRemoteProjectsResponse {
2637 rejoined_projects: rejoined_projects
2638 .iter()
2639 .map(|project| project.to_proto())
2640 .collect(),
2641 })?;
2642 notify_rejoined_projects(&mut rejoined_projects, &session)
2643}
2644
2645async fn reconnect_dev_server(
2646 request: proto::ReconnectDevServer,
2647 response: Response<proto::ReconnectDevServer>,
2648 session: DevServerSession,
2649) -> Result<()> {
2650 let reshared_projects = {
2651 let db = session.db().await;
2652 db.reshare_dev_server_projects(
2653 &request.reshared_projects,
2654 session.dev_server_id(),
2655 session.0.connection_id,
2656 )
2657 .await?
2658 };
2659
2660 for project in &reshared_projects {
2661 for collaborator in &project.collaborators {
2662 session
2663 .peer
2664 .send(
2665 collaborator.connection_id,
2666 proto::UpdateProjectCollaborator {
2667 project_id: project.id.to_proto(),
2668 old_peer_id: Some(project.old_connection_id.into()),
2669 new_peer_id: Some(session.connection_id.into()),
2670 },
2671 )
2672 .trace_err();
2673 }
2674
2675 broadcast(
2676 Some(session.connection_id),
2677 project
2678 .collaborators
2679 .iter()
2680 .map(|collaborator| collaborator.connection_id),
2681 |connection_id| {
2682 session.peer.forward_send(
2683 session.connection_id,
2684 connection_id,
2685 proto::UpdateProject {
2686 project_id: project.id.to_proto(),
2687 worktrees: project.worktrees.clone(),
2688 },
2689 )
2690 },
2691 );
2692 }
2693
2694 response.send(proto::ReconnectDevServerResponse {
2695 reshared_projects: reshared_projects
2696 .iter()
2697 .map(|project| proto::ResharedProject {
2698 id: project.id.to_proto(),
2699 collaborators: project
2700 .collaborators
2701 .iter()
2702 .map(|collaborator| collaborator.to_proto())
2703 .collect(),
2704 })
2705 .collect(),
2706 })?;
2707
2708 Ok(())
2709}
2710
2711async fn shutdown_dev_server(
2712 _: proto::ShutdownDevServer,
2713 response: Response<proto::ShutdownDevServer>,
2714 session: DevServerSession,
2715) -> Result<()> {
2716 response.send(proto::Ack {})?;
2717 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2718 remove_dev_server_connection(session.dev_server_id(), &session).await
2719}
2720
2721async fn shutdown_dev_server_internal(
2722 dev_server_id: DevServerId,
2723 connection_id: ConnectionId,
2724 session: &Session,
2725) -> Result<()> {
2726 let (dev_server_projects, dev_server) = {
2727 let db = session.db().await;
2728 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2729 let dev_server = db.get_dev_server(dev_server_id).await?;
2730 (dev_server_projects, dev_server)
2731 };
2732
2733 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2734 unshare_project_internal(
2735 ProjectId::from_proto(project_id),
2736 connection_id,
2737 None,
2738 session,
2739 )
2740 .await?;
2741 }
2742
2743 session
2744 .connection_pool()
2745 .await
2746 .set_dev_server_offline(dev_server_id);
2747
2748 let status = session
2749 .db()
2750 .await
2751 .dev_server_projects_update(dev_server.user_id)
2752 .await?;
2753 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2754
2755 Ok(())
2756}
2757
2758async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2759 let dev_server_connection = session
2760 .connection_pool()
2761 .await
2762 .dev_server_connection_id(dev_server_id);
2763
2764 if let Some(dev_server_connection) = dev_server_connection {
2765 session
2766 .connection_pool()
2767 .await
2768 .remove_connection(dev_server_connection)?;
2769 }
2770 Ok(())
2771}
2772
2773/// Updates other participants with changes to the project
2774async fn update_project(
2775 request: proto::UpdateProject,
2776 response: Response<proto::UpdateProject>,
2777 session: Session,
2778) -> Result<()> {
2779 let project_id = ProjectId::from_proto(request.project_id);
2780 let (room, guest_connection_ids) = &*session
2781 .db()
2782 .await
2783 .update_project(project_id, session.connection_id, &request.worktrees)
2784 .await?;
2785 broadcast(
2786 Some(session.connection_id),
2787 guest_connection_ids.iter().copied(),
2788 |connection_id| {
2789 session
2790 .peer
2791 .forward_send(session.connection_id, connection_id, request.clone())
2792 },
2793 );
2794 if let Some(room) = room {
2795 room_updated(room, &session.peer);
2796 }
2797 response.send(proto::Ack {})?;
2798
2799 Ok(())
2800}
2801
2802/// Updates other participants with changes to the worktree
2803async fn update_worktree(
2804 request: proto::UpdateWorktree,
2805 response: Response<proto::UpdateWorktree>,
2806 session: Session,
2807) -> Result<()> {
2808 let guest_connection_ids = session
2809 .db()
2810 .await
2811 .update_worktree(&request, session.connection_id)
2812 .await?;
2813
2814 broadcast(
2815 Some(session.connection_id),
2816 guest_connection_ids.iter().copied(),
2817 |connection_id| {
2818 session
2819 .peer
2820 .forward_send(session.connection_id, connection_id, request.clone())
2821 },
2822 );
2823 response.send(proto::Ack {})?;
2824 Ok(())
2825}
2826
2827/// Updates other participants with changes to the diagnostics
2828async fn update_diagnostic_summary(
2829 message: proto::UpdateDiagnosticSummary,
2830 session: Session,
2831) -> Result<()> {
2832 let guest_connection_ids = session
2833 .db()
2834 .await
2835 .update_diagnostic_summary(&message, session.connection_id)
2836 .await?;
2837
2838 broadcast(
2839 Some(session.connection_id),
2840 guest_connection_ids.iter().copied(),
2841 |connection_id| {
2842 session
2843 .peer
2844 .forward_send(session.connection_id, connection_id, message.clone())
2845 },
2846 );
2847
2848 Ok(())
2849}
2850
2851/// Updates other participants with changes to the worktree settings
2852async fn update_worktree_settings(
2853 message: proto::UpdateWorktreeSettings,
2854 session: Session,
2855) -> Result<()> {
2856 let guest_connection_ids = session
2857 .db()
2858 .await
2859 .update_worktree_settings(&message, session.connection_id)
2860 .await?;
2861
2862 broadcast(
2863 Some(session.connection_id),
2864 guest_connection_ids.iter().copied(),
2865 |connection_id| {
2866 session
2867 .peer
2868 .forward_send(session.connection_id, connection_id, message.clone())
2869 },
2870 );
2871
2872 Ok(())
2873}
2874
2875/// Notify other participants that a language server has started.
2876async fn start_language_server(
2877 request: proto::StartLanguageServer,
2878 session: Session,
2879) -> Result<()> {
2880 let guest_connection_ids = session
2881 .db()
2882 .await
2883 .start_language_server(&request, session.connection_id)
2884 .await?;
2885
2886 broadcast(
2887 Some(session.connection_id),
2888 guest_connection_ids.iter().copied(),
2889 |connection_id| {
2890 session
2891 .peer
2892 .forward_send(session.connection_id, connection_id, request.clone())
2893 },
2894 );
2895 Ok(())
2896}
2897
2898/// Notify other participants that a language server has changed.
2899async fn update_language_server(
2900 request: proto::UpdateLanguageServer,
2901 session: Session,
2902) -> Result<()> {
2903 let project_id = ProjectId::from_proto(request.project_id);
2904 let project_connection_ids = session
2905 .db()
2906 .await
2907 .project_connection_ids(project_id, session.connection_id, true)
2908 .await?;
2909 broadcast(
2910 Some(session.connection_id),
2911 project_connection_ids.iter().copied(),
2912 |connection_id| {
2913 session
2914 .peer
2915 .forward_send(session.connection_id, connection_id, request.clone())
2916 },
2917 );
2918 Ok(())
2919}
2920
2921/// forward a project request to the host. These requests should be read only
2922/// as guests are allowed to send them.
2923async fn forward_read_only_project_request<T>(
2924 request: T,
2925 response: Response<T>,
2926 session: UserSession,
2927) -> Result<()>
2928where
2929 T: EntityMessage + RequestMessage,
2930{
2931 let project_id = ProjectId::from_proto(request.remote_entity_id());
2932 let host_connection_id = session
2933 .db()
2934 .await
2935 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2936 .await?;
2937 let payload = session
2938 .peer
2939 .forward_request(session.connection_id, host_connection_id, request)
2940 .await?;
2941 response.send(payload)?;
2942 Ok(())
2943}
2944
2945async fn forward_find_search_candidates_request(
2946 request: proto::FindSearchCandidates,
2947 response: Response<proto::FindSearchCandidates>,
2948 session: UserSession,
2949) -> Result<()> {
2950 let project_id = ProjectId::from_proto(request.remote_entity_id());
2951 let host_connection_id = session
2952 .db()
2953 .await
2954 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2955 .await?;
2956 let payload = session
2957 .peer
2958 .forward_request(session.connection_id, host_connection_id, request)
2959 .await?;
2960 response.send(payload)?;
2961 Ok(())
2962}
2963
2964/// forward a project request to the dev server. Only allowed
2965/// if it's your dev server.
2966async fn forward_project_request_for_owner<T>(
2967 request: T,
2968 response: Response<T>,
2969 session: UserSession,
2970) -> Result<()>
2971where
2972 T: EntityMessage + RequestMessage,
2973{
2974 let project_id = ProjectId::from_proto(request.remote_entity_id());
2975
2976 let host_connection_id = session
2977 .db()
2978 .await
2979 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2980 .await?;
2981 let payload = session
2982 .peer
2983 .forward_request(session.connection_id, host_connection_id, request)
2984 .await?;
2985 response.send(payload)?;
2986 Ok(())
2987}
2988
2989/// forward a project request to the host. These requests are disallowed
2990/// for guests.
2991async fn forward_mutating_project_request<T>(
2992 request: T,
2993 response: Response<T>,
2994 session: UserSession,
2995) -> Result<()>
2996where
2997 T: EntityMessage + RequestMessage,
2998{
2999 let project_id = ProjectId::from_proto(request.remote_entity_id());
3000
3001 let host_connection_id = session
3002 .db()
3003 .await
3004 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3005 .await?;
3006 let payload = session
3007 .peer
3008 .forward_request(session.connection_id, host_connection_id, request)
3009 .await?;
3010 response.send(payload)?;
3011 Ok(())
3012}
3013
3014/// Notify other participants that a new buffer has been created
3015async fn create_buffer_for_peer(
3016 request: proto::CreateBufferForPeer,
3017 session: Session,
3018) -> Result<()> {
3019 session
3020 .db()
3021 .await
3022 .check_user_is_project_host(
3023 ProjectId::from_proto(request.project_id),
3024 session.connection_id,
3025 )
3026 .await?;
3027 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3028 session
3029 .peer
3030 .forward_send(session.connection_id, peer_id.into(), request)?;
3031 Ok(())
3032}
3033
3034/// Notify other participants that a buffer has been updated. This is
3035/// allowed for guests as long as the update is limited to selections.
3036async fn update_buffer(
3037 request: proto::UpdateBuffer,
3038 response: Response<proto::UpdateBuffer>,
3039 session: Session,
3040) -> Result<()> {
3041 let project_id = ProjectId::from_proto(request.project_id);
3042 let mut capability = Capability::ReadOnly;
3043
3044 for op in request.operations.iter() {
3045 match op.variant {
3046 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3047 Some(_) => capability = Capability::ReadWrite,
3048 }
3049 }
3050
3051 let host = {
3052 let guard = session
3053 .db()
3054 .await
3055 .connections_for_buffer_update(
3056 project_id,
3057 session.principal_id(),
3058 session.connection_id,
3059 capability,
3060 )
3061 .await?;
3062
3063 let (host, guests) = &*guard;
3064
3065 broadcast(
3066 Some(session.connection_id),
3067 guests.clone(),
3068 |connection_id| {
3069 session
3070 .peer
3071 .forward_send(session.connection_id, connection_id, request.clone())
3072 },
3073 );
3074
3075 *host
3076 };
3077
3078 if host != session.connection_id {
3079 session
3080 .peer
3081 .forward_request(session.connection_id, host, request.clone())
3082 .await?;
3083 }
3084
3085 response.send(proto::Ack {})?;
3086 Ok(())
3087}
3088
3089async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3090 let project_id = ProjectId::from_proto(message.project_id);
3091
3092 let operation = message.operation.as_ref().context("invalid operation")?;
3093 let capability = match operation.variant.as_ref() {
3094 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3095 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3096 match buffer_op.variant {
3097 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3098 Capability::ReadOnly
3099 }
3100 _ => Capability::ReadWrite,
3101 }
3102 } else {
3103 Capability::ReadWrite
3104 }
3105 }
3106 Some(_) => Capability::ReadWrite,
3107 None => Capability::ReadOnly,
3108 };
3109
3110 let guard = session
3111 .db()
3112 .await
3113 .connections_for_buffer_update(
3114 project_id,
3115 session.principal_id(),
3116 session.connection_id,
3117 capability,
3118 )
3119 .await?;
3120
3121 let (host, guests) = &*guard;
3122
3123 broadcast(
3124 Some(session.connection_id),
3125 guests.iter().chain([host]).copied(),
3126 |connection_id| {
3127 session
3128 .peer
3129 .forward_send(session.connection_id, connection_id, message.clone())
3130 },
3131 );
3132
3133 Ok(())
3134}
3135
3136/// Notify other participants that a project has been updated.
3137async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3138 request: T,
3139 session: Session,
3140) -> Result<()> {
3141 let project_id = ProjectId::from_proto(request.remote_entity_id());
3142 let project_connection_ids = session
3143 .db()
3144 .await
3145 .project_connection_ids(project_id, session.connection_id, false)
3146 .await?;
3147
3148 broadcast(
3149 Some(session.connection_id),
3150 project_connection_ids.iter().copied(),
3151 |connection_id| {
3152 session
3153 .peer
3154 .forward_send(session.connection_id, connection_id, request.clone())
3155 },
3156 );
3157 Ok(())
3158}
3159
3160/// Start following another user in a call.
3161async fn follow(
3162 request: proto::Follow,
3163 response: Response<proto::Follow>,
3164 session: UserSession,
3165) -> Result<()> {
3166 let room_id = RoomId::from_proto(request.room_id);
3167 let project_id = request.project_id.map(ProjectId::from_proto);
3168 let leader_id = request
3169 .leader_id
3170 .ok_or_else(|| anyhow!("invalid leader id"))?
3171 .into();
3172 let follower_id = session.connection_id;
3173
3174 session
3175 .db()
3176 .await
3177 .check_room_participants(room_id, leader_id, session.connection_id)
3178 .await?;
3179
3180 let response_payload = session
3181 .peer
3182 .forward_request(session.connection_id, leader_id, request)
3183 .await?;
3184 response.send(response_payload)?;
3185
3186 if let Some(project_id) = project_id {
3187 let room = session
3188 .db()
3189 .await
3190 .follow(room_id, project_id, leader_id, follower_id)
3191 .await?;
3192 room_updated(&room, &session.peer);
3193 }
3194
3195 Ok(())
3196}
3197
3198/// Stop following another user in a call.
3199async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3200 let room_id = RoomId::from_proto(request.room_id);
3201 let project_id = request.project_id.map(ProjectId::from_proto);
3202 let leader_id = request
3203 .leader_id
3204 .ok_or_else(|| anyhow!("invalid leader id"))?
3205 .into();
3206 let follower_id = session.connection_id;
3207
3208 session
3209 .db()
3210 .await
3211 .check_room_participants(room_id, leader_id, session.connection_id)
3212 .await?;
3213
3214 session
3215 .peer
3216 .forward_send(session.connection_id, leader_id, request)?;
3217
3218 if let Some(project_id) = project_id {
3219 let room = session
3220 .db()
3221 .await
3222 .unfollow(room_id, project_id, leader_id, follower_id)
3223 .await?;
3224 room_updated(&room, &session.peer);
3225 }
3226
3227 Ok(())
3228}
3229
3230/// Notify everyone following you of your current location.
3231async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3232 let room_id = RoomId::from_proto(request.room_id);
3233 let database = session.db.lock().await;
3234
3235 let connection_ids = if let Some(project_id) = request.project_id {
3236 let project_id = ProjectId::from_proto(project_id);
3237 database
3238 .project_connection_ids(project_id, session.connection_id, true)
3239 .await?
3240 } else {
3241 database
3242 .room_connection_ids(room_id, session.connection_id)
3243 .await?
3244 };
3245
3246 // For now, don't send view update messages back to that view's current leader.
3247 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3248 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3249 _ => None,
3250 });
3251
3252 for connection_id in connection_ids.iter().cloned() {
3253 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3254 session
3255 .peer
3256 .forward_send(session.connection_id, connection_id, request.clone())?;
3257 }
3258 }
3259 Ok(())
3260}
3261
3262/// Get public data about users.
3263async fn get_users(
3264 request: proto::GetUsers,
3265 response: Response<proto::GetUsers>,
3266 session: Session,
3267) -> Result<()> {
3268 let user_ids = request
3269 .user_ids
3270 .into_iter()
3271 .map(UserId::from_proto)
3272 .collect();
3273 let users = session
3274 .db()
3275 .await
3276 .get_users_by_ids(user_ids)
3277 .await?
3278 .into_iter()
3279 .map(|user| proto::User {
3280 id: user.id.to_proto(),
3281 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3282 github_login: user.github_login,
3283 })
3284 .collect();
3285 response.send(proto::UsersResponse { users })?;
3286 Ok(())
3287}
3288
3289/// Search for users (to invite) buy Github login
3290async fn fuzzy_search_users(
3291 request: proto::FuzzySearchUsers,
3292 response: Response<proto::FuzzySearchUsers>,
3293 session: UserSession,
3294) -> Result<()> {
3295 let query = request.query;
3296 let users = match query.len() {
3297 0 => vec![],
3298 1 | 2 => session
3299 .db()
3300 .await
3301 .get_user_by_github_login(&query)
3302 .await?
3303 .into_iter()
3304 .collect(),
3305 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3306 };
3307 let users = users
3308 .into_iter()
3309 .filter(|user| user.id != session.user_id())
3310 .map(|user| proto::User {
3311 id: user.id.to_proto(),
3312 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3313 github_login: user.github_login,
3314 })
3315 .collect();
3316 response.send(proto::UsersResponse { users })?;
3317 Ok(())
3318}
3319
3320/// Send a contact request to another user.
3321async fn request_contact(
3322 request: proto::RequestContact,
3323 response: Response<proto::RequestContact>,
3324 session: UserSession,
3325) -> Result<()> {
3326 let requester_id = session.user_id();
3327 let responder_id = UserId::from_proto(request.responder_id);
3328 if requester_id == responder_id {
3329 return Err(anyhow!("cannot add yourself as a contact"))?;
3330 }
3331
3332 let notifications = session
3333 .db()
3334 .await
3335 .send_contact_request(requester_id, responder_id)
3336 .await?;
3337
3338 // Update outgoing contact requests of requester
3339 let mut update = proto::UpdateContacts::default();
3340 update.outgoing_requests.push(responder_id.to_proto());
3341 for connection_id in session
3342 .connection_pool()
3343 .await
3344 .user_connection_ids(requester_id)
3345 {
3346 session.peer.send(connection_id, update.clone())?;
3347 }
3348
3349 // Update incoming contact requests of responder
3350 let mut update = proto::UpdateContacts::default();
3351 update
3352 .incoming_requests
3353 .push(proto::IncomingContactRequest {
3354 requester_id: requester_id.to_proto(),
3355 });
3356 let connection_pool = session.connection_pool().await;
3357 for connection_id in connection_pool.user_connection_ids(responder_id) {
3358 session.peer.send(connection_id, update.clone())?;
3359 }
3360
3361 send_notifications(&connection_pool, &session.peer, notifications);
3362
3363 response.send(proto::Ack {})?;
3364 Ok(())
3365}
3366
3367/// Accept or decline a contact request
3368async fn respond_to_contact_request(
3369 request: proto::RespondToContactRequest,
3370 response: Response<proto::RespondToContactRequest>,
3371 session: UserSession,
3372) -> Result<()> {
3373 let responder_id = session.user_id();
3374 let requester_id = UserId::from_proto(request.requester_id);
3375 let db = session.db().await;
3376 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3377 db.dismiss_contact_notification(responder_id, requester_id)
3378 .await?;
3379 } else {
3380 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3381
3382 let notifications = db
3383 .respond_to_contact_request(responder_id, requester_id, accept)
3384 .await?;
3385 let requester_busy = db.is_user_busy(requester_id).await?;
3386 let responder_busy = db.is_user_busy(responder_id).await?;
3387
3388 let pool = session.connection_pool().await;
3389 // Update responder with new contact
3390 let mut update = proto::UpdateContacts::default();
3391 if accept {
3392 update
3393 .contacts
3394 .push(contact_for_user(requester_id, requester_busy, &pool));
3395 }
3396 update
3397 .remove_incoming_requests
3398 .push(requester_id.to_proto());
3399 for connection_id in pool.user_connection_ids(responder_id) {
3400 session.peer.send(connection_id, update.clone())?;
3401 }
3402
3403 // Update requester with new contact
3404 let mut update = proto::UpdateContacts::default();
3405 if accept {
3406 update
3407 .contacts
3408 .push(contact_for_user(responder_id, responder_busy, &pool));
3409 }
3410 update
3411 .remove_outgoing_requests
3412 .push(responder_id.to_proto());
3413
3414 for connection_id in pool.user_connection_ids(requester_id) {
3415 session.peer.send(connection_id, update.clone())?;
3416 }
3417
3418 send_notifications(&pool, &session.peer, notifications);
3419 }
3420
3421 response.send(proto::Ack {})?;
3422 Ok(())
3423}
3424
3425/// Remove a contact.
3426async fn remove_contact(
3427 request: proto::RemoveContact,
3428 response: Response<proto::RemoveContact>,
3429 session: UserSession,
3430) -> Result<()> {
3431 let requester_id = session.user_id();
3432 let responder_id = UserId::from_proto(request.user_id);
3433 let db = session.db().await;
3434 let (contact_accepted, deleted_notification_id) =
3435 db.remove_contact(requester_id, responder_id).await?;
3436
3437 let pool = session.connection_pool().await;
3438 // Update outgoing contact requests of requester
3439 let mut update = proto::UpdateContacts::default();
3440 if contact_accepted {
3441 update.remove_contacts.push(responder_id.to_proto());
3442 } else {
3443 update
3444 .remove_outgoing_requests
3445 .push(responder_id.to_proto());
3446 }
3447 for connection_id in pool.user_connection_ids(requester_id) {
3448 session.peer.send(connection_id, update.clone())?;
3449 }
3450
3451 // Update incoming contact requests of responder
3452 let mut update = proto::UpdateContacts::default();
3453 if contact_accepted {
3454 update.remove_contacts.push(requester_id.to_proto());
3455 } else {
3456 update
3457 .remove_incoming_requests
3458 .push(requester_id.to_proto());
3459 }
3460 for connection_id in pool.user_connection_ids(responder_id) {
3461 session.peer.send(connection_id, update.clone())?;
3462 if let Some(notification_id) = deleted_notification_id {
3463 session.peer.send(
3464 connection_id,
3465 proto::DeleteNotification {
3466 notification_id: notification_id.to_proto(),
3467 },
3468 )?;
3469 }
3470 }
3471
3472 response.send(proto::Ack {})?;
3473 Ok(())
3474}
3475
3476fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3477 version.0.minor() < 139
3478}
3479
3480async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3481 let plan = session.current_plan(&session.db().await).await?;
3482
3483 session
3484 .peer
3485 .send(
3486 session.connection_id,
3487 proto::UpdateUserPlan { plan: plan.into() },
3488 )
3489 .trace_err();
3490
3491 Ok(())
3492}
3493
3494async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3495 subscribe_user_to_channels(
3496 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3497 &session,
3498 )
3499 .await?;
3500 Ok(())
3501}
3502
3503async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3504 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3505 let mut pool = session.connection_pool().await;
3506 for membership in &channels_for_user.channel_memberships {
3507 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3508 }
3509 session.peer.send(
3510 session.connection_id,
3511 build_update_user_channels(&channels_for_user),
3512 )?;
3513 session.peer.send(
3514 session.connection_id,
3515 build_channels_update(channels_for_user),
3516 )?;
3517 Ok(())
3518}
3519
3520/// Creates a new channel.
3521async fn create_channel(
3522 request: proto::CreateChannel,
3523 response: Response<proto::CreateChannel>,
3524 session: UserSession,
3525) -> Result<()> {
3526 let db = session.db().await;
3527
3528 let parent_id = request.parent_id.map(ChannelId::from_proto);
3529 let (channel, membership) = db
3530 .create_channel(&request.name, parent_id, session.user_id())
3531 .await?;
3532
3533 let root_id = channel.root_id();
3534 let channel = Channel::from_model(channel);
3535
3536 response.send(proto::CreateChannelResponse {
3537 channel: Some(channel.to_proto()),
3538 parent_id: request.parent_id,
3539 })?;
3540
3541 let mut connection_pool = session.connection_pool().await;
3542 if let Some(membership) = membership {
3543 connection_pool.subscribe_to_channel(
3544 membership.user_id,
3545 membership.channel_id,
3546 membership.role,
3547 );
3548 let update = proto::UpdateUserChannels {
3549 channel_memberships: vec![proto::ChannelMembership {
3550 channel_id: membership.channel_id.to_proto(),
3551 role: membership.role.into(),
3552 }],
3553 ..Default::default()
3554 };
3555 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3556 session.peer.send(connection_id, update.clone())?;
3557 }
3558 }
3559
3560 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3561 if !role.can_see_channel(channel.visibility) {
3562 continue;
3563 }
3564
3565 let update = proto::UpdateChannels {
3566 channels: vec![channel.to_proto()],
3567 ..Default::default()
3568 };
3569 session.peer.send(connection_id, update.clone())?;
3570 }
3571
3572 Ok(())
3573}
3574
3575/// Delete a channel
3576async fn delete_channel(
3577 request: proto::DeleteChannel,
3578 response: Response<proto::DeleteChannel>,
3579 session: UserSession,
3580) -> Result<()> {
3581 let db = session.db().await;
3582
3583 let channel_id = request.channel_id;
3584 let (root_channel, removed_channels) = db
3585 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3586 .await?;
3587 response.send(proto::Ack {})?;
3588
3589 // Notify members of removed channels
3590 let mut update = proto::UpdateChannels::default();
3591 update
3592 .delete_channels
3593 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3594
3595 let connection_pool = session.connection_pool().await;
3596 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3597 session.peer.send(connection_id, update.clone())?;
3598 }
3599
3600 Ok(())
3601}
3602
3603/// Invite someone to join a channel.
3604async fn invite_channel_member(
3605 request: proto::InviteChannelMember,
3606 response: Response<proto::InviteChannelMember>,
3607 session: UserSession,
3608) -> Result<()> {
3609 let db = session.db().await;
3610 let channel_id = ChannelId::from_proto(request.channel_id);
3611 let invitee_id = UserId::from_proto(request.user_id);
3612 let InviteMemberResult {
3613 channel,
3614 notifications,
3615 } = db
3616 .invite_channel_member(
3617 channel_id,
3618 invitee_id,
3619 session.user_id(),
3620 request.role().into(),
3621 )
3622 .await?;
3623
3624 let update = proto::UpdateChannels {
3625 channel_invitations: vec![channel.to_proto()],
3626 ..Default::default()
3627 };
3628
3629 let connection_pool = session.connection_pool().await;
3630 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3631 session.peer.send(connection_id, update.clone())?;
3632 }
3633
3634 send_notifications(&connection_pool, &session.peer, notifications);
3635
3636 response.send(proto::Ack {})?;
3637 Ok(())
3638}
3639
3640/// remove someone from a channel
3641async fn remove_channel_member(
3642 request: proto::RemoveChannelMember,
3643 response: Response<proto::RemoveChannelMember>,
3644 session: UserSession,
3645) -> Result<()> {
3646 let db = session.db().await;
3647 let channel_id = ChannelId::from_proto(request.channel_id);
3648 let member_id = UserId::from_proto(request.user_id);
3649
3650 let RemoveChannelMemberResult {
3651 membership_update,
3652 notification_id,
3653 } = db
3654 .remove_channel_member(channel_id, member_id, session.user_id())
3655 .await?;
3656
3657 let mut connection_pool = session.connection_pool().await;
3658 notify_membership_updated(
3659 &mut connection_pool,
3660 membership_update,
3661 member_id,
3662 &session.peer,
3663 );
3664 for connection_id in connection_pool.user_connection_ids(member_id) {
3665 if let Some(notification_id) = notification_id {
3666 session
3667 .peer
3668 .send(
3669 connection_id,
3670 proto::DeleteNotification {
3671 notification_id: notification_id.to_proto(),
3672 },
3673 )
3674 .trace_err();
3675 }
3676 }
3677
3678 response.send(proto::Ack {})?;
3679 Ok(())
3680}
3681
3682/// Toggle the channel between public and private.
3683/// Care is taken to maintain the invariant that public channels only descend from public channels,
3684/// (though members-only channels can appear at any point in the hierarchy).
3685async fn set_channel_visibility(
3686 request: proto::SetChannelVisibility,
3687 response: Response<proto::SetChannelVisibility>,
3688 session: UserSession,
3689) -> Result<()> {
3690 let db = session.db().await;
3691 let channel_id = ChannelId::from_proto(request.channel_id);
3692 let visibility = request.visibility().into();
3693
3694 let channel_model = db
3695 .set_channel_visibility(channel_id, visibility, session.user_id())
3696 .await?;
3697 let root_id = channel_model.root_id();
3698 let channel = Channel::from_model(channel_model);
3699
3700 let mut connection_pool = session.connection_pool().await;
3701 for (user_id, role) in connection_pool
3702 .channel_user_ids(root_id)
3703 .collect::<Vec<_>>()
3704 .into_iter()
3705 {
3706 let update = if role.can_see_channel(channel.visibility) {
3707 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3708 proto::UpdateChannels {
3709 channels: vec![channel.to_proto()],
3710 ..Default::default()
3711 }
3712 } else {
3713 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3714 proto::UpdateChannels {
3715 delete_channels: vec![channel.id.to_proto()],
3716 ..Default::default()
3717 }
3718 };
3719
3720 for connection_id in connection_pool.user_connection_ids(user_id) {
3721 session.peer.send(connection_id, update.clone())?;
3722 }
3723 }
3724
3725 response.send(proto::Ack {})?;
3726 Ok(())
3727}
3728
3729/// Alter the role for a user in the channel.
3730async fn set_channel_member_role(
3731 request: proto::SetChannelMemberRole,
3732 response: Response<proto::SetChannelMemberRole>,
3733 session: UserSession,
3734) -> Result<()> {
3735 let db = session.db().await;
3736 let channel_id = ChannelId::from_proto(request.channel_id);
3737 let member_id = UserId::from_proto(request.user_id);
3738 let result = db
3739 .set_channel_member_role(
3740 channel_id,
3741 session.user_id(),
3742 member_id,
3743 request.role().into(),
3744 )
3745 .await?;
3746
3747 match result {
3748 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3749 let mut connection_pool = session.connection_pool().await;
3750 notify_membership_updated(
3751 &mut connection_pool,
3752 membership_update,
3753 member_id,
3754 &session.peer,
3755 )
3756 }
3757 db::SetMemberRoleResult::InviteUpdated(channel) => {
3758 let update = proto::UpdateChannels {
3759 channel_invitations: vec![channel.to_proto()],
3760 ..Default::default()
3761 };
3762
3763 for connection_id in session
3764 .connection_pool()
3765 .await
3766 .user_connection_ids(member_id)
3767 {
3768 session.peer.send(connection_id, update.clone())?;
3769 }
3770 }
3771 }
3772
3773 response.send(proto::Ack {})?;
3774 Ok(())
3775}
3776
3777/// Change the name of a channel
3778async fn rename_channel(
3779 request: proto::RenameChannel,
3780 response: Response<proto::RenameChannel>,
3781 session: UserSession,
3782) -> Result<()> {
3783 let db = session.db().await;
3784 let channel_id = ChannelId::from_proto(request.channel_id);
3785 let channel_model = db
3786 .rename_channel(channel_id, session.user_id(), &request.name)
3787 .await?;
3788 let root_id = channel_model.root_id();
3789 let channel = Channel::from_model(channel_model);
3790
3791 response.send(proto::RenameChannelResponse {
3792 channel: Some(channel.to_proto()),
3793 })?;
3794
3795 let connection_pool = session.connection_pool().await;
3796 let update = proto::UpdateChannels {
3797 channels: vec![channel.to_proto()],
3798 ..Default::default()
3799 };
3800 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3801 if role.can_see_channel(channel.visibility) {
3802 session.peer.send(connection_id, update.clone())?;
3803 }
3804 }
3805
3806 Ok(())
3807}
3808
3809/// Move a channel to a new parent.
3810async fn move_channel(
3811 request: proto::MoveChannel,
3812 response: Response<proto::MoveChannel>,
3813 session: UserSession,
3814) -> Result<()> {
3815 let channel_id = ChannelId::from_proto(request.channel_id);
3816 let to = ChannelId::from_proto(request.to);
3817
3818 let (root_id, channels) = session
3819 .db()
3820 .await
3821 .move_channel(channel_id, to, session.user_id())
3822 .await?;
3823
3824 let connection_pool = session.connection_pool().await;
3825 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3826 let channels = channels
3827 .iter()
3828 .filter_map(|channel| {
3829 if role.can_see_channel(channel.visibility) {
3830 Some(channel.to_proto())
3831 } else {
3832 None
3833 }
3834 })
3835 .collect::<Vec<_>>();
3836 if channels.is_empty() {
3837 continue;
3838 }
3839
3840 let update = proto::UpdateChannels {
3841 channels,
3842 ..Default::default()
3843 };
3844
3845 session.peer.send(connection_id, update.clone())?;
3846 }
3847
3848 response.send(Ack {})?;
3849 Ok(())
3850}
3851
3852/// Get the list of channel members
3853async fn get_channel_members(
3854 request: proto::GetChannelMembers,
3855 response: Response<proto::GetChannelMembers>,
3856 session: UserSession,
3857) -> Result<()> {
3858 let db = session.db().await;
3859 let channel_id = ChannelId::from_proto(request.channel_id);
3860 let limit = if request.limit == 0 {
3861 u16::MAX as u64
3862 } else {
3863 request.limit
3864 };
3865 let (members, users) = db
3866 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3867 .await?;
3868 response.send(proto::GetChannelMembersResponse { members, users })?;
3869 Ok(())
3870}
3871
3872/// Accept or decline a channel invitation.
3873async fn respond_to_channel_invite(
3874 request: proto::RespondToChannelInvite,
3875 response: Response<proto::RespondToChannelInvite>,
3876 session: UserSession,
3877) -> Result<()> {
3878 let db = session.db().await;
3879 let channel_id = ChannelId::from_proto(request.channel_id);
3880 let RespondToChannelInvite {
3881 membership_update,
3882 notifications,
3883 } = db
3884 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3885 .await?;
3886
3887 let mut connection_pool = session.connection_pool().await;
3888 if let Some(membership_update) = membership_update {
3889 notify_membership_updated(
3890 &mut connection_pool,
3891 membership_update,
3892 session.user_id(),
3893 &session.peer,
3894 );
3895 } else {
3896 let update = proto::UpdateChannels {
3897 remove_channel_invitations: vec![channel_id.to_proto()],
3898 ..Default::default()
3899 };
3900
3901 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3902 session.peer.send(connection_id, update.clone())?;
3903 }
3904 };
3905
3906 send_notifications(&connection_pool, &session.peer, notifications);
3907
3908 response.send(proto::Ack {})?;
3909
3910 Ok(())
3911}
3912
3913/// Join the channels' room
3914async fn join_channel(
3915 request: proto::JoinChannel,
3916 response: Response<proto::JoinChannel>,
3917 session: UserSession,
3918) -> Result<()> {
3919 let channel_id = ChannelId::from_proto(request.channel_id);
3920 join_channel_internal(channel_id, Box::new(response), session).await
3921}
3922
3923trait JoinChannelInternalResponse {
3924 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3925}
3926impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3927 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3928 Response::<proto::JoinChannel>::send(self, result)
3929 }
3930}
3931impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3932 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3933 Response::<proto::JoinRoom>::send(self, result)
3934 }
3935}
3936
3937async fn join_channel_internal(
3938 channel_id: ChannelId,
3939 response: Box<impl JoinChannelInternalResponse>,
3940 session: UserSession,
3941) -> Result<()> {
3942 let joined_room = {
3943 let mut db = session.db().await;
3944 // If zed quits without leaving the room, and the user re-opens zed before the
3945 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3946 // room they were in.
3947 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3948 tracing::info!(
3949 stale_connection_id = %connection,
3950 "cleaning up stale connection",
3951 );
3952 drop(db);
3953 leave_room_for_session(&session, connection).await?;
3954 db = session.db().await;
3955 }
3956
3957 let (joined_room, membership_updated, role) = db
3958 .join_channel(channel_id, session.user_id(), session.connection_id)
3959 .await?;
3960
3961 let live_kit_connection_info =
3962 session
3963 .app_state
3964 .live_kit_client
3965 .as_ref()
3966 .and_then(|live_kit| {
3967 let (can_publish, token) = if role == ChannelRole::Guest {
3968 (
3969 false,
3970 live_kit
3971 .guest_token(
3972 &joined_room.room.live_kit_room,
3973 &session.user_id().to_string(),
3974 )
3975 .trace_err()?,
3976 )
3977 } else {
3978 (
3979 true,
3980 live_kit
3981 .room_token(
3982 &joined_room.room.live_kit_room,
3983 &session.user_id().to_string(),
3984 )
3985 .trace_err()?,
3986 )
3987 };
3988
3989 Some(LiveKitConnectionInfo {
3990 server_url: live_kit.url().into(),
3991 token,
3992 can_publish,
3993 })
3994 });
3995
3996 response.send(proto::JoinRoomResponse {
3997 room: Some(joined_room.room.clone()),
3998 channel_id: joined_room
3999 .channel
4000 .as_ref()
4001 .map(|channel| channel.id.to_proto()),
4002 live_kit_connection_info,
4003 })?;
4004
4005 let mut connection_pool = session.connection_pool().await;
4006 if let Some(membership_updated) = membership_updated {
4007 notify_membership_updated(
4008 &mut connection_pool,
4009 membership_updated,
4010 session.user_id(),
4011 &session.peer,
4012 );
4013 }
4014
4015 room_updated(&joined_room.room, &session.peer);
4016
4017 joined_room
4018 };
4019
4020 channel_updated(
4021 &joined_room
4022 .channel
4023 .ok_or_else(|| anyhow!("channel not returned"))?,
4024 &joined_room.room,
4025 &session.peer,
4026 &*session.connection_pool().await,
4027 );
4028
4029 update_user_contacts(session.user_id(), &session).await?;
4030 Ok(())
4031}
4032
4033/// Start editing the channel notes
4034async fn join_channel_buffer(
4035 request: proto::JoinChannelBuffer,
4036 response: Response<proto::JoinChannelBuffer>,
4037 session: UserSession,
4038) -> Result<()> {
4039 let db = session.db().await;
4040 let channel_id = ChannelId::from_proto(request.channel_id);
4041
4042 let open_response = db
4043 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4044 .await?;
4045
4046 let collaborators = open_response.collaborators.clone();
4047 response.send(open_response)?;
4048
4049 let update = UpdateChannelBufferCollaborators {
4050 channel_id: channel_id.to_proto(),
4051 collaborators: collaborators.clone(),
4052 };
4053 channel_buffer_updated(
4054 session.connection_id,
4055 collaborators
4056 .iter()
4057 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4058 &update,
4059 &session.peer,
4060 );
4061
4062 Ok(())
4063}
4064
4065/// Edit the channel notes
4066async fn update_channel_buffer(
4067 request: proto::UpdateChannelBuffer,
4068 session: UserSession,
4069) -> Result<()> {
4070 let db = session.db().await;
4071 let channel_id = ChannelId::from_proto(request.channel_id);
4072
4073 let (collaborators, epoch, version) = db
4074 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4075 .await?;
4076
4077 channel_buffer_updated(
4078 session.connection_id,
4079 collaborators.clone(),
4080 &proto::UpdateChannelBuffer {
4081 channel_id: channel_id.to_proto(),
4082 operations: request.operations,
4083 },
4084 &session.peer,
4085 );
4086
4087 let pool = &*session.connection_pool().await;
4088
4089 let non_collaborators =
4090 pool.channel_connection_ids(channel_id)
4091 .filter_map(|(connection_id, _)| {
4092 if collaborators.contains(&connection_id) {
4093 None
4094 } else {
4095 Some(connection_id)
4096 }
4097 });
4098
4099 broadcast(None, non_collaborators, |peer_id| {
4100 session.peer.send(
4101 peer_id,
4102 proto::UpdateChannels {
4103 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4104 channel_id: channel_id.to_proto(),
4105 epoch: epoch as u64,
4106 version: version.clone(),
4107 }],
4108 ..Default::default()
4109 },
4110 )
4111 });
4112
4113 Ok(())
4114}
4115
4116/// Rejoin the channel notes after a connection blip
4117async fn rejoin_channel_buffers(
4118 request: proto::RejoinChannelBuffers,
4119 response: Response<proto::RejoinChannelBuffers>,
4120 session: UserSession,
4121) -> Result<()> {
4122 let db = session.db().await;
4123 let buffers = db
4124 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4125 .await?;
4126
4127 for rejoined_buffer in &buffers {
4128 let collaborators_to_notify = rejoined_buffer
4129 .buffer
4130 .collaborators
4131 .iter()
4132 .filter_map(|c| Some(c.peer_id?.into()));
4133 channel_buffer_updated(
4134 session.connection_id,
4135 collaborators_to_notify,
4136 &proto::UpdateChannelBufferCollaborators {
4137 channel_id: rejoined_buffer.buffer.channel_id,
4138 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4139 },
4140 &session.peer,
4141 );
4142 }
4143
4144 response.send(proto::RejoinChannelBuffersResponse {
4145 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4146 })?;
4147
4148 Ok(())
4149}
4150
4151/// Stop editing the channel notes
4152async fn leave_channel_buffer(
4153 request: proto::LeaveChannelBuffer,
4154 response: Response<proto::LeaveChannelBuffer>,
4155 session: UserSession,
4156) -> Result<()> {
4157 let db = session.db().await;
4158 let channel_id = ChannelId::from_proto(request.channel_id);
4159
4160 let left_buffer = db
4161 .leave_channel_buffer(channel_id, session.connection_id)
4162 .await?;
4163
4164 response.send(Ack {})?;
4165
4166 channel_buffer_updated(
4167 session.connection_id,
4168 left_buffer.connections,
4169 &proto::UpdateChannelBufferCollaborators {
4170 channel_id: channel_id.to_proto(),
4171 collaborators: left_buffer.collaborators,
4172 },
4173 &session.peer,
4174 );
4175
4176 Ok(())
4177}
4178
4179fn channel_buffer_updated<T: EnvelopedMessage>(
4180 sender_id: ConnectionId,
4181 collaborators: impl IntoIterator<Item = ConnectionId>,
4182 message: &T,
4183 peer: &Peer,
4184) {
4185 broadcast(Some(sender_id), collaborators, |peer_id| {
4186 peer.send(peer_id, message.clone())
4187 });
4188}
4189
4190fn send_notifications(
4191 connection_pool: &ConnectionPool,
4192 peer: &Peer,
4193 notifications: db::NotificationBatch,
4194) {
4195 for (user_id, notification) in notifications {
4196 for connection_id in connection_pool.user_connection_ids(user_id) {
4197 if let Err(error) = peer.send(
4198 connection_id,
4199 proto::AddNotification {
4200 notification: Some(notification.clone()),
4201 },
4202 ) {
4203 tracing::error!(
4204 "failed to send notification to {:?} {}",
4205 connection_id,
4206 error
4207 );
4208 }
4209 }
4210 }
4211}
4212
4213/// Send a message to the channel
4214async fn send_channel_message(
4215 request: proto::SendChannelMessage,
4216 response: Response<proto::SendChannelMessage>,
4217 session: UserSession,
4218) -> Result<()> {
4219 // Validate the message body.
4220 let body = request.body.trim().to_string();
4221 if body.len() > MAX_MESSAGE_LEN {
4222 return Err(anyhow!("message is too long"))?;
4223 }
4224 if body.is_empty() {
4225 return Err(anyhow!("message can't be blank"))?;
4226 }
4227
4228 // TODO: adjust mentions if body is trimmed
4229
4230 let timestamp = OffsetDateTime::now_utc();
4231 let nonce = request
4232 .nonce
4233 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4234
4235 let channel_id = ChannelId::from_proto(request.channel_id);
4236 let CreatedChannelMessage {
4237 message_id,
4238 participant_connection_ids,
4239 notifications,
4240 } = session
4241 .db()
4242 .await
4243 .create_channel_message(
4244 channel_id,
4245 session.user_id(),
4246 &body,
4247 &request.mentions,
4248 timestamp,
4249 nonce.clone().into(),
4250 request.reply_to_message_id.map(MessageId::from_proto),
4251 )
4252 .await?;
4253
4254 let message = proto::ChannelMessage {
4255 sender_id: session.user_id().to_proto(),
4256 id: message_id.to_proto(),
4257 body,
4258 mentions: request.mentions,
4259 timestamp: timestamp.unix_timestamp() as u64,
4260 nonce: Some(nonce),
4261 reply_to_message_id: request.reply_to_message_id,
4262 edited_at: None,
4263 };
4264 broadcast(
4265 Some(session.connection_id),
4266 participant_connection_ids.clone(),
4267 |connection| {
4268 session.peer.send(
4269 connection,
4270 proto::ChannelMessageSent {
4271 channel_id: channel_id.to_proto(),
4272 message: Some(message.clone()),
4273 },
4274 )
4275 },
4276 );
4277 response.send(proto::SendChannelMessageResponse {
4278 message: Some(message),
4279 })?;
4280
4281 let pool = &*session.connection_pool().await;
4282 let non_participants =
4283 pool.channel_connection_ids(channel_id)
4284 .filter_map(|(connection_id, _)| {
4285 if participant_connection_ids.contains(&connection_id) {
4286 None
4287 } else {
4288 Some(connection_id)
4289 }
4290 });
4291 broadcast(None, non_participants, |peer_id| {
4292 session.peer.send(
4293 peer_id,
4294 proto::UpdateChannels {
4295 latest_channel_message_ids: vec![proto::ChannelMessageId {
4296 channel_id: channel_id.to_proto(),
4297 message_id: message_id.to_proto(),
4298 }],
4299 ..Default::default()
4300 },
4301 )
4302 });
4303 send_notifications(pool, &session.peer, notifications);
4304
4305 Ok(())
4306}
4307
4308/// Delete a channel message
4309async fn remove_channel_message(
4310 request: proto::RemoveChannelMessage,
4311 response: Response<proto::RemoveChannelMessage>,
4312 session: UserSession,
4313) -> Result<()> {
4314 let channel_id = ChannelId::from_proto(request.channel_id);
4315 let message_id = MessageId::from_proto(request.message_id);
4316 let (connection_ids, existing_notification_ids) = session
4317 .db()
4318 .await
4319 .remove_channel_message(channel_id, message_id, session.user_id())
4320 .await?;
4321
4322 broadcast(
4323 Some(session.connection_id),
4324 connection_ids,
4325 move |connection| {
4326 session.peer.send(connection, request.clone())?;
4327
4328 for notification_id in &existing_notification_ids {
4329 session.peer.send(
4330 connection,
4331 proto::DeleteNotification {
4332 notification_id: (*notification_id).to_proto(),
4333 },
4334 )?;
4335 }
4336
4337 Ok(())
4338 },
4339 );
4340 response.send(proto::Ack {})?;
4341 Ok(())
4342}
4343
4344async fn update_channel_message(
4345 request: proto::UpdateChannelMessage,
4346 response: Response<proto::UpdateChannelMessage>,
4347 session: UserSession,
4348) -> Result<()> {
4349 let channel_id = ChannelId::from_proto(request.channel_id);
4350 let message_id = MessageId::from_proto(request.message_id);
4351 let updated_at = OffsetDateTime::now_utc();
4352 let UpdatedChannelMessage {
4353 message_id,
4354 participant_connection_ids,
4355 notifications,
4356 reply_to_message_id,
4357 timestamp,
4358 deleted_mention_notification_ids,
4359 updated_mention_notifications,
4360 } = session
4361 .db()
4362 .await
4363 .update_channel_message(
4364 channel_id,
4365 message_id,
4366 session.user_id(),
4367 request.body.as_str(),
4368 &request.mentions,
4369 updated_at,
4370 )
4371 .await?;
4372
4373 let nonce = request
4374 .nonce
4375 .clone()
4376 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4377
4378 let message = proto::ChannelMessage {
4379 sender_id: session.user_id().to_proto(),
4380 id: message_id.to_proto(),
4381 body: request.body.clone(),
4382 mentions: request.mentions.clone(),
4383 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4384 nonce: Some(nonce),
4385 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4386 edited_at: Some(updated_at.unix_timestamp() as u64),
4387 };
4388
4389 response.send(proto::Ack {})?;
4390
4391 let pool = &*session.connection_pool().await;
4392 broadcast(
4393 Some(session.connection_id),
4394 participant_connection_ids,
4395 |connection| {
4396 session.peer.send(
4397 connection,
4398 proto::ChannelMessageUpdate {
4399 channel_id: channel_id.to_proto(),
4400 message: Some(message.clone()),
4401 },
4402 )?;
4403
4404 for notification_id in &deleted_mention_notification_ids {
4405 session.peer.send(
4406 connection,
4407 proto::DeleteNotification {
4408 notification_id: (*notification_id).to_proto(),
4409 },
4410 )?;
4411 }
4412
4413 for notification in &updated_mention_notifications {
4414 session.peer.send(
4415 connection,
4416 proto::UpdateNotification {
4417 notification: Some(notification.clone()),
4418 },
4419 )?;
4420 }
4421
4422 Ok(())
4423 },
4424 );
4425
4426 send_notifications(pool, &session.peer, notifications);
4427
4428 Ok(())
4429}
4430
4431/// Mark a channel message as read
4432async fn acknowledge_channel_message(
4433 request: proto::AckChannelMessage,
4434 session: UserSession,
4435) -> Result<()> {
4436 let channel_id = ChannelId::from_proto(request.channel_id);
4437 let message_id = MessageId::from_proto(request.message_id);
4438 let notifications = session
4439 .db()
4440 .await
4441 .observe_channel_message(channel_id, session.user_id(), message_id)
4442 .await?;
4443 send_notifications(
4444 &*session.connection_pool().await,
4445 &session.peer,
4446 notifications,
4447 );
4448 Ok(())
4449}
4450
4451/// Mark a buffer version as synced
4452async fn acknowledge_buffer_version(
4453 request: proto::AckBufferOperation,
4454 session: UserSession,
4455) -> Result<()> {
4456 let buffer_id = BufferId::from_proto(request.buffer_id);
4457 session
4458 .db()
4459 .await
4460 .observe_buffer_version(
4461 buffer_id,
4462 session.user_id(),
4463 request.epoch as i32,
4464 &request.version,
4465 )
4466 .await?;
4467 Ok(())
4468}
4469
4470async fn count_language_model_tokens(
4471 request: proto::CountLanguageModelTokens,
4472 response: Response<proto::CountLanguageModelTokens>,
4473 session: Session,
4474 config: &Config,
4475) -> Result<()> {
4476 let Some(session) = session.for_user() else {
4477 return Err(anyhow!("user not found"))?;
4478 };
4479 authorize_access_to_legacy_llm_endpoints(&session).await?;
4480
4481 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4482 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4483 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4484 };
4485
4486 session
4487 .app_state
4488 .rate_limiter
4489 .check(&*rate_limit, session.user_id())
4490 .await?;
4491
4492 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4493 Some(proto::LanguageModelProvider::Google) => {
4494 let api_key = config
4495 .google_ai_api_key
4496 .as_ref()
4497 .context("no Google AI API key configured on the server")?;
4498 google_ai::count_tokens(
4499 session.http_client.as_ref(),
4500 google_ai::API_URL,
4501 api_key,
4502 serde_json::from_str(&request.request)?,
4503 None,
4504 )
4505 .await?
4506 }
4507 _ => return Err(anyhow!("unsupported provider"))?,
4508 };
4509
4510 response.send(proto::CountLanguageModelTokensResponse {
4511 token_count: result.total_tokens as u32,
4512 })?;
4513
4514 Ok(())
4515}
4516
4517struct ZedProCountLanguageModelTokensRateLimit;
4518
4519impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4520 fn capacity(&self) -> usize {
4521 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4522 .ok()
4523 .and_then(|v| v.parse().ok())
4524 .unwrap_or(600) // Picked arbitrarily
4525 }
4526
4527 fn refill_duration(&self) -> chrono::Duration {
4528 chrono::Duration::hours(1)
4529 }
4530
4531 fn db_name(&self) -> &'static str {
4532 "zed-pro:count-language-model-tokens"
4533 }
4534}
4535
4536struct FreeCountLanguageModelTokensRateLimit;
4537
4538impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4539 fn capacity(&self) -> usize {
4540 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4541 .ok()
4542 .and_then(|v| v.parse().ok())
4543 .unwrap_or(600 / 10) // Picked arbitrarily
4544 }
4545
4546 fn refill_duration(&self) -> chrono::Duration {
4547 chrono::Duration::hours(1)
4548 }
4549
4550 fn db_name(&self) -> &'static str {
4551 "free:count-language-model-tokens"
4552 }
4553}
4554
4555struct ZedProComputeEmbeddingsRateLimit;
4556
4557impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4558 fn capacity(&self) -> usize {
4559 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4560 .ok()
4561 .and_then(|v| v.parse().ok())
4562 .unwrap_or(5000) // Picked arbitrarily
4563 }
4564
4565 fn refill_duration(&self) -> chrono::Duration {
4566 chrono::Duration::hours(1)
4567 }
4568
4569 fn db_name(&self) -> &'static str {
4570 "zed-pro:compute-embeddings"
4571 }
4572}
4573
4574struct FreeComputeEmbeddingsRateLimit;
4575
4576impl RateLimit for FreeComputeEmbeddingsRateLimit {
4577 fn capacity(&self) -> usize {
4578 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4579 .ok()
4580 .and_then(|v| v.parse().ok())
4581 .unwrap_or(5000 / 10) // Picked arbitrarily
4582 }
4583
4584 fn refill_duration(&self) -> chrono::Duration {
4585 chrono::Duration::hours(1)
4586 }
4587
4588 fn db_name(&self) -> &'static str {
4589 "free:compute-embeddings"
4590 }
4591}
4592
4593async fn compute_embeddings(
4594 request: proto::ComputeEmbeddings,
4595 response: Response<proto::ComputeEmbeddings>,
4596 session: UserSession,
4597 api_key: Option<Arc<str>>,
4598) -> Result<()> {
4599 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4600 authorize_access_to_legacy_llm_endpoints(&session).await?;
4601
4602 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4603 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4604 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4605 };
4606
4607 session
4608 .app_state
4609 .rate_limiter
4610 .check(&*rate_limit, session.user_id())
4611 .await?;
4612
4613 let embeddings = match request.model.as_str() {
4614 "openai/text-embedding-3-small" => {
4615 open_ai::embed(
4616 session.http_client.as_ref(),
4617 OPEN_AI_API_URL,
4618 &api_key,
4619 OpenAiEmbeddingModel::TextEmbedding3Small,
4620 request.texts.iter().map(|text| text.as_str()),
4621 )
4622 .await?
4623 }
4624 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4625 };
4626
4627 let embeddings = request
4628 .texts
4629 .iter()
4630 .map(|text| {
4631 let mut hasher = sha2::Sha256::new();
4632 hasher.update(text.as_bytes());
4633 let result = hasher.finalize();
4634 result.to_vec()
4635 })
4636 .zip(
4637 embeddings
4638 .data
4639 .into_iter()
4640 .map(|embedding| embedding.embedding),
4641 )
4642 .collect::<HashMap<_, _>>();
4643
4644 let db = session.db().await;
4645 db.save_embeddings(&request.model, &embeddings)
4646 .await
4647 .context("failed to save embeddings")
4648 .trace_err();
4649
4650 response.send(proto::ComputeEmbeddingsResponse {
4651 embeddings: embeddings
4652 .into_iter()
4653 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4654 .collect(),
4655 })?;
4656 Ok(())
4657}
4658
4659async fn get_cached_embeddings(
4660 request: proto::GetCachedEmbeddings,
4661 response: Response<proto::GetCachedEmbeddings>,
4662 session: UserSession,
4663) -> Result<()> {
4664 authorize_access_to_legacy_llm_endpoints(&session).await?;
4665
4666 let db = session.db().await;
4667 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4668
4669 response.send(proto::GetCachedEmbeddingsResponse {
4670 embeddings: embeddings
4671 .into_iter()
4672 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4673 .collect(),
4674 })?;
4675 Ok(())
4676}
4677
4678/// This is leftover from before the LLM service.
4679///
4680/// The endpoints protected by this check will be moved there eventually.
4681async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4682 if session.is_staff() {
4683 Ok(())
4684 } else {
4685 Err(anyhow!("permission denied"))?
4686 }
4687}
4688
4689/// Get a Supermaven API key for the user
4690async fn get_supermaven_api_key(
4691 _request: proto::GetSupermavenApiKey,
4692 response: Response<proto::GetSupermavenApiKey>,
4693 session: UserSession,
4694) -> Result<()> {
4695 let user_id: String = session.user_id().to_string();
4696 if !session.is_staff() {
4697 return Err(anyhow!("supermaven not enabled for this account"))?;
4698 }
4699
4700 let email = session
4701 .email()
4702 .ok_or_else(|| anyhow!("user must have an email"))?;
4703
4704 let supermaven_admin_api = session
4705 .supermaven_client
4706 .as_ref()
4707 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4708
4709 let result = supermaven_admin_api
4710 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4711 .await?;
4712
4713 response.send(proto::GetSupermavenApiKeyResponse {
4714 api_key: result.api_key,
4715 })?;
4716
4717 Ok(())
4718}
4719
4720/// Start receiving chat updates for a channel
4721async fn join_channel_chat(
4722 request: proto::JoinChannelChat,
4723 response: Response<proto::JoinChannelChat>,
4724 session: UserSession,
4725) -> Result<()> {
4726 let channel_id = ChannelId::from_proto(request.channel_id);
4727
4728 let db = session.db().await;
4729 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4730 .await?;
4731 let messages = db
4732 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4733 .await?;
4734 response.send(proto::JoinChannelChatResponse {
4735 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4736 messages,
4737 })?;
4738 Ok(())
4739}
4740
4741/// Stop receiving chat updates for a channel
4742async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4743 let channel_id = ChannelId::from_proto(request.channel_id);
4744 session
4745 .db()
4746 .await
4747 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4748 .await?;
4749 Ok(())
4750}
4751
4752/// Retrieve the chat history for a channel
4753async fn get_channel_messages(
4754 request: proto::GetChannelMessages,
4755 response: Response<proto::GetChannelMessages>,
4756 session: UserSession,
4757) -> Result<()> {
4758 let channel_id = ChannelId::from_proto(request.channel_id);
4759 let messages = session
4760 .db()
4761 .await
4762 .get_channel_messages(
4763 channel_id,
4764 session.user_id(),
4765 MESSAGE_COUNT_PER_PAGE,
4766 Some(MessageId::from_proto(request.before_message_id)),
4767 )
4768 .await?;
4769 response.send(proto::GetChannelMessagesResponse {
4770 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4771 messages,
4772 })?;
4773 Ok(())
4774}
4775
4776/// Retrieve specific chat messages
4777async fn get_channel_messages_by_id(
4778 request: proto::GetChannelMessagesById,
4779 response: Response<proto::GetChannelMessagesById>,
4780 session: UserSession,
4781) -> Result<()> {
4782 let message_ids = request
4783 .message_ids
4784 .iter()
4785 .map(|id| MessageId::from_proto(*id))
4786 .collect::<Vec<_>>();
4787 let messages = session
4788 .db()
4789 .await
4790 .get_channel_messages_by_id(session.user_id(), &message_ids)
4791 .await?;
4792 response.send(proto::GetChannelMessagesResponse {
4793 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4794 messages,
4795 })?;
4796 Ok(())
4797}
4798
4799/// Retrieve the current users notifications
4800async fn get_notifications(
4801 request: proto::GetNotifications,
4802 response: Response<proto::GetNotifications>,
4803 session: UserSession,
4804) -> Result<()> {
4805 let notifications = session
4806 .db()
4807 .await
4808 .get_notifications(
4809 session.user_id(),
4810 NOTIFICATION_COUNT_PER_PAGE,
4811 request.before_id.map(db::NotificationId::from_proto),
4812 )
4813 .await?;
4814 response.send(proto::GetNotificationsResponse {
4815 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4816 notifications,
4817 })?;
4818 Ok(())
4819}
4820
4821/// Mark notifications as read
4822async fn mark_notification_as_read(
4823 request: proto::MarkNotificationRead,
4824 response: Response<proto::MarkNotificationRead>,
4825 session: UserSession,
4826) -> Result<()> {
4827 let database = &session.db().await;
4828 let notifications = database
4829 .mark_notification_as_read_by_id(
4830 session.user_id(),
4831 NotificationId::from_proto(request.notification_id),
4832 )
4833 .await?;
4834 send_notifications(
4835 &*session.connection_pool().await,
4836 &session.peer,
4837 notifications,
4838 );
4839 response.send(proto::Ack {})?;
4840 Ok(())
4841}
4842
4843/// Get the current users information
4844async fn get_private_user_info(
4845 _request: proto::GetPrivateUserInfo,
4846 response: Response<proto::GetPrivateUserInfo>,
4847 session: UserSession,
4848) -> Result<()> {
4849 let db = session.db().await;
4850
4851 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4852 let user = db
4853 .get_user_by_id(session.user_id())
4854 .await?
4855 .ok_or_else(|| anyhow!("user not found"))?;
4856 let flags = db.get_user_flags(session.user_id()).await?;
4857
4858 response.send(proto::GetPrivateUserInfoResponse {
4859 metrics_id,
4860 staff: user.admin,
4861 flags,
4862 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4863 })?;
4864 Ok(())
4865}
4866
4867/// Accept the terms of service (tos) on behalf of the current user
4868async fn accept_terms_of_service(
4869 _request: proto::AcceptTermsOfService,
4870 response: Response<proto::AcceptTermsOfService>,
4871 session: UserSession,
4872) -> Result<()> {
4873 let db = session.db().await;
4874
4875 let accepted_tos_at = Utc::now();
4876 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4877 .await?;
4878
4879 response.send(proto::AcceptTermsOfServiceResponse {
4880 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4881 })?;
4882 Ok(())
4883}
4884
4885/// The minimum account age an account must have in order to use the LLM service.
4886const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4887
4888async fn get_llm_api_token(
4889 _request: proto::GetLlmToken,
4890 response: Response<proto::GetLlmToken>,
4891 session: UserSession,
4892) -> Result<()> {
4893 let db = session.db().await;
4894
4895 let flags = db.get_user_flags(session.user_id()).await?;
4896 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4897 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4898
4899 if !session.is_staff() && !has_language_models_feature_flag {
4900 Err(anyhow!("permission denied"))?
4901 }
4902
4903 let user_id = session.user_id();
4904 let user = db
4905 .get_user_by_id(user_id)
4906 .await?
4907 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4908
4909 if user.accepted_tos_at.is_none() {
4910 Err(anyhow!("terms of service not accepted"))?
4911 }
4912
4913 let mut account_created_at = user.created_at;
4914 if let Some(github_created_at) = user.github_user_created_at {
4915 account_created_at = account_created_at.min(github_created_at);
4916 }
4917 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4918 Err(anyhow!("account too young"))?
4919 }
4920
4921 let billing_preferences = db.get_billing_preferences(user.id).await?;
4922
4923 let token = LlmTokenClaims::create(
4924 user.id,
4925 user.github_login.clone(),
4926 session.is_staff(),
4927 billing_preferences,
4928 has_llm_closed_beta_feature_flag,
4929 session.has_llm_subscription(&db).await?,
4930 session.current_plan(&db).await?,
4931 &session.app_state.config,
4932 )?;
4933 response.send(proto::GetLlmTokenResponse { token })?;
4934 Ok(())
4935}
4936
4937fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4938 let message = match message {
4939 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4940 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4941 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4942 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4943 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4944 code: frame.code.into(),
4945 reason: frame.reason,
4946 })),
4947 // We should never receive a frame while reading the message, according
4948 // to the `tungstenite` maintainers:
4949 //
4950 // > It cannot occur when you read messages from the WebSocket, but it
4951 // > can be used when you want to send the raw frames (e.g. you want to
4952 // > send the frames to the WebSocket without composing the full message first).
4953 // >
4954 // > — https://github.com/snapview/tungstenite-rs/issues/268
4955 TungsteniteMessage::Frame(_) => {
4956 bail!("received an unexpected frame while reading the message")
4957 }
4958 };
4959
4960 Ok(message)
4961}
4962
4963fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4964 match message {
4965 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4966 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4967 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4968 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4969 AxumMessage::Close(frame) => {
4970 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4971 code: frame.code.into(),
4972 reason: frame.reason,
4973 }))
4974 }
4975 }
4976}
4977
4978fn notify_membership_updated(
4979 connection_pool: &mut ConnectionPool,
4980 result: MembershipUpdated,
4981 user_id: UserId,
4982 peer: &Peer,
4983) {
4984 for membership in &result.new_channels.channel_memberships {
4985 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4986 }
4987 for channel_id in &result.removed_channels {
4988 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4989 }
4990
4991 let user_channels_update = proto::UpdateUserChannels {
4992 channel_memberships: result
4993 .new_channels
4994 .channel_memberships
4995 .iter()
4996 .map(|cm| proto::ChannelMembership {
4997 channel_id: cm.channel_id.to_proto(),
4998 role: cm.role.into(),
4999 })
5000 .collect(),
5001 ..Default::default()
5002 };
5003
5004 let mut update = build_channels_update(result.new_channels);
5005 update.delete_channels = result
5006 .removed_channels
5007 .into_iter()
5008 .map(|id| id.to_proto())
5009 .collect();
5010 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5011
5012 for connection_id in connection_pool.user_connection_ids(user_id) {
5013 peer.send(connection_id, user_channels_update.clone())
5014 .trace_err();
5015 peer.send(connection_id, update.clone()).trace_err();
5016 }
5017}
5018
5019fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5020 proto::UpdateUserChannels {
5021 channel_memberships: channels
5022 .channel_memberships
5023 .iter()
5024 .map(|m| proto::ChannelMembership {
5025 channel_id: m.channel_id.to_proto(),
5026 role: m.role.into(),
5027 })
5028 .collect(),
5029 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5030 observed_channel_message_id: channels.observed_channel_messages.clone(),
5031 }
5032}
5033
5034fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5035 let mut update = proto::UpdateChannels::default();
5036
5037 for channel in channels.channels {
5038 update.channels.push(channel.to_proto());
5039 }
5040
5041 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5042 update.latest_channel_message_ids = channels.latest_channel_messages;
5043
5044 for (channel_id, participants) in channels.channel_participants {
5045 update
5046 .channel_participants
5047 .push(proto::ChannelParticipants {
5048 channel_id: channel_id.to_proto(),
5049 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5050 });
5051 }
5052
5053 for channel in channels.invited_channels {
5054 update.channel_invitations.push(channel.to_proto());
5055 }
5056
5057 update.hosted_projects = channels.hosted_projects;
5058 update
5059}
5060
5061fn build_initial_contacts_update(
5062 contacts: Vec<db::Contact>,
5063 pool: &ConnectionPool,
5064) -> proto::UpdateContacts {
5065 let mut update = proto::UpdateContacts::default();
5066
5067 for contact in contacts {
5068 match contact {
5069 db::Contact::Accepted { user_id, busy } => {
5070 update.contacts.push(contact_for_user(user_id, busy, pool));
5071 }
5072 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5073 db::Contact::Incoming { user_id } => {
5074 update
5075 .incoming_requests
5076 .push(proto::IncomingContactRequest {
5077 requester_id: user_id.to_proto(),
5078 })
5079 }
5080 }
5081 }
5082
5083 update
5084}
5085
5086fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5087 proto::Contact {
5088 user_id: user_id.to_proto(),
5089 online: pool.is_user_online(user_id),
5090 busy,
5091 }
5092}
5093
5094fn room_updated(room: &proto::Room, peer: &Peer) {
5095 broadcast(
5096 None,
5097 room.participants
5098 .iter()
5099 .filter_map(|participant| Some(participant.peer_id?.into())),
5100 |peer_id| {
5101 peer.send(
5102 peer_id,
5103 proto::RoomUpdated {
5104 room: Some(room.clone()),
5105 },
5106 )
5107 },
5108 );
5109}
5110
5111fn channel_updated(
5112 channel: &db::channel::Model,
5113 room: &proto::Room,
5114 peer: &Peer,
5115 pool: &ConnectionPool,
5116) {
5117 let participants = room
5118 .participants
5119 .iter()
5120 .map(|p| p.user_id)
5121 .collect::<Vec<_>>();
5122
5123 broadcast(
5124 None,
5125 pool.channel_connection_ids(channel.root_id())
5126 .filter_map(|(channel_id, role)| {
5127 role.can_see_channel(channel.visibility)
5128 .then_some(channel_id)
5129 }),
5130 |peer_id| {
5131 peer.send(
5132 peer_id,
5133 proto::UpdateChannels {
5134 channel_participants: vec![proto::ChannelParticipants {
5135 channel_id: channel.id.to_proto(),
5136 participant_user_ids: participants.clone(),
5137 }],
5138 ..Default::default()
5139 },
5140 )
5141 },
5142 );
5143}
5144
5145async fn send_dev_server_projects_update(
5146 user_id: UserId,
5147 mut status: proto::DevServerProjectsUpdate,
5148 session: &Session,
5149) {
5150 let pool = session.connection_pool().await;
5151 for dev_server in &mut status.dev_servers {
5152 dev_server.status =
5153 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5154 }
5155 let connections = pool.user_connection_ids(user_id);
5156 for connection_id in connections {
5157 session.peer.send(connection_id, status.clone()).trace_err();
5158 }
5159}
5160
5161async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5162 let db = session.db().await;
5163
5164 let contacts = db.get_contacts(user_id).await?;
5165 let busy = db.is_user_busy(user_id).await?;
5166
5167 let pool = session.connection_pool().await;
5168 let updated_contact = contact_for_user(user_id, busy, &pool);
5169 for contact in contacts {
5170 if let db::Contact::Accepted {
5171 user_id: contact_user_id,
5172 ..
5173 } = contact
5174 {
5175 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5176 session
5177 .peer
5178 .send(
5179 contact_conn_id,
5180 proto::UpdateContacts {
5181 contacts: vec![updated_contact.clone()],
5182 remove_contacts: Default::default(),
5183 incoming_requests: Default::default(),
5184 remove_incoming_requests: Default::default(),
5185 outgoing_requests: Default::default(),
5186 remove_outgoing_requests: Default::default(),
5187 },
5188 )
5189 .trace_err();
5190 }
5191 }
5192 }
5193 Ok(())
5194}
5195
5196async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5197 log::info!("lost dev server connection, unsharing projects");
5198 let project_ids = session
5199 .db()
5200 .await
5201 .get_stale_dev_server_projects(session.connection_id)
5202 .await?;
5203
5204 for project_id in project_ids {
5205 // not unshare re-checks the connection ids match, so we get away with no transaction
5206 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5207 }
5208
5209 let user_id = session.dev_server().user_id;
5210 let update = session
5211 .db()
5212 .await
5213 .dev_server_projects_update(user_id)
5214 .await?;
5215
5216 send_dev_server_projects_update(user_id, update, session).await;
5217
5218 Ok(())
5219}
5220
5221async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5222 let mut contacts_to_update = HashSet::default();
5223
5224 let room_id;
5225 let canceled_calls_to_user_ids;
5226 let live_kit_room;
5227 let delete_live_kit_room;
5228 let room;
5229 let channel;
5230
5231 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5232 contacts_to_update.insert(session.user_id());
5233
5234 for project in left_room.left_projects.values() {
5235 project_left(project, session);
5236 }
5237
5238 room_id = RoomId::from_proto(left_room.room.id);
5239 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5240 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5241 delete_live_kit_room = left_room.deleted;
5242 room = mem::take(&mut left_room.room);
5243 channel = mem::take(&mut left_room.channel);
5244
5245 room_updated(&room, &session.peer);
5246 } else {
5247 return Ok(());
5248 }
5249
5250 if let Some(channel) = channel {
5251 channel_updated(
5252 &channel,
5253 &room,
5254 &session.peer,
5255 &*session.connection_pool().await,
5256 );
5257 }
5258
5259 {
5260 let pool = session.connection_pool().await;
5261 for canceled_user_id in canceled_calls_to_user_ids {
5262 for connection_id in pool.user_connection_ids(canceled_user_id) {
5263 session
5264 .peer
5265 .send(
5266 connection_id,
5267 proto::CallCanceled {
5268 room_id: room_id.to_proto(),
5269 },
5270 )
5271 .trace_err();
5272 }
5273 contacts_to_update.insert(canceled_user_id);
5274 }
5275 }
5276
5277 for contact_user_id in contacts_to_update {
5278 update_user_contacts(contact_user_id, session).await?;
5279 }
5280
5281 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5282 live_kit
5283 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5284 .await
5285 .trace_err();
5286
5287 if delete_live_kit_room {
5288 live_kit.delete_room(live_kit_room).await.trace_err();
5289 }
5290 }
5291
5292 Ok(())
5293}
5294
5295async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5296 let left_channel_buffers = session
5297 .db()
5298 .await
5299 .leave_channel_buffers(session.connection_id)
5300 .await?;
5301
5302 for left_buffer in left_channel_buffers {
5303 channel_buffer_updated(
5304 session.connection_id,
5305 left_buffer.connections,
5306 &proto::UpdateChannelBufferCollaborators {
5307 channel_id: left_buffer.channel_id.to_proto(),
5308 collaborators: left_buffer.collaborators,
5309 },
5310 &session.peer,
5311 );
5312 }
5313
5314 Ok(())
5315}
5316
5317fn project_left(project: &db::LeftProject, session: &UserSession) {
5318 for connection_id in &project.connection_ids {
5319 if project.should_unshare {
5320 session
5321 .peer
5322 .send(
5323 *connection_id,
5324 proto::UnshareProject {
5325 project_id: project.id.to_proto(),
5326 },
5327 )
5328 .trace_err();
5329 } else {
5330 session
5331 .peer
5332 .send(
5333 *connection_id,
5334 proto::RemoveProjectCollaborator {
5335 project_id: project.id.to_proto(),
5336 peer_id: Some(session.connection_id.into()),
5337 },
5338 )
5339 .trace_err();
5340 }
5341 }
5342}
5343
5344pub trait ResultExt {
5345 type Ok;
5346
5347 fn trace_err(self) -> Option<Self::Ok>;
5348}
5349
5350impl<T, E> ResultExt for Result<T, E>
5351where
5352 E: std::fmt::Debug,
5353{
5354 type Ok = T;
5355
5356 #[track_caller]
5357 fn trace_err(self) -> Option<T> {
5358 match self {
5359 Ok(value) => Some(value),
5360 Err(error) => {
5361 tracing::error!("{:?}", error);
5362 None
5363 }
5364 }
5365 }
5366}