1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use sha2::Digest;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 channel::oneshot,
44 future::{self, BoxFuture},
45 stream::FuturesUnordered,
46 FutureExt, SinkExt, StreamExt, TryStreamExt,
47};
48use http_client::IsahcHttpClient;
49use prometheus::{register_int_gauge, IntGauge};
50use rpc::{
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
56};
57use semantic_version::SemanticVersion;
58use serde::{Serialize, Serializer};
59use std::{
60 any::TypeId,
61 future::Future,
62 marker::PhantomData,
63 mem,
64 net::SocketAddr,
65 ops::{Deref, DerefMut},
66 rc::Rc,
67 sync::{
68 atomic::{AtomicBool, Ordering::SeqCst},
69 Arc, OnceLock,
70 },
71 time::{Duration, Instant},
72};
73use time::OffsetDateTime;
74use tokio::sync::{watch, MutexGuard, Semaphore};
75use tower::ServiceBuilder;
76use tracing::{
77 field::{self},
78 info_span, instrument, Instrument,
79};
80
81use self::connection_pool::VersionedMessage;
82
83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
84
85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
87
88const MESSAGE_COUNT_PER_PAGE: usize = 100;
89const MAX_MESSAGE_LEN: usize = 1024;
90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
94
95struct Response<R> {
96 peer: Arc<Peer>,
97 receipt: Receipt<R>,
98 responded: Arc<AtomicBool>,
99}
100
101impl<R: RequestMessage> Response<R> {
102 fn send(self, payload: R::Response) -> Result<()> {
103 self.responded.store(true, SeqCst);
104 self.peer.respond(self.receipt, payload)?;
105 Ok(())
106 }
107}
108
109#[derive(Clone, Debug)]
110pub enum Principal {
111 User(User),
112 Impersonated { user: User, admin: User },
113 DevServer(dev_server::Model),
114}
115
116impl Principal {
117 fn update_span(&self, span: &tracing::Span) {
118 match &self {
119 Principal::User(user) => {
120 span.record("user_id", &user.id.0);
121 span.record("login", &user.github_login);
122 }
123 Principal::Impersonated { user, admin } => {
124 span.record("user_id", &user.id.0);
125 span.record("login", &user.github_login);
126 span.record("impersonator", &admin.github_login);
127 }
128 Principal::DevServer(dev_server) => {
129 span.record("dev_server_id", &dev_server.id.0);
130 }
131 }
132 }
133}
134
135#[derive(Clone)]
136struct Session {
137 principal: Principal,
138 connection_id: ConnectionId,
139 db: Arc<tokio::sync::Mutex<DbHandle>>,
140 peer: Arc<Peer>,
141 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
142 app_state: Arc<AppState>,
143 supermaven_client: Option<Arc<SupermavenAdminApi>>,
144 http_client: Arc<IsahcHttpClient>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn for_dev_server(self) -> Option<DevServerSession> {
176 DevServerSession::new(self)
177 }
178
179 fn user_id(&self) -> Option<UserId> {
180 match &self.principal {
181 Principal::User(user) => Some(user.id),
182 Principal::Impersonated { user, .. } => Some(user.id),
183 Principal::DevServer(_) => None,
184 }
185 }
186
187 fn is_staff(&self) -> bool {
188 match &self.principal {
189 Principal::User(user) => user.admin,
190 Principal::Impersonated { .. } => true,
191 Principal::DevServer(_) => false,
192 }
193 }
194
195 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
196 if self.is_staff() {
197 return Ok(proto::Plan::ZedPro);
198 }
199
200 let Some(user_id) = self.user_id() else {
201 return Ok(proto::Plan::Free);
202 };
203
204 if db.has_active_billing_subscription(user_id).await? {
205 Ok(proto::Plan::ZedPro)
206 } else {
207 Ok(proto::Plan::Free)
208 }
209 }
210
211 fn dev_server_id(&self) -> Option<DevServerId> {
212 match &self.principal {
213 Principal::User(_) | Principal::Impersonated { .. } => None,
214 Principal::DevServer(dev_server) => Some(dev_server.id),
215 }
216 }
217
218 fn principal_id(&self) -> PrincipalId {
219 match &self.principal {
220 Principal::User(user) => PrincipalId::UserId(user.id),
221 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
222 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
223 }
224 }
225}
226
227impl Debug for Session {
228 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
229 let mut result = f.debug_struct("Session");
230 match &self.principal {
231 Principal::User(user) => {
232 result.field("user", &user.github_login);
233 }
234 Principal::Impersonated { user, admin } => {
235 result.field("user", &user.github_login);
236 result.field("impersonator", &admin.github_login);
237 }
238 Principal::DevServer(dev_server) => {
239 result.field("dev_server", &dev_server.id);
240 }
241 }
242 result.field("connection_id", &self.connection_id).finish()
243 }
244}
245
246struct UserSession(Session);
247
248impl UserSession {
249 pub fn new(s: Session) -> Option<Self> {
250 s.user_id().map(|_| UserSession(s))
251 }
252 pub fn user_id(&self) -> UserId {
253 self.0.user_id().unwrap()
254 }
255
256 pub fn email(&self) -> Option<String> {
257 match &self.0.principal {
258 Principal::User(user) => user.email_address.clone(),
259 Principal::Impersonated { user, .. } => user.email_address.clone(),
260 Principal::DevServer(..) => None,
261 }
262 }
263}
264
265impl Deref for UserSession {
266 type Target = Session;
267
268 fn deref(&self) -> &Self::Target {
269 &self.0
270 }
271}
272impl DerefMut for UserSession {
273 fn deref_mut(&mut self) -> &mut Self::Target {
274 &mut self.0
275 }
276}
277
278struct DevServerSession(Session);
279
280impl DevServerSession {
281 pub fn new(s: Session) -> Option<Self> {
282 s.dev_server_id().map(|_| DevServerSession(s))
283 }
284 pub fn dev_server_id(&self) -> DevServerId {
285 self.0.dev_server_id().unwrap()
286 }
287
288 fn dev_server(&self) -> &dev_server::Model {
289 match &self.0.principal {
290 Principal::DevServer(dev_server) => dev_server,
291 _ => unreachable!(),
292 }
293 }
294}
295
296impl Deref for DevServerSession {
297 type Target = Session;
298
299 fn deref(&self) -> &Self::Target {
300 &self.0
301 }
302}
303impl DerefMut for DevServerSession {
304 fn deref_mut(&mut self) -> &mut Self::Target {
305 &mut self.0
306 }
307}
308
309fn user_handler<M: RequestMessage, Fut>(
310 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
311) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
312where
313 Fut: Send + Future<Output = Result<()>>,
314{
315 let handler = Arc::new(handler);
316 move |message, response, session| {
317 let handler = handler.clone();
318 Box::pin(async move {
319 if let Some(user_session) = session.for_user() {
320 Ok(handler(message, response, user_session).await?)
321 } else {
322 Err(Error::Internal(anyhow!(
323 "must be a user to call {}",
324 M::NAME
325 )))
326 }
327 })
328 }
329}
330
331fn dev_server_handler<M: RequestMessage, Fut>(
332 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
333) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
334where
335 Fut: Send + Future<Output = Result<()>>,
336{
337 let handler = Arc::new(handler);
338 move |message, response, session| {
339 let handler = handler.clone();
340 Box::pin(async move {
341 if let Some(dev_server_session) = session.for_dev_server() {
342 Ok(handler(message, response, dev_server_session).await?)
343 } else {
344 Err(Error::Internal(anyhow!(
345 "must be a dev server to call {}",
346 M::NAME
347 )))
348 }
349 })
350 }
351}
352
353fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
354 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
355) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
356where
357 InnertRetFut: Send + Future<Output = Result<()>>,
358{
359 let handler = Arc::new(handler);
360 move |message, session| {
361 let handler = handler.clone();
362 Box::pin(async move {
363 if let Some(user_session) = session.for_user() {
364 Ok(handler(message, user_session).await?)
365 } else {
366 Err(Error::Internal(anyhow!(
367 "must be a user to call {}",
368 M::NAME
369 )))
370 }
371 })
372 }
373}
374
375struct DbHandle(Arc<Database>);
376
377impl Deref for DbHandle {
378 type Target = Database;
379
380 fn deref(&self) -> &Self::Target {
381 self.0.as_ref()
382 }
383}
384
385pub struct Server {
386 id: parking_lot::Mutex<ServerId>,
387 peer: Arc<Peer>,
388 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
389 app_state: Arc<AppState>,
390 handlers: HashMap<TypeId, MessageHandler>,
391 teardown: watch::Sender<bool>,
392}
393
394pub(crate) struct ConnectionPoolGuard<'a> {
395 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
396 _not_send: PhantomData<Rc<()>>,
397}
398
399#[derive(Serialize)]
400pub struct ServerSnapshot<'a> {
401 peer: &'a Peer,
402 #[serde(serialize_with = "serialize_deref")]
403 connection_pool: ConnectionPoolGuard<'a>,
404}
405
406pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
407where
408 S: Serializer,
409 T: Deref<Target = U>,
410 U: Serialize,
411{
412 Serialize::serialize(value.deref(), serializer)
413}
414
415impl Server {
416 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
417 let mut server = Self {
418 id: parking_lot::Mutex::new(id),
419 peer: Peer::new(id.0 as u32),
420 app_state: app_state.clone(),
421 connection_pool: Default::default(),
422 handlers: Default::default(),
423 teardown: watch::channel(false).0,
424 };
425
426 server
427 .add_request_handler(ping)
428 .add_request_handler(user_handler(create_room))
429 .add_request_handler(user_handler(join_room))
430 .add_request_handler(user_handler(rejoin_room))
431 .add_request_handler(user_handler(leave_room))
432 .add_request_handler(user_handler(set_room_participant_role))
433 .add_request_handler(user_handler(call))
434 .add_request_handler(user_handler(cancel_call))
435 .add_message_handler(user_message_handler(decline_call))
436 .add_request_handler(user_handler(update_participant_location))
437 .add_request_handler(user_handler(share_project))
438 .add_message_handler(unshare_project)
439 .add_request_handler(user_handler(join_project))
440 .add_request_handler(user_handler(join_hosted_project))
441 .add_request_handler(user_handler(rejoin_dev_server_projects))
442 .add_request_handler(user_handler(create_dev_server_project))
443 .add_request_handler(user_handler(update_dev_server_project))
444 .add_request_handler(user_handler(delete_dev_server_project))
445 .add_request_handler(user_handler(create_dev_server))
446 .add_request_handler(user_handler(regenerate_dev_server_token))
447 .add_request_handler(user_handler(rename_dev_server))
448 .add_request_handler(user_handler(delete_dev_server))
449 .add_request_handler(user_handler(list_remote_directory))
450 .add_request_handler(dev_server_handler(share_dev_server_project))
451 .add_request_handler(dev_server_handler(shutdown_dev_server))
452 .add_request_handler(dev_server_handler(reconnect_dev_server))
453 .add_message_handler(user_message_handler(leave_project))
454 .add_request_handler(update_project)
455 .add_request_handler(update_worktree)
456 .add_message_handler(start_language_server)
457 .add_message_handler(update_language_server)
458 .add_message_handler(update_diagnostic_summary)
459 .add_message_handler(update_worktree_settings)
460 .add_request_handler(user_handler(
461 forward_project_request_for_owner::<proto::TaskContextForLocation>,
462 ))
463 .add_request_handler(user_handler(
464 forward_project_request_for_owner::<proto::TaskTemplates>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetHover>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::GetDefinition>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetTypeDefinition>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetReferences>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::SearchProject>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDocumentHighlights>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetProjectSymbols>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferById>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::SynchronizeBuffers>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::InlayHints>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferByPath>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::GetCompletions>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
507 ))
508 .add_request_handler(user_handler(
509 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCodeActions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCodeAction>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::PrepareRename>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::PerformRename>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::ReloadBuffers>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::FormatBuffers>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::CreateProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::RenameProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::CopyProjectEntry>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::DeleteProjectEntry>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::ExpandProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::OnTypeFormatting>,
549 ))
550 .add_request_handler(user_handler(
551 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::BlameBuffer>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::MultiLspQuery>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::RestartLanguageServers>,
561 ))
562 .add_request_handler(user_handler(
563 forward_mutating_project_request::<proto::LinkedEditingRange>,
564 ))
565 .add_message_handler(create_buffer_for_peer)
566 .add_request_handler(update_buffer)
567 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
568 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
572 .add_request_handler(get_users)
573 .add_request_handler(user_handler(fuzzy_search_users))
574 .add_request_handler(user_handler(request_contact))
575 .add_request_handler(user_handler(remove_contact))
576 .add_request_handler(user_handler(respond_to_contact_request))
577 .add_message_handler(subscribe_to_channels)
578 .add_request_handler(user_handler(create_channel))
579 .add_request_handler(user_handler(delete_channel))
580 .add_request_handler(user_handler(invite_channel_member))
581 .add_request_handler(user_handler(remove_channel_member))
582 .add_request_handler(user_handler(set_channel_member_role))
583 .add_request_handler(user_handler(set_channel_visibility))
584 .add_request_handler(user_handler(rename_channel))
585 .add_request_handler(user_handler(join_channel_buffer))
586 .add_request_handler(user_handler(leave_channel_buffer))
587 .add_message_handler(user_message_handler(update_channel_buffer))
588 .add_request_handler(user_handler(rejoin_channel_buffers))
589 .add_request_handler(user_handler(get_channel_members))
590 .add_request_handler(user_handler(respond_to_channel_invite))
591 .add_request_handler(user_handler(join_channel))
592 .add_request_handler(user_handler(join_channel_chat))
593 .add_message_handler(user_message_handler(leave_channel_chat))
594 .add_request_handler(user_handler(send_channel_message))
595 .add_request_handler(user_handler(remove_channel_message))
596 .add_request_handler(user_handler(update_channel_message))
597 .add_request_handler(user_handler(get_channel_messages))
598 .add_request_handler(user_handler(get_channel_messages_by_id))
599 .add_request_handler(user_handler(get_notifications))
600 .add_request_handler(user_handler(mark_notification_as_read))
601 .add_request_handler(user_handler(move_channel))
602 .add_request_handler(user_handler(follow))
603 .add_message_handler(user_message_handler(unfollow))
604 .add_message_handler(user_message_handler(update_followers))
605 .add_request_handler(user_handler(get_private_user_info))
606 .add_request_handler(user_handler(get_llm_api_token))
607 .add_request_handler(user_handler(accept_terms_of_service))
608 .add_message_handler(user_message_handler(acknowledge_channel_message))
609 .add_message_handler(user_message_handler(acknowledge_buffer_version))
610 .add_request_handler(user_handler(get_supermaven_api_key))
611 .add_request_handler(user_handler(
612 forward_mutating_project_request::<proto::OpenContext>,
613 ))
614 .add_request_handler(user_handler(
615 forward_mutating_project_request::<proto::CreateContext>,
616 ))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::SynchronizeContexts>,
619 ))
620 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
621 .add_message_handler(update_context)
622 .add_request_handler({
623 let app_state = app_state.clone();
624 move |request, response, session| {
625 let app_state = app_state.clone();
626 async move {
627 count_language_model_tokens(request, response, session, &app_state.config)
628 .await
629 }
630 }
631 })
632 .add_request_handler({
633 user_handler(move |request, response, session| {
634 get_cached_embeddings(request, response, session)
635 })
636 })
637 .add_request_handler({
638 let app_state = app_state.clone();
639 user_handler(move |request, response, session| {
640 compute_embeddings(
641 request,
642 response,
643 session,
644 app_state.config.openai_api_key.clone(),
645 )
646 })
647 });
648
649 Arc::new(server)
650 }
651
652 pub async fn start(&self) -> Result<()> {
653 let server_id = *self.id.lock();
654 let app_state = self.app_state.clone();
655 let peer = self.peer.clone();
656 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
657 let pool = self.connection_pool.clone();
658 let live_kit_client = self.app_state.live_kit_client.clone();
659
660 let span = info_span!("start server");
661 self.app_state.executor.spawn_detached(
662 async move {
663 tracing::info!("waiting for cleanup timeout");
664 timeout.await;
665 tracing::info!("cleanup timeout expired, retrieving stale rooms");
666 if let Some((room_ids, channel_ids)) = app_state
667 .db
668 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
669 .await
670 .trace_err()
671 {
672 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
673 tracing::info!(
674 stale_channel_buffer_count = channel_ids.len(),
675 "retrieved stale channel buffers"
676 );
677
678 for channel_id in channel_ids {
679 if let Some(refreshed_channel_buffer) = app_state
680 .db
681 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
682 .await
683 .trace_err()
684 {
685 for connection_id in refreshed_channel_buffer.connection_ids {
686 peer.send(
687 connection_id,
688 proto::UpdateChannelBufferCollaborators {
689 channel_id: channel_id.to_proto(),
690 collaborators: refreshed_channel_buffer
691 .collaborators
692 .clone(),
693 },
694 )
695 .trace_err();
696 }
697 }
698 }
699
700 for room_id in room_ids {
701 let mut contacts_to_update = HashSet::default();
702 let mut canceled_calls_to_user_ids = Vec::new();
703 let mut live_kit_room = String::new();
704 let mut delete_live_kit_room = false;
705
706 if let Some(mut refreshed_room) = app_state
707 .db
708 .clear_stale_room_participants(room_id, server_id)
709 .await
710 .trace_err()
711 {
712 tracing::info!(
713 room_id = room_id.0,
714 new_participant_count = refreshed_room.room.participants.len(),
715 "refreshed room"
716 );
717 room_updated(&refreshed_room.room, &peer);
718 if let Some(channel) = refreshed_room.channel.as_ref() {
719 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
720 }
721 contacts_to_update
722 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
723 contacts_to_update
724 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
725 canceled_calls_to_user_ids =
726 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
727 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
728 delete_live_kit_room = refreshed_room.room.participants.is_empty();
729 }
730
731 {
732 let pool = pool.lock();
733 for canceled_user_id in canceled_calls_to_user_ids {
734 for connection_id in pool.user_connection_ids(canceled_user_id) {
735 peer.send(
736 connection_id,
737 proto::CallCanceled {
738 room_id: room_id.to_proto(),
739 },
740 )
741 .trace_err();
742 }
743 }
744 }
745
746 for user_id in contacts_to_update {
747 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
748 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
749 if let Some((busy, contacts)) = busy.zip(contacts) {
750 let pool = pool.lock();
751 let updated_contact = contact_for_user(user_id, busy, &pool);
752 for contact in contacts {
753 if let db::Contact::Accepted {
754 user_id: contact_user_id,
755 ..
756 } = contact
757 {
758 for contact_conn_id in
759 pool.user_connection_ids(contact_user_id)
760 {
761 peer.send(
762 contact_conn_id,
763 proto::UpdateContacts {
764 contacts: vec![updated_contact.clone()],
765 remove_contacts: Default::default(),
766 incoming_requests: Default::default(),
767 remove_incoming_requests: Default::default(),
768 outgoing_requests: Default::default(),
769 remove_outgoing_requests: Default::default(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776 }
777 }
778
779 if let Some(live_kit) = live_kit_client.as_ref() {
780 if delete_live_kit_room {
781 live_kit.delete_room(live_kit_room).await.trace_err();
782 }
783 }
784 }
785 }
786
787 app_state
788 .db
789 .delete_stale_servers(&app_state.config.zed_environment, server_id)
790 .await
791 .trace_err();
792 }
793 .instrument(span),
794 );
795 Ok(())
796 }
797
798 pub fn teardown(&self) {
799 self.peer.teardown();
800 self.connection_pool.lock().reset();
801 let _ = self.teardown.send(true);
802 }
803
804 #[cfg(test)]
805 pub fn reset(&self, id: ServerId) {
806 self.teardown();
807 *self.id.lock() = id;
808 self.peer.reset(id.0 as u32);
809 let _ = self.teardown.send(false);
810 }
811
812 #[cfg(test)]
813 pub fn id(&self) -> ServerId {
814 *self.id.lock()
815 }
816
817 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
818 where
819 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
820 Fut: 'static + Send + Future<Output = Result<()>>,
821 M: EnvelopedMessage,
822 {
823 let prev_handler = self.handlers.insert(
824 TypeId::of::<M>(),
825 Box::new(move |envelope, session| {
826 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
827 let received_at = envelope.received_at;
828 tracing::info!("message received");
829 let start_time = Instant::now();
830 let future = (handler)(*envelope, session);
831 async move {
832 let result = future.await;
833 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
834 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
835 let queue_duration_ms = total_duration_ms - processing_duration_ms;
836 let payload_type = M::NAME;
837
838 match result {
839 Err(error) => {
840 tracing::error!(
841 ?error,
842 total_duration_ms,
843 processing_duration_ms,
844 queue_duration_ms,
845 payload_type,
846 "error handling message"
847 )
848 }
849 Ok(()) => tracing::info!(
850 total_duration_ms,
851 processing_duration_ms,
852 queue_duration_ms,
853 "finished handling message"
854 ),
855 }
856 }
857 .boxed()
858 }),
859 );
860 if prev_handler.is_some() {
861 panic!("registered a handler for the same message twice");
862 }
863 self
864 }
865
866 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
869 Fut: 'static + Send + Future<Output = Result<()>>,
870 M: EnvelopedMessage,
871 {
872 self.add_handler(move |envelope, session| handler(envelope.payload, session));
873 self
874 }
875
876 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
879 Fut: Send + Future<Output = Result<()>>,
880 M: RequestMessage,
881 {
882 let handler = Arc::new(handler);
883 self.add_handler(move |envelope, session| {
884 let receipt = envelope.receipt();
885 let handler = handler.clone();
886 async move {
887 let peer = session.peer.clone();
888 let responded = Arc::new(AtomicBool::default());
889 let response = Response {
890 peer: peer.clone(),
891 responded: responded.clone(),
892 receipt,
893 };
894 match (handler)(envelope.payload, response, session).await {
895 Ok(()) => {
896 if responded.load(std::sync::atomic::Ordering::SeqCst) {
897 Ok(())
898 } else {
899 Err(anyhow!("handler did not send a response"))?
900 }
901 }
902 Err(error) => {
903 let proto_err = match &error {
904 Error::Internal(err) => err.to_proto(),
905 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
906 };
907 peer.respond_with_error(receipt, proto_err)?;
908 Err(error)
909 }
910 }
911 }
912 })
913 }
914
915 #[allow(clippy::too_many_arguments)]
916 pub fn handle_connection(
917 self: &Arc<Self>,
918 connection: Connection,
919 address: String,
920 principal: Principal,
921 zed_version: ZedVersion,
922 geoip_country_code: Option<String>,
923 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
924 executor: Executor,
925 ) -> impl Future<Output = ()> {
926 let this = self.clone();
927 let span = info_span!("handle connection", %address,
928 connection_id=field::Empty,
929 user_id=field::Empty,
930 login=field::Empty,
931 impersonator=field::Empty,
932 dev_server_id=field::Empty,
933 geoip_country_code=field::Empty
934 );
935 principal.update_span(&span);
936 if let Some(country_code) = geoip_country_code.as_ref() {
937 span.record("geoip_country_code", country_code);
938 }
939
940 let mut teardown = self.teardown.subscribe();
941 async move {
942 if *teardown.borrow() {
943 tracing::error!("server is tearing down");
944 return
945 }
946 let (connection_id, handle_io, mut incoming_rx) = this
947 .peer
948 .add_connection(connection, {
949 let executor = executor.clone();
950 move |duration| executor.sleep(duration)
951 });
952 tracing::Span::current().record("connection_id", format!("{}", connection_id));
953
954 tracing::info!("connection opened");
955
956 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
957 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
958 Ok(http_client) => Arc::new(http_client),
959 Err(error) => {
960 tracing::error!(?error, "failed to create HTTP client");
961 return;
962 }
963 };
964
965 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
966 Some(Arc::new(SupermavenAdminApi::new(
967 supermaven_admin_api_key.to_string(),
968 http_client.clone(),
969 )))
970 } else {
971 None
972 };
973
974 let session = Session {
975 principal: principal.clone(),
976 connection_id,
977 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
978 peer: this.peer.clone(),
979 connection_pool: this.connection_pool.clone(),
980 app_state: this.app_state.clone(),
981 http_client,
982 geoip_country_code,
983 _executor: executor.clone(),
984 supermaven_client,
985 };
986
987 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
988 tracing::error!(?error, "failed to send initial client update");
989 return;
990 }
991
992 let handle_io = handle_io.fuse();
993 futures::pin_mut!(handle_io);
994
995 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
996 // This prevents deadlocks when e.g., client A performs a request to client B and
997 // client B performs a request to client A. If both clients stop processing further
998 // messages until their respective request completes, they won't have a chance to
999 // respond to the other client's request and cause a deadlock.
1000 //
1001 // This arrangement ensures we will attempt to process earlier messages first, but fall
1002 // back to processing messages arrived later in the spirit of making progress.
1003 let mut foreground_message_handlers = FuturesUnordered::new();
1004 let concurrent_handlers = Arc::new(Semaphore::new(256));
1005 loop {
1006 let next_message = async {
1007 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1008 let message = incoming_rx.next().await;
1009 (permit, message)
1010 }.fuse();
1011 futures::pin_mut!(next_message);
1012 futures::select_biased! {
1013 _ = teardown.changed().fuse() => return,
1014 result = handle_io => {
1015 if let Err(error) = result {
1016 tracing::error!(?error, "error handling I/O");
1017 }
1018 break;
1019 }
1020 _ = foreground_message_handlers.next() => {}
1021 next_message = next_message => {
1022 let (permit, message) = next_message;
1023 if let Some(message) = message {
1024 let type_name = message.payload_type_name();
1025 // note: we copy all the fields from the parent span so we can query them in the logs.
1026 // (https://github.com/tokio-rs/tracing/issues/2670).
1027 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1028 user_id=field::Empty,
1029 login=field::Empty,
1030 impersonator=field::Empty,
1031 dev_server_id=field::Empty
1032 );
1033 principal.update_span(&span);
1034 let span_enter = span.enter();
1035 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1036 let is_background = message.is_background();
1037 let handle_message = (handler)(message, session.clone());
1038 drop(span_enter);
1039
1040 let handle_message = async move {
1041 handle_message.await;
1042 drop(permit);
1043 }.instrument(span);
1044 if is_background {
1045 executor.spawn_detached(handle_message);
1046 } else {
1047 foreground_message_handlers.push(handle_message);
1048 }
1049 } else {
1050 tracing::error!("no message handler");
1051 }
1052 } else {
1053 tracing::info!("connection closed");
1054 break;
1055 }
1056 }
1057 }
1058 }
1059
1060 drop(foreground_message_handlers);
1061 tracing::info!("signing out");
1062 if let Err(error) = connection_lost(session, teardown, executor).await {
1063 tracing::error!(?error, "error signing out");
1064 }
1065
1066 }.instrument(span)
1067 }
1068
1069 async fn send_initial_client_update(
1070 &self,
1071 connection_id: ConnectionId,
1072 principal: &Principal,
1073 zed_version: ZedVersion,
1074 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1075 session: &Session,
1076 ) -> Result<()> {
1077 self.peer.send(
1078 connection_id,
1079 proto::Hello {
1080 peer_id: Some(connection_id.into()),
1081 },
1082 )?;
1083 tracing::info!("sent hello message");
1084 if let Some(send_connection_id) = send_connection_id.take() {
1085 let _ = send_connection_id.send(connection_id);
1086 }
1087
1088 match principal {
1089 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1090 if !user.connected_once {
1091 self.peer.send(connection_id, proto::ShowContacts {})?;
1092 self.app_state
1093 .db
1094 .set_user_connected_once(user.id, true)
1095 .await?;
1096 }
1097
1098 update_user_plan(user.id, session).await?;
1099
1100 let (contacts, dev_server_projects) = future::try_join(
1101 self.app_state.db.get_contacts(user.id),
1102 self.app_state.db.dev_server_projects_update(user.id),
1103 )
1104 .await?;
1105
1106 {
1107 let mut pool = self.connection_pool.lock();
1108 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1109 self.peer.send(
1110 connection_id,
1111 build_initial_contacts_update(contacts, &pool),
1112 )?;
1113 }
1114
1115 if should_auto_subscribe_to_channels(zed_version) {
1116 subscribe_user_to_channels(user.id, session).await?;
1117 }
1118
1119 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1120
1121 if let Some(incoming_call) =
1122 self.app_state.db.incoming_call_for_user(user.id).await?
1123 {
1124 self.peer.send(connection_id, incoming_call)?;
1125 }
1126
1127 update_user_contacts(user.id, &session).await?;
1128 }
1129 Principal::DevServer(dev_server) => {
1130 {
1131 let mut pool = self.connection_pool.lock();
1132 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1133 {
1134 self.peer.send(
1135 stale_connection_id,
1136 proto::ShutdownDevServer {
1137 reason: Some(
1138 "another dev server connected with the same token".to_string(),
1139 ),
1140 },
1141 )?;
1142 pool.remove_connection(stale_connection_id)?;
1143 };
1144 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145 }
1146
1147 let projects = self
1148 .app_state
1149 .db
1150 .get_projects_for_dev_server(dev_server.id)
1151 .await?;
1152 self.peer
1153 .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155 let status = self
1156 .app_state
1157 .db
1158 .dev_server_projects_update(dev_server.user_id)
1159 .await?;
1160 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1161 }
1162 }
1163
1164 Ok(())
1165 }
1166
1167 pub async fn invite_code_redeemed(
1168 self: &Arc<Self>,
1169 inviter_id: UserId,
1170 invitee_id: UserId,
1171 ) -> Result<()> {
1172 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173 if let Some(code) = &user.invite_code {
1174 let pool = self.connection_pool.lock();
1175 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176 for connection_id in pool.user_connection_ids(inviter_id) {
1177 self.peer.send(
1178 connection_id,
1179 proto::UpdateContacts {
1180 contacts: vec![invitee_contact.clone()],
1181 ..Default::default()
1182 },
1183 )?;
1184 self.peer.send(
1185 connection_id,
1186 proto::UpdateInviteInfo {
1187 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188 count: user.invite_count as u32,
1189 },
1190 )?;
1191 }
1192 }
1193 }
1194 Ok(())
1195 }
1196
1197 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199 if let Some(invite_code) = &user.invite_code {
1200 let pool = self.connection_pool.lock();
1201 for connection_id in pool.user_connection_ids(user_id) {
1202 self.peer.send(
1203 connection_id,
1204 proto::UpdateInviteInfo {
1205 url: format!(
1206 "{}{}",
1207 self.app_state.config.invite_link_prefix, invite_code
1208 ),
1209 count: user.invite_count as u32,
1210 },
1211 )?;
1212 }
1213 }
1214 }
1215 Ok(())
1216 }
1217
1218 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219 ServerSnapshot {
1220 connection_pool: ConnectionPoolGuard {
1221 guard: self.connection_pool.lock(),
1222 _not_send: PhantomData,
1223 },
1224 peer: &self.peer,
1225 }
1226 }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230 type Target = ConnectionPool;
1231
1232 fn deref(&self) -> &Self::Target {
1233 &self.guard
1234 }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238 fn deref_mut(&mut self) -> &mut Self::Target {
1239 &mut self.guard
1240 }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244 fn drop(&mut self) {
1245 #[cfg(test)]
1246 self.check_invariants();
1247 }
1248}
1249
1250fn broadcast<F>(
1251 sender_id: Option<ConnectionId>,
1252 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253 mut f: F,
1254) where
1255 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257 for receiver_id in receiver_ids {
1258 if Some(receiver_id) != sender_id {
1259 if let Err(error) = f(receiver_id) {
1260 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261 }
1262 }
1263 }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269 fn name() -> &'static HeaderName {
1270 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272 }
1273
1274 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275 where
1276 Self: Sized,
1277 I: Iterator<Item = &'i axum::http::HeaderValue>,
1278 {
1279 let version = values
1280 .next()
1281 .ok_or_else(axum::headers::Error::invalid)?
1282 .to_str()
1283 .map_err(|_| axum::headers::Error::invalid())?
1284 .parse()
1285 .map_err(|_| axum::headers::Error::invalid())?;
1286 Ok(Self(version))
1287 }
1288
1289 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290 values.extend([self.0.to_string().parse().unwrap()]);
1291 }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296 fn name() -> &'static HeaderName {
1297 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299 }
1300
1301 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302 where
1303 Self: Sized,
1304 I: Iterator<Item = &'i axum::http::HeaderValue>,
1305 {
1306 let version = values
1307 .next()
1308 .ok_or_else(axum::headers::Error::invalid)?
1309 .to_str()
1310 .map_err(|_| axum::headers::Error::invalid())?
1311 .parse()
1312 .map_err(|_| axum::headers::Error::invalid())?;
1313 Ok(Self(version))
1314 }
1315
1316 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317 values.extend([self.0.to_string().parse().unwrap()]);
1318 }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322 Router::new()
1323 .route("/rpc", get(handle_websocket_request))
1324 .layer(
1325 ServiceBuilder::new()
1326 .layer(Extension(server.app_state.clone()))
1327 .layer(middleware::from_fn(auth::validate_header)),
1328 )
1329 .route("/metrics", get(handle_metrics))
1330 .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337 Extension(server): Extension<Arc<Server>>,
1338 Extension(principal): Extension<Principal>,
1339 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1340 ws: WebSocketUpgrade,
1341) -> axum::response::Response {
1342 if protocol_version != rpc::PROTOCOL_VERSION {
1343 return (
1344 StatusCode::UPGRADE_REQUIRED,
1345 "client must be upgraded".to_string(),
1346 )
1347 .into_response();
1348 }
1349
1350 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1351 return (
1352 StatusCode::UPGRADE_REQUIRED,
1353 "no version header found".to_string(),
1354 )
1355 .into_response();
1356 };
1357
1358 if !version.can_collaborate() {
1359 return (
1360 StatusCode::UPGRADE_REQUIRED,
1361 "client must be upgraded".to_string(),
1362 )
1363 .into_response();
1364 }
1365
1366 let socket_address = socket_address.to_string();
1367 ws.on_upgrade(move |socket| {
1368 let socket = socket
1369 .map_ok(to_tungstenite_message)
1370 .err_into()
1371 .with(|message| async move { to_axum_message(message) });
1372 let connection = Connection::new(Box::pin(socket));
1373 async move {
1374 server
1375 .handle_connection(
1376 connection,
1377 socket_address,
1378 principal,
1379 version,
1380 country_code_header.map(|header| header.to_string()),
1381 None,
1382 Executor::Production,
1383 )
1384 .await;
1385 }
1386 })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391 let connections_metric = CONNECTIONS_METRIC
1392 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394 let connections = server
1395 .connection_pool
1396 .lock()
1397 .connections()
1398 .filter(|connection| !connection.admin)
1399 .count();
1400 connections_metric.set(connections as _);
1401
1402 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404 register_int_gauge!(
1405 "shared_projects",
1406 "number of open projects with one or more guests"
1407 )
1408 .unwrap()
1409 });
1410
1411 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412 shared_projects_metric.set(shared_projects as _);
1413
1414 let encoder = prometheus::TextEncoder::new();
1415 let metric_families = prometheus::gather();
1416 let encoded_metrics = encoder
1417 .encode_to_string(&metric_families)
1418 .map_err(|err| anyhow!("{}", err))?;
1419 Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424 session: Session,
1425 mut teardown: watch::Receiver<bool>,
1426 executor: Executor,
1427) -> Result<()> {
1428 session.peer.disconnect(session.connection_id);
1429 session
1430 .connection_pool()
1431 .await
1432 .remove_connection(session.connection_id)?;
1433
1434 session
1435 .db()
1436 .await
1437 .connection_lost(session.connection_id)
1438 .await
1439 .trace_err();
1440
1441 futures::select_biased! {
1442 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443 match &session.principal {
1444 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445 let session = session.for_user().unwrap();
1446
1447 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448 leave_room_for_session(&session, session.connection_id).await.trace_err();
1449 leave_channel_buffers_for_session(&session)
1450 .await
1451 .trace_err();
1452
1453 if !session
1454 .connection_pool()
1455 .await
1456 .is_user_online(session.user_id())
1457 {
1458 let db = session.db().await;
1459 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460 room_updated(&room, &session.peer);
1461 }
1462 }
1463
1464 update_user_contacts(session.user_id(), &session).await?;
1465 },
1466 Principal::DevServer(_) => {
1467 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468 },
1469 }
1470 },
1471 _ = teardown.changed().fuse() => {}
1472 }
1473
1474 Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479 response.send(proto::Ack {})?;
1480 Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485 _request: proto::CreateRoom,
1486 response: Response<proto::CreateRoom>,
1487 session: UserSession,
1488) -> Result<()> {
1489 let live_kit_room = nanoid::nanoid!(30);
1490
1491 let live_kit_connection_info = util::maybe!(async {
1492 let live_kit = session.app_state.live_kit_client.as_ref();
1493 let live_kit = live_kit?;
1494 let user_id = session.user_id().to_string();
1495
1496 let token = live_kit
1497 .room_token(&live_kit_room, &user_id.to_string())
1498 .trace_err()?;
1499
1500 Some(proto::LiveKitConnectionInfo {
1501 server_url: live_kit.url().into(),
1502 token,
1503 can_publish: true,
1504 })
1505 })
1506 .await;
1507
1508 let room = session
1509 .db()
1510 .await
1511 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512 .await?;
1513
1514 response.send(proto::CreateRoomResponse {
1515 room: Some(room.clone()),
1516 live_kit_connection_info,
1517 })?;
1518
1519 update_user_contacts(session.user_id(), &session).await?;
1520 Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525 request: proto::JoinRoom,
1526 response: Response<proto::JoinRoom>,
1527 session: UserSession,
1528) -> Result<()> {
1529 let room_id = RoomId::from_proto(request.id);
1530
1531 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533 if let Some(channel_id) = channel_id {
1534 return join_channel_internal(channel_id, Box::new(response), session).await;
1535 }
1536
1537 let joined_room = {
1538 let room = session
1539 .db()
1540 .await
1541 .join_room(room_id, session.user_id(), session.connection_id)
1542 .await?;
1543 room_updated(&room.room, &session.peer);
1544 room.into_inner()
1545 };
1546
1547 for connection_id in session
1548 .connection_pool()
1549 .await
1550 .user_connection_ids(session.user_id())
1551 {
1552 session
1553 .peer
1554 .send(
1555 connection_id,
1556 proto::CallCanceled {
1557 room_id: room_id.to_proto(),
1558 },
1559 )
1560 .trace_err();
1561 }
1562
1563 let live_kit_connection_info =
1564 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1565 if let Some(token) = live_kit
1566 .room_token(
1567 &joined_room.room.live_kit_room,
1568 &session.user_id().to_string(),
1569 )
1570 .trace_err()
1571 {
1572 Some(proto::LiveKitConnectionInfo {
1573 server_url: live_kit.url().into(),
1574 token,
1575 can_publish: true,
1576 })
1577 } else {
1578 None
1579 }
1580 } else {
1581 None
1582 };
1583
1584 response.send(proto::JoinRoomResponse {
1585 room: Some(joined_room.room),
1586 channel_id: None,
1587 live_kit_connection_info,
1588 })?;
1589
1590 update_user_contacts(session.user_id(), &session).await?;
1591 Ok(())
1592}
1593
1594/// Rejoin room is used to reconnect to a room after connection errors.
1595async fn rejoin_room(
1596 request: proto::RejoinRoom,
1597 response: Response<proto::RejoinRoom>,
1598 session: UserSession,
1599) -> Result<()> {
1600 let room;
1601 let channel;
1602 {
1603 let mut rejoined_room = session
1604 .db()
1605 .await
1606 .rejoin_room(request, session.user_id(), session.connection_id)
1607 .await?;
1608
1609 response.send(proto::RejoinRoomResponse {
1610 room: Some(rejoined_room.room.clone()),
1611 reshared_projects: rejoined_room
1612 .reshared_projects
1613 .iter()
1614 .map(|project| proto::ResharedProject {
1615 id: project.id.to_proto(),
1616 collaborators: project
1617 .collaborators
1618 .iter()
1619 .map(|collaborator| collaborator.to_proto())
1620 .collect(),
1621 })
1622 .collect(),
1623 rejoined_projects: rejoined_room
1624 .rejoined_projects
1625 .iter()
1626 .map(|rejoined_project| rejoined_project.to_proto())
1627 .collect(),
1628 })?;
1629 room_updated(&rejoined_room.room, &session.peer);
1630
1631 for project in &rejoined_room.reshared_projects {
1632 for collaborator in &project.collaborators {
1633 session
1634 .peer
1635 .send(
1636 collaborator.connection_id,
1637 proto::UpdateProjectCollaborator {
1638 project_id: project.id.to_proto(),
1639 old_peer_id: Some(project.old_connection_id.into()),
1640 new_peer_id: Some(session.connection_id.into()),
1641 },
1642 )
1643 .trace_err();
1644 }
1645
1646 broadcast(
1647 Some(session.connection_id),
1648 project
1649 .collaborators
1650 .iter()
1651 .map(|collaborator| collaborator.connection_id),
1652 |connection_id| {
1653 session.peer.forward_send(
1654 session.connection_id,
1655 connection_id,
1656 proto::UpdateProject {
1657 project_id: project.id.to_proto(),
1658 worktrees: project.worktrees.clone(),
1659 },
1660 )
1661 },
1662 );
1663 }
1664
1665 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1666
1667 let rejoined_room = rejoined_room.into_inner();
1668
1669 room = rejoined_room.room;
1670 channel = rejoined_room.channel;
1671 }
1672
1673 if let Some(channel) = channel {
1674 channel_updated(
1675 &channel,
1676 &room,
1677 &session.peer,
1678 &*session.connection_pool().await,
1679 );
1680 }
1681
1682 update_user_contacts(session.user_id(), &session).await?;
1683 Ok(())
1684}
1685
1686fn notify_rejoined_projects(
1687 rejoined_projects: &mut Vec<RejoinedProject>,
1688 session: &UserSession,
1689) -> Result<()> {
1690 for project in rejoined_projects.iter() {
1691 for collaborator in &project.collaborators {
1692 session
1693 .peer
1694 .send(
1695 collaborator.connection_id,
1696 proto::UpdateProjectCollaborator {
1697 project_id: project.id.to_proto(),
1698 old_peer_id: Some(project.old_connection_id.into()),
1699 new_peer_id: Some(session.connection_id.into()),
1700 },
1701 )
1702 .trace_err();
1703 }
1704 }
1705
1706 for project in rejoined_projects {
1707 for worktree in mem::take(&mut project.worktrees) {
1708 #[cfg(any(test, feature = "test-support"))]
1709 const MAX_CHUNK_SIZE: usize = 2;
1710 #[cfg(not(any(test, feature = "test-support")))]
1711 const MAX_CHUNK_SIZE: usize = 256;
1712
1713 // Stream this worktree's entries.
1714 let message = proto::UpdateWorktree {
1715 project_id: project.id.to_proto(),
1716 worktree_id: worktree.id,
1717 abs_path: worktree.abs_path.clone(),
1718 root_name: worktree.root_name,
1719 updated_entries: worktree.updated_entries,
1720 removed_entries: worktree.removed_entries,
1721 scan_id: worktree.scan_id,
1722 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1723 updated_repositories: worktree.updated_repositories,
1724 removed_repositories: worktree.removed_repositories,
1725 };
1726 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1727 session.peer.send(session.connection_id, update.clone())?;
1728 }
1729
1730 // Stream this worktree's diagnostics.
1731 for summary in worktree.diagnostic_summaries {
1732 session.peer.send(
1733 session.connection_id,
1734 proto::UpdateDiagnosticSummary {
1735 project_id: project.id.to_proto(),
1736 worktree_id: worktree.id,
1737 summary: Some(summary),
1738 },
1739 )?;
1740 }
1741
1742 for settings_file in worktree.settings_files {
1743 session.peer.send(
1744 session.connection_id,
1745 proto::UpdateWorktreeSettings {
1746 project_id: project.id.to_proto(),
1747 worktree_id: worktree.id,
1748 path: settings_file.path,
1749 content: Some(settings_file.content),
1750 },
1751 )?;
1752 }
1753 }
1754
1755 for language_server in &project.language_servers {
1756 session.peer.send(
1757 session.connection_id,
1758 proto::UpdateLanguageServer {
1759 project_id: project.id.to_proto(),
1760 language_server_id: language_server.id,
1761 variant: Some(
1762 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1763 proto::LspDiskBasedDiagnosticsUpdated {},
1764 ),
1765 ),
1766 },
1767 )?;
1768 }
1769 }
1770 Ok(())
1771}
1772
1773/// leave room disconnects from the room.
1774async fn leave_room(
1775 _: proto::LeaveRoom,
1776 response: Response<proto::LeaveRoom>,
1777 session: UserSession,
1778) -> Result<()> {
1779 leave_room_for_session(&session, session.connection_id).await?;
1780 response.send(proto::Ack {})?;
1781 Ok(())
1782}
1783
1784/// Updates the permissions of someone else in the room.
1785async fn set_room_participant_role(
1786 request: proto::SetRoomParticipantRole,
1787 response: Response<proto::SetRoomParticipantRole>,
1788 session: UserSession,
1789) -> Result<()> {
1790 let user_id = UserId::from_proto(request.user_id);
1791 let role = ChannelRole::from(request.role());
1792
1793 let (live_kit_room, can_publish) = {
1794 let room = session
1795 .db()
1796 .await
1797 .set_room_participant_role(
1798 session.user_id(),
1799 RoomId::from_proto(request.room_id),
1800 user_id,
1801 role,
1802 )
1803 .await?;
1804
1805 let live_kit_room = room.live_kit_room.clone();
1806 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1807 room_updated(&room, &session.peer);
1808 (live_kit_room, can_publish)
1809 };
1810
1811 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1812 live_kit
1813 .update_participant(
1814 live_kit_room.clone(),
1815 request.user_id.to_string(),
1816 live_kit_server::proto::ParticipantPermission {
1817 can_subscribe: true,
1818 can_publish,
1819 can_publish_data: can_publish,
1820 hidden: false,
1821 recorder: false,
1822 },
1823 )
1824 .await
1825 .trace_err();
1826 }
1827
1828 response.send(proto::Ack {})?;
1829 Ok(())
1830}
1831
1832/// Call someone else into the current room
1833async fn call(
1834 request: proto::Call,
1835 response: Response<proto::Call>,
1836 session: UserSession,
1837) -> Result<()> {
1838 let room_id = RoomId::from_proto(request.room_id);
1839 let calling_user_id = session.user_id();
1840 let calling_connection_id = session.connection_id;
1841 let called_user_id = UserId::from_proto(request.called_user_id);
1842 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1843 if !session
1844 .db()
1845 .await
1846 .has_contact(calling_user_id, called_user_id)
1847 .await?
1848 {
1849 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1850 }
1851
1852 let incoming_call = {
1853 let (room, incoming_call) = &mut *session
1854 .db()
1855 .await
1856 .call(
1857 room_id,
1858 calling_user_id,
1859 calling_connection_id,
1860 called_user_id,
1861 initial_project_id,
1862 )
1863 .await?;
1864 room_updated(&room, &session.peer);
1865 mem::take(incoming_call)
1866 };
1867 update_user_contacts(called_user_id, &session).await?;
1868
1869 let mut calls = session
1870 .connection_pool()
1871 .await
1872 .user_connection_ids(called_user_id)
1873 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1874 .collect::<FuturesUnordered<_>>();
1875
1876 while let Some(call_response) = calls.next().await {
1877 match call_response.as_ref() {
1878 Ok(_) => {
1879 response.send(proto::Ack {})?;
1880 return Ok(());
1881 }
1882 Err(_) => {
1883 call_response.trace_err();
1884 }
1885 }
1886 }
1887
1888 {
1889 let room = session
1890 .db()
1891 .await
1892 .call_failed(room_id, called_user_id)
1893 .await?;
1894 room_updated(&room, &session.peer);
1895 }
1896 update_user_contacts(called_user_id, &session).await?;
1897
1898 Err(anyhow!("failed to ring user"))?
1899}
1900
1901/// Cancel an outgoing call.
1902async fn cancel_call(
1903 request: proto::CancelCall,
1904 response: Response<proto::CancelCall>,
1905 session: UserSession,
1906) -> Result<()> {
1907 let called_user_id = UserId::from_proto(request.called_user_id);
1908 let room_id = RoomId::from_proto(request.room_id);
1909 {
1910 let room = session
1911 .db()
1912 .await
1913 .cancel_call(room_id, session.connection_id, called_user_id)
1914 .await?;
1915 room_updated(&room, &session.peer);
1916 }
1917
1918 for connection_id in session
1919 .connection_pool()
1920 .await
1921 .user_connection_ids(called_user_id)
1922 {
1923 session
1924 .peer
1925 .send(
1926 connection_id,
1927 proto::CallCanceled {
1928 room_id: room_id.to_proto(),
1929 },
1930 )
1931 .trace_err();
1932 }
1933 response.send(proto::Ack {})?;
1934
1935 update_user_contacts(called_user_id, &session).await?;
1936 Ok(())
1937}
1938
1939/// Decline an incoming call.
1940async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1941 let room_id = RoomId::from_proto(message.room_id);
1942 {
1943 let room = session
1944 .db()
1945 .await
1946 .decline_call(Some(room_id), session.user_id())
1947 .await?
1948 .ok_or_else(|| anyhow!("failed to decline call"))?;
1949 room_updated(&room, &session.peer);
1950 }
1951
1952 for connection_id in session
1953 .connection_pool()
1954 .await
1955 .user_connection_ids(session.user_id())
1956 {
1957 session
1958 .peer
1959 .send(
1960 connection_id,
1961 proto::CallCanceled {
1962 room_id: room_id.to_proto(),
1963 },
1964 )
1965 .trace_err();
1966 }
1967 update_user_contacts(session.user_id(), &session).await?;
1968 Ok(())
1969}
1970
1971/// Updates other participants in the room with your current location.
1972async fn update_participant_location(
1973 request: proto::UpdateParticipantLocation,
1974 response: Response<proto::UpdateParticipantLocation>,
1975 session: UserSession,
1976) -> Result<()> {
1977 let room_id = RoomId::from_proto(request.room_id);
1978 let location = request
1979 .location
1980 .ok_or_else(|| anyhow!("invalid location"))?;
1981
1982 let db = session.db().await;
1983 let room = db
1984 .update_room_participant_location(room_id, session.connection_id, location)
1985 .await?;
1986
1987 room_updated(&room, &session.peer);
1988 response.send(proto::Ack {})?;
1989 Ok(())
1990}
1991
1992/// Share a project into the room.
1993async fn share_project(
1994 request: proto::ShareProject,
1995 response: Response<proto::ShareProject>,
1996 session: UserSession,
1997) -> Result<()> {
1998 let (project_id, room) = &*session
1999 .db()
2000 .await
2001 .share_project(
2002 RoomId::from_proto(request.room_id),
2003 session.connection_id,
2004 &request.worktrees,
2005 request
2006 .dev_server_project_id
2007 .map(|id| DevServerProjectId::from_proto(id)),
2008 )
2009 .await?;
2010 response.send(proto::ShareProjectResponse {
2011 project_id: project_id.to_proto(),
2012 })?;
2013 room_updated(&room, &session.peer);
2014
2015 Ok(())
2016}
2017
2018/// Unshare a project from the room.
2019async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2020 let project_id = ProjectId::from_proto(message.project_id);
2021 unshare_project_internal(
2022 project_id,
2023 session.connection_id,
2024 session.user_id(),
2025 &session,
2026 )
2027 .await
2028}
2029
2030async fn unshare_project_internal(
2031 project_id: ProjectId,
2032 connection_id: ConnectionId,
2033 user_id: Option<UserId>,
2034 session: &Session,
2035) -> Result<()> {
2036 let delete = {
2037 let room_guard = session
2038 .db()
2039 .await
2040 .unshare_project(project_id, connection_id, user_id)
2041 .await?;
2042
2043 let (delete, room, guest_connection_ids) = &*room_guard;
2044
2045 let message = proto::UnshareProject {
2046 project_id: project_id.to_proto(),
2047 };
2048
2049 broadcast(
2050 Some(connection_id),
2051 guest_connection_ids.iter().copied(),
2052 |conn_id| session.peer.send(conn_id, message.clone()),
2053 );
2054 if let Some(room) = room {
2055 room_updated(room, &session.peer);
2056 }
2057
2058 *delete
2059 };
2060
2061 if delete {
2062 let db = session.db().await;
2063 db.delete_project(project_id).await?;
2064 }
2065
2066 Ok(())
2067}
2068
2069/// DevServer makes a project available online
2070async fn share_dev_server_project(
2071 request: proto::ShareDevServerProject,
2072 response: Response<proto::ShareDevServerProject>,
2073 session: DevServerSession,
2074) -> Result<()> {
2075 let (dev_server_project, user_id, status) = session
2076 .db()
2077 .await
2078 .share_dev_server_project(
2079 DevServerProjectId::from_proto(request.dev_server_project_id),
2080 session.dev_server_id(),
2081 session.connection_id,
2082 &request.worktrees,
2083 )
2084 .await?;
2085 let Some(project_id) = dev_server_project.project_id else {
2086 return Err(anyhow!("failed to share remote project"))?;
2087 };
2088
2089 send_dev_server_projects_update(user_id, status, &session).await;
2090
2091 response.send(proto::ShareProjectResponse { project_id })?;
2092
2093 Ok(())
2094}
2095
2096/// Join someone elses shared project.
2097async fn join_project(
2098 request: proto::JoinProject,
2099 response: Response<proto::JoinProject>,
2100 session: UserSession,
2101) -> Result<()> {
2102 let project_id = ProjectId::from_proto(request.project_id);
2103
2104 tracing::info!(%project_id, "join project");
2105
2106 let db = session.db().await;
2107 let (project, replica_id) = &mut *db
2108 .join_project(project_id, session.connection_id, session.user_id())
2109 .await?;
2110 drop(db);
2111 tracing::info!(%project_id, "join remote project");
2112 join_project_internal(response, session, project, replica_id)
2113}
2114
2115trait JoinProjectInternalResponse {
2116 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2117}
2118impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2119 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2120 Response::<proto::JoinProject>::send(self, result)
2121 }
2122}
2123impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2124 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2125 Response::<proto::JoinHostedProject>::send(self, result)
2126 }
2127}
2128
2129fn join_project_internal(
2130 response: impl JoinProjectInternalResponse,
2131 session: UserSession,
2132 project: &mut Project,
2133 replica_id: &ReplicaId,
2134) -> Result<()> {
2135 let collaborators = project
2136 .collaborators
2137 .iter()
2138 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2139 .map(|collaborator| collaborator.to_proto())
2140 .collect::<Vec<_>>();
2141 let project_id = project.id;
2142 let guest_user_id = session.user_id();
2143
2144 let worktrees = project
2145 .worktrees
2146 .iter()
2147 .map(|(id, worktree)| proto::WorktreeMetadata {
2148 id: *id,
2149 root_name: worktree.root_name.clone(),
2150 visible: worktree.visible,
2151 abs_path: worktree.abs_path.clone(),
2152 })
2153 .collect::<Vec<_>>();
2154
2155 let add_project_collaborator = proto::AddProjectCollaborator {
2156 project_id: project_id.to_proto(),
2157 collaborator: Some(proto::Collaborator {
2158 peer_id: Some(session.connection_id.into()),
2159 replica_id: replica_id.0 as u32,
2160 user_id: guest_user_id.to_proto(),
2161 }),
2162 };
2163
2164 for collaborator in &collaborators {
2165 session
2166 .peer
2167 .send(
2168 collaborator.peer_id.unwrap().into(),
2169 add_project_collaborator.clone(),
2170 )
2171 .trace_err();
2172 }
2173
2174 // First, we send the metadata associated with each worktree.
2175 response.send(proto::JoinProjectResponse {
2176 project_id: project.id.0 as u64,
2177 worktrees: worktrees.clone(),
2178 replica_id: replica_id.0 as u32,
2179 collaborators: collaborators.clone(),
2180 language_servers: project.language_servers.clone(),
2181 role: project.role.into(),
2182 dev_server_project_id: project
2183 .dev_server_project_id
2184 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2185 })?;
2186
2187 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2188 #[cfg(any(test, feature = "test-support"))]
2189 const MAX_CHUNK_SIZE: usize = 2;
2190 #[cfg(not(any(test, feature = "test-support")))]
2191 const MAX_CHUNK_SIZE: usize = 256;
2192
2193 // Stream this worktree's entries.
2194 let message = proto::UpdateWorktree {
2195 project_id: project_id.to_proto(),
2196 worktree_id,
2197 abs_path: worktree.abs_path.clone(),
2198 root_name: worktree.root_name,
2199 updated_entries: worktree.entries,
2200 removed_entries: Default::default(),
2201 scan_id: worktree.scan_id,
2202 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2203 updated_repositories: worktree.repository_entries.into_values().collect(),
2204 removed_repositories: Default::default(),
2205 };
2206 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2207 session.peer.send(session.connection_id, update.clone())?;
2208 }
2209
2210 // Stream this worktree's diagnostics.
2211 for summary in worktree.diagnostic_summaries {
2212 session.peer.send(
2213 session.connection_id,
2214 proto::UpdateDiagnosticSummary {
2215 project_id: project_id.to_proto(),
2216 worktree_id: worktree.id,
2217 summary: Some(summary),
2218 },
2219 )?;
2220 }
2221
2222 for settings_file in worktree.settings_files {
2223 session.peer.send(
2224 session.connection_id,
2225 proto::UpdateWorktreeSettings {
2226 project_id: project_id.to_proto(),
2227 worktree_id: worktree.id,
2228 path: settings_file.path,
2229 content: Some(settings_file.content),
2230 },
2231 )?;
2232 }
2233 }
2234
2235 for language_server in &project.language_servers {
2236 session.peer.send(
2237 session.connection_id,
2238 proto::UpdateLanguageServer {
2239 project_id: project_id.to_proto(),
2240 language_server_id: language_server.id,
2241 variant: Some(
2242 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2243 proto::LspDiskBasedDiagnosticsUpdated {},
2244 ),
2245 ),
2246 },
2247 )?;
2248 }
2249
2250 Ok(())
2251}
2252
2253/// Leave someone elses shared project.
2254async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2255 let sender_id = session.connection_id;
2256 let project_id = ProjectId::from_proto(request.project_id);
2257 let db = session.db().await;
2258 if db.is_hosted_project(project_id).await? {
2259 let project = db.leave_hosted_project(project_id, sender_id).await?;
2260 project_left(&project, &session);
2261 return Ok(());
2262 }
2263
2264 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2265 tracing::info!(
2266 %project_id,
2267 "leave project"
2268 );
2269
2270 project_left(&project, &session);
2271 if let Some(room) = room {
2272 room_updated(&room, &session.peer);
2273 }
2274
2275 Ok(())
2276}
2277
2278async fn join_hosted_project(
2279 request: proto::JoinHostedProject,
2280 response: Response<proto::JoinHostedProject>,
2281 session: UserSession,
2282) -> Result<()> {
2283 let (mut project, replica_id) = session
2284 .db()
2285 .await
2286 .join_hosted_project(
2287 ProjectId(request.project_id as i32),
2288 session.user_id(),
2289 session.connection_id,
2290 )
2291 .await?;
2292
2293 join_project_internal(response, session, &mut project, &replica_id)
2294}
2295
2296async fn list_remote_directory(
2297 request: proto::ListRemoteDirectory,
2298 response: Response<proto::ListRemoteDirectory>,
2299 session: UserSession,
2300) -> Result<()> {
2301 let dev_server_id = DevServerId(request.dev_server_id as i32);
2302 let dev_server_connection_id = session
2303 .connection_pool()
2304 .await
2305 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2306
2307 session
2308 .db()
2309 .await
2310 .get_dev_server_for_user(dev_server_id, session.user_id())
2311 .await?;
2312
2313 response.send(
2314 session
2315 .peer
2316 .forward_request(session.connection_id, dev_server_connection_id, request)
2317 .await?,
2318 )?;
2319 Ok(())
2320}
2321
2322async fn update_dev_server_project(
2323 request: proto::UpdateDevServerProject,
2324 response: Response<proto::UpdateDevServerProject>,
2325 session: UserSession,
2326) -> Result<()> {
2327 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2328
2329 let (dev_server_project, update) = session
2330 .db()
2331 .await
2332 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2333 .await?;
2334
2335 let projects = session
2336 .db()
2337 .await
2338 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2339 .await?;
2340
2341 let dev_server_connection_id = session
2342 .connection_pool()
2343 .await
2344 .dev_server_connection_id_supporting(
2345 dev_server_project.dev_server_id,
2346 ZedVersion::with_list_directory(),
2347 )?;
2348
2349 session.peer.send(
2350 dev_server_connection_id,
2351 proto::DevServerInstructions { projects },
2352 )?;
2353
2354 send_dev_server_projects_update(session.user_id(), update, &session).await;
2355
2356 response.send(proto::Ack {})
2357}
2358
2359async fn create_dev_server_project(
2360 request: proto::CreateDevServerProject,
2361 response: Response<proto::CreateDevServerProject>,
2362 session: UserSession,
2363) -> Result<()> {
2364 let dev_server_id = DevServerId(request.dev_server_id as i32);
2365 let dev_server_connection_id = session
2366 .connection_pool()
2367 .await
2368 .dev_server_connection_id(dev_server_id);
2369 let Some(dev_server_connection_id) = dev_server_connection_id else {
2370 Err(ErrorCode::DevServerOffline
2371 .message("Cannot create a remote project when the dev server is offline".to_string())
2372 .anyhow())?
2373 };
2374
2375 let path = request.path.clone();
2376 //Check that the path exists on the dev server
2377 session
2378 .peer
2379 .forward_request(
2380 session.connection_id,
2381 dev_server_connection_id,
2382 proto::ValidateDevServerProjectRequest { path: path.clone() },
2383 )
2384 .await?;
2385
2386 let (dev_server_project, update) = session
2387 .db()
2388 .await
2389 .create_dev_server_project(
2390 DevServerId(request.dev_server_id as i32),
2391 &request.path,
2392 session.user_id(),
2393 )
2394 .await?;
2395
2396 let projects = session
2397 .db()
2398 .await
2399 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2400 .await?;
2401
2402 session.peer.send(
2403 dev_server_connection_id,
2404 proto::DevServerInstructions { projects },
2405 )?;
2406
2407 send_dev_server_projects_update(session.user_id(), update, &session).await;
2408
2409 response.send(proto::CreateDevServerProjectResponse {
2410 dev_server_project: Some(dev_server_project.to_proto(None)),
2411 })?;
2412 Ok(())
2413}
2414
2415async fn create_dev_server(
2416 request: proto::CreateDevServer,
2417 response: Response<proto::CreateDevServer>,
2418 session: UserSession,
2419) -> Result<()> {
2420 let access_token = auth::random_token();
2421 let hashed_access_token = auth::hash_access_token(&access_token);
2422
2423 if request.name.is_empty() {
2424 return Err(proto::ErrorCode::Forbidden
2425 .message("Dev server name cannot be empty".to_string())
2426 .anyhow())?;
2427 }
2428
2429 let (dev_server, status) = session
2430 .db()
2431 .await
2432 .create_dev_server(
2433 &request.name,
2434 request.ssh_connection_string.as_deref(),
2435 &hashed_access_token,
2436 session.user_id(),
2437 )
2438 .await?;
2439
2440 send_dev_server_projects_update(session.user_id(), status, &session).await;
2441
2442 response.send(proto::CreateDevServerResponse {
2443 dev_server_id: dev_server.id.0 as u64,
2444 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2445 name: request.name,
2446 })?;
2447 Ok(())
2448}
2449
2450async fn regenerate_dev_server_token(
2451 request: proto::RegenerateDevServerToken,
2452 response: Response<proto::RegenerateDevServerToken>,
2453 session: UserSession,
2454) -> Result<()> {
2455 let dev_server_id = DevServerId(request.dev_server_id as i32);
2456 let access_token = auth::random_token();
2457 let hashed_access_token = auth::hash_access_token(&access_token);
2458
2459 let connection_id = session
2460 .connection_pool()
2461 .await
2462 .dev_server_connection_id(dev_server_id);
2463 if let Some(connection_id) = connection_id {
2464 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2465 session.peer.send(
2466 connection_id,
2467 proto::ShutdownDevServer {
2468 reason: Some("dev server token was regenerated".to_string()),
2469 },
2470 )?;
2471 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2472 }
2473
2474 let status = session
2475 .db()
2476 .await
2477 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2478 .await?;
2479
2480 send_dev_server_projects_update(session.user_id(), status, &session).await;
2481
2482 response.send(proto::RegenerateDevServerTokenResponse {
2483 dev_server_id: dev_server_id.to_proto(),
2484 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2485 })?;
2486 Ok(())
2487}
2488
2489async fn rename_dev_server(
2490 request: proto::RenameDevServer,
2491 response: Response<proto::RenameDevServer>,
2492 session: UserSession,
2493) -> Result<()> {
2494 if request.name.trim().is_empty() {
2495 return Err(proto::ErrorCode::Forbidden
2496 .message("Dev server name cannot be empty".to_string())
2497 .anyhow())?;
2498 }
2499
2500 let dev_server_id = DevServerId(request.dev_server_id as i32);
2501 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2502 if dev_server.user_id != session.user_id() {
2503 return Err(anyhow!(ErrorCode::Forbidden))?;
2504 }
2505
2506 let status = session
2507 .db()
2508 .await
2509 .rename_dev_server(
2510 dev_server_id,
2511 &request.name,
2512 request.ssh_connection_string.as_deref(),
2513 session.user_id(),
2514 )
2515 .await?;
2516
2517 send_dev_server_projects_update(session.user_id(), status, &session).await;
2518
2519 response.send(proto::Ack {})?;
2520 Ok(())
2521}
2522
2523async fn delete_dev_server(
2524 request: proto::DeleteDevServer,
2525 response: Response<proto::DeleteDevServer>,
2526 session: UserSession,
2527) -> Result<()> {
2528 let dev_server_id = DevServerId(request.dev_server_id as i32);
2529 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2530 if dev_server.user_id != session.user_id() {
2531 return Err(anyhow!(ErrorCode::Forbidden))?;
2532 }
2533
2534 let connection_id = session
2535 .connection_pool()
2536 .await
2537 .dev_server_connection_id(dev_server_id);
2538 if let Some(connection_id) = connection_id {
2539 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2540 session.peer.send(
2541 connection_id,
2542 proto::ShutdownDevServer {
2543 reason: Some("dev server was deleted".to_string()),
2544 },
2545 )?;
2546 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2547 }
2548
2549 let status = session
2550 .db()
2551 .await
2552 .delete_dev_server(dev_server_id, session.user_id())
2553 .await?;
2554
2555 send_dev_server_projects_update(session.user_id(), status, &session).await;
2556
2557 response.send(proto::Ack {})?;
2558 Ok(())
2559}
2560
2561async fn delete_dev_server_project(
2562 request: proto::DeleteDevServerProject,
2563 response: Response<proto::DeleteDevServerProject>,
2564 session: UserSession,
2565) -> Result<()> {
2566 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2567 let dev_server_project = session
2568 .db()
2569 .await
2570 .get_dev_server_project(dev_server_project_id)
2571 .await?;
2572
2573 let dev_server = session
2574 .db()
2575 .await
2576 .get_dev_server(dev_server_project.dev_server_id)
2577 .await?;
2578 if dev_server.user_id != session.user_id() {
2579 return Err(anyhow!(ErrorCode::Forbidden))?;
2580 }
2581
2582 let dev_server_connection_id = session
2583 .connection_pool()
2584 .await
2585 .dev_server_connection_id(dev_server.id);
2586
2587 if let Some(dev_server_connection_id) = dev_server_connection_id {
2588 let project = session
2589 .db()
2590 .await
2591 .find_dev_server_project(dev_server_project_id)
2592 .await;
2593 if let Ok(project) = project {
2594 unshare_project_internal(
2595 project.id,
2596 dev_server_connection_id,
2597 Some(session.user_id()),
2598 &session,
2599 )
2600 .await?;
2601 }
2602 }
2603
2604 let (projects, status) = session
2605 .db()
2606 .await
2607 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2608 .await?;
2609
2610 if let Some(dev_server_connection_id) = dev_server_connection_id {
2611 session.peer.send(
2612 dev_server_connection_id,
2613 proto::DevServerInstructions { projects },
2614 )?;
2615 }
2616
2617 send_dev_server_projects_update(session.user_id(), status, &session).await;
2618
2619 response.send(proto::Ack {})?;
2620 Ok(())
2621}
2622
2623async fn rejoin_dev_server_projects(
2624 request: proto::RejoinRemoteProjects,
2625 response: Response<proto::RejoinRemoteProjects>,
2626 session: UserSession,
2627) -> Result<()> {
2628 let mut rejoined_projects = {
2629 let db = session.db().await;
2630 db.rejoin_dev_server_projects(
2631 &request.rejoined_projects,
2632 session.user_id(),
2633 session.0.connection_id,
2634 )
2635 .await?
2636 };
2637 response.send(proto::RejoinRemoteProjectsResponse {
2638 rejoined_projects: rejoined_projects
2639 .iter()
2640 .map(|project| project.to_proto())
2641 .collect(),
2642 })?;
2643 notify_rejoined_projects(&mut rejoined_projects, &session)
2644}
2645
2646async fn reconnect_dev_server(
2647 request: proto::ReconnectDevServer,
2648 response: Response<proto::ReconnectDevServer>,
2649 session: DevServerSession,
2650) -> Result<()> {
2651 let reshared_projects = {
2652 let db = session.db().await;
2653 db.reshare_dev_server_projects(
2654 &request.reshared_projects,
2655 session.dev_server_id(),
2656 session.0.connection_id,
2657 )
2658 .await?
2659 };
2660
2661 for project in &reshared_projects {
2662 for collaborator in &project.collaborators {
2663 session
2664 .peer
2665 .send(
2666 collaborator.connection_id,
2667 proto::UpdateProjectCollaborator {
2668 project_id: project.id.to_proto(),
2669 old_peer_id: Some(project.old_connection_id.into()),
2670 new_peer_id: Some(session.connection_id.into()),
2671 },
2672 )
2673 .trace_err();
2674 }
2675
2676 broadcast(
2677 Some(session.connection_id),
2678 project
2679 .collaborators
2680 .iter()
2681 .map(|collaborator| collaborator.connection_id),
2682 |connection_id| {
2683 session.peer.forward_send(
2684 session.connection_id,
2685 connection_id,
2686 proto::UpdateProject {
2687 project_id: project.id.to_proto(),
2688 worktrees: project.worktrees.clone(),
2689 },
2690 )
2691 },
2692 );
2693 }
2694
2695 response.send(proto::ReconnectDevServerResponse {
2696 reshared_projects: reshared_projects
2697 .iter()
2698 .map(|project| proto::ResharedProject {
2699 id: project.id.to_proto(),
2700 collaborators: project
2701 .collaborators
2702 .iter()
2703 .map(|collaborator| collaborator.to_proto())
2704 .collect(),
2705 })
2706 .collect(),
2707 })?;
2708
2709 Ok(())
2710}
2711
2712async fn shutdown_dev_server(
2713 _: proto::ShutdownDevServer,
2714 response: Response<proto::ShutdownDevServer>,
2715 session: DevServerSession,
2716) -> Result<()> {
2717 response.send(proto::Ack {})?;
2718 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2719 remove_dev_server_connection(session.dev_server_id(), &session).await
2720}
2721
2722async fn shutdown_dev_server_internal(
2723 dev_server_id: DevServerId,
2724 connection_id: ConnectionId,
2725 session: &Session,
2726) -> Result<()> {
2727 let (dev_server_projects, dev_server) = {
2728 let db = session.db().await;
2729 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2730 let dev_server = db.get_dev_server(dev_server_id).await?;
2731 (dev_server_projects, dev_server)
2732 };
2733
2734 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2735 unshare_project_internal(
2736 ProjectId::from_proto(project_id),
2737 connection_id,
2738 None,
2739 session,
2740 )
2741 .await?;
2742 }
2743
2744 session
2745 .connection_pool()
2746 .await
2747 .set_dev_server_offline(dev_server_id);
2748
2749 let status = session
2750 .db()
2751 .await
2752 .dev_server_projects_update(dev_server.user_id)
2753 .await?;
2754 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2755
2756 Ok(())
2757}
2758
2759async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2760 let dev_server_connection = session
2761 .connection_pool()
2762 .await
2763 .dev_server_connection_id(dev_server_id);
2764
2765 if let Some(dev_server_connection) = dev_server_connection {
2766 session
2767 .connection_pool()
2768 .await
2769 .remove_connection(dev_server_connection)?;
2770 }
2771 Ok(())
2772}
2773
2774/// Updates other participants with changes to the project
2775async fn update_project(
2776 request: proto::UpdateProject,
2777 response: Response<proto::UpdateProject>,
2778 session: Session,
2779) -> Result<()> {
2780 let project_id = ProjectId::from_proto(request.project_id);
2781 let (room, guest_connection_ids) = &*session
2782 .db()
2783 .await
2784 .update_project(project_id, session.connection_id, &request.worktrees)
2785 .await?;
2786 broadcast(
2787 Some(session.connection_id),
2788 guest_connection_ids.iter().copied(),
2789 |connection_id| {
2790 session
2791 .peer
2792 .forward_send(session.connection_id, connection_id, request.clone())
2793 },
2794 );
2795 if let Some(room) = room {
2796 room_updated(&room, &session.peer);
2797 }
2798 response.send(proto::Ack {})?;
2799
2800 Ok(())
2801}
2802
2803/// Updates other participants with changes to the worktree
2804async fn update_worktree(
2805 request: proto::UpdateWorktree,
2806 response: Response<proto::UpdateWorktree>,
2807 session: Session,
2808) -> Result<()> {
2809 let guest_connection_ids = session
2810 .db()
2811 .await
2812 .update_worktree(&request, session.connection_id)
2813 .await?;
2814
2815 broadcast(
2816 Some(session.connection_id),
2817 guest_connection_ids.iter().copied(),
2818 |connection_id| {
2819 session
2820 .peer
2821 .forward_send(session.connection_id, connection_id, request.clone())
2822 },
2823 );
2824 response.send(proto::Ack {})?;
2825 Ok(())
2826}
2827
2828/// Updates other participants with changes to the diagnostics
2829async fn update_diagnostic_summary(
2830 message: proto::UpdateDiagnosticSummary,
2831 session: Session,
2832) -> Result<()> {
2833 let guest_connection_ids = session
2834 .db()
2835 .await
2836 .update_diagnostic_summary(&message, session.connection_id)
2837 .await?;
2838
2839 broadcast(
2840 Some(session.connection_id),
2841 guest_connection_ids.iter().copied(),
2842 |connection_id| {
2843 session
2844 .peer
2845 .forward_send(session.connection_id, connection_id, message.clone())
2846 },
2847 );
2848
2849 Ok(())
2850}
2851
2852/// Updates other participants with changes to the worktree settings
2853async fn update_worktree_settings(
2854 message: proto::UpdateWorktreeSettings,
2855 session: Session,
2856) -> Result<()> {
2857 let guest_connection_ids = session
2858 .db()
2859 .await
2860 .update_worktree_settings(&message, session.connection_id)
2861 .await?;
2862
2863 broadcast(
2864 Some(session.connection_id),
2865 guest_connection_ids.iter().copied(),
2866 |connection_id| {
2867 session
2868 .peer
2869 .forward_send(session.connection_id, connection_id, message.clone())
2870 },
2871 );
2872
2873 Ok(())
2874}
2875
2876/// Notify other participants that a language server has started.
2877async fn start_language_server(
2878 request: proto::StartLanguageServer,
2879 session: Session,
2880) -> Result<()> {
2881 let guest_connection_ids = session
2882 .db()
2883 .await
2884 .start_language_server(&request, session.connection_id)
2885 .await?;
2886
2887 broadcast(
2888 Some(session.connection_id),
2889 guest_connection_ids.iter().copied(),
2890 |connection_id| {
2891 session
2892 .peer
2893 .forward_send(session.connection_id, connection_id, request.clone())
2894 },
2895 );
2896 Ok(())
2897}
2898
2899/// Notify other participants that a language server has changed.
2900async fn update_language_server(
2901 request: proto::UpdateLanguageServer,
2902 session: Session,
2903) -> Result<()> {
2904 let project_id = ProjectId::from_proto(request.project_id);
2905 let project_connection_ids = session
2906 .db()
2907 .await
2908 .project_connection_ids(project_id, session.connection_id, true)
2909 .await?;
2910 broadcast(
2911 Some(session.connection_id),
2912 project_connection_ids.iter().copied(),
2913 |connection_id| {
2914 session
2915 .peer
2916 .forward_send(session.connection_id, connection_id, request.clone())
2917 },
2918 );
2919 Ok(())
2920}
2921
2922/// forward a project request to the host. These requests should be read only
2923/// as guests are allowed to send them.
2924async fn forward_read_only_project_request<T>(
2925 request: T,
2926 response: Response<T>,
2927 session: UserSession,
2928) -> Result<()>
2929where
2930 T: EntityMessage + RequestMessage,
2931{
2932 let project_id = ProjectId::from_proto(request.remote_entity_id());
2933 let host_connection_id = session
2934 .db()
2935 .await
2936 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2937 .await?;
2938 let payload = session
2939 .peer
2940 .forward_request(session.connection_id, host_connection_id, request)
2941 .await?;
2942 response.send(payload)?;
2943 Ok(())
2944}
2945
2946/// forward a project request to the dev server. Only allowed
2947/// if it's your dev server.
2948async fn forward_project_request_for_owner<T>(
2949 request: T,
2950 response: Response<T>,
2951 session: UserSession,
2952) -> Result<()>
2953where
2954 T: EntityMessage + RequestMessage,
2955{
2956 let project_id = ProjectId::from_proto(request.remote_entity_id());
2957
2958 let host_connection_id = session
2959 .db()
2960 .await
2961 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2962 .await?;
2963 let payload = session
2964 .peer
2965 .forward_request(session.connection_id, host_connection_id, request)
2966 .await?;
2967 response.send(payload)?;
2968 Ok(())
2969}
2970
2971/// forward a project request to the host. These requests are disallowed
2972/// for guests.
2973async fn forward_mutating_project_request<T>(
2974 request: T,
2975 response: Response<T>,
2976 session: UserSession,
2977) -> Result<()>
2978where
2979 T: EntityMessage + RequestMessage,
2980{
2981 let project_id = ProjectId::from_proto(request.remote_entity_id());
2982
2983 let host_connection_id = session
2984 .db()
2985 .await
2986 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2987 .await?;
2988 let payload = session
2989 .peer
2990 .forward_request(session.connection_id, host_connection_id, request)
2991 .await?;
2992 response.send(payload)?;
2993 Ok(())
2994}
2995
2996/// forward a project request to the host. These requests are disallowed
2997/// for guests.
2998async fn forward_versioned_mutating_project_request<T>(
2999 request: T,
3000 response: Response<T>,
3001 session: UserSession,
3002) -> Result<()>
3003where
3004 T: EntityMessage + RequestMessage + VersionedMessage,
3005{
3006 let project_id = ProjectId::from_proto(request.remote_entity_id());
3007
3008 let host_connection_id = session
3009 .db()
3010 .await
3011 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3012 .await?;
3013 if let Some(host_version) = session
3014 .connection_pool()
3015 .await
3016 .connection(host_connection_id)
3017 .map(|c| c.zed_version)
3018 {
3019 if let Some(min_required_version) = request.required_host_version() {
3020 if min_required_version > host_version {
3021 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3022 .with_tag("required", &min_required_version.to_string())))?;
3023 }
3024 }
3025 }
3026
3027 let payload = session
3028 .peer
3029 .forward_request(session.connection_id, host_connection_id, request)
3030 .await?;
3031 response.send(payload)?;
3032 Ok(())
3033}
3034
3035/// Notify other participants that a new buffer has been created
3036async fn create_buffer_for_peer(
3037 request: proto::CreateBufferForPeer,
3038 session: Session,
3039) -> Result<()> {
3040 session
3041 .db()
3042 .await
3043 .check_user_is_project_host(
3044 ProjectId::from_proto(request.project_id),
3045 session.connection_id,
3046 )
3047 .await?;
3048 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3049 session
3050 .peer
3051 .forward_send(session.connection_id, peer_id.into(), request)?;
3052 Ok(())
3053}
3054
3055/// Notify other participants that a buffer has been updated. This is
3056/// allowed for guests as long as the update is limited to selections.
3057async fn update_buffer(
3058 request: proto::UpdateBuffer,
3059 response: Response<proto::UpdateBuffer>,
3060 session: Session,
3061) -> Result<()> {
3062 let project_id = ProjectId::from_proto(request.project_id);
3063 let mut capability = Capability::ReadOnly;
3064
3065 for op in request.operations.iter() {
3066 match op.variant {
3067 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3068 Some(_) => capability = Capability::ReadWrite,
3069 }
3070 }
3071
3072 let host = {
3073 let guard = session
3074 .db()
3075 .await
3076 .connections_for_buffer_update(
3077 project_id,
3078 session.principal_id(),
3079 session.connection_id,
3080 capability,
3081 )
3082 .await?;
3083
3084 let (host, guests) = &*guard;
3085
3086 broadcast(
3087 Some(session.connection_id),
3088 guests.clone(),
3089 |connection_id| {
3090 session
3091 .peer
3092 .forward_send(session.connection_id, connection_id, request.clone())
3093 },
3094 );
3095
3096 *host
3097 };
3098
3099 if host != session.connection_id {
3100 session
3101 .peer
3102 .forward_request(session.connection_id, host, request.clone())
3103 .await?;
3104 }
3105
3106 response.send(proto::Ack {})?;
3107 Ok(())
3108}
3109
3110async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3111 let project_id = ProjectId::from_proto(message.project_id);
3112
3113 let operation = message.operation.as_ref().context("invalid operation")?;
3114 let capability = match operation.variant.as_ref() {
3115 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3116 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3117 match buffer_op.variant {
3118 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3119 Capability::ReadOnly
3120 }
3121 _ => Capability::ReadWrite,
3122 }
3123 } else {
3124 Capability::ReadWrite
3125 }
3126 }
3127 Some(_) => Capability::ReadWrite,
3128 None => Capability::ReadOnly,
3129 };
3130
3131 let guard = session
3132 .db()
3133 .await
3134 .connections_for_buffer_update(
3135 project_id,
3136 session.principal_id(),
3137 session.connection_id,
3138 capability,
3139 )
3140 .await?;
3141
3142 let (host, guests) = &*guard;
3143
3144 broadcast(
3145 Some(session.connection_id),
3146 guests.iter().chain([host]).copied(),
3147 |connection_id| {
3148 session
3149 .peer
3150 .forward_send(session.connection_id, connection_id, message.clone())
3151 },
3152 );
3153
3154 Ok(())
3155}
3156
3157/// Notify other participants that a project has been updated.
3158async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3159 request: T,
3160 session: Session,
3161) -> Result<()> {
3162 let project_id = ProjectId::from_proto(request.remote_entity_id());
3163 let project_connection_ids = session
3164 .db()
3165 .await
3166 .project_connection_ids(project_id, session.connection_id, false)
3167 .await?;
3168
3169 broadcast(
3170 Some(session.connection_id),
3171 project_connection_ids.iter().copied(),
3172 |connection_id| {
3173 session
3174 .peer
3175 .forward_send(session.connection_id, connection_id, request.clone())
3176 },
3177 );
3178 Ok(())
3179}
3180
3181/// Start following another user in a call.
3182async fn follow(
3183 request: proto::Follow,
3184 response: Response<proto::Follow>,
3185 session: UserSession,
3186) -> Result<()> {
3187 let room_id = RoomId::from_proto(request.room_id);
3188 let project_id = request.project_id.map(ProjectId::from_proto);
3189 let leader_id = request
3190 .leader_id
3191 .ok_or_else(|| anyhow!("invalid leader id"))?
3192 .into();
3193 let follower_id = session.connection_id;
3194
3195 session
3196 .db()
3197 .await
3198 .check_room_participants(room_id, leader_id, session.connection_id)
3199 .await?;
3200
3201 let response_payload = session
3202 .peer
3203 .forward_request(session.connection_id, leader_id, request)
3204 .await?;
3205 response.send(response_payload)?;
3206
3207 if let Some(project_id) = project_id {
3208 let room = session
3209 .db()
3210 .await
3211 .follow(room_id, project_id, leader_id, follower_id)
3212 .await?;
3213 room_updated(&room, &session.peer);
3214 }
3215
3216 Ok(())
3217}
3218
3219/// Stop following another user in a call.
3220async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3221 let room_id = RoomId::from_proto(request.room_id);
3222 let project_id = request.project_id.map(ProjectId::from_proto);
3223 let leader_id = request
3224 .leader_id
3225 .ok_or_else(|| anyhow!("invalid leader id"))?
3226 .into();
3227 let follower_id = session.connection_id;
3228
3229 session
3230 .db()
3231 .await
3232 .check_room_participants(room_id, leader_id, session.connection_id)
3233 .await?;
3234
3235 session
3236 .peer
3237 .forward_send(session.connection_id, leader_id, request)?;
3238
3239 if let Some(project_id) = project_id {
3240 let room = session
3241 .db()
3242 .await
3243 .unfollow(room_id, project_id, leader_id, follower_id)
3244 .await?;
3245 room_updated(&room, &session.peer);
3246 }
3247
3248 Ok(())
3249}
3250
3251/// Notify everyone following you of your current location.
3252async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3253 let room_id = RoomId::from_proto(request.room_id);
3254 let database = session.db.lock().await;
3255
3256 let connection_ids = if let Some(project_id) = request.project_id {
3257 let project_id = ProjectId::from_proto(project_id);
3258 database
3259 .project_connection_ids(project_id, session.connection_id, true)
3260 .await?
3261 } else {
3262 database
3263 .room_connection_ids(room_id, session.connection_id)
3264 .await?
3265 };
3266
3267 // For now, don't send view update messages back to that view's current leader.
3268 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3269 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3270 _ => None,
3271 });
3272
3273 for connection_id in connection_ids.iter().cloned() {
3274 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3275 session
3276 .peer
3277 .forward_send(session.connection_id, connection_id, request.clone())?;
3278 }
3279 }
3280 Ok(())
3281}
3282
3283/// Get public data about users.
3284async fn get_users(
3285 request: proto::GetUsers,
3286 response: Response<proto::GetUsers>,
3287 session: Session,
3288) -> Result<()> {
3289 let user_ids = request
3290 .user_ids
3291 .into_iter()
3292 .map(UserId::from_proto)
3293 .collect();
3294 let users = session
3295 .db()
3296 .await
3297 .get_users_by_ids(user_ids)
3298 .await?
3299 .into_iter()
3300 .map(|user| proto::User {
3301 id: user.id.to_proto(),
3302 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3303 github_login: user.github_login,
3304 })
3305 .collect();
3306 response.send(proto::UsersResponse { users })?;
3307 Ok(())
3308}
3309
3310/// Search for users (to invite) buy Github login
3311async fn fuzzy_search_users(
3312 request: proto::FuzzySearchUsers,
3313 response: Response<proto::FuzzySearchUsers>,
3314 session: UserSession,
3315) -> Result<()> {
3316 let query = request.query;
3317 let users = match query.len() {
3318 0 => vec![],
3319 1 | 2 => session
3320 .db()
3321 .await
3322 .get_user_by_github_login(&query)
3323 .await?
3324 .into_iter()
3325 .collect(),
3326 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3327 };
3328 let users = users
3329 .into_iter()
3330 .filter(|user| user.id != session.user_id())
3331 .map(|user| proto::User {
3332 id: user.id.to_proto(),
3333 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3334 github_login: user.github_login,
3335 })
3336 .collect();
3337 response.send(proto::UsersResponse { users })?;
3338 Ok(())
3339}
3340
3341/// Send a contact request to another user.
3342async fn request_contact(
3343 request: proto::RequestContact,
3344 response: Response<proto::RequestContact>,
3345 session: UserSession,
3346) -> Result<()> {
3347 let requester_id = session.user_id();
3348 let responder_id = UserId::from_proto(request.responder_id);
3349 if requester_id == responder_id {
3350 return Err(anyhow!("cannot add yourself as a contact"))?;
3351 }
3352
3353 let notifications = session
3354 .db()
3355 .await
3356 .send_contact_request(requester_id, responder_id)
3357 .await?;
3358
3359 // Update outgoing contact requests of requester
3360 let mut update = proto::UpdateContacts::default();
3361 update.outgoing_requests.push(responder_id.to_proto());
3362 for connection_id in session
3363 .connection_pool()
3364 .await
3365 .user_connection_ids(requester_id)
3366 {
3367 session.peer.send(connection_id, update.clone())?;
3368 }
3369
3370 // Update incoming contact requests of responder
3371 let mut update = proto::UpdateContacts::default();
3372 update
3373 .incoming_requests
3374 .push(proto::IncomingContactRequest {
3375 requester_id: requester_id.to_proto(),
3376 });
3377 let connection_pool = session.connection_pool().await;
3378 for connection_id in connection_pool.user_connection_ids(responder_id) {
3379 session.peer.send(connection_id, update.clone())?;
3380 }
3381
3382 send_notifications(&connection_pool, &session.peer, notifications);
3383
3384 response.send(proto::Ack {})?;
3385 Ok(())
3386}
3387
3388/// Accept or decline a contact request
3389async fn respond_to_contact_request(
3390 request: proto::RespondToContactRequest,
3391 response: Response<proto::RespondToContactRequest>,
3392 session: UserSession,
3393) -> Result<()> {
3394 let responder_id = session.user_id();
3395 let requester_id = UserId::from_proto(request.requester_id);
3396 let db = session.db().await;
3397 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3398 db.dismiss_contact_notification(responder_id, requester_id)
3399 .await?;
3400 } else {
3401 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3402
3403 let notifications = db
3404 .respond_to_contact_request(responder_id, requester_id, accept)
3405 .await?;
3406 let requester_busy = db.is_user_busy(requester_id).await?;
3407 let responder_busy = db.is_user_busy(responder_id).await?;
3408
3409 let pool = session.connection_pool().await;
3410 // Update responder with new contact
3411 let mut update = proto::UpdateContacts::default();
3412 if accept {
3413 update
3414 .contacts
3415 .push(contact_for_user(requester_id, requester_busy, &pool));
3416 }
3417 update
3418 .remove_incoming_requests
3419 .push(requester_id.to_proto());
3420 for connection_id in pool.user_connection_ids(responder_id) {
3421 session.peer.send(connection_id, update.clone())?;
3422 }
3423
3424 // Update requester with new contact
3425 let mut update = proto::UpdateContacts::default();
3426 if accept {
3427 update
3428 .contacts
3429 .push(contact_for_user(responder_id, responder_busy, &pool));
3430 }
3431 update
3432 .remove_outgoing_requests
3433 .push(responder_id.to_proto());
3434
3435 for connection_id in pool.user_connection_ids(requester_id) {
3436 session.peer.send(connection_id, update.clone())?;
3437 }
3438
3439 send_notifications(&pool, &session.peer, notifications);
3440 }
3441
3442 response.send(proto::Ack {})?;
3443 Ok(())
3444}
3445
3446/// Remove a contact.
3447async fn remove_contact(
3448 request: proto::RemoveContact,
3449 response: Response<proto::RemoveContact>,
3450 session: UserSession,
3451) -> Result<()> {
3452 let requester_id = session.user_id();
3453 let responder_id = UserId::from_proto(request.user_id);
3454 let db = session.db().await;
3455 let (contact_accepted, deleted_notification_id) =
3456 db.remove_contact(requester_id, responder_id).await?;
3457
3458 let pool = session.connection_pool().await;
3459 // Update outgoing contact requests of requester
3460 let mut update = proto::UpdateContacts::default();
3461 if contact_accepted {
3462 update.remove_contacts.push(responder_id.to_proto());
3463 } else {
3464 update
3465 .remove_outgoing_requests
3466 .push(responder_id.to_proto());
3467 }
3468 for connection_id in pool.user_connection_ids(requester_id) {
3469 session.peer.send(connection_id, update.clone())?;
3470 }
3471
3472 // Update incoming contact requests of responder
3473 let mut update = proto::UpdateContacts::default();
3474 if contact_accepted {
3475 update.remove_contacts.push(requester_id.to_proto());
3476 } else {
3477 update
3478 .remove_incoming_requests
3479 .push(requester_id.to_proto());
3480 }
3481 for connection_id in pool.user_connection_ids(responder_id) {
3482 session.peer.send(connection_id, update.clone())?;
3483 if let Some(notification_id) = deleted_notification_id {
3484 session.peer.send(
3485 connection_id,
3486 proto::DeleteNotification {
3487 notification_id: notification_id.to_proto(),
3488 },
3489 )?;
3490 }
3491 }
3492
3493 response.send(proto::Ack {})?;
3494 Ok(())
3495}
3496
3497fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3498 version.0.minor() < 139
3499}
3500
3501async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3502 let plan = session.current_plan(session.db().await).await?;
3503
3504 session
3505 .peer
3506 .send(
3507 session.connection_id,
3508 proto::UpdateUserPlan { plan: plan.into() },
3509 )
3510 .trace_err();
3511
3512 Ok(())
3513}
3514
3515async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3516 subscribe_user_to_channels(
3517 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3518 &session,
3519 )
3520 .await?;
3521 Ok(())
3522}
3523
3524async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3525 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3526 let mut pool = session.connection_pool().await;
3527 for membership in &channels_for_user.channel_memberships {
3528 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3529 }
3530 session.peer.send(
3531 session.connection_id,
3532 build_update_user_channels(&channels_for_user),
3533 )?;
3534 session.peer.send(
3535 session.connection_id,
3536 build_channels_update(channels_for_user),
3537 )?;
3538 Ok(())
3539}
3540
3541/// Creates a new channel.
3542async fn create_channel(
3543 request: proto::CreateChannel,
3544 response: Response<proto::CreateChannel>,
3545 session: UserSession,
3546) -> Result<()> {
3547 let db = session.db().await;
3548
3549 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3550 let (channel, membership) = db
3551 .create_channel(&request.name, parent_id, session.user_id())
3552 .await?;
3553
3554 let root_id = channel.root_id();
3555 let channel = Channel::from_model(channel);
3556
3557 response.send(proto::CreateChannelResponse {
3558 channel: Some(channel.to_proto()),
3559 parent_id: request.parent_id,
3560 })?;
3561
3562 let mut connection_pool = session.connection_pool().await;
3563 if let Some(membership) = membership {
3564 connection_pool.subscribe_to_channel(
3565 membership.user_id,
3566 membership.channel_id,
3567 membership.role,
3568 );
3569 let update = proto::UpdateUserChannels {
3570 channel_memberships: vec![proto::ChannelMembership {
3571 channel_id: membership.channel_id.to_proto(),
3572 role: membership.role.into(),
3573 }],
3574 ..Default::default()
3575 };
3576 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3577 session.peer.send(connection_id, update.clone())?;
3578 }
3579 }
3580
3581 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3582 if !role.can_see_channel(channel.visibility) {
3583 continue;
3584 }
3585
3586 let update = proto::UpdateChannels {
3587 channels: vec![channel.to_proto()],
3588 ..Default::default()
3589 };
3590 session.peer.send(connection_id, update.clone())?;
3591 }
3592
3593 Ok(())
3594}
3595
3596/// Delete a channel
3597async fn delete_channel(
3598 request: proto::DeleteChannel,
3599 response: Response<proto::DeleteChannel>,
3600 session: UserSession,
3601) -> Result<()> {
3602 let db = session.db().await;
3603
3604 let channel_id = request.channel_id;
3605 let (root_channel, removed_channels) = db
3606 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3607 .await?;
3608 response.send(proto::Ack {})?;
3609
3610 // Notify members of removed channels
3611 let mut update = proto::UpdateChannels::default();
3612 update
3613 .delete_channels
3614 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3615
3616 let connection_pool = session.connection_pool().await;
3617 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3618 session.peer.send(connection_id, update.clone())?;
3619 }
3620
3621 Ok(())
3622}
3623
3624/// Invite someone to join a channel.
3625async fn invite_channel_member(
3626 request: proto::InviteChannelMember,
3627 response: Response<proto::InviteChannelMember>,
3628 session: UserSession,
3629) -> Result<()> {
3630 let db = session.db().await;
3631 let channel_id = ChannelId::from_proto(request.channel_id);
3632 let invitee_id = UserId::from_proto(request.user_id);
3633 let InviteMemberResult {
3634 channel,
3635 notifications,
3636 } = db
3637 .invite_channel_member(
3638 channel_id,
3639 invitee_id,
3640 session.user_id(),
3641 request.role().into(),
3642 )
3643 .await?;
3644
3645 let update = proto::UpdateChannels {
3646 channel_invitations: vec![channel.to_proto()],
3647 ..Default::default()
3648 };
3649
3650 let connection_pool = session.connection_pool().await;
3651 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3652 session.peer.send(connection_id, update.clone())?;
3653 }
3654
3655 send_notifications(&connection_pool, &session.peer, notifications);
3656
3657 response.send(proto::Ack {})?;
3658 Ok(())
3659}
3660
3661/// remove someone from a channel
3662async fn remove_channel_member(
3663 request: proto::RemoveChannelMember,
3664 response: Response<proto::RemoveChannelMember>,
3665 session: UserSession,
3666) -> Result<()> {
3667 let db = session.db().await;
3668 let channel_id = ChannelId::from_proto(request.channel_id);
3669 let member_id = UserId::from_proto(request.user_id);
3670
3671 let RemoveChannelMemberResult {
3672 membership_update,
3673 notification_id,
3674 } = db
3675 .remove_channel_member(channel_id, member_id, session.user_id())
3676 .await?;
3677
3678 let mut connection_pool = session.connection_pool().await;
3679 notify_membership_updated(
3680 &mut connection_pool,
3681 membership_update,
3682 member_id,
3683 &session.peer,
3684 );
3685 for connection_id in connection_pool.user_connection_ids(member_id) {
3686 if let Some(notification_id) = notification_id {
3687 session
3688 .peer
3689 .send(
3690 connection_id,
3691 proto::DeleteNotification {
3692 notification_id: notification_id.to_proto(),
3693 },
3694 )
3695 .trace_err();
3696 }
3697 }
3698
3699 response.send(proto::Ack {})?;
3700 Ok(())
3701}
3702
3703/// Toggle the channel between public and private.
3704/// Care is taken to maintain the invariant that public channels only descend from public channels,
3705/// (though members-only channels can appear at any point in the hierarchy).
3706async fn set_channel_visibility(
3707 request: proto::SetChannelVisibility,
3708 response: Response<proto::SetChannelVisibility>,
3709 session: UserSession,
3710) -> Result<()> {
3711 let db = session.db().await;
3712 let channel_id = ChannelId::from_proto(request.channel_id);
3713 let visibility = request.visibility().into();
3714
3715 let channel_model = db
3716 .set_channel_visibility(channel_id, visibility, session.user_id())
3717 .await?;
3718 let root_id = channel_model.root_id();
3719 let channel = Channel::from_model(channel_model);
3720
3721 let mut connection_pool = session.connection_pool().await;
3722 for (user_id, role) in connection_pool
3723 .channel_user_ids(root_id)
3724 .collect::<Vec<_>>()
3725 .into_iter()
3726 {
3727 let update = if role.can_see_channel(channel.visibility) {
3728 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3729 proto::UpdateChannels {
3730 channels: vec![channel.to_proto()],
3731 ..Default::default()
3732 }
3733 } else {
3734 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3735 proto::UpdateChannels {
3736 delete_channels: vec![channel.id.to_proto()],
3737 ..Default::default()
3738 }
3739 };
3740
3741 for connection_id in connection_pool.user_connection_ids(user_id) {
3742 session.peer.send(connection_id, update.clone())?;
3743 }
3744 }
3745
3746 response.send(proto::Ack {})?;
3747 Ok(())
3748}
3749
3750/// Alter the role for a user in the channel.
3751async fn set_channel_member_role(
3752 request: proto::SetChannelMemberRole,
3753 response: Response<proto::SetChannelMemberRole>,
3754 session: UserSession,
3755) -> Result<()> {
3756 let db = session.db().await;
3757 let channel_id = ChannelId::from_proto(request.channel_id);
3758 let member_id = UserId::from_proto(request.user_id);
3759 let result = db
3760 .set_channel_member_role(
3761 channel_id,
3762 session.user_id(),
3763 member_id,
3764 request.role().into(),
3765 )
3766 .await?;
3767
3768 match result {
3769 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3770 let mut connection_pool = session.connection_pool().await;
3771 notify_membership_updated(
3772 &mut connection_pool,
3773 membership_update,
3774 member_id,
3775 &session.peer,
3776 )
3777 }
3778 db::SetMemberRoleResult::InviteUpdated(channel) => {
3779 let update = proto::UpdateChannels {
3780 channel_invitations: vec![channel.to_proto()],
3781 ..Default::default()
3782 };
3783
3784 for connection_id in session
3785 .connection_pool()
3786 .await
3787 .user_connection_ids(member_id)
3788 {
3789 session.peer.send(connection_id, update.clone())?;
3790 }
3791 }
3792 }
3793
3794 response.send(proto::Ack {})?;
3795 Ok(())
3796}
3797
3798/// Change the name of a channel
3799async fn rename_channel(
3800 request: proto::RenameChannel,
3801 response: Response<proto::RenameChannel>,
3802 session: UserSession,
3803) -> Result<()> {
3804 let db = session.db().await;
3805 let channel_id = ChannelId::from_proto(request.channel_id);
3806 let channel_model = db
3807 .rename_channel(channel_id, session.user_id(), &request.name)
3808 .await?;
3809 let root_id = channel_model.root_id();
3810 let channel = Channel::from_model(channel_model);
3811
3812 response.send(proto::RenameChannelResponse {
3813 channel: Some(channel.to_proto()),
3814 })?;
3815
3816 let connection_pool = session.connection_pool().await;
3817 let update = proto::UpdateChannels {
3818 channels: vec![channel.to_proto()],
3819 ..Default::default()
3820 };
3821 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3822 if role.can_see_channel(channel.visibility) {
3823 session.peer.send(connection_id, update.clone())?;
3824 }
3825 }
3826
3827 Ok(())
3828}
3829
3830/// Move a channel to a new parent.
3831async fn move_channel(
3832 request: proto::MoveChannel,
3833 response: Response<proto::MoveChannel>,
3834 session: UserSession,
3835) -> Result<()> {
3836 let channel_id = ChannelId::from_proto(request.channel_id);
3837 let to = ChannelId::from_proto(request.to);
3838
3839 let (root_id, channels) = session
3840 .db()
3841 .await
3842 .move_channel(channel_id, to, session.user_id())
3843 .await?;
3844
3845 let connection_pool = session.connection_pool().await;
3846 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3847 let channels = channels
3848 .iter()
3849 .filter_map(|channel| {
3850 if role.can_see_channel(channel.visibility) {
3851 Some(channel.to_proto())
3852 } else {
3853 None
3854 }
3855 })
3856 .collect::<Vec<_>>();
3857 if channels.is_empty() {
3858 continue;
3859 }
3860
3861 let update = proto::UpdateChannels {
3862 channels,
3863 ..Default::default()
3864 };
3865
3866 session.peer.send(connection_id, update.clone())?;
3867 }
3868
3869 response.send(Ack {})?;
3870 Ok(())
3871}
3872
3873/// Get the list of channel members
3874async fn get_channel_members(
3875 request: proto::GetChannelMembers,
3876 response: Response<proto::GetChannelMembers>,
3877 session: UserSession,
3878) -> Result<()> {
3879 let db = session.db().await;
3880 let channel_id = ChannelId::from_proto(request.channel_id);
3881 let limit = if request.limit == 0 {
3882 u16::MAX as u64
3883 } else {
3884 request.limit
3885 };
3886 let (members, users) = db
3887 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3888 .await?;
3889 response.send(proto::GetChannelMembersResponse { members, users })?;
3890 Ok(())
3891}
3892
3893/// Accept or decline a channel invitation.
3894async fn respond_to_channel_invite(
3895 request: proto::RespondToChannelInvite,
3896 response: Response<proto::RespondToChannelInvite>,
3897 session: UserSession,
3898) -> Result<()> {
3899 let db = session.db().await;
3900 let channel_id = ChannelId::from_proto(request.channel_id);
3901 let RespondToChannelInvite {
3902 membership_update,
3903 notifications,
3904 } = db
3905 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3906 .await?;
3907
3908 let mut connection_pool = session.connection_pool().await;
3909 if let Some(membership_update) = membership_update {
3910 notify_membership_updated(
3911 &mut connection_pool,
3912 membership_update,
3913 session.user_id(),
3914 &session.peer,
3915 );
3916 } else {
3917 let update = proto::UpdateChannels {
3918 remove_channel_invitations: vec![channel_id.to_proto()],
3919 ..Default::default()
3920 };
3921
3922 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3923 session.peer.send(connection_id, update.clone())?;
3924 }
3925 };
3926
3927 send_notifications(&connection_pool, &session.peer, notifications);
3928
3929 response.send(proto::Ack {})?;
3930
3931 Ok(())
3932}
3933
3934/// Join the channels' room
3935async fn join_channel(
3936 request: proto::JoinChannel,
3937 response: Response<proto::JoinChannel>,
3938 session: UserSession,
3939) -> Result<()> {
3940 let channel_id = ChannelId::from_proto(request.channel_id);
3941 join_channel_internal(channel_id, Box::new(response), session).await
3942}
3943
3944trait JoinChannelInternalResponse {
3945 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3946}
3947impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3948 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3949 Response::<proto::JoinChannel>::send(self, result)
3950 }
3951}
3952impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3953 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3954 Response::<proto::JoinRoom>::send(self, result)
3955 }
3956}
3957
3958async fn join_channel_internal(
3959 channel_id: ChannelId,
3960 response: Box<impl JoinChannelInternalResponse>,
3961 session: UserSession,
3962) -> Result<()> {
3963 let joined_room = {
3964 let mut db = session.db().await;
3965 // If zed quits without leaving the room, and the user re-opens zed before the
3966 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3967 // room they were in.
3968 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3969 tracing::info!(
3970 stale_connection_id = %connection,
3971 "cleaning up stale connection",
3972 );
3973 drop(db);
3974 leave_room_for_session(&session, connection).await?;
3975 db = session.db().await;
3976 }
3977
3978 let (joined_room, membership_updated, role) = db
3979 .join_channel(channel_id, session.user_id(), session.connection_id)
3980 .await?;
3981
3982 let live_kit_connection_info =
3983 session
3984 .app_state
3985 .live_kit_client
3986 .as_ref()
3987 .and_then(|live_kit| {
3988 let (can_publish, token) = if role == ChannelRole::Guest {
3989 (
3990 false,
3991 live_kit
3992 .guest_token(
3993 &joined_room.room.live_kit_room,
3994 &session.user_id().to_string(),
3995 )
3996 .trace_err()?,
3997 )
3998 } else {
3999 (
4000 true,
4001 live_kit
4002 .room_token(
4003 &joined_room.room.live_kit_room,
4004 &session.user_id().to_string(),
4005 )
4006 .trace_err()?,
4007 )
4008 };
4009
4010 Some(LiveKitConnectionInfo {
4011 server_url: live_kit.url().into(),
4012 token,
4013 can_publish,
4014 })
4015 });
4016
4017 response.send(proto::JoinRoomResponse {
4018 room: Some(joined_room.room.clone()),
4019 channel_id: joined_room
4020 .channel
4021 .as_ref()
4022 .map(|channel| channel.id.to_proto()),
4023 live_kit_connection_info,
4024 })?;
4025
4026 let mut connection_pool = session.connection_pool().await;
4027 if let Some(membership_updated) = membership_updated {
4028 notify_membership_updated(
4029 &mut connection_pool,
4030 membership_updated,
4031 session.user_id(),
4032 &session.peer,
4033 );
4034 }
4035
4036 room_updated(&joined_room.room, &session.peer);
4037
4038 joined_room
4039 };
4040
4041 channel_updated(
4042 &joined_room
4043 .channel
4044 .ok_or_else(|| anyhow!("channel not returned"))?,
4045 &joined_room.room,
4046 &session.peer,
4047 &*session.connection_pool().await,
4048 );
4049
4050 update_user_contacts(session.user_id(), &session).await?;
4051 Ok(())
4052}
4053
4054/// Start editing the channel notes
4055async fn join_channel_buffer(
4056 request: proto::JoinChannelBuffer,
4057 response: Response<proto::JoinChannelBuffer>,
4058 session: UserSession,
4059) -> Result<()> {
4060 let db = session.db().await;
4061 let channel_id = ChannelId::from_proto(request.channel_id);
4062
4063 let open_response = db
4064 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4065 .await?;
4066
4067 let collaborators = open_response.collaborators.clone();
4068 response.send(open_response)?;
4069
4070 let update = UpdateChannelBufferCollaborators {
4071 channel_id: channel_id.to_proto(),
4072 collaborators: collaborators.clone(),
4073 };
4074 channel_buffer_updated(
4075 session.connection_id,
4076 collaborators
4077 .iter()
4078 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4079 &update,
4080 &session.peer,
4081 );
4082
4083 Ok(())
4084}
4085
4086/// Edit the channel notes
4087async fn update_channel_buffer(
4088 request: proto::UpdateChannelBuffer,
4089 session: UserSession,
4090) -> Result<()> {
4091 let db = session.db().await;
4092 let channel_id = ChannelId::from_proto(request.channel_id);
4093
4094 let (collaborators, epoch, version) = db
4095 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4096 .await?;
4097
4098 channel_buffer_updated(
4099 session.connection_id,
4100 collaborators.clone(),
4101 &proto::UpdateChannelBuffer {
4102 channel_id: channel_id.to_proto(),
4103 operations: request.operations,
4104 },
4105 &session.peer,
4106 );
4107
4108 let pool = &*session.connection_pool().await;
4109
4110 let non_collaborators =
4111 pool.channel_connection_ids(channel_id)
4112 .filter_map(|(connection_id, _)| {
4113 if collaborators.contains(&connection_id) {
4114 None
4115 } else {
4116 Some(connection_id)
4117 }
4118 });
4119
4120 broadcast(None, non_collaborators, |peer_id| {
4121 session.peer.send(
4122 peer_id,
4123 proto::UpdateChannels {
4124 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4125 channel_id: channel_id.to_proto(),
4126 epoch: epoch as u64,
4127 version: version.clone(),
4128 }],
4129 ..Default::default()
4130 },
4131 )
4132 });
4133
4134 Ok(())
4135}
4136
4137/// Rejoin the channel notes after a connection blip
4138async fn rejoin_channel_buffers(
4139 request: proto::RejoinChannelBuffers,
4140 response: Response<proto::RejoinChannelBuffers>,
4141 session: UserSession,
4142) -> Result<()> {
4143 let db = session.db().await;
4144 let buffers = db
4145 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4146 .await?;
4147
4148 for rejoined_buffer in &buffers {
4149 let collaborators_to_notify = rejoined_buffer
4150 .buffer
4151 .collaborators
4152 .iter()
4153 .filter_map(|c| Some(c.peer_id?.into()));
4154 channel_buffer_updated(
4155 session.connection_id,
4156 collaborators_to_notify,
4157 &proto::UpdateChannelBufferCollaborators {
4158 channel_id: rejoined_buffer.buffer.channel_id,
4159 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4160 },
4161 &session.peer,
4162 );
4163 }
4164
4165 response.send(proto::RejoinChannelBuffersResponse {
4166 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4167 })?;
4168
4169 Ok(())
4170}
4171
4172/// Stop editing the channel notes
4173async fn leave_channel_buffer(
4174 request: proto::LeaveChannelBuffer,
4175 response: Response<proto::LeaveChannelBuffer>,
4176 session: UserSession,
4177) -> Result<()> {
4178 let db = session.db().await;
4179 let channel_id = ChannelId::from_proto(request.channel_id);
4180
4181 let left_buffer = db
4182 .leave_channel_buffer(channel_id, session.connection_id)
4183 .await?;
4184
4185 response.send(Ack {})?;
4186
4187 channel_buffer_updated(
4188 session.connection_id,
4189 left_buffer.connections,
4190 &proto::UpdateChannelBufferCollaborators {
4191 channel_id: channel_id.to_proto(),
4192 collaborators: left_buffer.collaborators,
4193 },
4194 &session.peer,
4195 );
4196
4197 Ok(())
4198}
4199
4200fn channel_buffer_updated<T: EnvelopedMessage>(
4201 sender_id: ConnectionId,
4202 collaborators: impl IntoIterator<Item = ConnectionId>,
4203 message: &T,
4204 peer: &Peer,
4205) {
4206 broadcast(Some(sender_id), collaborators, |peer_id| {
4207 peer.send(peer_id, message.clone())
4208 });
4209}
4210
4211fn send_notifications(
4212 connection_pool: &ConnectionPool,
4213 peer: &Peer,
4214 notifications: db::NotificationBatch,
4215) {
4216 for (user_id, notification) in notifications {
4217 for connection_id in connection_pool.user_connection_ids(user_id) {
4218 if let Err(error) = peer.send(
4219 connection_id,
4220 proto::AddNotification {
4221 notification: Some(notification.clone()),
4222 },
4223 ) {
4224 tracing::error!(
4225 "failed to send notification to {:?} {}",
4226 connection_id,
4227 error
4228 );
4229 }
4230 }
4231 }
4232}
4233
4234/// Send a message to the channel
4235async fn send_channel_message(
4236 request: proto::SendChannelMessage,
4237 response: Response<proto::SendChannelMessage>,
4238 session: UserSession,
4239) -> Result<()> {
4240 // Validate the message body.
4241 let body = request.body.trim().to_string();
4242 if body.len() > MAX_MESSAGE_LEN {
4243 return Err(anyhow!("message is too long"))?;
4244 }
4245 if body.is_empty() {
4246 return Err(anyhow!("message can't be blank"))?;
4247 }
4248
4249 // TODO: adjust mentions if body is trimmed
4250
4251 let timestamp = OffsetDateTime::now_utc();
4252 let nonce = request
4253 .nonce
4254 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4255
4256 let channel_id = ChannelId::from_proto(request.channel_id);
4257 let CreatedChannelMessage {
4258 message_id,
4259 participant_connection_ids,
4260 notifications,
4261 } = session
4262 .db()
4263 .await
4264 .create_channel_message(
4265 channel_id,
4266 session.user_id(),
4267 &body,
4268 &request.mentions,
4269 timestamp,
4270 nonce.clone().into(),
4271 match request.reply_to_message_id {
4272 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4273 None => None,
4274 },
4275 )
4276 .await?;
4277
4278 let message = proto::ChannelMessage {
4279 sender_id: session.user_id().to_proto(),
4280 id: message_id.to_proto(),
4281 body,
4282 mentions: request.mentions,
4283 timestamp: timestamp.unix_timestamp() as u64,
4284 nonce: Some(nonce),
4285 reply_to_message_id: request.reply_to_message_id,
4286 edited_at: None,
4287 };
4288 broadcast(
4289 Some(session.connection_id),
4290 participant_connection_ids.clone(),
4291 |connection| {
4292 session.peer.send(
4293 connection,
4294 proto::ChannelMessageSent {
4295 channel_id: channel_id.to_proto(),
4296 message: Some(message.clone()),
4297 },
4298 )
4299 },
4300 );
4301 response.send(proto::SendChannelMessageResponse {
4302 message: Some(message),
4303 })?;
4304
4305 let pool = &*session.connection_pool().await;
4306 let non_participants =
4307 pool.channel_connection_ids(channel_id)
4308 .filter_map(|(connection_id, _)| {
4309 if participant_connection_ids.contains(&connection_id) {
4310 None
4311 } else {
4312 Some(connection_id)
4313 }
4314 });
4315 broadcast(None, non_participants, |peer_id| {
4316 session.peer.send(
4317 peer_id,
4318 proto::UpdateChannels {
4319 latest_channel_message_ids: vec![proto::ChannelMessageId {
4320 channel_id: channel_id.to_proto(),
4321 message_id: message_id.to_proto(),
4322 }],
4323 ..Default::default()
4324 },
4325 )
4326 });
4327 send_notifications(pool, &session.peer, notifications);
4328
4329 Ok(())
4330}
4331
4332/// Delete a channel message
4333async fn remove_channel_message(
4334 request: proto::RemoveChannelMessage,
4335 response: Response<proto::RemoveChannelMessage>,
4336 session: UserSession,
4337) -> Result<()> {
4338 let channel_id = ChannelId::from_proto(request.channel_id);
4339 let message_id = MessageId::from_proto(request.message_id);
4340 let (connection_ids, existing_notification_ids) = session
4341 .db()
4342 .await
4343 .remove_channel_message(channel_id, message_id, session.user_id())
4344 .await?;
4345
4346 broadcast(
4347 Some(session.connection_id),
4348 connection_ids,
4349 move |connection| {
4350 session.peer.send(connection, request.clone())?;
4351
4352 for notification_id in &existing_notification_ids {
4353 session.peer.send(
4354 connection,
4355 proto::DeleteNotification {
4356 notification_id: (*notification_id).to_proto(),
4357 },
4358 )?;
4359 }
4360
4361 Ok(())
4362 },
4363 );
4364 response.send(proto::Ack {})?;
4365 Ok(())
4366}
4367
4368async fn update_channel_message(
4369 request: proto::UpdateChannelMessage,
4370 response: Response<proto::UpdateChannelMessage>,
4371 session: UserSession,
4372) -> Result<()> {
4373 let channel_id = ChannelId::from_proto(request.channel_id);
4374 let message_id = MessageId::from_proto(request.message_id);
4375 let updated_at = OffsetDateTime::now_utc();
4376 let UpdatedChannelMessage {
4377 message_id,
4378 participant_connection_ids,
4379 notifications,
4380 reply_to_message_id,
4381 timestamp,
4382 deleted_mention_notification_ids,
4383 updated_mention_notifications,
4384 } = session
4385 .db()
4386 .await
4387 .update_channel_message(
4388 channel_id,
4389 message_id,
4390 session.user_id(),
4391 request.body.as_str(),
4392 &request.mentions,
4393 updated_at,
4394 )
4395 .await?;
4396
4397 let nonce = request
4398 .nonce
4399 .clone()
4400 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4401
4402 let message = proto::ChannelMessage {
4403 sender_id: session.user_id().to_proto(),
4404 id: message_id.to_proto(),
4405 body: request.body.clone(),
4406 mentions: request.mentions.clone(),
4407 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4408 nonce: Some(nonce),
4409 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4410 edited_at: Some(updated_at.unix_timestamp() as u64),
4411 };
4412
4413 response.send(proto::Ack {})?;
4414
4415 let pool = &*session.connection_pool().await;
4416 broadcast(
4417 Some(session.connection_id),
4418 participant_connection_ids,
4419 |connection| {
4420 session.peer.send(
4421 connection,
4422 proto::ChannelMessageUpdate {
4423 channel_id: channel_id.to_proto(),
4424 message: Some(message.clone()),
4425 },
4426 )?;
4427
4428 for notification_id in &deleted_mention_notification_ids {
4429 session.peer.send(
4430 connection,
4431 proto::DeleteNotification {
4432 notification_id: (*notification_id).to_proto(),
4433 },
4434 )?;
4435 }
4436
4437 for notification in &updated_mention_notifications {
4438 session.peer.send(
4439 connection,
4440 proto::UpdateNotification {
4441 notification: Some(notification.clone()),
4442 },
4443 )?;
4444 }
4445
4446 Ok(())
4447 },
4448 );
4449
4450 send_notifications(pool, &session.peer, notifications);
4451
4452 Ok(())
4453}
4454
4455/// Mark a channel message as read
4456async fn acknowledge_channel_message(
4457 request: proto::AckChannelMessage,
4458 session: UserSession,
4459) -> Result<()> {
4460 let channel_id = ChannelId::from_proto(request.channel_id);
4461 let message_id = MessageId::from_proto(request.message_id);
4462 let notifications = session
4463 .db()
4464 .await
4465 .observe_channel_message(channel_id, session.user_id(), message_id)
4466 .await?;
4467 send_notifications(
4468 &*session.connection_pool().await,
4469 &session.peer,
4470 notifications,
4471 );
4472 Ok(())
4473}
4474
4475/// Mark a buffer version as synced
4476async fn acknowledge_buffer_version(
4477 request: proto::AckBufferOperation,
4478 session: UserSession,
4479) -> Result<()> {
4480 let buffer_id = BufferId::from_proto(request.buffer_id);
4481 session
4482 .db()
4483 .await
4484 .observe_buffer_version(
4485 buffer_id,
4486 session.user_id(),
4487 request.epoch as i32,
4488 &request.version,
4489 )
4490 .await?;
4491 Ok(())
4492}
4493
4494async fn count_language_model_tokens(
4495 request: proto::CountLanguageModelTokens,
4496 response: Response<proto::CountLanguageModelTokens>,
4497 session: Session,
4498 config: &Config,
4499) -> Result<()> {
4500 let Some(session) = session.for_user() else {
4501 return Err(anyhow!("user not found"))?;
4502 };
4503 authorize_access_to_legacy_llm_endpoints(&session).await?;
4504
4505 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4506 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4507 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4508 };
4509
4510 session
4511 .app_state
4512 .rate_limiter
4513 .check(&*rate_limit, session.user_id())
4514 .await?;
4515
4516 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4517 Some(proto::LanguageModelProvider::Google) => {
4518 let api_key = config
4519 .google_ai_api_key
4520 .as_ref()
4521 .context("no Google AI API key configured on the server")?;
4522 google_ai::count_tokens(
4523 session.http_client.as_ref(),
4524 google_ai::API_URL,
4525 api_key,
4526 serde_json::from_str(&request.request)?,
4527 )
4528 .await?
4529 }
4530 _ => return Err(anyhow!("unsupported provider"))?,
4531 };
4532
4533 response.send(proto::CountLanguageModelTokensResponse {
4534 token_count: result.total_tokens as u32,
4535 })?;
4536
4537 Ok(())
4538}
4539
4540struct ZedProCountLanguageModelTokensRateLimit;
4541
4542impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4543 fn capacity(&self) -> usize {
4544 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4545 .ok()
4546 .and_then(|v| v.parse().ok())
4547 .unwrap_or(600) // Picked arbitrarily
4548 }
4549
4550 fn refill_duration(&self) -> chrono::Duration {
4551 chrono::Duration::hours(1)
4552 }
4553
4554 fn db_name(&self) -> &'static str {
4555 "zed-pro:count-language-model-tokens"
4556 }
4557}
4558
4559struct FreeCountLanguageModelTokensRateLimit;
4560
4561impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4562 fn capacity(&self) -> usize {
4563 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4564 .ok()
4565 .and_then(|v| v.parse().ok())
4566 .unwrap_or(600 / 10) // Picked arbitrarily
4567 }
4568
4569 fn refill_duration(&self) -> chrono::Duration {
4570 chrono::Duration::hours(1)
4571 }
4572
4573 fn db_name(&self) -> &'static str {
4574 "free:count-language-model-tokens"
4575 }
4576}
4577
4578struct ZedProComputeEmbeddingsRateLimit;
4579
4580impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4581 fn capacity(&self) -> usize {
4582 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4583 .ok()
4584 .and_then(|v| v.parse().ok())
4585 .unwrap_or(5000) // Picked arbitrarily
4586 }
4587
4588 fn refill_duration(&self) -> chrono::Duration {
4589 chrono::Duration::hours(1)
4590 }
4591
4592 fn db_name(&self) -> &'static str {
4593 "zed-pro:compute-embeddings"
4594 }
4595}
4596
4597struct FreeComputeEmbeddingsRateLimit;
4598
4599impl RateLimit for FreeComputeEmbeddingsRateLimit {
4600 fn capacity(&self) -> usize {
4601 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4602 .ok()
4603 .and_then(|v| v.parse().ok())
4604 .unwrap_or(5000 / 10) // Picked arbitrarily
4605 }
4606
4607 fn refill_duration(&self) -> chrono::Duration {
4608 chrono::Duration::hours(1)
4609 }
4610
4611 fn db_name(&self) -> &'static str {
4612 "free:compute-embeddings"
4613 }
4614}
4615
4616async fn compute_embeddings(
4617 request: proto::ComputeEmbeddings,
4618 response: Response<proto::ComputeEmbeddings>,
4619 session: UserSession,
4620 api_key: Option<Arc<str>>,
4621) -> Result<()> {
4622 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4623 authorize_access_to_legacy_llm_endpoints(&session).await?;
4624
4625 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4626 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4627 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4628 };
4629
4630 session
4631 .app_state
4632 .rate_limiter
4633 .check(&*rate_limit, session.user_id())
4634 .await?;
4635
4636 let embeddings = match request.model.as_str() {
4637 "openai/text-embedding-3-small" => {
4638 open_ai::embed(
4639 session.http_client.as_ref(),
4640 OPEN_AI_API_URL,
4641 &api_key,
4642 OpenAiEmbeddingModel::TextEmbedding3Small,
4643 request.texts.iter().map(|text| text.as_str()),
4644 )
4645 .await?
4646 }
4647 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4648 };
4649
4650 let embeddings = request
4651 .texts
4652 .iter()
4653 .map(|text| {
4654 let mut hasher = sha2::Sha256::new();
4655 hasher.update(text.as_bytes());
4656 let result = hasher.finalize();
4657 result.to_vec()
4658 })
4659 .zip(
4660 embeddings
4661 .data
4662 .into_iter()
4663 .map(|embedding| embedding.embedding),
4664 )
4665 .collect::<HashMap<_, _>>();
4666
4667 let db = session.db().await;
4668 db.save_embeddings(&request.model, &embeddings)
4669 .await
4670 .context("failed to save embeddings")
4671 .trace_err();
4672
4673 response.send(proto::ComputeEmbeddingsResponse {
4674 embeddings: embeddings
4675 .into_iter()
4676 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4677 .collect(),
4678 })?;
4679 Ok(())
4680}
4681
4682async fn get_cached_embeddings(
4683 request: proto::GetCachedEmbeddings,
4684 response: Response<proto::GetCachedEmbeddings>,
4685 session: UserSession,
4686) -> Result<()> {
4687 authorize_access_to_legacy_llm_endpoints(&session).await?;
4688
4689 let db = session.db().await;
4690 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4691
4692 response.send(proto::GetCachedEmbeddingsResponse {
4693 embeddings: embeddings
4694 .into_iter()
4695 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4696 .collect(),
4697 })?;
4698 Ok(())
4699}
4700
4701/// This is leftover from before the LLM service.
4702///
4703/// The endpoints protected by this check will be moved there eventually.
4704async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4705 if session.is_staff() {
4706 Ok(())
4707 } else {
4708 Err(anyhow!("permission denied"))?
4709 }
4710}
4711
4712/// Get a Supermaven API key for the user
4713async fn get_supermaven_api_key(
4714 _request: proto::GetSupermavenApiKey,
4715 response: Response<proto::GetSupermavenApiKey>,
4716 session: UserSession,
4717) -> Result<()> {
4718 let user_id: String = session.user_id().to_string();
4719 if !session.is_staff() {
4720 return Err(anyhow!("supermaven not enabled for this account"))?;
4721 }
4722
4723 let email = session
4724 .email()
4725 .ok_or_else(|| anyhow!("user must have an email"))?;
4726
4727 let supermaven_admin_api = session
4728 .supermaven_client
4729 .as_ref()
4730 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4731
4732 let result = supermaven_admin_api
4733 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4734 .await?;
4735
4736 response.send(proto::GetSupermavenApiKeyResponse {
4737 api_key: result.api_key,
4738 })?;
4739
4740 Ok(())
4741}
4742
4743/// Start receiving chat updates for a channel
4744async fn join_channel_chat(
4745 request: proto::JoinChannelChat,
4746 response: Response<proto::JoinChannelChat>,
4747 session: UserSession,
4748) -> Result<()> {
4749 let channel_id = ChannelId::from_proto(request.channel_id);
4750
4751 let db = session.db().await;
4752 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4753 .await?;
4754 let messages = db
4755 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4756 .await?;
4757 response.send(proto::JoinChannelChatResponse {
4758 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4759 messages,
4760 })?;
4761 Ok(())
4762}
4763
4764/// Stop receiving chat updates for a channel
4765async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4766 let channel_id = ChannelId::from_proto(request.channel_id);
4767 session
4768 .db()
4769 .await
4770 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4771 .await?;
4772 Ok(())
4773}
4774
4775/// Retrieve the chat history for a channel
4776async fn get_channel_messages(
4777 request: proto::GetChannelMessages,
4778 response: Response<proto::GetChannelMessages>,
4779 session: UserSession,
4780) -> Result<()> {
4781 let channel_id = ChannelId::from_proto(request.channel_id);
4782 let messages = session
4783 .db()
4784 .await
4785 .get_channel_messages(
4786 channel_id,
4787 session.user_id(),
4788 MESSAGE_COUNT_PER_PAGE,
4789 Some(MessageId::from_proto(request.before_message_id)),
4790 )
4791 .await?;
4792 response.send(proto::GetChannelMessagesResponse {
4793 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4794 messages,
4795 })?;
4796 Ok(())
4797}
4798
4799/// Retrieve specific chat messages
4800async fn get_channel_messages_by_id(
4801 request: proto::GetChannelMessagesById,
4802 response: Response<proto::GetChannelMessagesById>,
4803 session: UserSession,
4804) -> Result<()> {
4805 let message_ids = request
4806 .message_ids
4807 .iter()
4808 .map(|id| MessageId::from_proto(*id))
4809 .collect::<Vec<_>>();
4810 let messages = session
4811 .db()
4812 .await
4813 .get_channel_messages_by_id(session.user_id(), &message_ids)
4814 .await?;
4815 response.send(proto::GetChannelMessagesResponse {
4816 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4817 messages,
4818 })?;
4819 Ok(())
4820}
4821
4822/// Retrieve the current users notifications
4823async fn get_notifications(
4824 request: proto::GetNotifications,
4825 response: Response<proto::GetNotifications>,
4826 session: UserSession,
4827) -> Result<()> {
4828 let notifications = session
4829 .db()
4830 .await
4831 .get_notifications(
4832 session.user_id(),
4833 NOTIFICATION_COUNT_PER_PAGE,
4834 request
4835 .before_id
4836 .map(|id| db::NotificationId::from_proto(id)),
4837 )
4838 .await?;
4839 response.send(proto::GetNotificationsResponse {
4840 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4841 notifications,
4842 })?;
4843 Ok(())
4844}
4845
4846/// Mark notifications as read
4847async fn mark_notification_as_read(
4848 request: proto::MarkNotificationRead,
4849 response: Response<proto::MarkNotificationRead>,
4850 session: UserSession,
4851) -> Result<()> {
4852 let database = &session.db().await;
4853 let notifications = database
4854 .mark_notification_as_read_by_id(
4855 session.user_id(),
4856 NotificationId::from_proto(request.notification_id),
4857 )
4858 .await?;
4859 send_notifications(
4860 &*session.connection_pool().await,
4861 &session.peer,
4862 notifications,
4863 );
4864 response.send(proto::Ack {})?;
4865 Ok(())
4866}
4867
4868/// Get the current users information
4869async fn get_private_user_info(
4870 _request: proto::GetPrivateUserInfo,
4871 response: Response<proto::GetPrivateUserInfo>,
4872 session: UserSession,
4873) -> Result<()> {
4874 let db = session.db().await;
4875
4876 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4877 let user = db
4878 .get_user_by_id(session.user_id())
4879 .await?
4880 .ok_or_else(|| anyhow!("user not found"))?;
4881 let flags = db.get_user_flags(session.user_id()).await?;
4882
4883 response.send(proto::GetPrivateUserInfoResponse {
4884 metrics_id,
4885 staff: user.admin,
4886 flags,
4887 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4888 })?;
4889 Ok(())
4890}
4891
4892/// Accept the terms of service (tos) on behalf of the current user
4893async fn accept_terms_of_service(
4894 _request: proto::AcceptTermsOfService,
4895 response: Response<proto::AcceptTermsOfService>,
4896 session: UserSession,
4897) -> Result<()> {
4898 let db = session.db().await;
4899
4900 let accepted_tos_at = Utc::now();
4901 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4902 .await?;
4903
4904 response.send(proto::AcceptTermsOfServiceResponse {
4905 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4906 })?;
4907 Ok(())
4908}
4909
4910/// The minimum account age an account must have in order to use the LLM service.
4911const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4912
4913async fn get_llm_api_token(
4914 _request: proto::GetLlmToken,
4915 response: Response<proto::GetLlmToken>,
4916 session: UserSession,
4917) -> Result<()> {
4918 let db = session.db().await;
4919
4920 let flags = db.get_user_flags(session.user_id()).await?;
4921 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4922 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4923
4924 if !session.is_staff() && !has_language_models_feature_flag {
4925 Err(anyhow!("permission denied"))?
4926 }
4927
4928 let user_id = session.user_id();
4929 let user = db
4930 .get_user_by_id(user_id)
4931 .await?
4932 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4933
4934 if user.accepted_tos_at.is_none() {
4935 Err(anyhow!("terms of service not accepted"))?
4936 }
4937
4938 let mut account_created_at = user.created_at;
4939 if let Some(github_created_at) = user.github_user_created_at {
4940 account_created_at = account_created_at.min(github_created_at);
4941 }
4942 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4943 Err(anyhow!("account too young"))?
4944 }
4945 let token = LlmTokenClaims::create(
4946 user.id,
4947 user.github_login.clone(),
4948 session.is_staff(),
4949 has_llm_closed_beta_feature_flag,
4950 session.current_plan(db).await?,
4951 &session.app_state.config,
4952 )?;
4953 response.send(proto::GetLlmTokenResponse { token })?;
4954 Ok(())
4955}
4956
4957fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4958 let message = match message {
4959 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4960 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4961 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4962 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4963 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4964 code: frame.code.into(),
4965 reason: frame.reason,
4966 })),
4967 // We should never receive a frame while reading the message, according
4968 // to the `tungstenite` maintainers:
4969 //
4970 // > It cannot occur when you read messages from the WebSocket, but it
4971 // > can be used when you want to send the raw frames (e.g. you want to
4972 // > send the frames to the WebSocket without composing the full message first).
4973 // >
4974 // > — https://github.com/snapview/tungstenite-rs/issues/268
4975 TungsteniteMessage::Frame(_) => {
4976 bail!("received an unexpected frame while reading the message")
4977 }
4978 };
4979
4980 Ok(message)
4981}
4982
4983fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4984 match message {
4985 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4986 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4987 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4988 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4989 AxumMessage::Close(frame) => {
4990 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4991 code: frame.code.into(),
4992 reason: frame.reason,
4993 }))
4994 }
4995 }
4996}
4997
4998fn notify_membership_updated(
4999 connection_pool: &mut ConnectionPool,
5000 result: MembershipUpdated,
5001 user_id: UserId,
5002 peer: &Peer,
5003) {
5004 for membership in &result.new_channels.channel_memberships {
5005 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5006 }
5007 for channel_id in &result.removed_channels {
5008 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5009 }
5010
5011 let user_channels_update = proto::UpdateUserChannels {
5012 channel_memberships: result
5013 .new_channels
5014 .channel_memberships
5015 .iter()
5016 .map(|cm| proto::ChannelMembership {
5017 channel_id: cm.channel_id.to_proto(),
5018 role: cm.role.into(),
5019 })
5020 .collect(),
5021 ..Default::default()
5022 };
5023
5024 let mut update = build_channels_update(result.new_channels);
5025 update.delete_channels = result
5026 .removed_channels
5027 .into_iter()
5028 .map(|id| id.to_proto())
5029 .collect();
5030 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5031
5032 for connection_id in connection_pool.user_connection_ids(user_id) {
5033 peer.send(connection_id, user_channels_update.clone())
5034 .trace_err();
5035 peer.send(connection_id, update.clone()).trace_err();
5036 }
5037}
5038
5039fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5040 proto::UpdateUserChannels {
5041 channel_memberships: channels
5042 .channel_memberships
5043 .iter()
5044 .map(|m| proto::ChannelMembership {
5045 channel_id: m.channel_id.to_proto(),
5046 role: m.role.into(),
5047 })
5048 .collect(),
5049 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5050 observed_channel_message_id: channels.observed_channel_messages.clone(),
5051 }
5052}
5053
5054fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5055 let mut update = proto::UpdateChannels::default();
5056
5057 for channel in channels.channels {
5058 update.channels.push(channel.to_proto());
5059 }
5060
5061 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5062 update.latest_channel_message_ids = channels.latest_channel_messages;
5063
5064 for (channel_id, participants) in channels.channel_participants {
5065 update
5066 .channel_participants
5067 .push(proto::ChannelParticipants {
5068 channel_id: channel_id.to_proto(),
5069 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5070 });
5071 }
5072
5073 for channel in channels.invited_channels {
5074 update.channel_invitations.push(channel.to_proto());
5075 }
5076
5077 update.hosted_projects = channels.hosted_projects;
5078 update
5079}
5080
5081fn build_initial_contacts_update(
5082 contacts: Vec<db::Contact>,
5083 pool: &ConnectionPool,
5084) -> proto::UpdateContacts {
5085 let mut update = proto::UpdateContacts::default();
5086
5087 for contact in contacts {
5088 match contact {
5089 db::Contact::Accepted { user_id, busy } => {
5090 update.contacts.push(contact_for_user(user_id, busy, &pool));
5091 }
5092 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5093 db::Contact::Incoming { user_id } => {
5094 update
5095 .incoming_requests
5096 .push(proto::IncomingContactRequest {
5097 requester_id: user_id.to_proto(),
5098 })
5099 }
5100 }
5101 }
5102
5103 update
5104}
5105
5106fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5107 proto::Contact {
5108 user_id: user_id.to_proto(),
5109 online: pool.is_user_online(user_id),
5110 busy,
5111 }
5112}
5113
5114fn room_updated(room: &proto::Room, peer: &Peer) {
5115 broadcast(
5116 None,
5117 room.participants
5118 .iter()
5119 .filter_map(|participant| Some(participant.peer_id?.into())),
5120 |peer_id| {
5121 peer.send(
5122 peer_id,
5123 proto::RoomUpdated {
5124 room: Some(room.clone()),
5125 },
5126 )
5127 },
5128 );
5129}
5130
5131fn channel_updated(
5132 channel: &db::channel::Model,
5133 room: &proto::Room,
5134 peer: &Peer,
5135 pool: &ConnectionPool,
5136) {
5137 let participants = room
5138 .participants
5139 .iter()
5140 .map(|p| p.user_id)
5141 .collect::<Vec<_>>();
5142
5143 broadcast(
5144 None,
5145 pool.channel_connection_ids(channel.root_id())
5146 .filter_map(|(channel_id, role)| {
5147 role.can_see_channel(channel.visibility).then(|| channel_id)
5148 }),
5149 |peer_id| {
5150 peer.send(
5151 peer_id,
5152 proto::UpdateChannels {
5153 channel_participants: vec![proto::ChannelParticipants {
5154 channel_id: channel.id.to_proto(),
5155 participant_user_ids: participants.clone(),
5156 }],
5157 ..Default::default()
5158 },
5159 )
5160 },
5161 );
5162}
5163
5164async fn send_dev_server_projects_update(
5165 user_id: UserId,
5166 mut status: proto::DevServerProjectsUpdate,
5167 session: &Session,
5168) {
5169 let pool = session.connection_pool().await;
5170 for dev_server in &mut status.dev_servers {
5171 dev_server.status =
5172 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5173 }
5174 let connections = pool.user_connection_ids(user_id);
5175 for connection_id in connections {
5176 session.peer.send(connection_id, status.clone()).trace_err();
5177 }
5178}
5179
5180async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5181 let db = session.db().await;
5182
5183 let contacts = db.get_contacts(user_id).await?;
5184 let busy = db.is_user_busy(user_id).await?;
5185
5186 let pool = session.connection_pool().await;
5187 let updated_contact = contact_for_user(user_id, busy, &pool);
5188 for contact in contacts {
5189 if let db::Contact::Accepted {
5190 user_id: contact_user_id,
5191 ..
5192 } = contact
5193 {
5194 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5195 session
5196 .peer
5197 .send(
5198 contact_conn_id,
5199 proto::UpdateContacts {
5200 contacts: vec![updated_contact.clone()],
5201 remove_contacts: Default::default(),
5202 incoming_requests: Default::default(),
5203 remove_incoming_requests: Default::default(),
5204 outgoing_requests: Default::default(),
5205 remove_outgoing_requests: Default::default(),
5206 },
5207 )
5208 .trace_err();
5209 }
5210 }
5211 }
5212 Ok(())
5213}
5214
5215async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5216 log::info!("lost dev server connection, unsharing projects");
5217 let project_ids = session
5218 .db()
5219 .await
5220 .get_stale_dev_server_projects(session.connection_id)
5221 .await?;
5222
5223 for project_id in project_ids {
5224 // not unshare re-checks the connection ids match, so we get away with no transaction
5225 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5226 }
5227
5228 let user_id = session.dev_server().user_id;
5229 let update = session
5230 .db()
5231 .await
5232 .dev_server_projects_update(user_id)
5233 .await?;
5234
5235 send_dev_server_projects_update(user_id, update, session).await;
5236
5237 Ok(())
5238}
5239
5240async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5241 let mut contacts_to_update = HashSet::default();
5242
5243 let room_id;
5244 let canceled_calls_to_user_ids;
5245 let live_kit_room;
5246 let delete_live_kit_room;
5247 let room;
5248 let channel;
5249
5250 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5251 contacts_to_update.insert(session.user_id());
5252
5253 for project in left_room.left_projects.values() {
5254 project_left(project, session);
5255 }
5256
5257 room_id = RoomId::from_proto(left_room.room.id);
5258 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5259 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5260 delete_live_kit_room = left_room.deleted;
5261 room = mem::take(&mut left_room.room);
5262 channel = mem::take(&mut left_room.channel);
5263
5264 room_updated(&room, &session.peer);
5265 } else {
5266 return Ok(());
5267 }
5268
5269 if let Some(channel) = channel {
5270 channel_updated(
5271 &channel,
5272 &room,
5273 &session.peer,
5274 &*session.connection_pool().await,
5275 );
5276 }
5277
5278 {
5279 let pool = session.connection_pool().await;
5280 for canceled_user_id in canceled_calls_to_user_ids {
5281 for connection_id in pool.user_connection_ids(canceled_user_id) {
5282 session
5283 .peer
5284 .send(
5285 connection_id,
5286 proto::CallCanceled {
5287 room_id: room_id.to_proto(),
5288 },
5289 )
5290 .trace_err();
5291 }
5292 contacts_to_update.insert(canceled_user_id);
5293 }
5294 }
5295
5296 for contact_user_id in contacts_to_update {
5297 update_user_contacts(contact_user_id, &session).await?;
5298 }
5299
5300 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5301 live_kit
5302 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5303 .await
5304 .trace_err();
5305
5306 if delete_live_kit_room {
5307 live_kit.delete_room(live_kit_room).await.trace_err();
5308 }
5309 }
5310
5311 Ok(())
5312}
5313
5314async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5315 let left_channel_buffers = session
5316 .db()
5317 .await
5318 .leave_channel_buffers(session.connection_id)
5319 .await?;
5320
5321 for left_buffer in left_channel_buffers {
5322 channel_buffer_updated(
5323 session.connection_id,
5324 left_buffer.connections,
5325 &proto::UpdateChannelBufferCollaborators {
5326 channel_id: left_buffer.channel_id.to_proto(),
5327 collaborators: left_buffer.collaborators,
5328 },
5329 &session.peer,
5330 );
5331 }
5332
5333 Ok(())
5334}
5335
5336fn project_left(project: &db::LeftProject, session: &UserSession) {
5337 for connection_id in &project.connection_ids {
5338 if project.should_unshare {
5339 session
5340 .peer
5341 .send(
5342 *connection_id,
5343 proto::UnshareProject {
5344 project_id: project.id.to_proto(),
5345 },
5346 )
5347 .trace_err();
5348 } else {
5349 session
5350 .peer
5351 .send(
5352 *connection_id,
5353 proto::RemoveProjectCollaborator {
5354 project_id: project.id.to_proto(),
5355 peer_id: Some(session.connection_id.into()),
5356 },
5357 )
5358 .trace_err();
5359 }
5360 }
5361}
5362
5363pub trait ResultExt {
5364 type Ok;
5365
5366 fn trace_err(self) -> Option<Self::Ok>;
5367}
5368
5369impl<T, E> ResultExt for Result<T, E>
5370where
5371 E: std::fmt::Debug,
5372{
5373 type Ok = T;
5374
5375 #[track_caller]
5376 fn trace_err(self) -> Option<T> {
5377 match self {
5378 Ok(value) => Some(value),
5379 Err(error) => {
5380 tracing::error!("{:?}", error);
5381 None
5382 }
5383 }
5384 }
5385}