1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::RestartLanguageServers>,
550 ))
551 .add_message_handler(create_buffer_for_peer)
552 .add_request_handler(update_buffer)
553 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
554 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
555 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
556 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
557 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
558 .add_request_handler(get_users)
559 .add_request_handler(user_handler(fuzzy_search_users))
560 .add_request_handler(user_handler(request_contact))
561 .add_request_handler(user_handler(remove_contact))
562 .add_request_handler(user_handler(respond_to_contact_request))
563 .add_message_handler(subscribe_to_channels)
564 .add_request_handler(user_handler(create_channel))
565 .add_request_handler(user_handler(delete_channel))
566 .add_request_handler(user_handler(invite_channel_member))
567 .add_request_handler(user_handler(remove_channel_member))
568 .add_request_handler(user_handler(set_channel_member_role))
569 .add_request_handler(user_handler(set_channel_visibility))
570 .add_request_handler(user_handler(rename_channel))
571 .add_request_handler(user_handler(join_channel_buffer))
572 .add_request_handler(user_handler(leave_channel_buffer))
573 .add_message_handler(user_message_handler(update_channel_buffer))
574 .add_request_handler(user_handler(rejoin_channel_buffers))
575 .add_request_handler(user_handler(get_channel_members))
576 .add_request_handler(user_handler(respond_to_channel_invite))
577 .add_request_handler(user_handler(join_channel))
578 .add_request_handler(user_handler(join_channel_chat))
579 .add_message_handler(user_message_handler(leave_channel_chat))
580 .add_request_handler(user_handler(send_channel_message))
581 .add_request_handler(user_handler(remove_channel_message))
582 .add_request_handler(user_handler(update_channel_message))
583 .add_request_handler(user_handler(get_channel_messages))
584 .add_request_handler(user_handler(get_channel_messages_by_id))
585 .add_request_handler(user_handler(get_notifications))
586 .add_request_handler(user_handler(mark_notification_as_read))
587 .add_request_handler(user_handler(move_channel))
588 .add_request_handler(user_handler(follow))
589 .add_message_handler(user_message_handler(unfollow))
590 .add_message_handler(user_message_handler(update_followers))
591 .add_request_handler(user_handler(get_private_user_info))
592 .add_message_handler(user_message_handler(acknowledge_channel_message))
593 .add_message_handler(user_message_handler(acknowledge_buffer_version))
594 .add_request_handler(user_handler(get_supermaven_api_key))
595 .add_streaming_request_handler({
596 let app_state = app_state.clone();
597 move |request, response, session| {
598 complete_with_language_model(
599 request,
600 response,
601 session,
602 app_state.config.openai_api_key.clone(),
603 app_state.config.google_ai_api_key.clone(),
604 app_state.config.anthropic_api_key.clone(),
605 )
606 }
607 })
608 .add_request_handler({
609 let app_state = app_state.clone();
610 user_handler(move |request, response, session| {
611 count_tokens_with_language_model(
612 request,
613 response,
614 session,
615 app_state.config.google_ai_api_key.clone(),
616 )
617 })
618 })
619 .add_request_handler({
620 user_handler(move |request, response, session| {
621 get_cached_embeddings(request, response, session)
622 })
623 })
624 .add_request_handler({
625 let app_state = app_state.clone();
626 user_handler(move |request, response, session| {
627 compute_embeddings(
628 request,
629 response,
630 session,
631 app_state.config.openai_api_key.clone(),
632 )
633 })
634 });
635
636 Arc::new(server)
637 }
638
639 pub async fn start(&self) -> Result<()> {
640 let server_id = *self.id.lock();
641 let app_state = self.app_state.clone();
642 let peer = self.peer.clone();
643 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
644 let pool = self.connection_pool.clone();
645 let live_kit_client = self.app_state.live_kit_client.clone();
646
647 let span = info_span!("start server");
648 self.app_state.executor.spawn_detached(
649 async move {
650 tracing::info!("waiting for cleanup timeout");
651 timeout.await;
652 tracing::info!("cleanup timeout expired, retrieving stale rooms");
653 if let Some((room_ids, channel_ids)) = app_state
654 .db
655 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
656 .await
657 .trace_err()
658 {
659 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
660 tracing::info!(
661 stale_channel_buffer_count = channel_ids.len(),
662 "retrieved stale channel buffers"
663 );
664
665 for channel_id in channel_ids {
666 if let Some(refreshed_channel_buffer) = app_state
667 .db
668 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
669 .await
670 .trace_err()
671 {
672 for connection_id in refreshed_channel_buffer.connection_ids {
673 peer.send(
674 connection_id,
675 proto::UpdateChannelBufferCollaborators {
676 channel_id: channel_id.to_proto(),
677 collaborators: refreshed_channel_buffer
678 .collaborators
679 .clone(),
680 },
681 )
682 .trace_err();
683 }
684 }
685 }
686
687 for room_id in room_ids {
688 let mut contacts_to_update = HashSet::default();
689 let mut canceled_calls_to_user_ids = Vec::new();
690 let mut live_kit_room = String::new();
691 let mut delete_live_kit_room = false;
692
693 if let Some(mut refreshed_room) = app_state
694 .db
695 .clear_stale_room_participants(room_id, server_id)
696 .await
697 .trace_err()
698 {
699 tracing::info!(
700 room_id = room_id.0,
701 new_participant_count = refreshed_room.room.participants.len(),
702 "refreshed room"
703 );
704 room_updated(&refreshed_room.room, &peer);
705 if let Some(channel) = refreshed_room.channel.as_ref() {
706 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
707 }
708 contacts_to_update
709 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
710 contacts_to_update
711 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
712 canceled_calls_to_user_ids =
713 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
714 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
715 delete_live_kit_room = refreshed_room.room.participants.is_empty();
716 }
717
718 {
719 let pool = pool.lock();
720 for canceled_user_id in canceled_calls_to_user_ids {
721 for connection_id in pool.user_connection_ids(canceled_user_id) {
722 peer.send(
723 connection_id,
724 proto::CallCanceled {
725 room_id: room_id.to_proto(),
726 },
727 )
728 .trace_err();
729 }
730 }
731 }
732
733 for user_id in contacts_to_update {
734 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
735 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
736 if let Some((busy, contacts)) = busy.zip(contacts) {
737 let pool = pool.lock();
738 let updated_contact = contact_for_user(user_id, busy, &pool);
739 for contact in contacts {
740 if let db::Contact::Accepted {
741 user_id: contact_user_id,
742 ..
743 } = contact
744 {
745 for contact_conn_id in
746 pool.user_connection_ids(contact_user_id)
747 {
748 peer.send(
749 contact_conn_id,
750 proto::UpdateContacts {
751 contacts: vec![updated_contact.clone()],
752 remove_contacts: Default::default(),
753 incoming_requests: Default::default(),
754 remove_incoming_requests: Default::default(),
755 outgoing_requests: Default::default(),
756 remove_outgoing_requests: Default::default(),
757 },
758 )
759 .trace_err();
760 }
761 }
762 }
763 }
764 }
765
766 if let Some(live_kit) = live_kit_client.as_ref() {
767 if delete_live_kit_room {
768 live_kit.delete_room(live_kit_room).await.trace_err();
769 }
770 }
771 }
772 }
773
774 app_state
775 .db
776 .delete_stale_servers(&app_state.config.zed_environment, server_id)
777 .await
778 .trace_err();
779 }
780 .instrument(span),
781 );
782 Ok(())
783 }
784
785 pub fn teardown(&self) {
786 self.peer.teardown();
787 self.connection_pool.lock().reset();
788 let _ = self.teardown.send(true);
789 }
790
791 #[cfg(test)]
792 pub fn reset(&self, id: ServerId) {
793 self.teardown();
794 *self.id.lock() = id;
795 self.peer.reset(id.0 as u32);
796 let _ = self.teardown.send(false);
797 }
798
799 #[cfg(test)]
800 pub fn id(&self) -> ServerId {
801 *self.id.lock()
802 }
803
804 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
805 where
806 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
807 Fut: 'static + Send + Future<Output = Result<()>>,
808 M: EnvelopedMessage,
809 {
810 let prev_handler = self.handlers.insert(
811 TypeId::of::<M>(),
812 Box::new(move |envelope, session| {
813 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
814 let received_at = envelope.received_at;
815 tracing::info!("message received");
816 let start_time = Instant::now();
817 let future = (handler)(*envelope, session);
818 async move {
819 let result = future.await;
820 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
821 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
822 let queue_duration_ms = total_duration_ms - processing_duration_ms;
823 let payload_type = M::NAME;
824
825 match result {
826 Err(error) => {
827 tracing::error!(
828 ?error,
829 total_duration_ms,
830 processing_duration_ms,
831 queue_duration_ms,
832 payload_type,
833 "error handling message"
834 )
835 }
836 Ok(()) => tracing::info!(
837 total_duration_ms,
838 processing_duration_ms,
839 queue_duration_ms,
840 "finished handling message"
841 ),
842 }
843 }
844 .boxed()
845 }),
846 );
847 if prev_handler.is_some() {
848 panic!("registered a handler for the same message twice");
849 }
850 self
851 }
852
853 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
854 where
855 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
856 Fut: 'static + Send + Future<Output = Result<()>>,
857 M: EnvelopedMessage,
858 {
859 self.add_handler(move |envelope, session| handler(envelope.payload, session));
860 self
861 }
862
863 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
864 where
865 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
866 Fut: Send + Future<Output = Result<()>>,
867 M: RequestMessage,
868 {
869 let handler = Arc::new(handler);
870 self.add_handler(move |envelope, session| {
871 let receipt = envelope.receipt();
872 let handler = handler.clone();
873 async move {
874 let peer = session.peer.clone();
875 let responded = Arc::new(AtomicBool::default());
876 let response = Response {
877 peer: peer.clone(),
878 responded: responded.clone(),
879 receipt,
880 };
881 match (handler)(envelope.payload, response, session).await {
882 Ok(()) => {
883 if responded.load(std::sync::atomic::Ordering::SeqCst) {
884 Ok(())
885 } else {
886 Err(anyhow!("handler did not send a response"))?
887 }
888 }
889 Err(error) => {
890 let proto_err = match &error {
891 Error::Internal(err) => err.to_proto(),
892 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
893 };
894 peer.respond_with_error(receipt, proto_err)?;
895 Err(error)
896 }
897 }
898 }
899 })
900 }
901
902 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
903 where
904 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
905 Fut: Send + Future<Output = Result<()>>,
906 M: RequestMessage,
907 {
908 let handler = Arc::new(handler);
909 self.add_handler(move |envelope, session| {
910 let receipt = envelope.receipt();
911 let handler = handler.clone();
912 async move {
913 let peer = session.peer.clone();
914 let response = StreamingResponse {
915 peer: peer.clone(),
916 receipt,
917 };
918 match (handler)(envelope.payload, response, session).await {
919 Ok(()) => {
920 peer.end_stream(receipt)?;
921 Ok(())
922 }
923 Err(error) => {
924 let proto_err = match &error {
925 Error::Internal(err) => err.to_proto(),
926 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
927 };
928 peer.respond_with_error(receipt, proto_err)?;
929 Err(error)
930 }
931 }
932 }
933 })
934 }
935
936 #[allow(clippy::too_many_arguments)]
937 pub fn handle_connection(
938 self: &Arc<Self>,
939 connection: Connection,
940 address: String,
941 principal: Principal,
942 zed_version: ZedVersion,
943 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
944 executor: Executor,
945 ) -> impl Future<Output = ()> {
946 let this = self.clone();
947 let span = info_span!("handle connection", %address,
948 connection_id=field::Empty,
949 user_id=field::Empty,
950 login=field::Empty,
951 impersonator=field::Empty,
952 dev_server_id=field::Empty
953 );
954 principal.update_span(&span);
955
956 let mut teardown = self.teardown.subscribe();
957 async move {
958 if *teardown.borrow() {
959 tracing::error!("server is tearing down");
960 return
961 }
962 let (connection_id, handle_io, mut incoming_rx) = this
963 .peer
964 .add_connection(connection, {
965 let executor = executor.clone();
966 move |duration| executor.sleep(duration)
967 });
968 tracing::Span::current().record("connection_id", format!("{}", connection_id));
969 tracing::info!("connection opened");
970
971 let http_client = match IsahcHttpClient::new() {
972 Ok(http_client) => Arc::new(http_client),
973 Err(error) => {
974 tracing::error!(?error, "failed to create HTTP client");
975 return;
976 }
977 };
978
979 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
980 Some(Arc::new(SupermavenAdminApi::new(
981 supermaven_admin_api_key.to_string(),
982 http_client.clone(),
983 )))
984 } else {
985 None
986 };
987
988 let session = Session {
989 principal: principal.clone(),
990 connection_id,
991 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
992 peer: this.peer.clone(),
993 connection_pool: this.connection_pool.clone(),
994 live_kit_client: this.app_state.live_kit_client.clone(),
995 http_client,
996 rate_limiter: this.app_state.rate_limiter.clone(),
997 _executor: executor.clone(),
998 supermaven_client,
999 };
1000
1001 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1002 tracing::error!(?error, "failed to send initial client update");
1003 return;
1004 }
1005
1006 let handle_io = handle_io.fuse();
1007 futures::pin_mut!(handle_io);
1008
1009 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1010 // This prevents deadlocks when e.g., client A performs a request to client B and
1011 // client B performs a request to client A. If both clients stop processing further
1012 // messages until their respective request completes, they won't have a chance to
1013 // respond to the other client's request and cause a deadlock.
1014 //
1015 // This arrangement ensures we will attempt to process earlier messages first, but fall
1016 // back to processing messages arrived later in the spirit of making progress.
1017 let mut foreground_message_handlers = FuturesUnordered::new();
1018 let concurrent_handlers = Arc::new(Semaphore::new(256));
1019 loop {
1020 let next_message = async {
1021 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1022 let message = incoming_rx.next().await;
1023 (permit, message)
1024 }.fuse();
1025 futures::pin_mut!(next_message);
1026 futures::select_biased! {
1027 _ = teardown.changed().fuse() => return,
1028 result = handle_io => {
1029 if let Err(error) = result {
1030 tracing::error!(?error, "error handling I/O");
1031 }
1032 break;
1033 }
1034 _ = foreground_message_handlers.next() => {}
1035 next_message = next_message => {
1036 let (permit, message) = next_message;
1037 if let Some(message) = message {
1038 let type_name = message.payload_type_name();
1039 // note: we copy all the fields from the parent span so we can query them in the logs.
1040 // (https://github.com/tokio-rs/tracing/issues/2670).
1041 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1042 user_id=field::Empty,
1043 login=field::Empty,
1044 impersonator=field::Empty,
1045 dev_server_id=field::Empty
1046 );
1047 principal.update_span(&span);
1048 let span_enter = span.enter();
1049 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1050 let is_background = message.is_background();
1051 let handle_message = (handler)(message, session.clone());
1052 drop(span_enter);
1053
1054 let handle_message = async move {
1055 handle_message.await;
1056 drop(permit);
1057 }.instrument(span);
1058 if is_background {
1059 executor.spawn_detached(handle_message);
1060 } else {
1061 foreground_message_handlers.push(handle_message);
1062 }
1063 } else {
1064 tracing::error!("no message handler");
1065 }
1066 } else {
1067 tracing::info!("connection closed");
1068 break;
1069 }
1070 }
1071 }
1072 }
1073
1074 drop(foreground_message_handlers);
1075 tracing::info!("signing out");
1076 if let Err(error) = connection_lost(session, teardown, executor).await {
1077 tracing::error!(?error, "error signing out");
1078 }
1079
1080 }.instrument(span)
1081 }
1082
1083 async fn send_initial_client_update(
1084 &self,
1085 connection_id: ConnectionId,
1086 principal: &Principal,
1087 zed_version: ZedVersion,
1088 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1089 session: &Session,
1090 ) -> Result<()> {
1091 self.peer.send(
1092 connection_id,
1093 proto::Hello {
1094 peer_id: Some(connection_id.into()),
1095 },
1096 )?;
1097 tracing::info!("sent hello message");
1098 if let Some(send_connection_id) = send_connection_id.take() {
1099 let _ = send_connection_id.send(connection_id);
1100 }
1101
1102 match principal {
1103 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1104 if !user.connected_once {
1105 self.peer.send(connection_id, proto::ShowContacts {})?;
1106 self.app_state
1107 .db
1108 .set_user_connected_once(user.id, true)
1109 .await?;
1110 }
1111
1112 let (contacts, dev_server_projects) = future::try_join(
1113 self.app_state.db.get_contacts(user.id),
1114 self.app_state.db.dev_server_projects_update(user.id),
1115 )
1116 .await?;
1117
1118 {
1119 let mut pool = self.connection_pool.lock();
1120 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1121 self.peer.send(
1122 connection_id,
1123 build_initial_contacts_update(contacts, &pool),
1124 )?;
1125 }
1126
1127 if should_auto_subscribe_to_channels(zed_version) {
1128 subscribe_user_to_channels(user.id, session).await?;
1129 }
1130
1131 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1132
1133 if let Some(incoming_call) =
1134 self.app_state.db.incoming_call_for_user(user.id).await?
1135 {
1136 self.peer.send(connection_id, incoming_call)?;
1137 }
1138
1139 update_user_contacts(user.id, &session).await?;
1140 }
1141 Principal::DevServer(dev_server) => {
1142 {
1143 let mut pool = self.connection_pool.lock();
1144 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1145 {
1146 self.peer.send(
1147 stale_connection_id,
1148 proto::ShutdownDevServer {
1149 reason: Some(
1150 "another dev server connected with the same token".to_string(),
1151 ),
1152 },
1153 )?;
1154 pool.remove_connection(stale_connection_id)?;
1155 };
1156 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1157 }
1158
1159 let projects = self
1160 .app_state
1161 .db
1162 .get_projects_for_dev_server(dev_server.id)
1163 .await?;
1164 self.peer
1165 .send(connection_id, proto::DevServerInstructions { projects })?;
1166
1167 let status = self
1168 .app_state
1169 .db
1170 .dev_server_projects_update(dev_server.user_id)
1171 .await?;
1172 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1173 }
1174 }
1175
1176 Ok(())
1177 }
1178
1179 pub async fn invite_code_redeemed(
1180 self: &Arc<Self>,
1181 inviter_id: UserId,
1182 invitee_id: UserId,
1183 ) -> Result<()> {
1184 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1185 if let Some(code) = &user.invite_code {
1186 let pool = self.connection_pool.lock();
1187 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1188 for connection_id in pool.user_connection_ids(inviter_id) {
1189 self.peer.send(
1190 connection_id,
1191 proto::UpdateContacts {
1192 contacts: vec![invitee_contact.clone()],
1193 ..Default::default()
1194 },
1195 )?;
1196 self.peer.send(
1197 connection_id,
1198 proto::UpdateInviteInfo {
1199 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1200 count: user.invite_count as u32,
1201 },
1202 )?;
1203 }
1204 }
1205 }
1206 Ok(())
1207 }
1208
1209 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1210 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1211 if let Some(invite_code) = &user.invite_code {
1212 let pool = self.connection_pool.lock();
1213 for connection_id in pool.user_connection_ids(user_id) {
1214 self.peer.send(
1215 connection_id,
1216 proto::UpdateInviteInfo {
1217 url: format!(
1218 "{}{}",
1219 self.app_state.config.invite_link_prefix, invite_code
1220 ),
1221 count: user.invite_count as u32,
1222 },
1223 )?;
1224 }
1225 }
1226 }
1227 Ok(())
1228 }
1229
1230 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1231 ServerSnapshot {
1232 connection_pool: ConnectionPoolGuard {
1233 guard: self.connection_pool.lock(),
1234 _not_send: PhantomData,
1235 },
1236 peer: &self.peer,
1237 }
1238 }
1239}
1240
1241impl<'a> Deref for ConnectionPoolGuard<'a> {
1242 type Target = ConnectionPool;
1243
1244 fn deref(&self) -> &Self::Target {
1245 &self.guard
1246 }
1247}
1248
1249impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1250 fn deref_mut(&mut self) -> &mut Self::Target {
1251 &mut self.guard
1252 }
1253}
1254
1255impl<'a> Drop for ConnectionPoolGuard<'a> {
1256 fn drop(&mut self) {
1257 #[cfg(test)]
1258 self.check_invariants();
1259 }
1260}
1261
1262fn broadcast<F>(
1263 sender_id: Option<ConnectionId>,
1264 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1265 mut f: F,
1266) where
1267 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1268{
1269 for receiver_id in receiver_ids {
1270 if Some(receiver_id) != sender_id {
1271 if let Err(error) = f(receiver_id) {
1272 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1273 }
1274 }
1275 }
1276}
1277
1278pub struct ProtocolVersion(u32);
1279
1280impl Header for ProtocolVersion {
1281 fn name() -> &'static HeaderName {
1282 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1283 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1284 }
1285
1286 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1287 where
1288 Self: Sized,
1289 I: Iterator<Item = &'i axum::http::HeaderValue>,
1290 {
1291 let version = values
1292 .next()
1293 .ok_or_else(axum::headers::Error::invalid)?
1294 .to_str()
1295 .map_err(|_| axum::headers::Error::invalid())?
1296 .parse()
1297 .map_err(|_| axum::headers::Error::invalid())?;
1298 Ok(Self(version))
1299 }
1300
1301 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1302 values.extend([self.0.to_string().parse().unwrap()]);
1303 }
1304}
1305
1306pub struct AppVersionHeader(SemanticVersion);
1307impl Header for AppVersionHeader {
1308 fn name() -> &'static HeaderName {
1309 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1311 }
1312
1313 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314 where
1315 Self: Sized,
1316 I: Iterator<Item = &'i axum::http::HeaderValue>,
1317 {
1318 let version = values
1319 .next()
1320 .ok_or_else(axum::headers::Error::invalid)?
1321 .to_str()
1322 .map_err(|_| axum::headers::Error::invalid())?
1323 .parse()
1324 .map_err(|_| axum::headers::Error::invalid())?;
1325 Ok(Self(version))
1326 }
1327
1328 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329 values.extend([self.0.to_string().parse().unwrap()]);
1330 }
1331}
1332
1333pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1334 Router::new()
1335 .route("/rpc", get(handle_websocket_request))
1336 .layer(
1337 ServiceBuilder::new()
1338 .layer(Extension(server.app_state.clone()))
1339 .layer(middleware::from_fn(auth::validate_header)),
1340 )
1341 .route("/metrics", get(handle_metrics))
1342 .layer(Extension(server))
1343}
1344
1345pub async fn handle_websocket_request(
1346 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1347 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1348 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1349 Extension(server): Extension<Arc<Server>>,
1350 Extension(principal): Extension<Principal>,
1351 ws: WebSocketUpgrade,
1352) -> axum::response::Response {
1353 if protocol_version != rpc::PROTOCOL_VERSION {
1354 return (
1355 StatusCode::UPGRADE_REQUIRED,
1356 "client must be upgraded".to_string(),
1357 )
1358 .into_response();
1359 }
1360
1361 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1362 return (
1363 StatusCode::UPGRADE_REQUIRED,
1364 "no version header found".to_string(),
1365 )
1366 .into_response();
1367 };
1368
1369 if !version.can_collaborate() {
1370 return (
1371 StatusCode::UPGRADE_REQUIRED,
1372 "client must be upgraded".to_string(),
1373 )
1374 .into_response();
1375 }
1376
1377 let socket_address = socket_address.to_string();
1378 ws.on_upgrade(move |socket| {
1379 let socket = socket
1380 .map_ok(to_tungstenite_message)
1381 .err_into()
1382 .with(|message| async move { Ok(to_axum_message(message)) });
1383 let connection = Connection::new(Box::pin(socket));
1384 async move {
1385 server
1386 .handle_connection(
1387 connection,
1388 socket_address,
1389 principal,
1390 version,
1391 None,
1392 Executor::Production,
1393 )
1394 .await;
1395 }
1396 })
1397}
1398
1399pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1400 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401 let connections_metric = CONNECTIONS_METRIC
1402 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1403
1404 let connections = server
1405 .connection_pool
1406 .lock()
1407 .connections()
1408 .filter(|connection| !connection.admin)
1409 .count();
1410 connections_metric.set(connections as _);
1411
1412 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1413 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1414 register_int_gauge!(
1415 "shared_projects",
1416 "number of open projects with one or more guests"
1417 )
1418 .unwrap()
1419 });
1420
1421 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1422 shared_projects_metric.set(shared_projects as _);
1423
1424 let encoder = prometheus::TextEncoder::new();
1425 let metric_families = prometheus::gather();
1426 let encoded_metrics = encoder
1427 .encode_to_string(&metric_families)
1428 .map_err(|err| anyhow!("{}", err))?;
1429 Ok(encoded_metrics)
1430}
1431
1432#[instrument(err, skip(executor))]
1433async fn connection_lost(
1434 session: Session,
1435 mut teardown: watch::Receiver<bool>,
1436 executor: Executor,
1437) -> Result<()> {
1438 session.peer.disconnect(session.connection_id);
1439 session
1440 .connection_pool()
1441 .await
1442 .remove_connection(session.connection_id)?;
1443
1444 session
1445 .db()
1446 .await
1447 .connection_lost(session.connection_id)
1448 .await
1449 .trace_err();
1450
1451 futures::select_biased! {
1452 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1453 match &session.principal {
1454 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1455 let session = session.for_user().unwrap();
1456
1457 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1458 leave_room_for_session(&session, session.connection_id).await.trace_err();
1459 leave_channel_buffers_for_session(&session)
1460 .await
1461 .trace_err();
1462
1463 if !session
1464 .connection_pool()
1465 .await
1466 .is_user_online(session.user_id())
1467 {
1468 let db = session.db().await;
1469 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1470 room_updated(&room, &session.peer);
1471 }
1472 }
1473
1474 update_user_contacts(session.user_id(), &session).await?;
1475 },
1476 Principal::DevServer(_) => {
1477 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1478 },
1479 }
1480 },
1481 _ = teardown.changed().fuse() => {}
1482 }
1483
1484 Ok(())
1485}
1486
1487/// Acknowledges a ping from a client, used to keep the connection alive.
1488async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1489 response.send(proto::Ack {})?;
1490 Ok(())
1491}
1492
1493/// Creates a new room for calling (outside of channels)
1494async fn create_room(
1495 _request: proto::CreateRoom,
1496 response: Response<proto::CreateRoom>,
1497 session: UserSession,
1498) -> Result<()> {
1499 let live_kit_room = nanoid::nanoid!(30);
1500
1501 let live_kit_connection_info = util::maybe!(async {
1502 let live_kit = session.live_kit_client.as_ref();
1503 let live_kit = live_kit?;
1504 let user_id = session.user_id().to_string();
1505
1506 let token = live_kit
1507 .room_token(&live_kit_room, &user_id.to_string())
1508 .trace_err()?;
1509
1510 Some(proto::LiveKitConnectionInfo {
1511 server_url: live_kit.url().into(),
1512 token,
1513 can_publish: true,
1514 })
1515 })
1516 .await;
1517
1518 let room = session
1519 .db()
1520 .await
1521 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1522 .await?;
1523
1524 response.send(proto::CreateRoomResponse {
1525 room: Some(room.clone()),
1526 live_kit_connection_info,
1527 })?;
1528
1529 update_user_contacts(session.user_id(), &session).await?;
1530 Ok(())
1531}
1532
1533/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1534async fn join_room(
1535 request: proto::JoinRoom,
1536 response: Response<proto::JoinRoom>,
1537 session: UserSession,
1538) -> Result<()> {
1539 let room_id = RoomId::from_proto(request.id);
1540
1541 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1542
1543 if let Some(channel_id) = channel_id {
1544 return join_channel_internal(channel_id, Box::new(response), session).await;
1545 }
1546
1547 let joined_room = {
1548 let room = session
1549 .db()
1550 .await
1551 .join_room(room_id, session.user_id(), session.connection_id)
1552 .await?;
1553 room_updated(&room.room, &session.peer);
1554 room.into_inner()
1555 };
1556
1557 for connection_id in session
1558 .connection_pool()
1559 .await
1560 .user_connection_ids(session.user_id())
1561 {
1562 session
1563 .peer
1564 .send(
1565 connection_id,
1566 proto::CallCanceled {
1567 room_id: room_id.to_proto(),
1568 },
1569 )
1570 .trace_err();
1571 }
1572
1573 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1574 if let Some(token) = live_kit
1575 .room_token(
1576 &joined_room.room.live_kit_room,
1577 &session.user_id().to_string(),
1578 )
1579 .trace_err()
1580 {
1581 Some(proto::LiveKitConnectionInfo {
1582 server_url: live_kit.url().into(),
1583 token,
1584 can_publish: true,
1585 })
1586 } else {
1587 None
1588 }
1589 } else {
1590 None
1591 };
1592
1593 response.send(proto::JoinRoomResponse {
1594 room: Some(joined_room.room),
1595 channel_id: None,
1596 live_kit_connection_info,
1597 })?;
1598
1599 update_user_contacts(session.user_id(), &session).await?;
1600 Ok(())
1601}
1602
1603/// Rejoin room is used to reconnect to a room after connection errors.
1604async fn rejoin_room(
1605 request: proto::RejoinRoom,
1606 response: Response<proto::RejoinRoom>,
1607 session: UserSession,
1608) -> Result<()> {
1609 let room;
1610 let channel;
1611 {
1612 let mut rejoined_room = session
1613 .db()
1614 .await
1615 .rejoin_room(request, session.user_id(), session.connection_id)
1616 .await?;
1617
1618 response.send(proto::RejoinRoomResponse {
1619 room: Some(rejoined_room.room.clone()),
1620 reshared_projects: rejoined_room
1621 .reshared_projects
1622 .iter()
1623 .map(|project| proto::ResharedProject {
1624 id: project.id.to_proto(),
1625 collaborators: project
1626 .collaborators
1627 .iter()
1628 .map(|collaborator| collaborator.to_proto())
1629 .collect(),
1630 })
1631 .collect(),
1632 rejoined_projects: rejoined_room
1633 .rejoined_projects
1634 .iter()
1635 .map(|rejoined_project| rejoined_project.to_proto())
1636 .collect(),
1637 })?;
1638 room_updated(&rejoined_room.room, &session.peer);
1639
1640 for project in &rejoined_room.reshared_projects {
1641 for collaborator in &project.collaborators {
1642 session
1643 .peer
1644 .send(
1645 collaborator.connection_id,
1646 proto::UpdateProjectCollaborator {
1647 project_id: project.id.to_proto(),
1648 old_peer_id: Some(project.old_connection_id.into()),
1649 new_peer_id: Some(session.connection_id.into()),
1650 },
1651 )
1652 .trace_err();
1653 }
1654
1655 broadcast(
1656 Some(session.connection_id),
1657 project
1658 .collaborators
1659 .iter()
1660 .map(|collaborator| collaborator.connection_id),
1661 |connection_id| {
1662 session.peer.forward_send(
1663 session.connection_id,
1664 connection_id,
1665 proto::UpdateProject {
1666 project_id: project.id.to_proto(),
1667 worktrees: project.worktrees.clone(),
1668 },
1669 )
1670 },
1671 );
1672 }
1673
1674 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1675
1676 let rejoined_room = rejoined_room.into_inner();
1677
1678 room = rejoined_room.room;
1679 channel = rejoined_room.channel;
1680 }
1681
1682 if let Some(channel) = channel {
1683 channel_updated(
1684 &channel,
1685 &room,
1686 &session.peer,
1687 &*session.connection_pool().await,
1688 );
1689 }
1690
1691 update_user_contacts(session.user_id(), &session).await?;
1692 Ok(())
1693}
1694
1695fn notify_rejoined_projects(
1696 rejoined_projects: &mut Vec<RejoinedProject>,
1697 session: &UserSession,
1698) -> Result<()> {
1699 for project in rejoined_projects.iter() {
1700 for collaborator in &project.collaborators {
1701 session
1702 .peer
1703 .send(
1704 collaborator.connection_id,
1705 proto::UpdateProjectCollaborator {
1706 project_id: project.id.to_proto(),
1707 old_peer_id: Some(project.old_connection_id.into()),
1708 new_peer_id: Some(session.connection_id.into()),
1709 },
1710 )
1711 .trace_err();
1712 }
1713 }
1714
1715 for project in rejoined_projects {
1716 for worktree in mem::take(&mut project.worktrees) {
1717 #[cfg(any(test, feature = "test-support"))]
1718 const MAX_CHUNK_SIZE: usize = 2;
1719 #[cfg(not(any(test, feature = "test-support")))]
1720 const MAX_CHUNK_SIZE: usize = 256;
1721
1722 // Stream this worktree's entries.
1723 let message = proto::UpdateWorktree {
1724 project_id: project.id.to_proto(),
1725 worktree_id: worktree.id,
1726 abs_path: worktree.abs_path.clone(),
1727 root_name: worktree.root_name,
1728 updated_entries: worktree.updated_entries,
1729 removed_entries: worktree.removed_entries,
1730 scan_id: worktree.scan_id,
1731 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1732 updated_repositories: worktree.updated_repositories,
1733 removed_repositories: worktree.removed_repositories,
1734 };
1735 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1736 session.peer.send(session.connection_id, update.clone())?;
1737 }
1738
1739 // Stream this worktree's diagnostics.
1740 for summary in worktree.diagnostic_summaries {
1741 session.peer.send(
1742 session.connection_id,
1743 proto::UpdateDiagnosticSummary {
1744 project_id: project.id.to_proto(),
1745 worktree_id: worktree.id,
1746 summary: Some(summary),
1747 },
1748 )?;
1749 }
1750
1751 for settings_file in worktree.settings_files {
1752 session.peer.send(
1753 session.connection_id,
1754 proto::UpdateWorktreeSettings {
1755 project_id: project.id.to_proto(),
1756 worktree_id: worktree.id,
1757 path: settings_file.path,
1758 content: Some(settings_file.content),
1759 },
1760 )?;
1761 }
1762 }
1763
1764 for language_server in &project.language_servers {
1765 session.peer.send(
1766 session.connection_id,
1767 proto::UpdateLanguageServer {
1768 project_id: project.id.to_proto(),
1769 language_server_id: language_server.id,
1770 variant: Some(
1771 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1772 proto::LspDiskBasedDiagnosticsUpdated {},
1773 ),
1774 ),
1775 },
1776 )?;
1777 }
1778 }
1779 Ok(())
1780}
1781
1782/// leave room disconnects from the room.
1783async fn leave_room(
1784 _: proto::LeaveRoom,
1785 response: Response<proto::LeaveRoom>,
1786 session: UserSession,
1787) -> Result<()> {
1788 leave_room_for_session(&session, session.connection_id).await?;
1789 response.send(proto::Ack {})?;
1790 Ok(())
1791}
1792
1793/// Updates the permissions of someone else in the room.
1794async fn set_room_participant_role(
1795 request: proto::SetRoomParticipantRole,
1796 response: Response<proto::SetRoomParticipantRole>,
1797 session: UserSession,
1798) -> Result<()> {
1799 let user_id = UserId::from_proto(request.user_id);
1800 let role = ChannelRole::from(request.role());
1801
1802 let (live_kit_room, can_publish) = {
1803 let room = session
1804 .db()
1805 .await
1806 .set_room_participant_role(
1807 session.user_id(),
1808 RoomId::from_proto(request.room_id),
1809 user_id,
1810 role,
1811 )
1812 .await?;
1813
1814 let live_kit_room = room.live_kit_room.clone();
1815 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1816 room_updated(&room, &session.peer);
1817 (live_kit_room, can_publish)
1818 };
1819
1820 if let Some(live_kit) = session.live_kit_client.as_ref() {
1821 live_kit
1822 .update_participant(
1823 live_kit_room.clone(),
1824 request.user_id.to_string(),
1825 live_kit_server::proto::ParticipantPermission {
1826 can_subscribe: true,
1827 can_publish,
1828 can_publish_data: can_publish,
1829 hidden: false,
1830 recorder: false,
1831 },
1832 )
1833 .await
1834 .trace_err();
1835 }
1836
1837 response.send(proto::Ack {})?;
1838 Ok(())
1839}
1840
1841/// Call someone else into the current room
1842async fn call(
1843 request: proto::Call,
1844 response: Response<proto::Call>,
1845 session: UserSession,
1846) -> Result<()> {
1847 let room_id = RoomId::from_proto(request.room_id);
1848 let calling_user_id = session.user_id();
1849 let calling_connection_id = session.connection_id;
1850 let called_user_id = UserId::from_proto(request.called_user_id);
1851 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1852 if !session
1853 .db()
1854 .await
1855 .has_contact(calling_user_id, called_user_id)
1856 .await?
1857 {
1858 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1859 }
1860
1861 let incoming_call = {
1862 let (room, incoming_call) = &mut *session
1863 .db()
1864 .await
1865 .call(
1866 room_id,
1867 calling_user_id,
1868 calling_connection_id,
1869 called_user_id,
1870 initial_project_id,
1871 )
1872 .await?;
1873 room_updated(&room, &session.peer);
1874 mem::take(incoming_call)
1875 };
1876 update_user_contacts(called_user_id, &session).await?;
1877
1878 let mut calls = session
1879 .connection_pool()
1880 .await
1881 .user_connection_ids(called_user_id)
1882 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1883 .collect::<FuturesUnordered<_>>();
1884
1885 while let Some(call_response) = calls.next().await {
1886 match call_response.as_ref() {
1887 Ok(_) => {
1888 response.send(proto::Ack {})?;
1889 return Ok(());
1890 }
1891 Err(_) => {
1892 call_response.trace_err();
1893 }
1894 }
1895 }
1896
1897 {
1898 let room = session
1899 .db()
1900 .await
1901 .call_failed(room_id, called_user_id)
1902 .await?;
1903 room_updated(&room, &session.peer);
1904 }
1905 update_user_contacts(called_user_id, &session).await?;
1906
1907 Err(anyhow!("failed to ring user"))?
1908}
1909
1910/// Cancel an outgoing call.
1911async fn cancel_call(
1912 request: proto::CancelCall,
1913 response: Response<proto::CancelCall>,
1914 session: UserSession,
1915) -> Result<()> {
1916 let called_user_id = UserId::from_proto(request.called_user_id);
1917 let room_id = RoomId::from_proto(request.room_id);
1918 {
1919 let room = session
1920 .db()
1921 .await
1922 .cancel_call(room_id, session.connection_id, called_user_id)
1923 .await?;
1924 room_updated(&room, &session.peer);
1925 }
1926
1927 for connection_id in session
1928 .connection_pool()
1929 .await
1930 .user_connection_ids(called_user_id)
1931 {
1932 session
1933 .peer
1934 .send(
1935 connection_id,
1936 proto::CallCanceled {
1937 room_id: room_id.to_proto(),
1938 },
1939 )
1940 .trace_err();
1941 }
1942 response.send(proto::Ack {})?;
1943
1944 update_user_contacts(called_user_id, &session).await?;
1945 Ok(())
1946}
1947
1948/// Decline an incoming call.
1949async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1950 let room_id = RoomId::from_proto(message.room_id);
1951 {
1952 let room = session
1953 .db()
1954 .await
1955 .decline_call(Some(room_id), session.user_id())
1956 .await?
1957 .ok_or_else(|| anyhow!("failed to decline call"))?;
1958 room_updated(&room, &session.peer);
1959 }
1960
1961 for connection_id in session
1962 .connection_pool()
1963 .await
1964 .user_connection_ids(session.user_id())
1965 {
1966 session
1967 .peer
1968 .send(
1969 connection_id,
1970 proto::CallCanceled {
1971 room_id: room_id.to_proto(),
1972 },
1973 )
1974 .trace_err();
1975 }
1976 update_user_contacts(session.user_id(), &session).await?;
1977 Ok(())
1978}
1979
1980/// Updates other participants in the room with your current location.
1981async fn update_participant_location(
1982 request: proto::UpdateParticipantLocation,
1983 response: Response<proto::UpdateParticipantLocation>,
1984 session: UserSession,
1985) -> Result<()> {
1986 let room_id = RoomId::from_proto(request.room_id);
1987 let location = request
1988 .location
1989 .ok_or_else(|| anyhow!("invalid location"))?;
1990
1991 let db = session.db().await;
1992 let room = db
1993 .update_room_participant_location(room_id, session.connection_id, location)
1994 .await?;
1995
1996 room_updated(&room, &session.peer);
1997 response.send(proto::Ack {})?;
1998 Ok(())
1999}
2000
2001/// Share a project into the room.
2002async fn share_project(
2003 request: proto::ShareProject,
2004 response: Response<proto::ShareProject>,
2005 session: UserSession,
2006) -> Result<()> {
2007 let (project_id, room) = &*session
2008 .db()
2009 .await
2010 .share_project(
2011 RoomId::from_proto(request.room_id),
2012 session.connection_id,
2013 &request.worktrees,
2014 request
2015 .dev_server_project_id
2016 .map(|id| DevServerProjectId::from_proto(id)),
2017 )
2018 .await?;
2019 response.send(proto::ShareProjectResponse {
2020 project_id: project_id.to_proto(),
2021 })?;
2022 room_updated(&room, &session.peer);
2023
2024 Ok(())
2025}
2026
2027/// Unshare a project from the room.
2028async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2029 let project_id = ProjectId::from_proto(message.project_id);
2030 unshare_project_internal(
2031 project_id,
2032 session.connection_id,
2033 session.user_id(),
2034 &session,
2035 )
2036 .await
2037}
2038
2039async fn unshare_project_internal(
2040 project_id: ProjectId,
2041 connection_id: ConnectionId,
2042 user_id: Option<UserId>,
2043 session: &Session,
2044) -> Result<()> {
2045 let delete = {
2046 let room_guard = session
2047 .db()
2048 .await
2049 .unshare_project(project_id, connection_id, user_id)
2050 .await?;
2051
2052 let (delete, room, guest_connection_ids) = &*room_guard;
2053
2054 let message = proto::UnshareProject {
2055 project_id: project_id.to_proto(),
2056 };
2057
2058 broadcast(
2059 Some(connection_id),
2060 guest_connection_ids.iter().copied(),
2061 |conn_id| session.peer.send(conn_id, message.clone()),
2062 );
2063 if let Some(room) = room {
2064 room_updated(room, &session.peer);
2065 }
2066
2067 *delete
2068 };
2069
2070 if delete {
2071 let db = session.db().await;
2072 db.delete_project(project_id).await?;
2073 }
2074
2075 Ok(())
2076}
2077
2078/// DevServer makes a project available online
2079async fn share_dev_server_project(
2080 request: proto::ShareDevServerProject,
2081 response: Response<proto::ShareDevServerProject>,
2082 session: DevServerSession,
2083) -> Result<()> {
2084 let (dev_server_project, user_id, status) = session
2085 .db()
2086 .await
2087 .share_dev_server_project(
2088 DevServerProjectId::from_proto(request.dev_server_project_id),
2089 session.dev_server_id(),
2090 session.connection_id,
2091 &request.worktrees,
2092 )
2093 .await?;
2094 let Some(project_id) = dev_server_project.project_id else {
2095 return Err(anyhow!("failed to share remote project"))?;
2096 };
2097
2098 send_dev_server_projects_update(user_id, status, &session).await;
2099
2100 response.send(proto::ShareProjectResponse { project_id })?;
2101
2102 Ok(())
2103}
2104
2105/// Join someone elses shared project.
2106async fn join_project(
2107 request: proto::JoinProject,
2108 response: Response<proto::JoinProject>,
2109 session: UserSession,
2110) -> Result<()> {
2111 let project_id = ProjectId::from_proto(request.project_id);
2112
2113 tracing::info!(%project_id, "join project");
2114
2115 let db = session.db().await;
2116 let (project, replica_id) = &mut *db
2117 .join_project(project_id, session.connection_id, session.user_id())
2118 .await?;
2119 drop(db);
2120 tracing::info!(%project_id, "join remote project");
2121 join_project_internal(response, session, project, replica_id)
2122}
2123
2124trait JoinProjectInternalResponse {
2125 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2126}
2127impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2128 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2129 Response::<proto::JoinProject>::send(self, result)
2130 }
2131}
2132impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2133 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2134 Response::<proto::JoinHostedProject>::send(self, result)
2135 }
2136}
2137
2138fn join_project_internal(
2139 response: impl JoinProjectInternalResponse,
2140 session: UserSession,
2141 project: &mut Project,
2142 replica_id: &ReplicaId,
2143) -> Result<()> {
2144 let collaborators = project
2145 .collaborators
2146 .iter()
2147 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2148 .map(|collaborator| collaborator.to_proto())
2149 .collect::<Vec<_>>();
2150 let project_id = project.id;
2151 let guest_user_id = session.user_id();
2152
2153 let worktrees = project
2154 .worktrees
2155 .iter()
2156 .map(|(id, worktree)| proto::WorktreeMetadata {
2157 id: *id,
2158 root_name: worktree.root_name.clone(),
2159 visible: worktree.visible,
2160 abs_path: worktree.abs_path.clone(),
2161 })
2162 .collect::<Vec<_>>();
2163
2164 let add_project_collaborator = proto::AddProjectCollaborator {
2165 project_id: project_id.to_proto(),
2166 collaborator: Some(proto::Collaborator {
2167 peer_id: Some(session.connection_id.into()),
2168 replica_id: replica_id.0 as u32,
2169 user_id: guest_user_id.to_proto(),
2170 }),
2171 };
2172
2173 for collaborator in &collaborators {
2174 session
2175 .peer
2176 .send(
2177 collaborator.peer_id.unwrap().into(),
2178 add_project_collaborator.clone(),
2179 )
2180 .trace_err();
2181 }
2182
2183 // First, we send the metadata associated with each worktree.
2184 response.send(proto::JoinProjectResponse {
2185 project_id: project.id.0 as u64,
2186 worktrees: worktrees.clone(),
2187 replica_id: replica_id.0 as u32,
2188 collaborators: collaborators.clone(),
2189 language_servers: project.language_servers.clone(),
2190 role: project.role.into(),
2191 dev_server_project_id: project
2192 .dev_server_project_id
2193 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2194 })?;
2195
2196 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2197 #[cfg(any(test, feature = "test-support"))]
2198 const MAX_CHUNK_SIZE: usize = 2;
2199 #[cfg(not(any(test, feature = "test-support")))]
2200 const MAX_CHUNK_SIZE: usize = 256;
2201
2202 // Stream this worktree's entries.
2203 let message = proto::UpdateWorktree {
2204 project_id: project_id.to_proto(),
2205 worktree_id,
2206 abs_path: worktree.abs_path.clone(),
2207 root_name: worktree.root_name,
2208 updated_entries: worktree.entries,
2209 removed_entries: Default::default(),
2210 scan_id: worktree.scan_id,
2211 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2212 updated_repositories: worktree.repository_entries.into_values().collect(),
2213 removed_repositories: Default::default(),
2214 };
2215 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2216 session.peer.send(session.connection_id, update.clone())?;
2217 }
2218
2219 // Stream this worktree's diagnostics.
2220 for summary in worktree.diagnostic_summaries {
2221 session.peer.send(
2222 session.connection_id,
2223 proto::UpdateDiagnosticSummary {
2224 project_id: project_id.to_proto(),
2225 worktree_id: worktree.id,
2226 summary: Some(summary),
2227 },
2228 )?;
2229 }
2230
2231 for settings_file in worktree.settings_files {
2232 session.peer.send(
2233 session.connection_id,
2234 proto::UpdateWorktreeSettings {
2235 project_id: project_id.to_proto(),
2236 worktree_id: worktree.id,
2237 path: settings_file.path,
2238 content: Some(settings_file.content),
2239 },
2240 )?;
2241 }
2242 }
2243
2244 for language_server in &project.language_servers {
2245 session.peer.send(
2246 session.connection_id,
2247 proto::UpdateLanguageServer {
2248 project_id: project_id.to_proto(),
2249 language_server_id: language_server.id,
2250 variant: Some(
2251 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2252 proto::LspDiskBasedDiagnosticsUpdated {},
2253 ),
2254 ),
2255 },
2256 )?;
2257 }
2258
2259 Ok(())
2260}
2261
2262/// Leave someone elses shared project.
2263async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2264 let sender_id = session.connection_id;
2265 let project_id = ProjectId::from_proto(request.project_id);
2266 let db = session.db().await;
2267 if db.is_hosted_project(project_id).await? {
2268 let project = db.leave_hosted_project(project_id, sender_id).await?;
2269 project_left(&project, &session);
2270 return Ok(());
2271 }
2272
2273 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2274 tracing::info!(
2275 %project_id,
2276 "leave project"
2277 );
2278
2279 project_left(&project, &session);
2280 if let Some(room) = room {
2281 room_updated(&room, &session.peer);
2282 }
2283
2284 Ok(())
2285}
2286
2287async fn join_hosted_project(
2288 request: proto::JoinHostedProject,
2289 response: Response<proto::JoinHostedProject>,
2290 session: UserSession,
2291) -> Result<()> {
2292 let (mut project, replica_id) = session
2293 .db()
2294 .await
2295 .join_hosted_project(
2296 ProjectId(request.project_id as i32),
2297 session.user_id(),
2298 session.connection_id,
2299 )
2300 .await?;
2301
2302 join_project_internal(response, session, &mut project, &replica_id)
2303}
2304
2305async fn create_dev_server_project(
2306 request: proto::CreateDevServerProject,
2307 response: Response<proto::CreateDevServerProject>,
2308 session: UserSession,
2309) -> Result<()> {
2310 let dev_server_id = DevServerId(request.dev_server_id as i32);
2311 let dev_server_connection_id = session
2312 .connection_pool()
2313 .await
2314 .dev_server_connection_id(dev_server_id);
2315 let Some(dev_server_connection_id) = dev_server_connection_id else {
2316 Err(ErrorCode::DevServerOffline
2317 .message("Cannot create a remote project when the dev server is offline".to_string())
2318 .anyhow())?
2319 };
2320
2321 let path = request.path.clone();
2322 //Check that the path exists on the dev server
2323 session
2324 .peer
2325 .forward_request(
2326 session.connection_id,
2327 dev_server_connection_id,
2328 proto::ValidateDevServerProjectRequest { path: path.clone() },
2329 )
2330 .await?;
2331
2332 let (dev_server_project, update) = session
2333 .db()
2334 .await
2335 .create_dev_server_project(
2336 DevServerId(request.dev_server_id as i32),
2337 &request.path,
2338 session.user_id(),
2339 )
2340 .await?;
2341
2342 let projects = session
2343 .db()
2344 .await
2345 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2346 .await?;
2347
2348 session.peer.send(
2349 dev_server_connection_id,
2350 proto::DevServerInstructions { projects },
2351 )?;
2352
2353 send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355 response.send(proto::CreateDevServerProjectResponse {
2356 dev_server_project: Some(dev_server_project.to_proto(None)),
2357 })?;
2358 Ok(())
2359}
2360
2361async fn create_dev_server(
2362 request: proto::CreateDevServer,
2363 response: Response<proto::CreateDevServer>,
2364 session: UserSession,
2365) -> Result<()> {
2366 let access_token = auth::random_token();
2367 let hashed_access_token = auth::hash_access_token(&access_token);
2368
2369 if request.name.is_empty() {
2370 return Err(proto::ErrorCode::Forbidden
2371 .message("Dev server name cannot be empty".to_string())
2372 .anyhow())?;
2373 }
2374
2375 let (dev_server, status) = session
2376 .db()
2377 .await
2378 .create_dev_server(
2379 &request.name,
2380 request.ssh_connection_string.as_deref(),
2381 &hashed_access_token,
2382 session.user_id(),
2383 )
2384 .await?;
2385
2386 send_dev_server_projects_update(session.user_id(), status, &session).await;
2387
2388 response.send(proto::CreateDevServerResponse {
2389 dev_server_id: dev_server.id.0 as u64,
2390 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2391 name: request.name,
2392 })?;
2393 Ok(())
2394}
2395
2396async fn regenerate_dev_server_token(
2397 request: proto::RegenerateDevServerToken,
2398 response: Response<proto::RegenerateDevServerToken>,
2399 session: UserSession,
2400) -> Result<()> {
2401 let dev_server_id = DevServerId(request.dev_server_id as i32);
2402 let access_token = auth::random_token();
2403 let hashed_access_token = auth::hash_access_token(&access_token);
2404
2405 let connection_id = session
2406 .connection_pool()
2407 .await
2408 .dev_server_connection_id(dev_server_id);
2409 if let Some(connection_id) = connection_id {
2410 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2411 session.peer.send(
2412 connection_id,
2413 proto::ShutdownDevServer {
2414 reason: Some("dev server token was regenerated".to_string()),
2415 },
2416 )?;
2417 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2418 }
2419
2420 let status = session
2421 .db()
2422 .await
2423 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2424 .await?;
2425
2426 send_dev_server_projects_update(session.user_id(), status, &session).await;
2427
2428 response.send(proto::RegenerateDevServerTokenResponse {
2429 dev_server_id: dev_server_id.to_proto(),
2430 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2431 })?;
2432 Ok(())
2433}
2434
2435async fn rename_dev_server(
2436 request: proto::RenameDevServer,
2437 response: Response<proto::RenameDevServer>,
2438 session: UserSession,
2439) -> Result<()> {
2440 if request.name.trim().is_empty() {
2441 return Err(proto::ErrorCode::Forbidden
2442 .message("Dev server name cannot be empty".to_string())
2443 .anyhow())?;
2444 }
2445
2446 let dev_server_id = DevServerId(request.dev_server_id as i32);
2447 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2448 if dev_server.user_id != session.user_id() {
2449 return Err(anyhow!(ErrorCode::Forbidden))?;
2450 }
2451
2452 let status = session
2453 .db()
2454 .await
2455 .rename_dev_server(
2456 dev_server_id,
2457 &request.name,
2458 request.ssh_connection_string.as_deref(),
2459 session.user_id(),
2460 )
2461 .await?;
2462
2463 send_dev_server_projects_update(session.user_id(), status, &session).await;
2464
2465 response.send(proto::Ack {})?;
2466 Ok(())
2467}
2468
2469async fn delete_dev_server(
2470 request: proto::DeleteDevServer,
2471 response: Response<proto::DeleteDevServer>,
2472 session: UserSession,
2473) -> Result<()> {
2474 let dev_server_id = DevServerId(request.dev_server_id as i32);
2475 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2476 if dev_server.user_id != session.user_id() {
2477 return Err(anyhow!(ErrorCode::Forbidden))?;
2478 }
2479
2480 let connection_id = session
2481 .connection_pool()
2482 .await
2483 .dev_server_connection_id(dev_server_id);
2484 if let Some(connection_id) = connection_id {
2485 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2486 session.peer.send(
2487 connection_id,
2488 proto::ShutdownDevServer {
2489 reason: Some("dev server was deleted".to_string()),
2490 },
2491 )?;
2492 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2493 }
2494
2495 let status = session
2496 .db()
2497 .await
2498 .delete_dev_server(dev_server_id, session.user_id())
2499 .await?;
2500
2501 send_dev_server_projects_update(session.user_id(), status, &session).await;
2502
2503 response.send(proto::Ack {})?;
2504 Ok(())
2505}
2506
2507async fn delete_dev_server_project(
2508 request: proto::DeleteDevServerProject,
2509 response: Response<proto::DeleteDevServerProject>,
2510 session: UserSession,
2511) -> Result<()> {
2512 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2513 let dev_server_project = session
2514 .db()
2515 .await
2516 .get_dev_server_project(dev_server_project_id)
2517 .await?;
2518
2519 let dev_server = session
2520 .db()
2521 .await
2522 .get_dev_server(dev_server_project.dev_server_id)
2523 .await?;
2524 if dev_server.user_id != session.user_id() {
2525 return Err(anyhow!(ErrorCode::Forbidden))?;
2526 }
2527
2528 let dev_server_connection_id = session
2529 .connection_pool()
2530 .await
2531 .dev_server_connection_id(dev_server.id);
2532
2533 if let Some(dev_server_connection_id) = dev_server_connection_id {
2534 let project = session
2535 .db()
2536 .await
2537 .find_dev_server_project(dev_server_project_id)
2538 .await;
2539 if let Ok(project) = project {
2540 unshare_project_internal(
2541 project.id,
2542 dev_server_connection_id,
2543 Some(session.user_id()),
2544 &session,
2545 )
2546 .await?;
2547 }
2548 }
2549
2550 let (projects, status) = session
2551 .db()
2552 .await
2553 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2554 .await?;
2555
2556 if let Some(dev_server_connection_id) = dev_server_connection_id {
2557 session.peer.send(
2558 dev_server_connection_id,
2559 proto::DevServerInstructions { projects },
2560 )?;
2561 }
2562
2563 send_dev_server_projects_update(session.user_id(), status, &session).await;
2564
2565 response.send(proto::Ack {})?;
2566 Ok(())
2567}
2568
2569async fn rejoin_dev_server_projects(
2570 request: proto::RejoinRemoteProjects,
2571 response: Response<proto::RejoinRemoteProjects>,
2572 session: UserSession,
2573) -> Result<()> {
2574 let mut rejoined_projects = {
2575 let db = session.db().await;
2576 db.rejoin_dev_server_projects(
2577 &request.rejoined_projects,
2578 session.user_id(),
2579 session.0.connection_id,
2580 )
2581 .await?
2582 };
2583 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2584
2585 response.send(proto::RejoinRemoteProjectsResponse {
2586 rejoined_projects: rejoined_projects
2587 .into_iter()
2588 .map(|project| project.to_proto())
2589 .collect(),
2590 })
2591}
2592
2593async fn reconnect_dev_server(
2594 request: proto::ReconnectDevServer,
2595 response: Response<proto::ReconnectDevServer>,
2596 session: DevServerSession,
2597) -> Result<()> {
2598 let reshared_projects = {
2599 let db = session.db().await;
2600 db.reshare_dev_server_projects(
2601 &request.reshared_projects,
2602 session.dev_server_id(),
2603 session.0.connection_id,
2604 )
2605 .await?
2606 };
2607
2608 for project in &reshared_projects {
2609 for collaborator in &project.collaborators {
2610 session
2611 .peer
2612 .send(
2613 collaborator.connection_id,
2614 proto::UpdateProjectCollaborator {
2615 project_id: project.id.to_proto(),
2616 old_peer_id: Some(project.old_connection_id.into()),
2617 new_peer_id: Some(session.connection_id.into()),
2618 },
2619 )
2620 .trace_err();
2621 }
2622
2623 broadcast(
2624 Some(session.connection_id),
2625 project
2626 .collaborators
2627 .iter()
2628 .map(|collaborator| collaborator.connection_id),
2629 |connection_id| {
2630 session.peer.forward_send(
2631 session.connection_id,
2632 connection_id,
2633 proto::UpdateProject {
2634 project_id: project.id.to_proto(),
2635 worktrees: project.worktrees.clone(),
2636 },
2637 )
2638 },
2639 );
2640 }
2641
2642 response.send(proto::ReconnectDevServerResponse {
2643 reshared_projects: reshared_projects
2644 .iter()
2645 .map(|project| proto::ResharedProject {
2646 id: project.id.to_proto(),
2647 collaborators: project
2648 .collaborators
2649 .iter()
2650 .map(|collaborator| collaborator.to_proto())
2651 .collect(),
2652 })
2653 .collect(),
2654 })?;
2655
2656 Ok(())
2657}
2658
2659async fn shutdown_dev_server(
2660 _: proto::ShutdownDevServer,
2661 response: Response<proto::ShutdownDevServer>,
2662 session: DevServerSession,
2663) -> Result<()> {
2664 response.send(proto::Ack {})?;
2665 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2666 remove_dev_server_connection(session.dev_server_id(), &session).await
2667}
2668
2669async fn shutdown_dev_server_internal(
2670 dev_server_id: DevServerId,
2671 connection_id: ConnectionId,
2672 session: &Session,
2673) -> Result<()> {
2674 let (dev_server_projects, dev_server) = {
2675 let db = session.db().await;
2676 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2677 let dev_server = db.get_dev_server(dev_server_id).await?;
2678 (dev_server_projects, dev_server)
2679 };
2680
2681 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2682 unshare_project_internal(
2683 ProjectId::from_proto(project_id),
2684 connection_id,
2685 None,
2686 session,
2687 )
2688 .await?;
2689 }
2690
2691 session
2692 .connection_pool()
2693 .await
2694 .set_dev_server_offline(dev_server_id);
2695
2696 let status = session
2697 .db()
2698 .await
2699 .dev_server_projects_update(dev_server.user_id)
2700 .await?;
2701 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2702
2703 Ok(())
2704}
2705
2706async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2707 let dev_server_connection = session
2708 .connection_pool()
2709 .await
2710 .dev_server_connection_id(dev_server_id);
2711
2712 if let Some(dev_server_connection) = dev_server_connection {
2713 session
2714 .connection_pool()
2715 .await
2716 .remove_connection(dev_server_connection)?;
2717 }
2718 Ok(())
2719}
2720
2721/// Updates other participants with changes to the project
2722async fn update_project(
2723 request: proto::UpdateProject,
2724 response: Response<proto::UpdateProject>,
2725 session: Session,
2726) -> Result<()> {
2727 let project_id = ProjectId::from_proto(request.project_id);
2728 let (room, guest_connection_ids) = &*session
2729 .db()
2730 .await
2731 .update_project(project_id, session.connection_id, &request.worktrees)
2732 .await?;
2733 broadcast(
2734 Some(session.connection_id),
2735 guest_connection_ids.iter().copied(),
2736 |connection_id| {
2737 session
2738 .peer
2739 .forward_send(session.connection_id, connection_id, request.clone())
2740 },
2741 );
2742 if let Some(room) = room {
2743 room_updated(&room, &session.peer);
2744 }
2745 response.send(proto::Ack {})?;
2746
2747 Ok(())
2748}
2749
2750/// Updates other participants with changes to the worktree
2751async fn update_worktree(
2752 request: proto::UpdateWorktree,
2753 response: Response<proto::UpdateWorktree>,
2754 session: Session,
2755) -> Result<()> {
2756 let guest_connection_ids = session
2757 .db()
2758 .await
2759 .update_worktree(&request, session.connection_id)
2760 .await?;
2761
2762 broadcast(
2763 Some(session.connection_id),
2764 guest_connection_ids.iter().copied(),
2765 |connection_id| {
2766 session
2767 .peer
2768 .forward_send(session.connection_id, connection_id, request.clone())
2769 },
2770 );
2771 response.send(proto::Ack {})?;
2772 Ok(())
2773}
2774
2775/// Updates other participants with changes to the diagnostics
2776async fn update_diagnostic_summary(
2777 message: proto::UpdateDiagnosticSummary,
2778 session: Session,
2779) -> Result<()> {
2780 let guest_connection_ids = session
2781 .db()
2782 .await
2783 .update_diagnostic_summary(&message, session.connection_id)
2784 .await?;
2785
2786 broadcast(
2787 Some(session.connection_id),
2788 guest_connection_ids.iter().copied(),
2789 |connection_id| {
2790 session
2791 .peer
2792 .forward_send(session.connection_id, connection_id, message.clone())
2793 },
2794 );
2795
2796 Ok(())
2797}
2798
2799/// Updates other participants with changes to the worktree settings
2800async fn update_worktree_settings(
2801 message: proto::UpdateWorktreeSettings,
2802 session: Session,
2803) -> Result<()> {
2804 let guest_connection_ids = session
2805 .db()
2806 .await
2807 .update_worktree_settings(&message, session.connection_id)
2808 .await?;
2809
2810 broadcast(
2811 Some(session.connection_id),
2812 guest_connection_ids.iter().copied(),
2813 |connection_id| {
2814 session
2815 .peer
2816 .forward_send(session.connection_id, connection_id, message.clone())
2817 },
2818 );
2819
2820 Ok(())
2821}
2822
2823/// Notify other participants that a language server has started.
2824async fn start_language_server(
2825 request: proto::StartLanguageServer,
2826 session: Session,
2827) -> Result<()> {
2828 let guest_connection_ids = session
2829 .db()
2830 .await
2831 .start_language_server(&request, session.connection_id)
2832 .await?;
2833
2834 broadcast(
2835 Some(session.connection_id),
2836 guest_connection_ids.iter().copied(),
2837 |connection_id| {
2838 session
2839 .peer
2840 .forward_send(session.connection_id, connection_id, request.clone())
2841 },
2842 );
2843 Ok(())
2844}
2845
2846/// Notify other participants that a language server has changed.
2847async fn update_language_server(
2848 request: proto::UpdateLanguageServer,
2849 session: Session,
2850) -> Result<()> {
2851 let project_id = ProjectId::from_proto(request.project_id);
2852 let project_connection_ids = session
2853 .db()
2854 .await
2855 .project_connection_ids(project_id, session.connection_id, true)
2856 .await?;
2857 broadcast(
2858 Some(session.connection_id),
2859 project_connection_ids.iter().copied(),
2860 |connection_id| {
2861 session
2862 .peer
2863 .forward_send(session.connection_id, connection_id, request.clone())
2864 },
2865 );
2866 Ok(())
2867}
2868
2869/// forward a project request to the host. These requests should be read only
2870/// as guests are allowed to send them.
2871async fn forward_read_only_project_request<T>(
2872 request: T,
2873 response: Response<T>,
2874 session: UserSession,
2875) -> Result<()>
2876where
2877 T: EntityMessage + RequestMessage,
2878{
2879 let project_id = ProjectId::from_proto(request.remote_entity_id());
2880 let host_connection_id = session
2881 .db()
2882 .await
2883 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2884 .await?;
2885 let payload = session
2886 .peer
2887 .forward_request(session.connection_id, host_connection_id, request)
2888 .await?;
2889 response.send(payload)?;
2890 Ok(())
2891}
2892
2893/// forward a project request to the dev server. Only allowed
2894/// if it's your dev server.
2895async fn forward_project_request_for_owner<T>(
2896 request: T,
2897 response: Response<T>,
2898 session: UserSession,
2899) -> Result<()>
2900where
2901 T: EntityMessage + RequestMessage,
2902{
2903 let project_id = ProjectId::from_proto(request.remote_entity_id());
2904
2905 let host_connection_id = session
2906 .db()
2907 .await
2908 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2909 .await?;
2910 let payload = session
2911 .peer
2912 .forward_request(session.connection_id, host_connection_id, request)
2913 .await?;
2914 response.send(payload)?;
2915 Ok(())
2916}
2917
2918/// forward a project request to the host. These requests are disallowed
2919/// for guests.
2920async fn forward_mutating_project_request<T>(
2921 request: T,
2922 response: Response<T>,
2923 session: UserSession,
2924) -> Result<()>
2925where
2926 T: EntityMessage + RequestMessage,
2927{
2928 let project_id = ProjectId::from_proto(request.remote_entity_id());
2929
2930 let host_connection_id = session
2931 .db()
2932 .await
2933 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2934 .await?;
2935 let payload = session
2936 .peer
2937 .forward_request(session.connection_id, host_connection_id, request)
2938 .await?;
2939 response.send(payload)?;
2940 Ok(())
2941}
2942
2943/// forward a project request to the host. These requests are disallowed
2944/// for guests.
2945async fn forward_versioned_mutating_project_request<T>(
2946 request: T,
2947 response: Response<T>,
2948 session: UserSession,
2949) -> Result<()>
2950where
2951 T: EntityMessage + RequestMessage + VersionedMessage,
2952{
2953 let project_id = ProjectId::from_proto(request.remote_entity_id());
2954
2955 let host_connection_id = session
2956 .db()
2957 .await
2958 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2959 .await?;
2960 if let Some(host_version) = session
2961 .connection_pool()
2962 .await
2963 .connection(host_connection_id)
2964 .map(|c| c.zed_version)
2965 {
2966 if let Some(min_required_version) = request.required_host_version() {
2967 if min_required_version > host_version {
2968 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2969 .with_tag("required", &min_required_version.to_string())))?;
2970 }
2971 }
2972 }
2973
2974 let payload = session
2975 .peer
2976 .forward_request(session.connection_id, host_connection_id, request)
2977 .await?;
2978 response.send(payload)?;
2979 Ok(())
2980}
2981
2982/// Notify other participants that a new buffer has been created
2983async fn create_buffer_for_peer(
2984 request: proto::CreateBufferForPeer,
2985 session: Session,
2986) -> Result<()> {
2987 session
2988 .db()
2989 .await
2990 .check_user_is_project_host(
2991 ProjectId::from_proto(request.project_id),
2992 session.connection_id,
2993 )
2994 .await?;
2995 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2996 session
2997 .peer
2998 .forward_send(session.connection_id, peer_id.into(), request)?;
2999 Ok(())
3000}
3001
3002/// Notify other participants that a buffer has been updated. This is
3003/// allowed for guests as long as the update is limited to selections.
3004async fn update_buffer(
3005 request: proto::UpdateBuffer,
3006 response: Response<proto::UpdateBuffer>,
3007 session: Session,
3008) -> Result<()> {
3009 let project_id = ProjectId::from_proto(request.project_id);
3010 let mut capability = Capability::ReadOnly;
3011
3012 for op in request.operations.iter() {
3013 match op.variant {
3014 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3015 Some(_) => capability = Capability::ReadWrite,
3016 }
3017 }
3018
3019 let host = {
3020 let guard = session
3021 .db()
3022 .await
3023 .connections_for_buffer_update(
3024 project_id,
3025 session.principal_id(),
3026 session.connection_id,
3027 capability,
3028 )
3029 .await?;
3030
3031 let (host, guests) = &*guard;
3032
3033 broadcast(
3034 Some(session.connection_id),
3035 guests.clone(),
3036 |connection_id| {
3037 session
3038 .peer
3039 .forward_send(session.connection_id, connection_id, request.clone())
3040 },
3041 );
3042
3043 *host
3044 };
3045
3046 if host != session.connection_id {
3047 session
3048 .peer
3049 .forward_request(session.connection_id, host, request.clone())
3050 .await?;
3051 }
3052
3053 response.send(proto::Ack {})?;
3054 Ok(())
3055}
3056
3057/// Notify other participants that a project has been updated.
3058async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3059 request: T,
3060 session: Session,
3061) -> Result<()> {
3062 let project_id = ProjectId::from_proto(request.remote_entity_id());
3063 let project_connection_ids = session
3064 .db()
3065 .await
3066 .project_connection_ids(project_id, session.connection_id, false)
3067 .await?;
3068
3069 broadcast(
3070 Some(session.connection_id),
3071 project_connection_ids.iter().copied(),
3072 |connection_id| {
3073 session
3074 .peer
3075 .forward_send(session.connection_id, connection_id, request.clone())
3076 },
3077 );
3078 Ok(())
3079}
3080
3081/// Start following another user in a call.
3082async fn follow(
3083 request: proto::Follow,
3084 response: Response<proto::Follow>,
3085 session: UserSession,
3086) -> Result<()> {
3087 let room_id = RoomId::from_proto(request.room_id);
3088 let project_id = request.project_id.map(ProjectId::from_proto);
3089 let leader_id = request
3090 .leader_id
3091 .ok_or_else(|| anyhow!("invalid leader id"))?
3092 .into();
3093 let follower_id = session.connection_id;
3094
3095 session
3096 .db()
3097 .await
3098 .check_room_participants(room_id, leader_id, session.connection_id)
3099 .await?;
3100
3101 let response_payload = session
3102 .peer
3103 .forward_request(session.connection_id, leader_id, request)
3104 .await?;
3105 response.send(response_payload)?;
3106
3107 if let Some(project_id) = project_id {
3108 let room = session
3109 .db()
3110 .await
3111 .follow(room_id, project_id, leader_id, follower_id)
3112 .await?;
3113 room_updated(&room, &session.peer);
3114 }
3115
3116 Ok(())
3117}
3118
3119/// Stop following another user in a call.
3120async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3121 let room_id = RoomId::from_proto(request.room_id);
3122 let project_id = request.project_id.map(ProjectId::from_proto);
3123 let leader_id = request
3124 .leader_id
3125 .ok_or_else(|| anyhow!("invalid leader id"))?
3126 .into();
3127 let follower_id = session.connection_id;
3128
3129 session
3130 .db()
3131 .await
3132 .check_room_participants(room_id, leader_id, session.connection_id)
3133 .await?;
3134
3135 session
3136 .peer
3137 .forward_send(session.connection_id, leader_id, request)?;
3138
3139 if let Some(project_id) = project_id {
3140 let room = session
3141 .db()
3142 .await
3143 .unfollow(room_id, project_id, leader_id, follower_id)
3144 .await?;
3145 room_updated(&room, &session.peer);
3146 }
3147
3148 Ok(())
3149}
3150
3151/// Notify everyone following you of your current location.
3152async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3153 let room_id = RoomId::from_proto(request.room_id);
3154 let database = session.db.lock().await;
3155
3156 let connection_ids = if let Some(project_id) = request.project_id {
3157 let project_id = ProjectId::from_proto(project_id);
3158 database
3159 .project_connection_ids(project_id, session.connection_id, true)
3160 .await?
3161 } else {
3162 database
3163 .room_connection_ids(room_id, session.connection_id)
3164 .await?
3165 };
3166
3167 // For now, don't send view update messages back to that view's current leader.
3168 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3169 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3170 _ => None,
3171 });
3172
3173 for connection_id in connection_ids.iter().cloned() {
3174 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3175 session
3176 .peer
3177 .forward_send(session.connection_id, connection_id, request.clone())?;
3178 }
3179 }
3180 Ok(())
3181}
3182
3183/// Get public data about users.
3184async fn get_users(
3185 request: proto::GetUsers,
3186 response: Response<proto::GetUsers>,
3187 session: Session,
3188) -> Result<()> {
3189 let user_ids = request
3190 .user_ids
3191 .into_iter()
3192 .map(UserId::from_proto)
3193 .collect();
3194 let users = session
3195 .db()
3196 .await
3197 .get_users_by_ids(user_ids)
3198 .await?
3199 .into_iter()
3200 .map(|user| proto::User {
3201 id: user.id.to_proto(),
3202 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3203 github_login: user.github_login,
3204 })
3205 .collect();
3206 response.send(proto::UsersResponse { users })?;
3207 Ok(())
3208}
3209
3210/// Search for users (to invite) buy Github login
3211async fn fuzzy_search_users(
3212 request: proto::FuzzySearchUsers,
3213 response: Response<proto::FuzzySearchUsers>,
3214 session: UserSession,
3215) -> Result<()> {
3216 let query = request.query;
3217 let users = match query.len() {
3218 0 => vec![],
3219 1 | 2 => session
3220 .db()
3221 .await
3222 .get_user_by_github_login(&query)
3223 .await?
3224 .into_iter()
3225 .collect(),
3226 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3227 };
3228 let users = users
3229 .into_iter()
3230 .filter(|user| user.id != session.user_id())
3231 .map(|user| proto::User {
3232 id: user.id.to_proto(),
3233 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3234 github_login: user.github_login,
3235 })
3236 .collect();
3237 response.send(proto::UsersResponse { users })?;
3238 Ok(())
3239}
3240
3241/// Send a contact request to another user.
3242async fn request_contact(
3243 request: proto::RequestContact,
3244 response: Response<proto::RequestContact>,
3245 session: UserSession,
3246) -> Result<()> {
3247 let requester_id = session.user_id();
3248 let responder_id = UserId::from_proto(request.responder_id);
3249 if requester_id == responder_id {
3250 return Err(anyhow!("cannot add yourself as a contact"))?;
3251 }
3252
3253 let notifications = session
3254 .db()
3255 .await
3256 .send_contact_request(requester_id, responder_id)
3257 .await?;
3258
3259 // Update outgoing contact requests of requester
3260 let mut update = proto::UpdateContacts::default();
3261 update.outgoing_requests.push(responder_id.to_proto());
3262 for connection_id in session
3263 .connection_pool()
3264 .await
3265 .user_connection_ids(requester_id)
3266 {
3267 session.peer.send(connection_id, update.clone())?;
3268 }
3269
3270 // Update incoming contact requests of responder
3271 let mut update = proto::UpdateContacts::default();
3272 update
3273 .incoming_requests
3274 .push(proto::IncomingContactRequest {
3275 requester_id: requester_id.to_proto(),
3276 });
3277 let connection_pool = session.connection_pool().await;
3278 for connection_id in connection_pool.user_connection_ids(responder_id) {
3279 session.peer.send(connection_id, update.clone())?;
3280 }
3281
3282 send_notifications(&connection_pool, &session.peer, notifications);
3283
3284 response.send(proto::Ack {})?;
3285 Ok(())
3286}
3287
3288/// Accept or decline a contact request
3289async fn respond_to_contact_request(
3290 request: proto::RespondToContactRequest,
3291 response: Response<proto::RespondToContactRequest>,
3292 session: UserSession,
3293) -> Result<()> {
3294 let responder_id = session.user_id();
3295 let requester_id = UserId::from_proto(request.requester_id);
3296 let db = session.db().await;
3297 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3298 db.dismiss_contact_notification(responder_id, requester_id)
3299 .await?;
3300 } else {
3301 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3302
3303 let notifications = db
3304 .respond_to_contact_request(responder_id, requester_id, accept)
3305 .await?;
3306 let requester_busy = db.is_user_busy(requester_id).await?;
3307 let responder_busy = db.is_user_busy(responder_id).await?;
3308
3309 let pool = session.connection_pool().await;
3310 // Update responder with new contact
3311 let mut update = proto::UpdateContacts::default();
3312 if accept {
3313 update
3314 .contacts
3315 .push(contact_for_user(requester_id, requester_busy, &pool));
3316 }
3317 update
3318 .remove_incoming_requests
3319 .push(requester_id.to_proto());
3320 for connection_id in pool.user_connection_ids(responder_id) {
3321 session.peer.send(connection_id, update.clone())?;
3322 }
3323
3324 // Update requester with new contact
3325 let mut update = proto::UpdateContacts::default();
3326 if accept {
3327 update
3328 .contacts
3329 .push(contact_for_user(responder_id, responder_busy, &pool));
3330 }
3331 update
3332 .remove_outgoing_requests
3333 .push(responder_id.to_proto());
3334
3335 for connection_id in pool.user_connection_ids(requester_id) {
3336 session.peer.send(connection_id, update.clone())?;
3337 }
3338
3339 send_notifications(&pool, &session.peer, notifications);
3340 }
3341
3342 response.send(proto::Ack {})?;
3343 Ok(())
3344}
3345
3346/// Remove a contact.
3347async fn remove_contact(
3348 request: proto::RemoveContact,
3349 response: Response<proto::RemoveContact>,
3350 session: UserSession,
3351) -> Result<()> {
3352 let requester_id = session.user_id();
3353 let responder_id = UserId::from_proto(request.user_id);
3354 let db = session.db().await;
3355 let (contact_accepted, deleted_notification_id) =
3356 db.remove_contact(requester_id, responder_id).await?;
3357
3358 let pool = session.connection_pool().await;
3359 // Update outgoing contact requests of requester
3360 let mut update = proto::UpdateContacts::default();
3361 if contact_accepted {
3362 update.remove_contacts.push(responder_id.to_proto());
3363 } else {
3364 update
3365 .remove_outgoing_requests
3366 .push(responder_id.to_proto());
3367 }
3368 for connection_id in pool.user_connection_ids(requester_id) {
3369 session.peer.send(connection_id, update.clone())?;
3370 }
3371
3372 // Update incoming contact requests of responder
3373 let mut update = proto::UpdateContacts::default();
3374 if contact_accepted {
3375 update.remove_contacts.push(requester_id.to_proto());
3376 } else {
3377 update
3378 .remove_incoming_requests
3379 .push(requester_id.to_proto());
3380 }
3381 for connection_id in pool.user_connection_ids(responder_id) {
3382 session.peer.send(connection_id, update.clone())?;
3383 if let Some(notification_id) = deleted_notification_id {
3384 session.peer.send(
3385 connection_id,
3386 proto::DeleteNotification {
3387 notification_id: notification_id.to_proto(),
3388 },
3389 )?;
3390 }
3391 }
3392
3393 response.send(proto::Ack {})?;
3394 Ok(())
3395}
3396
3397fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3398 version.0.minor() < 139
3399}
3400
3401async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3402 subscribe_user_to_channels(
3403 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3404 &session,
3405 )
3406 .await?;
3407 Ok(())
3408}
3409
3410async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3411 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3412 let mut pool = session.connection_pool().await;
3413 for membership in &channels_for_user.channel_memberships {
3414 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3415 }
3416 session.peer.send(
3417 session.connection_id,
3418 build_update_user_channels(&channels_for_user),
3419 )?;
3420 session.peer.send(
3421 session.connection_id,
3422 build_channels_update(channels_for_user),
3423 )?;
3424 Ok(())
3425}
3426
3427/// Creates a new channel.
3428async fn create_channel(
3429 request: proto::CreateChannel,
3430 response: Response<proto::CreateChannel>,
3431 session: UserSession,
3432) -> Result<()> {
3433 let db = session.db().await;
3434
3435 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3436 let (channel, membership) = db
3437 .create_channel(&request.name, parent_id, session.user_id())
3438 .await?;
3439
3440 let root_id = channel.root_id();
3441 let channel = Channel::from_model(channel);
3442
3443 response.send(proto::CreateChannelResponse {
3444 channel: Some(channel.to_proto()),
3445 parent_id: request.parent_id,
3446 })?;
3447
3448 let mut connection_pool = session.connection_pool().await;
3449 if let Some(membership) = membership {
3450 connection_pool.subscribe_to_channel(
3451 membership.user_id,
3452 membership.channel_id,
3453 membership.role,
3454 );
3455 let update = proto::UpdateUserChannels {
3456 channel_memberships: vec![proto::ChannelMembership {
3457 channel_id: membership.channel_id.to_proto(),
3458 role: membership.role.into(),
3459 }],
3460 ..Default::default()
3461 };
3462 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3463 session.peer.send(connection_id, update.clone())?;
3464 }
3465 }
3466
3467 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3468 if !role.can_see_channel(channel.visibility) {
3469 continue;
3470 }
3471
3472 let update = proto::UpdateChannels {
3473 channels: vec![channel.to_proto()],
3474 ..Default::default()
3475 };
3476 session.peer.send(connection_id, update.clone())?;
3477 }
3478
3479 Ok(())
3480}
3481
3482/// Delete a channel
3483async fn delete_channel(
3484 request: proto::DeleteChannel,
3485 response: Response<proto::DeleteChannel>,
3486 session: UserSession,
3487) -> Result<()> {
3488 let db = session.db().await;
3489
3490 let channel_id = request.channel_id;
3491 let (root_channel, removed_channels) = db
3492 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3493 .await?;
3494 response.send(proto::Ack {})?;
3495
3496 // Notify members of removed channels
3497 let mut update = proto::UpdateChannels::default();
3498 update
3499 .delete_channels
3500 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3501
3502 let connection_pool = session.connection_pool().await;
3503 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3504 session.peer.send(connection_id, update.clone())?;
3505 }
3506
3507 Ok(())
3508}
3509
3510/// Invite someone to join a channel.
3511async fn invite_channel_member(
3512 request: proto::InviteChannelMember,
3513 response: Response<proto::InviteChannelMember>,
3514 session: UserSession,
3515) -> Result<()> {
3516 let db = session.db().await;
3517 let channel_id = ChannelId::from_proto(request.channel_id);
3518 let invitee_id = UserId::from_proto(request.user_id);
3519 let InviteMemberResult {
3520 channel,
3521 notifications,
3522 } = db
3523 .invite_channel_member(
3524 channel_id,
3525 invitee_id,
3526 session.user_id(),
3527 request.role().into(),
3528 )
3529 .await?;
3530
3531 let update = proto::UpdateChannels {
3532 channel_invitations: vec![channel.to_proto()],
3533 ..Default::default()
3534 };
3535
3536 let connection_pool = session.connection_pool().await;
3537 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3538 session.peer.send(connection_id, update.clone())?;
3539 }
3540
3541 send_notifications(&connection_pool, &session.peer, notifications);
3542
3543 response.send(proto::Ack {})?;
3544 Ok(())
3545}
3546
3547/// remove someone from a channel
3548async fn remove_channel_member(
3549 request: proto::RemoveChannelMember,
3550 response: Response<proto::RemoveChannelMember>,
3551 session: UserSession,
3552) -> Result<()> {
3553 let db = session.db().await;
3554 let channel_id = ChannelId::from_proto(request.channel_id);
3555 let member_id = UserId::from_proto(request.user_id);
3556
3557 let RemoveChannelMemberResult {
3558 membership_update,
3559 notification_id,
3560 } = db
3561 .remove_channel_member(channel_id, member_id, session.user_id())
3562 .await?;
3563
3564 let mut connection_pool = session.connection_pool().await;
3565 notify_membership_updated(
3566 &mut connection_pool,
3567 membership_update,
3568 member_id,
3569 &session.peer,
3570 );
3571 for connection_id in connection_pool.user_connection_ids(member_id) {
3572 if let Some(notification_id) = notification_id {
3573 session
3574 .peer
3575 .send(
3576 connection_id,
3577 proto::DeleteNotification {
3578 notification_id: notification_id.to_proto(),
3579 },
3580 )
3581 .trace_err();
3582 }
3583 }
3584
3585 response.send(proto::Ack {})?;
3586 Ok(())
3587}
3588
3589/// Toggle the channel between public and private.
3590/// Care is taken to maintain the invariant that public channels only descend from public channels,
3591/// (though members-only channels can appear at any point in the hierarchy).
3592async fn set_channel_visibility(
3593 request: proto::SetChannelVisibility,
3594 response: Response<proto::SetChannelVisibility>,
3595 session: UserSession,
3596) -> Result<()> {
3597 let db = session.db().await;
3598 let channel_id = ChannelId::from_proto(request.channel_id);
3599 let visibility = request.visibility().into();
3600
3601 let channel_model = db
3602 .set_channel_visibility(channel_id, visibility, session.user_id())
3603 .await?;
3604 let root_id = channel_model.root_id();
3605 let channel = Channel::from_model(channel_model);
3606
3607 let mut connection_pool = session.connection_pool().await;
3608 for (user_id, role) in connection_pool
3609 .channel_user_ids(root_id)
3610 .collect::<Vec<_>>()
3611 .into_iter()
3612 {
3613 let update = if role.can_see_channel(channel.visibility) {
3614 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3615 proto::UpdateChannels {
3616 channels: vec![channel.to_proto()],
3617 ..Default::default()
3618 }
3619 } else {
3620 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3621 proto::UpdateChannels {
3622 delete_channels: vec![channel.id.to_proto()],
3623 ..Default::default()
3624 }
3625 };
3626
3627 for connection_id in connection_pool.user_connection_ids(user_id) {
3628 session.peer.send(connection_id, update.clone())?;
3629 }
3630 }
3631
3632 response.send(proto::Ack {})?;
3633 Ok(())
3634}
3635
3636/// Alter the role for a user in the channel.
3637async fn set_channel_member_role(
3638 request: proto::SetChannelMemberRole,
3639 response: Response<proto::SetChannelMemberRole>,
3640 session: UserSession,
3641) -> Result<()> {
3642 let db = session.db().await;
3643 let channel_id = ChannelId::from_proto(request.channel_id);
3644 let member_id = UserId::from_proto(request.user_id);
3645 let result = db
3646 .set_channel_member_role(
3647 channel_id,
3648 session.user_id(),
3649 member_id,
3650 request.role().into(),
3651 )
3652 .await?;
3653
3654 match result {
3655 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3656 let mut connection_pool = session.connection_pool().await;
3657 notify_membership_updated(
3658 &mut connection_pool,
3659 membership_update,
3660 member_id,
3661 &session.peer,
3662 )
3663 }
3664 db::SetMemberRoleResult::InviteUpdated(channel) => {
3665 let update = proto::UpdateChannels {
3666 channel_invitations: vec![channel.to_proto()],
3667 ..Default::default()
3668 };
3669
3670 for connection_id in session
3671 .connection_pool()
3672 .await
3673 .user_connection_ids(member_id)
3674 {
3675 session.peer.send(connection_id, update.clone())?;
3676 }
3677 }
3678 }
3679
3680 response.send(proto::Ack {})?;
3681 Ok(())
3682}
3683
3684/// Change the name of a channel
3685async fn rename_channel(
3686 request: proto::RenameChannel,
3687 response: Response<proto::RenameChannel>,
3688 session: UserSession,
3689) -> Result<()> {
3690 let db = session.db().await;
3691 let channel_id = ChannelId::from_proto(request.channel_id);
3692 let channel_model = db
3693 .rename_channel(channel_id, session.user_id(), &request.name)
3694 .await?;
3695 let root_id = channel_model.root_id();
3696 let channel = Channel::from_model(channel_model);
3697
3698 response.send(proto::RenameChannelResponse {
3699 channel: Some(channel.to_proto()),
3700 })?;
3701
3702 let connection_pool = session.connection_pool().await;
3703 let update = proto::UpdateChannels {
3704 channels: vec![channel.to_proto()],
3705 ..Default::default()
3706 };
3707 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3708 if role.can_see_channel(channel.visibility) {
3709 session.peer.send(connection_id, update.clone())?;
3710 }
3711 }
3712
3713 Ok(())
3714}
3715
3716/// Move a channel to a new parent.
3717async fn move_channel(
3718 request: proto::MoveChannel,
3719 response: Response<proto::MoveChannel>,
3720 session: UserSession,
3721) -> Result<()> {
3722 let channel_id = ChannelId::from_proto(request.channel_id);
3723 let to = ChannelId::from_proto(request.to);
3724
3725 let (root_id, channels) = session
3726 .db()
3727 .await
3728 .move_channel(channel_id, to, session.user_id())
3729 .await?;
3730
3731 let connection_pool = session.connection_pool().await;
3732 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3733 let channels = channels
3734 .iter()
3735 .filter_map(|channel| {
3736 if role.can_see_channel(channel.visibility) {
3737 Some(channel.to_proto())
3738 } else {
3739 None
3740 }
3741 })
3742 .collect::<Vec<_>>();
3743 if channels.is_empty() {
3744 continue;
3745 }
3746
3747 let update = proto::UpdateChannels {
3748 channels,
3749 ..Default::default()
3750 };
3751
3752 session.peer.send(connection_id, update.clone())?;
3753 }
3754
3755 response.send(Ack {})?;
3756 Ok(())
3757}
3758
3759/// Get the list of channel members
3760async fn get_channel_members(
3761 request: proto::GetChannelMembers,
3762 response: Response<proto::GetChannelMembers>,
3763 session: UserSession,
3764) -> Result<()> {
3765 let db = session.db().await;
3766 let channel_id = ChannelId::from_proto(request.channel_id);
3767 let limit = if request.limit == 0 {
3768 u16::MAX as u64
3769 } else {
3770 request.limit
3771 };
3772 let (members, users) = db
3773 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3774 .await?;
3775 response.send(proto::GetChannelMembersResponse { members, users })?;
3776 Ok(())
3777}
3778
3779/// Accept or decline a channel invitation.
3780async fn respond_to_channel_invite(
3781 request: proto::RespondToChannelInvite,
3782 response: Response<proto::RespondToChannelInvite>,
3783 session: UserSession,
3784) -> Result<()> {
3785 let db = session.db().await;
3786 let channel_id = ChannelId::from_proto(request.channel_id);
3787 let RespondToChannelInvite {
3788 membership_update,
3789 notifications,
3790 } = db
3791 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3792 .await?;
3793
3794 let mut connection_pool = session.connection_pool().await;
3795 if let Some(membership_update) = membership_update {
3796 notify_membership_updated(
3797 &mut connection_pool,
3798 membership_update,
3799 session.user_id(),
3800 &session.peer,
3801 );
3802 } else {
3803 let update = proto::UpdateChannels {
3804 remove_channel_invitations: vec![channel_id.to_proto()],
3805 ..Default::default()
3806 };
3807
3808 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3809 session.peer.send(connection_id, update.clone())?;
3810 }
3811 };
3812
3813 send_notifications(&connection_pool, &session.peer, notifications);
3814
3815 response.send(proto::Ack {})?;
3816
3817 Ok(())
3818}
3819
3820/// Join the channels' room
3821async fn join_channel(
3822 request: proto::JoinChannel,
3823 response: Response<proto::JoinChannel>,
3824 session: UserSession,
3825) -> Result<()> {
3826 let channel_id = ChannelId::from_proto(request.channel_id);
3827 join_channel_internal(channel_id, Box::new(response), session).await
3828}
3829
3830trait JoinChannelInternalResponse {
3831 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3832}
3833impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3834 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3835 Response::<proto::JoinChannel>::send(self, result)
3836 }
3837}
3838impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3839 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3840 Response::<proto::JoinRoom>::send(self, result)
3841 }
3842}
3843
3844async fn join_channel_internal(
3845 channel_id: ChannelId,
3846 response: Box<impl JoinChannelInternalResponse>,
3847 session: UserSession,
3848) -> Result<()> {
3849 let joined_room = {
3850 let mut db = session.db().await;
3851 // If zed quits without leaving the room, and the user re-opens zed before the
3852 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3853 // room they were in.
3854 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3855 tracing::info!(
3856 stale_connection_id = %connection,
3857 "cleaning up stale connection",
3858 );
3859 drop(db);
3860 leave_room_for_session(&session, connection).await?;
3861 db = session.db().await;
3862 }
3863
3864 let (joined_room, membership_updated, role) = db
3865 .join_channel(channel_id, session.user_id(), session.connection_id)
3866 .await?;
3867
3868 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3869 let (can_publish, token) = if role == ChannelRole::Guest {
3870 (
3871 false,
3872 live_kit
3873 .guest_token(
3874 &joined_room.room.live_kit_room,
3875 &session.user_id().to_string(),
3876 )
3877 .trace_err()?,
3878 )
3879 } else {
3880 (
3881 true,
3882 live_kit
3883 .room_token(
3884 &joined_room.room.live_kit_room,
3885 &session.user_id().to_string(),
3886 )
3887 .trace_err()?,
3888 )
3889 };
3890
3891 Some(LiveKitConnectionInfo {
3892 server_url: live_kit.url().into(),
3893 token,
3894 can_publish,
3895 })
3896 });
3897
3898 response.send(proto::JoinRoomResponse {
3899 room: Some(joined_room.room.clone()),
3900 channel_id: joined_room
3901 .channel
3902 .as_ref()
3903 .map(|channel| channel.id.to_proto()),
3904 live_kit_connection_info,
3905 })?;
3906
3907 let mut connection_pool = session.connection_pool().await;
3908 if let Some(membership_updated) = membership_updated {
3909 notify_membership_updated(
3910 &mut connection_pool,
3911 membership_updated,
3912 session.user_id(),
3913 &session.peer,
3914 );
3915 }
3916
3917 room_updated(&joined_room.room, &session.peer);
3918
3919 joined_room
3920 };
3921
3922 channel_updated(
3923 &joined_room
3924 .channel
3925 .ok_or_else(|| anyhow!("channel not returned"))?,
3926 &joined_room.room,
3927 &session.peer,
3928 &*session.connection_pool().await,
3929 );
3930
3931 update_user_contacts(session.user_id(), &session).await?;
3932 Ok(())
3933}
3934
3935/// Start editing the channel notes
3936async fn join_channel_buffer(
3937 request: proto::JoinChannelBuffer,
3938 response: Response<proto::JoinChannelBuffer>,
3939 session: UserSession,
3940) -> Result<()> {
3941 let db = session.db().await;
3942 let channel_id = ChannelId::from_proto(request.channel_id);
3943
3944 let open_response = db
3945 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3946 .await?;
3947
3948 let collaborators = open_response.collaborators.clone();
3949 response.send(open_response)?;
3950
3951 let update = UpdateChannelBufferCollaborators {
3952 channel_id: channel_id.to_proto(),
3953 collaborators: collaborators.clone(),
3954 };
3955 channel_buffer_updated(
3956 session.connection_id,
3957 collaborators
3958 .iter()
3959 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3960 &update,
3961 &session.peer,
3962 );
3963
3964 Ok(())
3965}
3966
3967/// Edit the channel notes
3968async fn update_channel_buffer(
3969 request: proto::UpdateChannelBuffer,
3970 session: UserSession,
3971) -> Result<()> {
3972 let db = session.db().await;
3973 let channel_id = ChannelId::from_proto(request.channel_id);
3974
3975 let (collaborators, epoch, version) = db
3976 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3977 .await?;
3978
3979 channel_buffer_updated(
3980 session.connection_id,
3981 collaborators.clone(),
3982 &proto::UpdateChannelBuffer {
3983 channel_id: channel_id.to_proto(),
3984 operations: request.operations,
3985 },
3986 &session.peer,
3987 );
3988
3989 let pool = &*session.connection_pool().await;
3990
3991 let non_collaborators =
3992 pool.channel_connection_ids(channel_id)
3993 .filter_map(|(connection_id, _)| {
3994 if collaborators.contains(&connection_id) {
3995 None
3996 } else {
3997 Some(connection_id)
3998 }
3999 });
4000
4001 broadcast(None, non_collaborators, |peer_id| {
4002 session.peer.send(
4003 peer_id,
4004 proto::UpdateChannels {
4005 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4006 channel_id: channel_id.to_proto(),
4007 epoch: epoch as u64,
4008 version: version.clone(),
4009 }],
4010 ..Default::default()
4011 },
4012 )
4013 });
4014
4015 Ok(())
4016}
4017
4018/// Rejoin the channel notes after a connection blip
4019async fn rejoin_channel_buffers(
4020 request: proto::RejoinChannelBuffers,
4021 response: Response<proto::RejoinChannelBuffers>,
4022 session: UserSession,
4023) -> Result<()> {
4024 let db = session.db().await;
4025 let buffers = db
4026 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4027 .await?;
4028
4029 for rejoined_buffer in &buffers {
4030 let collaborators_to_notify = rejoined_buffer
4031 .buffer
4032 .collaborators
4033 .iter()
4034 .filter_map(|c| Some(c.peer_id?.into()));
4035 channel_buffer_updated(
4036 session.connection_id,
4037 collaborators_to_notify,
4038 &proto::UpdateChannelBufferCollaborators {
4039 channel_id: rejoined_buffer.buffer.channel_id,
4040 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4041 },
4042 &session.peer,
4043 );
4044 }
4045
4046 response.send(proto::RejoinChannelBuffersResponse {
4047 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4048 })?;
4049
4050 Ok(())
4051}
4052
4053/// Stop editing the channel notes
4054async fn leave_channel_buffer(
4055 request: proto::LeaveChannelBuffer,
4056 response: Response<proto::LeaveChannelBuffer>,
4057 session: UserSession,
4058) -> Result<()> {
4059 let db = session.db().await;
4060 let channel_id = ChannelId::from_proto(request.channel_id);
4061
4062 let left_buffer = db
4063 .leave_channel_buffer(channel_id, session.connection_id)
4064 .await?;
4065
4066 response.send(Ack {})?;
4067
4068 channel_buffer_updated(
4069 session.connection_id,
4070 left_buffer.connections,
4071 &proto::UpdateChannelBufferCollaborators {
4072 channel_id: channel_id.to_proto(),
4073 collaborators: left_buffer.collaborators,
4074 },
4075 &session.peer,
4076 );
4077
4078 Ok(())
4079}
4080
4081fn channel_buffer_updated<T: EnvelopedMessage>(
4082 sender_id: ConnectionId,
4083 collaborators: impl IntoIterator<Item = ConnectionId>,
4084 message: &T,
4085 peer: &Peer,
4086) {
4087 broadcast(Some(sender_id), collaborators, |peer_id| {
4088 peer.send(peer_id, message.clone())
4089 });
4090}
4091
4092fn send_notifications(
4093 connection_pool: &ConnectionPool,
4094 peer: &Peer,
4095 notifications: db::NotificationBatch,
4096) {
4097 for (user_id, notification) in notifications {
4098 for connection_id in connection_pool.user_connection_ids(user_id) {
4099 if let Err(error) = peer.send(
4100 connection_id,
4101 proto::AddNotification {
4102 notification: Some(notification.clone()),
4103 },
4104 ) {
4105 tracing::error!(
4106 "failed to send notification to {:?} {}",
4107 connection_id,
4108 error
4109 );
4110 }
4111 }
4112 }
4113}
4114
4115/// Send a message to the channel
4116async fn send_channel_message(
4117 request: proto::SendChannelMessage,
4118 response: Response<proto::SendChannelMessage>,
4119 session: UserSession,
4120) -> Result<()> {
4121 // Validate the message body.
4122 let body = request.body.trim().to_string();
4123 if body.len() > MAX_MESSAGE_LEN {
4124 return Err(anyhow!("message is too long"))?;
4125 }
4126 if body.is_empty() {
4127 return Err(anyhow!("message can't be blank"))?;
4128 }
4129
4130 // TODO: adjust mentions if body is trimmed
4131
4132 let timestamp = OffsetDateTime::now_utc();
4133 let nonce = request
4134 .nonce
4135 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4136
4137 let channel_id = ChannelId::from_proto(request.channel_id);
4138 let CreatedChannelMessage {
4139 message_id,
4140 participant_connection_ids,
4141 notifications,
4142 } = session
4143 .db()
4144 .await
4145 .create_channel_message(
4146 channel_id,
4147 session.user_id(),
4148 &body,
4149 &request.mentions,
4150 timestamp,
4151 nonce.clone().into(),
4152 match request.reply_to_message_id {
4153 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4154 None => None,
4155 },
4156 )
4157 .await?;
4158
4159 let message = proto::ChannelMessage {
4160 sender_id: session.user_id().to_proto(),
4161 id: message_id.to_proto(),
4162 body,
4163 mentions: request.mentions,
4164 timestamp: timestamp.unix_timestamp() as u64,
4165 nonce: Some(nonce),
4166 reply_to_message_id: request.reply_to_message_id,
4167 edited_at: None,
4168 };
4169 broadcast(
4170 Some(session.connection_id),
4171 participant_connection_ids.clone(),
4172 |connection| {
4173 session.peer.send(
4174 connection,
4175 proto::ChannelMessageSent {
4176 channel_id: channel_id.to_proto(),
4177 message: Some(message.clone()),
4178 },
4179 )
4180 },
4181 );
4182 response.send(proto::SendChannelMessageResponse {
4183 message: Some(message),
4184 })?;
4185
4186 let pool = &*session.connection_pool().await;
4187 let non_participants =
4188 pool.channel_connection_ids(channel_id)
4189 .filter_map(|(connection_id, _)| {
4190 if participant_connection_ids.contains(&connection_id) {
4191 None
4192 } else {
4193 Some(connection_id)
4194 }
4195 });
4196 broadcast(None, non_participants, |peer_id| {
4197 session.peer.send(
4198 peer_id,
4199 proto::UpdateChannels {
4200 latest_channel_message_ids: vec![proto::ChannelMessageId {
4201 channel_id: channel_id.to_proto(),
4202 message_id: message_id.to_proto(),
4203 }],
4204 ..Default::default()
4205 },
4206 )
4207 });
4208 send_notifications(pool, &session.peer, notifications);
4209
4210 Ok(())
4211}
4212
4213/// Delete a channel message
4214async fn remove_channel_message(
4215 request: proto::RemoveChannelMessage,
4216 response: Response<proto::RemoveChannelMessage>,
4217 session: UserSession,
4218) -> Result<()> {
4219 let channel_id = ChannelId::from_proto(request.channel_id);
4220 let message_id = MessageId::from_proto(request.message_id);
4221 let (connection_ids, existing_notification_ids) = session
4222 .db()
4223 .await
4224 .remove_channel_message(channel_id, message_id, session.user_id())
4225 .await?;
4226
4227 broadcast(
4228 Some(session.connection_id),
4229 connection_ids,
4230 move |connection| {
4231 session.peer.send(connection, request.clone())?;
4232
4233 for notification_id in &existing_notification_ids {
4234 session.peer.send(
4235 connection,
4236 proto::DeleteNotification {
4237 notification_id: (*notification_id).to_proto(),
4238 },
4239 )?;
4240 }
4241
4242 Ok(())
4243 },
4244 );
4245 response.send(proto::Ack {})?;
4246 Ok(())
4247}
4248
4249async fn update_channel_message(
4250 request: proto::UpdateChannelMessage,
4251 response: Response<proto::UpdateChannelMessage>,
4252 session: UserSession,
4253) -> Result<()> {
4254 let channel_id = ChannelId::from_proto(request.channel_id);
4255 let message_id = MessageId::from_proto(request.message_id);
4256 let updated_at = OffsetDateTime::now_utc();
4257 let UpdatedChannelMessage {
4258 message_id,
4259 participant_connection_ids,
4260 notifications,
4261 reply_to_message_id,
4262 timestamp,
4263 deleted_mention_notification_ids,
4264 updated_mention_notifications,
4265 } = session
4266 .db()
4267 .await
4268 .update_channel_message(
4269 channel_id,
4270 message_id,
4271 session.user_id(),
4272 request.body.as_str(),
4273 &request.mentions,
4274 updated_at,
4275 )
4276 .await?;
4277
4278 let nonce = request
4279 .nonce
4280 .clone()
4281 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4282
4283 let message = proto::ChannelMessage {
4284 sender_id: session.user_id().to_proto(),
4285 id: message_id.to_proto(),
4286 body: request.body.clone(),
4287 mentions: request.mentions.clone(),
4288 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4289 nonce: Some(nonce),
4290 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4291 edited_at: Some(updated_at.unix_timestamp() as u64),
4292 };
4293
4294 response.send(proto::Ack {})?;
4295
4296 let pool = &*session.connection_pool().await;
4297 broadcast(
4298 Some(session.connection_id),
4299 participant_connection_ids,
4300 |connection| {
4301 session.peer.send(
4302 connection,
4303 proto::ChannelMessageUpdate {
4304 channel_id: channel_id.to_proto(),
4305 message: Some(message.clone()),
4306 },
4307 )?;
4308
4309 for notification_id in &deleted_mention_notification_ids {
4310 session.peer.send(
4311 connection,
4312 proto::DeleteNotification {
4313 notification_id: (*notification_id).to_proto(),
4314 },
4315 )?;
4316 }
4317
4318 for notification in &updated_mention_notifications {
4319 session.peer.send(
4320 connection,
4321 proto::UpdateNotification {
4322 notification: Some(notification.clone()),
4323 },
4324 )?;
4325 }
4326
4327 Ok(())
4328 },
4329 );
4330
4331 send_notifications(pool, &session.peer, notifications);
4332
4333 Ok(())
4334}
4335
4336/// Mark a channel message as read
4337async fn acknowledge_channel_message(
4338 request: proto::AckChannelMessage,
4339 session: UserSession,
4340) -> Result<()> {
4341 let channel_id = ChannelId::from_proto(request.channel_id);
4342 let message_id = MessageId::from_proto(request.message_id);
4343 let notifications = session
4344 .db()
4345 .await
4346 .observe_channel_message(channel_id, session.user_id(), message_id)
4347 .await?;
4348 send_notifications(
4349 &*session.connection_pool().await,
4350 &session.peer,
4351 notifications,
4352 );
4353 Ok(())
4354}
4355
4356/// Mark a buffer version as synced
4357async fn acknowledge_buffer_version(
4358 request: proto::AckBufferOperation,
4359 session: UserSession,
4360) -> Result<()> {
4361 let buffer_id = BufferId::from_proto(request.buffer_id);
4362 session
4363 .db()
4364 .await
4365 .observe_buffer_version(
4366 buffer_id,
4367 session.user_id(),
4368 request.epoch as i32,
4369 &request.version,
4370 )
4371 .await?;
4372 Ok(())
4373}
4374
4375struct CompleteWithLanguageModelRateLimit;
4376
4377impl RateLimit for CompleteWithLanguageModelRateLimit {
4378 fn capacity() -> usize {
4379 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4380 .ok()
4381 .and_then(|v| v.parse().ok())
4382 .unwrap_or(120) // Picked arbitrarily
4383 }
4384
4385 fn refill_duration() -> chrono::Duration {
4386 chrono::Duration::hours(1)
4387 }
4388
4389 fn db_name() -> &'static str {
4390 "complete-with-language-model"
4391 }
4392}
4393
4394async fn complete_with_language_model(
4395 request: proto::CompleteWithLanguageModel,
4396 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4397 session: Session,
4398 open_ai_api_key: Option<Arc<str>>,
4399 google_ai_api_key: Option<Arc<str>>,
4400 anthropic_api_key: Option<Arc<str>>,
4401) -> Result<()> {
4402 let Some(session) = session.for_user() else {
4403 return Err(anyhow!("user not found"))?;
4404 };
4405 authorize_access_to_language_models(&session).await?;
4406 session
4407 .rate_limiter
4408 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4409 .await?;
4410
4411 if request.model.starts_with("gpt") {
4412 let api_key =
4413 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4414 complete_with_open_ai(request, response, session, api_key).await?;
4415 } else if request.model.starts_with("gemini") {
4416 let api_key = google_ai_api_key
4417 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4418 complete_with_google_ai(request, response, session, api_key).await?;
4419 } else if request.model.starts_with("claude") {
4420 let api_key = anthropic_api_key
4421 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4422 complete_with_anthropic(request, response, session, api_key).await?;
4423 }
4424
4425 Ok(())
4426}
4427
4428async fn complete_with_open_ai(
4429 request: proto::CompleteWithLanguageModel,
4430 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4431 session: UserSession,
4432 api_key: Arc<str>,
4433) -> Result<()> {
4434 let mut completion_stream = open_ai::stream_completion(
4435 session.http_client.as_ref(),
4436 OPEN_AI_API_URL,
4437 &api_key,
4438 crate::ai::language_model_request_to_open_ai(request)?,
4439 None,
4440 )
4441 .await
4442 .context("open_ai::stream_completion request failed within collab")?;
4443
4444 while let Some(event) = completion_stream.next().await {
4445 let event = event?;
4446 response.send(proto::LanguageModelResponse {
4447 choices: event
4448 .choices
4449 .into_iter()
4450 .map(|choice| proto::LanguageModelChoiceDelta {
4451 index: choice.index,
4452 delta: Some(proto::LanguageModelResponseMessage {
4453 role: choice.delta.role.map(|role| match role {
4454 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4455 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4456 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4457 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4458 } as i32),
4459 content: choice.delta.content,
4460 tool_calls: choice
4461 .delta
4462 .tool_calls
4463 .into_iter()
4464 .map(|delta| proto::ToolCallDelta {
4465 index: delta.index as u32,
4466 id: delta.id,
4467 variant: match delta.function {
4468 Some(function) => {
4469 let name = function.name;
4470 let arguments = function.arguments;
4471
4472 Some(proto::tool_call_delta::Variant::Function(
4473 proto::tool_call_delta::FunctionCallDelta {
4474 name,
4475 arguments,
4476 },
4477 ))
4478 }
4479 None => None,
4480 },
4481 })
4482 .collect(),
4483 }),
4484 finish_reason: choice.finish_reason,
4485 })
4486 .collect(),
4487 })?;
4488 }
4489
4490 Ok(())
4491}
4492
4493async fn complete_with_google_ai(
4494 request: proto::CompleteWithLanguageModel,
4495 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4496 session: UserSession,
4497 api_key: Arc<str>,
4498) -> Result<()> {
4499 let mut stream = google_ai::stream_generate_content(
4500 session.http_client.clone(),
4501 google_ai::API_URL,
4502 api_key.as_ref(),
4503 crate::ai::language_model_request_to_google_ai(request)?,
4504 )
4505 .await
4506 .context("google_ai::stream_generate_content request failed")?;
4507
4508 while let Some(event) = stream.next().await {
4509 let event = event?;
4510 response.send(proto::LanguageModelResponse {
4511 choices: event
4512 .candidates
4513 .unwrap_or_default()
4514 .into_iter()
4515 .map(|candidate| proto::LanguageModelChoiceDelta {
4516 index: candidate.index as u32,
4517 delta: Some(proto::LanguageModelResponseMessage {
4518 role: Some(match candidate.content.role {
4519 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4520 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4521 } as i32),
4522 content: Some(
4523 candidate
4524 .content
4525 .parts
4526 .into_iter()
4527 .filter_map(|part| match part {
4528 google_ai::Part::TextPart(part) => Some(part.text),
4529 google_ai::Part::InlineDataPart(_) => None,
4530 })
4531 .collect(),
4532 ),
4533 // Tool calls are not supported for Google
4534 tool_calls: Vec::new(),
4535 }),
4536 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4537 })
4538 .collect(),
4539 })?;
4540 }
4541
4542 Ok(())
4543}
4544
4545async fn complete_with_anthropic(
4546 request: proto::CompleteWithLanguageModel,
4547 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4548 session: UserSession,
4549 api_key: Arc<str>,
4550) -> Result<()> {
4551 let model = anthropic::Model::from_id(&request.model)?;
4552
4553 let mut system_message = String::new();
4554 let messages = request
4555 .messages
4556 .into_iter()
4557 .filter_map(|message| {
4558 match message.role() {
4559 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4560 role: anthropic::Role::User,
4561 content: message.content,
4562 }),
4563 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4564 role: anthropic::Role::Assistant,
4565 content: message.content,
4566 }),
4567 // Anthropic's API breaks system instructions out as a separate field rather
4568 // than having a system message role.
4569 LanguageModelRole::LanguageModelSystem => {
4570 if !system_message.is_empty() {
4571 system_message.push_str("\n\n");
4572 }
4573 system_message.push_str(&message.content);
4574
4575 None
4576 }
4577 // We don't yet support tool calls for Anthropic
4578 LanguageModelRole::LanguageModelTool => None,
4579 }
4580 })
4581 .collect();
4582
4583 let mut stream = anthropic::stream_completion(
4584 session.http_client.as_ref(),
4585 anthropic::ANTHROPIC_API_URL,
4586 &api_key,
4587 anthropic::Request {
4588 model,
4589 messages,
4590 stream: true,
4591 system: system_message,
4592 max_tokens: 4092,
4593 },
4594 None,
4595 )
4596 .await?;
4597
4598 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4599
4600 while let Some(event) = stream.next().await {
4601 let event = event?;
4602
4603 match event {
4604 anthropic::ResponseEvent::MessageStart { message } => {
4605 if let Some(role) = message.role {
4606 if role == "assistant" {
4607 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4608 } else if role == "user" {
4609 current_role = proto::LanguageModelRole::LanguageModelUser;
4610 }
4611 }
4612 }
4613 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4614 match content_block {
4615 anthropic::ContentBlock::Text { text } => {
4616 if !text.is_empty() {
4617 response.send(proto::LanguageModelResponse {
4618 choices: vec![proto::LanguageModelChoiceDelta {
4619 index: 0,
4620 delta: Some(proto::LanguageModelResponseMessage {
4621 role: Some(current_role as i32),
4622 content: Some(text),
4623 tool_calls: Vec::new(),
4624 }),
4625 finish_reason: None,
4626 }],
4627 })?;
4628 }
4629 }
4630 }
4631 }
4632 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4633 anthropic::TextDelta::TextDelta { text } => {
4634 response.send(proto::LanguageModelResponse {
4635 choices: vec![proto::LanguageModelChoiceDelta {
4636 index: 0,
4637 delta: Some(proto::LanguageModelResponseMessage {
4638 role: Some(current_role as i32),
4639 content: Some(text),
4640 tool_calls: Vec::new(),
4641 }),
4642 finish_reason: None,
4643 }],
4644 })?;
4645 }
4646 },
4647 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4648 if let Some(stop_reason) = delta.stop_reason {
4649 response.send(proto::LanguageModelResponse {
4650 choices: vec![proto::LanguageModelChoiceDelta {
4651 index: 0,
4652 delta: None,
4653 finish_reason: Some(stop_reason),
4654 }],
4655 })?;
4656 }
4657 }
4658 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4659 anthropic::ResponseEvent::MessageStop {} => {}
4660 anthropic::ResponseEvent::Ping {} => {}
4661 }
4662 }
4663
4664 Ok(())
4665}
4666
4667struct CountTokensWithLanguageModelRateLimit;
4668
4669impl RateLimit for CountTokensWithLanguageModelRateLimit {
4670 fn capacity() -> usize {
4671 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4672 .ok()
4673 .and_then(|v| v.parse().ok())
4674 .unwrap_or(600) // Picked arbitrarily
4675 }
4676
4677 fn refill_duration() -> chrono::Duration {
4678 chrono::Duration::hours(1)
4679 }
4680
4681 fn db_name() -> &'static str {
4682 "count-tokens-with-language-model"
4683 }
4684}
4685
4686async fn count_tokens_with_language_model(
4687 request: proto::CountTokensWithLanguageModel,
4688 response: Response<proto::CountTokensWithLanguageModel>,
4689 session: UserSession,
4690 google_ai_api_key: Option<Arc<str>>,
4691) -> Result<()> {
4692 authorize_access_to_language_models(&session).await?;
4693
4694 if !request.model.starts_with("gemini") {
4695 return Err(anyhow!(
4696 "counting tokens for model: {:?} is not supported",
4697 request.model
4698 ))?;
4699 }
4700
4701 session
4702 .rate_limiter
4703 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4704 .await?;
4705
4706 let api_key = google_ai_api_key
4707 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4708 let tokens_response = google_ai::count_tokens(
4709 session.http_client.as_ref(),
4710 google_ai::API_URL,
4711 &api_key,
4712 crate::ai::count_tokens_request_to_google_ai(request)?,
4713 )
4714 .await?;
4715 response.send(proto::CountTokensResponse {
4716 token_count: tokens_response.total_tokens as u32,
4717 })?;
4718 Ok(())
4719}
4720
4721struct ComputeEmbeddingsRateLimit;
4722
4723impl RateLimit for ComputeEmbeddingsRateLimit {
4724 fn capacity() -> usize {
4725 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4726 .ok()
4727 .and_then(|v| v.parse().ok())
4728 .unwrap_or(5000) // Picked arbitrarily
4729 }
4730
4731 fn refill_duration() -> chrono::Duration {
4732 chrono::Duration::hours(1)
4733 }
4734
4735 fn db_name() -> &'static str {
4736 "compute-embeddings"
4737 }
4738}
4739
4740async fn compute_embeddings(
4741 request: proto::ComputeEmbeddings,
4742 response: Response<proto::ComputeEmbeddings>,
4743 session: UserSession,
4744 api_key: Option<Arc<str>>,
4745) -> Result<()> {
4746 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4747 authorize_access_to_language_models(&session).await?;
4748
4749 session
4750 .rate_limiter
4751 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4752 .await?;
4753
4754 let embeddings = match request.model.as_str() {
4755 "openai/text-embedding-3-small" => {
4756 open_ai::embed(
4757 session.http_client.as_ref(),
4758 OPEN_AI_API_URL,
4759 &api_key,
4760 OpenAiEmbeddingModel::TextEmbedding3Small,
4761 request.texts.iter().map(|text| text.as_str()),
4762 )
4763 .await?
4764 }
4765 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4766 };
4767
4768 let embeddings = request
4769 .texts
4770 .iter()
4771 .map(|text| {
4772 let mut hasher = sha2::Sha256::new();
4773 hasher.update(text.as_bytes());
4774 let result = hasher.finalize();
4775 result.to_vec()
4776 })
4777 .zip(
4778 embeddings
4779 .data
4780 .into_iter()
4781 .map(|embedding| embedding.embedding),
4782 )
4783 .collect::<HashMap<_, _>>();
4784
4785 let db = session.db().await;
4786 db.save_embeddings(&request.model, &embeddings)
4787 .await
4788 .context("failed to save embeddings")
4789 .trace_err();
4790
4791 response.send(proto::ComputeEmbeddingsResponse {
4792 embeddings: embeddings
4793 .into_iter()
4794 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4795 .collect(),
4796 })?;
4797 Ok(())
4798}
4799
4800async fn get_cached_embeddings(
4801 request: proto::GetCachedEmbeddings,
4802 response: Response<proto::GetCachedEmbeddings>,
4803 session: UserSession,
4804) -> Result<()> {
4805 authorize_access_to_language_models(&session).await?;
4806
4807 let db = session.db().await;
4808 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4809
4810 response.send(proto::GetCachedEmbeddingsResponse {
4811 embeddings: embeddings
4812 .into_iter()
4813 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4814 .collect(),
4815 })?;
4816 Ok(())
4817}
4818
4819async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4820 let db = session.db().await;
4821 let flags = db.get_user_flags(session.user_id()).await?;
4822 if flags.iter().any(|flag| flag == "language-models") {
4823 Ok(())
4824 } else {
4825 Err(anyhow!("permission denied"))?
4826 }
4827}
4828
4829/// Get a Supermaven API key for the user
4830async fn get_supermaven_api_key(
4831 _request: proto::GetSupermavenApiKey,
4832 response: Response<proto::GetSupermavenApiKey>,
4833 session: UserSession,
4834) -> Result<()> {
4835 let user_id: String = session.user_id().to_string();
4836 if !session.is_staff() {
4837 return Err(anyhow!("supermaven not enabled for this account"))?;
4838 }
4839
4840 let email = session
4841 .email()
4842 .ok_or_else(|| anyhow!("user must have an email"))?;
4843
4844 let supermaven_admin_api = session
4845 .supermaven_client
4846 .as_ref()
4847 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4848
4849 let result = supermaven_admin_api
4850 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4851 .await?;
4852
4853 response.send(proto::GetSupermavenApiKeyResponse {
4854 api_key: result.api_key,
4855 })?;
4856
4857 Ok(())
4858}
4859
4860/// Start receiving chat updates for a channel
4861async fn join_channel_chat(
4862 request: proto::JoinChannelChat,
4863 response: Response<proto::JoinChannelChat>,
4864 session: UserSession,
4865) -> Result<()> {
4866 let channel_id = ChannelId::from_proto(request.channel_id);
4867
4868 let db = session.db().await;
4869 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4870 .await?;
4871 let messages = db
4872 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4873 .await?;
4874 response.send(proto::JoinChannelChatResponse {
4875 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4876 messages,
4877 })?;
4878 Ok(())
4879}
4880
4881/// Stop receiving chat updates for a channel
4882async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4883 let channel_id = ChannelId::from_proto(request.channel_id);
4884 session
4885 .db()
4886 .await
4887 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4888 .await?;
4889 Ok(())
4890}
4891
4892/// Retrieve the chat history for a channel
4893async fn get_channel_messages(
4894 request: proto::GetChannelMessages,
4895 response: Response<proto::GetChannelMessages>,
4896 session: UserSession,
4897) -> Result<()> {
4898 let channel_id = ChannelId::from_proto(request.channel_id);
4899 let messages = session
4900 .db()
4901 .await
4902 .get_channel_messages(
4903 channel_id,
4904 session.user_id(),
4905 MESSAGE_COUNT_PER_PAGE,
4906 Some(MessageId::from_proto(request.before_message_id)),
4907 )
4908 .await?;
4909 response.send(proto::GetChannelMessagesResponse {
4910 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4911 messages,
4912 })?;
4913 Ok(())
4914}
4915
4916/// Retrieve specific chat messages
4917async fn get_channel_messages_by_id(
4918 request: proto::GetChannelMessagesById,
4919 response: Response<proto::GetChannelMessagesById>,
4920 session: UserSession,
4921) -> Result<()> {
4922 let message_ids = request
4923 .message_ids
4924 .iter()
4925 .map(|id| MessageId::from_proto(*id))
4926 .collect::<Vec<_>>();
4927 let messages = session
4928 .db()
4929 .await
4930 .get_channel_messages_by_id(session.user_id(), &message_ids)
4931 .await?;
4932 response.send(proto::GetChannelMessagesResponse {
4933 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4934 messages,
4935 })?;
4936 Ok(())
4937}
4938
4939/// Retrieve the current users notifications
4940async fn get_notifications(
4941 request: proto::GetNotifications,
4942 response: Response<proto::GetNotifications>,
4943 session: UserSession,
4944) -> Result<()> {
4945 let notifications = session
4946 .db()
4947 .await
4948 .get_notifications(
4949 session.user_id(),
4950 NOTIFICATION_COUNT_PER_PAGE,
4951 request
4952 .before_id
4953 .map(|id| db::NotificationId::from_proto(id)),
4954 )
4955 .await?;
4956 response.send(proto::GetNotificationsResponse {
4957 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4958 notifications,
4959 })?;
4960 Ok(())
4961}
4962
4963/// Mark notifications as read
4964async fn mark_notification_as_read(
4965 request: proto::MarkNotificationRead,
4966 response: Response<proto::MarkNotificationRead>,
4967 session: UserSession,
4968) -> Result<()> {
4969 let database = &session.db().await;
4970 let notifications = database
4971 .mark_notification_as_read_by_id(
4972 session.user_id(),
4973 NotificationId::from_proto(request.notification_id),
4974 )
4975 .await?;
4976 send_notifications(
4977 &*session.connection_pool().await,
4978 &session.peer,
4979 notifications,
4980 );
4981 response.send(proto::Ack {})?;
4982 Ok(())
4983}
4984
4985/// Get the current users information
4986async fn get_private_user_info(
4987 _request: proto::GetPrivateUserInfo,
4988 response: Response<proto::GetPrivateUserInfo>,
4989 session: UserSession,
4990) -> Result<()> {
4991 let db = session.db().await;
4992
4993 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4994 let user = db
4995 .get_user_by_id(session.user_id())
4996 .await?
4997 .ok_or_else(|| anyhow!("user not found"))?;
4998 let flags = db.get_user_flags(session.user_id()).await?;
4999
5000 response.send(proto::GetPrivateUserInfoResponse {
5001 metrics_id,
5002 staff: user.admin,
5003 flags,
5004 })?;
5005 Ok(())
5006}
5007
5008fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5009 match message {
5010 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5011 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5012 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5013 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5014 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5015 code: frame.code.into(),
5016 reason: frame.reason,
5017 })),
5018 }
5019}
5020
5021fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5022 match message {
5023 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5024 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5025 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5026 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5027 AxumMessage::Close(frame) => {
5028 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5029 code: frame.code.into(),
5030 reason: frame.reason,
5031 }))
5032 }
5033 }
5034}
5035
5036fn notify_membership_updated(
5037 connection_pool: &mut ConnectionPool,
5038 result: MembershipUpdated,
5039 user_id: UserId,
5040 peer: &Peer,
5041) {
5042 for membership in &result.new_channels.channel_memberships {
5043 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5044 }
5045 for channel_id in &result.removed_channels {
5046 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5047 }
5048
5049 let user_channels_update = proto::UpdateUserChannels {
5050 channel_memberships: result
5051 .new_channels
5052 .channel_memberships
5053 .iter()
5054 .map(|cm| proto::ChannelMembership {
5055 channel_id: cm.channel_id.to_proto(),
5056 role: cm.role.into(),
5057 })
5058 .collect(),
5059 ..Default::default()
5060 };
5061
5062 let mut update = build_channels_update(result.new_channels);
5063 update.delete_channels = result
5064 .removed_channels
5065 .into_iter()
5066 .map(|id| id.to_proto())
5067 .collect();
5068 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5069
5070 for connection_id in connection_pool.user_connection_ids(user_id) {
5071 peer.send(connection_id, user_channels_update.clone())
5072 .trace_err();
5073 peer.send(connection_id, update.clone()).trace_err();
5074 }
5075}
5076
5077fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5078 proto::UpdateUserChannels {
5079 channel_memberships: channels
5080 .channel_memberships
5081 .iter()
5082 .map(|m| proto::ChannelMembership {
5083 channel_id: m.channel_id.to_proto(),
5084 role: m.role.into(),
5085 })
5086 .collect(),
5087 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5088 observed_channel_message_id: channels.observed_channel_messages.clone(),
5089 }
5090}
5091
5092fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5093 let mut update = proto::UpdateChannels::default();
5094
5095 for channel in channels.channels {
5096 update.channels.push(channel.to_proto());
5097 }
5098
5099 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5100 update.latest_channel_message_ids = channels.latest_channel_messages;
5101
5102 for (channel_id, participants) in channels.channel_participants {
5103 update
5104 .channel_participants
5105 .push(proto::ChannelParticipants {
5106 channel_id: channel_id.to_proto(),
5107 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5108 });
5109 }
5110
5111 for channel in channels.invited_channels {
5112 update.channel_invitations.push(channel.to_proto());
5113 }
5114
5115 update.hosted_projects = channels.hosted_projects;
5116 update
5117}
5118
5119fn build_initial_contacts_update(
5120 contacts: Vec<db::Contact>,
5121 pool: &ConnectionPool,
5122) -> proto::UpdateContacts {
5123 let mut update = proto::UpdateContacts::default();
5124
5125 for contact in contacts {
5126 match contact {
5127 db::Contact::Accepted { user_id, busy } => {
5128 update.contacts.push(contact_for_user(user_id, busy, &pool));
5129 }
5130 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5131 db::Contact::Incoming { user_id } => {
5132 update
5133 .incoming_requests
5134 .push(proto::IncomingContactRequest {
5135 requester_id: user_id.to_proto(),
5136 })
5137 }
5138 }
5139 }
5140
5141 update
5142}
5143
5144fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5145 proto::Contact {
5146 user_id: user_id.to_proto(),
5147 online: pool.is_user_online(user_id),
5148 busy,
5149 }
5150}
5151
5152fn room_updated(room: &proto::Room, peer: &Peer) {
5153 broadcast(
5154 None,
5155 room.participants
5156 .iter()
5157 .filter_map(|participant| Some(participant.peer_id?.into())),
5158 |peer_id| {
5159 peer.send(
5160 peer_id,
5161 proto::RoomUpdated {
5162 room: Some(room.clone()),
5163 },
5164 )
5165 },
5166 );
5167}
5168
5169fn channel_updated(
5170 channel: &db::channel::Model,
5171 room: &proto::Room,
5172 peer: &Peer,
5173 pool: &ConnectionPool,
5174) {
5175 let participants = room
5176 .participants
5177 .iter()
5178 .map(|p| p.user_id)
5179 .collect::<Vec<_>>();
5180
5181 broadcast(
5182 None,
5183 pool.channel_connection_ids(channel.root_id())
5184 .filter_map(|(channel_id, role)| {
5185 role.can_see_channel(channel.visibility).then(|| channel_id)
5186 }),
5187 |peer_id| {
5188 peer.send(
5189 peer_id,
5190 proto::UpdateChannels {
5191 channel_participants: vec![proto::ChannelParticipants {
5192 channel_id: channel.id.to_proto(),
5193 participant_user_ids: participants.clone(),
5194 }],
5195 ..Default::default()
5196 },
5197 )
5198 },
5199 );
5200}
5201
5202async fn send_dev_server_projects_update(
5203 user_id: UserId,
5204 mut status: proto::DevServerProjectsUpdate,
5205 session: &Session,
5206) {
5207 let pool = session.connection_pool().await;
5208 for dev_server in &mut status.dev_servers {
5209 dev_server.status =
5210 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5211 }
5212 let connections = pool.user_connection_ids(user_id);
5213 for connection_id in connections {
5214 session.peer.send(connection_id, status.clone()).trace_err();
5215 }
5216}
5217
5218async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5219 let db = session.db().await;
5220
5221 let contacts = db.get_contacts(user_id).await?;
5222 let busy = db.is_user_busy(user_id).await?;
5223
5224 let pool = session.connection_pool().await;
5225 let updated_contact = contact_for_user(user_id, busy, &pool);
5226 for contact in contacts {
5227 if let db::Contact::Accepted {
5228 user_id: contact_user_id,
5229 ..
5230 } = contact
5231 {
5232 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5233 session
5234 .peer
5235 .send(
5236 contact_conn_id,
5237 proto::UpdateContacts {
5238 contacts: vec![updated_contact.clone()],
5239 remove_contacts: Default::default(),
5240 incoming_requests: Default::default(),
5241 remove_incoming_requests: Default::default(),
5242 outgoing_requests: Default::default(),
5243 remove_outgoing_requests: Default::default(),
5244 },
5245 )
5246 .trace_err();
5247 }
5248 }
5249 }
5250 Ok(())
5251}
5252
5253async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5254 log::info!("lost dev server connection, unsharing projects");
5255 let project_ids = session
5256 .db()
5257 .await
5258 .get_stale_dev_server_projects(session.connection_id)
5259 .await?;
5260
5261 for project_id in project_ids {
5262 // not unshare re-checks the connection ids match, so we get away with no transaction
5263 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5264 }
5265
5266 let user_id = session.dev_server().user_id;
5267 let update = session
5268 .db()
5269 .await
5270 .dev_server_projects_update(user_id)
5271 .await?;
5272
5273 send_dev_server_projects_update(user_id, update, session).await;
5274
5275 Ok(())
5276}
5277
5278async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5279 let mut contacts_to_update = HashSet::default();
5280
5281 let room_id;
5282 let canceled_calls_to_user_ids;
5283 let live_kit_room;
5284 let delete_live_kit_room;
5285 let room;
5286 let channel;
5287
5288 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5289 contacts_to_update.insert(session.user_id());
5290
5291 for project in left_room.left_projects.values() {
5292 project_left(project, session);
5293 }
5294
5295 room_id = RoomId::from_proto(left_room.room.id);
5296 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5297 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5298 delete_live_kit_room = left_room.deleted;
5299 room = mem::take(&mut left_room.room);
5300 channel = mem::take(&mut left_room.channel);
5301
5302 room_updated(&room, &session.peer);
5303 } else {
5304 return Ok(());
5305 }
5306
5307 if let Some(channel) = channel {
5308 channel_updated(
5309 &channel,
5310 &room,
5311 &session.peer,
5312 &*session.connection_pool().await,
5313 );
5314 }
5315
5316 {
5317 let pool = session.connection_pool().await;
5318 for canceled_user_id in canceled_calls_to_user_ids {
5319 for connection_id in pool.user_connection_ids(canceled_user_id) {
5320 session
5321 .peer
5322 .send(
5323 connection_id,
5324 proto::CallCanceled {
5325 room_id: room_id.to_proto(),
5326 },
5327 )
5328 .trace_err();
5329 }
5330 contacts_to_update.insert(canceled_user_id);
5331 }
5332 }
5333
5334 for contact_user_id in contacts_to_update {
5335 update_user_contacts(contact_user_id, &session).await?;
5336 }
5337
5338 if let Some(live_kit) = session.live_kit_client.as_ref() {
5339 live_kit
5340 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5341 .await
5342 .trace_err();
5343
5344 if delete_live_kit_room {
5345 live_kit.delete_room(live_kit_room).await.trace_err();
5346 }
5347 }
5348
5349 Ok(())
5350}
5351
5352async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5353 let left_channel_buffers = session
5354 .db()
5355 .await
5356 .leave_channel_buffers(session.connection_id)
5357 .await?;
5358
5359 for left_buffer in left_channel_buffers {
5360 channel_buffer_updated(
5361 session.connection_id,
5362 left_buffer.connections,
5363 &proto::UpdateChannelBufferCollaborators {
5364 channel_id: left_buffer.channel_id.to_proto(),
5365 collaborators: left_buffer.collaborators,
5366 },
5367 &session.peer,
5368 );
5369 }
5370
5371 Ok(())
5372}
5373
5374fn project_left(project: &db::LeftProject, session: &UserSession) {
5375 for connection_id in &project.connection_ids {
5376 if project.should_unshare {
5377 session
5378 .peer
5379 .send(
5380 *connection_id,
5381 proto::UnshareProject {
5382 project_id: project.id.to_proto(),
5383 },
5384 )
5385 .trace_err();
5386 } else {
5387 session
5388 .peer
5389 .send(
5390 *connection_id,
5391 proto::RemoveProjectCollaborator {
5392 project_id: project.id.to_proto(),
5393 peer_id: Some(session.connection_id.into()),
5394 },
5395 )
5396 .trace_err();
5397 }
5398 }
5399}
5400
5401pub trait ResultExt {
5402 type Ok;
5403
5404 fn trace_err(self) -> Option<Self::Ok>;
5405}
5406
5407impl<T, E> ResultExt for Result<T, E>
5408where
5409 E: std::fmt::Debug,
5410{
5411 type Ok = T;
5412
5413 #[track_caller]
5414 fn trace_err(self) -> Option<T> {
5415 match self {
5416 Ok(value) => Some(value),
5417 Err(error) => {
5418 tracing::error!("{:?}", error);
5419 None
5420 }
5421 }
5422 }
5423}