1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Config, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::CreateContext>,
605 ))
606 .add_request_handler(user_handler(
607 forward_mutating_project_request::<proto::SynchronizeContexts>,
608 ))
609 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
610 .add_message_handler(update_context)
611 .add_request_handler({
612 let app_state = app_state.clone();
613 move |request, response, session| {
614 let app_state = app_state.clone();
615 async move {
616 complete_with_language_model(request, response, session, &app_state.config)
617 .await
618 }
619 }
620 })
621 .add_streaming_request_handler({
622 let app_state = app_state.clone();
623 move |request, response, session| {
624 let app_state = app_state.clone();
625 async move {
626 stream_complete_with_language_model(
627 request,
628 response,
629 session,
630 &app_state.config,
631 )
632 .await
633 }
634 }
635 })
636 .add_request_handler({
637 let app_state = app_state.clone();
638 move |request, response, session| {
639 let app_state = app_state.clone();
640 async move {
641 count_language_model_tokens(request, response, session, &app_state.config)
642 .await
643 }
644 }
645 })
646 .add_request_handler({
647 user_handler(move |request, response, session| {
648 get_cached_embeddings(request, response, session)
649 })
650 })
651 .add_request_handler({
652 let app_state = app_state.clone();
653 user_handler(move |request, response, session| {
654 compute_embeddings(
655 request,
656 response,
657 session,
658 app_state.config.openai_api_key.clone(),
659 )
660 })
661 });
662
663 Arc::new(server)
664 }
665
666 pub async fn start(&self) -> Result<()> {
667 let server_id = *self.id.lock();
668 let app_state = self.app_state.clone();
669 let peer = self.peer.clone();
670 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
671 let pool = self.connection_pool.clone();
672 let live_kit_client = self.app_state.live_kit_client.clone();
673
674 let span = info_span!("start server");
675 self.app_state.executor.spawn_detached(
676 async move {
677 tracing::info!("waiting for cleanup timeout");
678 timeout.await;
679 tracing::info!("cleanup timeout expired, retrieving stale rooms");
680 if let Some((room_ids, channel_ids)) = app_state
681 .db
682 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
683 .await
684 .trace_err()
685 {
686 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
687 tracing::info!(
688 stale_channel_buffer_count = channel_ids.len(),
689 "retrieved stale channel buffers"
690 );
691
692 for channel_id in channel_ids {
693 if let Some(refreshed_channel_buffer) = app_state
694 .db
695 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
696 .await
697 .trace_err()
698 {
699 for connection_id in refreshed_channel_buffer.connection_ids {
700 peer.send(
701 connection_id,
702 proto::UpdateChannelBufferCollaborators {
703 channel_id: channel_id.to_proto(),
704 collaborators: refreshed_channel_buffer
705 .collaborators
706 .clone(),
707 },
708 )
709 .trace_err();
710 }
711 }
712 }
713
714 for room_id in room_ids {
715 let mut contacts_to_update = HashSet::default();
716 let mut canceled_calls_to_user_ids = Vec::new();
717 let mut live_kit_room = String::new();
718 let mut delete_live_kit_room = false;
719
720 if let Some(mut refreshed_room) = app_state
721 .db
722 .clear_stale_room_participants(room_id, server_id)
723 .await
724 .trace_err()
725 {
726 tracing::info!(
727 room_id = room_id.0,
728 new_participant_count = refreshed_room.room.participants.len(),
729 "refreshed room"
730 );
731 room_updated(&refreshed_room.room, &peer);
732 if let Some(channel) = refreshed_room.channel.as_ref() {
733 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
734 }
735 contacts_to_update
736 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
737 contacts_to_update
738 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
739 canceled_calls_to_user_ids =
740 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
741 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
742 delete_live_kit_room = refreshed_room.room.participants.is_empty();
743 }
744
745 {
746 let pool = pool.lock();
747 for canceled_user_id in canceled_calls_to_user_ids {
748 for connection_id in pool.user_connection_ids(canceled_user_id) {
749 peer.send(
750 connection_id,
751 proto::CallCanceled {
752 room_id: room_id.to_proto(),
753 },
754 )
755 .trace_err();
756 }
757 }
758 }
759
760 for user_id in contacts_to_update {
761 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
762 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
763 if let Some((busy, contacts)) = busy.zip(contacts) {
764 let pool = pool.lock();
765 let updated_contact = contact_for_user(user_id, busy, &pool);
766 for contact in contacts {
767 if let db::Contact::Accepted {
768 user_id: contact_user_id,
769 ..
770 } = contact
771 {
772 for contact_conn_id in
773 pool.user_connection_ids(contact_user_id)
774 {
775 peer.send(
776 contact_conn_id,
777 proto::UpdateContacts {
778 contacts: vec![updated_contact.clone()],
779 remove_contacts: Default::default(),
780 incoming_requests: Default::default(),
781 remove_incoming_requests: Default::default(),
782 outgoing_requests: Default::default(),
783 remove_outgoing_requests: Default::default(),
784 },
785 )
786 .trace_err();
787 }
788 }
789 }
790 }
791 }
792
793 if let Some(live_kit) = live_kit_client.as_ref() {
794 if delete_live_kit_room {
795 live_kit.delete_room(live_kit_room).await.trace_err();
796 }
797 }
798 }
799 }
800
801 app_state
802 .db
803 .delete_stale_servers(&app_state.config.zed_environment, server_id)
804 .await
805 .trace_err();
806 }
807 .instrument(span),
808 );
809 Ok(())
810 }
811
812 pub fn teardown(&self) {
813 self.peer.teardown();
814 self.connection_pool.lock().reset();
815 let _ = self.teardown.send(true);
816 }
817
818 #[cfg(test)]
819 pub fn reset(&self, id: ServerId) {
820 self.teardown();
821 *self.id.lock() = id;
822 self.peer.reset(id.0 as u32);
823 let _ = self.teardown.send(false);
824 }
825
826 #[cfg(test)]
827 pub fn id(&self) -> ServerId {
828 *self.id.lock()
829 }
830
831 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
832 where
833 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
834 Fut: 'static + Send + Future<Output = Result<()>>,
835 M: EnvelopedMessage,
836 {
837 let prev_handler = self.handlers.insert(
838 TypeId::of::<M>(),
839 Box::new(move |envelope, session| {
840 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
841 let received_at = envelope.received_at;
842 tracing::info!("message received");
843 let start_time = Instant::now();
844 let future = (handler)(*envelope, session);
845 async move {
846 let result = future.await;
847 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
848 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
849 let queue_duration_ms = total_duration_ms - processing_duration_ms;
850 let payload_type = M::NAME;
851
852 match result {
853 Err(error) => {
854 tracing::error!(
855 ?error,
856 total_duration_ms,
857 processing_duration_ms,
858 queue_duration_ms,
859 payload_type,
860 "error handling message"
861 )
862 }
863 Ok(()) => tracing::info!(
864 total_duration_ms,
865 processing_duration_ms,
866 queue_duration_ms,
867 "finished handling message"
868 ),
869 }
870 }
871 .boxed()
872 }),
873 );
874 if prev_handler.is_some() {
875 panic!("registered a handler for the same message twice");
876 }
877 self
878 }
879
880 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
881 where
882 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
883 Fut: 'static + Send + Future<Output = Result<()>>,
884 M: EnvelopedMessage,
885 {
886 self.add_handler(move |envelope, session| handler(envelope.payload, session));
887 self
888 }
889
890 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
891 where
892 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
893 Fut: Send + Future<Output = Result<()>>,
894 M: RequestMessage,
895 {
896 let handler = Arc::new(handler);
897 self.add_handler(move |envelope, session| {
898 let receipt = envelope.receipt();
899 let handler = handler.clone();
900 async move {
901 let peer = session.peer.clone();
902 let responded = Arc::new(AtomicBool::default());
903 let response = Response {
904 peer: peer.clone(),
905 responded: responded.clone(),
906 receipt,
907 };
908 match (handler)(envelope.payload, response, session).await {
909 Ok(()) => {
910 if responded.load(std::sync::atomic::Ordering::SeqCst) {
911 Ok(())
912 } else {
913 Err(anyhow!("handler did not send a response"))?
914 }
915 }
916 Err(error) => {
917 let proto_err = match &error {
918 Error::Internal(err) => err.to_proto(),
919 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
920 };
921 peer.respond_with_error(receipt, proto_err)?;
922 Err(error)
923 }
924 }
925 }
926 })
927 }
928
929 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
930 where
931 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
932 Fut: Send + Future<Output = Result<()>>,
933 M: RequestMessage,
934 {
935 let handler = Arc::new(handler);
936 self.add_handler(move |envelope, session| {
937 let receipt = envelope.receipt();
938 let handler = handler.clone();
939 async move {
940 let peer = session.peer.clone();
941 let response = StreamingResponse {
942 peer: peer.clone(),
943 receipt,
944 };
945 match (handler)(envelope.payload, response, session).await {
946 Ok(()) => {
947 peer.end_stream(receipt)?;
948 Ok(())
949 }
950 Err(error) => {
951 let proto_err = match &error {
952 Error::Internal(err) => err.to_proto(),
953 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
954 };
955 peer.respond_with_error(receipt, proto_err)?;
956 Err(error)
957 }
958 }
959 }
960 })
961 }
962
963 #[allow(clippy::too_many_arguments)]
964 pub fn handle_connection(
965 self: &Arc<Self>,
966 connection: Connection,
967 address: String,
968 principal: Principal,
969 zed_version: ZedVersion,
970 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
971 executor: Executor,
972 ) -> impl Future<Output = ()> {
973 let this = self.clone();
974 let span = info_span!("handle connection", %address,
975 connection_id=field::Empty,
976 user_id=field::Empty,
977 login=field::Empty,
978 impersonator=field::Empty,
979 dev_server_id=field::Empty
980 );
981 principal.update_span(&span);
982
983 let mut teardown = self.teardown.subscribe();
984 async move {
985 if *teardown.borrow() {
986 tracing::error!("server is tearing down");
987 return
988 }
989 let (connection_id, handle_io, mut incoming_rx) = this
990 .peer
991 .add_connection(connection, {
992 let executor = executor.clone();
993 move |duration| executor.sleep(duration)
994 });
995 tracing::Span::current().record("connection_id", format!("{}", connection_id));
996 tracing::info!("connection opened");
997
998 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
999 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1000 Ok(http_client) => Arc::new(http_client),
1001 Err(error) => {
1002 tracing::error!(?error, "failed to create HTTP client");
1003 return;
1004 }
1005 };
1006
1007 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1008 Some(Arc::new(SupermavenAdminApi::new(
1009 supermaven_admin_api_key.to_string(),
1010 http_client.clone(),
1011 )))
1012 } else {
1013 None
1014 };
1015
1016 let session = Session {
1017 principal: principal.clone(),
1018 connection_id,
1019 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1020 peer: this.peer.clone(),
1021 connection_pool: this.connection_pool.clone(),
1022 live_kit_client: this.app_state.live_kit_client.clone(),
1023 http_client,
1024 rate_limiter: this.app_state.rate_limiter.clone(),
1025 _executor: executor.clone(),
1026 supermaven_client,
1027 };
1028
1029 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1030 tracing::error!(?error, "failed to send initial client update");
1031 return;
1032 }
1033
1034 let handle_io = handle_io.fuse();
1035 futures::pin_mut!(handle_io);
1036
1037 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1038 // This prevents deadlocks when e.g., client A performs a request to client B and
1039 // client B performs a request to client A. If both clients stop processing further
1040 // messages until their respective request completes, they won't have a chance to
1041 // respond to the other client's request and cause a deadlock.
1042 //
1043 // This arrangement ensures we will attempt to process earlier messages first, but fall
1044 // back to processing messages arrived later in the spirit of making progress.
1045 let mut foreground_message_handlers = FuturesUnordered::new();
1046 let concurrent_handlers = Arc::new(Semaphore::new(256));
1047 loop {
1048 let next_message = async {
1049 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1050 let message = incoming_rx.next().await;
1051 (permit, message)
1052 }.fuse();
1053 futures::pin_mut!(next_message);
1054 futures::select_biased! {
1055 _ = teardown.changed().fuse() => return,
1056 result = handle_io => {
1057 if let Err(error) = result {
1058 tracing::error!(?error, "error handling I/O");
1059 }
1060 break;
1061 }
1062 _ = foreground_message_handlers.next() => {}
1063 next_message = next_message => {
1064 let (permit, message) = next_message;
1065 if let Some(message) = message {
1066 let type_name = message.payload_type_name();
1067 // note: we copy all the fields from the parent span so we can query them in the logs.
1068 // (https://github.com/tokio-rs/tracing/issues/2670).
1069 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1070 user_id=field::Empty,
1071 login=field::Empty,
1072 impersonator=field::Empty,
1073 dev_server_id=field::Empty
1074 );
1075 principal.update_span(&span);
1076 let span_enter = span.enter();
1077 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1078 let is_background = message.is_background();
1079 let handle_message = (handler)(message, session.clone());
1080 drop(span_enter);
1081
1082 let handle_message = async move {
1083 handle_message.await;
1084 drop(permit);
1085 }.instrument(span);
1086 if is_background {
1087 executor.spawn_detached(handle_message);
1088 } else {
1089 foreground_message_handlers.push(handle_message);
1090 }
1091 } else {
1092 tracing::error!("no message handler");
1093 }
1094 } else {
1095 tracing::info!("connection closed");
1096 break;
1097 }
1098 }
1099 }
1100 }
1101
1102 drop(foreground_message_handlers);
1103 tracing::info!("signing out");
1104 if let Err(error) = connection_lost(session, teardown, executor).await {
1105 tracing::error!(?error, "error signing out");
1106 }
1107
1108 }.instrument(span)
1109 }
1110
1111 async fn send_initial_client_update(
1112 &self,
1113 connection_id: ConnectionId,
1114 principal: &Principal,
1115 zed_version: ZedVersion,
1116 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1117 session: &Session,
1118 ) -> Result<()> {
1119 self.peer.send(
1120 connection_id,
1121 proto::Hello {
1122 peer_id: Some(connection_id.into()),
1123 },
1124 )?;
1125 tracing::info!("sent hello message");
1126 if let Some(send_connection_id) = send_connection_id.take() {
1127 let _ = send_connection_id.send(connection_id);
1128 }
1129
1130 match principal {
1131 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1132 if !user.connected_once {
1133 self.peer.send(connection_id, proto::ShowContacts {})?;
1134 self.app_state
1135 .db
1136 .set_user_connected_once(user.id, true)
1137 .await?;
1138 }
1139
1140 update_user_plan(user.id, session).await?;
1141
1142 let (contacts, dev_server_projects) = future::try_join(
1143 self.app_state.db.get_contacts(user.id),
1144 self.app_state.db.dev_server_projects_update(user.id),
1145 )
1146 .await?;
1147
1148 {
1149 let mut pool = self.connection_pool.lock();
1150 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1151 self.peer.send(
1152 connection_id,
1153 build_initial_contacts_update(contacts, &pool),
1154 )?;
1155 }
1156
1157 if should_auto_subscribe_to_channels(zed_version) {
1158 subscribe_user_to_channels(user.id, session).await?;
1159 }
1160
1161 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1162
1163 if let Some(incoming_call) =
1164 self.app_state.db.incoming_call_for_user(user.id).await?
1165 {
1166 self.peer.send(connection_id, incoming_call)?;
1167 }
1168
1169 update_user_contacts(user.id, &session).await?;
1170 }
1171 Principal::DevServer(dev_server) => {
1172 {
1173 let mut pool = self.connection_pool.lock();
1174 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1175 {
1176 self.peer.send(
1177 stale_connection_id,
1178 proto::ShutdownDevServer {
1179 reason: Some(
1180 "another dev server connected with the same token".to_string(),
1181 ),
1182 },
1183 )?;
1184 pool.remove_connection(stale_connection_id)?;
1185 };
1186 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1187 }
1188
1189 let projects = self
1190 .app_state
1191 .db
1192 .get_projects_for_dev_server(dev_server.id)
1193 .await?;
1194 self.peer
1195 .send(connection_id, proto::DevServerInstructions { projects })?;
1196
1197 let status = self
1198 .app_state
1199 .db
1200 .dev_server_projects_update(dev_server.user_id)
1201 .await?;
1202 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1203 }
1204 }
1205
1206 Ok(())
1207 }
1208
1209 pub async fn invite_code_redeemed(
1210 self: &Arc<Self>,
1211 inviter_id: UserId,
1212 invitee_id: UserId,
1213 ) -> Result<()> {
1214 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1215 if let Some(code) = &user.invite_code {
1216 let pool = self.connection_pool.lock();
1217 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1218 for connection_id in pool.user_connection_ids(inviter_id) {
1219 self.peer.send(
1220 connection_id,
1221 proto::UpdateContacts {
1222 contacts: vec![invitee_contact.clone()],
1223 ..Default::default()
1224 },
1225 )?;
1226 self.peer.send(
1227 connection_id,
1228 proto::UpdateInviteInfo {
1229 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1230 count: user.invite_count as u32,
1231 },
1232 )?;
1233 }
1234 }
1235 }
1236 Ok(())
1237 }
1238
1239 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1240 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1241 if let Some(invite_code) = &user.invite_code {
1242 let pool = self.connection_pool.lock();
1243 for connection_id in pool.user_connection_ids(user_id) {
1244 self.peer.send(
1245 connection_id,
1246 proto::UpdateInviteInfo {
1247 url: format!(
1248 "{}{}",
1249 self.app_state.config.invite_link_prefix, invite_code
1250 ),
1251 count: user.invite_count as u32,
1252 },
1253 )?;
1254 }
1255 }
1256 }
1257 Ok(())
1258 }
1259
1260 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1261 ServerSnapshot {
1262 connection_pool: ConnectionPoolGuard {
1263 guard: self.connection_pool.lock(),
1264 _not_send: PhantomData,
1265 },
1266 peer: &self.peer,
1267 }
1268 }
1269}
1270
1271impl<'a> Deref for ConnectionPoolGuard<'a> {
1272 type Target = ConnectionPool;
1273
1274 fn deref(&self) -> &Self::Target {
1275 &self.guard
1276 }
1277}
1278
1279impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1280 fn deref_mut(&mut self) -> &mut Self::Target {
1281 &mut self.guard
1282 }
1283}
1284
1285impl<'a> Drop for ConnectionPoolGuard<'a> {
1286 fn drop(&mut self) {
1287 #[cfg(test)]
1288 self.check_invariants();
1289 }
1290}
1291
1292fn broadcast<F>(
1293 sender_id: Option<ConnectionId>,
1294 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1295 mut f: F,
1296) where
1297 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1298{
1299 for receiver_id in receiver_ids {
1300 if Some(receiver_id) != sender_id {
1301 if let Err(error) = f(receiver_id) {
1302 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1303 }
1304 }
1305 }
1306}
1307
1308pub struct ProtocolVersion(u32);
1309
1310impl Header for ProtocolVersion {
1311 fn name() -> &'static HeaderName {
1312 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1313 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1314 }
1315
1316 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1317 where
1318 Self: Sized,
1319 I: Iterator<Item = &'i axum::http::HeaderValue>,
1320 {
1321 let version = values
1322 .next()
1323 .ok_or_else(axum::headers::Error::invalid)?
1324 .to_str()
1325 .map_err(|_| axum::headers::Error::invalid())?
1326 .parse()
1327 .map_err(|_| axum::headers::Error::invalid())?;
1328 Ok(Self(version))
1329 }
1330
1331 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1332 values.extend([self.0.to_string().parse().unwrap()]);
1333 }
1334}
1335
1336pub struct AppVersionHeader(SemanticVersion);
1337impl Header for AppVersionHeader {
1338 fn name() -> &'static HeaderName {
1339 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1340 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1341 }
1342
1343 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1344 where
1345 Self: Sized,
1346 I: Iterator<Item = &'i axum::http::HeaderValue>,
1347 {
1348 let version = values
1349 .next()
1350 .ok_or_else(axum::headers::Error::invalid)?
1351 .to_str()
1352 .map_err(|_| axum::headers::Error::invalid())?
1353 .parse()
1354 .map_err(|_| axum::headers::Error::invalid())?;
1355 Ok(Self(version))
1356 }
1357
1358 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1359 values.extend([self.0.to_string().parse().unwrap()]);
1360 }
1361}
1362
1363pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1364 Router::new()
1365 .route("/rpc", get(handle_websocket_request))
1366 .layer(
1367 ServiceBuilder::new()
1368 .layer(Extension(server.app_state.clone()))
1369 .layer(middleware::from_fn(auth::validate_header)),
1370 )
1371 .route("/metrics", get(handle_metrics))
1372 .layer(Extension(server))
1373}
1374
1375pub async fn handle_websocket_request(
1376 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1377 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1378 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1379 Extension(server): Extension<Arc<Server>>,
1380 Extension(principal): Extension<Principal>,
1381 ws: WebSocketUpgrade,
1382) -> axum::response::Response {
1383 if protocol_version != rpc::PROTOCOL_VERSION {
1384 return (
1385 StatusCode::UPGRADE_REQUIRED,
1386 "client must be upgraded".to_string(),
1387 )
1388 .into_response();
1389 }
1390
1391 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1392 return (
1393 StatusCode::UPGRADE_REQUIRED,
1394 "no version header found".to_string(),
1395 )
1396 .into_response();
1397 };
1398
1399 if !version.can_collaborate() {
1400 return (
1401 StatusCode::UPGRADE_REQUIRED,
1402 "client must be upgraded".to_string(),
1403 )
1404 .into_response();
1405 }
1406
1407 let socket_address = socket_address.to_string();
1408 ws.on_upgrade(move |socket| {
1409 let socket = socket
1410 .map_ok(to_tungstenite_message)
1411 .err_into()
1412 .with(|message| async move { to_axum_message(message) });
1413 let connection = Connection::new(Box::pin(socket));
1414 async move {
1415 server
1416 .handle_connection(
1417 connection,
1418 socket_address,
1419 principal,
1420 version,
1421 None,
1422 Executor::Production,
1423 )
1424 .await;
1425 }
1426 })
1427}
1428
1429pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1430 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1431 let connections_metric = CONNECTIONS_METRIC
1432 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1433
1434 let connections = server
1435 .connection_pool
1436 .lock()
1437 .connections()
1438 .filter(|connection| !connection.admin)
1439 .count();
1440 connections_metric.set(connections as _);
1441
1442 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1443 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1444 register_int_gauge!(
1445 "shared_projects",
1446 "number of open projects with one or more guests"
1447 )
1448 .unwrap()
1449 });
1450
1451 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1452 shared_projects_metric.set(shared_projects as _);
1453
1454 let encoder = prometheus::TextEncoder::new();
1455 let metric_families = prometheus::gather();
1456 let encoded_metrics = encoder
1457 .encode_to_string(&metric_families)
1458 .map_err(|err| anyhow!("{}", err))?;
1459 Ok(encoded_metrics)
1460}
1461
1462#[instrument(err, skip(executor))]
1463async fn connection_lost(
1464 session: Session,
1465 mut teardown: watch::Receiver<bool>,
1466 executor: Executor,
1467) -> Result<()> {
1468 session.peer.disconnect(session.connection_id);
1469 session
1470 .connection_pool()
1471 .await
1472 .remove_connection(session.connection_id)?;
1473
1474 session
1475 .db()
1476 .await
1477 .connection_lost(session.connection_id)
1478 .await
1479 .trace_err();
1480
1481 futures::select_biased! {
1482 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1483 match &session.principal {
1484 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1485 let session = session.for_user().unwrap();
1486
1487 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1488 leave_room_for_session(&session, session.connection_id).await.trace_err();
1489 leave_channel_buffers_for_session(&session)
1490 .await
1491 .trace_err();
1492
1493 if !session
1494 .connection_pool()
1495 .await
1496 .is_user_online(session.user_id())
1497 {
1498 let db = session.db().await;
1499 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1500 room_updated(&room, &session.peer);
1501 }
1502 }
1503
1504 update_user_contacts(session.user_id(), &session).await?;
1505 },
1506 Principal::DevServer(_) => {
1507 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1508 },
1509 }
1510 },
1511 _ = teardown.changed().fuse() => {}
1512 }
1513
1514 Ok(())
1515}
1516
1517/// Acknowledges a ping from a client, used to keep the connection alive.
1518async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1519 response.send(proto::Ack {})?;
1520 Ok(())
1521}
1522
1523/// Creates a new room for calling (outside of channels)
1524async fn create_room(
1525 _request: proto::CreateRoom,
1526 response: Response<proto::CreateRoom>,
1527 session: UserSession,
1528) -> Result<()> {
1529 let live_kit_room = nanoid::nanoid!(30);
1530
1531 let live_kit_connection_info = util::maybe!(async {
1532 let live_kit = session.live_kit_client.as_ref();
1533 let live_kit = live_kit?;
1534 let user_id = session.user_id().to_string();
1535
1536 let token = live_kit
1537 .room_token(&live_kit_room, &user_id.to_string())
1538 .trace_err()?;
1539
1540 Some(proto::LiveKitConnectionInfo {
1541 server_url: live_kit.url().into(),
1542 token,
1543 can_publish: true,
1544 })
1545 })
1546 .await;
1547
1548 let room = session
1549 .db()
1550 .await
1551 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1552 .await?;
1553
1554 response.send(proto::CreateRoomResponse {
1555 room: Some(room.clone()),
1556 live_kit_connection_info,
1557 })?;
1558
1559 update_user_contacts(session.user_id(), &session).await?;
1560 Ok(())
1561}
1562
1563/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1564async fn join_room(
1565 request: proto::JoinRoom,
1566 response: Response<proto::JoinRoom>,
1567 session: UserSession,
1568) -> Result<()> {
1569 let room_id = RoomId::from_proto(request.id);
1570
1571 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1572
1573 if let Some(channel_id) = channel_id {
1574 return join_channel_internal(channel_id, Box::new(response), session).await;
1575 }
1576
1577 let joined_room = {
1578 let room = session
1579 .db()
1580 .await
1581 .join_room(room_id, session.user_id(), session.connection_id)
1582 .await?;
1583 room_updated(&room.room, &session.peer);
1584 room.into_inner()
1585 };
1586
1587 for connection_id in session
1588 .connection_pool()
1589 .await
1590 .user_connection_ids(session.user_id())
1591 {
1592 session
1593 .peer
1594 .send(
1595 connection_id,
1596 proto::CallCanceled {
1597 room_id: room_id.to_proto(),
1598 },
1599 )
1600 .trace_err();
1601 }
1602
1603 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1604 if let Some(token) = live_kit
1605 .room_token(
1606 &joined_room.room.live_kit_room,
1607 &session.user_id().to_string(),
1608 )
1609 .trace_err()
1610 {
1611 Some(proto::LiveKitConnectionInfo {
1612 server_url: live_kit.url().into(),
1613 token,
1614 can_publish: true,
1615 })
1616 } else {
1617 None
1618 }
1619 } else {
1620 None
1621 };
1622
1623 response.send(proto::JoinRoomResponse {
1624 room: Some(joined_room.room),
1625 channel_id: None,
1626 live_kit_connection_info,
1627 })?;
1628
1629 update_user_contacts(session.user_id(), &session).await?;
1630 Ok(())
1631}
1632
1633/// Rejoin room is used to reconnect to a room after connection errors.
1634async fn rejoin_room(
1635 request: proto::RejoinRoom,
1636 response: Response<proto::RejoinRoom>,
1637 session: UserSession,
1638) -> Result<()> {
1639 let room;
1640 let channel;
1641 {
1642 let mut rejoined_room = session
1643 .db()
1644 .await
1645 .rejoin_room(request, session.user_id(), session.connection_id)
1646 .await?;
1647
1648 response.send(proto::RejoinRoomResponse {
1649 room: Some(rejoined_room.room.clone()),
1650 reshared_projects: rejoined_room
1651 .reshared_projects
1652 .iter()
1653 .map(|project| proto::ResharedProject {
1654 id: project.id.to_proto(),
1655 collaborators: project
1656 .collaborators
1657 .iter()
1658 .map(|collaborator| collaborator.to_proto())
1659 .collect(),
1660 })
1661 .collect(),
1662 rejoined_projects: rejoined_room
1663 .rejoined_projects
1664 .iter()
1665 .map(|rejoined_project| rejoined_project.to_proto())
1666 .collect(),
1667 })?;
1668 room_updated(&rejoined_room.room, &session.peer);
1669
1670 for project in &rejoined_room.reshared_projects {
1671 for collaborator in &project.collaborators {
1672 session
1673 .peer
1674 .send(
1675 collaborator.connection_id,
1676 proto::UpdateProjectCollaborator {
1677 project_id: project.id.to_proto(),
1678 old_peer_id: Some(project.old_connection_id.into()),
1679 new_peer_id: Some(session.connection_id.into()),
1680 },
1681 )
1682 .trace_err();
1683 }
1684
1685 broadcast(
1686 Some(session.connection_id),
1687 project
1688 .collaborators
1689 .iter()
1690 .map(|collaborator| collaborator.connection_id),
1691 |connection_id| {
1692 session.peer.forward_send(
1693 session.connection_id,
1694 connection_id,
1695 proto::UpdateProject {
1696 project_id: project.id.to_proto(),
1697 worktrees: project.worktrees.clone(),
1698 },
1699 )
1700 },
1701 );
1702 }
1703
1704 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1705
1706 let rejoined_room = rejoined_room.into_inner();
1707
1708 room = rejoined_room.room;
1709 channel = rejoined_room.channel;
1710 }
1711
1712 if let Some(channel) = channel {
1713 channel_updated(
1714 &channel,
1715 &room,
1716 &session.peer,
1717 &*session.connection_pool().await,
1718 );
1719 }
1720
1721 update_user_contacts(session.user_id(), &session).await?;
1722 Ok(())
1723}
1724
1725fn notify_rejoined_projects(
1726 rejoined_projects: &mut Vec<RejoinedProject>,
1727 session: &UserSession,
1728) -> Result<()> {
1729 for project in rejoined_projects.iter() {
1730 for collaborator in &project.collaborators {
1731 session
1732 .peer
1733 .send(
1734 collaborator.connection_id,
1735 proto::UpdateProjectCollaborator {
1736 project_id: project.id.to_proto(),
1737 old_peer_id: Some(project.old_connection_id.into()),
1738 new_peer_id: Some(session.connection_id.into()),
1739 },
1740 )
1741 .trace_err();
1742 }
1743 }
1744
1745 for project in rejoined_projects {
1746 for worktree in mem::take(&mut project.worktrees) {
1747 #[cfg(any(test, feature = "test-support"))]
1748 const MAX_CHUNK_SIZE: usize = 2;
1749 #[cfg(not(any(test, feature = "test-support")))]
1750 const MAX_CHUNK_SIZE: usize = 256;
1751
1752 // Stream this worktree's entries.
1753 let message = proto::UpdateWorktree {
1754 project_id: project.id.to_proto(),
1755 worktree_id: worktree.id,
1756 abs_path: worktree.abs_path.clone(),
1757 root_name: worktree.root_name,
1758 updated_entries: worktree.updated_entries,
1759 removed_entries: worktree.removed_entries,
1760 scan_id: worktree.scan_id,
1761 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1762 updated_repositories: worktree.updated_repositories,
1763 removed_repositories: worktree.removed_repositories,
1764 };
1765 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1766 session.peer.send(session.connection_id, update.clone())?;
1767 }
1768
1769 // Stream this worktree's diagnostics.
1770 for summary in worktree.diagnostic_summaries {
1771 session.peer.send(
1772 session.connection_id,
1773 proto::UpdateDiagnosticSummary {
1774 project_id: project.id.to_proto(),
1775 worktree_id: worktree.id,
1776 summary: Some(summary),
1777 },
1778 )?;
1779 }
1780
1781 for settings_file in worktree.settings_files {
1782 session.peer.send(
1783 session.connection_id,
1784 proto::UpdateWorktreeSettings {
1785 project_id: project.id.to_proto(),
1786 worktree_id: worktree.id,
1787 path: settings_file.path,
1788 content: Some(settings_file.content),
1789 },
1790 )?;
1791 }
1792 }
1793
1794 for language_server in &project.language_servers {
1795 session.peer.send(
1796 session.connection_id,
1797 proto::UpdateLanguageServer {
1798 project_id: project.id.to_proto(),
1799 language_server_id: language_server.id,
1800 variant: Some(
1801 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1802 proto::LspDiskBasedDiagnosticsUpdated {},
1803 ),
1804 ),
1805 },
1806 )?;
1807 }
1808 }
1809 Ok(())
1810}
1811
1812/// leave room disconnects from the room.
1813async fn leave_room(
1814 _: proto::LeaveRoom,
1815 response: Response<proto::LeaveRoom>,
1816 session: UserSession,
1817) -> Result<()> {
1818 leave_room_for_session(&session, session.connection_id).await?;
1819 response.send(proto::Ack {})?;
1820 Ok(())
1821}
1822
1823/// Updates the permissions of someone else in the room.
1824async fn set_room_participant_role(
1825 request: proto::SetRoomParticipantRole,
1826 response: Response<proto::SetRoomParticipantRole>,
1827 session: UserSession,
1828) -> Result<()> {
1829 let user_id = UserId::from_proto(request.user_id);
1830 let role = ChannelRole::from(request.role());
1831
1832 let (live_kit_room, can_publish) = {
1833 let room = session
1834 .db()
1835 .await
1836 .set_room_participant_role(
1837 session.user_id(),
1838 RoomId::from_proto(request.room_id),
1839 user_id,
1840 role,
1841 )
1842 .await?;
1843
1844 let live_kit_room = room.live_kit_room.clone();
1845 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1846 room_updated(&room, &session.peer);
1847 (live_kit_room, can_publish)
1848 };
1849
1850 if let Some(live_kit) = session.live_kit_client.as_ref() {
1851 live_kit
1852 .update_participant(
1853 live_kit_room.clone(),
1854 request.user_id.to_string(),
1855 live_kit_server::proto::ParticipantPermission {
1856 can_subscribe: true,
1857 can_publish,
1858 can_publish_data: can_publish,
1859 hidden: false,
1860 recorder: false,
1861 },
1862 )
1863 .await
1864 .trace_err();
1865 }
1866
1867 response.send(proto::Ack {})?;
1868 Ok(())
1869}
1870
1871/// Call someone else into the current room
1872async fn call(
1873 request: proto::Call,
1874 response: Response<proto::Call>,
1875 session: UserSession,
1876) -> Result<()> {
1877 let room_id = RoomId::from_proto(request.room_id);
1878 let calling_user_id = session.user_id();
1879 let calling_connection_id = session.connection_id;
1880 let called_user_id = UserId::from_proto(request.called_user_id);
1881 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1882 if !session
1883 .db()
1884 .await
1885 .has_contact(calling_user_id, called_user_id)
1886 .await?
1887 {
1888 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1889 }
1890
1891 let incoming_call = {
1892 let (room, incoming_call) = &mut *session
1893 .db()
1894 .await
1895 .call(
1896 room_id,
1897 calling_user_id,
1898 calling_connection_id,
1899 called_user_id,
1900 initial_project_id,
1901 )
1902 .await?;
1903 room_updated(&room, &session.peer);
1904 mem::take(incoming_call)
1905 };
1906 update_user_contacts(called_user_id, &session).await?;
1907
1908 let mut calls = session
1909 .connection_pool()
1910 .await
1911 .user_connection_ids(called_user_id)
1912 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1913 .collect::<FuturesUnordered<_>>();
1914
1915 while let Some(call_response) = calls.next().await {
1916 match call_response.as_ref() {
1917 Ok(_) => {
1918 response.send(proto::Ack {})?;
1919 return Ok(());
1920 }
1921 Err(_) => {
1922 call_response.trace_err();
1923 }
1924 }
1925 }
1926
1927 {
1928 let room = session
1929 .db()
1930 .await
1931 .call_failed(room_id, called_user_id)
1932 .await?;
1933 room_updated(&room, &session.peer);
1934 }
1935 update_user_contacts(called_user_id, &session).await?;
1936
1937 Err(anyhow!("failed to ring user"))?
1938}
1939
1940/// Cancel an outgoing call.
1941async fn cancel_call(
1942 request: proto::CancelCall,
1943 response: Response<proto::CancelCall>,
1944 session: UserSession,
1945) -> Result<()> {
1946 let called_user_id = UserId::from_proto(request.called_user_id);
1947 let room_id = RoomId::from_proto(request.room_id);
1948 {
1949 let room = session
1950 .db()
1951 .await
1952 .cancel_call(room_id, session.connection_id, called_user_id)
1953 .await?;
1954 room_updated(&room, &session.peer);
1955 }
1956
1957 for connection_id in session
1958 .connection_pool()
1959 .await
1960 .user_connection_ids(called_user_id)
1961 {
1962 session
1963 .peer
1964 .send(
1965 connection_id,
1966 proto::CallCanceled {
1967 room_id: room_id.to_proto(),
1968 },
1969 )
1970 .trace_err();
1971 }
1972 response.send(proto::Ack {})?;
1973
1974 update_user_contacts(called_user_id, &session).await?;
1975 Ok(())
1976}
1977
1978/// Decline an incoming call.
1979async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1980 let room_id = RoomId::from_proto(message.room_id);
1981 {
1982 let room = session
1983 .db()
1984 .await
1985 .decline_call(Some(room_id), session.user_id())
1986 .await?
1987 .ok_or_else(|| anyhow!("failed to decline call"))?;
1988 room_updated(&room, &session.peer);
1989 }
1990
1991 for connection_id in session
1992 .connection_pool()
1993 .await
1994 .user_connection_ids(session.user_id())
1995 {
1996 session
1997 .peer
1998 .send(
1999 connection_id,
2000 proto::CallCanceled {
2001 room_id: room_id.to_proto(),
2002 },
2003 )
2004 .trace_err();
2005 }
2006 update_user_contacts(session.user_id(), &session).await?;
2007 Ok(())
2008}
2009
2010/// Updates other participants in the room with your current location.
2011async fn update_participant_location(
2012 request: proto::UpdateParticipantLocation,
2013 response: Response<proto::UpdateParticipantLocation>,
2014 session: UserSession,
2015) -> Result<()> {
2016 let room_id = RoomId::from_proto(request.room_id);
2017 let location = request
2018 .location
2019 .ok_or_else(|| anyhow!("invalid location"))?;
2020
2021 let db = session.db().await;
2022 let room = db
2023 .update_room_participant_location(room_id, session.connection_id, location)
2024 .await?;
2025
2026 room_updated(&room, &session.peer);
2027 response.send(proto::Ack {})?;
2028 Ok(())
2029}
2030
2031/// Share a project into the room.
2032async fn share_project(
2033 request: proto::ShareProject,
2034 response: Response<proto::ShareProject>,
2035 session: UserSession,
2036) -> Result<()> {
2037 let (project_id, room) = &*session
2038 .db()
2039 .await
2040 .share_project(
2041 RoomId::from_proto(request.room_id),
2042 session.connection_id,
2043 &request.worktrees,
2044 request
2045 .dev_server_project_id
2046 .map(|id| DevServerProjectId::from_proto(id)),
2047 )
2048 .await?;
2049 response.send(proto::ShareProjectResponse {
2050 project_id: project_id.to_proto(),
2051 })?;
2052 room_updated(&room, &session.peer);
2053
2054 Ok(())
2055}
2056
2057/// Unshare a project from the room.
2058async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2059 let project_id = ProjectId::from_proto(message.project_id);
2060 unshare_project_internal(
2061 project_id,
2062 session.connection_id,
2063 session.user_id(),
2064 &session,
2065 )
2066 .await
2067}
2068
2069async fn unshare_project_internal(
2070 project_id: ProjectId,
2071 connection_id: ConnectionId,
2072 user_id: Option<UserId>,
2073 session: &Session,
2074) -> Result<()> {
2075 let delete = {
2076 let room_guard = session
2077 .db()
2078 .await
2079 .unshare_project(project_id, connection_id, user_id)
2080 .await?;
2081
2082 let (delete, room, guest_connection_ids) = &*room_guard;
2083
2084 let message = proto::UnshareProject {
2085 project_id: project_id.to_proto(),
2086 };
2087
2088 broadcast(
2089 Some(connection_id),
2090 guest_connection_ids.iter().copied(),
2091 |conn_id| session.peer.send(conn_id, message.clone()),
2092 );
2093 if let Some(room) = room {
2094 room_updated(room, &session.peer);
2095 }
2096
2097 *delete
2098 };
2099
2100 if delete {
2101 let db = session.db().await;
2102 db.delete_project(project_id).await?;
2103 }
2104
2105 Ok(())
2106}
2107
2108/// DevServer makes a project available online
2109async fn share_dev_server_project(
2110 request: proto::ShareDevServerProject,
2111 response: Response<proto::ShareDevServerProject>,
2112 session: DevServerSession,
2113) -> Result<()> {
2114 let (dev_server_project, user_id, status) = session
2115 .db()
2116 .await
2117 .share_dev_server_project(
2118 DevServerProjectId::from_proto(request.dev_server_project_id),
2119 session.dev_server_id(),
2120 session.connection_id,
2121 &request.worktrees,
2122 )
2123 .await?;
2124 let Some(project_id) = dev_server_project.project_id else {
2125 return Err(anyhow!("failed to share remote project"))?;
2126 };
2127
2128 send_dev_server_projects_update(user_id, status, &session).await;
2129
2130 response.send(proto::ShareProjectResponse { project_id })?;
2131
2132 Ok(())
2133}
2134
2135/// Join someone elses shared project.
2136async fn join_project(
2137 request: proto::JoinProject,
2138 response: Response<proto::JoinProject>,
2139 session: UserSession,
2140) -> Result<()> {
2141 let project_id = ProjectId::from_proto(request.project_id);
2142
2143 tracing::info!(%project_id, "join project");
2144
2145 let db = session.db().await;
2146 let (project, replica_id) = &mut *db
2147 .join_project(project_id, session.connection_id, session.user_id())
2148 .await?;
2149 drop(db);
2150 tracing::info!(%project_id, "join remote project");
2151 join_project_internal(response, session, project, replica_id)
2152}
2153
2154trait JoinProjectInternalResponse {
2155 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2156}
2157impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2158 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2159 Response::<proto::JoinProject>::send(self, result)
2160 }
2161}
2162impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2163 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2164 Response::<proto::JoinHostedProject>::send(self, result)
2165 }
2166}
2167
2168fn join_project_internal(
2169 response: impl JoinProjectInternalResponse,
2170 session: UserSession,
2171 project: &mut Project,
2172 replica_id: &ReplicaId,
2173) -> Result<()> {
2174 let collaborators = project
2175 .collaborators
2176 .iter()
2177 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2178 .map(|collaborator| collaborator.to_proto())
2179 .collect::<Vec<_>>();
2180 let project_id = project.id;
2181 let guest_user_id = session.user_id();
2182
2183 let worktrees = project
2184 .worktrees
2185 .iter()
2186 .map(|(id, worktree)| proto::WorktreeMetadata {
2187 id: *id,
2188 root_name: worktree.root_name.clone(),
2189 visible: worktree.visible,
2190 abs_path: worktree.abs_path.clone(),
2191 })
2192 .collect::<Vec<_>>();
2193
2194 let add_project_collaborator = proto::AddProjectCollaborator {
2195 project_id: project_id.to_proto(),
2196 collaborator: Some(proto::Collaborator {
2197 peer_id: Some(session.connection_id.into()),
2198 replica_id: replica_id.0 as u32,
2199 user_id: guest_user_id.to_proto(),
2200 }),
2201 };
2202
2203 for collaborator in &collaborators {
2204 session
2205 .peer
2206 .send(
2207 collaborator.peer_id.unwrap().into(),
2208 add_project_collaborator.clone(),
2209 )
2210 .trace_err();
2211 }
2212
2213 // First, we send the metadata associated with each worktree.
2214 response.send(proto::JoinProjectResponse {
2215 project_id: project.id.0 as u64,
2216 worktrees: worktrees.clone(),
2217 replica_id: replica_id.0 as u32,
2218 collaborators: collaborators.clone(),
2219 language_servers: project.language_servers.clone(),
2220 role: project.role.into(),
2221 dev_server_project_id: project
2222 .dev_server_project_id
2223 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2224 })?;
2225
2226 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2227 #[cfg(any(test, feature = "test-support"))]
2228 const MAX_CHUNK_SIZE: usize = 2;
2229 #[cfg(not(any(test, feature = "test-support")))]
2230 const MAX_CHUNK_SIZE: usize = 256;
2231
2232 // Stream this worktree's entries.
2233 let message = proto::UpdateWorktree {
2234 project_id: project_id.to_proto(),
2235 worktree_id,
2236 abs_path: worktree.abs_path.clone(),
2237 root_name: worktree.root_name,
2238 updated_entries: worktree.entries,
2239 removed_entries: Default::default(),
2240 scan_id: worktree.scan_id,
2241 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2242 updated_repositories: worktree.repository_entries.into_values().collect(),
2243 removed_repositories: Default::default(),
2244 };
2245 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2246 session.peer.send(session.connection_id, update.clone())?;
2247 }
2248
2249 // Stream this worktree's diagnostics.
2250 for summary in worktree.diagnostic_summaries {
2251 session.peer.send(
2252 session.connection_id,
2253 proto::UpdateDiagnosticSummary {
2254 project_id: project_id.to_proto(),
2255 worktree_id: worktree.id,
2256 summary: Some(summary),
2257 },
2258 )?;
2259 }
2260
2261 for settings_file in worktree.settings_files {
2262 session.peer.send(
2263 session.connection_id,
2264 proto::UpdateWorktreeSettings {
2265 project_id: project_id.to_proto(),
2266 worktree_id: worktree.id,
2267 path: settings_file.path,
2268 content: Some(settings_file.content),
2269 },
2270 )?;
2271 }
2272 }
2273
2274 for language_server in &project.language_servers {
2275 session.peer.send(
2276 session.connection_id,
2277 proto::UpdateLanguageServer {
2278 project_id: project_id.to_proto(),
2279 language_server_id: language_server.id,
2280 variant: Some(
2281 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2282 proto::LspDiskBasedDiagnosticsUpdated {},
2283 ),
2284 ),
2285 },
2286 )?;
2287 }
2288
2289 Ok(())
2290}
2291
2292/// Leave someone elses shared project.
2293async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2294 let sender_id = session.connection_id;
2295 let project_id = ProjectId::from_proto(request.project_id);
2296 let db = session.db().await;
2297 if db.is_hosted_project(project_id).await? {
2298 let project = db.leave_hosted_project(project_id, sender_id).await?;
2299 project_left(&project, &session);
2300 return Ok(());
2301 }
2302
2303 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2304 tracing::info!(
2305 %project_id,
2306 "leave project"
2307 );
2308
2309 project_left(&project, &session);
2310 if let Some(room) = room {
2311 room_updated(&room, &session.peer);
2312 }
2313
2314 Ok(())
2315}
2316
2317async fn join_hosted_project(
2318 request: proto::JoinHostedProject,
2319 response: Response<proto::JoinHostedProject>,
2320 session: UserSession,
2321) -> Result<()> {
2322 let (mut project, replica_id) = session
2323 .db()
2324 .await
2325 .join_hosted_project(
2326 ProjectId(request.project_id as i32),
2327 session.user_id(),
2328 session.connection_id,
2329 )
2330 .await?;
2331
2332 join_project_internal(response, session, &mut project, &replica_id)
2333}
2334
2335async fn list_remote_directory(
2336 request: proto::ListRemoteDirectory,
2337 response: Response<proto::ListRemoteDirectory>,
2338 session: UserSession,
2339) -> Result<()> {
2340 let dev_server_id = DevServerId(request.dev_server_id as i32);
2341 let dev_server_connection_id = session
2342 .connection_pool()
2343 .await
2344 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2345
2346 session
2347 .db()
2348 .await
2349 .get_dev_server_for_user(dev_server_id, session.user_id())
2350 .await?;
2351
2352 response.send(
2353 session
2354 .peer
2355 .forward_request(session.connection_id, dev_server_connection_id, request)
2356 .await?,
2357 )?;
2358 Ok(())
2359}
2360
2361async fn update_dev_server_project(
2362 request: proto::UpdateDevServerProject,
2363 response: Response<proto::UpdateDevServerProject>,
2364 session: UserSession,
2365) -> Result<()> {
2366 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2367
2368 let (dev_server_project, update) = session
2369 .db()
2370 .await
2371 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2372 .await?;
2373
2374 let projects = session
2375 .db()
2376 .await
2377 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2378 .await?;
2379
2380 let dev_server_connection_id = session
2381 .connection_pool()
2382 .await
2383 .dev_server_connection_id_supporting(
2384 dev_server_project.dev_server_id,
2385 ZedVersion::with_list_directory(),
2386 )?;
2387
2388 session.peer.send(
2389 dev_server_connection_id,
2390 proto::DevServerInstructions { projects },
2391 )?;
2392
2393 send_dev_server_projects_update(session.user_id(), update, &session).await;
2394
2395 response.send(proto::Ack {})
2396}
2397
2398async fn create_dev_server_project(
2399 request: proto::CreateDevServerProject,
2400 response: Response<proto::CreateDevServerProject>,
2401 session: UserSession,
2402) -> Result<()> {
2403 let dev_server_id = DevServerId(request.dev_server_id as i32);
2404 let dev_server_connection_id = session
2405 .connection_pool()
2406 .await
2407 .dev_server_connection_id(dev_server_id);
2408 let Some(dev_server_connection_id) = dev_server_connection_id else {
2409 Err(ErrorCode::DevServerOffline
2410 .message("Cannot create a remote project when the dev server is offline".to_string())
2411 .anyhow())?
2412 };
2413
2414 let path = request.path.clone();
2415 //Check that the path exists on the dev server
2416 session
2417 .peer
2418 .forward_request(
2419 session.connection_id,
2420 dev_server_connection_id,
2421 proto::ValidateDevServerProjectRequest { path: path.clone() },
2422 )
2423 .await?;
2424
2425 let (dev_server_project, update) = session
2426 .db()
2427 .await
2428 .create_dev_server_project(
2429 DevServerId(request.dev_server_id as i32),
2430 &request.path,
2431 session.user_id(),
2432 )
2433 .await?;
2434
2435 let projects = session
2436 .db()
2437 .await
2438 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2439 .await?;
2440
2441 session.peer.send(
2442 dev_server_connection_id,
2443 proto::DevServerInstructions { projects },
2444 )?;
2445
2446 send_dev_server_projects_update(session.user_id(), update, &session).await;
2447
2448 response.send(proto::CreateDevServerProjectResponse {
2449 dev_server_project: Some(dev_server_project.to_proto(None)),
2450 })?;
2451 Ok(())
2452}
2453
2454async fn create_dev_server(
2455 request: proto::CreateDevServer,
2456 response: Response<proto::CreateDevServer>,
2457 session: UserSession,
2458) -> Result<()> {
2459 let access_token = auth::random_token();
2460 let hashed_access_token = auth::hash_access_token(&access_token);
2461
2462 if request.name.is_empty() {
2463 return Err(proto::ErrorCode::Forbidden
2464 .message("Dev server name cannot be empty".to_string())
2465 .anyhow())?;
2466 }
2467
2468 let (dev_server, status) = session
2469 .db()
2470 .await
2471 .create_dev_server(
2472 &request.name,
2473 request.ssh_connection_string.as_deref(),
2474 &hashed_access_token,
2475 session.user_id(),
2476 )
2477 .await?;
2478
2479 send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481 response.send(proto::CreateDevServerResponse {
2482 dev_server_id: dev_server.id.0 as u64,
2483 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2484 name: request.name,
2485 })?;
2486 Ok(())
2487}
2488
2489async fn regenerate_dev_server_token(
2490 request: proto::RegenerateDevServerToken,
2491 response: Response<proto::RegenerateDevServerToken>,
2492 session: UserSession,
2493) -> Result<()> {
2494 let dev_server_id = DevServerId(request.dev_server_id as i32);
2495 let access_token = auth::random_token();
2496 let hashed_access_token = auth::hash_access_token(&access_token);
2497
2498 let connection_id = session
2499 .connection_pool()
2500 .await
2501 .dev_server_connection_id(dev_server_id);
2502 if let Some(connection_id) = connection_id {
2503 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2504 session.peer.send(
2505 connection_id,
2506 proto::ShutdownDevServer {
2507 reason: Some("dev server token was regenerated".to_string()),
2508 },
2509 )?;
2510 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2511 }
2512
2513 let status = session
2514 .db()
2515 .await
2516 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2517 .await?;
2518
2519 send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521 response.send(proto::RegenerateDevServerTokenResponse {
2522 dev_server_id: dev_server_id.to_proto(),
2523 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2524 })?;
2525 Ok(())
2526}
2527
2528async fn rename_dev_server(
2529 request: proto::RenameDevServer,
2530 response: Response<proto::RenameDevServer>,
2531 session: UserSession,
2532) -> Result<()> {
2533 if request.name.trim().is_empty() {
2534 return Err(proto::ErrorCode::Forbidden
2535 .message("Dev server name cannot be empty".to_string())
2536 .anyhow())?;
2537 }
2538
2539 let dev_server_id = DevServerId(request.dev_server_id as i32);
2540 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2541 if dev_server.user_id != session.user_id() {
2542 return Err(anyhow!(ErrorCode::Forbidden))?;
2543 }
2544
2545 let status = session
2546 .db()
2547 .await
2548 .rename_dev_server(
2549 dev_server_id,
2550 &request.name,
2551 request.ssh_connection_string.as_deref(),
2552 session.user_id(),
2553 )
2554 .await?;
2555
2556 send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558 response.send(proto::Ack {})?;
2559 Ok(())
2560}
2561
2562async fn delete_dev_server(
2563 request: proto::DeleteDevServer,
2564 response: Response<proto::DeleteDevServer>,
2565 session: UserSession,
2566) -> Result<()> {
2567 let dev_server_id = DevServerId(request.dev_server_id as i32);
2568 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2569 if dev_server.user_id != session.user_id() {
2570 return Err(anyhow!(ErrorCode::Forbidden))?;
2571 }
2572
2573 let connection_id = session
2574 .connection_pool()
2575 .await
2576 .dev_server_connection_id(dev_server_id);
2577 if let Some(connection_id) = connection_id {
2578 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2579 session.peer.send(
2580 connection_id,
2581 proto::ShutdownDevServer {
2582 reason: Some("dev server was deleted".to_string()),
2583 },
2584 )?;
2585 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2586 }
2587
2588 let status = session
2589 .db()
2590 .await
2591 .delete_dev_server(dev_server_id, session.user_id())
2592 .await?;
2593
2594 send_dev_server_projects_update(session.user_id(), status, &session).await;
2595
2596 response.send(proto::Ack {})?;
2597 Ok(())
2598}
2599
2600async fn delete_dev_server_project(
2601 request: proto::DeleteDevServerProject,
2602 response: Response<proto::DeleteDevServerProject>,
2603 session: UserSession,
2604) -> Result<()> {
2605 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2606 let dev_server_project = session
2607 .db()
2608 .await
2609 .get_dev_server_project(dev_server_project_id)
2610 .await?;
2611
2612 let dev_server = session
2613 .db()
2614 .await
2615 .get_dev_server(dev_server_project.dev_server_id)
2616 .await?;
2617 if dev_server.user_id != session.user_id() {
2618 return Err(anyhow!(ErrorCode::Forbidden))?;
2619 }
2620
2621 let dev_server_connection_id = session
2622 .connection_pool()
2623 .await
2624 .dev_server_connection_id(dev_server.id);
2625
2626 if let Some(dev_server_connection_id) = dev_server_connection_id {
2627 let project = session
2628 .db()
2629 .await
2630 .find_dev_server_project(dev_server_project_id)
2631 .await;
2632 if let Ok(project) = project {
2633 unshare_project_internal(
2634 project.id,
2635 dev_server_connection_id,
2636 Some(session.user_id()),
2637 &session,
2638 )
2639 .await?;
2640 }
2641 }
2642
2643 let (projects, status) = session
2644 .db()
2645 .await
2646 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2647 .await?;
2648
2649 if let Some(dev_server_connection_id) = dev_server_connection_id {
2650 session.peer.send(
2651 dev_server_connection_id,
2652 proto::DevServerInstructions { projects },
2653 )?;
2654 }
2655
2656 send_dev_server_projects_update(session.user_id(), status, &session).await;
2657
2658 response.send(proto::Ack {})?;
2659 Ok(())
2660}
2661
2662async fn rejoin_dev_server_projects(
2663 request: proto::RejoinRemoteProjects,
2664 response: Response<proto::RejoinRemoteProjects>,
2665 session: UserSession,
2666) -> Result<()> {
2667 let mut rejoined_projects = {
2668 let db = session.db().await;
2669 db.rejoin_dev_server_projects(
2670 &request.rejoined_projects,
2671 session.user_id(),
2672 session.0.connection_id,
2673 )
2674 .await?
2675 };
2676 response.send(proto::RejoinRemoteProjectsResponse {
2677 rejoined_projects: rejoined_projects
2678 .iter()
2679 .map(|project| project.to_proto())
2680 .collect(),
2681 })?;
2682 notify_rejoined_projects(&mut rejoined_projects, &session)
2683}
2684
2685async fn reconnect_dev_server(
2686 request: proto::ReconnectDevServer,
2687 response: Response<proto::ReconnectDevServer>,
2688 session: DevServerSession,
2689) -> Result<()> {
2690 let reshared_projects = {
2691 let db = session.db().await;
2692 db.reshare_dev_server_projects(
2693 &request.reshared_projects,
2694 session.dev_server_id(),
2695 session.0.connection_id,
2696 )
2697 .await?
2698 };
2699
2700 for project in &reshared_projects {
2701 for collaborator in &project.collaborators {
2702 session
2703 .peer
2704 .send(
2705 collaborator.connection_id,
2706 proto::UpdateProjectCollaborator {
2707 project_id: project.id.to_proto(),
2708 old_peer_id: Some(project.old_connection_id.into()),
2709 new_peer_id: Some(session.connection_id.into()),
2710 },
2711 )
2712 .trace_err();
2713 }
2714
2715 broadcast(
2716 Some(session.connection_id),
2717 project
2718 .collaborators
2719 .iter()
2720 .map(|collaborator| collaborator.connection_id),
2721 |connection_id| {
2722 session.peer.forward_send(
2723 session.connection_id,
2724 connection_id,
2725 proto::UpdateProject {
2726 project_id: project.id.to_proto(),
2727 worktrees: project.worktrees.clone(),
2728 },
2729 )
2730 },
2731 );
2732 }
2733
2734 response.send(proto::ReconnectDevServerResponse {
2735 reshared_projects: reshared_projects
2736 .iter()
2737 .map(|project| proto::ResharedProject {
2738 id: project.id.to_proto(),
2739 collaborators: project
2740 .collaborators
2741 .iter()
2742 .map(|collaborator| collaborator.to_proto())
2743 .collect(),
2744 })
2745 .collect(),
2746 })?;
2747
2748 Ok(())
2749}
2750
2751async fn shutdown_dev_server(
2752 _: proto::ShutdownDevServer,
2753 response: Response<proto::ShutdownDevServer>,
2754 session: DevServerSession,
2755) -> Result<()> {
2756 response.send(proto::Ack {})?;
2757 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2758 remove_dev_server_connection(session.dev_server_id(), &session).await
2759}
2760
2761async fn shutdown_dev_server_internal(
2762 dev_server_id: DevServerId,
2763 connection_id: ConnectionId,
2764 session: &Session,
2765) -> Result<()> {
2766 let (dev_server_projects, dev_server) = {
2767 let db = session.db().await;
2768 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2769 let dev_server = db.get_dev_server(dev_server_id).await?;
2770 (dev_server_projects, dev_server)
2771 };
2772
2773 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2774 unshare_project_internal(
2775 ProjectId::from_proto(project_id),
2776 connection_id,
2777 None,
2778 session,
2779 )
2780 .await?;
2781 }
2782
2783 session
2784 .connection_pool()
2785 .await
2786 .set_dev_server_offline(dev_server_id);
2787
2788 let status = session
2789 .db()
2790 .await
2791 .dev_server_projects_update(dev_server.user_id)
2792 .await?;
2793 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2794
2795 Ok(())
2796}
2797
2798async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2799 let dev_server_connection = session
2800 .connection_pool()
2801 .await
2802 .dev_server_connection_id(dev_server_id);
2803
2804 if let Some(dev_server_connection) = dev_server_connection {
2805 session
2806 .connection_pool()
2807 .await
2808 .remove_connection(dev_server_connection)?;
2809 }
2810 Ok(())
2811}
2812
2813/// Updates other participants with changes to the project
2814async fn update_project(
2815 request: proto::UpdateProject,
2816 response: Response<proto::UpdateProject>,
2817 session: Session,
2818) -> Result<()> {
2819 let project_id = ProjectId::from_proto(request.project_id);
2820 let (room, guest_connection_ids) = &*session
2821 .db()
2822 .await
2823 .update_project(project_id, session.connection_id, &request.worktrees)
2824 .await?;
2825 broadcast(
2826 Some(session.connection_id),
2827 guest_connection_ids.iter().copied(),
2828 |connection_id| {
2829 session
2830 .peer
2831 .forward_send(session.connection_id, connection_id, request.clone())
2832 },
2833 );
2834 if let Some(room) = room {
2835 room_updated(&room, &session.peer);
2836 }
2837 response.send(proto::Ack {})?;
2838
2839 Ok(())
2840}
2841
2842/// Updates other participants with changes to the worktree
2843async fn update_worktree(
2844 request: proto::UpdateWorktree,
2845 response: Response<proto::UpdateWorktree>,
2846 session: Session,
2847) -> Result<()> {
2848 let guest_connection_ids = session
2849 .db()
2850 .await
2851 .update_worktree(&request, session.connection_id)
2852 .await?;
2853
2854 broadcast(
2855 Some(session.connection_id),
2856 guest_connection_ids.iter().copied(),
2857 |connection_id| {
2858 session
2859 .peer
2860 .forward_send(session.connection_id, connection_id, request.clone())
2861 },
2862 );
2863 response.send(proto::Ack {})?;
2864 Ok(())
2865}
2866
2867/// Updates other participants with changes to the diagnostics
2868async fn update_diagnostic_summary(
2869 message: proto::UpdateDiagnosticSummary,
2870 session: Session,
2871) -> Result<()> {
2872 let guest_connection_ids = session
2873 .db()
2874 .await
2875 .update_diagnostic_summary(&message, session.connection_id)
2876 .await?;
2877
2878 broadcast(
2879 Some(session.connection_id),
2880 guest_connection_ids.iter().copied(),
2881 |connection_id| {
2882 session
2883 .peer
2884 .forward_send(session.connection_id, connection_id, message.clone())
2885 },
2886 );
2887
2888 Ok(())
2889}
2890
2891/// Updates other participants with changes to the worktree settings
2892async fn update_worktree_settings(
2893 message: proto::UpdateWorktreeSettings,
2894 session: Session,
2895) -> Result<()> {
2896 let guest_connection_ids = session
2897 .db()
2898 .await
2899 .update_worktree_settings(&message, session.connection_id)
2900 .await?;
2901
2902 broadcast(
2903 Some(session.connection_id),
2904 guest_connection_ids.iter().copied(),
2905 |connection_id| {
2906 session
2907 .peer
2908 .forward_send(session.connection_id, connection_id, message.clone())
2909 },
2910 );
2911
2912 Ok(())
2913}
2914
2915/// Notify other participants that a language server has started.
2916async fn start_language_server(
2917 request: proto::StartLanguageServer,
2918 session: Session,
2919) -> Result<()> {
2920 let guest_connection_ids = session
2921 .db()
2922 .await
2923 .start_language_server(&request, session.connection_id)
2924 .await?;
2925
2926 broadcast(
2927 Some(session.connection_id),
2928 guest_connection_ids.iter().copied(),
2929 |connection_id| {
2930 session
2931 .peer
2932 .forward_send(session.connection_id, connection_id, request.clone())
2933 },
2934 );
2935 Ok(())
2936}
2937
2938/// Notify other participants that a language server has changed.
2939async fn update_language_server(
2940 request: proto::UpdateLanguageServer,
2941 session: Session,
2942) -> Result<()> {
2943 let project_id = ProjectId::from_proto(request.project_id);
2944 let project_connection_ids = session
2945 .db()
2946 .await
2947 .project_connection_ids(project_id, session.connection_id, true)
2948 .await?;
2949 broadcast(
2950 Some(session.connection_id),
2951 project_connection_ids.iter().copied(),
2952 |connection_id| {
2953 session
2954 .peer
2955 .forward_send(session.connection_id, connection_id, request.clone())
2956 },
2957 );
2958 Ok(())
2959}
2960
2961/// forward a project request to the host. These requests should be read only
2962/// as guests are allowed to send them.
2963async fn forward_read_only_project_request<T>(
2964 request: T,
2965 response: Response<T>,
2966 session: UserSession,
2967) -> Result<()>
2968where
2969 T: EntityMessage + RequestMessage,
2970{
2971 let project_id = ProjectId::from_proto(request.remote_entity_id());
2972 let host_connection_id = session
2973 .db()
2974 .await
2975 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2976 .await?;
2977 let payload = session
2978 .peer
2979 .forward_request(session.connection_id, host_connection_id, request)
2980 .await?;
2981 response.send(payload)?;
2982 Ok(())
2983}
2984
2985/// forward a project request to the dev server. Only allowed
2986/// if it's your dev server.
2987async fn forward_project_request_for_owner<T>(
2988 request: T,
2989 response: Response<T>,
2990 session: UserSession,
2991) -> Result<()>
2992where
2993 T: EntityMessage + RequestMessage,
2994{
2995 let project_id = ProjectId::from_proto(request.remote_entity_id());
2996
2997 let host_connection_id = session
2998 .db()
2999 .await
3000 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3001 .await?;
3002 let payload = session
3003 .peer
3004 .forward_request(session.connection_id, host_connection_id, request)
3005 .await?;
3006 response.send(payload)?;
3007 Ok(())
3008}
3009
3010/// forward a project request to the host. These requests are disallowed
3011/// for guests.
3012async fn forward_mutating_project_request<T>(
3013 request: T,
3014 response: Response<T>,
3015 session: UserSession,
3016) -> Result<()>
3017where
3018 T: EntityMessage + RequestMessage,
3019{
3020 let project_id = ProjectId::from_proto(request.remote_entity_id());
3021
3022 let host_connection_id = session
3023 .db()
3024 .await
3025 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3026 .await?;
3027 let payload = session
3028 .peer
3029 .forward_request(session.connection_id, host_connection_id, request)
3030 .await?;
3031 response.send(payload)?;
3032 Ok(())
3033}
3034
3035/// forward a project request to the host. These requests are disallowed
3036/// for guests.
3037async fn forward_versioned_mutating_project_request<T>(
3038 request: T,
3039 response: Response<T>,
3040 session: UserSession,
3041) -> Result<()>
3042where
3043 T: EntityMessage + RequestMessage + VersionedMessage,
3044{
3045 let project_id = ProjectId::from_proto(request.remote_entity_id());
3046
3047 let host_connection_id = session
3048 .db()
3049 .await
3050 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3051 .await?;
3052 if let Some(host_version) = session
3053 .connection_pool()
3054 .await
3055 .connection(host_connection_id)
3056 .map(|c| c.zed_version)
3057 {
3058 if let Some(min_required_version) = request.required_host_version() {
3059 if min_required_version > host_version {
3060 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3061 .with_tag("required", &min_required_version.to_string())))?;
3062 }
3063 }
3064 }
3065
3066 let payload = session
3067 .peer
3068 .forward_request(session.connection_id, host_connection_id, request)
3069 .await?;
3070 response.send(payload)?;
3071 Ok(())
3072}
3073
3074/// Notify other participants that a new buffer has been created
3075async fn create_buffer_for_peer(
3076 request: proto::CreateBufferForPeer,
3077 session: Session,
3078) -> Result<()> {
3079 session
3080 .db()
3081 .await
3082 .check_user_is_project_host(
3083 ProjectId::from_proto(request.project_id),
3084 session.connection_id,
3085 )
3086 .await?;
3087 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3088 session
3089 .peer
3090 .forward_send(session.connection_id, peer_id.into(), request)?;
3091 Ok(())
3092}
3093
3094/// Notify other participants that a buffer has been updated. This is
3095/// allowed for guests as long as the update is limited to selections.
3096async fn update_buffer(
3097 request: proto::UpdateBuffer,
3098 response: Response<proto::UpdateBuffer>,
3099 session: Session,
3100) -> Result<()> {
3101 let project_id = ProjectId::from_proto(request.project_id);
3102 let mut capability = Capability::ReadOnly;
3103
3104 for op in request.operations.iter() {
3105 match op.variant {
3106 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3107 Some(_) => capability = Capability::ReadWrite,
3108 }
3109 }
3110
3111 let host = {
3112 let guard = session
3113 .db()
3114 .await
3115 .connections_for_buffer_update(
3116 project_id,
3117 session.principal_id(),
3118 session.connection_id,
3119 capability,
3120 )
3121 .await?;
3122
3123 let (host, guests) = &*guard;
3124
3125 broadcast(
3126 Some(session.connection_id),
3127 guests.clone(),
3128 |connection_id| {
3129 session
3130 .peer
3131 .forward_send(session.connection_id, connection_id, request.clone())
3132 },
3133 );
3134
3135 *host
3136 };
3137
3138 if host != session.connection_id {
3139 session
3140 .peer
3141 .forward_request(session.connection_id, host, request.clone())
3142 .await?;
3143 }
3144
3145 response.send(proto::Ack {})?;
3146 Ok(())
3147}
3148
3149async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3150 let project_id = ProjectId::from_proto(message.project_id);
3151
3152 let operation = message.operation.as_ref().context("invalid operation")?;
3153 let capability = match operation.variant.as_ref() {
3154 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3155 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3156 match buffer_op.variant {
3157 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3158 Capability::ReadOnly
3159 }
3160 _ => Capability::ReadWrite,
3161 }
3162 } else {
3163 Capability::ReadWrite
3164 }
3165 }
3166 Some(_) => Capability::ReadWrite,
3167 None => Capability::ReadOnly,
3168 };
3169
3170 let guard = session
3171 .db()
3172 .await
3173 .connections_for_buffer_update(
3174 project_id,
3175 session.principal_id(),
3176 session.connection_id,
3177 capability,
3178 )
3179 .await?;
3180
3181 let (host, guests) = &*guard;
3182
3183 broadcast(
3184 Some(session.connection_id),
3185 guests.iter().chain([host]).copied(),
3186 |connection_id| {
3187 session
3188 .peer
3189 .forward_send(session.connection_id, connection_id, message.clone())
3190 },
3191 );
3192
3193 Ok(())
3194}
3195
3196/// Notify other participants that a project has been updated.
3197async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3198 request: T,
3199 session: Session,
3200) -> Result<()> {
3201 let project_id = ProjectId::from_proto(request.remote_entity_id());
3202 let project_connection_ids = session
3203 .db()
3204 .await
3205 .project_connection_ids(project_id, session.connection_id, false)
3206 .await?;
3207
3208 broadcast(
3209 Some(session.connection_id),
3210 project_connection_ids.iter().copied(),
3211 |connection_id| {
3212 session
3213 .peer
3214 .forward_send(session.connection_id, connection_id, request.clone())
3215 },
3216 );
3217 Ok(())
3218}
3219
3220/// Start following another user in a call.
3221async fn follow(
3222 request: proto::Follow,
3223 response: Response<proto::Follow>,
3224 session: UserSession,
3225) -> Result<()> {
3226 let room_id = RoomId::from_proto(request.room_id);
3227 let project_id = request.project_id.map(ProjectId::from_proto);
3228 let leader_id = request
3229 .leader_id
3230 .ok_or_else(|| anyhow!("invalid leader id"))?
3231 .into();
3232 let follower_id = session.connection_id;
3233
3234 session
3235 .db()
3236 .await
3237 .check_room_participants(room_id, leader_id, session.connection_id)
3238 .await?;
3239
3240 let response_payload = session
3241 .peer
3242 .forward_request(session.connection_id, leader_id, request)
3243 .await?;
3244 response.send(response_payload)?;
3245
3246 if let Some(project_id) = project_id {
3247 let room = session
3248 .db()
3249 .await
3250 .follow(room_id, project_id, leader_id, follower_id)
3251 .await?;
3252 room_updated(&room, &session.peer);
3253 }
3254
3255 Ok(())
3256}
3257
3258/// Stop following another user in a call.
3259async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3260 let room_id = RoomId::from_proto(request.room_id);
3261 let project_id = request.project_id.map(ProjectId::from_proto);
3262 let leader_id = request
3263 .leader_id
3264 .ok_or_else(|| anyhow!("invalid leader id"))?
3265 .into();
3266 let follower_id = session.connection_id;
3267
3268 session
3269 .db()
3270 .await
3271 .check_room_participants(room_id, leader_id, session.connection_id)
3272 .await?;
3273
3274 session
3275 .peer
3276 .forward_send(session.connection_id, leader_id, request)?;
3277
3278 if let Some(project_id) = project_id {
3279 let room = session
3280 .db()
3281 .await
3282 .unfollow(room_id, project_id, leader_id, follower_id)
3283 .await?;
3284 room_updated(&room, &session.peer);
3285 }
3286
3287 Ok(())
3288}
3289
3290/// Notify everyone following you of your current location.
3291async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3292 let room_id = RoomId::from_proto(request.room_id);
3293 let database = session.db.lock().await;
3294
3295 let connection_ids = if let Some(project_id) = request.project_id {
3296 let project_id = ProjectId::from_proto(project_id);
3297 database
3298 .project_connection_ids(project_id, session.connection_id, true)
3299 .await?
3300 } else {
3301 database
3302 .room_connection_ids(room_id, session.connection_id)
3303 .await?
3304 };
3305
3306 // For now, don't send view update messages back to that view's current leader.
3307 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3308 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3309 _ => None,
3310 });
3311
3312 for connection_id in connection_ids.iter().cloned() {
3313 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3314 session
3315 .peer
3316 .forward_send(session.connection_id, connection_id, request.clone())?;
3317 }
3318 }
3319 Ok(())
3320}
3321
3322/// Get public data about users.
3323async fn get_users(
3324 request: proto::GetUsers,
3325 response: Response<proto::GetUsers>,
3326 session: Session,
3327) -> Result<()> {
3328 let user_ids = request
3329 .user_ids
3330 .into_iter()
3331 .map(UserId::from_proto)
3332 .collect();
3333 let users = session
3334 .db()
3335 .await
3336 .get_users_by_ids(user_ids)
3337 .await?
3338 .into_iter()
3339 .map(|user| proto::User {
3340 id: user.id.to_proto(),
3341 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3342 github_login: user.github_login,
3343 })
3344 .collect();
3345 response.send(proto::UsersResponse { users })?;
3346 Ok(())
3347}
3348
3349/// Search for users (to invite) buy Github login
3350async fn fuzzy_search_users(
3351 request: proto::FuzzySearchUsers,
3352 response: Response<proto::FuzzySearchUsers>,
3353 session: UserSession,
3354) -> Result<()> {
3355 let query = request.query;
3356 let users = match query.len() {
3357 0 => vec![],
3358 1 | 2 => session
3359 .db()
3360 .await
3361 .get_user_by_github_login(&query)
3362 .await?
3363 .into_iter()
3364 .collect(),
3365 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3366 };
3367 let users = users
3368 .into_iter()
3369 .filter(|user| user.id != session.user_id())
3370 .map(|user| proto::User {
3371 id: user.id.to_proto(),
3372 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3373 github_login: user.github_login,
3374 })
3375 .collect();
3376 response.send(proto::UsersResponse { users })?;
3377 Ok(())
3378}
3379
3380/// Send a contact request to another user.
3381async fn request_contact(
3382 request: proto::RequestContact,
3383 response: Response<proto::RequestContact>,
3384 session: UserSession,
3385) -> Result<()> {
3386 let requester_id = session.user_id();
3387 let responder_id = UserId::from_proto(request.responder_id);
3388 if requester_id == responder_id {
3389 return Err(anyhow!("cannot add yourself as a contact"))?;
3390 }
3391
3392 let notifications = session
3393 .db()
3394 .await
3395 .send_contact_request(requester_id, responder_id)
3396 .await?;
3397
3398 // Update outgoing contact requests of requester
3399 let mut update = proto::UpdateContacts::default();
3400 update.outgoing_requests.push(responder_id.to_proto());
3401 for connection_id in session
3402 .connection_pool()
3403 .await
3404 .user_connection_ids(requester_id)
3405 {
3406 session.peer.send(connection_id, update.clone())?;
3407 }
3408
3409 // Update incoming contact requests of responder
3410 let mut update = proto::UpdateContacts::default();
3411 update
3412 .incoming_requests
3413 .push(proto::IncomingContactRequest {
3414 requester_id: requester_id.to_proto(),
3415 });
3416 let connection_pool = session.connection_pool().await;
3417 for connection_id in connection_pool.user_connection_ids(responder_id) {
3418 session.peer.send(connection_id, update.clone())?;
3419 }
3420
3421 send_notifications(&connection_pool, &session.peer, notifications);
3422
3423 response.send(proto::Ack {})?;
3424 Ok(())
3425}
3426
3427/// Accept or decline a contact request
3428async fn respond_to_contact_request(
3429 request: proto::RespondToContactRequest,
3430 response: Response<proto::RespondToContactRequest>,
3431 session: UserSession,
3432) -> Result<()> {
3433 let responder_id = session.user_id();
3434 let requester_id = UserId::from_proto(request.requester_id);
3435 let db = session.db().await;
3436 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3437 db.dismiss_contact_notification(responder_id, requester_id)
3438 .await?;
3439 } else {
3440 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3441
3442 let notifications = db
3443 .respond_to_contact_request(responder_id, requester_id, accept)
3444 .await?;
3445 let requester_busy = db.is_user_busy(requester_id).await?;
3446 let responder_busy = db.is_user_busy(responder_id).await?;
3447
3448 let pool = session.connection_pool().await;
3449 // Update responder with new contact
3450 let mut update = proto::UpdateContacts::default();
3451 if accept {
3452 update
3453 .contacts
3454 .push(contact_for_user(requester_id, requester_busy, &pool));
3455 }
3456 update
3457 .remove_incoming_requests
3458 .push(requester_id.to_proto());
3459 for connection_id in pool.user_connection_ids(responder_id) {
3460 session.peer.send(connection_id, update.clone())?;
3461 }
3462
3463 // Update requester with new contact
3464 let mut update = proto::UpdateContacts::default();
3465 if accept {
3466 update
3467 .contacts
3468 .push(contact_for_user(responder_id, responder_busy, &pool));
3469 }
3470 update
3471 .remove_outgoing_requests
3472 .push(responder_id.to_proto());
3473
3474 for connection_id in pool.user_connection_ids(requester_id) {
3475 session.peer.send(connection_id, update.clone())?;
3476 }
3477
3478 send_notifications(&pool, &session.peer, notifications);
3479 }
3480
3481 response.send(proto::Ack {})?;
3482 Ok(())
3483}
3484
3485/// Remove a contact.
3486async fn remove_contact(
3487 request: proto::RemoveContact,
3488 response: Response<proto::RemoveContact>,
3489 session: UserSession,
3490) -> Result<()> {
3491 let requester_id = session.user_id();
3492 let responder_id = UserId::from_proto(request.user_id);
3493 let db = session.db().await;
3494 let (contact_accepted, deleted_notification_id) =
3495 db.remove_contact(requester_id, responder_id).await?;
3496
3497 let pool = session.connection_pool().await;
3498 // Update outgoing contact requests of requester
3499 let mut update = proto::UpdateContacts::default();
3500 if contact_accepted {
3501 update.remove_contacts.push(responder_id.to_proto());
3502 } else {
3503 update
3504 .remove_outgoing_requests
3505 .push(responder_id.to_proto());
3506 }
3507 for connection_id in pool.user_connection_ids(requester_id) {
3508 session.peer.send(connection_id, update.clone())?;
3509 }
3510
3511 // Update incoming contact requests of responder
3512 let mut update = proto::UpdateContacts::default();
3513 if contact_accepted {
3514 update.remove_contacts.push(requester_id.to_proto());
3515 } else {
3516 update
3517 .remove_incoming_requests
3518 .push(requester_id.to_proto());
3519 }
3520 for connection_id in pool.user_connection_ids(responder_id) {
3521 session.peer.send(connection_id, update.clone())?;
3522 if let Some(notification_id) = deleted_notification_id {
3523 session.peer.send(
3524 connection_id,
3525 proto::DeleteNotification {
3526 notification_id: notification_id.to_proto(),
3527 },
3528 )?;
3529 }
3530 }
3531
3532 response.send(proto::Ack {})?;
3533 Ok(())
3534}
3535
3536fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3537 version.0.minor() < 139
3538}
3539
3540async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
3541 let db = session.db().await;
3542 let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
3543
3544 let plan = if session.is_staff() || !active_subscriptions.is_empty() {
3545 proto::Plan::ZedPro
3546 } else {
3547 proto::Plan::Free
3548 };
3549
3550 session
3551 .peer
3552 .send(
3553 session.connection_id,
3554 proto::UpdateUserPlan { plan: plan.into() },
3555 )
3556 .trace_err();
3557
3558 Ok(())
3559}
3560
3561async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3562 subscribe_user_to_channels(
3563 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3564 &session,
3565 )
3566 .await?;
3567 Ok(())
3568}
3569
3570async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3571 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3572 let mut pool = session.connection_pool().await;
3573 for membership in &channels_for_user.channel_memberships {
3574 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3575 }
3576 session.peer.send(
3577 session.connection_id,
3578 build_update_user_channels(&channels_for_user),
3579 )?;
3580 session.peer.send(
3581 session.connection_id,
3582 build_channels_update(channels_for_user),
3583 )?;
3584 Ok(())
3585}
3586
3587/// Creates a new channel.
3588async fn create_channel(
3589 request: proto::CreateChannel,
3590 response: Response<proto::CreateChannel>,
3591 session: UserSession,
3592) -> Result<()> {
3593 let db = session.db().await;
3594
3595 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3596 let (channel, membership) = db
3597 .create_channel(&request.name, parent_id, session.user_id())
3598 .await?;
3599
3600 let root_id = channel.root_id();
3601 let channel = Channel::from_model(channel);
3602
3603 response.send(proto::CreateChannelResponse {
3604 channel: Some(channel.to_proto()),
3605 parent_id: request.parent_id,
3606 })?;
3607
3608 let mut connection_pool = session.connection_pool().await;
3609 if let Some(membership) = membership {
3610 connection_pool.subscribe_to_channel(
3611 membership.user_id,
3612 membership.channel_id,
3613 membership.role,
3614 );
3615 let update = proto::UpdateUserChannels {
3616 channel_memberships: vec![proto::ChannelMembership {
3617 channel_id: membership.channel_id.to_proto(),
3618 role: membership.role.into(),
3619 }],
3620 ..Default::default()
3621 };
3622 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3623 session.peer.send(connection_id, update.clone())?;
3624 }
3625 }
3626
3627 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3628 if !role.can_see_channel(channel.visibility) {
3629 continue;
3630 }
3631
3632 let update = proto::UpdateChannels {
3633 channels: vec![channel.to_proto()],
3634 ..Default::default()
3635 };
3636 session.peer.send(connection_id, update.clone())?;
3637 }
3638
3639 Ok(())
3640}
3641
3642/// Delete a channel
3643async fn delete_channel(
3644 request: proto::DeleteChannel,
3645 response: Response<proto::DeleteChannel>,
3646 session: UserSession,
3647) -> Result<()> {
3648 let db = session.db().await;
3649
3650 let channel_id = request.channel_id;
3651 let (root_channel, removed_channels) = db
3652 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3653 .await?;
3654 response.send(proto::Ack {})?;
3655
3656 // Notify members of removed channels
3657 let mut update = proto::UpdateChannels::default();
3658 update
3659 .delete_channels
3660 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3661
3662 let connection_pool = session.connection_pool().await;
3663 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3664 session.peer.send(connection_id, update.clone())?;
3665 }
3666
3667 Ok(())
3668}
3669
3670/// Invite someone to join a channel.
3671async fn invite_channel_member(
3672 request: proto::InviteChannelMember,
3673 response: Response<proto::InviteChannelMember>,
3674 session: UserSession,
3675) -> Result<()> {
3676 let db = session.db().await;
3677 let channel_id = ChannelId::from_proto(request.channel_id);
3678 let invitee_id = UserId::from_proto(request.user_id);
3679 let InviteMemberResult {
3680 channel,
3681 notifications,
3682 } = db
3683 .invite_channel_member(
3684 channel_id,
3685 invitee_id,
3686 session.user_id(),
3687 request.role().into(),
3688 )
3689 .await?;
3690
3691 let update = proto::UpdateChannels {
3692 channel_invitations: vec![channel.to_proto()],
3693 ..Default::default()
3694 };
3695
3696 let connection_pool = session.connection_pool().await;
3697 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3698 session.peer.send(connection_id, update.clone())?;
3699 }
3700
3701 send_notifications(&connection_pool, &session.peer, notifications);
3702
3703 response.send(proto::Ack {})?;
3704 Ok(())
3705}
3706
3707/// remove someone from a channel
3708async fn remove_channel_member(
3709 request: proto::RemoveChannelMember,
3710 response: Response<proto::RemoveChannelMember>,
3711 session: UserSession,
3712) -> Result<()> {
3713 let db = session.db().await;
3714 let channel_id = ChannelId::from_proto(request.channel_id);
3715 let member_id = UserId::from_proto(request.user_id);
3716
3717 let RemoveChannelMemberResult {
3718 membership_update,
3719 notification_id,
3720 } = db
3721 .remove_channel_member(channel_id, member_id, session.user_id())
3722 .await?;
3723
3724 let mut connection_pool = session.connection_pool().await;
3725 notify_membership_updated(
3726 &mut connection_pool,
3727 membership_update,
3728 member_id,
3729 &session.peer,
3730 );
3731 for connection_id in connection_pool.user_connection_ids(member_id) {
3732 if let Some(notification_id) = notification_id {
3733 session
3734 .peer
3735 .send(
3736 connection_id,
3737 proto::DeleteNotification {
3738 notification_id: notification_id.to_proto(),
3739 },
3740 )
3741 .trace_err();
3742 }
3743 }
3744
3745 response.send(proto::Ack {})?;
3746 Ok(())
3747}
3748
3749/// Toggle the channel between public and private.
3750/// Care is taken to maintain the invariant that public channels only descend from public channels,
3751/// (though members-only channels can appear at any point in the hierarchy).
3752async fn set_channel_visibility(
3753 request: proto::SetChannelVisibility,
3754 response: Response<proto::SetChannelVisibility>,
3755 session: UserSession,
3756) -> Result<()> {
3757 let db = session.db().await;
3758 let channel_id = ChannelId::from_proto(request.channel_id);
3759 let visibility = request.visibility().into();
3760
3761 let channel_model = db
3762 .set_channel_visibility(channel_id, visibility, session.user_id())
3763 .await?;
3764 let root_id = channel_model.root_id();
3765 let channel = Channel::from_model(channel_model);
3766
3767 let mut connection_pool = session.connection_pool().await;
3768 for (user_id, role) in connection_pool
3769 .channel_user_ids(root_id)
3770 .collect::<Vec<_>>()
3771 .into_iter()
3772 {
3773 let update = if role.can_see_channel(channel.visibility) {
3774 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3775 proto::UpdateChannels {
3776 channels: vec![channel.to_proto()],
3777 ..Default::default()
3778 }
3779 } else {
3780 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3781 proto::UpdateChannels {
3782 delete_channels: vec![channel.id.to_proto()],
3783 ..Default::default()
3784 }
3785 };
3786
3787 for connection_id in connection_pool.user_connection_ids(user_id) {
3788 session.peer.send(connection_id, update.clone())?;
3789 }
3790 }
3791
3792 response.send(proto::Ack {})?;
3793 Ok(())
3794}
3795
3796/// Alter the role for a user in the channel.
3797async fn set_channel_member_role(
3798 request: proto::SetChannelMemberRole,
3799 response: Response<proto::SetChannelMemberRole>,
3800 session: UserSession,
3801) -> Result<()> {
3802 let db = session.db().await;
3803 let channel_id = ChannelId::from_proto(request.channel_id);
3804 let member_id = UserId::from_proto(request.user_id);
3805 let result = db
3806 .set_channel_member_role(
3807 channel_id,
3808 session.user_id(),
3809 member_id,
3810 request.role().into(),
3811 )
3812 .await?;
3813
3814 match result {
3815 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3816 let mut connection_pool = session.connection_pool().await;
3817 notify_membership_updated(
3818 &mut connection_pool,
3819 membership_update,
3820 member_id,
3821 &session.peer,
3822 )
3823 }
3824 db::SetMemberRoleResult::InviteUpdated(channel) => {
3825 let update = proto::UpdateChannels {
3826 channel_invitations: vec![channel.to_proto()],
3827 ..Default::default()
3828 };
3829
3830 for connection_id in session
3831 .connection_pool()
3832 .await
3833 .user_connection_ids(member_id)
3834 {
3835 session.peer.send(connection_id, update.clone())?;
3836 }
3837 }
3838 }
3839
3840 response.send(proto::Ack {})?;
3841 Ok(())
3842}
3843
3844/// Change the name of a channel
3845async fn rename_channel(
3846 request: proto::RenameChannel,
3847 response: Response<proto::RenameChannel>,
3848 session: UserSession,
3849) -> Result<()> {
3850 let db = session.db().await;
3851 let channel_id = ChannelId::from_proto(request.channel_id);
3852 let channel_model = db
3853 .rename_channel(channel_id, session.user_id(), &request.name)
3854 .await?;
3855 let root_id = channel_model.root_id();
3856 let channel = Channel::from_model(channel_model);
3857
3858 response.send(proto::RenameChannelResponse {
3859 channel: Some(channel.to_proto()),
3860 })?;
3861
3862 let connection_pool = session.connection_pool().await;
3863 let update = proto::UpdateChannels {
3864 channels: vec![channel.to_proto()],
3865 ..Default::default()
3866 };
3867 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3868 if role.can_see_channel(channel.visibility) {
3869 session.peer.send(connection_id, update.clone())?;
3870 }
3871 }
3872
3873 Ok(())
3874}
3875
3876/// Move a channel to a new parent.
3877async fn move_channel(
3878 request: proto::MoveChannel,
3879 response: Response<proto::MoveChannel>,
3880 session: UserSession,
3881) -> Result<()> {
3882 let channel_id = ChannelId::from_proto(request.channel_id);
3883 let to = ChannelId::from_proto(request.to);
3884
3885 let (root_id, channels) = session
3886 .db()
3887 .await
3888 .move_channel(channel_id, to, session.user_id())
3889 .await?;
3890
3891 let connection_pool = session.connection_pool().await;
3892 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3893 let channels = channels
3894 .iter()
3895 .filter_map(|channel| {
3896 if role.can_see_channel(channel.visibility) {
3897 Some(channel.to_proto())
3898 } else {
3899 None
3900 }
3901 })
3902 .collect::<Vec<_>>();
3903 if channels.is_empty() {
3904 continue;
3905 }
3906
3907 let update = proto::UpdateChannels {
3908 channels,
3909 ..Default::default()
3910 };
3911
3912 session.peer.send(connection_id, update.clone())?;
3913 }
3914
3915 response.send(Ack {})?;
3916 Ok(())
3917}
3918
3919/// Get the list of channel members
3920async fn get_channel_members(
3921 request: proto::GetChannelMembers,
3922 response: Response<proto::GetChannelMembers>,
3923 session: UserSession,
3924) -> Result<()> {
3925 let db = session.db().await;
3926 let channel_id = ChannelId::from_proto(request.channel_id);
3927 let limit = if request.limit == 0 {
3928 u16::MAX as u64
3929 } else {
3930 request.limit
3931 };
3932 let (members, users) = db
3933 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3934 .await?;
3935 response.send(proto::GetChannelMembersResponse { members, users })?;
3936 Ok(())
3937}
3938
3939/// Accept or decline a channel invitation.
3940async fn respond_to_channel_invite(
3941 request: proto::RespondToChannelInvite,
3942 response: Response<proto::RespondToChannelInvite>,
3943 session: UserSession,
3944) -> Result<()> {
3945 let db = session.db().await;
3946 let channel_id = ChannelId::from_proto(request.channel_id);
3947 let RespondToChannelInvite {
3948 membership_update,
3949 notifications,
3950 } = db
3951 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3952 .await?;
3953
3954 let mut connection_pool = session.connection_pool().await;
3955 if let Some(membership_update) = membership_update {
3956 notify_membership_updated(
3957 &mut connection_pool,
3958 membership_update,
3959 session.user_id(),
3960 &session.peer,
3961 );
3962 } else {
3963 let update = proto::UpdateChannels {
3964 remove_channel_invitations: vec![channel_id.to_proto()],
3965 ..Default::default()
3966 };
3967
3968 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3969 session.peer.send(connection_id, update.clone())?;
3970 }
3971 };
3972
3973 send_notifications(&connection_pool, &session.peer, notifications);
3974
3975 response.send(proto::Ack {})?;
3976
3977 Ok(())
3978}
3979
3980/// Join the channels' room
3981async fn join_channel(
3982 request: proto::JoinChannel,
3983 response: Response<proto::JoinChannel>,
3984 session: UserSession,
3985) -> Result<()> {
3986 let channel_id = ChannelId::from_proto(request.channel_id);
3987 join_channel_internal(channel_id, Box::new(response), session).await
3988}
3989
3990trait JoinChannelInternalResponse {
3991 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3992}
3993impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3994 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3995 Response::<proto::JoinChannel>::send(self, result)
3996 }
3997}
3998impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3999 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4000 Response::<proto::JoinRoom>::send(self, result)
4001 }
4002}
4003
4004async fn join_channel_internal(
4005 channel_id: ChannelId,
4006 response: Box<impl JoinChannelInternalResponse>,
4007 session: UserSession,
4008) -> Result<()> {
4009 let joined_room = {
4010 let mut db = session.db().await;
4011 // If zed quits without leaving the room, and the user re-opens zed before the
4012 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4013 // room they were in.
4014 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4015 tracing::info!(
4016 stale_connection_id = %connection,
4017 "cleaning up stale connection",
4018 );
4019 drop(db);
4020 leave_room_for_session(&session, connection).await?;
4021 db = session.db().await;
4022 }
4023
4024 let (joined_room, membership_updated, role) = db
4025 .join_channel(channel_id, session.user_id(), session.connection_id)
4026 .await?;
4027
4028 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4029 let (can_publish, token) = if role == ChannelRole::Guest {
4030 (
4031 false,
4032 live_kit
4033 .guest_token(
4034 &joined_room.room.live_kit_room,
4035 &session.user_id().to_string(),
4036 )
4037 .trace_err()?,
4038 )
4039 } else {
4040 (
4041 true,
4042 live_kit
4043 .room_token(
4044 &joined_room.room.live_kit_room,
4045 &session.user_id().to_string(),
4046 )
4047 .trace_err()?,
4048 )
4049 };
4050
4051 Some(LiveKitConnectionInfo {
4052 server_url: live_kit.url().into(),
4053 token,
4054 can_publish,
4055 })
4056 });
4057
4058 response.send(proto::JoinRoomResponse {
4059 room: Some(joined_room.room.clone()),
4060 channel_id: joined_room
4061 .channel
4062 .as_ref()
4063 .map(|channel| channel.id.to_proto()),
4064 live_kit_connection_info,
4065 })?;
4066
4067 let mut connection_pool = session.connection_pool().await;
4068 if let Some(membership_updated) = membership_updated {
4069 notify_membership_updated(
4070 &mut connection_pool,
4071 membership_updated,
4072 session.user_id(),
4073 &session.peer,
4074 );
4075 }
4076
4077 room_updated(&joined_room.room, &session.peer);
4078
4079 joined_room
4080 };
4081
4082 channel_updated(
4083 &joined_room
4084 .channel
4085 .ok_or_else(|| anyhow!("channel not returned"))?,
4086 &joined_room.room,
4087 &session.peer,
4088 &*session.connection_pool().await,
4089 );
4090
4091 update_user_contacts(session.user_id(), &session).await?;
4092 Ok(())
4093}
4094
4095/// Start editing the channel notes
4096async fn join_channel_buffer(
4097 request: proto::JoinChannelBuffer,
4098 response: Response<proto::JoinChannelBuffer>,
4099 session: UserSession,
4100) -> Result<()> {
4101 let db = session.db().await;
4102 let channel_id = ChannelId::from_proto(request.channel_id);
4103
4104 let open_response = db
4105 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4106 .await?;
4107
4108 let collaborators = open_response.collaborators.clone();
4109 response.send(open_response)?;
4110
4111 let update = UpdateChannelBufferCollaborators {
4112 channel_id: channel_id.to_proto(),
4113 collaborators: collaborators.clone(),
4114 };
4115 channel_buffer_updated(
4116 session.connection_id,
4117 collaborators
4118 .iter()
4119 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4120 &update,
4121 &session.peer,
4122 );
4123
4124 Ok(())
4125}
4126
4127/// Edit the channel notes
4128async fn update_channel_buffer(
4129 request: proto::UpdateChannelBuffer,
4130 session: UserSession,
4131) -> Result<()> {
4132 let db = session.db().await;
4133 let channel_id = ChannelId::from_proto(request.channel_id);
4134
4135 let (collaborators, epoch, version) = db
4136 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4137 .await?;
4138
4139 channel_buffer_updated(
4140 session.connection_id,
4141 collaborators.clone(),
4142 &proto::UpdateChannelBuffer {
4143 channel_id: channel_id.to_proto(),
4144 operations: request.operations,
4145 },
4146 &session.peer,
4147 );
4148
4149 let pool = &*session.connection_pool().await;
4150
4151 let non_collaborators =
4152 pool.channel_connection_ids(channel_id)
4153 .filter_map(|(connection_id, _)| {
4154 if collaborators.contains(&connection_id) {
4155 None
4156 } else {
4157 Some(connection_id)
4158 }
4159 });
4160
4161 broadcast(None, non_collaborators, |peer_id| {
4162 session.peer.send(
4163 peer_id,
4164 proto::UpdateChannels {
4165 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4166 channel_id: channel_id.to_proto(),
4167 epoch: epoch as u64,
4168 version: version.clone(),
4169 }],
4170 ..Default::default()
4171 },
4172 )
4173 });
4174
4175 Ok(())
4176}
4177
4178/// Rejoin the channel notes after a connection blip
4179async fn rejoin_channel_buffers(
4180 request: proto::RejoinChannelBuffers,
4181 response: Response<proto::RejoinChannelBuffers>,
4182 session: UserSession,
4183) -> Result<()> {
4184 let db = session.db().await;
4185 let buffers = db
4186 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4187 .await?;
4188
4189 for rejoined_buffer in &buffers {
4190 let collaborators_to_notify = rejoined_buffer
4191 .buffer
4192 .collaborators
4193 .iter()
4194 .filter_map(|c| Some(c.peer_id?.into()));
4195 channel_buffer_updated(
4196 session.connection_id,
4197 collaborators_to_notify,
4198 &proto::UpdateChannelBufferCollaborators {
4199 channel_id: rejoined_buffer.buffer.channel_id,
4200 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4201 },
4202 &session.peer,
4203 );
4204 }
4205
4206 response.send(proto::RejoinChannelBuffersResponse {
4207 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4208 })?;
4209
4210 Ok(())
4211}
4212
4213/// Stop editing the channel notes
4214async fn leave_channel_buffer(
4215 request: proto::LeaveChannelBuffer,
4216 response: Response<proto::LeaveChannelBuffer>,
4217 session: UserSession,
4218) -> Result<()> {
4219 let db = session.db().await;
4220 let channel_id = ChannelId::from_proto(request.channel_id);
4221
4222 let left_buffer = db
4223 .leave_channel_buffer(channel_id, session.connection_id)
4224 .await?;
4225
4226 response.send(Ack {})?;
4227
4228 channel_buffer_updated(
4229 session.connection_id,
4230 left_buffer.connections,
4231 &proto::UpdateChannelBufferCollaborators {
4232 channel_id: channel_id.to_proto(),
4233 collaborators: left_buffer.collaborators,
4234 },
4235 &session.peer,
4236 );
4237
4238 Ok(())
4239}
4240
4241fn channel_buffer_updated<T: EnvelopedMessage>(
4242 sender_id: ConnectionId,
4243 collaborators: impl IntoIterator<Item = ConnectionId>,
4244 message: &T,
4245 peer: &Peer,
4246) {
4247 broadcast(Some(sender_id), collaborators, |peer_id| {
4248 peer.send(peer_id, message.clone())
4249 });
4250}
4251
4252fn send_notifications(
4253 connection_pool: &ConnectionPool,
4254 peer: &Peer,
4255 notifications: db::NotificationBatch,
4256) {
4257 for (user_id, notification) in notifications {
4258 for connection_id in connection_pool.user_connection_ids(user_id) {
4259 if let Err(error) = peer.send(
4260 connection_id,
4261 proto::AddNotification {
4262 notification: Some(notification.clone()),
4263 },
4264 ) {
4265 tracing::error!(
4266 "failed to send notification to {:?} {}",
4267 connection_id,
4268 error
4269 );
4270 }
4271 }
4272 }
4273}
4274
4275/// Send a message to the channel
4276async fn send_channel_message(
4277 request: proto::SendChannelMessage,
4278 response: Response<proto::SendChannelMessage>,
4279 session: UserSession,
4280) -> Result<()> {
4281 // Validate the message body.
4282 let body = request.body.trim().to_string();
4283 if body.len() > MAX_MESSAGE_LEN {
4284 return Err(anyhow!("message is too long"))?;
4285 }
4286 if body.is_empty() {
4287 return Err(anyhow!("message can't be blank"))?;
4288 }
4289
4290 // TODO: adjust mentions if body is trimmed
4291
4292 let timestamp = OffsetDateTime::now_utc();
4293 let nonce = request
4294 .nonce
4295 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4296
4297 let channel_id = ChannelId::from_proto(request.channel_id);
4298 let CreatedChannelMessage {
4299 message_id,
4300 participant_connection_ids,
4301 notifications,
4302 } = session
4303 .db()
4304 .await
4305 .create_channel_message(
4306 channel_id,
4307 session.user_id(),
4308 &body,
4309 &request.mentions,
4310 timestamp,
4311 nonce.clone().into(),
4312 match request.reply_to_message_id {
4313 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4314 None => None,
4315 },
4316 )
4317 .await?;
4318
4319 let message = proto::ChannelMessage {
4320 sender_id: session.user_id().to_proto(),
4321 id: message_id.to_proto(),
4322 body,
4323 mentions: request.mentions,
4324 timestamp: timestamp.unix_timestamp() as u64,
4325 nonce: Some(nonce),
4326 reply_to_message_id: request.reply_to_message_id,
4327 edited_at: None,
4328 };
4329 broadcast(
4330 Some(session.connection_id),
4331 participant_connection_ids.clone(),
4332 |connection| {
4333 session.peer.send(
4334 connection,
4335 proto::ChannelMessageSent {
4336 channel_id: channel_id.to_proto(),
4337 message: Some(message.clone()),
4338 },
4339 )
4340 },
4341 );
4342 response.send(proto::SendChannelMessageResponse {
4343 message: Some(message),
4344 })?;
4345
4346 let pool = &*session.connection_pool().await;
4347 let non_participants =
4348 pool.channel_connection_ids(channel_id)
4349 .filter_map(|(connection_id, _)| {
4350 if participant_connection_ids.contains(&connection_id) {
4351 None
4352 } else {
4353 Some(connection_id)
4354 }
4355 });
4356 broadcast(None, non_participants, |peer_id| {
4357 session.peer.send(
4358 peer_id,
4359 proto::UpdateChannels {
4360 latest_channel_message_ids: vec![proto::ChannelMessageId {
4361 channel_id: channel_id.to_proto(),
4362 message_id: message_id.to_proto(),
4363 }],
4364 ..Default::default()
4365 },
4366 )
4367 });
4368 send_notifications(pool, &session.peer, notifications);
4369
4370 Ok(())
4371}
4372
4373/// Delete a channel message
4374async fn remove_channel_message(
4375 request: proto::RemoveChannelMessage,
4376 response: Response<proto::RemoveChannelMessage>,
4377 session: UserSession,
4378) -> Result<()> {
4379 let channel_id = ChannelId::from_proto(request.channel_id);
4380 let message_id = MessageId::from_proto(request.message_id);
4381 let (connection_ids, existing_notification_ids) = session
4382 .db()
4383 .await
4384 .remove_channel_message(channel_id, message_id, session.user_id())
4385 .await?;
4386
4387 broadcast(
4388 Some(session.connection_id),
4389 connection_ids,
4390 move |connection| {
4391 session.peer.send(connection, request.clone())?;
4392
4393 for notification_id in &existing_notification_ids {
4394 session.peer.send(
4395 connection,
4396 proto::DeleteNotification {
4397 notification_id: (*notification_id).to_proto(),
4398 },
4399 )?;
4400 }
4401
4402 Ok(())
4403 },
4404 );
4405 response.send(proto::Ack {})?;
4406 Ok(())
4407}
4408
4409async fn update_channel_message(
4410 request: proto::UpdateChannelMessage,
4411 response: Response<proto::UpdateChannelMessage>,
4412 session: UserSession,
4413) -> Result<()> {
4414 let channel_id = ChannelId::from_proto(request.channel_id);
4415 let message_id = MessageId::from_proto(request.message_id);
4416 let updated_at = OffsetDateTime::now_utc();
4417 let UpdatedChannelMessage {
4418 message_id,
4419 participant_connection_ids,
4420 notifications,
4421 reply_to_message_id,
4422 timestamp,
4423 deleted_mention_notification_ids,
4424 updated_mention_notifications,
4425 } = session
4426 .db()
4427 .await
4428 .update_channel_message(
4429 channel_id,
4430 message_id,
4431 session.user_id(),
4432 request.body.as_str(),
4433 &request.mentions,
4434 updated_at,
4435 )
4436 .await?;
4437
4438 let nonce = request
4439 .nonce
4440 .clone()
4441 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4442
4443 let message = proto::ChannelMessage {
4444 sender_id: session.user_id().to_proto(),
4445 id: message_id.to_proto(),
4446 body: request.body.clone(),
4447 mentions: request.mentions.clone(),
4448 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4449 nonce: Some(nonce),
4450 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4451 edited_at: Some(updated_at.unix_timestamp() as u64),
4452 };
4453
4454 response.send(proto::Ack {})?;
4455
4456 let pool = &*session.connection_pool().await;
4457 broadcast(
4458 Some(session.connection_id),
4459 participant_connection_ids,
4460 |connection| {
4461 session.peer.send(
4462 connection,
4463 proto::ChannelMessageUpdate {
4464 channel_id: channel_id.to_proto(),
4465 message: Some(message.clone()),
4466 },
4467 )?;
4468
4469 for notification_id in &deleted_mention_notification_ids {
4470 session.peer.send(
4471 connection,
4472 proto::DeleteNotification {
4473 notification_id: (*notification_id).to_proto(),
4474 },
4475 )?;
4476 }
4477
4478 for notification in &updated_mention_notifications {
4479 session.peer.send(
4480 connection,
4481 proto::UpdateNotification {
4482 notification: Some(notification.clone()),
4483 },
4484 )?;
4485 }
4486
4487 Ok(())
4488 },
4489 );
4490
4491 send_notifications(pool, &session.peer, notifications);
4492
4493 Ok(())
4494}
4495
4496/// Mark a channel message as read
4497async fn acknowledge_channel_message(
4498 request: proto::AckChannelMessage,
4499 session: UserSession,
4500) -> Result<()> {
4501 let channel_id = ChannelId::from_proto(request.channel_id);
4502 let message_id = MessageId::from_proto(request.message_id);
4503 let notifications = session
4504 .db()
4505 .await
4506 .observe_channel_message(channel_id, session.user_id(), message_id)
4507 .await?;
4508 send_notifications(
4509 &*session.connection_pool().await,
4510 &session.peer,
4511 notifications,
4512 );
4513 Ok(())
4514}
4515
4516/// Mark a buffer version as synced
4517async fn acknowledge_buffer_version(
4518 request: proto::AckBufferOperation,
4519 session: UserSession,
4520) -> Result<()> {
4521 let buffer_id = BufferId::from_proto(request.buffer_id);
4522 session
4523 .db()
4524 .await
4525 .observe_buffer_version(
4526 buffer_id,
4527 session.user_id(),
4528 request.epoch as i32,
4529 &request.version,
4530 )
4531 .await?;
4532 Ok(())
4533}
4534
4535struct CompleteWithLanguageModelRateLimit;
4536
4537impl RateLimit for CompleteWithLanguageModelRateLimit {
4538 fn capacity() -> usize {
4539 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4540 .ok()
4541 .and_then(|v| v.parse().ok())
4542 .unwrap_or(120) // Picked arbitrarily
4543 }
4544
4545 fn refill_duration() -> chrono::Duration {
4546 chrono::Duration::hours(1)
4547 }
4548
4549 fn db_name() -> &'static str {
4550 "complete-with-language-model"
4551 }
4552}
4553
4554async fn complete_with_language_model(
4555 request: proto::CompleteWithLanguageModel,
4556 response: Response<proto::CompleteWithLanguageModel>,
4557 session: Session,
4558 config: &Config,
4559) -> Result<()> {
4560 let Some(session) = session.for_user() else {
4561 return Err(anyhow!("user not found"))?;
4562 };
4563 authorize_access_to_language_models(&session).await?;
4564
4565 session
4566 .rate_limiter
4567 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4568 .await?;
4569
4570 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4571 Some(proto::LanguageModelProvider::Anthropic) => {
4572 let api_key = config
4573 .anthropic_api_key
4574 .as_ref()
4575 .context("no Anthropic AI API key configured on the server")?;
4576 anthropic::complete(
4577 session.http_client.as_ref(),
4578 anthropic::ANTHROPIC_API_URL,
4579 api_key,
4580 serde_json::from_str(&request.request)?,
4581 )
4582 .await?
4583 }
4584 _ => return Err(anyhow!("unsupported provider"))?,
4585 };
4586
4587 response.send(proto::CompleteWithLanguageModelResponse {
4588 completion: serde_json::to_string(&result)?,
4589 })?;
4590
4591 Ok(())
4592}
4593
4594async fn stream_complete_with_language_model(
4595 request: proto::StreamCompleteWithLanguageModel,
4596 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4597 session: Session,
4598 config: &Config,
4599) -> Result<()> {
4600 let Some(session) = session.for_user() else {
4601 return Err(anyhow!("user not found"))?;
4602 };
4603 authorize_access_to_language_models(&session).await?;
4604
4605 session
4606 .rate_limiter
4607 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4608 .await?;
4609
4610 match proto::LanguageModelProvider::from_i32(request.provider) {
4611 Some(proto::LanguageModelProvider::Anthropic) => {
4612 let api_key = config
4613 .anthropic_api_key
4614 .as_ref()
4615 .context("no Anthropic AI API key configured on the server")?;
4616 let mut chunks = anthropic::stream_completion(
4617 session.http_client.as_ref(),
4618 anthropic::ANTHROPIC_API_URL,
4619 api_key,
4620 serde_json::from_str(&request.request)?,
4621 None,
4622 )
4623 .await?;
4624 while let Some(event) = chunks.next().await {
4625 let chunk = event?;
4626 response.send(proto::StreamCompleteWithLanguageModelResponse {
4627 event: serde_json::to_string(&chunk)?,
4628 })?;
4629 }
4630 }
4631 Some(proto::LanguageModelProvider::OpenAi) => {
4632 let api_key = config
4633 .openai_api_key
4634 .as_ref()
4635 .context("no OpenAI API key configured on the server")?;
4636 let mut events = open_ai::stream_completion(
4637 session.http_client.as_ref(),
4638 open_ai::OPEN_AI_API_URL,
4639 api_key,
4640 serde_json::from_str(&request.request)?,
4641 None,
4642 )
4643 .await?;
4644 while let Some(event) = events.next().await {
4645 let event = event?;
4646 response.send(proto::StreamCompleteWithLanguageModelResponse {
4647 event: serde_json::to_string(&event)?,
4648 })?;
4649 }
4650 }
4651 Some(proto::LanguageModelProvider::Google) => {
4652 let api_key = config
4653 .google_ai_api_key
4654 .as_ref()
4655 .context("no Google AI API key configured on the server")?;
4656 let mut events = google_ai::stream_generate_content(
4657 session.http_client.as_ref(),
4658 google_ai::API_URL,
4659 api_key,
4660 serde_json::from_str(&request.request)?,
4661 )
4662 .await?;
4663 while let Some(event) = events.next().await {
4664 let event = event?;
4665 response.send(proto::StreamCompleteWithLanguageModelResponse {
4666 event: serde_json::to_string(&event)?,
4667 })?;
4668 }
4669 }
4670 None => return Err(anyhow!("unknown provider"))?,
4671 }
4672
4673 Ok(())
4674}
4675
4676async fn count_language_model_tokens(
4677 request: proto::CountLanguageModelTokens,
4678 response: Response<proto::CountLanguageModelTokens>,
4679 session: Session,
4680 config: &Config,
4681) -> Result<()> {
4682 let Some(session) = session.for_user() else {
4683 return Err(anyhow!("user not found"))?;
4684 };
4685 authorize_access_to_language_models(&session).await?;
4686
4687 session
4688 .rate_limiter
4689 .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4690 .await?;
4691
4692 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4693 Some(proto::LanguageModelProvider::Google) => {
4694 let api_key = config
4695 .google_ai_api_key
4696 .as_ref()
4697 .context("no Google AI API key configured on the server")?;
4698 google_ai::count_tokens(
4699 session.http_client.as_ref(),
4700 google_ai::API_URL,
4701 api_key,
4702 serde_json::from_str(&request.request)?,
4703 )
4704 .await?
4705 }
4706 _ => return Err(anyhow!("unsupported provider"))?,
4707 };
4708
4709 response.send(proto::CountLanguageModelTokensResponse {
4710 token_count: result.total_tokens as u32,
4711 })?;
4712
4713 Ok(())
4714}
4715
4716struct CountLanguageModelTokensRateLimit;
4717
4718impl RateLimit for CountLanguageModelTokensRateLimit {
4719 fn capacity() -> usize {
4720 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4721 .ok()
4722 .and_then(|v| v.parse().ok())
4723 .unwrap_or(600) // Picked arbitrarily
4724 }
4725
4726 fn refill_duration() -> chrono::Duration {
4727 chrono::Duration::hours(1)
4728 }
4729
4730 fn db_name() -> &'static str {
4731 "count-language-model-tokens"
4732 }
4733}
4734
4735struct ComputeEmbeddingsRateLimit;
4736
4737impl RateLimit for ComputeEmbeddingsRateLimit {
4738 fn capacity() -> usize {
4739 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4740 .ok()
4741 .and_then(|v| v.parse().ok())
4742 .unwrap_or(5000) // Picked arbitrarily
4743 }
4744
4745 fn refill_duration() -> chrono::Duration {
4746 chrono::Duration::hours(1)
4747 }
4748
4749 fn db_name() -> &'static str {
4750 "compute-embeddings"
4751 }
4752}
4753
4754async fn compute_embeddings(
4755 request: proto::ComputeEmbeddings,
4756 response: Response<proto::ComputeEmbeddings>,
4757 session: UserSession,
4758 api_key: Option<Arc<str>>,
4759) -> Result<()> {
4760 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4761 authorize_access_to_language_models(&session).await?;
4762
4763 session
4764 .rate_limiter
4765 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4766 .await?;
4767
4768 let embeddings = match request.model.as_str() {
4769 "openai/text-embedding-3-small" => {
4770 open_ai::embed(
4771 session.http_client.as_ref(),
4772 OPEN_AI_API_URL,
4773 &api_key,
4774 OpenAiEmbeddingModel::TextEmbedding3Small,
4775 request.texts.iter().map(|text| text.as_str()),
4776 )
4777 .await?
4778 }
4779 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4780 };
4781
4782 let embeddings = request
4783 .texts
4784 .iter()
4785 .map(|text| {
4786 let mut hasher = sha2::Sha256::new();
4787 hasher.update(text.as_bytes());
4788 let result = hasher.finalize();
4789 result.to_vec()
4790 })
4791 .zip(
4792 embeddings
4793 .data
4794 .into_iter()
4795 .map(|embedding| embedding.embedding),
4796 )
4797 .collect::<HashMap<_, _>>();
4798
4799 let db = session.db().await;
4800 db.save_embeddings(&request.model, &embeddings)
4801 .await
4802 .context("failed to save embeddings")
4803 .trace_err();
4804
4805 response.send(proto::ComputeEmbeddingsResponse {
4806 embeddings: embeddings
4807 .into_iter()
4808 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4809 .collect(),
4810 })?;
4811 Ok(())
4812}
4813
4814async fn get_cached_embeddings(
4815 request: proto::GetCachedEmbeddings,
4816 response: Response<proto::GetCachedEmbeddings>,
4817 session: UserSession,
4818) -> Result<()> {
4819 authorize_access_to_language_models(&session).await?;
4820
4821 let db = session.db().await;
4822 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4823
4824 response.send(proto::GetCachedEmbeddingsResponse {
4825 embeddings: embeddings
4826 .into_iter()
4827 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4828 .collect(),
4829 })?;
4830 Ok(())
4831}
4832
4833async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4834 let db = session.db().await;
4835 let flags = db.get_user_flags(session.user_id()).await?;
4836 if flags.iter().any(|flag| flag == "language-models") {
4837 Ok(())
4838 } else {
4839 Err(anyhow!("permission denied"))?
4840 }
4841}
4842
4843/// Get a Supermaven API key for the user
4844async fn get_supermaven_api_key(
4845 _request: proto::GetSupermavenApiKey,
4846 response: Response<proto::GetSupermavenApiKey>,
4847 session: UserSession,
4848) -> Result<()> {
4849 let user_id: String = session.user_id().to_string();
4850 if !session.is_staff() {
4851 return Err(anyhow!("supermaven not enabled for this account"))?;
4852 }
4853
4854 let email = session
4855 .email()
4856 .ok_or_else(|| anyhow!("user must have an email"))?;
4857
4858 let supermaven_admin_api = session
4859 .supermaven_client
4860 .as_ref()
4861 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4862
4863 let result = supermaven_admin_api
4864 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4865 .await?;
4866
4867 response.send(proto::GetSupermavenApiKeyResponse {
4868 api_key: result.api_key,
4869 })?;
4870
4871 Ok(())
4872}
4873
4874/// Start receiving chat updates for a channel
4875async fn join_channel_chat(
4876 request: proto::JoinChannelChat,
4877 response: Response<proto::JoinChannelChat>,
4878 session: UserSession,
4879) -> Result<()> {
4880 let channel_id = ChannelId::from_proto(request.channel_id);
4881
4882 let db = session.db().await;
4883 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4884 .await?;
4885 let messages = db
4886 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4887 .await?;
4888 response.send(proto::JoinChannelChatResponse {
4889 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4890 messages,
4891 })?;
4892 Ok(())
4893}
4894
4895/// Stop receiving chat updates for a channel
4896async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4897 let channel_id = ChannelId::from_proto(request.channel_id);
4898 session
4899 .db()
4900 .await
4901 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4902 .await?;
4903 Ok(())
4904}
4905
4906/// Retrieve the chat history for a channel
4907async fn get_channel_messages(
4908 request: proto::GetChannelMessages,
4909 response: Response<proto::GetChannelMessages>,
4910 session: UserSession,
4911) -> Result<()> {
4912 let channel_id = ChannelId::from_proto(request.channel_id);
4913 let messages = session
4914 .db()
4915 .await
4916 .get_channel_messages(
4917 channel_id,
4918 session.user_id(),
4919 MESSAGE_COUNT_PER_PAGE,
4920 Some(MessageId::from_proto(request.before_message_id)),
4921 )
4922 .await?;
4923 response.send(proto::GetChannelMessagesResponse {
4924 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4925 messages,
4926 })?;
4927 Ok(())
4928}
4929
4930/// Retrieve specific chat messages
4931async fn get_channel_messages_by_id(
4932 request: proto::GetChannelMessagesById,
4933 response: Response<proto::GetChannelMessagesById>,
4934 session: UserSession,
4935) -> Result<()> {
4936 let message_ids = request
4937 .message_ids
4938 .iter()
4939 .map(|id| MessageId::from_proto(*id))
4940 .collect::<Vec<_>>();
4941 let messages = session
4942 .db()
4943 .await
4944 .get_channel_messages_by_id(session.user_id(), &message_ids)
4945 .await?;
4946 response.send(proto::GetChannelMessagesResponse {
4947 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4948 messages,
4949 })?;
4950 Ok(())
4951}
4952
4953/// Retrieve the current users notifications
4954async fn get_notifications(
4955 request: proto::GetNotifications,
4956 response: Response<proto::GetNotifications>,
4957 session: UserSession,
4958) -> Result<()> {
4959 let notifications = session
4960 .db()
4961 .await
4962 .get_notifications(
4963 session.user_id(),
4964 NOTIFICATION_COUNT_PER_PAGE,
4965 request
4966 .before_id
4967 .map(|id| db::NotificationId::from_proto(id)),
4968 )
4969 .await?;
4970 response.send(proto::GetNotificationsResponse {
4971 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4972 notifications,
4973 })?;
4974 Ok(())
4975}
4976
4977/// Mark notifications as read
4978async fn mark_notification_as_read(
4979 request: proto::MarkNotificationRead,
4980 response: Response<proto::MarkNotificationRead>,
4981 session: UserSession,
4982) -> Result<()> {
4983 let database = &session.db().await;
4984 let notifications = database
4985 .mark_notification_as_read_by_id(
4986 session.user_id(),
4987 NotificationId::from_proto(request.notification_id),
4988 )
4989 .await?;
4990 send_notifications(
4991 &*session.connection_pool().await,
4992 &session.peer,
4993 notifications,
4994 );
4995 response.send(proto::Ack {})?;
4996 Ok(())
4997}
4998
4999/// Get the current users information
5000async fn get_private_user_info(
5001 _request: proto::GetPrivateUserInfo,
5002 response: Response<proto::GetPrivateUserInfo>,
5003 session: UserSession,
5004) -> Result<()> {
5005 let db = session.db().await;
5006
5007 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5008 let user = db
5009 .get_user_by_id(session.user_id())
5010 .await?
5011 .ok_or_else(|| anyhow!("user not found"))?;
5012 let flags = db.get_user_flags(session.user_id()).await?;
5013
5014 response.send(proto::GetPrivateUserInfoResponse {
5015 metrics_id,
5016 staff: user.admin,
5017 flags,
5018 })?;
5019 Ok(())
5020}
5021
5022fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5023 let message = match message {
5024 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5025 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5026 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5027 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5028 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5029 code: frame.code.into(),
5030 reason: frame.reason,
5031 })),
5032 // We should never receive a frame while reading the message, according
5033 // to the `tungstenite` maintainers:
5034 //
5035 // > It cannot occur when you read messages from the WebSocket, but it
5036 // > can be used when you want to send the raw frames (e.g. you want to
5037 // > send the frames to the WebSocket without composing the full message first).
5038 // >
5039 // > — https://github.com/snapview/tungstenite-rs/issues/268
5040 TungsteniteMessage::Frame(_) => {
5041 bail!("received an unexpected frame while reading the message")
5042 }
5043 };
5044
5045 Ok(message)
5046}
5047
5048fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5049 match message {
5050 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5051 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5052 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5053 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5054 AxumMessage::Close(frame) => {
5055 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5056 code: frame.code.into(),
5057 reason: frame.reason,
5058 }))
5059 }
5060 }
5061}
5062
5063fn notify_membership_updated(
5064 connection_pool: &mut ConnectionPool,
5065 result: MembershipUpdated,
5066 user_id: UserId,
5067 peer: &Peer,
5068) {
5069 for membership in &result.new_channels.channel_memberships {
5070 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5071 }
5072 for channel_id in &result.removed_channels {
5073 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5074 }
5075
5076 let user_channels_update = proto::UpdateUserChannels {
5077 channel_memberships: result
5078 .new_channels
5079 .channel_memberships
5080 .iter()
5081 .map(|cm| proto::ChannelMembership {
5082 channel_id: cm.channel_id.to_proto(),
5083 role: cm.role.into(),
5084 })
5085 .collect(),
5086 ..Default::default()
5087 };
5088
5089 let mut update = build_channels_update(result.new_channels);
5090 update.delete_channels = result
5091 .removed_channels
5092 .into_iter()
5093 .map(|id| id.to_proto())
5094 .collect();
5095 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5096
5097 for connection_id in connection_pool.user_connection_ids(user_id) {
5098 peer.send(connection_id, user_channels_update.clone())
5099 .trace_err();
5100 peer.send(connection_id, update.clone()).trace_err();
5101 }
5102}
5103
5104fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5105 proto::UpdateUserChannels {
5106 channel_memberships: channels
5107 .channel_memberships
5108 .iter()
5109 .map(|m| proto::ChannelMembership {
5110 channel_id: m.channel_id.to_proto(),
5111 role: m.role.into(),
5112 })
5113 .collect(),
5114 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5115 observed_channel_message_id: channels.observed_channel_messages.clone(),
5116 }
5117}
5118
5119fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5120 let mut update = proto::UpdateChannels::default();
5121
5122 for channel in channels.channels {
5123 update.channels.push(channel.to_proto());
5124 }
5125
5126 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5127 update.latest_channel_message_ids = channels.latest_channel_messages;
5128
5129 for (channel_id, participants) in channels.channel_participants {
5130 update
5131 .channel_participants
5132 .push(proto::ChannelParticipants {
5133 channel_id: channel_id.to_proto(),
5134 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5135 });
5136 }
5137
5138 for channel in channels.invited_channels {
5139 update.channel_invitations.push(channel.to_proto());
5140 }
5141
5142 update.hosted_projects = channels.hosted_projects;
5143 update
5144}
5145
5146fn build_initial_contacts_update(
5147 contacts: Vec<db::Contact>,
5148 pool: &ConnectionPool,
5149) -> proto::UpdateContacts {
5150 let mut update = proto::UpdateContacts::default();
5151
5152 for contact in contacts {
5153 match contact {
5154 db::Contact::Accepted { user_id, busy } => {
5155 update.contacts.push(contact_for_user(user_id, busy, &pool));
5156 }
5157 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5158 db::Contact::Incoming { user_id } => {
5159 update
5160 .incoming_requests
5161 .push(proto::IncomingContactRequest {
5162 requester_id: user_id.to_proto(),
5163 })
5164 }
5165 }
5166 }
5167
5168 update
5169}
5170
5171fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5172 proto::Contact {
5173 user_id: user_id.to_proto(),
5174 online: pool.is_user_online(user_id),
5175 busy,
5176 }
5177}
5178
5179fn room_updated(room: &proto::Room, peer: &Peer) {
5180 broadcast(
5181 None,
5182 room.participants
5183 .iter()
5184 .filter_map(|participant| Some(participant.peer_id?.into())),
5185 |peer_id| {
5186 peer.send(
5187 peer_id,
5188 proto::RoomUpdated {
5189 room: Some(room.clone()),
5190 },
5191 )
5192 },
5193 );
5194}
5195
5196fn channel_updated(
5197 channel: &db::channel::Model,
5198 room: &proto::Room,
5199 peer: &Peer,
5200 pool: &ConnectionPool,
5201) {
5202 let participants = room
5203 .participants
5204 .iter()
5205 .map(|p| p.user_id)
5206 .collect::<Vec<_>>();
5207
5208 broadcast(
5209 None,
5210 pool.channel_connection_ids(channel.root_id())
5211 .filter_map(|(channel_id, role)| {
5212 role.can_see_channel(channel.visibility).then(|| channel_id)
5213 }),
5214 |peer_id| {
5215 peer.send(
5216 peer_id,
5217 proto::UpdateChannels {
5218 channel_participants: vec![proto::ChannelParticipants {
5219 channel_id: channel.id.to_proto(),
5220 participant_user_ids: participants.clone(),
5221 }],
5222 ..Default::default()
5223 },
5224 )
5225 },
5226 );
5227}
5228
5229async fn send_dev_server_projects_update(
5230 user_id: UserId,
5231 mut status: proto::DevServerProjectsUpdate,
5232 session: &Session,
5233) {
5234 let pool = session.connection_pool().await;
5235 for dev_server in &mut status.dev_servers {
5236 dev_server.status =
5237 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5238 }
5239 let connections = pool.user_connection_ids(user_id);
5240 for connection_id in connections {
5241 session.peer.send(connection_id, status.clone()).trace_err();
5242 }
5243}
5244
5245async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5246 let db = session.db().await;
5247
5248 let contacts = db.get_contacts(user_id).await?;
5249 let busy = db.is_user_busy(user_id).await?;
5250
5251 let pool = session.connection_pool().await;
5252 let updated_contact = contact_for_user(user_id, busy, &pool);
5253 for contact in contacts {
5254 if let db::Contact::Accepted {
5255 user_id: contact_user_id,
5256 ..
5257 } = contact
5258 {
5259 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5260 session
5261 .peer
5262 .send(
5263 contact_conn_id,
5264 proto::UpdateContacts {
5265 contacts: vec![updated_contact.clone()],
5266 remove_contacts: Default::default(),
5267 incoming_requests: Default::default(),
5268 remove_incoming_requests: Default::default(),
5269 outgoing_requests: Default::default(),
5270 remove_outgoing_requests: Default::default(),
5271 },
5272 )
5273 .trace_err();
5274 }
5275 }
5276 }
5277 Ok(())
5278}
5279
5280async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5281 log::info!("lost dev server connection, unsharing projects");
5282 let project_ids = session
5283 .db()
5284 .await
5285 .get_stale_dev_server_projects(session.connection_id)
5286 .await?;
5287
5288 for project_id in project_ids {
5289 // not unshare re-checks the connection ids match, so we get away with no transaction
5290 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5291 }
5292
5293 let user_id = session.dev_server().user_id;
5294 let update = session
5295 .db()
5296 .await
5297 .dev_server_projects_update(user_id)
5298 .await?;
5299
5300 send_dev_server_projects_update(user_id, update, session).await;
5301
5302 Ok(())
5303}
5304
5305async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5306 let mut contacts_to_update = HashSet::default();
5307
5308 let room_id;
5309 let canceled_calls_to_user_ids;
5310 let live_kit_room;
5311 let delete_live_kit_room;
5312 let room;
5313 let channel;
5314
5315 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5316 contacts_to_update.insert(session.user_id());
5317
5318 for project in left_room.left_projects.values() {
5319 project_left(project, session);
5320 }
5321
5322 room_id = RoomId::from_proto(left_room.room.id);
5323 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5324 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5325 delete_live_kit_room = left_room.deleted;
5326 room = mem::take(&mut left_room.room);
5327 channel = mem::take(&mut left_room.channel);
5328
5329 room_updated(&room, &session.peer);
5330 } else {
5331 return Ok(());
5332 }
5333
5334 if let Some(channel) = channel {
5335 channel_updated(
5336 &channel,
5337 &room,
5338 &session.peer,
5339 &*session.connection_pool().await,
5340 );
5341 }
5342
5343 {
5344 let pool = session.connection_pool().await;
5345 for canceled_user_id in canceled_calls_to_user_ids {
5346 for connection_id in pool.user_connection_ids(canceled_user_id) {
5347 session
5348 .peer
5349 .send(
5350 connection_id,
5351 proto::CallCanceled {
5352 room_id: room_id.to_proto(),
5353 },
5354 )
5355 .trace_err();
5356 }
5357 contacts_to_update.insert(canceled_user_id);
5358 }
5359 }
5360
5361 for contact_user_id in contacts_to_update {
5362 update_user_contacts(contact_user_id, &session).await?;
5363 }
5364
5365 if let Some(live_kit) = session.live_kit_client.as_ref() {
5366 live_kit
5367 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5368 .await
5369 .trace_err();
5370
5371 if delete_live_kit_room {
5372 live_kit.delete_room(live_kit_room).await.trace_err();
5373 }
5374 }
5375
5376 Ok(())
5377}
5378
5379async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5380 let left_channel_buffers = session
5381 .db()
5382 .await
5383 .leave_channel_buffers(session.connection_id)
5384 .await?;
5385
5386 for left_buffer in left_channel_buffers {
5387 channel_buffer_updated(
5388 session.connection_id,
5389 left_buffer.connections,
5390 &proto::UpdateChannelBufferCollaborators {
5391 channel_id: left_buffer.channel_id.to_proto(),
5392 collaborators: left_buffer.collaborators,
5393 },
5394 &session.peer,
5395 );
5396 }
5397
5398 Ok(())
5399}
5400
5401fn project_left(project: &db::LeftProject, session: &UserSession) {
5402 for connection_id in &project.connection_ids {
5403 if project.should_unshare {
5404 session
5405 .peer
5406 .send(
5407 *connection_id,
5408 proto::UnshareProject {
5409 project_id: project.id.to_proto(),
5410 },
5411 )
5412 .trace_err();
5413 } else {
5414 session
5415 .peer
5416 .send(
5417 *connection_id,
5418 proto::RemoveProjectCollaborator {
5419 project_id: project.id.to_proto(),
5420 peer_id: Some(session.connection_id.into()),
5421 },
5422 )
5423 .trace_err();
5424 }
5425 }
5426}
5427
5428pub trait ResultExt {
5429 type Ok;
5430
5431 fn trace_err(self) -> Option<Self::Ok>;
5432}
5433
5434impl<T, E> ResultExt for Result<T, E>
5435where
5436 E: std::fmt::Debug,
5437{
5438 type Ok = T;
5439
5440 #[track_caller]
5441 fn trace_err(self) -> Option<T> {
5442 match self {
5443 Ok(value) => Some(value),
5444 Err(error) => {
5445 tracing::error!("{:?}", error);
5446 None
5447 }
5448 }
5449 }
5450}