1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_request_handler(user_handler(
549 forward_mutating_project_request::<proto::RestartLanguageServers>,
550 ))
551 .add_request_handler(user_handler(
552 forward_mutating_project_request::<proto::LinkedEditingRange>,
553 ))
554 .add_message_handler(create_buffer_for_peer)
555 .add_request_handler(update_buffer)
556 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
557 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
558 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
561 .add_request_handler(get_users)
562 .add_request_handler(user_handler(fuzzy_search_users))
563 .add_request_handler(user_handler(request_contact))
564 .add_request_handler(user_handler(remove_contact))
565 .add_request_handler(user_handler(respond_to_contact_request))
566 .add_message_handler(subscribe_to_channels)
567 .add_request_handler(user_handler(create_channel))
568 .add_request_handler(user_handler(delete_channel))
569 .add_request_handler(user_handler(invite_channel_member))
570 .add_request_handler(user_handler(remove_channel_member))
571 .add_request_handler(user_handler(set_channel_member_role))
572 .add_request_handler(user_handler(set_channel_visibility))
573 .add_request_handler(user_handler(rename_channel))
574 .add_request_handler(user_handler(join_channel_buffer))
575 .add_request_handler(user_handler(leave_channel_buffer))
576 .add_message_handler(user_message_handler(update_channel_buffer))
577 .add_request_handler(user_handler(rejoin_channel_buffers))
578 .add_request_handler(user_handler(get_channel_members))
579 .add_request_handler(user_handler(respond_to_channel_invite))
580 .add_request_handler(user_handler(join_channel))
581 .add_request_handler(user_handler(join_channel_chat))
582 .add_message_handler(user_message_handler(leave_channel_chat))
583 .add_request_handler(user_handler(send_channel_message))
584 .add_request_handler(user_handler(remove_channel_message))
585 .add_request_handler(user_handler(update_channel_message))
586 .add_request_handler(user_handler(get_channel_messages))
587 .add_request_handler(user_handler(get_channel_messages_by_id))
588 .add_request_handler(user_handler(get_notifications))
589 .add_request_handler(user_handler(mark_notification_as_read))
590 .add_request_handler(user_handler(move_channel))
591 .add_request_handler(user_handler(follow))
592 .add_message_handler(user_message_handler(unfollow))
593 .add_message_handler(user_message_handler(update_followers))
594 .add_request_handler(user_handler(get_private_user_info))
595 .add_message_handler(user_message_handler(acknowledge_channel_message))
596 .add_message_handler(user_message_handler(acknowledge_buffer_version))
597 .add_request_handler(user_handler(get_supermaven_api_key))
598 .add_streaming_request_handler({
599 let app_state = app_state.clone();
600 move |request, response, session| {
601 complete_with_language_model(
602 request,
603 response,
604 session,
605 app_state.config.openai_api_key.clone(),
606 app_state.config.google_ai_api_key.clone(),
607 app_state.config.anthropic_api_key.clone(),
608 )
609 }
610 })
611 .add_request_handler({
612 let app_state = app_state.clone();
613 user_handler(move |request, response, session| {
614 count_tokens_with_language_model(
615 request,
616 response,
617 session,
618 app_state.config.google_ai_api_key.clone(),
619 )
620 })
621 })
622 .add_request_handler({
623 user_handler(move |request, response, session| {
624 get_cached_embeddings(request, response, session)
625 })
626 })
627 .add_request_handler({
628 let app_state = app_state.clone();
629 user_handler(move |request, response, session| {
630 compute_embeddings(
631 request,
632 response,
633 session,
634 app_state.config.openai_api_key.clone(),
635 )
636 })
637 });
638
639 Arc::new(server)
640 }
641
642 pub async fn start(&self) -> Result<()> {
643 let server_id = *self.id.lock();
644 let app_state = self.app_state.clone();
645 let peer = self.peer.clone();
646 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
647 let pool = self.connection_pool.clone();
648 let live_kit_client = self.app_state.live_kit_client.clone();
649
650 let span = info_span!("start server");
651 self.app_state.executor.spawn_detached(
652 async move {
653 tracing::info!("waiting for cleanup timeout");
654 timeout.await;
655 tracing::info!("cleanup timeout expired, retrieving stale rooms");
656 if let Some((room_ids, channel_ids)) = app_state
657 .db
658 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
659 .await
660 .trace_err()
661 {
662 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
663 tracing::info!(
664 stale_channel_buffer_count = channel_ids.len(),
665 "retrieved stale channel buffers"
666 );
667
668 for channel_id in channel_ids {
669 if let Some(refreshed_channel_buffer) = app_state
670 .db
671 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
672 .await
673 .trace_err()
674 {
675 for connection_id in refreshed_channel_buffer.connection_ids {
676 peer.send(
677 connection_id,
678 proto::UpdateChannelBufferCollaborators {
679 channel_id: channel_id.to_proto(),
680 collaborators: refreshed_channel_buffer
681 .collaborators
682 .clone(),
683 },
684 )
685 .trace_err();
686 }
687 }
688 }
689
690 for room_id in room_ids {
691 let mut contacts_to_update = HashSet::default();
692 let mut canceled_calls_to_user_ids = Vec::new();
693 let mut live_kit_room = String::new();
694 let mut delete_live_kit_room = false;
695
696 if let Some(mut refreshed_room) = app_state
697 .db
698 .clear_stale_room_participants(room_id, server_id)
699 .await
700 .trace_err()
701 {
702 tracing::info!(
703 room_id = room_id.0,
704 new_participant_count = refreshed_room.room.participants.len(),
705 "refreshed room"
706 );
707 room_updated(&refreshed_room.room, &peer);
708 if let Some(channel) = refreshed_room.channel.as_ref() {
709 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
710 }
711 contacts_to_update
712 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
713 contacts_to_update
714 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
715 canceled_calls_to_user_ids =
716 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
717 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
718 delete_live_kit_room = refreshed_room.room.participants.is_empty();
719 }
720
721 {
722 let pool = pool.lock();
723 for canceled_user_id in canceled_calls_to_user_ids {
724 for connection_id in pool.user_connection_ids(canceled_user_id) {
725 peer.send(
726 connection_id,
727 proto::CallCanceled {
728 room_id: room_id.to_proto(),
729 },
730 )
731 .trace_err();
732 }
733 }
734 }
735
736 for user_id in contacts_to_update {
737 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
738 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
739 if let Some((busy, contacts)) = busy.zip(contacts) {
740 let pool = pool.lock();
741 let updated_contact = contact_for_user(user_id, busy, &pool);
742 for contact in contacts {
743 if let db::Contact::Accepted {
744 user_id: contact_user_id,
745 ..
746 } = contact
747 {
748 for contact_conn_id in
749 pool.user_connection_ids(contact_user_id)
750 {
751 peer.send(
752 contact_conn_id,
753 proto::UpdateContacts {
754 contacts: vec![updated_contact.clone()],
755 remove_contacts: Default::default(),
756 incoming_requests: Default::default(),
757 remove_incoming_requests: Default::default(),
758 outgoing_requests: Default::default(),
759 remove_outgoing_requests: Default::default(),
760 },
761 )
762 .trace_err();
763 }
764 }
765 }
766 }
767 }
768
769 if let Some(live_kit) = live_kit_client.as_ref() {
770 if delete_live_kit_room {
771 live_kit.delete_room(live_kit_room).await.trace_err();
772 }
773 }
774 }
775 }
776
777 app_state
778 .db
779 .delete_stale_servers(&app_state.config.zed_environment, server_id)
780 .await
781 .trace_err();
782 }
783 .instrument(span),
784 );
785 Ok(())
786 }
787
788 pub fn teardown(&self) {
789 self.peer.teardown();
790 self.connection_pool.lock().reset();
791 let _ = self.teardown.send(true);
792 }
793
794 #[cfg(test)]
795 pub fn reset(&self, id: ServerId) {
796 self.teardown();
797 *self.id.lock() = id;
798 self.peer.reset(id.0 as u32);
799 let _ = self.teardown.send(false);
800 }
801
802 #[cfg(test)]
803 pub fn id(&self) -> ServerId {
804 *self.id.lock()
805 }
806
807 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
808 where
809 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
810 Fut: 'static + Send + Future<Output = Result<()>>,
811 M: EnvelopedMessage,
812 {
813 let prev_handler = self.handlers.insert(
814 TypeId::of::<M>(),
815 Box::new(move |envelope, session| {
816 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
817 let received_at = envelope.received_at;
818 tracing::info!("message received");
819 let start_time = Instant::now();
820 let future = (handler)(*envelope, session);
821 async move {
822 let result = future.await;
823 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
824 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
825 let queue_duration_ms = total_duration_ms - processing_duration_ms;
826 let payload_type = M::NAME;
827
828 match result {
829 Err(error) => {
830 tracing::error!(
831 ?error,
832 total_duration_ms,
833 processing_duration_ms,
834 queue_duration_ms,
835 payload_type,
836 "error handling message"
837 )
838 }
839 Ok(()) => tracing::info!(
840 total_duration_ms,
841 processing_duration_ms,
842 queue_duration_ms,
843 "finished handling message"
844 ),
845 }
846 }
847 .boxed()
848 }),
849 );
850 if prev_handler.is_some() {
851 panic!("registered a handler for the same message twice");
852 }
853 self
854 }
855
856 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
857 where
858 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
859 Fut: 'static + Send + Future<Output = Result<()>>,
860 M: EnvelopedMessage,
861 {
862 self.add_handler(move |envelope, session| handler(envelope.payload, session));
863 self
864 }
865
866 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
869 Fut: Send + Future<Output = Result<()>>,
870 M: RequestMessage,
871 {
872 let handler = Arc::new(handler);
873 self.add_handler(move |envelope, session| {
874 let receipt = envelope.receipt();
875 let handler = handler.clone();
876 async move {
877 let peer = session.peer.clone();
878 let responded = Arc::new(AtomicBool::default());
879 let response = Response {
880 peer: peer.clone(),
881 responded: responded.clone(),
882 receipt,
883 };
884 match (handler)(envelope.payload, response, session).await {
885 Ok(()) => {
886 if responded.load(std::sync::atomic::Ordering::SeqCst) {
887 Ok(())
888 } else {
889 Err(anyhow!("handler did not send a response"))?
890 }
891 }
892 Err(error) => {
893 let proto_err = match &error {
894 Error::Internal(err) => err.to_proto(),
895 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
896 };
897 peer.respond_with_error(receipt, proto_err)?;
898 Err(error)
899 }
900 }
901 }
902 })
903 }
904
905 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
906 where
907 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
908 Fut: Send + Future<Output = Result<()>>,
909 M: RequestMessage,
910 {
911 let handler = Arc::new(handler);
912 self.add_handler(move |envelope, session| {
913 let receipt = envelope.receipt();
914 let handler = handler.clone();
915 async move {
916 let peer = session.peer.clone();
917 let response = StreamingResponse {
918 peer: peer.clone(),
919 receipt,
920 };
921 match (handler)(envelope.payload, response, session).await {
922 Ok(()) => {
923 peer.end_stream(receipt)?;
924 Ok(())
925 }
926 Err(error) => {
927 let proto_err = match &error {
928 Error::Internal(err) => err.to_proto(),
929 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
930 };
931 peer.respond_with_error(receipt, proto_err)?;
932 Err(error)
933 }
934 }
935 }
936 })
937 }
938
939 #[allow(clippy::too_many_arguments)]
940 pub fn handle_connection(
941 self: &Arc<Self>,
942 connection: Connection,
943 address: String,
944 principal: Principal,
945 zed_version: ZedVersion,
946 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
947 executor: Executor,
948 ) -> impl Future<Output = ()> {
949 let this = self.clone();
950 let span = info_span!("handle connection", %address,
951 connection_id=field::Empty,
952 user_id=field::Empty,
953 login=field::Empty,
954 impersonator=field::Empty,
955 dev_server_id=field::Empty
956 );
957 principal.update_span(&span);
958
959 let mut teardown = self.teardown.subscribe();
960 async move {
961 if *teardown.borrow() {
962 tracing::error!("server is tearing down");
963 return
964 }
965 let (connection_id, handle_io, mut incoming_rx) = this
966 .peer
967 .add_connection(connection, {
968 let executor = executor.clone();
969 move |duration| executor.sleep(duration)
970 });
971 tracing::Span::current().record("connection_id", format!("{}", connection_id));
972 tracing::info!("connection opened");
973
974 let http_client = match IsahcHttpClient::new() {
975 Ok(http_client) => Arc::new(http_client),
976 Err(error) => {
977 tracing::error!(?error, "failed to create HTTP client");
978 return;
979 }
980 };
981
982 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
983 Some(Arc::new(SupermavenAdminApi::new(
984 supermaven_admin_api_key.to_string(),
985 http_client.clone(),
986 )))
987 } else {
988 None
989 };
990
991 let session = Session {
992 principal: principal.clone(),
993 connection_id,
994 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
995 peer: this.peer.clone(),
996 connection_pool: this.connection_pool.clone(),
997 live_kit_client: this.app_state.live_kit_client.clone(),
998 http_client,
999 rate_limiter: this.app_state.rate_limiter.clone(),
1000 _executor: executor.clone(),
1001 supermaven_client,
1002 };
1003
1004 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1005 tracing::error!(?error, "failed to send initial client update");
1006 return;
1007 }
1008
1009 let handle_io = handle_io.fuse();
1010 futures::pin_mut!(handle_io);
1011
1012 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1013 // This prevents deadlocks when e.g., client A performs a request to client B and
1014 // client B performs a request to client A. If both clients stop processing further
1015 // messages until their respective request completes, they won't have a chance to
1016 // respond to the other client's request and cause a deadlock.
1017 //
1018 // This arrangement ensures we will attempt to process earlier messages first, but fall
1019 // back to processing messages arrived later in the spirit of making progress.
1020 let mut foreground_message_handlers = FuturesUnordered::new();
1021 let concurrent_handlers = Arc::new(Semaphore::new(256));
1022 loop {
1023 let next_message = async {
1024 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1025 let message = incoming_rx.next().await;
1026 (permit, message)
1027 }.fuse();
1028 futures::pin_mut!(next_message);
1029 futures::select_biased! {
1030 _ = teardown.changed().fuse() => return,
1031 result = handle_io => {
1032 if let Err(error) = result {
1033 tracing::error!(?error, "error handling I/O");
1034 }
1035 break;
1036 }
1037 _ = foreground_message_handlers.next() => {}
1038 next_message = next_message => {
1039 let (permit, message) = next_message;
1040 if let Some(message) = message {
1041 let type_name = message.payload_type_name();
1042 // note: we copy all the fields from the parent span so we can query them in the logs.
1043 // (https://github.com/tokio-rs/tracing/issues/2670).
1044 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1045 user_id=field::Empty,
1046 login=field::Empty,
1047 impersonator=field::Empty,
1048 dev_server_id=field::Empty
1049 );
1050 principal.update_span(&span);
1051 let span_enter = span.enter();
1052 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1053 let is_background = message.is_background();
1054 let handle_message = (handler)(message, session.clone());
1055 drop(span_enter);
1056
1057 let handle_message = async move {
1058 handle_message.await;
1059 drop(permit);
1060 }.instrument(span);
1061 if is_background {
1062 executor.spawn_detached(handle_message);
1063 } else {
1064 foreground_message_handlers.push(handle_message);
1065 }
1066 } else {
1067 tracing::error!("no message handler");
1068 }
1069 } else {
1070 tracing::info!("connection closed");
1071 break;
1072 }
1073 }
1074 }
1075 }
1076
1077 drop(foreground_message_handlers);
1078 tracing::info!("signing out");
1079 if let Err(error) = connection_lost(session, teardown, executor).await {
1080 tracing::error!(?error, "error signing out");
1081 }
1082
1083 }.instrument(span)
1084 }
1085
1086 async fn send_initial_client_update(
1087 &self,
1088 connection_id: ConnectionId,
1089 principal: &Principal,
1090 zed_version: ZedVersion,
1091 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1092 session: &Session,
1093 ) -> Result<()> {
1094 self.peer.send(
1095 connection_id,
1096 proto::Hello {
1097 peer_id: Some(connection_id.into()),
1098 },
1099 )?;
1100 tracing::info!("sent hello message");
1101 if let Some(send_connection_id) = send_connection_id.take() {
1102 let _ = send_connection_id.send(connection_id);
1103 }
1104
1105 match principal {
1106 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1107 if !user.connected_once {
1108 self.peer.send(connection_id, proto::ShowContacts {})?;
1109 self.app_state
1110 .db
1111 .set_user_connected_once(user.id, true)
1112 .await?;
1113 }
1114
1115 let (contacts, dev_server_projects) = future::try_join(
1116 self.app_state.db.get_contacts(user.id),
1117 self.app_state.db.dev_server_projects_update(user.id),
1118 )
1119 .await?;
1120
1121 {
1122 let mut pool = self.connection_pool.lock();
1123 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1124 self.peer.send(
1125 connection_id,
1126 build_initial_contacts_update(contacts, &pool),
1127 )?;
1128 }
1129
1130 if should_auto_subscribe_to_channels(zed_version) {
1131 subscribe_user_to_channels(user.id, session).await?;
1132 }
1133
1134 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1135
1136 if let Some(incoming_call) =
1137 self.app_state.db.incoming_call_for_user(user.id).await?
1138 {
1139 self.peer.send(connection_id, incoming_call)?;
1140 }
1141
1142 update_user_contacts(user.id, &session).await?;
1143 }
1144 Principal::DevServer(dev_server) => {
1145 {
1146 let mut pool = self.connection_pool.lock();
1147 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1148 {
1149 self.peer.send(
1150 stale_connection_id,
1151 proto::ShutdownDevServer {
1152 reason: Some(
1153 "another dev server connected with the same token".to_string(),
1154 ),
1155 },
1156 )?;
1157 pool.remove_connection(stale_connection_id)?;
1158 };
1159 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1160 }
1161
1162 let projects = self
1163 .app_state
1164 .db
1165 .get_projects_for_dev_server(dev_server.id)
1166 .await?;
1167 self.peer
1168 .send(connection_id, proto::DevServerInstructions { projects })?;
1169
1170 let status = self
1171 .app_state
1172 .db
1173 .dev_server_projects_update(dev_server.user_id)
1174 .await?;
1175 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1176 }
1177 }
1178
1179 Ok(())
1180 }
1181
1182 pub async fn invite_code_redeemed(
1183 self: &Arc<Self>,
1184 inviter_id: UserId,
1185 invitee_id: UserId,
1186 ) -> Result<()> {
1187 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1188 if let Some(code) = &user.invite_code {
1189 let pool = self.connection_pool.lock();
1190 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1191 for connection_id in pool.user_connection_ids(inviter_id) {
1192 self.peer.send(
1193 connection_id,
1194 proto::UpdateContacts {
1195 contacts: vec![invitee_contact.clone()],
1196 ..Default::default()
1197 },
1198 )?;
1199 self.peer.send(
1200 connection_id,
1201 proto::UpdateInviteInfo {
1202 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1203 count: user.invite_count as u32,
1204 },
1205 )?;
1206 }
1207 }
1208 }
1209 Ok(())
1210 }
1211
1212 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1213 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1214 if let Some(invite_code) = &user.invite_code {
1215 let pool = self.connection_pool.lock();
1216 for connection_id in pool.user_connection_ids(user_id) {
1217 self.peer.send(
1218 connection_id,
1219 proto::UpdateInviteInfo {
1220 url: format!(
1221 "{}{}",
1222 self.app_state.config.invite_link_prefix, invite_code
1223 ),
1224 count: user.invite_count as u32,
1225 },
1226 )?;
1227 }
1228 }
1229 }
1230 Ok(())
1231 }
1232
1233 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1234 ServerSnapshot {
1235 connection_pool: ConnectionPoolGuard {
1236 guard: self.connection_pool.lock(),
1237 _not_send: PhantomData,
1238 },
1239 peer: &self.peer,
1240 }
1241 }
1242}
1243
1244impl<'a> Deref for ConnectionPoolGuard<'a> {
1245 type Target = ConnectionPool;
1246
1247 fn deref(&self) -> &Self::Target {
1248 &self.guard
1249 }
1250}
1251
1252impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1253 fn deref_mut(&mut self) -> &mut Self::Target {
1254 &mut self.guard
1255 }
1256}
1257
1258impl<'a> Drop for ConnectionPoolGuard<'a> {
1259 fn drop(&mut self) {
1260 #[cfg(test)]
1261 self.check_invariants();
1262 }
1263}
1264
1265fn broadcast<F>(
1266 sender_id: Option<ConnectionId>,
1267 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1268 mut f: F,
1269) where
1270 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1271{
1272 for receiver_id in receiver_ids {
1273 if Some(receiver_id) != sender_id {
1274 if let Err(error) = f(receiver_id) {
1275 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1276 }
1277 }
1278 }
1279}
1280
1281pub struct ProtocolVersion(u32);
1282
1283impl Header for ProtocolVersion {
1284 fn name() -> &'static HeaderName {
1285 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1286 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1287 }
1288
1289 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1290 where
1291 Self: Sized,
1292 I: Iterator<Item = &'i axum::http::HeaderValue>,
1293 {
1294 let version = values
1295 .next()
1296 .ok_or_else(axum::headers::Error::invalid)?
1297 .to_str()
1298 .map_err(|_| axum::headers::Error::invalid())?
1299 .parse()
1300 .map_err(|_| axum::headers::Error::invalid())?;
1301 Ok(Self(version))
1302 }
1303
1304 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1305 values.extend([self.0.to_string().parse().unwrap()]);
1306 }
1307}
1308
1309pub struct AppVersionHeader(SemanticVersion);
1310impl Header for AppVersionHeader {
1311 fn name() -> &'static HeaderName {
1312 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1313 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1314 }
1315
1316 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1317 where
1318 Self: Sized,
1319 I: Iterator<Item = &'i axum::http::HeaderValue>,
1320 {
1321 let version = values
1322 .next()
1323 .ok_or_else(axum::headers::Error::invalid)?
1324 .to_str()
1325 .map_err(|_| axum::headers::Error::invalid())?
1326 .parse()
1327 .map_err(|_| axum::headers::Error::invalid())?;
1328 Ok(Self(version))
1329 }
1330
1331 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1332 values.extend([self.0.to_string().parse().unwrap()]);
1333 }
1334}
1335
1336pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1337 Router::new()
1338 .route("/rpc", get(handle_websocket_request))
1339 .layer(
1340 ServiceBuilder::new()
1341 .layer(Extension(server.app_state.clone()))
1342 .layer(middleware::from_fn(auth::validate_header)),
1343 )
1344 .route("/metrics", get(handle_metrics))
1345 .layer(Extension(server))
1346}
1347
1348pub async fn handle_websocket_request(
1349 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1350 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1351 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1352 Extension(server): Extension<Arc<Server>>,
1353 Extension(principal): Extension<Principal>,
1354 ws: WebSocketUpgrade,
1355) -> axum::response::Response {
1356 if protocol_version != rpc::PROTOCOL_VERSION {
1357 return (
1358 StatusCode::UPGRADE_REQUIRED,
1359 "client must be upgraded".to_string(),
1360 )
1361 .into_response();
1362 }
1363
1364 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1365 return (
1366 StatusCode::UPGRADE_REQUIRED,
1367 "no version header found".to_string(),
1368 )
1369 .into_response();
1370 };
1371
1372 if !version.can_collaborate() {
1373 return (
1374 StatusCode::UPGRADE_REQUIRED,
1375 "client must be upgraded".to_string(),
1376 )
1377 .into_response();
1378 }
1379
1380 let socket_address = socket_address.to_string();
1381 ws.on_upgrade(move |socket| {
1382 let socket = socket
1383 .map_ok(to_tungstenite_message)
1384 .err_into()
1385 .with(|message| async move { Ok(to_axum_message(message)) });
1386 let connection = Connection::new(Box::pin(socket));
1387 async move {
1388 server
1389 .handle_connection(
1390 connection,
1391 socket_address,
1392 principal,
1393 version,
1394 None,
1395 Executor::Production,
1396 )
1397 .await;
1398 }
1399 })
1400}
1401
1402pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1403 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404 let connections_metric = CONNECTIONS_METRIC
1405 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1406
1407 let connections = server
1408 .connection_pool
1409 .lock()
1410 .connections()
1411 .filter(|connection| !connection.admin)
1412 .count();
1413 connections_metric.set(connections as _);
1414
1415 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1416 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1417 register_int_gauge!(
1418 "shared_projects",
1419 "number of open projects with one or more guests"
1420 )
1421 .unwrap()
1422 });
1423
1424 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1425 shared_projects_metric.set(shared_projects as _);
1426
1427 let encoder = prometheus::TextEncoder::new();
1428 let metric_families = prometheus::gather();
1429 let encoded_metrics = encoder
1430 .encode_to_string(&metric_families)
1431 .map_err(|err| anyhow!("{}", err))?;
1432 Ok(encoded_metrics)
1433}
1434
1435#[instrument(err, skip(executor))]
1436async fn connection_lost(
1437 session: Session,
1438 mut teardown: watch::Receiver<bool>,
1439 executor: Executor,
1440) -> Result<()> {
1441 session.peer.disconnect(session.connection_id);
1442 session
1443 .connection_pool()
1444 .await
1445 .remove_connection(session.connection_id)?;
1446
1447 session
1448 .db()
1449 .await
1450 .connection_lost(session.connection_id)
1451 .await
1452 .trace_err();
1453
1454 futures::select_biased! {
1455 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1456 match &session.principal {
1457 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1458 let session = session.for_user().unwrap();
1459
1460 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1461 leave_room_for_session(&session, session.connection_id).await.trace_err();
1462 leave_channel_buffers_for_session(&session)
1463 .await
1464 .trace_err();
1465
1466 if !session
1467 .connection_pool()
1468 .await
1469 .is_user_online(session.user_id())
1470 {
1471 let db = session.db().await;
1472 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1473 room_updated(&room, &session.peer);
1474 }
1475 }
1476
1477 update_user_contacts(session.user_id(), &session).await?;
1478 },
1479 Principal::DevServer(_) => {
1480 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1481 },
1482 }
1483 },
1484 _ = teardown.changed().fuse() => {}
1485 }
1486
1487 Ok(())
1488}
1489
1490/// Acknowledges a ping from a client, used to keep the connection alive.
1491async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1492 response.send(proto::Ack {})?;
1493 Ok(())
1494}
1495
1496/// Creates a new room for calling (outside of channels)
1497async fn create_room(
1498 _request: proto::CreateRoom,
1499 response: Response<proto::CreateRoom>,
1500 session: UserSession,
1501) -> Result<()> {
1502 let live_kit_room = nanoid::nanoid!(30);
1503
1504 let live_kit_connection_info = util::maybe!(async {
1505 let live_kit = session.live_kit_client.as_ref();
1506 let live_kit = live_kit?;
1507 let user_id = session.user_id().to_string();
1508
1509 let token = live_kit
1510 .room_token(&live_kit_room, &user_id.to_string())
1511 .trace_err()?;
1512
1513 Some(proto::LiveKitConnectionInfo {
1514 server_url: live_kit.url().into(),
1515 token,
1516 can_publish: true,
1517 })
1518 })
1519 .await;
1520
1521 let room = session
1522 .db()
1523 .await
1524 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1525 .await?;
1526
1527 response.send(proto::CreateRoomResponse {
1528 room: Some(room.clone()),
1529 live_kit_connection_info,
1530 })?;
1531
1532 update_user_contacts(session.user_id(), &session).await?;
1533 Ok(())
1534}
1535
1536/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1537async fn join_room(
1538 request: proto::JoinRoom,
1539 response: Response<proto::JoinRoom>,
1540 session: UserSession,
1541) -> Result<()> {
1542 let room_id = RoomId::from_proto(request.id);
1543
1544 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1545
1546 if let Some(channel_id) = channel_id {
1547 return join_channel_internal(channel_id, Box::new(response), session).await;
1548 }
1549
1550 let joined_room = {
1551 let room = session
1552 .db()
1553 .await
1554 .join_room(room_id, session.user_id(), session.connection_id)
1555 .await?;
1556 room_updated(&room.room, &session.peer);
1557 room.into_inner()
1558 };
1559
1560 for connection_id in session
1561 .connection_pool()
1562 .await
1563 .user_connection_ids(session.user_id())
1564 {
1565 session
1566 .peer
1567 .send(
1568 connection_id,
1569 proto::CallCanceled {
1570 room_id: room_id.to_proto(),
1571 },
1572 )
1573 .trace_err();
1574 }
1575
1576 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1577 if let Some(token) = live_kit
1578 .room_token(
1579 &joined_room.room.live_kit_room,
1580 &session.user_id().to_string(),
1581 )
1582 .trace_err()
1583 {
1584 Some(proto::LiveKitConnectionInfo {
1585 server_url: live_kit.url().into(),
1586 token,
1587 can_publish: true,
1588 })
1589 } else {
1590 None
1591 }
1592 } else {
1593 None
1594 };
1595
1596 response.send(proto::JoinRoomResponse {
1597 room: Some(joined_room.room),
1598 channel_id: None,
1599 live_kit_connection_info,
1600 })?;
1601
1602 update_user_contacts(session.user_id(), &session).await?;
1603 Ok(())
1604}
1605
1606/// Rejoin room is used to reconnect to a room after connection errors.
1607async fn rejoin_room(
1608 request: proto::RejoinRoom,
1609 response: Response<proto::RejoinRoom>,
1610 session: UserSession,
1611) -> Result<()> {
1612 let room;
1613 let channel;
1614 {
1615 let mut rejoined_room = session
1616 .db()
1617 .await
1618 .rejoin_room(request, session.user_id(), session.connection_id)
1619 .await?;
1620
1621 response.send(proto::RejoinRoomResponse {
1622 room: Some(rejoined_room.room.clone()),
1623 reshared_projects: rejoined_room
1624 .reshared_projects
1625 .iter()
1626 .map(|project| proto::ResharedProject {
1627 id: project.id.to_proto(),
1628 collaborators: project
1629 .collaborators
1630 .iter()
1631 .map(|collaborator| collaborator.to_proto())
1632 .collect(),
1633 })
1634 .collect(),
1635 rejoined_projects: rejoined_room
1636 .rejoined_projects
1637 .iter()
1638 .map(|rejoined_project| rejoined_project.to_proto())
1639 .collect(),
1640 })?;
1641 room_updated(&rejoined_room.room, &session.peer);
1642
1643 for project in &rejoined_room.reshared_projects {
1644 for collaborator in &project.collaborators {
1645 session
1646 .peer
1647 .send(
1648 collaborator.connection_id,
1649 proto::UpdateProjectCollaborator {
1650 project_id: project.id.to_proto(),
1651 old_peer_id: Some(project.old_connection_id.into()),
1652 new_peer_id: Some(session.connection_id.into()),
1653 },
1654 )
1655 .trace_err();
1656 }
1657
1658 broadcast(
1659 Some(session.connection_id),
1660 project
1661 .collaborators
1662 .iter()
1663 .map(|collaborator| collaborator.connection_id),
1664 |connection_id| {
1665 session.peer.forward_send(
1666 session.connection_id,
1667 connection_id,
1668 proto::UpdateProject {
1669 project_id: project.id.to_proto(),
1670 worktrees: project.worktrees.clone(),
1671 },
1672 )
1673 },
1674 );
1675 }
1676
1677 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1678
1679 let rejoined_room = rejoined_room.into_inner();
1680
1681 room = rejoined_room.room;
1682 channel = rejoined_room.channel;
1683 }
1684
1685 if let Some(channel) = channel {
1686 channel_updated(
1687 &channel,
1688 &room,
1689 &session.peer,
1690 &*session.connection_pool().await,
1691 );
1692 }
1693
1694 update_user_contacts(session.user_id(), &session).await?;
1695 Ok(())
1696}
1697
1698fn notify_rejoined_projects(
1699 rejoined_projects: &mut Vec<RejoinedProject>,
1700 session: &UserSession,
1701) -> Result<()> {
1702 for project in rejoined_projects.iter() {
1703 for collaborator in &project.collaborators {
1704 session
1705 .peer
1706 .send(
1707 collaborator.connection_id,
1708 proto::UpdateProjectCollaborator {
1709 project_id: project.id.to_proto(),
1710 old_peer_id: Some(project.old_connection_id.into()),
1711 new_peer_id: Some(session.connection_id.into()),
1712 },
1713 )
1714 .trace_err();
1715 }
1716 }
1717
1718 for project in rejoined_projects {
1719 for worktree in mem::take(&mut project.worktrees) {
1720 #[cfg(any(test, feature = "test-support"))]
1721 const MAX_CHUNK_SIZE: usize = 2;
1722 #[cfg(not(any(test, feature = "test-support")))]
1723 const MAX_CHUNK_SIZE: usize = 256;
1724
1725 // Stream this worktree's entries.
1726 let message = proto::UpdateWorktree {
1727 project_id: project.id.to_proto(),
1728 worktree_id: worktree.id,
1729 abs_path: worktree.abs_path.clone(),
1730 root_name: worktree.root_name,
1731 updated_entries: worktree.updated_entries,
1732 removed_entries: worktree.removed_entries,
1733 scan_id: worktree.scan_id,
1734 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1735 updated_repositories: worktree.updated_repositories,
1736 removed_repositories: worktree.removed_repositories,
1737 };
1738 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1739 session.peer.send(session.connection_id, update.clone())?;
1740 }
1741
1742 // Stream this worktree's diagnostics.
1743 for summary in worktree.diagnostic_summaries {
1744 session.peer.send(
1745 session.connection_id,
1746 proto::UpdateDiagnosticSummary {
1747 project_id: project.id.to_proto(),
1748 worktree_id: worktree.id,
1749 summary: Some(summary),
1750 },
1751 )?;
1752 }
1753
1754 for settings_file in worktree.settings_files {
1755 session.peer.send(
1756 session.connection_id,
1757 proto::UpdateWorktreeSettings {
1758 project_id: project.id.to_proto(),
1759 worktree_id: worktree.id,
1760 path: settings_file.path,
1761 content: Some(settings_file.content),
1762 },
1763 )?;
1764 }
1765 }
1766
1767 for language_server in &project.language_servers {
1768 session.peer.send(
1769 session.connection_id,
1770 proto::UpdateLanguageServer {
1771 project_id: project.id.to_proto(),
1772 language_server_id: language_server.id,
1773 variant: Some(
1774 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1775 proto::LspDiskBasedDiagnosticsUpdated {},
1776 ),
1777 ),
1778 },
1779 )?;
1780 }
1781 }
1782 Ok(())
1783}
1784
1785/// leave room disconnects from the room.
1786async fn leave_room(
1787 _: proto::LeaveRoom,
1788 response: Response<proto::LeaveRoom>,
1789 session: UserSession,
1790) -> Result<()> {
1791 leave_room_for_session(&session, session.connection_id).await?;
1792 response.send(proto::Ack {})?;
1793 Ok(())
1794}
1795
1796/// Updates the permissions of someone else in the room.
1797async fn set_room_participant_role(
1798 request: proto::SetRoomParticipantRole,
1799 response: Response<proto::SetRoomParticipantRole>,
1800 session: UserSession,
1801) -> Result<()> {
1802 let user_id = UserId::from_proto(request.user_id);
1803 let role = ChannelRole::from(request.role());
1804
1805 let (live_kit_room, can_publish) = {
1806 let room = session
1807 .db()
1808 .await
1809 .set_room_participant_role(
1810 session.user_id(),
1811 RoomId::from_proto(request.room_id),
1812 user_id,
1813 role,
1814 )
1815 .await?;
1816
1817 let live_kit_room = room.live_kit_room.clone();
1818 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1819 room_updated(&room, &session.peer);
1820 (live_kit_room, can_publish)
1821 };
1822
1823 if let Some(live_kit) = session.live_kit_client.as_ref() {
1824 live_kit
1825 .update_participant(
1826 live_kit_room.clone(),
1827 request.user_id.to_string(),
1828 live_kit_server::proto::ParticipantPermission {
1829 can_subscribe: true,
1830 can_publish,
1831 can_publish_data: can_publish,
1832 hidden: false,
1833 recorder: false,
1834 },
1835 )
1836 .await
1837 .trace_err();
1838 }
1839
1840 response.send(proto::Ack {})?;
1841 Ok(())
1842}
1843
1844/// Call someone else into the current room
1845async fn call(
1846 request: proto::Call,
1847 response: Response<proto::Call>,
1848 session: UserSession,
1849) -> Result<()> {
1850 let room_id = RoomId::from_proto(request.room_id);
1851 let calling_user_id = session.user_id();
1852 let calling_connection_id = session.connection_id;
1853 let called_user_id = UserId::from_proto(request.called_user_id);
1854 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1855 if !session
1856 .db()
1857 .await
1858 .has_contact(calling_user_id, called_user_id)
1859 .await?
1860 {
1861 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1862 }
1863
1864 let incoming_call = {
1865 let (room, incoming_call) = &mut *session
1866 .db()
1867 .await
1868 .call(
1869 room_id,
1870 calling_user_id,
1871 calling_connection_id,
1872 called_user_id,
1873 initial_project_id,
1874 )
1875 .await?;
1876 room_updated(&room, &session.peer);
1877 mem::take(incoming_call)
1878 };
1879 update_user_contacts(called_user_id, &session).await?;
1880
1881 let mut calls = session
1882 .connection_pool()
1883 .await
1884 .user_connection_ids(called_user_id)
1885 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1886 .collect::<FuturesUnordered<_>>();
1887
1888 while let Some(call_response) = calls.next().await {
1889 match call_response.as_ref() {
1890 Ok(_) => {
1891 response.send(proto::Ack {})?;
1892 return Ok(());
1893 }
1894 Err(_) => {
1895 call_response.trace_err();
1896 }
1897 }
1898 }
1899
1900 {
1901 let room = session
1902 .db()
1903 .await
1904 .call_failed(room_id, called_user_id)
1905 .await?;
1906 room_updated(&room, &session.peer);
1907 }
1908 update_user_contacts(called_user_id, &session).await?;
1909
1910 Err(anyhow!("failed to ring user"))?
1911}
1912
1913/// Cancel an outgoing call.
1914async fn cancel_call(
1915 request: proto::CancelCall,
1916 response: Response<proto::CancelCall>,
1917 session: UserSession,
1918) -> Result<()> {
1919 let called_user_id = UserId::from_proto(request.called_user_id);
1920 let room_id = RoomId::from_proto(request.room_id);
1921 {
1922 let room = session
1923 .db()
1924 .await
1925 .cancel_call(room_id, session.connection_id, called_user_id)
1926 .await?;
1927 room_updated(&room, &session.peer);
1928 }
1929
1930 for connection_id in session
1931 .connection_pool()
1932 .await
1933 .user_connection_ids(called_user_id)
1934 {
1935 session
1936 .peer
1937 .send(
1938 connection_id,
1939 proto::CallCanceled {
1940 room_id: room_id.to_proto(),
1941 },
1942 )
1943 .trace_err();
1944 }
1945 response.send(proto::Ack {})?;
1946
1947 update_user_contacts(called_user_id, &session).await?;
1948 Ok(())
1949}
1950
1951/// Decline an incoming call.
1952async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1953 let room_id = RoomId::from_proto(message.room_id);
1954 {
1955 let room = session
1956 .db()
1957 .await
1958 .decline_call(Some(room_id), session.user_id())
1959 .await?
1960 .ok_or_else(|| anyhow!("failed to decline call"))?;
1961 room_updated(&room, &session.peer);
1962 }
1963
1964 for connection_id in session
1965 .connection_pool()
1966 .await
1967 .user_connection_ids(session.user_id())
1968 {
1969 session
1970 .peer
1971 .send(
1972 connection_id,
1973 proto::CallCanceled {
1974 room_id: room_id.to_proto(),
1975 },
1976 )
1977 .trace_err();
1978 }
1979 update_user_contacts(session.user_id(), &session).await?;
1980 Ok(())
1981}
1982
1983/// Updates other participants in the room with your current location.
1984async fn update_participant_location(
1985 request: proto::UpdateParticipantLocation,
1986 response: Response<proto::UpdateParticipantLocation>,
1987 session: UserSession,
1988) -> Result<()> {
1989 let room_id = RoomId::from_proto(request.room_id);
1990 let location = request
1991 .location
1992 .ok_or_else(|| anyhow!("invalid location"))?;
1993
1994 let db = session.db().await;
1995 let room = db
1996 .update_room_participant_location(room_id, session.connection_id, location)
1997 .await?;
1998
1999 room_updated(&room, &session.peer);
2000 response.send(proto::Ack {})?;
2001 Ok(())
2002}
2003
2004/// Share a project into the room.
2005async fn share_project(
2006 request: proto::ShareProject,
2007 response: Response<proto::ShareProject>,
2008 session: UserSession,
2009) -> Result<()> {
2010 let (project_id, room) = &*session
2011 .db()
2012 .await
2013 .share_project(
2014 RoomId::from_proto(request.room_id),
2015 session.connection_id,
2016 &request.worktrees,
2017 request
2018 .dev_server_project_id
2019 .map(|id| DevServerProjectId::from_proto(id)),
2020 )
2021 .await?;
2022 response.send(proto::ShareProjectResponse {
2023 project_id: project_id.to_proto(),
2024 })?;
2025 room_updated(&room, &session.peer);
2026
2027 Ok(())
2028}
2029
2030/// Unshare a project from the room.
2031async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2032 let project_id = ProjectId::from_proto(message.project_id);
2033 unshare_project_internal(
2034 project_id,
2035 session.connection_id,
2036 session.user_id(),
2037 &session,
2038 )
2039 .await
2040}
2041
2042async fn unshare_project_internal(
2043 project_id: ProjectId,
2044 connection_id: ConnectionId,
2045 user_id: Option<UserId>,
2046 session: &Session,
2047) -> Result<()> {
2048 let delete = {
2049 let room_guard = session
2050 .db()
2051 .await
2052 .unshare_project(project_id, connection_id, user_id)
2053 .await?;
2054
2055 let (delete, room, guest_connection_ids) = &*room_guard;
2056
2057 let message = proto::UnshareProject {
2058 project_id: project_id.to_proto(),
2059 };
2060
2061 broadcast(
2062 Some(connection_id),
2063 guest_connection_ids.iter().copied(),
2064 |conn_id| session.peer.send(conn_id, message.clone()),
2065 );
2066 if let Some(room) = room {
2067 room_updated(room, &session.peer);
2068 }
2069
2070 *delete
2071 };
2072
2073 if delete {
2074 let db = session.db().await;
2075 db.delete_project(project_id).await?;
2076 }
2077
2078 Ok(())
2079}
2080
2081/// DevServer makes a project available online
2082async fn share_dev_server_project(
2083 request: proto::ShareDevServerProject,
2084 response: Response<proto::ShareDevServerProject>,
2085 session: DevServerSession,
2086) -> Result<()> {
2087 let (dev_server_project, user_id, status) = session
2088 .db()
2089 .await
2090 .share_dev_server_project(
2091 DevServerProjectId::from_proto(request.dev_server_project_id),
2092 session.dev_server_id(),
2093 session.connection_id,
2094 &request.worktrees,
2095 )
2096 .await?;
2097 let Some(project_id) = dev_server_project.project_id else {
2098 return Err(anyhow!("failed to share remote project"))?;
2099 };
2100
2101 send_dev_server_projects_update(user_id, status, &session).await;
2102
2103 response.send(proto::ShareProjectResponse { project_id })?;
2104
2105 Ok(())
2106}
2107
2108/// Join someone elses shared project.
2109async fn join_project(
2110 request: proto::JoinProject,
2111 response: Response<proto::JoinProject>,
2112 session: UserSession,
2113) -> Result<()> {
2114 let project_id = ProjectId::from_proto(request.project_id);
2115
2116 tracing::info!(%project_id, "join project");
2117
2118 let db = session.db().await;
2119 let (project, replica_id) = &mut *db
2120 .join_project(project_id, session.connection_id, session.user_id())
2121 .await?;
2122 drop(db);
2123 tracing::info!(%project_id, "join remote project");
2124 join_project_internal(response, session, project, replica_id)
2125}
2126
2127trait JoinProjectInternalResponse {
2128 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2129}
2130impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2131 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2132 Response::<proto::JoinProject>::send(self, result)
2133 }
2134}
2135impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2136 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2137 Response::<proto::JoinHostedProject>::send(self, result)
2138 }
2139}
2140
2141fn join_project_internal(
2142 response: impl JoinProjectInternalResponse,
2143 session: UserSession,
2144 project: &mut Project,
2145 replica_id: &ReplicaId,
2146) -> Result<()> {
2147 let collaborators = project
2148 .collaborators
2149 .iter()
2150 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2151 .map(|collaborator| collaborator.to_proto())
2152 .collect::<Vec<_>>();
2153 let project_id = project.id;
2154 let guest_user_id = session.user_id();
2155
2156 let worktrees = project
2157 .worktrees
2158 .iter()
2159 .map(|(id, worktree)| proto::WorktreeMetadata {
2160 id: *id,
2161 root_name: worktree.root_name.clone(),
2162 visible: worktree.visible,
2163 abs_path: worktree.abs_path.clone(),
2164 })
2165 .collect::<Vec<_>>();
2166
2167 let add_project_collaborator = proto::AddProjectCollaborator {
2168 project_id: project_id.to_proto(),
2169 collaborator: Some(proto::Collaborator {
2170 peer_id: Some(session.connection_id.into()),
2171 replica_id: replica_id.0 as u32,
2172 user_id: guest_user_id.to_proto(),
2173 }),
2174 };
2175
2176 for collaborator in &collaborators {
2177 session
2178 .peer
2179 .send(
2180 collaborator.peer_id.unwrap().into(),
2181 add_project_collaborator.clone(),
2182 )
2183 .trace_err();
2184 }
2185
2186 // First, we send the metadata associated with each worktree.
2187 response.send(proto::JoinProjectResponse {
2188 project_id: project.id.0 as u64,
2189 worktrees: worktrees.clone(),
2190 replica_id: replica_id.0 as u32,
2191 collaborators: collaborators.clone(),
2192 language_servers: project.language_servers.clone(),
2193 role: project.role.into(),
2194 dev_server_project_id: project
2195 .dev_server_project_id
2196 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2197 })?;
2198
2199 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2200 #[cfg(any(test, feature = "test-support"))]
2201 const MAX_CHUNK_SIZE: usize = 2;
2202 #[cfg(not(any(test, feature = "test-support")))]
2203 const MAX_CHUNK_SIZE: usize = 256;
2204
2205 // Stream this worktree's entries.
2206 let message = proto::UpdateWorktree {
2207 project_id: project_id.to_proto(),
2208 worktree_id,
2209 abs_path: worktree.abs_path.clone(),
2210 root_name: worktree.root_name,
2211 updated_entries: worktree.entries,
2212 removed_entries: Default::default(),
2213 scan_id: worktree.scan_id,
2214 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2215 updated_repositories: worktree.repository_entries.into_values().collect(),
2216 removed_repositories: Default::default(),
2217 };
2218 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2219 session.peer.send(session.connection_id, update.clone())?;
2220 }
2221
2222 // Stream this worktree's diagnostics.
2223 for summary in worktree.diagnostic_summaries {
2224 session.peer.send(
2225 session.connection_id,
2226 proto::UpdateDiagnosticSummary {
2227 project_id: project_id.to_proto(),
2228 worktree_id: worktree.id,
2229 summary: Some(summary),
2230 },
2231 )?;
2232 }
2233
2234 for settings_file in worktree.settings_files {
2235 session.peer.send(
2236 session.connection_id,
2237 proto::UpdateWorktreeSettings {
2238 project_id: project_id.to_proto(),
2239 worktree_id: worktree.id,
2240 path: settings_file.path,
2241 content: Some(settings_file.content),
2242 },
2243 )?;
2244 }
2245 }
2246
2247 for language_server in &project.language_servers {
2248 session.peer.send(
2249 session.connection_id,
2250 proto::UpdateLanguageServer {
2251 project_id: project_id.to_proto(),
2252 language_server_id: language_server.id,
2253 variant: Some(
2254 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2255 proto::LspDiskBasedDiagnosticsUpdated {},
2256 ),
2257 ),
2258 },
2259 )?;
2260 }
2261
2262 Ok(())
2263}
2264
2265/// Leave someone elses shared project.
2266async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2267 let sender_id = session.connection_id;
2268 let project_id = ProjectId::from_proto(request.project_id);
2269 let db = session.db().await;
2270 if db.is_hosted_project(project_id).await? {
2271 let project = db.leave_hosted_project(project_id, sender_id).await?;
2272 project_left(&project, &session);
2273 return Ok(());
2274 }
2275
2276 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2277 tracing::info!(
2278 %project_id,
2279 "leave project"
2280 );
2281
2282 project_left(&project, &session);
2283 if let Some(room) = room {
2284 room_updated(&room, &session.peer);
2285 }
2286
2287 Ok(())
2288}
2289
2290async fn join_hosted_project(
2291 request: proto::JoinHostedProject,
2292 response: Response<proto::JoinHostedProject>,
2293 session: UserSession,
2294) -> Result<()> {
2295 let (mut project, replica_id) = session
2296 .db()
2297 .await
2298 .join_hosted_project(
2299 ProjectId(request.project_id as i32),
2300 session.user_id(),
2301 session.connection_id,
2302 )
2303 .await?;
2304
2305 join_project_internal(response, session, &mut project, &replica_id)
2306}
2307
2308async fn create_dev_server_project(
2309 request: proto::CreateDevServerProject,
2310 response: Response<proto::CreateDevServerProject>,
2311 session: UserSession,
2312) -> Result<()> {
2313 let dev_server_id = DevServerId(request.dev_server_id as i32);
2314 let dev_server_connection_id = session
2315 .connection_pool()
2316 .await
2317 .dev_server_connection_id(dev_server_id);
2318 let Some(dev_server_connection_id) = dev_server_connection_id else {
2319 Err(ErrorCode::DevServerOffline
2320 .message("Cannot create a remote project when the dev server is offline".to_string())
2321 .anyhow())?
2322 };
2323
2324 let path = request.path.clone();
2325 //Check that the path exists on the dev server
2326 session
2327 .peer
2328 .forward_request(
2329 session.connection_id,
2330 dev_server_connection_id,
2331 proto::ValidateDevServerProjectRequest { path: path.clone() },
2332 )
2333 .await?;
2334
2335 let (dev_server_project, update) = session
2336 .db()
2337 .await
2338 .create_dev_server_project(
2339 DevServerId(request.dev_server_id as i32),
2340 &request.path,
2341 session.user_id(),
2342 )
2343 .await?;
2344
2345 let projects = session
2346 .db()
2347 .await
2348 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2349 .await?;
2350
2351 session.peer.send(
2352 dev_server_connection_id,
2353 proto::DevServerInstructions { projects },
2354 )?;
2355
2356 send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358 response.send(proto::CreateDevServerProjectResponse {
2359 dev_server_project: Some(dev_server_project.to_proto(None)),
2360 })?;
2361 Ok(())
2362}
2363
2364async fn create_dev_server(
2365 request: proto::CreateDevServer,
2366 response: Response<proto::CreateDevServer>,
2367 session: UserSession,
2368) -> Result<()> {
2369 let access_token = auth::random_token();
2370 let hashed_access_token = auth::hash_access_token(&access_token);
2371
2372 if request.name.is_empty() {
2373 return Err(proto::ErrorCode::Forbidden
2374 .message("Dev server name cannot be empty".to_string())
2375 .anyhow())?;
2376 }
2377
2378 let (dev_server, status) = session
2379 .db()
2380 .await
2381 .create_dev_server(
2382 &request.name,
2383 request.ssh_connection_string.as_deref(),
2384 &hashed_access_token,
2385 session.user_id(),
2386 )
2387 .await?;
2388
2389 send_dev_server_projects_update(session.user_id(), status, &session).await;
2390
2391 response.send(proto::CreateDevServerResponse {
2392 dev_server_id: dev_server.id.0 as u64,
2393 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2394 name: request.name,
2395 })?;
2396 Ok(())
2397}
2398
2399async fn regenerate_dev_server_token(
2400 request: proto::RegenerateDevServerToken,
2401 response: Response<proto::RegenerateDevServerToken>,
2402 session: UserSession,
2403) -> Result<()> {
2404 let dev_server_id = DevServerId(request.dev_server_id as i32);
2405 let access_token = auth::random_token();
2406 let hashed_access_token = auth::hash_access_token(&access_token);
2407
2408 let connection_id = session
2409 .connection_pool()
2410 .await
2411 .dev_server_connection_id(dev_server_id);
2412 if let Some(connection_id) = connection_id {
2413 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2414 session.peer.send(
2415 connection_id,
2416 proto::ShutdownDevServer {
2417 reason: Some("dev server token was regenerated".to_string()),
2418 },
2419 )?;
2420 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2421 }
2422
2423 let status = session
2424 .db()
2425 .await
2426 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2427 .await?;
2428
2429 send_dev_server_projects_update(session.user_id(), status, &session).await;
2430
2431 response.send(proto::RegenerateDevServerTokenResponse {
2432 dev_server_id: dev_server_id.to_proto(),
2433 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2434 })?;
2435 Ok(())
2436}
2437
2438async fn rename_dev_server(
2439 request: proto::RenameDevServer,
2440 response: Response<proto::RenameDevServer>,
2441 session: UserSession,
2442) -> Result<()> {
2443 if request.name.trim().is_empty() {
2444 return Err(proto::ErrorCode::Forbidden
2445 .message("Dev server name cannot be empty".to_string())
2446 .anyhow())?;
2447 }
2448
2449 let dev_server_id = DevServerId(request.dev_server_id as i32);
2450 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2451 if dev_server.user_id != session.user_id() {
2452 return Err(anyhow!(ErrorCode::Forbidden))?;
2453 }
2454
2455 let status = session
2456 .db()
2457 .await
2458 .rename_dev_server(
2459 dev_server_id,
2460 &request.name,
2461 request.ssh_connection_string.as_deref(),
2462 session.user_id(),
2463 )
2464 .await?;
2465
2466 send_dev_server_projects_update(session.user_id(), status, &session).await;
2467
2468 response.send(proto::Ack {})?;
2469 Ok(())
2470}
2471
2472async fn delete_dev_server(
2473 request: proto::DeleteDevServer,
2474 response: Response<proto::DeleteDevServer>,
2475 session: UserSession,
2476) -> Result<()> {
2477 let dev_server_id = DevServerId(request.dev_server_id as i32);
2478 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2479 if dev_server.user_id != session.user_id() {
2480 return Err(anyhow!(ErrorCode::Forbidden))?;
2481 }
2482
2483 let connection_id = session
2484 .connection_pool()
2485 .await
2486 .dev_server_connection_id(dev_server_id);
2487 if let Some(connection_id) = connection_id {
2488 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2489 session.peer.send(
2490 connection_id,
2491 proto::ShutdownDevServer {
2492 reason: Some("dev server was deleted".to_string()),
2493 },
2494 )?;
2495 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2496 }
2497
2498 let status = session
2499 .db()
2500 .await
2501 .delete_dev_server(dev_server_id, session.user_id())
2502 .await?;
2503
2504 send_dev_server_projects_update(session.user_id(), status, &session).await;
2505
2506 response.send(proto::Ack {})?;
2507 Ok(())
2508}
2509
2510async fn delete_dev_server_project(
2511 request: proto::DeleteDevServerProject,
2512 response: Response<proto::DeleteDevServerProject>,
2513 session: UserSession,
2514) -> Result<()> {
2515 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2516 let dev_server_project = session
2517 .db()
2518 .await
2519 .get_dev_server_project(dev_server_project_id)
2520 .await?;
2521
2522 let dev_server = session
2523 .db()
2524 .await
2525 .get_dev_server(dev_server_project.dev_server_id)
2526 .await?;
2527 if dev_server.user_id != session.user_id() {
2528 return Err(anyhow!(ErrorCode::Forbidden))?;
2529 }
2530
2531 let dev_server_connection_id = session
2532 .connection_pool()
2533 .await
2534 .dev_server_connection_id(dev_server.id);
2535
2536 if let Some(dev_server_connection_id) = dev_server_connection_id {
2537 let project = session
2538 .db()
2539 .await
2540 .find_dev_server_project(dev_server_project_id)
2541 .await;
2542 if let Ok(project) = project {
2543 unshare_project_internal(
2544 project.id,
2545 dev_server_connection_id,
2546 Some(session.user_id()),
2547 &session,
2548 )
2549 .await?;
2550 }
2551 }
2552
2553 let (projects, status) = session
2554 .db()
2555 .await
2556 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2557 .await?;
2558
2559 if let Some(dev_server_connection_id) = dev_server_connection_id {
2560 session.peer.send(
2561 dev_server_connection_id,
2562 proto::DevServerInstructions { projects },
2563 )?;
2564 }
2565
2566 send_dev_server_projects_update(session.user_id(), status, &session).await;
2567
2568 response.send(proto::Ack {})?;
2569 Ok(())
2570}
2571
2572async fn rejoin_dev_server_projects(
2573 request: proto::RejoinRemoteProjects,
2574 response: Response<proto::RejoinRemoteProjects>,
2575 session: UserSession,
2576) -> Result<()> {
2577 let mut rejoined_projects = {
2578 let db = session.db().await;
2579 db.rejoin_dev_server_projects(
2580 &request.rejoined_projects,
2581 session.user_id(),
2582 session.0.connection_id,
2583 )
2584 .await?
2585 };
2586 response.send(proto::RejoinRemoteProjectsResponse {
2587 rejoined_projects: rejoined_projects
2588 .iter()
2589 .map(|project| project.to_proto())
2590 .collect(),
2591 })?;
2592 notify_rejoined_projects(&mut rejoined_projects, &session)
2593}
2594
2595async fn reconnect_dev_server(
2596 request: proto::ReconnectDevServer,
2597 response: Response<proto::ReconnectDevServer>,
2598 session: DevServerSession,
2599) -> Result<()> {
2600 let reshared_projects = {
2601 let db = session.db().await;
2602 db.reshare_dev_server_projects(
2603 &request.reshared_projects,
2604 session.dev_server_id(),
2605 session.0.connection_id,
2606 )
2607 .await?
2608 };
2609
2610 for project in &reshared_projects {
2611 for collaborator in &project.collaborators {
2612 session
2613 .peer
2614 .send(
2615 collaborator.connection_id,
2616 proto::UpdateProjectCollaborator {
2617 project_id: project.id.to_proto(),
2618 old_peer_id: Some(project.old_connection_id.into()),
2619 new_peer_id: Some(session.connection_id.into()),
2620 },
2621 )
2622 .trace_err();
2623 }
2624
2625 broadcast(
2626 Some(session.connection_id),
2627 project
2628 .collaborators
2629 .iter()
2630 .map(|collaborator| collaborator.connection_id),
2631 |connection_id| {
2632 session.peer.forward_send(
2633 session.connection_id,
2634 connection_id,
2635 proto::UpdateProject {
2636 project_id: project.id.to_proto(),
2637 worktrees: project.worktrees.clone(),
2638 },
2639 )
2640 },
2641 );
2642 }
2643
2644 response.send(proto::ReconnectDevServerResponse {
2645 reshared_projects: reshared_projects
2646 .iter()
2647 .map(|project| proto::ResharedProject {
2648 id: project.id.to_proto(),
2649 collaborators: project
2650 .collaborators
2651 .iter()
2652 .map(|collaborator| collaborator.to_proto())
2653 .collect(),
2654 })
2655 .collect(),
2656 })?;
2657
2658 Ok(())
2659}
2660
2661async fn shutdown_dev_server(
2662 _: proto::ShutdownDevServer,
2663 response: Response<proto::ShutdownDevServer>,
2664 session: DevServerSession,
2665) -> Result<()> {
2666 response.send(proto::Ack {})?;
2667 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2668 remove_dev_server_connection(session.dev_server_id(), &session).await
2669}
2670
2671async fn shutdown_dev_server_internal(
2672 dev_server_id: DevServerId,
2673 connection_id: ConnectionId,
2674 session: &Session,
2675) -> Result<()> {
2676 let (dev_server_projects, dev_server) = {
2677 let db = session.db().await;
2678 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2679 let dev_server = db.get_dev_server(dev_server_id).await?;
2680 (dev_server_projects, dev_server)
2681 };
2682
2683 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2684 unshare_project_internal(
2685 ProjectId::from_proto(project_id),
2686 connection_id,
2687 None,
2688 session,
2689 )
2690 .await?;
2691 }
2692
2693 session
2694 .connection_pool()
2695 .await
2696 .set_dev_server_offline(dev_server_id);
2697
2698 let status = session
2699 .db()
2700 .await
2701 .dev_server_projects_update(dev_server.user_id)
2702 .await?;
2703 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2704
2705 Ok(())
2706}
2707
2708async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2709 let dev_server_connection = session
2710 .connection_pool()
2711 .await
2712 .dev_server_connection_id(dev_server_id);
2713
2714 if let Some(dev_server_connection) = dev_server_connection {
2715 session
2716 .connection_pool()
2717 .await
2718 .remove_connection(dev_server_connection)?;
2719 }
2720 Ok(())
2721}
2722
2723/// Updates other participants with changes to the project
2724async fn update_project(
2725 request: proto::UpdateProject,
2726 response: Response<proto::UpdateProject>,
2727 session: Session,
2728) -> Result<()> {
2729 let project_id = ProjectId::from_proto(request.project_id);
2730 let (room, guest_connection_ids) = &*session
2731 .db()
2732 .await
2733 .update_project(project_id, session.connection_id, &request.worktrees)
2734 .await?;
2735 broadcast(
2736 Some(session.connection_id),
2737 guest_connection_ids.iter().copied(),
2738 |connection_id| {
2739 session
2740 .peer
2741 .forward_send(session.connection_id, connection_id, request.clone())
2742 },
2743 );
2744 if let Some(room) = room {
2745 room_updated(&room, &session.peer);
2746 }
2747 response.send(proto::Ack {})?;
2748
2749 Ok(())
2750}
2751
2752/// Updates other participants with changes to the worktree
2753async fn update_worktree(
2754 request: proto::UpdateWorktree,
2755 response: Response<proto::UpdateWorktree>,
2756 session: Session,
2757) -> Result<()> {
2758 let guest_connection_ids = session
2759 .db()
2760 .await
2761 .update_worktree(&request, session.connection_id)
2762 .await?;
2763
2764 broadcast(
2765 Some(session.connection_id),
2766 guest_connection_ids.iter().copied(),
2767 |connection_id| {
2768 session
2769 .peer
2770 .forward_send(session.connection_id, connection_id, request.clone())
2771 },
2772 );
2773 response.send(proto::Ack {})?;
2774 Ok(())
2775}
2776
2777/// Updates other participants with changes to the diagnostics
2778async fn update_diagnostic_summary(
2779 message: proto::UpdateDiagnosticSummary,
2780 session: Session,
2781) -> Result<()> {
2782 let guest_connection_ids = session
2783 .db()
2784 .await
2785 .update_diagnostic_summary(&message, session.connection_id)
2786 .await?;
2787
2788 broadcast(
2789 Some(session.connection_id),
2790 guest_connection_ids.iter().copied(),
2791 |connection_id| {
2792 session
2793 .peer
2794 .forward_send(session.connection_id, connection_id, message.clone())
2795 },
2796 );
2797
2798 Ok(())
2799}
2800
2801/// Updates other participants with changes to the worktree settings
2802async fn update_worktree_settings(
2803 message: proto::UpdateWorktreeSettings,
2804 session: Session,
2805) -> Result<()> {
2806 let guest_connection_ids = session
2807 .db()
2808 .await
2809 .update_worktree_settings(&message, session.connection_id)
2810 .await?;
2811
2812 broadcast(
2813 Some(session.connection_id),
2814 guest_connection_ids.iter().copied(),
2815 |connection_id| {
2816 session
2817 .peer
2818 .forward_send(session.connection_id, connection_id, message.clone())
2819 },
2820 );
2821
2822 Ok(())
2823}
2824
2825/// Notify other participants that a language server has started.
2826async fn start_language_server(
2827 request: proto::StartLanguageServer,
2828 session: Session,
2829) -> Result<()> {
2830 let guest_connection_ids = session
2831 .db()
2832 .await
2833 .start_language_server(&request, session.connection_id)
2834 .await?;
2835
2836 broadcast(
2837 Some(session.connection_id),
2838 guest_connection_ids.iter().copied(),
2839 |connection_id| {
2840 session
2841 .peer
2842 .forward_send(session.connection_id, connection_id, request.clone())
2843 },
2844 );
2845 Ok(())
2846}
2847
2848/// Notify other participants that a language server has changed.
2849async fn update_language_server(
2850 request: proto::UpdateLanguageServer,
2851 session: Session,
2852) -> Result<()> {
2853 let project_id = ProjectId::from_proto(request.project_id);
2854 let project_connection_ids = session
2855 .db()
2856 .await
2857 .project_connection_ids(project_id, session.connection_id, true)
2858 .await?;
2859 broadcast(
2860 Some(session.connection_id),
2861 project_connection_ids.iter().copied(),
2862 |connection_id| {
2863 session
2864 .peer
2865 .forward_send(session.connection_id, connection_id, request.clone())
2866 },
2867 );
2868 Ok(())
2869}
2870
2871/// forward a project request to the host. These requests should be read only
2872/// as guests are allowed to send them.
2873async fn forward_read_only_project_request<T>(
2874 request: T,
2875 response: Response<T>,
2876 session: UserSession,
2877) -> Result<()>
2878where
2879 T: EntityMessage + RequestMessage,
2880{
2881 let project_id = ProjectId::from_proto(request.remote_entity_id());
2882 let host_connection_id = session
2883 .db()
2884 .await
2885 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2886 .await?;
2887 let payload = session
2888 .peer
2889 .forward_request(session.connection_id, host_connection_id, request)
2890 .await?;
2891 response.send(payload)?;
2892 Ok(())
2893}
2894
2895/// forward a project request to the dev server. Only allowed
2896/// if it's your dev server.
2897async fn forward_project_request_for_owner<T>(
2898 request: T,
2899 response: Response<T>,
2900 session: UserSession,
2901) -> Result<()>
2902where
2903 T: EntityMessage + RequestMessage,
2904{
2905 let project_id = ProjectId::from_proto(request.remote_entity_id());
2906
2907 let host_connection_id = session
2908 .db()
2909 .await
2910 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2911 .await?;
2912 let payload = session
2913 .peer
2914 .forward_request(session.connection_id, host_connection_id, request)
2915 .await?;
2916 response.send(payload)?;
2917 Ok(())
2918}
2919
2920/// forward a project request to the host. These requests are disallowed
2921/// for guests.
2922async fn forward_mutating_project_request<T>(
2923 request: T,
2924 response: Response<T>,
2925 session: UserSession,
2926) -> Result<()>
2927where
2928 T: EntityMessage + RequestMessage,
2929{
2930 let project_id = ProjectId::from_proto(request.remote_entity_id());
2931
2932 let host_connection_id = session
2933 .db()
2934 .await
2935 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2936 .await?;
2937 let payload = session
2938 .peer
2939 .forward_request(session.connection_id, host_connection_id, request)
2940 .await?;
2941 response.send(payload)?;
2942 Ok(())
2943}
2944
2945/// forward a project request to the host. These requests are disallowed
2946/// for guests.
2947async fn forward_versioned_mutating_project_request<T>(
2948 request: T,
2949 response: Response<T>,
2950 session: UserSession,
2951) -> Result<()>
2952where
2953 T: EntityMessage + RequestMessage + VersionedMessage,
2954{
2955 let project_id = ProjectId::from_proto(request.remote_entity_id());
2956
2957 let host_connection_id = session
2958 .db()
2959 .await
2960 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2961 .await?;
2962 if let Some(host_version) = session
2963 .connection_pool()
2964 .await
2965 .connection(host_connection_id)
2966 .map(|c| c.zed_version)
2967 {
2968 if let Some(min_required_version) = request.required_host_version() {
2969 if min_required_version > host_version {
2970 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2971 .with_tag("required", &min_required_version.to_string())))?;
2972 }
2973 }
2974 }
2975
2976 let payload = session
2977 .peer
2978 .forward_request(session.connection_id, host_connection_id, request)
2979 .await?;
2980 response.send(payload)?;
2981 Ok(())
2982}
2983
2984/// Notify other participants that a new buffer has been created
2985async fn create_buffer_for_peer(
2986 request: proto::CreateBufferForPeer,
2987 session: Session,
2988) -> Result<()> {
2989 session
2990 .db()
2991 .await
2992 .check_user_is_project_host(
2993 ProjectId::from_proto(request.project_id),
2994 session.connection_id,
2995 )
2996 .await?;
2997 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2998 session
2999 .peer
3000 .forward_send(session.connection_id, peer_id.into(), request)?;
3001 Ok(())
3002}
3003
3004/// Notify other participants that a buffer has been updated. This is
3005/// allowed for guests as long as the update is limited to selections.
3006async fn update_buffer(
3007 request: proto::UpdateBuffer,
3008 response: Response<proto::UpdateBuffer>,
3009 session: Session,
3010) -> Result<()> {
3011 let project_id = ProjectId::from_proto(request.project_id);
3012 let mut capability = Capability::ReadOnly;
3013
3014 for op in request.operations.iter() {
3015 match op.variant {
3016 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3017 Some(_) => capability = Capability::ReadWrite,
3018 }
3019 }
3020
3021 let host = {
3022 let guard = session
3023 .db()
3024 .await
3025 .connections_for_buffer_update(
3026 project_id,
3027 session.principal_id(),
3028 session.connection_id,
3029 capability,
3030 )
3031 .await?;
3032
3033 let (host, guests) = &*guard;
3034
3035 broadcast(
3036 Some(session.connection_id),
3037 guests.clone(),
3038 |connection_id| {
3039 session
3040 .peer
3041 .forward_send(session.connection_id, connection_id, request.clone())
3042 },
3043 );
3044
3045 *host
3046 };
3047
3048 if host != session.connection_id {
3049 session
3050 .peer
3051 .forward_request(session.connection_id, host, request.clone())
3052 .await?;
3053 }
3054
3055 response.send(proto::Ack {})?;
3056 Ok(())
3057}
3058
3059/// Notify other participants that a project has been updated.
3060async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3061 request: T,
3062 session: Session,
3063) -> Result<()> {
3064 let project_id = ProjectId::from_proto(request.remote_entity_id());
3065 let project_connection_ids = session
3066 .db()
3067 .await
3068 .project_connection_ids(project_id, session.connection_id, false)
3069 .await?;
3070
3071 broadcast(
3072 Some(session.connection_id),
3073 project_connection_ids.iter().copied(),
3074 |connection_id| {
3075 session
3076 .peer
3077 .forward_send(session.connection_id, connection_id, request.clone())
3078 },
3079 );
3080 Ok(())
3081}
3082
3083/// Start following another user in a call.
3084async fn follow(
3085 request: proto::Follow,
3086 response: Response<proto::Follow>,
3087 session: UserSession,
3088) -> Result<()> {
3089 let room_id = RoomId::from_proto(request.room_id);
3090 let project_id = request.project_id.map(ProjectId::from_proto);
3091 let leader_id = request
3092 .leader_id
3093 .ok_or_else(|| anyhow!("invalid leader id"))?
3094 .into();
3095 let follower_id = session.connection_id;
3096
3097 session
3098 .db()
3099 .await
3100 .check_room_participants(room_id, leader_id, session.connection_id)
3101 .await?;
3102
3103 let response_payload = session
3104 .peer
3105 .forward_request(session.connection_id, leader_id, request)
3106 .await?;
3107 response.send(response_payload)?;
3108
3109 if let Some(project_id) = project_id {
3110 let room = session
3111 .db()
3112 .await
3113 .follow(room_id, project_id, leader_id, follower_id)
3114 .await?;
3115 room_updated(&room, &session.peer);
3116 }
3117
3118 Ok(())
3119}
3120
3121/// Stop following another user in a call.
3122async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3123 let room_id = RoomId::from_proto(request.room_id);
3124 let project_id = request.project_id.map(ProjectId::from_proto);
3125 let leader_id = request
3126 .leader_id
3127 .ok_or_else(|| anyhow!("invalid leader id"))?
3128 .into();
3129 let follower_id = session.connection_id;
3130
3131 session
3132 .db()
3133 .await
3134 .check_room_participants(room_id, leader_id, session.connection_id)
3135 .await?;
3136
3137 session
3138 .peer
3139 .forward_send(session.connection_id, leader_id, request)?;
3140
3141 if let Some(project_id) = project_id {
3142 let room = session
3143 .db()
3144 .await
3145 .unfollow(room_id, project_id, leader_id, follower_id)
3146 .await?;
3147 room_updated(&room, &session.peer);
3148 }
3149
3150 Ok(())
3151}
3152
3153/// Notify everyone following you of your current location.
3154async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3155 let room_id = RoomId::from_proto(request.room_id);
3156 let database = session.db.lock().await;
3157
3158 let connection_ids = if let Some(project_id) = request.project_id {
3159 let project_id = ProjectId::from_proto(project_id);
3160 database
3161 .project_connection_ids(project_id, session.connection_id, true)
3162 .await?
3163 } else {
3164 database
3165 .room_connection_ids(room_id, session.connection_id)
3166 .await?
3167 };
3168
3169 // For now, don't send view update messages back to that view's current leader.
3170 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3171 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3172 _ => None,
3173 });
3174
3175 for connection_id in connection_ids.iter().cloned() {
3176 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3177 session
3178 .peer
3179 .forward_send(session.connection_id, connection_id, request.clone())?;
3180 }
3181 }
3182 Ok(())
3183}
3184
3185/// Get public data about users.
3186async fn get_users(
3187 request: proto::GetUsers,
3188 response: Response<proto::GetUsers>,
3189 session: Session,
3190) -> Result<()> {
3191 let user_ids = request
3192 .user_ids
3193 .into_iter()
3194 .map(UserId::from_proto)
3195 .collect();
3196 let users = session
3197 .db()
3198 .await
3199 .get_users_by_ids(user_ids)
3200 .await?
3201 .into_iter()
3202 .map(|user| proto::User {
3203 id: user.id.to_proto(),
3204 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3205 github_login: user.github_login,
3206 })
3207 .collect();
3208 response.send(proto::UsersResponse { users })?;
3209 Ok(())
3210}
3211
3212/// Search for users (to invite) buy Github login
3213async fn fuzzy_search_users(
3214 request: proto::FuzzySearchUsers,
3215 response: Response<proto::FuzzySearchUsers>,
3216 session: UserSession,
3217) -> Result<()> {
3218 let query = request.query;
3219 let users = match query.len() {
3220 0 => vec![],
3221 1 | 2 => session
3222 .db()
3223 .await
3224 .get_user_by_github_login(&query)
3225 .await?
3226 .into_iter()
3227 .collect(),
3228 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3229 };
3230 let users = users
3231 .into_iter()
3232 .filter(|user| user.id != session.user_id())
3233 .map(|user| proto::User {
3234 id: user.id.to_proto(),
3235 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3236 github_login: user.github_login,
3237 })
3238 .collect();
3239 response.send(proto::UsersResponse { users })?;
3240 Ok(())
3241}
3242
3243/// Send a contact request to another user.
3244async fn request_contact(
3245 request: proto::RequestContact,
3246 response: Response<proto::RequestContact>,
3247 session: UserSession,
3248) -> Result<()> {
3249 let requester_id = session.user_id();
3250 let responder_id = UserId::from_proto(request.responder_id);
3251 if requester_id == responder_id {
3252 return Err(anyhow!("cannot add yourself as a contact"))?;
3253 }
3254
3255 let notifications = session
3256 .db()
3257 .await
3258 .send_contact_request(requester_id, responder_id)
3259 .await?;
3260
3261 // Update outgoing contact requests of requester
3262 let mut update = proto::UpdateContacts::default();
3263 update.outgoing_requests.push(responder_id.to_proto());
3264 for connection_id in session
3265 .connection_pool()
3266 .await
3267 .user_connection_ids(requester_id)
3268 {
3269 session.peer.send(connection_id, update.clone())?;
3270 }
3271
3272 // Update incoming contact requests of responder
3273 let mut update = proto::UpdateContacts::default();
3274 update
3275 .incoming_requests
3276 .push(proto::IncomingContactRequest {
3277 requester_id: requester_id.to_proto(),
3278 });
3279 let connection_pool = session.connection_pool().await;
3280 for connection_id in connection_pool.user_connection_ids(responder_id) {
3281 session.peer.send(connection_id, update.clone())?;
3282 }
3283
3284 send_notifications(&connection_pool, &session.peer, notifications);
3285
3286 response.send(proto::Ack {})?;
3287 Ok(())
3288}
3289
3290/// Accept or decline a contact request
3291async fn respond_to_contact_request(
3292 request: proto::RespondToContactRequest,
3293 response: Response<proto::RespondToContactRequest>,
3294 session: UserSession,
3295) -> Result<()> {
3296 let responder_id = session.user_id();
3297 let requester_id = UserId::from_proto(request.requester_id);
3298 let db = session.db().await;
3299 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3300 db.dismiss_contact_notification(responder_id, requester_id)
3301 .await?;
3302 } else {
3303 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3304
3305 let notifications = db
3306 .respond_to_contact_request(responder_id, requester_id, accept)
3307 .await?;
3308 let requester_busy = db.is_user_busy(requester_id).await?;
3309 let responder_busy = db.is_user_busy(responder_id).await?;
3310
3311 let pool = session.connection_pool().await;
3312 // Update responder with new contact
3313 let mut update = proto::UpdateContacts::default();
3314 if accept {
3315 update
3316 .contacts
3317 .push(contact_for_user(requester_id, requester_busy, &pool));
3318 }
3319 update
3320 .remove_incoming_requests
3321 .push(requester_id.to_proto());
3322 for connection_id in pool.user_connection_ids(responder_id) {
3323 session.peer.send(connection_id, update.clone())?;
3324 }
3325
3326 // Update requester with new contact
3327 let mut update = proto::UpdateContacts::default();
3328 if accept {
3329 update
3330 .contacts
3331 .push(contact_for_user(responder_id, responder_busy, &pool));
3332 }
3333 update
3334 .remove_outgoing_requests
3335 .push(responder_id.to_proto());
3336
3337 for connection_id in pool.user_connection_ids(requester_id) {
3338 session.peer.send(connection_id, update.clone())?;
3339 }
3340
3341 send_notifications(&pool, &session.peer, notifications);
3342 }
3343
3344 response.send(proto::Ack {})?;
3345 Ok(())
3346}
3347
3348/// Remove a contact.
3349async fn remove_contact(
3350 request: proto::RemoveContact,
3351 response: Response<proto::RemoveContact>,
3352 session: UserSession,
3353) -> Result<()> {
3354 let requester_id = session.user_id();
3355 let responder_id = UserId::from_proto(request.user_id);
3356 let db = session.db().await;
3357 let (contact_accepted, deleted_notification_id) =
3358 db.remove_contact(requester_id, responder_id).await?;
3359
3360 let pool = session.connection_pool().await;
3361 // Update outgoing contact requests of requester
3362 let mut update = proto::UpdateContacts::default();
3363 if contact_accepted {
3364 update.remove_contacts.push(responder_id.to_proto());
3365 } else {
3366 update
3367 .remove_outgoing_requests
3368 .push(responder_id.to_proto());
3369 }
3370 for connection_id in pool.user_connection_ids(requester_id) {
3371 session.peer.send(connection_id, update.clone())?;
3372 }
3373
3374 // Update incoming contact requests of responder
3375 let mut update = proto::UpdateContacts::default();
3376 if contact_accepted {
3377 update.remove_contacts.push(requester_id.to_proto());
3378 } else {
3379 update
3380 .remove_incoming_requests
3381 .push(requester_id.to_proto());
3382 }
3383 for connection_id in pool.user_connection_ids(responder_id) {
3384 session.peer.send(connection_id, update.clone())?;
3385 if let Some(notification_id) = deleted_notification_id {
3386 session.peer.send(
3387 connection_id,
3388 proto::DeleteNotification {
3389 notification_id: notification_id.to_proto(),
3390 },
3391 )?;
3392 }
3393 }
3394
3395 response.send(proto::Ack {})?;
3396 Ok(())
3397}
3398
3399fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3400 version.0.minor() < 139
3401}
3402
3403async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3404 subscribe_user_to_channels(
3405 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3406 &session,
3407 )
3408 .await?;
3409 Ok(())
3410}
3411
3412async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3413 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3414 let mut pool = session.connection_pool().await;
3415 for membership in &channels_for_user.channel_memberships {
3416 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3417 }
3418 session.peer.send(
3419 session.connection_id,
3420 build_update_user_channels(&channels_for_user),
3421 )?;
3422 session.peer.send(
3423 session.connection_id,
3424 build_channels_update(channels_for_user),
3425 )?;
3426 Ok(())
3427}
3428
3429/// Creates a new channel.
3430async fn create_channel(
3431 request: proto::CreateChannel,
3432 response: Response<proto::CreateChannel>,
3433 session: UserSession,
3434) -> Result<()> {
3435 let db = session.db().await;
3436
3437 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3438 let (channel, membership) = db
3439 .create_channel(&request.name, parent_id, session.user_id())
3440 .await?;
3441
3442 let root_id = channel.root_id();
3443 let channel = Channel::from_model(channel);
3444
3445 response.send(proto::CreateChannelResponse {
3446 channel: Some(channel.to_proto()),
3447 parent_id: request.parent_id,
3448 })?;
3449
3450 let mut connection_pool = session.connection_pool().await;
3451 if let Some(membership) = membership {
3452 connection_pool.subscribe_to_channel(
3453 membership.user_id,
3454 membership.channel_id,
3455 membership.role,
3456 );
3457 let update = proto::UpdateUserChannels {
3458 channel_memberships: vec![proto::ChannelMembership {
3459 channel_id: membership.channel_id.to_proto(),
3460 role: membership.role.into(),
3461 }],
3462 ..Default::default()
3463 };
3464 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3465 session.peer.send(connection_id, update.clone())?;
3466 }
3467 }
3468
3469 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3470 if !role.can_see_channel(channel.visibility) {
3471 continue;
3472 }
3473
3474 let update = proto::UpdateChannels {
3475 channels: vec![channel.to_proto()],
3476 ..Default::default()
3477 };
3478 session.peer.send(connection_id, update.clone())?;
3479 }
3480
3481 Ok(())
3482}
3483
3484/// Delete a channel
3485async fn delete_channel(
3486 request: proto::DeleteChannel,
3487 response: Response<proto::DeleteChannel>,
3488 session: UserSession,
3489) -> Result<()> {
3490 let db = session.db().await;
3491
3492 let channel_id = request.channel_id;
3493 let (root_channel, removed_channels) = db
3494 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3495 .await?;
3496 response.send(proto::Ack {})?;
3497
3498 // Notify members of removed channels
3499 let mut update = proto::UpdateChannels::default();
3500 update
3501 .delete_channels
3502 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3503
3504 let connection_pool = session.connection_pool().await;
3505 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3506 session.peer.send(connection_id, update.clone())?;
3507 }
3508
3509 Ok(())
3510}
3511
3512/// Invite someone to join a channel.
3513async fn invite_channel_member(
3514 request: proto::InviteChannelMember,
3515 response: Response<proto::InviteChannelMember>,
3516 session: UserSession,
3517) -> Result<()> {
3518 let db = session.db().await;
3519 let channel_id = ChannelId::from_proto(request.channel_id);
3520 let invitee_id = UserId::from_proto(request.user_id);
3521 let InviteMemberResult {
3522 channel,
3523 notifications,
3524 } = db
3525 .invite_channel_member(
3526 channel_id,
3527 invitee_id,
3528 session.user_id(),
3529 request.role().into(),
3530 )
3531 .await?;
3532
3533 let update = proto::UpdateChannels {
3534 channel_invitations: vec![channel.to_proto()],
3535 ..Default::default()
3536 };
3537
3538 let connection_pool = session.connection_pool().await;
3539 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3540 session.peer.send(connection_id, update.clone())?;
3541 }
3542
3543 send_notifications(&connection_pool, &session.peer, notifications);
3544
3545 response.send(proto::Ack {})?;
3546 Ok(())
3547}
3548
3549/// remove someone from a channel
3550async fn remove_channel_member(
3551 request: proto::RemoveChannelMember,
3552 response: Response<proto::RemoveChannelMember>,
3553 session: UserSession,
3554) -> Result<()> {
3555 let db = session.db().await;
3556 let channel_id = ChannelId::from_proto(request.channel_id);
3557 let member_id = UserId::from_proto(request.user_id);
3558
3559 let RemoveChannelMemberResult {
3560 membership_update,
3561 notification_id,
3562 } = db
3563 .remove_channel_member(channel_id, member_id, session.user_id())
3564 .await?;
3565
3566 let mut connection_pool = session.connection_pool().await;
3567 notify_membership_updated(
3568 &mut connection_pool,
3569 membership_update,
3570 member_id,
3571 &session.peer,
3572 );
3573 for connection_id in connection_pool.user_connection_ids(member_id) {
3574 if let Some(notification_id) = notification_id {
3575 session
3576 .peer
3577 .send(
3578 connection_id,
3579 proto::DeleteNotification {
3580 notification_id: notification_id.to_proto(),
3581 },
3582 )
3583 .trace_err();
3584 }
3585 }
3586
3587 response.send(proto::Ack {})?;
3588 Ok(())
3589}
3590
3591/// Toggle the channel between public and private.
3592/// Care is taken to maintain the invariant that public channels only descend from public channels,
3593/// (though members-only channels can appear at any point in the hierarchy).
3594async fn set_channel_visibility(
3595 request: proto::SetChannelVisibility,
3596 response: Response<proto::SetChannelVisibility>,
3597 session: UserSession,
3598) -> Result<()> {
3599 let db = session.db().await;
3600 let channel_id = ChannelId::from_proto(request.channel_id);
3601 let visibility = request.visibility().into();
3602
3603 let channel_model = db
3604 .set_channel_visibility(channel_id, visibility, session.user_id())
3605 .await?;
3606 let root_id = channel_model.root_id();
3607 let channel = Channel::from_model(channel_model);
3608
3609 let mut connection_pool = session.connection_pool().await;
3610 for (user_id, role) in connection_pool
3611 .channel_user_ids(root_id)
3612 .collect::<Vec<_>>()
3613 .into_iter()
3614 {
3615 let update = if role.can_see_channel(channel.visibility) {
3616 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3617 proto::UpdateChannels {
3618 channels: vec![channel.to_proto()],
3619 ..Default::default()
3620 }
3621 } else {
3622 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3623 proto::UpdateChannels {
3624 delete_channels: vec![channel.id.to_proto()],
3625 ..Default::default()
3626 }
3627 };
3628
3629 for connection_id in connection_pool.user_connection_ids(user_id) {
3630 session.peer.send(connection_id, update.clone())?;
3631 }
3632 }
3633
3634 response.send(proto::Ack {})?;
3635 Ok(())
3636}
3637
3638/// Alter the role for a user in the channel.
3639async fn set_channel_member_role(
3640 request: proto::SetChannelMemberRole,
3641 response: Response<proto::SetChannelMemberRole>,
3642 session: UserSession,
3643) -> Result<()> {
3644 let db = session.db().await;
3645 let channel_id = ChannelId::from_proto(request.channel_id);
3646 let member_id = UserId::from_proto(request.user_id);
3647 let result = db
3648 .set_channel_member_role(
3649 channel_id,
3650 session.user_id(),
3651 member_id,
3652 request.role().into(),
3653 )
3654 .await?;
3655
3656 match result {
3657 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3658 let mut connection_pool = session.connection_pool().await;
3659 notify_membership_updated(
3660 &mut connection_pool,
3661 membership_update,
3662 member_id,
3663 &session.peer,
3664 )
3665 }
3666 db::SetMemberRoleResult::InviteUpdated(channel) => {
3667 let update = proto::UpdateChannels {
3668 channel_invitations: vec![channel.to_proto()],
3669 ..Default::default()
3670 };
3671
3672 for connection_id in session
3673 .connection_pool()
3674 .await
3675 .user_connection_ids(member_id)
3676 {
3677 session.peer.send(connection_id, update.clone())?;
3678 }
3679 }
3680 }
3681
3682 response.send(proto::Ack {})?;
3683 Ok(())
3684}
3685
3686/// Change the name of a channel
3687async fn rename_channel(
3688 request: proto::RenameChannel,
3689 response: Response<proto::RenameChannel>,
3690 session: UserSession,
3691) -> Result<()> {
3692 let db = session.db().await;
3693 let channel_id = ChannelId::from_proto(request.channel_id);
3694 let channel_model = db
3695 .rename_channel(channel_id, session.user_id(), &request.name)
3696 .await?;
3697 let root_id = channel_model.root_id();
3698 let channel = Channel::from_model(channel_model);
3699
3700 response.send(proto::RenameChannelResponse {
3701 channel: Some(channel.to_proto()),
3702 })?;
3703
3704 let connection_pool = session.connection_pool().await;
3705 let update = proto::UpdateChannels {
3706 channels: vec![channel.to_proto()],
3707 ..Default::default()
3708 };
3709 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3710 if role.can_see_channel(channel.visibility) {
3711 session.peer.send(connection_id, update.clone())?;
3712 }
3713 }
3714
3715 Ok(())
3716}
3717
3718/// Move a channel to a new parent.
3719async fn move_channel(
3720 request: proto::MoveChannel,
3721 response: Response<proto::MoveChannel>,
3722 session: UserSession,
3723) -> Result<()> {
3724 let channel_id = ChannelId::from_proto(request.channel_id);
3725 let to = ChannelId::from_proto(request.to);
3726
3727 let (root_id, channels) = session
3728 .db()
3729 .await
3730 .move_channel(channel_id, to, session.user_id())
3731 .await?;
3732
3733 let connection_pool = session.connection_pool().await;
3734 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3735 let channels = channels
3736 .iter()
3737 .filter_map(|channel| {
3738 if role.can_see_channel(channel.visibility) {
3739 Some(channel.to_proto())
3740 } else {
3741 None
3742 }
3743 })
3744 .collect::<Vec<_>>();
3745 if channels.is_empty() {
3746 continue;
3747 }
3748
3749 let update = proto::UpdateChannels {
3750 channels,
3751 ..Default::default()
3752 };
3753
3754 session.peer.send(connection_id, update.clone())?;
3755 }
3756
3757 response.send(Ack {})?;
3758 Ok(())
3759}
3760
3761/// Get the list of channel members
3762async fn get_channel_members(
3763 request: proto::GetChannelMembers,
3764 response: Response<proto::GetChannelMembers>,
3765 session: UserSession,
3766) -> Result<()> {
3767 let db = session.db().await;
3768 let channel_id = ChannelId::from_proto(request.channel_id);
3769 let limit = if request.limit == 0 {
3770 u16::MAX as u64
3771 } else {
3772 request.limit
3773 };
3774 let (members, users) = db
3775 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3776 .await?;
3777 response.send(proto::GetChannelMembersResponse { members, users })?;
3778 Ok(())
3779}
3780
3781/// Accept or decline a channel invitation.
3782async fn respond_to_channel_invite(
3783 request: proto::RespondToChannelInvite,
3784 response: Response<proto::RespondToChannelInvite>,
3785 session: UserSession,
3786) -> Result<()> {
3787 let db = session.db().await;
3788 let channel_id = ChannelId::from_proto(request.channel_id);
3789 let RespondToChannelInvite {
3790 membership_update,
3791 notifications,
3792 } = db
3793 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3794 .await?;
3795
3796 let mut connection_pool = session.connection_pool().await;
3797 if let Some(membership_update) = membership_update {
3798 notify_membership_updated(
3799 &mut connection_pool,
3800 membership_update,
3801 session.user_id(),
3802 &session.peer,
3803 );
3804 } else {
3805 let update = proto::UpdateChannels {
3806 remove_channel_invitations: vec![channel_id.to_proto()],
3807 ..Default::default()
3808 };
3809
3810 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3811 session.peer.send(connection_id, update.clone())?;
3812 }
3813 };
3814
3815 send_notifications(&connection_pool, &session.peer, notifications);
3816
3817 response.send(proto::Ack {})?;
3818
3819 Ok(())
3820}
3821
3822/// Join the channels' room
3823async fn join_channel(
3824 request: proto::JoinChannel,
3825 response: Response<proto::JoinChannel>,
3826 session: UserSession,
3827) -> Result<()> {
3828 let channel_id = ChannelId::from_proto(request.channel_id);
3829 join_channel_internal(channel_id, Box::new(response), session).await
3830}
3831
3832trait JoinChannelInternalResponse {
3833 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3834}
3835impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3836 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3837 Response::<proto::JoinChannel>::send(self, result)
3838 }
3839}
3840impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3841 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3842 Response::<proto::JoinRoom>::send(self, result)
3843 }
3844}
3845
3846async fn join_channel_internal(
3847 channel_id: ChannelId,
3848 response: Box<impl JoinChannelInternalResponse>,
3849 session: UserSession,
3850) -> Result<()> {
3851 let joined_room = {
3852 let mut db = session.db().await;
3853 // If zed quits without leaving the room, and the user re-opens zed before the
3854 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3855 // room they were in.
3856 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3857 tracing::info!(
3858 stale_connection_id = %connection,
3859 "cleaning up stale connection",
3860 );
3861 drop(db);
3862 leave_room_for_session(&session, connection).await?;
3863 db = session.db().await;
3864 }
3865
3866 let (joined_room, membership_updated, role) = db
3867 .join_channel(channel_id, session.user_id(), session.connection_id)
3868 .await?;
3869
3870 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3871 let (can_publish, token) = if role == ChannelRole::Guest {
3872 (
3873 false,
3874 live_kit
3875 .guest_token(
3876 &joined_room.room.live_kit_room,
3877 &session.user_id().to_string(),
3878 )
3879 .trace_err()?,
3880 )
3881 } else {
3882 (
3883 true,
3884 live_kit
3885 .room_token(
3886 &joined_room.room.live_kit_room,
3887 &session.user_id().to_string(),
3888 )
3889 .trace_err()?,
3890 )
3891 };
3892
3893 Some(LiveKitConnectionInfo {
3894 server_url: live_kit.url().into(),
3895 token,
3896 can_publish,
3897 })
3898 });
3899
3900 response.send(proto::JoinRoomResponse {
3901 room: Some(joined_room.room.clone()),
3902 channel_id: joined_room
3903 .channel
3904 .as_ref()
3905 .map(|channel| channel.id.to_proto()),
3906 live_kit_connection_info,
3907 })?;
3908
3909 let mut connection_pool = session.connection_pool().await;
3910 if let Some(membership_updated) = membership_updated {
3911 notify_membership_updated(
3912 &mut connection_pool,
3913 membership_updated,
3914 session.user_id(),
3915 &session.peer,
3916 );
3917 }
3918
3919 room_updated(&joined_room.room, &session.peer);
3920
3921 joined_room
3922 };
3923
3924 channel_updated(
3925 &joined_room
3926 .channel
3927 .ok_or_else(|| anyhow!("channel not returned"))?,
3928 &joined_room.room,
3929 &session.peer,
3930 &*session.connection_pool().await,
3931 );
3932
3933 update_user_contacts(session.user_id(), &session).await?;
3934 Ok(())
3935}
3936
3937/// Start editing the channel notes
3938async fn join_channel_buffer(
3939 request: proto::JoinChannelBuffer,
3940 response: Response<proto::JoinChannelBuffer>,
3941 session: UserSession,
3942) -> Result<()> {
3943 let db = session.db().await;
3944 let channel_id = ChannelId::from_proto(request.channel_id);
3945
3946 let open_response = db
3947 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3948 .await?;
3949
3950 let collaborators = open_response.collaborators.clone();
3951 response.send(open_response)?;
3952
3953 let update = UpdateChannelBufferCollaborators {
3954 channel_id: channel_id.to_proto(),
3955 collaborators: collaborators.clone(),
3956 };
3957 channel_buffer_updated(
3958 session.connection_id,
3959 collaborators
3960 .iter()
3961 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3962 &update,
3963 &session.peer,
3964 );
3965
3966 Ok(())
3967}
3968
3969/// Edit the channel notes
3970async fn update_channel_buffer(
3971 request: proto::UpdateChannelBuffer,
3972 session: UserSession,
3973) -> Result<()> {
3974 let db = session.db().await;
3975 let channel_id = ChannelId::from_proto(request.channel_id);
3976
3977 let (collaborators, epoch, version) = db
3978 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3979 .await?;
3980
3981 channel_buffer_updated(
3982 session.connection_id,
3983 collaborators.clone(),
3984 &proto::UpdateChannelBuffer {
3985 channel_id: channel_id.to_proto(),
3986 operations: request.operations,
3987 },
3988 &session.peer,
3989 );
3990
3991 let pool = &*session.connection_pool().await;
3992
3993 let non_collaborators =
3994 pool.channel_connection_ids(channel_id)
3995 .filter_map(|(connection_id, _)| {
3996 if collaborators.contains(&connection_id) {
3997 None
3998 } else {
3999 Some(connection_id)
4000 }
4001 });
4002
4003 broadcast(None, non_collaborators, |peer_id| {
4004 session.peer.send(
4005 peer_id,
4006 proto::UpdateChannels {
4007 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4008 channel_id: channel_id.to_proto(),
4009 epoch: epoch as u64,
4010 version: version.clone(),
4011 }],
4012 ..Default::default()
4013 },
4014 )
4015 });
4016
4017 Ok(())
4018}
4019
4020/// Rejoin the channel notes after a connection blip
4021async fn rejoin_channel_buffers(
4022 request: proto::RejoinChannelBuffers,
4023 response: Response<proto::RejoinChannelBuffers>,
4024 session: UserSession,
4025) -> Result<()> {
4026 let db = session.db().await;
4027 let buffers = db
4028 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4029 .await?;
4030
4031 for rejoined_buffer in &buffers {
4032 let collaborators_to_notify = rejoined_buffer
4033 .buffer
4034 .collaborators
4035 .iter()
4036 .filter_map(|c| Some(c.peer_id?.into()));
4037 channel_buffer_updated(
4038 session.connection_id,
4039 collaborators_to_notify,
4040 &proto::UpdateChannelBufferCollaborators {
4041 channel_id: rejoined_buffer.buffer.channel_id,
4042 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4043 },
4044 &session.peer,
4045 );
4046 }
4047
4048 response.send(proto::RejoinChannelBuffersResponse {
4049 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4050 })?;
4051
4052 Ok(())
4053}
4054
4055/// Stop editing the channel notes
4056async fn leave_channel_buffer(
4057 request: proto::LeaveChannelBuffer,
4058 response: Response<proto::LeaveChannelBuffer>,
4059 session: UserSession,
4060) -> Result<()> {
4061 let db = session.db().await;
4062 let channel_id = ChannelId::from_proto(request.channel_id);
4063
4064 let left_buffer = db
4065 .leave_channel_buffer(channel_id, session.connection_id)
4066 .await?;
4067
4068 response.send(Ack {})?;
4069
4070 channel_buffer_updated(
4071 session.connection_id,
4072 left_buffer.connections,
4073 &proto::UpdateChannelBufferCollaborators {
4074 channel_id: channel_id.to_proto(),
4075 collaborators: left_buffer.collaborators,
4076 },
4077 &session.peer,
4078 );
4079
4080 Ok(())
4081}
4082
4083fn channel_buffer_updated<T: EnvelopedMessage>(
4084 sender_id: ConnectionId,
4085 collaborators: impl IntoIterator<Item = ConnectionId>,
4086 message: &T,
4087 peer: &Peer,
4088) {
4089 broadcast(Some(sender_id), collaborators, |peer_id| {
4090 peer.send(peer_id, message.clone())
4091 });
4092}
4093
4094fn send_notifications(
4095 connection_pool: &ConnectionPool,
4096 peer: &Peer,
4097 notifications: db::NotificationBatch,
4098) {
4099 for (user_id, notification) in notifications {
4100 for connection_id in connection_pool.user_connection_ids(user_id) {
4101 if let Err(error) = peer.send(
4102 connection_id,
4103 proto::AddNotification {
4104 notification: Some(notification.clone()),
4105 },
4106 ) {
4107 tracing::error!(
4108 "failed to send notification to {:?} {}",
4109 connection_id,
4110 error
4111 );
4112 }
4113 }
4114 }
4115}
4116
4117/// Send a message to the channel
4118async fn send_channel_message(
4119 request: proto::SendChannelMessage,
4120 response: Response<proto::SendChannelMessage>,
4121 session: UserSession,
4122) -> Result<()> {
4123 // Validate the message body.
4124 let body = request.body.trim().to_string();
4125 if body.len() > MAX_MESSAGE_LEN {
4126 return Err(anyhow!("message is too long"))?;
4127 }
4128 if body.is_empty() {
4129 return Err(anyhow!("message can't be blank"))?;
4130 }
4131
4132 // TODO: adjust mentions if body is trimmed
4133
4134 let timestamp = OffsetDateTime::now_utc();
4135 let nonce = request
4136 .nonce
4137 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4138
4139 let channel_id = ChannelId::from_proto(request.channel_id);
4140 let CreatedChannelMessage {
4141 message_id,
4142 participant_connection_ids,
4143 notifications,
4144 } = session
4145 .db()
4146 .await
4147 .create_channel_message(
4148 channel_id,
4149 session.user_id(),
4150 &body,
4151 &request.mentions,
4152 timestamp,
4153 nonce.clone().into(),
4154 match request.reply_to_message_id {
4155 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4156 None => None,
4157 },
4158 )
4159 .await?;
4160
4161 let message = proto::ChannelMessage {
4162 sender_id: session.user_id().to_proto(),
4163 id: message_id.to_proto(),
4164 body,
4165 mentions: request.mentions,
4166 timestamp: timestamp.unix_timestamp() as u64,
4167 nonce: Some(nonce),
4168 reply_to_message_id: request.reply_to_message_id,
4169 edited_at: None,
4170 };
4171 broadcast(
4172 Some(session.connection_id),
4173 participant_connection_ids.clone(),
4174 |connection| {
4175 session.peer.send(
4176 connection,
4177 proto::ChannelMessageSent {
4178 channel_id: channel_id.to_proto(),
4179 message: Some(message.clone()),
4180 },
4181 )
4182 },
4183 );
4184 response.send(proto::SendChannelMessageResponse {
4185 message: Some(message),
4186 })?;
4187
4188 let pool = &*session.connection_pool().await;
4189 let non_participants =
4190 pool.channel_connection_ids(channel_id)
4191 .filter_map(|(connection_id, _)| {
4192 if participant_connection_ids.contains(&connection_id) {
4193 None
4194 } else {
4195 Some(connection_id)
4196 }
4197 });
4198 broadcast(None, non_participants, |peer_id| {
4199 session.peer.send(
4200 peer_id,
4201 proto::UpdateChannels {
4202 latest_channel_message_ids: vec![proto::ChannelMessageId {
4203 channel_id: channel_id.to_proto(),
4204 message_id: message_id.to_proto(),
4205 }],
4206 ..Default::default()
4207 },
4208 )
4209 });
4210 send_notifications(pool, &session.peer, notifications);
4211
4212 Ok(())
4213}
4214
4215/// Delete a channel message
4216async fn remove_channel_message(
4217 request: proto::RemoveChannelMessage,
4218 response: Response<proto::RemoveChannelMessage>,
4219 session: UserSession,
4220) -> Result<()> {
4221 let channel_id = ChannelId::from_proto(request.channel_id);
4222 let message_id = MessageId::from_proto(request.message_id);
4223 let (connection_ids, existing_notification_ids) = session
4224 .db()
4225 .await
4226 .remove_channel_message(channel_id, message_id, session.user_id())
4227 .await?;
4228
4229 broadcast(
4230 Some(session.connection_id),
4231 connection_ids,
4232 move |connection| {
4233 session.peer.send(connection, request.clone())?;
4234
4235 for notification_id in &existing_notification_ids {
4236 session.peer.send(
4237 connection,
4238 proto::DeleteNotification {
4239 notification_id: (*notification_id).to_proto(),
4240 },
4241 )?;
4242 }
4243
4244 Ok(())
4245 },
4246 );
4247 response.send(proto::Ack {})?;
4248 Ok(())
4249}
4250
4251async fn update_channel_message(
4252 request: proto::UpdateChannelMessage,
4253 response: Response<proto::UpdateChannelMessage>,
4254 session: UserSession,
4255) -> Result<()> {
4256 let channel_id = ChannelId::from_proto(request.channel_id);
4257 let message_id = MessageId::from_proto(request.message_id);
4258 let updated_at = OffsetDateTime::now_utc();
4259 let UpdatedChannelMessage {
4260 message_id,
4261 participant_connection_ids,
4262 notifications,
4263 reply_to_message_id,
4264 timestamp,
4265 deleted_mention_notification_ids,
4266 updated_mention_notifications,
4267 } = session
4268 .db()
4269 .await
4270 .update_channel_message(
4271 channel_id,
4272 message_id,
4273 session.user_id(),
4274 request.body.as_str(),
4275 &request.mentions,
4276 updated_at,
4277 )
4278 .await?;
4279
4280 let nonce = request
4281 .nonce
4282 .clone()
4283 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4284
4285 let message = proto::ChannelMessage {
4286 sender_id: session.user_id().to_proto(),
4287 id: message_id.to_proto(),
4288 body: request.body.clone(),
4289 mentions: request.mentions.clone(),
4290 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4291 nonce: Some(nonce),
4292 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4293 edited_at: Some(updated_at.unix_timestamp() as u64),
4294 };
4295
4296 response.send(proto::Ack {})?;
4297
4298 let pool = &*session.connection_pool().await;
4299 broadcast(
4300 Some(session.connection_id),
4301 participant_connection_ids,
4302 |connection| {
4303 session.peer.send(
4304 connection,
4305 proto::ChannelMessageUpdate {
4306 channel_id: channel_id.to_proto(),
4307 message: Some(message.clone()),
4308 },
4309 )?;
4310
4311 for notification_id in &deleted_mention_notification_ids {
4312 session.peer.send(
4313 connection,
4314 proto::DeleteNotification {
4315 notification_id: (*notification_id).to_proto(),
4316 },
4317 )?;
4318 }
4319
4320 for notification in &updated_mention_notifications {
4321 session.peer.send(
4322 connection,
4323 proto::UpdateNotification {
4324 notification: Some(notification.clone()),
4325 },
4326 )?;
4327 }
4328
4329 Ok(())
4330 },
4331 );
4332
4333 send_notifications(pool, &session.peer, notifications);
4334
4335 Ok(())
4336}
4337
4338/// Mark a channel message as read
4339async fn acknowledge_channel_message(
4340 request: proto::AckChannelMessage,
4341 session: UserSession,
4342) -> Result<()> {
4343 let channel_id = ChannelId::from_proto(request.channel_id);
4344 let message_id = MessageId::from_proto(request.message_id);
4345 let notifications = session
4346 .db()
4347 .await
4348 .observe_channel_message(channel_id, session.user_id(), message_id)
4349 .await?;
4350 send_notifications(
4351 &*session.connection_pool().await,
4352 &session.peer,
4353 notifications,
4354 );
4355 Ok(())
4356}
4357
4358/// Mark a buffer version as synced
4359async fn acknowledge_buffer_version(
4360 request: proto::AckBufferOperation,
4361 session: UserSession,
4362) -> Result<()> {
4363 let buffer_id = BufferId::from_proto(request.buffer_id);
4364 session
4365 .db()
4366 .await
4367 .observe_buffer_version(
4368 buffer_id,
4369 session.user_id(),
4370 request.epoch as i32,
4371 &request.version,
4372 )
4373 .await?;
4374 Ok(())
4375}
4376
4377struct CompleteWithLanguageModelRateLimit;
4378
4379impl RateLimit for CompleteWithLanguageModelRateLimit {
4380 fn capacity() -> usize {
4381 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4382 .ok()
4383 .and_then(|v| v.parse().ok())
4384 .unwrap_or(120) // Picked arbitrarily
4385 }
4386
4387 fn refill_duration() -> chrono::Duration {
4388 chrono::Duration::hours(1)
4389 }
4390
4391 fn db_name() -> &'static str {
4392 "complete-with-language-model"
4393 }
4394}
4395
4396async fn complete_with_language_model(
4397 request: proto::CompleteWithLanguageModel,
4398 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4399 session: Session,
4400 open_ai_api_key: Option<Arc<str>>,
4401 google_ai_api_key: Option<Arc<str>>,
4402 anthropic_api_key: Option<Arc<str>>,
4403) -> Result<()> {
4404 let Some(session) = session.for_user() else {
4405 return Err(anyhow!("user not found"))?;
4406 };
4407 authorize_access_to_language_models(&session).await?;
4408 session
4409 .rate_limiter
4410 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4411 .await?;
4412
4413 if request.model.starts_with("gpt") {
4414 let api_key =
4415 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4416 complete_with_open_ai(request, response, session, api_key).await?;
4417 } else if request.model.starts_with("gemini") {
4418 let api_key = google_ai_api_key
4419 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4420 complete_with_google_ai(request, response, session, api_key).await?;
4421 } else if request.model.starts_with("claude") {
4422 let api_key = anthropic_api_key
4423 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4424 complete_with_anthropic(request, response, session, api_key).await?;
4425 }
4426
4427 Ok(())
4428}
4429
4430async fn complete_with_open_ai(
4431 request: proto::CompleteWithLanguageModel,
4432 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4433 session: UserSession,
4434 api_key: Arc<str>,
4435) -> Result<()> {
4436 let mut completion_stream = open_ai::stream_completion(
4437 session.http_client.as_ref(),
4438 OPEN_AI_API_URL,
4439 &api_key,
4440 crate::ai::language_model_request_to_open_ai(request)?,
4441 None,
4442 )
4443 .await
4444 .context("open_ai::stream_completion request failed within collab")?;
4445
4446 while let Some(event) = completion_stream.next().await {
4447 let event = event?;
4448 response.send(proto::LanguageModelResponse {
4449 choices: event
4450 .choices
4451 .into_iter()
4452 .map(|choice| proto::LanguageModelChoiceDelta {
4453 index: choice.index,
4454 delta: Some(proto::LanguageModelResponseMessage {
4455 role: choice.delta.role.map(|role| match role {
4456 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4457 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4458 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4459 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4460 } as i32),
4461 content: choice.delta.content,
4462 tool_calls: choice
4463 .delta
4464 .tool_calls
4465 .into_iter()
4466 .map(|delta| proto::ToolCallDelta {
4467 index: delta.index as u32,
4468 id: delta.id,
4469 variant: match delta.function {
4470 Some(function) => {
4471 let name = function.name;
4472 let arguments = function.arguments;
4473
4474 Some(proto::tool_call_delta::Variant::Function(
4475 proto::tool_call_delta::FunctionCallDelta {
4476 name,
4477 arguments,
4478 },
4479 ))
4480 }
4481 None => None,
4482 },
4483 })
4484 .collect(),
4485 }),
4486 finish_reason: choice.finish_reason,
4487 })
4488 .collect(),
4489 })?;
4490 }
4491
4492 Ok(())
4493}
4494
4495async fn complete_with_google_ai(
4496 request: proto::CompleteWithLanguageModel,
4497 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4498 session: UserSession,
4499 api_key: Arc<str>,
4500) -> Result<()> {
4501 let mut stream = google_ai::stream_generate_content(
4502 session.http_client.clone(),
4503 google_ai::API_URL,
4504 api_key.as_ref(),
4505 &request.model.clone(),
4506 crate::ai::language_model_request_to_google_ai(request)?,
4507 )
4508 .await
4509 .context("google_ai::stream_generate_content request failed")?;
4510
4511 while let Some(event) = stream.next().await {
4512 let event = event?;
4513 response.send(proto::LanguageModelResponse {
4514 choices: event
4515 .candidates
4516 .unwrap_or_default()
4517 .into_iter()
4518 .map(|candidate| proto::LanguageModelChoiceDelta {
4519 index: candidate.index as u32,
4520 delta: Some(proto::LanguageModelResponseMessage {
4521 role: Some(match candidate.content.role {
4522 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4523 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4524 } as i32),
4525 content: Some(
4526 candidate
4527 .content
4528 .parts
4529 .into_iter()
4530 .filter_map(|part| match part {
4531 google_ai::Part::TextPart(part) => Some(part.text),
4532 google_ai::Part::InlineDataPart(_) => None,
4533 })
4534 .collect(),
4535 ),
4536 // Tool calls are not supported for Google
4537 tool_calls: Vec::new(),
4538 }),
4539 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4540 })
4541 .collect(),
4542 })?;
4543 }
4544
4545 Ok(())
4546}
4547
4548async fn complete_with_anthropic(
4549 request: proto::CompleteWithLanguageModel,
4550 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4551 session: UserSession,
4552 api_key: Arc<str>,
4553) -> Result<()> {
4554 let model = anthropic::Model::from_id(&request.model)?;
4555
4556 let mut system_message = String::new();
4557 let messages = request
4558 .messages
4559 .into_iter()
4560 .filter_map(|message| {
4561 match message.role() {
4562 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4563 role: anthropic::Role::User,
4564 content: message.content,
4565 }),
4566 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4567 role: anthropic::Role::Assistant,
4568 content: message.content,
4569 }),
4570 // Anthropic's API breaks system instructions out as a separate field rather
4571 // than having a system message role.
4572 LanguageModelRole::LanguageModelSystem => {
4573 if !system_message.is_empty() {
4574 system_message.push_str("\n\n");
4575 }
4576 system_message.push_str(&message.content);
4577
4578 None
4579 }
4580 // We don't yet support tool calls for Anthropic
4581 LanguageModelRole::LanguageModelTool => None,
4582 }
4583 })
4584 .collect();
4585
4586 let mut stream = anthropic::stream_completion(
4587 session.http_client.as_ref(),
4588 anthropic::ANTHROPIC_API_URL,
4589 &api_key,
4590 anthropic::Request {
4591 model,
4592 messages,
4593 stream: true,
4594 system: system_message,
4595 max_tokens: 4092,
4596 },
4597 None,
4598 )
4599 .await?;
4600
4601 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4602
4603 while let Some(event) = stream.next().await {
4604 let event = event?;
4605
4606 match event {
4607 anthropic::ResponseEvent::MessageStart { message } => {
4608 if let Some(role) = message.role {
4609 if role == "assistant" {
4610 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4611 } else if role == "user" {
4612 current_role = proto::LanguageModelRole::LanguageModelUser;
4613 }
4614 }
4615 }
4616 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4617 match content_block {
4618 anthropic::ContentBlock::Text { text } => {
4619 if !text.is_empty() {
4620 response.send(proto::LanguageModelResponse {
4621 choices: vec![proto::LanguageModelChoiceDelta {
4622 index: 0,
4623 delta: Some(proto::LanguageModelResponseMessage {
4624 role: Some(current_role as i32),
4625 content: Some(text),
4626 tool_calls: Vec::new(),
4627 }),
4628 finish_reason: None,
4629 }],
4630 })?;
4631 }
4632 }
4633 }
4634 }
4635 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4636 anthropic::TextDelta::TextDelta { text } => {
4637 response.send(proto::LanguageModelResponse {
4638 choices: vec![proto::LanguageModelChoiceDelta {
4639 index: 0,
4640 delta: Some(proto::LanguageModelResponseMessage {
4641 role: Some(current_role as i32),
4642 content: Some(text),
4643 tool_calls: Vec::new(),
4644 }),
4645 finish_reason: None,
4646 }],
4647 })?;
4648 }
4649 },
4650 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4651 if let Some(stop_reason) = delta.stop_reason {
4652 response.send(proto::LanguageModelResponse {
4653 choices: vec![proto::LanguageModelChoiceDelta {
4654 index: 0,
4655 delta: None,
4656 finish_reason: Some(stop_reason),
4657 }],
4658 })?;
4659 }
4660 }
4661 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4662 anthropic::ResponseEvent::MessageStop {} => {}
4663 anthropic::ResponseEvent::Ping {} => {}
4664 }
4665 }
4666
4667 Ok(())
4668}
4669
4670struct CountTokensWithLanguageModelRateLimit;
4671
4672impl RateLimit for CountTokensWithLanguageModelRateLimit {
4673 fn capacity() -> usize {
4674 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4675 .ok()
4676 .and_then(|v| v.parse().ok())
4677 .unwrap_or(600) // Picked arbitrarily
4678 }
4679
4680 fn refill_duration() -> chrono::Duration {
4681 chrono::Duration::hours(1)
4682 }
4683
4684 fn db_name() -> &'static str {
4685 "count-tokens-with-language-model"
4686 }
4687}
4688
4689async fn count_tokens_with_language_model(
4690 request: proto::CountTokensWithLanguageModel,
4691 response: Response<proto::CountTokensWithLanguageModel>,
4692 session: UserSession,
4693 google_ai_api_key: Option<Arc<str>>,
4694) -> Result<()> {
4695 authorize_access_to_language_models(&session).await?;
4696
4697 if !request.model.starts_with("gemini") {
4698 return Err(anyhow!(
4699 "counting tokens for model: {:?} is not supported",
4700 request.model
4701 ))?;
4702 }
4703
4704 session
4705 .rate_limiter
4706 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4707 .await?;
4708
4709 let api_key = google_ai_api_key
4710 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4711 let tokens_response = google_ai::count_tokens(
4712 session.http_client.as_ref(),
4713 google_ai::API_URL,
4714 &api_key,
4715 crate::ai::count_tokens_request_to_google_ai(request)?,
4716 )
4717 .await?;
4718 response.send(proto::CountTokensResponse {
4719 token_count: tokens_response.total_tokens as u32,
4720 })?;
4721 Ok(())
4722}
4723
4724struct ComputeEmbeddingsRateLimit;
4725
4726impl RateLimit for ComputeEmbeddingsRateLimit {
4727 fn capacity() -> usize {
4728 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4729 .ok()
4730 .and_then(|v| v.parse().ok())
4731 .unwrap_or(5000) // Picked arbitrarily
4732 }
4733
4734 fn refill_duration() -> chrono::Duration {
4735 chrono::Duration::hours(1)
4736 }
4737
4738 fn db_name() -> &'static str {
4739 "compute-embeddings"
4740 }
4741}
4742
4743async fn compute_embeddings(
4744 request: proto::ComputeEmbeddings,
4745 response: Response<proto::ComputeEmbeddings>,
4746 session: UserSession,
4747 api_key: Option<Arc<str>>,
4748) -> Result<()> {
4749 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4750 authorize_access_to_language_models(&session).await?;
4751
4752 session
4753 .rate_limiter
4754 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4755 .await?;
4756
4757 let embeddings = match request.model.as_str() {
4758 "openai/text-embedding-3-small" => {
4759 open_ai::embed(
4760 session.http_client.as_ref(),
4761 OPEN_AI_API_URL,
4762 &api_key,
4763 OpenAiEmbeddingModel::TextEmbedding3Small,
4764 request.texts.iter().map(|text| text.as_str()),
4765 )
4766 .await?
4767 }
4768 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4769 };
4770
4771 let embeddings = request
4772 .texts
4773 .iter()
4774 .map(|text| {
4775 let mut hasher = sha2::Sha256::new();
4776 hasher.update(text.as_bytes());
4777 let result = hasher.finalize();
4778 result.to_vec()
4779 })
4780 .zip(
4781 embeddings
4782 .data
4783 .into_iter()
4784 .map(|embedding| embedding.embedding),
4785 )
4786 .collect::<HashMap<_, _>>();
4787
4788 let db = session.db().await;
4789 db.save_embeddings(&request.model, &embeddings)
4790 .await
4791 .context("failed to save embeddings")
4792 .trace_err();
4793
4794 response.send(proto::ComputeEmbeddingsResponse {
4795 embeddings: embeddings
4796 .into_iter()
4797 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4798 .collect(),
4799 })?;
4800 Ok(())
4801}
4802
4803async fn get_cached_embeddings(
4804 request: proto::GetCachedEmbeddings,
4805 response: Response<proto::GetCachedEmbeddings>,
4806 session: UserSession,
4807) -> Result<()> {
4808 authorize_access_to_language_models(&session).await?;
4809
4810 let db = session.db().await;
4811 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4812
4813 response.send(proto::GetCachedEmbeddingsResponse {
4814 embeddings: embeddings
4815 .into_iter()
4816 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4817 .collect(),
4818 })?;
4819 Ok(())
4820}
4821
4822async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4823 let db = session.db().await;
4824 let flags = db.get_user_flags(session.user_id()).await?;
4825 if flags.iter().any(|flag| flag == "language-models") {
4826 Ok(())
4827 } else {
4828 Err(anyhow!("permission denied"))?
4829 }
4830}
4831
4832/// Get a Supermaven API key for the user
4833async fn get_supermaven_api_key(
4834 _request: proto::GetSupermavenApiKey,
4835 response: Response<proto::GetSupermavenApiKey>,
4836 session: UserSession,
4837) -> Result<()> {
4838 let user_id: String = session.user_id().to_string();
4839 if !session.is_staff() {
4840 return Err(anyhow!("supermaven not enabled for this account"))?;
4841 }
4842
4843 let email = session
4844 .email()
4845 .ok_or_else(|| anyhow!("user must have an email"))?;
4846
4847 let supermaven_admin_api = session
4848 .supermaven_client
4849 .as_ref()
4850 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4851
4852 let result = supermaven_admin_api
4853 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4854 .await?;
4855
4856 response.send(proto::GetSupermavenApiKeyResponse {
4857 api_key: result.api_key,
4858 })?;
4859
4860 Ok(())
4861}
4862
4863/// Start receiving chat updates for a channel
4864async fn join_channel_chat(
4865 request: proto::JoinChannelChat,
4866 response: Response<proto::JoinChannelChat>,
4867 session: UserSession,
4868) -> Result<()> {
4869 let channel_id = ChannelId::from_proto(request.channel_id);
4870
4871 let db = session.db().await;
4872 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4873 .await?;
4874 let messages = db
4875 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4876 .await?;
4877 response.send(proto::JoinChannelChatResponse {
4878 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4879 messages,
4880 })?;
4881 Ok(())
4882}
4883
4884/// Stop receiving chat updates for a channel
4885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4886 let channel_id = ChannelId::from_proto(request.channel_id);
4887 session
4888 .db()
4889 .await
4890 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4891 .await?;
4892 Ok(())
4893}
4894
4895/// Retrieve the chat history for a channel
4896async fn get_channel_messages(
4897 request: proto::GetChannelMessages,
4898 response: Response<proto::GetChannelMessages>,
4899 session: UserSession,
4900) -> Result<()> {
4901 let channel_id = ChannelId::from_proto(request.channel_id);
4902 let messages = session
4903 .db()
4904 .await
4905 .get_channel_messages(
4906 channel_id,
4907 session.user_id(),
4908 MESSAGE_COUNT_PER_PAGE,
4909 Some(MessageId::from_proto(request.before_message_id)),
4910 )
4911 .await?;
4912 response.send(proto::GetChannelMessagesResponse {
4913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4914 messages,
4915 })?;
4916 Ok(())
4917}
4918
4919/// Retrieve specific chat messages
4920async fn get_channel_messages_by_id(
4921 request: proto::GetChannelMessagesById,
4922 response: Response<proto::GetChannelMessagesById>,
4923 session: UserSession,
4924) -> Result<()> {
4925 let message_ids = request
4926 .message_ids
4927 .iter()
4928 .map(|id| MessageId::from_proto(*id))
4929 .collect::<Vec<_>>();
4930 let messages = session
4931 .db()
4932 .await
4933 .get_channel_messages_by_id(session.user_id(), &message_ids)
4934 .await?;
4935 response.send(proto::GetChannelMessagesResponse {
4936 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4937 messages,
4938 })?;
4939 Ok(())
4940}
4941
4942/// Retrieve the current users notifications
4943async fn get_notifications(
4944 request: proto::GetNotifications,
4945 response: Response<proto::GetNotifications>,
4946 session: UserSession,
4947) -> Result<()> {
4948 let notifications = session
4949 .db()
4950 .await
4951 .get_notifications(
4952 session.user_id(),
4953 NOTIFICATION_COUNT_PER_PAGE,
4954 request
4955 .before_id
4956 .map(|id| db::NotificationId::from_proto(id)),
4957 )
4958 .await?;
4959 response.send(proto::GetNotificationsResponse {
4960 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4961 notifications,
4962 })?;
4963 Ok(())
4964}
4965
4966/// Mark notifications as read
4967async fn mark_notification_as_read(
4968 request: proto::MarkNotificationRead,
4969 response: Response<proto::MarkNotificationRead>,
4970 session: UserSession,
4971) -> Result<()> {
4972 let database = &session.db().await;
4973 let notifications = database
4974 .mark_notification_as_read_by_id(
4975 session.user_id(),
4976 NotificationId::from_proto(request.notification_id),
4977 )
4978 .await?;
4979 send_notifications(
4980 &*session.connection_pool().await,
4981 &session.peer,
4982 notifications,
4983 );
4984 response.send(proto::Ack {})?;
4985 Ok(())
4986}
4987
4988/// Get the current users information
4989async fn get_private_user_info(
4990 _request: proto::GetPrivateUserInfo,
4991 response: Response<proto::GetPrivateUserInfo>,
4992 session: UserSession,
4993) -> Result<()> {
4994 let db = session.db().await;
4995
4996 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4997 let user = db
4998 .get_user_by_id(session.user_id())
4999 .await?
5000 .ok_or_else(|| anyhow!("user not found"))?;
5001 let flags = db.get_user_flags(session.user_id()).await?;
5002
5003 response.send(proto::GetPrivateUserInfoResponse {
5004 metrics_id,
5005 staff: user.admin,
5006 flags,
5007 })?;
5008 Ok(())
5009}
5010
5011fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5012 match message {
5013 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5014 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5015 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5016 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5017 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5018 code: frame.code.into(),
5019 reason: frame.reason,
5020 })),
5021 }
5022}
5023
5024fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5025 match message {
5026 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5027 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5028 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5029 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5030 AxumMessage::Close(frame) => {
5031 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5032 code: frame.code.into(),
5033 reason: frame.reason,
5034 }))
5035 }
5036 }
5037}
5038
5039fn notify_membership_updated(
5040 connection_pool: &mut ConnectionPool,
5041 result: MembershipUpdated,
5042 user_id: UserId,
5043 peer: &Peer,
5044) {
5045 for membership in &result.new_channels.channel_memberships {
5046 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5047 }
5048 for channel_id in &result.removed_channels {
5049 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5050 }
5051
5052 let user_channels_update = proto::UpdateUserChannels {
5053 channel_memberships: result
5054 .new_channels
5055 .channel_memberships
5056 .iter()
5057 .map(|cm| proto::ChannelMembership {
5058 channel_id: cm.channel_id.to_proto(),
5059 role: cm.role.into(),
5060 })
5061 .collect(),
5062 ..Default::default()
5063 };
5064
5065 let mut update = build_channels_update(result.new_channels);
5066 update.delete_channels = result
5067 .removed_channels
5068 .into_iter()
5069 .map(|id| id.to_proto())
5070 .collect();
5071 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5072
5073 for connection_id in connection_pool.user_connection_ids(user_id) {
5074 peer.send(connection_id, user_channels_update.clone())
5075 .trace_err();
5076 peer.send(connection_id, update.clone()).trace_err();
5077 }
5078}
5079
5080fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5081 proto::UpdateUserChannels {
5082 channel_memberships: channels
5083 .channel_memberships
5084 .iter()
5085 .map(|m| proto::ChannelMembership {
5086 channel_id: m.channel_id.to_proto(),
5087 role: m.role.into(),
5088 })
5089 .collect(),
5090 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5091 observed_channel_message_id: channels.observed_channel_messages.clone(),
5092 }
5093}
5094
5095fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5096 let mut update = proto::UpdateChannels::default();
5097
5098 for channel in channels.channels {
5099 update.channels.push(channel.to_proto());
5100 }
5101
5102 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5103 update.latest_channel_message_ids = channels.latest_channel_messages;
5104
5105 for (channel_id, participants) in channels.channel_participants {
5106 update
5107 .channel_participants
5108 .push(proto::ChannelParticipants {
5109 channel_id: channel_id.to_proto(),
5110 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5111 });
5112 }
5113
5114 for channel in channels.invited_channels {
5115 update.channel_invitations.push(channel.to_proto());
5116 }
5117
5118 update.hosted_projects = channels.hosted_projects;
5119 update
5120}
5121
5122fn build_initial_contacts_update(
5123 contacts: Vec<db::Contact>,
5124 pool: &ConnectionPool,
5125) -> proto::UpdateContacts {
5126 let mut update = proto::UpdateContacts::default();
5127
5128 for contact in contacts {
5129 match contact {
5130 db::Contact::Accepted { user_id, busy } => {
5131 update.contacts.push(contact_for_user(user_id, busy, &pool));
5132 }
5133 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5134 db::Contact::Incoming { user_id } => {
5135 update
5136 .incoming_requests
5137 .push(proto::IncomingContactRequest {
5138 requester_id: user_id.to_proto(),
5139 })
5140 }
5141 }
5142 }
5143
5144 update
5145}
5146
5147fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5148 proto::Contact {
5149 user_id: user_id.to_proto(),
5150 online: pool.is_user_online(user_id),
5151 busy,
5152 }
5153}
5154
5155fn room_updated(room: &proto::Room, peer: &Peer) {
5156 broadcast(
5157 None,
5158 room.participants
5159 .iter()
5160 .filter_map(|participant| Some(participant.peer_id?.into())),
5161 |peer_id| {
5162 peer.send(
5163 peer_id,
5164 proto::RoomUpdated {
5165 room: Some(room.clone()),
5166 },
5167 )
5168 },
5169 );
5170}
5171
5172fn channel_updated(
5173 channel: &db::channel::Model,
5174 room: &proto::Room,
5175 peer: &Peer,
5176 pool: &ConnectionPool,
5177) {
5178 let participants = room
5179 .participants
5180 .iter()
5181 .map(|p| p.user_id)
5182 .collect::<Vec<_>>();
5183
5184 broadcast(
5185 None,
5186 pool.channel_connection_ids(channel.root_id())
5187 .filter_map(|(channel_id, role)| {
5188 role.can_see_channel(channel.visibility).then(|| channel_id)
5189 }),
5190 |peer_id| {
5191 peer.send(
5192 peer_id,
5193 proto::UpdateChannels {
5194 channel_participants: vec![proto::ChannelParticipants {
5195 channel_id: channel.id.to_proto(),
5196 participant_user_ids: participants.clone(),
5197 }],
5198 ..Default::default()
5199 },
5200 )
5201 },
5202 );
5203}
5204
5205async fn send_dev_server_projects_update(
5206 user_id: UserId,
5207 mut status: proto::DevServerProjectsUpdate,
5208 session: &Session,
5209) {
5210 let pool = session.connection_pool().await;
5211 for dev_server in &mut status.dev_servers {
5212 dev_server.status =
5213 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5214 }
5215 let connections = pool.user_connection_ids(user_id);
5216 for connection_id in connections {
5217 session.peer.send(connection_id, status.clone()).trace_err();
5218 }
5219}
5220
5221async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5222 let db = session.db().await;
5223
5224 let contacts = db.get_contacts(user_id).await?;
5225 let busy = db.is_user_busy(user_id).await?;
5226
5227 let pool = session.connection_pool().await;
5228 let updated_contact = contact_for_user(user_id, busy, &pool);
5229 for contact in contacts {
5230 if let db::Contact::Accepted {
5231 user_id: contact_user_id,
5232 ..
5233 } = contact
5234 {
5235 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5236 session
5237 .peer
5238 .send(
5239 contact_conn_id,
5240 proto::UpdateContacts {
5241 contacts: vec![updated_contact.clone()],
5242 remove_contacts: Default::default(),
5243 incoming_requests: Default::default(),
5244 remove_incoming_requests: Default::default(),
5245 outgoing_requests: Default::default(),
5246 remove_outgoing_requests: Default::default(),
5247 },
5248 )
5249 .trace_err();
5250 }
5251 }
5252 }
5253 Ok(())
5254}
5255
5256async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5257 log::info!("lost dev server connection, unsharing projects");
5258 let project_ids = session
5259 .db()
5260 .await
5261 .get_stale_dev_server_projects(session.connection_id)
5262 .await?;
5263
5264 for project_id in project_ids {
5265 // not unshare re-checks the connection ids match, so we get away with no transaction
5266 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5267 }
5268
5269 let user_id = session.dev_server().user_id;
5270 let update = session
5271 .db()
5272 .await
5273 .dev_server_projects_update(user_id)
5274 .await?;
5275
5276 send_dev_server_projects_update(user_id, update, session).await;
5277
5278 Ok(())
5279}
5280
5281async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5282 let mut contacts_to_update = HashSet::default();
5283
5284 let room_id;
5285 let canceled_calls_to_user_ids;
5286 let live_kit_room;
5287 let delete_live_kit_room;
5288 let room;
5289 let channel;
5290
5291 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5292 contacts_to_update.insert(session.user_id());
5293
5294 for project in left_room.left_projects.values() {
5295 project_left(project, session);
5296 }
5297
5298 room_id = RoomId::from_proto(left_room.room.id);
5299 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5300 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5301 delete_live_kit_room = left_room.deleted;
5302 room = mem::take(&mut left_room.room);
5303 channel = mem::take(&mut left_room.channel);
5304
5305 room_updated(&room, &session.peer);
5306 } else {
5307 return Ok(());
5308 }
5309
5310 if let Some(channel) = channel {
5311 channel_updated(
5312 &channel,
5313 &room,
5314 &session.peer,
5315 &*session.connection_pool().await,
5316 );
5317 }
5318
5319 {
5320 let pool = session.connection_pool().await;
5321 for canceled_user_id in canceled_calls_to_user_ids {
5322 for connection_id in pool.user_connection_ids(canceled_user_id) {
5323 session
5324 .peer
5325 .send(
5326 connection_id,
5327 proto::CallCanceled {
5328 room_id: room_id.to_proto(),
5329 },
5330 )
5331 .trace_err();
5332 }
5333 contacts_to_update.insert(canceled_user_id);
5334 }
5335 }
5336
5337 for contact_user_id in contacts_to_update {
5338 update_user_contacts(contact_user_id, &session).await?;
5339 }
5340
5341 if let Some(live_kit) = session.live_kit_client.as_ref() {
5342 live_kit
5343 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5344 .await
5345 .trace_err();
5346
5347 if delete_live_kit_room {
5348 live_kit.delete_room(live_kit_room).await.trace_err();
5349 }
5350 }
5351
5352 Ok(())
5353}
5354
5355async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5356 let left_channel_buffers = session
5357 .db()
5358 .await
5359 .leave_channel_buffers(session.connection_id)
5360 .await?;
5361
5362 for left_buffer in left_channel_buffers {
5363 channel_buffer_updated(
5364 session.connection_id,
5365 left_buffer.connections,
5366 &proto::UpdateChannelBufferCollaborators {
5367 channel_id: left_buffer.channel_id.to_proto(),
5368 collaborators: left_buffer.collaborators,
5369 },
5370 &session.peer,
5371 );
5372 }
5373
5374 Ok(())
5375}
5376
5377fn project_left(project: &db::LeftProject, session: &UserSession) {
5378 for connection_id in &project.connection_ids {
5379 if project.should_unshare {
5380 session
5381 .peer
5382 .send(
5383 *connection_id,
5384 proto::UnshareProject {
5385 project_id: project.id.to_proto(),
5386 },
5387 )
5388 .trace_err();
5389 } else {
5390 session
5391 .peer
5392 .send(
5393 *connection_id,
5394 proto::RemoveProjectCollaborator {
5395 project_id: project.id.to_proto(),
5396 peer_id: Some(session.connection_id.into()),
5397 },
5398 )
5399 .trace_err();
5400 }
5401 }
5402}
5403
5404pub trait ResultExt {
5405 type Ok;
5406
5407 fn trace_err(self) -> Option<Self::Ok>;
5408}
5409
5410impl<T, E> ResultExt for Result<T, E>
5411where
5412 E: std::fmt::Debug,
5413{
5414 type Ok = T;
5415
5416 #[track_caller]
5417 fn trace_err(self) -> Option<T> {
5418 match self {
5419 Ok(value) => Some(value),
5420 Err(error) => {
5421 tracing::error!("{:?}", error);
5422 None
5423 }
5424 }
5425 }
5426}