1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81use self::connection_pool::VersionedMessage;
82
83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
84
85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
87
88const MESSAGE_COUNT_PER_PAGE: usize = 100;
89const MAX_MESSAGE_LEN: usize = 1024;
90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
94
95struct Response<R> {
96 peer: Arc<Peer>,
97 receipt: Receipt<R>,
98 responded: Arc<AtomicBool>,
99}
100
101impl<R: RequestMessage> Response<R> {
102 fn send(self, payload: R::Response) -> Result<()> {
103 self.responded.store(true, SeqCst);
104 self.peer.respond(self.receipt, payload)?;
105 Ok(())
106 }
107}
108
109#[derive(Clone, Debug)]
110pub enum Principal {
111 User(User),
112 Impersonated { user: User, admin: User },
113 DevServer(dev_server::Model),
114}
115
116impl Principal {
117 fn update_span(&self, span: &tracing::Span) {
118 match &self {
119 Principal::User(user) => {
120 span.record("user_id", &user.id.0);
121 span.record("login", &user.github_login);
122 }
123 Principal::Impersonated { user, admin } => {
124 span.record("user_id", &user.id.0);
125 span.record("login", &user.github_login);
126 span.record("impersonator", &admin.github_login);
127 }
128 Principal::DevServer(dev_server) => {
129 span.record("dev_server_id", &dev_server.id.0);
130 }
131 }
132 }
133}
134
135#[derive(Clone)]
136struct Session {
137 principal: Principal,
138 connection_id: ConnectionId,
139 db: Arc<tokio::sync::Mutex<DbHandle>>,
140 peer: Arc<Peer>,
141 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
142 app_state: Arc<AppState>,
143 supermaven_client: Option<Arc<SupermavenAdminApi>>,
144 http_client: Arc<IsahcHttpClient>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn for_dev_server(self) -> Option<DevServerSession> {
176 DevServerSession::new(self)
177 }
178
179 fn user_id(&self) -> Option<UserId> {
180 match &self.principal {
181 Principal::User(user) => Some(user.id),
182 Principal::Impersonated { user, .. } => Some(user.id),
183 Principal::DevServer(_) => None,
184 }
185 }
186
187 fn is_staff(&self) -> bool {
188 match &self.principal {
189 Principal::User(user) => user.admin,
190 Principal::Impersonated { .. } => true,
191 Principal::DevServer(_) => false,
192 }
193 }
194
195 pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
196 if self.is_staff() {
197 return Ok(proto::Plan::ZedPro);
198 }
199
200 let Some(user_id) = self.user_id() else {
201 return Ok(proto::Plan::Free);
202 };
203
204 let db = self.db().await;
205 if db.has_active_billing_subscription(user_id).await? {
206 Ok(proto::Plan::ZedPro)
207 } else {
208 Ok(proto::Plan::Free)
209 }
210 }
211
212 fn dev_server_id(&self) -> Option<DevServerId> {
213 match &self.principal {
214 Principal::User(_) | Principal::Impersonated { .. } => None,
215 Principal::DevServer(dev_server) => Some(dev_server.id),
216 }
217 }
218
219 fn principal_id(&self) -> PrincipalId {
220 match &self.principal {
221 Principal::User(user) => PrincipalId::UserId(user.id),
222 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
223 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
224 }
225 }
226}
227
228impl Debug for Session {
229 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
230 let mut result = f.debug_struct("Session");
231 match &self.principal {
232 Principal::User(user) => {
233 result.field("user", &user.github_login);
234 }
235 Principal::Impersonated { user, admin } => {
236 result.field("user", &user.github_login);
237 result.field("impersonator", &admin.github_login);
238 }
239 Principal::DevServer(dev_server) => {
240 result.field("dev_server", &dev_server.id);
241 }
242 }
243 result.field("connection_id", &self.connection_id).finish()
244 }
245}
246
247struct UserSession(Session);
248
249impl UserSession {
250 pub fn new(s: Session) -> Option<Self> {
251 s.user_id().map(|_| UserSession(s))
252 }
253 pub fn user_id(&self) -> UserId {
254 self.0.user_id().unwrap()
255 }
256
257 pub fn email(&self) -> Option<String> {
258 match &self.0.principal {
259 Principal::User(user) => user.email_address.clone(),
260 Principal::Impersonated { user, .. } => user.email_address.clone(),
261 Principal::DevServer(..) => None,
262 }
263 }
264}
265
266impl Deref for UserSession {
267 type Target = Session;
268
269 fn deref(&self) -> &Self::Target {
270 &self.0
271 }
272}
273impl DerefMut for UserSession {
274 fn deref_mut(&mut self) -> &mut Self::Target {
275 &mut self.0
276 }
277}
278
279struct DevServerSession(Session);
280
281impl DevServerSession {
282 pub fn new(s: Session) -> Option<Self> {
283 s.dev_server_id().map(|_| DevServerSession(s))
284 }
285 pub fn dev_server_id(&self) -> DevServerId {
286 self.0.dev_server_id().unwrap()
287 }
288
289 fn dev_server(&self) -> &dev_server::Model {
290 match &self.0.principal {
291 Principal::DevServer(dev_server) => dev_server,
292 _ => unreachable!(),
293 }
294 }
295}
296
297impl Deref for DevServerSession {
298 type Target = Session;
299
300 fn deref(&self) -> &Self::Target {
301 &self.0
302 }
303}
304impl DerefMut for DevServerSession {
305 fn deref_mut(&mut self) -> &mut Self::Target {
306 &mut self.0
307 }
308}
309
310fn user_handler<M: RequestMessage, Fut>(
311 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
312) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
313where
314 Fut: Send + Future<Output = Result<()>>,
315{
316 let handler = Arc::new(handler);
317 move |message, response, session| {
318 let handler = handler.clone();
319 Box::pin(async move {
320 if let Some(user_session) = session.for_user() {
321 Ok(handler(message, response, user_session).await?)
322 } else {
323 Err(Error::Internal(anyhow!(
324 "must be a user to call {}",
325 M::NAME
326 )))
327 }
328 })
329 }
330}
331
332fn dev_server_handler<M: RequestMessage, Fut>(
333 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
334) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
335where
336 Fut: Send + Future<Output = Result<()>>,
337{
338 let handler = Arc::new(handler);
339 move |message, response, session| {
340 let handler = handler.clone();
341 Box::pin(async move {
342 if let Some(dev_server_session) = session.for_dev_server() {
343 Ok(handler(message, response, dev_server_session).await?)
344 } else {
345 Err(Error::Internal(anyhow!(
346 "must be a dev server to call {}",
347 M::NAME
348 )))
349 }
350 })
351 }
352}
353
354fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
355 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
356) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
357where
358 InnertRetFut: Send + Future<Output = Result<()>>,
359{
360 let handler = Arc::new(handler);
361 move |message, session| {
362 let handler = handler.clone();
363 Box::pin(async move {
364 if let Some(user_session) = session.for_user() {
365 Ok(handler(message, user_session).await?)
366 } else {
367 Err(Error::Internal(anyhow!(
368 "must be a user to call {}",
369 M::NAME
370 )))
371 }
372 })
373 }
374}
375
376struct DbHandle(Arc<Database>);
377
378impl Deref for DbHandle {
379 type Target = Database;
380
381 fn deref(&self) -> &Self::Target {
382 self.0.as_ref()
383 }
384}
385
386pub struct Server {
387 id: parking_lot::Mutex<ServerId>,
388 peer: Arc<Peer>,
389 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
390 app_state: Arc<AppState>,
391 handlers: HashMap<TypeId, MessageHandler>,
392 teardown: watch::Sender<bool>,
393}
394
395pub(crate) struct ConnectionPoolGuard<'a> {
396 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
397 _not_send: PhantomData<Rc<()>>,
398}
399
400#[derive(Serialize)]
401pub struct ServerSnapshot<'a> {
402 peer: &'a Peer,
403 #[serde(serialize_with = "serialize_deref")]
404 connection_pool: ConnectionPoolGuard<'a>,
405}
406
407pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
408where
409 S: Serializer,
410 T: Deref<Target = U>,
411 U: Serialize,
412{
413 Serialize::serialize(value.deref(), serializer)
414}
415
416impl Server {
417 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
418 let mut server = Self {
419 id: parking_lot::Mutex::new(id),
420 peer: Peer::new(id.0 as u32),
421 app_state: app_state.clone(),
422 connection_pool: Default::default(),
423 handlers: Default::default(),
424 teardown: watch::channel(false).0,
425 };
426
427 server
428 .add_request_handler(ping)
429 .add_request_handler(user_handler(create_room))
430 .add_request_handler(user_handler(join_room))
431 .add_request_handler(user_handler(rejoin_room))
432 .add_request_handler(user_handler(leave_room))
433 .add_request_handler(user_handler(set_room_participant_role))
434 .add_request_handler(user_handler(call))
435 .add_request_handler(user_handler(cancel_call))
436 .add_message_handler(user_message_handler(decline_call))
437 .add_request_handler(user_handler(update_participant_location))
438 .add_request_handler(user_handler(share_project))
439 .add_message_handler(unshare_project)
440 .add_request_handler(user_handler(join_project))
441 .add_request_handler(user_handler(join_hosted_project))
442 .add_request_handler(user_handler(rejoin_dev_server_projects))
443 .add_request_handler(user_handler(create_dev_server_project))
444 .add_request_handler(user_handler(update_dev_server_project))
445 .add_request_handler(user_handler(delete_dev_server_project))
446 .add_request_handler(user_handler(create_dev_server))
447 .add_request_handler(user_handler(regenerate_dev_server_token))
448 .add_request_handler(user_handler(rename_dev_server))
449 .add_request_handler(user_handler(delete_dev_server))
450 .add_request_handler(user_handler(list_remote_directory))
451 .add_request_handler(dev_server_handler(share_dev_server_project))
452 .add_request_handler(dev_server_handler(shutdown_dev_server))
453 .add_request_handler(dev_server_handler(reconnect_dev_server))
454 .add_message_handler(user_message_handler(leave_project))
455 .add_request_handler(update_project)
456 .add_request_handler(update_worktree)
457 .add_message_handler(start_language_server)
458 .add_message_handler(update_language_server)
459 .add_message_handler(update_diagnostic_summary)
460 .add_message_handler(update_worktree_settings)
461 .add_request_handler(user_handler(
462 forward_project_request_for_owner::<proto::TaskContextForLocation>,
463 ))
464 .add_request_handler(user_handler(
465 forward_project_request_for_owner::<proto::TaskTemplates>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetHover>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDefinition>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetTypeDefinition>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::GetReferences>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::SearchProject>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::GetDocumentHighlights>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::GetProjectSymbols>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
490 ))
491 .add_request_handler(user_handler(
492 forward_read_only_project_request::<proto::OpenBufferById>,
493 ))
494 .add_request_handler(user_handler(
495 forward_read_only_project_request::<proto::SynchronizeBuffers>,
496 ))
497 .add_request_handler(user_handler(
498 forward_read_only_project_request::<proto::InlayHints>,
499 ))
500 .add_request_handler(user_handler(
501 forward_read_only_project_request::<proto::OpenBufferByPath>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCompletions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
508 ))
509 .add_request_handler(user_handler(
510 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::GetCodeActions>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::ApplyCodeAction>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::PrepareRename>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::PerformRename>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::ReloadBuffers>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::FormatBuffers>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::CreateProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::RenameProjectEntry>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::CopyProjectEntry>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::DeleteProjectEntry>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::ExpandProjectEntry>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::OnTypeFormatting>,
550 ))
551 .add_request_handler(user_handler(
552 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
553 ))
554 .add_request_handler(user_handler(
555 forward_mutating_project_request::<proto::BlameBuffer>,
556 ))
557 .add_request_handler(user_handler(
558 forward_mutating_project_request::<proto::MultiLspQuery>,
559 ))
560 .add_request_handler(user_handler(
561 forward_mutating_project_request::<proto::RestartLanguageServers>,
562 ))
563 .add_request_handler(user_handler(
564 forward_mutating_project_request::<proto::LinkedEditingRange>,
565 ))
566 .add_message_handler(create_buffer_for_peer)
567 .add_request_handler(update_buffer)
568 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
572 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
573 .add_request_handler(get_users)
574 .add_request_handler(user_handler(fuzzy_search_users))
575 .add_request_handler(user_handler(request_contact))
576 .add_request_handler(user_handler(remove_contact))
577 .add_request_handler(user_handler(respond_to_contact_request))
578 .add_message_handler(subscribe_to_channels)
579 .add_request_handler(user_handler(create_channel))
580 .add_request_handler(user_handler(delete_channel))
581 .add_request_handler(user_handler(invite_channel_member))
582 .add_request_handler(user_handler(remove_channel_member))
583 .add_request_handler(user_handler(set_channel_member_role))
584 .add_request_handler(user_handler(set_channel_visibility))
585 .add_request_handler(user_handler(rename_channel))
586 .add_request_handler(user_handler(join_channel_buffer))
587 .add_request_handler(user_handler(leave_channel_buffer))
588 .add_message_handler(user_message_handler(update_channel_buffer))
589 .add_request_handler(user_handler(rejoin_channel_buffers))
590 .add_request_handler(user_handler(get_channel_members))
591 .add_request_handler(user_handler(respond_to_channel_invite))
592 .add_request_handler(user_handler(join_channel))
593 .add_request_handler(user_handler(join_channel_chat))
594 .add_message_handler(user_message_handler(leave_channel_chat))
595 .add_request_handler(user_handler(send_channel_message))
596 .add_request_handler(user_handler(remove_channel_message))
597 .add_request_handler(user_handler(update_channel_message))
598 .add_request_handler(user_handler(get_channel_messages))
599 .add_request_handler(user_handler(get_channel_messages_by_id))
600 .add_request_handler(user_handler(get_notifications))
601 .add_request_handler(user_handler(mark_notification_as_read))
602 .add_request_handler(user_handler(move_channel))
603 .add_request_handler(user_handler(follow))
604 .add_message_handler(user_message_handler(unfollow))
605 .add_message_handler(user_message_handler(update_followers))
606 .add_request_handler(user_handler(get_private_user_info))
607 .add_request_handler(user_handler(get_llm_api_token))
608 .add_request_handler(user_handler(accept_terms_of_service))
609 .add_message_handler(user_message_handler(acknowledge_channel_message))
610 .add_message_handler(user_message_handler(acknowledge_buffer_version))
611 .add_request_handler(user_handler(get_supermaven_api_key))
612 .add_request_handler(user_handler(
613 forward_mutating_project_request::<proto::OpenContext>,
614 ))
615 .add_request_handler(user_handler(
616 forward_mutating_project_request::<proto::CreateContext>,
617 ))
618 .add_request_handler(user_handler(
619 forward_mutating_project_request::<proto::SynchronizeContexts>,
620 ))
621 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
622 .add_message_handler(update_context)
623 .add_request_handler({
624 let app_state = app_state.clone();
625 move |request, response, session| {
626 let app_state = app_state.clone();
627 async move {
628 count_language_model_tokens(request, response, session, &app_state.config)
629 .await
630 }
631 }
632 })
633 .add_request_handler({
634 user_handler(move |request, response, session| {
635 get_cached_embeddings(request, response, session)
636 })
637 })
638 .add_request_handler({
639 let app_state = app_state.clone();
640 user_handler(move |request, response, session| {
641 compute_embeddings(
642 request,
643 response,
644 session,
645 app_state.config.openai_api_key.clone(),
646 )
647 })
648 });
649
650 Arc::new(server)
651 }
652
653 pub async fn start(&self) -> Result<()> {
654 let server_id = *self.id.lock();
655 let app_state = self.app_state.clone();
656 let peer = self.peer.clone();
657 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
658 let pool = self.connection_pool.clone();
659 let live_kit_client = self.app_state.live_kit_client.clone();
660
661 let span = info_span!("start server");
662 self.app_state.executor.spawn_detached(
663 async move {
664 tracing::info!("waiting for cleanup timeout");
665 timeout.await;
666 tracing::info!("cleanup timeout expired, retrieving stale rooms");
667 if let Some((room_ids, channel_ids)) = app_state
668 .db
669 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
670 .await
671 .trace_err()
672 {
673 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
674 tracing::info!(
675 stale_channel_buffer_count = channel_ids.len(),
676 "retrieved stale channel buffers"
677 );
678
679 for channel_id in channel_ids {
680 if let Some(refreshed_channel_buffer) = app_state
681 .db
682 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
683 .await
684 .trace_err()
685 {
686 for connection_id in refreshed_channel_buffer.connection_ids {
687 peer.send(
688 connection_id,
689 proto::UpdateChannelBufferCollaborators {
690 channel_id: channel_id.to_proto(),
691 collaborators: refreshed_channel_buffer
692 .collaborators
693 .clone(),
694 },
695 )
696 .trace_err();
697 }
698 }
699 }
700
701 for room_id in room_ids {
702 let mut contacts_to_update = HashSet::default();
703 let mut canceled_calls_to_user_ids = Vec::new();
704 let mut live_kit_room = String::new();
705 let mut delete_live_kit_room = false;
706
707 if let Some(mut refreshed_room) = app_state
708 .db
709 .clear_stale_room_participants(room_id, server_id)
710 .await
711 .trace_err()
712 {
713 tracing::info!(
714 room_id = room_id.0,
715 new_participant_count = refreshed_room.room.participants.len(),
716 "refreshed room"
717 );
718 room_updated(&refreshed_room.room, &peer);
719 if let Some(channel) = refreshed_room.channel.as_ref() {
720 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
721 }
722 contacts_to_update
723 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
724 contacts_to_update
725 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
726 canceled_calls_to_user_ids =
727 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
728 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
729 delete_live_kit_room = refreshed_room.room.participants.is_empty();
730 }
731
732 {
733 let pool = pool.lock();
734 for canceled_user_id in canceled_calls_to_user_ids {
735 for connection_id in pool.user_connection_ids(canceled_user_id) {
736 peer.send(
737 connection_id,
738 proto::CallCanceled {
739 room_id: room_id.to_proto(),
740 },
741 )
742 .trace_err();
743 }
744 }
745 }
746
747 for user_id in contacts_to_update {
748 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
749 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
750 if let Some((busy, contacts)) = busy.zip(contacts) {
751 let pool = pool.lock();
752 let updated_contact = contact_for_user(user_id, busy, &pool);
753 for contact in contacts {
754 if let db::Contact::Accepted {
755 user_id: contact_user_id,
756 ..
757 } = contact
758 {
759 for contact_conn_id in
760 pool.user_connection_ids(contact_user_id)
761 {
762 peer.send(
763 contact_conn_id,
764 proto::UpdateContacts {
765 contacts: vec![updated_contact.clone()],
766 remove_contacts: Default::default(),
767 incoming_requests: Default::default(),
768 remove_incoming_requests: Default::default(),
769 outgoing_requests: Default::default(),
770 remove_outgoing_requests: Default::default(),
771 },
772 )
773 .trace_err();
774 }
775 }
776 }
777 }
778 }
779
780 if let Some(live_kit) = live_kit_client.as_ref() {
781 if delete_live_kit_room {
782 live_kit.delete_room(live_kit_room).await.trace_err();
783 }
784 }
785 }
786 }
787
788 app_state
789 .db
790 .delete_stale_servers(&app_state.config.zed_environment, server_id)
791 .await
792 .trace_err();
793 }
794 .instrument(span),
795 );
796 Ok(())
797 }
798
799 pub fn teardown(&self) {
800 self.peer.teardown();
801 self.connection_pool.lock().reset();
802 let _ = self.teardown.send(true);
803 }
804
805 #[cfg(test)]
806 pub fn reset(&self, id: ServerId) {
807 self.teardown();
808 *self.id.lock() = id;
809 self.peer.reset(id.0 as u32);
810 let _ = self.teardown.send(false);
811 }
812
813 #[cfg(test)]
814 pub fn id(&self) -> ServerId {
815 *self.id.lock()
816 }
817
818 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
819 where
820 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
821 Fut: 'static + Send + Future<Output = Result<()>>,
822 M: EnvelopedMessage,
823 {
824 let prev_handler = self.handlers.insert(
825 TypeId::of::<M>(),
826 Box::new(move |envelope, session| {
827 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
828 let received_at = envelope.received_at;
829 tracing::info!("message received");
830 let start_time = Instant::now();
831 let future = (handler)(*envelope, session);
832 async move {
833 let result = future.await;
834 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
835 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
836 let queue_duration_ms = total_duration_ms - processing_duration_ms;
837 let payload_type = M::NAME;
838
839 match result {
840 Err(error) => {
841 tracing::error!(
842 ?error,
843 total_duration_ms,
844 processing_duration_ms,
845 queue_duration_ms,
846 payload_type,
847 "error handling message"
848 )
849 }
850 Ok(()) => tracing::info!(
851 total_duration_ms,
852 processing_duration_ms,
853 queue_duration_ms,
854 "finished handling message"
855 ),
856 }
857 }
858 .boxed()
859 }),
860 );
861 if prev_handler.is_some() {
862 panic!("registered a handler for the same message twice");
863 }
864 self
865 }
866
867 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
868 where
869 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
870 Fut: 'static + Send + Future<Output = Result<()>>,
871 M: EnvelopedMessage,
872 {
873 self.add_handler(move |envelope, session| handler(envelope.payload, session));
874 self
875 }
876
877 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
878 where
879 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
880 Fut: Send + Future<Output = Result<()>>,
881 M: RequestMessage,
882 {
883 let handler = Arc::new(handler);
884 self.add_handler(move |envelope, session| {
885 let receipt = envelope.receipt();
886 let handler = handler.clone();
887 async move {
888 let peer = session.peer.clone();
889 let responded = Arc::new(AtomicBool::default());
890 let response = Response {
891 peer: peer.clone(),
892 responded: responded.clone(),
893 receipt,
894 };
895 match (handler)(envelope.payload, response, session).await {
896 Ok(()) => {
897 if responded.load(std::sync::atomic::Ordering::SeqCst) {
898 Ok(())
899 } else {
900 Err(anyhow!("handler did not send a response"))?
901 }
902 }
903 Err(error) => {
904 let proto_err = match &error {
905 Error::Internal(err) => err.to_proto(),
906 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
907 };
908 peer.respond_with_error(receipt, proto_err)?;
909 Err(error)
910 }
911 }
912 }
913 })
914 }
915
916 #[allow(clippy::too_many_arguments)]
917 pub fn handle_connection(
918 self: &Arc<Self>,
919 connection: Connection,
920 address: String,
921 principal: Principal,
922 zed_version: ZedVersion,
923 geoip_country_code: Option<String>,
924 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
925 executor: Executor,
926 ) -> impl Future<Output = ()> {
927 let this = self.clone();
928 let span = info_span!("handle connection", %address,
929 connection_id=field::Empty,
930 user_id=field::Empty,
931 login=field::Empty,
932 impersonator=field::Empty,
933 dev_server_id=field::Empty,
934 geoip_country_code=field::Empty
935 );
936 principal.update_span(&span);
937 if let Some(country_code) = geoip_country_code.as_ref() {
938 span.record("geoip_country_code", country_code);
939 }
940
941 let mut teardown = self.teardown.subscribe();
942 async move {
943 if *teardown.borrow() {
944 tracing::error!("server is tearing down");
945 return
946 }
947 let (connection_id, handle_io, mut incoming_rx) = this
948 .peer
949 .add_connection(connection, {
950 let executor = executor.clone();
951 move |duration| executor.sleep(duration)
952 });
953 tracing::Span::current().record("connection_id", format!("{}", connection_id));
954
955 tracing::info!("connection opened");
956
957 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
958 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
959 Ok(http_client) => Arc::new(http_client),
960 Err(error) => {
961 tracing::error!(?error, "failed to create HTTP client");
962 return;
963 }
964 };
965
966 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
967 Some(Arc::new(SupermavenAdminApi::new(
968 supermaven_admin_api_key.to_string(),
969 http_client.clone(),
970 )))
971 } else {
972 None
973 };
974
975 let session = Session {
976 principal: principal.clone(),
977 connection_id,
978 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
979 peer: this.peer.clone(),
980 connection_pool: this.connection_pool.clone(),
981 app_state: this.app_state.clone(),
982 http_client,
983 geoip_country_code,
984 _executor: executor.clone(),
985 supermaven_client,
986 };
987
988 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
989 tracing::error!(?error, "failed to send initial client update");
990 return;
991 }
992
993 let handle_io = handle_io.fuse();
994 futures::pin_mut!(handle_io);
995
996 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
997 // This prevents deadlocks when e.g., client A performs a request to client B and
998 // client B performs a request to client A. If both clients stop processing further
999 // messages until their respective request completes, they won't have a chance to
1000 // respond to the other client's request and cause a deadlock.
1001 //
1002 // This arrangement ensures we will attempt to process earlier messages first, but fall
1003 // back to processing messages arrived later in the spirit of making progress.
1004 let mut foreground_message_handlers = FuturesUnordered::new();
1005 let concurrent_handlers = Arc::new(Semaphore::new(256));
1006 loop {
1007 let next_message = async {
1008 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1009 let message = incoming_rx.next().await;
1010 (permit, message)
1011 }.fuse();
1012 futures::pin_mut!(next_message);
1013 futures::select_biased! {
1014 _ = teardown.changed().fuse() => return,
1015 result = handle_io => {
1016 if let Err(error) = result {
1017 tracing::error!(?error, "error handling I/O");
1018 }
1019 break;
1020 }
1021 _ = foreground_message_handlers.next() => {}
1022 next_message = next_message => {
1023 let (permit, message) = next_message;
1024 if let Some(message) = message {
1025 let type_name = message.payload_type_name();
1026 // note: we copy all the fields from the parent span so we can query them in the logs.
1027 // (https://github.com/tokio-rs/tracing/issues/2670).
1028 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1029 user_id=field::Empty,
1030 login=field::Empty,
1031 impersonator=field::Empty,
1032 dev_server_id=field::Empty
1033 );
1034 principal.update_span(&span);
1035 let span_enter = span.enter();
1036 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1037 let is_background = message.is_background();
1038 let handle_message = (handler)(message, session.clone());
1039 drop(span_enter);
1040
1041 let handle_message = async move {
1042 handle_message.await;
1043 drop(permit);
1044 }.instrument(span);
1045 if is_background {
1046 executor.spawn_detached(handle_message);
1047 } else {
1048 foreground_message_handlers.push(handle_message);
1049 }
1050 } else {
1051 tracing::error!("no message handler");
1052 }
1053 } else {
1054 tracing::info!("connection closed");
1055 break;
1056 }
1057 }
1058 }
1059 }
1060
1061 drop(foreground_message_handlers);
1062 tracing::info!("signing out");
1063 if let Err(error) = connection_lost(session, teardown, executor).await {
1064 tracing::error!(?error, "error signing out");
1065 }
1066
1067 }.instrument(span)
1068 }
1069
1070 async fn send_initial_client_update(
1071 &self,
1072 connection_id: ConnectionId,
1073 principal: &Principal,
1074 zed_version: ZedVersion,
1075 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1076 session: &Session,
1077 ) -> Result<()> {
1078 self.peer.send(
1079 connection_id,
1080 proto::Hello {
1081 peer_id: Some(connection_id.into()),
1082 },
1083 )?;
1084 tracing::info!("sent hello message");
1085 if let Some(send_connection_id) = send_connection_id.take() {
1086 let _ = send_connection_id.send(connection_id);
1087 }
1088
1089 match principal {
1090 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1091 if !user.connected_once {
1092 self.peer.send(connection_id, proto::ShowContacts {})?;
1093 self.app_state
1094 .db
1095 .set_user_connected_once(user.id, true)
1096 .await?;
1097 }
1098
1099 update_user_plan(user.id, session).await?;
1100
1101 let (contacts, dev_server_projects) = future::try_join(
1102 self.app_state.db.get_contacts(user.id),
1103 self.app_state.db.dev_server_projects_update(user.id),
1104 )
1105 .await?;
1106
1107 {
1108 let mut pool = self.connection_pool.lock();
1109 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1110 self.peer.send(
1111 connection_id,
1112 build_initial_contacts_update(contacts, &pool),
1113 )?;
1114 }
1115
1116 if should_auto_subscribe_to_channels(zed_version) {
1117 subscribe_user_to_channels(user.id, session).await?;
1118 }
1119
1120 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1121
1122 if let Some(incoming_call) =
1123 self.app_state.db.incoming_call_for_user(user.id).await?
1124 {
1125 self.peer.send(connection_id, incoming_call)?;
1126 }
1127
1128 update_user_contacts(user.id, &session).await?;
1129 }
1130 Principal::DevServer(dev_server) => {
1131 {
1132 let mut pool = self.connection_pool.lock();
1133 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1134 {
1135 self.peer.send(
1136 stale_connection_id,
1137 proto::ShutdownDevServer {
1138 reason: Some(
1139 "another dev server connected with the same token".to_string(),
1140 ),
1141 },
1142 )?;
1143 pool.remove_connection(stale_connection_id)?;
1144 };
1145 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1146 }
1147
1148 let projects = self
1149 .app_state
1150 .db
1151 .get_projects_for_dev_server(dev_server.id)
1152 .await?;
1153 self.peer
1154 .send(connection_id, proto::DevServerInstructions { projects })?;
1155
1156 let status = self
1157 .app_state
1158 .db
1159 .dev_server_projects_update(dev_server.user_id)
1160 .await?;
1161 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1162 }
1163 }
1164
1165 Ok(())
1166 }
1167
1168 pub async fn invite_code_redeemed(
1169 self: &Arc<Self>,
1170 inviter_id: UserId,
1171 invitee_id: UserId,
1172 ) -> Result<()> {
1173 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1174 if let Some(code) = &user.invite_code {
1175 let pool = self.connection_pool.lock();
1176 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1177 for connection_id in pool.user_connection_ids(inviter_id) {
1178 self.peer.send(
1179 connection_id,
1180 proto::UpdateContacts {
1181 contacts: vec![invitee_contact.clone()],
1182 ..Default::default()
1183 },
1184 )?;
1185 self.peer.send(
1186 connection_id,
1187 proto::UpdateInviteInfo {
1188 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1189 count: user.invite_count as u32,
1190 },
1191 )?;
1192 }
1193 }
1194 }
1195 Ok(())
1196 }
1197
1198 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1199 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1200 if let Some(invite_code) = &user.invite_code {
1201 let pool = self.connection_pool.lock();
1202 for connection_id in pool.user_connection_ids(user_id) {
1203 self.peer.send(
1204 connection_id,
1205 proto::UpdateInviteInfo {
1206 url: format!(
1207 "{}{}",
1208 self.app_state.config.invite_link_prefix, invite_code
1209 ),
1210 count: user.invite_count as u32,
1211 },
1212 )?;
1213 }
1214 }
1215 }
1216 Ok(())
1217 }
1218
1219 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1220 ServerSnapshot {
1221 connection_pool: ConnectionPoolGuard {
1222 guard: self.connection_pool.lock(),
1223 _not_send: PhantomData,
1224 },
1225 peer: &self.peer,
1226 }
1227 }
1228}
1229
1230impl<'a> Deref for ConnectionPoolGuard<'a> {
1231 type Target = ConnectionPool;
1232
1233 fn deref(&self) -> &Self::Target {
1234 &self.guard
1235 }
1236}
1237
1238impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1239 fn deref_mut(&mut self) -> &mut Self::Target {
1240 &mut self.guard
1241 }
1242}
1243
1244impl<'a> Drop for ConnectionPoolGuard<'a> {
1245 fn drop(&mut self) {
1246 #[cfg(test)]
1247 self.check_invariants();
1248 }
1249}
1250
1251fn broadcast<F>(
1252 sender_id: Option<ConnectionId>,
1253 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1254 mut f: F,
1255) where
1256 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1257{
1258 for receiver_id in receiver_ids {
1259 if Some(receiver_id) != sender_id {
1260 if let Err(error) = f(receiver_id) {
1261 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1262 }
1263 }
1264 }
1265}
1266
1267pub struct ProtocolVersion(u32);
1268
1269impl Header for ProtocolVersion {
1270 fn name() -> &'static HeaderName {
1271 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1272 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1273 }
1274
1275 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1276 where
1277 Self: Sized,
1278 I: Iterator<Item = &'i axum::http::HeaderValue>,
1279 {
1280 let version = values
1281 .next()
1282 .ok_or_else(axum::headers::Error::invalid)?
1283 .to_str()
1284 .map_err(|_| axum::headers::Error::invalid())?
1285 .parse()
1286 .map_err(|_| axum::headers::Error::invalid())?;
1287 Ok(Self(version))
1288 }
1289
1290 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1291 values.extend([self.0.to_string().parse().unwrap()]);
1292 }
1293}
1294
1295pub struct AppVersionHeader(SemanticVersion);
1296impl Header for AppVersionHeader {
1297 fn name() -> &'static HeaderName {
1298 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1299 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1300 }
1301
1302 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1303 where
1304 Self: Sized,
1305 I: Iterator<Item = &'i axum::http::HeaderValue>,
1306 {
1307 let version = values
1308 .next()
1309 .ok_or_else(axum::headers::Error::invalid)?
1310 .to_str()
1311 .map_err(|_| axum::headers::Error::invalid())?
1312 .parse()
1313 .map_err(|_| axum::headers::Error::invalid())?;
1314 Ok(Self(version))
1315 }
1316
1317 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1318 values.extend([self.0.to_string().parse().unwrap()]);
1319 }
1320}
1321
1322pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1323 Router::new()
1324 .route("/rpc", get(handle_websocket_request))
1325 .layer(
1326 ServiceBuilder::new()
1327 .layer(Extension(server.app_state.clone()))
1328 .layer(middleware::from_fn(auth::validate_header)),
1329 )
1330 .route("/metrics", get(handle_metrics))
1331 .layer(Extension(server))
1332}
1333
1334pub async fn handle_websocket_request(
1335 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1336 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1337 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1338 Extension(server): Extension<Arc<Server>>,
1339 Extension(principal): Extension<Principal>,
1340 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1341 ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343 if protocol_version != rpc::PROTOCOL_VERSION {
1344 return (
1345 StatusCode::UPGRADE_REQUIRED,
1346 "client must be upgraded".to_string(),
1347 )
1348 .into_response();
1349 }
1350
1351 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352 return (
1353 StatusCode::UPGRADE_REQUIRED,
1354 "no version header found".to_string(),
1355 )
1356 .into_response();
1357 };
1358
1359 if !version.can_collaborate() {
1360 return (
1361 StatusCode::UPGRADE_REQUIRED,
1362 "client must be upgraded".to_string(),
1363 )
1364 .into_response();
1365 }
1366
1367 let socket_address = socket_address.to_string();
1368 ws.on_upgrade(move |socket| {
1369 let socket = socket
1370 .map_ok(to_tungstenite_message)
1371 .err_into()
1372 .with(|message| async move { to_axum_message(message) });
1373 let connection = Connection::new(Box::pin(socket));
1374 async move {
1375 server
1376 .handle_connection(
1377 connection,
1378 socket_address,
1379 principal,
1380 version,
1381 country_code_header.map(|header| header.to_string()),
1382 None,
1383 Executor::Production,
1384 )
1385 .await;
1386 }
1387 })
1388}
1389
1390pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1391 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1392 let connections_metric = CONNECTIONS_METRIC
1393 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1394
1395 let connections = server
1396 .connection_pool
1397 .lock()
1398 .connections()
1399 .filter(|connection| !connection.admin)
1400 .count();
1401 connections_metric.set(connections as _);
1402
1403 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1405 register_int_gauge!(
1406 "shared_projects",
1407 "number of open projects with one or more guests"
1408 )
1409 .unwrap()
1410 });
1411
1412 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1413 shared_projects_metric.set(shared_projects as _);
1414
1415 let encoder = prometheus::TextEncoder::new();
1416 let metric_families = prometheus::gather();
1417 let encoded_metrics = encoder
1418 .encode_to_string(&metric_families)
1419 .map_err(|err| anyhow!("{}", err))?;
1420 Ok(encoded_metrics)
1421}
1422
1423#[instrument(err, skip(executor))]
1424async fn connection_lost(
1425 session: Session,
1426 mut teardown: watch::Receiver<bool>,
1427 executor: Executor,
1428) -> Result<()> {
1429 session.peer.disconnect(session.connection_id);
1430 session
1431 .connection_pool()
1432 .await
1433 .remove_connection(session.connection_id)?;
1434
1435 session
1436 .db()
1437 .await
1438 .connection_lost(session.connection_id)
1439 .await
1440 .trace_err();
1441
1442 futures::select_biased! {
1443 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1444 match &session.principal {
1445 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1446 let session = session.for_user().unwrap();
1447
1448 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1449 leave_room_for_session(&session, session.connection_id).await.trace_err();
1450 leave_channel_buffers_for_session(&session)
1451 .await
1452 .trace_err();
1453
1454 if !session
1455 .connection_pool()
1456 .await
1457 .is_user_online(session.user_id())
1458 {
1459 let db = session.db().await;
1460 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1461 room_updated(&room, &session.peer);
1462 }
1463 }
1464
1465 update_user_contacts(session.user_id(), &session).await?;
1466 },
1467 Principal::DevServer(_) => {
1468 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1469 },
1470 }
1471 },
1472 _ = teardown.changed().fuse() => {}
1473 }
1474
1475 Ok(())
1476}
1477
1478/// Acknowledges a ping from a client, used to keep the connection alive.
1479async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1480 response.send(proto::Ack {})?;
1481 Ok(())
1482}
1483
1484/// Creates a new room for calling (outside of channels)
1485async fn create_room(
1486 _request: proto::CreateRoom,
1487 response: Response<proto::CreateRoom>,
1488 session: UserSession,
1489) -> Result<()> {
1490 let live_kit_room = nanoid::nanoid!(30);
1491
1492 let live_kit_connection_info = util::maybe!(async {
1493 let live_kit = session.app_state.live_kit_client.as_ref();
1494 let live_kit = live_kit?;
1495 let user_id = session.user_id().to_string();
1496
1497 let token = live_kit
1498 .room_token(&live_kit_room, &user_id.to_string())
1499 .trace_err()?;
1500
1501 Some(proto::LiveKitConnectionInfo {
1502 server_url: live_kit.url().into(),
1503 token,
1504 can_publish: true,
1505 })
1506 })
1507 .await;
1508
1509 let room = session
1510 .db()
1511 .await
1512 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1513 .await?;
1514
1515 response.send(proto::CreateRoomResponse {
1516 room: Some(room.clone()),
1517 live_kit_connection_info,
1518 })?;
1519
1520 update_user_contacts(session.user_id(), &session).await?;
1521 Ok(())
1522}
1523
1524/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1525async fn join_room(
1526 request: proto::JoinRoom,
1527 response: Response<proto::JoinRoom>,
1528 session: UserSession,
1529) -> Result<()> {
1530 let room_id = RoomId::from_proto(request.id);
1531
1532 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1533
1534 if let Some(channel_id) = channel_id {
1535 return join_channel_internal(channel_id, Box::new(response), session).await;
1536 }
1537
1538 let joined_room = {
1539 let room = session
1540 .db()
1541 .await
1542 .join_room(room_id, session.user_id(), session.connection_id)
1543 .await?;
1544 room_updated(&room.room, &session.peer);
1545 room.into_inner()
1546 };
1547
1548 for connection_id in session
1549 .connection_pool()
1550 .await
1551 .user_connection_ids(session.user_id())
1552 {
1553 session
1554 .peer
1555 .send(
1556 connection_id,
1557 proto::CallCanceled {
1558 room_id: room_id.to_proto(),
1559 },
1560 )
1561 .trace_err();
1562 }
1563
1564 let live_kit_connection_info =
1565 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1566 if let Some(token) = live_kit
1567 .room_token(
1568 &joined_room.room.live_kit_room,
1569 &session.user_id().to_string(),
1570 )
1571 .trace_err()
1572 {
1573 Some(proto::LiveKitConnectionInfo {
1574 server_url: live_kit.url().into(),
1575 token,
1576 can_publish: true,
1577 })
1578 } else {
1579 None
1580 }
1581 } else {
1582 None
1583 };
1584
1585 response.send(proto::JoinRoomResponse {
1586 room: Some(joined_room.room),
1587 channel_id: None,
1588 live_kit_connection_info,
1589 })?;
1590
1591 update_user_contacts(session.user_id(), &session).await?;
1592 Ok(())
1593}
1594
1595/// Rejoin room is used to reconnect to a room after connection errors.
1596async fn rejoin_room(
1597 request: proto::RejoinRoom,
1598 response: Response<proto::RejoinRoom>,
1599 session: UserSession,
1600) -> Result<()> {
1601 let room;
1602 let channel;
1603 {
1604 let mut rejoined_room = session
1605 .db()
1606 .await
1607 .rejoin_room(request, session.user_id(), session.connection_id)
1608 .await?;
1609
1610 response.send(proto::RejoinRoomResponse {
1611 room: Some(rejoined_room.room.clone()),
1612 reshared_projects: rejoined_room
1613 .reshared_projects
1614 .iter()
1615 .map(|project| proto::ResharedProject {
1616 id: project.id.to_proto(),
1617 collaborators: project
1618 .collaborators
1619 .iter()
1620 .map(|collaborator| collaborator.to_proto())
1621 .collect(),
1622 })
1623 .collect(),
1624 rejoined_projects: rejoined_room
1625 .rejoined_projects
1626 .iter()
1627 .map(|rejoined_project| rejoined_project.to_proto())
1628 .collect(),
1629 })?;
1630 room_updated(&rejoined_room.room, &session.peer);
1631
1632 for project in &rejoined_room.reshared_projects {
1633 for collaborator in &project.collaborators {
1634 session
1635 .peer
1636 .send(
1637 collaborator.connection_id,
1638 proto::UpdateProjectCollaborator {
1639 project_id: project.id.to_proto(),
1640 old_peer_id: Some(project.old_connection_id.into()),
1641 new_peer_id: Some(session.connection_id.into()),
1642 },
1643 )
1644 .trace_err();
1645 }
1646
1647 broadcast(
1648 Some(session.connection_id),
1649 project
1650 .collaborators
1651 .iter()
1652 .map(|collaborator| collaborator.connection_id),
1653 |connection_id| {
1654 session.peer.forward_send(
1655 session.connection_id,
1656 connection_id,
1657 proto::UpdateProject {
1658 project_id: project.id.to_proto(),
1659 worktrees: project.worktrees.clone(),
1660 },
1661 )
1662 },
1663 );
1664 }
1665
1666 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1667
1668 let rejoined_room = rejoined_room.into_inner();
1669
1670 room = rejoined_room.room;
1671 channel = rejoined_room.channel;
1672 }
1673
1674 if let Some(channel) = channel {
1675 channel_updated(
1676 &channel,
1677 &room,
1678 &session.peer,
1679 &*session.connection_pool().await,
1680 );
1681 }
1682
1683 update_user_contacts(session.user_id(), &session).await?;
1684 Ok(())
1685}
1686
1687fn notify_rejoined_projects(
1688 rejoined_projects: &mut Vec<RejoinedProject>,
1689 session: &UserSession,
1690) -> Result<()> {
1691 for project in rejoined_projects.iter() {
1692 for collaborator in &project.collaborators {
1693 session
1694 .peer
1695 .send(
1696 collaborator.connection_id,
1697 proto::UpdateProjectCollaborator {
1698 project_id: project.id.to_proto(),
1699 old_peer_id: Some(project.old_connection_id.into()),
1700 new_peer_id: Some(session.connection_id.into()),
1701 },
1702 )
1703 .trace_err();
1704 }
1705 }
1706
1707 for project in rejoined_projects {
1708 for worktree in mem::take(&mut project.worktrees) {
1709 #[cfg(any(test, feature = "test-support"))]
1710 const MAX_CHUNK_SIZE: usize = 2;
1711 #[cfg(not(any(test, feature = "test-support")))]
1712 const MAX_CHUNK_SIZE: usize = 256;
1713
1714 // Stream this worktree's entries.
1715 let message = proto::UpdateWorktree {
1716 project_id: project.id.to_proto(),
1717 worktree_id: worktree.id,
1718 abs_path: worktree.abs_path.clone(),
1719 root_name: worktree.root_name,
1720 updated_entries: worktree.updated_entries,
1721 removed_entries: worktree.removed_entries,
1722 scan_id: worktree.scan_id,
1723 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1724 updated_repositories: worktree.updated_repositories,
1725 removed_repositories: worktree.removed_repositories,
1726 };
1727 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1728 session.peer.send(session.connection_id, update.clone())?;
1729 }
1730
1731 // Stream this worktree's diagnostics.
1732 for summary in worktree.diagnostic_summaries {
1733 session.peer.send(
1734 session.connection_id,
1735 proto::UpdateDiagnosticSummary {
1736 project_id: project.id.to_proto(),
1737 worktree_id: worktree.id,
1738 summary: Some(summary),
1739 },
1740 )?;
1741 }
1742
1743 for settings_file in worktree.settings_files {
1744 session.peer.send(
1745 session.connection_id,
1746 proto::UpdateWorktreeSettings {
1747 project_id: project.id.to_proto(),
1748 worktree_id: worktree.id,
1749 path: settings_file.path,
1750 content: Some(settings_file.content),
1751 },
1752 )?;
1753 }
1754 }
1755
1756 for language_server in &project.language_servers {
1757 session.peer.send(
1758 session.connection_id,
1759 proto::UpdateLanguageServer {
1760 project_id: project.id.to_proto(),
1761 language_server_id: language_server.id,
1762 variant: Some(
1763 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1764 proto::LspDiskBasedDiagnosticsUpdated {},
1765 ),
1766 ),
1767 },
1768 )?;
1769 }
1770 }
1771 Ok(())
1772}
1773
1774/// leave room disconnects from the room.
1775async fn leave_room(
1776 _: proto::LeaveRoom,
1777 response: Response<proto::LeaveRoom>,
1778 session: UserSession,
1779) -> Result<()> {
1780 leave_room_for_session(&session, session.connection_id).await?;
1781 response.send(proto::Ack {})?;
1782 Ok(())
1783}
1784
1785/// Updates the permissions of someone else in the room.
1786async fn set_room_participant_role(
1787 request: proto::SetRoomParticipantRole,
1788 response: Response<proto::SetRoomParticipantRole>,
1789 session: UserSession,
1790) -> Result<()> {
1791 let user_id = UserId::from_proto(request.user_id);
1792 let role = ChannelRole::from(request.role());
1793
1794 let (live_kit_room, can_publish) = {
1795 let room = session
1796 .db()
1797 .await
1798 .set_room_participant_role(
1799 session.user_id(),
1800 RoomId::from_proto(request.room_id),
1801 user_id,
1802 role,
1803 )
1804 .await?;
1805
1806 let live_kit_room = room.live_kit_room.clone();
1807 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1808 room_updated(&room, &session.peer);
1809 (live_kit_room, can_publish)
1810 };
1811
1812 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1813 live_kit
1814 .update_participant(
1815 live_kit_room.clone(),
1816 request.user_id.to_string(),
1817 live_kit_server::proto::ParticipantPermission {
1818 can_subscribe: true,
1819 can_publish,
1820 can_publish_data: can_publish,
1821 hidden: false,
1822 recorder: false,
1823 },
1824 )
1825 .await
1826 .trace_err();
1827 }
1828
1829 response.send(proto::Ack {})?;
1830 Ok(())
1831}
1832
1833/// Call someone else into the current room
1834async fn call(
1835 request: proto::Call,
1836 response: Response<proto::Call>,
1837 session: UserSession,
1838) -> Result<()> {
1839 let room_id = RoomId::from_proto(request.room_id);
1840 let calling_user_id = session.user_id();
1841 let calling_connection_id = session.connection_id;
1842 let called_user_id = UserId::from_proto(request.called_user_id);
1843 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1844 if !session
1845 .db()
1846 .await
1847 .has_contact(calling_user_id, called_user_id)
1848 .await?
1849 {
1850 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1851 }
1852
1853 let incoming_call = {
1854 let (room, incoming_call) = &mut *session
1855 .db()
1856 .await
1857 .call(
1858 room_id,
1859 calling_user_id,
1860 calling_connection_id,
1861 called_user_id,
1862 initial_project_id,
1863 )
1864 .await?;
1865 room_updated(&room, &session.peer);
1866 mem::take(incoming_call)
1867 };
1868 update_user_contacts(called_user_id, &session).await?;
1869
1870 let mut calls = session
1871 .connection_pool()
1872 .await
1873 .user_connection_ids(called_user_id)
1874 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1875 .collect::<FuturesUnordered<_>>();
1876
1877 while let Some(call_response) = calls.next().await {
1878 match call_response.as_ref() {
1879 Ok(_) => {
1880 response.send(proto::Ack {})?;
1881 return Ok(());
1882 }
1883 Err(_) => {
1884 call_response.trace_err();
1885 }
1886 }
1887 }
1888
1889 {
1890 let room = session
1891 .db()
1892 .await
1893 .call_failed(room_id, called_user_id)
1894 .await?;
1895 room_updated(&room, &session.peer);
1896 }
1897 update_user_contacts(called_user_id, &session).await?;
1898
1899 Err(anyhow!("failed to ring user"))?
1900}
1901
1902/// Cancel an outgoing call.
1903async fn cancel_call(
1904 request: proto::CancelCall,
1905 response: Response<proto::CancelCall>,
1906 session: UserSession,
1907) -> Result<()> {
1908 let called_user_id = UserId::from_proto(request.called_user_id);
1909 let room_id = RoomId::from_proto(request.room_id);
1910 {
1911 let room = session
1912 .db()
1913 .await
1914 .cancel_call(room_id, session.connection_id, called_user_id)
1915 .await?;
1916 room_updated(&room, &session.peer);
1917 }
1918
1919 for connection_id in session
1920 .connection_pool()
1921 .await
1922 .user_connection_ids(called_user_id)
1923 {
1924 session
1925 .peer
1926 .send(
1927 connection_id,
1928 proto::CallCanceled {
1929 room_id: room_id.to_proto(),
1930 },
1931 )
1932 .trace_err();
1933 }
1934 response.send(proto::Ack {})?;
1935
1936 update_user_contacts(called_user_id, &session).await?;
1937 Ok(())
1938}
1939
1940/// Decline an incoming call.
1941async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1942 let room_id = RoomId::from_proto(message.room_id);
1943 {
1944 let room = session
1945 .db()
1946 .await
1947 .decline_call(Some(room_id), session.user_id())
1948 .await?
1949 .ok_or_else(|| anyhow!("failed to decline call"))?;
1950 room_updated(&room, &session.peer);
1951 }
1952
1953 for connection_id in session
1954 .connection_pool()
1955 .await
1956 .user_connection_ids(session.user_id())
1957 {
1958 session
1959 .peer
1960 .send(
1961 connection_id,
1962 proto::CallCanceled {
1963 room_id: room_id.to_proto(),
1964 },
1965 )
1966 .trace_err();
1967 }
1968 update_user_contacts(session.user_id(), &session).await?;
1969 Ok(())
1970}
1971
1972/// Updates other participants in the room with your current location.
1973async fn update_participant_location(
1974 request: proto::UpdateParticipantLocation,
1975 response: Response<proto::UpdateParticipantLocation>,
1976 session: UserSession,
1977) -> Result<()> {
1978 let room_id = RoomId::from_proto(request.room_id);
1979 let location = request
1980 .location
1981 .ok_or_else(|| anyhow!("invalid location"))?;
1982
1983 let db = session.db().await;
1984 let room = db
1985 .update_room_participant_location(room_id, session.connection_id, location)
1986 .await?;
1987
1988 room_updated(&room, &session.peer);
1989 response.send(proto::Ack {})?;
1990 Ok(())
1991}
1992
1993/// Share a project into the room.
1994async fn share_project(
1995 request: proto::ShareProject,
1996 response: Response<proto::ShareProject>,
1997 session: UserSession,
1998) -> Result<()> {
1999 let (project_id, room) = &*session
2000 .db()
2001 .await
2002 .share_project(
2003 RoomId::from_proto(request.room_id),
2004 session.connection_id,
2005 &request.worktrees,
2006 request
2007 .dev_server_project_id
2008 .map(|id| DevServerProjectId::from_proto(id)),
2009 )
2010 .await?;
2011 response.send(proto::ShareProjectResponse {
2012 project_id: project_id.to_proto(),
2013 })?;
2014 room_updated(&room, &session.peer);
2015
2016 Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021 let project_id = ProjectId::from_proto(message.project_id);
2022 unshare_project_internal(
2023 project_id,
2024 session.connection_id,
2025 session.user_id(),
2026 &session,
2027 )
2028 .await
2029}
2030
2031async fn unshare_project_internal(
2032 project_id: ProjectId,
2033 connection_id: ConnectionId,
2034 user_id: Option<UserId>,
2035 session: &Session,
2036) -> Result<()> {
2037 let delete = {
2038 let room_guard = session
2039 .db()
2040 .await
2041 .unshare_project(project_id, connection_id, user_id)
2042 .await?;
2043
2044 let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046 let message = proto::UnshareProject {
2047 project_id: project_id.to_proto(),
2048 };
2049
2050 broadcast(
2051 Some(connection_id),
2052 guest_connection_ids.iter().copied(),
2053 |conn_id| session.peer.send(conn_id, message.clone()),
2054 );
2055 if let Some(room) = room {
2056 room_updated(room, &session.peer);
2057 }
2058
2059 *delete
2060 };
2061
2062 if delete {
2063 let db = session.db().await;
2064 db.delete_project(project_id).await?;
2065 }
2066
2067 Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072 request: proto::ShareDevServerProject,
2073 response: Response<proto::ShareDevServerProject>,
2074 session: DevServerSession,
2075) -> Result<()> {
2076 let (dev_server_project, user_id, status) = session
2077 .db()
2078 .await
2079 .share_dev_server_project(
2080 DevServerProjectId::from_proto(request.dev_server_project_id),
2081 session.dev_server_id(),
2082 session.connection_id,
2083 &request.worktrees,
2084 )
2085 .await?;
2086 let Some(project_id) = dev_server_project.project_id else {
2087 return Err(anyhow!("failed to share remote project"))?;
2088 };
2089
2090 send_dev_server_projects_update(user_id, status, &session).await;
2091
2092 response.send(proto::ShareProjectResponse { project_id })?;
2093
2094 Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099 request: proto::JoinProject,
2100 response: Response<proto::JoinProject>,
2101 session: UserSession,
2102) -> Result<()> {
2103 let project_id = ProjectId::from_proto(request.project_id);
2104
2105 tracing::info!(%project_id, "join project");
2106
2107 let db = session.db().await;
2108 let (project, replica_id) = &mut *db
2109 .join_project(project_id, session.connection_id, session.user_id())
2110 .await?;
2111 drop(db);
2112 tracing::info!(%project_id, "join remote project");
2113 join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121 Response::<proto::JoinProject>::send(self, result)
2122 }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126 Response::<proto::JoinHostedProject>::send(self, result)
2127 }
2128}
2129
2130fn join_project_internal(
2131 response: impl JoinProjectInternalResponse,
2132 session: UserSession,
2133 project: &mut Project,
2134 replica_id: &ReplicaId,
2135) -> Result<()> {
2136 let collaborators = project
2137 .collaborators
2138 .iter()
2139 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140 .map(|collaborator| collaborator.to_proto())
2141 .collect::<Vec<_>>();
2142 let project_id = project.id;
2143 let guest_user_id = session.user_id();
2144
2145 let worktrees = project
2146 .worktrees
2147 .iter()
2148 .map(|(id, worktree)| proto::WorktreeMetadata {
2149 id: *id,
2150 root_name: worktree.root_name.clone(),
2151 visible: worktree.visible,
2152 abs_path: worktree.abs_path.clone(),
2153 })
2154 .collect::<Vec<_>>();
2155
2156 let add_project_collaborator = proto::AddProjectCollaborator {
2157 project_id: project_id.to_proto(),
2158 collaborator: Some(proto::Collaborator {
2159 peer_id: Some(session.connection_id.into()),
2160 replica_id: replica_id.0 as u32,
2161 user_id: guest_user_id.to_proto(),
2162 }),
2163 };
2164
2165 for collaborator in &collaborators {
2166 session
2167 .peer
2168 .send(
2169 collaborator.peer_id.unwrap().into(),
2170 add_project_collaborator.clone(),
2171 )
2172 .trace_err();
2173 }
2174
2175 // First, we send the metadata associated with each worktree.
2176 response.send(proto::JoinProjectResponse {
2177 project_id: project.id.0 as u64,
2178 worktrees: worktrees.clone(),
2179 replica_id: replica_id.0 as u32,
2180 collaborators: collaborators.clone(),
2181 language_servers: project.language_servers.clone(),
2182 role: project.role.into(),
2183 dev_server_project_id: project
2184 .dev_server_project_id
2185 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186 })?;
2187
2188 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189 #[cfg(any(test, feature = "test-support"))]
2190 const MAX_CHUNK_SIZE: usize = 2;
2191 #[cfg(not(any(test, feature = "test-support")))]
2192 const MAX_CHUNK_SIZE: usize = 256;
2193
2194 // Stream this worktree's entries.
2195 let message = proto::UpdateWorktree {
2196 project_id: project_id.to_proto(),
2197 worktree_id,
2198 abs_path: worktree.abs_path.clone(),
2199 root_name: worktree.root_name,
2200 updated_entries: worktree.entries,
2201 removed_entries: Default::default(),
2202 scan_id: worktree.scan_id,
2203 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204 updated_repositories: worktree.repository_entries.into_values().collect(),
2205 removed_repositories: Default::default(),
2206 };
2207 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208 session.peer.send(session.connection_id, update.clone())?;
2209 }
2210
2211 // Stream this worktree's diagnostics.
2212 for summary in worktree.diagnostic_summaries {
2213 session.peer.send(
2214 session.connection_id,
2215 proto::UpdateDiagnosticSummary {
2216 project_id: project_id.to_proto(),
2217 worktree_id: worktree.id,
2218 summary: Some(summary),
2219 },
2220 )?;
2221 }
2222
2223 for settings_file in worktree.settings_files {
2224 session.peer.send(
2225 session.connection_id,
2226 proto::UpdateWorktreeSettings {
2227 project_id: project_id.to_proto(),
2228 worktree_id: worktree.id,
2229 path: settings_file.path,
2230 content: Some(settings_file.content),
2231 },
2232 )?;
2233 }
2234 }
2235
2236 for language_server in &project.language_servers {
2237 session.peer.send(
2238 session.connection_id,
2239 proto::UpdateLanguageServer {
2240 project_id: project_id.to_proto(),
2241 language_server_id: language_server.id,
2242 variant: Some(
2243 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2244 proto::LspDiskBasedDiagnosticsUpdated {},
2245 ),
2246 ),
2247 },
2248 )?;
2249 }
2250
2251 Ok(())
2252}
2253
2254/// Leave someone elses shared project.
2255async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2256 let sender_id = session.connection_id;
2257 let project_id = ProjectId::from_proto(request.project_id);
2258 let db = session.db().await;
2259 if db.is_hosted_project(project_id).await? {
2260 let project = db.leave_hosted_project(project_id, sender_id).await?;
2261 project_left(&project, &session);
2262 return Ok(());
2263 }
2264
2265 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2266 tracing::info!(
2267 %project_id,
2268 "leave project"
2269 );
2270
2271 project_left(&project, &session);
2272 if let Some(room) = room {
2273 room_updated(&room, &session.peer);
2274 }
2275
2276 Ok(())
2277}
2278
2279async fn join_hosted_project(
2280 request: proto::JoinHostedProject,
2281 response: Response<proto::JoinHostedProject>,
2282 session: UserSession,
2283) -> Result<()> {
2284 let (mut project, replica_id) = session
2285 .db()
2286 .await
2287 .join_hosted_project(
2288 ProjectId(request.project_id as i32),
2289 session.user_id(),
2290 session.connection_id,
2291 )
2292 .await?;
2293
2294 join_project_internal(response, session, &mut project, &replica_id)
2295}
2296
2297async fn list_remote_directory(
2298 request: proto::ListRemoteDirectory,
2299 response: Response<proto::ListRemoteDirectory>,
2300 session: UserSession,
2301) -> Result<()> {
2302 let dev_server_id = DevServerId(request.dev_server_id as i32);
2303 let dev_server_connection_id = session
2304 .connection_pool()
2305 .await
2306 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2307
2308 session
2309 .db()
2310 .await
2311 .get_dev_server_for_user(dev_server_id, session.user_id())
2312 .await?;
2313
2314 response.send(
2315 session
2316 .peer
2317 .forward_request(session.connection_id, dev_server_connection_id, request)
2318 .await?,
2319 )?;
2320 Ok(())
2321}
2322
2323async fn update_dev_server_project(
2324 request: proto::UpdateDevServerProject,
2325 response: Response<proto::UpdateDevServerProject>,
2326 session: UserSession,
2327) -> Result<()> {
2328 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2329
2330 let (dev_server_project, update) = session
2331 .db()
2332 .await
2333 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2334 .await?;
2335
2336 let projects = session
2337 .db()
2338 .await
2339 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2340 .await?;
2341
2342 let dev_server_connection_id = session
2343 .connection_pool()
2344 .await
2345 .dev_server_connection_id_supporting(
2346 dev_server_project.dev_server_id,
2347 ZedVersion::with_list_directory(),
2348 )?;
2349
2350 session.peer.send(
2351 dev_server_connection_id,
2352 proto::DevServerInstructions { projects },
2353 )?;
2354
2355 send_dev_server_projects_update(session.user_id(), update, &session).await;
2356
2357 response.send(proto::Ack {})
2358}
2359
2360async fn create_dev_server_project(
2361 request: proto::CreateDevServerProject,
2362 response: Response<proto::CreateDevServerProject>,
2363 session: UserSession,
2364) -> Result<()> {
2365 let dev_server_id = DevServerId(request.dev_server_id as i32);
2366 let dev_server_connection_id = session
2367 .connection_pool()
2368 .await
2369 .dev_server_connection_id(dev_server_id);
2370 let Some(dev_server_connection_id) = dev_server_connection_id else {
2371 Err(ErrorCode::DevServerOffline
2372 .message("Cannot create a remote project when the dev server is offline".to_string())
2373 .anyhow())?
2374 };
2375
2376 let path = request.path.clone();
2377 //Check that the path exists on the dev server
2378 session
2379 .peer
2380 .forward_request(
2381 session.connection_id,
2382 dev_server_connection_id,
2383 proto::ValidateDevServerProjectRequest { path: path.clone() },
2384 )
2385 .await?;
2386
2387 let (dev_server_project, update) = session
2388 .db()
2389 .await
2390 .create_dev_server_project(
2391 DevServerId(request.dev_server_id as i32),
2392 &request.path,
2393 session.user_id(),
2394 )
2395 .await?;
2396
2397 let projects = session
2398 .db()
2399 .await
2400 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2401 .await?;
2402
2403 session.peer.send(
2404 dev_server_connection_id,
2405 proto::DevServerInstructions { projects },
2406 )?;
2407
2408 send_dev_server_projects_update(session.user_id(), update, &session).await;
2409
2410 response.send(proto::CreateDevServerProjectResponse {
2411 dev_server_project: Some(dev_server_project.to_proto(None)),
2412 })?;
2413 Ok(())
2414}
2415
2416async fn create_dev_server(
2417 request: proto::CreateDevServer,
2418 response: Response<proto::CreateDevServer>,
2419 session: UserSession,
2420) -> Result<()> {
2421 let access_token = auth::random_token();
2422 let hashed_access_token = auth::hash_access_token(&access_token);
2423
2424 if request.name.is_empty() {
2425 return Err(proto::ErrorCode::Forbidden
2426 .message("Dev server name cannot be empty".to_string())
2427 .anyhow())?;
2428 }
2429
2430 let (dev_server, status) = session
2431 .db()
2432 .await
2433 .create_dev_server(
2434 &request.name,
2435 request.ssh_connection_string.as_deref(),
2436 &hashed_access_token,
2437 session.user_id(),
2438 )
2439 .await?;
2440
2441 send_dev_server_projects_update(session.user_id(), status, &session).await;
2442
2443 response.send(proto::CreateDevServerResponse {
2444 dev_server_id: dev_server.id.0 as u64,
2445 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2446 name: request.name,
2447 })?;
2448 Ok(())
2449}
2450
2451async fn regenerate_dev_server_token(
2452 request: proto::RegenerateDevServerToken,
2453 response: Response<proto::RegenerateDevServerToken>,
2454 session: UserSession,
2455) -> Result<()> {
2456 let dev_server_id = DevServerId(request.dev_server_id as i32);
2457 let access_token = auth::random_token();
2458 let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460 let connection_id = session
2461 .connection_pool()
2462 .await
2463 .dev_server_connection_id(dev_server_id);
2464 if let Some(connection_id) = connection_id {
2465 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2466 session.peer.send(
2467 connection_id,
2468 proto::ShutdownDevServer {
2469 reason: Some("dev server token was regenerated".to_string()),
2470 },
2471 )?;
2472 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2473 }
2474
2475 let status = session
2476 .db()
2477 .await
2478 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2479 .await?;
2480
2481 send_dev_server_projects_update(session.user_id(), status, &session).await;
2482
2483 response.send(proto::RegenerateDevServerTokenResponse {
2484 dev_server_id: dev_server_id.to_proto(),
2485 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2486 })?;
2487 Ok(())
2488}
2489
2490async fn rename_dev_server(
2491 request: proto::RenameDevServer,
2492 response: Response<proto::RenameDevServer>,
2493 session: UserSession,
2494) -> Result<()> {
2495 if request.name.trim().is_empty() {
2496 return Err(proto::ErrorCode::Forbidden
2497 .message("Dev server name cannot be empty".to_string())
2498 .anyhow())?;
2499 }
2500
2501 let dev_server_id = DevServerId(request.dev_server_id as i32);
2502 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2503 if dev_server.user_id != session.user_id() {
2504 return Err(anyhow!(ErrorCode::Forbidden))?;
2505 }
2506
2507 let status = session
2508 .db()
2509 .await
2510 .rename_dev_server(
2511 dev_server_id,
2512 &request.name,
2513 request.ssh_connection_string.as_deref(),
2514 session.user_id(),
2515 )
2516 .await?;
2517
2518 send_dev_server_projects_update(session.user_id(), status, &session).await;
2519
2520 response.send(proto::Ack {})?;
2521 Ok(())
2522}
2523
2524async fn delete_dev_server(
2525 request: proto::DeleteDevServer,
2526 response: Response<proto::DeleteDevServer>,
2527 session: UserSession,
2528) -> Result<()> {
2529 let dev_server_id = DevServerId(request.dev_server_id as i32);
2530 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2531 if dev_server.user_id != session.user_id() {
2532 return Err(anyhow!(ErrorCode::Forbidden))?;
2533 }
2534
2535 let connection_id = session
2536 .connection_pool()
2537 .await
2538 .dev_server_connection_id(dev_server_id);
2539 if let Some(connection_id) = connection_id {
2540 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2541 session.peer.send(
2542 connection_id,
2543 proto::ShutdownDevServer {
2544 reason: Some("dev server was deleted".to_string()),
2545 },
2546 )?;
2547 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2548 }
2549
2550 let status = session
2551 .db()
2552 .await
2553 .delete_dev_server(dev_server_id, session.user_id())
2554 .await?;
2555
2556 send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558 response.send(proto::Ack {})?;
2559 Ok(())
2560}
2561
2562async fn delete_dev_server_project(
2563 request: proto::DeleteDevServerProject,
2564 response: Response<proto::DeleteDevServerProject>,
2565 session: UserSession,
2566) -> Result<()> {
2567 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2568 let dev_server_project = session
2569 .db()
2570 .await
2571 .get_dev_server_project(dev_server_project_id)
2572 .await?;
2573
2574 let dev_server = session
2575 .db()
2576 .await
2577 .get_dev_server(dev_server_project.dev_server_id)
2578 .await?;
2579 if dev_server.user_id != session.user_id() {
2580 return Err(anyhow!(ErrorCode::Forbidden))?;
2581 }
2582
2583 let dev_server_connection_id = session
2584 .connection_pool()
2585 .await
2586 .dev_server_connection_id(dev_server.id);
2587
2588 if let Some(dev_server_connection_id) = dev_server_connection_id {
2589 let project = session
2590 .db()
2591 .await
2592 .find_dev_server_project(dev_server_project_id)
2593 .await;
2594 if let Ok(project) = project {
2595 unshare_project_internal(
2596 project.id,
2597 dev_server_connection_id,
2598 Some(session.user_id()),
2599 &session,
2600 )
2601 .await?;
2602 }
2603 }
2604
2605 let (projects, status) = session
2606 .db()
2607 .await
2608 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2609 .await?;
2610
2611 if let Some(dev_server_connection_id) = dev_server_connection_id {
2612 session.peer.send(
2613 dev_server_connection_id,
2614 proto::DevServerInstructions { projects },
2615 )?;
2616 }
2617
2618 send_dev_server_projects_update(session.user_id(), status, &session).await;
2619
2620 response.send(proto::Ack {})?;
2621 Ok(())
2622}
2623
2624async fn rejoin_dev_server_projects(
2625 request: proto::RejoinRemoteProjects,
2626 response: Response<proto::RejoinRemoteProjects>,
2627 session: UserSession,
2628) -> Result<()> {
2629 let mut rejoined_projects = {
2630 let db = session.db().await;
2631 db.rejoin_dev_server_projects(
2632 &request.rejoined_projects,
2633 session.user_id(),
2634 session.0.connection_id,
2635 )
2636 .await?
2637 };
2638 response.send(proto::RejoinRemoteProjectsResponse {
2639 rejoined_projects: rejoined_projects
2640 .iter()
2641 .map(|project| project.to_proto())
2642 .collect(),
2643 })?;
2644 notify_rejoined_projects(&mut rejoined_projects, &session)
2645}
2646
2647async fn reconnect_dev_server(
2648 request: proto::ReconnectDevServer,
2649 response: Response<proto::ReconnectDevServer>,
2650 session: DevServerSession,
2651) -> Result<()> {
2652 let reshared_projects = {
2653 let db = session.db().await;
2654 db.reshare_dev_server_projects(
2655 &request.reshared_projects,
2656 session.dev_server_id(),
2657 session.0.connection_id,
2658 )
2659 .await?
2660 };
2661
2662 for project in &reshared_projects {
2663 for collaborator in &project.collaborators {
2664 session
2665 .peer
2666 .send(
2667 collaborator.connection_id,
2668 proto::UpdateProjectCollaborator {
2669 project_id: project.id.to_proto(),
2670 old_peer_id: Some(project.old_connection_id.into()),
2671 new_peer_id: Some(session.connection_id.into()),
2672 },
2673 )
2674 .trace_err();
2675 }
2676
2677 broadcast(
2678 Some(session.connection_id),
2679 project
2680 .collaborators
2681 .iter()
2682 .map(|collaborator| collaborator.connection_id),
2683 |connection_id| {
2684 session.peer.forward_send(
2685 session.connection_id,
2686 connection_id,
2687 proto::UpdateProject {
2688 project_id: project.id.to_proto(),
2689 worktrees: project.worktrees.clone(),
2690 },
2691 )
2692 },
2693 );
2694 }
2695
2696 response.send(proto::ReconnectDevServerResponse {
2697 reshared_projects: reshared_projects
2698 .iter()
2699 .map(|project| proto::ResharedProject {
2700 id: project.id.to_proto(),
2701 collaborators: project
2702 .collaborators
2703 .iter()
2704 .map(|collaborator| collaborator.to_proto())
2705 .collect(),
2706 })
2707 .collect(),
2708 })?;
2709
2710 Ok(())
2711}
2712
2713async fn shutdown_dev_server(
2714 _: proto::ShutdownDevServer,
2715 response: Response<proto::ShutdownDevServer>,
2716 session: DevServerSession,
2717) -> Result<()> {
2718 response.send(proto::Ack {})?;
2719 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2720 remove_dev_server_connection(session.dev_server_id(), &session).await
2721}
2722
2723async fn shutdown_dev_server_internal(
2724 dev_server_id: DevServerId,
2725 connection_id: ConnectionId,
2726 session: &Session,
2727) -> Result<()> {
2728 let (dev_server_projects, dev_server) = {
2729 let db = session.db().await;
2730 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2731 let dev_server = db.get_dev_server(dev_server_id).await?;
2732 (dev_server_projects, dev_server)
2733 };
2734
2735 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2736 unshare_project_internal(
2737 ProjectId::from_proto(project_id),
2738 connection_id,
2739 None,
2740 session,
2741 )
2742 .await?;
2743 }
2744
2745 session
2746 .connection_pool()
2747 .await
2748 .set_dev_server_offline(dev_server_id);
2749
2750 let status = session
2751 .db()
2752 .await
2753 .dev_server_projects_update(dev_server.user_id)
2754 .await?;
2755 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2756
2757 Ok(())
2758}
2759
2760async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2761 let dev_server_connection = session
2762 .connection_pool()
2763 .await
2764 .dev_server_connection_id(dev_server_id);
2765
2766 if let Some(dev_server_connection) = dev_server_connection {
2767 session
2768 .connection_pool()
2769 .await
2770 .remove_connection(dev_server_connection)?;
2771 }
2772 Ok(())
2773}
2774
2775/// Updates other participants with changes to the project
2776async fn update_project(
2777 request: proto::UpdateProject,
2778 response: Response<proto::UpdateProject>,
2779 session: Session,
2780) -> Result<()> {
2781 let project_id = ProjectId::from_proto(request.project_id);
2782 let (room, guest_connection_ids) = &*session
2783 .db()
2784 .await
2785 .update_project(project_id, session.connection_id, &request.worktrees)
2786 .await?;
2787 broadcast(
2788 Some(session.connection_id),
2789 guest_connection_ids.iter().copied(),
2790 |connection_id| {
2791 session
2792 .peer
2793 .forward_send(session.connection_id, connection_id, request.clone())
2794 },
2795 );
2796 if let Some(room) = room {
2797 room_updated(&room, &session.peer);
2798 }
2799 response.send(proto::Ack {})?;
2800
2801 Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree
2805async fn update_worktree(
2806 request: proto::UpdateWorktree,
2807 response: Response<proto::UpdateWorktree>,
2808 session: Session,
2809) -> Result<()> {
2810 let guest_connection_ids = session
2811 .db()
2812 .await
2813 .update_worktree(&request, session.connection_id)
2814 .await?;
2815
2816 broadcast(
2817 Some(session.connection_id),
2818 guest_connection_ids.iter().copied(),
2819 |connection_id| {
2820 session
2821 .peer
2822 .forward_send(session.connection_id, connection_id, request.clone())
2823 },
2824 );
2825 response.send(proto::Ack {})?;
2826 Ok(())
2827}
2828
2829/// Updates other participants with changes to the diagnostics
2830async fn update_diagnostic_summary(
2831 message: proto::UpdateDiagnosticSummary,
2832 session: Session,
2833) -> Result<()> {
2834 let guest_connection_ids = session
2835 .db()
2836 .await
2837 .update_diagnostic_summary(&message, session.connection_id)
2838 .await?;
2839
2840 broadcast(
2841 Some(session.connection_id),
2842 guest_connection_ids.iter().copied(),
2843 |connection_id| {
2844 session
2845 .peer
2846 .forward_send(session.connection_id, connection_id, message.clone())
2847 },
2848 );
2849
2850 Ok(())
2851}
2852
2853/// Updates other participants with changes to the worktree settings
2854async fn update_worktree_settings(
2855 message: proto::UpdateWorktreeSettings,
2856 session: Session,
2857) -> Result<()> {
2858 let guest_connection_ids = session
2859 .db()
2860 .await
2861 .update_worktree_settings(&message, session.connection_id)
2862 .await?;
2863
2864 broadcast(
2865 Some(session.connection_id),
2866 guest_connection_ids.iter().copied(),
2867 |connection_id| {
2868 session
2869 .peer
2870 .forward_send(session.connection_id, connection_id, message.clone())
2871 },
2872 );
2873
2874 Ok(())
2875}
2876
2877/// Notify other participants that a language server has started.
2878async fn start_language_server(
2879 request: proto::StartLanguageServer,
2880 session: Session,
2881) -> Result<()> {
2882 let guest_connection_ids = session
2883 .db()
2884 .await
2885 .start_language_server(&request, session.connection_id)
2886 .await?;
2887
2888 broadcast(
2889 Some(session.connection_id),
2890 guest_connection_ids.iter().copied(),
2891 |connection_id| {
2892 session
2893 .peer
2894 .forward_send(session.connection_id, connection_id, request.clone())
2895 },
2896 );
2897 Ok(())
2898}
2899
2900/// Notify other participants that a language server has changed.
2901async fn update_language_server(
2902 request: proto::UpdateLanguageServer,
2903 session: Session,
2904) -> Result<()> {
2905 let project_id = ProjectId::from_proto(request.project_id);
2906 let project_connection_ids = session
2907 .db()
2908 .await
2909 .project_connection_ids(project_id, session.connection_id, true)
2910 .await?;
2911 broadcast(
2912 Some(session.connection_id),
2913 project_connection_ids.iter().copied(),
2914 |connection_id| {
2915 session
2916 .peer
2917 .forward_send(session.connection_id, connection_id, request.clone())
2918 },
2919 );
2920 Ok(())
2921}
2922
2923/// forward a project request to the host. These requests should be read only
2924/// as guests are allowed to send them.
2925async fn forward_read_only_project_request<T>(
2926 request: T,
2927 response: Response<T>,
2928 session: UserSession,
2929) -> Result<()>
2930where
2931 T: EntityMessage + RequestMessage,
2932{
2933 let project_id = ProjectId::from_proto(request.remote_entity_id());
2934 let host_connection_id = session
2935 .db()
2936 .await
2937 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2938 .await?;
2939 let payload = session
2940 .peer
2941 .forward_request(session.connection_id, host_connection_id, request)
2942 .await?;
2943 response.send(payload)?;
2944 Ok(())
2945}
2946
2947/// forward a project request to the dev server. Only allowed
2948/// if it's your dev server.
2949async fn forward_project_request_for_owner<T>(
2950 request: T,
2951 response: Response<T>,
2952 session: UserSession,
2953) -> Result<()>
2954where
2955 T: EntityMessage + RequestMessage,
2956{
2957 let project_id = ProjectId::from_proto(request.remote_entity_id());
2958
2959 let host_connection_id = session
2960 .db()
2961 .await
2962 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2963 .await?;
2964 let payload = session
2965 .peer
2966 .forward_request(session.connection_id, host_connection_id, request)
2967 .await?;
2968 response.send(payload)?;
2969 Ok(())
2970}
2971
2972/// forward a project request to the host. These requests are disallowed
2973/// for guests.
2974async fn forward_mutating_project_request<T>(
2975 request: T,
2976 response: Response<T>,
2977 session: UserSession,
2978) -> Result<()>
2979where
2980 T: EntityMessage + RequestMessage,
2981{
2982 let project_id = ProjectId::from_proto(request.remote_entity_id());
2983
2984 let host_connection_id = session
2985 .db()
2986 .await
2987 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2988 .await?;
2989 let payload = session
2990 .peer
2991 .forward_request(session.connection_id, host_connection_id, request)
2992 .await?;
2993 response.send(payload)?;
2994 Ok(())
2995}
2996
2997/// forward a project request to the host. These requests are disallowed
2998/// for guests.
2999async fn forward_versioned_mutating_project_request<T>(
3000 request: T,
3001 response: Response<T>,
3002 session: UserSession,
3003) -> Result<()>
3004where
3005 T: EntityMessage + RequestMessage + VersionedMessage,
3006{
3007 let project_id = ProjectId::from_proto(request.remote_entity_id());
3008
3009 let host_connection_id = session
3010 .db()
3011 .await
3012 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3013 .await?;
3014 if let Some(host_version) = session
3015 .connection_pool()
3016 .await
3017 .connection(host_connection_id)
3018 .map(|c| c.zed_version)
3019 {
3020 if let Some(min_required_version) = request.required_host_version() {
3021 if min_required_version > host_version {
3022 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3023 .with_tag("required", &min_required_version.to_string())))?;
3024 }
3025 }
3026 }
3027
3028 let payload = session
3029 .peer
3030 .forward_request(session.connection_id, host_connection_id, request)
3031 .await?;
3032 response.send(payload)?;
3033 Ok(())
3034}
3035
3036/// Notify other participants that a new buffer has been created
3037async fn create_buffer_for_peer(
3038 request: proto::CreateBufferForPeer,
3039 session: Session,
3040) -> Result<()> {
3041 session
3042 .db()
3043 .await
3044 .check_user_is_project_host(
3045 ProjectId::from_proto(request.project_id),
3046 session.connection_id,
3047 )
3048 .await?;
3049 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3050 session
3051 .peer
3052 .forward_send(session.connection_id, peer_id.into(), request)?;
3053 Ok(())
3054}
3055
3056/// Notify other participants that a buffer has been updated. This is
3057/// allowed for guests as long as the update is limited to selections.
3058async fn update_buffer(
3059 request: proto::UpdateBuffer,
3060 response: Response<proto::UpdateBuffer>,
3061 session: Session,
3062) -> Result<()> {
3063 let project_id = ProjectId::from_proto(request.project_id);
3064 let mut capability = Capability::ReadOnly;
3065
3066 for op in request.operations.iter() {
3067 match op.variant {
3068 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3069 Some(_) => capability = Capability::ReadWrite,
3070 }
3071 }
3072
3073 let host = {
3074 let guard = session
3075 .db()
3076 .await
3077 .connections_for_buffer_update(
3078 project_id,
3079 session.principal_id(),
3080 session.connection_id,
3081 capability,
3082 )
3083 .await?;
3084
3085 let (host, guests) = &*guard;
3086
3087 broadcast(
3088 Some(session.connection_id),
3089 guests.clone(),
3090 |connection_id| {
3091 session
3092 .peer
3093 .forward_send(session.connection_id, connection_id, request.clone())
3094 },
3095 );
3096
3097 *host
3098 };
3099
3100 if host != session.connection_id {
3101 session
3102 .peer
3103 .forward_request(session.connection_id, host, request.clone())
3104 .await?;
3105 }
3106
3107 response.send(proto::Ack {})?;
3108 Ok(())
3109}
3110
3111async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3112 let project_id = ProjectId::from_proto(message.project_id);
3113
3114 let operation = message.operation.as_ref().context("invalid operation")?;
3115 let capability = match operation.variant.as_ref() {
3116 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3117 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3118 match buffer_op.variant {
3119 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3120 Capability::ReadOnly
3121 }
3122 _ => Capability::ReadWrite,
3123 }
3124 } else {
3125 Capability::ReadWrite
3126 }
3127 }
3128 Some(_) => Capability::ReadWrite,
3129 None => Capability::ReadOnly,
3130 };
3131
3132 let guard = session
3133 .db()
3134 .await
3135 .connections_for_buffer_update(
3136 project_id,
3137 session.principal_id(),
3138 session.connection_id,
3139 capability,
3140 )
3141 .await?;
3142
3143 let (host, guests) = &*guard;
3144
3145 broadcast(
3146 Some(session.connection_id),
3147 guests.iter().chain([host]).copied(),
3148 |connection_id| {
3149 session
3150 .peer
3151 .forward_send(session.connection_id, connection_id, message.clone())
3152 },
3153 );
3154
3155 Ok(())
3156}
3157
3158/// Notify other participants that a project has been updated.
3159async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3160 request: T,
3161 session: Session,
3162) -> Result<()> {
3163 let project_id = ProjectId::from_proto(request.remote_entity_id());
3164 let project_connection_ids = session
3165 .db()
3166 .await
3167 .project_connection_ids(project_id, session.connection_id, false)
3168 .await?;
3169
3170 broadcast(
3171 Some(session.connection_id),
3172 project_connection_ids.iter().copied(),
3173 |connection_id| {
3174 session
3175 .peer
3176 .forward_send(session.connection_id, connection_id, request.clone())
3177 },
3178 );
3179 Ok(())
3180}
3181
3182/// Start following another user in a call.
3183async fn follow(
3184 request: proto::Follow,
3185 response: Response<proto::Follow>,
3186 session: UserSession,
3187) -> Result<()> {
3188 let room_id = RoomId::from_proto(request.room_id);
3189 let project_id = request.project_id.map(ProjectId::from_proto);
3190 let leader_id = request
3191 .leader_id
3192 .ok_or_else(|| anyhow!("invalid leader id"))?
3193 .into();
3194 let follower_id = session.connection_id;
3195
3196 session
3197 .db()
3198 .await
3199 .check_room_participants(room_id, leader_id, session.connection_id)
3200 .await?;
3201
3202 let response_payload = session
3203 .peer
3204 .forward_request(session.connection_id, leader_id, request)
3205 .await?;
3206 response.send(response_payload)?;
3207
3208 if let Some(project_id) = project_id {
3209 let room = session
3210 .db()
3211 .await
3212 .follow(room_id, project_id, leader_id, follower_id)
3213 .await?;
3214 room_updated(&room, &session.peer);
3215 }
3216
3217 Ok(())
3218}
3219
3220/// Stop following another user in a call.
3221async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3222 let room_id = RoomId::from_proto(request.room_id);
3223 let project_id = request.project_id.map(ProjectId::from_proto);
3224 let leader_id = request
3225 .leader_id
3226 .ok_or_else(|| anyhow!("invalid leader id"))?
3227 .into();
3228 let follower_id = session.connection_id;
3229
3230 session
3231 .db()
3232 .await
3233 .check_room_participants(room_id, leader_id, session.connection_id)
3234 .await?;
3235
3236 session
3237 .peer
3238 .forward_send(session.connection_id, leader_id, request)?;
3239
3240 if let Some(project_id) = project_id {
3241 let room = session
3242 .db()
3243 .await
3244 .unfollow(room_id, project_id, leader_id, follower_id)
3245 .await?;
3246 room_updated(&room, &session.peer);
3247 }
3248
3249 Ok(())
3250}
3251
3252/// Notify everyone following you of your current location.
3253async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3254 let room_id = RoomId::from_proto(request.room_id);
3255 let database = session.db.lock().await;
3256
3257 let connection_ids = if let Some(project_id) = request.project_id {
3258 let project_id = ProjectId::from_proto(project_id);
3259 database
3260 .project_connection_ids(project_id, session.connection_id, true)
3261 .await?
3262 } else {
3263 database
3264 .room_connection_ids(room_id, session.connection_id)
3265 .await?
3266 };
3267
3268 // For now, don't send view update messages back to that view's current leader.
3269 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3270 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3271 _ => None,
3272 });
3273
3274 for connection_id in connection_ids.iter().cloned() {
3275 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3276 session
3277 .peer
3278 .forward_send(session.connection_id, connection_id, request.clone())?;
3279 }
3280 }
3281 Ok(())
3282}
3283
3284/// Get public data about users.
3285async fn get_users(
3286 request: proto::GetUsers,
3287 response: Response<proto::GetUsers>,
3288 session: Session,
3289) -> Result<()> {
3290 let user_ids = request
3291 .user_ids
3292 .into_iter()
3293 .map(UserId::from_proto)
3294 .collect();
3295 let users = session
3296 .db()
3297 .await
3298 .get_users_by_ids(user_ids)
3299 .await?
3300 .into_iter()
3301 .map(|user| proto::User {
3302 id: user.id.to_proto(),
3303 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3304 github_login: user.github_login,
3305 })
3306 .collect();
3307 response.send(proto::UsersResponse { users })?;
3308 Ok(())
3309}
3310
3311/// Search for users (to invite) buy Github login
3312async fn fuzzy_search_users(
3313 request: proto::FuzzySearchUsers,
3314 response: Response<proto::FuzzySearchUsers>,
3315 session: UserSession,
3316) -> Result<()> {
3317 let query = request.query;
3318 let users = match query.len() {
3319 0 => vec![],
3320 1 | 2 => session
3321 .db()
3322 .await
3323 .get_user_by_github_login(&query)
3324 .await?
3325 .into_iter()
3326 .collect(),
3327 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3328 };
3329 let users = users
3330 .into_iter()
3331 .filter(|user| user.id != session.user_id())
3332 .map(|user| proto::User {
3333 id: user.id.to_proto(),
3334 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3335 github_login: user.github_login,
3336 })
3337 .collect();
3338 response.send(proto::UsersResponse { users })?;
3339 Ok(())
3340}
3341
3342/// Send a contact request to another user.
3343async fn request_contact(
3344 request: proto::RequestContact,
3345 response: Response<proto::RequestContact>,
3346 session: UserSession,
3347) -> Result<()> {
3348 let requester_id = session.user_id();
3349 let responder_id = UserId::from_proto(request.responder_id);
3350 if requester_id == responder_id {
3351 return Err(anyhow!("cannot add yourself as a contact"))?;
3352 }
3353
3354 let notifications = session
3355 .db()
3356 .await
3357 .send_contact_request(requester_id, responder_id)
3358 .await?;
3359
3360 // Update outgoing contact requests of requester
3361 let mut update = proto::UpdateContacts::default();
3362 update.outgoing_requests.push(responder_id.to_proto());
3363 for connection_id in session
3364 .connection_pool()
3365 .await
3366 .user_connection_ids(requester_id)
3367 {
3368 session.peer.send(connection_id, update.clone())?;
3369 }
3370
3371 // Update incoming contact requests of responder
3372 let mut update = proto::UpdateContacts::default();
3373 update
3374 .incoming_requests
3375 .push(proto::IncomingContactRequest {
3376 requester_id: requester_id.to_proto(),
3377 });
3378 let connection_pool = session.connection_pool().await;
3379 for connection_id in connection_pool.user_connection_ids(responder_id) {
3380 session.peer.send(connection_id, update.clone())?;
3381 }
3382
3383 send_notifications(&connection_pool, &session.peer, notifications);
3384
3385 response.send(proto::Ack {})?;
3386 Ok(())
3387}
3388
3389/// Accept or decline a contact request
3390async fn respond_to_contact_request(
3391 request: proto::RespondToContactRequest,
3392 response: Response<proto::RespondToContactRequest>,
3393 session: UserSession,
3394) -> Result<()> {
3395 let responder_id = session.user_id();
3396 let requester_id = UserId::from_proto(request.requester_id);
3397 let db = session.db().await;
3398 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3399 db.dismiss_contact_notification(responder_id, requester_id)
3400 .await?;
3401 } else {
3402 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3403
3404 let notifications = db
3405 .respond_to_contact_request(responder_id, requester_id, accept)
3406 .await?;
3407 let requester_busy = db.is_user_busy(requester_id).await?;
3408 let responder_busy = db.is_user_busy(responder_id).await?;
3409
3410 let pool = session.connection_pool().await;
3411 // Update responder with new contact
3412 let mut update = proto::UpdateContacts::default();
3413 if accept {
3414 update
3415 .contacts
3416 .push(contact_for_user(requester_id, requester_busy, &pool));
3417 }
3418 update
3419 .remove_incoming_requests
3420 .push(requester_id.to_proto());
3421 for connection_id in pool.user_connection_ids(responder_id) {
3422 session.peer.send(connection_id, update.clone())?;
3423 }
3424
3425 // Update requester with new contact
3426 let mut update = proto::UpdateContacts::default();
3427 if accept {
3428 update
3429 .contacts
3430 .push(contact_for_user(responder_id, responder_busy, &pool));
3431 }
3432 update
3433 .remove_outgoing_requests
3434 .push(responder_id.to_proto());
3435
3436 for connection_id in pool.user_connection_ids(requester_id) {
3437 session.peer.send(connection_id, update.clone())?;
3438 }
3439
3440 send_notifications(&pool, &session.peer, notifications);
3441 }
3442
3443 response.send(proto::Ack {})?;
3444 Ok(())
3445}
3446
3447/// Remove a contact.
3448async fn remove_contact(
3449 request: proto::RemoveContact,
3450 response: Response<proto::RemoveContact>,
3451 session: UserSession,
3452) -> Result<()> {
3453 let requester_id = session.user_id();
3454 let responder_id = UserId::from_proto(request.user_id);
3455 let db = session.db().await;
3456 let (contact_accepted, deleted_notification_id) =
3457 db.remove_contact(requester_id, responder_id).await?;
3458
3459 let pool = session.connection_pool().await;
3460 // Update outgoing contact requests of requester
3461 let mut update = proto::UpdateContacts::default();
3462 if contact_accepted {
3463 update.remove_contacts.push(responder_id.to_proto());
3464 } else {
3465 update
3466 .remove_outgoing_requests
3467 .push(responder_id.to_proto());
3468 }
3469 for connection_id in pool.user_connection_ids(requester_id) {
3470 session.peer.send(connection_id, update.clone())?;
3471 }
3472
3473 // Update incoming contact requests of responder
3474 let mut update = proto::UpdateContacts::default();
3475 if contact_accepted {
3476 update.remove_contacts.push(requester_id.to_proto());
3477 } else {
3478 update
3479 .remove_incoming_requests
3480 .push(requester_id.to_proto());
3481 }
3482 for connection_id in pool.user_connection_ids(responder_id) {
3483 session.peer.send(connection_id, update.clone())?;
3484 if let Some(notification_id) = deleted_notification_id {
3485 session.peer.send(
3486 connection_id,
3487 proto::DeleteNotification {
3488 notification_id: notification_id.to_proto(),
3489 },
3490 )?;
3491 }
3492 }
3493
3494 response.send(proto::Ack {})?;
3495 Ok(())
3496}
3497
3498fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3499 version.0.minor() < 139
3500}
3501
3502async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3503 let plan = session.current_plan().await?;
3504
3505 session
3506 .peer
3507 .send(
3508 session.connection_id,
3509 proto::UpdateUserPlan { plan: plan.into() },
3510 )
3511 .trace_err();
3512
3513 Ok(())
3514}
3515
3516async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3517 subscribe_user_to_channels(
3518 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3519 &session,
3520 )
3521 .await?;
3522 Ok(())
3523}
3524
3525async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3526 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3527 let mut pool = session.connection_pool().await;
3528 for membership in &channels_for_user.channel_memberships {
3529 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3530 }
3531 session.peer.send(
3532 session.connection_id,
3533 build_update_user_channels(&channels_for_user),
3534 )?;
3535 session.peer.send(
3536 session.connection_id,
3537 build_channels_update(channels_for_user),
3538 )?;
3539 Ok(())
3540}
3541
3542/// Creates a new channel.
3543async fn create_channel(
3544 request: proto::CreateChannel,
3545 response: Response<proto::CreateChannel>,
3546 session: UserSession,
3547) -> Result<()> {
3548 let db = session.db().await;
3549
3550 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3551 let (channel, membership) = db
3552 .create_channel(&request.name, parent_id, session.user_id())
3553 .await?;
3554
3555 let root_id = channel.root_id();
3556 let channel = Channel::from_model(channel);
3557
3558 response.send(proto::CreateChannelResponse {
3559 channel: Some(channel.to_proto()),
3560 parent_id: request.parent_id,
3561 })?;
3562
3563 let mut connection_pool = session.connection_pool().await;
3564 if let Some(membership) = membership {
3565 connection_pool.subscribe_to_channel(
3566 membership.user_id,
3567 membership.channel_id,
3568 membership.role,
3569 );
3570 let update = proto::UpdateUserChannels {
3571 channel_memberships: vec![proto::ChannelMembership {
3572 channel_id: membership.channel_id.to_proto(),
3573 role: membership.role.into(),
3574 }],
3575 ..Default::default()
3576 };
3577 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3578 session.peer.send(connection_id, update.clone())?;
3579 }
3580 }
3581
3582 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3583 if !role.can_see_channel(channel.visibility) {
3584 continue;
3585 }
3586
3587 let update = proto::UpdateChannels {
3588 channels: vec![channel.to_proto()],
3589 ..Default::default()
3590 };
3591 session.peer.send(connection_id, update.clone())?;
3592 }
3593
3594 Ok(())
3595}
3596
3597/// Delete a channel
3598async fn delete_channel(
3599 request: proto::DeleteChannel,
3600 response: Response<proto::DeleteChannel>,
3601 session: UserSession,
3602) -> Result<()> {
3603 let db = session.db().await;
3604
3605 let channel_id = request.channel_id;
3606 let (root_channel, removed_channels) = db
3607 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3608 .await?;
3609 response.send(proto::Ack {})?;
3610
3611 // Notify members of removed channels
3612 let mut update = proto::UpdateChannels::default();
3613 update
3614 .delete_channels
3615 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3616
3617 let connection_pool = session.connection_pool().await;
3618 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3619 session.peer.send(connection_id, update.clone())?;
3620 }
3621
3622 Ok(())
3623}
3624
3625/// Invite someone to join a channel.
3626async fn invite_channel_member(
3627 request: proto::InviteChannelMember,
3628 response: Response<proto::InviteChannelMember>,
3629 session: UserSession,
3630) -> Result<()> {
3631 let db = session.db().await;
3632 let channel_id = ChannelId::from_proto(request.channel_id);
3633 let invitee_id = UserId::from_proto(request.user_id);
3634 let InviteMemberResult {
3635 channel,
3636 notifications,
3637 } = db
3638 .invite_channel_member(
3639 channel_id,
3640 invitee_id,
3641 session.user_id(),
3642 request.role().into(),
3643 )
3644 .await?;
3645
3646 let update = proto::UpdateChannels {
3647 channel_invitations: vec![channel.to_proto()],
3648 ..Default::default()
3649 };
3650
3651 let connection_pool = session.connection_pool().await;
3652 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3653 session.peer.send(connection_id, update.clone())?;
3654 }
3655
3656 send_notifications(&connection_pool, &session.peer, notifications);
3657
3658 response.send(proto::Ack {})?;
3659 Ok(())
3660}
3661
3662/// remove someone from a channel
3663async fn remove_channel_member(
3664 request: proto::RemoveChannelMember,
3665 response: Response<proto::RemoveChannelMember>,
3666 session: UserSession,
3667) -> Result<()> {
3668 let db = session.db().await;
3669 let channel_id = ChannelId::from_proto(request.channel_id);
3670 let member_id = UserId::from_proto(request.user_id);
3671
3672 let RemoveChannelMemberResult {
3673 membership_update,
3674 notification_id,
3675 } = db
3676 .remove_channel_member(channel_id, member_id, session.user_id())
3677 .await?;
3678
3679 let mut connection_pool = session.connection_pool().await;
3680 notify_membership_updated(
3681 &mut connection_pool,
3682 membership_update,
3683 member_id,
3684 &session.peer,
3685 );
3686 for connection_id in connection_pool.user_connection_ids(member_id) {
3687 if let Some(notification_id) = notification_id {
3688 session
3689 .peer
3690 .send(
3691 connection_id,
3692 proto::DeleteNotification {
3693 notification_id: notification_id.to_proto(),
3694 },
3695 )
3696 .trace_err();
3697 }
3698 }
3699
3700 response.send(proto::Ack {})?;
3701 Ok(())
3702}
3703
3704/// Toggle the channel between public and private.
3705/// Care is taken to maintain the invariant that public channels only descend from public channels,
3706/// (though members-only channels can appear at any point in the hierarchy).
3707async fn set_channel_visibility(
3708 request: proto::SetChannelVisibility,
3709 response: Response<proto::SetChannelVisibility>,
3710 session: UserSession,
3711) -> Result<()> {
3712 let db = session.db().await;
3713 let channel_id = ChannelId::from_proto(request.channel_id);
3714 let visibility = request.visibility().into();
3715
3716 let channel_model = db
3717 .set_channel_visibility(channel_id, visibility, session.user_id())
3718 .await?;
3719 let root_id = channel_model.root_id();
3720 let channel = Channel::from_model(channel_model);
3721
3722 let mut connection_pool = session.connection_pool().await;
3723 for (user_id, role) in connection_pool
3724 .channel_user_ids(root_id)
3725 .collect::<Vec<_>>()
3726 .into_iter()
3727 {
3728 let update = if role.can_see_channel(channel.visibility) {
3729 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3730 proto::UpdateChannels {
3731 channels: vec![channel.to_proto()],
3732 ..Default::default()
3733 }
3734 } else {
3735 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3736 proto::UpdateChannels {
3737 delete_channels: vec![channel.id.to_proto()],
3738 ..Default::default()
3739 }
3740 };
3741
3742 for connection_id in connection_pool.user_connection_ids(user_id) {
3743 session.peer.send(connection_id, update.clone())?;
3744 }
3745 }
3746
3747 response.send(proto::Ack {})?;
3748 Ok(())
3749}
3750
3751/// Alter the role for a user in the channel.
3752async fn set_channel_member_role(
3753 request: proto::SetChannelMemberRole,
3754 response: Response<proto::SetChannelMemberRole>,
3755 session: UserSession,
3756) -> Result<()> {
3757 let db = session.db().await;
3758 let channel_id = ChannelId::from_proto(request.channel_id);
3759 let member_id = UserId::from_proto(request.user_id);
3760 let result = db
3761 .set_channel_member_role(
3762 channel_id,
3763 session.user_id(),
3764 member_id,
3765 request.role().into(),
3766 )
3767 .await?;
3768
3769 match result {
3770 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3771 let mut connection_pool = session.connection_pool().await;
3772 notify_membership_updated(
3773 &mut connection_pool,
3774 membership_update,
3775 member_id,
3776 &session.peer,
3777 )
3778 }
3779 db::SetMemberRoleResult::InviteUpdated(channel) => {
3780 let update = proto::UpdateChannels {
3781 channel_invitations: vec![channel.to_proto()],
3782 ..Default::default()
3783 };
3784
3785 for connection_id in session
3786 .connection_pool()
3787 .await
3788 .user_connection_ids(member_id)
3789 {
3790 session.peer.send(connection_id, update.clone())?;
3791 }
3792 }
3793 }
3794
3795 response.send(proto::Ack {})?;
3796 Ok(())
3797}
3798
3799/// Change the name of a channel
3800async fn rename_channel(
3801 request: proto::RenameChannel,
3802 response: Response<proto::RenameChannel>,
3803 session: UserSession,
3804) -> Result<()> {
3805 let db = session.db().await;
3806 let channel_id = ChannelId::from_proto(request.channel_id);
3807 let channel_model = db
3808 .rename_channel(channel_id, session.user_id(), &request.name)
3809 .await?;
3810 let root_id = channel_model.root_id();
3811 let channel = Channel::from_model(channel_model);
3812
3813 response.send(proto::RenameChannelResponse {
3814 channel: Some(channel.to_proto()),
3815 })?;
3816
3817 let connection_pool = session.connection_pool().await;
3818 let update = proto::UpdateChannels {
3819 channels: vec![channel.to_proto()],
3820 ..Default::default()
3821 };
3822 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3823 if role.can_see_channel(channel.visibility) {
3824 session.peer.send(connection_id, update.clone())?;
3825 }
3826 }
3827
3828 Ok(())
3829}
3830
3831/// Move a channel to a new parent.
3832async fn move_channel(
3833 request: proto::MoveChannel,
3834 response: Response<proto::MoveChannel>,
3835 session: UserSession,
3836) -> Result<()> {
3837 let channel_id = ChannelId::from_proto(request.channel_id);
3838 let to = ChannelId::from_proto(request.to);
3839
3840 let (root_id, channels) = session
3841 .db()
3842 .await
3843 .move_channel(channel_id, to, session.user_id())
3844 .await?;
3845
3846 let connection_pool = session.connection_pool().await;
3847 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3848 let channels = channels
3849 .iter()
3850 .filter_map(|channel| {
3851 if role.can_see_channel(channel.visibility) {
3852 Some(channel.to_proto())
3853 } else {
3854 None
3855 }
3856 })
3857 .collect::<Vec<_>>();
3858 if channels.is_empty() {
3859 continue;
3860 }
3861
3862 let update = proto::UpdateChannels {
3863 channels,
3864 ..Default::default()
3865 };
3866
3867 session.peer.send(connection_id, update.clone())?;
3868 }
3869
3870 response.send(Ack {})?;
3871 Ok(())
3872}
3873
3874/// Get the list of channel members
3875async fn get_channel_members(
3876 request: proto::GetChannelMembers,
3877 response: Response<proto::GetChannelMembers>,
3878 session: UserSession,
3879) -> Result<()> {
3880 let db = session.db().await;
3881 let channel_id = ChannelId::from_proto(request.channel_id);
3882 let limit = if request.limit == 0 {
3883 u16::MAX as u64
3884 } else {
3885 request.limit
3886 };
3887 let (members, users) = db
3888 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3889 .await?;
3890 response.send(proto::GetChannelMembersResponse { members, users })?;
3891 Ok(())
3892}
3893
3894/// Accept or decline a channel invitation.
3895async fn respond_to_channel_invite(
3896 request: proto::RespondToChannelInvite,
3897 response: Response<proto::RespondToChannelInvite>,
3898 session: UserSession,
3899) -> Result<()> {
3900 let db = session.db().await;
3901 let channel_id = ChannelId::from_proto(request.channel_id);
3902 let RespondToChannelInvite {
3903 membership_update,
3904 notifications,
3905 } = db
3906 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3907 .await?;
3908
3909 let mut connection_pool = session.connection_pool().await;
3910 if let Some(membership_update) = membership_update {
3911 notify_membership_updated(
3912 &mut connection_pool,
3913 membership_update,
3914 session.user_id(),
3915 &session.peer,
3916 );
3917 } else {
3918 let update = proto::UpdateChannels {
3919 remove_channel_invitations: vec![channel_id.to_proto()],
3920 ..Default::default()
3921 };
3922
3923 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3924 session.peer.send(connection_id, update.clone())?;
3925 }
3926 };
3927
3928 send_notifications(&connection_pool, &session.peer, notifications);
3929
3930 response.send(proto::Ack {})?;
3931
3932 Ok(())
3933}
3934
3935/// Join the channels' room
3936async fn join_channel(
3937 request: proto::JoinChannel,
3938 response: Response<proto::JoinChannel>,
3939 session: UserSession,
3940) -> Result<()> {
3941 let channel_id = ChannelId::from_proto(request.channel_id);
3942 join_channel_internal(channel_id, Box::new(response), session).await
3943}
3944
3945trait JoinChannelInternalResponse {
3946 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3947}
3948impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3949 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3950 Response::<proto::JoinChannel>::send(self, result)
3951 }
3952}
3953impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3954 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3955 Response::<proto::JoinRoom>::send(self, result)
3956 }
3957}
3958
3959async fn join_channel_internal(
3960 channel_id: ChannelId,
3961 response: Box<impl JoinChannelInternalResponse>,
3962 session: UserSession,
3963) -> Result<()> {
3964 let joined_room = {
3965 let mut db = session.db().await;
3966 // If zed quits without leaving the room, and the user re-opens zed before the
3967 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3968 // room they were in.
3969 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3970 tracing::info!(
3971 stale_connection_id = %connection,
3972 "cleaning up stale connection",
3973 );
3974 drop(db);
3975 leave_room_for_session(&session, connection).await?;
3976 db = session.db().await;
3977 }
3978
3979 let (joined_room, membership_updated, role) = db
3980 .join_channel(channel_id, session.user_id(), session.connection_id)
3981 .await?;
3982
3983 let live_kit_connection_info =
3984 session
3985 .app_state
3986 .live_kit_client
3987 .as_ref()
3988 .and_then(|live_kit| {
3989 let (can_publish, token) = if role == ChannelRole::Guest {
3990 (
3991 false,
3992 live_kit
3993 .guest_token(
3994 &joined_room.room.live_kit_room,
3995 &session.user_id().to_string(),
3996 )
3997 .trace_err()?,
3998 )
3999 } else {
4000 (
4001 true,
4002 live_kit
4003 .room_token(
4004 &joined_room.room.live_kit_room,
4005 &session.user_id().to_string(),
4006 )
4007 .trace_err()?,
4008 )
4009 };
4010
4011 Some(LiveKitConnectionInfo {
4012 server_url: live_kit.url().into(),
4013 token,
4014 can_publish,
4015 })
4016 });
4017
4018 response.send(proto::JoinRoomResponse {
4019 room: Some(joined_room.room.clone()),
4020 channel_id: joined_room
4021 .channel
4022 .as_ref()
4023 .map(|channel| channel.id.to_proto()),
4024 live_kit_connection_info,
4025 })?;
4026
4027 let mut connection_pool = session.connection_pool().await;
4028 if let Some(membership_updated) = membership_updated {
4029 notify_membership_updated(
4030 &mut connection_pool,
4031 membership_updated,
4032 session.user_id(),
4033 &session.peer,
4034 );
4035 }
4036
4037 room_updated(&joined_room.room, &session.peer);
4038
4039 joined_room
4040 };
4041
4042 channel_updated(
4043 &joined_room
4044 .channel
4045 .ok_or_else(|| anyhow!("channel not returned"))?,
4046 &joined_room.room,
4047 &session.peer,
4048 &*session.connection_pool().await,
4049 );
4050
4051 update_user_contacts(session.user_id(), &session).await?;
4052 Ok(())
4053}
4054
4055/// Start editing the channel notes
4056async fn join_channel_buffer(
4057 request: proto::JoinChannelBuffer,
4058 response: Response<proto::JoinChannelBuffer>,
4059 session: UserSession,
4060) -> Result<()> {
4061 let db = session.db().await;
4062 let channel_id = ChannelId::from_proto(request.channel_id);
4063
4064 let open_response = db
4065 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4066 .await?;
4067
4068 let collaborators = open_response.collaborators.clone();
4069 response.send(open_response)?;
4070
4071 let update = UpdateChannelBufferCollaborators {
4072 channel_id: channel_id.to_proto(),
4073 collaborators: collaborators.clone(),
4074 };
4075 channel_buffer_updated(
4076 session.connection_id,
4077 collaborators
4078 .iter()
4079 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4080 &update,
4081 &session.peer,
4082 );
4083
4084 Ok(())
4085}
4086
4087/// Edit the channel notes
4088async fn update_channel_buffer(
4089 request: proto::UpdateChannelBuffer,
4090 session: UserSession,
4091) -> Result<()> {
4092 let db = session.db().await;
4093 let channel_id = ChannelId::from_proto(request.channel_id);
4094
4095 let (collaborators, epoch, version) = db
4096 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4097 .await?;
4098
4099 channel_buffer_updated(
4100 session.connection_id,
4101 collaborators.clone(),
4102 &proto::UpdateChannelBuffer {
4103 channel_id: channel_id.to_proto(),
4104 operations: request.operations,
4105 },
4106 &session.peer,
4107 );
4108
4109 let pool = &*session.connection_pool().await;
4110
4111 let non_collaborators =
4112 pool.channel_connection_ids(channel_id)
4113 .filter_map(|(connection_id, _)| {
4114 if collaborators.contains(&connection_id) {
4115 None
4116 } else {
4117 Some(connection_id)
4118 }
4119 });
4120
4121 broadcast(None, non_collaborators, |peer_id| {
4122 session.peer.send(
4123 peer_id,
4124 proto::UpdateChannels {
4125 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4126 channel_id: channel_id.to_proto(),
4127 epoch: epoch as u64,
4128 version: version.clone(),
4129 }],
4130 ..Default::default()
4131 },
4132 )
4133 });
4134
4135 Ok(())
4136}
4137
4138/// Rejoin the channel notes after a connection blip
4139async fn rejoin_channel_buffers(
4140 request: proto::RejoinChannelBuffers,
4141 response: Response<proto::RejoinChannelBuffers>,
4142 session: UserSession,
4143) -> Result<()> {
4144 let db = session.db().await;
4145 let buffers = db
4146 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4147 .await?;
4148
4149 for rejoined_buffer in &buffers {
4150 let collaborators_to_notify = rejoined_buffer
4151 .buffer
4152 .collaborators
4153 .iter()
4154 .filter_map(|c| Some(c.peer_id?.into()));
4155 channel_buffer_updated(
4156 session.connection_id,
4157 collaborators_to_notify,
4158 &proto::UpdateChannelBufferCollaborators {
4159 channel_id: rejoined_buffer.buffer.channel_id,
4160 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4161 },
4162 &session.peer,
4163 );
4164 }
4165
4166 response.send(proto::RejoinChannelBuffersResponse {
4167 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4168 })?;
4169
4170 Ok(())
4171}
4172
4173/// Stop editing the channel notes
4174async fn leave_channel_buffer(
4175 request: proto::LeaveChannelBuffer,
4176 response: Response<proto::LeaveChannelBuffer>,
4177 session: UserSession,
4178) -> Result<()> {
4179 let db = session.db().await;
4180 let channel_id = ChannelId::from_proto(request.channel_id);
4181
4182 let left_buffer = db
4183 .leave_channel_buffer(channel_id, session.connection_id)
4184 .await?;
4185
4186 response.send(Ack {})?;
4187
4188 channel_buffer_updated(
4189 session.connection_id,
4190 left_buffer.connections,
4191 &proto::UpdateChannelBufferCollaborators {
4192 channel_id: channel_id.to_proto(),
4193 collaborators: left_buffer.collaborators,
4194 },
4195 &session.peer,
4196 );
4197
4198 Ok(())
4199}
4200
4201fn channel_buffer_updated<T: EnvelopedMessage>(
4202 sender_id: ConnectionId,
4203 collaborators: impl IntoIterator<Item = ConnectionId>,
4204 message: &T,
4205 peer: &Peer,
4206) {
4207 broadcast(Some(sender_id), collaborators, |peer_id| {
4208 peer.send(peer_id, message.clone())
4209 });
4210}
4211
4212fn send_notifications(
4213 connection_pool: &ConnectionPool,
4214 peer: &Peer,
4215 notifications: db::NotificationBatch,
4216) {
4217 for (user_id, notification) in notifications {
4218 for connection_id in connection_pool.user_connection_ids(user_id) {
4219 if let Err(error) = peer.send(
4220 connection_id,
4221 proto::AddNotification {
4222 notification: Some(notification.clone()),
4223 },
4224 ) {
4225 tracing::error!(
4226 "failed to send notification to {:?} {}",
4227 connection_id,
4228 error
4229 );
4230 }
4231 }
4232 }
4233}
4234
4235/// Send a message to the channel
4236async fn send_channel_message(
4237 request: proto::SendChannelMessage,
4238 response: Response<proto::SendChannelMessage>,
4239 session: UserSession,
4240) -> Result<()> {
4241 // Validate the message body.
4242 let body = request.body.trim().to_string();
4243 if body.len() > MAX_MESSAGE_LEN {
4244 return Err(anyhow!("message is too long"))?;
4245 }
4246 if body.is_empty() {
4247 return Err(anyhow!("message can't be blank"))?;
4248 }
4249
4250 // TODO: adjust mentions if body is trimmed
4251
4252 let timestamp = OffsetDateTime::now_utc();
4253 let nonce = request
4254 .nonce
4255 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4256
4257 let channel_id = ChannelId::from_proto(request.channel_id);
4258 let CreatedChannelMessage {
4259 message_id,
4260 participant_connection_ids,
4261 notifications,
4262 } = session
4263 .db()
4264 .await
4265 .create_channel_message(
4266 channel_id,
4267 session.user_id(),
4268 &body,
4269 &request.mentions,
4270 timestamp,
4271 nonce.clone().into(),
4272 match request.reply_to_message_id {
4273 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4274 None => None,
4275 },
4276 )
4277 .await?;
4278
4279 let message = proto::ChannelMessage {
4280 sender_id: session.user_id().to_proto(),
4281 id: message_id.to_proto(),
4282 body,
4283 mentions: request.mentions,
4284 timestamp: timestamp.unix_timestamp() as u64,
4285 nonce: Some(nonce),
4286 reply_to_message_id: request.reply_to_message_id,
4287 edited_at: None,
4288 };
4289 broadcast(
4290 Some(session.connection_id),
4291 participant_connection_ids.clone(),
4292 |connection| {
4293 session.peer.send(
4294 connection,
4295 proto::ChannelMessageSent {
4296 channel_id: channel_id.to_proto(),
4297 message: Some(message.clone()),
4298 },
4299 )
4300 },
4301 );
4302 response.send(proto::SendChannelMessageResponse {
4303 message: Some(message),
4304 })?;
4305
4306 let pool = &*session.connection_pool().await;
4307 let non_participants =
4308 pool.channel_connection_ids(channel_id)
4309 .filter_map(|(connection_id, _)| {
4310 if participant_connection_ids.contains(&connection_id) {
4311 None
4312 } else {
4313 Some(connection_id)
4314 }
4315 });
4316 broadcast(None, non_participants, |peer_id| {
4317 session.peer.send(
4318 peer_id,
4319 proto::UpdateChannels {
4320 latest_channel_message_ids: vec![proto::ChannelMessageId {
4321 channel_id: channel_id.to_proto(),
4322 message_id: message_id.to_proto(),
4323 }],
4324 ..Default::default()
4325 },
4326 )
4327 });
4328 send_notifications(pool, &session.peer, notifications);
4329
4330 Ok(())
4331}
4332
4333/// Delete a channel message
4334async fn remove_channel_message(
4335 request: proto::RemoveChannelMessage,
4336 response: Response<proto::RemoveChannelMessage>,
4337 session: UserSession,
4338) -> Result<()> {
4339 let channel_id = ChannelId::from_proto(request.channel_id);
4340 let message_id = MessageId::from_proto(request.message_id);
4341 let (connection_ids, existing_notification_ids) = session
4342 .db()
4343 .await
4344 .remove_channel_message(channel_id, message_id, session.user_id())
4345 .await?;
4346
4347 broadcast(
4348 Some(session.connection_id),
4349 connection_ids,
4350 move |connection| {
4351 session.peer.send(connection, request.clone())?;
4352
4353 for notification_id in &existing_notification_ids {
4354 session.peer.send(
4355 connection,
4356 proto::DeleteNotification {
4357 notification_id: (*notification_id).to_proto(),
4358 },
4359 )?;
4360 }
4361
4362 Ok(())
4363 },
4364 );
4365 response.send(proto::Ack {})?;
4366 Ok(())
4367}
4368
4369async fn update_channel_message(
4370 request: proto::UpdateChannelMessage,
4371 response: Response<proto::UpdateChannelMessage>,
4372 session: UserSession,
4373) -> Result<()> {
4374 let channel_id = ChannelId::from_proto(request.channel_id);
4375 let message_id = MessageId::from_proto(request.message_id);
4376 let updated_at = OffsetDateTime::now_utc();
4377 let UpdatedChannelMessage {
4378 message_id,
4379 participant_connection_ids,
4380 notifications,
4381 reply_to_message_id,
4382 timestamp,
4383 deleted_mention_notification_ids,
4384 updated_mention_notifications,
4385 } = session
4386 .db()
4387 .await
4388 .update_channel_message(
4389 channel_id,
4390 message_id,
4391 session.user_id(),
4392 request.body.as_str(),
4393 &request.mentions,
4394 updated_at,
4395 )
4396 .await?;
4397
4398 let nonce = request
4399 .nonce
4400 .clone()
4401 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4402
4403 let message = proto::ChannelMessage {
4404 sender_id: session.user_id().to_proto(),
4405 id: message_id.to_proto(),
4406 body: request.body.clone(),
4407 mentions: request.mentions.clone(),
4408 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4409 nonce: Some(nonce),
4410 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4411 edited_at: Some(updated_at.unix_timestamp() as u64),
4412 };
4413
4414 response.send(proto::Ack {})?;
4415
4416 let pool = &*session.connection_pool().await;
4417 broadcast(
4418 Some(session.connection_id),
4419 participant_connection_ids,
4420 |connection| {
4421 session.peer.send(
4422 connection,
4423 proto::ChannelMessageUpdate {
4424 channel_id: channel_id.to_proto(),
4425 message: Some(message.clone()),
4426 },
4427 )?;
4428
4429 for notification_id in &deleted_mention_notification_ids {
4430 session.peer.send(
4431 connection,
4432 proto::DeleteNotification {
4433 notification_id: (*notification_id).to_proto(),
4434 },
4435 )?;
4436 }
4437
4438 for notification in &updated_mention_notifications {
4439 session.peer.send(
4440 connection,
4441 proto::UpdateNotification {
4442 notification: Some(notification.clone()),
4443 },
4444 )?;
4445 }
4446
4447 Ok(())
4448 },
4449 );
4450
4451 send_notifications(pool, &session.peer, notifications);
4452
4453 Ok(())
4454}
4455
4456/// Mark a channel message as read
4457async fn acknowledge_channel_message(
4458 request: proto::AckChannelMessage,
4459 session: UserSession,
4460) -> Result<()> {
4461 let channel_id = ChannelId::from_proto(request.channel_id);
4462 let message_id = MessageId::from_proto(request.message_id);
4463 let notifications = session
4464 .db()
4465 .await
4466 .observe_channel_message(channel_id, session.user_id(), message_id)
4467 .await?;
4468 send_notifications(
4469 &*session.connection_pool().await,
4470 &session.peer,
4471 notifications,
4472 );
4473 Ok(())
4474}
4475
4476/// Mark a buffer version as synced
4477async fn acknowledge_buffer_version(
4478 request: proto::AckBufferOperation,
4479 session: UserSession,
4480) -> Result<()> {
4481 let buffer_id = BufferId::from_proto(request.buffer_id);
4482 session
4483 .db()
4484 .await
4485 .observe_buffer_version(
4486 buffer_id,
4487 session.user_id(),
4488 request.epoch as i32,
4489 &request.version,
4490 )
4491 .await?;
4492 Ok(())
4493}
4494
4495async fn count_language_model_tokens(
4496 request: proto::CountLanguageModelTokens,
4497 response: Response<proto::CountLanguageModelTokens>,
4498 session: Session,
4499 config: &Config,
4500) -> Result<()> {
4501 let Some(session) = session.for_user() else {
4502 return Err(anyhow!("user not found"))?;
4503 };
4504 authorize_access_to_legacy_llm_endpoints(&session).await?;
4505
4506 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4507 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4508 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4509 };
4510
4511 session
4512 .app_state
4513 .rate_limiter
4514 .check(&*rate_limit, session.user_id())
4515 .await?;
4516
4517 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4518 Some(proto::LanguageModelProvider::Google) => {
4519 let api_key = config
4520 .google_ai_api_key
4521 .as_ref()
4522 .context("no Google AI API key configured on the server")?;
4523 google_ai::count_tokens(
4524 session.http_client.as_ref(),
4525 google_ai::API_URL,
4526 api_key,
4527 serde_json::from_str(&request.request)?,
4528 )
4529 .await?
4530 }
4531 _ => return Err(anyhow!("unsupported provider"))?,
4532 };
4533
4534 response.send(proto::CountLanguageModelTokensResponse {
4535 token_count: result.total_tokens as u32,
4536 })?;
4537
4538 Ok(())
4539}
4540
4541struct ZedProCountLanguageModelTokensRateLimit;
4542
4543impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4544 fn capacity(&self) -> usize {
4545 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4546 .ok()
4547 .and_then(|v| v.parse().ok())
4548 .unwrap_or(600) // Picked arbitrarily
4549 }
4550
4551 fn refill_duration(&self) -> chrono::Duration {
4552 chrono::Duration::hours(1)
4553 }
4554
4555 fn db_name(&self) -> &'static str {
4556 "zed-pro:count-language-model-tokens"
4557 }
4558}
4559
4560struct FreeCountLanguageModelTokensRateLimit;
4561
4562impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4563 fn capacity(&self) -> usize {
4564 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4565 .ok()
4566 .and_then(|v| v.parse().ok())
4567 .unwrap_or(600 / 10) // Picked arbitrarily
4568 }
4569
4570 fn refill_duration(&self) -> chrono::Duration {
4571 chrono::Duration::hours(1)
4572 }
4573
4574 fn db_name(&self) -> &'static str {
4575 "free:count-language-model-tokens"
4576 }
4577}
4578
4579struct ZedProComputeEmbeddingsRateLimit;
4580
4581impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4582 fn capacity(&self) -> usize {
4583 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4584 .ok()
4585 .and_then(|v| v.parse().ok())
4586 .unwrap_or(5000) // Picked arbitrarily
4587 }
4588
4589 fn refill_duration(&self) -> chrono::Duration {
4590 chrono::Duration::hours(1)
4591 }
4592
4593 fn db_name(&self) -> &'static str {
4594 "zed-pro:compute-embeddings"
4595 }
4596}
4597
4598struct FreeComputeEmbeddingsRateLimit;
4599
4600impl RateLimit for FreeComputeEmbeddingsRateLimit {
4601 fn capacity(&self) -> usize {
4602 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4603 .ok()
4604 .and_then(|v| v.parse().ok())
4605 .unwrap_or(5000 / 10) // Picked arbitrarily
4606 }
4607
4608 fn refill_duration(&self) -> chrono::Duration {
4609 chrono::Duration::hours(1)
4610 }
4611
4612 fn db_name(&self) -> &'static str {
4613 "free:compute-embeddings"
4614 }
4615}
4616
4617async fn compute_embeddings(
4618 request: proto::ComputeEmbeddings,
4619 response: Response<proto::ComputeEmbeddings>,
4620 session: UserSession,
4621 api_key: Option<Arc<str>>,
4622) -> Result<()> {
4623 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4624 authorize_access_to_legacy_llm_endpoints(&session).await?;
4625
4626 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4627 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4628 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4629 };
4630
4631 session
4632 .app_state
4633 .rate_limiter
4634 .check(&*rate_limit, session.user_id())
4635 .await?;
4636
4637 let embeddings = match request.model.as_str() {
4638 "openai/text-embedding-3-small" => {
4639 open_ai::embed(
4640 session.http_client.as_ref(),
4641 OPEN_AI_API_URL,
4642 &api_key,
4643 OpenAiEmbeddingModel::TextEmbedding3Small,
4644 request.texts.iter().map(|text| text.as_str()),
4645 )
4646 .await?
4647 }
4648 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4649 };
4650
4651 let embeddings = request
4652 .texts
4653 .iter()
4654 .map(|text| {
4655 let mut hasher = sha2::Sha256::new();
4656 hasher.update(text.as_bytes());
4657 let result = hasher.finalize();
4658 result.to_vec()
4659 })
4660 .zip(
4661 embeddings
4662 .data
4663 .into_iter()
4664 .map(|embedding| embedding.embedding),
4665 )
4666 .collect::<HashMap<_, _>>();
4667
4668 let db = session.db().await;
4669 db.save_embeddings(&request.model, &embeddings)
4670 .await
4671 .context("failed to save embeddings")
4672 .trace_err();
4673
4674 response.send(proto::ComputeEmbeddingsResponse {
4675 embeddings: embeddings
4676 .into_iter()
4677 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4678 .collect(),
4679 })?;
4680 Ok(())
4681}
4682
4683async fn get_cached_embeddings(
4684 request: proto::GetCachedEmbeddings,
4685 response: Response<proto::GetCachedEmbeddings>,
4686 session: UserSession,
4687) -> Result<()> {
4688 authorize_access_to_legacy_llm_endpoints(&session).await?;
4689
4690 let db = session.db().await;
4691 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4692
4693 response.send(proto::GetCachedEmbeddingsResponse {
4694 embeddings: embeddings
4695 .into_iter()
4696 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4697 .collect(),
4698 })?;
4699 Ok(())
4700}
4701
4702/// This is leftover from before the LLM service.
4703///
4704/// The endpoints protected by this check will be moved there eventually.
4705async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4706 if session.is_staff() {
4707 Ok(())
4708 } else {
4709 Err(anyhow!("permission denied"))?
4710 }
4711}
4712
4713/// Get a Supermaven API key for the user
4714async fn get_supermaven_api_key(
4715 _request: proto::GetSupermavenApiKey,
4716 response: Response<proto::GetSupermavenApiKey>,
4717 session: UserSession,
4718) -> Result<()> {
4719 let user_id: String = session.user_id().to_string();
4720 if !session.is_staff() {
4721 return Err(anyhow!("supermaven not enabled for this account"))?;
4722 }
4723
4724 let email = session
4725 .email()
4726 .ok_or_else(|| anyhow!("user must have an email"))?;
4727
4728 let supermaven_admin_api = session
4729 .supermaven_client
4730 .as_ref()
4731 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4732
4733 let result = supermaven_admin_api
4734 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4735 .await?;
4736
4737 response.send(proto::GetSupermavenApiKeyResponse {
4738 api_key: result.api_key,
4739 })?;
4740
4741 Ok(())
4742}
4743
4744/// Start receiving chat updates for a channel
4745async fn join_channel_chat(
4746 request: proto::JoinChannelChat,
4747 response: Response<proto::JoinChannelChat>,
4748 session: UserSession,
4749) -> Result<()> {
4750 let channel_id = ChannelId::from_proto(request.channel_id);
4751
4752 let db = session.db().await;
4753 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4754 .await?;
4755 let messages = db
4756 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4757 .await?;
4758 response.send(proto::JoinChannelChatResponse {
4759 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4760 messages,
4761 })?;
4762 Ok(())
4763}
4764
4765/// Stop receiving chat updates for a channel
4766async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4767 let channel_id = ChannelId::from_proto(request.channel_id);
4768 session
4769 .db()
4770 .await
4771 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4772 .await?;
4773 Ok(())
4774}
4775
4776/// Retrieve the chat history for a channel
4777async fn get_channel_messages(
4778 request: proto::GetChannelMessages,
4779 response: Response<proto::GetChannelMessages>,
4780 session: UserSession,
4781) -> Result<()> {
4782 let channel_id = ChannelId::from_proto(request.channel_id);
4783 let messages = session
4784 .db()
4785 .await
4786 .get_channel_messages(
4787 channel_id,
4788 session.user_id(),
4789 MESSAGE_COUNT_PER_PAGE,
4790 Some(MessageId::from_proto(request.before_message_id)),
4791 )
4792 .await?;
4793 response.send(proto::GetChannelMessagesResponse {
4794 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4795 messages,
4796 })?;
4797 Ok(())
4798}
4799
4800/// Retrieve specific chat messages
4801async fn get_channel_messages_by_id(
4802 request: proto::GetChannelMessagesById,
4803 response: Response<proto::GetChannelMessagesById>,
4804 session: UserSession,
4805) -> Result<()> {
4806 let message_ids = request
4807 .message_ids
4808 .iter()
4809 .map(|id| MessageId::from_proto(*id))
4810 .collect::<Vec<_>>();
4811 let messages = session
4812 .db()
4813 .await
4814 .get_channel_messages_by_id(session.user_id(), &message_ids)
4815 .await?;
4816 response.send(proto::GetChannelMessagesResponse {
4817 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4818 messages,
4819 })?;
4820 Ok(())
4821}
4822
4823/// Retrieve the current users notifications
4824async fn get_notifications(
4825 request: proto::GetNotifications,
4826 response: Response<proto::GetNotifications>,
4827 session: UserSession,
4828) -> Result<()> {
4829 let notifications = session
4830 .db()
4831 .await
4832 .get_notifications(
4833 session.user_id(),
4834 NOTIFICATION_COUNT_PER_PAGE,
4835 request
4836 .before_id
4837 .map(|id| db::NotificationId::from_proto(id)),
4838 )
4839 .await?;
4840 response.send(proto::GetNotificationsResponse {
4841 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4842 notifications,
4843 })?;
4844 Ok(())
4845}
4846
4847/// Mark notifications as read
4848async fn mark_notification_as_read(
4849 request: proto::MarkNotificationRead,
4850 response: Response<proto::MarkNotificationRead>,
4851 session: UserSession,
4852) -> Result<()> {
4853 let database = &session.db().await;
4854 let notifications = database
4855 .mark_notification_as_read_by_id(
4856 session.user_id(),
4857 NotificationId::from_proto(request.notification_id),
4858 )
4859 .await?;
4860 send_notifications(
4861 &*session.connection_pool().await,
4862 &session.peer,
4863 notifications,
4864 );
4865 response.send(proto::Ack {})?;
4866 Ok(())
4867}
4868
4869/// Get the current users information
4870async fn get_private_user_info(
4871 _request: proto::GetPrivateUserInfo,
4872 response: Response<proto::GetPrivateUserInfo>,
4873 session: UserSession,
4874) -> Result<()> {
4875 let db = session.db().await;
4876
4877 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4878 let user = db
4879 .get_user_by_id(session.user_id())
4880 .await?
4881 .ok_or_else(|| anyhow!("user not found"))?;
4882 let flags = db.get_user_flags(session.user_id()).await?;
4883
4884 response.send(proto::GetPrivateUserInfoResponse {
4885 metrics_id,
4886 staff: user.admin,
4887 flags,
4888 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4889 })?;
4890 Ok(())
4891}
4892
4893/// Accept the terms of service (tos) on behalf of the current user
4894async fn accept_terms_of_service(
4895 _request: proto::AcceptTermsOfService,
4896 response: Response<proto::AcceptTermsOfService>,
4897 session: UserSession,
4898) -> Result<()> {
4899 let db = session.db().await;
4900
4901 let accepted_tos_at = Utc::now();
4902 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4903 .await?;
4904
4905 response.send(proto::AcceptTermsOfServiceResponse {
4906 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4907 })?;
4908 Ok(())
4909}
4910
4911/// The minimum account age an account must have in order to use the LLM service.
4912const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4913
4914async fn get_llm_api_token(
4915 _request: proto::GetLlmToken,
4916 response: Response<proto::GetLlmToken>,
4917 session: UserSession,
4918) -> Result<()> {
4919 let db = session.db().await;
4920
4921 let flags = db.get_user_flags(session.user_id()).await?;
4922 if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
4923 Err(anyhow!("permission denied"))?
4924 }
4925
4926 let user_id = session.user_id();
4927 let user = db
4928 .get_user_by_id(user_id)
4929 .await?
4930 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4931
4932 if user.accepted_tos_at.is_none() {
4933 Err(anyhow!("terms of service not accepted"))?
4934 }
4935
4936 let mut account_created_at = user.created_at;
4937 if let Some(github_created_at) = user.github_user_created_at {
4938 account_created_at = account_created_at.min(github_created_at);
4939 }
4940 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4941 Err(anyhow!("account too young"))?
4942 }
4943
4944 let token = LlmTokenClaims::create(
4945 user.id,
4946 session.is_staff(),
4947 session.current_plan().await?,
4948 &session.app_state.config,
4949 )?;
4950 response.send(proto::GetLlmTokenResponse { token })?;
4951 Ok(())
4952}
4953
4954fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4955 let message = match message {
4956 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4957 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4958 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4959 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4960 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4961 code: frame.code.into(),
4962 reason: frame.reason,
4963 })),
4964 // We should never receive a frame while reading the message, according
4965 // to the `tungstenite` maintainers:
4966 //
4967 // > It cannot occur when you read messages from the WebSocket, but it
4968 // > can be used when you want to send the raw frames (e.g. you want to
4969 // > send the frames to the WebSocket without composing the full message first).
4970 // >
4971 // > — https://github.com/snapview/tungstenite-rs/issues/268
4972 TungsteniteMessage::Frame(_) => {
4973 bail!("received an unexpected frame while reading the message")
4974 }
4975 };
4976
4977 Ok(message)
4978}
4979
4980fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4981 match message {
4982 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4983 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4984 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4985 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4986 AxumMessage::Close(frame) => {
4987 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4988 code: frame.code.into(),
4989 reason: frame.reason,
4990 }))
4991 }
4992 }
4993}
4994
4995fn notify_membership_updated(
4996 connection_pool: &mut ConnectionPool,
4997 result: MembershipUpdated,
4998 user_id: UserId,
4999 peer: &Peer,
5000) {
5001 for membership in &result.new_channels.channel_memberships {
5002 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5003 }
5004 for channel_id in &result.removed_channels {
5005 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5006 }
5007
5008 let user_channels_update = proto::UpdateUserChannels {
5009 channel_memberships: result
5010 .new_channels
5011 .channel_memberships
5012 .iter()
5013 .map(|cm| proto::ChannelMembership {
5014 channel_id: cm.channel_id.to_proto(),
5015 role: cm.role.into(),
5016 })
5017 .collect(),
5018 ..Default::default()
5019 };
5020
5021 let mut update = build_channels_update(result.new_channels);
5022 update.delete_channels = result
5023 .removed_channels
5024 .into_iter()
5025 .map(|id| id.to_proto())
5026 .collect();
5027 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5028
5029 for connection_id in connection_pool.user_connection_ids(user_id) {
5030 peer.send(connection_id, user_channels_update.clone())
5031 .trace_err();
5032 peer.send(connection_id, update.clone()).trace_err();
5033 }
5034}
5035
5036fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5037 proto::UpdateUserChannels {
5038 channel_memberships: channels
5039 .channel_memberships
5040 .iter()
5041 .map(|m| proto::ChannelMembership {
5042 channel_id: m.channel_id.to_proto(),
5043 role: m.role.into(),
5044 })
5045 .collect(),
5046 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5047 observed_channel_message_id: channels.observed_channel_messages.clone(),
5048 }
5049}
5050
5051fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5052 let mut update = proto::UpdateChannels::default();
5053
5054 for channel in channels.channels {
5055 update.channels.push(channel.to_proto());
5056 }
5057
5058 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5059 update.latest_channel_message_ids = channels.latest_channel_messages;
5060
5061 for (channel_id, participants) in channels.channel_participants {
5062 update
5063 .channel_participants
5064 .push(proto::ChannelParticipants {
5065 channel_id: channel_id.to_proto(),
5066 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5067 });
5068 }
5069
5070 for channel in channels.invited_channels {
5071 update.channel_invitations.push(channel.to_proto());
5072 }
5073
5074 update.hosted_projects = channels.hosted_projects;
5075 update
5076}
5077
5078fn build_initial_contacts_update(
5079 contacts: Vec<db::Contact>,
5080 pool: &ConnectionPool,
5081) -> proto::UpdateContacts {
5082 let mut update = proto::UpdateContacts::default();
5083
5084 for contact in contacts {
5085 match contact {
5086 db::Contact::Accepted { user_id, busy } => {
5087 update.contacts.push(contact_for_user(user_id, busy, &pool));
5088 }
5089 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5090 db::Contact::Incoming { user_id } => {
5091 update
5092 .incoming_requests
5093 .push(proto::IncomingContactRequest {
5094 requester_id: user_id.to_proto(),
5095 })
5096 }
5097 }
5098 }
5099
5100 update
5101}
5102
5103fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5104 proto::Contact {
5105 user_id: user_id.to_proto(),
5106 online: pool.is_user_online(user_id),
5107 busy,
5108 }
5109}
5110
5111fn room_updated(room: &proto::Room, peer: &Peer) {
5112 broadcast(
5113 None,
5114 room.participants
5115 .iter()
5116 .filter_map(|participant| Some(participant.peer_id?.into())),
5117 |peer_id| {
5118 peer.send(
5119 peer_id,
5120 proto::RoomUpdated {
5121 room: Some(room.clone()),
5122 },
5123 )
5124 },
5125 );
5126}
5127
5128fn channel_updated(
5129 channel: &db::channel::Model,
5130 room: &proto::Room,
5131 peer: &Peer,
5132 pool: &ConnectionPool,
5133) {
5134 let participants = room
5135 .participants
5136 .iter()
5137 .map(|p| p.user_id)
5138 .collect::<Vec<_>>();
5139
5140 broadcast(
5141 None,
5142 pool.channel_connection_ids(channel.root_id())
5143 .filter_map(|(channel_id, role)| {
5144 role.can_see_channel(channel.visibility).then(|| channel_id)
5145 }),
5146 |peer_id| {
5147 peer.send(
5148 peer_id,
5149 proto::UpdateChannels {
5150 channel_participants: vec![proto::ChannelParticipants {
5151 channel_id: channel.id.to_proto(),
5152 participant_user_ids: participants.clone(),
5153 }],
5154 ..Default::default()
5155 },
5156 )
5157 },
5158 );
5159}
5160
5161async fn send_dev_server_projects_update(
5162 user_id: UserId,
5163 mut status: proto::DevServerProjectsUpdate,
5164 session: &Session,
5165) {
5166 let pool = session.connection_pool().await;
5167 for dev_server in &mut status.dev_servers {
5168 dev_server.status =
5169 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5170 }
5171 let connections = pool.user_connection_ids(user_id);
5172 for connection_id in connections {
5173 session.peer.send(connection_id, status.clone()).trace_err();
5174 }
5175}
5176
5177async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5178 let db = session.db().await;
5179
5180 let contacts = db.get_contacts(user_id).await?;
5181 let busy = db.is_user_busy(user_id).await?;
5182
5183 let pool = session.connection_pool().await;
5184 let updated_contact = contact_for_user(user_id, busy, &pool);
5185 for contact in contacts {
5186 if let db::Contact::Accepted {
5187 user_id: contact_user_id,
5188 ..
5189 } = contact
5190 {
5191 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5192 session
5193 .peer
5194 .send(
5195 contact_conn_id,
5196 proto::UpdateContacts {
5197 contacts: vec![updated_contact.clone()],
5198 remove_contacts: Default::default(),
5199 incoming_requests: Default::default(),
5200 remove_incoming_requests: Default::default(),
5201 outgoing_requests: Default::default(),
5202 remove_outgoing_requests: Default::default(),
5203 },
5204 )
5205 .trace_err();
5206 }
5207 }
5208 }
5209 Ok(())
5210}
5211
5212async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5213 log::info!("lost dev server connection, unsharing projects");
5214 let project_ids = session
5215 .db()
5216 .await
5217 .get_stale_dev_server_projects(session.connection_id)
5218 .await?;
5219
5220 for project_id in project_ids {
5221 // not unshare re-checks the connection ids match, so we get away with no transaction
5222 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5223 }
5224
5225 let user_id = session.dev_server().user_id;
5226 let update = session
5227 .db()
5228 .await
5229 .dev_server_projects_update(user_id)
5230 .await?;
5231
5232 send_dev_server_projects_update(user_id, update, session).await;
5233
5234 Ok(())
5235}
5236
5237async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5238 let mut contacts_to_update = HashSet::default();
5239
5240 let room_id;
5241 let canceled_calls_to_user_ids;
5242 let live_kit_room;
5243 let delete_live_kit_room;
5244 let room;
5245 let channel;
5246
5247 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5248 contacts_to_update.insert(session.user_id());
5249
5250 for project in left_room.left_projects.values() {
5251 project_left(project, session);
5252 }
5253
5254 room_id = RoomId::from_proto(left_room.room.id);
5255 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5256 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5257 delete_live_kit_room = left_room.deleted;
5258 room = mem::take(&mut left_room.room);
5259 channel = mem::take(&mut left_room.channel);
5260
5261 room_updated(&room, &session.peer);
5262 } else {
5263 return Ok(());
5264 }
5265
5266 if let Some(channel) = channel {
5267 channel_updated(
5268 &channel,
5269 &room,
5270 &session.peer,
5271 &*session.connection_pool().await,
5272 );
5273 }
5274
5275 {
5276 let pool = session.connection_pool().await;
5277 for canceled_user_id in canceled_calls_to_user_ids {
5278 for connection_id in pool.user_connection_ids(canceled_user_id) {
5279 session
5280 .peer
5281 .send(
5282 connection_id,
5283 proto::CallCanceled {
5284 room_id: room_id.to_proto(),
5285 },
5286 )
5287 .trace_err();
5288 }
5289 contacts_to_update.insert(canceled_user_id);
5290 }
5291 }
5292
5293 for contact_user_id in contacts_to_update {
5294 update_user_contacts(contact_user_id, &session).await?;
5295 }
5296
5297 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5298 live_kit
5299 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5300 .await
5301 .trace_err();
5302
5303 if delete_live_kit_room {
5304 live_kit.delete_room(live_kit_room).await.trace_err();
5305 }
5306 }
5307
5308 Ok(())
5309}
5310
5311async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5312 let left_channel_buffers = session
5313 .db()
5314 .await
5315 .leave_channel_buffers(session.connection_id)
5316 .await?;
5317
5318 for left_buffer in left_channel_buffers {
5319 channel_buffer_updated(
5320 session.connection_id,
5321 left_buffer.connections,
5322 &proto::UpdateChannelBufferCollaborators {
5323 channel_id: left_buffer.channel_id.to_proto(),
5324 collaborators: left_buffer.collaborators,
5325 },
5326 &session.peer,
5327 );
5328 }
5329
5330 Ok(())
5331}
5332
5333fn project_left(project: &db::LeftProject, session: &UserSession) {
5334 for connection_id in &project.connection_ids {
5335 if project.should_unshare {
5336 session
5337 .peer
5338 .send(
5339 *connection_id,
5340 proto::UnshareProject {
5341 project_id: project.id.to_proto(),
5342 },
5343 )
5344 .trace_err();
5345 } else {
5346 session
5347 .peer
5348 .send(
5349 *connection_id,
5350 proto::RemoveProjectCollaborator {
5351 project_id: project.id.to_proto(),
5352 peer_id: Some(session.connection_id.into()),
5353 },
5354 )
5355 .trace_err();
5356 }
5357 }
5358}
5359
5360pub trait ResultExt {
5361 type Ok;
5362
5363 fn trace_err(self) -> Option<Self::Ok>;
5364}
5365
5366impl<T, E> ResultExt for Result<T, E>
5367where
5368 E: std::fmt::Debug,
5369{
5370 type Ok = T;
5371
5372 #[track_caller]
5373 fn trace_err(self) -> Option<T> {
5374 match self {
5375 Ok(value) => Some(value),
5376 Err(error) => {
5377 tracing::error!("{:?}", error);
5378 None
5379 }
5380 }
5381 }
5382}