1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_message_handler(create_buffer_for_peer)
549 .add_request_handler(update_buffer)
550 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
551 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
552 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
553 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
554 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
555 .add_request_handler(get_users)
556 .add_request_handler(user_handler(fuzzy_search_users))
557 .add_request_handler(user_handler(request_contact))
558 .add_request_handler(user_handler(remove_contact))
559 .add_request_handler(user_handler(respond_to_contact_request))
560 .add_message_handler(subscribe_to_channels)
561 .add_request_handler(user_handler(create_channel))
562 .add_request_handler(user_handler(delete_channel))
563 .add_request_handler(user_handler(invite_channel_member))
564 .add_request_handler(user_handler(remove_channel_member))
565 .add_request_handler(user_handler(set_channel_member_role))
566 .add_request_handler(user_handler(set_channel_visibility))
567 .add_request_handler(user_handler(rename_channel))
568 .add_request_handler(user_handler(join_channel_buffer))
569 .add_request_handler(user_handler(leave_channel_buffer))
570 .add_message_handler(user_message_handler(update_channel_buffer))
571 .add_request_handler(user_handler(rejoin_channel_buffers))
572 .add_request_handler(user_handler(get_channel_members))
573 .add_request_handler(user_handler(respond_to_channel_invite))
574 .add_request_handler(user_handler(join_channel))
575 .add_request_handler(user_handler(join_channel_chat))
576 .add_message_handler(user_message_handler(leave_channel_chat))
577 .add_request_handler(user_handler(send_channel_message))
578 .add_request_handler(user_handler(remove_channel_message))
579 .add_request_handler(user_handler(update_channel_message))
580 .add_request_handler(user_handler(get_channel_messages))
581 .add_request_handler(user_handler(get_channel_messages_by_id))
582 .add_request_handler(user_handler(get_notifications))
583 .add_request_handler(user_handler(mark_notification_as_read))
584 .add_request_handler(user_handler(move_channel))
585 .add_request_handler(user_handler(follow))
586 .add_message_handler(user_message_handler(unfollow))
587 .add_message_handler(user_message_handler(update_followers))
588 .add_request_handler(user_handler(get_private_user_info))
589 .add_message_handler(user_message_handler(acknowledge_channel_message))
590 .add_message_handler(user_message_handler(acknowledge_buffer_version))
591 .add_request_handler(user_handler(get_supermaven_api_key))
592 .add_streaming_request_handler({
593 let app_state = app_state.clone();
594 move |request, response, session| {
595 complete_with_language_model(
596 request,
597 response,
598 session,
599 app_state.config.openai_api_key.clone(),
600 app_state.config.google_ai_api_key.clone(),
601 app_state.config.anthropic_api_key.clone(),
602 )
603 }
604 })
605 .add_request_handler({
606 let app_state = app_state.clone();
607 user_handler(move |request, response, session| {
608 count_tokens_with_language_model(
609 request,
610 response,
611 session,
612 app_state.config.google_ai_api_key.clone(),
613 )
614 })
615 })
616 .add_request_handler({
617 user_handler(move |request, response, session| {
618 get_cached_embeddings(request, response, session)
619 })
620 })
621 .add_request_handler({
622 let app_state = app_state.clone();
623 user_handler(move |request, response, session| {
624 compute_embeddings(
625 request,
626 response,
627 session,
628 app_state.config.openai_api_key.clone(),
629 )
630 })
631 });
632
633 Arc::new(server)
634 }
635
636 pub async fn start(&self) -> Result<()> {
637 let server_id = *self.id.lock();
638 let app_state = self.app_state.clone();
639 let peer = self.peer.clone();
640 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
641 let pool = self.connection_pool.clone();
642 let live_kit_client = self.app_state.live_kit_client.clone();
643
644 let span = info_span!("start server");
645 self.app_state.executor.spawn_detached(
646 async move {
647 tracing::info!("waiting for cleanup timeout");
648 timeout.await;
649 tracing::info!("cleanup timeout expired, retrieving stale rooms");
650 if let Some((room_ids, channel_ids)) = app_state
651 .db
652 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
653 .await
654 .trace_err()
655 {
656 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
657 tracing::info!(
658 stale_channel_buffer_count = channel_ids.len(),
659 "retrieved stale channel buffers"
660 );
661
662 for channel_id in channel_ids {
663 if let Some(refreshed_channel_buffer) = app_state
664 .db
665 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
666 .await
667 .trace_err()
668 {
669 for connection_id in refreshed_channel_buffer.connection_ids {
670 peer.send(
671 connection_id,
672 proto::UpdateChannelBufferCollaborators {
673 channel_id: channel_id.to_proto(),
674 collaborators: refreshed_channel_buffer
675 .collaborators
676 .clone(),
677 },
678 )
679 .trace_err();
680 }
681 }
682 }
683
684 for room_id in room_ids {
685 let mut contacts_to_update = HashSet::default();
686 let mut canceled_calls_to_user_ids = Vec::new();
687 let mut live_kit_room = String::new();
688 let mut delete_live_kit_room = false;
689
690 if let Some(mut refreshed_room) = app_state
691 .db
692 .clear_stale_room_participants(room_id, server_id)
693 .await
694 .trace_err()
695 {
696 tracing::info!(
697 room_id = room_id.0,
698 new_participant_count = refreshed_room.room.participants.len(),
699 "refreshed room"
700 );
701 room_updated(&refreshed_room.room, &peer);
702 if let Some(channel) = refreshed_room.channel.as_ref() {
703 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
704 }
705 contacts_to_update
706 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
707 contacts_to_update
708 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
709 canceled_calls_to_user_ids =
710 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
711 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
712 delete_live_kit_room = refreshed_room.room.participants.is_empty();
713 }
714
715 {
716 let pool = pool.lock();
717 for canceled_user_id in canceled_calls_to_user_ids {
718 for connection_id in pool.user_connection_ids(canceled_user_id) {
719 peer.send(
720 connection_id,
721 proto::CallCanceled {
722 room_id: room_id.to_proto(),
723 },
724 )
725 .trace_err();
726 }
727 }
728 }
729
730 for user_id in contacts_to_update {
731 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
732 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
733 if let Some((busy, contacts)) = busy.zip(contacts) {
734 let pool = pool.lock();
735 let updated_contact = contact_for_user(user_id, busy, &pool);
736 for contact in contacts {
737 if let db::Contact::Accepted {
738 user_id: contact_user_id,
739 ..
740 } = contact
741 {
742 for contact_conn_id in
743 pool.user_connection_ids(contact_user_id)
744 {
745 peer.send(
746 contact_conn_id,
747 proto::UpdateContacts {
748 contacts: vec![updated_contact.clone()],
749 remove_contacts: Default::default(),
750 incoming_requests: Default::default(),
751 remove_incoming_requests: Default::default(),
752 outgoing_requests: Default::default(),
753 remove_outgoing_requests: Default::default(),
754 },
755 )
756 .trace_err();
757 }
758 }
759 }
760 }
761 }
762
763 if let Some(live_kit) = live_kit_client.as_ref() {
764 if delete_live_kit_room {
765 live_kit.delete_room(live_kit_room).await.trace_err();
766 }
767 }
768 }
769 }
770
771 app_state
772 .db
773 .delete_stale_servers(&app_state.config.zed_environment, server_id)
774 .await
775 .trace_err();
776 }
777 .instrument(span),
778 );
779 Ok(())
780 }
781
782 pub fn teardown(&self) {
783 self.peer.teardown();
784 self.connection_pool.lock().reset();
785 let _ = self.teardown.send(true);
786 }
787
788 #[cfg(test)]
789 pub fn reset(&self, id: ServerId) {
790 self.teardown();
791 *self.id.lock() = id;
792 self.peer.reset(id.0 as u32);
793 let _ = self.teardown.send(false);
794 }
795
796 #[cfg(test)]
797 pub fn id(&self) -> ServerId {
798 *self.id.lock()
799 }
800
801 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
802 where
803 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
804 Fut: 'static + Send + Future<Output = Result<()>>,
805 M: EnvelopedMessage,
806 {
807 let prev_handler = self.handlers.insert(
808 TypeId::of::<M>(),
809 Box::new(move |envelope, session| {
810 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
811 let received_at = envelope.received_at;
812 tracing::info!("message received");
813 let start_time = Instant::now();
814 let future = (handler)(*envelope, session);
815 async move {
816 let result = future.await;
817 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
818 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
819 let queue_duration_ms = total_duration_ms - processing_duration_ms;
820 let payload_type = M::NAME;
821
822 match result {
823 Err(error) => {
824 tracing::error!(
825 ?error,
826 total_duration_ms,
827 processing_duration_ms,
828 queue_duration_ms,
829 payload_type,
830 "error handling message"
831 )
832 }
833 Ok(()) => tracing::info!(
834 total_duration_ms,
835 processing_duration_ms,
836 queue_duration_ms,
837 "finished handling message"
838 ),
839 }
840 }
841 .boxed()
842 }),
843 );
844 if prev_handler.is_some() {
845 panic!("registered a handler for the same message twice");
846 }
847 self
848 }
849
850 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
851 where
852 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
853 Fut: 'static + Send + Future<Output = Result<()>>,
854 M: EnvelopedMessage,
855 {
856 self.add_handler(move |envelope, session| handler(envelope.payload, session));
857 self
858 }
859
860 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
861 where
862 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
863 Fut: Send + Future<Output = Result<()>>,
864 M: RequestMessage,
865 {
866 let handler = Arc::new(handler);
867 self.add_handler(move |envelope, session| {
868 let receipt = envelope.receipt();
869 let handler = handler.clone();
870 async move {
871 let peer = session.peer.clone();
872 let responded = Arc::new(AtomicBool::default());
873 let response = Response {
874 peer: peer.clone(),
875 responded: responded.clone(),
876 receipt,
877 };
878 match (handler)(envelope.payload, response, session).await {
879 Ok(()) => {
880 if responded.load(std::sync::atomic::Ordering::SeqCst) {
881 Ok(())
882 } else {
883 Err(anyhow!("handler did not send a response"))?
884 }
885 }
886 Err(error) => {
887 let proto_err = match &error {
888 Error::Internal(err) => err.to_proto(),
889 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
890 };
891 peer.respond_with_error(receipt, proto_err)?;
892 Err(error)
893 }
894 }
895 }
896 })
897 }
898
899 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
900 where
901 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
902 Fut: Send + Future<Output = Result<()>>,
903 M: RequestMessage,
904 {
905 let handler = Arc::new(handler);
906 self.add_handler(move |envelope, session| {
907 let receipt = envelope.receipt();
908 let handler = handler.clone();
909 async move {
910 let peer = session.peer.clone();
911 let response = StreamingResponse {
912 peer: peer.clone(),
913 receipt,
914 };
915 match (handler)(envelope.payload, response, session).await {
916 Ok(()) => {
917 peer.end_stream(receipt)?;
918 Ok(())
919 }
920 Err(error) => {
921 let proto_err = match &error {
922 Error::Internal(err) => err.to_proto(),
923 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
924 };
925 peer.respond_with_error(receipt, proto_err)?;
926 Err(error)
927 }
928 }
929 }
930 })
931 }
932
933 #[allow(clippy::too_many_arguments)]
934 pub fn handle_connection(
935 self: &Arc<Self>,
936 connection: Connection,
937 address: String,
938 principal: Principal,
939 zed_version: ZedVersion,
940 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
941 executor: Executor,
942 ) -> impl Future<Output = ()> {
943 let this = self.clone();
944 let span = info_span!("handle connection", %address,
945 connection_id=field::Empty,
946 user_id=field::Empty,
947 login=field::Empty,
948 impersonator=field::Empty,
949 dev_server_id=field::Empty
950 );
951 principal.update_span(&span);
952
953 let mut teardown = self.teardown.subscribe();
954 async move {
955 if *teardown.borrow() {
956 tracing::error!("server is tearing down");
957 return
958 }
959 let (connection_id, handle_io, mut incoming_rx) = this
960 .peer
961 .add_connection(connection, {
962 let executor = executor.clone();
963 move |duration| executor.sleep(duration)
964 });
965 tracing::Span::current().record("connection_id", format!("{}", connection_id));
966 tracing::info!("connection opened");
967
968 let http_client = match IsahcHttpClient::new() {
969 Ok(http_client) => Arc::new(http_client),
970 Err(error) => {
971 tracing::error!(?error, "failed to create HTTP client");
972 return;
973 }
974 };
975
976 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
977 Some(Arc::new(SupermavenAdminApi::new(
978 supermaven_admin_api_key.to_string(),
979 http_client.clone(),
980 )))
981 } else {
982 None
983 };
984
985 let session = Session {
986 principal: principal.clone(),
987 connection_id,
988 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
989 peer: this.peer.clone(),
990 connection_pool: this.connection_pool.clone(),
991 live_kit_client: this.app_state.live_kit_client.clone(),
992 http_client,
993 rate_limiter: this.app_state.rate_limiter.clone(),
994 _executor: executor.clone(),
995 supermaven_client,
996 };
997
998 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
999 tracing::error!(?error, "failed to send initial client update");
1000 return;
1001 }
1002
1003 let handle_io = handle_io.fuse();
1004 futures::pin_mut!(handle_io);
1005
1006 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1007 // This prevents deadlocks when e.g., client A performs a request to client B and
1008 // client B performs a request to client A. If both clients stop processing further
1009 // messages until their respective request completes, they won't have a chance to
1010 // respond to the other client's request and cause a deadlock.
1011 //
1012 // This arrangement ensures we will attempt to process earlier messages first, but fall
1013 // back to processing messages arrived later in the spirit of making progress.
1014 let mut foreground_message_handlers = FuturesUnordered::new();
1015 let concurrent_handlers = Arc::new(Semaphore::new(256));
1016 loop {
1017 let next_message = async {
1018 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1019 let message = incoming_rx.next().await;
1020 (permit, message)
1021 }.fuse();
1022 futures::pin_mut!(next_message);
1023 futures::select_biased! {
1024 _ = teardown.changed().fuse() => return,
1025 result = handle_io => {
1026 if let Err(error) = result {
1027 tracing::error!(?error, "error handling I/O");
1028 }
1029 break;
1030 }
1031 _ = foreground_message_handlers.next() => {}
1032 next_message = next_message => {
1033 let (permit, message) = next_message;
1034 if let Some(message) = message {
1035 let type_name = message.payload_type_name();
1036 // note: we copy all the fields from the parent span so we can query them in the logs.
1037 // (https://github.com/tokio-rs/tracing/issues/2670).
1038 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1039 user_id=field::Empty,
1040 login=field::Empty,
1041 impersonator=field::Empty,
1042 dev_server_id=field::Empty
1043 );
1044 principal.update_span(&span);
1045 let span_enter = span.enter();
1046 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1047 let is_background = message.is_background();
1048 let handle_message = (handler)(message, session.clone());
1049 drop(span_enter);
1050
1051 let handle_message = async move {
1052 handle_message.await;
1053 drop(permit);
1054 }.instrument(span);
1055 if is_background {
1056 executor.spawn_detached(handle_message);
1057 } else {
1058 foreground_message_handlers.push(handle_message);
1059 }
1060 } else {
1061 tracing::error!("no message handler");
1062 }
1063 } else {
1064 tracing::info!("connection closed");
1065 break;
1066 }
1067 }
1068 }
1069 }
1070
1071 drop(foreground_message_handlers);
1072 tracing::info!("signing out");
1073 if let Err(error) = connection_lost(session, teardown, executor).await {
1074 tracing::error!(?error, "error signing out");
1075 }
1076
1077 }.instrument(span)
1078 }
1079
1080 async fn send_initial_client_update(
1081 &self,
1082 connection_id: ConnectionId,
1083 principal: &Principal,
1084 zed_version: ZedVersion,
1085 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1086 session: &Session,
1087 ) -> Result<()> {
1088 self.peer.send(
1089 connection_id,
1090 proto::Hello {
1091 peer_id: Some(connection_id.into()),
1092 },
1093 )?;
1094 tracing::info!("sent hello message");
1095 if let Some(send_connection_id) = send_connection_id.take() {
1096 let _ = send_connection_id.send(connection_id);
1097 }
1098
1099 match principal {
1100 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1101 if !user.connected_once {
1102 self.peer.send(connection_id, proto::ShowContacts {})?;
1103 self.app_state
1104 .db
1105 .set_user_connected_once(user.id, true)
1106 .await?;
1107 }
1108
1109 let (contacts, dev_server_projects) = future::try_join(
1110 self.app_state.db.get_contacts(user.id),
1111 self.app_state.db.dev_server_projects_update(user.id),
1112 )
1113 .await?;
1114
1115 {
1116 let mut pool = self.connection_pool.lock();
1117 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1118 self.peer.send(
1119 connection_id,
1120 build_initial_contacts_update(contacts, &pool),
1121 )?;
1122 }
1123
1124 if should_auto_subscribe_to_channels(zed_version) {
1125 subscribe_user_to_channels(user.id, session).await?;
1126 }
1127
1128 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1129
1130 if let Some(incoming_call) =
1131 self.app_state.db.incoming_call_for_user(user.id).await?
1132 {
1133 self.peer.send(connection_id, incoming_call)?;
1134 }
1135
1136 update_user_contacts(user.id, &session).await?;
1137 }
1138 Principal::DevServer(dev_server) => {
1139 {
1140 let mut pool = self.connection_pool.lock();
1141 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1142 {
1143 self.peer.send(
1144 stale_connection_id,
1145 proto::ShutdownDevServer {
1146 reason: Some(
1147 "another dev server connected with the same token".to_string(),
1148 ),
1149 },
1150 )?;
1151 pool.remove_connection(stale_connection_id)?;
1152 };
1153 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1154 }
1155
1156 let projects = self
1157 .app_state
1158 .db
1159 .get_projects_for_dev_server(dev_server.id)
1160 .await?;
1161 self.peer
1162 .send(connection_id, proto::DevServerInstructions { projects })?;
1163
1164 let status = self
1165 .app_state
1166 .db
1167 .dev_server_projects_update(dev_server.user_id)
1168 .await?;
1169 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1170 }
1171 }
1172
1173 Ok(())
1174 }
1175
1176 pub async fn invite_code_redeemed(
1177 self: &Arc<Self>,
1178 inviter_id: UserId,
1179 invitee_id: UserId,
1180 ) -> Result<()> {
1181 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1182 if let Some(code) = &user.invite_code {
1183 let pool = self.connection_pool.lock();
1184 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1185 for connection_id in pool.user_connection_ids(inviter_id) {
1186 self.peer.send(
1187 connection_id,
1188 proto::UpdateContacts {
1189 contacts: vec![invitee_contact.clone()],
1190 ..Default::default()
1191 },
1192 )?;
1193 self.peer.send(
1194 connection_id,
1195 proto::UpdateInviteInfo {
1196 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1197 count: user.invite_count as u32,
1198 },
1199 )?;
1200 }
1201 }
1202 }
1203 Ok(())
1204 }
1205
1206 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1207 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1208 if let Some(invite_code) = &user.invite_code {
1209 let pool = self.connection_pool.lock();
1210 for connection_id in pool.user_connection_ids(user_id) {
1211 self.peer.send(
1212 connection_id,
1213 proto::UpdateInviteInfo {
1214 url: format!(
1215 "{}{}",
1216 self.app_state.config.invite_link_prefix, invite_code
1217 ),
1218 count: user.invite_count as u32,
1219 },
1220 )?;
1221 }
1222 }
1223 }
1224 Ok(())
1225 }
1226
1227 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1228 ServerSnapshot {
1229 connection_pool: ConnectionPoolGuard {
1230 guard: self.connection_pool.lock(),
1231 _not_send: PhantomData,
1232 },
1233 peer: &self.peer,
1234 }
1235 }
1236}
1237
1238impl<'a> Deref for ConnectionPoolGuard<'a> {
1239 type Target = ConnectionPool;
1240
1241 fn deref(&self) -> &Self::Target {
1242 &self.guard
1243 }
1244}
1245
1246impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1247 fn deref_mut(&mut self) -> &mut Self::Target {
1248 &mut self.guard
1249 }
1250}
1251
1252impl<'a> Drop for ConnectionPoolGuard<'a> {
1253 fn drop(&mut self) {
1254 #[cfg(test)]
1255 self.check_invariants();
1256 }
1257}
1258
1259fn broadcast<F>(
1260 sender_id: Option<ConnectionId>,
1261 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1262 mut f: F,
1263) where
1264 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1265{
1266 for receiver_id in receiver_ids {
1267 if Some(receiver_id) != sender_id {
1268 if let Err(error) = f(receiver_id) {
1269 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1270 }
1271 }
1272 }
1273}
1274
1275pub struct ProtocolVersion(u32);
1276
1277impl Header for ProtocolVersion {
1278 fn name() -> &'static HeaderName {
1279 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1280 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1281 }
1282
1283 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1284 where
1285 Self: Sized,
1286 I: Iterator<Item = &'i axum::http::HeaderValue>,
1287 {
1288 let version = values
1289 .next()
1290 .ok_or_else(axum::headers::Error::invalid)?
1291 .to_str()
1292 .map_err(|_| axum::headers::Error::invalid())?
1293 .parse()
1294 .map_err(|_| axum::headers::Error::invalid())?;
1295 Ok(Self(version))
1296 }
1297
1298 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1299 values.extend([self.0.to_string().parse().unwrap()]);
1300 }
1301}
1302
1303pub struct AppVersionHeader(SemanticVersion);
1304impl Header for AppVersionHeader {
1305 fn name() -> &'static HeaderName {
1306 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1307 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1308 }
1309
1310 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1311 where
1312 Self: Sized,
1313 I: Iterator<Item = &'i axum::http::HeaderValue>,
1314 {
1315 let version = values
1316 .next()
1317 .ok_or_else(axum::headers::Error::invalid)?
1318 .to_str()
1319 .map_err(|_| axum::headers::Error::invalid())?
1320 .parse()
1321 .map_err(|_| axum::headers::Error::invalid())?;
1322 Ok(Self(version))
1323 }
1324
1325 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1326 values.extend([self.0.to_string().parse().unwrap()]);
1327 }
1328}
1329
1330pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1331 Router::new()
1332 .route("/rpc", get(handle_websocket_request))
1333 .layer(
1334 ServiceBuilder::new()
1335 .layer(Extension(server.app_state.clone()))
1336 .layer(middleware::from_fn(auth::validate_header)),
1337 )
1338 .route("/metrics", get(handle_metrics))
1339 .layer(Extension(server))
1340}
1341
1342pub async fn handle_websocket_request(
1343 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1344 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1345 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1346 Extension(server): Extension<Arc<Server>>,
1347 Extension(principal): Extension<Principal>,
1348 ws: WebSocketUpgrade,
1349) -> axum::response::Response {
1350 if protocol_version != rpc::PROTOCOL_VERSION {
1351 return (
1352 StatusCode::UPGRADE_REQUIRED,
1353 "client must be upgraded".to_string(),
1354 )
1355 .into_response();
1356 }
1357
1358 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1359 return (
1360 StatusCode::UPGRADE_REQUIRED,
1361 "no version header found".to_string(),
1362 )
1363 .into_response();
1364 };
1365
1366 if !version.can_collaborate() {
1367 return (
1368 StatusCode::UPGRADE_REQUIRED,
1369 "client must be upgraded".to_string(),
1370 )
1371 .into_response();
1372 }
1373
1374 let socket_address = socket_address.to_string();
1375 ws.on_upgrade(move |socket| {
1376 let socket = socket
1377 .map_ok(to_tungstenite_message)
1378 .err_into()
1379 .with(|message| async move { Ok(to_axum_message(message)) });
1380 let connection = Connection::new(Box::pin(socket));
1381 async move {
1382 server
1383 .handle_connection(
1384 connection,
1385 socket_address,
1386 principal,
1387 version,
1388 None,
1389 Executor::Production,
1390 )
1391 .await;
1392 }
1393 })
1394}
1395
1396pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1397 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1398 let connections_metric = CONNECTIONS_METRIC
1399 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1400
1401 let connections = server
1402 .connection_pool
1403 .lock()
1404 .connections()
1405 .filter(|connection| !connection.admin)
1406 .count();
1407 connections_metric.set(connections as _);
1408
1409 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1410 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1411 register_int_gauge!(
1412 "shared_projects",
1413 "number of open projects with one or more guests"
1414 )
1415 .unwrap()
1416 });
1417
1418 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1419 shared_projects_metric.set(shared_projects as _);
1420
1421 let encoder = prometheus::TextEncoder::new();
1422 let metric_families = prometheus::gather();
1423 let encoded_metrics = encoder
1424 .encode_to_string(&metric_families)
1425 .map_err(|err| anyhow!("{}", err))?;
1426 Ok(encoded_metrics)
1427}
1428
1429#[instrument(err, skip(executor))]
1430async fn connection_lost(
1431 session: Session,
1432 mut teardown: watch::Receiver<bool>,
1433 executor: Executor,
1434) -> Result<()> {
1435 session.peer.disconnect(session.connection_id);
1436 session
1437 .connection_pool()
1438 .await
1439 .remove_connection(session.connection_id)?;
1440
1441 session
1442 .db()
1443 .await
1444 .connection_lost(session.connection_id)
1445 .await
1446 .trace_err();
1447
1448 futures::select_biased! {
1449 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1450 match &session.principal {
1451 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1452 let session = session.for_user().unwrap();
1453
1454 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1455 leave_room_for_session(&session, session.connection_id).await.trace_err();
1456 leave_channel_buffers_for_session(&session)
1457 .await
1458 .trace_err();
1459
1460 if !session
1461 .connection_pool()
1462 .await
1463 .is_user_online(session.user_id())
1464 {
1465 let db = session.db().await;
1466 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1467 room_updated(&room, &session.peer);
1468 }
1469 }
1470
1471 update_user_contacts(session.user_id(), &session).await?;
1472 },
1473 Principal::DevServer(_) => {
1474 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1475 },
1476 }
1477 },
1478 _ = teardown.changed().fuse() => {}
1479 }
1480
1481 Ok(())
1482}
1483
1484/// Acknowledges a ping from a client, used to keep the connection alive.
1485async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1486 response.send(proto::Ack {})?;
1487 Ok(())
1488}
1489
1490/// Creates a new room for calling (outside of channels)
1491async fn create_room(
1492 _request: proto::CreateRoom,
1493 response: Response<proto::CreateRoom>,
1494 session: UserSession,
1495) -> Result<()> {
1496 let live_kit_room = nanoid::nanoid!(30);
1497
1498 let live_kit_connection_info = util::maybe!(async {
1499 let live_kit = session.live_kit_client.as_ref();
1500 let live_kit = live_kit?;
1501 let user_id = session.user_id().to_string();
1502
1503 let token = live_kit
1504 .room_token(&live_kit_room, &user_id.to_string())
1505 .trace_err()?;
1506
1507 Some(proto::LiveKitConnectionInfo {
1508 server_url: live_kit.url().into(),
1509 token,
1510 can_publish: true,
1511 })
1512 })
1513 .await;
1514
1515 let room = session
1516 .db()
1517 .await
1518 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1519 .await?;
1520
1521 response.send(proto::CreateRoomResponse {
1522 room: Some(room.clone()),
1523 live_kit_connection_info,
1524 })?;
1525
1526 update_user_contacts(session.user_id(), &session).await?;
1527 Ok(())
1528}
1529
1530/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1531async fn join_room(
1532 request: proto::JoinRoom,
1533 response: Response<proto::JoinRoom>,
1534 session: UserSession,
1535) -> Result<()> {
1536 let room_id = RoomId::from_proto(request.id);
1537
1538 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1539
1540 if let Some(channel_id) = channel_id {
1541 return join_channel_internal(channel_id, Box::new(response), session).await;
1542 }
1543
1544 let joined_room = {
1545 let room = session
1546 .db()
1547 .await
1548 .join_room(room_id, session.user_id(), session.connection_id)
1549 .await?;
1550 room_updated(&room.room, &session.peer);
1551 room.into_inner()
1552 };
1553
1554 for connection_id in session
1555 .connection_pool()
1556 .await
1557 .user_connection_ids(session.user_id())
1558 {
1559 session
1560 .peer
1561 .send(
1562 connection_id,
1563 proto::CallCanceled {
1564 room_id: room_id.to_proto(),
1565 },
1566 )
1567 .trace_err();
1568 }
1569
1570 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1571 if let Some(token) = live_kit
1572 .room_token(
1573 &joined_room.room.live_kit_room,
1574 &session.user_id().to_string(),
1575 )
1576 .trace_err()
1577 {
1578 Some(proto::LiveKitConnectionInfo {
1579 server_url: live_kit.url().into(),
1580 token,
1581 can_publish: true,
1582 })
1583 } else {
1584 None
1585 }
1586 } else {
1587 None
1588 };
1589
1590 response.send(proto::JoinRoomResponse {
1591 room: Some(joined_room.room),
1592 channel_id: None,
1593 live_kit_connection_info,
1594 })?;
1595
1596 update_user_contacts(session.user_id(), &session).await?;
1597 Ok(())
1598}
1599
1600/// Rejoin room is used to reconnect to a room after connection errors.
1601async fn rejoin_room(
1602 request: proto::RejoinRoom,
1603 response: Response<proto::RejoinRoom>,
1604 session: UserSession,
1605) -> Result<()> {
1606 let room;
1607 let channel;
1608 {
1609 let mut rejoined_room = session
1610 .db()
1611 .await
1612 .rejoin_room(request, session.user_id(), session.connection_id)
1613 .await?;
1614
1615 response.send(proto::RejoinRoomResponse {
1616 room: Some(rejoined_room.room.clone()),
1617 reshared_projects: rejoined_room
1618 .reshared_projects
1619 .iter()
1620 .map(|project| proto::ResharedProject {
1621 id: project.id.to_proto(),
1622 collaborators: project
1623 .collaborators
1624 .iter()
1625 .map(|collaborator| collaborator.to_proto())
1626 .collect(),
1627 })
1628 .collect(),
1629 rejoined_projects: rejoined_room
1630 .rejoined_projects
1631 .iter()
1632 .map(|rejoined_project| rejoined_project.to_proto())
1633 .collect(),
1634 })?;
1635 room_updated(&rejoined_room.room, &session.peer);
1636
1637 for project in &rejoined_room.reshared_projects {
1638 for collaborator in &project.collaborators {
1639 session
1640 .peer
1641 .send(
1642 collaborator.connection_id,
1643 proto::UpdateProjectCollaborator {
1644 project_id: project.id.to_proto(),
1645 old_peer_id: Some(project.old_connection_id.into()),
1646 new_peer_id: Some(session.connection_id.into()),
1647 },
1648 )
1649 .trace_err();
1650 }
1651
1652 broadcast(
1653 Some(session.connection_id),
1654 project
1655 .collaborators
1656 .iter()
1657 .map(|collaborator| collaborator.connection_id),
1658 |connection_id| {
1659 session.peer.forward_send(
1660 session.connection_id,
1661 connection_id,
1662 proto::UpdateProject {
1663 project_id: project.id.to_proto(),
1664 worktrees: project.worktrees.clone(),
1665 },
1666 )
1667 },
1668 );
1669 }
1670
1671 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1672
1673 let rejoined_room = rejoined_room.into_inner();
1674
1675 room = rejoined_room.room;
1676 channel = rejoined_room.channel;
1677 }
1678
1679 if let Some(channel) = channel {
1680 channel_updated(
1681 &channel,
1682 &room,
1683 &session.peer,
1684 &*session.connection_pool().await,
1685 );
1686 }
1687
1688 update_user_contacts(session.user_id(), &session).await?;
1689 Ok(())
1690}
1691
1692fn notify_rejoined_projects(
1693 rejoined_projects: &mut Vec<RejoinedProject>,
1694 session: &UserSession,
1695) -> Result<()> {
1696 for project in rejoined_projects.iter() {
1697 for collaborator in &project.collaborators {
1698 session
1699 .peer
1700 .send(
1701 collaborator.connection_id,
1702 proto::UpdateProjectCollaborator {
1703 project_id: project.id.to_proto(),
1704 old_peer_id: Some(project.old_connection_id.into()),
1705 new_peer_id: Some(session.connection_id.into()),
1706 },
1707 )
1708 .trace_err();
1709 }
1710 }
1711
1712 for project in rejoined_projects {
1713 for worktree in mem::take(&mut project.worktrees) {
1714 #[cfg(any(test, feature = "test-support"))]
1715 const MAX_CHUNK_SIZE: usize = 2;
1716 #[cfg(not(any(test, feature = "test-support")))]
1717 const MAX_CHUNK_SIZE: usize = 256;
1718
1719 // Stream this worktree's entries.
1720 let message = proto::UpdateWorktree {
1721 project_id: project.id.to_proto(),
1722 worktree_id: worktree.id,
1723 abs_path: worktree.abs_path.clone(),
1724 root_name: worktree.root_name,
1725 updated_entries: worktree.updated_entries,
1726 removed_entries: worktree.removed_entries,
1727 scan_id: worktree.scan_id,
1728 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1729 updated_repositories: worktree.updated_repositories,
1730 removed_repositories: worktree.removed_repositories,
1731 };
1732 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1733 session.peer.send(session.connection_id, update.clone())?;
1734 }
1735
1736 // Stream this worktree's diagnostics.
1737 for summary in worktree.diagnostic_summaries {
1738 session.peer.send(
1739 session.connection_id,
1740 proto::UpdateDiagnosticSummary {
1741 project_id: project.id.to_proto(),
1742 worktree_id: worktree.id,
1743 summary: Some(summary),
1744 },
1745 )?;
1746 }
1747
1748 for settings_file in worktree.settings_files {
1749 session.peer.send(
1750 session.connection_id,
1751 proto::UpdateWorktreeSettings {
1752 project_id: project.id.to_proto(),
1753 worktree_id: worktree.id,
1754 path: settings_file.path,
1755 content: Some(settings_file.content),
1756 },
1757 )?;
1758 }
1759 }
1760
1761 for language_server in &project.language_servers {
1762 session.peer.send(
1763 session.connection_id,
1764 proto::UpdateLanguageServer {
1765 project_id: project.id.to_proto(),
1766 language_server_id: language_server.id,
1767 variant: Some(
1768 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1769 proto::LspDiskBasedDiagnosticsUpdated {},
1770 ),
1771 ),
1772 },
1773 )?;
1774 }
1775 }
1776 Ok(())
1777}
1778
1779/// leave room disconnects from the room.
1780async fn leave_room(
1781 _: proto::LeaveRoom,
1782 response: Response<proto::LeaveRoom>,
1783 session: UserSession,
1784) -> Result<()> {
1785 leave_room_for_session(&session, session.connection_id).await?;
1786 response.send(proto::Ack {})?;
1787 Ok(())
1788}
1789
1790/// Updates the permissions of someone else in the room.
1791async fn set_room_participant_role(
1792 request: proto::SetRoomParticipantRole,
1793 response: Response<proto::SetRoomParticipantRole>,
1794 session: UserSession,
1795) -> Result<()> {
1796 let user_id = UserId::from_proto(request.user_id);
1797 let role = ChannelRole::from(request.role());
1798
1799 let (live_kit_room, can_publish) = {
1800 let room = session
1801 .db()
1802 .await
1803 .set_room_participant_role(
1804 session.user_id(),
1805 RoomId::from_proto(request.room_id),
1806 user_id,
1807 role,
1808 )
1809 .await?;
1810
1811 let live_kit_room = room.live_kit_room.clone();
1812 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1813 room_updated(&room, &session.peer);
1814 (live_kit_room, can_publish)
1815 };
1816
1817 if let Some(live_kit) = session.live_kit_client.as_ref() {
1818 live_kit
1819 .update_participant(
1820 live_kit_room.clone(),
1821 request.user_id.to_string(),
1822 live_kit_server::proto::ParticipantPermission {
1823 can_subscribe: true,
1824 can_publish,
1825 can_publish_data: can_publish,
1826 hidden: false,
1827 recorder: false,
1828 },
1829 )
1830 .await
1831 .trace_err();
1832 }
1833
1834 response.send(proto::Ack {})?;
1835 Ok(())
1836}
1837
1838/// Call someone else into the current room
1839async fn call(
1840 request: proto::Call,
1841 response: Response<proto::Call>,
1842 session: UserSession,
1843) -> Result<()> {
1844 let room_id = RoomId::from_proto(request.room_id);
1845 let calling_user_id = session.user_id();
1846 let calling_connection_id = session.connection_id;
1847 let called_user_id = UserId::from_proto(request.called_user_id);
1848 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1849 if !session
1850 .db()
1851 .await
1852 .has_contact(calling_user_id, called_user_id)
1853 .await?
1854 {
1855 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1856 }
1857
1858 let incoming_call = {
1859 let (room, incoming_call) = &mut *session
1860 .db()
1861 .await
1862 .call(
1863 room_id,
1864 calling_user_id,
1865 calling_connection_id,
1866 called_user_id,
1867 initial_project_id,
1868 )
1869 .await?;
1870 room_updated(&room, &session.peer);
1871 mem::take(incoming_call)
1872 };
1873 update_user_contacts(called_user_id, &session).await?;
1874
1875 let mut calls = session
1876 .connection_pool()
1877 .await
1878 .user_connection_ids(called_user_id)
1879 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1880 .collect::<FuturesUnordered<_>>();
1881
1882 while let Some(call_response) = calls.next().await {
1883 match call_response.as_ref() {
1884 Ok(_) => {
1885 response.send(proto::Ack {})?;
1886 return Ok(());
1887 }
1888 Err(_) => {
1889 call_response.trace_err();
1890 }
1891 }
1892 }
1893
1894 {
1895 let room = session
1896 .db()
1897 .await
1898 .call_failed(room_id, called_user_id)
1899 .await?;
1900 room_updated(&room, &session.peer);
1901 }
1902 update_user_contacts(called_user_id, &session).await?;
1903
1904 Err(anyhow!("failed to ring user"))?
1905}
1906
1907/// Cancel an outgoing call.
1908async fn cancel_call(
1909 request: proto::CancelCall,
1910 response: Response<proto::CancelCall>,
1911 session: UserSession,
1912) -> Result<()> {
1913 let called_user_id = UserId::from_proto(request.called_user_id);
1914 let room_id = RoomId::from_proto(request.room_id);
1915 {
1916 let room = session
1917 .db()
1918 .await
1919 .cancel_call(room_id, session.connection_id, called_user_id)
1920 .await?;
1921 room_updated(&room, &session.peer);
1922 }
1923
1924 for connection_id in session
1925 .connection_pool()
1926 .await
1927 .user_connection_ids(called_user_id)
1928 {
1929 session
1930 .peer
1931 .send(
1932 connection_id,
1933 proto::CallCanceled {
1934 room_id: room_id.to_proto(),
1935 },
1936 )
1937 .trace_err();
1938 }
1939 response.send(proto::Ack {})?;
1940
1941 update_user_contacts(called_user_id, &session).await?;
1942 Ok(())
1943}
1944
1945/// Decline an incoming call.
1946async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1947 let room_id = RoomId::from_proto(message.room_id);
1948 {
1949 let room = session
1950 .db()
1951 .await
1952 .decline_call(Some(room_id), session.user_id())
1953 .await?
1954 .ok_or_else(|| anyhow!("failed to decline call"))?;
1955 room_updated(&room, &session.peer);
1956 }
1957
1958 for connection_id in session
1959 .connection_pool()
1960 .await
1961 .user_connection_ids(session.user_id())
1962 {
1963 session
1964 .peer
1965 .send(
1966 connection_id,
1967 proto::CallCanceled {
1968 room_id: room_id.to_proto(),
1969 },
1970 )
1971 .trace_err();
1972 }
1973 update_user_contacts(session.user_id(), &session).await?;
1974 Ok(())
1975}
1976
1977/// Updates other participants in the room with your current location.
1978async fn update_participant_location(
1979 request: proto::UpdateParticipantLocation,
1980 response: Response<proto::UpdateParticipantLocation>,
1981 session: UserSession,
1982) -> Result<()> {
1983 let room_id = RoomId::from_proto(request.room_id);
1984 let location = request
1985 .location
1986 .ok_or_else(|| anyhow!("invalid location"))?;
1987
1988 let db = session.db().await;
1989 let room = db
1990 .update_room_participant_location(room_id, session.connection_id, location)
1991 .await?;
1992
1993 room_updated(&room, &session.peer);
1994 response.send(proto::Ack {})?;
1995 Ok(())
1996}
1997
1998/// Share a project into the room.
1999async fn share_project(
2000 request: proto::ShareProject,
2001 response: Response<proto::ShareProject>,
2002 session: UserSession,
2003) -> Result<()> {
2004 let (project_id, room) = &*session
2005 .db()
2006 .await
2007 .share_project(
2008 RoomId::from_proto(request.room_id),
2009 session.connection_id,
2010 &request.worktrees,
2011 request
2012 .dev_server_project_id
2013 .map(|id| DevServerProjectId::from_proto(id)),
2014 )
2015 .await?;
2016 response.send(proto::ShareProjectResponse {
2017 project_id: project_id.to_proto(),
2018 })?;
2019 room_updated(&room, &session.peer);
2020
2021 Ok(())
2022}
2023
2024/// Unshare a project from the room.
2025async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2026 let project_id = ProjectId::from_proto(message.project_id);
2027 unshare_project_internal(
2028 project_id,
2029 session.connection_id,
2030 session.user_id(),
2031 &session,
2032 )
2033 .await
2034}
2035
2036async fn unshare_project_internal(
2037 project_id: ProjectId,
2038 connection_id: ConnectionId,
2039 user_id: Option<UserId>,
2040 session: &Session,
2041) -> Result<()> {
2042 let delete = {
2043 let room_guard = session
2044 .db()
2045 .await
2046 .unshare_project(project_id, connection_id, user_id)
2047 .await?;
2048
2049 let (delete, room, guest_connection_ids) = &*room_guard;
2050
2051 let message = proto::UnshareProject {
2052 project_id: project_id.to_proto(),
2053 };
2054
2055 broadcast(
2056 Some(connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |conn_id| session.peer.send(conn_id, message.clone()),
2059 );
2060 if let Some(room) = room {
2061 room_updated(room, &session.peer);
2062 }
2063
2064 *delete
2065 };
2066
2067 if delete {
2068 let db = session.db().await;
2069 db.delete_project(project_id).await?;
2070 }
2071
2072 Ok(())
2073}
2074
2075/// DevServer makes a project available online
2076async fn share_dev_server_project(
2077 request: proto::ShareDevServerProject,
2078 response: Response<proto::ShareDevServerProject>,
2079 session: DevServerSession,
2080) -> Result<()> {
2081 let (dev_server_project, user_id, status) = session
2082 .db()
2083 .await
2084 .share_dev_server_project(
2085 DevServerProjectId::from_proto(request.dev_server_project_id),
2086 session.dev_server_id(),
2087 session.connection_id,
2088 &request.worktrees,
2089 )
2090 .await?;
2091 let Some(project_id) = dev_server_project.project_id else {
2092 return Err(anyhow!("failed to share remote project"))?;
2093 };
2094
2095 send_dev_server_projects_update(user_id, status, &session).await;
2096
2097 response.send(proto::ShareProjectResponse { project_id })?;
2098
2099 Ok(())
2100}
2101
2102/// Join someone elses shared project.
2103async fn join_project(
2104 request: proto::JoinProject,
2105 response: Response<proto::JoinProject>,
2106 session: UserSession,
2107) -> Result<()> {
2108 let project_id = ProjectId::from_proto(request.project_id);
2109
2110 tracing::info!(%project_id, "join project");
2111
2112 let db = session.db().await;
2113 let (project, replica_id) = &mut *db
2114 .join_project(project_id, session.connection_id, session.user_id())
2115 .await?;
2116 drop(db);
2117 tracing::info!(%project_id, "join remote project");
2118 join_project_internal(response, session, project, replica_id)
2119}
2120
2121trait JoinProjectInternalResponse {
2122 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2125 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126 Response::<proto::JoinProject>::send(self, result)
2127 }
2128}
2129impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2130 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2131 Response::<proto::JoinHostedProject>::send(self, result)
2132 }
2133}
2134
2135fn join_project_internal(
2136 response: impl JoinProjectInternalResponse,
2137 session: UserSession,
2138 project: &mut Project,
2139 replica_id: &ReplicaId,
2140) -> Result<()> {
2141 let collaborators = project
2142 .collaborators
2143 .iter()
2144 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2145 .map(|collaborator| collaborator.to_proto())
2146 .collect::<Vec<_>>();
2147 let project_id = project.id;
2148 let guest_user_id = session.user_id();
2149
2150 let worktrees = project
2151 .worktrees
2152 .iter()
2153 .map(|(id, worktree)| proto::WorktreeMetadata {
2154 id: *id,
2155 root_name: worktree.root_name.clone(),
2156 visible: worktree.visible,
2157 abs_path: worktree.abs_path.clone(),
2158 })
2159 .collect::<Vec<_>>();
2160
2161 let add_project_collaborator = proto::AddProjectCollaborator {
2162 project_id: project_id.to_proto(),
2163 collaborator: Some(proto::Collaborator {
2164 peer_id: Some(session.connection_id.into()),
2165 replica_id: replica_id.0 as u32,
2166 user_id: guest_user_id.to_proto(),
2167 }),
2168 };
2169
2170 for collaborator in &collaborators {
2171 session
2172 .peer
2173 .send(
2174 collaborator.peer_id.unwrap().into(),
2175 add_project_collaborator.clone(),
2176 )
2177 .trace_err();
2178 }
2179
2180 // First, we send the metadata associated with each worktree.
2181 response.send(proto::JoinProjectResponse {
2182 project_id: project.id.0 as u64,
2183 worktrees: worktrees.clone(),
2184 replica_id: replica_id.0 as u32,
2185 collaborators: collaborators.clone(),
2186 language_servers: project.language_servers.clone(),
2187 role: project.role.into(),
2188 dev_server_project_id: project
2189 .dev_server_project_id
2190 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2191 })?;
2192
2193 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2194 #[cfg(any(test, feature = "test-support"))]
2195 const MAX_CHUNK_SIZE: usize = 2;
2196 #[cfg(not(any(test, feature = "test-support")))]
2197 const MAX_CHUNK_SIZE: usize = 256;
2198
2199 // Stream this worktree's entries.
2200 let message = proto::UpdateWorktree {
2201 project_id: project_id.to_proto(),
2202 worktree_id,
2203 abs_path: worktree.abs_path.clone(),
2204 root_name: worktree.root_name,
2205 updated_entries: worktree.entries,
2206 removed_entries: Default::default(),
2207 scan_id: worktree.scan_id,
2208 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2209 updated_repositories: worktree.repository_entries.into_values().collect(),
2210 removed_repositories: Default::default(),
2211 };
2212 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2213 session.peer.send(session.connection_id, update.clone())?;
2214 }
2215
2216 // Stream this worktree's diagnostics.
2217 for summary in worktree.diagnostic_summaries {
2218 session.peer.send(
2219 session.connection_id,
2220 proto::UpdateDiagnosticSummary {
2221 project_id: project_id.to_proto(),
2222 worktree_id: worktree.id,
2223 summary: Some(summary),
2224 },
2225 )?;
2226 }
2227
2228 for settings_file in worktree.settings_files {
2229 session.peer.send(
2230 session.connection_id,
2231 proto::UpdateWorktreeSettings {
2232 project_id: project_id.to_proto(),
2233 worktree_id: worktree.id,
2234 path: settings_file.path,
2235 content: Some(settings_file.content),
2236 },
2237 )?;
2238 }
2239 }
2240
2241 for language_server in &project.language_servers {
2242 session.peer.send(
2243 session.connection_id,
2244 proto::UpdateLanguageServer {
2245 project_id: project_id.to_proto(),
2246 language_server_id: language_server.id,
2247 variant: Some(
2248 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2249 proto::LspDiskBasedDiagnosticsUpdated {},
2250 ),
2251 ),
2252 },
2253 )?;
2254 }
2255
2256 Ok(())
2257}
2258
2259/// Leave someone elses shared project.
2260async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2261 let sender_id = session.connection_id;
2262 let project_id = ProjectId::from_proto(request.project_id);
2263 let db = session.db().await;
2264 if db.is_hosted_project(project_id).await? {
2265 let project = db.leave_hosted_project(project_id, sender_id).await?;
2266 project_left(&project, &session);
2267 return Ok(());
2268 }
2269
2270 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2271 tracing::info!(
2272 %project_id,
2273 "leave project"
2274 );
2275
2276 project_left(&project, &session);
2277 if let Some(room) = room {
2278 room_updated(&room, &session.peer);
2279 }
2280
2281 Ok(())
2282}
2283
2284async fn join_hosted_project(
2285 request: proto::JoinHostedProject,
2286 response: Response<proto::JoinHostedProject>,
2287 session: UserSession,
2288) -> Result<()> {
2289 let (mut project, replica_id) = session
2290 .db()
2291 .await
2292 .join_hosted_project(
2293 ProjectId(request.project_id as i32),
2294 session.user_id(),
2295 session.connection_id,
2296 )
2297 .await?;
2298
2299 join_project_internal(response, session, &mut project, &replica_id)
2300}
2301
2302async fn create_dev_server_project(
2303 request: proto::CreateDevServerProject,
2304 response: Response<proto::CreateDevServerProject>,
2305 session: UserSession,
2306) -> Result<()> {
2307 let dev_server_id = DevServerId(request.dev_server_id as i32);
2308 let dev_server_connection_id = session
2309 .connection_pool()
2310 .await
2311 .dev_server_connection_id(dev_server_id);
2312 let Some(dev_server_connection_id) = dev_server_connection_id else {
2313 Err(ErrorCode::DevServerOffline
2314 .message("Cannot create a remote project when the dev server is offline".to_string())
2315 .anyhow())?
2316 };
2317
2318 let path = request.path.clone();
2319 //Check that the path exists on the dev server
2320 session
2321 .peer
2322 .forward_request(
2323 session.connection_id,
2324 dev_server_connection_id,
2325 proto::ValidateDevServerProjectRequest { path: path.clone() },
2326 )
2327 .await?;
2328
2329 let (dev_server_project, update) = session
2330 .db()
2331 .await
2332 .create_dev_server_project(
2333 DevServerId(request.dev_server_id as i32),
2334 &request.path,
2335 session.user_id(),
2336 )
2337 .await?;
2338
2339 let projects = session
2340 .db()
2341 .await
2342 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2343 .await?;
2344
2345 session.peer.send(
2346 dev_server_connection_id,
2347 proto::DevServerInstructions { projects },
2348 )?;
2349
2350 send_dev_server_projects_update(session.user_id(), update, &session).await;
2351
2352 response.send(proto::CreateDevServerProjectResponse {
2353 dev_server_project: Some(dev_server_project.to_proto(None)),
2354 })?;
2355 Ok(())
2356}
2357
2358async fn create_dev_server(
2359 request: proto::CreateDevServer,
2360 response: Response<proto::CreateDevServer>,
2361 session: UserSession,
2362) -> Result<()> {
2363 let access_token = auth::random_token();
2364 let hashed_access_token = auth::hash_access_token(&access_token);
2365
2366 if request.name.is_empty() {
2367 return Err(proto::ErrorCode::Forbidden
2368 .message("Dev server name cannot be empty".to_string())
2369 .anyhow())?;
2370 }
2371
2372 let (dev_server, status) = session
2373 .db()
2374 .await
2375 .create_dev_server(
2376 &request.name,
2377 request.ssh_connection_string.as_deref(),
2378 &hashed_access_token,
2379 session.user_id(),
2380 )
2381 .await?;
2382
2383 send_dev_server_projects_update(session.user_id(), status, &session).await;
2384
2385 response.send(proto::CreateDevServerResponse {
2386 dev_server_id: dev_server.id.0 as u64,
2387 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2388 name: request.name,
2389 })?;
2390 Ok(())
2391}
2392
2393async fn regenerate_dev_server_token(
2394 request: proto::RegenerateDevServerToken,
2395 response: Response<proto::RegenerateDevServerToken>,
2396 session: UserSession,
2397) -> Result<()> {
2398 let dev_server_id = DevServerId(request.dev_server_id as i32);
2399 let access_token = auth::random_token();
2400 let hashed_access_token = auth::hash_access_token(&access_token);
2401
2402 let connection_id = session
2403 .connection_pool()
2404 .await
2405 .dev_server_connection_id(dev_server_id);
2406 if let Some(connection_id) = connection_id {
2407 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2408 session.peer.send(
2409 connection_id,
2410 proto::ShutdownDevServer {
2411 reason: Some("dev server token was regenerated".to_string()),
2412 },
2413 )?;
2414 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2415 }
2416
2417 let status = session
2418 .db()
2419 .await
2420 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2421 .await?;
2422
2423 send_dev_server_projects_update(session.user_id(), status, &session).await;
2424
2425 response.send(proto::RegenerateDevServerTokenResponse {
2426 dev_server_id: dev_server_id.to_proto(),
2427 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2428 })?;
2429 Ok(())
2430}
2431
2432async fn rename_dev_server(
2433 request: proto::RenameDevServer,
2434 response: Response<proto::RenameDevServer>,
2435 session: UserSession,
2436) -> Result<()> {
2437 if request.name.trim().is_empty() {
2438 return Err(proto::ErrorCode::Forbidden
2439 .message("Dev server name cannot be empty".to_string())
2440 .anyhow())?;
2441 }
2442
2443 let dev_server_id = DevServerId(request.dev_server_id as i32);
2444 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2445 if dev_server.user_id != session.user_id() {
2446 return Err(anyhow!(ErrorCode::Forbidden))?;
2447 }
2448
2449 let status = session
2450 .db()
2451 .await
2452 .rename_dev_server(
2453 dev_server_id,
2454 &request.name,
2455 request.ssh_connection_string.as_deref(),
2456 session.user_id(),
2457 )
2458 .await?;
2459
2460 send_dev_server_projects_update(session.user_id(), status, &session).await;
2461
2462 response.send(proto::Ack {})?;
2463 Ok(())
2464}
2465
2466async fn delete_dev_server(
2467 request: proto::DeleteDevServer,
2468 response: Response<proto::DeleteDevServer>,
2469 session: UserSession,
2470) -> Result<()> {
2471 let dev_server_id = DevServerId(request.dev_server_id as i32);
2472 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2473 if dev_server.user_id != session.user_id() {
2474 return Err(anyhow!(ErrorCode::Forbidden))?;
2475 }
2476
2477 let connection_id = session
2478 .connection_pool()
2479 .await
2480 .dev_server_connection_id(dev_server_id);
2481 if let Some(connection_id) = connection_id {
2482 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2483 session.peer.send(
2484 connection_id,
2485 proto::ShutdownDevServer {
2486 reason: Some("dev server was deleted".to_string()),
2487 },
2488 )?;
2489 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2490 }
2491
2492 let status = session
2493 .db()
2494 .await
2495 .delete_dev_server(dev_server_id, session.user_id())
2496 .await?;
2497
2498 send_dev_server_projects_update(session.user_id(), status, &session).await;
2499
2500 response.send(proto::Ack {})?;
2501 Ok(())
2502}
2503
2504async fn delete_dev_server_project(
2505 request: proto::DeleteDevServerProject,
2506 response: Response<proto::DeleteDevServerProject>,
2507 session: UserSession,
2508) -> Result<()> {
2509 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2510 let dev_server_project = session
2511 .db()
2512 .await
2513 .get_dev_server_project(dev_server_project_id)
2514 .await?;
2515
2516 let dev_server = session
2517 .db()
2518 .await
2519 .get_dev_server(dev_server_project.dev_server_id)
2520 .await?;
2521 if dev_server.user_id != session.user_id() {
2522 return Err(anyhow!(ErrorCode::Forbidden))?;
2523 }
2524
2525 let dev_server_connection_id = session
2526 .connection_pool()
2527 .await
2528 .dev_server_connection_id(dev_server.id);
2529
2530 if let Some(dev_server_connection_id) = dev_server_connection_id {
2531 let project = session
2532 .db()
2533 .await
2534 .find_dev_server_project(dev_server_project_id)
2535 .await;
2536 if let Ok(project) = project {
2537 unshare_project_internal(
2538 project.id,
2539 dev_server_connection_id,
2540 Some(session.user_id()),
2541 &session,
2542 )
2543 .await?;
2544 }
2545 }
2546
2547 let (projects, status) = session
2548 .db()
2549 .await
2550 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2551 .await?;
2552
2553 if let Some(dev_server_connection_id) = dev_server_connection_id {
2554 session.peer.send(
2555 dev_server_connection_id,
2556 proto::DevServerInstructions { projects },
2557 )?;
2558 }
2559
2560 send_dev_server_projects_update(session.user_id(), status, &session).await;
2561
2562 response.send(proto::Ack {})?;
2563 Ok(())
2564}
2565
2566async fn rejoin_dev_server_projects(
2567 request: proto::RejoinRemoteProjects,
2568 response: Response<proto::RejoinRemoteProjects>,
2569 session: UserSession,
2570) -> Result<()> {
2571 let mut rejoined_projects = {
2572 let db = session.db().await;
2573 db.rejoin_dev_server_projects(
2574 &request.rejoined_projects,
2575 session.user_id(),
2576 session.0.connection_id,
2577 )
2578 .await?
2579 };
2580 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2581
2582 response.send(proto::RejoinRemoteProjectsResponse {
2583 rejoined_projects: rejoined_projects
2584 .into_iter()
2585 .map(|project| project.to_proto())
2586 .collect(),
2587 })
2588}
2589
2590async fn reconnect_dev_server(
2591 request: proto::ReconnectDevServer,
2592 response: Response<proto::ReconnectDevServer>,
2593 session: DevServerSession,
2594) -> Result<()> {
2595 let reshared_projects = {
2596 let db = session.db().await;
2597 db.reshare_dev_server_projects(
2598 &request.reshared_projects,
2599 session.dev_server_id(),
2600 session.0.connection_id,
2601 )
2602 .await?
2603 };
2604
2605 for project in &reshared_projects {
2606 for collaborator in &project.collaborators {
2607 session
2608 .peer
2609 .send(
2610 collaborator.connection_id,
2611 proto::UpdateProjectCollaborator {
2612 project_id: project.id.to_proto(),
2613 old_peer_id: Some(project.old_connection_id.into()),
2614 new_peer_id: Some(session.connection_id.into()),
2615 },
2616 )
2617 .trace_err();
2618 }
2619
2620 broadcast(
2621 Some(session.connection_id),
2622 project
2623 .collaborators
2624 .iter()
2625 .map(|collaborator| collaborator.connection_id),
2626 |connection_id| {
2627 session.peer.forward_send(
2628 session.connection_id,
2629 connection_id,
2630 proto::UpdateProject {
2631 project_id: project.id.to_proto(),
2632 worktrees: project.worktrees.clone(),
2633 },
2634 )
2635 },
2636 );
2637 }
2638
2639 response.send(proto::ReconnectDevServerResponse {
2640 reshared_projects: reshared_projects
2641 .iter()
2642 .map(|project| proto::ResharedProject {
2643 id: project.id.to_proto(),
2644 collaborators: project
2645 .collaborators
2646 .iter()
2647 .map(|collaborator| collaborator.to_proto())
2648 .collect(),
2649 })
2650 .collect(),
2651 })?;
2652
2653 Ok(())
2654}
2655
2656async fn shutdown_dev_server(
2657 _: proto::ShutdownDevServer,
2658 response: Response<proto::ShutdownDevServer>,
2659 session: DevServerSession,
2660) -> Result<()> {
2661 response.send(proto::Ack {})?;
2662 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2663 remove_dev_server_connection(session.dev_server_id(), &session).await
2664}
2665
2666async fn shutdown_dev_server_internal(
2667 dev_server_id: DevServerId,
2668 connection_id: ConnectionId,
2669 session: &Session,
2670) -> Result<()> {
2671 let (dev_server_projects, dev_server) = {
2672 let db = session.db().await;
2673 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2674 let dev_server = db.get_dev_server(dev_server_id).await?;
2675 (dev_server_projects, dev_server)
2676 };
2677
2678 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2679 unshare_project_internal(
2680 ProjectId::from_proto(project_id),
2681 connection_id,
2682 None,
2683 session,
2684 )
2685 .await?;
2686 }
2687
2688 session
2689 .connection_pool()
2690 .await
2691 .set_dev_server_offline(dev_server_id);
2692
2693 let status = session
2694 .db()
2695 .await
2696 .dev_server_projects_update(dev_server.user_id)
2697 .await?;
2698 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2699
2700 Ok(())
2701}
2702
2703async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2704 let dev_server_connection = session
2705 .connection_pool()
2706 .await
2707 .dev_server_connection_id(dev_server_id);
2708
2709 if let Some(dev_server_connection) = dev_server_connection {
2710 session
2711 .connection_pool()
2712 .await
2713 .remove_connection(dev_server_connection)?;
2714 }
2715 Ok(())
2716}
2717
2718/// Updates other participants with changes to the project
2719async fn update_project(
2720 request: proto::UpdateProject,
2721 response: Response<proto::UpdateProject>,
2722 session: Session,
2723) -> Result<()> {
2724 let project_id = ProjectId::from_proto(request.project_id);
2725 let (room, guest_connection_ids) = &*session
2726 .db()
2727 .await
2728 .update_project(project_id, session.connection_id, &request.worktrees)
2729 .await?;
2730 broadcast(
2731 Some(session.connection_id),
2732 guest_connection_ids.iter().copied(),
2733 |connection_id| {
2734 session
2735 .peer
2736 .forward_send(session.connection_id, connection_id, request.clone())
2737 },
2738 );
2739 if let Some(room) = room {
2740 room_updated(&room, &session.peer);
2741 }
2742 response.send(proto::Ack {})?;
2743
2744 Ok(())
2745}
2746
2747/// Updates other participants with changes to the worktree
2748async fn update_worktree(
2749 request: proto::UpdateWorktree,
2750 response: Response<proto::UpdateWorktree>,
2751 session: Session,
2752) -> Result<()> {
2753 let guest_connection_ids = session
2754 .db()
2755 .await
2756 .update_worktree(&request, session.connection_id)
2757 .await?;
2758
2759 broadcast(
2760 Some(session.connection_id),
2761 guest_connection_ids.iter().copied(),
2762 |connection_id| {
2763 session
2764 .peer
2765 .forward_send(session.connection_id, connection_id, request.clone())
2766 },
2767 );
2768 response.send(proto::Ack {})?;
2769 Ok(())
2770}
2771
2772/// Updates other participants with changes to the diagnostics
2773async fn update_diagnostic_summary(
2774 message: proto::UpdateDiagnosticSummary,
2775 session: Session,
2776) -> Result<()> {
2777 let guest_connection_ids = session
2778 .db()
2779 .await
2780 .update_diagnostic_summary(&message, session.connection_id)
2781 .await?;
2782
2783 broadcast(
2784 Some(session.connection_id),
2785 guest_connection_ids.iter().copied(),
2786 |connection_id| {
2787 session
2788 .peer
2789 .forward_send(session.connection_id, connection_id, message.clone())
2790 },
2791 );
2792
2793 Ok(())
2794}
2795
2796/// Updates other participants with changes to the worktree settings
2797async fn update_worktree_settings(
2798 message: proto::UpdateWorktreeSettings,
2799 session: Session,
2800) -> Result<()> {
2801 let guest_connection_ids = session
2802 .db()
2803 .await
2804 .update_worktree_settings(&message, session.connection_id)
2805 .await?;
2806
2807 broadcast(
2808 Some(session.connection_id),
2809 guest_connection_ids.iter().copied(),
2810 |connection_id| {
2811 session
2812 .peer
2813 .forward_send(session.connection_id, connection_id, message.clone())
2814 },
2815 );
2816
2817 Ok(())
2818}
2819
2820/// Notify other participants that a language server has started.
2821async fn start_language_server(
2822 request: proto::StartLanguageServer,
2823 session: Session,
2824) -> Result<()> {
2825 let guest_connection_ids = session
2826 .db()
2827 .await
2828 .start_language_server(&request, session.connection_id)
2829 .await?;
2830
2831 broadcast(
2832 Some(session.connection_id),
2833 guest_connection_ids.iter().copied(),
2834 |connection_id| {
2835 session
2836 .peer
2837 .forward_send(session.connection_id, connection_id, request.clone())
2838 },
2839 );
2840 Ok(())
2841}
2842
2843/// Notify other participants that a language server has changed.
2844async fn update_language_server(
2845 request: proto::UpdateLanguageServer,
2846 session: Session,
2847) -> Result<()> {
2848 let project_id = ProjectId::from_proto(request.project_id);
2849 let project_connection_ids = session
2850 .db()
2851 .await
2852 .project_connection_ids(project_id, session.connection_id, true)
2853 .await?;
2854 broadcast(
2855 Some(session.connection_id),
2856 project_connection_ids.iter().copied(),
2857 |connection_id| {
2858 session
2859 .peer
2860 .forward_send(session.connection_id, connection_id, request.clone())
2861 },
2862 );
2863 Ok(())
2864}
2865
2866/// forward a project request to the host. These requests should be read only
2867/// as guests are allowed to send them.
2868async fn forward_read_only_project_request<T>(
2869 request: T,
2870 response: Response<T>,
2871 session: UserSession,
2872) -> Result<()>
2873where
2874 T: EntityMessage + RequestMessage,
2875{
2876 let project_id = ProjectId::from_proto(request.remote_entity_id());
2877 let host_connection_id = session
2878 .db()
2879 .await
2880 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2881 .await?;
2882 let payload = session
2883 .peer
2884 .forward_request(session.connection_id, host_connection_id, request)
2885 .await?;
2886 response.send(payload)?;
2887 Ok(())
2888}
2889
2890/// forward a project request to the dev server. Only allowed
2891/// if it's your dev server.
2892async fn forward_project_request_for_owner<T>(
2893 request: T,
2894 response: Response<T>,
2895 session: UserSession,
2896) -> Result<()>
2897where
2898 T: EntityMessage + RequestMessage,
2899{
2900 let project_id = ProjectId::from_proto(request.remote_entity_id());
2901
2902 let host_connection_id = session
2903 .db()
2904 .await
2905 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2906 .await?;
2907 let payload = session
2908 .peer
2909 .forward_request(session.connection_id, host_connection_id, request)
2910 .await?;
2911 response.send(payload)?;
2912 Ok(())
2913}
2914
2915/// forward a project request to the host. These requests are disallowed
2916/// for guests.
2917async fn forward_mutating_project_request<T>(
2918 request: T,
2919 response: Response<T>,
2920 session: UserSession,
2921) -> Result<()>
2922where
2923 T: EntityMessage + RequestMessage,
2924{
2925 let project_id = ProjectId::from_proto(request.remote_entity_id());
2926
2927 let host_connection_id = session
2928 .db()
2929 .await
2930 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2931 .await?;
2932 let payload = session
2933 .peer
2934 .forward_request(session.connection_id, host_connection_id, request)
2935 .await?;
2936 response.send(payload)?;
2937 Ok(())
2938}
2939
2940/// forward a project request to the host. These requests are disallowed
2941/// for guests.
2942async fn forward_versioned_mutating_project_request<T>(
2943 request: T,
2944 response: Response<T>,
2945 session: UserSession,
2946) -> Result<()>
2947where
2948 T: EntityMessage + RequestMessage + VersionedMessage,
2949{
2950 let project_id = ProjectId::from_proto(request.remote_entity_id());
2951
2952 let host_connection_id = session
2953 .db()
2954 .await
2955 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2956 .await?;
2957 if let Some(host_version) = session
2958 .connection_pool()
2959 .await
2960 .connection(host_connection_id)
2961 .map(|c| c.zed_version)
2962 {
2963 if let Some(min_required_version) = request.required_host_version() {
2964 if min_required_version > host_version {
2965 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2966 .with_tag("required", &min_required_version.to_string())))?;
2967 }
2968 }
2969 }
2970
2971 let payload = session
2972 .peer
2973 .forward_request(session.connection_id, host_connection_id, request)
2974 .await?;
2975 response.send(payload)?;
2976 Ok(())
2977}
2978
2979/// Notify other participants that a new buffer has been created
2980async fn create_buffer_for_peer(
2981 request: proto::CreateBufferForPeer,
2982 session: Session,
2983) -> Result<()> {
2984 session
2985 .db()
2986 .await
2987 .check_user_is_project_host(
2988 ProjectId::from_proto(request.project_id),
2989 session.connection_id,
2990 )
2991 .await?;
2992 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2993 session
2994 .peer
2995 .forward_send(session.connection_id, peer_id.into(), request)?;
2996 Ok(())
2997}
2998
2999/// Notify other participants that a buffer has been updated. This is
3000/// allowed for guests as long as the update is limited to selections.
3001async fn update_buffer(
3002 request: proto::UpdateBuffer,
3003 response: Response<proto::UpdateBuffer>,
3004 session: Session,
3005) -> Result<()> {
3006 let project_id = ProjectId::from_proto(request.project_id);
3007 let mut capability = Capability::ReadOnly;
3008
3009 for op in request.operations.iter() {
3010 match op.variant {
3011 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3012 Some(_) => capability = Capability::ReadWrite,
3013 }
3014 }
3015
3016 let host = {
3017 let guard = session
3018 .db()
3019 .await
3020 .connections_for_buffer_update(
3021 project_id,
3022 session.principal_id(),
3023 session.connection_id,
3024 capability,
3025 )
3026 .await?;
3027
3028 let (host, guests) = &*guard;
3029
3030 broadcast(
3031 Some(session.connection_id),
3032 guests.clone(),
3033 |connection_id| {
3034 session
3035 .peer
3036 .forward_send(session.connection_id, connection_id, request.clone())
3037 },
3038 );
3039
3040 *host
3041 };
3042
3043 if host != session.connection_id {
3044 session
3045 .peer
3046 .forward_request(session.connection_id, host, request.clone())
3047 .await?;
3048 }
3049
3050 response.send(proto::Ack {})?;
3051 Ok(())
3052}
3053
3054/// Notify other participants that a project has been updated.
3055async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3056 request: T,
3057 session: Session,
3058) -> Result<()> {
3059 let project_id = ProjectId::from_proto(request.remote_entity_id());
3060 let project_connection_ids = session
3061 .db()
3062 .await
3063 .project_connection_ids(project_id, session.connection_id, false)
3064 .await?;
3065
3066 broadcast(
3067 Some(session.connection_id),
3068 project_connection_ids.iter().copied(),
3069 |connection_id| {
3070 session
3071 .peer
3072 .forward_send(session.connection_id, connection_id, request.clone())
3073 },
3074 );
3075 Ok(())
3076}
3077
3078/// Start following another user in a call.
3079async fn follow(
3080 request: proto::Follow,
3081 response: Response<proto::Follow>,
3082 session: UserSession,
3083) -> Result<()> {
3084 let room_id = RoomId::from_proto(request.room_id);
3085 let project_id = request.project_id.map(ProjectId::from_proto);
3086 let leader_id = request
3087 .leader_id
3088 .ok_or_else(|| anyhow!("invalid leader id"))?
3089 .into();
3090 let follower_id = session.connection_id;
3091
3092 session
3093 .db()
3094 .await
3095 .check_room_participants(room_id, leader_id, session.connection_id)
3096 .await?;
3097
3098 let response_payload = session
3099 .peer
3100 .forward_request(session.connection_id, leader_id, request)
3101 .await?;
3102 response.send(response_payload)?;
3103
3104 if let Some(project_id) = project_id {
3105 let room = session
3106 .db()
3107 .await
3108 .follow(room_id, project_id, leader_id, follower_id)
3109 .await?;
3110 room_updated(&room, &session.peer);
3111 }
3112
3113 Ok(())
3114}
3115
3116/// Stop following another user in a call.
3117async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3118 let room_id = RoomId::from_proto(request.room_id);
3119 let project_id = request.project_id.map(ProjectId::from_proto);
3120 let leader_id = request
3121 .leader_id
3122 .ok_or_else(|| anyhow!("invalid leader id"))?
3123 .into();
3124 let follower_id = session.connection_id;
3125
3126 session
3127 .db()
3128 .await
3129 .check_room_participants(room_id, leader_id, session.connection_id)
3130 .await?;
3131
3132 session
3133 .peer
3134 .forward_send(session.connection_id, leader_id, request)?;
3135
3136 if let Some(project_id) = project_id {
3137 let room = session
3138 .db()
3139 .await
3140 .unfollow(room_id, project_id, leader_id, follower_id)
3141 .await?;
3142 room_updated(&room, &session.peer);
3143 }
3144
3145 Ok(())
3146}
3147
3148/// Notify everyone following you of your current location.
3149async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3150 let room_id = RoomId::from_proto(request.room_id);
3151 let database = session.db.lock().await;
3152
3153 let connection_ids = if let Some(project_id) = request.project_id {
3154 let project_id = ProjectId::from_proto(project_id);
3155 database
3156 .project_connection_ids(project_id, session.connection_id, true)
3157 .await?
3158 } else {
3159 database
3160 .room_connection_ids(room_id, session.connection_id)
3161 .await?
3162 };
3163
3164 // For now, don't send view update messages back to that view's current leader.
3165 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3166 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3167 _ => None,
3168 });
3169
3170 for connection_id in connection_ids.iter().cloned() {
3171 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3172 session
3173 .peer
3174 .forward_send(session.connection_id, connection_id, request.clone())?;
3175 }
3176 }
3177 Ok(())
3178}
3179
3180/// Get public data about users.
3181async fn get_users(
3182 request: proto::GetUsers,
3183 response: Response<proto::GetUsers>,
3184 session: Session,
3185) -> Result<()> {
3186 let user_ids = request
3187 .user_ids
3188 .into_iter()
3189 .map(UserId::from_proto)
3190 .collect();
3191 let users = session
3192 .db()
3193 .await
3194 .get_users_by_ids(user_ids)
3195 .await?
3196 .into_iter()
3197 .map(|user| proto::User {
3198 id: user.id.to_proto(),
3199 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3200 github_login: user.github_login,
3201 })
3202 .collect();
3203 response.send(proto::UsersResponse { users })?;
3204 Ok(())
3205}
3206
3207/// Search for users (to invite) buy Github login
3208async fn fuzzy_search_users(
3209 request: proto::FuzzySearchUsers,
3210 response: Response<proto::FuzzySearchUsers>,
3211 session: UserSession,
3212) -> Result<()> {
3213 let query = request.query;
3214 let users = match query.len() {
3215 0 => vec![],
3216 1 | 2 => session
3217 .db()
3218 .await
3219 .get_user_by_github_login(&query)
3220 .await?
3221 .into_iter()
3222 .collect(),
3223 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3224 };
3225 let users = users
3226 .into_iter()
3227 .filter(|user| user.id != session.user_id())
3228 .map(|user| proto::User {
3229 id: user.id.to_proto(),
3230 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3231 github_login: user.github_login,
3232 })
3233 .collect();
3234 response.send(proto::UsersResponse { users })?;
3235 Ok(())
3236}
3237
3238/// Send a contact request to another user.
3239async fn request_contact(
3240 request: proto::RequestContact,
3241 response: Response<proto::RequestContact>,
3242 session: UserSession,
3243) -> Result<()> {
3244 let requester_id = session.user_id();
3245 let responder_id = UserId::from_proto(request.responder_id);
3246 if requester_id == responder_id {
3247 return Err(anyhow!("cannot add yourself as a contact"))?;
3248 }
3249
3250 let notifications = session
3251 .db()
3252 .await
3253 .send_contact_request(requester_id, responder_id)
3254 .await?;
3255
3256 // Update outgoing contact requests of requester
3257 let mut update = proto::UpdateContacts::default();
3258 update.outgoing_requests.push(responder_id.to_proto());
3259 for connection_id in session
3260 .connection_pool()
3261 .await
3262 .user_connection_ids(requester_id)
3263 {
3264 session.peer.send(connection_id, update.clone())?;
3265 }
3266
3267 // Update incoming contact requests of responder
3268 let mut update = proto::UpdateContacts::default();
3269 update
3270 .incoming_requests
3271 .push(proto::IncomingContactRequest {
3272 requester_id: requester_id.to_proto(),
3273 });
3274 let connection_pool = session.connection_pool().await;
3275 for connection_id in connection_pool.user_connection_ids(responder_id) {
3276 session.peer.send(connection_id, update.clone())?;
3277 }
3278
3279 send_notifications(&connection_pool, &session.peer, notifications);
3280
3281 response.send(proto::Ack {})?;
3282 Ok(())
3283}
3284
3285/// Accept or decline a contact request
3286async fn respond_to_contact_request(
3287 request: proto::RespondToContactRequest,
3288 response: Response<proto::RespondToContactRequest>,
3289 session: UserSession,
3290) -> Result<()> {
3291 let responder_id = session.user_id();
3292 let requester_id = UserId::from_proto(request.requester_id);
3293 let db = session.db().await;
3294 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3295 db.dismiss_contact_notification(responder_id, requester_id)
3296 .await?;
3297 } else {
3298 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3299
3300 let notifications = db
3301 .respond_to_contact_request(responder_id, requester_id, accept)
3302 .await?;
3303 let requester_busy = db.is_user_busy(requester_id).await?;
3304 let responder_busy = db.is_user_busy(responder_id).await?;
3305
3306 let pool = session.connection_pool().await;
3307 // Update responder with new contact
3308 let mut update = proto::UpdateContacts::default();
3309 if accept {
3310 update
3311 .contacts
3312 .push(contact_for_user(requester_id, requester_busy, &pool));
3313 }
3314 update
3315 .remove_incoming_requests
3316 .push(requester_id.to_proto());
3317 for connection_id in pool.user_connection_ids(responder_id) {
3318 session.peer.send(connection_id, update.clone())?;
3319 }
3320
3321 // Update requester with new contact
3322 let mut update = proto::UpdateContacts::default();
3323 if accept {
3324 update
3325 .contacts
3326 .push(contact_for_user(responder_id, responder_busy, &pool));
3327 }
3328 update
3329 .remove_outgoing_requests
3330 .push(responder_id.to_proto());
3331
3332 for connection_id in pool.user_connection_ids(requester_id) {
3333 session.peer.send(connection_id, update.clone())?;
3334 }
3335
3336 send_notifications(&pool, &session.peer, notifications);
3337 }
3338
3339 response.send(proto::Ack {})?;
3340 Ok(())
3341}
3342
3343/// Remove a contact.
3344async fn remove_contact(
3345 request: proto::RemoveContact,
3346 response: Response<proto::RemoveContact>,
3347 session: UserSession,
3348) -> Result<()> {
3349 let requester_id = session.user_id();
3350 let responder_id = UserId::from_proto(request.user_id);
3351 let db = session.db().await;
3352 let (contact_accepted, deleted_notification_id) =
3353 db.remove_contact(requester_id, responder_id).await?;
3354
3355 let pool = session.connection_pool().await;
3356 // Update outgoing contact requests of requester
3357 let mut update = proto::UpdateContacts::default();
3358 if contact_accepted {
3359 update.remove_contacts.push(responder_id.to_proto());
3360 } else {
3361 update
3362 .remove_outgoing_requests
3363 .push(responder_id.to_proto());
3364 }
3365 for connection_id in pool.user_connection_ids(requester_id) {
3366 session.peer.send(connection_id, update.clone())?;
3367 }
3368
3369 // Update incoming contact requests of responder
3370 let mut update = proto::UpdateContacts::default();
3371 if contact_accepted {
3372 update.remove_contacts.push(requester_id.to_proto());
3373 } else {
3374 update
3375 .remove_incoming_requests
3376 .push(requester_id.to_proto());
3377 }
3378 for connection_id in pool.user_connection_ids(responder_id) {
3379 session.peer.send(connection_id, update.clone())?;
3380 if let Some(notification_id) = deleted_notification_id {
3381 session.peer.send(
3382 connection_id,
3383 proto::DeleteNotification {
3384 notification_id: notification_id.to_proto(),
3385 },
3386 )?;
3387 }
3388 }
3389
3390 response.send(proto::Ack {})?;
3391 Ok(())
3392}
3393
3394fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3395 version.0.minor() < 139
3396}
3397
3398async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3399 subscribe_user_to_channels(
3400 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3401 &session,
3402 )
3403 .await?;
3404 Ok(())
3405}
3406
3407async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3408 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3409 let mut pool = session.connection_pool().await;
3410 for membership in &channels_for_user.channel_memberships {
3411 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3412 }
3413 session.peer.send(
3414 session.connection_id,
3415 build_update_user_channels(&channels_for_user),
3416 )?;
3417 session.peer.send(
3418 session.connection_id,
3419 build_channels_update(channels_for_user),
3420 )?;
3421 Ok(())
3422}
3423
3424/// Creates a new channel.
3425async fn create_channel(
3426 request: proto::CreateChannel,
3427 response: Response<proto::CreateChannel>,
3428 session: UserSession,
3429) -> Result<()> {
3430 let db = session.db().await;
3431
3432 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3433 let (channel, membership) = db
3434 .create_channel(&request.name, parent_id, session.user_id())
3435 .await?;
3436
3437 let root_id = channel.root_id();
3438 let channel = Channel::from_model(channel);
3439
3440 response.send(proto::CreateChannelResponse {
3441 channel: Some(channel.to_proto()),
3442 parent_id: request.parent_id,
3443 })?;
3444
3445 let mut connection_pool = session.connection_pool().await;
3446 if let Some(membership) = membership {
3447 connection_pool.subscribe_to_channel(
3448 membership.user_id,
3449 membership.channel_id,
3450 membership.role,
3451 );
3452 let update = proto::UpdateUserChannels {
3453 channel_memberships: vec![proto::ChannelMembership {
3454 channel_id: membership.channel_id.to_proto(),
3455 role: membership.role.into(),
3456 }],
3457 ..Default::default()
3458 };
3459 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3460 session.peer.send(connection_id, update.clone())?;
3461 }
3462 }
3463
3464 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3465 if !role.can_see_channel(channel.visibility) {
3466 continue;
3467 }
3468
3469 let update = proto::UpdateChannels {
3470 channels: vec![channel.to_proto()],
3471 ..Default::default()
3472 };
3473 session.peer.send(connection_id, update.clone())?;
3474 }
3475
3476 Ok(())
3477}
3478
3479/// Delete a channel
3480async fn delete_channel(
3481 request: proto::DeleteChannel,
3482 response: Response<proto::DeleteChannel>,
3483 session: UserSession,
3484) -> Result<()> {
3485 let db = session.db().await;
3486
3487 let channel_id = request.channel_id;
3488 let (root_channel, removed_channels) = db
3489 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3490 .await?;
3491 response.send(proto::Ack {})?;
3492
3493 // Notify members of removed channels
3494 let mut update = proto::UpdateChannels::default();
3495 update
3496 .delete_channels
3497 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3498
3499 let connection_pool = session.connection_pool().await;
3500 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3501 session.peer.send(connection_id, update.clone())?;
3502 }
3503
3504 Ok(())
3505}
3506
3507/// Invite someone to join a channel.
3508async fn invite_channel_member(
3509 request: proto::InviteChannelMember,
3510 response: Response<proto::InviteChannelMember>,
3511 session: UserSession,
3512) -> Result<()> {
3513 let db = session.db().await;
3514 let channel_id = ChannelId::from_proto(request.channel_id);
3515 let invitee_id = UserId::from_proto(request.user_id);
3516 let InviteMemberResult {
3517 channel,
3518 notifications,
3519 } = db
3520 .invite_channel_member(
3521 channel_id,
3522 invitee_id,
3523 session.user_id(),
3524 request.role().into(),
3525 )
3526 .await?;
3527
3528 let update = proto::UpdateChannels {
3529 channel_invitations: vec![channel.to_proto()],
3530 ..Default::default()
3531 };
3532
3533 let connection_pool = session.connection_pool().await;
3534 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3535 session.peer.send(connection_id, update.clone())?;
3536 }
3537
3538 send_notifications(&connection_pool, &session.peer, notifications);
3539
3540 response.send(proto::Ack {})?;
3541 Ok(())
3542}
3543
3544/// remove someone from a channel
3545async fn remove_channel_member(
3546 request: proto::RemoveChannelMember,
3547 response: Response<proto::RemoveChannelMember>,
3548 session: UserSession,
3549) -> Result<()> {
3550 let db = session.db().await;
3551 let channel_id = ChannelId::from_proto(request.channel_id);
3552 let member_id = UserId::from_proto(request.user_id);
3553
3554 let RemoveChannelMemberResult {
3555 membership_update,
3556 notification_id,
3557 } = db
3558 .remove_channel_member(channel_id, member_id, session.user_id())
3559 .await?;
3560
3561 let mut connection_pool = session.connection_pool().await;
3562 notify_membership_updated(
3563 &mut connection_pool,
3564 membership_update,
3565 member_id,
3566 &session.peer,
3567 );
3568 for connection_id in connection_pool.user_connection_ids(member_id) {
3569 if let Some(notification_id) = notification_id {
3570 session
3571 .peer
3572 .send(
3573 connection_id,
3574 proto::DeleteNotification {
3575 notification_id: notification_id.to_proto(),
3576 },
3577 )
3578 .trace_err();
3579 }
3580 }
3581
3582 response.send(proto::Ack {})?;
3583 Ok(())
3584}
3585
3586/// Toggle the channel between public and private.
3587/// Care is taken to maintain the invariant that public channels only descend from public channels,
3588/// (though members-only channels can appear at any point in the hierarchy).
3589async fn set_channel_visibility(
3590 request: proto::SetChannelVisibility,
3591 response: Response<proto::SetChannelVisibility>,
3592 session: UserSession,
3593) -> Result<()> {
3594 let db = session.db().await;
3595 let channel_id = ChannelId::from_proto(request.channel_id);
3596 let visibility = request.visibility().into();
3597
3598 let channel_model = db
3599 .set_channel_visibility(channel_id, visibility, session.user_id())
3600 .await?;
3601 let root_id = channel_model.root_id();
3602 let channel = Channel::from_model(channel_model);
3603
3604 let mut connection_pool = session.connection_pool().await;
3605 for (user_id, role) in connection_pool
3606 .channel_user_ids(root_id)
3607 .collect::<Vec<_>>()
3608 .into_iter()
3609 {
3610 let update = if role.can_see_channel(channel.visibility) {
3611 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3612 proto::UpdateChannels {
3613 channels: vec![channel.to_proto()],
3614 ..Default::default()
3615 }
3616 } else {
3617 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3618 proto::UpdateChannels {
3619 delete_channels: vec![channel.id.to_proto()],
3620 ..Default::default()
3621 }
3622 };
3623
3624 for connection_id in connection_pool.user_connection_ids(user_id) {
3625 session.peer.send(connection_id, update.clone())?;
3626 }
3627 }
3628
3629 response.send(proto::Ack {})?;
3630 Ok(())
3631}
3632
3633/// Alter the role for a user in the channel.
3634async fn set_channel_member_role(
3635 request: proto::SetChannelMemberRole,
3636 response: Response<proto::SetChannelMemberRole>,
3637 session: UserSession,
3638) -> Result<()> {
3639 let db = session.db().await;
3640 let channel_id = ChannelId::from_proto(request.channel_id);
3641 let member_id = UserId::from_proto(request.user_id);
3642 let result = db
3643 .set_channel_member_role(
3644 channel_id,
3645 session.user_id(),
3646 member_id,
3647 request.role().into(),
3648 )
3649 .await?;
3650
3651 match result {
3652 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3653 let mut connection_pool = session.connection_pool().await;
3654 notify_membership_updated(
3655 &mut connection_pool,
3656 membership_update,
3657 member_id,
3658 &session.peer,
3659 )
3660 }
3661 db::SetMemberRoleResult::InviteUpdated(channel) => {
3662 let update = proto::UpdateChannels {
3663 channel_invitations: vec![channel.to_proto()],
3664 ..Default::default()
3665 };
3666
3667 for connection_id in session
3668 .connection_pool()
3669 .await
3670 .user_connection_ids(member_id)
3671 {
3672 session.peer.send(connection_id, update.clone())?;
3673 }
3674 }
3675 }
3676
3677 response.send(proto::Ack {})?;
3678 Ok(())
3679}
3680
3681/// Change the name of a channel
3682async fn rename_channel(
3683 request: proto::RenameChannel,
3684 response: Response<proto::RenameChannel>,
3685 session: UserSession,
3686) -> Result<()> {
3687 let db = session.db().await;
3688 let channel_id = ChannelId::from_proto(request.channel_id);
3689 let channel_model = db
3690 .rename_channel(channel_id, session.user_id(), &request.name)
3691 .await?;
3692 let root_id = channel_model.root_id();
3693 let channel = Channel::from_model(channel_model);
3694
3695 response.send(proto::RenameChannelResponse {
3696 channel: Some(channel.to_proto()),
3697 })?;
3698
3699 let connection_pool = session.connection_pool().await;
3700 let update = proto::UpdateChannels {
3701 channels: vec![channel.to_proto()],
3702 ..Default::default()
3703 };
3704 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3705 if role.can_see_channel(channel.visibility) {
3706 session.peer.send(connection_id, update.clone())?;
3707 }
3708 }
3709
3710 Ok(())
3711}
3712
3713/// Move a channel to a new parent.
3714async fn move_channel(
3715 request: proto::MoveChannel,
3716 response: Response<proto::MoveChannel>,
3717 session: UserSession,
3718) -> Result<()> {
3719 let channel_id = ChannelId::from_proto(request.channel_id);
3720 let to = ChannelId::from_proto(request.to);
3721
3722 let (root_id, channels) = session
3723 .db()
3724 .await
3725 .move_channel(channel_id, to, session.user_id())
3726 .await?;
3727
3728 let connection_pool = session.connection_pool().await;
3729 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3730 let channels = channels
3731 .iter()
3732 .filter_map(|channel| {
3733 if role.can_see_channel(channel.visibility) {
3734 Some(channel.to_proto())
3735 } else {
3736 None
3737 }
3738 })
3739 .collect::<Vec<_>>();
3740 if channels.is_empty() {
3741 continue;
3742 }
3743
3744 let update = proto::UpdateChannels {
3745 channels,
3746 ..Default::default()
3747 };
3748
3749 session.peer.send(connection_id, update.clone())?;
3750 }
3751
3752 response.send(Ack {})?;
3753 Ok(())
3754}
3755
3756/// Get the list of channel members
3757async fn get_channel_members(
3758 request: proto::GetChannelMembers,
3759 response: Response<proto::GetChannelMembers>,
3760 session: UserSession,
3761) -> Result<()> {
3762 let db = session.db().await;
3763 let channel_id = ChannelId::from_proto(request.channel_id);
3764 let limit = if request.limit == 0 {
3765 u16::MAX as u64
3766 } else {
3767 request.limit
3768 };
3769 let (members, users) = db
3770 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3771 .await?;
3772 response.send(proto::GetChannelMembersResponse { members, users })?;
3773 Ok(())
3774}
3775
3776/// Accept or decline a channel invitation.
3777async fn respond_to_channel_invite(
3778 request: proto::RespondToChannelInvite,
3779 response: Response<proto::RespondToChannelInvite>,
3780 session: UserSession,
3781) -> Result<()> {
3782 let db = session.db().await;
3783 let channel_id = ChannelId::from_proto(request.channel_id);
3784 let RespondToChannelInvite {
3785 membership_update,
3786 notifications,
3787 } = db
3788 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3789 .await?;
3790
3791 let mut connection_pool = session.connection_pool().await;
3792 if let Some(membership_update) = membership_update {
3793 notify_membership_updated(
3794 &mut connection_pool,
3795 membership_update,
3796 session.user_id(),
3797 &session.peer,
3798 );
3799 } else {
3800 let update = proto::UpdateChannels {
3801 remove_channel_invitations: vec![channel_id.to_proto()],
3802 ..Default::default()
3803 };
3804
3805 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3806 session.peer.send(connection_id, update.clone())?;
3807 }
3808 };
3809
3810 send_notifications(&connection_pool, &session.peer, notifications);
3811
3812 response.send(proto::Ack {})?;
3813
3814 Ok(())
3815}
3816
3817/// Join the channels' room
3818async fn join_channel(
3819 request: proto::JoinChannel,
3820 response: Response<proto::JoinChannel>,
3821 session: UserSession,
3822) -> Result<()> {
3823 let channel_id = ChannelId::from_proto(request.channel_id);
3824 join_channel_internal(channel_id, Box::new(response), session).await
3825}
3826
3827trait JoinChannelInternalResponse {
3828 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3829}
3830impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3831 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3832 Response::<proto::JoinChannel>::send(self, result)
3833 }
3834}
3835impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3836 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3837 Response::<proto::JoinRoom>::send(self, result)
3838 }
3839}
3840
3841async fn join_channel_internal(
3842 channel_id: ChannelId,
3843 response: Box<impl JoinChannelInternalResponse>,
3844 session: UserSession,
3845) -> Result<()> {
3846 let joined_room = {
3847 let mut db = session.db().await;
3848 // If zed quits without leaving the room, and the user re-opens zed before the
3849 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3850 // room they were in.
3851 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3852 tracing::info!(
3853 stale_connection_id = %connection,
3854 "cleaning up stale connection",
3855 );
3856 drop(db);
3857 leave_room_for_session(&session, connection).await?;
3858 db = session.db().await;
3859 }
3860
3861 let (joined_room, membership_updated, role) = db
3862 .join_channel(channel_id, session.user_id(), session.connection_id)
3863 .await?;
3864
3865 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3866 let (can_publish, token) = if role == ChannelRole::Guest {
3867 (
3868 false,
3869 live_kit
3870 .guest_token(
3871 &joined_room.room.live_kit_room,
3872 &session.user_id().to_string(),
3873 )
3874 .trace_err()?,
3875 )
3876 } else {
3877 (
3878 true,
3879 live_kit
3880 .room_token(
3881 &joined_room.room.live_kit_room,
3882 &session.user_id().to_string(),
3883 )
3884 .trace_err()?,
3885 )
3886 };
3887
3888 Some(LiveKitConnectionInfo {
3889 server_url: live_kit.url().into(),
3890 token,
3891 can_publish,
3892 })
3893 });
3894
3895 response.send(proto::JoinRoomResponse {
3896 room: Some(joined_room.room.clone()),
3897 channel_id: joined_room
3898 .channel
3899 .as_ref()
3900 .map(|channel| channel.id.to_proto()),
3901 live_kit_connection_info,
3902 })?;
3903
3904 let mut connection_pool = session.connection_pool().await;
3905 if let Some(membership_updated) = membership_updated {
3906 notify_membership_updated(
3907 &mut connection_pool,
3908 membership_updated,
3909 session.user_id(),
3910 &session.peer,
3911 );
3912 }
3913
3914 room_updated(&joined_room.room, &session.peer);
3915
3916 joined_room
3917 };
3918
3919 channel_updated(
3920 &joined_room
3921 .channel
3922 .ok_or_else(|| anyhow!("channel not returned"))?,
3923 &joined_room.room,
3924 &session.peer,
3925 &*session.connection_pool().await,
3926 );
3927
3928 update_user_contacts(session.user_id(), &session).await?;
3929 Ok(())
3930}
3931
3932/// Start editing the channel notes
3933async fn join_channel_buffer(
3934 request: proto::JoinChannelBuffer,
3935 response: Response<proto::JoinChannelBuffer>,
3936 session: UserSession,
3937) -> Result<()> {
3938 let db = session.db().await;
3939 let channel_id = ChannelId::from_proto(request.channel_id);
3940
3941 let open_response = db
3942 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3943 .await?;
3944
3945 let collaborators = open_response.collaborators.clone();
3946 response.send(open_response)?;
3947
3948 let update = UpdateChannelBufferCollaborators {
3949 channel_id: channel_id.to_proto(),
3950 collaborators: collaborators.clone(),
3951 };
3952 channel_buffer_updated(
3953 session.connection_id,
3954 collaborators
3955 .iter()
3956 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3957 &update,
3958 &session.peer,
3959 );
3960
3961 Ok(())
3962}
3963
3964/// Edit the channel notes
3965async fn update_channel_buffer(
3966 request: proto::UpdateChannelBuffer,
3967 session: UserSession,
3968) -> Result<()> {
3969 let db = session.db().await;
3970 let channel_id = ChannelId::from_proto(request.channel_id);
3971
3972 let (collaborators, epoch, version) = db
3973 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3974 .await?;
3975
3976 channel_buffer_updated(
3977 session.connection_id,
3978 collaborators.clone(),
3979 &proto::UpdateChannelBuffer {
3980 channel_id: channel_id.to_proto(),
3981 operations: request.operations,
3982 },
3983 &session.peer,
3984 );
3985
3986 let pool = &*session.connection_pool().await;
3987
3988 let non_collaborators =
3989 pool.channel_connection_ids(channel_id)
3990 .filter_map(|(connection_id, _)| {
3991 if collaborators.contains(&connection_id) {
3992 None
3993 } else {
3994 Some(connection_id)
3995 }
3996 });
3997
3998 broadcast(None, non_collaborators, |peer_id| {
3999 session.peer.send(
4000 peer_id,
4001 proto::UpdateChannels {
4002 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4003 channel_id: channel_id.to_proto(),
4004 epoch: epoch as u64,
4005 version: version.clone(),
4006 }],
4007 ..Default::default()
4008 },
4009 )
4010 });
4011
4012 Ok(())
4013}
4014
4015/// Rejoin the channel notes after a connection blip
4016async fn rejoin_channel_buffers(
4017 request: proto::RejoinChannelBuffers,
4018 response: Response<proto::RejoinChannelBuffers>,
4019 session: UserSession,
4020) -> Result<()> {
4021 let db = session.db().await;
4022 let buffers = db
4023 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4024 .await?;
4025
4026 for rejoined_buffer in &buffers {
4027 let collaborators_to_notify = rejoined_buffer
4028 .buffer
4029 .collaborators
4030 .iter()
4031 .filter_map(|c| Some(c.peer_id?.into()));
4032 channel_buffer_updated(
4033 session.connection_id,
4034 collaborators_to_notify,
4035 &proto::UpdateChannelBufferCollaborators {
4036 channel_id: rejoined_buffer.buffer.channel_id,
4037 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4038 },
4039 &session.peer,
4040 );
4041 }
4042
4043 response.send(proto::RejoinChannelBuffersResponse {
4044 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4045 })?;
4046
4047 Ok(())
4048}
4049
4050/// Stop editing the channel notes
4051async fn leave_channel_buffer(
4052 request: proto::LeaveChannelBuffer,
4053 response: Response<proto::LeaveChannelBuffer>,
4054 session: UserSession,
4055) -> Result<()> {
4056 let db = session.db().await;
4057 let channel_id = ChannelId::from_proto(request.channel_id);
4058
4059 let left_buffer = db
4060 .leave_channel_buffer(channel_id, session.connection_id)
4061 .await?;
4062
4063 response.send(Ack {})?;
4064
4065 channel_buffer_updated(
4066 session.connection_id,
4067 left_buffer.connections,
4068 &proto::UpdateChannelBufferCollaborators {
4069 channel_id: channel_id.to_proto(),
4070 collaborators: left_buffer.collaborators,
4071 },
4072 &session.peer,
4073 );
4074
4075 Ok(())
4076}
4077
4078fn channel_buffer_updated<T: EnvelopedMessage>(
4079 sender_id: ConnectionId,
4080 collaborators: impl IntoIterator<Item = ConnectionId>,
4081 message: &T,
4082 peer: &Peer,
4083) {
4084 broadcast(Some(sender_id), collaborators, |peer_id| {
4085 peer.send(peer_id, message.clone())
4086 });
4087}
4088
4089fn send_notifications(
4090 connection_pool: &ConnectionPool,
4091 peer: &Peer,
4092 notifications: db::NotificationBatch,
4093) {
4094 for (user_id, notification) in notifications {
4095 for connection_id in connection_pool.user_connection_ids(user_id) {
4096 if let Err(error) = peer.send(
4097 connection_id,
4098 proto::AddNotification {
4099 notification: Some(notification.clone()),
4100 },
4101 ) {
4102 tracing::error!(
4103 "failed to send notification to {:?} {}",
4104 connection_id,
4105 error
4106 );
4107 }
4108 }
4109 }
4110}
4111
4112/// Send a message to the channel
4113async fn send_channel_message(
4114 request: proto::SendChannelMessage,
4115 response: Response<proto::SendChannelMessage>,
4116 session: UserSession,
4117) -> Result<()> {
4118 // Validate the message body.
4119 let body = request.body.trim().to_string();
4120 if body.len() > MAX_MESSAGE_LEN {
4121 return Err(anyhow!("message is too long"))?;
4122 }
4123 if body.is_empty() {
4124 return Err(anyhow!("message can't be blank"))?;
4125 }
4126
4127 // TODO: adjust mentions if body is trimmed
4128
4129 let timestamp = OffsetDateTime::now_utc();
4130 let nonce = request
4131 .nonce
4132 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4133
4134 let channel_id = ChannelId::from_proto(request.channel_id);
4135 let CreatedChannelMessage {
4136 message_id,
4137 participant_connection_ids,
4138 notifications,
4139 } = session
4140 .db()
4141 .await
4142 .create_channel_message(
4143 channel_id,
4144 session.user_id(),
4145 &body,
4146 &request.mentions,
4147 timestamp,
4148 nonce.clone().into(),
4149 match request.reply_to_message_id {
4150 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4151 None => None,
4152 },
4153 )
4154 .await?;
4155
4156 let message = proto::ChannelMessage {
4157 sender_id: session.user_id().to_proto(),
4158 id: message_id.to_proto(),
4159 body,
4160 mentions: request.mentions,
4161 timestamp: timestamp.unix_timestamp() as u64,
4162 nonce: Some(nonce),
4163 reply_to_message_id: request.reply_to_message_id,
4164 edited_at: None,
4165 };
4166 broadcast(
4167 Some(session.connection_id),
4168 participant_connection_ids.clone(),
4169 |connection| {
4170 session.peer.send(
4171 connection,
4172 proto::ChannelMessageSent {
4173 channel_id: channel_id.to_proto(),
4174 message: Some(message.clone()),
4175 },
4176 )
4177 },
4178 );
4179 response.send(proto::SendChannelMessageResponse {
4180 message: Some(message),
4181 })?;
4182
4183 let pool = &*session.connection_pool().await;
4184 let non_participants =
4185 pool.channel_connection_ids(channel_id)
4186 .filter_map(|(connection_id, _)| {
4187 if participant_connection_ids.contains(&connection_id) {
4188 None
4189 } else {
4190 Some(connection_id)
4191 }
4192 });
4193 broadcast(None, non_participants, |peer_id| {
4194 session.peer.send(
4195 peer_id,
4196 proto::UpdateChannels {
4197 latest_channel_message_ids: vec![proto::ChannelMessageId {
4198 channel_id: channel_id.to_proto(),
4199 message_id: message_id.to_proto(),
4200 }],
4201 ..Default::default()
4202 },
4203 )
4204 });
4205 send_notifications(pool, &session.peer, notifications);
4206
4207 Ok(())
4208}
4209
4210/// Delete a channel message
4211async fn remove_channel_message(
4212 request: proto::RemoveChannelMessage,
4213 response: Response<proto::RemoveChannelMessage>,
4214 session: UserSession,
4215) -> Result<()> {
4216 let channel_id = ChannelId::from_proto(request.channel_id);
4217 let message_id = MessageId::from_proto(request.message_id);
4218 let (connection_ids, existing_notification_ids) = session
4219 .db()
4220 .await
4221 .remove_channel_message(channel_id, message_id, session.user_id())
4222 .await?;
4223
4224 broadcast(
4225 Some(session.connection_id),
4226 connection_ids,
4227 move |connection| {
4228 session.peer.send(connection, request.clone())?;
4229
4230 for notification_id in &existing_notification_ids {
4231 session.peer.send(
4232 connection,
4233 proto::DeleteNotification {
4234 notification_id: (*notification_id).to_proto(),
4235 },
4236 )?;
4237 }
4238
4239 Ok(())
4240 },
4241 );
4242 response.send(proto::Ack {})?;
4243 Ok(())
4244}
4245
4246async fn update_channel_message(
4247 request: proto::UpdateChannelMessage,
4248 response: Response<proto::UpdateChannelMessage>,
4249 session: UserSession,
4250) -> Result<()> {
4251 let channel_id = ChannelId::from_proto(request.channel_id);
4252 let message_id = MessageId::from_proto(request.message_id);
4253 let updated_at = OffsetDateTime::now_utc();
4254 let UpdatedChannelMessage {
4255 message_id,
4256 participant_connection_ids,
4257 notifications,
4258 reply_to_message_id,
4259 timestamp,
4260 deleted_mention_notification_ids,
4261 updated_mention_notifications,
4262 } = session
4263 .db()
4264 .await
4265 .update_channel_message(
4266 channel_id,
4267 message_id,
4268 session.user_id(),
4269 request.body.as_str(),
4270 &request.mentions,
4271 updated_at,
4272 )
4273 .await?;
4274
4275 let nonce = request
4276 .nonce
4277 .clone()
4278 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4279
4280 let message = proto::ChannelMessage {
4281 sender_id: session.user_id().to_proto(),
4282 id: message_id.to_proto(),
4283 body: request.body.clone(),
4284 mentions: request.mentions.clone(),
4285 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4286 nonce: Some(nonce),
4287 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4288 edited_at: Some(updated_at.unix_timestamp() as u64),
4289 };
4290
4291 response.send(proto::Ack {})?;
4292
4293 let pool = &*session.connection_pool().await;
4294 broadcast(
4295 Some(session.connection_id),
4296 participant_connection_ids,
4297 |connection| {
4298 session.peer.send(
4299 connection,
4300 proto::ChannelMessageUpdate {
4301 channel_id: channel_id.to_proto(),
4302 message: Some(message.clone()),
4303 },
4304 )?;
4305
4306 for notification_id in &deleted_mention_notification_ids {
4307 session.peer.send(
4308 connection,
4309 proto::DeleteNotification {
4310 notification_id: (*notification_id).to_proto(),
4311 },
4312 )?;
4313 }
4314
4315 for notification in &updated_mention_notifications {
4316 session.peer.send(
4317 connection,
4318 proto::UpdateNotification {
4319 notification: Some(notification.clone()),
4320 },
4321 )?;
4322 }
4323
4324 Ok(())
4325 },
4326 );
4327
4328 send_notifications(pool, &session.peer, notifications);
4329
4330 Ok(())
4331}
4332
4333/// Mark a channel message as read
4334async fn acknowledge_channel_message(
4335 request: proto::AckChannelMessage,
4336 session: UserSession,
4337) -> Result<()> {
4338 let channel_id = ChannelId::from_proto(request.channel_id);
4339 let message_id = MessageId::from_proto(request.message_id);
4340 let notifications = session
4341 .db()
4342 .await
4343 .observe_channel_message(channel_id, session.user_id(), message_id)
4344 .await?;
4345 send_notifications(
4346 &*session.connection_pool().await,
4347 &session.peer,
4348 notifications,
4349 );
4350 Ok(())
4351}
4352
4353/// Mark a buffer version as synced
4354async fn acknowledge_buffer_version(
4355 request: proto::AckBufferOperation,
4356 session: UserSession,
4357) -> Result<()> {
4358 let buffer_id = BufferId::from_proto(request.buffer_id);
4359 session
4360 .db()
4361 .await
4362 .observe_buffer_version(
4363 buffer_id,
4364 session.user_id(),
4365 request.epoch as i32,
4366 &request.version,
4367 )
4368 .await?;
4369 Ok(())
4370}
4371
4372struct CompleteWithLanguageModelRateLimit;
4373
4374impl RateLimit for CompleteWithLanguageModelRateLimit {
4375 fn capacity() -> usize {
4376 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4377 .ok()
4378 .and_then(|v| v.parse().ok())
4379 .unwrap_or(120) // Picked arbitrarily
4380 }
4381
4382 fn refill_duration() -> chrono::Duration {
4383 chrono::Duration::hours(1)
4384 }
4385
4386 fn db_name() -> &'static str {
4387 "complete-with-language-model"
4388 }
4389}
4390
4391async fn complete_with_language_model(
4392 request: proto::CompleteWithLanguageModel,
4393 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4394 session: Session,
4395 open_ai_api_key: Option<Arc<str>>,
4396 google_ai_api_key: Option<Arc<str>>,
4397 anthropic_api_key: Option<Arc<str>>,
4398) -> Result<()> {
4399 let Some(session) = session.for_user() else {
4400 return Err(anyhow!("user not found"))?;
4401 };
4402 authorize_access_to_language_models(&session).await?;
4403 session
4404 .rate_limiter
4405 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4406 .await?;
4407
4408 if request.model.starts_with("gpt") {
4409 let api_key =
4410 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4411 complete_with_open_ai(request, response, session, api_key).await?;
4412 } else if request.model.starts_with("gemini") {
4413 let api_key = google_ai_api_key
4414 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4415 complete_with_google_ai(request, response, session, api_key).await?;
4416 } else if request.model.starts_with("claude") {
4417 let api_key = anthropic_api_key
4418 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4419 complete_with_anthropic(request, response, session, api_key).await?;
4420 }
4421
4422 Ok(())
4423}
4424
4425async fn complete_with_open_ai(
4426 request: proto::CompleteWithLanguageModel,
4427 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4428 session: UserSession,
4429 api_key: Arc<str>,
4430) -> Result<()> {
4431 let mut completion_stream = open_ai::stream_completion(
4432 session.http_client.as_ref(),
4433 OPEN_AI_API_URL,
4434 &api_key,
4435 crate::ai::language_model_request_to_open_ai(request)?,
4436 None,
4437 )
4438 .await
4439 .context("open_ai::stream_completion request failed within collab")?;
4440
4441 while let Some(event) = completion_stream.next().await {
4442 let event = event?;
4443 response.send(proto::LanguageModelResponse {
4444 choices: event
4445 .choices
4446 .into_iter()
4447 .map(|choice| proto::LanguageModelChoiceDelta {
4448 index: choice.index,
4449 delta: Some(proto::LanguageModelResponseMessage {
4450 role: choice.delta.role.map(|role| match role {
4451 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4452 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4453 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4454 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4455 } as i32),
4456 content: choice.delta.content,
4457 tool_calls: choice
4458 .delta
4459 .tool_calls
4460 .into_iter()
4461 .map(|delta| proto::ToolCallDelta {
4462 index: delta.index as u32,
4463 id: delta.id,
4464 variant: match delta.function {
4465 Some(function) => {
4466 let name = function.name;
4467 let arguments = function.arguments;
4468
4469 Some(proto::tool_call_delta::Variant::Function(
4470 proto::tool_call_delta::FunctionCallDelta {
4471 name,
4472 arguments,
4473 },
4474 ))
4475 }
4476 None => None,
4477 },
4478 })
4479 .collect(),
4480 }),
4481 finish_reason: choice.finish_reason,
4482 })
4483 .collect(),
4484 })?;
4485 }
4486
4487 Ok(())
4488}
4489
4490async fn complete_with_google_ai(
4491 request: proto::CompleteWithLanguageModel,
4492 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4493 session: UserSession,
4494 api_key: Arc<str>,
4495) -> Result<()> {
4496 let mut stream = google_ai::stream_generate_content(
4497 session.http_client.clone(),
4498 google_ai::API_URL,
4499 api_key.as_ref(),
4500 crate::ai::language_model_request_to_google_ai(request)?,
4501 )
4502 .await
4503 .context("google_ai::stream_generate_content request failed")?;
4504
4505 while let Some(event) = stream.next().await {
4506 let event = event?;
4507 response.send(proto::LanguageModelResponse {
4508 choices: event
4509 .candidates
4510 .unwrap_or_default()
4511 .into_iter()
4512 .map(|candidate| proto::LanguageModelChoiceDelta {
4513 index: candidate.index as u32,
4514 delta: Some(proto::LanguageModelResponseMessage {
4515 role: Some(match candidate.content.role {
4516 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4517 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4518 } as i32),
4519 content: Some(
4520 candidate
4521 .content
4522 .parts
4523 .into_iter()
4524 .filter_map(|part| match part {
4525 google_ai::Part::TextPart(part) => Some(part.text),
4526 google_ai::Part::InlineDataPart(_) => None,
4527 })
4528 .collect(),
4529 ),
4530 // Tool calls are not supported for Google
4531 tool_calls: Vec::new(),
4532 }),
4533 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4534 })
4535 .collect(),
4536 })?;
4537 }
4538
4539 Ok(())
4540}
4541
4542async fn complete_with_anthropic(
4543 request: proto::CompleteWithLanguageModel,
4544 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4545 session: UserSession,
4546 api_key: Arc<str>,
4547) -> Result<()> {
4548 let model = anthropic::Model::from_id(&request.model)?;
4549
4550 let mut system_message = String::new();
4551 let messages = request
4552 .messages
4553 .into_iter()
4554 .filter_map(|message| {
4555 match message.role() {
4556 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4557 role: anthropic::Role::User,
4558 content: message.content,
4559 }),
4560 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4561 role: anthropic::Role::Assistant,
4562 content: message.content,
4563 }),
4564 // Anthropic's API breaks system instructions out as a separate field rather
4565 // than having a system message role.
4566 LanguageModelRole::LanguageModelSystem => {
4567 if !system_message.is_empty() {
4568 system_message.push_str("\n\n");
4569 }
4570 system_message.push_str(&message.content);
4571
4572 None
4573 }
4574 // We don't yet support tool calls for Anthropic
4575 LanguageModelRole::LanguageModelTool => None,
4576 }
4577 })
4578 .collect();
4579
4580 let mut stream = anthropic::stream_completion(
4581 session.http_client.as_ref(),
4582 anthropic::ANTHROPIC_API_URL,
4583 &api_key,
4584 anthropic::Request {
4585 model,
4586 messages,
4587 stream: true,
4588 system: system_message,
4589 max_tokens: 4092,
4590 },
4591 None,
4592 )
4593 .await?;
4594
4595 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4596
4597 while let Some(event) = stream.next().await {
4598 let event = event?;
4599
4600 match event {
4601 anthropic::ResponseEvent::MessageStart { message } => {
4602 if let Some(role) = message.role {
4603 if role == "assistant" {
4604 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4605 } else if role == "user" {
4606 current_role = proto::LanguageModelRole::LanguageModelUser;
4607 }
4608 }
4609 }
4610 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4611 match content_block {
4612 anthropic::ContentBlock::Text { text } => {
4613 if !text.is_empty() {
4614 response.send(proto::LanguageModelResponse {
4615 choices: vec![proto::LanguageModelChoiceDelta {
4616 index: 0,
4617 delta: Some(proto::LanguageModelResponseMessage {
4618 role: Some(current_role as i32),
4619 content: Some(text),
4620 tool_calls: Vec::new(),
4621 }),
4622 finish_reason: None,
4623 }],
4624 })?;
4625 }
4626 }
4627 }
4628 }
4629 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4630 anthropic::TextDelta::TextDelta { text } => {
4631 response.send(proto::LanguageModelResponse {
4632 choices: vec![proto::LanguageModelChoiceDelta {
4633 index: 0,
4634 delta: Some(proto::LanguageModelResponseMessage {
4635 role: Some(current_role as i32),
4636 content: Some(text),
4637 tool_calls: Vec::new(),
4638 }),
4639 finish_reason: None,
4640 }],
4641 })?;
4642 }
4643 },
4644 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4645 if let Some(stop_reason) = delta.stop_reason {
4646 response.send(proto::LanguageModelResponse {
4647 choices: vec![proto::LanguageModelChoiceDelta {
4648 index: 0,
4649 delta: None,
4650 finish_reason: Some(stop_reason),
4651 }],
4652 })?;
4653 }
4654 }
4655 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4656 anthropic::ResponseEvent::MessageStop {} => {}
4657 anthropic::ResponseEvent::Ping {} => {}
4658 }
4659 }
4660
4661 Ok(())
4662}
4663
4664struct CountTokensWithLanguageModelRateLimit;
4665
4666impl RateLimit for CountTokensWithLanguageModelRateLimit {
4667 fn capacity() -> usize {
4668 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4669 .ok()
4670 .and_then(|v| v.parse().ok())
4671 .unwrap_or(600) // Picked arbitrarily
4672 }
4673
4674 fn refill_duration() -> chrono::Duration {
4675 chrono::Duration::hours(1)
4676 }
4677
4678 fn db_name() -> &'static str {
4679 "count-tokens-with-language-model"
4680 }
4681}
4682
4683async fn count_tokens_with_language_model(
4684 request: proto::CountTokensWithLanguageModel,
4685 response: Response<proto::CountTokensWithLanguageModel>,
4686 session: UserSession,
4687 google_ai_api_key: Option<Arc<str>>,
4688) -> Result<()> {
4689 authorize_access_to_language_models(&session).await?;
4690
4691 if !request.model.starts_with("gemini") {
4692 return Err(anyhow!(
4693 "counting tokens for model: {:?} is not supported",
4694 request.model
4695 ))?;
4696 }
4697
4698 session
4699 .rate_limiter
4700 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4701 .await?;
4702
4703 let api_key = google_ai_api_key
4704 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4705 let tokens_response = google_ai::count_tokens(
4706 session.http_client.as_ref(),
4707 google_ai::API_URL,
4708 &api_key,
4709 crate::ai::count_tokens_request_to_google_ai(request)?,
4710 )
4711 .await?;
4712 response.send(proto::CountTokensResponse {
4713 token_count: tokens_response.total_tokens as u32,
4714 })?;
4715 Ok(())
4716}
4717
4718struct ComputeEmbeddingsRateLimit;
4719
4720impl RateLimit for ComputeEmbeddingsRateLimit {
4721 fn capacity() -> usize {
4722 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4723 .ok()
4724 .and_then(|v| v.parse().ok())
4725 .unwrap_or(5000) // Picked arbitrarily
4726 }
4727
4728 fn refill_duration() -> chrono::Duration {
4729 chrono::Duration::hours(1)
4730 }
4731
4732 fn db_name() -> &'static str {
4733 "compute-embeddings"
4734 }
4735}
4736
4737async fn compute_embeddings(
4738 request: proto::ComputeEmbeddings,
4739 response: Response<proto::ComputeEmbeddings>,
4740 session: UserSession,
4741 api_key: Option<Arc<str>>,
4742) -> Result<()> {
4743 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4744 authorize_access_to_language_models(&session).await?;
4745
4746 session
4747 .rate_limiter
4748 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4749 .await?;
4750
4751 let embeddings = match request.model.as_str() {
4752 "openai/text-embedding-3-small" => {
4753 open_ai::embed(
4754 session.http_client.as_ref(),
4755 OPEN_AI_API_URL,
4756 &api_key,
4757 OpenAiEmbeddingModel::TextEmbedding3Small,
4758 request.texts.iter().map(|text| text.as_str()),
4759 )
4760 .await?
4761 }
4762 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4763 };
4764
4765 let embeddings = request
4766 .texts
4767 .iter()
4768 .map(|text| {
4769 let mut hasher = sha2::Sha256::new();
4770 hasher.update(text.as_bytes());
4771 let result = hasher.finalize();
4772 result.to_vec()
4773 })
4774 .zip(
4775 embeddings
4776 .data
4777 .into_iter()
4778 .map(|embedding| embedding.embedding),
4779 )
4780 .collect::<HashMap<_, _>>();
4781
4782 let db = session.db().await;
4783 db.save_embeddings(&request.model, &embeddings)
4784 .await
4785 .context("failed to save embeddings")
4786 .trace_err();
4787
4788 response.send(proto::ComputeEmbeddingsResponse {
4789 embeddings: embeddings
4790 .into_iter()
4791 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4792 .collect(),
4793 })?;
4794 Ok(())
4795}
4796
4797async fn get_cached_embeddings(
4798 request: proto::GetCachedEmbeddings,
4799 response: Response<proto::GetCachedEmbeddings>,
4800 session: UserSession,
4801) -> Result<()> {
4802 authorize_access_to_language_models(&session).await?;
4803
4804 let db = session.db().await;
4805 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4806
4807 response.send(proto::GetCachedEmbeddingsResponse {
4808 embeddings: embeddings
4809 .into_iter()
4810 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4811 .collect(),
4812 })?;
4813 Ok(())
4814}
4815
4816async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4817 let db = session.db().await;
4818 let flags = db.get_user_flags(session.user_id()).await?;
4819 if flags.iter().any(|flag| flag == "language-models") {
4820 Ok(())
4821 } else {
4822 Err(anyhow!("permission denied"))?
4823 }
4824}
4825
4826/// Get a Supermaven API key for the user
4827async fn get_supermaven_api_key(
4828 _request: proto::GetSupermavenApiKey,
4829 response: Response<proto::GetSupermavenApiKey>,
4830 session: UserSession,
4831) -> Result<()> {
4832 let user_id: String = session.user_id().to_string();
4833 if !session.is_staff() {
4834 return Err(anyhow!("supermaven not enabled for this account"))?;
4835 }
4836
4837 let email = session
4838 .email()
4839 .ok_or_else(|| anyhow!("user must have an email"))?;
4840
4841 let supermaven_admin_api = session
4842 .supermaven_client
4843 .as_ref()
4844 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4845
4846 let result = supermaven_admin_api
4847 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4848 .await?;
4849
4850 response.send(proto::GetSupermavenApiKeyResponse {
4851 api_key: result.api_key,
4852 })?;
4853
4854 Ok(())
4855}
4856
4857/// Start receiving chat updates for a channel
4858async fn join_channel_chat(
4859 request: proto::JoinChannelChat,
4860 response: Response<proto::JoinChannelChat>,
4861 session: UserSession,
4862) -> Result<()> {
4863 let channel_id = ChannelId::from_proto(request.channel_id);
4864
4865 let db = session.db().await;
4866 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4867 .await?;
4868 let messages = db
4869 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4870 .await?;
4871 response.send(proto::JoinChannelChatResponse {
4872 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4873 messages,
4874 })?;
4875 Ok(())
4876}
4877
4878/// Stop receiving chat updates for a channel
4879async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4880 let channel_id = ChannelId::from_proto(request.channel_id);
4881 session
4882 .db()
4883 .await
4884 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4885 .await?;
4886 Ok(())
4887}
4888
4889/// Retrieve the chat history for a channel
4890async fn get_channel_messages(
4891 request: proto::GetChannelMessages,
4892 response: Response<proto::GetChannelMessages>,
4893 session: UserSession,
4894) -> Result<()> {
4895 let channel_id = ChannelId::from_proto(request.channel_id);
4896 let messages = session
4897 .db()
4898 .await
4899 .get_channel_messages(
4900 channel_id,
4901 session.user_id(),
4902 MESSAGE_COUNT_PER_PAGE,
4903 Some(MessageId::from_proto(request.before_message_id)),
4904 )
4905 .await?;
4906 response.send(proto::GetChannelMessagesResponse {
4907 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4908 messages,
4909 })?;
4910 Ok(())
4911}
4912
4913/// Retrieve specific chat messages
4914async fn get_channel_messages_by_id(
4915 request: proto::GetChannelMessagesById,
4916 response: Response<proto::GetChannelMessagesById>,
4917 session: UserSession,
4918) -> Result<()> {
4919 let message_ids = request
4920 .message_ids
4921 .iter()
4922 .map(|id| MessageId::from_proto(*id))
4923 .collect::<Vec<_>>();
4924 let messages = session
4925 .db()
4926 .await
4927 .get_channel_messages_by_id(session.user_id(), &message_ids)
4928 .await?;
4929 response.send(proto::GetChannelMessagesResponse {
4930 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4931 messages,
4932 })?;
4933 Ok(())
4934}
4935
4936/// Retrieve the current users notifications
4937async fn get_notifications(
4938 request: proto::GetNotifications,
4939 response: Response<proto::GetNotifications>,
4940 session: UserSession,
4941) -> Result<()> {
4942 let notifications = session
4943 .db()
4944 .await
4945 .get_notifications(
4946 session.user_id(),
4947 NOTIFICATION_COUNT_PER_PAGE,
4948 request
4949 .before_id
4950 .map(|id| db::NotificationId::from_proto(id)),
4951 )
4952 .await?;
4953 response.send(proto::GetNotificationsResponse {
4954 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4955 notifications,
4956 })?;
4957 Ok(())
4958}
4959
4960/// Mark notifications as read
4961async fn mark_notification_as_read(
4962 request: proto::MarkNotificationRead,
4963 response: Response<proto::MarkNotificationRead>,
4964 session: UserSession,
4965) -> Result<()> {
4966 let database = &session.db().await;
4967 let notifications = database
4968 .mark_notification_as_read_by_id(
4969 session.user_id(),
4970 NotificationId::from_proto(request.notification_id),
4971 )
4972 .await?;
4973 send_notifications(
4974 &*session.connection_pool().await,
4975 &session.peer,
4976 notifications,
4977 );
4978 response.send(proto::Ack {})?;
4979 Ok(())
4980}
4981
4982/// Get the current users information
4983async fn get_private_user_info(
4984 _request: proto::GetPrivateUserInfo,
4985 response: Response<proto::GetPrivateUserInfo>,
4986 session: UserSession,
4987) -> Result<()> {
4988 let db = session.db().await;
4989
4990 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4991 let user = db
4992 .get_user_by_id(session.user_id())
4993 .await?
4994 .ok_or_else(|| anyhow!("user not found"))?;
4995 let flags = db.get_user_flags(session.user_id()).await?;
4996
4997 response.send(proto::GetPrivateUserInfoResponse {
4998 metrics_id,
4999 staff: user.admin,
5000 flags,
5001 })?;
5002 Ok(())
5003}
5004
5005fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5006 match message {
5007 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5008 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5009 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5010 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5011 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5012 code: frame.code.into(),
5013 reason: frame.reason,
5014 })),
5015 }
5016}
5017
5018fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5019 match message {
5020 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5021 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5022 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5023 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5024 AxumMessage::Close(frame) => {
5025 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5026 code: frame.code.into(),
5027 reason: frame.reason,
5028 }))
5029 }
5030 }
5031}
5032
5033fn notify_membership_updated(
5034 connection_pool: &mut ConnectionPool,
5035 result: MembershipUpdated,
5036 user_id: UserId,
5037 peer: &Peer,
5038) {
5039 for membership in &result.new_channels.channel_memberships {
5040 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5041 }
5042 for channel_id in &result.removed_channels {
5043 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5044 }
5045
5046 let user_channels_update = proto::UpdateUserChannels {
5047 channel_memberships: result
5048 .new_channels
5049 .channel_memberships
5050 .iter()
5051 .map(|cm| proto::ChannelMembership {
5052 channel_id: cm.channel_id.to_proto(),
5053 role: cm.role.into(),
5054 })
5055 .collect(),
5056 ..Default::default()
5057 };
5058
5059 let mut update = build_channels_update(result.new_channels);
5060 update.delete_channels = result
5061 .removed_channels
5062 .into_iter()
5063 .map(|id| id.to_proto())
5064 .collect();
5065 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5066
5067 for connection_id in connection_pool.user_connection_ids(user_id) {
5068 peer.send(connection_id, user_channels_update.clone())
5069 .trace_err();
5070 peer.send(connection_id, update.clone()).trace_err();
5071 }
5072}
5073
5074fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5075 proto::UpdateUserChannels {
5076 channel_memberships: channels
5077 .channel_memberships
5078 .iter()
5079 .map(|m| proto::ChannelMembership {
5080 channel_id: m.channel_id.to_proto(),
5081 role: m.role.into(),
5082 })
5083 .collect(),
5084 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5085 observed_channel_message_id: channels.observed_channel_messages.clone(),
5086 }
5087}
5088
5089fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5090 let mut update = proto::UpdateChannels::default();
5091
5092 for channel in channels.channels {
5093 update.channels.push(channel.to_proto());
5094 }
5095
5096 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5097 update.latest_channel_message_ids = channels.latest_channel_messages;
5098
5099 for (channel_id, participants) in channels.channel_participants {
5100 update
5101 .channel_participants
5102 .push(proto::ChannelParticipants {
5103 channel_id: channel_id.to_proto(),
5104 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5105 });
5106 }
5107
5108 for channel in channels.invited_channels {
5109 update.channel_invitations.push(channel.to_proto());
5110 }
5111
5112 update.hosted_projects = channels.hosted_projects;
5113 update
5114}
5115
5116fn build_initial_contacts_update(
5117 contacts: Vec<db::Contact>,
5118 pool: &ConnectionPool,
5119) -> proto::UpdateContacts {
5120 let mut update = proto::UpdateContacts::default();
5121
5122 for contact in contacts {
5123 match contact {
5124 db::Contact::Accepted { user_id, busy } => {
5125 update.contacts.push(contact_for_user(user_id, busy, &pool));
5126 }
5127 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5128 db::Contact::Incoming { user_id } => {
5129 update
5130 .incoming_requests
5131 .push(proto::IncomingContactRequest {
5132 requester_id: user_id.to_proto(),
5133 })
5134 }
5135 }
5136 }
5137
5138 update
5139}
5140
5141fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5142 proto::Contact {
5143 user_id: user_id.to_proto(),
5144 online: pool.is_user_online(user_id),
5145 busy,
5146 }
5147}
5148
5149fn room_updated(room: &proto::Room, peer: &Peer) {
5150 broadcast(
5151 None,
5152 room.participants
5153 .iter()
5154 .filter_map(|participant| Some(participant.peer_id?.into())),
5155 |peer_id| {
5156 peer.send(
5157 peer_id,
5158 proto::RoomUpdated {
5159 room: Some(room.clone()),
5160 },
5161 )
5162 },
5163 );
5164}
5165
5166fn channel_updated(
5167 channel: &db::channel::Model,
5168 room: &proto::Room,
5169 peer: &Peer,
5170 pool: &ConnectionPool,
5171) {
5172 let participants = room
5173 .participants
5174 .iter()
5175 .map(|p| p.user_id)
5176 .collect::<Vec<_>>();
5177
5178 broadcast(
5179 None,
5180 pool.channel_connection_ids(channel.root_id())
5181 .filter_map(|(channel_id, role)| {
5182 role.can_see_channel(channel.visibility).then(|| channel_id)
5183 }),
5184 |peer_id| {
5185 peer.send(
5186 peer_id,
5187 proto::UpdateChannels {
5188 channel_participants: vec![proto::ChannelParticipants {
5189 channel_id: channel.id.to_proto(),
5190 participant_user_ids: participants.clone(),
5191 }],
5192 ..Default::default()
5193 },
5194 )
5195 },
5196 );
5197}
5198
5199async fn send_dev_server_projects_update(
5200 user_id: UserId,
5201 mut status: proto::DevServerProjectsUpdate,
5202 session: &Session,
5203) {
5204 let pool = session.connection_pool().await;
5205 for dev_server in &mut status.dev_servers {
5206 dev_server.status =
5207 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5208 }
5209 let connections = pool.user_connection_ids(user_id);
5210 for connection_id in connections {
5211 session.peer.send(connection_id, status.clone()).trace_err();
5212 }
5213}
5214
5215async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5216 let db = session.db().await;
5217
5218 let contacts = db.get_contacts(user_id).await?;
5219 let busy = db.is_user_busy(user_id).await?;
5220
5221 let pool = session.connection_pool().await;
5222 let updated_contact = contact_for_user(user_id, busy, &pool);
5223 for contact in contacts {
5224 if let db::Contact::Accepted {
5225 user_id: contact_user_id,
5226 ..
5227 } = contact
5228 {
5229 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5230 session
5231 .peer
5232 .send(
5233 contact_conn_id,
5234 proto::UpdateContacts {
5235 contacts: vec![updated_contact.clone()],
5236 remove_contacts: Default::default(),
5237 incoming_requests: Default::default(),
5238 remove_incoming_requests: Default::default(),
5239 outgoing_requests: Default::default(),
5240 remove_outgoing_requests: Default::default(),
5241 },
5242 )
5243 .trace_err();
5244 }
5245 }
5246 }
5247 Ok(())
5248}
5249
5250async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5251 log::info!("lost dev server connection, unsharing projects");
5252 let project_ids = session
5253 .db()
5254 .await
5255 .get_stale_dev_server_projects(session.connection_id)
5256 .await?;
5257
5258 for project_id in project_ids {
5259 // not unshare re-checks the connection ids match, so we get away with no transaction
5260 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5261 }
5262
5263 let user_id = session.dev_server().user_id;
5264 let update = session
5265 .db()
5266 .await
5267 .dev_server_projects_update(user_id)
5268 .await?;
5269
5270 send_dev_server_projects_update(user_id, update, session).await;
5271
5272 Ok(())
5273}
5274
5275async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5276 let mut contacts_to_update = HashSet::default();
5277
5278 let room_id;
5279 let canceled_calls_to_user_ids;
5280 let live_kit_room;
5281 let delete_live_kit_room;
5282 let room;
5283 let channel;
5284
5285 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5286 contacts_to_update.insert(session.user_id());
5287
5288 for project in left_room.left_projects.values() {
5289 project_left(project, session);
5290 }
5291
5292 room_id = RoomId::from_proto(left_room.room.id);
5293 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5294 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5295 delete_live_kit_room = left_room.deleted;
5296 room = mem::take(&mut left_room.room);
5297 channel = mem::take(&mut left_room.channel);
5298
5299 room_updated(&room, &session.peer);
5300 } else {
5301 return Ok(());
5302 }
5303
5304 if let Some(channel) = channel {
5305 channel_updated(
5306 &channel,
5307 &room,
5308 &session.peer,
5309 &*session.connection_pool().await,
5310 );
5311 }
5312
5313 {
5314 let pool = session.connection_pool().await;
5315 for canceled_user_id in canceled_calls_to_user_ids {
5316 for connection_id in pool.user_connection_ids(canceled_user_id) {
5317 session
5318 .peer
5319 .send(
5320 connection_id,
5321 proto::CallCanceled {
5322 room_id: room_id.to_proto(),
5323 },
5324 )
5325 .trace_err();
5326 }
5327 contacts_to_update.insert(canceled_user_id);
5328 }
5329 }
5330
5331 for contact_user_id in contacts_to_update {
5332 update_user_contacts(contact_user_id, &session).await?;
5333 }
5334
5335 if let Some(live_kit) = session.live_kit_client.as_ref() {
5336 live_kit
5337 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5338 .await
5339 .trace_err();
5340
5341 if delete_live_kit_room {
5342 live_kit.delete_room(live_kit_room).await.trace_err();
5343 }
5344 }
5345
5346 Ok(())
5347}
5348
5349async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5350 let left_channel_buffers = session
5351 .db()
5352 .await
5353 .leave_channel_buffers(session.connection_id)
5354 .await?;
5355
5356 for left_buffer in left_channel_buffers {
5357 channel_buffer_updated(
5358 session.connection_id,
5359 left_buffer.connections,
5360 &proto::UpdateChannelBufferCollaborators {
5361 channel_id: left_buffer.channel_id.to_proto(),
5362 collaborators: left_buffer.collaborators,
5363 },
5364 &session.peer,
5365 );
5366 }
5367
5368 Ok(())
5369}
5370
5371fn project_left(project: &db::LeftProject, session: &UserSession) {
5372 for connection_id in &project.connection_ids {
5373 if project.should_unshare {
5374 session
5375 .peer
5376 .send(
5377 *connection_id,
5378 proto::UnshareProject {
5379 project_id: project.id.to_proto(),
5380 },
5381 )
5382 .trace_err();
5383 } else {
5384 session
5385 .peer
5386 .send(
5387 *connection_id,
5388 proto::RemoveProjectCollaborator {
5389 project_id: project.id.to_proto(),
5390 peer_id: Some(session.connection_id.into()),
5391 },
5392 )
5393 .trace_err();
5394 }
5395 }
5396}
5397
5398pub trait ResultExt {
5399 type Ok;
5400
5401 fn trace_err(self) -> Option<Self::Ok>;
5402}
5403
5404impl<T, E> ResultExt for Result<T, E>
5405where
5406 E: std::fmt::Debug,
5407{
5408 type Ok = T;
5409
5410 #[track_caller]
5411 fn trace_err(self) -> Option<T> {
5412 match self {
5413 Ok(value) => Some(value),
5414 Err(error) => {
5415 tracing::error!("{:?}", error);
5416 None
5417 }
5418 }
5419 }
5420}