1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::{
5 auth,
6 db::{
7 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
9 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
10 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
11 ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, RateLimiter, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
37use sha2::Digest;
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39
40use futures::{
41 channel::oneshot,
42 future::{self, BoxFuture},
43 stream::FuturesUnordered,
44 FutureExt, SinkExt, StreamExt, TryStreamExt,
45};
46use http_client::IsahcHttpClient;
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79use self::connection_pool::VersionedMessage;
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107struct StreamingResponse<R: RequestMessage> {
108 peer: Arc<Peer>,
109 receipt: Receipt<R>,
110}
111
112impl<R: RequestMessage> StreamingResponse<R> {
113 fn send(&self, payload: R::Response) -> Result<()> {
114 self.peer.respond(self.receipt, payload)?;
115 Ok(())
116 }
117}
118
119#[derive(Clone, Debug)]
120pub enum Principal {
121 User(User),
122 Impersonated { user: User, admin: User },
123 DevServer(dev_server::Model),
124}
125
126impl Principal {
127 fn update_span(&self, span: &tracing::Span) {
128 match &self {
129 Principal::User(user) => {
130 span.record("user_id", &user.id.0);
131 span.record("login", &user.github_login);
132 }
133 Principal::Impersonated { user, admin } => {
134 span.record("user_id", &user.id.0);
135 span.record("login", &user.github_login);
136 span.record("impersonator", &admin.github_login);
137 }
138 Principal::DevServer(dev_server) => {
139 span.record("dev_server_id", &dev_server.id.0);
140 }
141 }
142 }
143}
144
145#[derive(Clone)]
146struct Session {
147 principal: Principal,
148 connection_id: ConnectionId,
149 db: Arc<tokio::sync::Mutex<DbHandle>>,
150 peer: Arc<Peer>,
151 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
152 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
153 supermaven_client: Option<Arc<SupermavenAdminApi>>,
154 http_client: Arc<IsahcHttpClient>,
155 rate_limiter: Arc<RateLimiter>,
156 /// The GeoIP country code for the user.
157 #[allow(unused)]
158 geoip_country_code: Option<String>,
159 _executor: Executor,
160}
161
162impl Session {
163 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
164 #[cfg(test)]
165 tokio::task::yield_now().await;
166 let guard = self.db.lock().await;
167 #[cfg(test)]
168 tokio::task::yield_now().await;
169 guard
170 }
171
172 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
173 #[cfg(test)]
174 tokio::task::yield_now().await;
175 let guard = self.connection_pool.lock();
176 ConnectionPoolGuard {
177 guard,
178 _not_send: PhantomData,
179 }
180 }
181
182 fn for_user(self) -> Option<UserSession> {
183 UserSession::new(self)
184 }
185
186 fn for_dev_server(self) -> Option<DevServerSession> {
187 DevServerSession::new(self)
188 }
189
190 fn user_id(&self) -> Option<UserId> {
191 match &self.principal {
192 Principal::User(user) => Some(user.id),
193 Principal::Impersonated { user, .. } => Some(user.id),
194 Principal::DevServer(_) => None,
195 }
196 }
197
198 fn is_staff(&self) -> bool {
199 match &self.principal {
200 Principal::User(user) => user.admin,
201 Principal::Impersonated { .. } => true,
202 Principal::DevServer(_) => false,
203 }
204 }
205
206 pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
207 if self.is_staff() {
208 return Ok(proto::Plan::ZedPro);
209 }
210
211 let Some(user_id) = self.user_id() else {
212 return Ok(proto::Plan::Free);
213 };
214
215 let db = self.db().await;
216 if db.has_active_billing_subscription(user_id).await? {
217 Ok(proto::Plan::ZedPro)
218 } else {
219 Ok(proto::Plan::Free)
220 }
221 }
222
223 fn dev_server_id(&self) -> Option<DevServerId> {
224 match &self.principal {
225 Principal::User(_) | Principal::Impersonated { .. } => None,
226 Principal::DevServer(dev_server) => Some(dev_server.id),
227 }
228 }
229
230 fn principal_id(&self) -> PrincipalId {
231 match &self.principal {
232 Principal::User(user) => PrincipalId::UserId(user.id),
233 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
234 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
235 }
236 }
237}
238
239impl Debug for Session {
240 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
241 let mut result = f.debug_struct("Session");
242 match &self.principal {
243 Principal::User(user) => {
244 result.field("user", &user.github_login);
245 }
246 Principal::Impersonated { user, admin } => {
247 result.field("user", &user.github_login);
248 result.field("impersonator", &admin.github_login);
249 }
250 Principal::DevServer(dev_server) => {
251 result.field("dev_server", &dev_server.id);
252 }
253 }
254 result.field("connection_id", &self.connection_id).finish()
255 }
256}
257
258struct UserSession(Session);
259
260impl UserSession {
261 pub fn new(s: Session) -> Option<Self> {
262 s.user_id().map(|_| UserSession(s))
263 }
264 pub fn user_id(&self) -> UserId {
265 self.0.user_id().unwrap()
266 }
267
268 pub fn email(&self) -> Option<String> {
269 match &self.0.principal {
270 Principal::User(user) => user.email_address.clone(),
271 Principal::Impersonated { user, .. } => user.email_address.clone(),
272 Principal::DevServer(..) => None,
273 }
274 }
275}
276
277impl Deref for UserSession {
278 type Target = Session;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284impl DerefMut for UserSession {
285 fn deref_mut(&mut self) -> &mut Self::Target {
286 &mut self.0
287 }
288}
289
290struct DevServerSession(Session);
291
292impl DevServerSession {
293 pub fn new(s: Session) -> Option<Self> {
294 s.dev_server_id().map(|_| DevServerSession(s))
295 }
296 pub fn dev_server_id(&self) -> DevServerId {
297 self.0.dev_server_id().unwrap()
298 }
299
300 fn dev_server(&self) -> &dev_server::Model {
301 match &self.0.principal {
302 Principal::DevServer(dev_server) => dev_server,
303 _ => unreachable!(),
304 }
305 }
306}
307
308impl Deref for DevServerSession {
309 type Target = Session;
310
311 fn deref(&self) -> &Self::Target {
312 &self.0
313 }
314}
315impl DerefMut for DevServerSession {
316 fn deref_mut(&mut self) -> &mut Self::Target {
317 &mut self.0
318 }
319}
320
321fn user_handler<M: RequestMessage, Fut>(
322 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
324where
325 Fut: Send + Future<Output = Result<()>>,
326{
327 let handler = Arc::new(handler);
328 move |message, response, session| {
329 let handler = handler.clone();
330 Box::pin(async move {
331 if let Some(user_session) = session.for_user() {
332 Ok(handler(message, response, user_session).await?)
333 } else {
334 Err(Error::Internal(anyhow!(
335 "must be a user to call {}",
336 M::NAME
337 )))
338 }
339 })
340 }
341}
342
343fn dev_server_handler<M: RequestMessage, Fut>(
344 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
346where
347 Fut: Send + Future<Output = Result<()>>,
348{
349 let handler = Arc::new(handler);
350 move |message, response, session| {
351 let handler = handler.clone();
352 Box::pin(async move {
353 if let Some(dev_server_session) = session.for_dev_server() {
354 Ok(handler(message, response, dev_server_session).await?)
355 } else {
356 Err(Error::Internal(anyhow!(
357 "must be a dev server to call {}",
358 M::NAME
359 )))
360 }
361 })
362 }
363}
364
365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
366 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
368where
369 InnertRetFut: Send + Future<Output = Result<()>>,
370{
371 let handler = Arc::new(handler);
372 move |message, session| {
373 let handler = handler.clone();
374 Box::pin(async move {
375 if let Some(user_session) = session.for_user() {
376 Ok(handler(message, user_session).await?)
377 } else {
378 Err(Error::Internal(anyhow!(
379 "must be a user to call {}",
380 M::NAME
381 )))
382 }
383 })
384 }
385}
386
387struct DbHandle(Arc<Database>);
388
389impl Deref for DbHandle {
390 type Target = Database;
391
392 fn deref(&self) -> &Self::Target {
393 self.0.as_ref()
394 }
395}
396
397pub struct Server {
398 id: parking_lot::Mutex<ServerId>,
399 peer: Arc<Peer>,
400 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
401 app_state: Arc<AppState>,
402 handlers: HashMap<TypeId, MessageHandler>,
403 teardown: watch::Sender<bool>,
404}
405
406pub(crate) struct ConnectionPoolGuard<'a> {
407 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
408 _not_send: PhantomData<Rc<()>>,
409}
410
411#[derive(Serialize)]
412pub struct ServerSnapshot<'a> {
413 peer: &'a Peer,
414 #[serde(serialize_with = "serialize_deref")]
415 connection_pool: ConnectionPoolGuard<'a>,
416}
417
418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
419where
420 S: Serializer,
421 T: Deref<Target = U>,
422 U: Serialize,
423{
424 Serialize::serialize(value.deref(), serializer)
425}
426
427impl Server {
428 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
429 let mut server = Self {
430 id: parking_lot::Mutex::new(id),
431 peer: Peer::new(id.0 as u32),
432 app_state: app_state.clone(),
433 connection_pool: Default::default(),
434 handlers: Default::default(),
435 teardown: watch::channel(false).0,
436 };
437
438 server
439 .add_request_handler(ping)
440 .add_request_handler(user_handler(create_room))
441 .add_request_handler(user_handler(join_room))
442 .add_request_handler(user_handler(rejoin_room))
443 .add_request_handler(user_handler(leave_room))
444 .add_request_handler(user_handler(set_room_participant_role))
445 .add_request_handler(user_handler(call))
446 .add_request_handler(user_handler(cancel_call))
447 .add_message_handler(user_message_handler(decline_call))
448 .add_request_handler(user_handler(update_participant_location))
449 .add_request_handler(user_handler(share_project))
450 .add_message_handler(unshare_project)
451 .add_request_handler(user_handler(join_project))
452 .add_request_handler(user_handler(join_hosted_project))
453 .add_request_handler(user_handler(rejoin_dev_server_projects))
454 .add_request_handler(user_handler(create_dev_server_project))
455 .add_request_handler(user_handler(update_dev_server_project))
456 .add_request_handler(user_handler(delete_dev_server_project))
457 .add_request_handler(user_handler(create_dev_server))
458 .add_request_handler(user_handler(regenerate_dev_server_token))
459 .add_request_handler(user_handler(rename_dev_server))
460 .add_request_handler(user_handler(delete_dev_server))
461 .add_request_handler(user_handler(list_remote_directory))
462 .add_request_handler(dev_server_handler(share_dev_server_project))
463 .add_request_handler(dev_server_handler(shutdown_dev_server))
464 .add_request_handler(dev_server_handler(reconnect_dev_server))
465 .add_message_handler(user_message_handler(leave_project))
466 .add_request_handler(update_project)
467 .add_request_handler(update_worktree)
468 .add_message_handler(start_language_server)
469 .add_message_handler(update_language_server)
470 .add_message_handler(update_diagnostic_summary)
471 .add_message_handler(update_worktree_settings)
472 .add_request_handler(user_handler(
473 forward_project_request_for_owner::<proto::TaskContextForLocation>,
474 ))
475 .add_request_handler(user_handler(
476 forward_project_request_for_owner::<proto::TaskTemplates>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetHover>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDefinition>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetTypeDefinition>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::GetReferences>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::SearchProject>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::GetDocumentHighlights>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::GetProjectSymbols>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
501 ))
502 .add_request_handler(user_handler(
503 forward_read_only_project_request::<proto::OpenBufferById>,
504 ))
505 .add_request_handler(user_handler(
506 forward_read_only_project_request::<proto::SynchronizeBuffers>,
507 ))
508 .add_request_handler(user_handler(
509 forward_read_only_project_request::<proto::InlayHints>,
510 ))
511 .add_request_handler(user_handler(
512 forward_read_only_project_request::<proto::OpenBufferByPath>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCompletions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
519 ))
520 .add_request_handler(user_handler(
521 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::GetCodeActions>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::ApplyCodeAction>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::PrepareRename>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::PerformRename>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::ReloadBuffers>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::FormatBuffers>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::CreateProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::RenameProjectEntry>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::CopyProjectEntry>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::DeleteProjectEntry>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::ExpandProjectEntry>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::OnTypeFormatting>,
561 ))
562 .add_request_handler(user_handler(
563 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
564 ))
565 .add_request_handler(user_handler(
566 forward_mutating_project_request::<proto::BlameBuffer>,
567 ))
568 .add_request_handler(user_handler(
569 forward_mutating_project_request::<proto::MultiLspQuery>,
570 ))
571 .add_request_handler(user_handler(
572 forward_mutating_project_request::<proto::RestartLanguageServers>,
573 ))
574 .add_request_handler(user_handler(
575 forward_mutating_project_request::<proto::LinkedEditingRange>,
576 ))
577 .add_message_handler(create_buffer_for_peer)
578 .add_request_handler(update_buffer)
579 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
580 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
581 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
582 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
583 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
584 .add_request_handler(get_users)
585 .add_request_handler(user_handler(fuzzy_search_users))
586 .add_request_handler(user_handler(request_contact))
587 .add_request_handler(user_handler(remove_contact))
588 .add_request_handler(user_handler(respond_to_contact_request))
589 .add_message_handler(subscribe_to_channels)
590 .add_request_handler(user_handler(create_channel))
591 .add_request_handler(user_handler(delete_channel))
592 .add_request_handler(user_handler(invite_channel_member))
593 .add_request_handler(user_handler(remove_channel_member))
594 .add_request_handler(user_handler(set_channel_member_role))
595 .add_request_handler(user_handler(set_channel_visibility))
596 .add_request_handler(user_handler(rename_channel))
597 .add_request_handler(user_handler(join_channel_buffer))
598 .add_request_handler(user_handler(leave_channel_buffer))
599 .add_message_handler(user_message_handler(update_channel_buffer))
600 .add_request_handler(user_handler(rejoin_channel_buffers))
601 .add_request_handler(user_handler(get_channel_members))
602 .add_request_handler(user_handler(respond_to_channel_invite))
603 .add_request_handler(user_handler(join_channel))
604 .add_request_handler(user_handler(join_channel_chat))
605 .add_message_handler(user_message_handler(leave_channel_chat))
606 .add_request_handler(user_handler(send_channel_message))
607 .add_request_handler(user_handler(remove_channel_message))
608 .add_request_handler(user_handler(update_channel_message))
609 .add_request_handler(user_handler(get_channel_messages))
610 .add_request_handler(user_handler(get_channel_messages_by_id))
611 .add_request_handler(user_handler(get_notifications))
612 .add_request_handler(user_handler(mark_notification_as_read))
613 .add_request_handler(user_handler(move_channel))
614 .add_request_handler(user_handler(follow))
615 .add_message_handler(user_message_handler(unfollow))
616 .add_message_handler(user_message_handler(update_followers))
617 .add_request_handler(user_handler(get_private_user_info))
618 .add_message_handler(user_message_handler(acknowledge_channel_message))
619 .add_message_handler(user_message_handler(acknowledge_buffer_version))
620 .add_request_handler(user_handler(get_supermaven_api_key))
621 .add_request_handler(user_handler(
622 forward_mutating_project_request::<proto::OpenContext>,
623 ))
624 .add_request_handler(user_handler(
625 forward_mutating_project_request::<proto::CreateContext>,
626 ))
627 .add_request_handler(user_handler(
628 forward_mutating_project_request::<proto::SynchronizeContexts>,
629 ))
630 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
631 .add_message_handler(update_context)
632 .add_request_handler({
633 let app_state = app_state.clone();
634 move |request, response, session| {
635 let app_state = app_state.clone();
636 async move {
637 complete_with_language_model(request, response, session, &app_state.config)
638 .await
639 }
640 }
641 })
642 .add_streaming_request_handler({
643 let app_state = app_state.clone();
644 move |request, response, session| {
645 let app_state = app_state.clone();
646 async move {
647 stream_complete_with_language_model(
648 request,
649 response,
650 session,
651 &app_state.config,
652 )
653 .await
654 }
655 }
656 })
657 .add_request_handler({
658 let app_state = app_state.clone();
659 move |request, response, session| {
660 let app_state = app_state.clone();
661 async move {
662 count_language_model_tokens(request, response, session, &app_state.config)
663 .await
664 }
665 }
666 })
667 .add_request_handler({
668 user_handler(move |request, response, session| {
669 get_cached_embeddings(request, response, session)
670 })
671 })
672 .add_request_handler({
673 let app_state = app_state.clone();
674 user_handler(move |request, response, session| {
675 compute_embeddings(
676 request,
677 response,
678 session,
679 app_state.config.openai_api_key.clone(),
680 )
681 })
682 });
683
684 Arc::new(server)
685 }
686
687 pub async fn start(&self) -> Result<()> {
688 let server_id = *self.id.lock();
689 let app_state = self.app_state.clone();
690 let peer = self.peer.clone();
691 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
692 let pool = self.connection_pool.clone();
693 let live_kit_client = self.app_state.live_kit_client.clone();
694
695 let span = info_span!("start server");
696 self.app_state.executor.spawn_detached(
697 async move {
698 tracing::info!("waiting for cleanup timeout");
699 timeout.await;
700 tracing::info!("cleanup timeout expired, retrieving stale rooms");
701 if let Some((room_ids, channel_ids)) = app_state
702 .db
703 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
704 .await
705 .trace_err()
706 {
707 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
708 tracing::info!(
709 stale_channel_buffer_count = channel_ids.len(),
710 "retrieved stale channel buffers"
711 );
712
713 for channel_id in channel_ids {
714 if let Some(refreshed_channel_buffer) = app_state
715 .db
716 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
717 .await
718 .trace_err()
719 {
720 for connection_id in refreshed_channel_buffer.connection_ids {
721 peer.send(
722 connection_id,
723 proto::UpdateChannelBufferCollaborators {
724 channel_id: channel_id.to_proto(),
725 collaborators: refreshed_channel_buffer
726 .collaborators
727 .clone(),
728 },
729 )
730 .trace_err();
731 }
732 }
733 }
734
735 for room_id in room_ids {
736 let mut contacts_to_update = HashSet::default();
737 let mut canceled_calls_to_user_ids = Vec::new();
738 let mut live_kit_room = String::new();
739 let mut delete_live_kit_room = false;
740
741 if let Some(mut refreshed_room) = app_state
742 .db
743 .clear_stale_room_participants(room_id, server_id)
744 .await
745 .trace_err()
746 {
747 tracing::info!(
748 room_id = room_id.0,
749 new_participant_count = refreshed_room.room.participants.len(),
750 "refreshed room"
751 );
752 room_updated(&refreshed_room.room, &peer);
753 if let Some(channel) = refreshed_room.channel.as_ref() {
754 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
755 }
756 contacts_to_update
757 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
758 contacts_to_update
759 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
760 canceled_calls_to_user_ids =
761 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
762 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
763 delete_live_kit_room = refreshed_room.room.participants.is_empty();
764 }
765
766 {
767 let pool = pool.lock();
768 for canceled_user_id in canceled_calls_to_user_ids {
769 for connection_id in pool.user_connection_ids(canceled_user_id) {
770 peer.send(
771 connection_id,
772 proto::CallCanceled {
773 room_id: room_id.to_proto(),
774 },
775 )
776 .trace_err();
777 }
778 }
779 }
780
781 for user_id in contacts_to_update {
782 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
783 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
784 if let Some((busy, contacts)) = busy.zip(contacts) {
785 let pool = pool.lock();
786 let updated_contact = contact_for_user(user_id, busy, &pool);
787 for contact in contacts {
788 if let db::Contact::Accepted {
789 user_id: contact_user_id,
790 ..
791 } = contact
792 {
793 for contact_conn_id in
794 pool.user_connection_ids(contact_user_id)
795 {
796 peer.send(
797 contact_conn_id,
798 proto::UpdateContacts {
799 contacts: vec![updated_contact.clone()],
800 remove_contacts: Default::default(),
801 incoming_requests: Default::default(),
802 remove_incoming_requests: Default::default(),
803 outgoing_requests: Default::default(),
804 remove_outgoing_requests: Default::default(),
805 },
806 )
807 .trace_err();
808 }
809 }
810 }
811 }
812 }
813
814 if let Some(live_kit) = live_kit_client.as_ref() {
815 if delete_live_kit_room {
816 live_kit.delete_room(live_kit_room).await.trace_err();
817 }
818 }
819 }
820 }
821
822 app_state
823 .db
824 .delete_stale_servers(&app_state.config.zed_environment, server_id)
825 .await
826 .trace_err();
827 }
828 .instrument(span),
829 );
830 Ok(())
831 }
832
833 pub fn teardown(&self) {
834 self.peer.teardown();
835 self.connection_pool.lock().reset();
836 let _ = self.teardown.send(true);
837 }
838
839 #[cfg(test)]
840 pub fn reset(&self, id: ServerId) {
841 self.teardown();
842 *self.id.lock() = id;
843 self.peer.reset(id.0 as u32);
844 let _ = self.teardown.send(false);
845 }
846
847 #[cfg(test)]
848 pub fn id(&self) -> ServerId {
849 *self.id.lock()
850 }
851
852 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
853 where
854 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
855 Fut: 'static + Send + Future<Output = Result<()>>,
856 M: EnvelopedMessage,
857 {
858 let prev_handler = self.handlers.insert(
859 TypeId::of::<M>(),
860 Box::new(move |envelope, session| {
861 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
862 let received_at = envelope.received_at;
863 tracing::info!("message received");
864 let start_time = Instant::now();
865 let future = (handler)(*envelope, session);
866 async move {
867 let result = future.await;
868 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
869 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
870 let queue_duration_ms = total_duration_ms - processing_duration_ms;
871 let payload_type = M::NAME;
872
873 match result {
874 Err(error) => {
875 tracing::error!(
876 ?error,
877 total_duration_ms,
878 processing_duration_ms,
879 queue_duration_ms,
880 payload_type,
881 "error handling message"
882 )
883 }
884 Ok(()) => tracing::info!(
885 total_duration_ms,
886 processing_duration_ms,
887 queue_duration_ms,
888 "finished handling message"
889 ),
890 }
891 }
892 .boxed()
893 }),
894 );
895 if prev_handler.is_some() {
896 panic!("registered a handler for the same message twice");
897 }
898 self
899 }
900
901 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
902 where
903 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
904 Fut: 'static + Send + Future<Output = Result<()>>,
905 M: EnvelopedMessage,
906 {
907 self.add_handler(move |envelope, session| handler(envelope.payload, session));
908 self
909 }
910
911 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
912 where
913 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
914 Fut: Send + Future<Output = Result<()>>,
915 M: RequestMessage,
916 {
917 let handler = Arc::new(handler);
918 self.add_handler(move |envelope, session| {
919 let receipt = envelope.receipt();
920 let handler = handler.clone();
921 async move {
922 let peer = session.peer.clone();
923 let responded = Arc::new(AtomicBool::default());
924 let response = Response {
925 peer: peer.clone(),
926 responded: responded.clone(),
927 receipt,
928 };
929 match (handler)(envelope.payload, response, session).await {
930 Ok(()) => {
931 if responded.load(std::sync::atomic::Ordering::SeqCst) {
932 Ok(())
933 } else {
934 Err(anyhow!("handler did not send a response"))?
935 }
936 }
937 Err(error) => {
938 let proto_err = match &error {
939 Error::Internal(err) => err.to_proto(),
940 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
941 };
942 peer.respond_with_error(receipt, proto_err)?;
943 Err(error)
944 }
945 }
946 }
947 })
948 }
949
950 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
951 where
952 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
953 Fut: Send + Future<Output = Result<()>>,
954 M: RequestMessage,
955 {
956 let handler = Arc::new(handler);
957 self.add_handler(move |envelope, session| {
958 let receipt = envelope.receipt();
959 let handler = handler.clone();
960 async move {
961 let peer = session.peer.clone();
962 let response = StreamingResponse {
963 peer: peer.clone(),
964 receipt,
965 };
966 match (handler)(envelope.payload, response, session).await {
967 Ok(()) => {
968 peer.end_stream(receipt)?;
969 Ok(())
970 }
971 Err(error) => {
972 let proto_err = match &error {
973 Error::Internal(err) => err.to_proto(),
974 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
975 };
976 peer.respond_with_error(receipt, proto_err)?;
977 Err(error)
978 }
979 }
980 }
981 })
982 }
983
984 #[allow(clippy::too_many_arguments)]
985 pub fn handle_connection(
986 self: &Arc<Self>,
987 connection: Connection,
988 address: String,
989 principal: Principal,
990 zed_version: ZedVersion,
991 geoip_country_code: Option<String>,
992 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
993 executor: Executor,
994 ) -> impl Future<Output = ()> {
995 let this = self.clone();
996 let span = info_span!("handle connection", %address,
997 connection_id=field::Empty,
998 user_id=field::Empty,
999 login=field::Empty,
1000 impersonator=field::Empty,
1001 dev_server_id=field::Empty
1002 );
1003 principal.update_span(&span);
1004
1005 let mut teardown = self.teardown.subscribe();
1006 async move {
1007 if *teardown.borrow() {
1008 tracing::error!("server is tearing down");
1009 return
1010 }
1011 let (connection_id, handle_io, mut incoming_rx) = this
1012 .peer
1013 .add_connection(connection, {
1014 let executor = executor.clone();
1015 move |duration| executor.sleep(duration)
1016 });
1017 tracing::Span::current()
1018 .record("connection_id", format!("{}", connection_id))
1019 .record("geoip_country_code", geoip_country_code.clone());
1020
1021 tracing::info!("connection opened");
1022
1023 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1024 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1025 Ok(http_client) => Arc::new(http_client),
1026 Err(error) => {
1027 tracing::error!(?error, "failed to create HTTP client");
1028 return;
1029 }
1030 };
1031
1032 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1033 Some(Arc::new(SupermavenAdminApi::new(
1034 supermaven_admin_api_key.to_string(),
1035 http_client.clone(),
1036 )))
1037 } else {
1038 None
1039 };
1040
1041 let session = Session {
1042 principal: principal.clone(),
1043 connection_id,
1044 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1045 peer: this.peer.clone(),
1046 connection_pool: this.connection_pool.clone(),
1047 live_kit_client: this.app_state.live_kit_client.clone(),
1048 http_client,
1049 rate_limiter: this.app_state.rate_limiter.clone(),
1050 geoip_country_code,
1051 _executor: executor.clone(),
1052 supermaven_client,
1053 };
1054
1055 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1056 tracing::error!(?error, "failed to send initial client update");
1057 return;
1058 }
1059
1060 let handle_io = handle_io.fuse();
1061 futures::pin_mut!(handle_io);
1062
1063 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1064 // This prevents deadlocks when e.g., client A performs a request to client B and
1065 // client B performs a request to client A. If both clients stop processing further
1066 // messages until their respective request completes, they won't have a chance to
1067 // respond to the other client's request and cause a deadlock.
1068 //
1069 // This arrangement ensures we will attempt to process earlier messages first, but fall
1070 // back to processing messages arrived later in the spirit of making progress.
1071 let mut foreground_message_handlers = FuturesUnordered::new();
1072 let concurrent_handlers = Arc::new(Semaphore::new(256));
1073 loop {
1074 let next_message = async {
1075 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1076 let message = incoming_rx.next().await;
1077 (permit, message)
1078 }.fuse();
1079 futures::pin_mut!(next_message);
1080 futures::select_biased! {
1081 _ = teardown.changed().fuse() => return,
1082 result = handle_io => {
1083 if let Err(error) = result {
1084 tracing::error!(?error, "error handling I/O");
1085 }
1086 break;
1087 }
1088 _ = foreground_message_handlers.next() => {}
1089 next_message = next_message => {
1090 let (permit, message) = next_message;
1091 if let Some(message) = message {
1092 let type_name = message.payload_type_name();
1093 // note: we copy all the fields from the parent span so we can query them in the logs.
1094 // (https://github.com/tokio-rs/tracing/issues/2670).
1095 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1096 user_id=field::Empty,
1097 login=field::Empty,
1098 impersonator=field::Empty,
1099 dev_server_id=field::Empty
1100 );
1101 principal.update_span(&span);
1102 let span_enter = span.enter();
1103 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1104 let is_background = message.is_background();
1105 let handle_message = (handler)(message, session.clone());
1106 drop(span_enter);
1107
1108 let handle_message = async move {
1109 handle_message.await;
1110 drop(permit);
1111 }.instrument(span);
1112 if is_background {
1113 executor.spawn_detached(handle_message);
1114 } else {
1115 foreground_message_handlers.push(handle_message);
1116 }
1117 } else {
1118 tracing::error!("no message handler");
1119 }
1120 } else {
1121 tracing::info!("connection closed");
1122 break;
1123 }
1124 }
1125 }
1126 }
1127
1128 drop(foreground_message_handlers);
1129 tracing::info!("signing out");
1130 if let Err(error) = connection_lost(session, teardown, executor).await {
1131 tracing::error!(?error, "error signing out");
1132 }
1133
1134 }.instrument(span)
1135 }
1136
1137 async fn send_initial_client_update(
1138 &self,
1139 connection_id: ConnectionId,
1140 principal: &Principal,
1141 zed_version: ZedVersion,
1142 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1143 session: &Session,
1144 ) -> Result<()> {
1145 self.peer.send(
1146 connection_id,
1147 proto::Hello {
1148 peer_id: Some(connection_id.into()),
1149 },
1150 )?;
1151 tracing::info!("sent hello message");
1152 if let Some(send_connection_id) = send_connection_id.take() {
1153 let _ = send_connection_id.send(connection_id);
1154 }
1155
1156 match principal {
1157 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1158 if !user.connected_once {
1159 self.peer.send(connection_id, proto::ShowContacts {})?;
1160 self.app_state
1161 .db
1162 .set_user_connected_once(user.id, true)
1163 .await?;
1164 }
1165
1166 update_user_plan(user.id, session).await?;
1167
1168 let (contacts, dev_server_projects) = future::try_join(
1169 self.app_state.db.get_contacts(user.id),
1170 self.app_state.db.dev_server_projects_update(user.id),
1171 )
1172 .await?;
1173
1174 {
1175 let mut pool = self.connection_pool.lock();
1176 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1177 self.peer.send(
1178 connection_id,
1179 build_initial_contacts_update(contacts, &pool),
1180 )?;
1181 }
1182
1183 if should_auto_subscribe_to_channels(zed_version) {
1184 subscribe_user_to_channels(user.id, session).await?;
1185 }
1186
1187 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1188
1189 if let Some(incoming_call) =
1190 self.app_state.db.incoming_call_for_user(user.id).await?
1191 {
1192 self.peer.send(connection_id, incoming_call)?;
1193 }
1194
1195 update_user_contacts(user.id, &session).await?;
1196 }
1197 Principal::DevServer(dev_server) => {
1198 {
1199 let mut pool = self.connection_pool.lock();
1200 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1201 {
1202 self.peer.send(
1203 stale_connection_id,
1204 proto::ShutdownDevServer {
1205 reason: Some(
1206 "another dev server connected with the same token".to_string(),
1207 ),
1208 },
1209 )?;
1210 pool.remove_connection(stale_connection_id)?;
1211 };
1212 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1213 }
1214
1215 let projects = self
1216 .app_state
1217 .db
1218 .get_projects_for_dev_server(dev_server.id)
1219 .await?;
1220 self.peer
1221 .send(connection_id, proto::DevServerInstructions { projects })?;
1222
1223 let status = self
1224 .app_state
1225 .db
1226 .dev_server_projects_update(dev_server.user_id)
1227 .await?;
1228 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1229 }
1230 }
1231
1232 Ok(())
1233 }
1234
1235 pub async fn invite_code_redeemed(
1236 self: &Arc<Self>,
1237 inviter_id: UserId,
1238 invitee_id: UserId,
1239 ) -> Result<()> {
1240 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1241 if let Some(code) = &user.invite_code {
1242 let pool = self.connection_pool.lock();
1243 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1244 for connection_id in pool.user_connection_ids(inviter_id) {
1245 self.peer.send(
1246 connection_id,
1247 proto::UpdateContacts {
1248 contacts: vec![invitee_contact.clone()],
1249 ..Default::default()
1250 },
1251 )?;
1252 self.peer.send(
1253 connection_id,
1254 proto::UpdateInviteInfo {
1255 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1256 count: user.invite_count as u32,
1257 },
1258 )?;
1259 }
1260 }
1261 }
1262 Ok(())
1263 }
1264
1265 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1266 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1267 if let Some(invite_code) = &user.invite_code {
1268 let pool = self.connection_pool.lock();
1269 for connection_id in pool.user_connection_ids(user_id) {
1270 self.peer.send(
1271 connection_id,
1272 proto::UpdateInviteInfo {
1273 url: format!(
1274 "{}{}",
1275 self.app_state.config.invite_link_prefix, invite_code
1276 ),
1277 count: user.invite_count as u32,
1278 },
1279 )?;
1280 }
1281 }
1282 }
1283 Ok(())
1284 }
1285
1286 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1287 ServerSnapshot {
1288 connection_pool: ConnectionPoolGuard {
1289 guard: self.connection_pool.lock(),
1290 _not_send: PhantomData,
1291 },
1292 peer: &self.peer,
1293 }
1294 }
1295}
1296
1297impl<'a> Deref for ConnectionPoolGuard<'a> {
1298 type Target = ConnectionPool;
1299
1300 fn deref(&self) -> &Self::Target {
1301 &self.guard
1302 }
1303}
1304
1305impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1306 fn deref_mut(&mut self) -> &mut Self::Target {
1307 &mut self.guard
1308 }
1309}
1310
1311impl<'a> Drop for ConnectionPoolGuard<'a> {
1312 fn drop(&mut self) {
1313 #[cfg(test)]
1314 self.check_invariants();
1315 }
1316}
1317
1318fn broadcast<F>(
1319 sender_id: Option<ConnectionId>,
1320 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1321 mut f: F,
1322) where
1323 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1324{
1325 for receiver_id in receiver_ids {
1326 if Some(receiver_id) != sender_id {
1327 if let Err(error) = f(receiver_id) {
1328 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1329 }
1330 }
1331 }
1332}
1333
1334pub struct ProtocolVersion(u32);
1335
1336impl Header for ProtocolVersion {
1337 fn name() -> &'static HeaderName {
1338 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1339 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1340 }
1341
1342 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1343 where
1344 Self: Sized,
1345 I: Iterator<Item = &'i axum::http::HeaderValue>,
1346 {
1347 let version = values
1348 .next()
1349 .ok_or_else(axum::headers::Error::invalid)?
1350 .to_str()
1351 .map_err(|_| axum::headers::Error::invalid())?
1352 .parse()
1353 .map_err(|_| axum::headers::Error::invalid())?;
1354 Ok(Self(version))
1355 }
1356
1357 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1358 values.extend([self.0.to_string().parse().unwrap()]);
1359 }
1360}
1361
1362pub struct AppVersionHeader(SemanticVersion);
1363impl Header for AppVersionHeader {
1364 fn name() -> &'static HeaderName {
1365 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1366 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1367 }
1368
1369 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1370 where
1371 Self: Sized,
1372 I: Iterator<Item = &'i axum::http::HeaderValue>,
1373 {
1374 let version = values
1375 .next()
1376 .ok_or_else(axum::headers::Error::invalid)?
1377 .to_str()
1378 .map_err(|_| axum::headers::Error::invalid())?
1379 .parse()
1380 .map_err(|_| axum::headers::Error::invalid())?;
1381 Ok(Self(version))
1382 }
1383
1384 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1385 values.extend([self.0.to_string().parse().unwrap()]);
1386 }
1387}
1388
1389pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1390 Router::new()
1391 .route("/rpc", get(handle_websocket_request))
1392 .layer(
1393 ServiceBuilder::new()
1394 .layer(Extension(server.app_state.clone()))
1395 .layer(middleware::from_fn(auth::validate_header)),
1396 )
1397 .route("/metrics", get(handle_metrics))
1398 .layer(Extension(server))
1399}
1400
1401pub async fn handle_websocket_request(
1402 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1403 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1404 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1405 Extension(server): Extension<Arc<Server>>,
1406 Extension(principal): Extension<Principal>,
1407 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1408 ws: WebSocketUpgrade,
1409) -> axum::response::Response {
1410 if protocol_version != rpc::PROTOCOL_VERSION {
1411 return (
1412 StatusCode::UPGRADE_REQUIRED,
1413 "client must be upgraded".to_string(),
1414 )
1415 .into_response();
1416 }
1417
1418 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1419 return (
1420 StatusCode::UPGRADE_REQUIRED,
1421 "no version header found".to_string(),
1422 )
1423 .into_response();
1424 };
1425
1426 if !version.can_collaborate() {
1427 return (
1428 StatusCode::UPGRADE_REQUIRED,
1429 "client must be upgraded".to_string(),
1430 )
1431 .into_response();
1432 }
1433
1434 let socket_address = socket_address.to_string();
1435 ws.on_upgrade(move |socket| {
1436 let socket = socket
1437 .map_ok(to_tungstenite_message)
1438 .err_into()
1439 .with(|message| async move { to_axum_message(message) });
1440 let connection = Connection::new(Box::pin(socket));
1441 async move {
1442 server
1443 .handle_connection(
1444 connection,
1445 socket_address,
1446 principal,
1447 version,
1448 country_code_header.map(|header| header.to_string()),
1449 None,
1450 Executor::Production,
1451 )
1452 .await;
1453 }
1454 })
1455}
1456
1457pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1458 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1459 let connections_metric = CONNECTIONS_METRIC
1460 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1461
1462 let connections = server
1463 .connection_pool
1464 .lock()
1465 .connections()
1466 .filter(|connection| !connection.admin)
1467 .count();
1468 connections_metric.set(connections as _);
1469
1470 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1471 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1472 register_int_gauge!(
1473 "shared_projects",
1474 "number of open projects with one or more guests"
1475 )
1476 .unwrap()
1477 });
1478
1479 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1480 shared_projects_metric.set(shared_projects as _);
1481
1482 let encoder = prometheus::TextEncoder::new();
1483 let metric_families = prometheus::gather();
1484 let encoded_metrics = encoder
1485 .encode_to_string(&metric_families)
1486 .map_err(|err| anyhow!("{}", err))?;
1487 Ok(encoded_metrics)
1488}
1489
1490#[instrument(err, skip(executor))]
1491async fn connection_lost(
1492 session: Session,
1493 mut teardown: watch::Receiver<bool>,
1494 executor: Executor,
1495) -> Result<()> {
1496 session.peer.disconnect(session.connection_id);
1497 session
1498 .connection_pool()
1499 .await
1500 .remove_connection(session.connection_id)?;
1501
1502 session
1503 .db()
1504 .await
1505 .connection_lost(session.connection_id)
1506 .await
1507 .trace_err();
1508
1509 futures::select_biased! {
1510 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1511 match &session.principal {
1512 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1513 let session = session.for_user().unwrap();
1514
1515 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1516 leave_room_for_session(&session, session.connection_id).await.trace_err();
1517 leave_channel_buffers_for_session(&session)
1518 .await
1519 .trace_err();
1520
1521 if !session
1522 .connection_pool()
1523 .await
1524 .is_user_online(session.user_id())
1525 {
1526 let db = session.db().await;
1527 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1528 room_updated(&room, &session.peer);
1529 }
1530 }
1531
1532 update_user_contacts(session.user_id(), &session).await?;
1533 },
1534 Principal::DevServer(_) => {
1535 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1536 },
1537 }
1538 },
1539 _ = teardown.changed().fuse() => {}
1540 }
1541
1542 Ok(())
1543}
1544
1545/// Acknowledges a ping from a client, used to keep the connection alive.
1546async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1547 response.send(proto::Ack {})?;
1548 Ok(())
1549}
1550
1551/// Creates a new room for calling (outside of channels)
1552async fn create_room(
1553 _request: proto::CreateRoom,
1554 response: Response<proto::CreateRoom>,
1555 session: UserSession,
1556) -> Result<()> {
1557 let live_kit_room = nanoid::nanoid!(30);
1558
1559 let live_kit_connection_info = util::maybe!(async {
1560 let live_kit = session.live_kit_client.as_ref();
1561 let live_kit = live_kit?;
1562 let user_id = session.user_id().to_string();
1563
1564 let token = live_kit
1565 .room_token(&live_kit_room, &user_id.to_string())
1566 .trace_err()?;
1567
1568 Some(proto::LiveKitConnectionInfo {
1569 server_url: live_kit.url().into(),
1570 token,
1571 can_publish: true,
1572 })
1573 })
1574 .await;
1575
1576 let room = session
1577 .db()
1578 .await
1579 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1580 .await?;
1581
1582 response.send(proto::CreateRoomResponse {
1583 room: Some(room.clone()),
1584 live_kit_connection_info,
1585 })?;
1586
1587 update_user_contacts(session.user_id(), &session).await?;
1588 Ok(())
1589}
1590
1591/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1592async fn join_room(
1593 request: proto::JoinRoom,
1594 response: Response<proto::JoinRoom>,
1595 session: UserSession,
1596) -> Result<()> {
1597 let room_id = RoomId::from_proto(request.id);
1598
1599 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1600
1601 if let Some(channel_id) = channel_id {
1602 return join_channel_internal(channel_id, Box::new(response), session).await;
1603 }
1604
1605 let joined_room = {
1606 let room = session
1607 .db()
1608 .await
1609 .join_room(room_id, session.user_id(), session.connection_id)
1610 .await?;
1611 room_updated(&room.room, &session.peer);
1612 room.into_inner()
1613 };
1614
1615 for connection_id in session
1616 .connection_pool()
1617 .await
1618 .user_connection_ids(session.user_id())
1619 {
1620 session
1621 .peer
1622 .send(
1623 connection_id,
1624 proto::CallCanceled {
1625 room_id: room_id.to_proto(),
1626 },
1627 )
1628 .trace_err();
1629 }
1630
1631 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1632 if let Some(token) = live_kit
1633 .room_token(
1634 &joined_room.room.live_kit_room,
1635 &session.user_id().to_string(),
1636 )
1637 .trace_err()
1638 {
1639 Some(proto::LiveKitConnectionInfo {
1640 server_url: live_kit.url().into(),
1641 token,
1642 can_publish: true,
1643 })
1644 } else {
1645 None
1646 }
1647 } else {
1648 None
1649 };
1650
1651 response.send(proto::JoinRoomResponse {
1652 room: Some(joined_room.room),
1653 channel_id: None,
1654 live_kit_connection_info,
1655 })?;
1656
1657 update_user_contacts(session.user_id(), &session).await?;
1658 Ok(())
1659}
1660
1661/// Rejoin room is used to reconnect to a room after connection errors.
1662async fn rejoin_room(
1663 request: proto::RejoinRoom,
1664 response: Response<proto::RejoinRoom>,
1665 session: UserSession,
1666) -> Result<()> {
1667 let room;
1668 let channel;
1669 {
1670 let mut rejoined_room = session
1671 .db()
1672 .await
1673 .rejoin_room(request, session.user_id(), session.connection_id)
1674 .await?;
1675
1676 response.send(proto::RejoinRoomResponse {
1677 room: Some(rejoined_room.room.clone()),
1678 reshared_projects: rejoined_room
1679 .reshared_projects
1680 .iter()
1681 .map(|project| proto::ResharedProject {
1682 id: project.id.to_proto(),
1683 collaborators: project
1684 .collaborators
1685 .iter()
1686 .map(|collaborator| collaborator.to_proto())
1687 .collect(),
1688 })
1689 .collect(),
1690 rejoined_projects: rejoined_room
1691 .rejoined_projects
1692 .iter()
1693 .map(|rejoined_project| rejoined_project.to_proto())
1694 .collect(),
1695 })?;
1696 room_updated(&rejoined_room.room, &session.peer);
1697
1698 for project in &rejoined_room.reshared_projects {
1699 for collaborator in &project.collaborators {
1700 session
1701 .peer
1702 .send(
1703 collaborator.connection_id,
1704 proto::UpdateProjectCollaborator {
1705 project_id: project.id.to_proto(),
1706 old_peer_id: Some(project.old_connection_id.into()),
1707 new_peer_id: Some(session.connection_id.into()),
1708 },
1709 )
1710 .trace_err();
1711 }
1712
1713 broadcast(
1714 Some(session.connection_id),
1715 project
1716 .collaborators
1717 .iter()
1718 .map(|collaborator| collaborator.connection_id),
1719 |connection_id| {
1720 session.peer.forward_send(
1721 session.connection_id,
1722 connection_id,
1723 proto::UpdateProject {
1724 project_id: project.id.to_proto(),
1725 worktrees: project.worktrees.clone(),
1726 },
1727 )
1728 },
1729 );
1730 }
1731
1732 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1733
1734 let rejoined_room = rejoined_room.into_inner();
1735
1736 room = rejoined_room.room;
1737 channel = rejoined_room.channel;
1738 }
1739
1740 if let Some(channel) = channel {
1741 channel_updated(
1742 &channel,
1743 &room,
1744 &session.peer,
1745 &*session.connection_pool().await,
1746 );
1747 }
1748
1749 update_user_contacts(session.user_id(), &session).await?;
1750 Ok(())
1751}
1752
1753fn notify_rejoined_projects(
1754 rejoined_projects: &mut Vec<RejoinedProject>,
1755 session: &UserSession,
1756) -> Result<()> {
1757 for project in rejoined_projects.iter() {
1758 for collaborator in &project.collaborators {
1759 session
1760 .peer
1761 .send(
1762 collaborator.connection_id,
1763 proto::UpdateProjectCollaborator {
1764 project_id: project.id.to_proto(),
1765 old_peer_id: Some(project.old_connection_id.into()),
1766 new_peer_id: Some(session.connection_id.into()),
1767 },
1768 )
1769 .trace_err();
1770 }
1771 }
1772
1773 for project in rejoined_projects {
1774 for worktree in mem::take(&mut project.worktrees) {
1775 #[cfg(any(test, feature = "test-support"))]
1776 const MAX_CHUNK_SIZE: usize = 2;
1777 #[cfg(not(any(test, feature = "test-support")))]
1778 const MAX_CHUNK_SIZE: usize = 256;
1779
1780 // Stream this worktree's entries.
1781 let message = proto::UpdateWorktree {
1782 project_id: project.id.to_proto(),
1783 worktree_id: worktree.id,
1784 abs_path: worktree.abs_path.clone(),
1785 root_name: worktree.root_name,
1786 updated_entries: worktree.updated_entries,
1787 removed_entries: worktree.removed_entries,
1788 scan_id: worktree.scan_id,
1789 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1790 updated_repositories: worktree.updated_repositories,
1791 removed_repositories: worktree.removed_repositories,
1792 };
1793 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1794 session.peer.send(session.connection_id, update.clone())?;
1795 }
1796
1797 // Stream this worktree's diagnostics.
1798 for summary in worktree.diagnostic_summaries {
1799 session.peer.send(
1800 session.connection_id,
1801 proto::UpdateDiagnosticSummary {
1802 project_id: project.id.to_proto(),
1803 worktree_id: worktree.id,
1804 summary: Some(summary),
1805 },
1806 )?;
1807 }
1808
1809 for settings_file in worktree.settings_files {
1810 session.peer.send(
1811 session.connection_id,
1812 proto::UpdateWorktreeSettings {
1813 project_id: project.id.to_proto(),
1814 worktree_id: worktree.id,
1815 path: settings_file.path,
1816 content: Some(settings_file.content),
1817 },
1818 )?;
1819 }
1820 }
1821
1822 for language_server in &project.language_servers {
1823 session.peer.send(
1824 session.connection_id,
1825 proto::UpdateLanguageServer {
1826 project_id: project.id.to_proto(),
1827 language_server_id: language_server.id,
1828 variant: Some(
1829 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1830 proto::LspDiskBasedDiagnosticsUpdated {},
1831 ),
1832 ),
1833 },
1834 )?;
1835 }
1836 }
1837 Ok(())
1838}
1839
1840/// leave room disconnects from the room.
1841async fn leave_room(
1842 _: proto::LeaveRoom,
1843 response: Response<proto::LeaveRoom>,
1844 session: UserSession,
1845) -> Result<()> {
1846 leave_room_for_session(&session, session.connection_id).await?;
1847 response.send(proto::Ack {})?;
1848 Ok(())
1849}
1850
1851/// Updates the permissions of someone else in the room.
1852async fn set_room_participant_role(
1853 request: proto::SetRoomParticipantRole,
1854 response: Response<proto::SetRoomParticipantRole>,
1855 session: UserSession,
1856) -> Result<()> {
1857 let user_id = UserId::from_proto(request.user_id);
1858 let role = ChannelRole::from(request.role());
1859
1860 let (live_kit_room, can_publish) = {
1861 let room = session
1862 .db()
1863 .await
1864 .set_room_participant_role(
1865 session.user_id(),
1866 RoomId::from_proto(request.room_id),
1867 user_id,
1868 role,
1869 )
1870 .await?;
1871
1872 let live_kit_room = room.live_kit_room.clone();
1873 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1874 room_updated(&room, &session.peer);
1875 (live_kit_room, can_publish)
1876 };
1877
1878 if let Some(live_kit) = session.live_kit_client.as_ref() {
1879 live_kit
1880 .update_participant(
1881 live_kit_room.clone(),
1882 request.user_id.to_string(),
1883 live_kit_server::proto::ParticipantPermission {
1884 can_subscribe: true,
1885 can_publish,
1886 can_publish_data: can_publish,
1887 hidden: false,
1888 recorder: false,
1889 },
1890 )
1891 .await
1892 .trace_err();
1893 }
1894
1895 response.send(proto::Ack {})?;
1896 Ok(())
1897}
1898
1899/// Call someone else into the current room
1900async fn call(
1901 request: proto::Call,
1902 response: Response<proto::Call>,
1903 session: UserSession,
1904) -> Result<()> {
1905 let room_id = RoomId::from_proto(request.room_id);
1906 let calling_user_id = session.user_id();
1907 let calling_connection_id = session.connection_id;
1908 let called_user_id = UserId::from_proto(request.called_user_id);
1909 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1910 if !session
1911 .db()
1912 .await
1913 .has_contact(calling_user_id, called_user_id)
1914 .await?
1915 {
1916 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1917 }
1918
1919 let incoming_call = {
1920 let (room, incoming_call) = &mut *session
1921 .db()
1922 .await
1923 .call(
1924 room_id,
1925 calling_user_id,
1926 calling_connection_id,
1927 called_user_id,
1928 initial_project_id,
1929 )
1930 .await?;
1931 room_updated(&room, &session.peer);
1932 mem::take(incoming_call)
1933 };
1934 update_user_contacts(called_user_id, &session).await?;
1935
1936 let mut calls = session
1937 .connection_pool()
1938 .await
1939 .user_connection_ids(called_user_id)
1940 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1941 .collect::<FuturesUnordered<_>>();
1942
1943 while let Some(call_response) = calls.next().await {
1944 match call_response.as_ref() {
1945 Ok(_) => {
1946 response.send(proto::Ack {})?;
1947 return Ok(());
1948 }
1949 Err(_) => {
1950 call_response.trace_err();
1951 }
1952 }
1953 }
1954
1955 {
1956 let room = session
1957 .db()
1958 .await
1959 .call_failed(room_id, called_user_id)
1960 .await?;
1961 room_updated(&room, &session.peer);
1962 }
1963 update_user_contacts(called_user_id, &session).await?;
1964
1965 Err(anyhow!("failed to ring user"))?
1966}
1967
1968/// Cancel an outgoing call.
1969async fn cancel_call(
1970 request: proto::CancelCall,
1971 response: Response<proto::CancelCall>,
1972 session: UserSession,
1973) -> Result<()> {
1974 let called_user_id = UserId::from_proto(request.called_user_id);
1975 let room_id = RoomId::from_proto(request.room_id);
1976 {
1977 let room = session
1978 .db()
1979 .await
1980 .cancel_call(room_id, session.connection_id, called_user_id)
1981 .await?;
1982 room_updated(&room, &session.peer);
1983 }
1984
1985 for connection_id in session
1986 .connection_pool()
1987 .await
1988 .user_connection_ids(called_user_id)
1989 {
1990 session
1991 .peer
1992 .send(
1993 connection_id,
1994 proto::CallCanceled {
1995 room_id: room_id.to_proto(),
1996 },
1997 )
1998 .trace_err();
1999 }
2000 response.send(proto::Ack {})?;
2001
2002 update_user_contacts(called_user_id, &session).await?;
2003 Ok(())
2004}
2005
2006/// Decline an incoming call.
2007async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2008 let room_id = RoomId::from_proto(message.room_id);
2009 {
2010 let room = session
2011 .db()
2012 .await
2013 .decline_call(Some(room_id), session.user_id())
2014 .await?
2015 .ok_or_else(|| anyhow!("failed to decline call"))?;
2016 room_updated(&room, &session.peer);
2017 }
2018
2019 for connection_id in session
2020 .connection_pool()
2021 .await
2022 .user_connection_ids(session.user_id())
2023 {
2024 session
2025 .peer
2026 .send(
2027 connection_id,
2028 proto::CallCanceled {
2029 room_id: room_id.to_proto(),
2030 },
2031 )
2032 .trace_err();
2033 }
2034 update_user_contacts(session.user_id(), &session).await?;
2035 Ok(())
2036}
2037
2038/// Updates other participants in the room with your current location.
2039async fn update_participant_location(
2040 request: proto::UpdateParticipantLocation,
2041 response: Response<proto::UpdateParticipantLocation>,
2042 session: UserSession,
2043) -> Result<()> {
2044 let room_id = RoomId::from_proto(request.room_id);
2045 let location = request
2046 .location
2047 .ok_or_else(|| anyhow!("invalid location"))?;
2048
2049 let db = session.db().await;
2050 let room = db
2051 .update_room_participant_location(room_id, session.connection_id, location)
2052 .await?;
2053
2054 room_updated(&room, &session.peer);
2055 response.send(proto::Ack {})?;
2056 Ok(())
2057}
2058
2059/// Share a project into the room.
2060async fn share_project(
2061 request: proto::ShareProject,
2062 response: Response<proto::ShareProject>,
2063 session: UserSession,
2064) -> Result<()> {
2065 let (project_id, room) = &*session
2066 .db()
2067 .await
2068 .share_project(
2069 RoomId::from_proto(request.room_id),
2070 session.connection_id,
2071 &request.worktrees,
2072 request
2073 .dev_server_project_id
2074 .map(|id| DevServerProjectId::from_proto(id)),
2075 )
2076 .await?;
2077 response.send(proto::ShareProjectResponse {
2078 project_id: project_id.to_proto(),
2079 })?;
2080 room_updated(&room, &session.peer);
2081
2082 Ok(())
2083}
2084
2085/// Unshare a project from the room.
2086async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2087 let project_id = ProjectId::from_proto(message.project_id);
2088 unshare_project_internal(
2089 project_id,
2090 session.connection_id,
2091 session.user_id(),
2092 &session,
2093 )
2094 .await
2095}
2096
2097async fn unshare_project_internal(
2098 project_id: ProjectId,
2099 connection_id: ConnectionId,
2100 user_id: Option<UserId>,
2101 session: &Session,
2102) -> Result<()> {
2103 let delete = {
2104 let room_guard = session
2105 .db()
2106 .await
2107 .unshare_project(project_id, connection_id, user_id)
2108 .await?;
2109
2110 let (delete, room, guest_connection_ids) = &*room_guard;
2111
2112 let message = proto::UnshareProject {
2113 project_id: project_id.to_proto(),
2114 };
2115
2116 broadcast(
2117 Some(connection_id),
2118 guest_connection_ids.iter().copied(),
2119 |conn_id| session.peer.send(conn_id, message.clone()),
2120 );
2121 if let Some(room) = room {
2122 room_updated(room, &session.peer);
2123 }
2124
2125 *delete
2126 };
2127
2128 if delete {
2129 let db = session.db().await;
2130 db.delete_project(project_id).await?;
2131 }
2132
2133 Ok(())
2134}
2135
2136/// DevServer makes a project available online
2137async fn share_dev_server_project(
2138 request: proto::ShareDevServerProject,
2139 response: Response<proto::ShareDevServerProject>,
2140 session: DevServerSession,
2141) -> Result<()> {
2142 let (dev_server_project, user_id, status) = session
2143 .db()
2144 .await
2145 .share_dev_server_project(
2146 DevServerProjectId::from_proto(request.dev_server_project_id),
2147 session.dev_server_id(),
2148 session.connection_id,
2149 &request.worktrees,
2150 )
2151 .await?;
2152 let Some(project_id) = dev_server_project.project_id else {
2153 return Err(anyhow!("failed to share remote project"))?;
2154 };
2155
2156 send_dev_server_projects_update(user_id, status, &session).await;
2157
2158 response.send(proto::ShareProjectResponse { project_id })?;
2159
2160 Ok(())
2161}
2162
2163/// Join someone elses shared project.
2164async fn join_project(
2165 request: proto::JoinProject,
2166 response: Response<proto::JoinProject>,
2167 session: UserSession,
2168) -> Result<()> {
2169 let project_id = ProjectId::from_proto(request.project_id);
2170
2171 tracing::info!(%project_id, "join project");
2172
2173 let db = session.db().await;
2174 let (project, replica_id) = &mut *db
2175 .join_project(project_id, session.connection_id, session.user_id())
2176 .await?;
2177 drop(db);
2178 tracing::info!(%project_id, "join remote project");
2179 join_project_internal(response, session, project, replica_id)
2180}
2181
2182trait JoinProjectInternalResponse {
2183 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2184}
2185impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2186 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2187 Response::<proto::JoinProject>::send(self, result)
2188 }
2189}
2190impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2191 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2192 Response::<proto::JoinHostedProject>::send(self, result)
2193 }
2194}
2195
2196fn join_project_internal(
2197 response: impl JoinProjectInternalResponse,
2198 session: UserSession,
2199 project: &mut Project,
2200 replica_id: &ReplicaId,
2201) -> Result<()> {
2202 let collaborators = project
2203 .collaborators
2204 .iter()
2205 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2206 .map(|collaborator| collaborator.to_proto())
2207 .collect::<Vec<_>>();
2208 let project_id = project.id;
2209 let guest_user_id = session.user_id();
2210
2211 let worktrees = project
2212 .worktrees
2213 .iter()
2214 .map(|(id, worktree)| proto::WorktreeMetadata {
2215 id: *id,
2216 root_name: worktree.root_name.clone(),
2217 visible: worktree.visible,
2218 abs_path: worktree.abs_path.clone(),
2219 })
2220 .collect::<Vec<_>>();
2221
2222 let add_project_collaborator = proto::AddProjectCollaborator {
2223 project_id: project_id.to_proto(),
2224 collaborator: Some(proto::Collaborator {
2225 peer_id: Some(session.connection_id.into()),
2226 replica_id: replica_id.0 as u32,
2227 user_id: guest_user_id.to_proto(),
2228 }),
2229 };
2230
2231 for collaborator in &collaborators {
2232 session
2233 .peer
2234 .send(
2235 collaborator.peer_id.unwrap().into(),
2236 add_project_collaborator.clone(),
2237 )
2238 .trace_err();
2239 }
2240
2241 // First, we send the metadata associated with each worktree.
2242 response.send(proto::JoinProjectResponse {
2243 project_id: project.id.0 as u64,
2244 worktrees: worktrees.clone(),
2245 replica_id: replica_id.0 as u32,
2246 collaborators: collaborators.clone(),
2247 language_servers: project.language_servers.clone(),
2248 role: project.role.into(),
2249 dev_server_project_id: project
2250 .dev_server_project_id
2251 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2252 })?;
2253
2254 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2255 #[cfg(any(test, feature = "test-support"))]
2256 const MAX_CHUNK_SIZE: usize = 2;
2257 #[cfg(not(any(test, feature = "test-support")))]
2258 const MAX_CHUNK_SIZE: usize = 256;
2259
2260 // Stream this worktree's entries.
2261 let message = proto::UpdateWorktree {
2262 project_id: project_id.to_proto(),
2263 worktree_id,
2264 abs_path: worktree.abs_path.clone(),
2265 root_name: worktree.root_name,
2266 updated_entries: worktree.entries,
2267 removed_entries: Default::default(),
2268 scan_id: worktree.scan_id,
2269 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2270 updated_repositories: worktree.repository_entries.into_values().collect(),
2271 removed_repositories: Default::default(),
2272 };
2273 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2274 session.peer.send(session.connection_id, update.clone())?;
2275 }
2276
2277 // Stream this worktree's diagnostics.
2278 for summary in worktree.diagnostic_summaries {
2279 session.peer.send(
2280 session.connection_id,
2281 proto::UpdateDiagnosticSummary {
2282 project_id: project_id.to_proto(),
2283 worktree_id: worktree.id,
2284 summary: Some(summary),
2285 },
2286 )?;
2287 }
2288
2289 for settings_file in worktree.settings_files {
2290 session.peer.send(
2291 session.connection_id,
2292 proto::UpdateWorktreeSettings {
2293 project_id: project_id.to_proto(),
2294 worktree_id: worktree.id,
2295 path: settings_file.path,
2296 content: Some(settings_file.content),
2297 },
2298 )?;
2299 }
2300 }
2301
2302 for language_server in &project.language_servers {
2303 session.peer.send(
2304 session.connection_id,
2305 proto::UpdateLanguageServer {
2306 project_id: project_id.to_proto(),
2307 language_server_id: language_server.id,
2308 variant: Some(
2309 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2310 proto::LspDiskBasedDiagnosticsUpdated {},
2311 ),
2312 ),
2313 },
2314 )?;
2315 }
2316
2317 Ok(())
2318}
2319
2320/// Leave someone elses shared project.
2321async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2322 let sender_id = session.connection_id;
2323 let project_id = ProjectId::from_proto(request.project_id);
2324 let db = session.db().await;
2325 if db.is_hosted_project(project_id).await? {
2326 let project = db.leave_hosted_project(project_id, sender_id).await?;
2327 project_left(&project, &session);
2328 return Ok(());
2329 }
2330
2331 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2332 tracing::info!(
2333 %project_id,
2334 "leave project"
2335 );
2336
2337 project_left(&project, &session);
2338 if let Some(room) = room {
2339 room_updated(&room, &session.peer);
2340 }
2341
2342 Ok(())
2343}
2344
2345async fn join_hosted_project(
2346 request: proto::JoinHostedProject,
2347 response: Response<proto::JoinHostedProject>,
2348 session: UserSession,
2349) -> Result<()> {
2350 let (mut project, replica_id) = session
2351 .db()
2352 .await
2353 .join_hosted_project(
2354 ProjectId(request.project_id as i32),
2355 session.user_id(),
2356 session.connection_id,
2357 )
2358 .await?;
2359
2360 join_project_internal(response, session, &mut project, &replica_id)
2361}
2362
2363async fn list_remote_directory(
2364 request: proto::ListRemoteDirectory,
2365 response: Response<proto::ListRemoteDirectory>,
2366 session: UserSession,
2367) -> Result<()> {
2368 let dev_server_id = DevServerId(request.dev_server_id as i32);
2369 let dev_server_connection_id = session
2370 .connection_pool()
2371 .await
2372 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2373
2374 session
2375 .db()
2376 .await
2377 .get_dev_server_for_user(dev_server_id, session.user_id())
2378 .await?;
2379
2380 response.send(
2381 session
2382 .peer
2383 .forward_request(session.connection_id, dev_server_connection_id, request)
2384 .await?,
2385 )?;
2386 Ok(())
2387}
2388
2389async fn update_dev_server_project(
2390 request: proto::UpdateDevServerProject,
2391 response: Response<proto::UpdateDevServerProject>,
2392 session: UserSession,
2393) -> Result<()> {
2394 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2395
2396 let (dev_server_project, update) = session
2397 .db()
2398 .await
2399 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2400 .await?;
2401
2402 let projects = session
2403 .db()
2404 .await
2405 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2406 .await?;
2407
2408 let dev_server_connection_id = session
2409 .connection_pool()
2410 .await
2411 .dev_server_connection_id_supporting(
2412 dev_server_project.dev_server_id,
2413 ZedVersion::with_list_directory(),
2414 )?;
2415
2416 session.peer.send(
2417 dev_server_connection_id,
2418 proto::DevServerInstructions { projects },
2419 )?;
2420
2421 send_dev_server_projects_update(session.user_id(), update, &session).await;
2422
2423 response.send(proto::Ack {})
2424}
2425
2426async fn create_dev_server_project(
2427 request: proto::CreateDevServerProject,
2428 response: Response<proto::CreateDevServerProject>,
2429 session: UserSession,
2430) -> Result<()> {
2431 let dev_server_id = DevServerId(request.dev_server_id as i32);
2432 let dev_server_connection_id = session
2433 .connection_pool()
2434 .await
2435 .dev_server_connection_id(dev_server_id);
2436 let Some(dev_server_connection_id) = dev_server_connection_id else {
2437 Err(ErrorCode::DevServerOffline
2438 .message("Cannot create a remote project when the dev server is offline".to_string())
2439 .anyhow())?
2440 };
2441
2442 let path = request.path.clone();
2443 //Check that the path exists on the dev server
2444 session
2445 .peer
2446 .forward_request(
2447 session.connection_id,
2448 dev_server_connection_id,
2449 proto::ValidateDevServerProjectRequest { path: path.clone() },
2450 )
2451 .await?;
2452
2453 let (dev_server_project, update) = session
2454 .db()
2455 .await
2456 .create_dev_server_project(
2457 DevServerId(request.dev_server_id as i32),
2458 &request.path,
2459 session.user_id(),
2460 )
2461 .await?;
2462
2463 let projects = session
2464 .db()
2465 .await
2466 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2467 .await?;
2468
2469 session.peer.send(
2470 dev_server_connection_id,
2471 proto::DevServerInstructions { projects },
2472 )?;
2473
2474 send_dev_server_projects_update(session.user_id(), update, &session).await;
2475
2476 response.send(proto::CreateDevServerProjectResponse {
2477 dev_server_project: Some(dev_server_project.to_proto(None)),
2478 })?;
2479 Ok(())
2480}
2481
2482async fn create_dev_server(
2483 request: proto::CreateDevServer,
2484 response: Response<proto::CreateDevServer>,
2485 session: UserSession,
2486) -> Result<()> {
2487 let access_token = auth::random_token();
2488 let hashed_access_token = auth::hash_access_token(&access_token);
2489
2490 if request.name.is_empty() {
2491 return Err(proto::ErrorCode::Forbidden
2492 .message("Dev server name cannot be empty".to_string())
2493 .anyhow())?;
2494 }
2495
2496 let (dev_server, status) = session
2497 .db()
2498 .await
2499 .create_dev_server(
2500 &request.name,
2501 request.ssh_connection_string.as_deref(),
2502 &hashed_access_token,
2503 session.user_id(),
2504 )
2505 .await?;
2506
2507 send_dev_server_projects_update(session.user_id(), status, &session).await;
2508
2509 response.send(proto::CreateDevServerResponse {
2510 dev_server_id: dev_server.id.0 as u64,
2511 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2512 name: request.name,
2513 })?;
2514 Ok(())
2515}
2516
2517async fn regenerate_dev_server_token(
2518 request: proto::RegenerateDevServerToken,
2519 response: Response<proto::RegenerateDevServerToken>,
2520 session: UserSession,
2521) -> Result<()> {
2522 let dev_server_id = DevServerId(request.dev_server_id as i32);
2523 let access_token = auth::random_token();
2524 let hashed_access_token = auth::hash_access_token(&access_token);
2525
2526 let connection_id = session
2527 .connection_pool()
2528 .await
2529 .dev_server_connection_id(dev_server_id);
2530 if let Some(connection_id) = connection_id {
2531 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2532 session.peer.send(
2533 connection_id,
2534 proto::ShutdownDevServer {
2535 reason: Some("dev server token was regenerated".to_string()),
2536 },
2537 )?;
2538 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2539 }
2540
2541 let status = session
2542 .db()
2543 .await
2544 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2545 .await?;
2546
2547 send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549 response.send(proto::RegenerateDevServerTokenResponse {
2550 dev_server_id: dev_server_id.to_proto(),
2551 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2552 })?;
2553 Ok(())
2554}
2555
2556async fn rename_dev_server(
2557 request: proto::RenameDevServer,
2558 response: Response<proto::RenameDevServer>,
2559 session: UserSession,
2560) -> Result<()> {
2561 if request.name.trim().is_empty() {
2562 return Err(proto::ErrorCode::Forbidden
2563 .message("Dev server name cannot be empty".to_string())
2564 .anyhow())?;
2565 }
2566
2567 let dev_server_id = DevServerId(request.dev_server_id as i32);
2568 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2569 if dev_server.user_id != session.user_id() {
2570 return Err(anyhow!(ErrorCode::Forbidden))?;
2571 }
2572
2573 let status = session
2574 .db()
2575 .await
2576 .rename_dev_server(
2577 dev_server_id,
2578 &request.name,
2579 request.ssh_connection_string.as_deref(),
2580 session.user_id(),
2581 )
2582 .await?;
2583
2584 send_dev_server_projects_update(session.user_id(), status, &session).await;
2585
2586 response.send(proto::Ack {})?;
2587 Ok(())
2588}
2589
2590async fn delete_dev_server(
2591 request: proto::DeleteDevServer,
2592 response: Response<proto::DeleteDevServer>,
2593 session: UserSession,
2594) -> Result<()> {
2595 let dev_server_id = DevServerId(request.dev_server_id as i32);
2596 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2597 if dev_server.user_id != session.user_id() {
2598 return Err(anyhow!(ErrorCode::Forbidden))?;
2599 }
2600
2601 let connection_id = session
2602 .connection_pool()
2603 .await
2604 .dev_server_connection_id(dev_server_id);
2605 if let Some(connection_id) = connection_id {
2606 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2607 session.peer.send(
2608 connection_id,
2609 proto::ShutdownDevServer {
2610 reason: Some("dev server was deleted".to_string()),
2611 },
2612 )?;
2613 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2614 }
2615
2616 let status = session
2617 .db()
2618 .await
2619 .delete_dev_server(dev_server_id, session.user_id())
2620 .await?;
2621
2622 send_dev_server_projects_update(session.user_id(), status, &session).await;
2623
2624 response.send(proto::Ack {})?;
2625 Ok(())
2626}
2627
2628async fn delete_dev_server_project(
2629 request: proto::DeleteDevServerProject,
2630 response: Response<proto::DeleteDevServerProject>,
2631 session: UserSession,
2632) -> Result<()> {
2633 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2634 let dev_server_project = session
2635 .db()
2636 .await
2637 .get_dev_server_project(dev_server_project_id)
2638 .await?;
2639
2640 let dev_server = session
2641 .db()
2642 .await
2643 .get_dev_server(dev_server_project.dev_server_id)
2644 .await?;
2645 if dev_server.user_id != session.user_id() {
2646 return Err(anyhow!(ErrorCode::Forbidden))?;
2647 }
2648
2649 let dev_server_connection_id = session
2650 .connection_pool()
2651 .await
2652 .dev_server_connection_id(dev_server.id);
2653
2654 if let Some(dev_server_connection_id) = dev_server_connection_id {
2655 let project = session
2656 .db()
2657 .await
2658 .find_dev_server_project(dev_server_project_id)
2659 .await;
2660 if let Ok(project) = project {
2661 unshare_project_internal(
2662 project.id,
2663 dev_server_connection_id,
2664 Some(session.user_id()),
2665 &session,
2666 )
2667 .await?;
2668 }
2669 }
2670
2671 let (projects, status) = session
2672 .db()
2673 .await
2674 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2675 .await?;
2676
2677 if let Some(dev_server_connection_id) = dev_server_connection_id {
2678 session.peer.send(
2679 dev_server_connection_id,
2680 proto::DevServerInstructions { projects },
2681 )?;
2682 }
2683
2684 send_dev_server_projects_update(session.user_id(), status, &session).await;
2685
2686 response.send(proto::Ack {})?;
2687 Ok(())
2688}
2689
2690async fn rejoin_dev_server_projects(
2691 request: proto::RejoinRemoteProjects,
2692 response: Response<proto::RejoinRemoteProjects>,
2693 session: UserSession,
2694) -> Result<()> {
2695 let mut rejoined_projects = {
2696 let db = session.db().await;
2697 db.rejoin_dev_server_projects(
2698 &request.rejoined_projects,
2699 session.user_id(),
2700 session.0.connection_id,
2701 )
2702 .await?
2703 };
2704 response.send(proto::RejoinRemoteProjectsResponse {
2705 rejoined_projects: rejoined_projects
2706 .iter()
2707 .map(|project| project.to_proto())
2708 .collect(),
2709 })?;
2710 notify_rejoined_projects(&mut rejoined_projects, &session)
2711}
2712
2713async fn reconnect_dev_server(
2714 request: proto::ReconnectDevServer,
2715 response: Response<proto::ReconnectDevServer>,
2716 session: DevServerSession,
2717) -> Result<()> {
2718 let reshared_projects = {
2719 let db = session.db().await;
2720 db.reshare_dev_server_projects(
2721 &request.reshared_projects,
2722 session.dev_server_id(),
2723 session.0.connection_id,
2724 )
2725 .await?
2726 };
2727
2728 for project in &reshared_projects {
2729 for collaborator in &project.collaborators {
2730 session
2731 .peer
2732 .send(
2733 collaborator.connection_id,
2734 proto::UpdateProjectCollaborator {
2735 project_id: project.id.to_proto(),
2736 old_peer_id: Some(project.old_connection_id.into()),
2737 new_peer_id: Some(session.connection_id.into()),
2738 },
2739 )
2740 .trace_err();
2741 }
2742
2743 broadcast(
2744 Some(session.connection_id),
2745 project
2746 .collaborators
2747 .iter()
2748 .map(|collaborator| collaborator.connection_id),
2749 |connection_id| {
2750 session.peer.forward_send(
2751 session.connection_id,
2752 connection_id,
2753 proto::UpdateProject {
2754 project_id: project.id.to_proto(),
2755 worktrees: project.worktrees.clone(),
2756 },
2757 )
2758 },
2759 );
2760 }
2761
2762 response.send(proto::ReconnectDevServerResponse {
2763 reshared_projects: reshared_projects
2764 .iter()
2765 .map(|project| proto::ResharedProject {
2766 id: project.id.to_proto(),
2767 collaborators: project
2768 .collaborators
2769 .iter()
2770 .map(|collaborator| collaborator.to_proto())
2771 .collect(),
2772 })
2773 .collect(),
2774 })?;
2775
2776 Ok(())
2777}
2778
2779async fn shutdown_dev_server(
2780 _: proto::ShutdownDevServer,
2781 response: Response<proto::ShutdownDevServer>,
2782 session: DevServerSession,
2783) -> Result<()> {
2784 response.send(proto::Ack {})?;
2785 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2786 remove_dev_server_connection(session.dev_server_id(), &session).await
2787}
2788
2789async fn shutdown_dev_server_internal(
2790 dev_server_id: DevServerId,
2791 connection_id: ConnectionId,
2792 session: &Session,
2793) -> Result<()> {
2794 let (dev_server_projects, dev_server) = {
2795 let db = session.db().await;
2796 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2797 let dev_server = db.get_dev_server(dev_server_id).await?;
2798 (dev_server_projects, dev_server)
2799 };
2800
2801 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2802 unshare_project_internal(
2803 ProjectId::from_proto(project_id),
2804 connection_id,
2805 None,
2806 session,
2807 )
2808 .await?;
2809 }
2810
2811 session
2812 .connection_pool()
2813 .await
2814 .set_dev_server_offline(dev_server_id);
2815
2816 let status = session
2817 .db()
2818 .await
2819 .dev_server_projects_update(dev_server.user_id)
2820 .await?;
2821 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2822
2823 Ok(())
2824}
2825
2826async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2827 let dev_server_connection = session
2828 .connection_pool()
2829 .await
2830 .dev_server_connection_id(dev_server_id);
2831
2832 if let Some(dev_server_connection) = dev_server_connection {
2833 session
2834 .connection_pool()
2835 .await
2836 .remove_connection(dev_server_connection)?;
2837 }
2838 Ok(())
2839}
2840
2841/// Updates other participants with changes to the project
2842async fn update_project(
2843 request: proto::UpdateProject,
2844 response: Response<proto::UpdateProject>,
2845 session: Session,
2846) -> Result<()> {
2847 let project_id = ProjectId::from_proto(request.project_id);
2848 let (room, guest_connection_ids) = &*session
2849 .db()
2850 .await
2851 .update_project(project_id, session.connection_id, &request.worktrees)
2852 .await?;
2853 broadcast(
2854 Some(session.connection_id),
2855 guest_connection_ids.iter().copied(),
2856 |connection_id| {
2857 session
2858 .peer
2859 .forward_send(session.connection_id, connection_id, request.clone())
2860 },
2861 );
2862 if let Some(room) = room {
2863 room_updated(&room, &session.peer);
2864 }
2865 response.send(proto::Ack {})?;
2866
2867 Ok(())
2868}
2869
2870/// Updates other participants with changes to the worktree
2871async fn update_worktree(
2872 request: proto::UpdateWorktree,
2873 response: Response<proto::UpdateWorktree>,
2874 session: Session,
2875) -> Result<()> {
2876 let guest_connection_ids = session
2877 .db()
2878 .await
2879 .update_worktree(&request, session.connection_id)
2880 .await?;
2881
2882 broadcast(
2883 Some(session.connection_id),
2884 guest_connection_ids.iter().copied(),
2885 |connection_id| {
2886 session
2887 .peer
2888 .forward_send(session.connection_id, connection_id, request.clone())
2889 },
2890 );
2891 response.send(proto::Ack {})?;
2892 Ok(())
2893}
2894
2895/// Updates other participants with changes to the diagnostics
2896async fn update_diagnostic_summary(
2897 message: proto::UpdateDiagnosticSummary,
2898 session: Session,
2899) -> Result<()> {
2900 let guest_connection_ids = session
2901 .db()
2902 .await
2903 .update_diagnostic_summary(&message, session.connection_id)
2904 .await?;
2905
2906 broadcast(
2907 Some(session.connection_id),
2908 guest_connection_ids.iter().copied(),
2909 |connection_id| {
2910 session
2911 .peer
2912 .forward_send(session.connection_id, connection_id, message.clone())
2913 },
2914 );
2915
2916 Ok(())
2917}
2918
2919/// Updates other participants with changes to the worktree settings
2920async fn update_worktree_settings(
2921 message: proto::UpdateWorktreeSettings,
2922 session: Session,
2923) -> Result<()> {
2924 let guest_connection_ids = session
2925 .db()
2926 .await
2927 .update_worktree_settings(&message, session.connection_id)
2928 .await?;
2929
2930 broadcast(
2931 Some(session.connection_id),
2932 guest_connection_ids.iter().copied(),
2933 |connection_id| {
2934 session
2935 .peer
2936 .forward_send(session.connection_id, connection_id, message.clone())
2937 },
2938 );
2939
2940 Ok(())
2941}
2942
2943/// Notify other participants that a language server has started.
2944async fn start_language_server(
2945 request: proto::StartLanguageServer,
2946 session: Session,
2947) -> Result<()> {
2948 let guest_connection_ids = session
2949 .db()
2950 .await
2951 .start_language_server(&request, session.connection_id)
2952 .await?;
2953
2954 broadcast(
2955 Some(session.connection_id),
2956 guest_connection_ids.iter().copied(),
2957 |connection_id| {
2958 session
2959 .peer
2960 .forward_send(session.connection_id, connection_id, request.clone())
2961 },
2962 );
2963 Ok(())
2964}
2965
2966/// Notify other participants that a language server has changed.
2967async fn update_language_server(
2968 request: proto::UpdateLanguageServer,
2969 session: Session,
2970) -> Result<()> {
2971 let project_id = ProjectId::from_proto(request.project_id);
2972 let project_connection_ids = session
2973 .db()
2974 .await
2975 .project_connection_ids(project_id, session.connection_id, true)
2976 .await?;
2977 broadcast(
2978 Some(session.connection_id),
2979 project_connection_ids.iter().copied(),
2980 |connection_id| {
2981 session
2982 .peer
2983 .forward_send(session.connection_id, connection_id, request.clone())
2984 },
2985 );
2986 Ok(())
2987}
2988
2989/// forward a project request to the host. These requests should be read only
2990/// as guests are allowed to send them.
2991async fn forward_read_only_project_request<T>(
2992 request: T,
2993 response: Response<T>,
2994 session: UserSession,
2995) -> Result<()>
2996where
2997 T: EntityMessage + RequestMessage,
2998{
2999 let project_id = ProjectId::from_proto(request.remote_entity_id());
3000 let host_connection_id = session
3001 .db()
3002 .await
3003 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3004 .await?;
3005 let payload = session
3006 .peer
3007 .forward_request(session.connection_id, host_connection_id, request)
3008 .await?;
3009 response.send(payload)?;
3010 Ok(())
3011}
3012
3013/// forward a project request to the dev server. Only allowed
3014/// if it's your dev server.
3015async fn forward_project_request_for_owner<T>(
3016 request: T,
3017 response: Response<T>,
3018 session: UserSession,
3019) -> Result<()>
3020where
3021 T: EntityMessage + RequestMessage,
3022{
3023 let project_id = ProjectId::from_proto(request.remote_entity_id());
3024
3025 let host_connection_id = session
3026 .db()
3027 .await
3028 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3029 .await?;
3030 let payload = session
3031 .peer
3032 .forward_request(session.connection_id, host_connection_id, request)
3033 .await?;
3034 response.send(payload)?;
3035 Ok(())
3036}
3037
3038/// forward a project request to the host. These requests are disallowed
3039/// for guests.
3040async fn forward_mutating_project_request<T>(
3041 request: T,
3042 response: Response<T>,
3043 session: UserSession,
3044) -> Result<()>
3045where
3046 T: EntityMessage + RequestMessage,
3047{
3048 let project_id = ProjectId::from_proto(request.remote_entity_id());
3049
3050 let host_connection_id = session
3051 .db()
3052 .await
3053 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3054 .await?;
3055 let payload = session
3056 .peer
3057 .forward_request(session.connection_id, host_connection_id, request)
3058 .await?;
3059 response.send(payload)?;
3060 Ok(())
3061}
3062
3063/// forward a project request to the host. These requests are disallowed
3064/// for guests.
3065async fn forward_versioned_mutating_project_request<T>(
3066 request: T,
3067 response: Response<T>,
3068 session: UserSession,
3069) -> Result<()>
3070where
3071 T: EntityMessage + RequestMessage + VersionedMessage,
3072{
3073 let project_id = ProjectId::from_proto(request.remote_entity_id());
3074
3075 let host_connection_id = session
3076 .db()
3077 .await
3078 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3079 .await?;
3080 if let Some(host_version) = session
3081 .connection_pool()
3082 .await
3083 .connection(host_connection_id)
3084 .map(|c| c.zed_version)
3085 {
3086 if let Some(min_required_version) = request.required_host_version() {
3087 if min_required_version > host_version {
3088 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3089 .with_tag("required", &min_required_version.to_string())))?;
3090 }
3091 }
3092 }
3093
3094 let payload = session
3095 .peer
3096 .forward_request(session.connection_id, host_connection_id, request)
3097 .await?;
3098 response.send(payload)?;
3099 Ok(())
3100}
3101
3102/// Notify other participants that a new buffer has been created
3103async fn create_buffer_for_peer(
3104 request: proto::CreateBufferForPeer,
3105 session: Session,
3106) -> Result<()> {
3107 session
3108 .db()
3109 .await
3110 .check_user_is_project_host(
3111 ProjectId::from_proto(request.project_id),
3112 session.connection_id,
3113 )
3114 .await?;
3115 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3116 session
3117 .peer
3118 .forward_send(session.connection_id, peer_id.into(), request)?;
3119 Ok(())
3120}
3121
3122/// Notify other participants that a buffer has been updated. This is
3123/// allowed for guests as long as the update is limited to selections.
3124async fn update_buffer(
3125 request: proto::UpdateBuffer,
3126 response: Response<proto::UpdateBuffer>,
3127 session: Session,
3128) -> Result<()> {
3129 let project_id = ProjectId::from_proto(request.project_id);
3130 let mut capability = Capability::ReadOnly;
3131
3132 for op in request.operations.iter() {
3133 match op.variant {
3134 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3135 Some(_) => capability = Capability::ReadWrite,
3136 }
3137 }
3138
3139 let host = {
3140 let guard = session
3141 .db()
3142 .await
3143 .connections_for_buffer_update(
3144 project_id,
3145 session.principal_id(),
3146 session.connection_id,
3147 capability,
3148 )
3149 .await?;
3150
3151 let (host, guests) = &*guard;
3152
3153 broadcast(
3154 Some(session.connection_id),
3155 guests.clone(),
3156 |connection_id| {
3157 session
3158 .peer
3159 .forward_send(session.connection_id, connection_id, request.clone())
3160 },
3161 );
3162
3163 *host
3164 };
3165
3166 if host != session.connection_id {
3167 session
3168 .peer
3169 .forward_request(session.connection_id, host, request.clone())
3170 .await?;
3171 }
3172
3173 response.send(proto::Ack {})?;
3174 Ok(())
3175}
3176
3177async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3178 let project_id = ProjectId::from_proto(message.project_id);
3179
3180 let operation = message.operation.as_ref().context("invalid operation")?;
3181 let capability = match operation.variant.as_ref() {
3182 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3183 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3184 match buffer_op.variant {
3185 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3186 Capability::ReadOnly
3187 }
3188 _ => Capability::ReadWrite,
3189 }
3190 } else {
3191 Capability::ReadWrite
3192 }
3193 }
3194 Some(_) => Capability::ReadWrite,
3195 None => Capability::ReadOnly,
3196 };
3197
3198 let guard = session
3199 .db()
3200 .await
3201 .connections_for_buffer_update(
3202 project_id,
3203 session.principal_id(),
3204 session.connection_id,
3205 capability,
3206 )
3207 .await?;
3208
3209 let (host, guests) = &*guard;
3210
3211 broadcast(
3212 Some(session.connection_id),
3213 guests.iter().chain([host]).copied(),
3214 |connection_id| {
3215 session
3216 .peer
3217 .forward_send(session.connection_id, connection_id, message.clone())
3218 },
3219 );
3220
3221 Ok(())
3222}
3223
3224/// Notify other participants that a project has been updated.
3225async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3226 request: T,
3227 session: Session,
3228) -> Result<()> {
3229 let project_id = ProjectId::from_proto(request.remote_entity_id());
3230 let project_connection_ids = session
3231 .db()
3232 .await
3233 .project_connection_ids(project_id, session.connection_id, false)
3234 .await?;
3235
3236 broadcast(
3237 Some(session.connection_id),
3238 project_connection_ids.iter().copied(),
3239 |connection_id| {
3240 session
3241 .peer
3242 .forward_send(session.connection_id, connection_id, request.clone())
3243 },
3244 );
3245 Ok(())
3246}
3247
3248/// Start following another user in a call.
3249async fn follow(
3250 request: proto::Follow,
3251 response: Response<proto::Follow>,
3252 session: UserSession,
3253) -> Result<()> {
3254 let room_id = RoomId::from_proto(request.room_id);
3255 let project_id = request.project_id.map(ProjectId::from_proto);
3256 let leader_id = request
3257 .leader_id
3258 .ok_or_else(|| anyhow!("invalid leader id"))?
3259 .into();
3260 let follower_id = session.connection_id;
3261
3262 session
3263 .db()
3264 .await
3265 .check_room_participants(room_id, leader_id, session.connection_id)
3266 .await?;
3267
3268 let response_payload = session
3269 .peer
3270 .forward_request(session.connection_id, leader_id, request)
3271 .await?;
3272 response.send(response_payload)?;
3273
3274 if let Some(project_id) = project_id {
3275 let room = session
3276 .db()
3277 .await
3278 .follow(room_id, project_id, leader_id, follower_id)
3279 .await?;
3280 room_updated(&room, &session.peer);
3281 }
3282
3283 Ok(())
3284}
3285
3286/// Stop following another user in a call.
3287async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3288 let room_id = RoomId::from_proto(request.room_id);
3289 let project_id = request.project_id.map(ProjectId::from_proto);
3290 let leader_id = request
3291 .leader_id
3292 .ok_or_else(|| anyhow!("invalid leader id"))?
3293 .into();
3294 let follower_id = session.connection_id;
3295
3296 session
3297 .db()
3298 .await
3299 .check_room_participants(room_id, leader_id, session.connection_id)
3300 .await?;
3301
3302 session
3303 .peer
3304 .forward_send(session.connection_id, leader_id, request)?;
3305
3306 if let Some(project_id) = project_id {
3307 let room = session
3308 .db()
3309 .await
3310 .unfollow(room_id, project_id, leader_id, follower_id)
3311 .await?;
3312 room_updated(&room, &session.peer);
3313 }
3314
3315 Ok(())
3316}
3317
3318/// Notify everyone following you of your current location.
3319async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3320 let room_id = RoomId::from_proto(request.room_id);
3321 let database = session.db.lock().await;
3322
3323 let connection_ids = if let Some(project_id) = request.project_id {
3324 let project_id = ProjectId::from_proto(project_id);
3325 database
3326 .project_connection_ids(project_id, session.connection_id, true)
3327 .await?
3328 } else {
3329 database
3330 .room_connection_ids(room_id, session.connection_id)
3331 .await?
3332 };
3333
3334 // For now, don't send view update messages back to that view's current leader.
3335 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3336 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3337 _ => None,
3338 });
3339
3340 for connection_id in connection_ids.iter().cloned() {
3341 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3342 session
3343 .peer
3344 .forward_send(session.connection_id, connection_id, request.clone())?;
3345 }
3346 }
3347 Ok(())
3348}
3349
3350/// Get public data about users.
3351async fn get_users(
3352 request: proto::GetUsers,
3353 response: Response<proto::GetUsers>,
3354 session: Session,
3355) -> Result<()> {
3356 let user_ids = request
3357 .user_ids
3358 .into_iter()
3359 .map(UserId::from_proto)
3360 .collect();
3361 let users = session
3362 .db()
3363 .await
3364 .get_users_by_ids(user_ids)
3365 .await?
3366 .into_iter()
3367 .map(|user| proto::User {
3368 id: user.id.to_proto(),
3369 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3370 github_login: user.github_login,
3371 })
3372 .collect();
3373 response.send(proto::UsersResponse { users })?;
3374 Ok(())
3375}
3376
3377/// Search for users (to invite) buy Github login
3378async fn fuzzy_search_users(
3379 request: proto::FuzzySearchUsers,
3380 response: Response<proto::FuzzySearchUsers>,
3381 session: UserSession,
3382) -> Result<()> {
3383 let query = request.query;
3384 let users = match query.len() {
3385 0 => vec![],
3386 1 | 2 => session
3387 .db()
3388 .await
3389 .get_user_by_github_login(&query)
3390 .await?
3391 .into_iter()
3392 .collect(),
3393 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3394 };
3395 let users = users
3396 .into_iter()
3397 .filter(|user| user.id != session.user_id())
3398 .map(|user| proto::User {
3399 id: user.id.to_proto(),
3400 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3401 github_login: user.github_login,
3402 })
3403 .collect();
3404 response.send(proto::UsersResponse { users })?;
3405 Ok(())
3406}
3407
3408/// Send a contact request to another user.
3409async fn request_contact(
3410 request: proto::RequestContact,
3411 response: Response<proto::RequestContact>,
3412 session: UserSession,
3413) -> Result<()> {
3414 let requester_id = session.user_id();
3415 let responder_id = UserId::from_proto(request.responder_id);
3416 if requester_id == responder_id {
3417 return Err(anyhow!("cannot add yourself as a contact"))?;
3418 }
3419
3420 let notifications = session
3421 .db()
3422 .await
3423 .send_contact_request(requester_id, responder_id)
3424 .await?;
3425
3426 // Update outgoing contact requests of requester
3427 let mut update = proto::UpdateContacts::default();
3428 update.outgoing_requests.push(responder_id.to_proto());
3429 for connection_id in session
3430 .connection_pool()
3431 .await
3432 .user_connection_ids(requester_id)
3433 {
3434 session.peer.send(connection_id, update.clone())?;
3435 }
3436
3437 // Update incoming contact requests of responder
3438 let mut update = proto::UpdateContacts::default();
3439 update
3440 .incoming_requests
3441 .push(proto::IncomingContactRequest {
3442 requester_id: requester_id.to_proto(),
3443 });
3444 let connection_pool = session.connection_pool().await;
3445 for connection_id in connection_pool.user_connection_ids(responder_id) {
3446 session.peer.send(connection_id, update.clone())?;
3447 }
3448
3449 send_notifications(&connection_pool, &session.peer, notifications);
3450
3451 response.send(proto::Ack {})?;
3452 Ok(())
3453}
3454
3455/// Accept or decline a contact request
3456async fn respond_to_contact_request(
3457 request: proto::RespondToContactRequest,
3458 response: Response<proto::RespondToContactRequest>,
3459 session: UserSession,
3460) -> Result<()> {
3461 let responder_id = session.user_id();
3462 let requester_id = UserId::from_proto(request.requester_id);
3463 let db = session.db().await;
3464 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3465 db.dismiss_contact_notification(responder_id, requester_id)
3466 .await?;
3467 } else {
3468 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3469
3470 let notifications = db
3471 .respond_to_contact_request(responder_id, requester_id, accept)
3472 .await?;
3473 let requester_busy = db.is_user_busy(requester_id).await?;
3474 let responder_busy = db.is_user_busy(responder_id).await?;
3475
3476 let pool = session.connection_pool().await;
3477 // Update responder with new contact
3478 let mut update = proto::UpdateContacts::default();
3479 if accept {
3480 update
3481 .contacts
3482 .push(contact_for_user(requester_id, requester_busy, &pool));
3483 }
3484 update
3485 .remove_incoming_requests
3486 .push(requester_id.to_proto());
3487 for connection_id in pool.user_connection_ids(responder_id) {
3488 session.peer.send(connection_id, update.clone())?;
3489 }
3490
3491 // Update requester with new contact
3492 let mut update = proto::UpdateContacts::default();
3493 if accept {
3494 update
3495 .contacts
3496 .push(contact_for_user(responder_id, responder_busy, &pool));
3497 }
3498 update
3499 .remove_outgoing_requests
3500 .push(responder_id.to_proto());
3501
3502 for connection_id in pool.user_connection_ids(requester_id) {
3503 session.peer.send(connection_id, update.clone())?;
3504 }
3505
3506 send_notifications(&pool, &session.peer, notifications);
3507 }
3508
3509 response.send(proto::Ack {})?;
3510 Ok(())
3511}
3512
3513/// Remove a contact.
3514async fn remove_contact(
3515 request: proto::RemoveContact,
3516 response: Response<proto::RemoveContact>,
3517 session: UserSession,
3518) -> Result<()> {
3519 let requester_id = session.user_id();
3520 let responder_id = UserId::from_proto(request.user_id);
3521 let db = session.db().await;
3522 let (contact_accepted, deleted_notification_id) =
3523 db.remove_contact(requester_id, responder_id).await?;
3524
3525 let pool = session.connection_pool().await;
3526 // Update outgoing contact requests of requester
3527 let mut update = proto::UpdateContacts::default();
3528 if contact_accepted {
3529 update.remove_contacts.push(responder_id.to_proto());
3530 } else {
3531 update
3532 .remove_outgoing_requests
3533 .push(responder_id.to_proto());
3534 }
3535 for connection_id in pool.user_connection_ids(requester_id) {
3536 session.peer.send(connection_id, update.clone())?;
3537 }
3538
3539 // Update incoming contact requests of responder
3540 let mut update = proto::UpdateContacts::default();
3541 if contact_accepted {
3542 update.remove_contacts.push(requester_id.to_proto());
3543 } else {
3544 update
3545 .remove_incoming_requests
3546 .push(requester_id.to_proto());
3547 }
3548 for connection_id in pool.user_connection_ids(responder_id) {
3549 session.peer.send(connection_id, update.clone())?;
3550 if let Some(notification_id) = deleted_notification_id {
3551 session.peer.send(
3552 connection_id,
3553 proto::DeleteNotification {
3554 notification_id: notification_id.to_proto(),
3555 },
3556 )?;
3557 }
3558 }
3559
3560 response.send(proto::Ack {})?;
3561 Ok(())
3562}
3563
3564fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3565 version.0.minor() < 139
3566}
3567
3568async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3569 let plan = session.current_plan().await?;
3570
3571 session
3572 .peer
3573 .send(
3574 session.connection_id,
3575 proto::UpdateUserPlan { plan: plan.into() },
3576 )
3577 .trace_err();
3578
3579 Ok(())
3580}
3581
3582async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3583 subscribe_user_to_channels(
3584 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3585 &session,
3586 )
3587 .await?;
3588 Ok(())
3589}
3590
3591async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3592 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3593 let mut pool = session.connection_pool().await;
3594 for membership in &channels_for_user.channel_memberships {
3595 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3596 }
3597 session.peer.send(
3598 session.connection_id,
3599 build_update_user_channels(&channels_for_user),
3600 )?;
3601 session.peer.send(
3602 session.connection_id,
3603 build_channels_update(channels_for_user),
3604 )?;
3605 Ok(())
3606}
3607
3608/// Creates a new channel.
3609async fn create_channel(
3610 request: proto::CreateChannel,
3611 response: Response<proto::CreateChannel>,
3612 session: UserSession,
3613) -> Result<()> {
3614 let db = session.db().await;
3615
3616 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3617 let (channel, membership) = db
3618 .create_channel(&request.name, parent_id, session.user_id())
3619 .await?;
3620
3621 let root_id = channel.root_id();
3622 let channel = Channel::from_model(channel);
3623
3624 response.send(proto::CreateChannelResponse {
3625 channel: Some(channel.to_proto()),
3626 parent_id: request.parent_id,
3627 })?;
3628
3629 let mut connection_pool = session.connection_pool().await;
3630 if let Some(membership) = membership {
3631 connection_pool.subscribe_to_channel(
3632 membership.user_id,
3633 membership.channel_id,
3634 membership.role,
3635 );
3636 let update = proto::UpdateUserChannels {
3637 channel_memberships: vec![proto::ChannelMembership {
3638 channel_id: membership.channel_id.to_proto(),
3639 role: membership.role.into(),
3640 }],
3641 ..Default::default()
3642 };
3643 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3644 session.peer.send(connection_id, update.clone())?;
3645 }
3646 }
3647
3648 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3649 if !role.can_see_channel(channel.visibility) {
3650 continue;
3651 }
3652
3653 let update = proto::UpdateChannels {
3654 channels: vec![channel.to_proto()],
3655 ..Default::default()
3656 };
3657 session.peer.send(connection_id, update.clone())?;
3658 }
3659
3660 Ok(())
3661}
3662
3663/// Delete a channel
3664async fn delete_channel(
3665 request: proto::DeleteChannel,
3666 response: Response<proto::DeleteChannel>,
3667 session: UserSession,
3668) -> Result<()> {
3669 let db = session.db().await;
3670
3671 let channel_id = request.channel_id;
3672 let (root_channel, removed_channels) = db
3673 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3674 .await?;
3675 response.send(proto::Ack {})?;
3676
3677 // Notify members of removed channels
3678 let mut update = proto::UpdateChannels::default();
3679 update
3680 .delete_channels
3681 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3682
3683 let connection_pool = session.connection_pool().await;
3684 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3685 session.peer.send(connection_id, update.clone())?;
3686 }
3687
3688 Ok(())
3689}
3690
3691/// Invite someone to join a channel.
3692async fn invite_channel_member(
3693 request: proto::InviteChannelMember,
3694 response: Response<proto::InviteChannelMember>,
3695 session: UserSession,
3696) -> Result<()> {
3697 let db = session.db().await;
3698 let channel_id = ChannelId::from_proto(request.channel_id);
3699 let invitee_id = UserId::from_proto(request.user_id);
3700 let InviteMemberResult {
3701 channel,
3702 notifications,
3703 } = db
3704 .invite_channel_member(
3705 channel_id,
3706 invitee_id,
3707 session.user_id(),
3708 request.role().into(),
3709 )
3710 .await?;
3711
3712 let update = proto::UpdateChannels {
3713 channel_invitations: vec![channel.to_proto()],
3714 ..Default::default()
3715 };
3716
3717 let connection_pool = session.connection_pool().await;
3718 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3719 session.peer.send(connection_id, update.clone())?;
3720 }
3721
3722 send_notifications(&connection_pool, &session.peer, notifications);
3723
3724 response.send(proto::Ack {})?;
3725 Ok(())
3726}
3727
3728/// remove someone from a channel
3729async fn remove_channel_member(
3730 request: proto::RemoveChannelMember,
3731 response: Response<proto::RemoveChannelMember>,
3732 session: UserSession,
3733) -> Result<()> {
3734 let db = session.db().await;
3735 let channel_id = ChannelId::from_proto(request.channel_id);
3736 let member_id = UserId::from_proto(request.user_id);
3737
3738 let RemoveChannelMemberResult {
3739 membership_update,
3740 notification_id,
3741 } = db
3742 .remove_channel_member(channel_id, member_id, session.user_id())
3743 .await?;
3744
3745 let mut connection_pool = session.connection_pool().await;
3746 notify_membership_updated(
3747 &mut connection_pool,
3748 membership_update,
3749 member_id,
3750 &session.peer,
3751 );
3752 for connection_id in connection_pool.user_connection_ids(member_id) {
3753 if let Some(notification_id) = notification_id {
3754 session
3755 .peer
3756 .send(
3757 connection_id,
3758 proto::DeleteNotification {
3759 notification_id: notification_id.to_proto(),
3760 },
3761 )
3762 .trace_err();
3763 }
3764 }
3765
3766 response.send(proto::Ack {})?;
3767 Ok(())
3768}
3769
3770/// Toggle the channel between public and private.
3771/// Care is taken to maintain the invariant that public channels only descend from public channels,
3772/// (though members-only channels can appear at any point in the hierarchy).
3773async fn set_channel_visibility(
3774 request: proto::SetChannelVisibility,
3775 response: Response<proto::SetChannelVisibility>,
3776 session: UserSession,
3777) -> Result<()> {
3778 let db = session.db().await;
3779 let channel_id = ChannelId::from_proto(request.channel_id);
3780 let visibility = request.visibility().into();
3781
3782 let channel_model = db
3783 .set_channel_visibility(channel_id, visibility, session.user_id())
3784 .await?;
3785 let root_id = channel_model.root_id();
3786 let channel = Channel::from_model(channel_model);
3787
3788 let mut connection_pool = session.connection_pool().await;
3789 for (user_id, role) in connection_pool
3790 .channel_user_ids(root_id)
3791 .collect::<Vec<_>>()
3792 .into_iter()
3793 {
3794 let update = if role.can_see_channel(channel.visibility) {
3795 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3796 proto::UpdateChannels {
3797 channels: vec![channel.to_proto()],
3798 ..Default::default()
3799 }
3800 } else {
3801 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3802 proto::UpdateChannels {
3803 delete_channels: vec![channel.id.to_proto()],
3804 ..Default::default()
3805 }
3806 };
3807
3808 for connection_id in connection_pool.user_connection_ids(user_id) {
3809 session.peer.send(connection_id, update.clone())?;
3810 }
3811 }
3812
3813 response.send(proto::Ack {})?;
3814 Ok(())
3815}
3816
3817/// Alter the role for a user in the channel.
3818async fn set_channel_member_role(
3819 request: proto::SetChannelMemberRole,
3820 response: Response<proto::SetChannelMemberRole>,
3821 session: UserSession,
3822) -> Result<()> {
3823 let db = session.db().await;
3824 let channel_id = ChannelId::from_proto(request.channel_id);
3825 let member_id = UserId::from_proto(request.user_id);
3826 let result = db
3827 .set_channel_member_role(
3828 channel_id,
3829 session.user_id(),
3830 member_id,
3831 request.role().into(),
3832 )
3833 .await?;
3834
3835 match result {
3836 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3837 let mut connection_pool = session.connection_pool().await;
3838 notify_membership_updated(
3839 &mut connection_pool,
3840 membership_update,
3841 member_id,
3842 &session.peer,
3843 )
3844 }
3845 db::SetMemberRoleResult::InviteUpdated(channel) => {
3846 let update = proto::UpdateChannels {
3847 channel_invitations: vec![channel.to_proto()],
3848 ..Default::default()
3849 };
3850
3851 for connection_id in session
3852 .connection_pool()
3853 .await
3854 .user_connection_ids(member_id)
3855 {
3856 session.peer.send(connection_id, update.clone())?;
3857 }
3858 }
3859 }
3860
3861 response.send(proto::Ack {})?;
3862 Ok(())
3863}
3864
3865/// Change the name of a channel
3866async fn rename_channel(
3867 request: proto::RenameChannel,
3868 response: Response<proto::RenameChannel>,
3869 session: UserSession,
3870) -> Result<()> {
3871 let db = session.db().await;
3872 let channel_id = ChannelId::from_proto(request.channel_id);
3873 let channel_model = db
3874 .rename_channel(channel_id, session.user_id(), &request.name)
3875 .await?;
3876 let root_id = channel_model.root_id();
3877 let channel = Channel::from_model(channel_model);
3878
3879 response.send(proto::RenameChannelResponse {
3880 channel: Some(channel.to_proto()),
3881 })?;
3882
3883 let connection_pool = session.connection_pool().await;
3884 let update = proto::UpdateChannels {
3885 channels: vec![channel.to_proto()],
3886 ..Default::default()
3887 };
3888 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3889 if role.can_see_channel(channel.visibility) {
3890 session.peer.send(connection_id, update.clone())?;
3891 }
3892 }
3893
3894 Ok(())
3895}
3896
3897/// Move a channel to a new parent.
3898async fn move_channel(
3899 request: proto::MoveChannel,
3900 response: Response<proto::MoveChannel>,
3901 session: UserSession,
3902) -> Result<()> {
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 let to = ChannelId::from_proto(request.to);
3905
3906 let (root_id, channels) = session
3907 .db()
3908 .await
3909 .move_channel(channel_id, to, session.user_id())
3910 .await?;
3911
3912 let connection_pool = session.connection_pool().await;
3913 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3914 let channels = channels
3915 .iter()
3916 .filter_map(|channel| {
3917 if role.can_see_channel(channel.visibility) {
3918 Some(channel.to_proto())
3919 } else {
3920 None
3921 }
3922 })
3923 .collect::<Vec<_>>();
3924 if channels.is_empty() {
3925 continue;
3926 }
3927
3928 let update = proto::UpdateChannels {
3929 channels,
3930 ..Default::default()
3931 };
3932
3933 session.peer.send(connection_id, update.clone())?;
3934 }
3935
3936 response.send(Ack {})?;
3937 Ok(())
3938}
3939
3940/// Get the list of channel members
3941async fn get_channel_members(
3942 request: proto::GetChannelMembers,
3943 response: Response<proto::GetChannelMembers>,
3944 session: UserSession,
3945) -> Result<()> {
3946 let db = session.db().await;
3947 let channel_id = ChannelId::from_proto(request.channel_id);
3948 let limit = if request.limit == 0 {
3949 u16::MAX as u64
3950 } else {
3951 request.limit
3952 };
3953 let (members, users) = db
3954 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3955 .await?;
3956 response.send(proto::GetChannelMembersResponse { members, users })?;
3957 Ok(())
3958}
3959
3960/// Accept or decline a channel invitation.
3961async fn respond_to_channel_invite(
3962 request: proto::RespondToChannelInvite,
3963 response: Response<proto::RespondToChannelInvite>,
3964 session: UserSession,
3965) -> Result<()> {
3966 let db = session.db().await;
3967 let channel_id = ChannelId::from_proto(request.channel_id);
3968 let RespondToChannelInvite {
3969 membership_update,
3970 notifications,
3971 } = db
3972 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3973 .await?;
3974
3975 let mut connection_pool = session.connection_pool().await;
3976 if let Some(membership_update) = membership_update {
3977 notify_membership_updated(
3978 &mut connection_pool,
3979 membership_update,
3980 session.user_id(),
3981 &session.peer,
3982 );
3983 } else {
3984 let update = proto::UpdateChannels {
3985 remove_channel_invitations: vec![channel_id.to_proto()],
3986 ..Default::default()
3987 };
3988
3989 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3990 session.peer.send(connection_id, update.clone())?;
3991 }
3992 };
3993
3994 send_notifications(&connection_pool, &session.peer, notifications);
3995
3996 response.send(proto::Ack {})?;
3997
3998 Ok(())
3999}
4000
4001/// Join the channels' room
4002async fn join_channel(
4003 request: proto::JoinChannel,
4004 response: Response<proto::JoinChannel>,
4005 session: UserSession,
4006) -> Result<()> {
4007 let channel_id = ChannelId::from_proto(request.channel_id);
4008 join_channel_internal(channel_id, Box::new(response), session).await
4009}
4010
4011trait JoinChannelInternalResponse {
4012 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4013}
4014impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4015 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4016 Response::<proto::JoinChannel>::send(self, result)
4017 }
4018}
4019impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4020 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4021 Response::<proto::JoinRoom>::send(self, result)
4022 }
4023}
4024
4025async fn join_channel_internal(
4026 channel_id: ChannelId,
4027 response: Box<impl JoinChannelInternalResponse>,
4028 session: UserSession,
4029) -> Result<()> {
4030 let joined_room = {
4031 let mut db = session.db().await;
4032 // If zed quits without leaving the room, and the user re-opens zed before the
4033 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4034 // room they were in.
4035 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4036 tracing::info!(
4037 stale_connection_id = %connection,
4038 "cleaning up stale connection",
4039 );
4040 drop(db);
4041 leave_room_for_session(&session, connection).await?;
4042 db = session.db().await;
4043 }
4044
4045 let (joined_room, membership_updated, role) = db
4046 .join_channel(channel_id, session.user_id(), session.connection_id)
4047 .await?;
4048
4049 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4050 let (can_publish, token) = if role == ChannelRole::Guest {
4051 (
4052 false,
4053 live_kit
4054 .guest_token(
4055 &joined_room.room.live_kit_room,
4056 &session.user_id().to_string(),
4057 )
4058 .trace_err()?,
4059 )
4060 } else {
4061 (
4062 true,
4063 live_kit
4064 .room_token(
4065 &joined_room.room.live_kit_room,
4066 &session.user_id().to_string(),
4067 )
4068 .trace_err()?,
4069 )
4070 };
4071
4072 Some(LiveKitConnectionInfo {
4073 server_url: live_kit.url().into(),
4074 token,
4075 can_publish,
4076 })
4077 });
4078
4079 response.send(proto::JoinRoomResponse {
4080 room: Some(joined_room.room.clone()),
4081 channel_id: joined_room
4082 .channel
4083 .as_ref()
4084 .map(|channel| channel.id.to_proto()),
4085 live_kit_connection_info,
4086 })?;
4087
4088 let mut connection_pool = session.connection_pool().await;
4089 if let Some(membership_updated) = membership_updated {
4090 notify_membership_updated(
4091 &mut connection_pool,
4092 membership_updated,
4093 session.user_id(),
4094 &session.peer,
4095 );
4096 }
4097
4098 room_updated(&joined_room.room, &session.peer);
4099
4100 joined_room
4101 };
4102
4103 channel_updated(
4104 &joined_room
4105 .channel
4106 .ok_or_else(|| anyhow!("channel not returned"))?,
4107 &joined_room.room,
4108 &session.peer,
4109 &*session.connection_pool().await,
4110 );
4111
4112 update_user_contacts(session.user_id(), &session).await?;
4113 Ok(())
4114}
4115
4116/// Start editing the channel notes
4117async fn join_channel_buffer(
4118 request: proto::JoinChannelBuffer,
4119 response: Response<proto::JoinChannelBuffer>,
4120 session: UserSession,
4121) -> Result<()> {
4122 let db = session.db().await;
4123 let channel_id = ChannelId::from_proto(request.channel_id);
4124
4125 let open_response = db
4126 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4127 .await?;
4128
4129 let collaborators = open_response.collaborators.clone();
4130 response.send(open_response)?;
4131
4132 let update = UpdateChannelBufferCollaborators {
4133 channel_id: channel_id.to_proto(),
4134 collaborators: collaborators.clone(),
4135 };
4136 channel_buffer_updated(
4137 session.connection_id,
4138 collaborators
4139 .iter()
4140 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4141 &update,
4142 &session.peer,
4143 );
4144
4145 Ok(())
4146}
4147
4148/// Edit the channel notes
4149async fn update_channel_buffer(
4150 request: proto::UpdateChannelBuffer,
4151 session: UserSession,
4152) -> Result<()> {
4153 let db = session.db().await;
4154 let channel_id = ChannelId::from_proto(request.channel_id);
4155
4156 let (collaborators, epoch, version) = db
4157 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4158 .await?;
4159
4160 channel_buffer_updated(
4161 session.connection_id,
4162 collaborators.clone(),
4163 &proto::UpdateChannelBuffer {
4164 channel_id: channel_id.to_proto(),
4165 operations: request.operations,
4166 },
4167 &session.peer,
4168 );
4169
4170 let pool = &*session.connection_pool().await;
4171
4172 let non_collaborators =
4173 pool.channel_connection_ids(channel_id)
4174 .filter_map(|(connection_id, _)| {
4175 if collaborators.contains(&connection_id) {
4176 None
4177 } else {
4178 Some(connection_id)
4179 }
4180 });
4181
4182 broadcast(None, non_collaborators, |peer_id| {
4183 session.peer.send(
4184 peer_id,
4185 proto::UpdateChannels {
4186 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4187 channel_id: channel_id.to_proto(),
4188 epoch: epoch as u64,
4189 version: version.clone(),
4190 }],
4191 ..Default::default()
4192 },
4193 )
4194 });
4195
4196 Ok(())
4197}
4198
4199/// Rejoin the channel notes after a connection blip
4200async fn rejoin_channel_buffers(
4201 request: proto::RejoinChannelBuffers,
4202 response: Response<proto::RejoinChannelBuffers>,
4203 session: UserSession,
4204) -> Result<()> {
4205 let db = session.db().await;
4206 let buffers = db
4207 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4208 .await?;
4209
4210 for rejoined_buffer in &buffers {
4211 let collaborators_to_notify = rejoined_buffer
4212 .buffer
4213 .collaborators
4214 .iter()
4215 .filter_map(|c| Some(c.peer_id?.into()));
4216 channel_buffer_updated(
4217 session.connection_id,
4218 collaborators_to_notify,
4219 &proto::UpdateChannelBufferCollaborators {
4220 channel_id: rejoined_buffer.buffer.channel_id,
4221 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4222 },
4223 &session.peer,
4224 );
4225 }
4226
4227 response.send(proto::RejoinChannelBuffersResponse {
4228 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4229 })?;
4230
4231 Ok(())
4232}
4233
4234/// Stop editing the channel notes
4235async fn leave_channel_buffer(
4236 request: proto::LeaveChannelBuffer,
4237 response: Response<proto::LeaveChannelBuffer>,
4238 session: UserSession,
4239) -> Result<()> {
4240 let db = session.db().await;
4241 let channel_id = ChannelId::from_proto(request.channel_id);
4242
4243 let left_buffer = db
4244 .leave_channel_buffer(channel_id, session.connection_id)
4245 .await?;
4246
4247 response.send(Ack {})?;
4248
4249 channel_buffer_updated(
4250 session.connection_id,
4251 left_buffer.connections,
4252 &proto::UpdateChannelBufferCollaborators {
4253 channel_id: channel_id.to_proto(),
4254 collaborators: left_buffer.collaborators,
4255 },
4256 &session.peer,
4257 );
4258
4259 Ok(())
4260}
4261
4262fn channel_buffer_updated<T: EnvelopedMessage>(
4263 sender_id: ConnectionId,
4264 collaborators: impl IntoIterator<Item = ConnectionId>,
4265 message: &T,
4266 peer: &Peer,
4267) {
4268 broadcast(Some(sender_id), collaborators, |peer_id| {
4269 peer.send(peer_id, message.clone())
4270 });
4271}
4272
4273fn send_notifications(
4274 connection_pool: &ConnectionPool,
4275 peer: &Peer,
4276 notifications: db::NotificationBatch,
4277) {
4278 for (user_id, notification) in notifications {
4279 for connection_id in connection_pool.user_connection_ids(user_id) {
4280 if let Err(error) = peer.send(
4281 connection_id,
4282 proto::AddNotification {
4283 notification: Some(notification.clone()),
4284 },
4285 ) {
4286 tracing::error!(
4287 "failed to send notification to {:?} {}",
4288 connection_id,
4289 error
4290 );
4291 }
4292 }
4293 }
4294}
4295
4296/// Send a message to the channel
4297async fn send_channel_message(
4298 request: proto::SendChannelMessage,
4299 response: Response<proto::SendChannelMessage>,
4300 session: UserSession,
4301) -> Result<()> {
4302 // Validate the message body.
4303 let body = request.body.trim().to_string();
4304 if body.len() > MAX_MESSAGE_LEN {
4305 return Err(anyhow!("message is too long"))?;
4306 }
4307 if body.is_empty() {
4308 return Err(anyhow!("message can't be blank"))?;
4309 }
4310
4311 // TODO: adjust mentions if body is trimmed
4312
4313 let timestamp = OffsetDateTime::now_utc();
4314 let nonce = request
4315 .nonce
4316 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4317
4318 let channel_id = ChannelId::from_proto(request.channel_id);
4319 let CreatedChannelMessage {
4320 message_id,
4321 participant_connection_ids,
4322 notifications,
4323 } = session
4324 .db()
4325 .await
4326 .create_channel_message(
4327 channel_id,
4328 session.user_id(),
4329 &body,
4330 &request.mentions,
4331 timestamp,
4332 nonce.clone().into(),
4333 match request.reply_to_message_id {
4334 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4335 None => None,
4336 },
4337 )
4338 .await?;
4339
4340 let message = proto::ChannelMessage {
4341 sender_id: session.user_id().to_proto(),
4342 id: message_id.to_proto(),
4343 body,
4344 mentions: request.mentions,
4345 timestamp: timestamp.unix_timestamp() as u64,
4346 nonce: Some(nonce),
4347 reply_to_message_id: request.reply_to_message_id,
4348 edited_at: None,
4349 };
4350 broadcast(
4351 Some(session.connection_id),
4352 participant_connection_ids.clone(),
4353 |connection| {
4354 session.peer.send(
4355 connection,
4356 proto::ChannelMessageSent {
4357 channel_id: channel_id.to_proto(),
4358 message: Some(message.clone()),
4359 },
4360 )
4361 },
4362 );
4363 response.send(proto::SendChannelMessageResponse {
4364 message: Some(message),
4365 })?;
4366
4367 let pool = &*session.connection_pool().await;
4368 let non_participants =
4369 pool.channel_connection_ids(channel_id)
4370 .filter_map(|(connection_id, _)| {
4371 if participant_connection_ids.contains(&connection_id) {
4372 None
4373 } else {
4374 Some(connection_id)
4375 }
4376 });
4377 broadcast(None, non_participants, |peer_id| {
4378 session.peer.send(
4379 peer_id,
4380 proto::UpdateChannels {
4381 latest_channel_message_ids: vec![proto::ChannelMessageId {
4382 channel_id: channel_id.to_proto(),
4383 message_id: message_id.to_proto(),
4384 }],
4385 ..Default::default()
4386 },
4387 )
4388 });
4389 send_notifications(pool, &session.peer, notifications);
4390
4391 Ok(())
4392}
4393
4394/// Delete a channel message
4395async fn remove_channel_message(
4396 request: proto::RemoveChannelMessage,
4397 response: Response<proto::RemoveChannelMessage>,
4398 session: UserSession,
4399) -> Result<()> {
4400 let channel_id = ChannelId::from_proto(request.channel_id);
4401 let message_id = MessageId::from_proto(request.message_id);
4402 let (connection_ids, existing_notification_ids) = session
4403 .db()
4404 .await
4405 .remove_channel_message(channel_id, message_id, session.user_id())
4406 .await?;
4407
4408 broadcast(
4409 Some(session.connection_id),
4410 connection_ids,
4411 move |connection| {
4412 session.peer.send(connection, request.clone())?;
4413
4414 for notification_id in &existing_notification_ids {
4415 session.peer.send(
4416 connection,
4417 proto::DeleteNotification {
4418 notification_id: (*notification_id).to_proto(),
4419 },
4420 )?;
4421 }
4422
4423 Ok(())
4424 },
4425 );
4426 response.send(proto::Ack {})?;
4427 Ok(())
4428}
4429
4430async fn update_channel_message(
4431 request: proto::UpdateChannelMessage,
4432 response: Response<proto::UpdateChannelMessage>,
4433 session: UserSession,
4434) -> Result<()> {
4435 let channel_id = ChannelId::from_proto(request.channel_id);
4436 let message_id = MessageId::from_proto(request.message_id);
4437 let updated_at = OffsetDateTime::now_utc();
4438 let UpdatedChannelMessage {
4439 message_id,
4440 participant_connection_ids,
4441 notifications,
4442 reply_to_message_id,
4443 timestamp,
4444 deleted_mention_notification_ids,
4445 updated_mention_notifications,
4446 } = session
4447 .db()
4448 .await
4449 .update_channel_message(
4450 channel_id,
4451 message_id,
4452 session.user_id(),
4453 request.body.as_str(),
4454 &request.mentions,
4455 updated_at,
4456 )
4457 .await?;
4458
4459 let nonce = request
4460 .nonce
4461 .clone()
4462 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4463
4464 let message = proto::ChannelMessage {
4465 sender_id: session.user_id().to_proto(),
4466 id: message_id.to_proto(),
4467 body: request.body.clone(),
4468 mentions: request.mentions.clone(),
4469 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4470 nonce: Some(nonce),
4471 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4472 edited_at: Some(updated_at.unix_timestamp() as u64),
4473 };
4474
4475 response.send(proto::Ack {})?;
4476
4477 let pool = &*session.connection_pool().await;
4478 broadcast(
4479 Some(session.connection_id),
4480 participant_connection_ids,
4481 |connection| {
4482 session.peer.send(
4483 connection,
4484 proto::ChannelMessageUpdate {
4485 channel_id: channel_id.to_proto(),
4486 message: Some(message.clone()),
4487 },
4488 )?;
4489
4490 for notification_id in &deleted_mention_notification_ids {
4491 session.peer.send(
4492 connection,
4493 proto::DeleteNotification {
4494 notification_id: (*notification_id).to_proto(),
4495 },
4496 )?;
4497 }
4498
4499 for notification in &updated_mention_notifications {
4500 session.peer.send(
4501 connection,
4502 proto::UpdateNotification {
4503 notification: Some(notification.clone()),
4504 },
4505 )?;
4506 }
4507
4508 Ok(())
4509 },
4510 );
4511
4512 send_notifications(pool, &session.peer, notifications);
4513
4514 Ok(())
4515}
4516
4517/// Mark a channel message as read
4518async fn acknowledge_channel_message(
4519 request: proto::AckChannelMessage,
4520 session: UserSession,
4521) -> Result<()> {
4522 let channel_id = ChannelId::from_proto(request.channel_id);
4523 let message_id = MessageId::from_proto(request.message_id);
4524 let notifications = session
4525 .db()
4526 .await
4527 .observe_channel_message(channel_id, session.user_id(), message_id)
4528 .await?;
4529 send_notifications(
4530 &*session.connection_pool().await,
4531 &session.peer,
4532 notifications,
4533 );
4534 Ok(())
4535}
4536
4537/// Mark a buffer version as synced
4538async fn acknowledge_buffer_version(
4539 request: proto::AckBufferOperation,
4540 session: UserSession,
4541) -> Result<()> {
4542 let buffer_id = BufferId::from_proto(request.buffer_id);
4543 session
4544 .db()
4545 .await
4546 .observe_buffer_version(
4547 buffer_id,
4548 session.user_id(),
4549 request.epoch as i32,
4550 &request.version,
4551 )
4552 .await?;
4553 Ok(())
4554}
4555
4556struct ZedProCompleteWithLanguageModelRateLimit;
4557
4558impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4559 fn capacity(&self) -> usize {
4560 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4561 .ok()
4562 .and_then(|v| v.parse().ok())
4563 .unwrap_or(120) // Picked arbitrarily
4564 }
4565
4566 fn refill_duration(&self) -> chrono::Duration {
4567 chrono::Duration::hours(1)
4568 }
4569
4570 fn db_name(&self) -> &'static str {
4571 "zed-pro:complete-with-language-model"
4572 }
4573}
4574
4575struct FreeCompleteWithLanguageModelRateLimit;
4576
4577impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4578 fn capacity(&self) -> usize {
4579 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4580 .ok()
4581 .and_then(|v| v.parse().ok())
4582 .unwrap_or(120 / 10) // Picked arbitrarily
4583 }
4584
4585 fn refill_duration(&self) -> chrono::Duration {
4586 chrono::Duration::hours(1)
4587 }
4588
4589 fn db_name(&self) -> &'static str {
4590 "free:complete-with-language-model"
4591 }
4592}
4593
4594async fn complete_with_language_model(
4595 request: proto::CompleteWithLanguageModel,
4596 response: Response<proto::CompleteWithLanguageModel>,
4597 session: Session,
4598 config: &Config,
4599) -> Result<()> {
4600 let Some(session) = session.for_user() else {
4601 return Err(anyhow!("user not found"))?;
4602 };
4603 authorize_access_to_language_models(&session).await?;
4604
4605 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4606 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4607 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4608 };
4609
4610 session
4611 .rate_limiter
4612 .check(&*rate_limit, session.user_id())
4613 .await?;
4614
4615 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4616 Some(proto::LanguageModelProvider::Anthropic) => {
4617 let api_key = config
4618 .anthropic_api_key
4619 .as_ref()
4620 .context("no Anthropic AI API key configured on the server")?;
4621 anthropic::complete(
4622 session.http_client.as_ref(),
4623 anthropic::ANTHROPIC_API_URL,
4624 api_key,
4625 serde_json::from_str(&request.request)?,
4626 )
4627 .await?
4628 }
4629 _ => return Err(anyhow!("unsupported provider"))?,
4630 };
4631
4632 response.send(proto::CompleteWithLanguageModelResponse {
4633 completion: serde_json::to_string(&result)?,
4634 })?;
4635
4636 Ok(())
4637}
4638
4639async fn stream_complete_with_language_model(
4640 request: proto::StreamCompleteWithLanguageModel,
4641 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4642 session: Session,
4643 config: &Config,
4644) -> Result<()> {
4645 let Some(session) = session.for_user() else {
4646 return Err(anyhow!("user not found"))?;
4647 };
4648 authorize_access_to_language_models(&session).await?;
4649
4650 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4651 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4652 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4653 };
4654
4655 session
4656 .rate_limiter
4657 .check(&*rate_limit, session.user_id())
4658 .await?;
4659
4660 match proto::LanguageModelProvider::from_i32(request.provider) {
4661 Some(proto::LanguageModelProvider::Anthropic) => {
4662 let api_key = config
4663 .anthropic_api_key
4664 .as_ref()
4665 .context("no Anthropic AI API key configured on the server")?;
4666 let mut chunks = anthropic::stream_completion(
4667 session.http_client.as_ref(),
4668 anthropic::ANTHROPIC_API_URL,
4669 api_key,
4670 serde_json::from_str(&request.request)?,
4671 None,
4672 )
4673 .await?;
4674 while let Some(event) = chunks.next().await {
4675 let chunk = event?;
4676 response.send(proto::StreamCompleteWithLanguageModelResponse {
4677 event: serde_json::to_string(&chunk)?,
4678 })?;
4679 }
4680 }
4681 Some(proto::LanguageModelProvider::OpenAi) => {
4682 let api_key = config
4683 .openai_api_key
4684 .as_ref()
4685 .context("no OpenAI API key configured on the server")?;
4686 let mut events = open_ai::stream_completion(
4687 session.http_client.as_ref(),
4688 open_ai::OPEN_AI_API_URL,
4689 api_key,
4690 serde_json::from_str(&request.request)?,
4691 None,
4692 )
4693 .await?;
4694 while let Some(event) = events.next().await {
4695 let event = event?;
4696 response.send(proto::StreamCompleteWithLanguageModelResponse {
4697 event: serde_json::to_string(&event)?,
4698 })?;
4699 }
4700 }
4701 Some(proto::LanguageModelProvider::Google) => {
4702 let api_key = config
4703 .google_ai_api_key
4704 .as_ref()
4705 .context("no Google AI API key configured on the server")?;
4706 let mut events = google_ai::stream_generate_content(
4707 session.http_client.as_ref(),
4708 google_ai::API_URL,
4709 api_key,
4710 serde_json::from_str(&request.request)?,
4711 )
4712 .await?;
4713 while let Some(event) = events.next().await {
4714 let event = event?;
4715 response.send(proto::StreamCompleteWithLanguageModelResponse {
4716 event: serde_json::to_string(&event)?,
4717 })?;
4718 }
4719 }
4720 Some(proto::LanguageModelProvider::Zed) => {
4721 let api_key = config
4722 .qwen2_7b_api_key
4723 .as_ref()
4724 .context("no Qwen2-7B API key configured on the server")?;
4725 let api_url = config
4726 .qwen2_7b_api_url
4727 .as_ref()
4728 .context("no Qwen2-7B URL configured on the server")?;
4729 let mut events = open_ai::stream_completion(
4730 session.http_client.as_ref(),
4731 &api_url,
4732 api_key,
4733 serde_json::from_str(&request.request)?,
4734 None,
4735 )
4736 .await?;
4737 while let Some(event) = events.next().await {
4738 let event = event?;
4739 response.send(proto::StreamCompleteWithLanguageModelResponse {
4740 event: serde_json::to_string(&event)?,
4741 })?;
4742 }
4743 }
4744 None => return Err(anyhow!("unknown provider"))?,
4745 }
4746
4747 Ok(())
4748}
4749
4750async fn count_language_model_tokens(
4751 request: proto::CountLanguageModelTokens,
4752 response: Response<proto::CountLanguageModelTokens>,
4753 session: Session,
4754 config: &Config,
4755) -> Result<()> {
4756 let Some(session) = session.for_user() else {
4757 return Err(anyhow!("user not found"))?;
4758 };
4759 authorize_access_to_language_models(&session).await?;
4760
4761 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4762 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4763 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4764 };
4765
4766 session
4767 .rate_limiter
4768 .check(&*rate_limit, session.user_id())
4769 .await?;
4770
4771 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4772 Some(proto::LanguageModelProvider::Google) => {
4773 let api_key = config
4774 .google_ai_api_key
4775 .as_ref()
4776 .context("no Google AI API key configured on the server")?;
4777 google_ai::count_tokens(
4778 session.http_client.as_ref(),
4779 google_ai::API_URL,
4780 api_key,
4781 serde_json::from_str(&request.request)?,
4782 )
4783 .await?
4784 }
4785 _ => return Err(anyhow!("unsupported provider"))?,
4786 };
4787
4788 response.send(proto::CountLanguageModelTokensResponse {
4789 token_count: result.total_tokens as u32,
4790 })?;
4791
4792 Ok(())
4793}
4794
4795struct ZedProCountLanguageModelTokensRateLimit;
4796
4797impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4798 fn capacity(&self) -> usize {
4799 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4800 .ok()
4801 .and_then(|v| v.parse().ok())
4802 .unwrap_or(600) // Picked arbitrarily
4803 }
4804
4805 fn refill_duration(&self) -> chrono::Duration {
4806 chrono::Duration::hours(1)
4807 }
4808
4809 fn db_name(&self) -> &'static str {
4810 "zed-pro:count-language-model-tokens"
4811 }
4812}
4813
4814struct FreeCountLanguageModelTokensRateLimit;
4815
4816impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4817 fn capacity(&self) -> usize {
4818 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4819 .ok()
4820 .and_then(|v| v.parse().ok())
4821 .unwrap_or(600 / 10) // Picked arbitrarily
4822 }
4823
4824 fn refill_duration(&self) -> chrono::Duration {
4825 chrono::Duration::hours(1)
4826 }
4827
4828 fn db_name(&self) -> &'static str {
4829 "free:count-language-model-tokens"
4830 }
4831}
4832
4833struct ZedProComputeEmbeddingsRateLimit;
4834
4835impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4836 fn capacity(&self) -> usize {
4837 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4838 .ok()
4839 .and_then(|v| v.parse().ok())
4840 .unwrap_or(5000) // Picked arbitrarily
4841 }
4842
4843 fn refill_duration(&self) -> chrono::Duration {
4844 chrono::Duration::hours(1)
4845 }
4846
4847 fn db_name(&self) -> &'static str {
4848 "zed-pro:compute-embeddings"
4849 }
4850}
4851
4852struct FreeComputeEmbeddingsRateLimit;
4853
4854impl RateLimit for FreeComputeEmbeddingsRateLimit {
4855 fn capacity(&self) -> usize {
4856 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4857 .ok()
4858 .and_then(|v| v.parse().ok())
4859 .unwrap_or(5000 / 10) // Picked arbitrarily
4860 }
4861
4862 fn refill_duration(&self) -> chrono::Duration {
4863 chrono::Duration::hours(1)
4864 }
4865
4866 fn db_name(&self) -> &'static str {
4867 "free:compute-embeddings"
4868 }
4869}
4870
4871async fn compute_embeddings(
4872 request: proto::ComputeEmbeddings,
4873 response: Response<proto::ComputeEmbeddings>,
4874 session: UserSession,
4875 api_key: Option<Arc<str>>,
4876) -> Result<()> {
4877 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4878 authorize_access_to_language_models(&session).await?;
4879
4880 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4881 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4882 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4883 };
4884
4885 session
4886 .rate_limiter
4887 .check(&*rate_limit, session.user_id())
4888 .await?;
4889
4890 let embeddings = match request.model.as_str() {
4891 "openai/text-embedding-3-small" => {
4892 open_ai::embed(
4893 session.http_client.as_ref(),
4894 OPEN_AI_API_URL,
4895 &api_key,
4896 OpenAiEmbeddingModel::TextEmbedding3Small,
4897 request.texts.iter().map(|text| text.as_str()),
4898 )
4899 .await?
4900 }
4901 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4902 };
4903
4904 let embeddings = request
4905 .texts
4906 .iter()
4907 .map(|text| {
4908 let mut hasher = sha2::Sha256::new();
4909 hasher.update(text.as_bytes());
4910 let result = hasher.finalize();
4911 result.to_vec()
4912 })
4913 .zip(
4914 embeddings
4915 .data
4916 .into_iter()
4917 .map(|embedding| embedding.embedding),
4918 )
4919 .collect::<HashMap<_, _>>();
4920
4921 let db = session.db().await;
4922 db.save_embeddings(&request.model, &embeddings)
4923 .await
4924 .context("failed to save embeddings")
4925 .trace_err();
4926
4927 response.send(proto::ComputeEmbeddingsResponse {
4928 embeddings: embeddings
4929 .into_iter()
4930 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4931 .collect(),
4932 })?;
4933 Ok(())
4934}
4935
4936async fn get_cached_embeddings(
4937 request: proto::GetCachedEmbeddings,
4938 response: Response<proto::GetCachedEmbeddings>,
4939 session: UserSession,
4940) -> Result<()> {
4941 authorize_access_to_language_models(&session).await?;
4942
4943 let db = session.db().await;
4944 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4945
4946 response.send(proto::GetCachedEmbeddingsResponse {
4947 embeddings: embeddings
4948 .into_iter()
4949 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4950 .collect(),
4951 })?;
4952 Ok(())
4953}
4954
4955async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4956 let db = session.db().await;
4957 let flags = db.get_user_flags(session.user_id()).await?;
4958 if flags.iter().any(|flag| flag == "language-models") {
4959 return Ok(());
4960 }
4961
4962 Err(anyhow!("permission denied"))?
4963}
4964
4965/// Get a Supermaven API key for the user
4966async fn get_supermaven_api_key(
4967 _request: proto::GetSupermavenApiKey,
4968 response: Response<proto::GetSupermavenApiKey>,
4969 session: UserSession,
4970) -> Result<()> {
4971 let user_id: String = session.user_id().to_string();
4972 if !session.is_staff() {
4973 return Err(anyhow!("supermaven not enabled for this account"))?;
4974 }
4975
4976 let email = session
4977 .email()
4978 .ok_or_else(|| anyhow!("user must have an email"))?;
4979
4980 let supermaven_admin_api = session
4981 .supermaven_client
4982 .as_ref()
4983 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4984
4985 let result = supermaven_admin_api
4986 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4987 .await?;
4988
4989 response.send(proto::GetSupermavenApiKeyResponse {
4990 api_key: result.api_key,
4991 })?;
4992
4993 Ok(())
4994}
4995
4996/// Start receiving chat updates for a channel
4997async fn join_channel_chat(
4998 request: proto::JoinChannelChat,
4999 response: Response<proto::JoinChannelChat>,
5000 session: UserSession,
5001) -> Result<()> {
5002 let channel_id = ChannelId::from_proto(request.channel_id);
5003
5004 let db = session.db().await;
5005 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5006 .await?;
5007 let messages = db
5008 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5009 .await?;
5010 response.send(proto::JoinChannelChatResponse {
5011 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5012 messages,
5013 })?;
5014 Ok(())
5015}
5016
5017/// Stop receiving chat updates for a channel
5018async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5019 let channel_id = ChannelId::from_proto(request.channel_id);
5020 session
5021 .db()
5022 .await
5023 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5024 .await?;
5025 Ok(())
5026}
5027
5028/// Retrieve the chat history for a channel
5029async fn get_channel_messages(
5030 request: proto::GetChannelMessages,
5031 response: Response<proto::GetChannelMessages>,
5032 session: UserSession,
5033) -> Result<()> {
5034 let channel_id = ChannelId::from_proto(request.channel_id);
5035 let messages = session
5036 .db()
5037 .await
5038 .get_channel_messages(
5039 channel_id,
5040 session.user_id(),
5041 MESSAGE_COUNT_PER_PAGE,
5042 Some(MessageId::from_proto(request.before_message_id)),
5043 )
5044 .await?;
5045 response.send(proto::GetChannelMessagesResponse {
5046 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5047 messages,
5048 })?;
5049 Ok(())
5050}
5051
5052/// Retrieve specific chat messages
5053async fn get_channel_messages_by_id(
5054 request: proto::GetChannelMessagesById,
5055 response: Response<proto::GetChannelMessagesById>,
5056 session: UserSession,
5057) -> Result<()> {
5058 let message_ids = request
5059 .message_ids
5060 .iter()
5061 .map(|id| MessageId::from_proto(*id))
5062 .collect::<Vec<_>>();
5063 let messages = session
5064 .db()
5065 .await
5066 .get_channel_messages_by_id(session.user_id(), &message_ids)
5067 .await?;
5068 response.send(proto::GetChannelMessagesResponse {
5069 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5070 messages,
5071 })?;
5072 Ok(())
5073}
5074
5075/// Retrieve the current users notifications
5076async fn get_notifications(
5077 request: proto::GetNotifications,
5078 response: Response<proto::GetNotifications>,
5079 session: UserSession,
5080) -> Result<()> {
5081 let notifications = session
5082 .db()
5083 .await
5084 .get_notifications(
5085 session.user_id(),
5086 NOTIFICATION_COUNT_PER_PAGE,
5087 request
5088 .before_id
5089 .map(|id| db::NotificationId::from_proto(id)),
5090 )
5091 .await?;
5092 response.send(proto::GetNotificationsResponse {
5093 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5094 notifications,
5095 })?;
5096 Ok(())
5097}
5098
5099/// Mark notifications as read
5100async fn mark_notification_as_read(
5101 request: proto::MarkNotificationRead,
5102 response: Response<proto::MarkNotificationRead>,
5103 session: UserSession,
5104) -> Result<()> {
5105 let database = &session.db().await;
5106 let notifications = database
5107 .mark_notification_as_read_by_id(
5108 session.user_id(),
5109 NotificationId::from_proto(request.notification_id),
5110 )
5111 .await?;
5112 send_notifications(
5113 &*session.connection_pool().await,
5114 &session.peer,
5115 notifications,
5116 );
5117 response.send(proto::Ack {})?;
5118 Ok(())
5119}
5120
5121/// Get the current users information
5122async fn get_private_user_info(
5123 _request: proto::GetPrivateUserInfo,
5124 response: Response<proto::GetPrivateUserInfo>,
5125 session: UserSession,
5126) -> Result<()> {
5127 let db = session.db().await;
5128
5129 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5130 let user = db
5131 .get_user_by_id(session.user_id())
5132 .await?
5133 .ok_or_else(|| anyhow!("user not found"))?;
5134 let flags = db.get_user_flags(session.user_id()).await?;
5135
5136 response.send(proto::GetPrivateUserInfoResponse {
5137 metrics_id,
5138 staff: user.admin,
5139 flags,
5140 })?;
5141 Ok(())
5142}
5143
5144fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5145 let message = match message {
5146 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5147 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5148 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5149 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5150 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5151 code: frame.code.into(),
5152 reason: frame.reason,
5153 })),
5154 // We should never receive a frame while reading the message, according
5155 // to the `tungstenite` maintainers:
5156 //
5157 // > It cannot occur when you read messages from the WebSocket, but it
5158 // > can be used when you want to send the raw frames (e.g. you want to
5159 // > send the frames to the WebSocket without composing the full message first).
5160 // >
5161 // > — https://github.com/snapview/tungstenite-rs/issues/268
5162 TungsteniteMessage::Frame(_) => {
5163 bail!("received an unexpected frame while reading the message")
5164 }
5165 };
5166
5167 Ok(message)
5168}
5169
5170fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5171 match message {
5172 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5173 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5174 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5175 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5176 AxumMessage::Close(frame) => {
5177 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5178 code: frame.code.into(),
5179 reason: frame.reason,
5180 }))
5181 }
5182 }
5183}
5184
5185fn notify_membership_updated(
5186 connection_pool: &mut ConnectionPool,
5187 result: MembershipUpdated,
5188 user_id: UserId,
5189 peer: &Peer,
5190) {
5191 for membership in &result.new_channels.channel_memberships {
5192 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5193 }
5194 for channel_id in &result.removed_channels {
5195 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5196 }
5197
5198 let user_channels_update = proto::UpdateUserChannels {
5199 channel_memberships: result
5200 .new_channels
5201 .channel_memberships
5202 .iter()
5203 .map(|cm| proto::ChannelMembership {
5204 channel_id: cm.channel_id.to_proto(),
5205 role: cm.role.into(),
5206 })
5207 .collect(),
5208 ..Default::default()
5209 };
5210
5211 let mut update = build_channels_update(result.new_channels);
5212 update.delete_channels = result
5213 .removed_channels
5214 .into_iter()
5215 .map(|id| id.to_proto())
5216 .collect();
5217 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5218
5219 for connection_id in connection_pool.user_connection_ids(user_id) {
5220 peer.send(connection_id, user_channels_update.clone())
5221 .trace_err();
5222 peer.send(connection_id, update.clone()).trace_err();
5223 }
5224}
5225
5226fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5227 proto::UpdateUserChannels {
5228 channel_memberships: channels
5229 .channel_memberships
5230 .iter()
5231 .map(|m| proto::ChannelMembership {
5232 channel_id: m.channel_id.to_proto(),
5233 role: m.role.into(),
5234 })
5235 .collect(),
5236 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5237 observed_channel_message_id: channels.observed_channel_messages.clone(),
5238 }
5239}
5240
5241fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5242 let mut update = proto::UpdateChannels::default();
5243
5244 for channel in channels.channels {
5245 update.channels.push(channel.to_proto());
5246 }
5247
5248 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5249 update.latest_channel_message_ids = channels.latest_channel_messages;
5250
5251 for (channel_id, participants) in channels.channel_participants {
5252 update
5253 .channel_participants
5254 .push(proto::ChannelParticipants {
5255 channel_id: channel_id.to_proto(),
5256 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5257 });
5258 }
5259
5260 for channel in channels.invited_channels {
5261 update.channel_invitations.push(channel.to_proto());
5262 }
5263
5264 update.hosted_projects = channels.hosted_projects;
5265 update
5266}
5267
5268fn build_initial_contacts_update(
5269 contacts: Vec<db::Contact>,
5270 pool: &ConnectionPool,
5271) -> proto::UpdateContacts {
5272 let mut update = proto::UpdateContacts::default();
5273
5274 for contact in contacts {
5275 match contact {
5276 db::Contact::Accepted { user_id, busy } => {
5277 update.contacts.push(contact_for_user(user_id, busy, &pool));
5278 }
5279 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5280 db::Contact::Incoming { user_id } => {
5281 update
5282 .incoming_requests
5283 .push(proto::IncomingContactRequest {
5284 requester_id: user_id.to_proto(),
5285 })
5286 }
5287 }
5288 }
5289
5290 update
5291}
5292
5293fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5294 proto::Contact {
5295 user_id: user_id.to_proto(),
5296 online: pool.is_user_online(user_id),
5297 busy,
5298 }
5299}
5300
5301fn room_updated(room: &proto::Room, peer: &Peer) {
5302 broadcast(
5303 None,
5304 room.participants
5305 .iter()
5306 .filter_map(|participant| Some(participant.peer_id?.into())),
5307 |peer_id| {
5308 peer.send(
5309 peer_id,
5310 proto::RoomUpdated {
5311 room: Some(room.clone()),
5312 },
5313 )
5314 },
5315 );
5316}
5317
5318fn channel_updated(
5319 channel: &db::channel::Model,
5320 room: &proto::Room,
5321 peer: &Peer,
5322 pool: &ConnectionPool,
5323) {
5324 let participants = room
5325 .participants
5326 .iter()
5327 .map(|p| p.user_id)
5328 .collect::<Vec<_>>();
5329
5330 broadcast(
5331 None,
5332 pool.channel_connection_ids(channel.root_id())
5333 .filter_map(|(channel_id, role)| {
5334 role.can_see_channel(channel.visibility).then(|| channel_id)
5335 }),
5336 |peer_id| {
5337 peer.send(
5338 peer_id,
5339 proto::UpdateChannels {
5340 channel_participants: vec![proto::ChannelParticipants {
5341 channel_id: channel.id.to_proto(),
5342 participant_user_ids: participants.clone(),
5343 }],
5344 ..Default::default()
5345 },
5346 )
5347 },
5348 );
5349}
5350
5351async fn send_dev_server_projects_update(
5352 user_id: UserId,
5353 mut status: proto::DevServerProjectsUpdate,
5354 session: &Session,
5355) {
5356 let pool = session.connection_pool().await;
5357 for dev_server in &mut status.dev_servers {
5358 dev_server.status =
5359 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5360 }
5361 let connections = pool.user_connection_ids(user_id);
5362 for connection_id in connections {
5363 session.peer.send(connection_id, status.clone()).trace_err();
5364 }
5365}
5366
5367async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5368 let db = session.db().await;
5369
5370 let contacts = db.get_contacts(user_id).await?;
5371 let busy = db.is_user_busy(user_id).await?;
5372
5373 let pool = session.connection_pool().await;
5374 let updated_contact = contact_for_user(user_id, busy, &pool);
5375 for contact in contacts {
5376 if let db::Contact::Accepted {
5377 user_id: contact_user_id,
5378 ..
5379 } = contact
5380 {
5381 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5382 session
5383 .peer
5384 .send(
5385 contact_conn_id,
5386 proto::UpdateContacts {
5387 contacts: vec![updated_contact.clone()],
5388 remove_contacts: Default::default(),
5389 incoming_requests: Default::default(),
5390 remove_incoming_requests: Default::default(),
5391 outgoing_requests: Default::default(),
5392 remove_outgoing_requests: Default::default(),
5393 },
5394 )
5395 .trace_err();
5396 }
5397 }
5398 }
5399 Ok(())
5400}
5401
5402async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5403 log::info!("lost dev server connection, unsharing projects");
5404 let project_ids = session
5405 .db()
5406 .await
5407 .get_stale_dev_server_projects(session.connection_id)
5408 .await?;
5409
5410 for project_id in project_ids {
5411 // not unshare re-checks the connection ids match, so we get away with no transaction
5412 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5413 }
5414
5415 let user_id = session.dev_server().user_id;
5416 let update = session
5417 .db()
5418 .await
5419 .dev_server_projects_update(user_id)
5420 .await?;
5421
5422 send_dev_server_projects_update(user_id, update, session).await;
5423
5424 Ok(())
5425}
5426
5427async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5428 let mut contacts_to_update = HashSet::default();
5429
5430 let room_id;
5431 let canceled_calls_to_user_ids;
5432 let live_kit_room;
5433 let delete_live_kit_room;
5434 let room;
5435 let channel;
5436
5437 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5438 contacts_to_update.insert(session.user_id());
5439
5440 for project in left_room.left_projects.values() {
5441 project_left(project, session);
5442 }
5443
5444 room_id = RoomId::from_proto(left_room.room.id);
5445 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5446 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5447 delete_live_kit_room = left_room.deleted;
5448 room = mem::take(&mut left_room.room);
5449 channel = mem::take(&mut left_room.channel);
5450
5451 room_updated(&room, &session.peer);
5452 } else {
5453 return Ok(());
5454 }
5455
5456 if let Some(channel) = channel {
5457 channel_updated(
5458 &channel,
5459 &room,
5460 &session.peer,
5461 &*session.connection_pool().await,
5462 );
5463 }
5464
5465 {
5466 let pool = session.connection_pool().await;
5467 for canceled_user_id in canceled_calls_to_user_ids {
5468 for connection_id in pool.user_connection_ids(canceled_user_id) {
5469 session
5470 .peer
5471 .send(
5472 connection_id,
5473 proto::CallCanceled {
5474 room_id: room_id.to_proto(),
5475 },
5476 )
5477 .trace_err();
5478 }
5479 contacts_to_update.insert(canceled_user_id);
5480 }
5481 }
5482
5483 for contact_user_id in contacts_to_update {
5484 update_user_contacts(contact_user_id, &session).await?;
5485 }
5486
5487 if let Some(live_kit) = session.live_kit_client.as_ref() {
5488 live_kit
5489 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5490 .await
5491 .trace_err();
5492
5493 if delete_live_kit_room {
5494 live_kit.delete_room(live_kit_room).await.trace_err();
5495 }
5496 }
5497
5498 Ok(())
5499}
5500
5501async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5502 let left_channel_buffers = session
5503 .db()
5504 .await
5505 .leave_channel_buffers(session.connection_id)
5506 .await?;
5507
5508 for left_buffer in left_channel_buffers {
5509 channel_buffer_updated(
5510 session.connection_id,
5511 left_buffer.connections,
5512 &proto::UpdateChannelBufferCollaborators {
5513 channel_id: left_buffer.channel_id.to_proto(),
5514 collaborators: left_buffer.collaborators,
5515 },
5516 &session.peer,
5517 );
5518 }
5519
5520 Ok(())
5521}
5522
5523fn project_left(project: &db::LeftProject, session: &UserSession) {
5524 for connection_id in &project.connection_ids {
5525 if project.should_unshare {
5526 session
5527 .peer
5528 .send(
5529 *connection_id,
5530 proto::UnshareProject {
5531 project_id: project.id.to_proto(),
5532 },
5533 )
5534 .trace_err();
5535 } else {
5536 session
5537 .peer
5538 .send(
5539 *connection_id,
5540 proto::RemoveProjectCollaborator {
5541 project_id: project.id.to_proto(),
5542 peer_id: Some(session.connection_id.into()),
5543 },
5544 )
5545 .trace_err();
5546 }
5547 }
5548}
5549
5550pub trait ResultExt {
5551 type Ok;
5552
5553 fn trace_err(self) -> Option<Self::Ok>;
5554}
5555
5556impl<T, E> ResultExt for Result<T, E>
5557where
5558 E: std::fmt::Debug,
5559{
5560 type Ok = T;
5561
5562 #[track_caller]
5563 fn trace_err(self) -> Option<T> {
5564 match self {
5565 Ok(value) => Some(value),
5566 Err(error) => {
5567 tracing::error!("{:?}", error);
5568 None
5569 }
5570 }
5571 }
5572}