1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::RestartLanguageServers>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::LinkedEditingRange>,
553 ))
554 .add_message_handler(create_buffer_for_peer)
555 .add_request_handler(update_buffer)
556 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
557 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
558 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
561 .add_request_handler(get_users)
562 .add_request_handler(user_handler(fuzzy_search_users))
563 .add_request_handler(user_handler(request_contact))
564 .add_request_handler(user_handler(remove_contact))
565 .add_request_handler(user_handler(respond_to_contact_request))
566 .add_message_handler(subscribe_to_channels)
567 .add_request_handler(user_handler(create_channel))
568 .add_request_handler(user_handler(delete_channel))
569 .add_request_handler(user_handler(invite_channel_member))
570 .add_request_handler(user_handler(remove_channel_member))
571 .add_request_handler(user_handler(set_channel_member_role))
572 .add_request_handler(user_handler(set_channel_visibility))
573 .add_request_handler(user_handler(rename_channel))
574 .add_request_handler(user_handler(join_channel_buffer))
575 .add_request_handler(user_handler(leave_channel_buffer))
576 .add_message_handler(user_message_handler(update_channel_buffer))
577 .add_request_handler(user_handler(rejoin_channel_buffers))
578 .add_request_handler(user_handler(get_channel_members))
579 .add_request_handler(user_handler(respond_to_channel_invite))
580 .add_request_handler(user_handler(join_channel))
581 .add_request_handler(user_handler(join_channel_chat))
582 .add_message_handler(user_message_handler(leave_channel_chat))
583 .add_request_handler(user_handler(send_channel_message))
584 .add_request_handler(user_handler(remove_channel_message))
585 .add_request_handler(user_handler(update_channel_message))
586 .add_request_handler(user_handler(get_channel_messages))
587 .add_request_handler(user_handler(get_channel_messages_by_id))
588 .add_request_handler(user_handler(get_notifications))
589 .add_request_handler(user_handler(mark_notification_as_read))
590 .add_request_handler(user_handler(move_channel))
591 .add_request_handler(user_handler(follow))
592 .add_message_handler(user_message_handler(unfollow))
593 .add_message_handler(user_message_handler(update_followers))
594 .add_request_handler(user_handler(get_private_user_info))
595 .add_message_handler(user_message_handler(acknowledge_channel_message))
596 .add_message_handler(user_message_handler(acknowledge_buffer_version))
597 .add_request_handler(user_handler(get_supermaven_api_key))
598 .add_request_handler(user_handler(
599 forward_mutating_project_request::<proto::OpenContext>,
600 ))
601 .add_request_handler(user_handler(
602 forward_mutating_project_request::<proto::SynchronizeContexts>,
603 ))
604 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
605 .add_message_handler(update_context)
606 .add_streaming_request_handler({
607 let app_state = app_state.clone();
608 move |request, response, session| {
609 complete_with_language_model(
610 request,
611 response,
612 session,
613 app_state.config.openai_api_key.clone(),
614 app_state.config.google_ai_api_key.clone(),
615 app_state.config.anthropic_api_key.clone(),
616 )
617 }
618 })
619 .add_request_handler({
620 let app_state = app_state.clone();
621 user_handler(move |request, response, session| {
622 count_tokens_with_language_model(
623 request,
624 response,
625 session,
626 app_state.config.google_ai_api_key.clone(),
627 )
628 })
629 })
630 .add_request_handler({
631 user_handler(move |request, response, session| {
632 get_cached_embeddings(request, response, session)
633 })
634 })
635 .add_request_handler({
636 let app_state = app_state.clone();
637 user_handler(move |request, response, session| {
638 compute_embeddings(
639 request,
640 response,
641 session,
642 app_state.config.openai_api_key.clone(),
643 )
644 })
645 })
646 .add_request_handler(user_handler(
647 forward_read_only_project_request::<proto::GetSignatureHelp>,
648 ));
649
650 Arc::new(server)
651 }
652
653 pub async fn start(&self) -> Result<()> {
654 let server_id = *self.id.lock();
655 let app_state = self.app_state.clone();
656 let peer = self.peer.clone();
657 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
658 let pool = self.connection_pool.clone();
659 let live_kit_client = self.app_state.live_kit_client.clone();
660
661 let span = info_span!("start server");
662 self.app_state.executor.spawn_detached(
663 async move {
664 tracing::info!("waiting for cleanup timeout");
665 timeout.await;
666 tracing::info!("cleanup timeout expired, retrieving stale rooms");
667 if let Some((room_ids, channel_ids)) = app_state
668 .db
669 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
670 .await
671 .trace_err()
672 {
673 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
674 tracing::info!(
675 stale_channel_buffer_count = channel_ids.len(),
676 "retrieved stale channel buffers"
677 );
678
679 for channel_id in channel_ids {
680 if let Some(refreshed_channel_buffer) = app_state
681 .db
682 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
683 .await
684 .trace_err()
685 {
686 for connection_id in refreshed_channel_buffer.connection_ids {
687 peer.send(
688 connection_id,
689 proto::UpdateChannelBufferCollaborators {
690 channel_id: channel_id.to_proto(),
691 collaborators: refreshed_channel_buffer
692 .collaborators
693 .clone(),
694 },
695 )
696 .trace_err();
697 }
698 }
699 }
700
701 for room_id in room_ids {
702 let mut contacts_to_update = HashSet::default();
703 let mut canceled_calls_to_user_ids = Vec::new();
704 let mut live_kit_room = String::new();
705 let mut delete_live_kit_room = false;
706
707 if let Some(mut refreshed_room) = app_state
708 .db
709 .clear_stale_room_participants(room_id, server_id)
710 .await
711 .trace_err()
712 {
713 tracing::info!(
714 room_id = room_id.0,
715 new_participant_count = refreshed_room.room.participants.len(),
716 "refreshed room"
717 );
718 room_updated(&refreshed_room.room, &peer);
719 if let Some(channel) = refreshed_room.channel.as_ref() {
720 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
721 }
722 contacts_to_update
723 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
724 contacts_to_update
725 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
726 canceled_calls_to_user_ids =
727 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
728 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
729 delete_live_kit_room = refreshed_room.room.participants.is_empty();
730 }
731
732 {
733 let pool = pool.lock();
734 for canceled_user_id in canceled_calls_to_user_ids {
735 for connection_id in pool.user_connection_ids(canceled_user_id) {
736 peer.send(
737 connection_id,
738 proto::CallCanceled {
739 room_id: room_id.to_proto(),
740 },
741 )
742 .trace_err();
743 }
744 }
745 }
746
747 for user_id in contacts_to_update {
748 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
749 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
750 if let Some((busy, contacts)) = busy.zip(contacts) {
751 let pool = pool.lock();
752 let updated_contact = contact_for_user(user_id, busy, &pool);
753 for contact in contacts {
754 if let db::Contact::Accepted {
755 user_id: contact_user_id,
756 ..
757 } = contact
758 {
759 for contact_conn_id in
760 pool.user_connection_ids(contact_user_id)
761 {
762 peer.send(
763 contact_conn_id,
764 proto::UpdateContacts {
765 contacts: vec![updated_contact.clone()],
766 remove_contacts: Default::default(),
767 incoming_requests: Default::default(),
768 remove_incoming_requests: Default::default(),
769 outgoing_requests: Default::default(),
770 remove_outgoing_requests: Default::default(),
771 },
772 )
773 .trace_err();
774 }
775 }
776 }
777 }
778 }
779
780 if let Some(live_kit) = live_kit_client.as_ref() {
781 if delete_live_kit_room {
782 live_kit.delete_room(live_kit_room).await.trace_err();
783 }
784 }
785 }
786 }
787
788 app_state
789 .db
790 .delete_stale_servers(&app_state.config.zed_environment, server_id)
791 .await
792 .trace_err();
793 }
794 .instrument(span),
795 );
796 Ok(())
797 }
798
799 pub fn teardown(&self) {
800 self.peer.teardown();
801 self.connection_pool.lock().reset();
802 let _ = self.teardown.send(true);
803 }
804
805 #[cfg(test)]
806 pub fn reset(&self, id: ServerId) {
807 self.teardown();
808 *self.id.lock() = id;
809 self.peer.reset(id.0 as u32);
810 let _ = self.teardown.send(false);
811 }
812
813 #[cfg(test)]
814 pub fn id(&self) -> ServerId {
815 *self.id.lock()
816 }
817
818 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
819 where
820 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
821 Fut: 'static + Send + Future<Output = Result<()>>,
822 M: EnvelopedMessage,
823 {
824 let prev_handler = self.handlers.insert(
825 TypeId::of::<M>(),
826 Box::new(move |envelope, session| {
827 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
828 let received_at = envelope.received_at;
829 tracing::info!("message received");
830 let start_time = Instant::now();
831 let future = (handler)(*envelope, session);
832 async move {
833 let result = future.await;
834 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
835 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
836 let queue_duration_ms = total_duration_ms - processing_duration_ms;
837 let payload_type = M::NAME;
838
839 match result {
840 Err(error) => {
841 tracing::error!(
842 ?error,
843 total_duration_ms,
844 processing_duration_ms,
845 queue_duration_ms,
846 payload_type,
847 "error handling message"
848 )
849 }
850 Ok(()) => tracing::info!(
851 total_duration_ms,
852 processing_duration_ms,
853 queue_duration_ms,
854 "finished handling message"
855 ),
856 }
857 }
858 .boxed()
859 }),
860 );
861 if prev_handler.is_some() {
862 panic!("registered a handler for the same message twice");
863 }
864 self
865 }
866
867 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
868 where
869 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
870 Fut: 'static + Send + Future<Output = Result<()>>,
871 M: EnvelopedMessage,
872 {
873 self.add_handler(move |envelope, session| handler(envelope.payload, session));
874 self
875 }
876
877 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
878 where
879 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
880 Fut: Send + Future<Output = Result<()>>,
881 M: RequestMessage,
882 {
883 let handler = Arc::new(handler);
884 self.add_handler(move |envelope, session| {
885 let receipt = envelope.receipt();
886 let handler = handler.clone();
887 async move {
888 let peer = session.peer.clone();
889 let responded = Arc::new(AtomicBool::default());
890 let response = Response {
891 peer: peer.clone(),
892 responded: responded.clone(),
893 receipt,
894 };
895 match (handler)(envelope.payload, response, session).await {
896 Ok(()) => {
897 if responded.load(std::sync::atomic::Ordering::SeqCst) {
898 Ok(())
899 } else {
900 Err(anyhow!("handler did not send a response"))?
901 }
902 }
903 Err(error) => {
904 let proto_err = match &error {
905 Error::Internal(err) => err.to_proto(),
906 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
907 };
908 peer.respond_with_error(receipt, proto_err)?;
909 Err(error)
910 }
911 }
912 }
913 })
914 }
915
916 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
917 where
918 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
919 Fut: Send + Future<Output = Result<()>>,
920 M: RequestMessage,
921 {
922 let handler = Arc::new(handler);
923 self.add_handler(move |envelope, session| {
924 let receipt = envelope.receipt();
925 let handler = handler.clone();
926 async move {
927 let peer = session.peer.clone();
928 let response = StreamingResponse {
929 peer: peer.clone(),
930 receipt,
931 };
932 match (handler)(envelope.payload, response, session).await {
933 Ok(()) => {
934 peer.end_stream(receipt)?;
935 Ok(())
936 }
937 Err(error) => {
938 let proto_err = match &error {
939 Error::Internal(err) => err.to_proto(),
940 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
941 };
942 peer.respond_with_error(receipt, proto_err)?;
943 Err(error)
944 }
945 }
946 }
947 })
948 }
949
950 #[allow(clippy::too_many_arguments)]
951 pub fn handle_connection(
952 self: &Arc<Self>,
953 connection: Connection,
954 address: String,
955 principal: Principal,
956 zed_version: ZedVersion,
957 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
958 executor: Executor,
959 ) -> impl Future<Output = ()> {
960 let this = self.clone();
961 let span = info_span!("handle connection", %address,
962 connection_id=field::Empty,
963 user_id=field::Empty,
964 login=field::Empty,
965 impersonator=field::Empty,
966 dev_server_id=field::Empty
967 );
968 principal.update_span(&span);
969
970 let mut teardown = self.teardown.subscribe();
971 async move {
972 if *teardown.borrow() {
973 tracing::error!("server is tearing down");
974 return
975 }
976 let (connection_id, handle_io, mut incoming_rx) = this
977 .peer
978 .add_connection(connection, {
979 let executor = executor.clone();
980 move |duration| executor.sleep(duration)
981 });
982 tracing::Span::current().record("connection_id", format!("{}", connection_id));
983 tracing::info!("connection opened");
984
985 let http_client = match IsahcHttpClient::new() {
986 Ok(http_client) => Arc::new(http_client),
987 Err(error) => {
988 tracing::error!(?error, "failed to create HTTP client");
989 return;
990 }
991 };
992
993 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
994 Some(Arc::new(SupermavenAdminApi::new(
995 supermaven_admin_api_key.to_string(),
996 http_client.clone(),
997 )))
998 } else {
999 None
1000 };
1001
1002 let session = Session {
1003 principal: principal.clone(),
1004 connection_id,
1005 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1006 peer: this.peer.clone(),
1007 connection_pool: this.connection_pool.clone(),
1008 live_kit_client: this.app_state.live_kit_client.clone(),
1009 http_client,
1010 rate_limiter: this.app_state.rate_limiter.clone(),
1011 _executor: executor.clone(),
1012 supermaven_client,
1013 };
1014
1015 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1016 tracing::error!(?error, "failed to send initial client update");
1017 return;
1018 }
1019
1020 let handle_io = handle_io.fuse();
1021 futures::pin_mut!(handle_io);
1022
1023 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1024 // This prevents deadlocks when e.g., client A performs a request to client B and
1025 // client B performs a request to client A. If both clients stop processing further
1026 // messages until their respective request completes, they won't have a chance to
1027 // respond to the other client's request and cause a deadlock.
1028 //
1029 // This arrangement ensures we will attempt to process earlier messages first, but fall
1030 // back to processing messages arrived later in the spirit of making progress.
1031 let mut foreground_message_handlers = FuturesUnordered::new();
1032 let concurrent_handlers = Arc::new(Semaphore::new(256));
1033 loop {
1034 let next_message = async {
1035 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1036 let message = incoming_rx.next().await;
1037 (permit, message)
1038 }.fuse();
1039 futures::pin_mut!(next_message);
1040 futures::select_biased! {
1041 _ = teardown.changed().fuse() => return,
1042 result = handle_io => {
1043 if let Err(error) = result {
1044 tracing::error!(?error, "error handling I/O");
1045 }
1046 break;
1047 }
1048 _ = foreground_message_handlers.next() => {}
1049 next_message = next_message => {
1050 let (permit, message) = next_message;
1051 if let Some(message) = message {
1052 let type_name = message.payload_type_name();
1053 // note: we copy all the fields from the parent span so we can query them in the logs.
1054 // (https://github.com/tokio-rs/tracing/issues/2670).
1055 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1056 user_id=field::Empty,
1057 login=field::Empty,
1058 impersonator=field::Empty,
1059 dev_server_id=field::Empty
1060 );
1061 principal.update_span(&span);
1062 let span_enter = span.enter();
1063 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1064 let is_background = message.is_background();
1065 let handle_message = (handler)(message, session.clone());
1066 drop(span_enter);
1067
1068 let handle_message = async move {
1069 handle_message.await;
1070 drop(permit);
1071 }.instrument(span);
1072 if is_background {
1073 executor.spawn_detached(handle_message);
1074 } else {
1075 foreground_message_handlers.push(handle_message);
1076 }
1077 } else {
1078 tracing::error!("no message handler");
1079 }
1080 } else {
1081 tracing::info!("connection closed");
1082 break;
1083 }
1084 }
1085 }
1086 }
1087
1088 drop(foreground_message_handlers);
1089 tracing::info!("signing out");
1090 if let Err(error) = connection_lost(session, teardown, executor).await {
1091 tracing::error!(?error, "error signing out");
1092 }
1093
1094 }.instrument(span)
1095 }
1096
1097 async fn send_initial_client_update(
1098 &self,
1099 connection_id: ConnectionId,
1100 principal: &Principal,
1101 zed_version: ZedVersion,
1102 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1103 session: &Session,
1104 ) -> Result<()> {
1105 self.peer.send(
1106 connection_id,
1107 proto::Hello {
1108 peer_id: Some(connection_id.into()),
1109 },
1110 )?;
1111 tracing::info!("sent hello message");
1112 if let Some(send_connection_id) = send_connection_id.take() {
1113 let _ = send_connection_id.send(connection_id);
1114 }
1115
1116 match principal {
1117 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1118 if !user.connected_once {
1119 self.peer.send(connection_id, proto::ShowContacts {})?;
1120 self.app_state
1121 .db
1122 .set_user_connected_once(user.id, true)
1123 .await?;
1124 }
1125
1126 let (contacts, dev_server_projects) = future::try_join(
1127 self.app_state.db.get_contacts(user.id),
1128 self.app_state.db.dev_server_projects_update(user.id),
1129 )
1130 .await?;
1131
1132 {
1133 let mut pool = self.connection_pool.lock();
1134 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1135 self.peer.send(
1136 connection_id,
1137 build_initial_contacts_update(contacts, &pool),
1138 )?;
1139 }
1140
1141 if should_auto_subscribe_to_channels(zed_version) {
1142 subscribe_user_to_channels(user.id, session).await?;
1143 }
1144
1145 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1146
1147 if let Some(incoming_call) =
1148 self.app_state.db.incoming_call_for_user(user.id).await?
1149 {
1150 self.peer.send(connection_id, incoming_call)?;
1151 }
1152
1153 update_user_contacts(user.id, &session).await?;
1154 }
1155 Principal::DevServer(dev_server) => {
1156 {
1157 let mut pool = self.connection_pool.lock();
1158 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1159 {
1160 self.peer.send(
1161 stale_connection_id,
1162 proto::ShutdownDevServer {
1163 reason: Some(
1164 "another dev server connected with the same token".to_string(),
1165 ),
1166 },
1167 )?;
1168 pool.remove_connection(stale_connection_id)?;
1169 };
1170 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1171 }
1172
1173 let projects = self
1174 .app_state
1175 .db
1176 .get_projects_for_dev_server(dev_server.id)
1177 .await?;
1178 self.peer
1179 .send(connection_id, proto::DevServerInstructions { projects })?;
1180
1181 let status = self
1182 .app_state
1183 .db
1184 .dev_server_projects_update(dev_server.user_id)
1185 .await?;
1186 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1187 }
1188 }
1189
1190 Ok(())
1191 }
1192
1193 pub async fn invite_code_redeemed(
1194 self: &Arc<Self>,
1195 inviter_id: UserId,
1196 invitee_id: UserId,
1197 ) -> Result<()> {
1198 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1199 if let Some(code) = &user.invite_code {
1200 let pool = self.connection_pool.lock();
1201 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1202 for connection_id in pool.user_connection_ids(inviter_id) {
1203 self.peer.send(
1204 connection_id,
1205 proto::UpdateContacts {
1206 contacts: vec![invitee_contact.clone()],
1207 ..Default::default()
1208 },
1209 )?;
1210 self.peer.send(
1211 connection_id,
1212 proto::UpdateInviteInfo {
1213 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1214 count: user.invite_count as u32,
1215 },
1216 )?;
1217 }
1218 }
1219 }
1220 Ok(())
1221 }
1222
1223 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1224 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1225 if let Some(invite_code) = &user.invite_code {
1226 let pool = self.connection_pool.lock();
1227 for connection_id in pool.user_connection_ids(user_id) {
1228 self.peer.send(
1229 connection_id,
1230 proto::UpdateInviteInfo {
1231 url: format!(
1232 "{}{}",
1233 self.app_state.config.invite_link_prefix, invite_code
1234 ),
1235 count: user.invite_count as u32,
1236 },
1237 )?;
1238 }
1239 }
1240 }
1241 Ok(())
1242 }
1243
1244 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1245 ServerSnapshot {
1246 connection_pool: ConnectionPoolGuard {
1247 guard: self.connection_pool.lock(),
1248 _not_send: PhantomData,
1249 },
1250 peer: &self.peer,
1251 }
1252 }
1253}
1254
1255impl<'a> Deref for ConnectionPoolGuard<'a> {
1256 type Target = ConnectionPool;
1257
1258 fn deref(&self) -> &Self::Target {
1259 &self.guard
1260 }
1261}
1262
1263impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1264 fn deref_mut(&mut self) -> &mut Self::Target {
1265 &mut self.guard
1266 }
1267}
1268
1269impl<'a> Drop for ConnectionPoolGuard<'a> {
1270 fn drop(&mut self) {
1271 #[cfg(test)]
1272 self.check_invariants();
1273 }
1274}
1275
1276fn broadcast<F>(
1277 sender_id: Option<ConnectionId>,
1278 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1279 mut f: F,
1280) where
1281 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1282{
1283 for receiver_id in receiver_ids {
1284 if Some(receiver_id) != sender_id {
1285 if let Err(error) = f(receiver_id) {
1286 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1287 }
1288 }
1289 }
1290}
1291
1292pub struct ProtocolVersion(u32);
1293
1294impl Header for ProtocolVersion {
1295 fn name() -> &'static HeaderName {
1296 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1298 }
1299
1300 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301 where
1302 Self: Sized,
1303 I: Iterator<Item = &'i axum::http::HeaderValue>,
1304 {
1305 let version = values
1306 .next()
1307 .ok_or_else(axum::headers::Error::invalid)?
1308 .to_str()
1309 .map_err(|_| axum::headers::Error::invalid())?
1310 .parse()
1311 .map_err(|_| axum::headers::Error::invalid())?;
1312 Ok(Self(version))
1313 }
1314
1315 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316 values.extend([self.0.to_string().parse().unwrap()]);
1317 }
1318}
1319
1320pub struct AppVersionHeader(SemanticVersion);
1321impl Header for AppVersionHeader {
1322 fn name() -> &'static HeaderName {
1323 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1324 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1325 }
1326
1327 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1328 where
1329 Self: Sized,
1330 I: Iterator<Item = &'i axum::http::HeaderValue>,
1331 {
1332 let version = values
1333 .next()
1334 .ok_or_else(axum::headers::Error::invalid)?
1335 .to_str()
1336 .map_err(|_| axum::headers::Error::invalid())?
1337 .parse()
1338 .map_err(|_| axum::headers::Error::invalid())?;
1339 Ok(Self(version))
1340 }
1341
1342 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1343 values.extend([self.0.to_string().parse().unwrap()]);
1344 }
1345}
1346
1347pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1348 Router::new()
1349 .route("/rpc", get(handle_websocket_request))
1350 .layer(
1351 ServiceBuilder::new()
1352 .layer(Extension(server.app_state.clone()))
1353 .layer(middleware::from_fn(auth::validate_header)),
1354 )
1355 .route("/metrics", get(handle_metrics))
1356 .layer(Extension(server))
1357}
1358
1359pub async fn handle_websocket_request(
1360 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1361 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1362 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1363 Extension(server): Extension<Arc<Server>>,
1364 Extension(principal): Extension<Principal>,
1365 ws: WebSocketUpgrade,
1366) -> axum::response::Response {
1367 if protocol_version != rpc::PROTOCOL_VERSION {
1368 return (
1369 StatusCode::UPGRADE_REQUIRED,
1370 "client must be upgraded".to_string(),
1371 )
1372 .into_response();
1373 }
1374
1375 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1376 return (
1377 StatusCode::UPGRADE_REQUIRED,
1378 "no version header found".to_string(),
1379 )
1380 .into_response();
1381 };
1382
1383 if !version.can_collaborate() {
1384 return (
1385 StatusCode::UPGRADE_REQUIRED,
1386 "client must be upgraded".to_string(),
1387 )
1388 .into_response();
1389 }
1390
1391 let socket_address = socket_address.to_string();
1392 ws.on_upgrade(move |socket| {
1393 let socket = socket
1394 .map_ok(to_tungstenite_message)
1395 .err_into()
1396 .with(|message| async move { Ok(to_axum_message(message)) });
1397 let connection = Connection::new(Box::pin(socket));
1398 async move {
1399 server
1400 .handle_connection(
1401 connection,
1402 socket_address,
1403 principal,
1404 version,
1405 None,
1406 Executor::Production,
1407 )
1408 .await;
1409 }
1410 })
1411}
1412
1413pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1414 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415 let connections_metric = CONNECTIONS_METRIC
1416 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1417
1418 let connections = server
1419 .connection_pool
1420 .lock()
1421 .connections()
1422 .filter(|connection| !connection.admin)
1423 .count();
1424 connections_metric.set(connections as _);
1425
1426 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1427 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1428 register_int_gauge!(
1429 "shared_projects",
1430 "number of open projects with one or more guests"
1431 )
1432 .unwrap()
1433 });
1434
1435 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1436 shared_projects_metric.set(shared_projects as _);
1437
1438 let encoder = prometheus::TextEncoder::new();
1439 let metric_families = prometheus::gather();
1440 let encoded_metrics = encoder
1441 .encode_to_string(&metric_families)
1442 .map_err(|err| anyhow!("{}", err))?;
1443 Ok(encoded_metrics)
1444}
1445
1446#[instrument(err, skip(executor))]
1447async fn connection_lost(
1448 session: Session,
1449 mut teardown: watch::Receiver<bool>,
1450 executor: Executor,
1451) -> Result<()> {
1452 session.peer.disconnect(session.connection_id);
1453 session
1454 .connection_pool()
1455 .await
1456 .remove_connection(session.connection_id)?;
1457
1458 session
1459 .db()
1460 .await
1461 .connection_lost(session.connection_id)
1462 .await
1463 .trace_err();
1464
1465 futures::select_biased! {
1466 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1467 match &session.principal {
1468 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1469 let session = session.for_user().unwrap();
1470
1471 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1472 leave_room_for_session(&session, session.connection_id).await.trace_err();
1473 leave_channel_buffers_for_session(&session)
1474 .await
1475 .trace_err();
1476
1477 if !session
1478 .connection_pool()
1479 .await
1480 .is_user_online(session.user_id())
1481 {
1482 let db = session.db().await;
1483 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1484 room_updated(&room, &session.peer);
1485 }
1486 }
1487
1488 update_user_contacts(session.user_id(), &session).await?;
1489 },
1490 Principal::DevServer(_) => {
1491 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1492 },
1493 }
1494 },
1495 _ = teardown.changed().fuse() => {}
1496 }
1497
1498 Ok(())
1499}
1500
1501/// Acknowledges a ping from a client, used to keep the connection alive.
1502async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1503 response.send(proto::Ack {})?;
1504 Ok(())
1505}
1506
1507/// Creates a new room for calling (outside of channels)
1508async fn create_room(
1509 _request: proto::CreateRoom,
1510 response: Response<proto::CreateRoom>,
1511 session: UserSession,
1512) -> Result<()> {
1513 let live_kit_room = nanoid::nanoid!(30);
1514
1515 let live_kit_connection_info = util::maybe!(async {
1516 let live_kit = session.live_kit_client.as_ref();
1517 let live_kit = live_kit?;
1518 let user_id = session.user_id().to_string();
1519
1520 let token = live_kit
1521 .room_token(&live_kit_room, &user_id.to_string())
1522 .trace_err()?;
1523
1524 Some(proto::LiveKitConnectionInfo {
1525 server_url: live_kit.url().into(),
1526 token,
1527 can_publish: true,
1528 })
1529 })
1530 .await;
1531
1532 let room = session
1533 .db()
1534 .await
1535 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1536 .await?;
1537
1538 response.send(proto::CreateRoomResponse {
1539 room: Some(room.clone()),
1540 live_kit_connection_info,
1541 })?;
1542
1543 update_user_contacts(session.user_id(), &session).await?;
1544 Ok(())
1545}
1546
1547/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1548async fn join_room(
1549 request: proto::JoinRoom,
1550 response: Response<proto::JoinRoom>,
1551 session: UserSession,
1552) -> Result<()> {
1553 let room_id = RoomId::from_proto(request.id);
1554
1555 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1556
1557 if let Some(channel_id) = channel_id {
1558 return join_channel_internal(channel_id, Box::new(response), session).await;
1559 }
1560
1561 let joined_room = {
1562 let room = session
1563 .db()
1564 .await
1565 .join_room(room_id, session.user_id(), session.connection_id)
1566 .await?;
1567 room_updated(&room.room, &session.peer);
1568 room.into_inner()
1569 };
1570
1571 for connection_id in session
1572 .connection_pool()
1573 .await
1574 .user_connection_ids(session.user_id())
1575 {
1576 session
1577 .peer
1578 .send(
1579 connection_id,
1580 proto::CallCanceled {
1581 room_id: room_id.to_proto(),
1582 },
1583 )
1584 .trace_err();
1585 }
1586
1587 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1588 if let Some(token) = live_kit
1589 .room_token(
1590 &joined_room.room.live_kit_room,
1591 &session.user_id().to_string(),
1592 )
1593 .trace_err()
1594 {
1595 Some(proto::LiveKitConnectionInfo {
1596 server_url: live_kit.url().into(),
1597 token,
1598 can_publish: true,
1599 })
1600 } else {
1601 None
1602 }
1603 } else {
1604 None
1605 };
1606
1607 response.send(proto::JoinRoomResponse {
1608 room: Some(joined_room.room),
1609 channel_id: None,
1610 live_kit_connection_info,
1611 })?;
1612
1613 update_user_contacts(session.user_id(), &session).await?;
1614 Ok(())
1615}
1616
1617/// Rejoin room is used to reconnect to a room after connection errors.
1618async fn rejoin_room(
1619 request: proto::RejoinRoom,
1620 response: Response<proto::RejoinRoom>,
1621 session: UserSession,
1622) -> Result<()> {
1623 let room;
1624 let channel;
1625 {
1626 let mut rejoined_room = session
1627 .db()
1628 .await
1629 .rejoin_room(request, session.user_id(), session.connection_id)
1630 .await?;
1631
1632 response.send(proto::RejoinRoomResponse {
1633 room: Some(rejoined_room.room.clone()),
1634 reshared_projects: rejoined_room
1635 .reshared_projects
1636 .iter()
1637 .map(|project| proto::ResharedProject {
1638 id: project.id.to_proto(),
1639 collaborators: project
1640 .collaborators
1641 .iter()
1642 .map(|collaborator| collaborator.to_proto())
1643 .collect(),
1644 })
1645 .collect(),
1646 rejoined_projects: rejoined_room
1647 .rejoined_projects
1648 .iter()
1649 .map(|rejoined_project| rejoined_project.to_proto())
1650 .collect(),
1651 })?;
1652 room_updated(&rejoined_room.room, &session.peer);
1653
1654 for project in &rejoined_room.reshared_projects {
1655 for collaborator in &project.collaborators {
1656 session
1657 .peer
1658 .send(
1659 collaborator.connection_id,
1660 proto::UpdateProjectCollaborator {
1661 project_id: project.id.to_proto(),
1662 old_peer_id: Some(project.old_connection_id.into()),
1663 new_peer_id: Some(session.connection_id.into()),
1664 },
1665 )
1666 .trace_err();
1667 }
1668
1669 broadcast(
1670 Some(session.connection_id),
1671 project
1672 .collaborators
1673 .iter()
1674 .map(|collaborator| collaborator.connection_id),
1675 |connection_id| {
1676 session.peer.forward_send(
1677 session.connection_id,
1678 connection_id,
1679 proto::UpdateProject {
1680 project_id: project.id.to_proto(),
1681 worktrees: project.worktrees.clone(),
1682 },
1683 )
1684 },
1685 );
1686 }
1687
1688 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1689
1690 let rejoined_room = rejoined_room.into_inner();
1691
1692 room = rejoined_room.room;
1693 channel = rejoined_room.channel;
1694 }
1695
1696 if let Some(channel) = channel {
1697 channel_updated(
1698 &channel,
1699 &room,
1700 &session.peer,
1701 &*session.connection_pool().await,
1702 );
1703 }
1704
1705 update_user_contacts(session.user_id(), &session).await?;
1706 Ok(())
1707}
1708
1709fn notify_rejoined_projects(
1710 rejoined_projects: &mut Vec<RejoinedProject>,
1711 session: &UserSession,
1712) -> Result<()> {
1713 for project in rejoined_projects.iter() {
1714 for collaborator in &project.collaborators {
1715 session
1716 .peer
1717 .send(
1718 collaborator.connection_id,
1719 proto::UpdateProjectCollaborator {
1720 project_id: project.id.to_proto(),
1721 old_peer_id: Some(project.old_connection_id.into()),
1722 new_peer_id: Some(session.connection_id.into()),
1723 },
1724 )
1725 .trace_err();
1726 }
1727 }
1728
1729 for project in rejoined_projects {
1730 for worktree in mem::take(&mut project.worktrees) {
1731 #[cfg(any(test, feature = "test-support"))]
1732 const MAX_CHUNK_SIZE: usize = 2;
1733 #[cfg(not(any(test, feature = "test-support")))]
1734 const MAX_CHUNK_SIZE: usize = 256;
1735
1736 // Stream this worktree's entries.
1737 let message = proto::UpdateWorktree {
1738 project_id: project.id.to_proto(),
1739 worktree_id: worktree.id,
1740 abs_path: worktree.abs_path.clone(),
1741 root_name: worktree.root_name,
1742 updated_entries: worktree.updated_entries,
1743 removed_entries: worktree.removed_entries,
1744 scan_id: worktree.scan_id,
1745 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1746 updated_repositories: worktree.updated_repositories,
1747 removed_repositories: worktree.removed_repositories,
1748 };
1749 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1750 session.peer.send(session.connection_id, update.clone())?;
1751 }
1752
1753 // Stream this worktree's diagnostics.
1754 for summary in worktree.diagnostic_summaries {
1755 session.peer.send(
1756 session.connection_id,
1757 proto::UpdateDiagnosticSummary {
1758 project_id: project.id.to_proto(),
1759 worktree_id: worktree.id,
1760 summary: Some(summary),
1761 },
1762 )?;
1763 }
1764
1765 for settings_file in worktree.settings_files {
1766 session.peer.send(
1767 session.connection_id,
1768 proto::UpdateWorktreeSettings {
1769 project_id: project.id.to_proto(),
1770 worktree_id: worktree.id,
1771 path: settings_file.path,
1772 content: Some(settings_file.content),
1773 },
1774 )?;
1775 }
1776 }
1777
1778 for language_server in &project.language_servers {
1779 session.peer.send(
1780 session.connection_id,
1781 proto::UpdateLanguageServer {
1782 project_id: project.id.to_proto(),
1783 language_server_id: language_server.id,
1784 variant: Some(
1785 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1786 proto::LspDiskBasedDiagnosticsUpdated {},
1787 ),
1788 ),
1789 },
1790 )?;
1791 }
1792 }
1793 Ok(())
1794}
1795
1796/// leave room disconnects from the room.
1797async fn leave_room(
1798 _: proto::LeaveRoom,
1799 response: Response<proto::LeaveRoom>,
1800 session: UserSession,
1801) -> Result<()> {
1802 leave_room_for_session(&session, session.connection_id).await?;
1803 response.send(proto::Ack {})?;
1804 Ok(())
1805}
1806
1807/// Updates the permissions of someone else in the room.
1808async fn set_room_participant_role(
1809 request: proto::SetRoomParticipantRole,
1810 response: Response<proto::SetRoomParticipantRole>,
1811 session: UserSession,
1812) -> Result<()> {
1813 let user_id = UserId::from_proto(request.user_id);
1814 let role = ChannelRole::from(request.role());
1815
1816 let (live_kit_room, can_publish) = {
1817 let room = session
1818 .db()
1819 .await
1820 .set_room_participant_role(
1821 session.user_id(),
1822 RoomId::from_proto(request.room_id),
1823 user_id,
1824 role,
1825 )
1826 .await?;
1827
1828 let live_kit_room = room.live_kit_room.clone();
1829 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1830 room_updated(&room, &session.peer);
1831 (live_kit_room, can_publish)
1832 };
1833
1834 if let Some(live_kit) = session.live_kit_client.as_ref() {
1835 live_kit
1836 .update_participant(
1837 live_kit_room.clone(),
1838 request.user_id.to_string(),
1839 live_kit_server::proto::ParticipantPermission {
1840 can_subscribe: true,
1841 can_publish,
1842 can_publish_data: can_publish,
1843 hidden: false,
1844 recorder: false,
1845 },
1846 )
1847 .await
1848 .trace_err();
1849 }
1850
1851 response.send(proto::Ack {})?;
1852 Ok(())
1853}
1854
1855/// Call someone else into the current room
1856async fn call(
1857 request: proto::Call,
1858 response: Response<proto::Call>,
1859 session: UserSession,
1860) -> Result<()> {
1861 let room_id = RoomId::from_proto(request.room_id);
1862 let calling_user_id = session.user_id();
1863 let calling_connection_id = session.connection_id;
1864 let called_user_id = UserId::from_proto(request.called_user_id);
1865 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1866 if !session
1867 .db()
1868 .await
1869 .has_contact(calling_user_id, called_user_id)
1870 .await?
1871 {
1872 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1873 }
1874
1875 let incoming_call = {
1876 let (room, incoming_call) = &mut *session
1877 .db()
1878 .await
1879 .call(
1880 room_id,
1881 calling_user_id,
1882 calling_connection_id,
1883 called_user_id,
1884 initial_project_id,
1885 )
1886 .await?;
1887 room_updated(&room, &session.peer);
1888 mem::take(incoming_call)
1889 };
1890 update_user_contacts(called_user_id, &session).await?;
1891
1892 let mut calls = session
1893 .connection_pool()
1894 .await
1895 .user_connection_ids(called_user_id)
1896 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1897 .collect::<FuturesUnordered<_>>();
1898
1899 while let Some(call_response) = calls.next().await {
1900 match call_response.as_ref() {
1901 Ok(_) => {
1902 response.send(proto::Ack {})?;
1903 return Ok(());
1904 }
1905 Err(_) => {
1906 call_response.trace_err();
1907 }
1908 }
1909 }
1910
1911 {
1912 let room = session
1913 .db()
1914 .await
1915 .call_failed(room_id, called_user_id)
1916 .await?;
1917 room_updated(&room, &session.peer);
1918 }
1919 update_user_contacts(called_user_id, &session).await?;
1920
1921 Err(anyhow!("failed to ring user"))?
1922}
1923
1924/// Cancel an outgoing call.
1925async fn cancel_call(
1926 request: proto::CancelCall,
1927 response: Response<proto::CancelCall>,
1928 session: UserSession,
1929) -> Result<()> {
1930 let called_user_id = UserId::from_proto(request.called_user_id);
1931 let room_id = RoomId::from_proto(request.room_id);
1932 {
1933 let room = session
1934 .db()
1935 .await
1936 .cancel_call(room_id, session.connection_id, called_user_id)
1937 .await?;
1938 room_updated(&room, &session.peer);
1939 }
1940
1941 for connection_id in session
1942 .connection_pool()
1943 .await
1944 .user_connection_ids(called_user_id)
1945 {
1946 session
1947 .peer
1948 .send(
1949 connection_id,
1950 proto::CallCanceled {
1951 room_id: room_id.to_proto(),
1952 },
1953 )
1954 .trace_err();
1955 }
1956 response.send(proto::Ack {})?;
1957
1958 update_user_contacts(called_user_id, &session).await?;
1959 Ok(())
1960}
1961
1962/// Decline an incoming call.
1963async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1964 let room_id = RoomId::from_proto(message.room_id);
1965 {
1966 let room = session
1967 .db()
1968 .await
1969 .decline_call(Some(room_id), session.user_id())
1970 .await?
1971 .ok_or_else(|| anyhow!("failed to decline call"))?;
1972 room_updated(&room, &session.peer);
1973 }
1974
1975 for connection_id in session
1976 .connection_pool()
1977 .await
1978 .user_connection_ids(session.user_id())
1979 {
1980 session
1981 .peer
1982 .send(
1983 connection_id,
1984 proto::CallCanceled {
1985 room_id: room_id.to_proto(),
1986 },
1987 )
1988 .trace_err();
1989 }
1990 update_user_contacts(session.user_id(), &session).await?;
1991 Ok(())
1992}
1993
1994/// Updates other participants in the room with your current location.
1995async fn update_participant_location(
1996 request: proto::UpdateParticipantLocation,
1997 response: Response<proto::UpdateParticipantLocation>,
1998 session: UserSession,
1999) -> Result<()> {
2000 let room_id = RoomId::from_proto(request.room_id);
2001 let location = request
2002 .location
2003 .ok_or_else(|| anyhow!("invalid location"))?;
2004
2005 let db = session.db().await;
2006 let room = db
2007 .update_room_participant_location(room_id, session.connection_id, location)
2008 .await?;
2009
2010 room_updated(&room, &session.peer);
2011 response.send(proto::Ack {})?;
2012 Ok(())
2013}
2014
2015/// Share a project into the room.
2016async fn share_project(
2017 request: proto::ShareProject,
2018 response: Response<proto::ShareProject>,
2019 session: UserSession,
2020) -> Result<()> {
2021 let (project_id, room) = &*session
2022 .db()
2023 .await
2024 .share_project(
2025 RoomId::from_proto(request.room_id),
2026 session.connection_id,
2027 &request.worktrees,
2028 request
2029 .dev_server_project_id
2030 .map(|id| DevServerProjectId::from_proto(id)),
2031 )
2032 .await?;
2033 response.send(proto::ShareProjectResponse {
2034 project_id: project_id.to_proto(),
2035 })?;
2036 room_updated(&room, &session.peer);
2037
2038 Ok(())
2039}
2040
2041/// Unshare a project from the room.
2042async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2043 let project_id = ProjectId::from_proto(message.project_id);
2044 unshare_project_internal(
2045 project_id,
2046 session.connection_id,
2047 session.user_id(),
2048 &session,
2049 )
2050 .await
2051}
2052
2053async fn unshare_project_internal(
2054 project_id: ProjectId,
2055 connection_id: ConnectionId,
2056 user_id: Option<UserId>,
2057 session: &Session,
2058) -> Result<()> {
2059 let delete = {
2060 let room_guard = session
2061 .db()
2062 .await
2063 .unshare_project(project_id, connection_id, user_id)
2064 .await?;
2065
2066 let (delete, room, guest_connection_ids) = &*room_guard;
2067
2068 let message = proto::UnshareProject {
2069 project_id: project_id.to_proto(),
2070 };
2071
2072 broadcast(
2073 Some(connection_id),
2074 guest_connection_ids.iter().copied(),
2075 |conn_id| session.peer.send(conn_id, message.clone()),
2076 );
2077 if let Some(room) = room {
2078 room_updated(room, &session.peer);
2079 }
2080
2081 *delete
2082 };
2083
2084 if delete {
2085 let db = session.db().await;
2086 db.delete_project(project_id).await?;
2087 }
2088
2089 Ok(())
2090}
2091
2092/// DevServer makes a project available online
2093async fn share_dev_server_project(
2094 request: proto::ShareDevServerProject,
2095 response: Response<proto::ShareDevServerProject>,
2096 session: DevServerSession,
2097) -> Result<()> {
2098 let (dev_server_project, user_id, status) = session
2099 .db()
2100 .await
2101 .share_dev_server_project(
2102 DevServerProjectId::from_proto(request.dev_server_project_id),
2103 session.dev_server_id(),
2104 session.connection_id,
2105 &request.worktrees,
2106 )
2107 .await?;
2108 let Some(project_id) = dev_server_project.project_id else {
2109 return Err(anyhow!("failed to share remote project"))?;
2110 };
2111
2112 send_dev_server_projects_update(user_id, status, &session).await;
2113
2114 response.send(proto::ShareProjectResponse { project_id })?;
2115
2116 Ok(())
2117}
2118
2119/// Join someone elses shared project.
2120async fn join_project(
2121 request: proto::JoinProject,
2122 response: Response<proto::JoinProject>,
2123 session: UserSession,
2124) -> Result<()> {
2125 let project_id = ProjectId::from_proto(request.project_id);
2126
2127 tracing::info!(%project_id, "join project");
2128
2129 let db = session.db().await;
2130 let (project, replica_id) = &mut *db
2131 .join_project(project_id, session.connection_id, session.user_id())
2132 .await?;
2133 drop(db);
2134 tracing::info!(%project_id, "join remote project");
2135 join_project_internal(response, session, project, replica_id)
2136}
2137
2138trait JoinProjectInternalResponse {
2139 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2140}
2141impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2142 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2143 Response::<proto::JoinProject>::send(self, result)
2144 }
2145}
2146impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2147 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2148 Response::<proto::JoinHostedProject>::send(self, result)
2149 }
2150}
2151
2152fn join_project_internal(
2153 response: impl JoinProjectInternalResponse,
2154 session: UserSession,
2155 project: &mut Project,
2156 replica_id: &ReplicaId,
2157) -> Result<()> {
2158 let collaborators = project
2159 .collaborators
2160 .iter()
2161 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2162 .map(|collaborator| collaborator.to_proto())
2163 .collect::<Vec<_>>();
2164 let project_id = project.id;
2165 let guest_user_id = session.user_id();
2166
2167 let worktrees = project
2168 .worktrees
2169 .iter()
2170 .map(|(id, worktree)| proto::WorktreeMetadata {
2171 id: *id,
2172 root_name: worktree.root_name.clone(),
2173 visible: worktree.visible,
2174 abs_path: worktree.abs_path.clone(),
2175 })
2176 .collect::<Vec<_>>();
2177
2178 let add_project_collaborator = proto::AddProjectCollaborator {
2179 project_id: project_id.to_proto(),
2180 collaborator: Some(proto::Collaborator {
2181 peer_id: Some(session.connection_id.into()),
2182 replica_id: replica_id.0 as u32,
2183 user_id: guest_user_id.to_proto(),
2184 }),
2185 };
2186
2187 for collaborator in &collaborators {
2188 session
2189 .peer
2190 .send(
2191 collaborator.peer_id.unwrap().into(),
2192 add_project_collaborator.clone(),
2193 )
2194 .trace_err();
2195 }
2196
2197 // First, we send the metadata associated with each worktree.
2198 response.send(proto::JoinProjectResponse {
2199 project_id: project.id.0 as u64,
2200 worktrees: worktrees.clone(),
2201 replica_id: replica_id.0 as u32,
2202 collaborators: collaborators.clone(),
2203 language_servers: project.language_servers.clone(),
2204 role: project.role.into(),
2205 dev_server_project_id: project
2206 .dev_server_project_id
2207 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2208 })?;
2209
2210 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2211 #[cfg(any(test, feature = "test-support"))]
2212 const MAX_CHUNK_SIZE: usize = 2;
2213 #[cfg(not(any(test, feature = "test-support")))]
2214 const MAX_CHUNK_SIZE: usize = 256;
2215
2216 // Stream this worktree's entries.
2217 let message = proto::UpdateWorktree {
2218 project_id: project_id.to_proto(),
2219 worktree_id,
2220 abs_path: worktree.abs_path.clone(),
2221 root_name: worktree.root_name,
2222 updated_entries: worktree.entries,
2223 removed_entries: Default::default(),
2224 scan_id: worktree.scan_id,
2225 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2226 updated_repositories: worktree.repository_entries.into_values().collect(),
2227 removed_repositories: Default::default(),
2228 };
2229 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2230 session.peer.send(session.connection_id, update.clone())?;
2231 }
2232
2233 // Stream this worktree's diagnostics.
2234 for summary in worktree.diagnostic_summaries {
2235 session.peer.send(
2236 session.connection_id,
2237 proto::UpdateDiagnosticSummary {
2238 project_id: project_id.to_proto(),
2239 worktree_id: worktree.id,
2240 summary: Some(summary),
2241 },
2242 )?;
2243 }
2244
2245 for settings_file in worktree.settings_files {
2246 session.peer.send(
2247 session.connection_id,
2248 proto::UpdateWorktreeSettings {
2249 project_id: project_id.to_proto(),
2250 worktree_id: worktree.id,
2251 path: settings_file.path,
2252 content: Some(settings_file.content),
2253 },
2254 )?;
2255 }
2256 }
2257
2258 for language_server in &project.language_servers {
2259 session.peer.send(
2260 session.connection_id,
2261 proto::UpdateLanguageServer {
2262 project_id: project_id.to_proto(),
2263 language_server_id: language_server.id,
2264 variant: Some(
2265 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2266 proto::LspDiskBasedDiagnosticsUpdated {},
2267 ),
2268 ),
2269 },
2270 )?;
2271 }
2272
2273 Ok(())
2274}
2275
2276/// Leave someone elses shared project.
2277async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2278 let sender_id = session.connection_id;
2279 let project_id = ProjectId::from_proto(request.project_id);
2280 let db = session.db().await;
2281 if db.is_hosted_project(project_id).await? {
2282 let project = db.leave_hosted_project(project_id, sender_id).await?;
2283 project_left(&project, &session);
2284 return Ok(());
2285 }
2286
2287 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2288 tracing::info!(
2289 %project_id,
2290 "leave project"
2291 );
2292
2293 project_left(&project, &session);
2294 if let Some(room) = room {
2295 room_updated(&room, &session.peer);
2296 }
2297
2298 Ok(())
2299}
2300
2301async fn join_hosted_project(
2302 request: proto::JoinHostedProject,
2303 response: Response<proto::JoinHostedProject>,
2304 session: UserSession,
2305) -> Result<()> {
2306 let (mut project, replica_id) = session
2307 .db()
2308 .await
2309 .join_hosted_project(
2310 ProjectId(request.project_id as i32),
2311 session.user_id(),
2312 session.connection_id,
2313 )
2314 .await?;
2315
2316 join_project_internal(response, session, &mut project, &replica_id)
2317}
2318
2319async fn create_dev_server_project(
2320 request: proto::CreateDevServerProject,
2321 response: Response<proto::CreateDevServerProject>,
2322 session: UserSession,
2323) -> Result<()> {
2324 let dev_server_id = DevServerId(request.dev_server_id as i32);
2325 let dev_server_connection_id = session
2326 .connection_pool()
2327 .await
2328 .dev_server_connection_id(dev_server_id);
2329 let Some(dev_server_connection_id) = dev_server_connection_id else {
2330 Err(ErrorCode::DevServerOffline
2331 .message("Cannot create a remote project when the dev server is offline".to_string())
2332 .anyhow())?
2333 };
2334
2335 let path = request.path.clone();
2336 //Check that the path exists on the dev server
2337 session
2338 .peer
2339 .forward_request(
2340 session.connection_id,
2341 dev_server_connection_id,
2342 proto::ValidateDevServerProjectRequest { path: path.clone() },
2343 )
2344 .await?;
2345
2346 let (dev_server_project, update) = session
2347 .db()
2348 .await
2349 .create_dev_server_project(
2350 DevServerId(request.dev_server_id as i32),
2351 &request.path,
2352 session.user_id(),
2353 )
2354 .await?;
2355
2356 let projects = session
2357 .db()
2358 .await
2359 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2360 .await?;
2361
2362 session.peer.send(
2363 dev_server_connection_id,
2364 proto::DevServerInstructions { projects },
2365 )?;
2366
2367 send_dev_server_projects_update(session.user_id(), update, &session).await;
2368
2369 response.send(proto::CreateDevServerProjectResponse {
2370 dev_server_project: Some(dev_server_project.to_proto(None)),
2371 })?;
2372 Ok(())
2373}
2374
2375async fn create_dev_server(
2376 request: proto::CreateDevServer,
2377 response: Response<proto::CreateDevServer>,
2378 session: UserSession,
2379) -> Result<()> {
2380 let access_token = auth::random_token();
2381 let hashed_access_token = auth::hash_access_token(&access_token);
2382
2383 if request.name.is_empty() {
2384 return Err(proto::ErrorCode::Forbidden
2385 .message("Dev server name cannot be empty".to_string())
2386 .anyhow())?;
2387 }
2388
2389 let (dev_server, status) = session
2390 .db()
2391 .await
2392 .create_dev_server(
2393 &request.name,
2394 request.ssh_connection_string.as_deref(),
2395 &hashed_access_token,
2396 session.user_id(),
2397 )
2398 .await?;
2399
2400 send_dev_server_projects_update(session.user_id(), status, &session).await;
2401
2402 response.send(proto::CreateDevServerResponse {
2403 dev_server_id: dev_server.id.0 as u64,
2404 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2405 name: request.name,
2406 })?;
2407 Ok(())
2408}
2409
2410async fn regenerate_dev_server_token(
2411 request: proto::RegenerateDevServerToken,
2412 response: Response<proto::RegenerateDevServerToken>,
2413 session: UserSession,
2414) -> Result<()> {
2415 let dev_server_id = DevServerId(request.dev_server_id as i32);
2416 let access_token = auth::random_token();
2417 let hashed_access_token = auth::hash_access_token(&access_token);
2418
2419 let connection_id = session
2420 .connection_pool()
2421 .await
2422 .dev_server_connection_id(dev_server_id);
2423 if let Some(connection_id) = connection_id {
2424 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2425 session.peer.send(
2426 connection_id,
2427 proto::ShutdownDevServer {
2428 reason: Some("dev server token was regenerated".to_string()),
2429 },
2430 )?;
2431 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2432 }
2433
2434 let status = session
2435 .db()
2436 .await
2437 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2438 .await?;
2439
2440 send_dev_server_projects_update(session.user_id(), status, &session).await;
2441
2442 response.send(proto::RegenerateDevServerTokenResponse {
2443 dev_server_id: dev_server_id.to_proto(),
2444 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2445 })?;
2446 Ok(())
2447}
2448
2449async fn rename_dev_server(
2450 request: proto::RenameDevServer,
2451 response: Response<proto::RenameDevServer>,
2452 session: UserSession,
2453) -> Result<()> {
2454 if request.name.trim().is_empty() {
2455 return Err(proto::ErrorCode::Forbidden
2456 .message("Dev server name cannot be empty".to_string())
2457 .anyhow())?;
2458 }
2459
2460 let dev_server_id = DevServerId(request.dev_server_id as i32);
2461 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2462 if dev_server.user_id != session.user_id() {
2463 return Err(anyhow!(ErrorCode::Forbidden))?;
2464 }
2465
2466 let status = session
2467 .db()
2468 .await
2469 .rename_dev_server(
2470 dev_server_id,
2471 &request.name,
2472 request.ssh_connection_string.as_deref(),
2473 session.user_id(),
2474 )
2475 .await?;
2476
2477 send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479 response.send(proto::Ack {})?;
2480 Ok(())
2481}
2482
2483async fn delete_dev_server(
2484 request: proto::DeleteDevServer,
2485 response: Response<proto::DeleteDevServer>,
2486 session: UserSession,
2487) -> Result<()> {
2488 let dev_server_id = DevServerId(request.dev_server_id as i32);
2489 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2490 if dev_server.user_id != session.user_id() {
2491 return Err(anyhow!(ErrorCode::Forbidden))?;
2492 }
2493
2494 let connection_id = session
2495 .connection_pool()
2496 .await
2497 .dev_server_connection_id(dev_server_id);
2498 if let Some(connection_id) = connection_id {
2499 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2500 session.peer.send(
2501 connection_id,
2502 proto::ShutdownDevServer {
2503 reason: Some("dev server was deleted".to_string()),
2504 },
2505 )?;
2506 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2507 }
2508
2509 let status = session
2510 .db()
2511 .await
2512 .delete_dev_server(dev_server_id, session.user_id())
2513 .await?;
2514
2515 send_dev_server_projects_update(session.user_id(), status, &session).await;
2516
2517 response.send(proto::Ack {})?;
2518 Ok(())
2519}
2520
2521async fn delete_dev_server_project(
2522 request: proto::DeleteDevServerProject,
2523 response: Response<proto::DeleteDevServerProject>,
2524 session: UserSession,
2525) -> Result<()> {
2526 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2527 let dev_server_project = session
2528 .db()
2529 .await
2530 .get_dev_server_project(dev_server_project_id)
2531 .await?;
2532
2533 let dev_server = session
2534 .db()
2535 .await
2536 .get_dev_server(dev_server_project.dev_server_id)
2537 .await?;
2538 if dev_server.user_id != session.user_id() {
2539 return Err(anyhow!(ErrorCode::Forbidden))?;
2540 }
2541
2542 let dev_server_connection_id = session
2543 .connection_pool()
2544 .await
2545 .dev_server_connection_id(dev_server.id);
2546
2547 if let Some(dev_server_connection_id) = dev_server_connection_id {
2548 let project = session
2549 .db()
2550 .await
2551 .find_dev_server_project(dev_server_project_id)
2552 .await;
2553 if let Ok(project) = project {
2554 unshare_project_internal(
2555 project.id,
2556 dev_server_connection_id,
2557 Some(session.user_id()),
2558 &session,
2559 )
2560 .await?;
2561 }
2562 }
2563
2564 let (projects, status) = session
2565 .db()
2566 .await
2567 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2568 .await?;
2569
2570 if let Some(dev_server_connection_id) = dev_server_connection_id {
2571 session.peer.send(
2572 dev_server_connection_id,
2573 proto::DevServerInstructions { projects },
2574 )?;
2575 }
2576
2577 send_dev_server_projects_update(session.user_id(), status, &session).await;
2578
2579 response.send(proto::Ack {})?;
2580 Ok(())
2581}
2582
2583async fn rejoin_dev_server_projects(
2584 request: proto::RejoinRemoteProjects,
2585 response: Response<proto::RejoinRemoteProjects>,
2586 session: UserSession,
2587) -> Result<()> {
2588 let mut rejoined_projects = {
2589 let db = session.db().await;
2590 db.rejoin_dev_server_projects(
2591 &request.rejoined_projects,
2592 session.user_id(),
2593 session.0.connection_id,
2594 )
2595 .await?
2596 };
2597 response.send(proto::RejoinRemoteProjectsResponse {
2598 rejoined_projects: rejoined_projects
2599 .iter()
2600 .map(|project| project.to_proto())
2601 .collect(),
2602 })?;
2603 notify_rejoined_projects(&mut rejoined_projects, &session)
2604}
2605
2606async fn reconnect_dev_server(
2607 request: proto::ReconnectDevServer,
2608 response: Response<proto::ReconnectDevServer>,
2609 session: DevServerSession,
2610) -> Result<()> {
2611 let reshared_projects = {
2612 let db = session.db().await;
2613 db.reshare_dev_server_projects(
2614 &request.reshared_projects,
2615 session.dev_server_id(),
2616 session.0.connection_id,
2617 )
2618 .await?
2619 };
2620
2621 for project in &reshared_projects {
2622 for collaborator in &project.collaborators {
2623 session
2624 .peer
2625 .send(
2626 collaborator.connection_id,
2627 proto::UpdateProjectCollaborator {
2628 project_id: project.id.to_proto(),
2629 old_peer_id: Some(project.old_connection_id.into()),
2630 new_peer_id: Some(session.connection_id.into()),
2631 },
2632 )
2633 .trace_err();
2634 }
2635
2636 broadcast(
2637 Some(session.connection_id),
2638 project
2639 .collaborators
2640 .iter()
2641 .map(|collaborator| collaborator.connection_id),
2642 |connection_id| {
2643 session.peer.forward_send(
2644 session.connection_id,
2645 connection_id,
2646 proto::UpdateProject {
2647 project_id: project.id.to_proto(),
2648 worktrees: project.worktrees.clone(),
2649 },
2650 )
2651 },
2652 );
2653 }
2654
2655 response.send(proto::ReconnectDevServerResponse {
2656 reshared_projects: reshared_projects
2657 .iter()
2658 .map(|project| proto::ResharedProject {
2659 id: project.id.to_proto(),
2660 collaborators: project
2661 .collaborators
2662 .iter()
2663 .map(|collaborator| collaborator.to_proto())
2664 .collect(),
2665 })
2666 .collect(),
2667 })?;
2668
2669 Ok(())
2670}
2671
2672async fn shutdown_dev_server(
2673 _: proto::ShutdownDevServer,
2674 response: Response<proto::ShutdownDevServer>,
2675 session: DevServerSession,
2676) -> Result<()> {
2677 response.send(proto::Ack {})?;
2678 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2679 remove_dev_server_connection(session.dev_server_id(), &session).await
2680}
2681
2682async fn shutdown_dev_server_internal(
2683 dev_server_id: DevServerId,
2684 connection_id: ConnectionId,
2685 session: &Session,
2686) -> Result<()> {
2687 let (dev_server_projects, dev_server) = {
2688 let db = session.db().await;
2689 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2690 let dev_server = db.get_dev_server(dev_server_id).await?;
2691 (dev_server_projects, dev_server)
2692 };
2693
2694 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2695 unshare_project_internal(
2696 ProjectId::from_proto(project_id),
2697 connection_id,
2698 None,
2699 session,
2700 )
2701 .await?;
2702 }
2703
2704 session
2705 .connection_pool()
2706 .await
2707 .set_dev_server_offline(dev_server_id);
2708
2709 let status = session
2710 .db()
2711 .await
2712 .dev_server_projects_update(dev_server.user_id)
2713 .await?;
2714 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2715
2716 Ok(())
2717}
2718
2719async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2720 let dev_server_connection = session
2721 .connection_pool()
2722 .await
2723 .dev_server_connection_id(dev_server_id);
2724
2725 if let Some(dev_server_connection) = dev_server_connection {
2726 session
2727 .connection_pool()
2728 .await
2729 .remove_connection(dev_server_connection)?;
2730 }
2731 Ok(())
2732}
2733
2734/// Updates other participants with changes to the project
2735async fn update_project(
2736 request: proto::UpdateProject,
2737 response: Response<proto::UpdateProject>,
2738 session: Session,
2739) -> Result<()> {
2740 let project_id = ProjectId::from_proto(request.project_id);
2741 let (room, guest_connection_ids) = &*session
2742 .db()
2743 .await
2744 .update_project(project_id, session.connection_id, &request.worktrees)
2745 .await?;
2746 broadcast(
2747 Some(session.connection_id),
2748 guest_connection_ids.iter().copied(),
2749 |connection_id| {
2750 session
2751 .peer
2752 .forward_send(session.connection_id, connection_id, request.clone())
2753 },
2754 );
2755 if let Some(room) = room {
2756 room_updated(&room, &session.peer);
2757 }
2758 response.send(proto::Ack {})?;
2759
2760 Ok(())
2761}
2762
2763/// Updates other participants with changes to the worktree
2764async fn update_worktree(
2765 request: proto::UpdateWorktree,
2766 response: Response<proto::UpdateWorktree>,
2767 session: Session,
2768) -> Result<()> {
2769 let guest_connection_ids = session
2770 .db()
2771 .await
2772 .update_worktree(&request, session.connection_id)
2773 .await?;
2774
2775 broadcast(
2776 Some(session.connection_id),
2777 guest_connection_ids.iter().copied(),
2778 |connection_id| {
2779 session
2780 .peer
2781 .forward_send(session.connection_id, connection_id, request.clone())
2782 },
2783 );
2784 response.send(proto::Ack {})?;
2785 Ok(())
2786}
2787
2788/// Updates other participants with changes to the diagnostics
2789async fn update_diagnostic_summary(
2790 message: proto::UpdateDiagnosticSummary,
2791 session: Session,
2792) -> Result<()> {
2793 let guest_connection_ids = session
2794 .db()
2795 .await
2796 .update_diagnostic_summary(&message, session.connection_id)
2797 .await?;
2798
2799 broadcast(
2800 Some(session.connection_id),
2801 guest_connection_ids.iter().copied(),
2802 |connection_id| {
2803 session
2804 .peer
2805 .forward_send(session.connection_id, connection_id, message.clone())
2806 },
2807 );
2808
2809 Ok(())
2810}
2811
2812/// Updates other participants with changes to the worktree settings
2813async fn update_worktree_settings(
2814 message: proto::UpdateWorktreeSettings,
2815 session: Session,
2816) -> Result<()> {
2817 let guest_connection_ids = session
2818 .db()
2819 .await
2820 .update_worktree_settings(&message, session.connection_id)
2821 .await?;
2822
2823 broadcast(
2824 Some(session.connection_id),
2825 guest_connection_ids.iter().copied(),
2826 |connection_id| {
2827 session
2828 .peer
2829 .forward_send(session.connection_id, connection_id, message.clone())
2830 },
2831 );
2832
2833 Ok(())
2834}
2835
2836/// Notify other participants that a language server has started.
2837async fn start_language_server(
2838 request: proto::StartLanguageServer,
2839 session: Session,
2840) -> Result<()> {
2841 let guest_connection_ids = session
2842 .db()
2843 .await
2844 .start_language_server(&request, session.connection_id)
2845 .await?;
2846
2847 broadcast(
2848 Some(session.connection_id),
2849 guest_connection_ids.iter().copied(),
2850 |connection_id| {
2851 session
2852 .peer
2853 .forward_send(session.connection_id, connection_id, request.clone())
2854 },
2855 );
2856 Ok(())
2857}
2858
2859/// Notify other participants that a language server has changed.
2860async fn update_language_server(
2861 request: proto::UpdateLanguageServer,
2862 session: Session,
2863) -> Result<()> {
2864 let project_id = ProjectId::from_proto(request.project_id);
2865 let project_connection_ids = session
2866 .db()
2867 .await
2868 .project_connection_ids(project_id, session.connection_id, true)
2869 .await?;
2870 broadcast(
2871 Some(session.connection_id),
2872 project_connection_ids.iter().copied(),
2873 |connection_id| {
2874 session
2875 .peer
2876 .forward_send(session.connection_id, connection_id, request.clone())
2877 },
2878 );
2879 Ok(())
2880}
2881
2882/// forward a project request to the host. These requests should be read only
2883/// as guests are allowed to send them.
2884async fn forward_read_only_project_request<T>(
2885 request: T,
2886 response: Response<T>,
2887 session: UserSession,
2888) -> Result<()>
2889where
2890 T: EntityMessage + RequestMessage,
2891{
2892 let project_id = ProjectId::from_proto(request.remote_entity_id());
2893 let host_connection_id = session
2894 .db()
2895 .await
2896 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2897 .await?;
2898 let payload = session
2899 .peer
2900 .forward_request(session.connection_id, host_connection_id, request)
2901 .await?;
2902 response.send(payload)?;
2903 Ok(())
2904}
2905
2906/// forward a project request to the dev server. Only allowed
2907/// if it's your dev server.
2908async fn forward_project_request_for_owner<T>(
2909 request: T,
2910 response: Response<T>,
2911 session: UserSession,
2912) -> Result<()>
2913where
2914 T: EntityMessage + RequestMessage,
2915{
2916 let project_id = ProjectId::from_proto(request.remote_entity_id());
2917
2918 let host_connection_id = session
2919 .db()
2920 .await
2921 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2922 .await?;
2923 let payload = session
2924 .peer
2925 .forward_request(session.connection_id, host_connection_id, request)
2926 .await?;
2927 response.send(payload)?;
2928 Ok(())
2929}
2930
2931/// forward a project request to the host. These requests are disallowed
2932/// for guests.
2933async fn forward_mutating_project_request<T>(
2934 request: T,
2935 response: Response<T>,
2936 session: UserSession,
2937) -> Result<()>
2938where
2939 T: EntityMessage + RequestMessage,
2940{
2941 let project_id = ProjectId::from_proto(request.remote_entity_id());
2942
2943 let host_connection_id = session
2944 .db()
2945 .await
2946 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2947 .await?;
2948 let payload = session
2949 .peer
2950 .forward_request(session.connection_id, host_connection_id, request)
2951 .await?;
2952 response.send(payload)?;
2953 Ok(())
2954}
2955
2956/// forward a project request to the host. These requests are disallowed
2957/// for guests.
2958async fn forward_versioned_mutating_project_request<T>(
2959 request: T,
2960 response: Response<T>,
2961 session: UserSession,
2962) -> Result<()>
2963where
2964 T: EntityMessage + RequestMessage + VersionedMessage,
2965{
2966 let project_id = ProjectId::from_proto(request.remote_entity_id());
2967
2968 let host_connection_id = session
2969 .db()
2970 .await
2971 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2972 .await?;
2973 if let Some(host_version) = session
2974 .connection_pool()
2975 .await
2976 .connection(host_connection_id)
2977 .map(|c| c.zed_version)
2978 {
2979 if let Some(min_required_version) = request.required_host_version() {
2980 if min_required_version > host_version {
2981 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2982 .with_tag("required", &min_required_version.to_string())))?;
2983 }
2984 }
2985 }
2986
2987 let payload = session
2988 .peer
2989 .forward_request(session.connection_id, host_connection_id, request)
2990 .await?;
2991 response.send(payload)?;
2992 Ok(())
2993}
2994
2995/// Notify other participants that a new buffer has been created
2996async fn create_buffer_for_peer(
2997 request: proto::CreateBufferForPeer,
2998 session: Session,
2999) -> Result<()> {
3000 session
3001 .db()
3002 .await
3003 .check_user_is_project_host(
3004 ProjectId::from_proto(request.project_id),
3005 session.connection_id,
3006 )
3007 .await?;
3008 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3009 session
3010 .peer
3011 .forward_send(session.connection_id, peer_id.into(), request)?;
3012 Ok(())
3013}
3014
3015/// Notify other participants that a buffer has been updated. This is
3016/// allowed for guests as long as the update is limited to selections.
3017async fn update_buffer(
3018 request: proto::UpdateBuffer,
3019 response: Response<proto::UpdateBuffer>,
3020 session: Session,
3021) -> Result<()> {
3022 let project_id = ProjectId::from_proto(request.project_id);
3023 let mut capability = Capability::ReadOnly;
3024
3025 for op in request.operations.iter() {
3026 match op.variant {
3027 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3028 Some(_) => capability = Capability::ReadWrite,
3029 }
3030 }
3031
3032 let host = {
3033 let guard = session
3034 .db()
3035 .await
3036 .connections_for_buffer_update(
3037 project_id,
3038 session.principal_id(),
3039 session.connection_id,
3040 capability,
3041 )
3042 .await?;
3043
3044 let (host, guests) = &*guard;
3045
3046 broadcast(
3047 Some(session.connection_id),
3048 guests.clone(),
3049 |connection_id| {
3050 session
3051 .peer
3052 .forward_send(session.connection_id, connection_id, request.clone())
3053 },
3054 );
3055
3056 *host
3057 };
3058
3059 if host != session.connection_id {
3060 session
3061 .peer
3062 .forward_request(session.connection_id, host, request.clone())
3063 .await?;
3064 }
3065
3066 response.send(proto::Ack {})?;
3067 Ok(())
3068}
3069
3070async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3071 let project_id = ProjectId::from_proto(message.project_id);
3072
3073 let operation = message.operation.as_ref().context("invalid operation")?;
3074 let capability = match operation.variant.as_ref() {
3075 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3076 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3077 match buffer_op.variant {
3078 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3079 Capability::ReadOnly
3080 }
3081 _ => Capability::ReadWrite,
3082 }
3083 } else {
3084 Capability::ReadWrite
3085 }
3086 }
3087 Some(_) => Capability::ReadWrite,
3088 None => Capability::ReadOnly,
3089 };
3090
3091 let guard = session
3092 .db()
3093 .await
3094 .connections_for_buffer_update(
3095 project_id,
3096 session.principal_id(),
3097 session.connection_id,
3098 capability,
3099 )
3100 .await?;
3101
3102 let (host, guests) = &*guard;
3103
3104 broadcast(
3105 Some(session.connection_id),
3106 guests.iter().chain([host]).copied(),
3107 |connection_id| {
3108 session
3109 .peer
3110 .forward_send(session.connection_id, connection_id, message.clone())
3111 },
3112 );
3113
3114 Ok(())
3115}
3116
3117/// Notify other participants that a project has been updated.
3118async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3119 request: T,
3120 session: Session,
3121) -> Result<()> {
3122 let project_id = ProjectId::from_proto(request.remote_entity_id());
3123 let project_connection_ids = session
3124 .db()
3125 .await
3126 .project_connection_ids(project_id, session.connection_id, false)
3127 .await?;
3128
3129 broadcast(
3130 Some(session.connection_id),
3131 project_connection_ids.iter().copied(),
3132 |connection_id| {
3133 session
3134 .peer
3135 .forward_send(session.connection_id, connection_id, request.clone())
3136 },
3137 );
3138 Ok(())
3139}
3140
3141/// Start following another user in a call.
3142async fn follow(
3143 request: proto::Follow,
3144 response: Response<proto::Follow>,
3145 session: UserSession,
3146) -> Result<()> {
3147 let room_id = RoomId::from_proto(request.room_id);
3148 let project_id = request.project_id.map(ProjectId::from_proto);
3149 let leader_id = request
3150 .leader_id
3151 .ok_or_else(|| anyhow!("invalid leader id"))?
3152 .into();
3153 let follower_id = session.connection_id;
3154
3155 session
3156 .db()
3157 .await
3158 .check_room_participants(room_id, leader_id, session.connection_id)
3159 .await?;
3160
3161 let response_payload = session
3162 .peer
3163 .forward_request(session.connection_id, leader_id, request)
3164 .await?;
3165 response.send(response_payload)?;
3166
3167 if let Some(project_id) = project_id {
3168 let room = session
3169 .db()
3170 .await
3171 .follow(room_id, project_id, leader_id, follower_id)
3172 .await?;
3173 room_updated(&room, &session.peer);
3174 }
3175
3176 Ok(())
3177}
3178
3179/// Stop following another user in a call.
3180async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3181 let room_id = RoomId::from_proto(request.room_id);
3182 let project_id = request.project_id.map(ProjectId::from_proto);
3183 let leader_id = request
3184 .leader_id
3185 .ok_or_else(|| anyhow!("invalid leader id"))?
3186 .into();
3187 let follower_id = session.connection_id;
3188
3189 session
3190 .db()
3191 .await
3192 .check_room_participants(room_id, leader_id, session.connection_id)
3193 .await?;
3194
3195 session
3196 .peer
3197 .forward_send(session.connection_id, leader_id, request)?;
3198
3199 if let Some(project_id) = project_id {
3200 let room = session
3201 .db()
3202 .await
3203 .unfollow(room_id, project_id, leader_id, follower_id)
3204 .await?;
3205 room_updated(&room, &session.peer);
3206 }
3207
3208 Ok(())
3209}
3210
3211/// Notify everyone following you of your current location.
3212async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3213 let room_id = RoomId::from_proto(request.room_id);
3214 let database = session.db.lock().await;
3215
3216 let connection_ids = if let Some(project_id) = request.project_id {
3217 let project_id = ProjectId::from_proto(project_id);
3218 database
3219 .project_connection_ids(project_id, session.connection_id, true)
3220 .await?
3221 } else {
3222 database
3223 .room_connection_ids(room_id, session.connection_id)
3224 .await?
3225 };
3226
3227 // For now, don't send view update messages back to that view's current leader.
3228 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3229 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3230 _ => None,
3231 });
3232
3233 for connection_id in connection_ids.iter().cloned() {
3234 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3235 session
3236 .peer
3237 .forward_send(session.connection_id, connection_id, request.clone())?;
3238 }
3239 }
3240 Ok(())
3241}
3242
3243/// Get public data about users.
3244async fn get_users(
3245 request: proto::GetUsers,
3246 response: Response<proto::GetUsers>,
3247 session: Session,
3248) -> Result<()> {
3249 let user_ids = request
3250 .user_ids
3251 .into_iter()
3252 .map(UserId::from_proto)
3253 .collect();
3254 let users = session
3255 .db()
3256 .await
3257 .get_users_by_ids(user_ids)
3258 .await?
3259 .into_iter()
3260 .map(|user| proto::User {
3261 id: user.id.to_proto(),
3262 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3263 github_login: user.github_login,
3264 })
3265 .collect();
3266 response.send(proto::UsersResponse { users })?;
3267 Ok(())
3268}
3269
3270/// Search for users (to invite) buy Github login
3271async fn fuzzy_search_users(
3272 request: proto::FuzzySearchUsers,
3273 response: Response<proto::FuzzySearchUsers>,
3274 session: UserSession,
3275) -> Result<()> {
3276 let query = request.query;
3277 let users = match query.len() {
3278 0 => vec![],
3279 1 | 2 => session
3280 .db()
3281 .await
3282 .get_user_by_github_login(&query)
3283 .await?
3284 .into_iter()
3285 .collect(),
3286 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3287 };
3288 let users = users
3289 .into_iter()
3290 .filter(|user| user.id != session.user_id())
3291 .map(|user| proto::User {
3292 id: user.id.to_proto(),
3293 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3294 github_login: user.github_login,
3295 })
3296 .collect();
3297 response.send(proto::UsersResponse { users })?;
3298 Ok(())
3299}
3300
3301/// Send a contact request to another user.
3302async fn request_contact(
3303 request: proto::RequestContact,
3304 response: Response<proto::RequestContact>,
3305 session: UserSession,
3306) -> Result<()> {
3307 let requester_id = session.user_id();
3308 let responder_id = UserId::from_proto(request.responder_id);
3309 if requester_id == responder_id {
3310 return Err(anyhow!("cannot add yourself as a contact"))?;
3311 }
3312
3313 let notifications = session
3314 .db()
3315 .await
3316 .send_contact_request(requester_id, responder_id)
3317 .await?;
3318
3319 // Update outgoing contact requests of requester
3320 let mut update = proto::UpdateContacts::default();
3321 update.outgoing_requests.push(responder_id.to_proto());
3322 for connection_id in session
3323 .connection_pool()
3324 .await
3325 .user_connection_ids(requester_id)
3326 {
3327 session.peer.send(connection_id, update.clone())?;
3328 }
3329
3330 // Update incoming contact requests of responder
3331 let mut update = proto::UpdateContacts::default();
3332 update
3333 .incoming_requests
3334 .push(proto::IncomingContactRequest {
3335 requester_id: requester_id.to_proto(),
3336 });
3337 let connection_pool = session.connection_pool().await;
3338 for connection_id in connection_pool.user_connection_ids(responder_id) {
3339 session.peer.send(connection_id, update.clone())?;
3340 }
3341
3342 send_notifications(&connection_pool, &session.peer, notifications);
3343
3344 response.send(proto::Ack {})?;
3345 Ok(())
3346}
3347
3348/// Accept or decline a contact request
3349async fn respond_to_contact_request(
3350 request: proto::RespondToContactRequest,
3351 response: Response<proto::RespondToContactRequest>,
3352 session: UserSession,
3353) -> Result<()> {
3354 let responder_id = session.user_id();
3355 let requester_id = UserId::from_proto(request.requester_id);
3356 let db = session.db().await;
3357 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3358 db.dismiss_contact_notification(responder_id, requester_id)
3359 .await?;
3360 } else {
3361 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3362
3363 let notifications = db
3364 .respond_to_contact_request(responder_id, requester_id, accept)
3365 .await?;
3366 let requester_busy = db.is_user_busy(requester_id).await?;
3367 let responder_busy = db.is_user_busy(responder_id).await?;
3368
3369 let pool = session.connection_pool().await;
3370 // Update responder with new contact
3371 let mut update = proto::UpdateContacts::default();
3372 if accept {
3373 update
3374 .contacts
3375 .push(contact_for_user(requester_id, requester_busy, &pool));
3376 }
3377 update
3378 .remove_incoming_requests
3379 .push(requester_id.to_proto());
3380 for connection_id in pool.user_connection_ids(responder_id) {
3381 session.peer.send(connection_id, update.clone())?;
3382 }
3383
3384 // Update requester with new contact
3385 let mut update = proto::UpdateContacts::default();
3386 if accept {
3387 update
3388 .contacts
3389 .push(contact_for_user(responder_id, responder_busy, &pool));
3390 }
3391 update
3392 .remove_outgoing_requests
3393 .push(responder_id.to_proto());
3394
3395 for connection_id in pool.user_connection_ids(requester_id) {
3396 session.peer.send(connection_id, update.clone())?;
3397 }
3398
3399 send_notifications(&pool, &session.peer, notifications);
3400 }
3401
3402 response.send(proto::Ack {})?;
3403 Ok(())
3404}
3405
3406/// Remove a contact.
3407async fn remove_contact(
3408 request: proto::RemoveContact,
3409 response: Response<proto::RemoveContact>,
3410 session: UserSession,
3411) -> Result<()> {
3412 let requester_id = session.user_id();
3413 let responder_id = UserId::from_proto(request.user_id);
3414 let db = session.db().await;
3415 let (contact_accepted, deleted_notification_id) =
3416 db.remove_contact(requester_id, responder_id).await?;
3417
3418 let pool = session.connection_pool().await;
3419 // Update outgoing contact requests of requester
3420 let mut update = proto::UpdateContacts::default();
3421 if contact_accepted {
3422 update.remove_contacts.push(responder_id.to_proto());
3423 } else {
3424 update
3425 .remove_outgoing_requests
3426 .push(responder_id.to_proto());
3427 }
3428 for connection_id in pool.user_connection_ids(requester_id) {
3429 session.peer.send(connection_id, update.clone())?;
3430 }
3431
3432 // Update incoming contact requests of responder
3433 let mut update = proto::UpdateContacts::default();
3434 if contact_accepted {
3435 update.remove_contacts.push(requester_id.to_proto());
3436 } else {
3437 update
3438 .remove_incoming_requests
3439 .push(requester_id.to_proto());
3440 }
3441 for connection_id in pool.user_connection_ids(responder_id) {
3442 session.peer.send(connection_id, update.clone())?;
3443 if let Some(notification_id) = deleted_notification_id {
3444 session.peer.send(
3445 connection_id,
3446 proto::DeleteNotification {
3447 notification_id: notification_id.to_proto(),
3448 },
3449 )?;
3450 }
3451 }
3452
3453 response.send(proto::Ack {})?;
3454 Ok(())
3455}
3456
3457fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3458 version.0.minor() < 139
3459}
3460
3461async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3462 subscribe_user_to_channels(
3463 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3464 &session,
3465 )
3466 .await?;
3467 Ok(())
3468}
3469
3470async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3471 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3472 let mut pool = session.connection_pool().await;
3473 for membership in &channels_for_user.channel_memberships {
3474 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3475 }
3476 session.peer.send(
3477 session.connection_id,
3478 build_update_user_channels(&channels_for_user),
3479 )?;
3480 session.peer.send(
3481 session.connection_id,
3482 build_channels_update(channels_for_user),
3483 )?;
3484 Ok(())
3485}
3486
3487/// Creates a new channel.
3488async fn create_channel(
3489 request: proto::CreateChannel,
3490 response: Response<proto::CreateChannel>,
3491 session: UserSession,
3492) -> Result<()> {
3493 let db = session.db().await;
3494
3495 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3496 let (channel, membership) = db
3497 .create_channel(&request.name, parent_id, session.user_id())
3498 .await?;
3499
3500 let root_id = channel.root_id();
3501 let channel = Channel::from_model(channel);
3502
3503 response.send(proto::CreateChannelResponse {
3504 channel: Some(channel.to_proto()),
3505 parent_id: request.parent_id,
3506 })?;
3507
3508 let mut connection_pool = session.connection_pool().await;
3509 if let Some(membership) = membership {
3510 connection_pool.subscribe_to_channel(
3511 membership.user_id,
3512 membership.channel_id,
3513 membership.role,
3514 );
3515 let update = proto::UpdateUserChannels {
3516 channel_memberships: vec![proto::ChannelMembership {
3517 channel_id: membership.channel_id.to_proto(),
3518 role: membership.role.into(),
3519 }],
3520 ..Default::default()
3521 };
3522 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3523 session.peer.send(connection_id, update.clone())?;
3524 }
3525 }
3526
3527 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3528 if !role.can_see_channel(channel.visibility) {
3529 continue;
3530 }
3531
3532 let update = proto::UpdateChannels {
3533 channels: vec![channel.to_proto()],
3534 ..Default::default()
3535 };
3536 session.peer.send(connection_id, update.clone())?;
3537 }
3538
3539 Ok(())
3540}
3541
3542/// Delete a channel
3543async fn delete_channel(
3544 request: proto::DeleteChannel,
3545 response: Response<proto::DeleteChannel>,
3546 session: UserSession,
3547) -> Result<()> {
3548 let db = session.db().await;
3549
3550 let channel_id = request.channel_id;
3551 let (root_channel, removed_channels) = db
3552 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3553 .await?;
3554 response.send(proto::Ack {})?;
3555
3556 // Notify members of removed channels
3557 let mut update = proto::UpdateChannels::default();
3558 update
3559 .delete_channels
3560 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3561
3562 let connection_pool = session.connection_pool().await;
3563 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3564 session.peer.send(connection_id, update.clone())?;
3565 }
3566
3567 Ok(())
3568}
3569
3570/// Invite someone to join a channel.
3571async fn invite_channel_member(
3572 request: proto::InviteChannelMember,
3573 response: Response<proto::InviteChannelMember>,
3574 session: UserSession,
3575) -> Result<()> {
3576 let db = session.db().await;
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578 let invitee_id = UserId::from_proto(request.user_id);
3579 let InviteMemberResult {
3580 channel,
3581 notifications,
3582 } = db
3583 .invite_channel_member(
3584 channel_id,
3585 invitee_id,
3586 session.user_id(),
3587 request.role().into(),
3588 )
3589 .await?;
3590
3591 let update = proto::UpdateChannels {
3592 channel_invitations: vec![channel.to_proto()],
3593 ..Default::default()
3594 };
3595
3596 let connection_pool = session.connection_pool().await;
3597 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3598 session.peer.send(connection_id, update.clone())?;
3599 }
3600
3601 send_notifications(&connection_pool, &session.peer, notifications);
3602
3603 response.send(proto::Ack {})?;
3604 Ok(())
3605}
3606
3607/// remove someone from a channel
3608async fn remove_channel_member(
3609 request: proto::RemoveChannelMember,
3610 response: Response<proto::RemoveChannelMember>,
3611 session: UserSession,
3612) -> Result<()> {
3613 let db = session.db().await;
3614 let channel_id = ChannelId::from_proto(request.channel_id);
3615 let member_id = UserId::from_proto(request.user_id);
3616
3617 let RemoveChannelMemberResult {
3618 membership_update,
3619 notification_id,
3620 } = db
3621 .remove_channel_member(channel_id, member_id, session.user_id())
3622 .await?;
3623
3624 let mut connection_pool = session.connection_pool().await;
3625 notify_membership_updated(
3626 &mut connection_pool,
3627 membership_update,
3628 member_id,
3629 &session.peer,
3630 );
3631 for connection_id in connection_pool.user_connection_ids(member_id) {
3632 if let Some(notification_id) = notification_id {
3633 session
3634 .peer
3635 .send(
3636 connection_id,
3637 proto::DeleteNotification {
3638 notification_id: notification_id.to_proto(),
3639 },
3640 )
3641 .trace_err();
3642 }
3643 }
3644
3645 response.send(proto::Ack {})?;
3646 Ok(())
3647}
3648
3649/// Toggle the channel between public and private.
3650/// Care is taken to maintain the invariant that public channels only descend from public channels,
3651/// (though members-only channels can appear at any point in the hierarchy).
3652async fn set_channel_visibility(
3653 request: proto::SetChannelVisibility,
3654 response: Response<proto::SetChannelVisibility>,
3655 session: UserSession,
3656) -> Result<()> {
3657 let db = session.db().await;
3658 let channel_id = ChannelId::from_proto(request.channel_id);
3659 let visibility = request.visibility().into();
3660
3661 let channel_model = db
3662 .set_channel_visibility(channel_id, visibility, session.user_id())
3663 .await?;
3664 let root_id = channel_model.root_id();
3665 let channel = Channel::from_model(channel_model);
3666
3667 let mut connection_pool = session.connection_pool().await;
3668 for (user_id, role) in connection_pool
3669 .channel_user_ids(root_id)
3670 .collect::<Vec<_>>()
3671 .into_iter()
3672 {
3673 let update = if role.can_see_channel(channel.visibility) {
3674 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3675 proto::UpdateChannels {
3676 channels: vec![channel.to_proto()],
3677 ..Default::default()
3678 }
3679 } else {
3680 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3681 proto::UpdateChannels {
3682 delete_channels: vec![channel.id.to_proto()],
3683 ..Default::default()
3684 }
3685 };
3686
3687 for connection_id in connection_pool.user_connection_ids(user_id) {
3688 session.peer.send(connection_id, update.clone())?;
3689 }
3690 }
3691
3692 response.send(proto::Ack {})?;
3693 Ok(())
3694}
3695
3696/// Alter the role for a user in the channel.
3697async fn set_channel_member_role(
3698 request: proto::SetChannelMemberRole,
3699 response: Response<proto::SetChannelMemberRole>,
3700 session: UserSession,
3701) -> Result<()> {
3702 let db = session.db().await;
3703 let channel_id = ChannelId::from_proto(request.channel_id);
3704 let member_id = UserId::from_proto(request.user_id);
3705 let result = db
3706 .set_channel_member_role(
3707 channel_id,
3708 session.user_id(),
3709 member_id,
3710 request.role().into(),
3711 )
3712 .await?;
3713
3714 match result {
3715 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3716 let mut connection_pool = session.connection_pool().await;
3717 notify_membership_updated(
3718 &mut connection_pool,
3719 membership_update,
3720 member_id,
3721 &session.peer,
3722 )
3723 }
3724 db::SetMemberRoleResult::InviteUpdated(channel) => {
3725 let update = proto::UpdateChannels {
3726 channel_invitations: vec![channel.to_proto()],
3727 ..Default::default()
3728 };
3729
3730 for connection_id in session
3731 .connection_pool()
3732 .await
3733 .user_connection_ids(member_id)
3734 {
3735 session.peer.send(connection_id, update.clone())?;
3736 }
3737 }
3738 }
3739
3740 response.send(proto::Ack {})?;
3741 Ok(())
3742}
3743
3744/// Change the name of a channel
3745async fn rename_channel(
3746 request: proto::RenameChannel,
3747 response: Response<proto::RenameChannel>,
3748 session: UserSession,
3749) -> Result<()> {
3750 let db = session.db().await;
3751 let channel_id = ChannelId::from_proto(request.channel_id);
3752 let channel_model = db
3753 .rename_channel(channel_id, session.user_id(), &request.name)
3754 .await?;
3755 let root_id = channel_model.root_id();
3756 let channel = Channel::from_model(channel_model);
3757
3758 response.send(proto::RenameChannelResponse {
3759 channel: Some(channel.to_proto()),
3760 })?;
3761
3762 let connection_pool = session.connection_pool().await;
3763 let update = proto::UpdateChannels {
3764 channels: vec![channel.to_proto()],
3765 ..Default::default()
3766 };
3767 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3768 if role.can_see_channel(channel.visibility) {
3769 session.peer.send(connection_id, update.clone())?;
3770 }
3771 }
3772
3773 Ok(())
3774}
3775
3776/// Move a channel to a new parent.
3777async fn move_channel(
3778 request: proto::MoveChannel,
3779 response: Response<proto::MoveChannel>,
3780 session: UserSession,
3781) -> Result<()> {
3782 let channel_id = ChannelId::from_proto(request.channel_id);
3783 let to = ChannelId::from_proto(request.to);
3784
3785 let (root_id, channels) = session
3786 .db()
3787 .await
3788 .move_channel(channel_id, to, session.user_id())
3789 .await?;
3790
3791 let connection_pool = session.connection_pool().await;
3792 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3793 let channels = channels
3794 .iter()
3795 .filter_map(|channel| {
3796 if role.can_see_channel(channel.visibility) {
3797 Some(channel.to_proto())
3798 } else {
3799 None
3800 }
3801 })
3802 .collect::<Vec<_>>();
3803 if channels.is_empty() {
3804 continue;
3805 }
3806
3807 let update = proto::UpdateChannels {
3808 channels,
3809 ..Default::default()
3810 };
3811
3812 session.peer.send(connection_id, update.clone())?;
3813 }
3814
3815 response.send(Ack {})?;
3816 Ok(())
3817}
3818
3819/// Get the list of channel members
3820async fn get_channel_members(
3821 request: proto::GetChannelMembers,
3822 response: Response<proto::GetChannelMembers>,
3823 session: UserSession,
3824) -> Result<()> {
3825 let db = session.db().await;
3826 let channel_id = ChannelId::from_proto(request.channel_id);
3827 let limit = if request.limit == 0 {
3828 u16::MAX as u64
3829 } else {
3830 request.limit
3831 };
3832 let (members, users) = db
3833 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3834 .await?;
3835 response.send(proto::GetChannelMembersResponse { members, users })?;
3836 Ok(())
3837}
3838
3839/// Accept or decline a channel invitation.
3840async fn respond_to_channel_invite(
3841 request: proto::RespondToChannelInvite,
3842 response: Response<proto::RespondToChannelInvite>,
3843 session: UserSession,
3844) -> Result<()> {
3845 let db = session.db().await;
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847 let RespondToChannelInvite {
3848 membership_update,
3849 notifications,
3850 } = db
3851 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3852 .await?;
3853
3854 let mut connection_pool = session.connection_pool().await;
3855 if let Some(membership_update) = membership_update {
3856 notify_membership_updated(
3857 &mut connection_pool,
3858 membership_update,
3859 session.user_id(),
3860 &session.peer,
3861 );
3862 } else {
3863 let update = proto::UpdateChannels {
3864 remove_channel_invitations: vec![channel_id.to_proto()],
3865 ..Default::default()
3866 };
3867
3868 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3869 session.peer.send(connection_id, update.clone())?;
3870 }
3871 };
3872
3873 send_notifications(&connection_pool, &session.peer, notifications);
3874
3875 response.send(proto::Ack {})?;
3876
3877 Ok(())
3878}
3879
3880/// Join the channels' room
3881async fn join_channel(
3882 request: proto::JoinChannel,
3883 response: Response<proto::JoinChannel>,
3884 session: UserSession,
3885) -> Result<()> {
3886 let channel_id = ChannelId::from_proto(request.channel_id);
3887 join_channel_internal(channel_id, Box::new(response), session).await
3888}
3889
3890trait JoinChannelInternalResponse {
3891 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3892}
3893impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3894 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3895 Response::<proto::JoinChannel>::send(self, result)
3896 }
3897}
3898impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3899 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3900 Response::<proto::JoinRoom>::send(self, result)
3901 }
3902}
3903
3904async fn join_channel_internal(
3905 channel_id: ChannelId,
3906 response: Box<impl JoinChannelInternalResponse>,
3907 session: UserSession,
3908) -> Result<()> {
3909 let joined_room = {
3910 let mut db = session.db().await;
3911 // If zed quits without leaving the room, and the user re-opens zed before the
3912 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3913 // room they were in.
3914 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3915 tracing::info!(
3916 stale_connection_id = %connection,
3917 "cleaning up stale connection",
3918 );
3919 drop(db);
3920 leave_room_for_session(&session, connection).await?;
3921 db = session.db().await;
3922 }
3923
3924 let (joined_room, membership_updated, role) = db
3925 .join_channel(channel_id, session.user_id(), session.connection_id)
3926 .await?;
3927
3928 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3929 let (can_publish, token) = if role == ChannelRole::Guest {
3930 (
3931 false,
3932 live_kit
3933 .guest_token(
3934 &joined_room.room.live_kit_room,
3935 &session.user_id().to_string(),
3936 )
3937 .trace_err()?,
3938 )
3939 } else {
3940 (
3941 true,
3942 live_kit
3943 .room_token(
3944 &joined_room.room.live_kit_room,
3945 &session.user_id().to_string(),
3946 )
3947 .trace_err()?,
3948 )
3949 };
3950
3951 Some(LiveKitConnectionInfo {
3952 server_url: live_kit.url().into(),
3953 token,
3954 can_publish,
3955 })
3956 });
3957
3958 response.send(proto::JoinRoomResponse {
3959 room: Some(joined_room.room.clone()),
3960 channel_id: joined_room
3961 .channel
3962 .as_ref()
3963 .map(|channel| channel.id.to_proto()),
3964 live_kit_connection_info,
3965 })?;
3966
3967 let mut connection_pool = session.connection_pool().await;
3968 if let Some(membership_updated) = membership_updated {
3969 notify_membership_updated(
3970 &mut connection_pool,
3971 membership_updated,
3972 session.user_id(),
3973 &session.peer,
3974 );
3975 }
3976
3977 room_updated(&joined_room.room, &session.peer);
3978
3979 joined_room
3980 };
3981
3982 channel_updated(
3983 &joined_room
3984 .channel
3985 .ok_or_else(|| anyhow!("channel not returned"))?,
3986 &joined_room.room,
3987 &session.peer,
3988 &*session.connection_pool().await,
3989 );
3990
3991 update_user_contacts(session.user_id(), &session).await?;
3992 Ok(())
3993}
3994
3995/// Start editing the channel notes
3996async fn join_channel_buffer(
3997 request: proto::JoinChannelBuffer,
3998 response: Response<proto::JoinChannelBuffer>,
3999 session: UserSession,
4000) -> Result<()> {
4001 let db = session.db().await;
4002 let channel_id = ChannelId::from_proto(request.channel_id);
4003
4004 let open_response = db
4005 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4006 .await?;
4007
4008 let collaborators = open_response.collaborators.clone();
4009 response.send(open_response)?;
4010
4011 let update = UpdateChannelBufferCollaborators {
4012 channel_id: channel_id.to_proto(),
4013 collaborators: collaborators.clone(),
4014 };
4015 channel_buffer_updated(
4016 session.connection_id,
4017 collaborators
4018 .iter()
4019 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4020 &update,
4021 &session.peer,
4022 );
4023
4024 Ok(())
4025}
4026
4027/// Edit the channel notes
4028async fn update_channel_buffer(
4029 request: proto::UpdateChannelBuffer,
4030 session: UserSession,
4031) -> Result<()> {
4032 let db = session.db().await;
4033 let channel_id = ChannelId::from_proto(request.channel_id);
4034
4035 let (collaborators, epoch, version) = db
4036 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4037 .await?;
4038
4039 channel_buffer_updated(
4040 session.connection_id,
4041 collaborators.clone(),
4042 &proto::UpdateChannelBuffer {
4043 channel_id: channel_id.to_proto(),
4044 operations: request.operations,
4045 },
4046 &session.peer,
4047 );
4048
4049 let pool = &*session.connection_pool().await;
4050
4051 let non_collaborators =
4052 pool.channel_connection_ids(channel_id)
4053 .filter_map(|(connection_id, _)| {
4054 if collaborators.contains(&connection_id) {
4055 None
4056 } else {
4057 Some(connection_id)
4058 }
4059 });
4060
4061 broadcast(None, non_collaborators, |peer_id| {
4062 session.peer.send(
4063 peer_id,
4064 proto::UpdateChannels {
4065 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4066 channel_id: channel_id.to_proto(),
4067 epoch: epoch as u64,
4068 version: version.clone(),
4069 }],
4070 ..Default::default()
4071 },
4072 )
4073 });
4074
4075 Ok(())
4076}
4077
4078/// Rejoin the channel notes after a connection blip
4079async fn rejoin_channel_buffers(
4080 request: proto::RejoinChannelBuffers,
4081 response: Response<proto::RejoinChannelBuffers>,
4082 session: UserSession,
4083) -> Result<()> {
4084 let db = session.db().await;
4085 let buffers = db
4086 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4087 .await?;
4088
4089 for rejoined_buffer in &buffers {
4090 let collaborators_to_notify = rejoined_buffer
4091 .buffer
4092 .collaborators
4093 .iter()
4094 .filter_map(|c| Some(c.peer_id?.into()));
4095 channel_buffer_updated(
4096 session.connection_id,
4097 collaborators_to_notify,
4098 &proto::UpdateChannelBufferCollaborators {
4099 channel_id: rejoined_buffer.buffer.channel_id,
4100 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4101 },
4102 &session.peer,
4103 );
4104 }
4105
4106 response.send(proto::RejoinChannelBuffersResponse {
4107 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4108 })?;
4109
4110 Ok(())
4111}
4112
4113/// Stop editing the channel notes
4114async fn leave_channel_buffer(
4115 request: proto::LeaveChannelBuffer,
4116 response: Response<proto::LeaveChannelBuffer>,
4117 session: UserSession,
4118) -> Result<()> {
4119 let db = session.db().await;
4120 let channel_id = ChannelId::from_proto(request.channel_id);
4121
4122 let left_buffer = db
4123 .leave_channel_buffer(channel_id, session.connection_id)
4124 .await?;
4125
4126 response.send(Ack {})?;
4127
4128 channel_buffer_updated(
4129 session.connection_id,
4130 left_buffer.connections,
4131 &proto::UpdateChannelBufferCollaborators {
4132 channel_id: channel_id.to_proto(),
4133 collaborators: left_buffer.collaborators,
4134 },
4135 &session.peer,
4136 );
4137
4138 Ok(())
4139}
4140
4141fn channel_buffer_updated<T: EnvelopedMessage>(
4142 sender_id: ConnectionId,
4143 collaborators: impl IntoIterator<Item = ConnectionId>,
4144 message: &T,
4145 peer: &Peer,
4146) {
4147 broadcast(Some(sender_id), collaborators, |peer_id| {
4148 peer.send(peer_id, message.clone())
4149 });
4150}
4151
4152fn send_notifications(
4153 connection_pool: &ConnectionPool,
4154 peer: &Peer,
4155 notifications: db::NotificationBatch,
4156) {
4157 for (user_id, notification) in notifications {
4158 for connection_id in connection_pool.user_connection_ids(user_id) {
4159 if let Err(error) = peer.send(
4160 connection_id,
4161 proto::AddNotification {
4162 notification: Some(notification.clone()),
4163 },
4164 ) {
4165 tracing::error!(
4166 "failed to send notification to {:?} {}",
4167 connection_id,
4168 error
4169 );
4170 }
4171 }
4172 }
4173}
4174
4175/// Send a message to the channel
4176async fn send_channel_message(
4177 request: proto::SendChannelMessage,
4178 response: Response<proto::SendChannelMessage>,
4179 session: UserSession,
4180) -> Result<()> {
4181 // Validate the message body.
4182 let body = request.body.trim().to_string();
4183 if body.len() > MAX_MESSAGE_LEN {
4184 return Err(anyhow!("message is too long"))?;
4185 }
4186 if body.is_empty() {
4187 return Err(anyhow!("message can't be blank"))?;
4188 }
4189
4190 // TODO: adjust mentions if body is trimmed
4191
4192 let timestamp = OffsetDateTime::now_utc();
4193 let nonce = request
4194 .nonce
4195 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4196
4197 let channel_id = ChannelId::from_proto(request.channel_id);
4198 let CreatedChannelMessage {
4199 message_id,
4200 participant_connection_ids,
4201 notifications,
4202 } = session
4203 .db()
4204 .await
4205 .create_channel_message(
4206 channel_id,
4207 session.user_id(),
4208 &body,
4209 &request.mentions,
4210 timestamp,
4211 nonce.clone().into(),
4212 match request.reply_to_message_id {
4213 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4214 None => None,
4215 },
4216 )
4217 .await?;
4218
4219 let message = proto::ChannelMessage {
4220 sender_id: session.user_id().to_proto(),
4221 id: message_id.to_proto(),
4222 body,
4223 mentions: request.mentions,
4224 timestamp: timestamp.unix_timestamp() as u64,
4225 nonce: Some(nonce),
4226 reply_to_message_id: request.reply_to_message_id,
4227 edited_at: None,
4228 };
4229 broadcast(
4230 Some(session.connection_id),
4231 participant_connection_ids.clone(),
4232 |connection| {
4233 session.peer.send(
4234 connection,
4235 proto::ChannelMessageSent {
4236 channel_id: channel_id.to_proto(),
4237 message: Some(message.clone()),
4238 },
4239 )
4240 },
4241 );
4242 response.send(proto::SendChannelMessageResponse {
4243 message: Some(message),
4244 })?;
4245
4246 let pool = &*session.connection_pool().await;
4247 let non_participants =
4248 pool.channel_connection_ids(channel_id)
4249 .filter_map(|(connection_id, _)| {
4250 if participant_connection_ids.contains(&connection_id) {
4251 None
4252 } else {
4253 Some(connection_id)
4254 }
4255 });
4256 broadcast(None, non_participants, |peer_id| {
4257 session.peer.send(
4258 peer_id,
4259 proto::UpdateChannels {
4260 latest_channel_message_ids: vec![proto::ChannelMessageId {
4261 channel_id: channel_id.to_proto(),
4262 message_id: message_id.to_proto(),
4263 }],
4264 ..Default::default()
4265 },
4266 )
4267 });
4268 send_notifications(pool, &session.peer, notifications);
4269
4270 Ok(())
4271}
4272
4273/// Delete a channel message
4274async fn remove_channel_message(
4275 request: proto::RemoveChannelMessage,
4276 response: Response<proto::RemoveChannelMessage>,
4277 session: UserSession,
4278) -> Result<()> {
4279 let channel_id = ChannelId::from_proto(request.channel_id);
4280 let message_id = MessageId::from_proto(request.message_id);
4281 let (connection_ids, existing_notification_ids) = session
4282 .db()
4283 .await
4284 .remove_channel_message(channel_id, message_id, session.user_id())
4285 .await?;
4286
4287 broadcast(
4288 Some(session.connection_id),
4289 connection_ids,
4290 move |connection| {
4291 session.peer.send(connection, request.clone())?;
4292
4293 for notification_id in &existing_notification_ids {
4294 session.peer.send(
4295 connection,
4296 proto::DeleteNotification {
4297 notification_id: (*notification_id).to_proto(),
4298 },
4299 )?;
4300 }
4301
4302 Ok(())
4303 },
4304 );
4305 response.send(proto::Ack {})?;
4306 Ok(())
4307}
4308
4309async fn update_channel_message(
4310 request: proto::UpdateChannelMessage,
4311 response: Response<proto::UpdateChannelMessage>,
4312 session: UserSession,
4313) -> Result<()> {
4314 let channel_id = ChannelId::from_proto(request.channel_id);
4315 let message_id = MessageId::from_proto(request.message_id);
4316 let updated_at = OffsetDateTime::now_utc();
4317 let UpdatedChannelMessage {
4318 message_id,
4319 participant_connection_ids,
4320 notifications,
4321 reply_to_message_id,
4322 timestamp,
4323 deleted_mention_notification_ids,
4324 updated_mention_notifications,
4325 } = session
4326 .db()
4327 .await
4328 .update_channel_message(
4329 channel_id,
4330 message_id,
4331 session.user_id(),
4332 request.body.as_str(),
4333 &request.mentions,
4334 updated_at,
4335 )
4336 .await?;
4337
4338 let nonce = request
4339 .nonce
4340 .clone()
4341 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4342
4343 let message = proto::ChannelMessage {
4344 sender_id: session.user_id().to_proto(),
4345 id: message_id.to_proto(),
4346 body: request.body.clone(),
4347 mentions: request.mentions.clone(),
4348 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4349 nonce: Some(nonce),
4350 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4351 edited_at: Some(updated_at.unix_timestamp() as u64),
4352 };
4353
4354 response.send(proto::Ack {})?;
4355
4356 let pool = &*session.connection_pool().await;
4357 broadcast(
4358 Some(session.connection_id),
4359 participant_connection_ids,
4360 |connection| {
4361 session.peer.send(
4362 connection,
4363 proto::ChannelMessageUpdate {
4364 channel_id: channel_id.to_proto(),
4365 message: Some(message.clone()),
4366 },
4367 )?;
4368
4369 for notification_id in &deleted_mention_notification_ids {
4370 session.peer.send(
4371 connection,
4372 proto::DeleteNotification {
4373 notification_id: (*notification_id).to_proto(),
4374 },
4375 )?;
4376 }
4377
4378 for notification in &updated_mention_notifications {
4379 session.peer.send(
4380 connection,
4381 proto::UpdateNotification {
4382 notification: Some(notification.clone()),
4383 },
4384 )?;
4385 }
4386
4387 Ok(())
4388 },
4389 );
4390
4391 send_notifications(pool, &session.peer, notifications);
4392
4393 Ok(())
4394}
4395
4396/// Mark a channel message as read
4397async fn acknowledge_channel_message(
4398 request: proto::AckChannelMessage,
4399 session: UserSession,
4400) -> Result<()> {
4401 let channel_id = ChannelId::from_proto(request.channel_id);
4402 let message_id = MessageId::from_proto(request.message_id);
4403 let notifications = session
4404 .db()
4405 .await
4406 .observe_channel_message(channel_id, session.user_id(), message_id)
4407 .await?;
4408 send_notifications(
4409 &*session.connection_pool().await,
4410 &session.peer,
4411 notifications,
4412 );
4413 Ok(())
4414}
4415
4416/// Mark a buffer version as synced
4417async fn acknowledge_buffer_version(
4418 request: proto::AckBufferOperation,
4419 session: UserSession,
4420) -> Result<()> {
4421 let buffer_id = BufferId::from_proto(request.buffer_id);
4422 session
4423 .db()
4424 .await
4425 .observe_buffer_version(
4426 buffer_id,
4427 session.user_id(),
4428 request.epoch as i32,
4429 &request.version,
4430 )
4431 .await?;
4432 Ok(())
4433}
4434
4435struct CompleteWithLanguageModelRateLimit;
4436
4437impl RateLimit for CompleteWithLanguageModelRateLimit {
4438 fn capacity() -> usize {
4439 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4440 .ok()
4441 .and_then(|v| v.parse().ok())
4442 .unwrap_or(120) // Picked arbitrarily
4443 }
4444
4445 fn refill_duration() -> chrono::Duration {
4446 chrono::Duration::hours(1)
4447 }
4448
4449 fn db_name() -> &'static str {
4450 "complete-with-language-model"
4451 }
4452}
4453
4454async fn complete_with_language_model(
4455 request: proto::CompleteWithLanguageModel,
4456 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4457 session: Session,
4458 open_ai_api_key: Option<Arc<str>>,
4459 google_ai_api_key: Option<Arc<str>>,
4460 anthropic_api_key: Option<Arc<str>>,
4461) -> Result<()> {
4462 let Some(session) = session.for_user() else {
4463 return Err(anyhow!("user not found"))?;
4464 };
4465 authorize_access_to_language_models(&session).await?;
4466 session
4467 .rate_limiter
4468 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4469 .await?;
4470
4471 if request.model.starts_with("gpt") {
4472 let api_key =
4473 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4474 complete_with_open_ai(request, response, session, api_key).await?;
4475 } else if request.model.starts_with("gemini") {
4476 let api_key = google_ai_api_key
4477 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4478 complete_with_google_ai(request, response, session, api_key).await?;
4479 } else if request.model.starts_with("claude") {
4480 let api_key = anthropic_api_key
4481 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4482 complete_with_anthropic(request, response, session, api_key).await?;
4483 }
4484
4485 Ok(())
4486}
4487
4488async fn complete_with_open_ai(
4489 request: proto::CompleteWithLanguageModel,
4490 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4491 session: UserSession,
4492 api_key: Arc<str>,
4493) -> Result<()> {
4494 let mut completion_stream = open_ai::stream_completion(
4495 session.http_client.as_ref(),
4496 OPEN_AI_API_URL,
4497 &api_key,
4498 crate::ai::language_model_request_to_open_ai(request)?,
4499 None,
4500 )
4501 .await
4502 .context("open_ai::stream_completion request failed within collab")?;
4503
4504 while let Some(event) = completion_stream.next().await {
4505 let event = event?;
4506 response.send(proto::LanguageModelResponse {
4507 choices: event
4508 .choices
4509 .into_iter()
4510 .map(|choice| proto::LanguageModelChoiceDelta {
4511 index: choice.index,
4512 delta: Some(proto::LanguageModelResponseMessage {
4513 role: choice.delta.role.map(|role| match role {
4514 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4515 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4516 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4517 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4518 } as i32),
4519 content: choice.delta.content,
4520 tool_calls: choice
4521 .delta
4522 .tool_calls
4523 .unwrap_or_default()
4524 .into_iter()
4525 .map(|delta| proto::ToolCallDelta {
4526 index: delta.index as u32,
4527 id: delta.id,
4528 variant: match delta.function {
4529 Some(function) => {
4530 let name = function.name;
4531 let arguments = function.arguments;
4532
4533 Some(proto::tool_call_delta::Variant::Function(
4534 proto::tool_call_delta::FunctionCallDelta {
4535 name,
4536 arguments,
4537 },
4538 ))
4539 }
4540 None => None,
4541 },
4542 })
4543 .collect(),
4544 }),
4545 finish_reason: choice.finish_reason,
4546 })
4547 .collect(),
4548 })?;
4549 }
4550
4551 Ok(())
4552}
4553
4554async fn complete_with_google_ai(
4555 request: proto::CompleteWithLanguageModel,
4556 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4557 session: UserSession,
4558 api_key: Arc<str>,
4559) -> Result<()> {
4560 let mut stream = google_ai::stream_generate_content(
4561 session.http_client.clone(),
4562 google_ai::API_URL,
4563 api_key.as_ref(),
4564 &request.model.clone(),
4565 crate::ai::language_model_request_to_google_ai(request)?,
4566 )
4567 .await
4568 .context("google_ai::stream_generate_content request failed")?;
4569
4570 while let Some(event) = stream.next().await {
4571 let event = event?;
4572 response.send(proto::LanguageModelResponse {
4573 choices: event
4574 .candidates
4575 .unwrap_or_default()
4576 .into_iter()
4577 .map(|candidate| proto::LanguageModelChoiceDelta {
4578 index: candidate.index as u32,
4579 delta: Some(proto::LanguageModelResponseMessage {
4580 role: Some(match candidate.content.role {
4581 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4582 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4583 } as i32),
4584 content: Some(
4585 candidate
4586 .content
4587 .parts
4588 .into_iter()
4589 .filter_map(|part| match part {
4590 google_ai::Part::TextPart(part) => Some(part.text),
4591 google_ai::Part::InlineDataPart(_) => None,
4592 })
4593 .collect(),
4594 ),
4595 // Tool calls are not supported for Google
4596 tool_calls: Vec::new(),
4597 }),
4598 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4599 })
4600 .collect(),
4601 })?;
4602 }
4603
4604 Ok(())
4605}
4606
4607async fn complete_with_anthropic(
4608 request: proto::CompleteWithLanguageModel,
4609 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4610 session: UserSession,
4611 api_key: Arc<str>,
4612) -> Result<()> {
4613 let model = anthropic::Model::from_id(&request.model)?;
4614
4615 let mut system_message = String::new();
4616 let messages = request
4617 .messages
4618 .into_iter()
4619 .filter_map(|message| {
4620 match message.role() {
4621 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4622 role: anthropic::Role::User,
4623 content: message.content,
4624 }),
4625 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4626 role: anthropic::Role::Assistant,
4627 content: message.content,
4628 }),
4629 // Anthropic's API breaks system instructions out as a separate field rather
4630 // than having a system message role.
4631 LanguageModelRole::LanguageModelSystem => {
4632 if !system_message.is_empty() {
4633 system_message.push_str("\n\n");
4634 }
4635 system_message.push_str(&message.content);
4636
4637 None
4638 }
4639 // We don't yet support tool calls for Anthropic
4640 LanguageModelRole::LanguageModelTool => None,
4641 }
4642 })
4643 .collect();
4644
4645 let mut stream = anthropic::stream_completion(
4646 session.http_client.as_ref(),
4647 anthropic::ANTHROPIC_API_URL,
4648 &api_key,
4649 anthropic::Request {
4650 model,
4651 messages,
4652 stream: true,
4653 system: system_message,
4654 max_tokens: 4092,
4655 },
4656 None,
4657 )
4658 .await?;
4659
4660 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4661
4662 while let Some(event) = stream.next().await {
4663 let event = event?;
4664
4665 match event {
4666 anthropic::ResponseEvent::MessageStart { message } => {
4667 if let Some(role) = message.role {
4668 if role == "assistant" {
4669 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4670 } else if role == "user" {
4671 current_role = proto::LanguageModelRole::LanguageModelUser;
4672 }
4673 }
4674 }
4675 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4676 match content_block {
4677 anthropic::ContentBlock::Text { text } => {
4678 if !text.is_empty() {
4679 response.send(proto::LanguageModelResponse {
4680 choices: vec![proto::LanguageModelChoiceDelta {
4681 index: 0,
4682 delta: Some(proto::LanguageModelResponseMessage {
4683 role: Some(current_role as i32),
4684 content: Some(text),
4685 tool_calls: Vec::new(),
4686 }),
4687 finish_reason: None,
4688 }],
4689 })?;
4690 }
4691 }
4692 }
4693 }
4694 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4695 anthropic::TextDelta::TextDelta { text } => {
4696 response.send(proto::LanguageModelResponse {
4697 choices: vec![proto::LanguageModelChoiceDelta {
4698 index: 0,
4699 delta: Some(proto::LanguageModelResponseMessage {
4700 role: Some(current_role as i32),
4701 content: Some(text),
4702 tool_calls: Vec::new(),
4703 }),
4704 finish_reason: None,
4705 }],
4706 })?;
4707 }
4708 },
4709 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4710 if let Some(stop_reason) = delta.stop_reason {
4711 response.send(proto::LanguageModelResponse {
4712 choices: vec![proto::LanguageModelChoiceDelta {
4713 index: 0,
4714 delta: None,
4715 finish_reason: Some(stop_reason),
4716 }],
4717 })?;
4718 }
4719 }
4720 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4721 anthropic::ResponseEvent::MessageStop {} => {}
4722 anthropic::ResponseEvent::Ping {} => {}
4723 }
4724 }
4725
4726 Ok(())
4727}
4728
4729struct CountTokensWithLanguageModelRateLimit;
4730
4731impl RateLimit for CountTokensWithLanguageModelRateLimit {
4732 fn capacity() -> usize {
4733 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4734 .ok()
4735 .and_then(|v| v.parse().ok())
4736 .unwrap_or(600) // Picked arbitrarily
4737 }
4738
4739 fn refill_duration() -> chrono::Duration {
4740 chrono::Duration::hours(1)
4741 }
4742
4743 fn db_name() -> &'static str {
4744 "count-tokens-with-language-model"
4745 }
4746}
4747
4748async fn count_tokens_with_language_model(
4749 request: proto::CountTokensWithLanguageModel,
4750 response: Response<proto::CountTokensWithLanguageModel>,
4751 session: UserSession,
4752 google_ai_api_key: Option<Arc<str>>,
4753) -> Result<()> {
4754 authorize_access_to_language_models(&session).await?;
4755
4756 if !request.model.starts_with("gemini") {
4757 return Err(anyhow!(
4758 "counting tokens for model: {:?} is not supported",
4759 request.model
4760 ))?;
4761 }
4762
4763 session
4764 .rate_limiter
4765 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4766 .await?;
4767
4768 let api_key = google_ai_api_key
4769 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4770 let tokens_response = google_ai::count_tokens(
4771 session.http_client.as_ref(),
4772 google_ai::API_URL,
4773 &api_key,
4774 crate::ai::count_tokens_request_to_google_ai(request)?,
4775 )
4776 .await?;
4777 response.send(proto::CountTokensResponse {
4778 token_count: tokens_response.total_tokens as u32,
4779 })?;
4780 Ok(())
4781}
4782
4783struct ComputeEmbeddingsRateLimit;
4784
4785impl RateLimit for ComputeEmbeddingsRateLimit {
4786 fn capacity() -> usize {
4787 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4788 .ok()
4789 .and_then(|v| v.parse().ok())
4790 .unwrap_or(5000) // Picked arbitrarily
4791 }
4792
4793 fn refill_duration() -> chrono::Duration {
4794 chrono::Duration::hours(1)
4795 }
4796
4797 fn db_name() -> &'static str {
4798 "compute-embeddings"
4799 }
4800}
4801
4802async fn compute_embeddings(
4803 request: proto::ComputeEmbeddings,
4804 response: Response<proto::ComputeEmbeddings>,
4805 session: UserSession,
4806 api_key: Option<Arc<str>>,
4807) -> Result<()> {
4808 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4809 authorize_access_to_language_models(&session).await?;
4810
4811 session
4812 .rate_limiter
4813 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4814 .await?;
4815
4816 let embeddings = match request.model.as_str() {
4817 "openai/text-embedding-3-small" => {
4818 open_ai::embed(
4819 session.http_client.as_ref(),
4820 OPEN_AI_API_URL,
4821 &api_key,
4822 OpenAiEmbeddingModel::TextEmbedding3Small,
4823 request.texts.iter().map(|text| text.as_str()),
4824 )
4825 .await?
4826 }
4827 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4828 };
4829
4830 let embeddings = request
4831 .texts
4832 .iter()
4833 .map(|text| {
4834 let mut hasher = sha2::Sha256::new();
4835 hasher.update(text.as_bytes());
4836 let result = hasher.finalize();
4837 result.to_vec()
4838 })
4839 .zip(
4840 embeddings
4841 .data
4842 .into_iter()
4843 .map(|embedding| embedding.embedding),
4844 )
4845 .collect::<HashMap<_, _>>();
4846
4847 let db = session.db().await;
4848 db.save_embeddings(&request.model, &embeddings)
4849 .await
4850 .context("failed to save embeddings")
4851 .trace_err();
4852
4853 response.send(proto::ComputeEmbeddingsResponse {
4854 embeddings: embeddings
4855 .into_iter()
4856 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4857 .collect(),
4858 })?;
4859 Ok(())
4860}
4861
4862async fn get_cached_embeddings(
4863 request: proto::GetCachedEmbeddings,
4864 response: Response<proto::GetCachedEmbeddings>,
4865 session: UserSession,
4866) -> Result<()> {
4867 authorize_access_to_language_models(&session).await?;
4868
4869 let db = session.db().await;
4870 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4871
4872 response.send(proto::GetCachedEmbeddingsResponse {
4873 embeddings: embeddings
4874 .into_iter()
4875 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4876 .collect(),
4877 })?;
4878 Ok(())
4879}
4880
4881async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4882 let db = session.db().await;
4883 let flags = db.get_user_flags(session.user_id()).await?;
4884 if flags.iter().any(|flag| flag == "language-models") {
4885 Ok(())
4886 } else {
4887 Err(anyhow!("permission denied"))?
4888 }
4889}
4890
4891/// Get a Supermaven API key for the user
4892async fn get_supermaven_api_key(
4893 _request: proto::GetSupermavenApiKey,
4894 response: Response<proto::GetSupermavenApiKey>,
4895 session: UserSession,
4896) -> Result<()> {
4897 let user_id: String = session.user_id().to_string();
4898 if !session.is_staff() {
4899 return Err(anyhow!("supermaven not enabled for this account"))?;
4900 }
4901
4902 let email = session
4903 .email()
4904 .ok_or_else(|| anyhow!("user must have an email"))?;
4905
4906 let supermaven_admin_api = session
4907 .supermaven_client
4908 .as_ref()
4909 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4910
4911 let result = supermaven_admin_api
4912 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4913 .await?;
4914
4915 response.send(proto::GetSupermavenApiKeyResponse {
4916 api_key: result.api_key,
4917 })?;
4918
4919 Ok(())
4920}
4921
4922/// Start receiving chat updates for a channel
4923async fn join_channel_chat(
4924 request: proto::JoinChannelChat,
4925 response: Response<proto::JoinChannelChat>,
4926 session: UserSession,
4927) -> Result<()> {
4928 let channel_id = ChannelId::from_proto(request.channel_id);
4929
4930 let db = session.db().await;
4931 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4932 .await?;
4933 let messages = db
4934 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4935 .await?;
4936 response.send(proto::JoinChannelChatResponse {
4937 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4938 messages,
4939 })?;
4940 Ok(())
4941}
4942
4943/// Stop receiving chat updates for a channel
4944async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4945 let channel_id = ChannelId::from_proto(request.channel_id);
4946 session
4947 .db()
4948 .await
4949 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4950 .await?;
4951 Ok(())
4952}
4953
4954/// Retrieve the chat history for a channel
4955async fn get_channel_messages(
4956 request: proto::GetChannelMessages,
4957 response: Response<proto::GetChannelMessages>,
4958 session: UserSession,
4959) -> Result<()> {
4960 let channel_id = ChannelId::from_proto(request.channel_id);
4961 let messages = session
4962 .db()
4963 .await
4964 .get_channel_messages(
4965 channel_id,
4966 session.user_id(),
4967 MESSAGE_COUNT_PER_PAGE,
4968 Some(MessageId::from_proto(request.before_message_id)),
4969 )
4970 .await?;
4971 response.send(proto::GetChannelMessagesResponse {
4972 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4973 messages,
4974 })?;
4975 Ok(())
4976}
4977
4978/// Retrieve specific chat messages
4979async fn get_channel_messages_by_id(
4980 request: proto::GetChannelMessagesById,
4981 response: Response<proto::GetChannelMessagesById>,
4982 session: UserSession,
4983) -> Result<()> {
4984 let message_ids = request
4985 .message_ids
4986 .iter()
4987 .map(|id| MessageId::from_proto(*id))
4988 .collect::<Vec<_>>();
4989 let messages = session
4990 .db()
4991 .await
4992 .get_channel_messages_by_id(session.user_id(), &message_ids)
4993 .await?;
4994 response.send(proto::GetChannelMessagesResponse {
4995 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4996 messages,
4997 })?;
4998 Ok(())
4999}
5000
5001/// Retrieve the current users notifications
5002async fn get_notifications(
5003 request: proto::GetNotifications,
5004 response: Response<proto::GetNotifications>,
5005 session: UserSession,
5006) -> Result<()> {
5007 let notifications = session
5008 .db()
5009 .await
5010 .get_notifications(
5011 session.user_id(),
5012 NOTIFICATION_COUNT_PER_PAGE,
5013 request
5014 .before_id
5015 .map(|id| db::NotificationId::from_proto(id)),
5016 )
5017 .await?;
5018 response.send(proto::GetNotificationsResponse {
5019 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5020 notifications,
5021 })?;
5022 Ok(())
5023}
5024
5025/// Mark notifications as read
5026async fn mark_notification_as_read(
5027 request: proto::MarkNotificationRead,
5028 response: Response<proto::MarkNotificationRead>,
5029 session: UserSession,
5030) -> Result<()> {
5031 let database = &session.db().await;
5032 let notifications = database
5033 .mark_notification_as_read_by_id(
5034 session.user_id(),
5035 NotificationId::from_proto(request.notification_id),
5036 )
5037 .await?;
5038 send_notifications(
5039 &*session.connection_pool().await,
5040 &session.peer,
5041 notifications,
5042 );
5043 response.send(proto::Ack {})?;
5044 Ok(())
5045}
5046
5047/// Get the current users information
5048async fn get_private_user_info(
5049 _request: proto::GetPrivateUserInfo,
5050 response: Response<proto::GetPrivateUserInfo>,
5051 session: UserSession,
5052) -> Result<()> {
5053 let db = session.db().await;
5054
5055 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5056 let user = db
5057 .get_user_by_id(session.user_id())
5058 .await?
5059 .ok_or_else(|| anyhow!("user not found"))?;
5060 let flags = db.get_user_flags(session.user_id()).await?;
5061
5062 response.send(proto::GetPrivateUserInfoResponse {
5063 metrics_id,
5064 staff: user.admin,
5065 flags,
5066 })?;
5067 Ok(())
5068}
5069
5070fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5071 match message {
5072 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5073 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5074 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5075 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5076 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5077 code: frame.code.into(),
5078 reason: frame.reason,
5079 })),
5080 }
5081}
5082
5083fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5084 match message {
5085 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5086 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5087 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5088 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5089 AxumMessage::Close(frame) => {
5090 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5091 code: frame.code.into(),
5092 reason: frame.reason,
5093 }))
5094 }
5095 }
5096}
5097
5098fn notify_membership_updated(
5099 connection_pool: &mut ConnectionPool,
5100 result: MembershipUpdated,
5101 user_id: UserId,
5102 peer: &Peer,
5103) {
5104 for membership in &result.new_channels.channel_memberships {
5105 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5106 }
5107 for channel_id in &result.removed_channels {
5108 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5109 }
5110
5111 let user_channels_update = proto::UpdateUserChannels {
5112 channel_memberships: result
5113 .new_channels
5114 .channel_memberships
5115 .iter()
5116 .map(|cm| proto::ChannelMembership {
5117 channel_id: cm.channel_id.to_proto(),
5118 role: cm.role.into(),
5119 })
5120 .collect(),
5121 ..Default::default()
5122 };
5123
5124 let mut update = build_channels_update(result.new_channels);
5125 update.delete_channels = result
5126 .removed_channels
5127 .into_iter()
5128 .map(|id| id.to_proto())
5129 .collect();
5130 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5131
5132 for connection_id in connection_pool.user_connection_ids(user_id) {
5133 peer.send(connection_id, user_channels_update.clone())
5134 .trace_err();
5135 peer.send(connection_id, update.clone()).trace_err();
5136 }
5137}
5138
5139fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5140 proto::UpdateUserChannels {
5141 channel_memberships: channels
5142 .channel_memberships
5143 .iter()
5144 .map(|m| proto::ChannelMembership {
5145 channel_id: m.channel_id.to_proto(),
5146 role: m.role.into(),
5147 })
5148 .collect(),
5149 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5150 observed_channel_message_id: channels.observed_channel_messages.clone(),
5151 }
5152}
5153
5154fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5155 let mut update = proto::UpdateChannels::default();
5156
5157 for channel in channels.channels {
5158 update.channels.push(channel.to_proto());
5159 }
5160
5161 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5162 update.latest_channel_message_ids = channels.latest_channel_messages;
5163
5164 for (channel_id, participants) in channels.channel_participants {
5165 update
5166 .channel_participants
5167 .push(proto::ChannelParticipants {
5168 channel_id: channel_id.to_proto(),
5169 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5170 });
5171 }
5172
5173 for channel in channels.invited_channels {
5174 update.channel_invitations.push(channel.to_proto());
5175 }
5176
5177 update.hosted_projects = channels.hosted_projects;
5178 update
5179}
5180
5181fn build_initial_contacts_update(
5182 contacts: Vec<db::Contact>,
5183 pool: &ConnectionPool,
5184) -> proto::UpdateContacts {
5185 let mut update = proto::UpdateContacts::default();
5186
5187 for contact in contacts {
5188 match contact {
5189 db::Contact::Accepted { user_id, busy } => {
5190 update.contacts.push(contact_for_user(user_id, busy, &pool));
5191 }
5192 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5193 db::Contact::Incoming { user_id } => {
5194 update
5195 .incoming_requests
5196 .push(proto::IncomingContactRequest {
5197 requester_id: user_id.to_proto(),
5198 })
5199 }
5200 }
5201 }
5202
5203 update
5204}
5205
5206fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5207 proto::Contact {
5208 user_id: user_id.to_proto(),
5209 online: pool.is_user_online(user_id),
5210 busy,
5211 }
5212}
5213
5214fn room_updated(room: &proto::Room, peer: &Peer) {
5215 broadcast(
5216 None,
5217 room.participants
5218 .iter()
5219 .filter_map(|participant| Some(participant.peer_id?.into())),
5220 |peer_id| {
5221 peer.send(
5222 peer_id,
5223 proto::RoomUpdated {
5224 room: Some(room.clone()),
5225 },
5226 )
5227 },
5228 );
5229}
5230
5231fn channel_updated(
5232 channel: &db::channel::Model,
5233 room: &proto::Room,
5234 peer: &Peer,
5235 pool: &ConnectionPool,
5236) {
5237 let participants = room
5238 .participants
5239 .iter()
5240 .map(|p| p.user_id)
5241 .collect::<Vec<_>>();
5242
5243 broadcast(
5244 None,
5245 pool.channel_connection_ids(channel.root_id())
5246 .filter_map(|(channel_id, role)| {
5247 role.can_see_channel(channel.visibility).then(|| channel_id)
5248 }),
5249 |peer_id| {
5250 peer.send(
5251 peer_id,
5252 proto::UpdateChannels {
5253 channel_participants: vec![proto::ChannelParticipants {
5254 channel_id: channel.id.to_proto(),
5255 participant_user_ids: participants.clone(),
5256 }],
5257 ..Default::default()
5258 },
5259 )
5260 },
5261 );
5262}
5263
5264async fn send_dev_server_projects_update(
5265 user_id: UserId,
5266 mut status: proto::DevServerProjectsUpdate,
5267 session: &Session,
5268) {
5269 let pool = session.connection_pool().await;
5270 for dev_server in &mut status.dev_servers {
5271 dev_server.status =
5272 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5273 }
5274 let connections = pool.user_connection_ids(user_id);
5275 for connection_id in connections {
5276 session.peer.send(connection_id, status.clone()).trace_err();
5277 }
5278}
5279
5280async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5281 let db = session.db().await;
5282
5283 let contacts = db.get_contacts(user_id).await?;
5284 let busy = db.is_user_busy(user_id).await?;
5285
5286 let pool = session.connection_pool().await;
5287 let updated_contact = contact_for_user(user_id, busy, &pool);
5288 for contact in contacts {
5289 if let db::Contact::Accepted {
5290 user_id: contact_user_id,
5291 ..
5292 } = contact
5293 {
5294 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5295 session
5296 .peer
5297 .send(
5298 contact_conn_id,
5299 proto::UpdateContacts {
5300 contacts: vec![updated_contact.clone()],
5301 remove_contacts: Default::default(),
5302 incoming_requests: Default::default(),
5303 remove_incoming_requests: Default::default(),
5304 outgoing_requests: Default::default(),
5305 remove_outgoing_requests: Default::default(),
5306 },
5307 )
5308 .trace_err();
5309 }
5310 }
5311 }
5312 Ok(())
5313}
5314
5315async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5316 log::info!("lost dev server connection, unsharing projects");
5317 let project_ids = session
5318 .db()
5319 .await
5320 .get_stale_dev_server_projects(session.connection_id)
5321 .await?;
5322
5323 for project_id in project_ids {
5324 // not unshare re-checks the connection ids match, so we get away with no transaction
5325 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5326 }
5327
5328 let user_id = session.dev_server().user_id;
5329 let update = session
5330 .db()
5331 .await
5332 .dev_server_projects_update(user_id)
5333 .await?;
5334
5335 send_dev_server_projects_update(user_id, update, session).await;
5336
5337 Ok(())
5338}
5339
5340async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5341 let mut contacts_to_update = HashSet::default();
5342
5343 let room_id;
5344 let canceled_calls_to_user_ids;
5345 let live_kit_room;
5346 let delete_live_kit_room;
5347 let room;
5348 let channel;
5349
5350 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5351 contacts_to_update.insert(session.user_id());
5352
5353 for project in left_room.left_projects.values() {
5354 project_left(project, session);
5355 }
5356
5357 room_id = RoomId::from_proto(left_room.room.id);
5358 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5359 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5360 delete_live_kit_room = left_room.deleted;
5361 room = mem::take(&mut left_room.room);
5362 channel = mem::take(&mut left_room.channel);
5363
5364 room_updated(&room, &session.peer);
5365 } else {
5366 return Ok(());
5367 }
5368
5369 if let Some(channel) = channel {
5370 channel_updated(
5371 &channel,
5372 &room,
5373 &session.peer,
5374 &*session.connection_pool().await,
5375 );
5376 }
5377
5378 {
5379 let pool = session.connection_pool().await;
5380 for canceled_user_id in canceled_calls_to_user_ids {
5381 for connection_id in pool.user_connection_ids(canceled_user_id) {
5382 session
5383 .peer
5384 .send(
5385 connection_id,
5386 proto::CallCanceled {
5387 room_id: room_id.to_proto(),
5388 },
5389 )
5390 .trace_err();
5391 }
5392 contacts_to_update.insert(canceled_user_id);
5393 }
5394 }
5395
5396 for contact_user_id in contacts_to_update {
5397 update_user_contacts(contact_user_id, &session).await?;
5398 }
5399
5400 if let Some(live_kit) = session.live_kit_client.as_ref() {
5401 live_kit
5402 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5403 .await
5404 .trace_err();
5405
5406 if delete_live_kit_room {
5407 live_kit.delete_room(live_kit_room).await.trace_err();
5408 }
5409 }
5410
5411 Ok(())
5412}
5413
5414async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5415 let left_channel_buffers = session
5416 .db()
5417 .await
5418 .leave_channel_buffers(session.connection_id)
5419 .await?;
5420
5421 for left_buffer in left_channel_buffers {
5422 channel_buffer_updated(
5423 session.connection_id,
5424 left_buffer.connections,
5425 &proto::UpdateChannelBufferCollaborators {
5426 channel_id: left_buffer.channel_id.to_proto(),
5427 collaborators: left_buffer.collaborators,
5428 },
5429 &session.peer,
5430 );
5431 }
5432
5433 Ok(())
5434}
5435
5436fn project_left(project: &db::LeftProject, session: &UserSession) {
5437 for connection_id in &project.connection_ids {
5438 if project.should_unshare {
5439 session
5440 .peer
5441 .send(
5442 *connection_id,
5443 proto::UnshareProject {
5444 project_id: project.id.to_proto(),
5445 },
5446 )
5447 .trace_err();
5448 } else {
5449 session
5450 .peer
5451 .send(
5452 *connection_id,
5453 proto::RemoveProjectCollaborator {
5454 project_id: project.id.to_proto(),
5455 peer_id: Some(session.connection_id.into()),
5456 },
5457 )
5458 .trace_err();
5459 }
5460 }
5461}
5462
5463pub trait ResultExt {
5464 type Ok;
5465
5466 fn trace_err(self) -> Option<Self::Ok>;
5467}
5468
5469impl<T, E> ResultExt for Result<T, E>
5470where
5471 E: std::fmt::Debug,
5472{
5473 type Ok = T;
5474
5475 #[track_caller]
5476 fn trace_err(self) -> Option<T> {
5477 match self {
5478 Ok(value) => Some(value),
5479 Err(error) => {
5480 tracing::error!("{:?}", error);
5481 None
5482 }
5483 }
5484 }
5485}