1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Config, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
203 if self.is_staff() {
204 return Ok(proto::Plan::ZedPro);
205 }
206
207 let Some(user_id) = self.user_id() else {
208 return Ok(proto::Plan::Free);
209 };
210
211 let db = self.db().await;
212 if db.has_active_billing_subscription(user_id).await? {
213 Ok(proto::Plan::ZedPro)
214 } else {
215 Ok(proto::Plan::Free)
216 }
217 }
218
219 fn dev_server_id(&self) -> Option<DevServerId> {
220 match &self.principal {
221 Principal::User(_) | Principal::Impersonated { .. } => None,
222 Principal::DevServer(dev_server) => Some(dev_server.id),
223 }
224 }
225
226 fn principal_id(&self) -> PrincipalId {
227 match &self.principal {
228 Principal::User(user) => PrincipalId::UserId(user.id),
229 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
230 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
231 }
232 }
233}
234
235impl Debug for Session {
236 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
237 let mut result = f.debug_struct("Session");
238 match &self.principal {
239 Principal::User(user) => {
240 result.field("user", &user.github_login);
241 }
242 Principal::Impersonated { user, admin } => {
243 result.field("user", &user.github_login);
244 result.field("impersonator", &admin.github_login);
245 }
246 Principal::DevServer(dev_server) => {
247 result.field("dev_server", &dev_server.id);
248 }
249 }
250 result.field("connection_id", &self.connection_id).finish()
251 }
252}
253
254struct UserSession(Session);
255
256impl UserSession {
257 pub fn new(s: Session) -> Option<Self> {
258 s.user_id().map(|_| UserSession(s))
259 }
260 pub fn user_id(&self) -> UserId {
261 self.0.user_id().unwrap()
262 }
263
264 pub fn email(&self) -> Option<String> {
265 match &self.0.principal {
266 Principal::User(user) => user.email_address.clone(),
267 Principal::Impersonated { user, .. } => user.email_address.clone(),
268 Principal::DevServer(..) => None,
269 }
270 }
271}
272
273impl Deref for UserSession {
274 type Target = Session;
275
276 fn deref(&self) -> &Self::Target {
277 &self.0
278 }
279}
280impl DerefMut for UserSession {
281 fn deref_mut(&mut self) -> &mut Self::Target {
282 &mut self.0
283 }
284}
285
286struct DevServerSession(Session);
287
288impl DevServerSession {
289 pub fn new(s: Session) -> Option<Self> {
290 s.dev_server_id().map(|_| DevServerSession(s))
291 }
292 pub fn dev_server_id(&self) -> DevServerId {
293 self.0.dev_server_id().unwrap()
294 }
295
296 fn dev_server(&self) -> &dev_server::Model {
297 match &self.0.principal {
298 Principal::DevServer(dev_server) => dev_server,
299 _ => unreachable!(),
300 }
301 }
302}
303
304impl Deref for DevServerSession {
305 type Target = Session;
306
307 fn deref(&self) -> &Self::Target {
308 &self.0
309 }
310}
311impl DerefMut for DevServerSession {
312 fn deref_mut(&mut self) -> &mut Self::Target {
313 &mut self.0
314 }
315}
316
317fn user_handler<M: RequestMessage, Fut>(
318 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
319) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
320where
321 Fut: Send + Future<Output = Result<()>>,
322{
323 let handler = Arc::new(handler);
324 move |message, response, session| {
325 let handler = handler.clone();
326 Box::pin(async move {
327 if let Some(user_session) = session.for_user() {
328 Ok(handler(message, response, user_session).await?)
329 } else {
330 Err(Error::Internal(anyhow!(
331 "must be a user to call {}",
332 M::NAME
333 )))
334 }
335 })
336 }
337}
338
339fn dev_server_handler<M: RequestMessage, Fut>(
340 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
341) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
342where
343 Fut: Send + Future<Output = Result<()>>,
344{
345 let handler = Arc::new(handler);
346 move |message, response, session| {
347 let handler = handler.clone();
348 Box::pin(async move {
349 if let Some(dev_server_session) = session.for_dev_server() {
350 Ok(handler(message, response, dev_server_session).await?)
351 } else {
352 Err(Error::Internal(anyhow!(
353 "must be a dev server to call {}",
354 M::NAME
355 )))
356 }
357 })
358 }
359}
360
361fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
362 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
363) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
364where
365 InnertRetFut: Send + Future<Output = Result<()>>,
366{
367 let handler = Arc::new(handler);
368 move |message, session| {
369 let handler = handler.clone();
370 Box::pin(async move {
371 if let Some(user_session) = session.for_user() {
372 Ok(handler(message, user_session).await?)
373 } else {
374 Err(Error::Internal(anyhow!(
375 "must be a user to call {}",
376 M::NAME
377 )))
378 }
379 })
380 }
381}
382
383struct DbHandle(Arc<Database>);
384
385impl Deref for DbHandle {
386 type Target = Database;
387
388 fn deref(&self) -> &Self::Target {
389 self.0.as_ref()
390 }
391}
392
393pub struct Server {
394 id: parking_lot::Mutex<ServerId>,
395 peer: Arc<Peer>,
396 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
397 app_state: Arc<AppState>,
398 handlers: HashMap<TypeId, MessageHandler>,
399 teardown: watch::Sender<bool>,
400}
401
402pub(crate) struct ConnectionPoolGuard<'a> {
403 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
404 _not_send: PhantomData<Rc<()>>,
405}
406
407#[derive(Serialize)]
408pub struct ServerSnapshot<'a> {
409 peer: &'a Peer,
410 #[serde(serialize_with = "serialize_deref")]
411 connection_pool: ConnectionPoolGuard<'a>,
412}
413
414pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
415where
416 S: Serializer,
417 T: Deref<Target = U>,
418 U: Serialize,
419{
420 Serialize::serialize(value.deref(), serializer)
421}
422
423impl Server {
424 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
425 let mut server = Self {
426 id: parking_lot::Mutex::new(id),
427 peer: Peer::new(id.0 as u32),
428 app_state: app_state.clone(),
429 connection_pool: Default::default(),
430 handlers: Default::default(),
431 teardown: watch::channel(false).0,
432 };
433
434 server
435 .add_request_handler(ping)
436 .add_request_handler(user_handler(create_room))
437 .add_request_handler(user_handler(join_room))
438 .add_request_handler(user_handler(rejoin_room))
439 .add_request_handler(user_handler(leave_room))
440 .add_request_handler(user_handler(set_room_participant_role))
441 .add_request_handler(user_handler(call))
442 .add_request_handler(user_handler(cancel_call))
443 .add_message_handler(user_message_handler(decline_call))
444 .add_request_handler(user_handler(update_participant_location))
445 .add_request_handler(user_handler(share_project))
446 .add_message_handler(unshare_project)
447 .add_request_handler(user_handler(join_project))
448 .add_request_handler(user_handler(join_hosted_project))
449 .add_request_handler(user_handler(rejoin_dev_server_projects))
450 .add_request_handler(user_handler(create_dev_server_project))
451 .add_request_handler(user_handler(update_dev_server_project))
452 .add_request_handler(user_handler(delete_dev_server_project))
453 .add_request_handler(user_handler(create_dev_server))
454 .add_request_handler(user_handler(regenerate_dev_server_token))
455 .add_request_handler(user_handler(rename_dev_server))
456 .add_request_handler(user_handler(delete_dev_server))
457 .add_request_handler(user_handler(list_remote_directory))
458 .add_request_handler(dev_server_handler(share_dev_server_project))
459 .add_request_handler(dev_server_handler(shutdown_dev_server))
460 .add_request_handler(dev_server_handler(reconnect_dev_server))
461 .add_message_handler(user_message_handler(leave_project))
462 .add_request_handler(update_project)
463 .add_request_handler(update_worktree)
464 .add_message_handler(start_language_server)
465 .add_message_handler(update_language_server)
466 .add_message_handler(update_diagnostic_summary)
467 .add_message_handler(update_worktree_settings)
468 .add_request_handler(user_handler(
469 forward_project_request_for_owner::<proto::TaskContextForLocation>,
470 ))
471 .add_request_handler(user_handler(
472 forward_project_request_for_owner::<proto::TaskTemplates>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::GetHover>,
476 ))
477 .add_request_handler(user_handler(
478 forward_read_only_project_request::<proto::GetDefinition>,
479 ))
480 .add_request_handler(user_handler(
481 forward_read_only_project_request::<proto::GetTypeDefinition>,
482 ))
483 .add_request_handler(user_handler(
484 forward_read_only_project_request::<proto::GetReferences>,
485 ))
486 .add_request_handler(user_handler(
487 forward_read_only_project_request::<proto::SearchProject>,
488 ))
489 .add_request_handler(user_handler(
490 forward_read_only_project_request::<proto::GetDocumentHighlights>,
491 ))
492 .add_request_handler(user_handler(
493 forward_read_only_project_request::<proto::GetProjectSymbols>,
494 ))
495 .add_request_handler(user_handler(
496 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
497 ))
498 .add_request_handler(user_handler(
499 forward_read_only_project_request::<proto::OpenBufferById>,
500 ))
501 .add_request_handler(user_handler(
502 forward_read_only_project_request::<proto::SynchronizeBuffers>,
503 ))
504 .add_request_handler(user_handler(
505 forward_read_only_project_request::<proto::InlayHints>,
506 ))
507 .add_request_handler(user_handler(
508 forward_read_only_project_request::<proto::OpenBufferByPath>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::GetCompletions>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
515 ))
516 .add_request_handler(user_handler(
517 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
521 ))
522 .add_request_handler(user_handler(
523 forward_mutating_project_request::<proto::GetCodeActions>,
524 ))
525 .add_request_handler(user_handler(
526 forward_mutating_project_request::<proto::ApplyCodeAction>,
527 ))
528 .add_request_handler(user_handler(
529 forward_mutating_project_request::<proto::PrepareRename>,
530 ))
531 .add_request_handler(user_handler(
532 forward_mutating_project_request::<proto::PerformRename>,
533 ))
534 .add_request_handler(user_handler(
535 forward_mutating_project_request::<proto::ReloadBuffers>,
536 ))
537 .add_request_handler(user_handler(
538 forward_mutating_project_request::<proto::FormatBuffers>,
539 ))
540 .add_request_handler(user_handler(
541 forward_mutating_project_request::<proto::CreateProjectEntry>,
542 ))
543 .add_request_handler(user_handler(
544 forward_mutating_project_request::<proto::RenameProjectEntry>,
545 ))
546 .add_request_handler(user_handler(
547 forward_mutating_project_request::<proto::CopyProjectEntry>,
548 ))
549 .add_request_handler(user_handler(
550 forward_mutating_project_request::<proto::DeleteProjectEntry>,
551 ))
552 .add_request_handler(user_handler(
553 forward_mutating_project_request::<proto::ExpandProjectEntry>,
554 ))
555 .add_request_handler(user_handler(
556 forward_mutating_project_request::<proto::OnTypeFormatting>,
557 ))
558 .add_request_handler(user_handler(
559 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
560 ))
561 .add_request_handler(user_handler(
562 forward_mutating_project_request::<proto::BlameBuffer>,
563 ))
564 .add_request_handler(user_handler(
565 forward_mutating_project_request::<proto::MultiLspQuery>,
566 ))
567 .add_request_handler(user_handler(
568 forward_mutating_project_request::<proto::RestartLanguageServers>,
569 ))
570 .add_request_handler(user_handler(
571 forward_mutating_project_request::<proto::LinkedEditingRange>,
572 ))
573 .add_message_handler(create_buffer_for_peer)
574 .add_request_handler(update_buffer)
575 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
576 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
577 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
578 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
579 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
580 .add_request_handler(get_users)
581 .add_request_handler(user_handler(fuzzy_search_users))
582 .add_request_handler(user_handler(request_contact))
583 .add_request_handler(user_handler(remove_contact))
584 .add_request_handler(user_handler(respond_to_contact_request))
585 .add_message_handler(subscribe_to_channels)
586 .add_request_handler(user_handler(create_channel))
587 .add_request_handler(user_handler(delete_channel))
588 .add_request_handler(user_handler(invite_channel_member))
589 .add_request_handler(user_handler(remove_channel_member))
590 .add_request_handler(user_handler(set_channel_member_role))
591 .add_request_handler(user_handler(set_channel_visibility))
592 .add_request_handler(user_handler(rename_channel))
593 .add_request_handler(user_handler(join_channel_buffer))
594 .add_request_handler(user_handler(leave_channel_buffer))
595 .add_message_handler(user_message_handler(update_channel_buffer))
596 .add_request_handler(user_handler(rejoin_channel_buffers))
597 .add_request_handler(user_handler(get_channel_members))
598 .add_request_handler(user_handler(respond_to_channel_invite))
599 .add_request_handler(user_handler(join_channel))
600 .add_request_handler(user_handler(join_channel_chat))
601 .add_message_handler(user_message_handler(leave_channel_chat))
602 .add_request_handler(user_handler(send_channel_message))
603 .add_request_handler(user_handler(remove_channel_message))
604 .add_request_handler(user_handler(update_channel_message))
605 .add_request_handler(user_handler(get_channel_messages))
606 .add_request_handler(user_handler(get_channel_messages_by_id))
607 .add_request_handler(user_handler(get_notifications))
608 .add_request_handler(user_handler(mark_notification_as_read))
609 .add_request_handler(user_handler(move_channel))
610 .add_request_handler(user_handler(follow))
611 .add_message_handler(user_message_handler(unfollow))
612 .add_message_handler(user_message_handler(update_followers))
613 .add_request_handler(user_handler(get_private_user_info))
614 .add_message_handler(user_message_handler(acknowledge_channel_message))
615 .add_message_handler(user_message_handler(acknowledge_buffer_version))
616 .add_request_handler(user_handler(get_supermaven_api_key))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::OpenContext>,
619 ))
620 .add_request_handler(user_handler(
621 forward_mutating_project_request::<proto::CreateContext>,
622 ))
623 .add_request_handler(user_handler(
624 forward_mutating_project_request::<proto::SynchronizeContexts>,
625 ))
626 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
627 .add_message_handler(update_context)
628 .add_request_handler({
629 let app_state = app_state.clone();
630 move |request, response, session| {
631 let app_state = app_state.clone();
632 async move {
633 complete_with_language_model(request, response, session, &app_state.config)
634 .await
635 }
636 }
637 })
638 .add_streaming_request_handler({
639 let app_state = app_state.clone();
640 move |request, response, session| {
641 let app_state = app_state.clone();
642 async move {
643 stream_complete_with_language_model(
644 request,
645 response,
646 session,
647 &app_state.config,
648 )
649 .await
650 }
651 }
652 })
653 .add_request_handler({
654 let app_state = app_state.clone();
655 move |request, response, session| {
656 let app_state = app_state.clone();
657 async move {
658 count_language_model_tokens(request, response, session, &app_state.config)
659 .await
660 }
661 }
662 })
663 .add_request_handler({
664 user_handler(move |request, response, session| {
665 get_cached_embeddings(request, response, session)
666 })
667 })
668 .add_request_handler({
669 let app_state = app_state.clone();
670 user_handler(move |request, response, session| {
671 compute_embeddings(
672 request,
673 response,
674 session,
675 app_state.config.openai_api_key.clone(),
676 )
677 })
678 });
679
680 Arc::new(server)
681 }
682
683 pub async fn start(&self) -> Result<()> {
684 let server_id = *self.id.lock();
685 let app_state = self.app_state.clone();
686 let peer = self.peer.clone();
687 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
688 let pool = self.connection_pool.clone();
689 let live_kit_client = self.app_state.live_kit_client.clone();
690
691 let span = info_span!("start server");
692 self.app_state.executor.spawn_detached(
693 async move {
694 tracing::info!("waiting for cleanup timeout");
695 timeout.await;
696 tracing::info!("cleanup timeout expired, retrieving stale rooms");
697 if let Some((room_ids, channel_ids)) = app_state
698 .db
699 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
700 .await
701 .trace_err()
702 {
703 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
704 tracing::info!(
705 stale_channel_buffer_count = channel_ids.len(),
706 "retrieved stale channel buffers"
707 );
708
709 for channel_id in channel_ids {
710 if let Some(refreshed_channel_buffer) = app_state
711 .db
712 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
713 .await
714 .trace_err()
715 {
716 for connection_id in refreshed_channel_buffer.connection_ids {
717 peer.send(
718 connection_id,
719 proto::UpdateChannelBufferCollaborators {
720 channel_id: channel_id.to_proto(),
721 collaborators: refreshed_channel_buffer
722 .collaborators
723 .clone(),
724 },
725 )
726 .trace_err();
727 }
728 }
729 }
730
731 for room_id in room_ids {
732 let mut contacts_to_update = HashSet::default();
733 let mut canceled_calls_to_user_ids = Vec::new();
734 let mut live_kit_room = String::new();
735 let mut delete_live_kit_room = false;
736
737 if let Some(mut refreshed_room) = app_state
738 .db
739 .clear_stale_room_participants(room_id, server_id)
740 .await
741 .trace_err()
742 {
743 tracing::info!(
744 room_id = room_id.0,
745 new_participant_count = refreshed_room.room.participants.len(),
746 "refreshed room"
747 );
748 room_updated(&refreshed_room.room, &peer);
749 if let Some(channel) = refreshed_room.channel.as_ref() {
750 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
751 }
752 contacts_to_update
753 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
754 contacts_to_update
755 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
756 canceled_calls_to_user_ids =
757 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
758 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
759 delete_live_kit_room = refreshed_room.room.participants.is_empty();
760 }
761
762 {
763 let pool = pool.lock();
764 for canceled_user_id in canceled_calls_to_user_ids {
765 for connection_id in pool.user_connection_ids(canceled_user_id) {
766 peer.send(
767 connection_id,
768 proto::CallCanceled {
769 room_id: room_id.to_proto(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776
777 for user_id in contacts_to_update {
778 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
779 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
780 if let Some((busy, contacts)) = busy.zip(contacts) {
781 let pool = pool.lock();
782 let updated_contact = contact_for_user(user_id, busy, &pool);
783 for contact in contacts {
784 if let db::Contact::Accepted {
785 user_id: contact_user_id,
786 ..
787 } = contact
788 {
789 for contact_conn_id in
790 pool.user_connection_ids(contact_user_id)
791 {
792 peer.send(
793 contact_conn_id,
794 proto::UpdateContacts {
795 contacts: vec![updated_contact.clone()],
796 remove_contacts: Default::default(),
797 incoming_requests: Default::default(),
798 remove_incoming_requests: Default::default(),
799 outgoing_requests: Default::default(),
800 remove_outgoing_requests: Default::default(),
801 },
802 )
803 .trace_err();
804 }
805 }
806 }
807 }
808 }
809
810 if let Some(live_kit) = live_kit_client.as_ref() {
811 if delete_live_kit_room {
812 live_kit.delete_room(live_kit_room).await.trace_err();
813 }
814 }
815 }
816 }
817
818 app_state
819 .db
820 .delete_stale_servers(&app_state.config.zed_environment, server_id)
821 .await
822 .trace_err();
823 }
824 .instrument(span),
825 );
826 Ok(())
827 }
828
829 pub fn teardown(&self) {
830 self.peer.teardown();
831 self.connection_pool.lock().reset();
832 let _ = self.teardown.send(true);
833 }
834
835 #[cfg(test)]
836 pub fn reset(&self, id: ServerId) {
837 self.teardown();
838 *self.id.lock() = id;
839 self.peer.reset(id.0 as u32);
840 let _ = self.teardown.send(false);
841 }
842
843 #[cfg(test)]
844 pub fn id(&self) -> ServerId {
845 *self.id.lock()
846 }
847
848 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
849 where
850 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
851 Fut: 'static + Send + Future<Output = Result<()>>,
852 M: EnvelopedMessage,
853 {
854 let prev_handler = self.handlers.insert(
855 TypeId::of::<M>(),
856 Box::new(move |envelope, session| {
857 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
858 let received_at = envelope.received_at;
859 tracing::info!("message received");
860 let start_time = Instant::now();
861 let future = (handler)(*envelope, session);
862 async move {
863 let result = future.await;
864 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
865 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
866 let queue_duration_ms = total_duration_ms - processing_duration_ms;
867 let payload_type = M::NAME;
868
869 match result {
870 Err(error) => {
871 tracing::error!(
872 ?error,
873 total_duration_ms,
874 processing_duration_ms,
875 queue_duration_ms,
876 payload_type,
877 "error handling message"
878 )
879 }
880 Ok(()) => tracing::info!(
881 total_duration_ms,
882 processing_duration_ms,
883 queue_duration_ms,
884 "finished handling message"
885 ),
886 }
887 }
888 .boxed()
889 }),
890 );
891 if prev_handler.is_some() {
892 panic!("registered a handler for the same message twice");
893 }
894 self
895 }
896
897 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
898 where
899 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
900 Fut: 'static + Send + Future<Output = Result<()>>,
901 M: EnvelopedMessage,
902 {
903 self.add_handler(move |envelope, session| handler(envelope.payload, session));
904 self
905 }
906
907 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
908 where
909 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
910 Fut: Send + Future<Output = Result<()>>,
911 M: RequestMessage,
912 {
913 let handler = Arc::new(handler);
914 self.add_handler(move |envelope, session| {
915 let receipt = envelope.receipt();
916 let handler = handler.clone();
917 async move {
918 let peer = session.peer.clone();
919 let responded = Arc::new(AtomicBool::default());
920 let response = Response {
921 peer: peer.clone(),
922 responded: responded.clone(),
923 receipt,
924 };
925 match (handler)(envelope.payload, response, session).await {
926 Ok(()) => {
927 if responded.load(std::sync::atomic::Ordering::SeqCst) {
928 Ok(())
929 } else {
930 Err(anyhow!("handler did not send a response"))?
931 }
932 }
933 Err(error) => {
934 let proto_err = match &error {
935 Error::Internal(err) => err.to_proto(),
936 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
937 };
938 peer.respond_with_error(receipt, proto_err)?;
939 Err(error)
940 }
941 }
942 }
943 })
944 }
945
946 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
947 where
948 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
949 Fut: Send + Future<Output = Result<()>>,
950 M: RequestMessage,
951 {
952 let handler = Arc::new(handler);
953 self.add_handler(move |envelope, session| {
954 let receipt = envelope.receipt();
955 let handler = handler.clone();
956 async move {
957 let peer = session.peer.clone();
958 let response = StreamingResponse {
959 peer: peer.clone(),
960 receipt,
961 };
962 match (handler)(envelope.payload, response, session).await {
963 Ok(()) => {
964 peer.end_stream(receipt)?;
965 Ok(())
966 }
967 Err(error) => {
968 let proto_err = match &error {
969 Error::Internal(err) => err.to_proto(),
970 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
971 };
972 peer.respond_with_error(receipt, proto_err)?;
973 Err(error)
974 }
975 }
976 }
977 })
978 }
979
980 #[allow(clippy::too_many_arguments)]
981 pub fn handle_connection(
982 self: &Arc<Self>,
983 connection: Connection,
984 address: String,
985 principal: Principal,
986 zed_version: ZedVersion,
987 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
988 executor: Executor,
989 ) -> impl Future<Output = ()> {
990 let this = self.clone();
991 let span = info_span!("handle connection", %address,
992 connection_id=field::Empty,
993 user_id=field::Empty,
994 login=field::Empty,
995 impersonator=field::Empty,
996 dev_server_id=field::Empty
997 );
998 principal.update_span(&span);
999
1000 let mut teardown = self.teardown.subscribe();
1001 async move {
1002 if *teardown.borrow() {
1003 tracing::error!("server is tearing down");
1004 return
1005 }
1006 let (connection_id, handle_io, mut incoming_rx) = this
1007 .peer
1008 .add_connection(connection, {
1009 let executor = executor.clone();
1010 move |duration| executor.sleep(duration)
1011 });
1012 tracing::Span::current().record("connection_id", format!("{}", connection_id));
1013 tracing::info!("connection opened");
1014
1015 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1016 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1017 Ok(http_client) => Arc::new(http_client),
1018 Err(error) => {
1019 tracing::error!(?error, "failed to create HTTP client");
1020 return;
1021 }
1022 };
1023
1024 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1025 Some(Arc::new(SupermavenAdminApi::new(
1026 supermaven_admin_api_key.to_string(),
1027 http_client.clone(),
1028 )))
1029 } else {
1030 None
1031 };
1032
1033 let session = Session {
1034 principal: principal.clone(),
1035 connection_id,
1036 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1037 peer: this.peer.clone(),
1038 connection_pool: this.connection_pool.clone(),
1039 live_kit_client: this.app_state.live_kit_client.clone(),
1040 http_client,
1041 rate_limiter: this.app_state.rate_limiter.clone(),
1042 _executor: executor.clone(),
1043 supermaven_client,
1044 };
1045
1046 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1047 tracing::error!(?error, "failed to send initial client update");
1048 return;
1049 }
1050
1051 let handle_io = handle_io.fuse();
1052 futures::pin_mut!(handle_io);
1053
1054 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1055 // This prevents deadlocks when e.g., client A performs a request to client B and
1056 // client B performs a request to client A. If both clients stop processing further
1057 // messages until their respective request completes, they won't have a chance to
1058 // respond to the other client's request and cause a deadlock.
1059 //
1060 // This arrangement ensures we will attempt to process earlier messages first, but fall
1061 // back to processing messages arrived later in the spirit of making progress.
1062 let mut foreground_message_handlers = FuturesUnordered::new();
1063 let concurrent_handlers = Arc::new(Semaphore::new(256));
1064 loop {
1065 let next_message = async {
1066 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1067 let message = incoming_rx.next().await;
1068 (permit, message)
1069 }.fuse();
1070 futures::pin_mut!(next_message);
1071 futures::select_biased! {
1072 _ = teardown.changed().fuse() => return,
1073 result = handle_io => {
1074 if let Err(error) = result {
1075 tracing::error!(?error, "error handling I/O");
1076 }
1077 break;
1078 }
1079 _ = foreground_message_handlers.next() => {}
1080 next_message = next_message => {
1081 let (permit, message) = next_message;
1082 if let Some(message) = message {
1083 let type_name = message.payload_type_name();
1084 // note: we copy all the fields from the parent span so we can query them in the logs.
1085 // (https://github.com/tokio-rs/tracing/issues/2670).
1086 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1087 user_id=field::Empty,
1088 login=field::Empty,
1089 impersonator=field::Empty,
1090 dev_server_id=field::Empty
1091 );
1092 principal.update_span(&span);
1093 let span_enter = span.enter();
1094 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1095 let is_background = message.is_background();
1096 let handle_message = (handler)(message, session.clone());
1097 drop(span_enter);
1098
1099 let handle_message = async move {
1100 handle_message.await;
1101 drop(permit);
1102 }.instrument(span);
1103 if is_background {
1104 executor.spawn_detached(handle_message);
1105 } else {
1106 foreground_message_handlers.push(handle_message);
1107 }
1108 } else {
1109 tracing::error!("no message handler");
1110 }
1111 } else {
1112 tracing::info!("connection closed");
1113 break;
1114 }
1115 }
1116 }
1117 }
1118
1119 drop(foreground_message_handlers);
1120 tracing::info!("signing out");
1121 if let Err(error) = connection_lost(session, teardown, executor).await {
1122 tracing::error!(?error, "error signing out");
1123 }
1124
1125 }.instrument(span)
1126 }
1127
1128 async fn send_initial_client_update(
1129 &self,
1130 connection_id: ConnectionId,
1131 principal: &Principal,
1132 zed_version: ZedVersion,
1133 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1134 session: &Session,
1135 ) -> Result<()> {
1136 self.peer.send(
1137 connection_id,
1138 proto::Hello {
1139 peer_id: Some(connection_id.into()),
1140 },
1141 )?;
1142 tracing::info!("sent hello message");
1143 if let Some(send_connection_id) = send_connection_id.take() {
1144 let _ = send_connection_id.send(connection_id);
1145 }
1146
1147 match principal {
1148 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1149 if !user.connected_once {
1150 self.peer.send(connection_id, proto::ShowContacts {})?;
1151 self.app_state
1152 .db
1153 .set_user_connected_once(user.id, true)
1154 .await?;
1155 }
1156
1157 update_user_plan(user.id, session).await?;
1158
1159 let (contacts, dev_server_projects) = future::try_join(
1160 self.app_state.db.get_contacts(user.id),
1161 self.app_state.db.dev_server_projects_update(user.id),
1162 )
1163 .await?;
1164
1165 {
1166 let mut pool = self.connection_pool.lock();
1167 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1168 self.peer.send(
1169 connection_id,
1170 build_initial_contacts_update(contacts, &pool),
1171 )?;
1172 }
1173
1174 if should_auto_subscribe_to_channels(zed_version) {
1175 subscribe_user_to_channels(user.id, session).await?;
1176 }
1177
1178 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1179
1180 if let Some(incoming_call) =
1181 self.app_state.db.incoming_call_for_user(user.id).await?
1182 {
1183 self.peer.send(connection_id, incoming_call)?;
1184 }
1185
1186 update_user_contacts(user.id, &session).await?;
1187 }
1188 Principal::DevServer(dev_server) => {
1189 {
1190 let mut pool = self.connection_pool.lock();
1191 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1192 {
1193 self.peer.send(
1194 stale_connection_id,
1195 proto::ShutdownDevServer {
1196 reason: Some(
1197 "another dev server connected with the same token".to_string(),
1198 ),
1199 },
1200 )?;
1201 pool.remove_connection(stale_connection_id)?;
1202 };
1203 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1204 }
1205
1206 let projects = self
1207 .app_state
1208 .db
1209 .get_projects_for_dev_server(dev_server.id)
1210 .await?;
1211 self.peer
1212 .send(connection_id, proto::DevServerInstructions { projects })?;
1213
1214 let status = self
1215 .app_state
1216 .db
1217 .dev_server_projects_update(dev_server.user_id)
1218 .await?;
1219 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1220 }
1221 }
1222
1223 Ok(())
1224 }
1225
1226 pub async fn invite_code_redeemed(
1227 self: &Arc<Self>,
1228 inviter_id: UserId,
1229 invitee_id: UserId,
1230 ) -> Result<()> {
1231 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1232 if let Some(code) = &user.invite_code {
1233 let pool = self.connection_pool.lock();
1234 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1235 for connection_id in pool.user_connection_ids(inviter_id) {
1236 self.peer.send(
1237 connection_id,
1238 proto::UpdateContacts {
1239 contacts: vec![invitee_contact.clone()],
1240 ..Default::default()
1241 },
1242 )?;
1243 self.peer.send(
1244 connection_id,
1245 proto::UpdateInviteInfo {
1246 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1247 count: user.invite_count as u32,
1248 },
1249 )?;
1250 }
1251 }
1252 }
1253 Ok(())
1254 }
1255
1256 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1257 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1258 if let Some(invite_code) = &user.invite_code {
1259 let pool = self.connection_pool.lock();
1260 for connection_id in pool.user_connection_ids(user_id) {
1261 self.peer.send(
1262 connection_id,
1263 proto::UpdateInviteInfo {
1264 url: format!(
1265 "{}{}",
1266 self.app_state.config.invite_link_prefix, invite_code
1267 ),
1268 count: user.invite_count as u32,
1269 },
1270 )?;
1271 }
1272 }
1273 }
1274 Ok(())
1275 }
1276
1277 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1278 ServerSnapshot {
1279 connection_pool: ConnectionPoolGuard {
1280 guard: self.connection_pool.lock(),
1281 _not_send: PhantomData,
1282 },
1283 peer: &self.peer,
1284 }
1285 }
1286}
1287
1288impl<'a> Deref for ConnectionPoolGuard<'a> {
1289 type Target = ConnectionPool;
1290
1291 fn deref(&self) -> &Self::Target {
1292 &self.guard
1293 }
1294}
1295
1296impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1297 fn deref_mut(&mut self) -> &mut Self::Target {
1298 &mut self.guard
1299 }
1300}
1301
1302impl<'a> Drop for ConnectionPoolGuard<'a> {
1303 fn drop(&mut self) {
1304 #[cfg(test)]
1305 self.check_invariants();
1306 }
1307}
1308
1309fn broadcast<F>(
1310 sender_id: Option<ConnectionId>,
1311 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1312 mut f: F,
1313) where
1314 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1315{
1316 for receiver_id in receiver_ids {
1317 if Some(receiver_id) != sender_id {
1318 if let Err(error) = f(receiver_id) {
1319 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1320 }
1321 }
1322 }
1323}
1324
1325pub struct ProtocolVersion(u32);
1326
1327impl Header for ProtocolVersion {
1328 fn name() -> &'static HeaderName {
1329 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1330 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1331 }
1332
1333 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1334 where
1335 Self: Sized,
1336 I: Iterator<Item = &'i axum::http::HeaderValue>,
1337 {
1338 let version = values
1339 .next()
1340 .ok_or_else(axum::headers::Error::invalid)?
1341 .to_str()
1342 .map_err(|_| axum::headers::Error::invalid())?
1343 .parse()
1344 .map_err(|_| axum::headers::Error::invalid())?;
1345 Ok(Self(version))
1346 }
1347
1348 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1349 values.extend([self.0.to_string().parse().unwrap()]);
1350 }
1351}
1352
1353pub struct AppVersionHeader(SemanticVersion);
1354impl Header for AppVersionHeader {
1355 fn name() -> &'static HeaderName {
1356 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1357 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1358 }
1359
1360 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1361 where
1362 Self: Sized,
1363 I: Iterator<Item = &'i axum::http::HeaderValue>,
1364 {
1365 let version = values
1366 .next()
1367 .ok_or_else(axum::headers::Error::invalid)?
1368 .to_str()
1369 .map_err(|_| axum::headers::Error::invalid())?
1370 .parse()
1371 .map_err(|_| axum::headers::Error::invalid())?;
1372 Ok(Self(version))
1373 }
1374
1375 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1376 values.extend([self.0.to_string().parse().unwrap()]);
1377 }
1378}
1379
1380pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1381 Router::new()
1382 .route("/rpc", get(handle_websocket_request))
1383 .layer(
1384 ServiceBuilder::new()
1385 .layer(Extension(server.app_state.clone()))
1386 .layer(middleware::from_fn(auth::validate_header)),
1387 )
1388 .route("/metrics", get(handle_metrics))
1389 .layer(Extension(server))
1390}
1391
1392pub async fn handle_websocket_request(
1393 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1394 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1395 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1396 Extension(server): Extension<Arc<Server>>,
1397 Extension(principal): Extension<Principal>,
1398 ws: WebSocketUpgrade,
1399) -> axum::response::Response {
1400 if protocol_version != rpc::PROTOCOL_VERSION {
1401 return (
1402 StatusCode::UPGRADE_REQUIRED,
1403 "client must be upgraded".to_string(),
1404 )
1405 .into_response();
1406 }
1407
1408 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1409 return (
1410 StatusCode::UPGRADE_REQUIRED,
1411 "no version header found".to_string(),
1412 )
1413 .into_response();
1414 };
1415
1416 if !version.can_collaborate() {
1417 return (
1418 StatusCode::UPGRADE_REQUIRED,
1419 "client must be upgraded".to_string(),
1420 )
1421 .into_response();
1422 }
1423
1424 let socket_address = socket_address.to_string();
1425 ws.on_upgrade(move |socket| {
1426 let socket = socket
1427 .map_ok(to_tungstenite_message)
1428 .err_into()
1429 .with(|message| async move { to_axum_message(message) });
1430 let connection = Connection::new(Box::pin(socket));
1431 async move {
1432 server
1433 .handle_connection(
1434 connection,
1435 socket_address,
1436 principal,
1437 version,
1438 None,
1439 Executor::Production,
1440 )
1441 .await;
1442 }
1443 })
1444}
1445
1446pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1447 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1448 let connections_metric = CONNECTIONS_METRIC
1449 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1450
1451 let connections = server
1452 .connection_pool
1453 .lock()
1454 .connections()
1455 .filter(|connection| !connection.admin)
1456 .count();
1457 connections_metric.set(connections as _);
1458
1459 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1460 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1461 register_int_gauge!(
1462 "shared_projects",
1463 "number of open projects with one or more guests"
1464 )
1465 .unwrap()
1466 });
1467
1468 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1469 shared_projects_metric.set(shared_projects as _);
1470
1471 let encoder = prometheus::TextEncoder::new();
1472 let metric_families = prometheus::gather();
1473 let encoded_metrics = encoder
1474 .encode_to_string(&metric_families)
1475 .map_err(|err| anyhow!("{}", err))?;
1476 Ok(encoded_metrics)
1477}
1478
1479#[instrument(err, skip(executor))]
1480async fn connection_lost(
1481 session: Session,
1482 mut teardown: watch::Receiver<bool>,
1483 executor: Executor,
1484) -> Result<()> {
1485 session.peer.disconnect(session.connection_id);
1486 session
1487 .connection_pool()
1488 .await
1489 .remove_connection(session.connection_id)?;
1490
1491 session
1492 .db()
1493 .await
1494 .connection_lost(session.connection_id)
1495 .await
1496 .trace_err();
1497
1498 futures::select_biased! {
1499 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1500 match &session.principal {
1501 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1502 let session = session.for_user().unwrap();
1503
1504 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1505 leave_room_for_session(&session, session.connection_id).await.trace_err();
1506 leave_channel_buffers_for_session(&session)
1507 .await
1508 .trace_err();
1509
1510 if !session
1511 .connection_pool()
1512 .await
1513 .is_user_online(session.user_id())
1514 {
1515 let db = session.db().await;
1516 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1517 room_updated(&room, &session.peer);
1518 }
1519 }
1520
1521 update_user_contacts(session.user_id(), &session).await?;
1522 },
1523 Principal::DevServer(_) => {
1524 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1525 },
1526 }
1527 },
1528 _ = teardown.changed().fuse() => {}
1529 }
1530
1531 Ok(())
1532}
1533
1534/// Acknowledges a ping from a client, used to keep the connection alive.
1535async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1536 response.send(proto::Ack {})?;
1537 Ok(())
1538}
1539
1540/// Creates a new room for calling (outside of channels)
1541async fn create_room(
1542 _request: proto::CreateRoom,
1543 response: Response<proto::CreateRoom>,
1544 session: UserSession,
1545) -> Result<()> {
1546 let live_kit_room = nanoid::nanoid!(30);
1547
1548 let live_kit_connection_info = util::maybe!(async {
1549 let live_kit = session.live_kit_client.as_ref();
1550 let live_kit = live_kit?;
1551 let user_id = session.user_id().to_string();
1552
1553 let token = live_kit
1554 .room_token(&live_kit_room, &user_id.to_string())
1555 .trace_err()?;
1556
1557 Some(proto::LiveKitConnectionInfo {
1558 server_url: live_kit.url().into(),
1559 token,
1560 can_publish: true,
1561 })
1562 })
1563 .await;
1564
1565 let room = session
1566 .db()
1567 .await
1568 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1569 .await?;
1570
1571 response.send(proto::CreateRoomResponse {
1572 room: Some(room.clone()),
1573 live_kit_connection_info,
1574 })?;
1575
1576 update_user_contacts(session.user_id(), &session).await?;
1577 Ok(())
1578}
1579
1580/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1581async fn join_room(
1582 request: proto::JoinRoom,
1583 response: Response<proto::JoinRoom>,
1584 session: UserSession,
1585) -> Result<()> {
1586 let room_id = RoomId::from_proto(request.id);
1587
1588 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1589
1590 if let Some(channel_id) = channel_id {
1591 return join_channel_internal(channel_id, Box::new(response), session).await;
1592 }
1593
1594 let joined_room = {
1595 let room = session
1596 .db()
1597 .await
1598 .join_room(room_id, session.user_id(), session.connection_id)
1599 .await?;
1600 room_updated(&room.room, &session.peer);
1601 room.into_inner()
1602 };
1603
1604 for connection_id in session
1605 .connection_pool()
1606 .await
1607 .user_connection_ids(session.user_id())
1608 {
1609 session
1610 .peer
1611 .send(
1612 connection_id,
1613 proto::CallCanceled {
1614 room_id: room_id.to_proto(),
1615 },
1616 )
1617 .trace_err();
1618 }
1619
1620 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1621 if let Some(token) = live_kit
1622 .room_token(
1623 &joined_room.room.live_kit_room,
1624 &session.user_id().to_string(),
1625 )
1626 .trace_err()
1627 {
1628 Some(proto::LiveKitConnectionInfo {
1629 server_url: live_kit.url().into(),
1630 token,
1631 can_publish: true,
1632 })
1633 } else {
1634 None
1635 }
1636 } else {
1637 None
1638 };
1639
1640 response.send(proto::JoinRoomResponse {
1641 room: Some(joined_room.room),
1642 channel_id: None,
1643 live_kit_connection_info,
1644 })?;
1645
1646 update_user_contacts(session.user_id(), &session).await?;
1647 Ok(())
1648}
1649
1650/// Rejoin room is used to reconnect to a room after connection errors.
1651async fn rejoin_room(
1652 request: proto::RejoinRoom,
1653 response: Response<proto::RejoinRoom>,
1654 session: UserSession,
1655) -> Result<()> {
1656 let room;
1657 let channel;
1658 {
1659 let mut rejoined_room = session
1660 .db()
1661 .await
1662 .rejoin_room(request, session.user_id(), session.connection_id)
1663 .await?;
1664
1665 response.send(proto::RejoinRoomResponse {
1666 room: Some(rejoined_room.room.clone()),
1667 reshared_projects: rejoined_room
1668 .reshared_projects
1669 .iter()
1670 .map(|project| proto::ResharedProject {
1671 id: project.id.to_proto(),
1672 collaborators: project
1673 .collaborators
1674 .iter()
1675 .map(|collaborator| collaborator.to_proto())
1676 .collect(),
1677 })
1678 .collect(),
1679 rejoined_projects: rejoined_room
1680 .rejoined_projects
1681 .iter()
1682 .map(|rejoined_project| rejoined_project.to_proto())
1683 .collect(),
1684 })?;
1685 room_updated(&rejoined_room.room, &session.peer);
1686
1687 for project in &rejoined_room.reshared_projects {
1688 for collaborator in &project.collaborators {
1689 session
1690 .peer
1691 .send(
1692 collaborator.connection_id,
1693 proto::UpdateProjectCollaborator {
1694 project_id: project.id.to_proto(),
1695 old_peer_id: Some(project.old_connection_id.into()),
1696 new_peer_id: Some(session.connection_id.into()),
1697 },
1698 )
1699 .trace_err();
1700 }
1701
1702 broadcast(
1703 Some(session.connection_id),
1704 project
1705 .collaborators
1706 .iter()
1707 .map(|collaborator| collaborator.connection_id),
1708 |connection_id| {
1709 session.peer.forward_send(
1710 session.connection_id,
1711 connection_id,
1712 proto::UpdateProject {
1713 project_id: project.id.to_proto(),
1714 worktrees: project.worktrees.clone(),
1715 },
1716 )
1717 },
1718 );
1719 }
1720
1721 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1722
1723 let rejoined_room = rejoined_room.into_inner();
1724
1725 room = rejoined_room.room;
1726 channel = rejoined_room.channel;
1727 }
1728
1729 if let Some(channel) = channel {
1730 channel_updated(
1731 &channel,
1732 &room,
1733 &session.peer,
1734 &*session.connection_pool().await,
1735 );
1736 }
1737
1738 update_user_contacts(session.user_id(), &session).await?;
1739 Ok(())
1740}
1741
1742fn notify_rejoined_projects(
1743 rejoined_projects: &mut Vec<RejoinedProject>,
1744 session: &UserSession,
1745) -> Result<()> {
1746 for project in rejoined_projects.iter() {
1747 for collaborator in &project.collaborators {
1748 session
1749 .peer
1750 .send(
1751 collaborator.connection_id,
1752 proto::UpdateProjectCollaborator {
1753 project_id: project.id.to_proto(),
1754 old_peer_id: Some(project.old_connection_id.into()),
1755 new_peer_id: Some(session.connection_id.into()),
1756 },
1757 )
1758 .trace_err();
1759 }
1760 }
1761
1762 for project in rejoined_projects {
1763 for worktree in mem::take(&mut project.worktrees) {
1764 #[cfg(any(test, feature = "test-support"))]
1765 const MAX_CHUNK_SIZE: usize = 2;
1766 #[cfg(not(any(test, feature = "test-support")))]
1767 const MAX_CHUNK_SIZE: usize = 256;
1768
1769 // Stream this worktree's entries.
1770 let message = proto::UpdateWorktree {
1771 project_id: project.id.to_proto(),
1772 worktree_id: worktree.id,
1773 abs_path: worktree.abs_path.clone(),
1774 root_name: worktree.root_name,
1775 updated_entries: worktree.updated_entries,
1776 removed_entries: worktree.removed_entries,
1777 scan_id: worktree.scan_id,
1778 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1779 updated_repositories: worktree.updated_repositories,
1780 removed_repositories: worktree.removed_repositories,
1781 };
1782 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1783 session.peer.send(session.connection_id, update.clone())?;
1784 }
1785
1786 // Stream this worktree's diagnostics.
1787 for summary in worktree.diagnostic_summaries {
1788 session.peer.send(
1789 session.connection_id,
1790 proto::UpdateDiagnosticSummary {
1791 project_id: project.id.to_proto(),
1792 worktree_id: worktree.id,
1793 summary: Some(summary),
1794 },
1795 )?;
1796 }
1797
1798 for settings_file in worktree.settings_files {
1799 session.peer.send(
1800 session.connection_id,
1801 proto::UpdateWorktreeSettings {
1802 project_id: project.id.to_proto(),
1803 worktree_id: worktree.id,
1804 path: settings_file.path,
1805 content: Some(settings_file.content),
1806 },
1807 )?;
1808 }
1809 }
1810
1811 for language_server in &project.language_servers {
1812 session.peer.send(
1813 session.connection_id,
1814 proto::UpdateLanguageServer {
1815 project_id: project.id.to_proto(),
1816 language_server_id: language_server.id,
1817 variant: Some(
1818 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1819 proto::LspDiskBasedDiagnosticsUpdated {},
1820 ),
1821 ),
1822 },
1823 )?;
1824 }
1825 }
1826 Ok(())
1827}
1828
1829/// leave room disconnects from the room.
1830async fn leave_room(
1831 _: proto::LeaveRoom,
1832 response: Response<proto::LeaveRoom>,
1833 session: UserSession,
1834) -> Result<()> {
1835 leave_room_for_session(&session, session.connection_id).await?;
1836 response.send(proto::Ack {})?;
1837 Ok(())
1838}
1839
1840/// Updates the permissions of someone else in the room.
1841async fn set_room_participant_role(
1842 request: proto::SetRoomParticipantRole,
1843 response: Response<proto::SetRoomParticipantRole>,
1844 session: UserSession,
1845) -> Result<()> {
1846 let user_id = UserId::from_proto(request.user_id);
1847 let role = ChannelRole::from(request.role());
1848
1849 let (live_kit_room, can_publish) = {
1850 let room = session
1851 .db()
1852 .await
1853 .set_room_participant_role(
1854 session.user_id(),
1855 RoomId::from_proto(request.room_id),
1856 user_id,
1857 role,
1858 )
1859 .await?;
1860
1861 let live_kit_room = room.live_kit_room.clone();
1862 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1863 room_updated(&room, &session.peer);
1864 (live_kit_room, can_publish)
1865 };
1866
1867 if let Some(live_kit) = session.live_kit_client.as_ref() {
1868 live_kit
1869 .update_participant(
1870 live_kit_room.clone(),
1871 request.user_id.to_string(),
1872 live_kit_server::proto::ParticipantPermission {
1873 can_subscribe: true,
1874 can_publish,
1875 can_publish_data: can_publish,
1876 hidden: false,
1877 recorder: false,
1878 },
1879 )
1880 .await
1881 .trace_err();
1882 }
1883
1884 response.send(proto::Ack {})?;
1885 Ok(())
1886}
1887
1888/// Call someone else into the current room
1889async fn call(
1890 request: proto::Call,
1891 response: Response<proto::Call>,
1892 session: UserSession,
1893) -> Result<()> {
1894 let room_id = RoomId::from_proto(request.room_id);
1895 let calling_user_id = session.user_id();
1896 let calling_connection_id = session.connection_id;
1897 let called_user_id = UserId::from_proto(request.called_user_id);
1898 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1899 if !session
1900 .db()
1901 .await
1902 .has_contact(calling_user_id, called_user_id)
1903 .await?
1904 {
1905 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1906 }
1907
1908 let incoming_call = {
1909 let (room, incoming_call) = &mut *session
1910 .db()
1911 .await
1912 .call(
1913 room_id,
1914 calling_user_id,
1915 calling_connection_id,
1916 called_user_id,
1917 initial_project_id,
1918 )
1919 .await?;
1920 room_updated(&room, &session.peer);
1921 mem::take(incoming_call)
1922 };
1923 update_user_contacts(called_user_id, &session).await?;
1924
1925 let mut calls = session
1926 .connection_pool()
1927 .await
1928 .user_connection_ids(called_user_id)
1929 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1930 .collect::<FuturesUnordered<_>>();
1931
1932 while let Some(call_response) = calls.next().await {
1933 match call_response.as_ref() {
1934 Ok(_) => {
1935 response.send(proto::Ack {})?;
1936 return Ok(());
1937 }
1938 Err(_) => {
1939 call_response.trace_err();
1940 }
1941 }
1942 }
1943
1944 {
1945 let room = session
1946 .db()
1947 .await
1948 .call_failed(room_id, called_user_id)
1949 .await?;
1950 room_updated(&room, &session.peer);
1951 }
1952 update_user_contacts(called_user_id, &session).await?;
1953
1954 Err(anyhow!("failed to ring user"))?
1955}
1956
1957/// Cancel an outgoing call.
1958async fn cancel_call(
1959 request: proto::CancelCall,
1960 response: Response<proto::CancelCall>,
1961 session: UserSession,
1962) -> Result<()> {
1963 let called_user_id = UserId::from_proto(request.called_user_id);
1964 let room_id = RoomId::from_proto(request.room_id);
1965 {
1966 let room = session
1967 .db()
1968 .await
1969 .cancel_call(room_id, session.connection_id, called_user_id)
1970 .await?;
1971 room_updated(&room, &session.peer);
1972 }
1973
1974 for connection_id in session
1975 .connection_pool()
1976 .await
1977 .user_connection_ids(called_user_id)
1978 {
1979 session
1980 .peer
1981 .send(
1982 connection_id,
1983 proto::CallCanceled {
1984 room_id: room_id.to_proto(),
1985 },
1986 )
1987 .trace_err();
1988 }
1989 response.send(proto::Ack {})?;
1990
1991 update_user_contacts(called_user_id, &session).await?;
1992 Ok(())
1993}
1994
1995/// Decline an incoming call.
1996async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1997 let room_id = RoomId::from_proto(message.room_id);
1998 {
1999 let room = session
2000 .db()
2001 .await
2002 .decline_call(Some(room_id), session.user_id())
2003 .await?
2004 .ok_or_else(|| anyhow!("failed to decline call"))?;
2005 room_updated(&room, &session.peer);
2006 }
2007
2008 for connection_id in session
2009 .connection_pool()
2010 .await
2011 .user_connection_ids(session.user_id())
2012 {
2013 session
2014 .peer
2015 .send(
2016 connection_id,
2017 proto::CallCanceled {
2018 room_id: room_id.to_proto(),
2019 },
2020 )
2021 .trace_err();
2022 }
2023 update_user_contacts(session.user_id(), &session).await?;
2024 Ok(())
2025}
2026
2027/// Updates other participants in the room with your current location.
2028async fn update_participant_location(
2029 request: proto::UpdateParticipantLocation,
2030 response: Response<proto::UpdateParticipantLocation>,
2031 session: UserSession,
2032) -> Result<()> {
2033 let room_id = RoomId::from_proto(request.room_id);
2034 let location = request
2035 .location
2036 .ok_or_else(|| anyhow!("invalid location"))?;
2037
2038 let db = session.db().await;
2039 let room = db
2040 .update_room_participant_location(room_id, session.connection_id, location)
2041 .await?;
2042
2043 room_updated(&room, &session.peer);
2044 response.send(proto::Ack {})?;
2045 Ok(())
2046}
2047
2048/// Share a project into the room.
2049async fn share_project(
2050 request: proto::ShareProject,
2051 response: Response<proto::ShareProject>,
2052 session: UserSession,
2053) -> Result<()> {
2054 let (project_id, room) = &*session
2055 .db()
2056 .await
2057 .share_project(
2058 RoomId::from_proto(request.room_id),
2059 session.connection_id,
2060 &request.worktrees,
2061 request
2062 .dev_server_project_id
2063 .map(|id| DevServerProjectId::from_proto(id)),
2064 )
2065 .await?;
2066 response.send(proto::ShareProjectResponse {
2067 project_id: project_id.to_proto(),
2068 })?;
2069 room_updated(&room, &session.peer);
2070
2071 Ok(())
2072}
2073
2074/// Unshare a project from the room.
2075async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2076 let project_id = ProjectId::from_proto(message.project_id);
2077 unshare_project_internal(
2078 project_id,
2079 session.connection_id,
2080 session.user_id(),
2081 &session,
2082 )
2083 .await
2084}
2085
2086async fn unshare_project_internal(
2087 project_id: ProjectId,
2088 connection_id: ConnectionId,
2089 user_id: Option<UserId>,
2090 session: &Session,
2091) -> Result<()> {
2092 let delete = {
2093 let room_guard = session
2094 .db()
2095 .await
2096 .unshare_project(project_id, connection_id, user_id)
2097 .await?;
2098
2099 let (delete, room, guest_connection_ids) = &*room_guard;
2100
2101 let message = proto::UnshareProject {
2102 project_id: project_id.to_proto(),
2103 };
2104
2105 broadcast(
2106 Some(connection_id),
2107 guest_connection_ids.iter().copied(),
2108 |conn_id| session.peer.send(conn_id, message.clone()),
2109 );
2110 if let Some(room) = room {
2111 room_updated(room, &session.peer);
2112 }
2113
2114 *delete
2115 };
2116
2117 if delete {
2118 let db = session.db().await;
2119 db.delete_project(project_id).await?;
2120 }
2121
2122 Ok(())
2123}
2124
2125/// DevServer makes a project available online
2126async fn share_dev_server_project(
2127 request: proto::ShareDevServerProject,
2128 response: Response<proto::ShareDevServerProject>,
2129 session: DevServerSession,
2130) -> Result<()> {
2131 let (dev_server_project, user_id, status) = session
2132 .db()
2133 .await
2134 .share_dev_server_project(
2135 DevServerProjectId::from_proto(request.dev_server_project_id),
2136 session.dev_server_id(),
2137 session.connection_id,
2138 &request.worktrees,
2139 )
2140 .await?;
2141 let Some(project_id) = dev_server_project.project_id else {
2142 return Err(anyhow!("failed to share remote project"))?;
2143 };
2144
2145 send_dev_server_projects_update(user_id, status, &session).await;
2146
2147 response.send(proto::ShareProjectResponse { project_id })?;
2148
2149 Ok(())
2150}
2151
2152/// Join someone elses shared project.
2153async fn join_project(
2154 request: proto::JoinProject,
2155 response: Response<proto::JoinProject>,
2156 session: UserSession,
2157) -> Result<()> {
2158 let project_id = ProjectId::from_proto(request.project_id);
2159
2160 tracing::info!(%project_id, "join project");
2161
2162 let db = session.db().await;
2163 let (project, replica_id) = &mut *db
2164 .join_project(project_id, session.connection_id, session.user_id())
2165 .await?;
2166 drop(db);
2167 tracing::info!(%project_id, "join remote project");
2168 join_project_internal(response, session, project, replica_id)
2169}
2170
2171trait JoinProjectInternalResponse {
2172 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2173}
2174impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2175 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2176 Response::<proto::JoinProject>::send(self, result)
2177 }
2178}
2179impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2180 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2181 Response::<proto::JoinHostedProject>::send(self, result)
2182 }
2183}
2184
2185fn join_project_internal(
2186 response: impl JoinProjectInternalResponse,
2187 session: UserSession,
2188 project: &mut Project,
2189 replica_id: &ReplicaId,
2190) -> Result<()> {
2191 let collaborators = project
2192 .collaborators
2193 .iter()
2194 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2195 .map(|collaborator| collaborator.to_proto())
2196 .collect::<Vec<_>>();
2197 let project_id = project.id;
2198 let guest_user_id = session.user_id();
2199
2200 let worktrees = project
2201 .worktrees
2202 .iter()
2203 .map(|(id, worktree)| proto::WorktreeMetadata {
2204 id: *id,
2205 root_name: worktree.root_name.clone(),
2206 visible: worktree.visible,
2207 abs_path: worktree.abs_path.clone(),
2208 })
2209 .collect::<Vec<_>>();
2210
2211 let add_project_collaborator = proto::AddProjectCollaborator {
2212 project_id: project_id.to_proto(),
2213 collaborator: Some(proto::Collaborator {
2214 peer_id: Some(session.connection_id.into()),
2215 replica_id: replica_id.0 as u32,
2216 user_id: guest_user_id.to_proto(),
2217 }),
2218 };
2219
2220 for collaborator in &collaborators {
2221 session
2222 .peer
2223 .send(
2224 collaborator.peer_id.unwrap().into(),
2225 add_project_collaborator.clone(),
2226 )
2227 .trace_err();
2228 }
2229
2230 // First, we send the metadata associated with each worktree.
2231 response.send(proto::JoinProjectResponse {
2232 project_id: project.id.0 as u64,
2233 worktrees: worktrees.clone(),
2234 replica_id: replica_id.0 as u32,
2235 collaborators: collaborators.clone(),
2236 language_servers: project.language_servers.clone(),
2237 role: project.role.into(),
2238 dev_server_project_id: project
2239 .dev_server_project_id
2240 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2241 })?;
2242
2243 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2244 #[cfg(any(test, feature = "test-support"))]
2245 const MAX_CHUNK_SIZE: usize = 2;
2246 #[cfg(not(any(test, feature = "test-support")))]
2247 const MAX_CHUNK_SIZE: usize = 256;
2248
2249 // Stream this worktree's entries.
2250 let message = proto::UpdateWorktree {
2251 project_id: project_id.to_proto(),
2252 worktree_id,
2253 abs_path: worktree.abs_path.clone(),
2254 root_name: worktree.root_name,
2255 updated_entries: worktree.entries,
2256 removed_entries: Default::default(),
2257 scan_id: worktree.scan_id,
2258 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2259 updated_repositories: worktree.repository_entries.into_values().collect(),
2260 removed_repositories: Default::default(),
2261 };
2262 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2263 session.peer.send(session.connection_id, update.clone())?;
2264 }
2265
2266 // Stream this worktree's diagnostics.
2267 for summary in worktree.diagnostic_summaries {
2268 session.peer.send(
2269 session.connection_id,
2270 proto::UpdateDiagnosticSummary {
2271 project_id: project_id.to_proto(),
2272 worktree_id: worktree.id,
2273 summary: Some(summary),
2274 },
2275 )?;
2276 }
2277
2278 for settings_file in worktree.settings_files {
2279 session.peer.send(
2280 session.connection_id,
2281 proto::UpdateWorktreeSettings {
2282 project_id: project_id.to_proto(),
2283 worktree_id: worktree.id,
2284 path: settings_file.path,
2285 content: Some(settings_file.content),
2286 },
2287 )?;
2288 }
2289 }
2290
2291 for language_server in &project.language_servers {
2292 session.peer.send(
2293 session.connection_id,
2294 proto::UpdateLanguageServer {
2295 project_id: project_id.to_proto(),
2296 language_server_id: language_server.id,
2297 variant: Some(
2298 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2299 proto::LspDiskBasedDiagnosticsUpdated {},
2300 ),
2301 ),
2302 },
2303 )?;
2304 }
2305
2306 Ok(())
2307}
2308
2309/// Leave someone elses shared project.
2310async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2311 let sender_id = session.connection_id;
2312 let project_id = ProjectId::from_proto(request.project_id);
2313 let db = session.db().await;
2314 if db.is_hosted_project(project_id).await? {
2315 let project = db.leave_hosted_project(project_id, sender_id).await?;
2316 project_left(&project, &session);
2317 return Ok(());
2318 }
2319
2320 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2321 tracing::info!(
2322 %project_id,
2323 "leave project"
2324 );
2325
2326 project_left(&project, &session);
2327 if let Some(room) = room {
2328 room_updated(&room, &session.peer);
2329 }
2330
2331 Ok(())
2332}
2333
2334async fn join_hosted_project(
2335 request: proto::JoinHostedProject,
2336 response: Response<proto::JoinHostedProject>,
2337 session: UserSession,
2338) -> Result<()> {
2339 let (mut project, replica_id) = session
2340 .db()
2341 .await
2342 .join_hosted_project(
2343 ProjectId(request.project_id as i32),
2344 session.user_id(),
2345 session.connection_id,
2346 )
2347 .await?;
2348
2349 join_project_internal(response, session, &mut project, &replica_id)
2350}
2351
2352async fn list_remote_directory(
2353 request: proto::ListRemoteDirectory,
2354 response: Response<proto::ListRemoteDirectory>,
2355 session: UserSession,
2356) -> Result<()> {
2357 let dev_server_id = DevServerId(request.dev_server_id as i32);
2358 let dev_server_connection_id = session
2359 .connection_pool()
2360 .await
2361 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2362
2363 session
2364 .db()
2365 .await
2366 .get_dev_server_for_user(dev_server_id, session.user_id())
2367 .await?;
2368
2369 response.send(
2370 session
2371 .peer
2372 .forward_request(session.connection_id, dev_server_connection_id, request)
2373 .await?,
2374 )?;
2375 Ok(())
2376}
2377
2378async fn update_dev_server_project(
2379 request: proto::UpdateDevServerProject,
2380 response: Response<proto::UpdateDevServerProject>,
2381 session: UserSession,
2382) -> Result<()> {
2383 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2384
2385 let (dev_server_project, update) = session
2386 .db()
2387 .await
2388 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2389 .await?;
2390
2391 let projects = session
2392 .db()
2393 .await
2394 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2395 .await?;
2396
2397 let dev_server_connection_id = session
2398 .connection_pool()
2399 .await
2400 .dev_server_connection_id_supporting(
2401 dev_server_project.dev_server_id,
2402 ZedVersion::with_list_directory(),
2403 )?;
2404
2405 session.peer.send(
2406 dev_server_connection_id,
2407 proto::DevServerInstructions { projects },
2408 )?;
2409
2410 send_dev_server_projects_update(session.user_id(), update, &session).await;
2411
2412 response.send(proto::Ack {})
2413}
2414
2415async fn create_dev_server_project(
2416 request: proto::CreateDevServerProject,
2417 response: Response<proto::CreateDevServerProject>,
2418 session: UserSession,
2419) -> Result<()> {
2420 let dev_server_id = DevServerId(request.dev_server_id as i32);
2421 let dev_server_connection_id = session
2422 .connection_pool()
2423 .await
2424 .dev_server_connection_id(dev_server_id);
2425 let Some(dev_server_connection_id) = dev_server_connection_id else {
2426 Err(ErrorCode::DevServerOffline
2427 .message("Cannot create a remote project when the dev server is offline".to_string())
2428 .anyhow())?
2429 };
2430
2431 let path = request.path.clone();
2432 //Check that the path exists on the dev server
2433 session
2434 .peer
2435 .forward_request(
2436 session.connection_id,
2437 dev_server_connection_id,
2438 proto::ValidateDevServerProjectRequest { path: path.clone() },
2439 )
2440 .await?;
2441
2442 let (dev_server_project, update) = session
2443 .db()
2444 .await
2445 .create_dev_server_project(
2446 DevServerId(request.dev_server_id as i32),
2447 &request.path,
2448 session.user_id(),
2449 )
2450 .await?;
2451
2452 let projects = session
2453 .db()
2454 .await
2455 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2456 .await?;
2457
2458 session.peer.send(
2459 dev_server_connection_id,
2460 proto::DevServerInstructions { projects },
2461 )?;
2462
2463 send_dev_server_projects_update(session.user_id(), update, &session).await;
2464
2465 response.send(proto::CreateDevServerProjectResponse {
2466 dev_server_project: Some(dev_server_project.to_proto(None)),
2467 })?;
2468 Ok(())
2469}
2470
2471async fn create_dev_server(
2472 request: proto::CreateDevServer,
2473 response: Response<proto::CreateDevServer>,
2474 session: UserSession,
2475) -> Result<()> {
2476 let access_token = auth::random_token();
2477 let hashed_access_token = auth::hash_access_token(&access_token);
2478
2479 if request.name.is_empty() {
2480 return Err(proto::ErrorCode::Forbidden
2481 .message("Dev server name cannot be empty".to_string())
2482 .anyhow())?;
2483 }
2484
2485 let (dev_server, status) = session
2486 .db()
2487 .await
2488 .create_dev_server(
2489 &request.name,
2490 request.ssh_connection_string.as_deref(),
2491 &hashed_access_token,
2492 session.user_id(),
2493 )
2494 .await?;
2495
2496 send_dev_server_projects_update(session.user_id(), status, &session).await;
2497
2498 response.send(proto::CreateDevServerResponse {
2499 dev_server_id: dev_server.id.0 as u64,
2500 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2501 name: request.name,
2502 })?;
2503 Ok(())
2504}
2505
2506async fn regenerate_dev_server_token(
2507 request: proto::RegenerateDevServerToken,
2508 response: Response<proto::RegenerateDevServerToken>,
2509 session: UserSession,
2510) -> Result<()> {
2511 let dev_server_id = DevServerId(request.dev_server_id as i32);
2512 let access_token = auth::random_token();
2513 let hashed_access_token = auth::hash_access_token(&access_token);
2514
2515 let connection_id = session
2516 .connection_pool()
2517 .await
2518 .dev_server_connection_id(dev_server_id);
2519 if let Some(connection_id) = connection_id {
2520 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2521 session.peer.send(
2522 connection_id,
2523 proto::ShutdownDevServer {
2524 reason: Some("dev server token was regenerated".to_string()),
2525 },
2526 )?;
2527 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2528 }
2529
2530 let status = session
2531 .db()
2532 .await
2533 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2534 .await?;
2535
2536 send_dev_server_projects_update(session.user_id(), status, &session).await;
2537
2538 response.send(proto::RegenerateDevServerTokenResponse {
2539 dev_server_id: dev_server_id.to_proto(),
2540 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2541 })?;
2542 Ok(())
2543}
2544
2545async fn rename_dev_server(
2546 request: proto::RenameDevServer,
2547 response: Response<proto::RenameDevServer>,
2548 session: UserSession,
2549) -> Result<()> {
2550 if request.name.trim().is_empty() {
2551 return Err(proto::ErrorCode::Forbidden
2552 .message("Dev server name cannot be empty".to_string())
2553 .anyhow())?;
2554 }
2555
2556 let dev_server_id = DevServerId(request.dev_server_id as i32);
2557 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2558 if dev_server.user_id != session.user_id() {
2559 return Err(anyhow!(ErrorCode::Forbidden))?;
2560 }
2561
2562 let status = session
2563 .db()
2564 .await
2565 .rename_dev_server(
2566 dev_server_id,
2567 &request.name,
2568 request.ssh_connection_string.as_deref(),
2569 session.user_id(),
2570 )
2571 .await?;
2572
2573 send_dev_server_projects_update(session.user_id(), status, &session).await;
2574
2575 response.send(proto::Ack {})?;
2576 Ok(())
2577}
2578
2579async fn delete_dev_server(
2580 request: proto::DeleteDevServer,
2581 response: Response<proto::DeleteDevServer>,
2582 session: UserSession,
2583) -> Result<()> {
2584 let dev_server_id = DevServerId(request.dev_server_id as i32);
2585 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2586 if dev_server.user_id != session.user_id() {
2587 return Err(anyhow!(ErrorCode::Forbidden))?;
2588 }
2589
2590 let connection_id = session
2591 .connection_pool()
2592 .await
2593 .dev_server_connection_id(dev_server_id);
2594 if let Some(connection_id) = connection_id {
2595 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2596 session.peer.send(
2597 connection_id,
2598 proto::ShutdownDevServer {
2599 reason: Some("dev server was deleted".to_string()),
2600 },
2601 )?;
2602 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2603 }
2604
2605 let status = session
2606 .db()
2607 .await
2608 .delete_dev_server(dev_server_id, session.user_id())
2609 .await?;
2610
2611 send_dev_server_projects_update(session.user_id(), status, &session).await;
2612
2613 response.send(proto::Ack {})?;
2614 Ok(())
2615}
2616
2617async fn delete_dev_server_project(
2618 request: proto::DeleteDevServerProject,
2619 response: Response<proto::DeleteDevServerProject>,
2620 session: UserSession,
2621) -> Result<()> {
2622 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2623 let dev_server_project = session
2624 .db()
2625 .await
2626 .get_dev_server_project(dev_server_project_id)
2627 .await?;
2628
2629 let dev_server = session
2630 .db()
2631 .await
2632 .get_dev_server(dev_server_project.dev_server_id)
2633 .await?;
2634 if dev_server.user_id != session.user_id() {
2635 return Err(anyhow!(ErrorCode::Forbidden))?;
2636 }
2637
2638 let dev_server_connection_id = session
2639 .connection_pool()
2640 .await
2641 .dev_server_connection_id(dev_server.id);
2642
2643 if let Some(dev_server_connection_id) = dev_server_connection_id {
2644 let project = session
2645 .db()
2646 .await
2647 .find_dev_server_project(dev_server_project_id)
2648 .await;
2649 if let Ok(project) = project {
2650 unshare_project_internal(
2651 project.id,
2652 dev_server_connection_id,
2653 Some(session.user_id()),
2654 &session,
2655 )
2656 .await?;
2657 }
2658 }
2659
2660 let (projects, status) = session
2661 .db()
2662 .await
2663 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2664 .await?;
2665
2666 if let Some(dev_server_connection_id) = dev_server_connection_id {
2667 session.peer.send(
2668 dev_server_connection_id,
2669 proto::DevServerInstructions { projects },
2670 )?;
2671 }
2672
2673 send_dev_server_projects_update(session.user_id(), status, &session).await;
2674
2675 response.send(proto::Ack {})?;
2676 Ok(())
2677}
2678
2679async fn rejoin_dev_server_projects(
2680 request: proto::RejoinRemoteProjects,
2681 response: Response<proto::RejoinRemoteProjects>,
2682 session: UserSession,
2683) -> Result<()> {
2684 let mut rejoined_projects = {
2685 let db = session.db().await;
2686 db.rejoin_dev_server_projects(
2687 &request.rejoined_projects,
2688 session.user_id(),
2689 session.0.connection_id,
2690 )
2691 .await?
2692 };
2693 response.send(proto::RejoinRemoteProjectsResponse {
2694 rejoined_projects: rejoined_projects
2695 .iter()
2696 .map(|project| project.to_proto())
2697 .collect(),
2698 })?;
2699 notify_rejoined_projects(&mut rejoined_projects, &session)
2700}
2701
2702async fn reconnect_dev_server(
2703 request: proto::ReconnectDevServer,
2704 response: Response<proto::ReconnectDevServer>,
2705 session: DevServerSession,
2706) -> Result<()> {
2707 let reshared_projects = {
2708 let db = session.db().await;
2709 db.reshare_dev_server_projects(
2710 &request.reshared_projects,
2711 session.dev_server_id(),
2712 session.0.connection_id,
2713 )
2714 .await?
2715 };
2716
2717 for project in &reshared_projects {
2718 for collaborator in &project.collaborators {
2719 session
2720 .peer
2721 .send(
2722 collaborator.connection_id,
2723 proto::UpdateProjectCollaborator {
2724 project_id: project.id.to_proto(),
2725 old_peer_id: Some(project.old_connection_id.into()),
2726 new_peer_id: Some(session.connection_id.into()),
2727 },
2728 )
2729 .trace_err();
2730 }
2731
2732 broadcast(
2733 Some(session.connection_id),
2734 project
2735 .collaborators
2736 .iter()
2737 .map(|collaborator| collaborator.connection_id),
2738 |connection_id| {
2739 session.peer.forward_send(
2740 session.connection_id,
2741 connection_id,
2742 proto::UpdateProject {
2743 project_id: project.id.to_proto(),
2744 worktrees: project.worktrees.clone(),
2745 },
2746 )
2747 },
2748 );
2749 }
2750
2751 response.send(proto::ReconnectDevServerResponse {
2752 reshared_projects: reshared_projects
2753 .iter()
2754 .map(|project| proto::ResharedProject {
2755 id: project.id.to_proto(),
2756 collaborators: project
2757 .collaborators
2758 .iter()
2759 .map(|collaborator| collaborator.to_proto())
2760 .collect(),
2761 })
2762 .collect(),
2763 })?;
2764
2765 Ok(())
2766}
2767
2768async fn shutdown_dev_server(
2769 _: proto::ShutdownDevServer,
2770 response: Response<proto::ShutdownDevServer>,
2771 session: DevServerSession,
2772) -> Result<()> {
2773 response.send(proto::Ack {})?;
2774 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2775 remove_dev_server_connection(session.dev_server_id(), &session).await
2776}
2777
2778async fn shutdown_dev_server_internal(
2779 dev_server_id: DevServerId,
2780 connection_id: ConnectionId,
2781 session: &Session,
2782) -> Result<()> {
2783 let (dev_server_projects, dev_server) = {
2784 let db = session.db().await;
2785 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2786 let dev_server = db.get_dev_server(dev_server_id).await?;
2787 (dev_server_projects, dev_server)
2788 };
2789
2790 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2791 unshare_project_internal(
2792 ProjectId::from_proto(project_id),
2793 connection_id,
2794 None,
2795 session,
2796 )
2797 .await?;
2798 }
2799
2800 session
2801 .connection_pool()
2802 .await
2803 .set_dev_server_offline(dev_server_id);
2804
2805 let status = session
2806 .db()
2807 .await
2808 .dev_server_projects_update(dev_server.user_id)
2809 .await?;
2810 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2811
2812 Ok(())
2813}
2814
2815async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2816 let dev_server_connection = session
2817 .connection_pool()
2818 .await
2819 .dev_server_connection_id(dev_server_id);
2820
2821 if let Some(dev_server_connection) = dev_server_connection {
2822 session
2823 .connection_pool()
2824 .await
2825 .remove_connection(dev_server_connection)?;
2826 }
2827 Ok(())
2828}
2829
2830/// Updates other participants with changes to the project
2831async fn update_project(
2832 request: proto::UpdateProject,
2833 response: Response<proto::UpdateProject>,
2834 session: Session,
2835) -> Result<()> {
2836 let project_id = ProjectId::from_proto(request.project_id);
2837 let (room, guest_connection_ids) = &*session
2838 .db()
2839 .await
2840 .update_project(project_id, session.connection_id, &request.worktrees)
2841 .await?;
2842 broadcast(
2843 Some(session.connection_id),
2844 guest_connection_ids.iter().copied(),
2845 |connection_id| {
2846 session
2847 .peer
2848 .forward_send(session.connection_id, connection_id, request.clone())
2849 },
2850 );
2851 if let Some(room) = room {
2852 room_updated(&room, &session.peer);
2853 }
2854 response.send(proto::Ack {})?;
2855
2856 Ok(())
2857}
2858
2859/// Updates other participants with changes to the worktree
2860async fn update_worktree(
2861 request: proto::UpdateWorktree,
2862 response: Response<proto::UpdateWorktree>,
2863 session: Session,
2864) -> Result<()> {
2865 let guest_connection_ids = session
2866 .db()
2867 .await
2868 .update_worktree(&request, session.connection_id)
2869 .await?;
2870
2871 broadcast(
2872 Some(session.connection_id),
2873 guest_connection_ids.iter().copied(),
2874 |connection_id| {
2875 session
2876 .peer
2877 .forward_send(session.connection_id, connection_id, request.clone())
2878 },
2879 );
2880 response.send(proto::Ack {})?;
2881 Ok(())
2882}
2883
2884/// Updates other participants with changes to the diagnostics
2885async fn update_diagnostic_summary(
2886 message: proto::UpdateDiagnosticSummary,
2887 session: Session,
2888) -> Result<()> {
2889 let guest_connection_ids = session
2890 .db()
2891 .await
2892 .update_diagnostic_summary(&message, session.connection_id)
2893 .await?;
2894
2895 broadcast(
2896 Some(session.connection_id),
2897 guest_connection_ids.iter().copied(),
2898 |connection_id| {
2899 session
2900 .peer
2901 .forward_send(session.connection_id, connection_id, message.clone())
2902 },
2903 );
2904
2905 Ok(())
2906}
2907
2908/// Updates other participants with changes to the worktree settings
2909async fn update_worktree_settings(
2910 message: proto::UpdateWorktreeSettings,
2911 session: Session,
2912) -> Result<()> {
2913 let guest_connection_ids = session
2914 .db()
2915 .await
2916 .update_worktree_settings(&message, session.connection_id)
2917 .await?;
2918
2919 broadcast(
2920 Some(session.connection_id),
2921 guest_connection_ids.iter().copied(),
2922 |connection_id| {
2923 session
2924 .peer
2925 .forward_send(session.connection_id, connection_id, message.clone())
2926 },
2927 );
2928
2929 Ok(())
2930}
2931
2932/// Notify other participants that a language server has started.
2933async fn start_language_server(
2934 request: proto::StartLanguageServer,
2935 session: Session,
2936) -> Result<()> {
2937 let guest_connection_ids = session
2938 .db()
2939 .await
2940 .start_language_server(&request, session.connection_id)
2941 .await?;
2942
2943 broadcast(
2944 Some(session.connection_id),
2945 guest_connection_ids.iter().copied(),
2946 |connection_id| {
2947 session
2948 .peer
2949 .forward_send(session.connection_id, connection_id, request.clone())
2950 },
2951 );
2952 Ok(())
2953}
2954
2955/// Notify other participants that a language server has changed.
2956async fn update_language_server(
2957 request: proto::UpdateLanguageServer,
2958 session: Session,
2959) -> Result<()> {
2960 let project_id = ProjectId::from_proto(request.project_id);
2961 let project_connection_ids = session
2962 .db()
2963 .await
2964 .project_connection_ids(project_id, session.connection_id, true)
2965 .await?;
2966 broadcast(
2967 Some(session.connection_id),
2968 project_connection_ids.iter().copied(),
2969 |connection_id| {
2970 session
2971 .peer
2972 .forward_send(session.connection_id, connection_id, request.clone())
2973 },
2974 );
2975 Ok(())
2976}
2977
2978/// forward a project request to the host. These requests should be read only
2979/// as guests are allowed to send them.
2980async fn forward_read_only_project_request<T>(
2981 request: T,
2982 response: Response<T>,
2983 session: UserSession,
2984) -> Result<()>
2985where
2986 T: EntityMessage + RequestMessage,
2987{
2988 let project_id = ProjectId::from_proto(request.remote_entity_id());
2989 let host_connection_id = session
2990 .db()
2991 .await
2992 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2993 .await?;
2994 let payload = session
2995 .peer
2996 .forward_request(session.connection_id, host_connection_id, request)
2997 .await?;
2998 response.send(payload)?;
2999 Ok(())
3000}
3001
3002/// forward a project request to the dev server. Only allowed
3003/// if it's your dev server.
3004async fn forward_project_request_for_owner<T>(
3005 request: T,
3006 response: Response<T>,
3007 session: UserSession,
3008) -> Result<()>
3009where
3010 T: EntityMessage + RequestMessage,
3011{
3012 let project_id = ProjectId::from_proto(request.remote_entity_id());
3013
3014 let host_connection_id = session
3015 .db()
3016 .await
3017 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3018 .await?;
3019 let payload = session
3020 .peer
3021 .forward_request(session.connection_id, host_connection_id, request)
3022 .await?;
3023 response.send(payload)?;
3024 Ok(())
3025}
3026
3027/// forward a project request to the host. These requests are disallowed
3028/// for guests.
3029async fn forward_mutating_project_request<T>(
3030 request: T,
3031 response: Response<T>,
3032 session: UserSession,
3033) -> Result<()>
3034where
3035 T: EntityMessage + RequestMessage,
3036{
3037 let project_id = ProjectId::from_proto(request.remote_entity_id());
3038
3039 let host_connection_id = session
3040 .db()
3041 .await
3042 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3043 .await?;
3044 let payload = session
3045 .peer
3046 .forward_request(session.connection_id, host_connection_id, request)
3047 .await?;
3048 response.send(payload)?;
3049 Ok(())
3050}
3051
3052/// forward a project request to the host. These requests are disallowed
3053/// for guests.
3054async fn forward_versioned_mutating_project_request<T>(
3055 request: T,
3056 response: Response<T>,
3057 session: UserSession,
3058) -> Result<()>
3059where
3060 T: EntityMessage + RequestMessage + VersionedMessage,
3061{
3062 let project_id = ProjectId::from_proto(request.remote_entity_id());
3063
3064 let host_connection_id = session
3065 .db()
3066 .await
3067 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3068 .await?;
3069 if let Some(host_version) = session
3070 .connection_pool()
3071 .await
3072 .connection(host_connection_id)
3073 .map(|c| c.zed_version)
3074 {
3075 if let Some(min_required_version) = request.required_host_version() {
3076 if min_required_version > host_version {
3077 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3078 .with_tag("required", &min_required_version.to_string())))?;
3079 }
3080 }
3081 }
3082
3083 let payload = session
3084 .peer
3085 .forward_request(session.connection_id, host_connection_id, request)
3086 .await?;
3087 response.send(payload)?;
3088 Ok(())
3089}
3090
3091/// Notify other participants that a new buffer has been created
3092async fn create_buffer_for_peer(
3093 request: proto::CreateBufferForPeer,
3094 session: Session,
3095) -> Result<()> {
3096 session
3097 .db()
3098 .await
3099 .check_user_is_project_host(
3100 ProjectId::from_proto(request.project_id),
3101 session.connection_id,
3102 )
3103 .await?;
3104 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3105 session
3106 .peer
3107 .forward_send(session.connection_id, peer_id.into(), request)?;
3108 Ok(())
3109}
3110
3111/// Notify other participants that a buffer has been updated. This is
3112/// allowed for guests as long as the update is limited to selections.
3113async fn update_buffer(
3114 request: proto::UpdateBuffer,
3115 response: Response<proto::UpdateBuffer>,
3116 session: Session,
3117) -> Result<()> {
3118 let project_id = ProjectId::from_proto(request.project_id);
3119 let mut capability = Capability::ReadOnly;
3120
3121 for op in request.operations.iter() {
3122 match op.variant {
3123 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3124 Some(_) => capability = Capability::ReadWrite,
3125 }
3126 }
3127
3128 let host = {
3129 let guard = session
3130 .db()
3131 .await
3132 .connections_for_buffer_update(
3133 project_id,
3134 session.principal_id(),
3135 session.connection_id,
3136 capability,
3137 )
3138 .await?;
3139
3140 let (host, guests) = &*guard;
3141
3142 broadcast(
3143 Some(session.connection_id),
3144 guests.clone(),
3145 |connection_id| {
3146 session
3147 .peer
3148 .forward_send(session.connection_id, connection_id, request.clone())
3149 },
3150 );
3151
3152 *host
3153 };
3154
3155 if host != session.connection_id {
3156 session
3157 .peer
3158 .forward_request(session.connection_id, host, request.clone())
3159 .await?;
3160 }
3161
3162 response.send(proto::Ack {})?;
3163 Ok(())
3164}
3165
3166async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3167 let project_id = ProjectId::from_proto(message.project_id);
3168
3169 let operation = message.operation.as_ref().context("invalid operation")?;
3170 let capability = match operation.variant.as_ref() {
3171 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3172 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3173 match buffer_op.variant {
3174 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3175 Capability::ReadOnly
3176 }
3177 _ => Capability::ReadWrite,
3178 }
3179 } else {
3180 Capability::ReadWrite
3181 }
3182 }
3183 Some(_) => Capability::ReadWrite,
3184 None => Capability::ReadOnly,
3185 };
3186
3187 let guard = session
3188 .db()
3189 .await
3190 .connections_for_buffer_update(
3191 project_id,
3192 session.principal_id(),
3193 session.connection_id,
3194 capability,
3195 )
3196 .await?;
3197
3198 let (host, guests) = &*guard;
3199
3200 broadcast(
3201 Some(session.connection_id),
3202 guests.iter().chain([host]).copied(),
3203 |connection_id| {
3204 session
3205 .peer
3206 .forward_send(session.connection_id, connection_id, message.clone())
3207 },
3208 );
3209
3210 Ok(())
3211}
3212
3213/// Notify other participants that a project has been updated.
3214async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3215 request: T,
3216 session: Session,
3217) -> Result<()> {
3218 let project_id = ProjectId::from_proto(request.remote_entity_id());
3219 let project_connection_ids = session
3220 .db()
3221 .await
3222 .project_connection_ids(project_id, session.connection_id, false)
3223 .await?;
3224
3225 broadcast(
3226 Some(session.connection_id),
3227 project_connection_ids.iter().copied(),
3228 |connection_id| {
3229 session
3230 .peer
3231 .forward_send(session.connection_id, connection_id, request.clone())
3232 },
3233 );
3234 Ok(())
3235}
3236
3237/// Start following another user in a call.
3238async fn follow(
3239 request: proto::Follow,
3240 response: Response<proto::Follow>,
3241 session: UserSession,
3242) -> Result<()> {
3243 let room_id = RoomId::from_proto(request.room_id);
3244 let project_id = request.project_id.map(ProjectId::from_proto);
3245 let leader_id = request
3246 .leader_id
3247 .ok_or_else(|| anyhow!("invalid leader id"))?
3248 .into();
3249 let follower_id = session.connection_id;
3250
3251 session
3252 .db()
3253 .await
3254 .check_room_participants(room_id, leader_id, session.connection_id)
3255 .await?;
3256
3257 let response_payload = session
3258 .peer
3259 .forward_request(session.connection_id, leader_id, request)
3260 .await?;
3261 response.send(response_payload)?;
3262
3263 if let Some(project_id) = project_id {
3264 let room = session
3265 .db()
3266 .await
3267 .follow(room_id, project_id, leader_id, follower_id)
3268 .await?;
3269 room_updated(&room, &session.peer);
3270 }
3271
3272 Ok(())
3273}
3274
3275/// Stop following another user in a call.
3276async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3277 let room_id = RoomId::from_proto(request.room_id);
3278 let project_id = request.project_id.map(ProjectId::from_proto);
3279 let leader_id = request
3280 .leader_id
3281 .ok_or_else(|| anyhow!("invalid leader id"))?
3282 .into();
3283 let follower_id = session.connection_id;
3284
3285 session
3286 .db()
3287 .await
3288 .check_room_participants(room_id, leader_id, session.connection_id)
3289 .await?;
3290
3291 session
3292 .peer
3293 .forward_send(session.connection_id, leader_id, request)?;
3294
3295 if let Some(project_id) = project_id {
3296 let room = session
3297 .db()
3298 .await
3299 .unfollow(room_id, project_id, leader_id, follower_id)
3300 .await?;
3301 room_updated(&room, &session.peer);
3302 }
3303
3304 Ok(())
3305}
3306
3307/// Notify everyone following you of your current location.
3308async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3309 let room_id = RoomId::from_proto(request.room_id);
3310 let database = session.db.lock().await;
3311
3312 let connection_ids = if let Some(project_id) = request.project_id {
3313 let project_id = ProjectId::from_proto(project_id);
3314 database
3315 .project_connection_ids(project_id, session.connection_id, true)
3316 .await?
3317 } else {
3318 database
3319 .room_connection_ids(room_id, session.connection_id)
3320 .await?
3321 };
3322
3323 // For now, don't send view update messages back to that view's current leader.
3324 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3325 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3326 _ => None,
3327 });
3328
3329 for connection_id in connection_ids.iter().cloned() {
3330 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3331 session
3332 .peer
3333 .forward_send(session.connection_id, connection_id, request.clone())?;
3334 }
3335 }
3336 Ok(())
3337}
3338
3339/// Get public data about users.
3340async fn get_users(
3341 request: proto::GetUsers,
3342 response: Response<proto::GetUsers>,
3343 session: Session,
3344) -> Result<()> {
3345 let user_ids = request
3346 .user_ids
3347 .into_iter()
3348 .map(UserId::from_proto)
3349 .collect();
3350 let users = session
3351 .db()
3352 .await
3353 .get_users_by_ids(user_ids)
3354 .await?
3355 .into_iter()
3356 .map(|user| proto::User {
3357 id: user.id.to_proto(),
3358 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3359 github_login: user.github_login,
3360 })
3361 .collect();
3362 response.send(proto::UsersResponse { users })?;
3363 Ok(())
3364}
3365
3366/// Search for users (to invite) buy Github login
3367async fn fuzzy_search_users(
3368 request: proto::FuzzySearchUsers,
3369 response: Response<proto::FuzzySearchUsers>,
3370 session: UserSession,
3371) -> Result<()> {
3372 let query = request.query;
3373 let users = match query.len() {
3374 0 => vec![],
3375 1 | 2 => session
3376 .db()
3377 .await
3378 .get_user_by_github_login(&query)
3379 .await?
3380 .into_iter()
3381 .collect(),
3382 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3383 };
3384 let users = users
3385 .into_iter()
3386 .filter(|user| user.id != session.user_id())
3387 .map(|user| proto::User {
3388 id: user.id.to_proto(),
3389 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3390 github_login: user.github_login,
3391 })
3392 .collect();
3393 response.send(proto::UsersResponse { users })?;
3394 Ok(())
3395}
3396
3397/// Send a contact request to another user.
3398async fn request_contact(
3399 request: proto::RequestContact,
3400 response: Response<proto::RequestContact>,
3401 session: UserSession,
3402) -> Result<()> {
3403 let requester_id = session.user_id();
3404 let responder_id = UserId::from_proto(request.responder_id);
3405 if requester_id == responder_id {
3406 return Err(anyhow!("cannot add yourself as a contact"))?;
3407 }
3408
3409 let notifications = session
3410 .db()
3411 .await
3412 .send_contact_request(requester_id, responder_id)
3413 .await?;
3414
3415 // Update outgoing contact requests of requester
3416 let mut update = proto::UpdateContacts::default();
3417 update.outgoing_requests.push(responder_id.to_proto());
3418 for connection_id in session
3419 .connection_pool()
3420 .await
3421 .user_connection_ids(requester_id)
3422 {
3423 session.peer.send(connection_id, update.clone())?;
3424 }
3425
3426 // Update incoming contact requests of responder
3427 let mut update = proto::UpdateContacts::default();
3428 update
3429 .incoming_requests
3430 .push(proto::IncomingContactRequest {
3431 requester_id: requester_id.to_proto(),
3432 });
3433 let connection_pool = session.connection_pool().await;
3434 for connection_id in connection_pool.user_connection_ids(responder_id) {
3435 session.peer.send(connection_id, update.clone())?;
3436 }
3437
3438 send_notifications(&connection_pool, &session.peer, notifications);
3439
3440 response.send(proto::Ack {})?;
3441 Ok(())
3442}
3443
3444/// Accept or decline a contact request
3445async fn respond_to_contact_request(
3446 request: proto::RespondToContactRequest,
3447 response: Response<proto::RespondToContactRequest>,
3448 session: UserSession,
3449) -> Result<()> {
3450 let responder_id = session.user_id();
3451 let requester_id = UserId::from_proto(request.requester_id);
3452 let db = session.db().await;
3453 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3454 db.dismiss_contact_notification(responder_id, requester_id)
3455 .await?;
3456 } else {
3457 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3458
3459 let notifications = db
3460 .respond_to_contact_request(responder_id, requester_id, accept)
3461 .await?;
3462 let requester_busy = db.is_user_busy(requester_id).await?;
3463 let responder_busy = db.is_user_busy(responder_id).await?;
3464
3465 let pool = session.connection_pool().await;
3466 // Update responder with new contact
3467 let mut update = proto::UpdateContacts::default();
3468 if accept {
3469 update
3470 .contacts
3471 .push(contact_for_user(requester_id, requester_busy, &pool));
3472 }
3473 update
3474 .remove_incoming_requests
3475 .push(requester_id.to_proto());
3476 for connection_id in pool.user_connection_ids(responder_id) {
3477 session.peer.send(connection_id, update.clone())?;
3478 }
3479
3480 // Update requester with new contact
3481 let mut update = proto::UpdateContacts::default();
3482 if accept {
3483 update
3484 .contacts
3485 .push(contact_for_user(responder_id, responder_busy, &pool));
3486 }
3487 update
3488 .remove_outgoing_requests
3489 .push(responder_id.to_proto());
3490
3491 for connection_id in pool.user_connection_ids(requester_id) {
3492 session.peer.send(connection_id, update.clone())?;
3493 }
3494
3495 send_notifications(&pool, &session.peer, notifications);
3496 }
3497
3498 response.send(proto::Ack {})?;
3499 Ok(())
3500}
3501
3502/// Remove a contact.
3503async fn remove_contact(
3504 request: proto::RemoveContact,
3505 response: Response<proto::RemoveContact>,
3506 session: UserSession,
3507) -> Result<()> {
3508 let requester_id = session.user_id();
3509 let responder_id = UserId::from_proto(request.user_id);
3510 let db = session.db().await;
3511 let (contact_accepted, deleted_notification_id) =
3512 db.remove_contact(requester_id, responder_id).await?;
3513
3514 let pool = session.connection_pool().await;
3515 // Update outgoing contact requests of requester
3516 let mut update = proto::UpdateContacts::default();
3517 if contact_accepted {
3518 update.remove_contacts.push(responder_id.to_proto());
3519 } else {
3520 update
3521 .remove_outgoing_requests
3522 .push(responder_id.to_proto());
3523 }
3524 for connection_id in pool.user_connection_ids(requester_id) {
3525 session.peer.send(connection_id, update.clone())?;
3526 }
3527
3528 // Update incoming contact requests of responder
3529 let mut update = proto::UpdateContacts::default();
3530 if contact_accepted {
3531 update.remove_contacts.push(requester_id.to_proto());
3532 } else {
3533 update
3534 .remove_incoming_requests
3535 .push(requester_id.to_proto());
3536 }
3537 for connection_id in pool.user_connection_ids(responder_id) {
3538 session.peer.send(connection_id, update.clone())?;
3539 if let Some(notification_id) = deleted_notification_id {
3540 session.peer.send(
3541 connection_id,
3542 proto::DeleteNotification {
3543 notification_id: notification_id.to_proto(),
3544 },
3545 )?;
3546 }
3547 }
3548
3549 response.send(proto::Ack {})?;
3550 Ok(())
3551}
3552
3553fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3554 version.0.minor() < 139
3555}
3556
3557async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3558 let plan = session.current_plan().await?;
3559
3560 session
3561 .peer
3562 .send(
3563 session.connection_id,
3564 proto::UpdateUserPlan { plan: plan.into() },
3565 )
3566 .trace_err();
3567
3568 Ok(())
3569}
3570
3571async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3572 subscribe_user_to_channels(
3573 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3574 &session,
3575 )
3576 .await?;
3577 Ok(())
3578}
3579
3580async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3581 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3582 let mut pool = session.connection_pool().await;
3583 for membership in &channels_for_user.channel_memberships {
3584 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3585 }
3586 session.peer.send(
3587 session.connection_id,
3588 build_update_user_channels(&channels_for_user),
3589 )?;
3590 session.peer.send(
3591 session.connection_id,
3592 build_channels_update(channels_for_user),
3593 )?;
3594 Ok(())
3595}
3596
3597/// Creates a new channel.
3598async fn create_channel(
3599 request: proto::CreateChannel,
3600 response: Response<proto::CreateChannel>,
3601 session: UserSession,
3602) -> Result<()> {
3603 let db = session.db().await;
3604
3605 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3606 let (channel, membership) = db
3607 .create_channel(&request.name, parent_id, session.user_id())
3608 .await?;
3609
3610 let root_id = channel.root_id();
3611 let channel = Channel::from_model(channel);
3612
3613 response.send(proto::CreateChannelResponse {
3614 channel: Some(channel.to_proto()),
3615 parent_id: request.parent_id,
3616 })?;
3617
3618 let mut connection_pool = session.connection_pool().await;
3619 if let Some(membership) = membership {
3620 connection_pool.subscribe_to_channel(
3621 membership.user_id,
3622 membership.channel_id,
3623 membership.role,
3624 );
3625 let update = proto::UpdateUserChannels {
3626 channel_memberships: vec![proto::ChannelMembership {
3627 channel_id: membership.channel_id.to_proto(),
3628 role: membership.role.into(),
3629 }],
3630 ..Default::default()
3631 };
3632 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3633 session.peer.send(connection_id, update.clone())?;
3634 }
3635 }
3636
3637 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3638 if !role.can_see_channel(channel.visibility) {
3639 continue;
3640 }
3641
3642 let update = proto::UpdateChannels {
3643 channels: vec![channel.to_proto()],
3644 ..Default::default()
3645 };
3646 session.peer.send(connection_id, update.clone())?;
3647 }
3648
3649 Ok(())
3650}
3651
3652/// Delete a channel
3653async fn delete_channel(
3654 request: proto::DeleteChannel,
3655 response: Response<proto::DeleteChannel>,
3656 session: UserSession,
3657) -> Result<()> {
3658 let db = session.db().await;
3659
3660 let channel_id = request.channel_id;
3661 let (root_channel, removed_channels) = db
3662 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3663 .await?;
3664 response.send(proto::Ack {})?;
3665
3666 // Notify members of removed channels
3667 let mut update = proto::UpdateChannels::default();
3668 update
3669 .delete_channels
3670 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3671
3672 let connection_pool = session.connection_pool().await;
3673 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3674 session.peer.send(connection_id, update.clone())?;
3675 }
3676
3677 Ok(())
3678}
3679
3680/// Invite someone to join a channel.
3681async fn invite_channel_member(
3682 request: proto::InviteChannelMember,
3683 response: Response<proto::InviteChannelMember>,
3684 session: UserSession,
3685) -> Result<()> {
3686 let db = session.db().await;
3687 let channel_id = ChannelId::from_proto(request.channel_id);
3688 let invitee_id = UserId::from_proto(request.user_id);
3689 let InviteMemberResult {
3690 channel,
3691 notifications,
3692 } = db
3693 .invite_channel_member(
3694 channel_id,
3695 invitee_id,
3696 session.user_id(),
3697 request.role().into(),
3698 )
3699 .await?;
3700
3701 let update = proto::UpdateChannels {
3702 channel_invitations: vec![channel.to_proto()],
3703 ..Default::default()
3704 };
3705
3706 let connection_pool = session.connection_pool().await;
3707 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3708 session.peer.send(connection_id, update.clone())?;
3709 }
3710
3711 send_notifications(&connection_pool, &session.peer, notifications);
3712
3713 response.send(proto::Ack {})?;
3714 Ok(())
3715}
3716
3717/// remove someone from a channel
3718async fn remove_channel_member(
3719 request: proto::RemoveChannelMember,
3720 response: Response<proto::RemoveChannelMember>,
3721 session: UserSession,
3722) -> Result<()> {
3723 let db = session.db().await;
3724 let channel_id = ChannelId::from_proto(request.channel_id);
3725 let member_id = UserId::from_proto(request.user_id);
3726
3727 let RemoveChannelMemberResult {
3728 membership_update,
3729 notification_id,
3730 } = db
3731 .remove_channel_member(channel_id, member_id, session.user_id())
3732 .await?;
3733
3734 let mut connection_pool = session.connection_pool().await;
3735 notify_membership_updated(
3736 &mut connection_pool,
3737 membership_update,
3738 member_id,
3739 &session.peer,
3740 );
3741 for connection_id in connection_pool.user_connection_ids(member_id) {
3742 if let Some(notification_id) = notification_id {
3743 session
3744 .peer
3745 .send(
3746 connection_id,
3747 proto::DeleteNotification {
3748 notification_id: notification_id.to_proto(),
3749 },
3750 )
3751 .trace_err();
3752 }
3753 }
3754
3755 response.send(proto::Ack {})?;
3756 Ok(())
3757}
3758
3759/// Toggle the channel between public and private.
3760/// Care is taken to maintain the invariant that public channels only descend from public channels,
3761/// (though members-only channels can appear at any point in the hierarchy).
3762async fn set_channel_visibility(
3763 request: proto::SetChannelVisibility,
3764 response: Response<proto::SetChannelVisibility>,
3765 session: UserSession,
3766) -> Result<()> {
3767 let db = session.db().await;
3768 let channel_id = ChannelId::from_proto(request.channel_id);
3769 let visibility = request.visibility().into();
3770
3771 let channel_model = db
3772 .set_channel_visibility(channel_id, visibility, session.user_id())
3773 .await?;
3774 let root_id = channel_model.root_id();
3775 let channel = Channel::from_model(channel_model);
3776
3777 let mut connection_pool = session.connection_pool().await;
3778 for (user_id, role) in connection_pool
3779 .channel_user_ids(root_id)
3780 .collect::<Vec<_>>()
3781 .into_iter()
3782 {
3783 let update = if role.can_see_channel(channel.visibility) {
3784 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3785 proto::UpdateChannels {
3786 channels: vec![channel.to_proto()],
3787 ..Default::default()
3788 }
3789 } else {
3790 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3791 proto::UpdateChannels {
3792 delete_channels: vec![channel.id.to_proto()],
3793 ..Default::default()
3794 }
3795 };
3796
3797 for connection_id in connection_pool.user_connection_ids(user_id) {
3798 session.peer.send(connection_id, update.clone())?;
3799 }
3800 }
3801
3802 response.send(proto::Ack {})?;
3803 Ok(())
3804}
3805
3806/// Alter the role for a user in the channel.
3807async fn set_channel_member_role(
3808 request: proto::SetChannelMemberRole,
3809 response: Response<proto::SetChannelMemberRole>,
3810 session: UserSession,
3811) -> Result<()> {
3812 let db = session.db().await;
3813 let channel_id = ChannelId::from_proto(request.channel_id);
3814 let member_id = UserId::from_proto(request.user_id);
3815 let result = db
3816 .set_channel_member_role(
3817 channel_id,
3818 session.user_id(),
3819 member_id,
3820 request.role().into(),
3821 )
3822 .await?;
3823
3824 match result {
3825 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3826 let mut connection_pool = session.connection_pool().await;
3827 notify_membership_updated(
3828 &mut connection_pool,
3829 membership_update,
3830 member_id,
3831 &session.peer,
3832 )
3833 }
3834 db::SetMemberRoleResult::InviteUpdated(channel) => {
3835 let update = proto::UpdateChannels {
3836 channel_invitations: vec![channel.to_proto()],
3837 ..Default::default()
3838 };
3839
3840 for connection_id in session
3841 .connection_pool()
3842 .await
3843 .user_connection_ids(member_id)
3844 {
3845 session.peer.send(connection_id, update.clone())?;
3846 }
3847 }
3848 }
3849
3850 response.send(proto::Ack {})?;
3851 Ok(())
3852}
3853
3854/// Change the name of a channel
3855async fn rename_channel(
3856 request: proto::RenameChannel,
3857 response: Response<proto::RenameChannel>,
3858 session: UserSession,
3859) -> Result<()> {
3860 let db = session.db().await;
3861 let channel_id = ChannelId::from_proto(request.channel_id);
3862 let channel_model = db
3863 .rename_channel(channel_id, session.user_id(), &request.name)
3864 .await?;
3865 let root_id = channel_model.root_id();
3866 let channel = Channel::from_model(channel_model);
3867
3868 response.send(proto::RenameChannelResponse {
3869 channel: Some(channel.to_proto()),
3870 })?;
3871
3872 let connection_pool = session.connection_pool().await;
3873 let update = proto::UpdateChannels {
3874 channels: vec![channel.to_proto()],
3875 ..Default::default()
3876 };
3877 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3878 if role.can_see_channel(channel.visibility) {
3879 session.peer.send(connection_id, update.clone())?;
3880 }
3881 }
3882
3883 Ok(())
3884}
3885
3886/// Move a channel to a new parent.
3887async fn move_channel(
3888 request: proto::MoveChannel,
3889 response: Response<proto::MoveChannel>,
3890 session: UserSession,
3891) -> Result<()> {
3892 let channel_id = ChannelId::from_proto(request.channel_id);
3893 let to = ChannelId::from_proto(request.to);
3894
3895 let (root_id, channels) = session
3896 .db()
3897 .await
3898 .move_channel(channel_id, to, session.user_id())
3899 .await?;
3900
3901 let connection_pool = session.connection_pool().await;
3902 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3903 let channels = channels
3904 .iter()
3905 .filter_map(|channel| {
3906 if role.can_see_channel(channel.visibility) {
3907 Some(channel.to_proto())
3908 } else {
3909 None
3910 }
3911 })
3912 .collect::<Vec<_>>();
3913 if channels.is_empty() {
3914 continue;
3915 }
3916
3917 let update = proto::UpdateChannels {
3918 channels,
3919 ..Default::default()
3920 };
3921
3922 session.peer.send(connection_id, update.clone())?;
3923 }
3924
3925 response.send(Ack {})?;
3926 Ok(())
3927}
3928
3929/// Get the list of channel members
3930async fn get_channel_members(
3931 request: proto::GetChannelMembers,
3932 response: Response<proto::GetChannelMembers>,
3933 session: UserSession,
3934) -> Result<()> {
3935 let db = session.db().await;
3936 let channel_id = ChannelId::from_proto(request.channel_id);
3937 let limit = if request.limit == 0 {
3938 u16::MAX as u64
3939 } else {
3940 request.limit
3941 };
3942 let (members, users) = db
3943 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3944 .await?;
3945 response.send(proto::GetChannelMembersResponse { members, users })?;
3946 Ok(())
3947}
3948
3949/// Accept or decline a channel invitation.
3950async fn respond_to_channel_invite(
3951 request: proto::RespondToChannelInvite,
3952 response: Response<proto::RespondToChannelInvite>,
3953 session: UserSession,
3954) -> Result<()> {
3955 let db = session.db().await;
3956 let channel_id = ChannelId::from_proto(request.channel_id);
3957 let RespondToChannelInvite {
3958 membership_update,
3959 notifications,
3960 } = db
3961 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3962 .await?;
3963
3964 let mut connection_pool = session.connection_pool().await;
3965 if let Some(membership_update) = membership_update {
3966 notify_membership_updated(
3967 &mut connection_pool,
3968 membership_update,
3969 session.user_id(),
3970 &session.peer,
3971 );
3972 } else {
3973 let update = proto::UpdateChannels {
3974 remove_channel_invitations: vec![channel_id.to_proto()],
3975 ..Default::default()
3976 };
3977
3978 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3979 session.peer.send(connection_id, update.clone())?;
3980 }
3981 };
3982
3983 send_notifications(&connection_pool, &session.peer, notifications);
3984
3985 response.send(proto::Ack {})?;
3986
3987 Ok(())
3988}
3989
3990/// Join the channels' room
3991async fn join_channel(
3992 request: proto::JoinChannel,
3993 response: Response<proto::JoinChannel>,
3994 session: UserSession,
3995) -> Result<()> {
3996 let channel_id = ChannelId::from_proto(request.channel_id);
3997 join_channel_internal(channel_id, Box::new(response), session).await
3998}
3999
4000trait JoinChannelInternalResponse {
4001 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4002}
4003impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4004 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4005 Response::<proto::JoinChannel>::send(self, result)
4006 }
4007}
4008impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4009 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4010 Response::<proto::JoinRoom>::send(self, result)
4011 }
4012}
4013
4014async fn join_channel_internal(
4015 channel_id: ChannelId,
4016 response: Box<impl JoinChannelInternalResponse>,
4017 session: UserSession,
4018) -> Result<()> {
4019 let joined_room = {
4020 let mut db = session.db().await;
4021 // If zed quits without leaving the room, and the user re-opens zed before the
4022 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4023 // room they were in.
4024 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4025 tracing::info!(
4026 stale_connection_id = %connection,
4027 "cleaning up stale connection",
4028 );
4029 drop(db);
4030 leave_room_for_session(&session, connection).await?;
4031 db = session.db().await;
4032 }
4033
4034 let (joined_room, membership_updated, role) = db
4035 .join_channel(channel_id, session.user_id(), session.connection_id)
4036 .await?;
4037
4038 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4039 let (can_publish, token) = if role == ChannelRole::Guest {
4040 (
4041 false,
4042 live_kit
4043 .guest_token(
4044 &joined_room.room.live_kit_room,
4045 &session.user_id().to_string(),
4046 )
4047 .trace_err()?,
4048 )
4049 } else {
4050 (
4051 true,
4052 live_kit
4053 .room_token(
4054 &joined_room.room.live_kit_room,
4055 &session.user_id().to_string(),
4056 )
4057 .trace_err()?,
4058 )
4059 };
4060
4061 Some(LiveKitConnectionInfo {
4062 server_url: live_kit.url().into(),
4063 token,
4064 can_publish,
4065 })
4066 });
4067
4068 response.send(proto::JoinRoomResponse {
4069 room: Some(joined_room.room.clone()),
4070 channel_id: joined_room
4071 .channel
4072 .as_ref()
4073 .map(|channel| channel.id.to_proto()),
4074 live_kit_connection_info,
4075 })?;
4076
4077 let mut connection_pool = session.connection_pool().await;
4078 if let Some(membership_updated) = membership_updated {
4079 notify_membership_updated(
4080 &mut connection_pool,
4081 membership_updated,
4082 session.user_id(),
4083 &session.peer,
4084 );
4085 }
4086
4087 room_updated(&joined_room.room, &session.peer);
4088
4089 joined_room
4090 };
4091
4092 channel_updated(
4093 &joined_room
4094 .channel
4095 .ok_or_else(|| anyhow!("channel not returned"))?,
4096 &joined_room.room,
4097 &session.peer,
4098 &*session.connection_pool().await,
4099 );
4100
4101 update_user_contacts(session.user_id(), &session).await?;
4102 Ok(())
4103}
4104
4105/// Start editing the channel notes
4106async fn join_channel_buffer(
4107 request: proto::JoinChannelBuffer,
4108 response: Response<proto::JoinChannelBuffer>,
4109 session: UserSession,
4110) -> Result<()> {
4111 let db = session.db().await;
4112 let channel_id = ChannelId::from_proto(request.channel_id);
4113
4114 let open_response = db
4115 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4116 .await?;
4117
4118 let collaborators = open_response.collaborators.clone();
4119 response.send(open_response)?;
4120
4121 let update = UpdateChannelBufferCollaborators {
4122 channel_id: channel_id.to_proto(),
4123 collaborators: collaborators.clone(),
4124 };
4125 channel_buffer_updated(
4126 session.connection_id,
4127 collaborators
4128 .iter()
4129 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4130 &update,
4131 &session.peer,
4132 );
4133
4134 Ok(())
4135}
4136
4137/// Edit the channel notes
4138async fn update_channel_buffer(
4139 request: proto::UpdateChannelBuffer,
4140 session: UserSession,
4141) -> Result<()> {
4142 let db = session.db().await;
4143 let channel_id = ChannelId::from_proto(request.channel_id);
4144
4145 let (collaborators, epoch, version) = db
4146 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4147 .await?;
4148
4149 channel_buffer_updated(
4150 session.connection_id,
4151 collaborators.clone(),
4152 &proto::UpdateChannelBuffer {
4153 channel_id: channel_id.to_proto(),
4154 operations: request.operations,
4155 },
4156 &session.peer,
4157 );
4158
4159 let pool = &*session.connection_pool().await;
4160
4161 let non_collaborators =
4162 pool.channel_connection_ids(channel_id)
4163 .filter_map(|(connection_id, _)| {
4164 if collaborators.contains(&connection_id) {
4165 None
4166 } else {
4167 Some(connection_id)
4168 }
4169 });
4170
4171 broadcast(None, non_collaborators, |peer_id| {
4172 session.peer.send(
4173 peer_id,
4174 proto::UpdateChannels {
4175 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4176 channel_id: channel_id.to_proto(),
4177 epoch: epoch as u64,
4178 version: version.clone(),
4179 }],
4180 ..Default::default()
4181 },
4182 )
4183 });
4184
4185 Ok(())
4186}
4187
4188/// Rejoin the channel notes after a connection blip
4189async fn rejoin_channel_buffers(
4190 request: proto::RejoinChannelBuffers,
4191 response: Response<proto::RejoinChannelBuffers>,
4192 session: UserSession,
4193) -> Result<()> {
4194 let db = session.db().await;
4195 let buffers = db
4196 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4197 .await?;
4198
4199 for rejoined_buffer in &buffers {
4200 let collaborators_to_notify = rejoined_buffer
4201 .buffer
4202 .collaborators
4203 .iter()
4204 .filter_map(|c| Some(c.peer_id?.into()));
4205 channel_buffer_updated(
4206 session.connection_id,
4207 collaborators_to_notify,
4208 &proto::UpdateChannelBufferCollaborators {
4209 channel_id: rejoined_buffer.buffer.channel_id,
4210 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4211 },
4212 &session.peer,
4213 );
4214 }
4215
4216 response.send(proto::RejoinChannelBuffersResponse {
4217 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4218 })?;
4219
4220 Ok(())
4221}
4222
4223/// Stop editing the channel notes
4224async fn leave_channel_buffer(
4225 request: proto::LeaveChannelBuffer,
4226 response: Response<proto::LeaveChannelBuffer>,
4227 session: UserSession,
4228) -> Result<()> {
4229 let db = session.db().await;
4230 let channel_id = ChannelId::from_proto(request.channel_id);
4231
4232 let left_buffer = db
4233 .leave_channel_buffer(channel_id, session.connection_id)
4234 .await?;
4235
4236 response.send(Ack {})?;
4237
4238 channel_buffer_updated(
4239 session.connection_id,
4240 left_buffer.connections,
4241 &proto::UpdateChannelBufferCollaborators {
4242 channel_id: channel_id.to_proto(),
4243 collaborators: left_buffer.collaborators,
4244 },
4245 &session.peer,
4246 );
4247
4248 Ok(())
4249}
4250
4251fn channel_buffer_updated<T: EnvelopedMessage>(
4252 sender_id: ConnectionId,
4253 collaborators: impl IntoIterator<Item = ConnectionId>,
4254 message: &T,
4255 peer: &Peer,
4256) {
4257 broadcast(Some(sender_id), collaborators, |peer_id| {
4258 peer.send(peer_id, message.clone())
4259 });
4260}
4261
4262fn send_notifications(
4263 connection_pool: &ConnectionPool,
4264 peer: &Peer,
4265 notifications: db::NotificationBatch,
4266) {
4267 for (user_id, notification) in notifications {
4268 for connection_id in connection_pool.user_connection_ids(user_id) {
4269 if let Err(error) = peer.send(
4270 connection_id,
4271 proto::AddNotification {
4272 notification: Some(notification.clone()),
4273 },
4274 ) {
4275 tracing::error!(
4276 "failed to send notification to {:?} {}",
4277 connection_id,
4278 error
4279 );
4280 }
4281 }
4282 }
4283}
4284
4285/// Send a message to the channel
4286async fn send_channel_message(
4287 request: proto::SendChannelMessage,
4288 response: Response<proto::SendChannelMessage>,
4289 session: UserSession,
4290) -> Result<()> {
4291 // Validate the message body.
4292 let body = request.body.trim().to_string();
4293 if body.len() > MAX_MESSAGE_LEN {
4294 return Err(anyhow!("message is too long"))?;
4295 }
4296 if body.is_empty() {
4297 return Err(anyhow!("message can't be blank"))?;
4298 }
4299
4300 // TODO: adjust mentions if body is trimmed
4301
4302 let timestamp = OffsetDateTime::now_utc();
4303 let nonce = request
4304 .nonce
4305 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4306
4307 let channel_id = ChannelId::from_proto(request.channel_id);
4308 let CreatedChannelMessage {
4309 message_id,
4310 participant_connection_ids,
4311 notifications,
4312 } = session
4313 .db()
4314 .await
4315 .create_channel_message(
4316 channel_id,
4317 session.user_id(),
4318 &body,
4319 &request.mentions,
4320 timestamp,
4321 nonce.clone().into(),
4322 match request.reply_to_message_id {
4323 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4324 None => None,
4325 },
4326 )
4327 .await?;
4328
4329 let message = proto::ChannelMessage {
4330 sender_id: session.user_id().to_proto(),
4331 id: message_id.to_proto(),
4332 body,
4333 mentions: request.mentions,
4334 timestamp: timestamp.unix_timestamp() as u64,
4335 nonce: Some(nonce),
4336 reply_to_message_id: request.reply_to_message_id,
4337 edited_at: None,
4338 };
4339 broadcast(
4340 Some(session.connection_id),
4341 participant_connection_ids.clone(),
4342 |connection| {
4343 session.peer.send(
4344 connection,
4345 proto::ChannelMessageSent {
4346 channel_id: channel_id.to_proto(),
4347 message: Some(message.clone()),
4348 },
4349 )
4350 },
4351 );
4352 response.send(proto::SendChannelMessageResponse {
4353 message: Some(message),
4354 })?;
4355
4356 let pool = &*session.connection_pool().await;
4357 let non_participants =
4358 pool.channel_connection_ids(channel_id)
4359 .filter_map(|(connection_id, _)| {
4360 if participant_connection_ids.contains(&connection_id) {
4361 None
4362 } else {
4363 Some(connection_id)
4364 }
4365 });
4366 broadcast(None, non_participants, |peer_id| {
4367 session.peer.send(
4368 peer_id,
4369 proto::UpdateChannels {
4370 latest_channel_message_ids: vec![proto::ChannelMessageId {
4371 channel_id: channel_id.to_proto(),
4372 message_id: message_id.to_proto(),
4373 }],
4374 ..Default::default()
4375 },
4376 )
4377 });
4378 send_notifications(pool, &session.peer, notifications);
4379
4380 Ok(())
4381}
4382
4383/// Delete a channel message
4384async fn remove_channel_message(
4385 request: proto::RemoveChannelMessage,
4386 response: Response<proto::RemoveChannelMessage>,
4387 session: UserSession,
4388) -> Result<()> {
4389 let channel_id = ChannelId::from_proto(request.channel_id);
4390 let message_id = MessageId::from_proto(request.message_id);
4391 let (connection_ids, existing_notification_ids) = session
4392 .db()
4393 .await
4394 .remove_channel_message(channel_id, message_id, session.user_id())
4395 .await?;
4396
4397 broadcast(
4398 Some(session.connection_id),
4399 connection_ids,
4400 move |connection| {
4401 session.peer.send(connection, request.clone())?;
4402
4403 for notification_id in &existing_notification_ids {
4404 session.peer.send(
4405 connection,
4406 proto::DeleteNotification {
4407 notification_id: (*notification_id).to_proto(),
4408 },
4409 )?;
4410 }
4411
4412 Ok(())
4413 },
4414 );
4415 response.send(proto::Ack {})?;
4416 Ok(())
4417}
4418
4419async fn update_channel_message(
4420 request: proto::UpdateChannelMessage,
4421 response: Response<proto::UpdateChannelMessage>,
4422 session: UserSession,
4423) -> Result<()> {
4424 let channel_id = ChannelId::from_proto(request.channel_id);
4425 let message_id = MessageId::from_proto(request.message_id);
4426 let updated_at = OffsetDateTime::now_utc();
4427 let UpdatedChannelMessage {
4428 message_id,
4429 participant_connection_ids,
4430 notifications,
4431 reply_to_message_id,
4432 timestamp,
4433 deleted_mention_notification_ids,
4434 updated_mention_notifications,
4435 } = session
4436 .db()
4437 .await
4438 .update_channel_message(
4439 channel_id,
4440 message_id,
4441 session.user_id(),
4442 request.body.as_str(),
4443 &request.mentions,
4444 updated_at,
4445 )
4446 .await?;
4447
4448 let nonce = request
4449 .nonce
4450 .clone()
4451 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4452
4453 let message = proto::ChannelMessage {
4454 sender_id: session.user_id().to_proto(),
4455 id: message_id.to_proto(),
4456 body: request.body.clone(),
4457 mentions: request.mentions.clone(),
4458 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4459 nonce: Some(nonce),
4460 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4461 edited_at: Some(updated_at.unix_timestamp() as u64),
4462 };
4463
4464 response.send(proto::Ack {})?;
4465
4466 let pool = &*session.connection_pool().await;
4467 broadcast(
4468 Some(session.connection_id),
4469 participant_connection_ids,
4470 |connection| {
4471 session.peer.send(
4472 connection,
4473 proto::ChannelMessageUpdate {
4474 channel_id: channel_id.to_proto(),
4475 message: Some(message.clone()),
4476 },
4477 )?;
4478
4479 for notification_id in &deleted_mention_notification_ids {
4480 session.peer.send(
4481 connection,
4482 proto::DeleteNotification {
4483 notification_id: (*notification_id).to_proto(),
4484 },
4485 )?;
4486 }
4487
4488 for notification in &updated_mention_notifications {
4489 session.peer.send(
4490 connection,
4491 proto::UpdateNotification {
4492 notification: Some(notification.clone()),
4493 },
4494 )?;
4495 }
4496
4497 Ok(())
4498 },
4499 );
4500
4501 send_notifications(pool, &session.peer, notifications);
4502
4503 Ok(())
4504}
4505
4506/// Mark a channel message as read
4507async fn acknowledge_channel_message(
4508 request: proto::AckChannelMessage,
4509 session: UserSession,
4510) -> Result<()> {
4511 let channel_id = ChannelId::from_proto(request.channel_id);
4512 let message_id = MessageId::from_proto(request.message_id);
4513 let notifications = session
4514 .db()
4515 .await
4516 .observe_channel_message(channel_id, session.user_id(), message_id)
4517 .await?;
4518 send_notifications(
4519 &*session.connection_pool().await,
4520 &session.peer,
4521 notifications,
4522 );
4523 Ok(())
4524}
4525
4526/// Mark a buffer version as synced
4527async fn acknowledge_buffer_version(
4528 request: proto::AckBufferOperation,
4529 session: UserSession,
4530) -> Result<()> {
4531 let buffer_id = BufferId::from_proto(request.buffer_id);
4532 session
4533 .db()
4534 .await
4535 .observe_buffer_version(
4536 buffer_id,
4537 session.user_id(),
4538 request.epoch as i32,
4539 &request.version,
4540 )
4541 .await?;
4542 Ok(())
4543}
4544
4545struct ZedProCompleteWithLanguageModelRateLimit;
4546
4547impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4548 fn capacity(&self) -> usize {
4549 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4550 .ok()
4551 .and_then(|v| v.parse().ok())
4552 .unwrap_or(120) // Picked arbitrarily
4553 }
4554
4555 fn refill_duration(&self) -> chrono::Duration {
4556 chrono::Duration::hours(1)
4557 }
4558
4559 fn db_name(&self) -> &'static str {
4560 "zed-pro:complete-with-language-model"
4561 }
4562}
4563
4564struct FreeCompleteWithLanguageModelRateLimit;
4565
4566impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4567 fn capacity(&self) -> usize {
4568 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4569 .ok()
4570 .and_then(|v| v.parse().ok())
4571 .unwrap_or(120 / 10) // Picked arbitrarily
4572 }
4573
4574 fn refill_duration(&self) -> chrono::Duration {
4575 chrono::Duration::hours(1)
4576 }
4577
4578 fn db_name(&self) -> &'static str {
4579 "free:complete-with-language-model"
4580 }
4581}
4582
4583async fn complete_with_language_model(
4584 request: proto::CompleteWithLanguageModel,
4585 response: Response<proto::CompleteWithLanguageModel>,
4586 session: Session,
4587 config: &Config,
4588) -> Result<()> {
4589 let Some(session) = session.for_user() else {
4590 return Err(anyhow!("user not found"))?;
4591 };
4592 authorize_access_to_language_models(&session).await?;
4593
4594 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4595 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4596 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4597 };
4598
4599 session
4600 .rate_limiter
4601 .check(&*rate_limit, session.user_id())
4602 .await?;
4603
4604 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4605 Some(proto::LanguageModelProvider::Anthropic) => {
4606 let api_key = config
4607 .anthropic_api_key
4608 .as_ref()
4609 .context("no Anthropic AI API key configured on the server")?;
4610 anthropic::complete(
4611 session.http_client.as_ref(),
4612 anthropic::ANTHROPIC_API_URL,
4613 api_key,
4614 serde_json::from_str(&request.request)?,
4615 )
4616 .await?
4617 }
4618 _ => return Err(anyhow!("unsupported provider"))?,
4619 };
4620
4621 response.send(proto::CompleteWithLanguageModelResponse {
4622 completion: serde_json::to_string(&result)?,
4623 })?;
4624
4625 Ok(())
4626}
4627
4628async fn stream_complete_with_language_model(
4629 request: proto::StreamCompleteWithLanguageModel,
4630 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4631 session: Session,
4632 config: &Config,
4633) -> Result<()> {
4634 let Some(session) = session.for_user() else {
4635 return Err(anyhow!("user not found"))?;
4636 };
4637 authorize_access_to_language_models(&session).await?;
4638
4639 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4640 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4641 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4642 };
4643
4644 session
4645 .rate_limiter
4646 .check(&*rate_limit, session.user_id())
4647 .await?;
4648
4649 match proto::LanguageModelProvider::from_i32(request.provider) {
4650 Some(proto::LanguageModelProvider::Anthropic) => {
4651 let api_key = config
4652 .anthropic_api_key
4653 .as_ref()
4654 .context("no Anthropic AI API key configured on the server")?;
4655 let mut chunks = anthropic::stream_completion(
4656 session.http_client.as_ref(),
4657 anthropic::ANTHROPIC_API_URL,
4658 api_key,
4659 serde_json::from_str(&request.request)?,
4660 None,
4661 )
4662 .await?;
4663 while let Some(event) = chunks.next().await {
4664 let chunk = event?;
4665 response.send(proto::StreamCompleteWithLanguageModelResponse {
4666 event: serde_json::to_string(&chunk)?,
4667 })?;
4668 }
4669 }
4670 Some(proto::LanguageModelProvider::OpenAi) => {
4671 let api_key = config
4672 .openai_api_key
4673 .as_ref()
4674 .context("no OpenAI API key configured on the server")?;
4675 let mut events = open_ai::stream_completion(
4676 session.http_client.as_ref(),
4677 open_ai::OPEN_AI_API_URL,
4678 api_key,
4679 serde_json::from_str(&request.request)?,
4680 None,
4681 )
4682 .await?;
4683 while let Some(event) = events.next().await {
4684 let event = event?;
4685 response.send(proto::StreamCompleteWithLanguageModelResponse {
4686 event: serde_json::to_string(&event)?,
4687 })?;
4688 }
4689 }
4690 Some(proto::LanguageModelProvider::Google) => {
4691 let api_key = config
4692 .google_ai_api_key
4693 .as_ref()
4694 .context("no Google AI API key configured on the server")?;
4695 let mut events = google_ai::stream_generate_content(
4696 session.http_client.as_ref(),
4697 google_ai::API_URL,
4698 api_key,
4699 serde_json::from_str(&request.request)?,
4700 )
4701 .await?;
4702 while let Some(event) = events.next().await {
4703 let event = event?;
4704 response.send(proto::StreamCompleteWithLanguageModelResponse {
4705 event: serde_json::to_string(&event)?,
4706 })?;
4707 }
4708 }
4709 Some(proto::LanguageModelProvider::Zed) => {
4710 let api_key = config
4711 .qwen2_7b_api_key
4712 .as_ref()
4713 .context("no Qwen2-7B API key configured on the server")?;
4714 let api_url = config
4715 .qwen2_7b_api_url
4716 .as_ref()
4717 .context("no Qwen2-7B URL configured on the server")?;
4718 let mut events = open_ai::stream_completion(
4719 session.http_client.as_ref(),
4720 &api_url,
4721 api_key,
4722 serde_json::from_str(&request.request)?,
4723 None,
4724 )
4725 .await?;
4726 while let Some(event) = events.next().await {
4727 let event = event?;
4728 response.send(proto::StreamCompleteWithLanguageModelResponse {
4729 event: serde_json::to_string(&event)?,
4730 })?;
4731 }
4732 }
4733 None => return Err(anyhow!("unknown provider"))?,
4734 }
4735
4736 Ok(())
4737}
4738
4739async fn count_language_model_tokens(
4740 request: proto::CountLanguageModelTokens,
4741 response: Response<proto::CountLanguageModelTokens>,
4742 session: Session,
4743 config: &Config,
4744) -> Result<()> {
4745 let Some(session) = session.for_user() else {
4746 return Err(anyhow!("user not found"))?;
4747 };
4748 authorize_access_to_language_models(&session).await?;
4749
4750 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4751 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4752 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4753 };
4754
4755 session
4756 .rate_limiter
4757 .check(&*rate_limit, session.user_id())
4758 .await?;
4759
4760 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4761 Some(proto::LanguageModelProvider::Google) => {
4762 let api_key = config
4763 .google_ai_api_key
4764 .as_ref()
4765 .context("no Google AI API key configured on the server")?;
4766 google_ai::count_tokens(
4767 session.http_client.as_ref(),
4768 google_ai::API_URL,
4769 api_key,
4770 serde_json::from_str(&request.request)?,
4771 )
4772 .await?
4773 }
4774 _ => return Err(anyhow!("unsupported provider"))?,
4775 };
4776
4777 response.send(proto::CountLanguageModelTokensResponse {
4778 token_count: result.total_tokens as u32,
4779 })?;
4780
4781 Ok(())
4782}
4783
4784struct ZedProCountLanguageModelTokensRateLimit;
4785
4786impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4787 fn capacity(&self) -> usize {
4788 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4789 .ok()
4790 .and_then(|v| v.parse().ok())
4791 .unwrap_or(600) // Picked arbitrarily
4792 }
4793
4794 fn refill_duration(&self) -> chrono::Duration {
4795 chrono::Duration::hours(1)
4796 }
4797
4798 fn db_name(&self) -> &'static str {
4799 "zed-pro:count-language-model-tokens"
4800 }
4801}
4802
4803struct FreeCountLanguageModelTokensRateLimit;
4804
4805impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4806 fn capacity(&self) -> usize {
4807 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4808 .ok()
4809 .and_then(|v| v.parse().ok())
4810 .unwrap_or(600 / 10) // Picked arbitrarily
4811 }
4812
4813 fn refill_duration(&self) -> chrono::Duration {
4814 chrono::Duration::hours(1)
4815 }
4816
4817 fn db_name(&self) -> &'static str {
4818 "free:count-language-model-tokens"
4819 }
4820}
4821
4822struct ZedProComputeEmbeddingsRateLimit;
4823
4824impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4825 fn capacity(&self) -> usize {
4826 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4827 .ok()
4828 .and_then(|v| v.parse().ok())
4829 .unwrap_or(5000) // Picked arbitrarily
4830 }
4831
4832 fn refill_duration(&self) -> chrono::Duration {
4833 chrono::Duration::hours(1)
4834 }
4835
4836 fn db_name(&self) -> &'static str {
4837 "zed-pro:compute-embeddings"
4838 }
4839}
4840
4841struct FreeComputeEmbeddingsRateLimit;
4842
4843impl RateLimit for FreeComputeEmbeddingsRateLimit {
4844 fn capacity(&self) -> usize {
4845 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4846 .ok()
4847 .and_then(|v| v.parse().ok())
4848 .unwrap_or(5000 / 10) // Picked arbitrarily
4849 }
4850
4851 fn refill_duration(&self) -> chrono::Duration {
4852 chrono::Duration::hours(1)
4853 }
4854
4855 fn db_name(&self) -> &'static str {
4856 "free:compute-embeddings"
4857 }
4858}
4859
4860async fn compute_embeddings(
4861 request: proto::ComputeEmbeddings,
4862 response: Response<proto::ComputeEmbeddings>,
4863 session: UserSession,
4864 api_key: Option<Arc<str>>,
4865) -> Result<()> {
4866 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4867 authorize_access_to_language_models(&session).await?;
4868
4869 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4870 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4871 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4872 };
4873
4874 session
4875 .rate_limiter
4876 .check(&*rate_limit, session.user_id())
4877 .await?;
4878
4879 let embeddings = match request.model.as_str() {
4880 "openai/text-embedding-3-small" => {
4881 open_ai::embed(
4882 session.http_client.as_ref(),
4883 OPEN_AI_API_URL,
4884 &api_key,
4885 OpenAiEmbeddingModel::TextEmbedding3Small,
4886 request.texts.iter().map(|text| text.as_str()),
4887 )
4888 .await?
4889 }
4890 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4891 };
4892
4893 let embeddings = request
4894 .texts
4895 .iter()
4896 .map(|text| {
4897 let mut hasher = sha2::Sha256::new();
4898 hasher.update(text.as_bytes());
4899 let result = hasher.finalize();
4900 result.to_vec()
4901 })
4902 .zip(
4903 embeddings
4904 .data
4905 .into_iter()
4906 .map(|embedding| embedding.embedding),
4907 )
4908 .collect::<HashMap<_, _>>();
4909
4910 let db = session.db().await;
4911 db.save_embeddings(&request.model, &embeddings)
4912 .await
4913 .context("failed to save embeddings")
4914 .trace_err();
4915
4916 response.send(proto::ComputeEmbeddingsResponse {
4917 embeddings: embeddings
4918 .into_iter()
4919 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4920 .collect(),
4921 })?;
4922 Ok(())
4923}
4924
4925async fn get_cached_embeddings(
4926 request: proto::GetCachedEmbeddings,
4927 response: Response<proto::GetCachedEmbeddings>,
4928 session: UserSession,
4929) -> Result<()> {
4930 authorize_access_to_language_models(&session).await?;
4931
4932 let db = session.db().await;
4933 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4934
4935 response.send(proto::GetCachedEmbeddingsResponse {
4936 embeddings: embeddings
4937 .into_iter()
4938 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4939 .collect(),
4940 })?;
4941 Ok(())
4942}
4943
4944async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4945 let db = session.db().await;
4946 let flags = db.get_user_flags(session.user_id()).await?;
4947 if flags.iter().any(|flag| flag == "language-models") {
4948 return Ok(());
4949 }
4950
4951 Err(anyhow!("permission denied"))?
4952}
4953
4954/// Get a Supermaven API key for the user
4955async fn get_supermaven_api_key(
4956 _request: proto::GetSupermavenApiKey,
4957 response: Response<proto::GetSupermavenApiKey>,
4958 session: UserSession,
4959) -> Result<()> {
4960 let user_id: String = session.user_id().to_string();
4961 if !session.is_staff() {
4962 return Err(anyhow!("supermaven not enabled for this account"))?;
4963 }
4964
4965 let email = session
4966 .email()
4967 .ok_or_else(|| anyhow!("user must have an email"))?;
4968
4969 let supermaven_admin_api = session
4970 .supermaven_client
4971 .as_ref()
4972 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4973
4974 let result = supermaven_admin_api
4975 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4976 .await?;
4977
4978 response.send(proto::GetSupermavenApiKeyResponse {
4979 api_key: result.api_key,
4980 })?;
4981
4982 Ok(())
4983}
4984
4985/// Start receiving chat updates for a channel
4986async fn join_channel_chat(
4987 request: proto::JoinChannelChat,
4988 response: Response<proto::JoinChannelChat>,
4989 session: UserSession,
4990) -> Result<()> {
4991 let channel_id = ChannelId::from_proto(request.channel_id);
4992
4993 let db = session.db().await;
4994 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4995 .await?;
4996 let messages = db
4997 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4998 .await?;
4999 response.send(proto::JoinChannelChatResponse {
5000 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5001 messages,
5002 })?;
5003 Ok(())
5004}
5005
5006/// Stop receiving chat updates for a channel
5007async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5008 let channel_id = ChannelId::from_proto(request.channel_id);
5009 session
5010 .db()
5011 .await
5012 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5013 .await?;
5014 Ok(())
5015}
5016
5017/// Retrieve the chat history for a channel
5018async fn get_channel_messages(
5019 request: proto::GetChannelMessages,
5020 response: Response<proto::GetChannelMessages>,
5021 session: UserSession,
5022) -> Result<()> {
5023 let channel_id = ChannelId::from_proto(request.channel_id);
5024 let messages = session
5025 .db()
5026 .await
5027 .get_channel_messages(
5028 channel_id,
5029 session.user_id(),
5030 MESSAGE_COUNT_PER_PAGE,
5031 Some(MessageId::from_proto(request.before_message_id)),
5032 )
5033 .await?;
5034 response.send(proto::GetChannelMessagesResponse {
5035 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5036 messages,
5037 })?;
5038 Ok(())
5039}
5040
5041/// Retrieve specific chat messages
5042async fn get_channel_messages_by_id(
5043 request: proto::GetChannelMessagesById,
5044 response: Response<proto::GetChannelMessagesById>,
5045 session: UserSession,
5046) -> Result<()> {
5047 let message_ids = request
5048 .message_ids
5049 .iter()
5050 .map(|id| MessageId::from_proto(*id))
5051 .collect::<Vec<_>>();
5052 let messages = session
5053 .db()
5054 .await
5055 .get_channel_messages_by_id(session.user_id(), &message_ids)
5056 .await?;
5057 response.send(proto::GetChannelMessagesResponse {
5058 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5059 messages,
5060 })?;
5061 Ok(())
5062}
5063
5064/// Retrieve the current users notifications
5065async fn get_notifications(
5066 request: proto::GetNotifications,
5067 response: Response<proto::GetNotifications>,
5068 session: UserSession,
5069) -> Result<()> {
5070 let notifications = session
5071 .db()
5072 .await
5073 .get_notifications(
5074 session.user_id(),
5075 NOTIFICATION_COUNT_PER_PAGE,
5076 request
5077 .before_id
5078 .map(|id| db::NotificationId::from_proto(id)),
5079 )
5080 .await?;
5081 response.send(proto::GetNotificationsResponse {
5082 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5083 notifications,
5084 })?;
5085 Ok(())
5086}
5087
5088/// Mark notifications as read
5089async fn mark_notification_as_read(
5090 request: proto::MarkNotificationRead,
5091 response: Response<proto::MarkNotificationRead>,
5092 session: UserSession,
5093) -> Result<()> {
5094 let database = &session.db().await;
5095 let notifications = database
5096 .mark_notification_as_read_by_id(
5097 session.user_id(),
5098 NotificationId::from_proto(request.notification_id),
5099 )
5100 .await?;
5101 send_notifications(
5102 &*session.connection_pool().await,
5103 &session.peer,
5104 notifications,
5105 );
5106 response.send(proto::Ack {})?;
5107 Ok(())
5108}
5109
5110/// Get the current users information
5111async fn get_private_user_info(
5112 _request: proto::GetPrivateUserInfo,
5113 response: Response<proto::GetPrivateUserInfo>,
5114 session: UserSession,
5115) -> Result<()> {
5116 let db = session.db().await;
5117
5118 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5119 let user = db
5120 .get_user_by_id(session.user_id())
5121 .await?
5122 .ok_or_else(|| anyhow!("user not found"))?;
5123 let flags = db.get_user_flags(session.user_id()).await?;
5124
5125 response.send(proto::GetPrivateUserInfoResponse {
5126 metrics_id,
5127 staff: user.admin,
5128 flags,
5129 })?;
5130 Ok(())
5131}
5132
5133fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5134 let message = match message {
5135 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5136 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5137 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5138 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5139 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5140 code: frame.code.into(),
5141 reason: frame.reason,
5142 })),
5143 // We should never receive a frame while reading the message, according
5144 // to the `tungstenite` maintainers:
5145 //
5146 // > It cannot occur when you read messages from the WebSocket, but it
5147 // > can be used when you want to send the raw frames (e.g. you want to
5148 // > send the frames to the WebSocket without composing the full message first).
5149 // >
5150 // > — https://github.com/snapview/tungstenite-rs/issues/268
5151 TungsteniteMessage::Frame(_) => {
5152 bail!("received an unexpected frame while reading the message")
5153 }
5154 };
5155
5156 Ok(message)
5157}
5158
5159fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5160 match message {
5161 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5162 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5163 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5164 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5165 AxumMessage::Close(frame) => {
5166 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5167 code: frame.code.into(),
5168 reason: frame.reason,
5169 }))
5170 }
5171 }
5172}
5173
5174fn notify_membership_updated(
5175 connection_pool: &mut ConnectionPool,
5176 result: MembershipUpdated,
5177 user_id: UserId,
5178 peer: &Peer,
5179) {
5180 for membership in &result.new_channels.channel_memberships {
5181 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5182 }
5183 for channel_id in &result.removed_channels {
5184 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5185 }
5186
5187 let user_channels_update = proto::UpdateUserChannels {
5188 channel_memberships: result
5189 .new_channels
5190 .channel_memberships
5191 .iter()
5192 .map(|cm| proto::ChannelMembership {
5193 channel_id: cm.channel_id.to_proto(),
5194 role: cm.role.into(),
5195 })
5196 .collect(),
5197 ..Default::default()
5198 };
5199
5200 let mut update = build_channels_update(result.new_channels);
5201 update.delete_channels = result
5202 .removed_channels
5203 .into_iter()
5204 .map(|id| id.to_proto())
5205 .collect();
5206 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5207
5208 for connection_id in connection_pool.user_connection_ids(user_id) {
5209 peer.send(connection_id, user_channels_update.clone())
5210 .trace_err();
5211 peer.send(connection_id, update.clone()).trace_err();
5212 }
5213}
5214
5215fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5216 proto::UpdateUserChannels {
5217 channel_memberships: channels
5218 .channel_memberships
5219 .iter()
5220 .map(|m| proto::ChannelMembership {
5221 channel_id: m.channel_id.to_proto(),
5222 role: m.role.into(),
5223 })
5224 .collect(),
5225 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5226 observed_channel_message_id: channels.observed_channel_messages.clone(),
5227 }
5228}
5229
5230fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5231 let mut update = proto::UpdateChannels::default();
5232
5233 for channel in channels.channels {
5234 update.channels.push(channel.to_proto());
5235 }
5236
5237 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5238 update.latest_channel_message_ids = channels.latest_channel_messages;
5239
5240 for (channel_id, participants) in channels.channel_participants {
5241 update
5242 .channel_participants
5243 .push(proto::ChannelParticipants {
5244 channel_id: channel_id.to_proto(),
5245 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5246 });
5247 }
5248
5249 for channel in channels.invited_channels {
5250 update.channel_invitations.push(channel.to_proto());
5251 }
5252
5253 update.hosted_projects = channels.hosted_projects;
5254 update
5255}
5256
5257fn build_initial_contacts_update(
5258 contacts: Vec<db::Contact>,
5259 pool: &ConnectionPool,
5260) -> proto::UpdateContacts {
5261 let mut update = proto::UpdateContacts::default();
5262
5263 for contact in contacts {
5264 match contact {
5265 db::Contact::Accepted { user_id, busy } => {
5266 update.contacts.push(contact_for_user(user_id, busy, &pool));
5267 }
5268 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5269 db::Contact::Incoming { user_id } => {
5270 update
5271 .incoming_requests
5272 .push(proto::IncomingContactRequest {
5273 requester_id: user_id.to_proto(),
5274 })
5275 }
5276 }
5277 }
5278
5279 update
5280}
5281
5282fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5283 proto::Contact {
5284 user_id: user_id.to_proto(),
5285 online: pool.is_user_online(user_id),
5286 busy,
5287 }
5288}
5289
5290fn room_updated(room: &proto::Room, peer: &Peer) {
5291 broadcast(
5292 None,
5293 room.participants
5294 .iter()
5295 .filter_map(|participant| Some(participant.peer_id?.into())),
5296 |peer_id| {
5297 peer.send(
5298 peer_id,
5299 proto::RoomUpdated {
5300 room: Some(room.clone()),
5301 },
5302 )
5303 },
5304 );
5305}
5306
5307fn channel_updated(
5308 channel: &db::channel::Model,
5309 room: &proto::Room,
5310 peer: &Peer,
5311 pool: &ConnectionPool,
5312) {
5313 let participants = room
5314 .participants
5315 .iter()
5316 .map(|p| p.user_id)
5317 .collect::<Vec<_>>();
5318
5319 broadcast(
5320 None,
5321 pool.channel_connection_ids(channel.root_id())
5322 .filter_map(|(channel_id, role)| {
5323 role.can_see_channel(channel.visibility).then(|| channel_id)
5324 }),
5325 |peer_id| {
5326 peer.send(
5327 peer_id,
5328 proto::UpdateChannels {
5329 channel_participants: vec![proto::ChannelParticipants {
5330 channel_id: channel.id.to_proto(),
5331 participant_user_ids: participants.clone(),
5332 }],
5333 ..Default::default()
5334 },
5335 )
5336 },
5337 );
5338}
5339
5340async fn send_dev_server_projects_update(
5341 user_id: UserId,
5342 mut status: proto::DevServerProjectsUpdate,
5343 session: &Session,
5344) {
5345 let pool = session.connection_pool().await;
5346 for dev_server in &mut status.dev_servers {
5347 dev_server.status =
5348 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5349 }
5350 let connections = pool.user_connection_ids(user_id);
5351 for connection_id in connections {
5352 session.peer.send(connection_id, status.clone()).trace_err();
5353 }
5354}
5355
5356async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5357 let db = session.db().await;
5358
5359 let contacts = db.get_contacts(user_id).await?;
5360 let busy = db.is_user_busy(user_id).await?;
5361
5362 let pool = session.connection_pool().await;
5363 let updated_contact = contact_for_user(user_id, busy, &pool);
5364 for contact in contacts {
5365 if let db::Contact::Accepted {
5366 user_id: contact_user_id,
5367 ..
5368 } = contact
5369 {
5370 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5371 session
5372 .peer
5373 .send(
5374 contact_conn_id,
5375 proto::UpdateContacts {
5376 contacts: vec![updated_contact.clone()],
5377 remove_contacts: Default::default(),
5378 incoming_requests: Default::default(),
5379 remove_incoming_requests: Default::default(),
5380 outgoing_requests: Default::default(),
5381 remove_outgoing_requests: Default::default(),
5382 },
5383 )
5384 .trace_err();
5385 }
5386 }
5387 }
5388 Ok(())
5389}
5390
5391async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5392 log::info!("lost dev server connection, unsharing projects");
5393 let project_ids = session
5394 .db()
5395 .await
5396 .get_stale_dev_server_projects(session.connection_id)
5397 .await?;
5398
5399 for project_id in project_ids {
5400 // not unshare re-checks the connection ids match, so we get away with no transaction
5401 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5402 }
5403
5404 let user_id = session.dev_server().user_id;
5405 let update = session
5406 .db()
5407 .await
5408 .dev_server_projects_update(user_id)
5409 .await?;
5410
5411 send_dev_server_projects_update(user_id, update, session).await;
5412
5413 Ok(())
5414}
5415
5416async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5417 let mut contacts_to_update = HashSet::default();
5418
5419 let room_id;
5420 let canceled_calls_to_user_ids;
5421 let live_kit_room;
5422 let delete_live_kit_room;
5423 let room;
5424 let channel;
5425
5426 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5427 contacts_to_update.insert(session.user_id());
5428
5429 for project in left_room.left_projects.values() {
5430 project_left(project, session);
5431 }
5432
5433 room_id = RoomId::from_proto(left_room.room.id);
5434 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5435 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5436 delete_live_kit_room = left_room.deleted;
5437 room = mem::take(&mut left_room.room);
5438 channel = mem::take(&mut left_room.channel);
5439
5440 room_updated(&room, &session.peer);
5441 } else {
5442 return Ok(());
5443 }
5444
5445 if let Some(channel) = channel {
5446 channel_updated(
5447 &channel,
5448 &room,
5449 &session.peer,
5450 &*session.connection_pool().await,
5451 );
5452 }
5453
5454 {
5455 let pool = session.connection_pool().await;
5456 for canceled_user_id in canceled_calls_to_user_ids {
5457 for connection_id in pool.user_connection_ids(canceled_user_id) {
5458 session
5459 .peer
5460 .send(
5461 connection_id,
5462 proto::CallCanceled {
5463 room_id: room_id.to_proto(),
5464 },
5465 )
5466 .trace_err();
5467 }
5468 contacts_to_update.insert(canceled_user_id);
5469 }
5470 }
5471
5472 for contact_user_id in contacts_to_update {
5473 update_user_contacts(contact_user_id, &session).await?;
5474 }
5475
5476 if let Some(live_kit) = session.live_kit_client.as_ref() {
5477 live_kit
5478 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5479 .await
5480 .trace_err();
5481
5482 if delete_live_kit_room {
5483 live_kit.delete_room(live_kit_room).await.trace_err();
5484 }
5485 }
5486
5487 Ok(())
5488}
5489
5490async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5491 let left_channel_buffers = session
5492 .db()
5493 .await
5494 .leave_channel_buffers(session.connection_id)
5495 .await?;
5496
5497 for left_buffer in left_channel_buffers {
5498 channel_buffer_updated(
5499 session.connection_id,
5500 left_buffer.connections,
5501 &proto::UpdateChannelBufferCollaborators {
5502 channel_id: left_buffer.channel_id.to_proto(),
5503 collaborators: left_buffer.collaborators,
5504 },
5505 &session.peer,
5506 );
5507 }
5508
5509 Ok(())
5510}
5511
5512fn project_left(project: &db::LeftProject, session: &UserSession) {
5513 for connection_id in &project.connection_ids {
5514 if project.should_unshare {
5515 session
5516 .peer
5517 .send(
5518 *connection_id,
5519 proto::UnshareProject {
5520 project_id: project.id.to_proto(),
5521 },
5522 )
5523 .trace_err();
5524 } else {
5525 session
5526 .peer
5527 .send(
5528 *connection_id,
5529 proto::RemoveProjectCollaborator {
5530 project_id: project.id.to_proto(),
5531 peer_id: Some(session.connection_id.into()),
5532 },
5533 )
5534 .trace_err();
5535 }
5536 }
5537}
5538
5539pub trait ResultExt {
5540 type Ok;
5541
5542 fn trace_err(self) -> Option<Self::Ok>;
5543}
5544
5545impl<T, E> ResultExt for Result<T, E>
5546where
5547 E: std::fmt::Debug,
5548{
5549 type Ok = T;
5550
5551 #[track_caller]
5552 fn trace_err(self) -> Option<T> {
5553 match self {
5554 Ok(value) => Some(value),
5555 Err(error) => {
5556 tracing::error!("{:?}", error);
5557 None
5558 }
5559 }
5560 }
5561}