1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Config, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::SynchronizeContexts>,
605 ))
606 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
607 .add_message_handler(update_context)
608 .add_request_handler({
609 let app_state = app_state.clone();
610 move |request, response, session| {
611 let app_state = app_state.clone();
612 async move {
613 complete_with_language_model(request, response, session, &app_state.config)
614 .await
615 }
616 }
617 })
618 .add_streaming_request_handler({
619 let app_state = app_state.clone();
620 move |request, response, session| {
621 let app_state = app_state.clone();
622 async move {
623 stream_complete_with_language_model(
624 request,
625 response,
626 session,
627 &app_state.config,
628 )
629 .await
630 }
631 }
632 })
633 .add_request_handler({
634 let app_state = app_state.clone();
635 move |request, response, session| {
636 let app_state = app_state.clone();
637 async move {
638 count_language_model_tokens(request, response, session, &app_state.config)
639 .await
640 }
641 }
642 })
643 .add_request_handler({
644 user_handler(move |request, response, session| {
645 get_cached_embeddings(request, response, session)
646 })
647 })
648 .add_request_handler({
649 let app_state = app_state.clone();
650 user_handler(move |request, response, session| {
651 compute_embeddings(
652 request,
653 response,
654 session,
655 app_state.config.openai_api_key.clone(),
656 )
657 })
658 });
659
660 Arc::new(server)
661 }
662
663 pub async fn start(&self) -> Result<()> {
664 let server_id = *self.id.lock();
665 let app_state = self.app_state.clone();
666 let peer = self.peer.clone();
667 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
668 let pool = self.connection_pool.clone();
669 let live_kit_client = self.app_state.live_kit_client.clone();
670
671 let span = info_span!("start server");
672 self.app_state.executor.spawn_detached(
673 async move {
674 tracing::info!("waiting for cleanup timeout");
675 timeout.await;
676 tracing::info!("cleanup timeout expired, retrieving stale rooms");
677 if let Some((room_ids, channel_ids)) = app_state
678 .db
679 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
680 .await
681 .trace_err()
682 {
683 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
684 tracing::info!(
685 stale_channel_buffer_count = channel_ids.len(),
686 "retrieved stale channel buffers"
687 );
688
689 for channel_id in channel_ids {
690 if let Some(refreshed_channel_buffer) = app_state
691 .db
692 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
693 .await
694 .trace_err()
695 {
696 for connection_id in refreshed_channel_buffer.connection_ids {
697 peer.send(
698 connection_id,
699 proto::UpdateChannelBufferCollaborators {
700 channel_id: channel_id.to_proto(),
701 collaborators: refreshed_channel_buffer
702 .collaborators
703 .clone(),
704 },
705 )
706 .trace_err();
707 }
708 }
709 }
710
711 for room_id in room_ids {
712 let mut contacts_to_update = HashSet::default();
713 let mut canceled_calls_to_user_ids = Vec::new();
714 let mut live_kit_room = String::new();
715 let mut delete_live_kit_room = false;
716
717 if let Some(mut refreshed_room) = app_state
718 .db
719 .clear_stale_room_participants(room_id, server_id)
720 .await
721 .trace_err()
722 {
723 tracing::info!(
724 room_id = room_id.0,
725 new_participant_count = refreshed_room.room.participants.len(),
726 "refreshed room"
727 );
728 room_updated(&refreshed_room.room, &peer);
729 if let Some(channel) = refreshed_room.channel.as_ref() {
730 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
731 }
732 contacts_to_update
733 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
734 contacts_to_update
735 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
736 canceled_calls_to_user_ids =
737 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
738 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
739 delete_live_kit_room = refreshed_room.room.participants.is_empty();
740 }
741
742 {
743 let pool = pool.lock();
744 for canceled_user_id in canceled_calls_to_user_ids {
745 for connection_id in pool.user_connection_ids(canceled_user_id) {
746 peer.send(
747 connection_id,
748 proto::CallCanceled {
749 room_id: room_id.to_proto(),
750 },
751 )
752 .trace_err();
753 }
754 }
755 }
756
757 for user_id in contacts_to_update {
758 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
759 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
760 if let Some((busy, contacts)) = busy.zip(contacts) {
761 let pool = pool.lock();
762 let updated_contact = contact_for_user(user_id, busy, &pool);
763 for contact in contacts {
764 if let db::Contact::Accepted {
765 user_id: contact_user_id,
766 ..
767 } = contact
768 {
769 for contact_conn_id in
770 pool.user_connection_ids(contact_user_id)
771 {
772 peer.send(
773 contact_conn_id,
774 proto::UpdateContacts {
775 contacts: vec![updated_contact.clone()],
776 remove_contacts: Default::default(),
777 incoming_requests: Default::default(),
778 remove_incoming_requests: Default::default(),
779 outgoing_requests: Default::default(),
780 remove_outgoing_requests: Default::default(),
781 },
782 )
783 .trace_err();
784 }
785 }
786 }
787 }
788 }
789
790 if let Some(live_kit) = live_kit_client.as_ref() {
791 if delete_live_kit_room {
792 live_kit.delete_room(live_kit_room).await.trace_err();
793 }
794 }
795 }
796 }
797
798 app_state
799 .db
800 .delete_stale_servers(&app_state.config.zed_environment, server_id)
801 .await
802 .trace_err();
803 }
804 .instrument(span),
805 );
806 Ok(())
807 }
808
809 pub fn teardown(&self) {
810 self.peer.teardown();
811 self.connection_pool.lock().reset();
812 let _ = self.teardown.send(true);
813 }
814
815 #[cfg(test)]
816 pub fn reset(&self, id: ServerId) {
817 self.teardown();
818 *self.id.lock() = id;
819 self.peer.reset(id.0 as u32);
820 let _ = self.teardown.send(false);
821 }
822
823 #[cfg(test)]
824 pub fn id(&self) -> ServerId {
825 *self.id.lock()
826 }
827
828 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
829 where
830 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
831 Fut: 'static + Send + Future<Output = Result<()>>,
832 M: EnvelopedMessage,
833 {
834 let prev_handler = self.handlers.insert(
835 TypeId::of::<M>(),
836 Box::new(move |envelope, session| {
837 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
838 let received_at = envelope.received_at;
839 tracing::info!("message received");
840 let start_time = Instant::now();
841 let future = (handler)(*envelope, session);
842 async move {
843 let result = future.await;
844 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
845 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
846 let queue_duration_ms = total_duration_ms - processing_duration_ms;
847 let payload_type = M::NAME;
848
849 match result {
850 Err(error) => {
851 tracing::error!(
852 ?error,
853 total_duration_ms,
854 processing_duration_ms,
855 queue_duration_ms,
856 payload_type,
857 "error handling message"
858 )
859 }
860 Ok(()) => tracing::info!(
861 total_duration_ms,
862 processing_duration_ms,
863 queue_duration_ms,
864 "finished handling message"
865 ),
866 }
867 }
868 .boxed()
869 }),
870 );
871 if prev_handler.is_some() {
872 panic!("registered a handler for the same message twice");
873 }
874 self
875 }
876
877 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
878 where
879 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
880 Fut: 'static + Send + Future<Output = Result<()>>,
881 M: EnvelopedMessage,
882 {
883 self.add_handler(move |envelope, session| handler(envelope.payload, session));
884 self
885 }
886
887 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
888 where
889 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
890 Fut: Send + Future<Output = Result<()>>,
891 M: RequestMessage,
892 {
893 let handler = Arc::new(handler);
894 self.add_handler(move |envelope, session| {
895 let receipt = envelope.receipt();
896 let handler = handler.clone();
897 async move {
898 let peer = session.peer.clone();
899 let responded = Arc::new(AtomicBool::default());
900 let response = Response {
901 peer: peer.clone(),
902 responded: responded.clone(),
903 receipt,
904 };
905 match (handler)(envelope.payload, response, session).await {
906 Ok(()) => {
907 if responded.load(std::sync::atomic::Ordering::SeqCst) {
908 Ok(())
909 } else {
910 Err(anyhow!("handler did not send a response"))?
911 }
912 }
913 Err(error) => {
914 let proto_err = match &error {
915 Error::Internal(err) => err.to_proto(),
916 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
917 };
918 peer.respond_with_error(receipt, proto_err)?;
919 Err(error)
920 }
921 }
922 }
923 })
924 }
925
926 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
927 where
928 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
929 Fut: Send + Future<Output = Result<()>>,
930 M: RequestMessage,
931 {
932 let handler = Arc::new(handler);
933 self.add_handler(move |envelope, session| {
934 let receipt = envelope.receipt();
935 let handler = handler.clone();
936 async move {
937 let peer = session.peer.clone();
938 let response = StreamingResponse {
939 peer: peer.clone(),
940 receipt,
941 };
942 match (handler)(envelope.payload, response, session).await {
943 Ok(()) => {
944 peer.end_stream(receipt)?;
945 Ok(())
946 }
947 Err(error) => {
948 let proto_err = match &error {
949 Error::Internal(err) => err.to_proto(),
950 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
951 };
952 peer.respond_with_error(receipt, proto_err)?;
953 Err(error)
954 }
955 }
956 }
957 })
958 }
959
960 #[allow(clippy::too_many_arguments)]
961 pub fn handle_connection(
962 self: &Arc<Self>,
963 connection: Connection,
964 address: String,
965 principal: Principal,
966 zed_version: ZedVersion,
967 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
968 executor: Executor,
969 ) -> impl Future<Output = ()> {
970 let this = self.clone();
971 let span = info_span!("handle connection", %address,
972 connection_id=field::Empty,
973 user_id=field::Empty,
974 login=field::Empty,
975 impersonator=field::Empty,
976 dev_server_id=field::Empty
977 );
978 principal.update_span(&span);
979
980 let mut teardown = self.teardown.subscribe();
981 async move {
982 if *teardown.borrow() {
983 tracing::error!("server is tearing down");
984 return
985 }
986 let (connection_id, handle_io, mut incoming_rx) = this
987 .peer
988 .add_connection(connection, {
989 let executor = executor.clone();
990 move |duration| executor.sleep(duration)
991 });
992 tracing::Span::current().record("connection_id", format!("{}", connection_id));
993 tracing::info!("connection opened");
994
995 let http_client = match IsahcHttpClient::new() {
996 Ok(http_client) => Arc::new(http_client),
997 Err(error) => {
998 tracing::error!(?error, "failed to create HTTP client");
999 return;
1000 }
1001 };
1002
1003 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1004 Some(Arc::new(SupermavenAdminApi::new(
1005 supermaven_admin_api_key.to_string(),
1006 http_client.clone(),
1007 )))
1008 } else {
1009 None
1010 };
1011
1012 let session = Session {
1013 principal: principal.clone(),
1014 connection_id,
1015 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1016 peer: this.peer.clone(),
1017 connection_pool: this.connection_pool.clone(),
1018 live_kit_client: this.app_state.live_kit_client.clone(),
1019 http_client,
1020 rate_limiter: this.app_state.rate_limiter.clone(),
1021 _executor: executor.clone(),
1022 supermaven_client,
1023 };
1024
1025 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1026 tracing::error!(?error, "failed to send initial client update");
1027 return;
1028 }
1029
1030 let handle_io = handle_io.fuse();
1031 futures::pin_mut!(handle_io);
1032
1033 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1034 // This prevents deadlocks when e.g., client A performs a request to client B and
1035 // client B performs a request to client A. If both clients stop processing further
1036 // messages until their respective request completes, they won't have a chance to
1037 // respond to the other client's request and cause a deadlock.
1038 //
1039 // This arrangement ensures we will attempt to process earlier messages first, but fall
1040 // back to processing messages arrived later in the spirit of making progress.
1041 let mut foreground_message_handlers = FuturesUnordered::new();
1042 let concurrent_handlers = Arc::new(Semaphore::new(256));
1043 loop {
1044 let next_message = async {
1045 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1046 let message = incoming_rx.next().await;
1047 (permit, message)
1048 }.fuse();
1049 futures::pin_mut!(next_message);
1050 futures::select_biased! {
1051 _ = teardown.changed().fuse() => return,
1052 result = handle_io => {
1053 if let Err(error) = result {
1054 tracing::error!(?error, "error handling I/O");
1055 }
1056 break;
1057 }
1058 _ = foreground_message_handlers.next() => {}
1059 next_message = next_message => {
1060 let (permit, message) = next_message;
1061 if let Some(message) = message {
1062 let type_name = message.payload_type_name();
1063 // note: we copy all the fields from the parent span so we can query them in the logs.
1064 // (https://github.com/tokio-rs/tracing/issues/2670).
1065 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1066 user_id=field::Empty,
1067 login=field::Empty,
1068 impersonator=field::Empty,
1069 dev_server_id=field::Empty
1070 );
1071 principal.update_span(&span);
1072 let span_enter = span.enter();
1073 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1074 let is_background = message.is_background();
1075 let handle_message = (handler)(message, session.clone());
1076 drop(span_enter);
1077
1078 let handle_message = async move {
1079 handle_message.await;
1080 drop(permit);
1081 }.instrument(span);
1082 if is_background {
1083 executor.spawn_detached(handle_message);
1084 } else {
1085 foreground_message_handlers.push(handle_message);
1086 }
1087 } else {
1088 tracing::error!("no message handler");
1089 }
1090 } else {
1091 tracing::info!("connection closed");
1092 break;
1093 }
1094 }
1095 }
1096 }
1097
1098 drop(foreground_message_handlers);
1099 tracing::info!("signing out");
1100 if let Err(error) = connection_lost(session, teardown, executor).await {
1101 tracing::error!(?error, "error signing out");
1102 }
1103
1104 }.instrument(span)
1105 }
1106
1107 async fn send_initial_client_update(
1108 &self,
1109 connection_id: ConnectionId,
1110 principal: &Principal,
1111 zed_version: ZedVersion,
1112 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1113 session: &Session,
1114 ) -> Result<()> {
1115 self.peer.send(
1116 connection_id,
1117 proto::Hello {
1118 peer_id: Some(connection_id.into()),
1119 },
1120 )?;
1121 tracing::info!("sent hello message");
1122 if let Some(send_connection_id) = send_connection_id.take() {
1123 let _ = send_connection_id.send(connection_id);
1124 }
1125
1126 match principal {
1127 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1128 if !user.connected_once {
1129 self.peer.send(connection_id, proto::ShowContacts {})?;
1130 self.app_state
1131 .db
1132 .set_user_connected_once(user.id, true)
1133 .await?;
1134 }
1135
1136 let (contacts, dev_server_projects) = future::try_join(
1137 self.app_state.db.get_contacts(user.id),
1138 self.app_state.db.dev_server_projects_update(user.id),
1139 )
1140 .await?;
1141
1142 {
1143 let mut pool = self.connection_pool.lock();
1144 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1145 self.peer.send(
1146 connection_id,
1147 build_initial_contacts_update(contacts, &pool),
1148 )?;
1149 }
1150
1151 if should_auto_subscribe_to_channels(zed_version) {
1152 subscribe_user_to_channels(user.id, session).await?;
1153 }
1154
1155 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1156
1157 if let Some(incoming_call) =
1158 self.app_state.db.incoming_call_for_user(user.id).await?
1159 {
1160 self.peer.send(connection_id, incoming_call)?;
1161 }
1162
1163 update_user_contacts(user.id, &session).await?;
1164 }
1165 Principal::DevServer(dev_server) => {
1166 {
1167 let mut pool = self.connection_pool.lock();
1168 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1169 {
1170 self.peer.send(
1171 stale_connection_id,
1172 proto::ShutdownDevServer {
1173 reason: Some(
1174 "another dev server connected with the same token".to_string(),
1175 ),
1176 },
1177 )?;
1178 pool.remove_connection(stale_connection_id)?;
1179 };
1180 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1181 }
1182
1183 let projects = self
1184 .app_state
1185 .db
1186 .get_projects_for_dev_server(dev_server.id)
1187 .await?;
1188 self.peer
1189 .send(connection_id, proto::DevServerInstructions { projects })?;
1190
1191 let status = self
1192 .app_state
1193 .db
1194 .dev_server_projects_update(dev_server.user_id)
1195 .await?;
1196 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1197 }
1198 }
1199
1200 Ok(())
1201 }
1202
1203 pub async fn invite_code_redeemed(
1204 self: &Arc<Self>,
1205 inviter_id: UserId,
1206 invitee_id: UserId,
1207 ) -> Result<()> {
1208 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1209 if let Some(code) = &user.invite_code {
1210 let pool = self.connection_pool.lock();
1211 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1212 for connection_id in pool.user_connection_ids(inviter_id) {
1213 self.peer.send(
1214 connection_id,
1215 proto::UpdateContacts {
1216 contacts: vec![invitee_contact.clone()],
1217 ..Default::default()
1218 },
1219 )?;
1220 self.peer.send(
1221 connection_id,
1222 proto::UpdateInviteInfo {
1223 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1224 count: user.invite_count as u32,
1225 },
1226 )?;
1227 }
1228 }
1229 }
1230 Ok(())
1231 }
1232
1233 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1234 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1235 if let Some(invite_code) = &user.invite_code {
1236 let pool = self.connection_pool.lock();
1237 for connection_id in pool.user_connection_ids(user_id) {
1238 self.peer.send(
1239 connection_id,
1240 proto::UpdateInviteInfo {
1241 url: format!(
1242 "{}{}",
1243 self.app_state.config.invite_link_prefix, invite_code
1244 ),
1245 count: user.invite_count as u32,
1246 },
1247 )?;
1248 }
1249 }
1250 }
1251 Ok(())
1252 }
1253
1254 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1255 ServerSnapshot {
1256 connection_pool: ConnectionPoolGuard {
1257 guard: self.connection_pool.lock(),
1258 _not_send: PhantomData,
1259 },
1260 peer: &self.peer,
1261 }
1262 }
1263}
1264
1265impl<'a> Deref for ConnectionPoolGuard<'a> {
1266 type Target = ConnectionPool;
1267
1268 fn deref(&self) -> &Self::Target {
1269 &self.guard
1270 }
1271}
1272
1273impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1274 fn deref_mut(&mut self) -> &mut Self::Target {
1275 &mut self.guard
1276 }
1277}
1278
1279impl<'a> Drop for ConnectionPoolGuard<'a> {
1280 fn drop(&mut self) {
1281 #[cfg(test)]
1282 self.check_invariants();
1283 }
1284}
1285
1286fn broadcast<F>(
1287 sender_id: Option<ConnectionId>,
1288 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1289 mut f: F,
1290) where
1291 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1292{
1293 for receiver_id in receiver_ids {
1294 if Some(receiver_id) != sender_id {
1295 if let Err(error) = f(receiver_id) {
1296 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1297 }
1298 }
1299 }
1300}
1301
1302pub struct ProtocolVersion(u32);
1303
1304impl Header for ProtocolVersion {
1305 fn name() -> &'static HeaderName {
1306 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1307 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1308 }
1309
1310 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1311 where
1312 Self: Sized,
1313 I: Iterator<Item = &'i axum::http::HeaderValue>,
1314 {
1315 let version = values
1316 .next()
1317 .ok_or_else(axum::headers::Error::invalid)?
1318 .to_str()
1319 .map_err(|_| axum::headers::Error::invalid())?
1320 .parse()
1321 .map_err(|_| axum::headers::Error::invalid())?;
1322 Ok(Self(version))
1323 }
1324
1325 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1326 values.extend([self.0.to_string().parse().unwrap()]);
1327 }
1328}
1329
1330pub struct AppVersionHeader(SemanticVersion);
1331impl Header for AppVersionHeader {
1332 fn name() -> &'static HeaderName {
1333 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1334 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1335 }
1336
1337 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1338 where
1339 Self: Sized,
1340 I: Iterator<Item = &'i axum::http::HeaderValue>,
1341 {
1342 let version = values
1343 .next()
1344 .ok_or_else(axum::headers::Error::invalid)?
1345 .to_str()
1346 .map_err(|_| axum::headers::Error::invalid())?
1347 .parse()
1348 .map_err(|_| axum::headers::Error::invalid())?;
1349 Ok(Self(version))
1350 }
1351
1352 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1353 values.extend([self.0.to_string().parse().unwrap()]);
1354 }
1355}
1356
1357pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1358 Router::new()
1359 .route("/rpc", get(handle_websocket_request))
1360 .layer(
1361 ServiceBuilder::new()
1362 .layer(Extension(server.app_state.clone()))
1363 .layer(middleware::from_fn(auth::validate_header)),
1364 )
1365 .route("/metrics", get(handle_metrics))
1366 .layer(Extension(server))
1367}
1368
1369pub async fn handle_websocket_request(
1370 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1371 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1372 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1373 Extension(server): Extension<Arc<Server>>,
1374 Extension(principal): Extension<Principal>,
1375 ws: WebSocketUpgrade,
1376) -> axum::response::Response {
1377 if protocol_version != rpc::PROTOCOL_VERSION {
1378 return (
1379 StatusCode::UPGRADE_REQUIRED,
1380 "client must be upgraded".to_string(),
1381 )
1382 .into_response();
1383 }
1384
1385 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1386 return (
1387 StatusCode::UPGRADE_REQUIRED,
1388 "no version header found".to_string(),
1389 )
1390 .into_response();
1391 };
1392
1393 if !version.can_collaborate() {
1394 return (
1395 StatusCode::UPGRADE_REQUIRED,
1396 "client must be upgraded".to_string(),
1397 )
1398 .into_response();
1399 }
1400
1401 let socket_address = socket_address.to_string();
1402 ws.on_upgrade(move |socket| {
1403 let socket = socket
1404 .map_ok(to_tungstenite_message)
1405 .err_into()
1406 .with(|message| async move { to_axum_message(message) });
1407 let connection = Connection::new(Box::pin(socket));
1408 async move {
1409 server
1410 .handle_connection(
1411 connection,
1412 socket_address,
1413 principal,
1414 version,
1415 None,
1416 Executor::Production,
1417 )
1418 .await;
1419 }
1420 })
1421}
1422
1423pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1424 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1425 let connections_metric = CONNECTIONS_METRIC
1426 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1427
1428 let connections = server
1429 .connection_pool
1430 .lock()
1431 .connections()
1432 .filter(|connection| !connection.admin)
1433 .count();
1434 connections_metric.set(connections as _);
1435
1436 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1437 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1438 register_int_gauge!(
1439 "shared_projects",
1440 "number of open projects with one or more guests"
1441 )
1442 .unwrap()
1443 });
1444
1445 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1446 shared_projects_metric.set(shared_projects as _);
1447
1448 let encoder = prometheus::TextEncoder::new();
1449 let metric_families = prometheus::gather();
1450 let encoded_metrics = encoder
1451 .encode_to_string(&metric_families)
1452 .map_err(|err| anyhow!("{}", err))?;
1453 Ok(encoded_metrics)
1454}
1455
1456#[instrument(err, skip(executor))]
1457async fn connection_lost(
1458 session: Session,
1459 mut teardown: watch::Receiver<bool>,
1460 executor: Executor,
1461) -> Result<()> {
1462 session.peer.disconnect(session.connection_id);
1463 session
1464 .connection_pool()
1465 .await
1466 .remove_connection(session.connection_id)?;
1467
1468 session
1469 .db()
1470 .await
1471 .connection_lost(session.connection_id)
1472 .await
1473 .trace_err();
1474
1475 futures::select_biased! {
1476 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1477 match &session.principal {
1478 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1479 let session = session.for_user().unwrap();
1480
1481 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1482 leave_room_for_session(&session, session.connection_id).await.trace_err();
1483 leave_channel_buffers_for_session(&session)
1484 .await
1485 .trace_err();
1486
1487 if !session
1488 .connection_pool()
1489 .await
1490 .is_user_online(session.user_id())
1491 {
1492 let db = session.db().await;
1493 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1494 room_updated(&room, &session.peer);
1495 }
1496 }
1497
1498 update_user_contacts(session.user_id(), &session).await?;
1499 },
1500 Principal::DevServer(_) => {
1501 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1502 },
1503 }
1504 },
1505 _ = teardown.changed().fuse() => {}
1506 }
1507
1508 Ok(())
1509}
1510
1511/// Acknowledges a ping from a client, used to keep the connection alive.
1512async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1513 response.send(proto::Ack {})?;
1514 Ok(())
1515}
1516
1517/// Creates a new room for calling (outside of channels)
1518async fn create_room(
1519 _request: proto::CreateRoom,
1520 response: Response<proto::CreateRoom>,
1521 session: UserSession,
1522) -> Result<()> {
1523 let live_kit_room = nanoid::nanoid!(30);
1524
1525 let live_kit_connection_info = util::maybe!(async {
1526 let live_kit = session.live_kit_client.as_ref();
1527 let live_kit = live_kit?;
1528 let user_id = session.user_id().to_string();
1529
1530 let token = live_kit
1531 .room_token(&live_kit_room, &user_id.to_string())
1532 .trace_err()?;
1533
1534 Some(proto::LiveKitConnectionInfo {
1535 server_url: live_kit.url().into(),
1536 token,
1537 can_publish: true,
1538 })
1539 })
1540 .await;
1541
1542 let room = session
1543 .db()
1544 .await
1545 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1546 .await?;
1547
1548 response.send(proto::CreateRoomResponse {
1549 room: Some(room.clone()),
1550 live_kit_connection_info,
1551 })?;
1552
1553 update_user_contacts(session.user_id(), &session).await?;
1554 Ok(())
1555}
1556
1557/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1558async fn join_room(
1559 request: proto::JoinRoom,
1560 response: Response<proto::JoinRoom>,
1561 session: UserSession,
1562) -> Result<()> {
1563 let room_id = RoomId::from_proto(request.id);
1564
1565 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1566
1567 if let Some(channel_id) = channel_id {
1568 return join_channel_internal(channel_id, Box::new(response), session).await;
1569 }
1570
1571 let joined_room = {
1572 let room = session
1573 .db()
1574 .await
1575 .join_room(room_id, session.user_id(), session.connection_id)
1576 .await?;
1577 room_updated(&room.room, &session.peer);
1578 room.into_inner()
1579 };
1580
1581 for connection_id in session
1582 .connection_pool()
1583 .await
1584 .user_connection_ids(session.user_id())
1585 {
1586 session
1587 .peer
1588 .send(
1589 connection_id,
1590 proto::CallCanceled {
1591 room_id: room_id.to_proto(),
1592 },
1593 )
1594 .trace_err();
1595 }
1596
1597 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1598 if let Some(token) = live_kit
1599 .room_token(
1600 &joined_room.room.live_kit_room,
1601 &session.user_id().to_string(),
1602 )
1603 .trace_err()
1604 {
1605 Some(proto::LiveKitConnectionInfo {
1606 server_url: live_kit.url().into(),
1607 token,
1608 can_publish: true,
1609 })
1610 } else {
1611 None
1612 }
1613 } else {
1614 None
1615 };
1616
1617 response.send(proto::JoinRoomResponse {
1618 room: Some(joined_room.room),
1619 channel_id: None,
1620 live_kit_connection_info,
1621 })?;
1622
1623 update_user_contacts(session.user_id(), &session).await?;
1624 Ok(())
1625}
1626
1627/// Rejoin room is used to reconnect to a room after connection errors.
1628async fn rejoin_room(
1629 request: proto::RejoinRoom,
1630 response: Response<proto::RejoinRoom>,
1631 session: UserSession,
1632) -> Result<()> {
1633 let room;
1634 let channel;
1635 {
1636 let mut rejoined_room = session
1637 .db()
1638 .await
1639 .rejoin_room(request, session.user_id(), session.connection_id)
1640 .await?;
1641
1642 response.send(proto::RejoinRoomResponse {
1643 room: Some(rejoined_room.room.clone()),
1644 reshared_projects: rejoined_room
1645 .reshared_projects
1646 .iter()
1647 .map(|project| proto::ResharedProject {
1648 id: project.id.to_proto(),
1649 collaborators: project
1650 .collaborators
1651 .iter()
1652 .map(|collaborator| collaborator.to_proto())
1653 .collect(),
1654 })
1655 .collect(),
1656 rejoined_projects: rejoined_room
1657 .rejoined_projects
1658 .iter()
1659 .map(|rejoined_project| rejoined_project.to_proto())
1660 .collect(),
1661 })?;
1662 room_updated(&rejoined_room.room, &session.peer);
1663
1664 for project in &rejoined_room.reshared_projects {
1665 for collaborator in &project.collaborators {
1666 session
1667 .peer
1668 .send(
1669 collaborator.connection_id,
1670 proto::UpdateProjectCollaborator {
1671 project_id: project.id.to_proto(),
1672 old_peer_id: Some(project.old_connection_id.into()),
1673 new_peer_id: Some(session.connection_id.into()),
1674 },
1675 )
1676 .trace_err();
1677 }
1678
1679 broadcast(
1680 Some(session.connection_id),
1681 project
1682 .collaborators
1683 .iter()
1684 .map(|collaborator| collaborator.connection_id),
1685 |connection_id| {
1686 session.peer.forward_send(
1687 session.connection_id,
1688 connection_id,
1689 proto::UpdateProject {
1690 project_id: project.id.to_proto(),
1691 worktrees: project.worktrees.clone(),
1692 },
1693 )
1694 },
1695 );
1696 }
1697
1698 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1699
1700 let rejoined_room = rejoined_room.into_inner();
1701
1702 room = rejoined_room.room;
1703 channel = rejoined_room.channel;
1704 }
1705
1706 if let Some(channel) = channel {
1707 channel_updated(
1708 &channel,
1709 &room,
1710 &session.peer,
1711 &*session.connection_pool().await,
1712 );
1713 }
1714
1715 update_user_contacts(session.user_id(), &session).await?;
1716 Ok(())
1717}
1718
1719fn notify_rejoined_projects(
1720 rejoined_projects: &mut Vec<RejoinedProject>,
1721 session: &UserSession,
1722) -> Result<()> {
1723 for project in rejoined_projects.iter() {
1724 for collaborator in &project.collaborators {
1725 session
1726 .peer
1727 .send(
1728 collaborator.connection_id,
1729 proto::UpdateProjectCollaborator {
1730 project_id: project.id.to_proto(),
1731 old_peer_id: Some(project.old_connection_id.into()),
1732 new_peer_id: Some(session.connection_id.into()),
1733 },
1734 )
1735 .trace_err();
1736 }
1737 }
1738
1739 for project in rejoined_projects {
1740 for worktree in mem::take(&mut project.worktrees) {
1741 #[cfg(any(test, feature = "test-support"))]
1742 const MAX_CHUNK_SIZE: usize = 2;
1743 #[cfg(not(any(test, feature = "test-support")))]
1744 const MAX_CHUNK_SIZE: usize = 256;
1745
1746 // Stream this worktree's entries.
1747 let message = proto::UpdateWorktree {
1748 project_id: project.id.to_proto(),
1749 worktree_id: worktree.id,
1750 abs_path: worktree.abs_path.clone(),
1751 root_name: worktree.root_name,
1752 updated_entries: worktree.updated_entries,
1753 removed_entries: worktree.removed_entries,
1754 scan_id: worktree.scan_id,
1755 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1756 updated_repositories: worktree.updated_repositories,
1757 removed_repositories: worktree.removed_repositories,
1758 };
1759 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1760 session.peer.send(session.connection_id, update.clone())?;
1761 }
1762
1763 // Stream this worktree's diagnostics.
1764 for summary in worktree.diagnostic_summaries {
1765 session.peer.send(
1766 session.connection_id,
1767 proto::UpdateDiagnosticSummary {
1768 project_id: project.id.to_proto(),
1769 worktree_id: worktree.id,
1770 summary: Some(summary),
1771 },
1772 )?;
1773 }
1774
1775 for settings_file in worktree.settings_files {
1776 session.peer.send(
1777 session.connection_id,
1778 proto::UpdateWorktreeSettings {
1779 project_id: project.id.to_proto(),
1780 worktree_id: worktree.id,
1781 path: settings_file.path,
1782 content: Some(settings_file.content),
1783 },
1784 )?;
1785 }
1786 }
1787
1788 for language_server in &project.language_servers {
1789 session.peer.send(
1790 session.connection_id,
1791 proto::UpdateLanguageServer {
1792 project_id: project.id.to_proto(),
1793 language_server_id: language_server.id,
1794 variant: Some(
1795 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1796 proto::LspDiskBasedDiagnosticsUpdated {},
1797 ),
1798 ),
1799 },
1800 )?;
1801 }
1802 }
1803 Ok(())
1804}
1805
1806/// leave room disconnects from the room.
1807async fn leave_room(
1808 _: proto::LeaveRoom,
1809 response: Response<proto::LeaveRoom>,
1810 session: UserSession,
1811) -> Result<()> {
1812 leave_room_for_session(&session, session.connection_id).await?;
1813 response.send(proto::Ack {})?;
1814 Ok(())
1815}
1816
1817/// Updates the permissions of someone else in the room.
1818async fn set_room_participant_role(
1819 request: proto::SetRoomParticipantRole,
1820 response: Response<proto::SetRoomParticipantRole>,
1821 session: UserSession,
1822) -> Result<()> {
1823 let user_id = UserId::from_proto(request.user_id);
1824 let role = ChannelRole::from(request.role());
1825
1826 let (live_kit_room, can_publish) = {
1827 let room = session
1828 .db()
1829 .await
1830 .set_room_participant_role(
1831 session.user_id(),
1832 RoomId::from_proto(request.room_id),
1833 user_id,
1834 role,
1835 )
1836 .await?;
1837
1838 let live_kit_room = room.live_kit_room.clone();
1839 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1840 room_updated(&room, &session.peer);
1841 (live_kit_room, can_publish)
1842 };
1843
1844 if let Some(live_kit) = session.live_kit_client.as_ref() {
1845 live_kit
1846 .update_participant(
1847 live_kit_room.clone(),
1848 request.user_id.to_string(),
1849 live_kit_server::proto::ParticipantPermission {
1850 can_subscribe: true,
1851 can_publish,
1852 can_publish_data: can_publish,
1853 hidden: false,
1854 recorder: false,
1855 },
1856 )
1857 .await
1858 .trace_err();
1859 }
1860
1861 response.send(proto::Ack {})?;
1862 Ok(())
1863}
1864
1865/// Call someone else into the current room
1866async fn call(
1867 request: proto::Call,
1868 response: Response<proto::Call>,
1869 session: UserSession,
1870) -> Result<()> {
1871 let room_id = RoomId::from_proto(request.room_id);
1872 let calling_user_id = session.user_id();
1873 let calling_connection_id = session.connection_id;
1874 let called_user_id = UserId::from_proto(request.called_user_id);
1875 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1876 if !session
1877 .db()
1878 .await
1879 .has_contact(calling_user_id, called_user_id)
1880 .await?
1881 {
1882 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1883 }
1884
1885 let incoming_call = {
1886 let (room, incoming_call) = &mut *session
1887 .db()
1888 .await
1889 .call(
1890 room_id,
1891 calling_user_id,
1892 calling_connection_id,
1893 called_user_id,
1894 initial_project_id,
1895 )
1896 .await?;
1897 room_updated(&room, &session.peer);
1898 mem::take(incoming_call)
1899 };
1900 update_user_contacts(called_user_id, &session).await?;
1901
1902 let mut calls = session
1903 .connection_pool()
1904 .await
1905 .user_connection_ids(called_user_id)
1906 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1907 .collect::<FuturesUnordered<_>>();
1908
1909 while let Some(call_response) = calls.next().await {
1910 match call_response.as_ref() {
1911 Ok(_) => {
1912 response.send(proto::Ack {})?;
1913 return Ok(());
1914 }
1915 Err(_) => {
1916 call_response.trace_err();
1917 }
1918 }
1919 }
1920
1921 {
1922 let room = session
1923 .db()
1924 .await
1925 .call_failed(room_id, called_user_id)
1926 .await?;
1927 room_updated(&room, &session.peer);
1928 }
1929 update_user_contacts(called_user_id, &session).await?;
1930
1931 Err(anyhow!("failed to ring user"))?
1932}
1933
1934/// Cancel an outgoing call.
1935async fn cancel_call(
1936 request: proto::CancelCall,
1937 response: Response<proto::CancelCall>,
1938 session: UserSession,
1939) -> Result<()> {
1940 let called_user_id = UserId::from_proto(request.called_user_id);
1941 let room_id = RoomId::from_proto(request.room_id);
1942 {
1943 let room = session
1944 .db()
1945 .await
1946 .cancel_call(room_id, session.connection_id, called_user_id)
1947 .await?;
1948 room_updated(&room, &session.peer);
1949 }
1950
1951 for connection_id in session
1952 .connection_pool()
1953 .await
1954 .user_connection_ids(called_user_id)
1955 {
1956 session
1957 .peer
1958 .send(
1959 connection_id,
1960 proto::CallCanceled {
1961 room_id: room_id.to_proto(),
1962 },
1963 )
1964 .trace_err();
1965 }
1966 response.send(proto::Ack {})?;
1967
1968 update_user_contacts(called_user_id, &session).await?;
1969 Ok(())
1970}
1971
1972/// Decline an incoming call.
1973async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1974 let room_id = RoomId::from_proto(message.room_id);
1975 {
1976 let room = session
1977 .db()
1978 .await
1979 .decline_call(Some(room_id), session.user_id())
1980 .await?
1981 .ok_or_else(|| anyhow!("failed to decline call"))?;
1982 room_updated(&room, &session.peer);
1983 }
1984
1985 for connection_id in session
1986 .connection_pool()
1987 .await
1988 .user_connection_ids(session.user_id())
1989 {
1990 session
1991 .peer
1992 .send(
1993 connection_id,
1994 proto::CallCanceled {
1995 room_id: room_id.to_proto(),
1996 },
1997 )
1998 .trace_err();
1999 }
2000 update_user_contacts(session.user_id(), &session).await?;
2001 Ok(())
2002}
2003
2004/// Updates other participants in the room with your current location.
2005async fn update_participant_location(
2006 request: proto::UpdateParticipantLocation,
2007 response: Response<proto::UpdateParticipantLocation>,
2008 session: UserSession,
2009) -> Result<()> {
2010 let room_id = RoomId::from_proto(request.room_id);
2011 let location = request
2012 .location
2013 .ok_or_else(|| anyhow!("invalid location"))?;
2014
2015 let db = session.db().await;
2016 let room = db
2017 .update_room_participant_location(room_id, session.connection_id, location)
2018 .await?;
2019
2020 room_updated(&room, &session.peer);
2021 response.send(proto::Ack {})?;
2022 Ok(())
2023}
2024
2025/// Share a project into the room.
2026async fn share_project(
2027 request: proto::ShareProject,
2028 response: Response<proto::ShareProject>,
2029 session: UserSession,
2030) -> Result<()> {
2031 let (project_id, room) = &*session
2032 .db()
2033 .await
2034 .share_project(
2035 RoomId::from_proto(request.room_id),
2036 session.connection_id,
2037 &request.worktrees,
2038 request
2039 .dev_server_project_id
2040 .map(|id| DevServerProjectId::from_proto(id)),
2041 )
2042 .await?;
2043 response.send(proto::ShareProjectResponse {
2044 project_id: project_id.to_proto(),
2045 })?;
2046 room_updated(&room, &session.peer);
2047
2048 Ok(())
2049}
2050
2051/// Unshare a project from the room.
2052async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2053 let project_id = ProjectId::from_proto(message.project_id);
2054 unshare_project_internal(
2055 project_id,
2056 session.connection_id,
2057 session.user_id(),
2058 &session,
2059 )
2060 .await
2061}
2062
2063async fn unshare_project_internal(
2064 project_id: ProjectId,
2065 connection_id: ConnectionId,
2066 user_id: Option<UserId>,
2067 session: &Session,
2068) -> Result<()> {
2069 let delete = {
2070 let room_guard = session
2071 .db()
2072 .await
2073 .unshare_project(project_id, connection_id, user_id)
2074 .await?;
2075
2076 let (delete, room, guest_connection_ids) = &*room_guard;
2077
2078 let message = proto::UnshareProject {
2079 project_id: project_id.to_proto(),
2080 };
2081
2082 broadcast(
2083 Some(connection_id),
2084 guest_connection_ids.iter().copied(),
2085 |conn_id| session.peer.send(conn_id, message.clone()),
2086 );
2087 if let Some(room) = room {
2088 room_updated(room, &session.peer);
2089 }
2090
2091 *delete
2092 };
2093
2094 if delete {
2095 let db = session.db().await;
2096 db.delete_project(project_id).await?;
2097 }
2098
2099 Ok(())
2100}
2101
2102/// DevServer makes a project available online
2103async fn share_dev_server_project(
2104 request: proto::ShareDevServerProject,
2105 response: Response<proto::ShareDevServerProject>,
2106 session: DevServerSession,
2107) -> Result<()> {
2108 let (dev_server_project, user_id, status) = session
2109 .db()
2110 .await
2111 .share_dev_server_project(
2112 DevServerProjectId::from_proto(request.dev_server_project_id),
2113 session.dev_server_id(),
2114 session.connection_id,
2115 &request.worktrees,
2116 )
2117 .await?;
2118 let Some(project_id) = dev_server_project.project_id else {
2119 return Err(anyhow!("failed to share remote project"))?;
2120 };
2121
2122 send_dev_server_projects_update(user_id, status, &session).await;
2123
2124 response.send(proto::ShareProjectResponse { project_id })?;
2125
2126 Ok(())
2127}
2128
2129/// Join someone elses shared project.
2130async fn join_project(
2131 request: proto::JoinProject,
2132 response: Response<proto::JoinProject>,
2133 session: UserSession,
2134) -> Result<()> {
2135 let project_id = ProjectId::from_proto(request.project_id);
2136
2137 tracing::info!(%project_id, "join project");
2138
2139 let db = session.db().await;
2140 let (project, replica_id) = &mut *db
2141 .join_project(project_id, session.connection_id, session.user_id())
2142 .await?;
2143 drop(db);
2144 tracing::info!(%project_id, "join remote project");
2145 join_project_internal(response, session, project, replica_id)
2146}
2147
2148trait JoinProjectInternalResponse {
2149 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2150}
2151impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2152 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2153 Response::<proto::JoinProject>::send(self, result)
2154 }
2155}
2156impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2157 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2158 Response::<proto::JoinHostedProject>::send(self, result)
2159 }
2160}
2161
2162fn join_project_internal(
2163 response: impl JoinProjectInternalResponse,
2164 session: UserSession,
2165 project: &mut Project,
2166 replica_id: &ReplicaId,
2167) -> Result<()> {
2168 let collaborators = project
2169 .collaborators
2170 .iter()
2171 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2172 .map(|collaborator| collaborator.to_proto())
2173 .collect::<Vec<_>>();
2174 let project_id = project.id;
2175 let guest_user_id = session.user_id();
2176
2177 let worktrees = project
2178 .worktrees
2179 .iter()
2180 .map(|(id, worktree)| proto::WorktreeMetadata {
2181 id: *id,
2182 root_name: worktree.root_name.clone(),
2183 visible: worktree.visible,
2184 abs_path: worktree.abs_path.clone(),
2185 })
2186 .collect::<Vec<_>>();
2187
2188 let add_project_collaborator = proto::AddProjectCollaborator {
2189 project_id: project_id.to_proto(),
2190 collaborator: Some(proto::Collaborator {
2191 peer_id: Some(session.connection_id.into()),
2192 replica_id: replica_id.0 as u32,
2193 user_id: guest_user_id.to_proto(),
2194 }),
2195 };
2196
2197 for collaborator in &collaborators {
2198 session
2199 .peer
2200 .send(
2201 collaborator.peer_id.unwrap().into(),
2202 add_project_collaborator.clone(),
2203 )
2204 .trace_err();
2205 }
2206
2207 // First, we send the metadata associated with each worktree.
2208 response.send(proto::JoinProjectResponse {
2209 project_id: project.id.0 as u64,
2210 worktrees: worktrees.clone(),
2211 replica_id: replica_id.0 as u32,
2212 collaborators: collaborators.clone(),
2213 language_servers: project.language_servers.clone(),
2214 role: project.role.into(),
2215 dev_server_project_id: project
2216 .dev_server_project_id
2217 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2218 })?;
2219
2220 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2221 #[cfg(any(test, feature = "test-support"))]
2222 const MAX_CHUNK_SIZE: usize = 2;
2223 #[cfg(not(any(test, feature = "test-support")))]
2224 const MAX_CHUNK_SIZE: usize = 256;
2225
2226 // Stream this worktree's entries.
2227 let message = proto::UpdateWorktree {
2228 project_id: project_id.to_proto(),
2229 worktree_id,
2230 abs_path: worktree.abs_path.clone(),
2231 root_name: worktree.root_name,
2232 updated_entries: worktree.entries,
2233 removed_entries: Default::default(),
2234 scan_id: worktree.scan_id,
2235 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2236 updated_repositories: worktree.repository_entries.into_values().collect(),
2237 removed_repositories: Default::default(),
2238 };
2239 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2240 session.peer.send(session.connection_id, update.clone())?;
2241 }
2242
2243 // Stream this worktree's diagnostics.
2244 for summary in worktree.diagnostic_summaries {
2245 session.peer.send(
2246 session.connection_id,
2247 proto::UpdateDiagnosticSummary {
2248 project_id: project_id.to_proto(),
2249 worktree_id: worktree.id,
2250 summary: Some(summary),
2251 },
2252 )?;
2253 }
2254
2255 for settings_file in worktree.settings_files {
2256 session.peer.send(
2257 session.connection_id,
2258 proto::UpdateWorktreeSettings {
2259 project_id: project_id.to_proto(),
2260 worktree_id: worktree.id,
2261 path: settings_file.path,
2262 content: Some(settings_file.content),
2263 },
2264 )?;
2265 }
2266 }
2267
2268 for language_server in &project.language_servers {
2269 session.peer.send(
2270 session.connection_id,
2271 proto::UpdateLanguageServer {
2272 project_id: project_id.to_proto(),
2273 language_server_id: language_server.id,
2274 variant: Some(
2275 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2276 proto::LspDiskBasedDiagnosticsUpdated {},
2277 ),
2278 ),
2279 },
2280 )?;
2281 }
2282
2283 Ok(())
2284}
2285
2286/// Leave someone elses shared project.
2287async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2288 let sender_id = session.connection_id;
2289 let project_id = ProjectId::from_proto(request.project_id);
2290 let db = session.db().await;
2291 if db.is_hosted_project(project_id).await? {
2292 let project = db.leave_hosted_project(project_id, sender_id).await?;
2293 project_left(&project, &session);
2294 return Ok(());
2295 }
2296
2297 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2298 tracing::info!(
2299 %project_id,
2300 "leave project"
2301 );
2302
2303 project_left(&project, &session);
2304 if let Some(room) = room {
2305 room_updated(&room, &session.peer);
2306 }
2307
2308 Ok(())
2309}
2310
2311async fn join_hosted_project(
2312 request: proto::JoinHostedProject,
2313 response: Response<proto::JoinHostedProject>,
2314 session: UserSession,
2315) -> Result<()> {
2316 let (mut project, replica_id) = session
2317 .db()
2318 .await
2319 .join_hosted_project(
2320 ProjectId(request.project_id as i32),
2321 session.user_id(),
2322 session.connection_id,
2323 )
2324 .await?;
2325
2326 join_project_internal(response, session, &mut project, &replica_id)
2327}
2328
2329async fn list_remote_directory(
2330 request: proto::ListRemoteDirectory,
2331 response: Response<proto::ListRemoteDirectory>,
2332 session: UserSession,
2333) -> Result<()> {
2334 let dev_server_id = DevServerId(request.dev_server_id as i32);
2335 let dev_server_connection_id = session
2336 .connection_pool()
2337 .await
2338 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2339
2340 session
2341 .db()
2342 .await
2343 .get_dev_server_for_user(dev_server_id, session.user_id())
2344 .await?;
2345
2346 response.send(
2347 session
2348 .peer
2349 .forward_request(session.connection_id, dev_server_connection_id, request)
2350 .await?,
2351 )?;
2352 Ok(())
2353}
2354
2355async fn update_dev_server_project(
2356 request: proto::UpdateDevServerProject,
2357 response: Response<proto::UpdateDevServerProject>,
2358 session: UserSession,
2359) -> Result<()> {
2360 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2361
2362 let (dev_server_project, update) = session
2363 .db()
2364 .await
2365 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2366 .await?;
2367
2368 let projects = session
2369 .db()
2370 .await
2371 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2372 .await?;
2373
2374 let dev_server_connection_id = session
2375 .connection_pool()
2376 .await
2377 .dev_server_connection_id_supporting(
2378 dev_server_project.dev_server_id,
2379 ZedVersion::with_list_directory(),
2380 )?;
2381
2382 session.peer.send(
2383 dev_server_connection_id,
2384 proto::DevServerInstructions { projects },
2385 )?;
2386
2387 send_dev_server_projects_update(session.user_id(), update, &session).await;
2388
2389 response.send(proto::Ack {})
2390}
2391
2392async fn create_dev_server_project(
2393 request: proto::CreateDevServerProject,
2394 response: Response<proto::CreateDevServerProject>,
2395 session: UserSession,
2396) -> Result<()> {
2397 let dev_server_id = DevServerId(request.dev_server_id as i32);
2398 let dev_server_connection_id = session
2399 .connection_pool()
2400 .await
2401 .dev_server_connection_id(dev_server_id);
2402 let Some(dev_server_connection_id) = dev_server_connection_id else {
2403 Err(ErrorCode::DevServerOffline
2404 .message("Cannot create a remote project when the dev server is offline".to_string())
2405 .anyhow())?
2406 };
2407
2408 let path = request.path.clone();
2409 //Check that the path exists on the dev server
2410 session
2411 .peer
2412 .forward_request(
2413 session.connection_id,
2414 dev_server_connection_id,
2415 proto::ValidateDevServerProjectRequest { path: path.clone() },
2416 )
2417 .await?;
2418
2419 let (dev_server_project, update) = session
2420 .db()
2421 .await
2422 .create_dev_server_project(
2423 DevServerId(request.dev_server_id as i32),
2424 &request.path,
2425 session.user_id(),
2426 )
2427 .await?;
2428
2429 let projects = session
2430 .db()
2431 .await
2432 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2433 .await?;
2434
2435 session.peer.send(
2436 dev_server_connection_id,
2437 proto::DevServerInstructions { projects },
2438 )?;
2439
2440 send_dev_server_projects_update(session.user_id(), update, &session).await;
2441
2442 response.send(proto::CreateDevServerProjectResponse {
2443 dev_server_project: Some(dev_server_project.to_proto(None)),
2444 })?;
2445 Ok(())
2446}
2447
2448async fn create_dev_server(
2449 request: proto::CreateDevServer,
2450 response: Response<proto::CreateDevServer>,
2451 session: UserSession,
2452) -> Result<()> {
2453 let access_token = auth::random_token();
2454 let hashed_access_token = auth::hash_access_token(&access_token);
2455
2456 if request.name.is_empty() {
2457 return Err(proto::ErrorCode::Forbidden
2458 .message("Dev server name cannot be empty".to_string())
2459 .anyhow())?;
2460 }
2461
2462 let (dev_server, status) = session
2463 .db()
2464 .await
2465 .create_dev_server(
2466 &request.name,
2467 request.ssh_connection_string.as_deref(),
2468 &hashed_access_token,
2469 session.user_id(),
2470 )
2471 .await?;
2472
2473 send_dev_server_projects_update(session.user_id(), status, &session).await;
2474
2475 response.send(proto::CreateDevServerResponse {
2476 dev_server_id: dev_server.id.0 as u64,
2477 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2478 name: request.name,
2479 })?;
2480 Ok(())
2481}
2482
2483async fn regenerate_dev_server_token(
2484 request: proto::RegenerateDevServerToken,
2485 response: Response<proto::RegenerateDevServerToken>,
2486 session: UserSession,
2487) -> Result<()> {
2488 let dev_server_id = DevServerId(request.dev_server_id as i32);
2489 let access_token = auth::random_token();
2490 let hashed_access_token = auth::hash_access_token(&access_token);
2491
2492 let connection_id = session
2493 .connection_pool()
2494 .await
2495 .dev_server_connection_id(dev_server_id);
2496 if let Some(connection_id) = connection_id {
2497 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2498 session.peer.send(
2499 connection_id,
2500 proto::ShutdownDevServer {
2501 reason: Some("dev server token was regenerated".to_string()),
2502 },
2503 )?;
2504 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2505 }
2506
2507 let status = session
2508 .db()
2509 .await
2510 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2511 .await?;
2512
2513 send_dev_server_projects_update(session.user_id(), status, &session).await;
2514
2515 response.send(proto::RegenerateDevServerTokenResponse {
2516 dev_server_id: dev_server_id.to_proto(),
2517 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2518 })?;
2519 Ok(())
2520}
2521
2522async fn rename_dev_server(
2523 request: proto::RenameDevServer,
2524 response: Response<proto::RenameDevServer>,
2525 session: UserSession,
2526) -> Result<()> {
2527 if request.name.trim().is_empty() {
2528 return Err(proto::ErrorCode::Forbidden
2529 .message("Dev server name cannot be empty".to_string())
2530 .anyhow())?;
2531 }
2532
2533 let dev_server_id = DevServerId(request.dev_server_id as i32);
2534 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2535 if dev_server.user_id != session.user_id() {
2536 return Err(anyhow!(ErrorCode::Forbidden))?;
2537 }
2538
2539 let status = session
2540 .db()
2541 .await
2542 .rename_dev_server(
2543 dev_server_id,
2544 &request.name,
2545 request.ssh_connection_string.as_deref(),
2546 session.user_id(),
2547 )
2548 .await?;
2549
2550 send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552 response.send(proto::Ack {})?;
2553 Ok(())
2554}
2555
2556async fn delete_dev_server(
2557 request: proto::DeleteDevServer,
2558 response: Response<proto::DeleteDevServer>,
2559 session: UserSession,
2560) -> Result<()> {
2561 let dev_server_id = DevServerId(request.dev_server_id as i32);
2562 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2563 if dev_server.user_id != session.user_id() {
2564 return Err(anyhow!(ErrorCode::Forbidden))?;
2565 }
2566
2567 let connection_id = session
2568 .connection_pool()
2569 .await
2570 .dev_server_connection_id(dev_server_id);
2571 if let Some(connection_id) = connection_id {
2572 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2573 session.peer.send(
2574 connection_id,
2575 proto::ShutdownDevServer {
2576 reason: Some("dev server was deleted".to_string()),
2577 },
2578 )?;
2579 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2580 }
2581
2582 let status = session
2583 .db()
2584 .await
2585 .delete_dev_server(dev_server_id, session.user_id())
2586 .await?;
2587
2588 send_dev_server_projects_update(session.user_id(), status, &session).await;
2589
2590 response.send(proto::Ack {})?;
2591 Ok(())
2592}
2593
2594async fn delete_dev_server_project(
2595 request: proto::DeleteDevServerProject,
2596 response: Response<proto::DeleteDevServerProject>,
2597 session: UserSession,
2598) -> Result<()> {
2599 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2600 let dev_server_project = session
2601 .db()
2602 .await
2603 .get_dev_server_project(dev_server_project_id)
2604 .await?;
2605
2606 let dev_server = session
2607 .db()
2608 .await
2609 .get_dev_server(dev_server_project.dev_server_id)
2610 .await?;
2611 if dev_server.user_id != session.user_id() {
2612 return Err(anyhow!(ErrorCode::Forbidden))?;
2613 }
2614
2615 let dev_server_connection_id = session
2616 .connection_pool()
2617 .await
2618 .dev_server_connection_id(dev_server.id);
2619
2620 if let Some(dev_server_connection_id) = dev_server_connection_id {
2621 let project = session
2622 .db()
2623 .await
2624 .find_dev_server_project(dev_server_project_id)
2625 .await;
2626 if let Ok(project) = project {
2627 unshare_project_internal(
2628 project.id,
2629 dev_server_connection_id,
2630 Some(session.user_id()),
2631 &session,
2632 )
2633 .await?;
2634 }
2635 }
2636
2637 let (projects, status) = session
2638 .db()
2639 .await
2640 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2641 .await?;
2642
2643 if let Some(dev_server_connection_id) = dev_server_connection_id {
2644 session.peer.send(
2645 dev_server_connection_id,
2646 proto::DevServerInstructions { projects },
2647 )?;
2648 }
2649
2650 send_dev_server_projects_update(session.user_id(), status, &session).await;
2651
2652 response.send(proto::Ack {})?;
2653 Ok(())
2654}
2655
2656async fn rejoin_dev_server_projects(
2657 request: proto::RejoinRemoteProjects,
2658 response: Response<proto::RejoinRemoteProjects>,
2659 session: UserSession,
2660) -> Result<()> {
2661 let mut rejoined_projects = {
2662 let db = session.db().await;
2663 db.rejoin_dev_server_projects(
2664 &request.rejoined_projects,
2665 session.user_id(),
2666 session.0.connection_id,
2667 )
2668 .await?
2669 };
2670 response.send(proto::RejoinRemoteProjectsResponse {
2671 rejoined_projects: rejoined_projects
2672 .iter()
2673 .map(|project| project.to_proto())
2674 .collect(),
2675 })?;
2676 notify_rejoined_projects(&mut rejoined_projects, &session)
2677}
2678
2679async fn reconnect_dev_server(
2680 request: proto::ReconnectDevServer,
2681 response: Response<proto::ReconnectDevServer>,
2682 session: DevServerSession,
2683) -> Result<()> {
2684 let reshared_projects = {
2685 let db = session.db().await;
2686 db.reshare_dev_server_projects(
2687 &request.reshared_projects,
2688 session.dev_server_id(),
2689 session.0.connection_id,
2690 )
2691 .await?
2692 };
2693
2694 for project in &reshared_projects {
2695 for collaborator in &project.collaborators {
2696 session
2697 .peer
2698 .send(
2699 collaborator.connection_id,
2700 proto::UpdateProjectCollaborator {
2701 project_id: project.id.to_proto(),
2702 old_peer_id: Some(project.old_connection_id.into()),
2703 new_peer_id: Some(session.connection_id.into()),
2704 },
2705 )
2706 .trace_err();
2707 }
2708
2709 broadcast(
2710 Some(session.connection_id),
2711 project
2712 .collaborators
2713 .iter()
2714 .map(|collaborator| collaborator.connection_id),
2715 |connection_id| {
2716 session.peer.forward_send(
2717 session.connection_id,
2718 connection_id,
2719 proto::UpdateProject {
2720 project_id: project.id.to_proto(),
2721 worktrees: project.worktrees.clone(),
2722 },
2723 )
2724 },
2725 );
2726 }
2727
2728 response.send(proto::ReconnectDevServerResponse {
2729 reshared_projects: reshared_projects
2730 .iter()
2731 .map(|project| proto::ResharedProject {
2732 id: project.id.to_proto(),
2733 collaborators: project
2734 .collaborators
2735 .iter()
2736 .map(|collaborator| collaborator.to_proto())
2737 .collect(),
2738 })
2739 .collect(),
2740 })?;
2741
2742 Ok(())
2743}
2744
2745async fn shutdown_dev_server(
2746 _: proto::ShutdownDevServer,
2747 response: Response<proto::ShutdownDevServer>,
2748 session: DevServerSession,
2749) -> Result<()> {
2750 response.send(proto::Ack {})?;
2751 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2752 remove_dev_server_connection(session.dev_server_id(), &session).await
2753}
2754
2755async fn shutdown_dev_server_internal(
2756 dev_server_id: DevServerId,
2757 connection_id: ConnectionId,
2758 session: &Session,
2759) -> Result<()> {
2760 let (dev_server_projects, dev_server) = {
2761 let db = session.db().await;
2762 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2763 let dev_server = db.get_dev_server(dev_server_id).await?;
2764 (dev_server_projects, dev_server)
2765 };
2766
2767 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2768 unshare_project_internal(
2769 ProjectId::from_proto(project_id),
2770 connection_id,
2771 None,
2772 session,
2773 )
2774 .await?;
2775 }
2776
2777 session
2778 .connection_pool()
2779 .await
2780 .set_dev_server_offline(dev_server_id);
2781
2782 let status = session
2783 .db()
2784 .await
2785 .dev_server_projects_update(dev_server.user_id)
2786 .await?;
2787 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2788
2789 Ok(())
2790}
2791
2792async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2793 let dev_server_connection = session
2794 .connection_pool()
2795 .await
2796 .dev_server_connection_id(dev_server_id);
2797
2798 if let Some(dev_server_connection) = dev_server_connection {
2799 session
2800 .connection_pool()
2801 .await
2802 .remove_connection(dev_server_connection)?;
2803 }
2804 Ok(())
2805}
2806
2807/// Updates other participants with changes to the project
2808async fn update_project(
2809 request: proto::UpdateProject,
2810 response: Response<proto::UpdateProject>,
2811 session: Session,
2812) -> Result<()> {
2813 let project_id = ProjectId::from_proto(request.project_id);
2814 let (room, guest_connection_ids) = &*session
2815 .db()
2816 .await
2817 .update_project(project_id, session.connection_id, &request.worktrees)
2818 .await?;
2819 broadcast(
2820 Some(session.connection_id),
2821 guest_connection_ids.iter().copied(),
2822 |connection_id| {
2823 session
2824 .peer
2825 .forward_send(session.connection_id, connection_id, request.clone())
2826 },
2827 );
2828 if let Some(room) = room {
2829 room_updated(&room, &session.peer);
2830 }
2831 response.send(proto::Ack {})?;
2832
2833 Ok(())
2834}
2835
2836/// Updates other participants with changes to the worktree
2837async fn update_worktree(
2838 request: proto::UpdateWorktree,
2839 response: Response<proto::UpdateWorktree>,
2840 session: Session,
2841) -> Result<()> {
2842 let guest_connection_ids = session
2843 .db()
2844 .await
2845 .update_worktree(&request, session.connection_id)
2846 .await?;
2847
2848 broadcast(
2849 Some(session.connection_id),
2850 guest_connection_ids.iter().copied(),
2851 |connection_id| {
2852 session
2853 .peer
2854 .forward_send(session.connection_id, connection_id, request.clone())
2855 },
2856 );
2857 response.send(proto::Ack {})?;
2858 Ok(())
2859}
2860
2861/// Updates other participants with changes to the diagnostics
2862async fn update_diagnostic_summary(
2863 message: proto::UpdateDiagnosticSummary,
2864 session: Session,
2865) -> Result<()> {
2866 let guest_connection_ids = session
2867 .db()
2868 .await
2869 .update_diagnostic_summary(&message, session.connection_id)
2870 .await?;
2871
2872 broadcast(
2873 Some(session.connection_id),
2874 guest_connection_ids.iter().copied(),
2875 |connection_id| {
2876 session
2877 .peer
2878 .forward_send(session.connection_id, connection_id, message.clone())
2879 },
2880 );
2881
2882 Ok(())
2883}
2884
2885/// Updates other participants with changes to the worktree settings
2886async fn update_worktree_settings(
2887 message: proto::UpdateWorktreeSettings,
2888 session: Session,
2889) -> Result<()> {
2890 let guest_connection_ids = session
2891 .db()
2892 .await
2893 .update_worktree_settings(&message, session.connection_id)
2894 .await?;
2895
2896 broadcast(
2897 Some(session.connection_id),
2898 guest_connection_ids.iter().copied(),
2899 |connection_id| {
2900 session
2901 .peer
2902 .forward_send(session.connection_id, connection_id, message.clone())
2903 },
2904 );
2905
2906 Ok(())
2907}
2908
2909/// Notify other participants that a language server has started.
2910async fn start_language_server(
2911 request: proto::StartLanguageServer,
2912 session: Session,
2913) -> Result<()> {
2914 let guest_connection_ids = session
2915 .db()
2916 .await
2917 .start_language_server(&request, session.connection_id)
2918 .await?;
2919
2920 broadcast(
2921 Some(session.connection_id),
2922 guest_connection_ids.iter().copied(),
2923 |connection_id| {
2924 session
2925 .peer
2926 .forward_send(session.connection_id, connection_id, request.clone())
2927 },
2928 );
2929 Ok(())
2930}
2931
2932/// Notify other participants that a language server has changed.
2933async fn update_language_server(
2934 request: proto::UpdateLanguageServer,
2935 session: Session,
2936) -> Result<()> {
2937 let project_id = ProjectId::from_proto(request.project_id);
2938 let project_connection_ids = session
2939 .db()
2940 .await
2941 .project_connection_ids(project_id, session.connection_id, true)
2942 .await?;
2943 broadcast(
2944 Some(session.connection_id),
2945 project_connection_ids.iter().copied(),
2946 |connection_id| {
2947 session
2948 .peer
2949 .forward_send(session.connection_id, connection_id, request.clone())
2950 },
2951 );
2952 Ok(())
2953}
2954
2955/// forward a project request to the host. These requests should be read only
2956/// as guests are allowed to send them.
2957async fn forward_read_only_project_request<T>(
2958 request: T,
2959 response: Response<T>,
2960 session: UserSession,
2961) -> Result<()>
2962where
2963 T: EntityMessage + RequestMessage,
2964{
2965 let project_id = ProjectId::from_proto(request.remote_entity_id());
2966 let host_connection_id = session
2967 .db()
2968 .await
2969 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2970 .await?;
2971 let payload = session
2972 .peer
2973 .forward_request(session.connection_id, host_connection_id, request)
2974 .await?;
2975 response.send(payload)?;
2976 Ok(())
2977}
2978
2979/// forward a project request to the dev server. Only allowed
2980/// if it's your dev server.
2981async fn forward_project_request_for_owner<T>(
2982 request: T,
2983 response: Response<T>,
2984 session: UserSession,
2985) -> Result<()>
2986where
2987 T: EntityMessage + RequestMessage,
2988{
2989 let project_id = ProjectId::from_proto(request.remote_entity_id());
2990
2991 let host_connection_id = session
2992 .db()
2993 .await
2994 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2995 .await?;
2996 let payload = session
2997 .peer
2998 .forward_request(session.connection_id, host_connection_id, request)
2999 .await?;
3000 response.send(payload)?;
3001 Ok(())
3002}
3003
3004/// forward a project request to the host. These requests are disallowed
3005/// for guests.
3006async fn forward_mutating_project_request<T>(
3007 request: T,
3008 response: Response<T>,
3009 session: UserSession,
3010) -> Result<()>
3011where
3012 T: EntityMessage + RequestMessage,
3013{
3014 let project_id = ProjectId::from_proto(request.remote_entity_id());
3015
3016 let host_connection_id = session
3017 .db()
3018 .await
3019 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3020 .await?;
3021 let payload = session
3022 .peer
3023 .forward_request(session.connection_id, host_connection_id, request)
3024 .await?;
3025 response.send(payload)?;
3026 Ok(())
3027}
3028
3029/// forward a project request to the host. These requests are disallowed
3030/// for guests.
3031async fn forward_versioned_mutating_project_request<T>(
3032 request: T,
3033 response: Response<T>,
3034 session: UserSession,
3035) -> Result<()>
3036where
3037 T: EntityMessage + RequestMessage + VersionedMessage,
3038{
3039 let project_id = ProjectId::from_proto(request.remote_entity_id());
3040
3041 let host_connection_id = session
3042 .db()
3043 .await
3044 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3045 .await?;
3046 if let Some(host_version) = session
3047 .connection_pool()
3048 .await
3049 .connection(host_connection_id)
3050 .map(|c| c.zed_version)
3051 {
3052 if let Some(min_required_version) = request.required_host_version() {
3053 if min_required_version > host_version {
3054 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3055 .with_tag("required", &min_required_version.to_string())))?;
3056 }
3057 }
3058 }
3059
3060 let payload = session
3061 .peer
3062 .forward_request(session.connection_id, host_connection_id, request)
3063 .await?;
3064 response.send(payload)?;
3065 Ok(())
3066}
3067
3068/// Notify other participants that a new buffer has been created
3069async fn create_buffer_for_peer(
3070 request: proto::CreateBufferForPeer,
3071 session: Session,
3072) -> Result<()> {
3073 session
3074 .db()
3075 .await
3076 .check_user_is_project_host(
3077 ProjectId::from_proto(request.project_id),
3078 session.connection_id,
3079 )
3080 .await?;
3081 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3082 session
3083 .peer
3084 .forward_send(session.connection_id, peer_id.into(), request)?;
3085 Ok(())
3086}
3087
3088/// Notify other participants that a buffer has been updated. This is
3089/// allowed for guests as long as the update is limited to selections.
3090async fn update_buffer(
3091 request: proto::UpdateBuffer,
3092 response: Response<proto::UpdateBuffer>,
3093 session: Session,
3094) -> Result<()> {
3095 let project_id = ProjectId::from_proto(request.project_id);
3096 let mut capability = Capability::ReadOnly;
3097
3098 for op in request.operations.iter() {
3099 match op.variant {
3100 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3101 Some(_) => capability = Capability::ReadWrite,
3102 }
3103 }
3104
3105 let host = {
3106 let guard = session
3107 .db()
3108 .await
3109 .connections_for_buffer_update(
3110 project_id,
3111 session.principal_id(),
3112 session.connection_id,
3113 capability,
3114 )
3115 .await?;
3116
3117 let (host, guests) = &*guard;
3118
3119 broadcast(
3120 Some(session.connection_id),
3121 guests.clone(),
3122 |connection_id| {
3123 session
3124 .peer
3125 .forward_send(session.connection_id, connection_id, request.clone())
3126 },
3127 );
3128
3129 *host
3130 };
3131
3132 if host != session.connection_id {
3133 session
3134 .peer
3135 .forward_request(session.connection_id, host, request.clone())
3136 .await?;
3137 }
3138
3139 response.send(proto::Ack {})?;
3140 Ok(())
3141}
3142
3143async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3144 let project_id = ProjectId::from_proto(message.project_id);
3145
3146 let operation = message.operation.as_ref().context("invalid operation")?;
3147 let capability = match operation.variant.as_ref() {
3148 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3149 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3150 match buffer_op.variant {
3151 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3152 Capability::ReadOnly
3153 }
3154 _ => Capability::ReadWrite,
3155 }
3156 } else {
3157 Capability::ReadWrite
3158 }
3159 }
3160 Some(_) => Capability::ReadWrite,
3161 None => Capability::ReadOnly,
3162 };
3163
3164 let guard = session
3165 .db()
3166 .await
3167 .connections_for_buffer_update(
3168 project_id,
3169 session.principal_id(),
3170 session.connection_id,
3171 capability,
3172 )
3173 .await?;
3174
3175 let (host, guests) = &*guard;
3176
3177 broadcast(
3178 Some(session.connection_id),
3179 guests.iter().chain([host]).copied(),
3180 |connection_id| {
3181 session
3182 .peer
3183 .forward_send(session.connection_id, connection_id, message.clone())
3184 },
3185 );
3186
3187 Ok(())
3188}
3189
3190/// Notify other participants that a project has been updated.
3191async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3192 request: T,
3193 session: Session,
3194) -> Result<()> {
3195 let project_id = ProjectId::from_proto(request.remote_entity_id());
3196 let project_connection_ids = session
3197 .db()
3198 .await
3199 .project_connection_ids(project_id, session.connection_id, false)
3200 .await?;
3201
3202 broadcast(
3203 Some(session.connection_id),
3204 project_connection_ids.iter().copied(),
3205 |connection_id| {
3206 session
3207 .peer
3208 .forward_send(session.connection_id, connection_id, request.clone())
3209 },
3210 );
3211 Ok(())
3212}
3213
3214/// Start following another user in a call.
3215async fn follow(
3216 request: proto::Follow,
3217 response: Response<proto::Follow>,
3218 session: UserSession,
3219) -> Result<()> {
3220 let room_id = RoomId::from_proto(request.room_id);
3221 let project_id = request.project_id.map(ProjectId::from_proto);
3222 let leader_id = request
3223 .leader_id
3224 .ok_or_else(|| anyhow!("invalid leader id"))?
3225 .into();
3226 let follower_id = session.connection_id;
3227
3228 session
3229 .db()
3230 .await
3231 .check_room_participants(room_id, leader_id, session.connection_id)
3232 .await?;
3233
3234 let response_payload = session
3235 .peer
3236 .forward_request(session.connection_id, leader_id, request)
3237 .await?;
3238 response.send(response_payload)?;
3239
3240 if let Some(project_id) = project_id {
3241 let room = session
3242 .db()
3243 .await
3244 .follow(room_id, project_id, leader_id, follower_id)
3245 .await?;
3246 room_updated(&room, &session.peer);
3247 }
3248
3249 Ok(())
3250}
3251
3252/// Stop following another user in a call.
3253async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3254 let room_id = RoomId::from_proto(request.room_id);
3255 let project_id = request.project_id.map(ProjectId::from_proto);
3256 let leader_id = request
3257 .leader_id
3258 .ok_or_else(|| anyhow!("invalid leader id"))?
3259 .into();
3260 let follower_id = session.connection_id;
3261
3262 session
3263 .db()
3264 .await
3265 .check_room_participants(room_id, leader_id, session.connection_id)
3266 .await?;
3267
3268 session
3269 .peer
3270 .forward_send(session.connection_id, leader_id, request)?;
3271
3272 if let Some(project_id) = project_id {
3273 let room = session
3274 .db()
3275 .await
3276 .unfollow(room_id, project_id, leader_id, follower_id)
3277 .await?;
3278 room_updated(&room, &session.peer);
3279 }
3280
3281 Ok(())
3282}
3283
3284/// Notify everyone following you of your current location.
3285async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3286 let room_id = RoomId::from_proto(request.room_id);
3287 let database = session.db.lock().await;
3288
3289 let connection_ids = if let Some(project_id) = request.project_id {
3290 let project_id = ProjectId::from_proto(project_id);
3291 database
3292 .project_connection_ids(project_id, session.connection_id, true)
3293 .await?
3294 } else {
3295 database
3296 .room_connection_ids(room_id, session.connection_id)
3297 .await?
3298 };
3299
3300 // For now, don't send view update messages back to that view's current leader.
3301 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3302 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3303 _ => None,
3304 });
3305
3306 for connection_id in connection_ids.iter().cloned() {
3307 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3308 session
3309 .peer
3310 .forward_send(session.connection_id, connection_id, request.clone())?;
3311 }
3312 }
3313 Ok(())
3314}
3315
3316/// Get public data about users.
3317async fn get_users(
3318 request: proto::GetUsers,
3319 response: Response<proto::GetUsers>,
3320 session: Session,
3321) -> Result<()> {
3322 let user_ids = request
3323 .user_ids
3324 .into_iter()
3325 .map(UserId::from_proto)
3326 .collect();
3327 let users = session
3328 .db()
3329 .await
3330 .get_users_by_ids(user_ids)
3331 .await?
3332 .into_iter()
3333 .map(|user| proto::User {
3334 id: user.id.to_proto(),
3335 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3336 github_login: user.github_login,
3337 })
3338 .collect();
3339 response.send(proto::UsersResponse { users })?;
3340 Ok(())
3341}
3342
3343/// Search for users (to invite) buy Github login
3344async fn fuzzy_search_users(
3345 request: proto::FuzzySearchUsers,
3346 response: Response<proto::FuzzySearchUsers>,
3347 session: UserSession,
3348) -> Result<()> {
3349 let query = request.query;
3350 let users = match query.len() {
3351 0 => vec![],
3352 1 | 2 => session
3353 .db()
3354 .await
3355 .get_user_by_github_login(&query)
3356 .await?
3357 .into_iter()
3358 .collect(),
3359 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3360 };
3361 let users = users
3362 .into_iter()
3363 .filter(|user| user.id != session.user_id())
3364 .map(|user| proto::User {
3365 id: user.id.to_proto(),
3366 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3367 github_login: user.github_login,
3368 })
3369 .collect();
3370 response.send(proto::UsersResponse { users })?;
3371 Ok(())
3372}
3373
3374/// Send a contact request to another user.
3375async fn request_contact(
3376 request: proto::RequestContact,
3377 response: Response<proto::RequestContact>,
3378 session: UserSession,
3379) -> Result<()> {
3380 let requester_id = session.user_id();
3381 let responder_id = UserId::from_proto(request.responder_id);
3382 if requester_id == responder_id {
3383 return Err(anyhow!("cannot add yourself as a contact"))?;
3384 }
3385
3386 let notifications = session
3387 .db()
3388 .await
3389 .send_contact_request(requester_id, responder_id)
3390 .await?;
3391
3392 // Update outgoing contact requests of requester
3393 let mut update = proto::UpdateContacts::default();
3394 update.outgoing_requests.push(responder_id.to_proto());
3395 for connection_id in session
3396 .connection_pool()
3397 .await
3398 .user_connection_ids(requester_id)
3399 {
3400 session.peer.send(connection_id, update.clone())?;
3401 }
3402
3403 // Update incoming contact requests of responder
3404 let mut update = proto::UpdateContacts::default();
3405 update
3406 .incoming_requests
3407 .push(proto::IncomingContactRequest {
3408 requester_id: requester_id.to_proto(),
3409 });
3410 let connection_pool = session.connection_pool().await;
3411 for connection_id in connection_pool.user_connection_ids(responder_id) {
3412 session.peer.send(connection_id, update.clone())?;
3413 }
3414
3415 send_notifications(&connection_pool, &session.peer, notifications);
3416
3417 response.send(proto::Ack {})?;
3418 Ok(())
3419}
3420
3421/// Accept or decline a contact request
3422async fn respond_to_contact_request(
3423 request: proto::RespondToContactRequest,
3424 response: Response<proto::RespondToContactRequest>,
3425 session: UserSession,
3426) -> Result<()> {
3427 let responder_id = session.user_id();
3428 let requester_id = UserId::from_proto(request.requester_id);
3429 let db = session.db().await;
3430 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3431 db.dismiss_contact_notification(responder_id, requester_id)
3432 .await?;
3433 } else {
3434 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3435
3436 let notifications = db
3437 .respond_to_contact_request(responder_id, requester_id, accept)
3438 .await?;
3439 let requester_busy = db.is_user_busy(requester_id).await?;
3440 let responder_busy = db.is_user_busy(responder_id).await?;
3441
3442 let pool = session.connection_pool().await;
3443 // Update responder with new contact
3444 let mut update = proto::UpdateContacts::default();
3445 if accept {
3446 update
3447 .contacts
3448 .push(contact_for_user(requester_id, requester_busy, &pool));
3449 }
3450 update
3451 .remove_incoming_requests
3452 .push(requester_id.to_proto());
3453 for connection_id in pool.user_connection_ids(responder_id) {
3454 session.peer.send(connection_id, update.clone())?;
3455 }
3456
3457 // Update requester with new contact
3458 let mut update = proto::UpdateContacts::default();
3459 if accept {
3460 update
3461 .contacts
3462 .push(contact_for_user(responder_id, responder_busy, &pool));
3463 }
3464 update
3465 .remove_outgoing_requests
3466 .push(responder_id.to_proto());
3467
3468 for connection_id in pool.user_connection_ids(requester_id) {
3469 session.peer.send(connection_id, update.clone())?;
3470 }
3471
3472 send_notifications(&pool, &session.peer, notifications);
3473 }
3474
3475 response.send(proto::Ack {})?;
3476 Ok(())
3477}
3478
3479/// Remove a contact.
3480async fn remove_contact(
3481 request: proto::RemoveContact,
3482 response: Response<proto::RemoveContact>,
3483 session: UserSession,
3484) -> Result<()> {
3485 let requester_id = session.user_id();
3486 let responder_id = UserId::from_proto(request.user_id);
3487 let db = session.db().await;
3488 let (contact_accepted, deleted_notification_id) =
3489 db.remove_contact(requester_id, responder_id).await?;
3490
3491 let pool = session.connection_pool().await;
3492 // Update outgoing contact requests of requester
3493 let mut update = proto::UpdateContacts::default();
3494 if contact_accepted {
3495 update.remove_contacts.push(responder_id.to_proto());
3496 } else {
3497 update
3498 .remove_outgoing_requests
3499 .push(responder_id.to_proto());
3500 }
3501 for connection_id in pool.user_connection_ids(requester_id) {
3502 session.peer.send(connection_id, update.clone())?;
3503 }
3504
3505 // Update incoming contact requests of responder
3506 let mut update = proto::UpdateContacts::default();
3507 if contact_accepted {
3508 update.remove_contacts.push(requester_id.to_proto());
3509 } else {
3510 update
3511 .remove_incoming_requests
3512 .push(requester_id.to_proto());
3513 }
3514 for connection_id in pool.user_connection_ids(responder_id) {
3515 session.peer.send(connection_id, update.clone())?;
3516 if let Some(notification_id) = deleted_notification_id {
3517 session.peer.send(
3518 connection_id,
3519 proto::DeleteNotification {
3520 notification_id: notification_id.to_proto(),
3521 },
3522 )?;
3523 }
3524 }
3525
3526 response.send(proto::Ack {})?;
3527 Ok(())
3528}
3529
3530fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3531 version.0.minor() < 139
3532}
3533
3534async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3535 subscribe_user_to_channels(
3536 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3537 &session,
3538 )
3539 .await?;
3540 Ok(())
3541}
3542
3543async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3544 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3545 let mut pool = session.connection_pool().await;
3546 for membership in &channels_for_user.channel_memberships {
3547 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3548 }
3549 session.peer.send(
3550 session.connection_id,
3551 build_update_user_channels(&channels_for_user),
3552 )?;
3553 session.peer.send(
3554 session.connection_id,
3555 build_channels_update(channels_for_user),
3556 )?;
3557 Ok(())
3558}
3559
3560/// Creates a new channel.
3561async fn create_channel(
3562 request: proto::CreateChannel,
3563 response: Response<proto::CreateChannel>,
3564 session: UserSession,
3565) -> Result<()> {
3566 let db = session.db().await;
3567
3568 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3569 let (channel, membership) = db
3570 .create_channel(&request.name, parent_id, session.user_id())
3571 .await?;
3572
3573 let root_id = channel.root_id();
3574 let channel = Channel::from_model(channel);
3575
3576 response.send(proto::CreateChannelResponse {
3577 channel: Some(channel.to_proto()),
3578 parent_id: request.parent_id,
3579 })?;
3580
3581 let mut connection_pool = session.connection_pool().await;
3582 if let Some(membership) = membership {
3583 connection_pool.subscribe_to_channel(
3584 membership.user_id,
3585 membership.channel_id,
3586 membership.role,
3587 );
3588 let update = proto::UpdateUserChannels {
3589 channel_memberships: vec![proto::ChannelMembership {
3590 channel_id: membership.channel_id.to_proto(),
3591 role: membership.role.into(),
3592 }],
3593 ..Default::default()
3594 };
3595 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3596 session.peer.send(connection_id, update.clone())?;
3597 }
3598 }
3599
3600 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3601 if !role.can_see_channel(channel.visibility) {
3602 continue;
3603 }
3604
3605 let update = proto::UpdateChannels {
3606 channels: vec![channel.to_proto()],
3607 ..Default::default()
3608 };
3609 session.peer.send(connection_id, update.clone())?;
3610 }
3611
3612 Ok(())
3613}
3614
3615/// Delete a channel
3616async fn delete_channel(
3617 request: proto::DeleteChannel,
3618 response: Response<proto::DeleteChannel>,
3619 session: UserSession,
3620) -> Result<()> {
3621 let db = session.db().await;
3622
3623 let channel_id = request.channel_id;
3624 let (root_channel, removed_channels) = db
3625 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3626 .await?;
3627 response.send(proto::Ack {})?;
3628
3629 // Notify members of removed channels
3630 let mut update = proto::UpdateChannels::default();
3631 update
3632 .delete_channels
3633 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3634
3635 let connection_pool = session.connection_pool().await;
3636 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3637 session.peer.send(connection_id, update.clone())?;
3638 }
3639
3640 Ok(())
3641}
3642
3643/// Invite someone to join a channel.
3644async fn invite_channel_member(
3645 request: proto::InviteChannelMember,
3646 response: Response<proto::InviteChannelMember>,
3647 session: UserSession,
3648) -> Result<()> {
3649 let db = session.db().await;
3650 let channel_id = ChannelId::from_proto(request.channel_id);
3651 let invitee_id = UserId::from_proto(request.user_id);
3652 let InviteMemberResult {
3653 channel,
3654 notifications,
3655 } = db
3656 .invite_channel_member(
3657 channel_id,
3658 invitee_id,
3659 session.user_id(),
3660 request.role().into(),
3661 )
3662 .await?;
3663
3664 let update = proto::UpdateChannels {
3665 channel_invitations: vec![channel.to_proto()],
3666 ..Default::default()
3667 };
3668
3669 let connection_pool = session.connection_pool().await;
3670 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3671 session.peer.send(connection_id, update.clone())?;
3672 }
3673
3674 send_notifications(&connection_pool, &session.peer, notifications);
3675
3676 response.send(proto::Ack {})?;
3677 Ok(())
3678}
3679
3680/// remove someone from a channel
3681async fn remove_channel_member(
3682 request: proto::RemoveChannelMember,
3683 response: Response<proto::RemoveChannelMember>,
3684 session: UserSession,
3685) -> Result<()> {
3686 let db = session.db().await;
3687 let channel_id = ChannelId::from_proto(request.channel_id);
3688 let member_id = UserId::from_proto(request.user_id);
3689
3690 let RemoveChannelMemberResult {
3691 membership_update,
3692 notification_id,
3693 } = db
3694 .remove_channel_member(channel_id, member_id, session.user_id())
3695 .await?;
3696
3697 let mut connection_pool = session.connection_pool().await;
3698 notify_membership_updated(
3699 &mut connection_pool,
3700 membership_update,
3701 member_id,
3702 &session.peer,
3703 );
3704 for connection_id in connection_pool.user_connection_ids(member_id) {
3705 if let Some(notification_id) = notification_id {
3706 session
3707 .peer
3708 .send(
3709 connection_id,
3710 proto::DeleteNotification {
3711 notification_id: notification_id.to_proto(),
3712 },
3713 )
3714 .trace_err();
3715 }
3716 }
3717
3718 response.send(proto::Ack {})?;
3719 Ok(())
3720}
3721
3722/// Toggle the channel between public and private.
3723/// Care is taken to maintain the invariant that public channels only descend from public channels,
3724/// (though members-only channels can appear at any point in the hierarchy).
3725async fn set_channel_visibility(
3726 request: proto::SetChannelVisibility,
3727 response: Response<proto::SetChannelVisibility>,
3728 session: UserSession,
3729) -> Result<()> {
3730 let db = session.db().await;
3731 let channel_id = ChannelId::from_proto(request.channel_id);
3732 let visibility = request.visibility().into();
3733
3734 let channel_model = db
3735 .set_channel_visibility(channel_id, visibility, session.user_id())
3736 .await?;
3737 let root_id = channel_model.root_id();
3738 let channel = Channel::from_model(channel_model);
3739
3740 let mut connection_pool = session.connection_pool().await;
3741 for (user_id, role) in connection_pool
3742 .channel_user_ids(root_id)
3743 .collect::<Vec<_>>()
3744 .into_iter()
3745 {
3746 let update = if role.can_see_channel(channel.visibility) {
3747 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3748 proto::UpdateChannels {
3749 channels: vec![channel.to_proto()],
3750 ..Default::default()
3751 }
3752 } else {
3753 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3754 proto::UpdateChannels {
3755 delete_channels: vec![channel.id.to_proto()],
3756 ..Default::default()
3757 }
3758 };
3759
3760 for connection_id in connection_pool.user_connection_ids(user_id) {
3761 session.peer.send(connection_id, update.clone())?;
3762 }
3763 }
3764
3765 response.send(proto::Ack {})?;
3766 Ok(())
3767}
3768
3769/// Alter the role for a user in the channel.
3770async fn set_channel_member_role(
3771 request: proto::SetChannelMemberRole,
3772 response: Response<proto::SetChannelMemberRole>,
3773 session: UserSession,
3774) -> Result<()> {
3775 let db = session.db().await;
3776 let channel_id = ChannelId::from_proto(request.channel_id);
3777 let member_id = UserId::from_proto(request.user_id);
3778 let result = db
3779 .set_channel_member_role(
3780 channel_id,
3781 session.user_id(),
3782 member_id,
3783 request.role().into(),
3784 )
3785 .await?;
3786
3787 match result {
3788 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3789 let mut connection_pool = session.connection_pool().await;
3790 notify_membership_updated(
3791 &mut connection_pool,
3792 membership_update,
3793 member_id,
3794 &session.peer,
3795 )
3796 }
3797 db::SetMemberRoleResult::InviteUpdated(channel) => {
3798 let update = proto::UpdateChannels {
3799 channel_invitations: vec![channel.to_proto()],
3800 ..Default::default()
3801 };
3802
3803 for connection_id in session
3804 .connection_pool()
3805 .await
3806 .user_connection_ids(member_id)
3807 {
3808 session.peer.send(connection_id, update.clone())?;
3809 }
3810 }
3811 }
3812
3813 response.send(proto::Ack {})?;
3814 Ok(())
3815}
3816
3817/// Change the name of a channel
3818async fn rename_channel(
3819 request: proto::RenameChannel,
3820 response: Response<proto::RenameChannel>,
3821 session: UserSession,
3822) -> Result<()> {
3823 let db = session.db().await;
3824 let channel_id = ChannelId::from_proto(request.channel_id);
3825 let channel_model = db
3826 .rename_channel(channel_id, session.user_id(), &request.name)
3827 .await?;
3828 let root_id = channel_model.root_id();
3829 let channel = Channel::from_model(channel_model);
3830
3831 response.send(proto::RenameChannelResponse {
3832 channel: Some(channel.to_proto()),
3833 })?;
3834
3835 let connection_pool = session.connection_pool().await;
3836 let update = proto::UpdateChannels {
3837 channels: vec![channel.to_proto()],
3838 ..Default::default()
3839 };
3840 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3841 if role.can_see_channel(channel.visibility) {
3842 session.peer.send(connection_id, update.clone())?;
3843 }
3844 }
3845
3846 Ok(())
3847}
3848
3849/// Move a channel to a new parent.
3850async fn move_channel(
3851 request: proto::MoveChannel,
3852 response: Response<proto::MoveChannel>,
3853 session: UserSession,
3854) -> Result<()> {
3855 let channel_id = ChannelId::from_proto(request.channel_id);
3856 let to = ChannelId::from_proto(request.to);
3857
3858 let (root_id, channels) = session
3859 .db()
3860 .await
3861 .move_channel(channel_id, to, session.user_id())
3862 .await?;
3863
3864 let connection_pool = session.connection_pool().await;
3865 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3866 let channels = channels
3867 .iter()
3868 .filter_map(|channel| {
3869 if role.can_see_channel(channel.visibility) {
3870 Some(channel.to_proto())
3871 } else {
3872 None
3873 }
3874 })
3875 .collect::<Vec<_>>();
3876 if channels.is_empty() {
3877 continue;
3878 }
3879
3880 let update = proto::UpdateChannels {
3881 channels,
3882 ..Default::default()
3883 };
3884
3885 session.peer.send(connection_id, update.clone())?;
3886 }
3887
3888 response.send(Ack {})?;
3889 Ok(())
3890}
3891
3892/// Get the list of channel members
3893async fn get_channel_members(
3894 request: proto::GetChannelMembers,
3895 response: Response<proto::GetChannelMembers>,
3896 session: UserSession,
3897) -> Result<()> {
3898 let db = session.db().await;
3899 let channel_id = ChannelId::from_proto(request.channel_id);
3900 let limit = if request.limit == 0 {
3901 u16::MAX as u64
3902 } else {
3903 request.limit
3904 };
3905 let (members, users) = db
3906 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3907 .await?;
3908 response.send(proto::GetChannelMembersResponse { members, users })?;
3909 Ok(())
3910}
3911
3912/// Accept or decline a channel invitation.
3913async fn respond_to_channel_invite(
3914 request: proto::RespondToChannelInvite,
3915 response: Response<proto::RespondToChannelInvite>,
3916 session: UserSession,
3917) -> Result<()> {
3918 let db = session.db().await;
3919 let channel_id = ChannelId::from_proto(request.channel_id);
3920 let RespondToChannelInvite {
3921 membership_update,
3922 notifications,
3923 } = db
3924 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3925 .await?;
3926
3927 let mut connection_pool = session.connection_pool().await;
3928 if let Some(membership_update) = membership_update {
3929 notify_membership_updated(
3930 &mut connection_pool,
3931 membership_update,
3932 session.user_id(),
3933 &session.peer,
3934 );
3935 } else {
3936 let update = proto::UpdateChannels {
3937 remove_channel_invitations: vec![channel_id.to_proto()],
3938 ..Default::default()
3939 };
3940
3941 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3942 session.peer.send(connection_id, update.clone())?;
3943 }
3944 };
3945
3946 send_notifications(&connection_pool, &session.peer, notifications);
3947
3948 response.send(proto::Ack {})?;
3949
3950 Ok(())
3951}
3952
3953/// Join the channels' room
3954async fn join_channel(
3955 request: proto::JoinChannel,
3956 response: Response<proto::JoinChannel>,
3957 session: UserSession,
3958) -> Result<()> {
3959 let channel_id = ChannelId::from_proto(request.channel_id);
3960 join_channel_internal(channel_id, Box::new(response), session).await
3961}
3962
3963trait JoinChannelInternalResponse {
3964 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3965}
3966impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3967 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3968 Response::<proto::JoinChannel>::send(self, result)
3969 }
3970}
3971impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3972 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3973 Response::<proto::JoinRoom>::send(self, result)
3974 }
3975}
3976
3977async fn join_channel_internal(
3978 channel_id: ChannelId,
3979 response: Box<impl JoinChannelInternalResponse>,
3980 session: UserSession,
3981) -> Result<()> {
3982 let joined_room = {
3983 let mut db = session.db().await;
3984 // If zed quits without leaving the room, and the user re-opens zed before the
3985 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3986 // room they were in.
3987 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3988 tracing::info!(
3989 stale_connection_id = %connection,
3990 "cleaning up stale connection",
3991 );
3992 drop(db);
3993 leave_room_for_session(&session, connection).await?;
3994 db = session.db().await;
3995 }
3996
3997 let (joined_room, membership_updated, role) = db
3998 .join_channel(channel_id, session.user_id(), session.connection_id)
3999 .await?;
4000
4001 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4002 let (can_publish, token) = if role == ChannelRole::Guest {
4003 (
4004 false,
4005 live_kit
4006 .guest_token(
4007 &joined_room.room.live_kit_room,
4008 &session.user_id().to_string(),
4009 )
4010 .trace_err()?,
4011 )
4012 } else {
4013 (
4014 true,
4015 live_kit
4016 .room_token(
4017 &joined_room.room.live_kit_room,
4018 &session.user_id().to_string(),
4019 )
4020 .trace_err()?,
4021 )
4022 };
4023
4024 Some(LiveKitConnectionInfo {
4025 server_url: live_kit.url().into(),
4026 token,
4027 can_publish,
4028 })
4029 });
4030
4031 response.send(proto::JoinRoomResponse {
4032 room: Some(joined_room.room.clone()),
4033 channel_id: joined_room
4034 .channel
4035 .as_ref()
4036 .map(|channel| channel.id.to_proto()),
4037 live_kit_connection_info,
4038 })?;
4039
4040 let mut connection_pool = session.connection_pool().await;
4041 if let Some(membership_updated) = membership_updated {
4042 notify_membership_updated(
4043 &mut connection_pool,
4044 membership_updated,
4045 session.user_id(),
4046 &session.peer,
4047 );
4048 }
4049
4050 room_updated(&joined_room.room, &session.peer);
4051
4052 joined_room
4053 };
4054
4055 channel_updated(
4056 &joined_room
4057 .channel
4058 .ok_or_else(|| anyhow!("channel not returned"))?,
4059 &joined_room.room,
4060 &session.peer,
4061 &*session.connection_pool().await,
4062 );
4063
4064 update_user_contacts(session.user_id(), &session).await?;
4065 Ok(())
4066}
4067
4068/// Start editing the channel notes
4069async fn join_channel_buffer(
4070 request: proto::JoinChannelBuffer,
4071 response: Response<proto::JoinChannelBuffer>,
4072 session: UserSession,
4073) -> Result<()> {
4074 let db = session.db().await;
4075 let channel_id = ChannelId::from_proto(request.channel_id);
4076
4077 let open_response = db
4078 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4079 .await?;
4080
4081 let collaborators = open_response.collaborators.clone();
4082 response.send(open_response)?;
4083
4084 let update = UpdateChannelBufferCollaborators {
4085 channel_id: channel_id.to_proto(),
4086 collaborators: collaborators.clone(),
4087 };
4088 channel_buffer_updated(
4089 session.connection_id,
4090 collaborators
4091 .iter()
4092 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4093 &update,
4094 &session.peer,
4095 );
4096
4097 Ok(())
4098}
4099
4100/// Edit the channel notes
4101async fn update_channel_buffer(
4102 request: proto::UpdateChannelBuffer,
4103 session: UserSession,
4104) -> Result<()> {
4105 let db = session.db().await;
4106 let channel_id = ChannelId::from_proto(request.channel_id);
4107
4108 let (collaborators, epoch, version) = db
4109 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4110 .await?;
4111
4112 channel_buffer_updated(
4113 session.connection_id,
4114 collaborators.clone(),
4115 &proto::UpdateChannelBuffer {
4116 channel_id: channel_id.to_proto(),
4117 operations: request.operations,
4118 },
4119 &session.peer,
4120 );
4121
4122 let pool = &*session.connection_pool().await;
4123
4124 let non_collaborators =
4125 pool.channel_connection_ids(channel_id)
4126 .filter_map(|(connection_id, _)| {
4127 if collaborators.contains(&connection_id) {
4128 None
4129 } else {
4130 Some(connection_id)
4131 }
4132 });
4133
4134 broadcast(None, non_collaborators, |peer_id| {
4135 session.peer.send(
4136 peer_id,
4137 proto::UpdateChannels {
4138 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4139 channel_id: channel_id.to_proto(),
4140 epoch: epoch as u64,
4141 version: version.clone(),
4142 }],
4143 ..Default::default()
4144 },
4145 )
4146 });
4147
4148 Ok(())
4149}
4150
4151/// Rejoin the channel notes after a connection blip
4152async fn rejoin_channel_buffers(
4153 request: proto::RejoinChannelBuffers,
4154 response: Response<proto::RejoinChannelBuffers>,
4155 session: UserSession,
4156) -> Result<()> {
4157 let db = session.db().await;
4158 let buffers = db
4159 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4160 .await?;
4161
4162 for rejoined_buffer in &buffers {
4163 let collaborators_to_notify = rejoined_buffer
4164 .buffer
4165 .collaborators
4166 .iter()
4167 .filter_map(|c| Some(c.peer_id?.into()));
4168 channel_buffer_updated(
4169 session.connection_id,
4170 collaborators_to_notify,
4171 &proto::UpdateChannelBufferCollaborators {
4172 channel_id: rejoined_buffer.buffer.channel_id,
4173 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4174 },
4175 &session.peer,
4176 );
4177 }
4178
4179 response.send(proto::RejoinChannelBuffersResponse {
4180 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4181 })?;
4182
4183 Ok(())
4184}
4185
4186/// Stop editing the channel notes
4187async fn leave_channel_buffer(
4188 request: proto::LeaveChannelBuffer,
4189 response: Response<proto::LeaveChannelBuffer>,
4190 session: UserSession,
4191) -> Result<()> {
4192 let db = session.db().await;
4193 let channel_id = ChannelId::from_proto(request.channel_id);
4194
4195 let left_buffer = db
4196 .leave_channel_buffer(channel_id, session.connection_id)
4197 .await?;
4198
4199 response.send(Ack {})?;
4200
4201 channel_buffer_updated(
4202 session.connection_id,
4203 left_buffer.connections,
4204 &proto::UpdateChannelBufferCollaborators {
4205 channel_id: channel_id.to_proto(),
4206 collaborators: left_buffer.collaborators,
4207 },
4208 &session.peer,
4209 );
4210
4211 Ok(())
4212}
4213
4214fn channel_buffer_updated<T: EnvelopedMessage>(
4215 sender_id: ConnectionId,
4216 collaborators: impl IntoIterator<Item = ConnectionId>,
4217 message: &T,
4218 peer: &Peer,
4219) {
4220 broadcast(Some(sender_id), collaborators, |peer_id| {
4221 peer.send(peer_id, message.clone())
4222 });
4223}
4224
4225fn send_notifications(
4226 connection_pool: &ConnectionPool,
4227 peer: &Peer,
4228 notifications: db::NotificationBatch,
4229) {
4230 for (user_id, notification) in notifications {
4231 for connection_id in connection_pool.user_connection_ids(user_id) {
4232 if let Err(error) = peer.send(
4233 connection_id,
4234 proto::AddNotification {
4235 notification: Some(notification.clone()),
4236 },
4237 ) {
4238 tracing::error!(
4239 "failed to send notification to {:?} {}",
4240 connection_id,
4241 error
4242 );
4243 }
4244 }
4245 }
4246}
4247
4248/// Send a message to the channel
4249async fn send_channel_message(
4250 request: proto::SendChannelMessage,
4251 response: Response<proto::SendChannelMessage>,
4252 session: UserSession,
4253) -> Result<()> {
4254 // Validate the message body.
4255 let body = request.body.trim().to_string();
4256 if body.len() > MAX_MESSAGE_LEN {
4257 return Err(anyhow!("message is too long"))?;
4258 }
4259 if body.is_empty() {
4260 return Err(anyhow!("message can't be blank"))?;
4261 }
4262
4263 // TODO: adjust mentions if body is trimmed
4264
4265 let timestamp = OffsetDateTime::now_utc();
4266 let nonce = request
4267 .nonce
4268 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4269
4270 let channel_id = ChannelId::from_proto(request.channel_id);
4271 let CreatedChannelMessage {
4272 message_id,
4273 participant_connection_ids,
4274 notifications,
4275 } = session
4276 .db()
4277 .await
4278 .create_channel_message(
4279 channel_id,
4280 session.user_id(),
4281 &body,
4282 &request.mentions,
4283 timestamp,
4284 nonce.clone().into(),
4285 match request.reply_to_message_id {
4286 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4287 None => None,
4288 },
4289 )
4290 .await?;
4291
4292 let message = proto::ChannelMessage {
4293 sender_id: session.user_id().to_proto(),
4294 id: message_id.to_proto(),
4295 body,
4296 mentions: request.mentions,
4297 timestamp: timestamp.unix_timestamp() as u64,
4298 nonce: Some(nonce),
4299 reply_to_message_id: request.reply_to_message_id,
4300 edited_at: None,
4301 };
4302 broadcast(
4303 Some(session.connection_id),
4304 participant_connection_ids.clone(),
4305 |connection| {
4306 session.peer.send(
4307 connection,
4308 proto::ChannelMessageSent {
4309 channel_id: channel_id.to_proto(),
4310 message: Some(message.clone()),
4311 },
4312 )
4313 },
4314 );
4315 response.send(proto::SendChannelMessageResponse {
4316 message: Some(message),
4317 })?;
4318
4319 let pool = &*session.connection_pool().await;
4320 let non_participants =
4321 pool.channel_connection_ids(channel_id)
4322 .filter_map(|(connection_id, _)| {
4323 if participant_connection_ids.contains(&connection_id) {
4324 None
4325 } else {
4326 Some(connection_id)
4327 }
4328 });
4329 broadcast(None, non_participants, |peer_id| {
4330 session.peer.send(
4331 peer_id,
4332 proto::UpdateChannels {
4333 latest_channel_message_ids: vec![proto::ChannelMessageId {
4334 channel_id: channel_id.to_proto(),
4335 message_id: message_id.to_proto(),
4336 }],
4337 ..Default::default()
4338 },
4339 )
4340 });
4341 send_notifications(pool, &session.peer, notifications);
4342
4343 Ok(())
4344}
4345
4346/// Delete a channel message
4347async fn remove_channel_message(
4348 request: proto::RemoveChannelMessage,
4349 response: Response<proto::RemoveChannelMessage>,
4350 session: UserSession,
4351) -> Result<()> {
4352 let channel_id = ChannelId::from_proto(request.channel_id);
4353 let message_id = MessageId::from_proto(request.message_id);
4354 let (connection_ids, existing_notification_ids) = session
4355 .db()
4356 .await
4357 .remove_channel_message(channel_id, message_id, session.user_id())
4358 .await?;
4359
4360 broadcast(
4361 Some(session.connection_id),
4362 connection_ids,
4363 move |connection| {
4364 session.peer.send(connection, request.clone())?;
4365
4366 for notification_id in &existing_notification_ids {
4367 session.peer.send(
4368 connection,
4369 proto::DeleteNotification {
4370 notification_id: (*notification_id).to_proto(),
4371 },
4372 )?;
4373 }
4374
4375 Ok(())
4376 },
4377 );
4378 response.send(proto::Ack {})?;
4379 Ok(())
4380}
4381
4382async fn update_channel_message(
4383 request: proto::UpdateChannelMessage,
4384 response: Response<proto::UpdateChannelMessage>,
4385 session: UserSession,
4386) -> Result<()> {
4387 let channel_id = ChannelId::from_proto(request.channel_id);
4388 let message_id = MessageId::from_proto(request.message_id);
4389 let updated_at = OffsetDateTime::now_utc();
4390 let UpdatedChannelMessage {
4391 message_id,
4392 participant_connection_ids,
4393 notifications,
4394 reply_to_message_id,
4395 timestamp,
4396 deleted_mention_notification_ids,
4397 updated_mention_notifications,
4398 } = session
4399 .db()
4400 .await
4401 .update_channel_message(
4402 channel_id,
4403 message_id,
4404 session.user_id(),
4405 request.body.as_str(),
4406 &request.mentions,
4407 updated_at,
4408 )
4409 .await?;
4410
4411 let nonce = request
4412 .nonce
4413 .clone()
4414 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4415
4416 let message = proto::ChannelMessage {
4417 sender_id: session.user_id().to_proto(),
4418 id: message_id.to_proto(),
4419 body: request.body.clone(),
4420 mentions: request.mentions.clone(),
4421 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4422 nonce: Some(nonce),
4423 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4424 edited_at: Some(updated_at.unix_timestamp() as u64),
4425 };
4426
4427 response.send(proto::Ack {})?;
4428
4429 let pool = &*session.connection_pool().await;
4430 broadcast(
4431 Some(session.connection_id),
4432 participant_connection_ids,
4433 |connection| {
4434 session.peer.send(
4435 connection,
4436 proto::ChannelMessageUpdate {
4437 channel_id: channel_id.to_proto(),
4438 message: Some(message.clone()),
4439 },
4440 )?;
4441
4442 for notification_id in &deleted_mention_notification_ids {
4443 session.peer.send(
4444 connection,
4445 proto::DeleteNotification {
4446 notification_id: (*notification_id).to_proto(),
4447 },
4448 )?;
4449 }
4450
4451 for notification in &updated_mention_notifications {
4452 session.peer.send(
4453 connection,
4454 proto::UpdateNotification {
4455 notification: Some(notification.clone()),
4456 },
4457 )?;
4458 }
4459
4460 Ok(())
4461 },
4462 );
4463
4464 send_notifications(pool, &session.peer, notifications);
4465
4466 Ok(())
4467}
4468
4469/// Mark a channel message as read
4470async fn acknowledge_channel_message(
4471 request: proto::AckChannelMessage,
4472 session: UserSession,
4473) -> Result<()> {
4474 let channel_id = ChannelId::from_proto(request.channel_id);
4475 let message_id = MessageId::from_proto(request.message_id);
4476 let notifications = session
4477 .db()
4478 .await
4479 .observe_channel_message(channel_id, session.user_id(), message_id)
4480 .await?;
4481 send_notifications(
4482 &*session.connection_pool().await,
4483 &session.peer,
4484 notifications,
4485 );
4486 Ok(())
4487}
4488
4489/// Mark a buffer version as synced
4490async fn acknowledge_buffer_version(
4491 request: proto::AckBufferOperation,
4492 session: UserSession,
4493) -> Result<()> {
4494 let buffer_id = BufferId::from_proto(request.buffer_id);
4495 session
4496 .db()
4497 .await
4498 .observe_buffer_version(
4499 buffer_id,
4500 session.user_id(),
4501 request.epoch as i32,
4502 &request.version,
4503 )
4504 .await?;
4505 Ok(())
4506}
4507
4508struct CompleteWithLanguageModelRateLimit;
4509
4510impl RateLimit for CompleteWithLanguageModelRateLimit {
4511 fn capacity() -> usize {
4512 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4513 .ok()
4514 .and_then(|v| v.parse().ok())
4515 .unwrap_or(120) // Picked arbitrarily
4516 }
4517
4518 fn refill_duration() -> chrono::Duration {
4519 chrono::Duration::hours(1)
4520 }
4521
4522 fn db_name() -> &'static str {
4523 "complete-with-language-model"
4524 }
4525}
4526
4527async fn complete_with_language_model(
4528 request: proto::CompleteWithLanguageModel,
4529 response: Response<proto::CompleteWithLanguageModel>,
4530 session: Session,
4531 config: &Config,
4532) -> Result<()> {
4533 let Some(session) = session.for_user() else {
4534 return Err(anyhow!("user not found"))?;
4535 };
4536 authorize_access_to_language_models(&session).await?;
4537
4538 session
4539 .rate_limiter
4540 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4541 .await?;
4542
4543 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4544 Some(proto::LanguageModelProvider::Anthropic) => {
4545 let api_key = config
4546 .anthropic_api_key
4547 .as_ref()
4548 .context("no Anthropic AI API key configured on the server")?;
4549 anthropic::complete(
4550 session.http_client.as_ref(),
4551 anthropic::ANTHROPIC_API_URL,
4552 api_key,
4553 serde_json::from_str(&request.request)?,
4554 )
4555 .await?
4556 }
4557 _ => return Err(anyhow!("unsupported provider"))?,
4558 };
4559
4560 response.send(proto::CompleteWithLanguageModelResponse {
4561 completion: serde_json::to_string(&result)?,
4562 })?;
4563
4564 Ok(())
4565}
4566
4567async fn stream_complete_with_language_model(
4568 request: proto::StreamCompleteWithLanguageModel,
4569 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4570 session: Session,
4571 config: &Config,
4572) -> Result<()> {
4573 let Some(session) = session.for_user() else {
4574 return Err(anyhow!("user not found"))?;
4575 };
4576 authorize_access_to_language_models(&session).await?;
4577
4578 session
4579 .rate_limiter
4580 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4581 .await?;
4582
4583 match proto::LanguageModelProvider::from_i32(request.provider) {
4584 Some(proto::LanguageModelProvider::Anthropic) => {
4585 let api_key = config
4586 .anthropic_api_key
4587 .as_ref()
4588 .context("no Anthropic AI API key configured on the server")?;
4589 let mut chunks = anthropic::stream_completion(
4590 session.http_client.as_ref(),
4591 anthropic::ANTHROPIC_API_URL,
4592 api_key,
4593 serde_json::from_str(&request.request)?,
4594 None,
4595 )
4596 .await?;
4597 while let Some(event) = chunks.next().await {
4598 let chunk = event?;
4599 response.send(proto::StreamCompleteWithLanguageModelResponse {
4600 event: serde_json::to_string(&chunk)?,
4601 })?;
4602 }
4603 }
4604 Some(proto::LanguageModelProvider::OpenAi) => {
4605 let api_key = config
4606 .openai_api_key
4607 .as_ref()
4608 .context("no OpenAI API key configured on the server")?;
4609 let mut events = open_ai::stream_completion(
4610 session.http_client.as_ref(),
4611 open_ai::OPEN_AI_API_URL,
4612 api_key,
4613 serde_json::from_str(&request.request)?,
4614 None,
4615 )
4616 .await?;
4617 while let Some(event) = events.next().await {
4618 let event = event?;
4619 response.send(proto::StreamCompleteWithLanguageModelResponse {
4620 event: serde_json::to_string(&event)?,
4621 })?;
4622 }
4623 }
4624 Some(proto::LanguageModelProvider::Google) => {
4625 let api_key = config
4626 .google_ai_api_key
4627 .as_ref()
4628 .context("no Google AI API key configured on the server")?;
4629 let mut events = google_ai::stream_generate_content(
4630 session.http_client.as_ref(),
4631 google_ai::API_URL,
4632 api_key,
4633 serde_json::from_str(&request.request)?,
4634 )
4635 .await?;
4636 while let Some(event) = events.next().await {
4637 let event = event?;
4638 response.send(proto::StreamCompleteWithLanguageModelResponse {
4639 event: serde_json::to_string(&event)?,
4640 })?;
4641 }
4642 }
4643 None => return Err(anyhow!("unknown provider"))?,
4644 }
4645
4646 Ok(())
4647}
4648
4649async fn count_language_model_tokens(
4650 request: proto::CountLanguageModelTokens,
4651 response: Response<proto::CountLanguageModelTokens>,
4652 session: Session,
4653 config: &Config,
4654) -> Result<()> {
4655 let Some(session) = session.for_user() else {
4656 return Err(anyhow!("user not found"))?;
4657 };
4658 authorize_access_to_language_models(&session).await?;
4659
4660 session
4661 .rate_limiter
4662 .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4663 .await?;
4664
4665 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4666 Some(proto::LanguageModelProvider::Google) => {
4667 let api_key = config
4668 .google_ai_api_key
4669 .as_ref()
4670 .context("no Google AI API key configured on the server")?;
4671 google_ai::count_tokens(
4672 session.http_client.as_ref(),
4673 google_ai::API_URL,
4674 api_key,
4675 serde_json::from_str(&request.request)?,
4676 )
4677 .await?
4678 }
4679 _ => return Err(anyhow!("unsupported provider"))?,
4680 };
4681
4682 response.send(proto::CountLanguageModelTokensResponse {
4683 token_count: result.total_tokens as u32,
4684 })?;
4685
4686 Ok(())
4687}
4688
4689struct CountLanguageModelTokensRateLimit;
4690
4691impl RateLimit for CountLanguageModelTokensRateLimit {
4692 fn capacity() -> usize {
4693 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4694 .ok()
4695 .and_then(|v| v.parse().ok())
4696 .unwrap_or(600) // Picked arbitrarily
4697 }
4698
4699 fn refill_duration() -> chrono::Duration {
4700 chrono::Duration::hours(1)
4701 }
4702
4703 fn db_name() -> &'static str {
4704 "count-language-model-tokens"
4705 }
4706}
4707
4708struct ComputeEmbeddingsRateLimit;
4709
4710impl RateLimit for ComputeEmbeddingsRateLimit {
4711 fn capacity() -> usize {
4712 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4713 .ok()
4714 .and_then(|v| v.parse().ok())
4715 .unwrap_or(5000) // Picked arbitrarily
4716 }
4717
4718 fn refill_duration() -> chrono::Duration {
4719 chrono::Duration::hours(1)
4720 }
4721
4722 fn db_name() -> &'static str {
4723 "compute-embeddings"
4724 }
4725}
4726
4727async fn compute_embeddings(
4728 request: proto::ComputeEmbeddings,
4729 response: Response<proto::ComputeEmbeddings>,
4730 session: UserSession,
4731 api_key: Option<Arc<str>>,
4732) -> Result<()> {
4733 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4734 authorize_access_to_language_models(&session).await?;
4735
4736 session
4737 .rate_limiter
4738 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4739 .await?;
4740
4741 let embeddings = match request.model.as_str() {
4742 "openai/text-embedding-3-small" => {
4743 open_ai::embed(
4744 session.http_client.as_ref(),
4745 OPEN_AI_API_URL,
4746 &api_key,
4747 OpenAiEmbeddingModel::TextEmbedding3Small,
4748 request.texts.iter().map(|text| text.as_str()),
4749 )
4750 .await?
4751 }
4752 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4753 };
4754
4755 let embeddings = request
4756 .texts
4757 .iter()
4758 .map(|text| {
4759 let mut hasher = sha2::Sha256::new();
4760 hasher.update(text.as_bytes());
4761 let result = hasher.finalize();
4762 result.to_vec()
4763 })
4764 .zip(
4765 embeddings
4766 .data
4767 .into_iter()
4768 .map(|embedding| embedding.embedding),
4769 )
4770 .collect::<HashMap<_, _>>();
4771
4772 let db = session.db().await;
4773 db.save_embeddings(&request.model, &embeddings)
4774 .await
4775 .context("failed to save embeddings")
4776 .trace_err();
4777
4778 response.send(proto::ComputeEmbeddingsResponse {
4779 embeddings: embeddings
4780 .into_iter()
4781 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4782 .collect(),
4783 })?;
4784 Ok(())
4785}
4786
4787async fn get_cached_embeddings(
4788 request: proto::GetCachedEmbeddings,
4789 response: Response<proto::GetCachedEmbeddings>,
4790 session: UserSession,
4791) -> Result<()> {
4792 authorize_access_to_language_models(&session).await?;
4793
4794 let db = session.db().await;
4795 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4796
4797 response.send(proto::GetCachedEmbeddingsResponse {
4798 embeddings: embeddings
4799 .into_iter()
4800 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4801 .collect(),
4802 })?;
4803 Ok(())
4804}
4805
4806async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4807 let db = session.db().await;
4808 let flags = db.get_user_flags(session.user_id()).await?;
4809 if flags.iter().any(|flag| flag == "language-models") {
4810 Ok(())
4811 } else {
4812 Err(anyhow!("permission denied"))?
4813 }
4814}
4815
4816/// Get a Supermaven API key for the user
4817async fn get_supermaven_api_key(
4818 _request: proto::GetSupermavenApiKey,
4819 response: Response<proto::GetSupermavenApiKey>,
4820 session: UserSession,
4821) -> Result<()> {
4822 let user_id: String = session.user_id().to_string();
4823 if !session.is_staff() {
4824 return Err(anyhow!("supermaven not enabled for this account"))?;
4825 }
4826
4827 let email = session
4828 .email()
4829 .ok_or_else(|| anyhow!("user must have an email"))?;
4830
4831 let supermaven_admin_api = session
4832 .supermaven_client
4833 .as_ref()
4834 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4835
4836 let result = supermaven_admin_api
4837 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4838 .await?;
4839
4840 response.send(proto::GetSupermavenApiKeyResponse {
4841 api_key: result.api_key,
4842 })?;
4843
4844 Ok(())
4845}
4846
4847/// Start receiving chat updates for a channel
4848async fn join_channel_chat(
4849 request: proto::JoinChannelChat,
4850 response: Response<proto::JoinChannelChat>,
4851 session: UserSession,
4852) -> Result<()> {
4853 let channel_id = ChannelId::from_proto(request.channel_id);
4854
4855 let db = session.db().await;
4856 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4857 .await?;
4858 let messages = db
4859 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4860 .await?;
4861 response.send(proto::JoinChannelChatResponse {
4862 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4863 messages,
4864 })?;
4865 Ok(())
4866}
4867
4868/// Stop receiving chat updates for a channel
4869async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4870 let channel_id = ChannelId::from_proto(request.channel_id);
4871 session
4872 .db()
4873 .await
4874 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4875 .await?;
4876 Ok(())
4877}
4878
4879/// Retrieve the chat history for a channel
4880async fn get_channel_messages(
4881 request: proto::GetChannelMessages,
4882 response: Response<proto::GetChannelMessages>,
4883 session: UserSession,
4884) -> Result<()> {
4885 let channel_id = ChannelId::from_proto(request.channel_id);
4886 let messages = session
4887 .db()
4888 .await
4889 .get_channel_messages(
4890 channel_id,
4891 session.user_id(),
4892 MESSAGE_COUNT_PER_PAGE,
4893 Some(MessageId::from_proto(request.before_message_id)),
4894 )
4895 .await?;
4896 response.send(proto::GetChannelMessagesResponse {
4897 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4898 messages,
4899 })?;
4900 Ok(())
4901}
4902
4903/// Retrieve specific chat messages
4904async fn get_channel_messages_by_id(
4905 request: proto::GetChannelMessagesById,
4906 response: Response<proto::GetChannelMessagesById>,
4907 session: UserSession,
4908) -> Result<()> {
4909 let message_ids = request
4910 .message_ids
4911 .iter()
4912 .map(|id| MessageId::from_proto(*id))
4913 .collect::<Vec<_>>();
4914 let messages = session
4915 .db()
4916 .await
4917 .get_channel_messages_by_id(session.user_id(), &message_ids)
4918 .await?;
4919 response.send(proto::GetChannelMessagesResponse {
4920 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4921 messages,
4922 })?;
4923 Ok(())
4924}
4925
4926/// Retrieve the current users notifications
4927async fn get_notifications(
4928 request: proto::GetNotifications,
4929 response: Response<proto::GetNotifications>,
4930 session: UserSession,
4931) -> Result<()> {
4932 let notifications = session
4933 .db()
4934 .await
4935 .get_notifications(
4936 session.user_id(),
4937 NOTIFICATION_COUNT_PER_PAGE,
4938 request
4939 .before_id
4940 .map(|id| db::NotificationId::from_proto(id)),
4941 )
4942 .await?;
4943 response.send(proto::GetNotificationsResponse {
4944 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4945 notifications,
4946 })?;
4947 Ok(())
4948}
4949
4950/// Mark notifications as read
4951async fn mark_notification_as_read(
4952 request: proto::MarkNotificationRead,
4953 response: Response<proto::MarkNotificationRead>,
4954 session: UserSession,
4955) -> Result<()> {
4956 let database = &session.db().await;
4957 let notifications = database
4958 .mark_notification_as_read_by_id(
4959 session.user_id(),
4960 NotificationId::from_proto(request.notification_id),
4961 )
4962 .await?;
4963 send_notifications(
4964 &*session.connection_pool().await,
4965 &session.peer,
4966 notifications,
4967 );
4968 response.send(proto::Ack {})?;
4969 Ok(())
4970}
4971
4972/// Get the current users information
4973async fn get_private_user_info(
4974 _request: proto::GetPrivateUserInfo,
4975 response: Response<proto::GetPrivateUserInfo>,
4976 session: UserSession,
4977) -> Result<()> {
4978 let db = session.db().await;
4979
4980 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4981 let user = db
4982 .get_user_by_id(session.user_id())
4983 .await?
4984 .ok_or_else(|| anyhow!("user not found"))?;
4985 let flags = db.get_user_flags(session.user_id()).await?;
4986
4987 response.send(proto::GetPrivateUserInfoResponse {
4988 metrics_id,
4989 staff: user.admin,
4990 flags,
4991 })?;
4992 Ok(())
4993}
4994
4995fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4996 let message = match message {
4997 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4998 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4999 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5000 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5001 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5002 code: frame.code.into(),
5003 reason: frame.reason,
5004 })),
5005 // We should never receive a frame while reading the message, according
5006 // to the `tungstenite` maintainers:
5007 //
5008 // > It cannot occur when you read messages from the WebSocket, but it
5009 // > can be used when you want to send the raw frames (e.g. you want to
5010 // > send the frames to the WebSocket without composing the full message first).
5011 // >
5012 // > — https://github.com/snapview/tungstenite-rs/issues/268
5013 TungsteniteMessage::Frame(_) => {
5014 bail!("received an unexpected frame while reading the message")
5015 }
5016 };
5017
5018 Ok(message)
5019}
5020
5021fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5022 match message {
5023 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5024 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5025 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5026 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5027 AxumMessage::Close(frame) => {
5028 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5029 code: frame.code.into(),
5030 reason: frame.reason,
5031 }))
5032 }
5033 }
5034}
5035
5036fn notify_membership_updated(
5037 connection_pool: &mut ConnectionPool,
5038 result: MembershipUpdated,
5039 user_id: UserId,
5040 peer: &Peer,
5041) {
5042 for membership in &result.new_channels.channel_memberships {
5043 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5044 }
5045 for channel_id in &result.removed_channels {
5046 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5047 }
5048
5049 let user_channels_update = proto::UpdateUserChannels {
5050 channel_memberships: result
5051 .new_channels
5052 .channel_memberships
5053 .iter()
5054 .map(|cm| proto::ChannelMembership {
5055 channel_id: cm.channel_id.to_proto(),
5056 role: cm.role.into(),
5057 })
5058 .collect(),
5059 ..Default::default()
5060 };
5061
5062 let mut update = build_channels_update(result.new_channels);
5063 update.delete_channels = result
5064 .removed_channels
5065 .into_iter()
5066 .map(|id| id.to_proto())
5067 .collect();
5068 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5069
5070 for connection_id in connection_pool.user_connection_ids(user_id) {
5071 peer.send(connection_id, user_channels_update.clone())
5072 .trace_err();
5073 peer.send(connection_id, update.clone()).trace_err();
5074 }
5075}
5076
5077fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5078 proto::UpdateUserChannels {
5079 channel_memberships: channels
5080 .channel_memberships
5081 .iter()
5082 .map(|m| proto::ChannelMembership {
5083 channel_id: m.channel_id.to_proto(),
5084 role: m.role.into(),
5085 })
5086 .collect(),
5087 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5088 observed_channel_message_id: channels.observed_channel_messages.clone(),
5089 }
5090}
5091
5092fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5093 let mut update = proto::UpdateChannels::default();
5094
5095 for channel in channels.channels {
5096 update.channels.push(channel.to_proto());
5097 }
5098
5099 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5100 update.latest_channel_message_ids = channels.latest_channel_messages;
5101
5102 for (channel_id, participants) in channels.channel_participants {
5103 update
5104 .channel_participants
5105 .push(proto::ChannelParticipants {
5106 channel_id: channel_id.to_proto(),
5107 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5108 });
5109 }
5110
5111 for channel in channels.invited_channels {
5112 update.channel_invitations.push(channel.to_proto());
5113 }
5114
5115 update.hosted_projects = channels.hosted_projects;
5116 update
5117}
5118
5119fn build_initial_contacts_update(
5120 contacts: Vec<db::Contact>,
5121 pool: &ConnectionPool,
5122) -> proto::UpdateContacts {
5123 let mut update = proto::UpdateContacts::default();
5124
5125 for contact in contacts {
5126 match contact {
5127 db::Contact::Accepted { user_id, busy } => {
5128 update.contacts.push(contact_for_user(user_id, busy, &pool));
5129 }
5130 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5131 db::Contact::Incoming { user_id } => {
5132 update
5133 .incoming_requests
5134 .push(proto::IncomingContactRequest {
5135 requester_id: user_id.to_proto(),
5136 })
5137 }
5138 }
5139 }
5140
5141 update
5142}
5143
5144fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5145 proto::Contact {
5146 user_id: user_id.to_proto(),
5147 online: pool.is_user_online(user_id),
5148 busy,
5149 }
5150}
5151
5152fn room_updated(room: &proto::Room, peer: &Peer) {
5153 broadcast(
5154 None,
5155 room.participants
5156 .iter()
5157 .filter_map(|participant| Some(participant.peer_id?.into())),
5158 |peer_id| {
5159 peer.send(
5160 peer_id,
5161 proto::RoomUpdated {
5162 room: Some(room.clone()),
5163 },
5164 )
5165 },
5166 );
5167}
5168
5169fn channel_updated(
5170 channel: &db::channel::Model,
5171 room: &proto::Room,
5172 peer: &Peer,
5173 pool: &ConnectionPool,
5174) {
5175 let participants = room
5176 .participants
5177 .iter()
5178 .map(|p| p.user_id)
5179 .collect::<Vec<_>>();
5180
5181 broadcast(
5182 None,
5183 pool.channel_connection_ids(channel.root_id())
5184 .filter_map(|(channel_id, role)| {
5185 role.can_see_channel(channel.visibility).then(|| channel_id)
5186 }),
5187 |peer_id| {
5188 peer.send(
5189 peer_id,
5190 proto::UpdateChannels {
5191 channel_participants: vec![proto::ChannelParticipants {
5192 channel_id: channel.id.to_proto(),
5193 participant_user_ids: participants.clone(),
5194 }],
5195 ..Default::default()
5196 },
5197 )
5198 },
5199 );
5200}
5201
5202async fn send_dev_server_projects_update(
5203 user_id: UserId,
5204 mut status: proto::DevServerProjectsUpdate,
5205 session: &Session,
5206) {
5207 let pool = session.connection_pool().await;
5208 for dev_server in &mut status.dev_servers {
5209 dev_server.status =
5210 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5211 }
5212 let connections = pool.user_connection_ids(user_id);
5213 for connection_id in connections {
5214 session.peer.send(connection_id, status.clone()).trace_err();
5215 }
5216}
5217
5218async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5219 let db = session.db().await;
5220
5221 let contacts = db.get_contacts(user_id).await?;
5222 let busy = db.is_user_busy(user_id).await?;
5223
5224 let pool = session.connection_pool().await;
5225 let updated_contact = contact_for_user(user_id, busy, &pool);
5226 for contact in contacts {
5227 if let db::Contact::Accepted {
5228 user_id: contact_user_id,
5229 ..
5230 } = contact
5231 {
5232 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5233 session
5234 .peer
5235 .send(
5236 contact_conn_id,
5237 proto::UpdateContacts {
5238 contacts: vec![updated_contact.clone()],
5239 remove_contacts: Default::default(),
5240 incoming_requests: Default::default(),
5241 remove_incoming_requests: Default::default(),
5242 outgoing_requests: Default::default(),
5243 remove_outgoing_requests: Default::default(),
5244 },
5245 )
5246 .trace_err();
5247 }
5248 }
5249 }
5250 Ok(())
5251}
5252
5253async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5254 log::info!("lost dev server connection, unsharing projects");
5255 let project_ids = session
5256 .db()
5257 .await
5258 .get_stale_dev_server_projects(session.connection_id)
5259 .await?;
5260
5261 for project_id in project_ids {
5262 // not unshare re-checks the connection ids match, so we get away with no transaction
5263 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5264 }
5265
5266 let user_id = session.dev_server().user_id;
5267 let update = session
5268 .db()
5269 .await
5270 .dev_server_projects_update(user_id)
5271 .await?;
5272
5273 send_dev_server_projects_update(user_id, update, session).await;
5274
5275 Ok(())
5276}
5277
5278async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5279 let mut contacts_to_update = HashSet::default();
5280
5281 let room_id;
5282 let canceled_calls_to_user_ids;
5283 let live_kit_room;
5284 let delete_live_kit_room;
5285 let room;
5286 let channel;
5287
5288 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5289 contacts_to_update.insert(session.user_id());
5290
5291 for project in left_room.left_projects.values() {
5292 project_left(project, session);
5293 }
5294
5295 room_id = RoomId::from_proto(left_room.room.id);
5296 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5297 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5298 delete_live_kit_room = left_room.deleted;
5299 room = mem::take(&mut left_room.room);
5300 channel = mem::take(&mut left_room.channel);
5301
5302 room_updated(&room, &session.peer);
5303 } else {
5304 return Ok(());
5305 }
5306
5307 if let Some(channel) = channel {
5308 channel_updated(
5309 &channel,
5310 &room,
5311 &session.peer,
5312 &*session.connection_pool().await,
5313 );
5314 }
5315
5316 {
5317 let pool = session.connection_pool().await;
5318 for canceled_user_id in canceled_calls_to_user_ids {
5319 for connection_id in pool.user_connection_ids(canceled_user_id) {
5320 session
5321 .peer
5322 .send(
5323 connection_id,
5324 proto::CallCanceled {
5325 room_id: room_id.to_proto(),
5326 },
5327 )
5328 .trace_err();
5329 }
5330 contacts_to_update.insert(canceled_user_id);
5331 }
5332 }
5333
5334 for contact_user_id in contacts_to_update {
5335 update_user_contacts(contact_user_id, &session).await?;
5336 }
5337
5338 if let Some(live_kit) = session.live_kit_client.as_ref() {
5339 live_kit
5340 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5341 .await
5342 .trace_err();
5343
5344 if delete_live_kit_room {
5345 live_kit.delete_room(live_kit_room).await.trace_err();
5346 }
5347 }
5348
5349 Ok(())
5350}
5351
5352async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5353 let left_channel_buffers = session
5354 .db()
5355 .await
5356 .leave_channel_buffers(session.connection_id)
5357 .await?;
5358
5359 for left_buffer in left_channel_buffers {
5360 channel_buffer_updated(
5361 session.connection_id,
5362 left_buffer.connections,
5363 &proto::UpdateChannelBufferCollaborators {
5364 channel_id: left_buffer.channel_id.to_proto(),
5365 collaborators: left_buffer.collaborators,
5366 },
5367 &session.peer,
5368 );
5369 }
5370
5371 Ok(())
5372}
5373
5374fn project_left(project: &db::LeftProject, session: &UserSession) {
5375 for connection_id in &project.connection_ids {
5376 if project.should_unshare {
5377 session
5378 .peer
5379 .send(
5380 *connection_id,
5381 proto::UnshareProject {
5382 project_id: project.id.to_proto(),
5383 },
5384 )
5385 .trace_err();
5386 } else {
5387 session
5388 .peer
5389 .send(
5390 *connection_id,
5391 proto::RemoveProjectCollaborator {
5392 project_id: project.id.to_proto(),
5393 peer_id: Some(session.connection_id.into()),
5394 },
5395 )
5396 .trace_err();
5397 }
5398 }
5399}
5400
5401pub trait ResultExt {
5402 type Ok;
5403
5404 fn trace_err(self) -> Option<Self::Ok>;
5405}
5406
5407impl<T, E> ResultExt for Result<T, E>
5408where
5409 E: std::fmt::Debug,
5410{
5411 type Ok = T;
5412
5413 #[track_caller]
5414 fn trace_err(self) -> Option<T> {
5415 match self {
5416 Ok(value) => Some(value),
5417 Err(error) => {
5418 tracing::error!("{:?}", error);
5419 None
5420 }
5421 }
5422 }
5423}