1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::RestartLanguageServers>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::LinkedEditingRange>,
553 ))
554 .add_message_handler(create_buffer_for_peer)
555 .add_request_handler(update_buffer)
556 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
557 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
558 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
561 .add_request_handler(get_users)
562 .add_request_handler(user_handler(fuzzy_search_users))
563 .add_request_handler(user_handler(request_contact))
564 .add_request_handler(user_handler(remove_contact))
565 .add_request_handler(user_handler(respond_to_contact_request))
566 .add_message_handler(subscribe_to_channels)
567 .add_request_handler(user_handler(create_channel))
568 .add_request_handler(user_handler(delete_channel))
569 .add_request_handler(user_handler(invite_channel_member))
570 .add_request_handler(user_handler(remove_channel_member))
571 .add_request_handler(user_handler(set_channel_member_role))
572 .add_request_handler(user_handler(set_channel_visibility))
573 .add_request_handler(user_handler(rename_channel))
574 .add_request_handler(user_handler(join_channel_buffer))
575 .add_request_handler(user_handler(leave_channel_buffer))
576 .add_message_handler(user_message_handler(update_channel_buffer))
577 .add_request_handler(user_handler(rejoin_channel_buffers))
578 .add_request_handler(user_handler(get_channel_members))
579 .add_request_handler(user_handler(respond_to_channel_invite))
580 .add_request_handler(user_handler(join_channel))
581 .add_request_handler(user_handler(join_channel_chat))
582 .add_message_handler(user_message_handler(leave_channel_chat))
583 .add_request_handler(user_handler(send_channel_message))
584 .add_request_handler(user_handler(remove_channel_message))
585 .add_request_handler(user_handler(update_channel_message))
586 .add_request_handler(user_handler(get_channel_messages))
587 .add_request_handler(user_handler(get_channel_messages_by_id))
588 .add_request_handler(user_handler(get_notifications))
589 .add_request_handler(user_handler(mark_notification_as_read))
590 .add_request_handler(user_handler(move_channel))
591 .add_request_handler(user_handler(follow))
592 .add_message_handler(user_message_handler(unfollow))
593 .add_message_handler(user_message_handler(update_followers))
594 .add_request_handler(user_handler(get_private_user_info))
595 .add_message_handler(user_message_handler(acknowledge_channel_message))
596 .add_message_handler(user_message_handler(acknowledge_buffer_version))
597 .add_request_handler(user_handler(get_supermaven_api_key))
598 .add_request_handler(user_handler(
599 forward_mutating_project_request::<proto::OpenContext>,
600 ))
601 .add_request_handler(user_handler(
602 forward_mutating_project_request::<proto::SynchronizeContexts>,
603 ))
604 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
605 .add_message_handler(update_context)
606 .add_streaming_request_handler({
607 let app_state = app_state.clone();
608 move |request, response, session| {
609 complete_with_language_model(
610 request,
611 response,
612 session,
613 app_state.config.openai_api_key.clone(),
614 app_state.config.google_ai_api_key.clone(),
615 app_state.config.anthropic_api_key.clone(),
616 )
617 }
618 })
619 .add_request_handler({
620 let app_state = app_state.clone();
621 user_handler(move |request, response, session| {
622 count_tokens_with_language_model(
623 request,
624 response,
625 session,
626 app_state.config.google_ai_api_key.clone(),
627 )
628 })
629 })
630 .add_request_handler({
631 user_handler(move |request, response, session| {
632 get_cached_embeddings(request, response, session)
633 })
634 })
635 .add_request_handler({
636 let app_state = app_state.clone();
637 user_handler(move |request, response, session| {
638 compute_embeddings(
639 request,
640 response,
641 session,
642 app_state.config.openai_api_key.clone(),
643 )
644 })
645 });
646
647 Arc::new(server)
648 }
649
650 pub async fn start(&self) -> Result<()> {
651 let server_id = *self.id.lock();
652 let app_state = self.app_state.clone();
653 let peer = self.peer.clone();
654 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
655 let pool = self.connection_pool.clone();
656 let live_kit_client = self.app_state.live_kit_client.clone();
657
658 let span = info_span!("start server");
659 self.app_state.executor.spawn_detached(
660 async move {
661 tracing::info!("waiting for cleanup timeout");
662 timeout.await;
663 tracing::info!("cleanup timeout expired, retrieving stale rooms");
664 if let Some((room_ids, channel_ids)) = app_state
665 .db
666 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
667 .await
668 .trace_err()
669 {
670 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
671 tracing::info!(
672 stale_channel_buffer_count = channel_ids.len(),
673 "retrieved stale channel buffers"
674 );
675
676 for channel_id in channel_ids {
677 if let Some(refreshed_channel_buffer) = app_state
678 .db
679 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
680 .await
681 .trace_err()
682 {
683 for connection_id in refreshed_channel_buffer.connection_ids {
684 peer.send(
685 connection_id,
686 proto::UpdateChannelBufferCollaborators {
687 channel_id: channel_id.to_proto(),
688 collaborators: refreshed_channel_buffer
689 .collaborators
690 .clone(),
691 },
692 )
693 .trace_err();
694 }
695 }
696 }
697
698 for room_id in room_ids {
699 let mut contacts_to_update = HashSet::default();
700 let mut canceled_calls_to_user_ids = Vec::new();
701 let mut live_kit_room = String::new();
702 let mut delete_live_kit_room = false;
703
704 if let Some(mut refreshed_room) = app_state
705 .db
706 .clear_stale_room_participants(room_id, server_id)
707 .await
708 .trace_err()
709 {
710 tracing::info!(
711 room_id = room_id.0,
712 new_participant_count = refreshed_room.room.participants.len(),
713 "refreshed room"
714 );
715 room_updated(&refreshed_room.room, &peer);
716 if let Some(channel) = refreshed_room.channel.as_ref() {
717 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
718 }
719 contacts_to_update
720 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
721 contacts_to_update
722 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
723 canceled_calls_to_user_ids =
724 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
725 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
726 delete_live_kit_room = refreshed_room.room.participants.is_empty();
727 }
728
729 {
730 let pool = pool.lock();
731 for canceled_user_id in canceled_calls_to_user_ids {
732 for connection_id in pool.user_connection_ids(canceled_user_id) {
733 peer.send(
734 connection_id,
735 proto::CallCanceled {
736 room_id: room_id.to_proto(),
737 },
738 )
739 .trace_err();
740 }
741 }
742 }
743
744 for user_id in contacts_to_update {
745 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
746 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
747 if let Some((busy, contacts)) = busy.zip(contacts) {
748 let pool = pool.lock();
749 let updated_contact = contact_for_user(user_id, busy, &pool);
750 for contact in contacts {
751 if let db::Contact::Accepted {
752 user_id: contact_user_id,
753 ..
754 } = contact
755 {
756 for contact_conn_id in
757 pool.user_connection_ids(contact_user_id)
758 {
759 peer.send(
760 contact_conn_id,
761 proto::UpdateContacts {
762 contacts: vec![updated_contact.clone()],
763 remove_contacts: Default::default(),
764 incoming_requests: Default::default(),
765 remove_incoming_requests: Default::default(),
766 outgoing_requests: Default::default(),
767 remove_outgoing_requests: Default::default(),
768 },
769 )
770 .trace_err();
771 }
772 }
773 }
774 }
775 }
776
777 if let Some(live_kit) = live_kit_client.as_ref() {
778 if delete_live_kit_room {
779 live_kit.delete_room(live_kit_room).await.trace_err();
780 }
781 }
782 }
783 }
784
785 app_state
786 .db
787 .delete_stale_servers(&app_state.config.zed_environment, server_id)
788 .await
789 .trace_err();
790 }
791 .instrument(span),
792 );
793 Ok(())
794 }
795
796 pub fn teardown(&self) {
797 self.peer.teardown();
798 self.connection_pool.lock().reset();
799 let _ = self.teardown.send(true);
800 }
801
802 #[cfg(test)]
803 pub fn reset(&self, id: ServerId) {
804 self.teardown();
805 *self.id.lock() = id;
806 self.peer.reset(id.0 as u32);
807 let _ = self.teardown.send(false);
808 }
809
810 #[cfg(test)]
811 pub fn id(&self) -> ServerId {
812 *self.id.lock()
813 }
814
815 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
816 where
817 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
818 Fut: 'static + Send + Future<Output = Result<()>>,
819 M: EnvelopedMessage,
820 {
821 let prev_handler = self.handlers.insert(
822 TypeId::of::<M>(),
823 Box::new(move |envelope, session| {
824 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
825 let received_at = envelope.received_at;
826 tracing::info!("message received");
827 let start_time = Instant::now();
828 let future = (handler)(*envelope, session);
829 async move {
830 let result = future.await;
831 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
832 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
833 let queue_duration_ms = total_duration_ms - processing_duration_ms;
834 let payload_type = M::NAME;
835
836 match result {
837 Err(error) => {
838 tracing::error!(
839 ?error,
840 total_duration_ms,
841 processing_duration_ms,
842 queue_duration_ms,
843 payload_type,
844 "error handling message"
845 )
846 }
847 Ok(()) => tracing::info!(
848 total_duration_ms,
849 processing_duration_ms,
850 queue_duration_ms,
851 "finished handling message"
852 ),
853 }
854 }
855 .boxed()
856 }),
857 );
858 if prev_handler.is_some() {
859 panic!("registered a handler for the same message twice");
860 }
861 self
862 }
863
864 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
865 where
866 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
867 Fut: 'static + Send + Future<Output = Result<()>>,
868 M: EnvelopedMessage,
869 {
870 self.add_handler(move |envelope, session| handler(envelope.payload, session));
871 self
872 }
873
874 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
875 where
876 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
877 Fut: Send + Future<Output = Result<()>>,
878 M: RequestMessage,
879 {
880 let handler = Arc::new(handler);
881 self.add_handler(move |envelope, session| {
882 let receipt = envelope.receipt();
883 let handler = handler.clone();
884 async move {
885 let peer = session.peer.clone();
886 let responded = Arc::new(AtomicBool::default());
887 let response = Response {
888 peer: peer.clone(),
889 responded: responded.clone(),
890 receipt,
891 };
892 match (handler)(envelope.payload, response, session).await {
893 Ok(()) => {
894 if responded.load(std::sync::atomic::Ordering::SeqCst) {
895 Ok(())
896 } else {
897 Err(anyhow!("handler did not send a response"))?
898 }
899 }
900 Err(error) => {
901 let proto_err = match &error {
902 Error::Internal(err) => err.to_proto(),
903 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
904 };
905 peer.respond_with_error(receipt, proto_err)?;
906 Err(error)
907 }
908 }
909 }
910 })
911 }
912
913 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
914 where
915 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
916 Fut: Send + Future<Output = Result<()>>,
917 M: RequestMessage,
918 {
919 let handler = Arc::new(handler);
920 self.add_handler(move |envelope, session| {
921 let receipt = envelope.receipt();
922 let handler = handler.clone();
923 async move {
924 let peer = session.peer.clone();
925 let response = StreamingResponse {
926 peer: peer.clone(),
927 receipt,
928 };
929 match (handler)(envelope.payload, response, session).await {
930 Ok(()) => {
931 peer.end_stream(receipt)?;
932 Ok(())
933 }
934 Err(error) => {
935 let proto_err = match &error {
936 Error::Internal(err) => err.to_proto(),
937 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
938 };
939 peer.respond_with_error(receipt, proto_err)?;
940 Err(error)
941 }
942 }
943 }
944 })
945 }
946
947 #[allow(clippy::too_many_arguments)]
948 pub fn handle_connection(
949 self: &Arc<Self>,
950 connection: Connection,
951 address: String,
952 principal: Principal,
953 zed_version: ZedVersion,
954 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
955 executor: Executor,
956 ) -> impl Future<Output = ()> {
957 let this = self.clone();
958 let span = info_span!("handle connection", %address,
959 connection_id=field::Empty,
960 user_id=field::Empty,
961 login=field::Empty,
962 impersonator=field::Empty,
963 dev_server_id=field::Empty
964 );
965 principal.update_span(&span);
966
967 let mut teardown = self.teardown.subscribe();
968 async move {
969 if *teardown.borrow() {
970 tracing::error!("server is tearing down");
971 return
972 }
973 let (connection_id, handle_io, mut incoming_rx) = this
974 .peer
975 .add_connection(connection, {
976 let executor = executor.clone();
977 move |duration| executor.sleep(duration)
978 });
979 tracing::Span::current().record("connection_id", format!("{}", connection_id));
980 tracing::info!("connection opened");
981
982 let http_client = match IsahcHttpClient::new() {
983 Ok(http_client) => Arc::new(http_client),
984 Err(error) => {
985 tracing::error!(?error, "failed to create HTTP client");
986 return;
987 }
988 };
989
990 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
991 Some(Arc::new(SupermavenAdminApi::new(
992 supermaven_admin_api_key.to_string(),
993 http_client.clone(),
994 )))
995 } else {
996 None
997 };
998
999 let session = Session {
1000 principal: principal.clone(),
1001 connection_id,
1002 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1003 peer: this.peer.clone(),
1004 connection_pool: this.connection_pool.clone(),
1005 live_kit_client: this.app_state.live_kit_client.clone(),
1006 http_client,
1007 rate_limiter: this.app_state.rate_limiter.clone(),
1008 _executor: executor.clone(),
1009 supermaven_client,
1010 };
1011
1012 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1013 tracing::error!(?error, "failed to send initial client update");
1014 return;
1015 }
1016
1017 let handle_io = handle_io.fuse();
1018 futures::pin_mut!(handle_io);
1019
1020 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1021 // This prevents deadlocks when e.g., client A performs a request to client B and
1022 // client B performs a request to client A. If both clients stop processing further
1023 // messages until their respective request completes, they won't have a chance to
1024 // respond to the other client's request and cause a deadlock.
1025 //
1026 // This arrangement ensures we will attempt to process earlier messages first, but fall
1027 // back to processing messages arrived later in the spirit of making progress.
1028 let mut foreground_message_handlers = FuturesUnordered::new();
1029 let concurrent_handlers = Arc::new(Semaphore::new(256));
1030 loop {
1031 let next_message = async {
1032 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1033 let message = incoming_rx.next().await;
1034 (permit, message)
1035 }.fuse();
1036 futures::pin_mut!(next_message);
1037 futures::select_biased! {
1038 _ = teardown.changed().fuse() => return,
1039 result = handle_io => {
1040 if let Err(error) = result {
1041 tracing::error!(?error, "error handling I/O");
1042 }
1043 break;
1044 }
1045 _ = foreground_message_handlers.next() => {}
1046 next_message = next_message => {
1047 let (permit, message) = next_message;
1048 if let Some(message) = message {
1049 let type_name = message.payload_type_name();
1050 // note: we copy all the fields from the parent span so we can query them in the logs.
1051 // (https://github.com/tokio-rs/tracing/issues/2670).
1052 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1053 user_id=field::Empty,
1054 login=field::Empty,
1055 impersonator=field::Empty,
1056 dev_server_id=field::Empty
1057 );
1058 principal.update_span(&span);
1059 let span_enter = span.enter();
1060 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1061 let is_background = message.is_background();
1062 let handle_message = (handler)(message, session.clone());
1063 drop(span_enter);
1064
1065 let handle_message = async move {
1066 handle_message.await;
1067 drop(permit);
1068 }.instrument(span);
1069 if is_background {
1070 executor.spawn_detached(handle_message);
1071 } else {
1072 foreground_message_handlers.push(handle_message);
1073 }
1074 } else {
1075 tracing::error!("no message handler");
1076 }
1077 } else {
1078 tracing::info!("connection closed");
1079 break;
1080 }
1081 }
1082 }
1083 }
1084
1085 drop(foreground_message_handlers);
1086 tracing::info!("signing out");
1087 if let Err(error) = connection_lost(session, teardown, executor).await {
1088 tracing::error!(?error, "error signing out");
1089 }
1090
1091 }.instrument(span)
1092 }
1093
1094 async fn send_initial_client_update(
1095 &self,
1096 connection_id: ConnectionId,
1097 principal: &Principal,
1098 zed_version: ZedVersion,
1099 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1100 session: &Session,
1101 ) -> Result<()> {
1102 self.peer.send(
1103 connection_id,
1104 proto::Hello {
1105 peer_id: Some(connection_id.into()),
1106 },
1107 )?;
1108 tracing::info!("sent hello message");
1109 if let Some(send_connection_id) = send_connection_id.take() {
1110 let _ = send_connection_id.send(connection_id);
1111 }
1112
1113 match principal {
1114 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1115 if !user.connected_once {
1116 self.peer.send(connection_id, proto::ShowContacts {})?;
1117 self.app_state
1118 .db
1119 .set_user_connected_once(user.id, true)
1120 .await?;
1121 }
1122
1123 let (contacts, dev_server_projects) = future::try_join(
1124 self.app_state.db.get_contacts(user.id),
1125 self.app_state.db.dev_server_projects_update(user.id),
1126 )
1127 .await?;
1128
1129 {
1130 let mut pool = self.connection_pool.lock();
1131 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1132 self.peer.send(
1133 connection_id,
1134 build_initial_contacts_update(contacts, &pool),
1135 )?;
1136 }
1137
1138 if should_auto_subscribe_to_channels(zed_version) {
1139 subscribe_user_to_channels(user.id, session).await?;
1140 }
1141
1142 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1143
1144 if let Some(incoming_call) =
1145 self.app_state.db.incoming_call_for_user(user.id).await?
1146 {
1147 self.peer.send(connection_id, incoming_call)?;
1148 }
1149
1150 update_user_contacts(user.id, &session).await?;
1151 }
1152 Principal::DevServer(dev_server) => {
1153 {
1154 let mut pool = self.connection_pool.lock();
1155 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1156 {
1157 self.peer.send(
1158 stale_connection_id,
1159 proto::ShutdownDevServer {
1160 reason: Some(
1161 "another dev server connected with the same token".to_string(),
1162 ),
1163 },
1164 )?;
1165 pool.remove_connection(stale_connection_id)?;
1166 };
1167 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1168 }
1169
1170 let projects = self
1171 .app_state
1172 .db
1173 .get_projects_for_dev_server(dev_server.id)
1174 .await?;
1175 self.peer
1176 .send(connection_id, proto::DevServerInstructions { projects })?;
1177
1178 let status = self
1179 .app_state
1180 .db
1181 .dev_server_projects_update(dev_server.user_id)
1182 .await?;
1183 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1184 }
1185 }
1186
1187 Ok(())
1188 }
1189
1190 pub async fn invite_code_redeemed(
1191 self: &Arc<Self>,
1192 inviter_id: UserId,
1193 invitee_id: UserId,
1194 ) -> Result<()> {
1195 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1196 if let Some(code) = &user.invite_code {
1197 let pool = self.connection_pool.lock();
1198 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1199 for connection_id in pool.user_connection_ids(inviter_id) {
1200 self.peer.send(
1201 connection_id,
1202 proto::UpdateContacts {
1203 contacts: vec![invitee_contact.clone()],
1204 ..Default::default()
1205 },
1206 )?;
1207 self.peer.send(
1208 connection_id,
1209 proto::UpdateInviteInfo {
1210 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1211 count: user.invite_count as u32,
1212 },
1213 )?;
1214 }
1215 }
1216 }
1217 Ok(())
1218 }
1219
1220 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1221 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1222 if let Some(invite_code) = &user.invite_code {
1223 let pool = self.connection_pool.lock();
1224 for connection_id in pool.user_connection_ids(user_id) {
1225 self.peer.send(
1226 connection_id,
1227 proto::UpdateInviteInfo {
1228 url: format!(
1229 "{}{}",
1230 self.app_state.config.invite_link_prefix, invite_code
1231 ),
1232 count: user.invite_count as u32,
1233 },
1234 )?;
1235 }
1236 }
1237 }
1238 Ok(())
1239 }
1240
1241 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1242 ServerSnapshot {
1243 connection_pool: ConnectionPoolGuard {
1244 guard: self.connection_pool.lock(),
1245 _not_send: PhantomData,
1246 },
1247 peer: &self.peer,
1248 }
1249 }
1250}
1251
1252impl<'a> Deref for ConnectionPoolGuard<'a> {
1253 type Target = ConnectionPool;
1254
1255 fn deref(&self) -> &Self::Target {
1256 &self.guard
1257 }
1258}
1259
1260impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1261 fn deref_mut(&mut self) -> &mut Self::Target {
1262 &mut self.guard
1263 }
1264}
1265
1266impl<'a> Drop for ConnectionPoolGuard<'a> {
1267 fn drop(&mut self) {
1268 #[cfg(test)]
1269 self.check_invariants();
1270 }
1271}
1272
1273fn broadcast<F>(
1274 sender_id: Option<ConnectionId>,
1275 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1276 mut f: F,
1277) where
1278 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1279{
1280 for receiver_id in receiver_ids {
1281 if Some(receiver_id) != sender_id {
1282 if let Err(error) = f(receiver_id) {
1283 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1284 }
1285 }
1286 }
1287}
1288
1289pub struct ProtocolVersion(u32);
1290
1291impl Header for ProtocolVersion {
1292 fn name() -> &'static HeaderName {
1293 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1295 }
1296
1297 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298 where
1299 Self: Sized,
1300 I: Iterator<Item = &'i axum::http::HeaderValue>,
1301 {
1302 let version = values
1303 .next()
1304 .ok_or_else(axum::headers::Error::invalid)?
1305 .to_str()
1306 .map_err(|_| axum::headers::Error::invalid())?
1307 .parse()
1308 .map_err(|_| axum::headers::Error::invalid())?;
1309 Ok(Self(version))
1310 }
1311
1312 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313 values.extend([self.0.to_string().parse().unwrap()]);
1314 }
1315}
1316
1317pub struct AppVersionHeader(SemanticVersion);
1318impl Header for AppVersionHeader {
1319 fn name() -> &'static HeaderName {
1320 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1321 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1322 }
1323
1324 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1325 where
1326 Self: Sized,
1327 I: Iterator<Item = &'i axum::http::HeaderValue>,
1328 {
1329 let version = values
1330 .next()
1331 .ok_or_else(axum::headers::Error::invalid)?
1332 .to_str()
1333 .map_err(|_| axum::headers::Error::invalid())?
1334 .parse()
1335 .map_err(|_| axum::headers::Error::invalid())?;
1336 Ok(Self(version))
1337 }
1338
1339 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1340 values.extend([self.0.to_string().parse().unwrap()]);
1341 }
1342}
1343
1344pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1345 Router::new()
1346 .route("/rpc", get(handle_websocket_request))
1347 .layer(
1348 ServiceBuilder::new()
1349 .layer(Extension(server.app_state.clone()))
1350 .layer(middleware::from_fn(auth::validate_header)),
1351 )
1352 .route("/metrics", get(handle_metrics))
1353 .layer(Extension(server))
1354}
1355
1356pub async fn handle_websocket_request(
1357 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1358 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1359 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1360 Extension(server): Extension<Arc<Server>>,
1361 Extension(principal): Extension<Principal>,
1362 ws: WebSocketUpgrade,
1363) -> axum::response::Response {
1364 if protocol_version != rpc::PROTOCOL_VERSION {
1365 return (
1366 StatusCode::UPGRADE_REQUIRED,
1367 "client must be upgraded".to_string(),
1368 )
1369 .into_response();
1370 }
1371
1372 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1373 return (
1374 StatusCode::UPGRADE_REQUIRED,
1375 "no version header found".to_string(),
1376 )
1377 .into_response();
1378 };
1379
1380 if !version.can_collaborate() {
1381 return (
1382 StatusCode::UPGRADE_REQUIRED,
1383 "client must be upgraded".to_string(),
1384 )
1385 .into_response();
1386 }
1387
1388 let socket_address = socket_address.to_string();
1389 ws.on_upgrade(move |socket| {
1390 let socket = socket
1391 .map_ok(to_tungstenite_message)
1392 .err_into()
1393 .with(|message| async move { Ok(to_axum_message(message)) });
1394 let connection = Connection::new(Box::pin(socket));
1395 async move {
1396 server
1397 .handle_connection(
1398 connection,
1399 socket_address,
1400 principal,
1401 version,
1402 None,
1403 Executor::Production,
1404 )
1405 .await;
1406 }
1407 })
1408}
1409
1410pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1411 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1412 let connections_metric = CONNECTIONS_METRIC
1413 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1414
1415 let connections = server
1416 .connection_pool
1417 .lock()
1418 .connections()
1419 .filter(|connection| !connection.admin)
1420 .count();
1421 connections_metric.set(connections as _);
1422
1423 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1424 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1425 register_int_gauge!(
1426 "shared_projects",
1427 "number of open projects with one or more guests"
1428 )
1429 .unwrap()
1430 });
1431
1432 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1433 shared_projects_metric.set(shared_projects as _);
1434
1435 let encoder = prometheus::TextEncoder::new();
1436 let metric_families = prometheus::gather();
1437 let encoded_metrics = encoder
1438 .encode_to_string(&metric_families)
1439 .map_err(|err| anyhow!("{}", err))?;
1440 Ok(encoded_metrics)
1441}
1442
1443#[instrument(err, skip(executor))]
1444async fn connection_lost(
1445 session: Session,
1446 mut teardown: watch::Receiver<bool>,
1447 executor: Executor,
1448) -> Result<()> {
1449 session.peer.disconnect(session.connection_id);
1450 session
1451 .connection_pool()
1452 .await
1453 .remove_connection(session.connection_id)?;
1454
1455 session
1456 .db()
1457 .await
1458 .connection_lost(session.connection_id)
1459 .await
1460 .trace_err();
1461
1462 futures::select_biased! {
1463 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1464 match &session.principal {
1465 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1466 let session = session.for_user().unwrap();
1467
1468 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1469 leave_room_for_session(&session, session.connection_id).await.trace_err();
1470 leave_channel_buffers_for_session(&session)
1471 .await
1472 .trace_err();
1473
1474 if !session
1475 .connection_pool()
1476 .await
1477 .is_user_online(session.user_id())
1478 {
1479 let db = session.db().await;
1480 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1481 room_updated(&room, &session.peer);
1482 }
1483 }
1484
1485 update_user_contacts(session.user_id(), &session).await?;
1486 },
1487 Principal::DevServer(_) => {
1488 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1489 },
1490 }
1491 },
1492 _ = teardown.changed().fuse() => {}
1493 }
1494
1495 Ok(())
1496}
1497
1498/// Acknowledges a ping from a client, used to keep the connection alive.
1499async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1500 response.send(proto::Ack {})?;
1501 Ok(())
1502}
1503
1504/// Creates a new room for calling (outside of channels)
1505async fn create_room(
1506 _request: proto::CreateRoom,
1507 response: Response<proto::CreateRoom>,
1508 session: UserSession,
1509) -> Result<()> {
1510 let live_kit_room = nanoid::nanoid!(30);
1511
1512 let live_kit_connection_info = util::maybe!(async {
1513 let live_kit = session.live_kit_client.as_ref();
1514 let live_kit = live_kit?;
1515 let user_id = session.user_id().to_string();
1516
1517 let token = live_kit
1518 .room_token(&live_kit_room, &user_id.to_string())
1519 .trace_err()?;
1520
1521 Some(proto::LiveKitConnectionInfo {
1522 server_url: live_kit.url().into(),
1523 token,
1524 can_publish: true,
1525 })
1526 })
1527 .await;
1528
1529 let room = session
1530 .db()
1531 .await
1532 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1533 .await?;
1534
1535 response.send(proto::CreateRoomResponse {
1536 room: Some(room.clone()),
1537 live_kit_connection_info,
1538 })?;
1539
1540 update_user_contacts(session.user_id(), &session).await?;
1541 Ok(())
1542}
1543
1544/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1545async fn join_room(
1546 request: proto::JoinRoom,
1547 response: Response<proto::JoinRoom>,
1548 session: UserSession,
1549) -> Result<()> {
1550 let room_id = RoomId::from_proto(request.id);
1551
1552 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1553
1554 if let Some(channel_id) = channel_id {
1555 return join_channel_internal(channel_id, Box::new(response), session).await;
1556 }
1557
1558 let joined_room = {
1559 let room = session
1560 .db()
1561 .await
1562 .join_room(room_id, session.user_id(), session.connection_id)
1563 .await?;
1564 room_updated(&room.room, &session.peer);
1565 room.into_inner()
1566 };
1567
1568 for connection_id in session
1569 .connection_pool()
1570 .await
1571 .user_connection_ids(session.user_id())
1572 {
1573 session
1574 .peer
1575 .send(
1576 connection_id,
1577 proto::CallCanceled {
1578 room_id: room_id.to_proto(),
1579 },
1580 )
1581 .trace_err();
1582 }
1583
1584 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1585 if let Some(token) = live_kit
1586 .room_token(
1587 &joined_room.room.live_kit_room,
1588 &session.user_id().to_string(),
1589 )
1590 .trace_err()
1591 {
1592 Some(proto::LiveKitConnectionInfo {
1593 server_url: live_kit.url().into(),
1594 token,
1595 can_publish: true,
1596 })
1597 } else {
1598 None
1599 }
1600 } else {
1601 None
1602 };
1603
1604 response.send(proto::JoinRoomResponse {
1605 room: Some(joined_room.room),
1606 channel_id: None,
1607 live_kit_connection_info,
1608 })?;
1609
1610 update_user_contacts(session.user_id(), &session).await?;
1611 Ok(())
1612}
1613
1614/// Rejoin room is used to reconnect to a room after connection errors.
1615async fn rejoin_room(
1616 request: proto::RejoinRoom,
1617 response: Response<proto::RejoinRoom>,
1618 session: UserSession,
1619) -> Result<()> {
1620 let room;
1621 let channel;
1622 {
1623 let mut rejoined_room = session
1624 .db()
1625 .await
1626 .rejoin_room(request, session.user_id(), session.connection_id)
1627 .await?;
1628
1629 response.send(proto::RejoinRoomResponse {
1630 room: Some(rejoined_room.room.clone()),
1631 reshared_projects: rejoined_room
1632 .reshared_projects
1633 .iter()
1634 .map(|project| proto::ResharedProject {
1635 id: project.id.to_proto(),
1636 collaborators: project
1637 .collaborators
1638 .iter()
1639 .map(|collaborator| collaborator.to_proto())
1640 .collect(),
1641 })
1642 .collect(),
1643 rejoined_projects: rejoined_room
1644 .rejoined_projects
1645 .iter()
1646 .map(|rejoined_project| rejoined_project.to_proto())
1647 .collect(),
1648 })?;
1649 room_updated(&rejoined_room.room, &session.peer);
1650
1651 for project in &rejoined_room.reshared_projects {
1652 for collaborator in &project.collaborators {
1653 session
1654 .peer
1655 .send(
1656 collaborator.connection_id,
1657 proto::UpdateProjectCollaborator {
1658 project_id: project.id.to_proto(),
1659 old_peer_id: Some(project.old_connection_id.into()),
1660 new_peer_id: Some(session.connection_id.into()),
1661 },
1662 )
1663 .trace_err();
1664 }
1665
1666 broadcast(
1667 Some(session.connection_id),
1668 project
1669 .collaborators
1670 .iter()
1671 .map(|collaborator| collaborator.connection_id),
1672 |connection_id| {
1673 session.peer.forward_send(
1674 session.connection_id,
1675 connection_id,
1676 proto::UpdateProject {
1677 project_id: project.id.to_proto(),
1678 worktrees: project.worktrees.clone(),
1679 },
1680 )
1681 },
1682 );
1683 }
1684
1685 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1686
1687 let rejoined_room = rejoined_room.into_inner();
1688
1689 room = rejoined_room.room;
1690 channel = rejoined_room.channel;
1691 }
1692
1693 if let Some(channel) = channel {
1694 channel_updated(
1695 &channel,
1696 &room,
1697 &session.peer,
1698 &*session.connection_pool().await,
1699 );
1700 }
1701
1702 update_user_contacts(session.user_id(), &session).await?;
1703 Ok(())
1704}
1705
1706fn notify_rejoined_projects(
1707 rejoined_projects: &mut Vec<RejoinedProject>,
1708 session: &UserSession,
1709) -> Result<()> {
1710 for project in rejoined_projects.iter() {
1711 for collaborator in &project.collaborators {
1712 session
1713 .peer
1714 .send(
1715 collaborator.connection_id,
1716 proto::UpdateProjectCollaborator {
1717 project_id: project.id.to_proto(),
1718 old_peer_id: Some(project.old_connection_id.into()),
1719 new_peer_id: Some(session.connection_id.into()),
1720 },
1721 )
1722 .trace_err();
1723 }
1724 }
1725
1726 for project in rejoined_projects {
1727 for worktree in mem::take(&mut project.worktrees) {
1728 #[cfg(any(test, feature = "test-support"))]
1729 const MAX_CHUNK_SIZE: usize = 2;
1730 #[cfg(not(any(test, feature = "test-support")))]
1731 const MAX_CHUNK_SIZE: usize = 256;
1732
1733 // Stream this worktree's entries.
1734 let message = proto::UpdateWorktree {
1735 project_id: project.id.to_proto(),
1736 worktree_id: worktree.id,
1737 abs_path: worktree.abs_path.clone(),
1738 root_name: worktree.root_name,
1739 updated_entries: worktree.updated_entries,
1740 removed_entries: worktree.removed_entries,
1741 scan_id: worktree.scan_id,
1742 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1743 updated_repositories: worktree.updated_repositories,
1744 removed_repositories: worktree.removed_repositories,
1745 };
1746 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1747 session.peer.send(session.connection_id, update.clone())?;
1748 }
1749
1750 // Stream this worktree's diagnostics.
1751 for summary in worktree.diagnostic_summaries {
1752 session.peer.send(
1753 session.connection_id,
1754 proto::UpdateDiagnosticSummary {
1755 project_id: project.id.to_proto(),
1756 worktree_id: worktree.id,
1757 summary: Some(summary),
1758 },
1759 )?;
1760 }
1761
1762 for settings_file in worktree.settings_files {
1763 session.peer.send(
1764 session.connection_id,
1765 proto::UpdateWorktreeSettings {
1766 project_id: project.id.to_proto(),
1767 worktree_id: worktree.id,
1768 path: settings_file.path,
1769 content: Some(settings_file.content),
1770 },
1771 )?;
1772 }
1773 }
1774
1775 for language_server in &project.language_servers {
1776 session.peer.send(
1777 session.connection_id,
1778 proto::UpdateLanguageServer {
1779 project_id: project.id.to_proto(),
1780 language_server_id: language_server.id,
1781 variant: Some(
1782 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1783 proto::LspDiskBasedDiagnosticsUpdated {},
1784 ),
1785 ),
1786 },
1787 )?;
1788 }
1789 }
1790 Ok(())
1791}
1792
1793/// leave room disconnects from the room.
1794async fn leave_room(
1795 _: proto::LeaveRoom,
1796 response: Response<proto::LeaveRoom>,
1797 session: UserSession,
1798) -> Result<()> {
1799 leave_room_for_session(&session, session.connection_id).await?;
1800 response.send(proto::Ack {})?;
1801 Ok(())
1802}
1803
1804/// Updates the permissions of someone else in the room.
1805async fn set_room_participant_role(
1806 request: proto::SetRoomParticipantRole,
1807 response: Response<proto::SetRoomParticipantRole>,
1808 session: UserSession,
1809) -> Result<()> {
1810 let user_id = UserId::from_proto(request.user_id);
1811 let role = ChannelRole::from(request.role());
1812
1813 let (live_kit_room, can_publish) = {
1814 let room = session
1815 .db()
1816 .await
1817 .set_room_participant_role(
1818 session.user_id(),
1819 RoomId::from_proto(request.room_id),
1820 user_id,
1821 role,
1822 )
1823 .await?;
1824
1825 let live_kit_room = room.live_kit_room.clone();
1826 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1827 room_updated(&room, &session.peer);
1828 (live_kit_room, can_publish)
1829 };
1830
1831 if let Some(live_kit) = session.live_kit_client.as_ref() {
1832 live_kit
1833 .update_participant(
1834 live_kit_room.clone(),
1835 request.user_id.to_string(),
1836 live_kit_server::proto::ParticipantPermission {
1837 can_subscribe: true,
1838 can_publish,
1839 can_publish_data: can_publish,
1840 hidden: false,
1841 recorder: false,
1842 },
1843 )
1844 .await
1845 .trace_err();
1846 }
1847
1848 response.send(proto::Ack {})?;
1849 Ok(())
1850}
1851
1852/// Call someone else into the current room
1853async fn call(
1854 request: proto::Call,
1855 response: Response<proto::Call>,
1856 session: UserSession,
1857) -> Result<()> {
1858 let room_id = RoomId::from_proto(request.room_id);
1859 let calling_user_id = session.user_id();
1860 let calling_connection_id = session.connection_id;
1861 let called_user_id = UserId::from_proto(request.called_user_id);
1862 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1863 if !session
1864 .db()
1865 .await
1866 .has_contact(calling_user_id, called_user_id)
1867 .await?
1868 {
1869 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1870 }
1871
1872 let incoming_call = {
1873 let (room, incoming_call) = &mut *session
1874 .db()
1875 .await
1876 .call(
1877 room_id,
1878 calling_user_id,
1879 calling_connection_id,
1880 called_user_id,
1881 initial_project_id,
1882 )
1883 .await?;
1884 room_updated(&room, &session.peer);
1885 mem::take(incoming_call)
1886 };
1887 update_user_contacts(called_user_id, &session).await?;
1888
1889 let mut calls = session
1890 .connection_pool()
1891 .await
1892 .user_connection_ids(called_user_id)
1893 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1894 .collect::<FuturesUnordered<_>>();
1895
1896 while let Some(call_response) = calls.next().await {
1897 match call_response.as_ref() {
1898 Ok(_) => {
1899 response.send(proto::Ack {})?;
1900 return Ok(());
1901 }
1902 Err(_) => {
1903 call_response.trace_err();
1904 }
1905 }
1906 }
1907
1908 {
1909 let room = session
1910 .db()
1911 .await
1912 .call_failed(room_id, called_user_id)
1913 .await?;
1914 room_updated(&room, &session.peer);
1915 }
1916 update_user_contacts(called_user_id, &session).await?;
1917
1918 Err(anyhow!("failed to ring user"))?
1919}
1920
1921/// Cancel an outgoing call.
1922async fn cancel_call(
1923 request: proto::CancelCall,
1924 response: Response<proto::CancelCall>,
1925 session: UserSession,
1926) -> Result<()> {
1927 let called_user_id = UserId::from_proto(request.called_user_id);
1928 let room_id = RoomId::from_proto(request.room_id);
1929 {
1930 let room = session
1931 .db()
1932 .await
1933 .cancel_call(room_id, session.connection_id, called_user_id)
1934 .await?;
1935 room_updated(&room, &session.peer);
1936 }
1937
1938 for connection_id in session
1939 .connection_pool()
1940 .await
1941 .user_connection_ids(called_user_id)
1942 {
1943 session
1944 .peer
1945 .send(
1946 connection_id,
1947 proto::CallCanceled {
1948 room_id: room_id.to_proto(),
1949 },
1950 )
1951 .trace_err();
1952 }
1953 response.send(proto::Ack {})?;
1954
1955 update_user_contacts(called_user_id, &session).await?;
1956 Ok(())
1957}
1958
1959/// Decline an incoming call.
1960async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1961 let room_id = RoomId::from_proto(message.room_id);
1962 {
1963 let room = session
1964 .db()
1965 .await
1966 .decline_call(Some(room_id), session.user_id())
1967 .await?
1968 .ok_or_else(|| anyhow!("failed to decline call"))?;
1969 room_updated(&room, &session.peer);
1970 }
1971
1972 for connection_id in session
1973 .connection_pool()
1974 .await
1975 .user_connection_ids(session.user_id())
1976 {
1977 session
1978 .peer
1979 .send(
1980 connection_id,
1981 proto::CallCanceled {
1982 room_id: room_id.to_proto(),
1983 },
1984 )
1985 .trace_err();
1986 }
1987 update_user_contacts(session.user_id(), &session).await?;
1988 Ok(())
1989}
1990
1991/// Updates other participants in the room with your current location.
1992async fn update_participant_location(
1993 request: proto::UpdateParticipantLocation,
1994 response: Response<proto::UpdateParticipantLocation>,
1995 session: UserSession,
1996) -> Result<()> {
1997 let room_id = RoomId::from_proto(request.room_id);
1998 let location = request
1999 .location
2000 .ok_or_else(|| anyhow!("invalid location"))?;
2001
2002 let db = session.db().await;
2003 let room = db
2004 .update_room_participant_location(room_id, session.connection_id, location)
2005 .await?;
2006
2007 room_updated(&room, &session.peer);
2008 response.send(proto::Ack {})?;
2009 Ok(())
2010}
2011
2012/// Share a project into the room.
2013async fn share_project(
2014 request: proto::ShareProject,
2015 response: Response<proto::ShareProject>,
2016 session: UserSession,
2017) -> Result<()> {
2018 let (project_id, room) = &*session
2019 .db()
2020 .await
2021 .share_project(
2022 RoomId::from_proto(request.room_id),
2023 session.connection_id,
2024 &request.worktrees,
2025 request
2026 .dev_server_project_id
2027 .map(|id| DevServerProjectId::from_proto(id)),
2028 )
2029 .await?;
2030 response.send(proto::ShareProjectResponse {
2031 project_id: project_id.to_proto(),
2032 })?;
2033 room_updated(&room, &session.peer);
2034
2035 Ok(())
2036}
2037
2038/// Unshare a project from the room.
2039async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2040 let project_id = ProjectId::from_proto(message.project_id);
2041 unshare_project_internal(
2042 project_id,
2043 session.connection_id,
2044 session.user_id(),
2045 &session,
2046 )
2047 .await
2048}
2049
2050async fn unshare_project_internal(
2051 project_id: ProjectId,
2052 connection_id: ConnectionId,
2053 user_id: Option<UserId>,
2054 session: &Session,
2055) -> Result<()> {
2056 let delete = {
2057 let room_guard = session
2058 .db()
2059 .await
2060 .unshare_project(project_id, connection_id, user_id)
2061 .await?;
2062
2063 let (delete, room, guest_connection_ids) = &*room_guard;
2064
2065 let message = proto::UnshareProject {
2066 project_id: project_id.to_proto(),
2067 };
2068
2069 broadcast(
2070 Some(connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |conn_id| session.peer.send(conn_id, message.clone()),
2073 );
2074 if let Some(room) = room {
2075 room_updated(room, &session.peer);
2076 }
2077
2078 *delete
2079 };
2080
2081 if delete {
2082 let db = session.db().await;
2083 db.delete_project(project_id).await?;
2084 }
2085
2086 Ok(())
2087}
2088
2089/// DevServer makes a project available online
2090async fn share_dev_server_project(
2091 request: proto::ShareDevServerProject,
2092 response: Response<proto::ShareDevServerProject>,
2093 session: DevServerSession,
2094) -> Result<()> {
2095 let (dev_server_project, user_id, status) = session
2096 .db()
2097 .await
2098 .share_dev_server_project(
2099 DevServerProjectId::from_proto(request.dev_server_project_id),
2100 session.dev_server_id(),
2101 session.connection_id,
2102 &request.worktrees,
2103 )
2104 .await?;
2105 let Some(project_id) = dev_server_project.project_id else {
2106 return Err(anyhow!("failed to share remote project"))?;
2107 };
2108
2109 send_dev_server_projects_update(user_id, status, &session).await;
2110
2111 response.send(proto::ShareProjectResponse { project_id })?;
2112
2113 Ok(())
2114}
2115
2116/// Join someone elses shared project.
2117async fn join_project(
2118 request: proto::JoinProject,
2119 response: Response<proto::JoinProject>,
2120 session: UserSession,
2121) -> Result<()> {
2122 let project_id = ProjectId::from_proto(request.project_id);
2123
2124 tracing::info!(%project_id, "join project");
2125
2126 let db = session.db().await;
2127 let (project, replica_id) = &mut *db
2128 .join_project(project_id, session.connection_id, session.user_id())
2129 .await?;
2130 drop(db);
2131 tracing::info!(%project_id, "join remote project");
2132 join_project_internal(response, session, project, replica_id)
2133}
2134
2135trait JoinProjectInternalResponse {
2136 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2137}
2138impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2139 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2140 Response::<proto::JoinProject>::send(self, result)
2141 }
2142}
2143impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2144 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2145 Response::<proto::JoinHostedProject>::send(self, result)
2146 }
2147}
2148
2149fn join_project_internal(
2150 response: impl JoinProjectInternalResponse,
2151 session: UserSession,
2152 project: &mut Project,
2153 replica_id: &ReplicaId,
2154) -> Result<()> {
2155 let collaborators = project
2156 .collaborators
2157 .iter()
2158 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2159 .map(|collaborator| collaborator.to_proto())
2160 .collect::<Vec<_>>();
2161 let project_id = project.id;
2162 let guest_user_id = session.user_id();
2163
2164 let worktrees = project
2165 .worktrees
2166 .iter()
2167 .map(|(id, worktree)| proto::WorktreeMetadata {
2168 id: *id,
2169 root_name: worktree.root_name.clone(),
2170 visible: worktree.visible,
2171 abs_path: worktree.abs_path.clone(),
2172 })
2173 .collect::<Vec<_>>();
2174
2175 let add_project_collaborator = proto::AddProjectCollaborator {
2176 project_id: project_id.to_proto(),
2177 collaborator: Some(proto::Collaborator {
2178 peer_id: Some(session.connection_id.into()),
2179 replica_id: replica_id.0 as u32,
2180 user_id: guest_user_id.to_proto(),
2181 }),
2182 };
2183
2184 for collaborator in &collaborators {
2185 session
2186 .peer
2187 .send(
2188 collaborator.peer_id.unwrap().into(),
2189 add_project_collaborator.clone(),
2190 )
2191 .trace_err();
2192 }
2193
2194 // First, we send the metadata associated with each worktree.
2195 response.send(proto::JoinProjectResponse {
2196 project_id: project.id.0 as u64,
2197 worktrees: worktrees.clone(),
2198 replica_id: replica_id.0 as u32,
2199 collaborators: collaborators.clone(),
2200 language_servers: project.language_servers.clone(),
2201 role: project.role.into(),
2202 dev_server_project_id: project
2203 .dev_server_project_id
2204 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2205 })?;
2206
2207 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2208 #[cfg(any(test, feature = "test-support"))]
2209 const MAX_CHUNK_SIZE: usize = 2;
2210 #[cfg(not(any(test, feature = "test-support")))]
2211 const MAX_CHUNK_SIZE: usize = 256;
2212
2213 // Stream this worktree's entries.
2214 let message = proto::UpdateWorktree {
2215 project_id: project_id.to_proto(),
2216 worktree_id,
2217 abs_path: worktree.abs_path.clone(),
2218 root_name: worktree.root_name,
2219 updated_entries: worktree.entries,
2220 removed_entries: Default::default(),
2221 scan_id: worktree.scan_id,
2222 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2223 updated_repositories: worktree.repository_entries.into_values().collect(),
2224 removed_repositories: Default::default(),
2225 };
2226 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2227 session.peer.send(session.connection_id, update.clone())?;
2228 }
2229
2230 // Stream this worktree's diagnostics.
2231 for summary in worktree.diagnostic_summaries {
2232 session.peer.send(
2233 session.connection_id,
2234 proto::UpdateDiagnosticSummary {
2235 project_id: project_id.to_proto(),
2236 worktree_id: worktree.id,
2237 summary: Some(summary),
2238 },
2239 )?;
2240 }
2241
2242 for settings_file in worktree.settings_files {
2243 session.peer.send(
2244 session.connection_id,
2245 proto::UpdateWorktreeSettings {
2246 project_id: project_id.to_proto(),
2247 worktree_id: worktree.id,
2248 path: settings_file.path,
2249 content: Some(settings_file.content),
2250 },
2251 )?;
2252 }
2253 }
2254
2255 for language_server in &project.language_servers {
2256 session.peer.send(
2257 session.connection_id,
2258 proto::UpdateLanguageServer {
2259 project_id: project_id.to_proto(),
2260 language_server_id: language_server.id,
2261 variant: Some(
2262 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2263 proto::LspDiskBasedDiagnosticsUpdated {},
2264 ),
2265 ),
2266 },
2267 )?;
2268 }
2269
2270 Ok(())
2271}
2272
2273/// Leave someone elses shared project.
2274async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2275 let sender_id = session.connection_id;
2276 let project_id = ProjectId::from_proto(request.project_id);
2277 let db = session.db().await;
2278 if db.is_hosted_project(project_id).await? {
2279 let project = db.leave_hosted_project(project_id, sender_id).await?;
2280 project_left(&project, &session);
2281 return Ok(());
2282 }
2283
2284 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2285 tracing::info!(
2286 %project_id,
2287 "leave project"
2288 );
2289
2290 project_left(&project, &session);
2291 if let Some(room) = room {
2292 room_updated(&room, &session.peer);
2293 }
2294
2295 Ok(())
2296}
2297
2298async fn join_hosted_project(
2299 request: proto::JoinHostedProject,
2300 response: Response<proto::JoinHostedProject>,
2301 session: UserSession,
2302) -> Result<()> {
2303 let (mut project, replica_id) = session
2304 .db()
2305 .await
2306 .join_hosted_project(
2307 ProjectId(request.project_id as i32),
2308 session.user_id(),
2309 session.connection_id,
2310 )
2311 .await?;
2312
2313 join_project_internal(response, session, &mut project, &replica_id)
2314}
2315
2316async fn create_dev_server_project(
2317 request: proto::CreateDevServerProject,
2318 response: Response<proto::CreateDevServerProject>,
2319 session: UserSession,
2320) -> Result<()> {
2321 let dev_server_id = DevServerId(request.dev_server_id as i32);
2322 let dev_server_connection_id = session
2323 .connection_pool()
2324 .await
2325 .dev_server_connection_id(dev_server_id);
2326 let Some(dev_server_connection_id) = dev_server_connection_id else {
2327 Err(ErrorCode::DevServerOffline
2328 .message("Cannot create a remote project when the dev server is offline".to_string())
2329 .anyhow())?
2330 };
2331
2332 let path = request.path.clone();
2333 //Check that the path exists on the dev server
2334 session
2335 .peer
2336 .forward_request(
2337 session.connection_id,
2338 dev_server_connection_id,
2339 proto::ValidateDevServerProjectRequest { path: path.clone() },
2340 )
2341 .await?;
2342
2343 let (dev_server_project, update) = session
2344 .db()
2345 .await
2346 .create_dev_server_project(
2347 DevServerId(request.dev_server_id as i32),
2348 &request.path,
2349 session.user_id(),
2350 )
2351 .await?;
2352
2353 let projects = session
2354 .db()
2355 .await
2356 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2357 .await?;
2358
2359 session.peer.send(
2360 dev_server_connection_id,
2361 proto::DevServerInstructions { projects },
2362 )?;
2363
2364 send_dev_server_projects_update(session.user_id(), update, &session).await;
2365
2366 response.send(proto::CreateDevServerProjectResponse {
2367 dev_server_project: Some(dev_server_project.to_proto(None)),
2368 })?;
2369 Ok(())
2370}
2371
2372async fn create_dev_server(
2373 request: proto::CreateDevServer,
2374 response: Response<proto::CreateDevServer>,
2375 session: UserSession,
2376) -> Result<()> {
2377 let access_token = auth::random_token();
2378 let hashed_access_token = auth::hash_access_token(&access_token);
2379
2380 if request.name.is_empty() {
2381 return Err(proto::ErrorCode::Forbidden
2382 .message("Dev server name cannot be empty".to_string())
2383 .anyhow())?;
2384 }
2385
2386 let (dev_server, status) = session
2387 .db()
2388 .await
2389 .create_dev_server(
2390 &request.name,
2391 request.ssh_connection_string.as_deref(),
2392 &hashed_access_token,
2393 session.user_id(),
2394 )
2395 .await?;
2396
2397 send_dev_server_projects_update(session.user_id(), status, &session).await;
2398
2399 response.send(proto::CreateDevServerResponse {
2400 dev_server_id: dev_server.id.0 as u64,
2401 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2402 name: request.name,
2403 })?;
2404 Ok(())
2405}
2406
2407async fn regenerate_dev_server_token(
2408 request: proto::RegenerateDevServerToken,
2409 response: Response<proto::RegenerateDevServerToken>,
2410 session: UserSession,
2411) -> Result<()> {
2412 let dev_server_id = DevServerId(request.dev_server_id as i32);
2413 let access_token = auth::random_token();
2414 let hashed_access_token = auth::hash_access_token(&access_token);
2415
2416 let connection_id = session
2417 .connection_pool()
2418 .await
2419 .dev_server_connection_id(dev_server_id);
2420 if let Some(connection_id) = connection_id {
2421 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2422 session.peer.send(
2423 connection_id,
2424 proto::ShutdownDevServer {
2425 reason: Some("dev server token was regenerated".to_string()),
2426 },
2427 )?;
2428 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2429 }
2430
2431 let status = session
2432 .db()
2433 .await
2434 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2435 .await?;
2436
2437 send_dev_server_projects_update(session.user_id(), status, &session).await;
2438
2439 response.send(proto::RegenerateDevServerTokenResponse {
2440 dev_server_id: dev_server_id.to_proto(),
2441 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2442 })?;
2443 Ok(())
2444}
2445
2446async fn rename_dev_server(
2447 request: proto::RenameDevServer,
2448 response: Response<proto::RenameDevServer>,
2449 session: UserSession,
2450) -> Result<()> {
2451 if request.name.trim().is_empty() {
2452 return Err(proto::ErrorCode::Forbidden
2453 .message("Dev server name cannot be empty".to_string())
2454 .anyhow())?;
2455 }
2456
2457 let dev_server_id = DevServerId(request.dev_server_id as i32);
2458 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2459 if dev_server.user_id != session.user_id() {
2460 return Err(anyhow!(ErrorCode::Forbidden))?;
2461 }
2462
2463 let status = session
2464 .db()
2465 .await
2466 .rename_dev_server(
2467 dev_server_id,
2468 &request.name,
2469 request.ssh_connection_string.as_deref(),
2470 session.user_id(),
2471 )
2472 .await?;
2473
2474 send_dev_server_projects_update(session.user_id(), status, &session).await;
2475
2476 response.send(proto::Ack {})?;
2477 Ok(())
2478}
2479
2480async fn delete_dev_server(
2481 request: proto::DeleteDevServer,
2482 response: Response<proto::DeleteDevServer>,
2483 session: UserSession,
2484) -> Result<()> {
2485 let dev_server_id = DevServerId(request.dev_server_id as i32);
2486 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2487 if dev_server.user_id != session.user_id() {
2488 return Err(anyhow!(ErrorCode::Forbidden))?;
2489 }
2490
2491 let connection_id = session
2492 .connection_pool()
2493 .await
2494 .dev_server_connection_id(dev_server_id);
2495 if let Some(connection_id) = connection_id {
2496 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2497 session.peer.send(
2498 connection_id,
2499 proto::ShutdownDevServer {
2500 reason: Some("dev server was deleted".to_string()),
2501 },
2502 )?;
2503 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2504 }
2505
2506 let status = session
2507 .db()
2508 .await
2509 .delete_dev_server(dev_server_id, session.user_id())
2510 .await?;
2511
2512 send_dev_server_projects_update(session.user_id(), status, &session).await;
2513
2514 response.send(proto::Ack {})?;
2515 Ok(())
2516}
2517
2518async fn delete_dev_server_project(
2519 request: proto::DeleteDevServerProject,
2520 response: Response<proto::DeleteDevServerProject>,
2521 session: UserSession,
2522) -> Result<()> {
2523 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2524 let dev_server_project = session
2525 .db()
2526 .await
2527 .get_dev_server_project(dev_server_project_id)
2528 .await?;
2529
2530 let dev_server = session
2531 .db()
2532 .await
2533 .get_dev_server(dev_server_project.dev_server_id)
2534 .await?;
2535 if dev_server.user_id != session.user_id() {
2536 return Err(anyhow!(ErrorCode::Forbidden))?;
2537 }
2538
2539 let dev_server_connection_id = session
2540 .connection_pool()
2541 .await
2542 .dev_server_connection_id(dev_server.id);
2543
2544 if let Some(dev_server_connection_id) = dev_server_connection_id {
2545 let project = session
2546 .db()
2547 .await
2548 .find_dev_server_project(dev_server_project_id)
2549 .await;
2550 if let Ok(project) = project {
2551 unshare_project_internal(
2552 project.id,
2553 dev_server_connection_id,
2554 Some(session.user_id()),
2555 &session,
2556 )
2557 .await?;
2558 }
2559 }
2560
2561 let (projects, status) = session
2562 .db()
2563 .await
2564 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2565 .await?;
2566
2567 if let Some(dev_server_connection_id) = dev_server_connection_id {
2568 session.peer.send(
2569 dev_server_connection_id,
2570 proto::DevServerInstructions { projects },
2571 )?;
2572 }
2573
2574 send_dev_server_projects_update(session.user_id(), status, &session).await;
2575
2576 response.send(proto::Ack {})?;
2577 Ok(())
2578}
2579
2580async fn rejoin_dev_server_projects(
2581 request: proto::RejoinRemoteProjects,
2582 response: Response<proto::RejoinRemoteProjects>,
2583 session: UserSession,
2584) -> Result<()> {
2585 let mut rejoined_projects = {
2586 let db = session.db().await;
2587 db.rejoin_dev_server_projects(
2588 &request.rejoined_projects,
2589 session.user_id(),
2590 session.0.connection_id,
2591 )
2592 .await?
2593 };
2594 response.send(proto::RejoinRemoteProjectsResponse {
2595 rejoined_projects: rejoined_projects
2596 .iter()
2597 .map(|project| project.to_proto())
2598 .collect(),
2599 })?;
2600 notify_rejoined_projects(&mut rejoined_projects, &session)
2601}
2602
2603async fn reconnect_dev_server(
2604 request: proto::ReconnectDevServer,
2605 response: Response<proto::ReconnectDevServer>,
2606 session: DevServerSession,
2607) -> Result<()> {
2608 let reshared_projects = {
2609 let db = session.db().await;
2610 db.reshare_dev_server_projects(
2611 &request.reshared_projects,
2612 session.dev_server_id(),
2613 session.0.connection_id,
2614 )
2615 .await?
2616 };
2617
2618 for project in &reshared_projects {
2619 for collaborator in &project.collaborators {
2620 session
2621 .peer
2622 .send(
2623 collaborator.connection_id,
2624 proto::UpdateProjectCollaborator {
2625 project_id: project.id.to_proto(),
2626 old_peer_id: Some(project.old_connection_id.into()),
2627 new_peer_id: Some(session.connection_id.into()),
2628 },
2629 )
2630 .trace_err();
2631 }
2632
2633 broadcast(
2634 Some(session.connection_id),
2635 project
2636 .collaborators
2637 .iter()
2638 .map(|collaborator| collaborator.connection_id),
2639 |connection_id| {
2640 session.peer.forward_send(
2641 session.connection_id,
2642 connection_id,
2643 proto::UpdateProject {
2644 project_id: project.id.to_proto(),
2645 worktrees: project.worktrees.clone(),
2646 },
2647 )
2648 },
2649 );
2650 }
2651
2652 response.send(proto::ReconnectDevServerResponse {
2653 reshared_projects: reshared_projects
2654 .iter()
2655 .map(|project| proto::ResharedProject {
2656 id: project.id.to_proto(),
2657 collaborators: project
2658 .collaborators
2659 .iter()
2660 .map(|collaborator| collaborator.to_proto())
2661 .collect(),
2662 })
2663 .collect(),
2664 })?;
2665
2666 Ok(())
2667}
2668
2669async fn shutdown_dev_server(
2670 _: proto::ShutdownDevServer,
2671 response: Response<proto::ShutdownDevServer>,
2672 session: DevServerSession,
2673) -> Result<()> {
2674 response.send(proto::Ack {})?;
2675 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2676 remove_dev_server_connection(session.dev_server_id(), &session).await
2677}
2678
2679async fn shutdown_dev_server_internal(
2680 dev_server_id: DevServerId,
2681 connection_id: ConnectionId,
2682 session: &Session,
2683) -> Result<()> {
2684 let (dev_server_projects, dev_server) = {
2685 let db = session.db().await;
2686 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2687 let dev_server = db.get_dev_server(dev_server_id).await?;
2688 (dev_server_projects, dev_server)
2689 };
2690
2691 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2692 unshare_project_internal(
2693 ProjectId::from_proto(project_id),
2694 connection_id,
2695 None,
2696 session,
2697 )
2698 .await?;
2699 }
2700
2701 session
2702 .connection_pool()
2703 .await
2704 .set_dev_server_offline(dev_server_id);
2705
2706 let status = session
2707 .db()
2708 .await
2709 .dev_server_projects_update(dev_server.user_id)
2710 .await?;
2711 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2712
2713 Ok(())
2714}
2715
2716async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2717 let dev_server_connection = session
2718 .connection_pool()
2719 .await
2720 .dev_server_connection_id(dev_server_id);
2721
2722 if let Some(dev_server_connection) = dev_server_connection {
2723 session
2724 .connection_pool()
2725 .await
2726 .remove_connection(dev_server_connection)?;
2727 }
2728 Ok(())
2729}
2730
2731/// Updates other participants with changes to the project
2732async fn update_project(
2733 request: proto::UpdateProject,
2734 response: Response<proto::UpdateProject>,
2735 session: Session,
2736) -> Result<()> {
2737 let project_id = ProjectId::from_proto(request.project_id);
2738 let (room, guest_connection_ids) = &*session
2739 .db()
2740 .await
2741 .update_project(project_id, session.connection_id, &request.worktrees)
2742 .await?;
2743 broadcast(
2744 Some(session.connection_id),
2745 guest_connection_ids.iter().copied(),
2746 |connection_id| {
2747 session
2748 .peer
2749 .forward_send(session.connection_id, connection_id, request.clone())
2750 },
2751 );
2752 if let Some(room) = room {
2753 room_updated(&room, &session.peer);
2754 }
2755 response.send(proto::Ack {})?;
2756
2757 Ok(())
2758}
2759
2760/// Updates other participants with changes to the worktree
2761async fn update_worktree(
2762 request: proto::UpdateWorktree,
2763 response: Response<proto::UpdateWorktree>,
2764 session: Session,
2765) -> Result<()> {
2766 let guest_connection_ids = session
2767 .db()
2768 .await
2769 .update_worktree(&request, session.connection_id)
2770 .await?;
2771
2772 broadcast(
2773 Some(session.connection_id),
2774 guest_connection_ids.iter().copied(),
2775 |connection_id| {
2776 session
2777 .peer
2778 .forward_send(session.connection_id, connection_id, request.clone())
2779 },
2780 );
2781 response.send(proto::Ack {})?;
2782 Ok(())
2783}
2784
2785/// Updates other participants with changes to the diagnostics
2786async fn update_diagnostic_summary(
2787 message: proto::UpdateDiagnosticSummary,
2788 session: Session,
2789) -> Result<()> {
2790 let guest_connection_ids = session
2791 .db()
2792 .await
2793 .update_diagnostic_summary(&message, session.connection_id)
2794 .await?;
2795
2796 broadcast(
2797 Some(session.connection_id),
2798 guest_connection_ids.iter().copied(),
2799 |connection_id| {
2800 session
2801 .peer
2802 .forward_send(session.connection_id, connection_id, message.clone())
2803 },
2804 );
2805
2806 Ok(())
2807}
2808
2809/// Updates other participants with changes to the worktree settings
2810async fn update_worktree_settings(
2811 message: proto::UpdateWorktreeSettings,
2812 session: Session,
2813) -> Result<()> {
2814 let guest_connection_ids = session
2815 .db()
2816 .await
2817 .update_worktree_settings(&message, session.connection_id)
2818 .await?;
2819
2820 broadcast(
2821 Some(session.connection_id),
2822 guest_connection_ids.iter().copied(),
2823 |connection_id| {
2824 session
2825 .peer
2826 .forward_send(session.connection_id, connection_id, message.clone())
2827 },
2828 );
2829
2830 Ok(())
2831}
2832
2833/// Notify other participants that a language server has started.
2834async fn start_language_server(
2835 request: proto::StartLanguageServer,
2836 session: Session,
2837) -> Result<()> {
2838 let guest_connection_ids = session
2839 .db()
2840 .await
2841 .start_language_server(&request, session.connection_id)
2842 .await?;
2843
2844 broadcast(
2845 Some(session.connection_id),
2846 guest_connection_ids.iter().copied(),
2847 |connection_id| {
2848 session
2849 .peer
2850 .forward_send(session.connection_id, connection_id, request.clone())
2851 },
2852 );
2853 Ok(())
2854}
2855
2856/// Notify other participants that a language server has changed.
2857async fn update_language_server(
2858 request: proto::UpdateLanguageServer,
2859 session: Session,
2860) -> Result<()> {
2861 let project_id = ProjectId::from_proto(request.project_id);
2862 let project_connection_ids = session
2863 .db()
2864 .await
2865 .project_connection_ids(project_id, session.connection_id, true)
2866 .await?;
2867 broadcast(
2868 Some(session.connection_id),
2869 project_connection_ids.iter().copied(),
2870 |connection_id| {
2871 session
2872 .peer
2873 .forward_send(session.connection_id, connection_id, request.clone())
2874 },
2875 );
2876 Ok(())
2877}
2878
2879/// forward a project request to the host. These requests should be read only
2880/// as guests are allowed to send them.
2881async fn forward_read_only_project_request<T>(
2882 request: T,
2883 response: Response<T>,
2884 session: UserSession,
2885) -> Result<()>
2886where
2887 T: EntityMessage + RequestMessage,
2888{
2889 let project_id = ProjectId::from_proto(request.remote_entity_id());
2890 let host_connection_id = session
2891 .db()
2892 .await
2893 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2894 .await?;
2895 let payload = session
2896 .peer
2897 .forward_request(session.connection_id, host_connection_id, request)
2898 .await?;
2899 response.send(payload)?;
2900 Ok(())
2901}
2902
2903/// forward a project request to the dev server. Only allowed
2904/// if it's your dev server.
2905async fn forward_project_request_for_owner<T>(
2906 request: T,
2907 response: Response<T>,
2908 session: UserSession,
2909) -> Result<()>
2910where
2911 T: EntityMessage + RequestMessage,
2912{
2913 let project_id = ProjectId::from_proto(request.remote_entity_id());
2914
2915 let host_connection_id = session
2916 .db()
2917 .await
2918 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2919 .await?;
2920 let payload = session
2921 .peer
2922 .forward_request(session.connection_id, host_connection_id, request)
2923 .await?;
2924 response.send(payload)?;
2925 Ok(())
2926}
2927
2928/// forward a project request to the host. These requests are disallowed
2929/// for guests.
2930async fn forward_mutating_project_request<T>(
2931 request: T,
2932 response: Response<T>,
2933 session: UserSession,
2934) -> Result<()>
2935where
2936 T: EntityMessage + RequestMessage,
2937{
2938 let project_id = ProjectId::from_proto(request.remote_entity_id());
2939
2940 let host_connection_id = session
2941 .db()
2942 .await
2943 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2944 .await?;
2945 let payload = session
2946 .peer
2947 .forward_request(session.connection_id, host_connection_id, request)
2948 .await?;
2949 response.send(payload)?;
2950 Ok(())
2951}
2952
2953/// forward a project request to the host. These requests are disallowed
2954/// for guests.
2955async fn forward_versioned_mutating_project_request<T>(
2956 request: T,
2957 response: Response<T>,
2958 session: UserSession,
2959) -> Result<()>
2960where
2961 T: EntityMessage + RequestMessage + VersionedMessage,
2962{
2963 let project_id = ProjectId::from_proto(request.remote_entity_id());
2964
2965 let host_connection_id = session
2966 .db()
2967 .await
2968 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2969 .await?;
2970 if let Some(host_version) = session
2971 .connection_pool()
2972 .await
2973 .connection(host_connection_id)
2974 .map(|c| c.zed_version)
2975 {
2976 if let Some(min_required_version) = request.required_host_version() {
2977 if min_required_version > host_version {
2978 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2979 .with_tag("required", &min_required_version.to_string())))?;
2980 }
2981 }
2982 }
2983
2984 let payload = session
2985 .peer
2986 .forward_request(session.connection_id, host_connection_id, request)
2987 .await?;
2988 response.send(payload)?;
2989 Ok(())
2990}
2991
2992/// Notify other participants that a new buffer has been created
2993async fn create_buffer_for_peer(
2994 request: proto::CreateBufferForPeer,
2995 session: Session,
2996) -> Result<()> {
2997 session
2998 .db()
2999 .await
3000 .check_user_is_project_host(
3001 ProjectId::from_proto(request.project_id),
3002 session.connection_id,
3003 )
3004 .await?;
3005 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3006 session
3007 .peer
3008 .forward_send(session.connection_id, peer_id.into(), request)?;
3009 Ok(())
3010}
3011
3012/// Notify other participants that a buffer has been updated. This is
3013/// allowed for guests as long as the update is limited to selections.
3014async fn update_buffer(
3015 request: proto::UpdateBuffer,
3016 response: Response<proto::UpdateBuffer>,
3017 session: Session,
3018) -> Result<()> {
3019 let project_id = ProjectId::from_proto(request.project_id);
3020 let mut capability = Capability::ReadOnly;
3021
3022 for op in request.operations.iter() {
3023 match op.variant {
3024 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3025 Some(_) => capability = Capability::ReadWrite,
3026 }
3027 }
3028
3029 let host = {
3030 let guard = session
3031 .db()
3032 .await
3033 .connections_for_buffer_update(
3034 project_id,
3035 session.principal_id(),
3036 session.connection_id,
3037 capability,
3038 )
3039 .await?;
3040
3041 let (host, guests) = &*guard;
3042
3043 broadcast(
3044 Some(session.connection_id),
3045 guests.clone(),
3046 |connection_id| {
3047 session
3048 .peer
3049 .forward_send(session.connection_id, connection_id, request.clone())
3050 },
3051 );
3052
3053 *host
3054 };
3055
3056 if host != session.connection_id {
3057 session
3058 .peer
3059 .forward_request(session.connection_id, host, request.clone())
3060 .await?;
3061 }
3062
3063 response.send(proto::Ack {})?;
3064 Ok(())
3065}
3066
3067async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3068 let project_id = ProjectId::from_proto(message.project_id);
3069
3070 let operation = message.operation.as_ref().context("invalid operation")?;
3071 let capability = match operation.variant.as_ref() {
3072 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3073 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3074 match buffer_op.variant {
3075 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3076 Capability::ReadOnly
3077 }
3078 _ => Capability::ReadWrite,
3079 }
3080 } else {
3081 Capability::ReadWrite
3082 }
3083 }
3084 Some(_) => Capability::ReadWrite,
3085 None => Capability::ReadOnly,
3086 };
3087
3088 let guard = session
3089 .db()
3090 .await
3091 .connections_for_buffer_update(
3092 project_id,
3093 session.principal_id(),
3094 session.connection_id,
3095 capability,
3096 )
3097 .await?;
3098
3099 let (host, guests) = &*guard;
3100
3101 broadcast(
3102 Some(session.connection_id),
3103 guests.iter().chain([host]).copied(),
3104 |connection_id| {
3105 session
3106 .peer
3107 .forward_send(session.connection_id, connection_id, message.clone())
3108 },
3109 );
3110
3111 Ok(())
3112}
3113
3114/// Notify other participants that a project has been updated.
3115async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3116 request: T,
3117 session: Session,
3118) -> Result<()> {
3119 let project_id = ProjectId::from_proto(request.remote_entity_id());
3120 let project_connection_ids = session
3121 .db()
3122 .await
3123 .project_connection_ids(project_id, session.connection_id, false)
3124 .await?;
3125
3126 broadcast(
3127 Some(session.connection_id),
3128 project_connection_ids.iter().copied(),
3129 |connection_id| {
3130 session
3131 .peer
3132 .forward_send(session.connection_id, connection_id, request.clone())
3133 },
3134 );
3135 Ok(())
3136}
3137
3138/// Start following another user in a call.
3139async fn follow(
3140 request: proto::Follow,
3141 response: Response<proto::Follow>,
3142 session: UserSession,
3143) -> Result<()> {
3144 let room_id = RoomId::from_proto(request.room_id);
3145 let project_id = request.project_id.map(ProjectId::from_proto);
3146 let leader_id = request
3147 .leader_id
3148 .ok_or_else(|| anyhow!("invalid leader id"))?
3149 .into();
3150 let follower_id = session.connection_id;
3151
3152 session
3153 .db()
3154 .await
3155 .check_room_participants(room_id, leader_id, session.connection_id)
3156 .await?;
3157
3158 let response_payload = session
3159 .peer
3160 .forward_request(session.connection_id, leader_id, request)
3161 .await?;
3162 response.send(response_payload)?;
3163
3164 if let Some(project_id) = project_id {
3165 let room = session
3166 .db()
3167 .await
3168 .follow(room_id, project_id, leader_id, follower_id)
3169 .await?;
3170 room_updated(&room, &session.peer);
3171 }
3172
3173 Ok(())
3174}
3175
3176/// Stop following another user in a call.
3177async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3178 let room_id = RoomId::from_proto(request.room_id);
3179 let project_id = request.project_id.map(ProjectId::from_proto);
3180 let leader_id = request
3181 .leader_id
3182 .ok_or_else(|| anyhow!("invalid leader id"))?
3183 .into();
3184 let follower_id = session.connection_id;
3185
3186 session
3187 .db()
3188 .await
3189 .check_room_participants(room_id, leader_id, session.connection_id)
3190 .await?;
3191
3192 session
3193 .peer
3194 .forward_send(session.connection_id, leader_id, request)?;
3195
3196 if let Some(project_id) = project_id {
3197 let room = session
3198 .db()
3199 .await
3200 .unfollow(room_id, project_id, leader_id, follower_id)
3201 .await?;
3202 room_updated(&room, &session.peer);
3203 }
3204
3205 Ok(())
3206}
3207
3208/// Notify everyone following you of your current location.
3209async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3210 let room_id = RoomId::from_proto(request.room_id);
3211 let database = session.db.lock().await;
3212
3213 let connection_ids = if let Some(project_id) = request.project_id {
3214 let project_id = ProjectId::from_proto(project_id);
3215 database
3216 .project_connection_ids(project_id, session.connection_id, true)
3217 .await?
3218 } else {
3219 database
3220 .room_connection_ids(room_id, session.connection_id)
3221 .await?
3222 };
3223
3224 // For now, don't send view update messages back to that view's current leader.
3225 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3226 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3227 _ => None,
3228 });
3229
3230 for connection_id in connection_ids.iter().cloned() {
3231 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3232 session
3233 .peer
3234 .forward_send(session.connection_id, connection_id, request.clone())?;
3235 }
3236 }
3237 Ok(())
3238}
3239
3240/// Get public data about users.
3241async fn get_users(
3242 request: proto::GetUsers,
3243 response: Response<proto::GetUsers>,
3244 session: Session,
3245) -> Result<()> {
3246 let user_ids = request
3247 .user_ids
3248 .into_iter()
3249 .map(UserId::from_proto)
3250 .collect();
3251 let users = session
3252 .db()
3253 .await
3254 .get_users_by_ids(user_ids)
3255 .await?
3256 .into_iter()
3257 .map(|user| proto::User {
3258 id: user.id.to_proto(),
3259 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3260 github_login: user.github_login,
3261 })
3262 .collect();
3263 response.send(proto::UsersResponse { users })?;
3264 Ok(())
3265}
3266
3267/// Search for users (to invite) buy Github login
3268async fn fuzzy_search_users(
3269 request: proto::FuzzySearchUsers,
3270 response: Response<proto::FuzzySearchUsers>,
3271 session: UserSession,
3272) -> Result<()> {
3273 let query = request.query;
3274 let users = match query.len() {
3275 0 => vec![],
3276 1 | 2 => session
3277 .db()
3278 .await
3279 .get_user_by_github_login(&query)
3280 .await?
3281 .into_iter()
3282 .collect(),
3283 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3284 };
3285 let users = users
3286 .into_iter()
3287 .filter(|user| user.id != session.user_id())
3288 .map(|user| proto::User {
3289 id: user.id.to_proto(),
3290 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3291 github_login: user.github_login,
3292 })
3293 .collect();
3294 response.send(proto::UsersResponse { users })?;
3295 Ok(())
3296}
3297
3298/// Send a contact request to another user.
3299async fn request_contact(
3300 request: proto::RequestContact,
3301 response: Response<proto::RequestContact>,
3302 session: UserSession,
3303) -> Result<()> {
3304 let requester_id = session.user_id();
3305 let responder_id = UserId::from_proto(request.responder_id);
3306 if requester_id == responder_id {
3307 return Err(anyhow!("cannot add yourself as a contact"))?;
3308 }
3309
3310 let notifications = session
3311 .db()
3312 .await
3313 .send_contact_request(requester_id, responder_id)
3314 .await?;
3315
3316 // Update outgoing contact requests of requester
3317 let mut update = proto::UpdateContacts::default();
3318 update.outgoing_requests.push(responder_id.to_proto());
3319 for connection_id in session
3320 .connection_pool()
3321 .await
3322 .user_connection_ids(requester_id)
3323 {
3324 session.peer.send(connection_id, update.clone())?;
3325 }
3326
3327 // Update incoming contact requests of responder
3328 let mut update = proto::UpdateContacts::default();
3329 update
3330 .incoming_requests
3331 .push(proto::IncomingContactRequest {
3332 requester_id: requester_id.to_proto(),
3333 });
3334 let connection_pool = session.connection_pool().await;
3335 for connection_id in connection_pool.user_connection_ids(responder_id) {
3336 session.peer.send(connection_id, update.clone())?;
3337 }
3338
3339 send_notifications(&connection_pool, &session.peer, notifications);
3340
3341 response.send(proto::Ack {})?;
3342 Ok(())
3343}
3344
3345/// Accept or decline a contact request
3346async fn respond_to_contact_request(
3347 request: proto::RespondToContactRequest,
3348 response: Response<proto::RespondToContactRequest>,
3349 session: UserSession,
3350) -> Result<()> {
3351 let responder_id = session.user_id();
3352 let requester_id = UserId::from_proto(request.requester_id);
3353 let db = session.db().await;
3354 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3355 db.dismiss_contact_notification(responder_id, requester_id)
3356 .await?;
3357 } else {
3358 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3359
3360 let notifications = db
3361 .respond_to_contact_request(responder_id, requester_id, accept)
3362 .await?;
3363 let requester_busy = db.is_user_busy(requester_id).await?;
3364 let responder_busy = db.is_user_busy(responder_id).await?;
3365
3366 let pool = session.connection_pool().await;
3367 // Update responder with new contact
3368 let mut update = proto::UpdateContacts::default();
3369 if accept {
3370 update
3371 .contacts
3372 .push(contact_for_user(requester_id, requester_busy, &pool));
3373 }
3374 update
3375 .remove_incoming_requests
3376 .push(requester_id.to_proto());
3377 for connection_id in pool.user_connection_ids(responder_id) {
3378 session.peer.send(connection_id, update.clone())?;
3379 }
3380
3381 // Update requester with new contact
3382 let mut update = proto::UpdateContacts::default();
3383 if accept {
3384 update
3385 .contacts
3386 .push(contact_for_user(responder_id, responder_busy, &pool));
3387 }
3388 update
3389 .remove_outgoing_requests
3390 .push(responder_id.to_proto());
3391
3392 for connection_id in pool.user_connection_ids(requester_id) {
3393 session.peer.send(connection_id, update.clone())?;
3394 }
3395
3396 send_notifications(&pool, &session.peer, notifications);
3397 }
3398
3399 response.send(proto::Ack {})?;
3400 Ok(())
3401}
3402
3403/// Remove a contact.
3404async fn remove_contact(
3405 request: proto::RemoveContact,
3406 response: Response<proto::RemoveContact>,
3407 session: UserSession,
3408) -> Result<()> {
3409 let requester_id = session.user_id();
3410 let responder_id = UserId::from_proto(request.user_id);
3411 let db = session.db().await;
3412 let (contact_accepted, deleted_notification_id) =
3413 db.remove_contact(requester_id, responder_id).await?;
3414
3415 let pool = session.connection_pool().await;
3416 // Update outgoing contact requests of requester
3417 let mut update = proto::UpdateContacts::default();
3418 if contact_accepted {
3419 update.remove_contacts.push(responder_id.to_proto());
3420 } else {
3421 update
3422 .remove_outgoing_requests
3423 .push(responder_id.to_proto());
3424 }
3425 for connection_id in pool.user_connection_ids(requester_id) {
3426 session.peer.send(connection_id, update.clone())?;
3427 }
3428
3429 // Update incoming contact requests of responder
3430 let mut update = proto::UpdateContacts::default();
3431 if contact_accepted {
3432 update.remove_contacts.push(requester_id.to_proto());
3433 } else {
3434 update
3435 .remove_incoming_requests
3436 .push(requester_id.to_proto());
3437 }
3438 for connection_id in pool.user_connection_ids(responder_id) {
3439 session.peer.send(connection_id, update.clone())?;
3440 if let Some(notification_id) = deleted_notification_id {
3441 session.peer.send(
3442 connection_id,
3443 proto::DeleteNotification {
3444 notification_id: notification_id.to_proto(),
3445 },
3446 )?;
3447 }
3448 }
3449
3450 response.send(proto::Ack {})?;
3451 Ok(())
3452}
3453
3454fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3455 version.0.minor() < 139
3456}
3457
3458async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3459 subscribe_user_to_channels(
3460 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3461 &session,
3462 )
3463 .await?;
3464 Ok(())
3465}
3466
3467async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3468 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3469 let mut pool = session.connection_pool().await;
3470 for membership in &channels_for_user.channel_memberships {
3471 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3472 }
3473 session.peer.send(
3474 session.connection_id,
3475 build_update_user_channels(&channels_for_user),
3476 )?;
3477 session.peer.send(
3478 session.connection_id,
3479 build_channels_update(channels_for_user),
3480 )?;
3481 Ok(())
3482}
3483
3484/// Creates a new channel.
3485async fn create_channel(
3486 request: proto::CreateChannel,
3487 response: Response<proto::CreateChannel>,
3488 session: UserSession,
3489) -> Result<()> {
3490 let db = session.db().await;
3491
3492 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3493 let (channel, membership) = db
3494 .create_channel(&request.name, parent_id, session.user_id())
3495 .await?;
3496
3497 let root_id = channel.root_id();
3498 let channel = Channel::from_model(channel);
3499
3500 response.send(proto::CreateChannelResponse {
3501 channel: Some(channel.to_proto()),
3502 parent_id: request.parent_id,
3503 })?;
3504
3505 let mut connection_pool = session.connection_pool().await;
3506 if let Some(membership) = membership {
3507 connection_pool.subscribe_to_channel(
3508 membership.user_id,
3509 membership.channel_id,
3510 membership.role,
3511 );
3512 let update = proto::UpdateUserChannels {
3513 channel_memberships: vec![proto::ChannelMembership {
3514 channel_id: membership.channel_id.to_proto(),
3515 role: membership.role.into(),
3516 }],
3517 ..Default::default()
3518 };
3519 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3520 session.peer.send(connection_id, update.clone())?;
3521 }
3522 }
3523
3524 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3525 if !role.can_see_channel(channel.visibility) {
3526 continue;
3527 }
3528
3529 let update = proto::UpdateChannels {
3530 channels: vec![channel.to_proto()],
3531 ..Default::default()
3532 };
3533 session.peer.send(connection_id, update.clone())?;
3534 }
3535
3536 Ok(())
3537}
3538
3539/// Delete a channel
3540async fn delete_channel(
3541 request: proto::DeleteChannel,
3542 response: Response<proto::DeleteChannel>,
3543 session: UserSession,
3544) -> Result<()> {
3545 let db = session.db().await;
3546
3547 let channel_id = request.channel_id;
3548 let (root_channel, removed_channels) = db
3549 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3550 .await?;
3551 response.send(proto::Ack {})?;
3552
3553 // Notify members of removed channels
3554 let mut update = proto::UpdateChannels::default();
3555 update
3556 .delete_channels
3557 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3558
3559 let connection_pool = session.connection_pool().await;
3560 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3561 session.peer.send(connection_id, update.clone())?;
3562 }
3563
3564 Ok(())
3565}
3566
3567/// Invite someone to join a channel.
3568async fn invite_channel_member(
3569 request: proto::InviteChannelMember,
3570 response: Response<proto::InviteChannelMember>,
3571 session: UserSession,
3572) -> Result<()> {
3573 let db = session.db().await;
3574 let channel_id = ChannelId::from_proto(request.channel_id);
3575 let invitee_id = UserId::from_proto(request.user_id);
3576 let InviteMemberResult {
3577 channel,
3578 notifications,
3579 } = db
3580 .invite_channel_member(
3581 channel_id,
3582 invitee_id,
3583 session.user_id(),
3584 request.role().into(),
3585 )
3586 .await?;
3587
3588 let update = proto::UpdateChannels {
3589 channel_invitations: vec![channel.to_proto()],
3590 ..Default::default()
3591 };
3592
3593 let connection_pool = session.connection_pool().await;
3594 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3595 session.peer.send(connection_id, update.clone())?;
3596 }
3597
3598 send_notifications(&connection_pool, &session.peer, notifications);
3599
3600 response.send(proto::Ack {})?;
3601 Ok(())
3602}
3603
3604/// remove someone from a channel
3605async fn remove_channel_member(
3606 request: proto::RemoveChannelMember,
3607 response: Response<proto::RemoveChannelMember>,
3608 session: UserSession,
3609) -> Result<()> {
3610 let db = session.db().await;
3611 let channel_id = ChannelId::from_proto(request.channel_id);
3612 let member_id = UserId::from_proto(request.user_id);
3613
3614 let RemoveChannelMemberResult {
3615 membership_update,
3616 notification_id,
3617 } = db
3618 .remove_channel_member(channel_id, member_id, session.user_id())
3619 .await?;
3620
3621 let mut connection_pool = session.connection_pool().await;
3622 notify_membership_updated(
3623 &mut connection_pool,
3624 membership_update,
3625 member_id,
3626 &session.peer,
3627 );
3628 for connection_id in connection_pool.user_connection_ids(member_id) {
3629 if let Some(notification_id) = notification_id {
3630 session
3631 .peer
3632 .send(
3633 connection_id,
3634 proto::DeleteNotification {
3635 notification_id: notification_id.to_proto(),
3636 },
3637 )
3638 .trace_err();
3639 }
3640 }
3641
3642 response.send(proto::Ack {})?;
3643 Ok(())
3644}
3645
3646/// Toggle the channel between public and private.
3647/// Care is taken to maintain the invariant that public channels only descend from public channels,
3648/// (though members-only channels can appear at any point in the hierarchy).
3649async fn set_channel_visibility(
3650 request: proto::SetChannelVisibility,
3651 response: Response<proto::SetChannelVisibility>,
3652 session: UserSession,
3653) -> Result<()> {
3654 let db = session.db().await;
3655 let channel_id = ChannelId::from_proto(request.channel_id);
3656 let visibility = request.visibility().into();
3657
3658 let channel_model = db
3659 .set_channel_visibility(channel_id, visibility, session.user_id())
3660 .await?;
3661 let root_id = channel_model.root_id();
3662 let channel = Channel::from_model(channel_model);
3663
3664 let mut connection_pool = session.connection_pool().await;
3665 for (user_id, role) in connection_pool
3666 .channel_user_ids(root_id)
3667 .collect::<Vec<_>>()
3668 .into_iter()
3669 {
3670 let update = if role.can_see_channel(channel.visibility) {
3671 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3672 proto::UpdateChannels {
3673 channels: vec![channel.to_proto()],
3674 ..Default::default()
3675 }
3676 } else {
3677 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3678 proto::UpdateChannels {
3679 delete_channels: vec![channel.id.to_proto()],
3680 ..Default::default()
3681 }
3682 };
3683
3684 for connection_id in connection_pool.user_connection_ids(user_id) {
3685 session.peer.send(connection_id, update.clone())?;
3686 }
3687 }
3688
3689 response.send(proto::Ack {})?;
3690 Ok(())
3691}
3692
3693/// Alter the role for a user in the channel.
3694async fn set_channel_member_role(
3695 request: proto::SetChannelMemberRole,
3696 response: Response<proto::SetChannelMemberRole>,
3697 session: UserSession,
3698) -> Result<()> {
3699 let db = session.db().await;
3700 let channel_id = ChannelId::from_proto(request.channel_id);
3701 let member_id = UserId::from_proto(request.user_id);
3702 let result = db
3703 .set_channel_member_role(
3704 channel_id,
3705 session.user_id(),
3706 member_id,
3707 request.role().into(),
3708 )
3709 .await?;
3710
3711 match result {
3712 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3713 let mut connection_pool = session.connection_pool().await;
3714 notify_membership_updated(
3715 &mut connection_pool,
3716 membership_update,
3717 member_id,
3718 &session.peer,
3719 )
3720 }
3721 db::SetMemberRoleResult::InviteUpdated(channel) => {
3722 let update = proto::UpdateChannels {
3723 channel_invitations: vec![channel.to_proto()],
3724 ..Default::default()
3725 };
3726
3727 for connection_id in session
3728 .connection_pool()
3729 .await
3730 .user_connection_ids(member_id)
3731 {
3732 session.peer.send(connection_id, update.clone())?;
3733 }
3734 }
3735 }
3736
3737 response.send(proto::Ack {})?;
3738 Ok(())
3739}
3740
3741/// Change the name of a channel
3742async fn rename_channel(
3743 request: proto::RenameChannel,
3744 response: Response<proto::RenameChannel>,
3745 session: UserSession,
3746) -> Result<()> {
3747 let db = session.db().await;
3748 let channel_id = ChannelId::from_proto(request.channel_id);
3749 let channel_model = db
3750 .rename_channel(channel_id, session.user_id(), &request.name)
3751 .await?;
3752 let root_id = channel_model.root_id();
3753 let channel = Channel::from_model(channel_model);
3754
3755 response.send(proto::RenameChannelResponse {
3756 channel: Some(channel.to_proto()),
3757 })?;
3758
3759 let connection_pool = session.connection_pool().await;
3760 let update = proto::UpdateChannels {
3761 channels: vec![channel.to_proto()],
3762 ..Default::default()
3763 };
3764 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3765 if role.can_see_channel(channel.visibility) {
3766 session.peer.send(connection_id, update.clone())?;
3767 }
3768 }
3769
3770 Ok(())
3771}
3772
3773/// Move a channel to a new parent.
3774async fn move_channel(
3775 request: proto::MoveChannel,
3776 response: Response<proto::MoveChannel>,
3777 session: UserSession,
3778) -> Result<()> {
3779 let channel_id = ChannelId::from_proto(request.channel_id);
3780 let to = ChannelId::from_proto(request.to);
3781
3782 let (root_id, channels) = session
3783 .db()
3784 .await
3785 .move_channel(channel_id, to, session.user_id())
3786 .await?;
3787
3788 let connection_pool = session.connection_pool().await;
3789 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3790 let channels = channels
3791 .iter()
3792 .filter_map(|channel| {
3793 if role.can_see_channel(channel.visibility) {
3794 Some(channel.to_proto())
3795 } else {
3796 None
3797 }
3798 })
3799 .collect::<Vec<_>>();
3800 if channels.is_empty() {
3801 continue;
3802 }
3803
3804 let update = proto::UpdateChannels {
3805 channels,
3806 ..Default::default()
3807 };
3808
3809 session.peer.send(connection_id, update.clone())?;
3810 }
3811
3812 response.send(Ack {})?;
3813 Ok(())
3814}
3815
3816/// Get the list of channel members
3817async fn get_channel_members(
3818 request: proto::GetChannelMembers,
3819 response: Response<proto::GetChannelMembers>,
3820 session: UserSession,
3821) -> Result<()> {
3822 let db = session.db().await;
3823 let channel_id = ChannelId::from_proto(request.channel_id);
3824 let limit = if request.limit == 0 {
3825 u16::MAX as u64
3826 } else {
3827 request.limit
3828 };
3829 let (members, users) = db
3830 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3831 .await?;
3832 response.send(proto::GetChannelMembersResponse { members, users })?;
3833 Ok(())
3834}
3835
3836/// Accept or decline a channel invitation.
3837async fn respond_to_channel_invite(
3838 request: proto::RespondToChannelInvite,
3839 response: Response<proto::RespondToChannelInvite>,
3840 session: UserSession,
3841) -> Result<()> {
3842 let db = session.db().await;
3843 let channel_id = ChannelId::from_proto(request.channel_id);
3844 let RespondToChannelInvite {
3845 membership_update,
3846 notifications,
3847 } = db
3848 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3849 .await?;
3850
3851 let mut connection_pool = session.connection_pool().await;
3852 if let Some(membership_update) = membership_update {
3853 notify_membership_updated(
3854 &mut connection_pool,
3855 membership_update,
3856 session.user_id(),
3857 &session.peer,
3858 );
3859 } else {
3860 let update = proto::UpdateChannels {
3861 remove_channel_invitations: vec![channel_id.to_proto()],
3862 ..Default::default()
3863 };
3864
3865 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3866 session.peer.send(connection_id, update.clone())?;
3867 }
3868 };
3869
3870 send_notifications(&connection_pool, &session.peer, notifications);
3871
3872 response.send(proto::Ack {})?;
3873
3874 Ok(())
3875}
3876
3877/// Join the channels' room
3878async fn join_channel(
3879 request: proto::JoinChannel,
3880 response: Response<proto::JoinChannel>,
3881 session: UserSession,
3882) -> Result<()> {
3883 let channel_id = ChannelId::from_proto(request.channel_id);
3884 join_channel_internal(channel_id, Box::new(response), session).await
3885}
3886
3887trait JoinChannelInternalResponse {
3888 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3889}
3890impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3891 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3892 Response::<proto::JoinChannel>::send(self, result)
3893 }
3894}
3895impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3896 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3897 Response::<proto::JoinRoom>::send(self, result)
3898 }
3899}
3900
3901async fn join_channel_internal(
3902 channel_id: ChannelId,
3903 response: Box<impl JoinChannelInternalResponse>,
3904 session: UserSession,
3905) -> Result<()> {
3906 let joined_room = {
3907 let mut db = session.db().await;
3908 // If zed quits without leaving the room, and the user re-opens zed before the
3909 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3910 // room they were in.
3911 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3912 tracing::info!(
3913 stale_connection_id = %connection,
3914 "cleaning up stale connection",
3915 );
3916 drop(db);
3917 leave_room_for_session(&session, connection).await?;
3918 db = session.db().await;
3919 }
3920
3921 let (joined_room, membership_updated, role) = db
3922 .join_channel(channel_id, session.user_id(), session.connection_id)
3923 .await?;
3924
3925 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3926 let (can_publish, token) = if role == ChannelRole::Guest {
3927 (
3928 false,
3929 live_kit
3930 .guest_token(
3931 &joined_room.room.live_kit_room,
3932 &session.user_id().to_string(),
3933 )
3934 .trace_err()?,
3935 )
3936 } else {
3937 (
3938 true,
3939 live_kit
3940 .room_token(
3941 &joined_room.room.live_kit_room,
3942 &session.user_id().to_string(),
3943 )
3944 .trace_err()?,
3945 )
3946 };
3947
3948 Some(LiveKitConnectionInfo {
3949 server_url: live_kit.url().into(),
3950 token,
3951 can_publish,
3952 })
3953 });
3954
3955 response.send(proto::JoinRoomResponse {
3956 room: Some(joined_room.room.clone()),
3957 channel_id: joined_room
3958 .channel
3959 .as_ref()
3960 .map(|channel| channel.id.to_proto()),
3961 live_kit_connection_info,
3962 })?;
3963
3964 let mut connection_pool = session.connection_pool().await;
3965 if let Some(membership_updated) = membership_updated {
3966 notify_membership_updated(
3967 &mut connection_pool,
3968 membership_updated,
3969 session.user_id(),
3970 &session.peer,
3971 );
3972 }
3973
3974 room_updated(&joined_room.room, &session.peer);
3975
3976 joined_room
3977 };
3978
3979 channel_updated(
3980 &joined_room
3981 .channel
3982 .ok_or_else(|| anyhow!("channel not returned"))?,
3983 &joined_room.room,
3984 &session.peer,
3985 &*session.connection_pool().await,
3986 );
3987
3988 update_user_contacts(session.user_id(), &session).await?;
3989 Ok(())
3990}
3991
3992/// Start editing the channel notes
3993async fn join_channel_buffer(
3994 request: proto::JoinChannelBuffer,
3995 response: Response<proto::JoinChannelBuffer>,
3996 session: UserSession,
3997) -> Result<()> {
3998 let db = session.db().await;
3999 let channel_id = ChannelId::from_proto(request.channel_id);
4000
4001 let open_response = db
4002 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4003 .await?;
4004
4005 let collaborators = open_response.collaborators.clone();
4006 response.send(open_response)?;
4007
4008 let update = UpdateChannelBufferCollaborators {
4009 channel_id: channel_id.to_proto(),
4010 collaborators: collaborators.clone(),
4011 };
4012 channel_buffer_updated(
4013 session.connection_id,
4014 collaborators
4015 .iter()
4016 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4017 &update,
4018 &session.peer,
4019 );
4020
4021 Ok(())
4022}
4023
4024/// Edit the channel notes
4025async fn update_channel_buffer(
4026 request: proto::UpdateChannelBuffer,
4027 session: UserSession,
4028) -> Result<()> {
4029 let db = session.db().await;
4030 let channel_id = ChannelId::from_proto(request.channel_id);
4031
4032 let (collaborators, epoch, version) = db
4033 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4034 .await?;
4035
4036 channel_buffer_updated(
4037 session.connection_id,
4038 collaborators.clone(),
4039 &proto::UpdateChannelBuffer {
4040 channel_id: channel_id.to_proto(),
4041 operations: request.operations,
4042 },
4043 &session.peer,
4044 );
4045
4046 let pool = &*session.connection_pool().await;
4047
4048 let non_collaborators =
4049 pool.channel_connection_ids(channel_id)
4050 .filter_map(|(connection_id, _)| {
4051 if collaborators.contains(&connection_id) {
4052 None
4053 } else {
4054 Some(connection_id)
4055 }
4056 });
4057
4058 broadcast(None, non_collaborators, |peer_id| {
4059 session.peer.send(
4060 peer_id,
4061 proto::UpdateChannels {
4062 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4063 channel_id: channel_id.to_proto(),
4064 epoch: epoch as u64,
4065 version: version.clone(),
4066 }],
4067 ..Default::default()
4068 },
4069 )
4070 });
4071
4072 Ok(())
4073}
4074
4075/// Rejoin the channel notes after a connection blip
4076async fn rejoin_channel_buffers(
4077 request: proto::RejoinChannelBuffers,
4078 response: Response<proto::RejoinChannelBuffers>,
4079 session: UserSession,
4080) -> Result<()> {
4081 let db = session.db().await;
4082 let buffers = db
4083 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4084 .await?;
4085
4086 for rejoined_buffer in &buffers {
4087 let collaborators_to_notify = rejoined_buffer
4088 .buffer
4089 .collaborators
4090 .iter()
4091 .filter_map(|c| Some(c.peer_id?.into()));
4092 channel_buffer_updated(
4093 session.connection_id,
4094 collaborators_to_notify,
4095 &proto::UpdateChannelBufferCollaborators {
4096 channel_id: rejoined_buffer.buffer.channel_id,
4097 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4098 },
4099 &session.peer,
4100 );
4101 }
4102
4103 response.send(proto::RejoinChannelBuffersResponse {
4104 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4105 })?;
4106
4107 Ok(())
4108}
4109
4110/// Stop editing the channel notes
4111async fn leave_channel_buffer(
4112 request: proto::LeaveChannelBuffer,
4113 response: Response<proto::LeaveChannelBuffer>,
4114 session: UserSession,
4115) -> Result<()> {
4116 let db = session.db().await;
4117 let channel_id = ChannelId::from_proto(request.channel_id);
4118
4119 let left_buffer = db
4120 .leave_channel_buffer(channel_id, session.connection_id)
4121 .await?;
4122
4123 response.send(Ack {})?;
4124
4125 channel_buffer_updated(
4126 session.connection_id,
4127 left_buffer.connections,
4128 &proto::UpdateChannelBufferCollaborators {
4129 channel_id: channel_id.to_proto(),
4130 collaborators: left_buffer.collaborators,
4131 },
4132 &session.peer,
4133 );
4134
4135 Ok(())
4136}
4137
4138fn channel_buffer_updated<T: EnvelopedMessage>(
4139 sender_id: ConnectionId,
4140 collaborators: impl IntoIterator<Item = ConnectionId>,
4141 message: &T,
4142 peer: &Peer,
4143) {
4144 broadcast(Some(sender_id), collaborators, |peer_id| {
4145 peer.send(peer_id, message.clone())
4146 });
4147}
4148
4149fn send_notifications(
4150 connection_pool: &ConnectionPool,
4151 peer: &Peer,
4152 notifications: db::NotificationBatch,
4153) {
4154 for (user_id, notification) in notifications {
4155 for connection_id in connection_pool.user_connection_ids(user_id) {
4156 if let Err(error) = peer.send(
4157 connection_id,
4158 proto::AddNotification {
4159 notification: Some(notification.clone()),
4160 },
4161 ) {
4162 tracing::error!(
4163 "failed to send notification to {:?} {}",
4164 connection_id,
4165 error
4166 );
4167 }
4168 }
4169 }
4170}
4171
4172/// Send a message to the channel
4173async fn send_channel_message(
4174 request: proto::SendChannelMessage,
4175 response: Response<proto::SendChannelMessage>,
4176 session: UserSession,
4177) -> Result<()> {
4178 // Validate the message body.
4179 let body = request.body.trim().to_string();
4180 if body.len() > MAX_MESSAGE_LEN {
4181 return Err(anyhow!("message is too long"))?;
4182 }
4183 if body.is_empty() {
4184 return Err(anyhow!("message can't be blank"))?;
4185 }
4186
4187 // TODO: adjust mentions if body is trimmed
4188
4189 let timestamp = OffsetDateTime::now_utc();
4190 let nonce = request
4191 .nonce
4192 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4193
4194 let channel_id = ChannelId::from_proto(request.channel_id);
4195 let CreatedChannelMessage {
4196 message_id,
4197 participant_connection_ids,
4198 notifications,
4199 } = session
4200 .db()
4201 .await
4202 .create_channel_message(
4203 channel_id,
4204 session.user_id(),
4205 &body,
4206 &request.mentions,
4207 timestamp,
4208 nonce.clone().into(),
4209 match request.reply_to_message_id {
4210 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4211 None => None,
4212 },
4213 )
4214 .await?;
4215
4216 let message = proto::ChannelMessage {
4217 sender_id: session.user_id().to_proto(),
4218 id: message_id.to_proto(),
4219 body,
4220 mentions: request.mentions,
4221 timestamp: timestamp.unix_timestamp() as u64,
4222 nonce: Some(nonce),
4223 reply_to_message_id: request.reply_to_message_id,
4224 edited_at: None,
4225 };
4226 broadcast(
4227 Some(session.connection_id),
4228 participant_connection_ids.clone(),
4229 |connection| {
4230 session.peer.send(
4231 connection,
4232 proto::ChannelMessageSent {
4233 channel_id: channel_id.to_proto(),
4234 message: Some(message.clone()),
4235 },
4236 )
4237 },
4238 );
4239 response.send(proto::SendChannelMessageResponse {
4240 message: Some(message),
4241 })?;
4242
4243 let pool = &*session.connection_pool().await;
4244 let non_participants =
4245 pool.channel_connection_ids(channel_id)
4246 .filter_map(|(connection_id, _)| {
4247 if participant_connection_ids.contains(&connection_id) {
4248 None
4249 } else {
4250 Some(connection_id)
4251 }
4252 });
4253 broadcast(None, non_participants, |peer_id| {
4254 session.peer.send(
4255 peer_id,
4256 proto::UpdateChannels {
4257 latest_channel_message_ids: vec![proto::ChannelMessageId {
4258 channel_id: channel_id.to_proto(),
4259 message_id: message_id.to_proto(),
4260 }],
4261 ..Default::default()
4262 },
4263 )
4264 });
4265 send_notifications(pool, &session.peer, notifications);
4266
4267 Ok(())
4268}
4269
4270/// Delete a channel message
4271async fn remove_channel_message(
4272 request: proto::RemoveChannelMessage,
4273 response: Response<proto::RemoveChannelMessage>,
4274 session: UserSession,
4275) -> Result<()> {
4276 let channel_id = ChannelId::from_proto(request.channel_id);
4277 let message_id = MessageId::from_proto(request.message_id);
4278 let (connection_ids, existing_notification_ids) = session
4279 .db()
4280 .await
4281 .remove_channel_message(channel_id, message_id, session.user_id())
4282 .await?;
4283
4284 broadcast(
4285 Some(session.connection_id),
4286 connection_ids,
4287 move |connection| {
4288 session.peer.send(connection, request.clone())?;
4289
4290 for notification_id in &existing_notification_ids {
4291 session.peer.send(
4292 connection,
4293 proto::DeleteNotification {
4294 notification_id: (*notification_id).to_proto(),
4295 },
4296 )?;
4297 }
4298
4299 Ok(())
4300 },
4301 );
4302 response.send(proto::Ack {})?;
4303 Ok(())
4304}
4305
4306async fn update_channel_message(
4307 request: proto::UpdateChannelMessage,
4308 response: Response<proto::UpdateChannelMessage>,
4309 session: UserSession,
4310) -> Result<()> {
4311 let channel_id = ChannelId::from_proto(request.channel_id);
4312 let message_id = MessageId::from_proto(request.message_id);
4313 let updated_at = OffsetDateTime::now_utc();
4314 let UpdatedChannelMessage {
4315 message_id,
4316 participant_connection_ids,
4317 notifications,
4318 reply_to_message_id,
4319 timestamp,
4320 deleted_mention_notification_ids,
4321 updated_mention_notifications,
4322 } = session
4323 .db()
4324 .await
4325 .update_channel_message(
4326 channel_id,
4327 message_id,
4328 session.user_id(),
4329 request.body.as_str(),
4330 &request.mentions,
4331 updated_at,
4332 )
4333 .await?;
4334
4335 let nonce = request
4336 .nonce
4337 .clone()
4338 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4339
4340 let message = proto::ChannelMessage {
4341 sender_id: session.user_id().to_proto(),
4342 id: message_id.to_proto(),
4343 body: request.body.clone(),
4344 mentions: request.mentions.clone(),
4345 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4346 nonce: Some(nonce),
4347 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4348 edited_at: Some(updated_at.unix_timestamp() as u64),
4349 };
4350
4351 response.send(proto::Ack {})?;
4352
4353 let pool = &*session.connection_pool().await;
4354 broadcast(
4355 Some(session.connection_id),
4356 participant_connection_ids,
4357 |connection| {
4358 session.peer.send(
4359 connection,
4360 proto::ChannelMessageUpdate {
4361 channel_id: channel_id.to_proto(),
4362 message: Some(message.clone()),
4363 },
4364 )?;
4365
4366 for notification_id in &deleted_mention_notification_ids {
4367 session.peer.send(
4368 connection,
4369 proto::DeleteNotification {
4370 notification_id: (*notification_id).to_proto(),
4371 },
4372 )?;
4373 }
4374
4375 for notification in &updated_mention_notifications {
4376 session.peer.send(
4377 connection,
4378 proto::UpdateNotification {
4379 notification: Some(notification.clone()),
4380 },
4381 )?;
4382 }
4383
4384 Ok(())
4385 },
4386 );
4387
4388 send_notifications(pool, &session.peer, notifications);
4389
4390 Ok(())
4391}
4392
4393/// Mark a channel message as read
4394async fn acknowledge_channel_message(
4395 request: proto::AckChannelMessage,
4396 session: UserSession,
4397) -> Result<()> {
4398 let channel_id = ChannelId::from_proto(request.channel_id);
4399 let message_id = MessageId::from_proto(request.message_id);
4400 let notifications = session
4401 .db()
4402 .await
4403 .observe_channel_message(channel_id, session.user_id(), message_id)
4404 .await?;
4405 send_notifications(
4406 &*session.connection_pool().await,
4407 &session.peer,
4408 notifications,
4409 );
4410 Ok(())
4411}
4412
4413/// Mark a buffer version as synced
4414async fn acknowledge_buffer_version(
4415 request: proto::AckBufferOperation,
4416 session: UserSession,
4417) -> Result<()> {
4418 let buffer_id = BufferId::from_proto(request.buffer_id);
4419 session
4420 .db()
4421 .await
4422 .observe_buffer_version(
4423 buffer_id,
4424 session.user_id(),
4425 request.epoch as i32,
4426 &request.version,
4427 )
4428 .await?;
4429 Ok(())
4430}
4431
4432struct CompleteWithLanguageModelRateLimit;
4433
4434impl RateLimit for CompleteWithLanguageModelRateLimit {
4435 fn capacity() -> usize {
4436 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4437 .ok()
4438 .and_then(|v| v.parse().ok())
4439 .unwrap_or(120) // Picked arbitrarily
4440 }
4441
4442 fn refill_duration() -> chrono::Duration {
4443 chrono::Duration::hours(1)
4444 }
4445
4446 fn db_name() -> &'static str {
4447 "complete-with-language-model"
4448 }
4449}
4450
4451async fn complete_with_language_model(
4452 request: proto::CompleteWithLanguageModel,
4453 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4454 session: Session,
4455 open_ai_api_key: Option<Arc<str>>,
4456 google_ai_api_key: Option<Arc<str>>,
4457 anthropic_api_key: Option<Arc<str>>,
4458) -> Result<()> {
4459 let Some(session) = session.for_user() else {
4460 return Err(anyhow!("user not found"))?;
4461 };
4462 authorize_access_to_language_models(&session).await?;
4463 session
4464 .rate_limiter
4465 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4466 .await?;
4467
4468 if request.model.starts_with("gpt") {
4469 let api_key =
4470 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4471 complete_with_open_ai(request, response, session, api_key).await?;
4472 } else if request.model.starts_with("gemini") {
4473 let api_key = google_ai_api_key
4474 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4475 complete_with_google_ai(request, response, session, api_key).await?;
4476 } else if request.model.starts_with("claude") {
4477 let api_key = anthropic_api_key
4478 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4479 complete_with_anthropic(request, response, session, api_key).await?;
4480 }
4481
4482 Ok(())
4483}
4484
4485async fn complete_with_open_ai(
4486 request: proto::CompleteWithLanguageModel,
4487 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4488 session: UserSession,
4489 api_key: Arc<str>,
4490) -> Result<()> {
4491 let mut completion_stream = open_ai::stream_completion(
4492 session.http_client.as_ref(),
4493 OPEN_AI_API_URL,
4494 &api_key,
4495 crate::ai::language_model_request_to_open_ai(request)?,
4496 None,
4497 )
4498 .await
4499 .context("open_ai::stream_completion request failed within collab")?;
4500
4501 while let Some(event) = completion_stream.next().await {
4502 let event = event?;
4503 response.send(proto::LanguageModelResponse {
4504 choices: event
4505 .choices
4506 .into_iter()
4507 .map(|choice| proto::LanguageModelChoiceDelta {
4508 index: choice.index,
4509 delta: Some(proto::LanguageModelResponseMessage {
4510 role: choice.delta.role.map(|role| match role {
4511 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4512 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4513 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4514 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4515 } as i32),
4516 content: choice.delta.content,
4517 tool_calls: choice
4518 .delta
4519 .tool_calls
4520 .unwrap_or_default()
4521 .into_iter()
4522 .map(|delta| proto::ToolCallDelta {
4523 index: delta.index as u32,
4524 id: delta.id,
4525 variant: match delta.function {
4526 Some(function) => {
4527 let name = function.name;
4528 let arguments = function.arguments;
4529
4530 Some(proto::tool_call_delta::Variant::Function(
4531 proto::tool_call_delta::FunctionCallDelta {
4532 name,
4533 arguments,
4534 },
4535 ))
4536 }
4537 None => None,
4538 },
4539 })
4540 .collect(),
4541 }),
4542 finish_reason: choice.finish_reason,
4543 })
4544 .collect(),
4545 })?;
4546 }
4547
4548 Ok(())
4549}
4550
4551async fn complete_with_google_ai(
4552 request: proto::CompleteWithLanguageModel,
4553 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4554 session: UserSession,
4555 api_key: Arc<str>,
4556) -> Result<()> {
4557 let mut stream = google_ai::stream_generate_content(
4558 session.http_client.clone(),
4559 google_ai::API_URL,
4560 api_key.as_ref(),
4561 &request.model.clone(),
4562 crate::ai::language_model_request_to_google_ai(request)?,
4563 )
4564 .await
4565 .context("google_ai::stream_generate_content request failed")?;
4566
4567 while let Some(event) = stream.next().await {
4568 let event = event?;
4569 response.send(proto::LanguageModelResponse {
4570 choices: event
4571 .candidates
4572 .unwrap_or_default()
4573 .into_iter()
4574 .map(|candidate| proto::LanguageModelChoiceDelta {
4575 index: candidate.index as u32,
4576 delta: Some(proto::LanguageModelResponseMessage {
4577 role: Some(match candidate.content.role {
4578 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4579 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4580 } as i32),
4581 content: Some(
4582 candidate
4583 .content
4584 .parts
4585 .into_iter()
4586 .filter_map(|part| match part {
4587 google_ai::Part::TextPart(part) => Some(part.text),
4588 google_ai::Part::InlineDataPart(_) => None,
4589 })
4590 .collect(),
4591 ),
4592 // Tool calls are not supported for Google
4593 tool_calls: Vec::new(),
4594 }),
4595 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4596 })
4597 .collect(),
4598 })?;
4599 }
4600
4601 Ok(())
4602}
4603
4604async fn complete_with_anthropic(
4605 request: proto::CompleteWithLanguageModel,
4606 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4607 session: UserSession,
4608 api_key: Arc<str>,
4609) -> Result<()> {
4610 let model = anthropic::Model::from_id(&request.model)?;
4611
4612 let mut system_message = String::new();
4613 let messages = request
4614 .messages
4615 .into_iter()
4616 .filter_map(|message| {
4617 match message.role() {
4618 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4619 role: anthropic::Role::User,
4620 content: message.content,
4621 }),
4622 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4623 role: anthropic::Role::Assistant,
4624 content: message.content,
4625 }),
4626 // Anthropic's API breaks system instructions out as a separate field rather
4627 // than having a system message role.
4628 LanguageModelRole::LanguageModelSystem => {
4629 if !system_message.is_empty() {
4630 system_message.push_str("\n\n");
4631 }
4632 system_message.push_str(&message.content);
4633
4634 None
4635 }
4636 // We don't yet support tool calls for Anthropic
4637 LanguageModelRole::LanguageModelTool => None,
4638 }
4639 })
4640 .collect();
4641
4642 let mut stream = anthropic::stream_completion(
4643 session.http_client.as_ref(),
4644 anthropic::ANTHROPIC_API_URL,
4645 &api_key,
4646 anthropic::Request {
4647 model,
4648 messages,
4649 stream: true,
4650 system: system_message,
4651 max_tokens: 4092,
4652 },
4653 None,
4654 )
4655 .await?;
4656
4657 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4658
4659 while let Some(event) = stream.next().await {
4660 let event = event?;
4661
4662 match event {
4663 anthropic::ResponseEvent::MessageStart { message } => {
4664 if let Some(role) = message.role {
4665 if role == "assistant" {
4666 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4667 } else if role == "user" {
4668 current_role = proto::LanguageModelRole::LanguageModelUser;
4669 }
4670 }
4671 }
4672 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4673 match content_block {
4674 anthropic::ContentBlock::Text { text } => {
4675 if !text.is_empty() {
4676 response.send(proto::LanguageModelResponse {
4677 choices: vec![proto::LanguageModelChoiceDelta {
4678 index: 0,
4679 delta: Some(proto::LanguageModelResponseMessage {
4680 role: Some(current_role as i32),
4681 content: Some(text),
4682 tool_calls: Vec::new(),
4683 }),
4684 finish_reason: None,
4685 }],
4686 })?;
4687 }
4688 }
4689 }
4690 }
4691 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4692 anthropic::TextDelta::TextDelta { text } => {
4693 response.send(proto::LanguageModelResponse {
4694 choices: vec![proto::LanguageModelChoiceDelta {
4695 index: 0,
4696 delta: Some(proto::LanguageModelResponseMessage {
4697 role: Some(current_role as i32),
4698 content: Some(text),
4699 tool_calls: Vec::new(),
4700 }),
4701 finish_reason: None,
4702 }],
4703 })?;
4704 }
4705 },
4706 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4707 if let Some(stop_reason) = delta.stop_reason {
4708 response.send(proto::LanguageModelResponse {
4709 choices: vec![proto::LanguageModelChoiceDelta {
4710 index: 0,
4711 delta: None,
4712 finish_reason: Some(stop_reason),
4713 }],
4714 })?;
4715 }
4716 }
4717 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4718 anthropic::ResponseEvent::MessageStop {} => {}
4719 anthropic::ResponseEvent::Ping {} => {}
4720 }
4721 }
4722
4723 Ok(())
4724}
4725
4726struct CountTokensWithLanguageModelRateLimit;
4727
4728impl RateLimit for CountTokensWithLanguageModelRateLimit {
4729 fn capacity() -> usize {
4730 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4731 .ok()
4732 .and_then(|v| v.parse().ok())
4733 .unwrap_or(600) // Picked arbitrarily
4734 }
4735
4736 fn refill_duration() -> chrono::Duration {
4737 chrono::Duration::hours(1)
4738 }
4739
4740 fn db_name() -> &'static str {
4741 "count-tokens-with-language-model"
4742 }
4743}
4744
4745async fn count_tokens_with_language_model(
4746 request: proto::CountTokensWithLanguageModel,
4747 response: Response<proto::CountTokensWithLanguageModel>,
4748 session: UserSession,
4749 google_ai_api_key: Option<Arc<str>>,
4750) -> Result<()> {
4751 authorize_access_to_language_models(&session).await?;
4752
4753 if !request.model.starts_with("gemini") {
4754 return Err(anyhow!(
4755 "counting tokens for model: {:?} is not supported",
4756 request.model
4757 ))?;
4758 }
4759
4760 session
4761 .rate_limiter
4762 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4763 .await?;
4764
4765 let api_key = google_ai_api_key
4766 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4767 let tokens_response = google_ai::count_tokens(
4768 session.http_client.as_ref(),
4769 google_ai::API_URL,
4770 &api_key,
4771 crate::ai::count_tokens_request_to_google_ai(request)?,
4772 )
4773 .await?;
4774 response.send(proto::CountTokensResponse {
4775 token_count: tokens_response.total_tokens as u32,
4776 })?;
4777 Ok(())
4778}
4779
4780struct ComputeEmbeddingsRateLimit;
4781
4782impl RateLimit for ComputeEmbeddingsRateLimit {
4783 fn capacity() -> usize {
4784 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4785 .ok()
4786 .and_then(|v| v.parse().ok())
4787 .unwrap_or(5000) // Picked arbitrarily
4788 }
4789
4790 fn refill_duration() -> chrono::Duration {
4791 chrono::Duration::hours(1)
4792 }
4793
4794 fn db_name() -> &'static str {
4795 "compute-embeddings"
4796 }
4797}
4798
4799async fn compute_embeddings(
4800 request: proto::ComputeEmbeddings,
4801 response: Response<proto::ComputeEmbeddings>,
4802 session: UserSession,
4803 api_key: Option<Arc<str>>,
4804) -> Result<()> {
4805 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4806 authorize_access_to_language_models(&session).await?;
4807
4808 session
4809 .rate_limiter
4810 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4811 .await?;
4812
4813 let embeddings = match request.model.as_str() {
4814 "openai/text-embedding-3-small" => {
4815 open_ai::embed(
4816 session.http_client.as_ref(),
4817 OPEN_AI_API_URL,
4818 &api_key,
4819 OpenAiEmbeddingModel::TextEmbedding3Small,
4820 request.texts.iter().map(|text| text.as_str()),
4821 )
4822 .await?
4823 }
4824 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4825 };
4826
4827 let embeddings = request
4828 .texts
4829 .iter()
4830 .map(|text| {
4831 let mut hasher = sha2::Sha256::new();
4832 hasher.update(text.as_bytes());
4833 let result = hasher.finalize();
4834 result.to_vec()
4835 })
4836 .zip(
4837 embeddings
4838 .data
4839 .into_iter()
4840 .map(|embedding| embedding.embedding),
4841 )
4842 .collect::<HashMap<_, _>>();
4843
4844 let db = session.db().await;
4845 db.save_embeddings(&request.model, &embeddings)
4846 .await
4847 .context("failed to save embeddings")
4848 .trace_err();
4849
4850 response.send(proto::ComputeEmbeddingsResponse {
4851 embeddings: embeddings
4852 .into_iter()
4853 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4854 .collect(),
4855 })?;
4856 Ok(())
4857}
4858
4859async fn get_cached_embeddings(
4860 request: proto::GetCachedEmbeddings,
4861 response: Response<proto::GetCachedEmbeddings>,
4862 session: UserSession,
4863) -> Result<()> {
4864 authorize_access_to_language_models(&session).await?;
4865
4866 let db = session.db().await;
4867 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4868
4869 response.send(proto::GetCachedEmbeddingsResponse {
4870 embeddings: embeddings
4871 .into_iter()
4872 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4873 .collect(),
4874 })?;
4875 Ok(())
4876}
4877
4878async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4879 let db = session.db().await;
4880 let flags = db.get_user_flags(session.user_id()).await?;
4881 if flags.iter().any(|flag| flag == "language-models") {
4882 Ok(())
4883 } else {
4884 Err(anyhow!("permission denied"))?
4885 }
4886}
4887
4888/// Get a Supermaven API key for the user
4889async fn get_supermaven_api_key(
4890 _request: proto::GetSupermavenApiKey,
4891 response: Response<proto::GetSupermavenApiKey>,
4892 session: UserSession,
4893) -> Result<()> {
4894 let user_id: String = session.user_id().to_string();
4895 if !session.is_staff() {
4896 return Err(anyhow!("supermaven not enabled for this account"))?;
4897 }
4898
4899 let email = session
4900 .email()
4901 .ok_or_else(|| anyhow!("user must have an email"))?;
4902
4903 let supermaven_admin_api = session
4904 .supermaven_client
4905 .as_ref()
4906 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4907
4908 let result = supermaven_admin_api
4909 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4910 .await?;
4911
4912 response.send(proto::GetSupermavenApiKeyResponse {
4913 api_key: result.api_key,
4914 })?;
4915
4916 Ok(())
4917}
4918
4919/// Start receiving chat updates for a channel
4920async fn join_channel_chat(
4921 request: proto::JoinChannelChat,
4922 response: Response<proto::JoinChannelChat>,
4923 session: UserSession,
4924) -> Result<()> {
4925 let channel_id = ChannelId::from_proto(request.channel_id);
4926
4927 let db = session.db().await;
4928 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4929 .await?;
4930 let messages = db
4931 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4932 .await?;
4933 response.send(proto::JoinChannelChatResponse {
4934 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4935 messages,
4936 })?;
4937 Ok(())
4938}
4939
4940/// Stop receiving chat updates for a channel
4941async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4942 let channel_id = ChannelId::from_proto(request.channel_id);
4943 session
4944 .db()
4945 .await
4946 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4947 .await?;
4948 Ok(())
4949}
4950
4951/// Retrieve the chat history for a channel
4952async fn get_channel_messages(
4953 request: proto::GetChannelMessages,
4954 response: Response<proto::GetChannelMessages>,
4955 session: UserSession,
4956) -> Result<()> {
4957 let channel_id = ChannelId::from_proto(request.channel_id);
4958 let messages = session
4959 .db()
4960 .await
4961 .get_channel_messages(
4962 channel_id,
4963 session.user_id(),
4964 MESSAGE_COUNT_PER_PAGE,
4965 Some(MessageId::from_proto(request.before_message_id)),
4966 )
4967 .await?;
4968 response.send(proto::GetChannelMessagesResponse {
4969 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4970 messages,
4971 })?;
4972 Ok(())
4973}
4974
4975/// Retrieve specific chat messages
4976async fn get_channel_messages_by_id(
4977 request: proto::GetChannelMessagesById,
4978 response: Response<proto::GetChannelMessagesById>,
4979 session: UserSession,
4980) -> Result<()> {
4981 let message_ids = request
4982 .message_ids
4983 .iter()
4984 .map(|id| MessageId::from_proto(*id))
4985 .collect::<Vec<_>>();
4986 let messages = session
4987 .db()
4988 .await
4989 .get_channel_messages_by_id(session.user_id(), &message_ids)
4990 .await?;
4991 response.send(proto::GetChannelMessagesResponse {
4992 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4993 messages,
4994 })?;
4995 Ok(())
4996}
4997
4998/// Retrieve the current users notifications
4999async fn get_notifications(
5000 request: proto::GetNotifications,
5001 response: Response<proto::GetNotifications>,
5002 session: UserSession,
5003) -> Result<()> {
5004 let notifications = session
5005 .db()
5006 .await
5007 .get_notifications(
5008 session.user_id(),
5009 NOTIFICATION_COUNT_PER_PAGE,
5010 request
5011 .before_id
5012 .map(|id| db::NotificationId::from_proto(id)),
5013 )
5014 .await?;
5015 response.send(proto::GetNotificationsResponse {
5016 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5017 notifications,
5018 })?;
5019 Ok(())
5020}
5021
5022/// Mark notifications as read
5023async fn mark_notification_as_read(
5024 request: proto::MarkNotificationRead,
5025 response: Response<proto::MarkNotificationRead>,
5026 session: UserSession,
5027) -> Result<()> {
5028 let database = &session.db().await;
5029 let notifications = database
5030 .mark_notification_as_read_by_id(
5031 session.user_id(),
5032 NotificationId::from_proto(request.notification_id),
5033 )
5034 .await?;
5035 send_notifications(
5036 &*session.connection_pool().await,
5037 &session.peer,
5038 notifications,
5039 );
5040 response.send(proto::Ack {})?;
5041 Ok(())
5042}
5043
5044/// Get the current users information
5045async fn get_private_user_info(
5046 _request: proto::GetPrivateUserInfo,
5047 response: Response<proto::GetPrivateUserInfo>,
5048 session: UserSession,
5049) -> Result<()> {
5050 let db = session.db().await;
5051
5052 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5053 let user = db
5054 .get_user_by_id(session.user_id())
5055 .await?
5056 .ok_or_else(|| anyhow!("user not found"))?;
5057 let flags = db.get_user_flags(session.user_id()).await?;
5058
5059 response.send(proto::GetPrivateUserInfoResponse {
5060 metrics_id,
5061 staff: user.admin,
5062 flags,
5063 })?;
5064 Ok(())
5065}
5066
5067fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5068 match message {
5069 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5070 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5071 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5072 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5073 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5074 code: frame.code.into(),
5075 reason: frame.reason,
5076 })),
5077 }
5078}
5079
5080fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5081 match message {
5082 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5083 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5084 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5085 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5086 AxumMessage::Close(frame) => {
5087 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5088 code: frame.code.into(),
5089 reason: frame.reason,
5090 }))
5091 }
5092 }
5093}
5094
5095fn notify_membership_updated(
5096 connection_pool: &mut ConnectionPool,
5097 result: MembershipUpdated,
5098 user_id: UserId,
5099 peer: &Peer,
5100) {
5101 for membership in &result.new_channels.channel_memberships {
5102 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5103 }
5104 for channel_id in &result.removed_channels {
5105 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5106 }
5107
5108 let user_channels_update = proto::UpdateUserChannels {
5109 channel_memberships: result
5110 .new_channels
5111 .channel_memberships
5112 .iter()
5113 .map(|cm| proto::ChannelMembership {
5114 channel_id: cm.channel_id.to_proto(),
5115 role: cm.role.into(),
5116 })
5117 .collect(),
5118 ..Default::default()
5119 };
5120
5121 let mut update = build_channels_update(result.new_channels);
5122 update.delete_channels = result
5123 .removed_channels
5124 .into_iter()
5125 .map(|id| id.to_proto())
5126 .collect();
5127 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5128
5129 for connection_id in connection_pool.user_connection_ids(user_id) {
5130 peer.send(connection_id, user_channels_update.clone())
5131 .trace_err();
5132 peer.send(connection_id, update.clone()).trace_err();
5133 }
5134}
5135
5136fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5137 proto::UpdateUserChannels {
5138 channel_memberships: channels
5139 .channel_memberships
5140 .iter()
5141 .map(|m| proto::ChannelMembership {
5142 channel_id: m.channel_id.to_proto(),
5143 role: m.role.into(),
5144 })
5145 .collect(),
5146 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5147 observed_channel_message_id: channels.observed_channel_messages.clone(),
5148 }
5149}
5150
5151fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5152 let mut update = proto::UpdateChannels::default();
5153
5154 for channel in channels.channels {
5155 update.channels.push(channel.to_proto());
5156 }
5157
5158 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5159 update.latest_channel_message_ids = channels.latest_channel_messages;
5160
5161 for (channel_id, participants) in channels.channel_participants {
5162 update
5163 .channel_participants
5164 .push(proto::ChannelParticipants {
5165 channel_id: channel_id.to_proto(),
5166 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5167 });
5168 }
5169
5170 for channel in channels.invited_channels {
5171 update.channel_invitations.push(channel.to_proto());
5172 }
5173
5174 update.hosted_projects = channels.hosted_projects;
5175 update
5176}
5177
5178fn build_initial_contacts_update(
5179 contacts: Vec<db::Contact>,
5180 pool: &ConnectionPool,
5181) -> proto::UpdateContacts {
5182 let mut update = proto::UpdateContacts::default();
5183
5184 for contact in contacts {
5185 match contact {
5186 db::Contact::Accepted { user_id, busy } => {
5187 update.contacts.push(contact_for_user(user_id, busy, &pool));
5188 }
5189 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5190 db::Contact::Incoming { user_id } => {
5191 update
5192 .incoming_requests
5193 .push(proto::IncomingContactRequest {
5194 requester_id: user_id.to_proto(),
5195 })
5196 }
5197 }
5198 }
5199
5200 update
5201}
5202
5203fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5204 proto::Contact {
5205 user_id: user_id.to_proto(),
5206 online: pool.is_user_online(user_id),
5207 busy,
5208 }
5209}
5210
5211fn room_updated(room: &proto::Room, peer: &Peer) {
5212 broadcast(
5213 None,
5214 room.participants
5215 .iter()
5216 .filter_map(|participant| Some(participant.peer_id?.into())),
5217 |peer_id| {
5218 peer.send(
5219 peer_id,
5220 proto::RoomUpdated {
5221 room: Some(room.clone()),
5222 },
5223 )
5224 },
5225 );
5226}
5227
5228fn channel_updated(
5229 channel: &db::channel::Model,
5230 room: &proto::Room,
5231 peer: &Peer,
5232 pool: &ConnectionPool,
5233) {
5234 let participants = room
5235 .participants
5236 .iter()
5237 .map(|p| p.user_id)
5238 .collect::<Vec<_>>();
5239
5240 broadcast(
5241 None,
5242 pool.channel_connection_ids(channel.root_id())
5243 .filter_map(|(channel_id, role)| {
5244 role.can_see_channel(channel.visibility).then(|| channel_id)
5245 }),
5246 |peer_id| {
5247 peer.send(
5248 peer_id,
5249 proto::UpdateChannels {
5250 channel_participants: vec![proto::ChannelParticipants {
5251 channel_id: channel.id.to_proto(),
5252 participant_user_ids: participants.clone(),
5253 }],
5254 ..Default::default()
5255 },
5256 )
5257 },
5258 );
5259}
5260
5261async fn send_dev_server_projects_update(
5262 user_id: UserId,
5263 mut status: proto::DevServerProjectsUpdate,
5264 session: &Session,
5265) {
5266 let pool = session.connection_pool().await;
5267 for dev_server in &mut status.dev_servers {
5268 dev_server.status =
5269 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5270 }
5271 let connections = pool.user_connection_ids(user_id);
5272 for connection_id in connections {
5273 session.peer.send(connection_id, status.clone()).trace_err();
5274 }
5275}
5276
5277async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5278 let db = session.db().await;
5279
5280 let contacts = db.get_contacts(user_id).await?;
5281 let busy = db.is_user_busy(user_id).await?;
5282
5283 let pool = session.connection_pool().await;
5284 let updated_contact = contact_for_user(user_id, busy, &pool);
5285 for contact in contacts {
5286 if let db::Contact::Accepted {
5287 user_id: contact_user_id,
5288 ..
5289 } = contact
5290 {
5291 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5292 session
5293 .peer
5294 .send(
5295 contact_conn_id,
5296 proto::UpdateContacts {
5297 contacts: vec![updated_contact.clone()],
5298 remove_contacts: Default::default(),
5299 incoming_requests: Default::default(),
5300 remove_incoming_requests: Default::default(),
5301 outgoing_requests: Default::default(),
5302 remove_outgoing_requests: Default::default(),
5303 },
5304 )
5305 .trace_err();
5306 }
5307 }
5308 }
5309 Ok(())
5310}
5311
5312async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5313 log::info!("lost dev server connection, unsharing projects");
5314 let project_ids = session
5315 .db()
5316 .await
5317 .get_stale_dev_server_projects(session.connection_id)
5318 .await?;
5319
5320 for project_id in project_ids {
5321 // not unshare re-checks the connection ids match, so we get away with no transaction
5322 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5323 }
5324
5325 let user_id = session.dev_server().user_id;
5326 let update = session
5327 .db()
5328 .await
5329 .dev_server_projects_update(user_id)
5330 .await?;
5331
5332 send_dev_server_projects_update(user_id, update, session).await;
5333
5334 Ok(())
5335}
5336
5337async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5338 let mut contacts_to_update = HashSet::default();
5339
5340 let room_id;
5341 let canceled_calls_to_user_ids;
5342 let live_kit_room;
5343 let delete_live_kit_room;
5344 let room;
5345 let channel;
5346
5347 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5348 contacts_to_update.insert(session.user_id());
5349
5350 for project in left_room.left_projects.values() {
5351 project_left(project, session);
5352 }
5353
5354 room_id = RoomId::from_proto(left_room.room.id);
5355 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5356 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5357 delete_live_kit_room = left_room.deleted;
5358 room = mem::take(&mut left_room.room);
5359 channel = mem::take(&mut left_room.channel);
5360
5361 room_updated(&room, &session.peer);
5362 } else {
5363 return Ok(());
5364 }
5365
5366 if let Some(channel) = channel {
5367 channel_updated(
5368 &channel,
5369 &room,
5370 &session.peer,
5371 &*session.connection_pool().await,
5372 );
5373 }
5374
5375 {
5376 let pool = session.connection_pool().await;
5377 for canceled_user_id in canceled_calls_to_user_ids {
5378 for connection_id in pool.user_connection_ids(canceled_user_id) {
5379 session
5380 .peer
5381 .send(
5382 connection_id,
5383 proto::CallCanceled {
5384 room_id: room_id.to_proto(),
5385 },
5386 )
5387 .trace_err();
5388 }
5389 contacts_to_update.insert(canceled_user_id);
5390 }
5391 }
5392
5393 for contact_user_id in contacts_to_update {
5394 update_user_contacts(contact_user_id, &session).await?;
5395 }
5396
5397 if let Some(live_kit) = session.live_kit_client.as_ref() {
5398 live_kit
5399 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5400 .await
5401 .trace_err();
5402
5403 if delete_live_kit_room {
5404 live_kit.delete_room(live_kit_room).await.trace_err();
5405 }
5406 }
5407
5408 Ok(())
5409}
5410
5411async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5412 let left_channel_buffers = session
5413 .db()
5414 .await
5415 .leave_channel_buffers(session.connection_id)
5416 .await?;
5417
5418 for left_buffer in left_channel_buffers {
5419 channel_buffer_updated(
5420 session.connection_id,
5421 left_buffer.connections,
5422 &proto::UpdateChannelBufferCollaborators {
5423 channel_id: left_buffer.channel_id.to_proto(),
5424 collaborators: left_buffer.collaborators,
5425 },
5426 &session.peer,
5427 );
5428 }
5429
5430 Ok(())
5431}
5432
5433fn project_left(project: &db::LeftProject, session: &UserSession) {
5434 for connection_id in &project.connection_ids {
5435 if project.should_unshare {
5436 session
5437 .peer
5438 .send(
5439 *connection_id,
5440 proto::UnshareProject {
5441 project_id: project.id.to_proto(),
5442 },
5443 )
5444 .trace_err();
5445 } else {
5446 session
5447 .peer
5448 .send(
5449 *connection_id,
5450 proto::RemoveProjectCollaborator {
5451 project_id: project.id.to_proto(),
5452 peer_id: Some(session.connection_id.into()),
5453 },
5454 )
5455 .trace_err();
5456 }
5457 }
5458}
5459
5460pub trait ResultExt {
5461 type Ok;
5462
5463 fn trace_err(self) -> Option<Self::Ok>;
5464}
5465
5466impl<T, E> ResultExt for Result<T, E>
5467where
5468 E: std::fmt::Debug,
5469{
5470 type Ok = T;
5471
5472 #[track_caller]
5473 fn trace_err(self) -> Option<T> {
5474 match self {
5475 Ok(value) => Some(value),
5476 Err(error) => {
5477 tracing::error!("{:?}", error);
5478 None
5479 }
5480 }
5481 }
5482}