1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_project_request_for_owner::<proto::TaskContextForLocation>,
451 ))
452 .add_request_handler(user_handler(
453 forward_project_request_for_owner::<proto::TaskTemplates>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetHover>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetDefinition>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::GetTypeDefinition>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetReferences>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::SearchProject>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::GetDocumentHighlights>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::GetProjectSymbols>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::OpenBufferById>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::SynchronizeBuffers>,
484 ))
485 .add_request_handler(user_handler(
486 forward_read_only_project_request::<proto::InlayHints>,
487 ))
488 .add_request_handler(user_handler(
489 forward_read_only_project_request::<proto::OpenBufferByPath>,
490 ))
491 .add_request_handler(user_handler(
492 forward_mutating_project_request::<proto::GetCompletions>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
496 ))
497 .add_request_handler(user_handler(
498 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::GetCodeActions>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::ApplyCodeAction>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::PrepareRename>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::PerformRename>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::ReloadBuffers>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::FormatBuffers>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CreateProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::RenameProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::CopyProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::DeleteProjectEntry>,
532 ))
533 .add_request_handler(user_handler(
534 forward_mutating_project_request::<proto::ExpandProjectEntry>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::OnTypeFormatting>,
538 ))
539 .add_request_handler(user_handler(
540 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
541 ))
542 .add_request_handler(user_handler(
543 forward_mutating_project_request::<proto::BlameBuffer>,
544 ))
545 .add_request_handler(user_handler(
546 forward_mutating_project_request::<proto::MultiLspQuery>,
547 ))
548 .add_message_handler(create_buffer_for_peer)
549 .add_request_handler(update_buffer)
550 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
551 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
552 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
553 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
554 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
555 .add_request_handler(get_users)
556 .add_request_handler(user_handler(fuzzy_search_users))
557 .add_request_handler(user_handler(request_contact))
558 .add_request_handler(user_handler(remove_contact))
559 .add_request_handler(user_handler(respond_to_contact_request))
560 .add_request_handler(user_handler(create_channel))
561 .add_request_handler(user_handler(delete_channel))
562 .add_request_handler(user_handler(invite_channel_member))
563 .add_request_handler(user_handler(remove_channel_member))
564 .add_request_handler(user_handler(set_channel_member_role))
565 .add_request_handler(user_handler(set_channel_visibility))
566 .add_request_handler(user_handler(rename_channel))
567 .add_request_handler(user_handler(join_channel_buffer))
568 .add_request_handler(user_handler(leave_channel_buffer))
569 .add_message_handler(user_message_handler(update_channel_buffer))
570 .add_request_handler(user_handler(rejoin_channel_buffers))
571 .add_request_handler(user_handler(get_channel_members))
572 .add_request_handler(user_handler(respond_to_channel_invite))
573 .add_request_handler(user_handler(join_channel))
574 .add_request_handler(user_handler(join_channel_chat))
575 .add_message_handler(user_message_handler(leave_channel_chat))
576 .add_request_handler(user_handler(send_channel_message))
577 .add_request_handler(user_handler(remove_channel_message))
578 .add_request_handler(user_handler(update_channel_message))
579 .add_request_handler(user_handler(get_channel_messages))
580 .add_request_handler(user_handler(get_channel_messages_by_id))
581 .add_request_handler(user_handler(get_notifications))
582 .add_request_handler(user_handler(mark_notification_as_read))
583 .add_request_handler(user_handler(move_channel))
584 .add_request_handler(user_handler(follow))
585 .add_message_handler(user_message_handler(unfollow))
586 .add_message_handler(user_message_handler(update_followers))
587 .add_request_handler(user_handler(get_private_user_info))
588 .add_message_handler(user_message_handler(acknowledge_channel_message))
589 .add_message_handler(user_message_handler(acknowledge_buffer_version))
590 .add_request_handler(user_handler(get_supermaven_api_key))
591 .add_streaming_request_handler({
592 let app_state = app_state.clone();
593 move |request, response, session| {
594 complete_with_language_model(
595 request,
596 response,
597 session,
598 app_state.config.openai_api_key.clone(),
599 app_state.config.google_ai_api_key.clone(),
600 app_state.config.anthropic_api_key.clone(),
601 )
602 }
603 })
604 .add_request_handler({
605 let app_state = app_state.clone();
606 user_handler(move |request, response, session| {
607 count_tokens_with_language_model(
608 request,
609 response,
610 session,
611 app_state.config.google_ai_api_key.clone(),
612 )
613 })
614 })
615 .add_request_handler({
616 user_handler(move |request, response, session| {
617 get_cached_embeddings(request, response, session)
618 })
619 })
620 .add_request_handler({
621 let app_state = app_state.clone();
622 user_handler(move |request, response, session| {
623 compute_embeddings(
624 request,
625 response,
626 session,
627 app_state.config.openai_api_key.clone(),
628 )
629 })
630 });
631
632 Arc::new(server)
633 }
634
635 pub async fn start(&self) -> Result<()> {
636 let server_id = *self.id.lock();
637 let app_state = self.app_state.clone();
638 let peer = self.peer.clone();
639 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
640 let pool = self.connection_pool.clone();
641 let live_kit_client = self.app_state.live_kit_client.clone();
642
643 let span = info_span!("start server");
644 self.app_state.executor.spawn_detached(
645 async move {
646 tracing::info!("waiting for cleanup timeout");
647 timeout.await;
648 tracing::info!("cleanup timeout expired, retrieving stale rooms");
649 if let Some((room_ids, channel_ids)) = app_state
650 .db
651 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
652 .await
653 .trace_err()
654 {
655 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
656 tracing::info!(
657 stale_channel_buffer_count = channel_ids.len(),
658 "retrieved stale channel buffers"
659 );
660
661 for channel_id in channel_ids {
662 if let Some(refreshed_channel_buffer) = app_state
663 .db
664 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
665 .await
666 .trace_err()
667 {
668 for connection_id in refreshed_channel_buffer.connection_ids {
669 peer.send(
670 connection_id,
671 proto::UpdateChannelBufferCollaborators {
672 channel_id: channel_id.to_proto(),
673 collaborators: refreshed_channel_buffer
674 .collaborators
675 .clone(),
676 },
677 )
678 .trace_err();
679 }
680 }
681 }
682
683 for room_id in room_ids {
684 let mut contacts_to_update = HashSet::default();
685 let mut canceled_calls_to_user_ids = Vec::new();
686 let mut live_kit_room = String::new();
687 let mut delete_live_kit_room = false;
688
689 if let Some(mut refreshed_room) = app_state
690 .db
691 .clear_stale_room_participants(room_id, server_id)
692 .await
693 .trace_err()
694 {
695 tracing::info!(
696 room_id = room_id.0,
697 new_participant_count = refreshed_room.room.participants.len(),
698 "refreshed room"
699 );
700 room_updated(&refreshed_room.room, &peer);
701 if let Some(channel) = refreshed_room.channel.as_ref() {
702 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
703 }
704 contacts_to_update
705 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
706 contacts_to_update
707 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
708 canceled_calls_to_user_ids =
709 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
710 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
711 delete_live_kit_room = refreshed_room.room.participants.is_empty();
712 }
713
714 {
715 let pool = pool.lock();
716 for canceled_user_id in canceled_calls_to_user_ids {
717 for connection_id in pool.user_connection_ids(canceled_user_id) {
718 peer.send(
719 connection_id,
720 proto::CallCanceled {
721 room_id: room_id.to_proto(),
722 },
723 )
724 .trace_err();
725 }
726 }
727 }
728
729 for user_id in contacts_to_update {
730 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
731 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
732 if let Some((busy, contacts)) = busy.zip(contacts) {
733 let pool = pool.lock();
734 let updated_contact = contact_for_user(user_id, busy, &pool);
735 for contact in contacts {
736 if let db::Contact::Accepted {
737 user_id: contact_user_id,
738 ..
739 } = contact
740 {
741 for contact_conn_id in
742 pool.user_connection_ids(contact_user_id)
743 {
744 peer.send(
745 contact_conn_id,
746 proto::UpdateContacts {
747 contacts: vec![updated_contact.clone()],
748 remove_contacts: Default::default(),
749 incoming_requests: Default::default(),
750 remove_incoming_requests: Default::default(),
751 outgoing_requests: Default::default(),
752 remove_outgoing_requests: Default::default(),
753 },
754 )
755 .trace_err();
756 }
757 }
758 }
759 }
760 }
761
762 if let Some(live_kit) = live_kit_client.as_ref() {
763 if delete_live_kit_room {
764 live_kit.delete_room(live_kit_room).await.trace_err();
765 }
766 }
767 }
768 }
769
770 app_state
771 .db
772 .delete_stale_servers(&app_state.config.zed_environment, server_id)
773 .await
774 .trace_err();
775 }
776 .instrument(span),
777 );
778 Ok(())
779 }
780
781 pub fn teardown(&self) {
782 self.peer.teardown();
783 self.connection_pool.lock().reset();
784 let _ = self.teardown.send(true);
785 }
786
787 #[cfg(test)]
788 pub fn reset(&self, id: ServerId) {
789 self.teardown();
790 *self.id.lock() = id;
791 self.peer.reset(id.0 as u32);
792 let _ = self.teardown.send(false);
793 }
794
795 #[cfg(test)]
796 pub fn id(&self) -> ServerId {
797 *self.id.lock()
798 }
799
800 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
801 where
802 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
803 Fut: 'static + Send + Future<Output = Result<()>>,
804 M: EnvelopedMessage,
805 {
806 let prev_handler = self.handlers.insert(
807 TypeId::of::<M>(),
808 Box::new(move |envelope, session| {
809 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
810 let received_at = envelope.received_at;
811 tracing::info!("message received");
812 let start_time = Instant::now();
813 let future = (handler)(*envelope, session);
814 async move {
815 let result = future.await;
816 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
817 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
818 let queue_duration_ms = total_duration_ms - processing_duration_ms;
819 let payload_type = M::NAME;
820
821 match result {
822 Err(error) => {
823 tracing::error!(
824 ?error,
825 total_duration_ms,
826 processing_duration_ms,
827 queue_duration_ms,
828 payload_type,
829 "error handling message"
830 )
831 }
832 Ok(()) => tracing::info!(
833 total_duration_ms,
834 processing_duration_ms,
835 queue_duration_ms,
836 "finished handling message"
837 ),
838 }
839 }
840 .boxed()
841 }),
842 );
843 if prev_handler.is_some() {
844 panic!("registered a handler for the same message twice");
845 }
846 self
847 }
848
849 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
850 where
851 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
852 Fut: 'static + Send + Future<Output = Result<()>>,
853 M: EnvelopedMessage,
854 {
855 self.add_handler(move |envelope, session| handler(envelope.payload, session));
856 self
857 }
858
859 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
860 where
861 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
862 Fut: Send + Future<Output = Result<()>>,
863 M: RequestMessage,
864 {
865 let handler = Arc::new(handler);
866 self.add_handler(move |envelope, session| {
867 let receipt = envelope.receipt();
868 let handler = handler.clone();
869 async move {
870 let peer = session.peer.clone();
871 let responded = Arc::new(AtomicBool::default());
872 let response = Response {
873 peer: peer.clone(),
874 responded: responded.clone(),
875 receipt,
876 };
877 match (handler)(envelope.payload, response, session).await {
878 Ok(()) => {
879 if responded.load(std::sync::atomic::Ordering::SeqCst) {
880 Ok(())
881 } else {
882 Err(anyhow!("handler did not send a response"))?
883 }
884 }
885 Err(error) => {
886 let proto_err = match &error {
887 Error::Internal(err) => err.to_proto(),
888 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
889 };
890 peer.respond_with_error(receipt, proto_err)?;
891 Err(error)
892 }
893 }
894 }
895 })
896 }
897
898 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
899 where
900 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
901 Fut: Send + Future<Output = Result<()>>,
902 M: RequestMessage,
903 {
904 let handler = Arc::new(handler);
905 self.add_handler(move |envelope, session| {
906 let receipt = envelope.receipt();
907 let handler = handler.clone();
908 async move {
909 let peer = session.peer.clone();
910 let response = StreamingResponse {
911 peer: peer.clone(),
912 receipt,
913 };
914 match (handler)(envelope.payload, response, session).await {
915 Ok(()) => {
916 peer.end_stream(receipt)?;
917 Ok(())
918 }
919 Err(error) => {
920 let proto_err = match &error {
921 Error::Internal(err) => err.to_proto(),
922 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
923 };
924 peer.respond_with_error(receipt, proto_err)?;
925 Err(error)
926 }
927 }
928 }
929 })
930 }
931
932 #[allow(clippy::too_many_arguments)]
933 pub fn handle_connection(
934 self: &Arc<Self>,
935 connection: Connection,
936 address: String,
937 principal: Principal,
938 zed_version: ZedVersion,
939 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
940 executor: Executor,
941 ) -> impl Future<Output = ()> {
942 let this = self.clone();
943 let span = info_span!("handle connection", %address,
944 connection_id=field::Empty,
945 user_id=field::Empty,
946 login=field::Empty,
947 impersonator=field::Empty,
948 dev_server_id=field::Empty
949 );
950 principal.update_span(&span);
951
952 let mut teardown = self.teardown.subscribe();
953 async move {
954 if *teardown.borrow() {
955 tracing::error!("server is tearing down");
956 return
957 }
958 let (connection_id, handle_io, mut incoming_rx) = this
959 .peer
960 .add_connection(connection, {
961 let executor = executor.clone();
962 move |duration| executor.sleep(duration)
963 });
964 tracing::Span::current().record("connection_id", format!("{}", connection_id));
965 tracing::info!("connection opened");
966
967 let http_client = match IsahcHttpClient::new() {
968 Ok(http_client) => Arc::new(http_client),
969 Err(error) => {
970 tracing::error!(?error, "failed to create HTTP client");
971 return;
972 }
973 };
974
975 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
976 Some(Arc::new(SupermavenAdminApi::new(
977 supermaven_admin_api_key.to_string(),
978 http_client.clone(),
979 )))
980 } else {
981 None
982 };
983
984 let session = Session {
985 principal: principal.clone(),
986 connection_id,
987 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
988 peer: this.peer.clone(),
989 connection_pool: this.connection_pool.clone(),
990 live_kit_client: this.app_state.live_kit_client.clone(),
991 http_client,
992 rate_limiter: this.app_state.rate_limiter.clone(),
993 _executor: executor.clone(),
994 supermaven_client,
995 };
996
997 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
998 tracing::error!(?error, "failed to send initial client update");
999 return;
1000 }
1001
1002 let handle_io = handle_io.fuse();
1003 futures::pin_mut!(handle_io);
1004
1005 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1006 // This prevents deadlocks when e.g., client A performs a request to client B and
1007 // client B performs a request to client A. If both clients stop processing further
1008 // messages until their respective request completes, they won't have a chance to
1009 // respond to the other client's request and cause a deadlock.
1010 //
1011 // This arrangement ensures we will attempt to process earlier messages first, but fall
1012 // back to processing messages arrived later in the spirit of making progress.
1013 let mut foreground_message_handlers = FuturesUnordered::new();
1014 let concurrent_handlers = Arc::new(Semaphore::new(256));
1015 loop {
1016 let next_message = async {
1017 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1018 let message = incoming_rx.next().await;
1019 (permit, message)
1020 }.fuse();
1021 futures::pin_mut!(next_message);
1022 futures::select_biased! {
1023 _ = teardown.changed().fuse() => return,
1024 result = handle_io => {
1025 if let Err(error) = result {
1026 tracing::error!(?error, "error handling I/O");
1027 }
1028 break;
1029 }
1030 _ = foreground_message_handlers.next() => {}
1031 next_message = next_message => {
1032 let (permit, message) = next_message;
1033 if let Some(message) = message {
1034 let type_name = message.payload_type_name();
1035 // note: we copy all the fields from the parent span so we can query them in the logs.
1036 // (https://github.com/tokio-rs/tracing/issues/2670).
1037 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1038 user_id=field::Empty,
1039 login=field::Empty,
1040 impersonator=field::Empty,
1041 dev_server_id=field::Empty
1042 );
1043 principal.update_span(&span);
1044 let span_enter = span.enter();
1045 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1046 let is_background = message.is_background();
1047 let handle_message = (handler)(message, session.clone());
1048 drop(span_enter);
1049
1050 let handle_message = async move {
1051 handle_message.await;
1052 drop(permit);
1053 }.instrument(span);
1054 if is_background {
1055 executor.spawn_detached(handle_message);
1056 } else {
1057 foreground_message_handlers.push(handle_message);
1058 }
1059 } else {
1060 tracing::error!("no message handler");
1061 }
1062 } else {
1063 tracing::info!("connection closed");
1064 break;
1065 }
1066 }
1067 }
1068 }
1069
1070 drop(foreground_message_handlers);
1071 tracing::info!("signing out");
1072 if let Err(error) = connection_lost(session, teardown, executor).await {
1073 tracing::error!(?error, "error signing out");
1074 }
1075
1076 }.instrument(span)
1077 }
1078
1079 async fn send_initial_client_update(
1080 &self,
1081 connection_id: ConnectionId,
1082 principal: &Principal,
1083 zed_version: ZedVersion,
1084 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1085 session: &Session,
1086 ) -> Result<()> {
1087 self.peer.send(
1088 connection_id,
1089 proto::Hello {
1090 peer_id: Some(connection_id.into()),
1091 },
1092 )?;
1093 tracing::info!("sent hello message");
1094 if let Some(send_connection_id) = send_connection_id.take() {
1095 let _ = send_connection_id.send(connection_id);
1096 }
1097
1098 match principal {
1099 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1100 if !user.connected_once {
1101 self.peer.send(connection_id, proto::ShowContacts {})?;
1102 self.app_state
1103 .db
1104 .set_user_connected_once(user.id, true)
1105 .await?;
1106 }
1107
1108 let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1109 future::try_join4(
1110 self.app_state.db.get_contacts(user.id),
1111 self.app_state.db.get_channels_for_user(user.id),
1112 self.app_state.db.get_channel_invites_for_user(user.id),
1113 self.app_state.db.dev_server_projects_update(user.id),
1114 )
1115 .await?;
1116
1117 {
1118 let mut pool = self.connection_pool.lock();
1119 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1120 for membership in &channels_for_user.channel_memberships {
1121 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1122 }
1123 self.peer.send(
1124 connection_id,
1125 build_initial_contacts_update(contacts, &pool),
1126 )?;
1127 self.peer.send(
1128 connection_id,
1129 build_update_user_channels(&channels_for_user),
1130 )?;
1131 self.peer.send(
1132 connection_id,
1133 build_channels_update(channels_for_user, channel_invites),
1134 )?;
1135 }
1136 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1137
1138 if let Some(incoming_call) =
1139 self.app_state.db.incoming_call_for_user(user.id).await?
1140 {
1141 self.peer.send(connection_id, incoming_call)?;
1142 }
1143
1144 update_user_contacts(user.id, &session).await?;
1145 }
1146 Principal::DevServer(dev_server) => {
1147 {
1148 let mut pool = self.connection_pool.lock();
1149 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1150 {
1151 self.peer.send(
1152 stale_connection_id,
1153 proto::ShutdownDevServer {
1154 reason: Some(
1155 "another dev server connected with the same token".to_string(),
1156 ),
1157 },
1158 )?;
1159 pool.remove_connection(stale_connection_id)?;
1160 };
1161 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1162 }
1163
1164 let projects = self
1165 .app_state
1166 .db
1167 .get_projects_for_dev_server(dev_server.id)
1168 .await?;
1169 self.peer
1170 .send(connection_id, proto::DevServerInstructions { projects })?;
1171
1172 let status = self
1173 .app_state
1174 .db
1175 .dev_server_projects_update(dev_server.user_id)
1176 .await?;
1177 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1178 }
1179 }
1180
1181 Ok(())
1182 }
1183
1184 pub async fn invite_code_redeemed(
1185 self: &Arc<Self>,
1186 inviter_id: UserId,
1187 invitee_id: UserId,
1188 ) -> Result<()> {
1189 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1190 if let Some(code) = &user.invite_code {
1191 let pool = self.connection_pool.lock();
1192 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1193 for connection_id in pool.user_connection_ids(inviter_id) {
1194 self.peer.send(
1195 connection_id,
1196 proto::UpdateContacts {
1197 contacts: vec![invitee_contact.clone()],
1198 ..Default::default()
1199 },
1200 )?;
1201 self.peer.send(
1202 connection_id,
1203 proto::UpdateInviteInfo {
1204 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1205 count: user.invite_count as u32,
1206 },
1207 )?;
1208 }
1209 }
1210 }
1211 Ok(())
1212 }
1213
1214 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1215 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1216 if let Some(invite_code) = &user.invite_code {
1217 let pool = self.connection_pool.lock();
1218 for connection_id in pool.user_connection_ids(user_id) {
1219 self.peer.send(
1220 connection_id,
1221 proto::UpdateInviteInfo {
1222 url: format!(
1223 "{}{}",
1224 self.app_state.config.invite_link_prefix, invite_code
1225 ),
1226 count: user.invite_count as u32,
1227 },
1228 )?;
1229 }
1230 }
1231 }
1232 Ok(())
1233 }
1234
1235 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1236 ServerSnapshot {
1237 connection_pool: ConnectionPoolGuard {
1238 guard: self.connection_pool.lock(),
1239 _not_send: PhantomData,
1240 },
1241 peer: &self.peer,
1242 }
1243 }
1244}
1245
1246impl<'a> Deref for ConnectionPoolGuard<'a> {
1247 type Target = ConnectionPool;
1248
1249 fn deref(&self) -> &Self::Target {
1250 &self.guard
1251 }
1252}
1253
1254impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1255 fn deref_mut(&mut self) -> &mut Self::Target {
1256 &mut self.guard
1257 }
1258}
1259
1260impl<'a> Drop for ConnectionPoolGuard<'a> {
1261 fn drop(&mut self) {
1262 #[cfg(test)]
1263 self.check_invariants();
1264 }
1265}
1266
1267fn broadcast<F>(
1268 sender_id: Option<ConnectionId>,
1269 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1270 mut f: F,
1271) where
1272 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1273{
1274 for receiver_id in receiver_ids {
1275 if Some(receiver_id) != sender_id {
1276 if let Err(error) = f(receiver_id) {
1277 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1278 }
1279 }
1280 }
1281}
1282
1283pub struct ProtocolVersion(u32);
1284
1285impl Header for ProtocolVersion {
1286 fn name() -> &'static HeaderName {
1287 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1288 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1289 }
1290
1291 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1292 where
1293 Self: Sized,
1294 I: Iterator<Item = &'i axum::http::HeaderValue>,
1295 {
1296 let version = values
1297 .next()
1298 .ok_or_else(axum::headers::Error::invalid)?
1299 .to_str()
1300 .map_err(|_| axum::headers::Error::invalid())?
1301 .parse()
1302 .map_err(|_| axum::headers::Error::invalid())?;
1303 Ok(Self(version))
1304 }
1305
1306 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1307 values.extend([self.0.to_string().parse().unwrap()]);
1308 }
1309}
1310
1311pub struct AppVersionHeader(SemanticVersion);
1312impl Header for AppVersionHeader {
1313 fn name() -> &'static HeaderName {
1314 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1315 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1316 }
1317
1318 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1319 where
1320 Self: Sized,
1321 I: Iterator<Item = &'i axum::http::HeaderValue>,
1322 {
1323 let version = values
1324 .next()
1325 .ok_or_else(axum::headers::Error::invalid)?
1326 .to_str()
1327 .map_err(|_| axum::headers::Error::invalid())?
1328 .parse()
1329 .map_err(|_| axum::headers::Error::invalid())?;
1330 Ok(Self(version))
1331 }
1332
1333 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1334 values.extend([self.0.to_string().parse().unwrap()]);
1335 }
1336}
1337
1338pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1339 Router::new()
1340 .route("/rpc", get(handle_websocket_request))
1341 .layer(
1342 ServiceBuilder::new()
1343 .layer(Extension(server.app_state.clone()))
1344 .layer(middleware::from_fn(auth::validate_header)),
1345 )
1346 .route("/metrics", get(handle_metrics))
1347 .layer(Extension(server))
1348}
1349
1350pub async fn handle_websocket_request(
1351 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1352 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1353 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1354 Extension(server): Extension<Arc<Server>>,
1355 Extension(principal): Extension<Principal>,
1356 ws: WebSocketUpgrade,
1357) -> axum::response::Response {
1358 if protocol_version != rpc::PROTOCOL_VERSION {
1359 return (
1360 StatusCode::UPGRADE_REQUIRED,
1361 "client must be upgraded".to_string(),
1362 )
1363 .into_response();
1364 }
1365
1366 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1367 return (
1368 StatusCode::UPGRADE_REQUIRED,
1369 "no version header found".to_string(),
1370 )
1371 .into_response();
1372 };
1373
1374 if !version.can_collaborate() {
1375 return (
1376 StatusCode::UPGRADE_REQUIRED,
1377 "client must be upgraded".to_string(),
1378 )
1379 .into_response();
1380 }
1381
1382 let socket_address = socket_address.to_string();
1383 ws.on_upgrade(move |socket| {
1384 let socket = socket
1385 .map_ok(to_tungstenite_message)
1386 .err_into()
1387 .with(|message| async move { Ok(to_axum_message(message)) });
1388 let connection = Connection::new(Box::pin(socket));
1389 async move {
1390 server
1391 .handle_connection(
1392 connection,
1393 socket_address,
1394 principal,
1395 version,
1396 None,
1397 Executor::Production,
1398 )
1399 .await;
1400 }
1401 })
1402}
1403
1404pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1405 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1406 let connections_metric = CONNECTIONS_METRIC
1407 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1408
1409 let connections = server
1410 .connection_pool
1411 .lock()
1412 .connections()
1413 .filter(|connection| !connection.admin)
1414 .count();
1415 connections_metric.set(connections as _);
1416
1417 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1418 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1419 register_int_gauge!(
1420 "shared_projects",
1421 "number of open projects with one or more guests"
1422 )
1423 .unwrap()
1424 });
1425
1426 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1427 shared_projects_metric.set(shared_projects as _);
1428
1429 let encoder = prometheus::TextEncoder::new();
1430 let metric_families = prometheus::gather();
1431 let encoded_metrics = encoder
1432 .encode_to_string(&metric_families)
1433 .map_err(|err| anyhow!("{}", err))?;
1434 Ok(encoded_metrics)
1435}
1436
1437#[instrument(err, skip(executor))]
1438async fn connection_lost(
1439 session: Session,
1440 mut teardown: watch::Receiver<bool>,
1441 executor: Executor,
1442) -> Result<()> {
1443 session.peer.disconnect(session.connection_id);
1444 session
1445 .connection_pool()
1446 .await
1447 .remove_connection(session.connection_id)?;
1448
1449 session
1450 .db()
1451 .await
1452 .connection_lost(session.connection_id)
1453 .await
1454 .trace_err();
1455
1456 futures::select_biased! {
1457 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1458 match &session.principal {
1459 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1460 let session = session.for_user().unwrap();
1461
1462 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1463 leave_room_for_session(&session, session.connection_id).await.trace_err();
1464 leave_channel_buffers_for_session(&session)
1465 .await
1466 .trace_err();
1467
1468 if !session
1469 .connection_pool()
1470 .await
1471 .is_user_online(session.user_id())
1472 {
1473 let db = session.db().await;
1474 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1475 room_updated(&room, &session.peer);
1476 }
1477 }
1478
1479 update_user_contacts(session.user_id(), &session).await?;
1480 },
1481 Principal::DevServer(_) => {
1482 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1483 },
1484 }
1485 },
1486 _ = teardown.changed().fuse() => {}
1487 }
1488
1489 Ok(())
1490}
1491
1492/// Acknowledges a ping from a client, used to keep the connection alive.
1493async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1494 response.send(proto::Ack {})?;
1495 Ok(())
1496}
1497
1498/// Creates a new room for calling (outside of channels)
1499async fn create_room(
1500 _request: proto::CreateRoom,
1501 response: Response<proto::CreateRoom>,
1502 session: UserSession,
1503) -> Result<()> {
1504 let live_kit_room = nanoid::nanoid!(30);
1505
1506 let live_kit_connection_info = util::maybe!(async {
1507 let live_kit = session.live_kit_client.as_ref();
1508 let live_kit = live_kit?;
1509 let user_id = session.user_id().to_string();
1510
1511 let token = live_kit
1512 .room_token(&live_kit_room, &user_id.to_string())
1513 .trace_err()?;
1514
1515 Some(proto::LiveKitConnectionInfo {
1516 server_url: live_kit.url().into(),
1517 token,
1518 can_publish: true,
1519 })
1520 })
1521 .await;
1522
1523 let room = session
1524 .db()
1525 .await
1526 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1527 .await?;
1528
1529 response.send(proto::CreateRoomResponse {
1530 room: Some(room.clone()),
1531 live_kit_connection_info,
1532 })?;
1533
1534 update_user_contacts(session.user_id(), &session).await?;
1535 Ok(())
1536}
1537
1538/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1539async fn join_room(
1540 request: proto::JoinRoom,
1541 response: Response<proto::JoinRoom>,
1542 session: UserSession,
1543) -> Result<()> {
1544 let room_id = RoomId::from_proto(request.id);
1545
1546 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1547
1548 if let Some(channel_id) = channel_id {
1549 return join_channel_internal(channel_id, Box::new(response), session).await;
1550 }
1551
1552 let joined_room = {
1553 let room = session
1554 .db()
1555 .await
1556 .join_room(room_id, session.user_id(), session.connection_id)
1557 .await?;
1558 room_updated(&room.room, &session.peer);
1559 room.into_inner()
1560 };
1561
1562 for connection_id in session
1563 .connection_pool()
1564 .await
1565 .user_connection_ids(session.user_id())
1566 {
1567 session
1568 .peer
1569 .send(
1570 connection_id,
1571 proto::CallCanceled {
1572 room_id: room_id.to_proto(),
1573 },
1574 )
1575 .trace_err();
1576 }
1577
1578 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1579 if let Some(token) = live_kit
1580 .room_token(
1581 &joined_room.room.live_kit_room,
1582 &session.user_id().to_string(),
1583 )
1584 .trace_err()
1585 {
1586 Some(proto::LiveKitConnectionInfo {
1587 server_url: live_kit.url().into(),
1588 token,
1589 can_publish: true,
1590 })
1591 } else {
1592 None
1593 }
1594 } else {
1595 None
1596 };
1597
1598 response.send(proto::JoinRoomResponse {
1599 room: Some(joined_room.room),
1600 channel_id: None,
1601 live_kit_connection_info,
1602 })?;
1603
1604 update_user_contacts(session.user_id(), &session).await?;
1605 Ok(())
1606}
1607
1608/// Rejoin room is used to reconnect to a room after connection errors.
1609async fn rejoin_room(
1610 request: proto::RejoinRoom,
1611 response: Response<proto::RejoinRoom>,
1612 session: UserSession,
1613) -> Result<()> {
1614 let room;
1615 let channel;
1616 {
1617 let mut rejoined_room = session
1618 .db()
1619 .await
1620 .rejoin_room(request, session.user_id(), session.connection_id)
1621 .await?;
1622
1623 response.send(proto::RejoinRoomResponse {
1624 room: Some(rejoined_room.room.clone()),
1625 reshared_projects: rejoined_room
1626 .reshared_projects
1627 .iter()
1628 .map(|project| proto::ResharedProject {
1629 id: project.id.to_proto(),
1630 collaborators: project
1631 .collaborators
1632 .iter()
1633 .map(|collaborator| collaborator.to_proto())
1634 .collect(),
1635 })
1636 .collect(),
1637 rejoined_projects: rejoined_room
1638 .rejoined_projects
1639 .iter()
1640 .map(|rejoined_project| rejoined_project.to_proto())
1641 .collect(),
1642 })?;
1643 room_updated(&rejoined_room.room, &session.peer);
1644
1645 for project in &rejoined_room.reshared_projects {
1646 for collaborator in &project.collaborators {
1647 session
1648 .peer
1649 .send(
1650 collaborator.connection_id,
1651 proto::UpdateProjectCollaborator {
1652 project_id: project.id.to_proto(),
1653 old_peer_id: Some(project.old_connection_id.into()),
1654 new_peer_id: Some(session.connection_id.into()),
1655 },
1656 )
1657 .trace_err();
1658 }
1659
1660 broadcast(
1661 Some(session.connection_id),
1662 project
1663 .collaborators
1664 .iter()
1665 .map(|collaborator| collaborator.connection_id),
1666 |connection_id| {
1667 session.peer.forward_send(
1668 session.connection_id,
1669 connection_id,
1670 proto::UpdateProject {
1671 project_id: project.id.to_proto(),
1672 worktrees: project.worktrees.clone(),
1673 },
1674 )
1675 },
1676 );
1677 }
1678
1679 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1680
1681 let rejoined_room = rejoined_room.into_inner();
1682
1683 room = rejoined_room.room;
1684 channel = rejoined_room.channel;
1685 }
1686
1687 if let Some(channel) = channel {
1688 channel_updated(
1689 &channel,
1690 &room,
1691 &session.peer,
1692 &*session.connection_pool().await,
1693 );
1694 }
1695
1696 update_user_contacts(session.user_id(), &session).await?;
1697 Ok(())
1698}
1699
1700fn notify_rejoined_projects(
1701 rejoined_projects: &mut Vec<RejoinedProject>,
1702 session: &UserSession,
1703) -> Result<()> {
1704 for project in rejoined_projects.iter() {
1705 for collaborator in &project.collaborators {
1706 session
1707 .peer
1708 .send(
1709 collaborator.connection_id,
1710 proto::UpdateProjectCollaborator {
1711 project_id: project.id.to_proto(),
1712 old_peer_id: Some(project.old_connection_id.into()),
1713 new_peer_id: Some(session.connection_id.into()),
1714 },
1715 )
1716 .trace_err();
1717 }
1718 }
1719
1720 for project in rejoined_projects {
1721 for worktree in mem::take(&mut project.worktrees) {
1722 #[cfg(any(test, feature = "test-support"))]
1723 const MAX_CHUNK_SIZE: usize = 2;
1724 #[cfg(not(any(test, feature = "test-support")))]
1725 const MAX_CHUNK_SIZE: usize = 256;
1726
1727 // Stream this worktree's entries.
1728 let message = proto::UpdateWorktree {
1729 project_id: project.id.to_proto(),
1730 worktree_id: worktree.id,
1731 abs_path: worktree.abs_path.clone(),
1732 root_name: worktree.root_name,
1733 updated_entries: worktree.updated_entries,
1734 removed_entries: worktree.removed_entries,
1735 scan_id: worktree.scan_id,
1736 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1737 updated_repositories: worktree.updated_repositories,
1738 removed_repositories: worktree.removed_repositories,
1739 };
1740 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1741 session.peer.send(session.connection_id, update.clone())?;
1742 }
1743
1744 // Stream this worktree's diagnostics.
1745 for summary in worktree.diagnostic_summaries {
1746 session.peer.send(
1747 session.connection_id,
1748 proto::UpdateDiagnosticSummary {
1749 project_id: project.id.to_proto(),
1750 worktree_id: worktree.id,
1751 summary: Some(summary),
1752 },
1753 )?;
1754 }
1755
1756 for settings_file in worktree.settings_files {
1757 session.peer.send(
1758 session.connection_id,
1759 proto::UpdateWorktreeSettings {
1760 project_id: project.id.to_proto(),
1761 worktree_id: worktree.id,
1762 path: settings_file.path,
1763 content: Some(settings_file.content),
1764 },
1765 )?;
1766 }
1767 }
1768
1769 for language_server in &project.language_servers {
1770 session.peer.send(
1771 session.connection_id,
1772 proto::UpdateLanguageServer {
1773 project_id: project.id.to_proto(),
1774 language_server_id: language_server.id,
1775 variant: Some(
1776 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1777 proto::LspDiskBasedDiagnosticsUpdated {},
1778 ),
1779 ),
1780 },
1781 )?;
1782 }
1783 }
1784 Ok(())
1785}
1786
1787/// leave room disconnects from the room.
1788async fn leave_room(
1789 _: proto::LeaveRoom,
1790 response: Response<proto::LeaveRoom>,
1791 session: UserSession,
1792) -> Result<()> {
1793 leave_room_for_session(&session, session.connection_id).await?;
1794 response.send(proto::Ack {})?;
1795 Ok(())
1796}
1797
1798/// Updates the permissions of someone else in the room.
1799async fn set_room_participant_role(
1800 request: proto::SetRoomParticipantRole,
1801 response: Response<proto::SetRoomParticipantRole>,
1802 session: UserSession,
1803) -> Result<()> {
1804 let user_id = UserId::from_proto(request.user_id);
1805 let role = ChannelRole::from(request.role());
1806
1807 let (live_kit_room, can_publish) = {
1808 let room = session
1809 .db()
1810 .await
1811 .set_room_participant_role(
1812 session.user_id(),
1813 RoomId::from_proto(request.room_id),
1814 user_id,
1815 role,
1816 )
1817 .await?;
1818
1819 let live_kit_room = room.live_kit_room.clone();
1820 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1821 room_updated(&room, &session.peer);
1822 (live_kit_room, can_publish)
1823 };
1824
1825 if let Some(live_kit) = session.live_kit_client.as_ref() {
1826 live_kit
1827 .update_participant(
1828 live_kit_room.clone(),
1829 request.user_id.to_string(),
1830 live_kit_server::proto::ParticipantPermission {
1831 can_subscribe: true,
1832 can_publish,
1833 can_publish_data: can_publish,
1834 hidden: false,
1835 recorder: false,
1836 },
1837 )
1838 .await
1839 .trace_err();
1840 }
1841
1842 response.send(proto::Ack {})?;
1843 Ok(())
1844}
1845
1846/// Call someone else into the current room
1847async fn call(
1848 request: proto::Call,
1849 response: Response<proto::Call>,
1850 session: UserSession,
1851) -> Result<()> {
1852 let room_id = RoomId::from_proto(request.room_id);
1853 let calling_user_id = session.user_id();
1854 let calling_connection_id = session.connection_id;
1855 let called_user_id = UserId::from_proto(request.called_user_id);
1856 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1857 if !session
1858 .db()
1859 .await
1860 .has_contact(calling_user_id, called_user_id)
1861 .await?
1862 {
1863 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1864 }
1865
1866 let incoming_call = {
1867 let (room, incoming_call) = &mut *session
1868 .db()
1869 .await
1870 .call(
1871 room_id,
1872 calling_user_id,
1873 calling_connection_id,
1874 called_user_id,
1875 initial_project_id,
1876 )
1877 .await?;
1878 room_updated(&room, &session.peer);
1879 mem::take(incoming_call)
1880 };
1881 update_user_contacts(called_user_id, &session).await?;
1882
1883 let mut calls = session
1884 .connection_pool()
1885 .await
1886 .user_connection_ids(called_user_id)
1887 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1888 .collect::<FuturesUnordered<_>>();
1889
1890 while let Some(call_response) = calls.next().await {
1891 match call_response.as_ref() {
1892 Ok(_) => {
1893 response.send(proto::Ack {})?;
1894 return Ok(());
1895 }
1896 Err(_) => {
1897 call_response.trace_err();
1898 }
1899 }
1900 }
1901
1902 {
1903 let room = session
1904 .db()
1905 .await
1906 .call_failed(room_id, called_user_id)
1907 .await?;
1908 room_updated(&room, &session.peer);
1909 }
1910 update_user_contacts(called_user_id, &session).await?;
1911
1912 Err(anyhow!("failed to ring user"))?
1913}
1914
1915/// Cancel an outgoing call.
1916async fn cancel_call(
1917 request: proto::CancelCall,
1918 response: Response<proto::CancelCall>,
1919 session: UserSession,
1920) -> Result<()> {
1921 let called_user_id = UserId::from_proto(request.called_user_id);
1922 let room_id = RoomId::from_proto(request.room_id);
1923 {
1924 let room = session
1925 .db()
1926 .await
1927 .cancel_call(room_id, session.connection_id, called_user_id)
1928 .await?;
1929 room_updated(&room, &session.peer);
1930 }
1931
1932 for connection_id in session
1933 .connection_pool()
1934 .await
1935 .user_connection_ids(called_user_id)
1936 {
1937 session
1938 .peer
1939 .send(
1940 connection_id,
1941 proto::CallCanceled {
1942 room_id: room_id.to_proto(),
1943 },
1944 )
1945 .trace_err();
1946 }
1947 response.send(proto::Ack {})?;
1948
1949 update_user_contacts(called_user_id, &session).await?;
1950 Ok(())
1951}
1952
1953/// Decline an incoming call.
1954async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1955 let room_id = RoomId::from_proto(message.room_id);
1956 {
1957 let room = session
1958 .db()
1959 .await
1960 .decline_call(Some(room_id), session.user_id())
1961 .await?
1962 .ok_or_else(|| anyhow!("failed to decline call"))?;
1963 room_updated(&room, &session.peer);
1964 }
1965
1966 for connection_id in session
1967 .connection_pool()
1968 .await
1969 .user_connection_ids(session.user_id())
1970 {
1971 session
1972 .peer
1973 .send(
1974 connection_id,
1975 proto::CallCanceled {
1976 room_id: room_id.to_proto(),
1977 },
1978 )
1979 .trace_err();
1980 }
1981 update_user_contacts(session.user_id(), &session).await?;
1982 Ok(())
1983}
1984
1985/// Updates other participants in the room with your current location.
1986async fn update_participant_location(
1987 request: proto::UpdateParticipantLocation,
1988 response: Response<proto::UpdateParticipantLocation>,
1989 session: UserSession,
1990) -> Result<()> {
1991 let room_id = RoomId::from_proto(request.room_id);
1992 let location = request
1993 .location
1994 .ok_or_else(|| anyhow!("invalid location"))?;
1995
1996 let db = session.db().await;
1997 let room = db
1998 .update_room_participant_location(room_id, session.connection_id, location)
1999 .await?;
2000
2001 room_updated(&room, &session.peer);
2002 response.send(proto::Ack {})?;
2003 Ok(())
2004}
2005
2006/// Share a project into the room.
2007async fn share_project(
2008 request: proto::ShareProject,
2009 response: Response<proto::ShareProject>,
2010 session: UserSession,
2011) -> Result<()> {
2012 let (project_id, room) = &*session
2013 .db()
2014 .await
2015 .share_project(
2016 RoomId::from_proto(request.room_id),
2017 session.connection_id,
2018 &request.worktrees,
2019 request
2020 .dev_server_project_id
2021 .map(|id| DevServerProjectId::from_proto(id)),
2022 )
2023 .await?;
2024 response.send(proto::ShareProjectResponse {
2025 project_id: project_id.to_proto(),
2026 })?;
2027 room_updated(&room, &session.peer);
2028
2029 Ok(())
2030}
2031
2032/// Unshare a project from the room.
2033async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2034 let project_id = ProjectId::from_proto(message.project_id);
2035 unshare_project_internal(
2036 project_id,
2037 session.connection_id,
2038 session.user_id(),
2039 &session,
2040 )
2041 .await
2042}
2043
2044async fn unshare_project_internal(
2045 project_id: ProjectId,
2046 connection_id: ConnectionId,
2047 user_id: Option<UserId>,
2048 session: &Session,
2049) -> Result<()> {
2050 let delete = {
2051 let room_guard = session
2052 .db()
2053 .await
2054 .unshare_project(project_id, connection_id, user_id)
2055 .await?;
2056
2057 let (delete, room, guest_connection_ids) = &*room_guard;
2058
2059 let message = proto::UnshareProject {
2060 project_id: project_id.to_proto(),
2061 };
2062
2063 broadcast(
2064 Some(connection_id),
2065 guest_connection_ids.iter().copied(),
2066 |conn_id| session.peer.send(conn_id, message.clone()),
2067 );
2068 if let Some(room) = room {
2069 room_updated(room, &session.peer);
2070 }
2071
2072 *delete
2073 };
2074
2075 if delete {
2076 let db = session.db().await;
2077 db.delete_project(project_id).await?;
2078 }
2079
2080 Ok(())
2081}
2082
2083/// DevServer makes a project available online
2084async fn share_dev_server_project(
2085 request: proto::ShareDevServerProject,
2086 response: Response<proto::ShareDevServerProject>,
2087 session: DevServerSession,
2088) -> Result<()> {
2089 let (dev_server_project, user_id, status) = session
2090 .db()
2091 .await
2092 .share_dev_server_project(
2093 DevServerProjectId::from_proto(request.dev_server_project_id),
2094 session.dev_server_id(),
2095 session.connection_id,
2096 &request.worktrees,
2097 )
2098 .await?;
2099 let Some(project_id) = dev_server_project.project_id else {
2100 return Err(anyhow!("failed to share remote project"))?;
2101 };
2102
2103 send_dev_server_projects_update(user_id, status, &session).await;
2104
2105 response.send(proto::ShareProjectResponse { project_id })?;
2106
2107 Ok(())
2108}
2109
2110/// Join someone elses shared project.
2111async fn join_project(
2112 request: proto::JoinProject,
2113 response: Response<proto::JoinProject>,
2114 session: UserSession,
2115) -> Result<()> {
2116 let project_id = ProjectId::from_proto(request.project_id);
2117
2118 tracing::info!(%project_id, "join project");
2119
2120 let db = session.db().await;
2121 let (project, replica_id) = &mut *db
2122 .join_project(project_id, session.connection_id, session.user_id())
2123 .await?;
2124 drop(db);
2125 tracing::info!(%project_id, "join remote project");
2126 join_project_internal(response, session, project, replica_id)
2127}
2128
2129trait JoinProjectInternalResponse {
2130 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2131}
2132impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2133 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2134 Response::<proto::JoinProject>::send(self, result)
2135 }
2136}
2137impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2138 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2139 Response::<proto::JoinHostedProject>::send(self, result)
2140 }
2141}
2142
2143fn join_project_internal(
2144 response: impl JoinProjectInternalResponse,
2145 session: UserSession,
2146 project: &mut Project,
2147 replica_id: &ReplicaId,
2148) -> Result<()> {
2149 let collaborators = project
2150 .collaborators
2151 .iter()
2152 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2153 .map(|collaborator| collaborator.to_proto())
2154 .collect::<Vec<_>>();
2155 let project_id = project.id;
2156 let guest_user_id = session.user_id();
2157
2158 let worktrees = project
2159 .worktrees
2160 .iter()
2161 .map(|(id, worktree)| proto::WorktreeMetadata {
2162 id: *id,
2163 root_name: worktree.root_name.clone(),
2164 visible: worktree.visible,
2165 abs_path: worktree.abs_path.clone(),
2166 })
2167 .collect::<Vec<_>>();
2168
2169 let add_project_collaborator = proto::AddProjectCollaborator {
2170 project_id: project_id.to_proto(),
2171 collaborator: Some(proto::Collaborator {
2172 peer_id: Some(session.connection_id.into()),
2173 replica_id: replica_id.0 as u32,
2174 user_id: guest_user_id.to_proto(),
2175 }),
2176 };
2177
2178 for collaborator in &collaborators {
2179 session
2180 .peer
2181 .send(
2182 collaborator.peer_id.unwrap().into(),
2183 add_project_collaborator.clone(),
2184 )
2185 .trace_err();
2186 }
2187
2188 // First, we send the metadata associated with each worktree.
2189 response.send(proto::JoinProjectResponse {
2190 project_id: project.id.0 as u64,
2191 worktrees: worktrees.clone(),
2192 replica_id: replica_id.0 as u32,
2193 collaborators: collaborators.clone(),
2194 language_servers: project.language_servers.clone(),
2195 role: project.role.into(),
2196 dev_server_project_id: project
2197 .dev_server_project_id
2198 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2199 })?;
2200
2201 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2202 #[cfg(any(test, feature = "test-support"))]
2203 const MAX_CHUNK_SIZE: usize = 2;
2204 #[cfg(not(any(test, feature = "test-support")))]
2205 const MAX_CHUNK_SIZE: usize = 256;
2206
2207 // Stream this worktree's entries.
2208 let message = proto::UpdateWorktree {
2209 project_id: project_id.to_proto(),
2210 worktree_id,
2211 abs_path: worktree.abs_path.clone(),
2212 root_name: worktree.root_name,
2213 updated_entries: worktree.entries,
2214 removed_entries: Default::default(),
2215 scan_id: worktree.scan_id,
2216 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2217 updated_repositories: worktree.repository_entries.into_values().collect(),
2218 removed_repositories: Default::default(),
2219 };
2220 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2221 session.peer.send(session.connection_id, update.clone())?;
2222 }
2223
2224 // Stream this worktree's diagnostics.
2225 for summary in worktree.diagnostic_summaries {
2226 session.peer.send(
2227 session.connection_id,
2228 proto::UpdateDiagnosticSummary {
2229 project_id: project_id.to_proto(),
2230 worktree_id: worktree.id,
2231 summary: Some(summary),
2232 },
2233 )?;
2234 }
2235
2236 for settings_file in worktree.settings_files {
2237 session.peer.send(
2238 session.connection_id,
2239 proto::UpdateWorktreeSettings {
2240 project_id: project_id.to_proto(),
2241 worktree_id: worktree.id,
2242 path: settings_file.path,
2243 content: Some(settings_file.content),
2244 },
2245 )?;
2246 }
2247 }
2248
2249 for language_server in &project.language_servers {
2250 session.peer.send(
2251 session.connection_id,
2252 proto::UpdateLanguageServer {
2253 project_id: project_id.to_proto(),
2254 language_server_id: language_server.id,
2255 variant: Some(
2256 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2257 proto::LspDiskBasedDiagnosticsUpdated {},
2258 ),
2259 ),
2260 },
2261 )?;
2262 }
2263
2264 Ok(())
2265}
2266
2267/// Leave someone elses shared project.
2268async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2269 let sender_id = session.connection_id;
2270 let project_id = ProjectId::from_proto(request.project_id);
2271 let db = session.db().await;
2272 if db.is_hosted_project(project_id).await? {
2273 let project = db.leave_hosted_project(project_id, sender_id).await?;
2274 project_left(&project, &session);
2275 return Ok(());
2276 }
2277
2278 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2279 tracing::info!(
2280 %project_id,
2281 "leave project"
2282 );
2283
2284 project_left(&project, &session);
2285 if let Some(room) = room {
2286 room_updated(&room, &session.peer);
2287 }
2288
2289 Ok(())
2290}
2291
2292async fn join_hosted_project(
2293 request: proto::JoinHostedProject,
2294 response: Response<proto::JoinHostedProject>,
2295 session: UserSession,
2296) -> Result<()> {
2297 let (mut project, replica_id) = session
2298 .db()
2299 .await
2300 .join_hosted_project(
2301 ProjectId(request.project_id as i32),
2302 session.user_id(),
2303 session.connection_id,
2304 )
2305 .await?;
2306
2307 join_project_internal(response, session, &mut project, &replica_id)
2308}
2309
2310async fn create_dev_server_project(
2311 request: proto::CreateDevServerProject,
2312 response: Response<proto::CreateDevServerProject>,
2313 session: UserSession,
2314) -> Result<()> {
2315 let dev_server_id = DevServerId(request.dev_server_id as i32);
2316 let dev_server_connection_id = session
2317 .connection_pool()
2318 .await
2319 .dev_server_connection_id(dev_server_id);
2320 let Some(dev_server_connection_id) = dev_server_connection_id else {
2321 Err(ErrorCode::DevServerOffline
2322 .message("Cannot create a remote project when the dev server is offline".to_string())
2323 .anyhow())?
2324 };
2325
2326 let path = request.path.clone();
2327 //Check that the path exists on the dev server
2328 session
2329 .peer
2330 .forward_request(
2331 session.connection_id,
2332 dev_server_connection_id,
2333 proto::ValidateDevServerProjectRequest { path: path.clone() },
2334 )
2335 .await?;
2336
2337 let (dev_server_project, update) = session
2338 .db()
2339 .await
2340 .create_dev_server_project(
2341 DevServerId(request.dev_server_id as i32),
2342 &request.path,
2343 session.user_id(),
2344 )
2345 .await?;
2346
2347 let projects = session
2348 .db()
2349 .await
2350 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2351 .await?;
2352
2353 session.peer.send(
2354 dev_server_connection_id,
2355 proto::DevServerInstructions { projects },
2356 )?;
2357
2358 send_dev_server_projects_update(session.user_id(), update, &session).await;
2359
2360 response.send(proto::CreateDevServerProjectResponse {
2361 dev_server_project: Some(dev_server_project.to_proto(None)),
2362 })?;
2363 Ok(())
2364}
2365
2366async fn create_dev_server(
2367 request: proto::CreateDevServer,
2368 response: Response<proto::CreateDevServer>,
2369 session: UserSession,
2370) -> Result<()> {
2371 let access_token = auth::random_token();
2372 let hashed_access_token = auth::hash_access_token(&access_token);
2373
2374 if request.name.is_empty() {
2375 return Err(proto::ErrorCode::Forbidden
2376 .message("Dev server name cannot be empty".to_string())
2377 .anyhow())?;
2378 }
2379
2380 let (dev_server, status) = session
2381 .db()
2382 .await
2383 .create_dev_server(
2384 &request.name,
2385 request.ssh_connection_string.as_deref(),
2386 &hashed_access_token,
2387 session.user_id(),
2388 )
2389 .await?;
2390
2391 send_dev_server_projects_update(session.user_id(), status, &session).await;
2392
2393 response.send(proto::CreateDevServerResponse {
2394 dev_server_id: dev_server.id.0 as u64,
2395 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2396 name: request.name,
2397 })?;
2398 Ok(())
2399}
2400
2401async fn regenerate_dev_server_token(
2402 request: proto::RegenerateDevServerToken,
2403 response: Response<proto::RegenerateDevServerToken>,
2404 session: UserSession,
2405) -> Result<()> {
2406 let dev_server_id = DevServerId(request.dev_server_id as i32);
2407 let access_token = auth::random_token();
2408 let hashed_access_token = auth::hash_access_token(&access_token);
2409
2410 let connection_id = session
2411 .connection_pool()
2412 .await
2413 .dev_server_connection_id(dev_server_id);
2414 if let Some(connection_id) = connection_id {
2415 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2416 session.peer.send(
2417 connection_id,
2418 proto::ShutdownDevServer {
2419 reason: Some("dev server token was regenerated".to_string()),
2420 },
2421 )?;
2422 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2423 }
2424
2425 let status = session
2426 .db()
2427 .await
2428 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2429 .await?;
2430
2431 send_dev_server_projects_update(session.user_id(), status, &session).await;
2432
2433 response.send(proto::RegenerateDevServerTokenResponse {
2434 dev_server_id: dev_server_id.to_proto(),
2435 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2436 })?;
2437 Ok(())
2438}
2439
2440async fn rename_dev_server(
2441 request: proto::RenameDevServer,
2442 response: Response<proto::RenameDevServer>,
2443 session: UserSession,
2444) -> Result<()> {
2445 if request.name.trim().is_empty() {
2446 return Err(proto::ErrorCode::Forbidden
2447 .message("Dev server name cannot be empty".to_string())
2448 .anyhow())?;
2449 }
2450
2451 let dev_server_id = DevServerId(request.dev_server_id as i32);
2452 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2453 if dev_server.user_id != session.user_id() {
2454 return Err(anyhow!(ErrorCode::Forbidden))?;
2455 }
2456
2457 let status = session
2458 .db()
2459 .await
2460 .rename_dev_server(
2461 dev_server_id,
2462 &request.name,
2463 request.ssh_connection_string.as_deref(),
2464 session.user_id(),
2465 )
2466 .await?;
2467
2468 send_dev_server_projects_update(session.user_id(), status, &session).await;
2469
2470 response.send(proto::Ack {})?;
2471 Ok(())
2472}
2473
2474async fn delete_dev_server(
2475 request: proto::DeleteDevServer,
2476 response: Response<proto::DeleteDevServer>,
2477 session: UserSession,
2478) -> Result<()> {
2479 let dev_server_id = DevServerId(request.dev_server_id as i32);
2480 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2481 if dev_server.user_id != session.user_id() {
2482 return Err(anyhow!(ErrorCode::Forbidden))?;
2483 }
2484
2485 let connection_id = session
2486 .connection_pool()
2487 .await
2488 .dev_server_connection_id(dev_server_id);
2489 if let Some(connection_id) = connection_id {
2490 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2491 session.peer.send(
2492 connection_id,
2493 proto::ShutdownDevServer {
2494 reason: Some("dev server was deleted".to_string()),
2495 },
2496 )?;
2497 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2498 }
2499
2500 let status = session
2501 .db()
2502 .await
2503 .delete_dev_server(dev_server_id, session.user_id())
2504 .await?;
2505
2506 send_dev_server_projects_update(session.user_id(), status, &session).await;
2507
2508 response.send(proto::Ack {})?;
2509 Ok(())
2510}
2511
2512async fn delete_dev_server_project(
2513 request: proto::DeleteDevServerProject,
2514 response: Response<proto::DeleteDevServerProject>,
2515 session: UserSession,
2516) -> Result<()> {
2517 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2518 let dev_server_project = session
2519 .db()
2520 .await
2521 .get_dev_server_project(dev_server_project_id)
2522 .await?;
2523
2524 let dev_server = session
2525 .db()
2526 .await
2527 .get_dev_server(dev_server_project.dev_server_id)
2528 .await?;
2529 if dev_server.user_id != session.user_id() {
2530 return Err(anyhow!(ErrorCode::Forbidden))?;
2531 }
2532
2533 let dev_server_connection_id = session
2534 .connection_pool()
2535 .await
2536 .dev_server_connection_id(dev_server.id);
2537
2538 if let Some(dev_server_connection_id) = dev_server_connection_id {
2539 let project = session
2540 .db()
2541 .await
2542 .find_dev_server_project(dev_server_project_id)
2543 .await;
2544 if let Ok(project) = project {
2545 unshare_project_internal(
2546 project.id,
2547 dev_server_connection_id,
2548 Some(session.user_id()),
2549 &session,
2550 )
2551 .await?;
2552 }
2553 }
2554
2555 let (projects, status) = session
2556 .db()
2557 .await
2558 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2559 .await?;
2560
2561 if let Some(dev_server_connection_id) = dev_server_connection_id {
2562 session.peer.send(
2563 dev_server_connection_id,
2564 proto::DevServerInstructions { projects },
2565 )?;
2566 }
2567
2568 send_dev_server_projects_update(session.user_id(), status, &session).await;
2569
2570 response.send(proto::Ack {})?;
2571 Ok(())
2572}
2573
2574async fn rejoin_dev_server_projects(
2575 request: proto::RejoinRemoteProjects,
2576 response: Response<proto::RejoinRemoteProjects>,
2577 session: UserSession,
2578) -> Result<()> {
2579 let mut rejoined_projects = {
2580 let db = session.db().await;
2581 db.rejoin_dev_server_projects(
2582 &request.rejoined_projects,
2583 session.user_id(),
2584 session.0.connection_id,
2585 )
2586 .await?
2587 };
2588 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2589
2590 response.send(proto::RejoinRemoteProjectsResponse {
2591 rejoined_projects: rejoined_projects
2592 .into_iter()
2593 .map(|project| project.to_proto())
2594 .collect(),
2595 })
2596}
2597
2598async fn reconnect_dev_server(
2599 request: proto::ReconnectDevServer,
2600 response: Response<proto::ReconnectDevServer>,
2601 session: DevServerSession,
2602) -> Result<()> {
2603 let reshared_projects = {
2604 let db = session.db().await;
2605 db.reshare_dev_server_projects(
2606 &request.reshared_projects,
2607 session.dev_server_id(),
2608 session.0.connection_id,
2609 )
2610 .await?
2611 };
2612
2613 for project in &reshared_projects {
2614 for collaborator in &project.collaborators {
2615 session
2616 .peer
2617 .send(
2618 collaborator.connection_id,
2619 proto::UpdateProjectCollaborator {
2620 project_id: project.id.to_proto(),
2621 old_peer_id: Some(project.old_connection_id.into()),
2622 new_peer_id: Some(session.connection_id.into()),
2623 },
2624 )
2625 .trace_err();
2626 }
2627
2628 broadcast(
2629 Some(session.connection_id),
2630 project
2631 .collaborators
2632 .iter()
2633 .map(|collaborator| collaborator.connection_id),
2634 |connection_id| {
2635 session.peer.forward_send(
2636 session.connection_id,
2637 connection_id,
2638 proto::UpdateProject {
2639 project_id: project.id.to_proto(),
2640 worktrees: project.worktrees.clone(),
2641 },
2642 )
2643 },
2644 );
2645 }
2646
2647 response.send(proto::ReconnectDevServerResponse {
2648 reshared_projects: reshared_projects
2649 .iter()
2650 .map(|project| proto::ResharedProject {
2651 id: project.id.to_proto(),
2652 collaborators: project
2653 .collaborators
2654 .iter()
2655 .map(|collaborator| collaborator.to_proto())
2656 .collect(),
2657 })
2658 .collect(),
2659 })?;
2660
2661 Ok(())
2662}
2663
2664async fn shutdown_dev_server(
2665 _: proto::ShutdownDevServer,
2666 response: Response<proto::ShutdownDevServer>,
2667 session: DevServerSession,
2668) -> Result<()> {
2669 response.send(proto::Ack {})?;
2670 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2671 remove_dev_server_connection(session.dev_server_id(), &session).await
2672}
2673
2674async fn shutdown_dev_server_internal(
2675 dev_server_id: DevServerId,
2676 connection_id: ConnectionId,
2677 session: &Session,
2678) -> Result<()> {
2679 let (dev_server_projects, dev_server) = {
2680 let db = session.db().await;
2681 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2682 let dev_server = db.get_dev_server(dev_server_id).await?;
2683 (dev_server_projects, dev_server)
2684 };
2685
2686 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2687 unshare_project_internal(
2688 ProjectId::from_proto(project_id),
2689 connection_id,
2690 None,
2691 session,
2692 )
2693 .await?;
2694 }
2695
2696 session
2697 .connection_pool()
2698 .await
2699 .set_dev_server_offline(dev_server_id);
2700
2701 let status = session
2702 .db()
2703 .await
2704 .dev_server_projects_update(dev_server.user_id)
2705 .await?;
2706 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2707
2708 Ok(())
2709}
2710
2711async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2712 let dev_server_connection = session
2713 .connection_pool()
2714 .await
2715 .dev_server_connection_id(dev_server_id);
2716
2717 if let Some(dev_server_connection) = dev_server_connection {
2718 session
2719 .connection_pool()
2720 .await
2721 .remove_connection(dev_server_connection)?;
2722 }
2723 Ok(())
2724}
2725
2726/// Updates other participants with changes to the project
2727async fn update_project(
2728 request: proto::UpdateProject,
2729 response: Response<proto::UpdateProject>,
2730 session: Session,
2731) -> Result<()> {
2732 let project_id = ProjectId::from_proto(request.project_id);
2733 let (room, guest_connection_ids) = &*session
2734 .db()
2735 .await
2736 .update_project(project_id, session.connection_id, &request.worktrees)
2737 .await?;
2738 broadcast(
2739 Some(session.connection_id),
2740 guest_connection_ids.iter().copied(),
2741 |connection_id| {
2742 session
2743 .peer
2744 .forward_send(session.connection_id, connection_id, request.clone())
2745 },
2746 );
2747 if let Some(room) = room {
2748 room_updated(&room, &session.peer);
2749 }
2750 response.send(proto::Ack {})?;
2751
2752 Ok(())
2753}
2754
2755/// Updates other participants with changes to the worktree
2756async fn update_worktree(
2757 request: proto::UpdateWorktree,
2758 response: Response<proto::UpdateWorktree>,
2759 session: Session,
2760) -> Result<()> {
2761 let guest_connection_ids = session
2762 .db()
2763 .await
2764 .update_worktree(&request, session.connection_id)
2765 .await?;
2766
2767 broadcast(
2768 Some(session.connection_id),
2769 guest_connection_ids.iter().copied(),
2770 |connection_id| {
2771 session
2772 .peer
2773 .forward_send(session.connection_id, connection_id, request.clone())
2774 },
2775 );
2776 response.send(proto::Ack {})?;
2777 Ok(())
2778}
2779
2780/// Updates other participants with changes to the diagnostics
2781async fn update_diagnostic_summary(
2782 message: proto::UpdateDiagnosticSummary,
2783 session: Session,
2784) -> Result<()> {
2785 let guest_connection_ids = session
2786 .db()
2787 .await
2788 .update_diagnostic_summary(&message, session.connection_id)
2789 .await?;
2790
2791 broadcast(
2792 Some(session.connection_id),
2793 guest_connection_ids.iter().copied(),
2794 |connection_id| {
2795 session
2796 .peer
2797 .forward_send(session.connection_id, connection_id, message.clone())
2798 },
2799 );
2800
2801 Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree settings
2805async fn update_worktree_settings(
2806 message: proto::UpdateWorktreeSettings,
2807 session: Session,
2808) -> Result<()> {
2809 let guest_connection_ids = session
2810 .db()
2811 .await
2812 .update_worktree_settings(&message, session.connection_id)
2813 .await?;
2814
2815 broadcast(
2816 Some(session.connection_id),
2817 guest_connection_ids.iter().copied(),
2818 |connection_id| {
2819 session
2820 .peer
2821 .forward_send(session.connection_id, connection_id, message.clone())
2822 },
2823 );
2824
2825 Ok(())
2826}
2827
2828/// Notify other participants that a language server has started.
2829async fn start_language_server(
2830 request: proto::StartLanguageServer,
2831 session: Session,
2832) -> Result<()> {
2833 let guest_connection_ids = session
2834 .db()
2835 .await
2836 .start_language_server(&request, session.connection_id)
2837 .await?;
2838
2839 broadcast(
2840 Some(session.connection_id),
2841 guest_connection_ids.iter().copied(),
2842 |connection_id| {
2843 session
2844 .peer
2845 .forward_send(session.connection_id, connection_id, request.clone())
2846 },
2847 );
2848 Ok(())
2849}
2850
2851/// Notify other participants that a language server has changed.
2852async fn update_language_server(
2853 request: proto::UpdateLanguageServer,
2854 session: Session,
2855) -> Result<()> {
2856 let project_id = ProjectId::from_proto(request.project_id);
2857 let project_connection_ids = session
2858 .db()
2859 .await
2860 .project_connection_ids(project_id, session.connection_id, true)
2861 .await?;
2862 broadcast(
2863 Some(session.connection_id),
2864 project_connection_ids.iter().copied(),
2865 |connection_id| {
2866 session
2867 .peer
2868 .forward_send(session.connection_id, connection_id, request.clone())
2869 },
2870 );
2871 Ok(())
2872}
2873
2874/// forward a project request to the host. These requests should be read only
2875/// as guests are allowed to send them.
2876async fn forward_read_only_project_request<T>(
2877 request: T,
2878 response: Response<T>,
2879 session: UserSession,
2880) -> Result<()>
2881where
2882 T: EntityMessage + RequestMessage,
2883{
2884 let project_id = ProjectId::from_proto(request.remote_entity_id());
2885 let host_connection_id = session
2886 .db()
2887 .await
2888 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2889 .await?;
2890 let payload = session
2891 .peer
2892 .forward_request(session.connection_id, host_connection_id, request)
2893 .await?;
2894 response.send(payload)?;
2895 Ok(())
2896}
2897
2898/// forward a project request to the dev server. Only allowed
2899/// if it's your dev server.
2900async fn forward_project_request_for_owner<T>(
2901 request: T,
2902 response: Response<T>,
2903 session: UserSession,
2904) -> Result<()>
2905where
2906 T: EntityMessage + RequestMessage,
2907{
2908 let project_id = ProjectId::from_proto(request.remote_entity_id());
2909
2910 let host_connection_id = session
2911 .db()
2912 .await
2913 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2914 .await?;
2915 let payload = session
2916 .peer
2917 .forward_request(session.connection_id, host_connection_id, request)
2918 .await?;
2919 response.send(payload)?;
2920 Ok(())
2921}
2922
2923/// forward a project request to the host. These requests are disallowed
2924/// for guests.
2925async fn forward_mutating_project_request<T>(
2926 request: T,
2927 response: Response<T>,
2928 session: UserSession,
2929) -> Result<()>
2930where
2931 T: EntityMessage + RequestMessage,
2932{
2933 let project_id = ProjectId::from_proto(request.remote_entity_id());
2934
2935 let host_connection_id = session
2936 .db()
2937 .await
2938 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2939 .await?;
2940 let payload = session
2941 .peer
2942 .forward_request(session.connection_id, host_connection_id, request)
2943 .await?;
2944 response.send(payload)?;
2945 Ok(())
2946}
2947
2948/// forward a project request to the host. These requests are disallowed
2949/// for guests.
2950async fn forward_versioned_mutating_project_request<T>(
2951 request: T,
2952 response: Response<T>,
2953 session: UserSession,
2954) -> Result<()>
2955where
2956 T: EntityMessage + RequestMessage + VersionedMessage,
2957{
2958 let project_id = ProjectId::from_proto(request.remote_entity_id());
2959
2960 let host_connection_id = session
2961 .db()
2962 .await
2963 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2964 .await?;
2965 if let Some(host_version) = session
2966 .connection_pool()
2967 .await
2968 .connection(host_connection_id)
2969 .map(|c| c.zed_version)
2970 {
2971 if let Some(min_required_version) = request.required_host_version() {
2972 if min_required_version > host_version {
2973 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2974 .with_tag("required", &min_required_version.to_string())))?;
2975 }
2976 }
2977 }
2978
2979 let payload = session
2980 .peer
2981 .forward_request(session.connection_id, host_connection_id, request)
2982 .await?;
2983 response.send(payload)?;
2984 Ok(())
2985}
2986
2987/// Notify other participants that a new buffer has been created
2988async fn create_buffer_for_peer(
2989 request: proto::CreateBufferForPeer,
2990 session: Session,
2991) -> Result<()> {
2992 session
2993 .db()
2994 .await
2995 .check_user_is_project_host(
2996 ProjectId::from_proto(request.project_id),
2997 session.connection_id,
2998 )
2999 .await?;
3000 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3001 session
3002 .peer
3003 .forward_send(session.connection_id, peer_id.into(), request)?;
3004 Ok(())
3005}
3006
3007/// Notify other participants that a buffer has been updated. This is
3008/// allowed for guests as long as the update is limited to selections.
3009async fn update_buffer(
3010 request: proto::UpdateBuffer,
3011 response: Response<proto::UpdateBuffer>,
3012 session: Session,
3013) -> Result<()> {
3014 let project_id = ProjectId::from_proto(request.project_id);
3015 let mut capability = Capability::ReadOnly;
3016
3017 for op in request.operations.iter() {
3018 match op.variant {
3019 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3020 Some(_) => capability = Capability::ReadWrite,
3021 }
3022 }
3023
3024 let host = {
3025 let guard = session
3026 .db()
3027 .await
3028 .connections_for_buffer_update(
3029 project_id,
3030 session.principal_id(),
3031 session.connection_id,
3032 capability,
3033 )
3034 .await?;
3035
3036 let (host, guests) = &*guard;
3037
3038 broadcast(
3039 Some(session.connection_id),
3040 guests.clone(),
3041 |connection_id| {
3042 session
3043 .peer
3044 .forward_send(session.connection_id, connection_id, request.clone())
3045 },
3046 );
3047
3048 *host
3049 };
3050
3051 if host != session.connection_id {
3052 session
3053 .peer
3054 .forward_request(session.connection_id, host, request.clone())
3055 .await?;
3056 }
3057
3058 response.send(proto::Ack {})?;
3059 Ok(())
3060}
3061
3062/// Notify other participants that a project has been updated.
3063async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3064 request: T,
3065 session: Session,
3066) -> Result<()> {
3067 let project_id = ProjectId::from_proto(request.remote_entity_id());
3068 let project_connection_ids = session
3069 .db()
3070 .await
3071 .project_connection_ids(project_id, session.connection_id, false)
3072 .await?;
3073
3074 broadcast(
3075 Some(session.connection_id),
3076 project_connection_ids.iter().copied(),
3077 |connection_id| {
3078 session
3079 .peer
3080 .forward_send(session.connection_id, connection_id, request.clone())
3081 },
3082 );
3083 Ok(())
3084}
3085
3086/// Start following another user in a call.
3087async fn follow(
3088 request: proto::Follow,
3089 response: Response<proto::Follow>,
3090 session: UserSession,
3091) -> Result<()> {
3092 let room_id = RoomId::from_proto(request.room_id);
3093 let project_id = request.project_id.map(ProjectId::from_proto);
3094 let leader_id = request
3095 .leader_id
3096 .ok_or_else(|| anyhow!("invalid leader id"))?
3097 .into();
3098 let follower_id = session.connection_id;
3099
3100 session
3101 .db()
3102 .await
3103 .check_room_participants(room_id, leader_id, session.connection_id)
3104 .await?;
3105
3106 let response_payload = session
3107 .peer
3108 .forward_request(session.connection_id, leader_id, request)
3109 .await?;
3110 response.send(response_payload)?;
3111
3112 if let Some(project_id) = project_id {
3113 let room = session
3114 .db()
3115 .await
3116 .follow(room_id, project_id, leader_id, follower_id)
3117 .await?;
3118 room_updated(&room, &session.peer);
3119 }
3120
3121 Ok(())
3122}
3123
3124/// Stop following another user in a call.
3125async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3126 let room_id = RoomId::from_proto(request.room_id);
3127 let project_id = request.project_id.map(ProjectId::from_proto);
3128 let leader_id = request
3129 .leader_id
3130 .ok_or_else(|| anyhow!("invalid leader id"))?
3131 .into();
3132 let follower_id = session.connection_id;
3133
3134 session
3135 .db()
3136 .await
3137 .check_room_participants(room_id, leader_id, session.connection_id)
3138 .await?;
3139
3140 session
3141 .peer
3142 .forward_send(session.connection_id, leader_id, request)?;
3143
3144 if let Some(project_id) = project_id {
3145 let room = session
3146 .db()
3147 .await
3148 .unfollow(room_id, project_id, leader_id, follower_id)
3149 .await?;
3150 room_updated(&room, &session.peer);
3151 }
3152
3153 Ok(())
3154}
3155
3156/// Notify everyone following you of your current location.
3157async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3158 let room_id = RoomId::from_proto(request.room_id);
3159 let database = session.db.lock().await;
3160
3161 let connection_ids = if let Some(project_id) = request.project_id {
3162 let project_id = ProjectId::from_proto(project_id);
3163 database
3164 .project_connection_ids(project_id, session.connection_id, true)
3165 .await?
3166 } else {
3167 database
3168 .room_connection_ids(room_id, session.connection_id)
3169 .await?
3170 };
3171
3172 // For now, don't send view update messages back to that view's current leader.
3173 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3174 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3175 _ => None,
3176 });
3177
3178 for connection_id in connection_ids.iter().cloned() {
3179 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3180 session
3181 .peer
3182 .forward_send(session.connection_id, connection_id, request.clone())?;
3183 }
3184 }
3185 Ok(())
3186}
3187
3188/// Get public data about users.
3189async fn get_users(
3190 request: proto::GetUsers,
3191 response: Response<proto::GetUsers>,
3192 session: Session,
3193) -> Result<()> {
3194 let user_ids = request
3195 .user_ids
3196 .into_iter()
3197 .map(UserId::from_proto)
3198 .collect();
3199 let users = session
3200 .db()
3201 .await
3202 .get_users_by_ids(user_ids)
3203 .await?
3204 .into_iter()
3205 .map(|user| proto::User {
3206 id: user.id.to_proto(),
3207 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3208 github_login: user.github_login,
3209 })
3210 .collect();
3211 response.send(proto::UsersResponse { users })?;
3212 Ok(())
3213}
3214
3215/// Search for users (to invite) buy Github login
3216async fn fuzzy_search_users(
3217 request: proto::FuzzySearchUsers,
3218 response: Response<proto::FuzzySearchUsers>,
3219 session: UserSession,
3220) -> Result<()> {
3221 let query = request.query;
3222 let users = match query.len() {
3223 0 => vec![],
3224 1 | 2 => session
3225 .db()
3226 .await
3227 .get_user_by_github_login(&query)
3228 .await?
3229 .into_iter()
3230 .collect(),
3231 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3232 };
3233 let users = users
3234 .into_iter()
3235 .filter(|user| user.id != session.user_id())
3236 .map(|user| proto::User {
3237 id: user.id.to_proto(),
3238 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3239 github_login: user.github_login,
3240 })
3241 .collect();
3242 response.send(proto::UsersResponse { users })?;
3243 Ok(())
3244}
3245
3246/// Send a contact request to another user.
3247async fn request_contact(
3248 request: proto::RequestContact,
3249 response: Response<proto::RequestContact>,
3250 session: UserSession,
3251) -> Result<()> {
3252 let requester_id = session.user_id();
3253 let responder_id = UserId::from_proto(request.responder_id);
3254 if requester_id == responder_id {
3255 return Err(anyhow!("cannot add yourself as a contact"))?;
3256 }
3257
3258 let notifications = session
3259 .db()
3260 .await
3261 .send_contact_request(requester_id, responder_id)
3262 .await?;
3263
3264 // Update outgoing contact requests of requester
3265 let mut update = proto::UpdateContacts::default();
3266 update.outgoing_requests.push(responder_id.to_proto());
3267 for connection_id in session
3268 .connection_pool()
3269 .await
3270 .user_connection_ids(requester_id)
3271 {
3272 session.peer.send(connection_id, update.clone())?;
3273 }
3274
3275 // Update incoming contact requests of responder
3276 let mut update = proto::UpdateContacts::default();
3277 update
3278 .incoming_requests
3279 .push(proto::IncomingContactRequest {
3280 requester_id: requester_id.to_proto(),
3281 });
3282 let connection_pool = session.connection_pool().await;
3283 for connection_id in connection_pool.user_connection_ids(responder_id) {
3284 session.peer.send(connection_id, update.clone())?;
3285 }
3286
3287 send_notifications(&connection_pool, &session.peer, notifications);
3288
3289 response.send(proto::Ack {})?;
3290 Ok(())
3291}
3292
3293/// Accept or decline a contact request
3294async fn respond_to_contact_request(
3295 request: proto::RespondToContactRequest,
3296 response: Response<proto::RespondToContactRequest>,
3297 session: UserSession,
3298) -> Result<()> {
3299 let responder_id = session.user_id();
3300 let requester_id = UserId::from_proto(request.requester_id);
3301 let db = session.db().await;
3302 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3303 db.dismiss_contact_notification(responder_id, requester_id)
3304 .await?;
3305 } else {
3306 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3307
3308 let notifications = db
3309 .respond_to_contact_request(responder_id, requester_id, accept)
3310 .await?;
3311 let requester_busy = db.is_user_busy(requester_id).await?;
3312 let responder_busy = db.is_user_busy(responder_id).await?;
3313
3314 let pool = session.connection_pool().await;
3315 // Update responder with new contact
3316 let mut update = proto::UpdateContacts::default();
3317 if accept {
3318 update
3319 .contacts
3320 .push(contact_for_user(requester_id, requester_busy, &pool));
3321 }
3322 update
3323 .remove_incoming_requests
3324 .push(requester_id.to_proto());
3325 for connection_id in pool.user_connection_ids(responder_id) {
3326 session.peer.send(connection_id, update.clone())?;
3327 }
3328
3329 // Update requester with new contact
3330 let mut update = proto::UpdateContacts::default();
3331 if accept {
3332 update
3333 .contacts
3334 .push(contact_for_user(responder_id, responder_busy, &pool));
3335 }
3336 update
3337 .remove_outgoing_requests
3338 .push(responder_id.to_proto());
3339
3340 for connection_id in pool.user_connection_ids(requester_id) {
3341 session.peer.send(connection_id, update.clone())?;
3342 }
3343
3344 send_notifications(&pool, &session.peer, notifications);
3345 }
3346
3347 response.send(proto::Ack {})?;
3348 Ok(())
3349}
3350
3351/// Remove a contact.
3352async fn remove_contact(
3353 request: proto::RemoveContact,
3354 response: Response<proto::RemoveContact>,
3355 session: UserSession,
3356) -> Result<()> {
3357 let requester_id = session.user_id();
3358 let responder_id = UserId::from_proto(request.user_id);
3359 let db = session.db().await;
3360 let (contact_accepted, deleted_notification_id) =
3361 db.remove_contact(requester_id, responder_id).await?;
3362
3363 let pool = session.connection_pool().await;
3364 // Update outgoing contact requests of requester
3365 let mut update = proto::UpdateContacts::default();
3366 if contact_accepted {
3367 update.remove_contacts.push(responder_id.to_proto());
3368 } else {
3369 update
3370 .remove_outgoing_requests
3371 .push(responder_id.to_proto());
3372 }
3373 for connection_id in pool.user_connection_ids(requester_id) {
3374 session.peer.send(connection_id, update.clone())?;
3375 }
3376
3377 // Update incoming contact requests of responder
3378 let mut update = proto::UpdateContacts::default();
3379 if contact_accepted {
3380 update.remove_contacts.push(requester_id.to_proto());
3381 } else {
3382 update
3383 .remove_incoming_requests
3384 .push(requester_id.to_proto());
3385 }
3386 for connection_id in pool.user_connection_ids(responder_id) {
3387 session.peer.send(connection_id, update.clone())?;
3388 if let Some(notification_id) = deleted_notification_id {
3389 session.peer.send(
3390 connection_id,
3391 proto::DeleteNotification {
3392 notification_id: notification_id.to_proto(),
3393 },
3394 )?;
3395 }
3396 }
3397
3398 response.send(proto::Ack {})?;
3399 Ok(())
3400}
3401
3402/// Creates a new channel.
3403async fn create_channel(
3404 request: proto::CreateChannel,
3405 response: Response<proto::CreateChannel>,
3406 session: UserSession,
3407) -> Result<()> {
3408 let db = session.db().await;
3409
3410 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3411 let (channel, membership) = db
3412 .create_channel(&request.name, parent_id, session.user_id())
3413 .await?;
3414
3415 let root_id = channel.root_id();
3416 let channel = Channel::from_model(channel);
3417
3418 response.send(proto::CreateChannelResponse {
3419 channel: Some(channel.to_proto()),
3420 parent_id: request.parent_id,
3421 })?;
3422
3423 let mut connection_pool = session.connection_pool().await;
3424 if let Some(membership) = membership {
3425 connection_pool.subscribe_to_channel(
3426 membership.user_id,
3427 membership.channel_id,
3428 membership.role,
3429 );
3430 let update = proto::UpdateUserChannels {
3431 channel_memberships: vec![proto::ChannelMembership {
3432 channel_id: membership.channel_id.to_proto(),
3433 role: membership.role.into(),
3434 }],
3435 ..Default::default()
3436 };
3437 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3438 session.peer.send(connection_id, update.clone())?;
3439 }
3440 }
3441
3442 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3443 if !role.can_see_channel(channel.visibility) {
3444 continue;
3445 }
3446
3447 let update = proto::UpdateChannels {
3448 channels: vec![channel.to_proto()],
3449 ..Default::default()
3450 };
3451 session.peer.send(connection_id, update.clone())?;
3452 }
3453
3454 Ok(())
3455}
3456
3457/// Delete a channel
3458async fn delete_channel(
3459 request: proto::DeleteChannel,
3460 response: Response<proto::DeleteChannel>,
3461 session: UserSession,
3462) -> Result<()> {
3463 let db = session.db().await;
3464
3465 let channel_id = request.channel_id;
3466 let (root_channel, removed_channels) = db
3467 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3468 .await?;
3469 response.send(proto::Ack {})?;
3470
3471 // Notify members of removed channels
3472 let mut update = proto::UpdateChannels::default();
3473 update
3474 .delete_channels
3475 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3476
3477 let connection_pool = session.connection_pool().await;
3478 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3479 session.peer.send(connection_id, update.clone())?;
3480 }
3481
3482 Ok(())
3483}
3484
3485/// Invite someone to join a channel.
3486async fn invite_channel_member(
3487 request: proto::InviteChannelMember,
3488 response: Response<proto::InviteChannelMember>,
3489 session: UserSession,
3490) -> Result<()> {
3491 let db = session.db().await;
3492 let channel_id = ChannelId::from_proto(request.channel_id);
3493 let invitee_id = UserId::from_proto(request.user_id);
3494 let InviteMemberResult {
3495 channel,
3496 notifications,
3497 } = db
3498 .invite_channel_member(
3499 channel_id,
3500 invitee_id,
3501 session.user_id(),
3502 request.role().into(),
3503 )
3504 .await?;
3505
3506 let update = proto::UpdateChannels {
3507 channel_invitations: vec![channel.to_proto()],
3508 ..Default::default()
3509 };
3510
3511 let connection_pool = session.connection_pool().await;
3512 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3513 session.peer.send(connection_id, update.clone())?;
3514 }
3515
3516 send_notifications(&connection_pool, &session.peer, notifications);
3517
3518 response.send(proto::Ack {})?;
3519 Ok(())
3520}
3521
3522/// remove someone from a channel
3523async fn remove_channel_member(
3524 request: proto::RemoveChannelMember,
3525 response: Response<proto::RemoveChannelMember>,
3526 session: UserSession,
3527) -> Result<()> {
3528 let db = session.db().await;
3529 let channel_id = ChannelId::from_proto(request.channel_id);
3530 let member_id = UserId::from_proto(request.user_id);
3531
3532 let RemoveChannelMemberResult {
3533 membership_update,
3534 notification_id,
3535 } = db
3536 .remove_channel_member(channel_id, member_id, session.user_id())
3537 .await?;
3538
3539 let mut connection_pool = session.connection_pool().await;
3540 notify_membership_updated(
3541 &mut connection_pool,
3542 membership_update,
3543 member_id,
3544 &session.peer,
3545 );
3546 for connection_id in connection_pool.user_connection_ids(member_id) {
3547 if let Some(notification_id) = notification_id {
3548 session
3549 .peer
3550 .send(
3551 connection_id,
3552 proto::DeleteNotification {
3553 notification_id: notification_id.to_proto(),
3554 },
3555 )
3556 .trace_err();
3557 }
3558 }
3559
3560 response.send(proto::Ack {})?;
3561 Ok(())
3562}
3563
3564/// Toggle the channel between public and private.
3565/// Care is taken to maintain the invariant that public channels only descend from public channels,
3566/// (though members-only channels can appear at any point in the hierarchy).
3567async fn set_channel_visibility(
3568 request: proto::SetChannelVisibility,
3569 response: Response<proto::SetChannelVisibility>,
3570 session: UserSession,
3571) -> Result<()> {
3572 let db = session.db().await;
3573 let channel_id = ChannelId::from_proto(request.channel_id);
3574 let visibility = request.visibility().into();
3575
3576 let channel_model = db
3577 .set_channel_visibility(channel_id, visibility, session.user_id())
3578 .await?;
3579 let root_id = channel_model.root_id();
3580 let channel = Channel::from_model(channel_model);
3581
3582 let mut connection_pool = session.connection_pool().await;
3583 for (user_id, role) in connection_pool
3584 .channel_user_ids(root_id)
3585 .collect::<Vec<_>>()
3586 .into_iter()
3587 {
3588 let update = if role.can_see_channel(channel.visibility) {
3589 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3590 proto::UpdateChannels {
3591 channels: vec![channel.to_proto()],
3592 ..Default::default()
3593 }
3594 } else {
3595 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3596 proto::UpdateChannels {
3597 delete_channels: vec![channel.id.to_proto()],
3598 ..Default::default()
3599 }
3600 };
3601
3602 for connection_id in connection_pool.user_connection_ids(user_id) {
3603 session.peer.send(connection_id, update.clone())?;
3604 }
3605 }
3606
3607 response.send(proto::Ack {})?;
3608 Ok(())
3609}
3610
3611/// Alter the role for a user in the channel.
3612async fn set_channel_member_role(
3613 request: proto::SetChannelMemberRole,
3614 response: Response<proto::SetChannelMemberRole>,
3615 session: UserSession,
3616) -> Result<()> {
3617 let db = session.db().await;
3618 let channel_id = ChannelId::from_proto(request.channel_id);
3619 let member_id = UserId::from_proto(request.user_id);
3620 let result = db
3621 .set_channel_member_role(
3622 channel_id,
3623 session.user_id(),
3624 member_id,
3625 request.role().into(),
3626 )
3627 .await?;
3628
3629 match result {
3630 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3631 let mut connection_pool = session.connection_pool().await;
3632 notify_membership_updated(
3633 &mut connection_pool,
3634 membership_update,
3635 member_id,
3636 &session.peer,
3637 )
3638 }
3639 db::SetMemberRoleResult::InviteUpdated(channel) => {
3640 let update = proto::UpdateChannels {
3641 channel_invitations: vec![channel.to_proto()],
3642 ..Default::default()
3643 };
3644
3645 for connection_id in session
3646 .connection_pool()
3647 .await
3648 .user_connection_ids(member_id)
3649 {
3650 session.peer.send(connection_id, update.clone())?;
3651 }
3652 }
3653 }
3654
3655 response.send(proto::Ack {})?;
3656 Ok(())
3657}
3658
3659/// Change the name of a channel
3660async fn rename_channel(
3661 request: proto::RenameChannel,
3662 response: Response<proto::RenameChannel>,
3663 session: UserSession,
3664) -> Result<()> {
3665 let db = session.db().await;
3666 let channel_id = ChannelId::from_proto(request.channel_id);
3667 let channel_model = db
3668 .rename_channel(channel_id, session.user_id(), &request.name)
3669 .await?;
3670 let root_id = channel_model.root_id();
3671 let channel = Channel::from_model(channel_model);
3672
3673 response.send(proto::RenameChannelResponse {
3674 channel: Some(channel.to_proto()),
3675 })?;
3676
3677 let connection_pool = session.connection_pool().await;
3678 let update = proto::UpdateChannels {
3679 channels: vec![channel.to_proto()],
3680 ..Default::default()
3681 };
3682 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3683 if role.can_see_channel(channel.visibility) {
3684 session.peer.send(connection_id, update.clone())?;
3685 }
3686 }
3687
3688 Ok(())
3689}
3690
3691/// Move a channel to a new parent.
3692async fn move_channel(
3693 request: proto::MoveChannel,
3694 response: Response<proto::MoveChannel>,
3695 session: UserSession,
3696) -> Result<()> {
3697 let channel_id = ChannelId::from_proto(request.channel_id);
3698 let to = ChannelId::from_proto(request.to);
3699
3700 let (root_id, channels) = session
3701 .db()
3702 .await
3703 .move_channel(channel_id, to, session.user_id())
3704 .await?;
3705
3706 let connection_pool = session.connection_pool().await;
3707 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3708 let channels = channels
3709 .iter()
3710 .filter_map(|channel| {
3711 if role.can_see_channel(channel.visibility) {
3712 Some(channel.to_proto())
3713 } else {
3714 None
3715 }
3716 })
3717 .collect::<Vec<_>>();
3718 if channels.is_empty() {
3719 continue;
3720 }
3721
3722 let update = proto::UpdateChannels {
3723 channels,
3724 ..Default::default()
3725 };
3726
3727 session.peer.send(connection_id, update.clone())?;
3728 }
3729
3730 response.send(Ack {})?;
3731 Ok(())
3732}
3733
3734/// Get the list of channel members
3735async fn get_channel_members(
3736 request: proto::GetChannelMembers,
3737 response: Response<proto::GetChannelMembers>,
3738 session: UserSession,
3739) -> Result<()> {
3740 let db = session.db().await;
3741 let channel_id = ChannelId::from_proto(request.channel_id);
3742 let limit = if request.limit == 0 {
3743 u16::MAX as u64
3744 } else {
3745 request.limit
3746 };
3747 let (members, users) = db
3748 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3749 .await?;
3750 response.send(proto::GetChannelMembersResponse { members, users })?;
3751 Ok(())
3752}
3753
3754/// Accept or decline a channel invitation.
3755async fn respond_to_channel_invite(
3756 request: proto::RespondToChannelInvite,
3757 response: Response<proto::RespondToChannelInvite>,
3758 session: UserSession,
3759) -> Result<()> {
3760 let db = session.db().await;
3761 let channel_id = ChannelId::from_proto(request.channel_id);
3762 let RespondToChannelInvite {
3763 membership_update,
3764 notifications,
3765 } = db
3766 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3767 .await?;
3768
3769 let mut connection_pool = session.connection_pool().await;
3770 if let Some(membership_update) = membership_update {
3771 notify_membership_updated(
3772 &mut connection_pool,
3773 membership_update,
3774 session.user_id(),
3775 &session.peer,
3776 );
3777 } else {
3778 let update = proto::UpdateChannels {
3779 remove_channel_invitations: vec![channel_id.to_proto()],
3780 ..Default::default()
3781 };
3782
3783 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3784 session.peer.send(connection_id, update.clone())?;
3785 }
3786 };
3787
3788 send_notifications(&connection_pool, &session.peer, notifications);
3789
3790 response.send(proto::Ack {})?;
3791
3792 Ok(())
3793}
3794
3795/// Join the channels' room
3796async fn join_channel(
3797 request: proto::JoinChannel,
3798 response: Response<proto::JoinChannel>,
3799 session: UserSession,
3800) -> Result<()> {
3801 let channel_id = ChannelId::from_proto(request.channel_id);
3802 join_channel_internal(channel_id, Box::new(response), session).await
3803}
3804
3805trait JoinChannelInternalResponse {
3806 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3807}
3808impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3809 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3810 Response::<proto::JoinChannel>::send(self, result)
3811 }
3812}
3813impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3814 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3815 Response::<proto::JoinRoom>::send(self, result)
3816 }
3817}
3818
3819async fn join_channel_internal(
3820 channel_id: ChannelId,
3821 response: Box<impl JoinChannelInternalResponse>,
3822 session: UserSession,
3823) -> Result<()> {
3824 let joined_room = {
3825 let mut db = session.db().await;
3826 // If zed quits without leaving the room, and the user re-opens zed before the
3827 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3828 // room they were in.
3829 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3830 tracing::info!(
3831 stale_connection_id = %connection,
3832 "cleaning up stale connection",
3833 );
3834 drop(db);
3835 leave_room_for_session(&session, connection).await?;
3836 db = session.db().await;
3837 }
3838
3839 let (joined_room, membership_updated, role) = db
3840 .join_channel(channel_id, session.user_id(), session.connection_id)
3841 .await?;
3842
3843 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3844 let (can_publish, token) = if role == ChannelRole::Guest {
3845 (
3846 false,
3847 live_kit
3848 .guest_token(
3849 &joined_room.room.live_kit_room,
3850 &session.user_id().to_string(),
3851 )
3852 .trace_err()?,
3853 )
3854 } else {
3855 (
3856 true,
3857 live_kit
3858 .room_token(
3859 &joined_room.room.live_kit_room,
3860 &session.user_id().to_string(),
3861 )
3862 .trace_err()?,
3863 )
3864 };
3865
3866 Some(LiveKitConnectionInfo {
3867 server_url: live_kit.url().into(),
3868 token,
3869 can_publish,
3870 })
3871 });
3872
3873 response.send(proto::JoinRoomResponse {
3874 room: Some(joined_room.room.clone()),
3875 channel_id: joined_room
3876 .channel
3877 .as_ref()
3878 .map(|channel| channel.id.to_proto()),
3879 live_kit_connection_info,
3880 })?;
3881
3882 let mut connection_pool = session.connection_pool().await;
3883 if let Some(membership_updated) = membership_updated {
3884 notify_membership_updated(
3885 &mut connection_pool,
3886 membership_updated,
3887 session.user_id(),
3888 &session.peer,
3889 );
3890 }
3891
3892 room_updated(&joined_room.room, &session.peer);
3893
3894 joined_room
3895 };
3896
3897 channel_updated(
3898 &joined_room
3899 .channel
3900 .ok_or_else(|| anyhow!("channel not returned"))?,
3901 &joined_room.room,
3902 &session.peer,
3903 &*session.connection_pool().await,
3904 );
3905
3906 update_user_contacts(session.user_id(), &session).await?;
3907 Ok(())
3908}
3909
3910/// Start editing the channel notes
3911async fn join_channel_buffer(
3912 request: proto::JoinChannelBuffer,
3913 response: Response<proto::JoinChannelBuffer>,
3914 session: UserSession,
3915) -> Result<()> {
3916 let db = session.db().await;
3917 let channel_id = ChannelId::from_proto(request.channel_id);
3918
3919 let open_response = db
3920 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3921 .await?;
3922
3923 let collaborators = open_response.collaborators.clone();
3924 response.send(open_response)?;
3925
3926 let update = UpdateChannelBufferCollaborators {
3927 channel_id: channel_id.to_proto(),
3928 collaborators: collaborators.clone(),
3929 };
3930 channel_buffer_updated(
3931 session.connection_id,
3932 collaborators
3933 .iter()
3934 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3935 &update,
3936 &session.peer,
3937 );
3938
3939 Ok(())
3940}
3941
3942/// Edit the channel notes
3943async fn update_channel_buffer(
3944 request: proto::UpdateChannelBuffer,
3945 session: UserSession,
3946) -> Result<()> {
3947 let db = session.db().await;
3948 let channel_id = ChannelId::from_proto(request.channel_id);
3949
3950 let (collaborators, epoch, version) = db
3951 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3952 .await?;
3953
3954 channel_buffer_updated(
3955 session.connection_id,
3956 collaborators.clone(),
3957 &proto::UpdateChannelBuffer {
3958 channel_id: channel_id.to_proto(),
3959 operations: request.operations,
3960 },
3961 &session.peer,
3962 );
3963
3964 let pool = &*session.connection_pool().await;
3965
3966 let non_collaborators =
3967 pool.channel_connection_ids(channel_id)
3968 .filter_map(|(connection_id, _)| {
3969 if collaborators.contains(&connection_id) {
3970 None
3971 } else {
3972 Some(connection_id)
3973 }
3974 });
3975
3976 broadcast(None, non_collaborators, |peer_id| {
3977 session.peer.send(
3978 peer_id,
3979 proto::UpdateChannels {
3980 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3981 channel_id: channel_id.to_proto(),
3982 epoch: epoch as u64,
3983 version: version.clone(),
3984 }],
3985 ..Default::default()
3986 },
3987 )
3988 });
3989
3990 Ok(())
3991}
3992
3993/// Rejoin the channel notes after a connection blip
3994async fn rejoin_channel_buffers(
3995 request: proto::RejoinChannelBuffers,
3996 response: Response<proto::RejoinChannelBuffers>,
3997 session: UserSession,
3998) -> Result<()> {
3999 let db = session.db().await;
4000 let buffers = db
4001 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4002 .await?;
4003
4004 for rejoined_buffer in &buffers {
4005 let collaborators_to_notify = rejoined_buffer
4006 .buffer
4007 .collaborators
4008 .iter()
4009 .filter_map(|c| Some(c.peer_id?.into()));
4010 channel_buffer_updated(
4011 session.connection_id,
4012 collaborators_to_notify,
4013 &proto::UpdateChannelBufferCollaborators {
4014 channel_id: rejoined_buffer.buffer.channel_id,
4015 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4016 },
4017 &session.peer,
4018 );
4019 }
4020
4021 response.send(proto::RejoinChannelBuffersResponse {
4022 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4023 })?;
4024
4025 Ok(())
4026}
4027
4028/// Stop editing the channel notes
4029async fn leave_channel_buffer(
4030 request: proto::LeaveChannelBuffer,
4031 response: Response<proto::LeaveChannelBuffer>,
4032 session: UserSession,
4033) -> Result<()> {
4034 let db = session.db().await;
4035 let channel_id = ChannelId::from_proto(request.channel_id);
4036
4037 let left_buffer = db
4038 .leave_channel_buffer(channel_id, session.connection_id)
4039 .await?;
4040
4041 response.send(Ack {})?;
4042
4043 channel_buffer_updated(
4044 session.connection_id,
4045 left_buffer.connections,
4046 &proto::UpdateChannelBufferCollaborators {
4047 channel_id: channel_id.to_proto(),
4048 collaborators: left_buffer.collaborators,
4049 },
4050 &session.peer,
4051 );
4052
4053 Ok(())
4054}
4055
4056fn channel_buffer_updated<T: EnvelopedMessage>(
4057 sender_id: ConnectionId,
4058 collaborators: impl IntoIterator<Item = ConnectionId>,
4059 message: &T,
4060 peer: &Peer,
4061) {
4062 broadcast(Some(sender_id), collaborators, |peer_id| {
4063 peer.send(peer_id, message.clone())
4064 });
4065}
4066
4067fn send_notifications(
4068 connection_pool: &ConnectionPool,
4069 peer: &Peer,
4070 notifications: db::NotificationBatch,
4071) {
4072 for (user_id, notification) in notifications {
4073 for connection_id in connection_pool.user_connection_ids(user_id) {
4074 if let Err(error) = peer.send(
4075 connection_id,
4076 proto::AddNotification {
4077 notification: Some(notification.clone()),
4078 },
4079 ) {
4080 tracing::error!(
4081 "failed to send notification to {:?} {}",
4082 connection_id,
4083 error
4084 );
4085 }
4086 }
4087 }
4088}
4089
4090/// Send a message to the channel
4091async fn send_channel_message(
4092 request: proto::SendChannelMessage,
4093 response: Response<proto::SendChannelMessage>,
4094 session: UserSession,
4095) -> Result<()> {
4096 // Validate the message body.
4097 let body = request.body.trim().to_string();
4098 if body.len() > MAX_MESSAGE_LEN {
4099 return Err(anyhow!("message is too long"))?;
4100 }
4101 if body.is_empty() {
4102 return Err(anyhow!("message can't be blank"))?;
4103 }
4104
4105 // TODO: adjust mentions if body is trimmed
4106
4107 let timestamp = OffsetDateTime::now_utc();
4108 let nonce = request
4109 .nonce
4110 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4111
4112 let channel_id = ChannelId::from_proto(request.channel_id);
4113 let CreatedChannelMessage {
4114 message_id,
4115 participant_connection_ids,
4116 notifications,
4117 } = session
4118 .db()
4119 .await
4120 .create_channel_message(
4121 channel_id,
4122 session.user_id(),
4123 &body,
4124 &request.mentions,
4125 timestamp,
4126 nonce.clone().into(),
4127 match request.reply_to_message_id {
4128 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4129 None => None,
4130 },
4131 )
4132 .await?;
4133
4134 let message = proto::ChannelMessage {
4135 sender_id: session.user_id().to_proto(),
4136 id: message_id.to_proto(),
4137 body,
4138 mentions: request.mentions,
4139 timestamp: timestamp.unix_timestamp() as u64,
4140 nonce: Some(nonce),
4141 reply_to_message_id: request.reply_to_message_id,
4142 edited_at: None,
4143 };
4144 broadcast(
4145 Some(session.connection_id),
4146 participant_connection_ids.clone(),
4147 |connection| {
4148 session.peer.send(
4149 connection,
4150 proto::ChannelMessageSent {
4151 channel_id: channel_id.to_proto(),
4152 message: Some(message.clone()),
4153 },
4154 )
4155 },
4156 );
4157 response.send(proto::SendChannelMessageResponse {
4158 message: Some(message),
4159 })?;
4160
4161 let pool = &*session.connection_pool().await;
4162 let non_participants =
4163 pool.channel_connection_ids(channel_id)
4164 .filter_map(|(connection_id, _)| {
4165 if participant_connection_ids.contains(&connection_id) {
4166 None
4167 } else {
4168 Some(connection_id)
4169 }
4170 });
4171 broadcast(None, non_participants, |peer_id| {
4172 session.peer.send(
4173 peer_id,
4174 proto::UpdateChannels {
4175 latest_channel_message_ids: vec![proto::ChannelMessageId {
4176 channel_id: channel_id.to_proto(),
4177 message_id: message_id.to_proto(),
4178 }],
4179 ..Default::default()
4180 },
4181 )
4182 });
4183 send_notifications(pool, &session.peer, notifications);
4184
4185 Ok(())
4186}
4187
4188/// Delete a channel message
4189async fn remove_channel_message(
4190 request: proto::RemoveChannelMessage,
4191 response: Response<proto::RemoveChannelMessage>,
4192 session: UserSession,
4193) -> Result<()> {
4194 let channel_id = ChannelId::from_proto(request.channel_id);
4195 let message_id = MessageId::from_proto(request.message_id);
4196 let (connection_ids, existing_notification_ids) = session
4197 .db()
4198 .await
4199 .remove_channel_message(channel_id, message_id, session.user_id())
4200 .await?;
4201
4202 broadcast(
4203 Some(session.connection_id),
4204 connection_ids,
4205 move |connection| {
4206 session.peer.send(connection, request.clone())?;
4207
4208 for notification_id in &existing_notification_ids {
4209 session.peer.send(
4210 connection,
4211 proto::DeleteNotification {
4212 notification_id: (*notification_id).to_proto(),
4213 },
4214 )?;
4215 }
4216
4217 Ok(())
4218 },
4219 );
4220 response.send(proto::Ack {})?;
4221 Ok(())
4222}
4223
4224async fn update_channel_message(
4225 request: proto::UpdateChannelMessage,
4226 response: Response<proto::UpdateChannelMessage>,
4227 session: UserSession,
4228) -> Result<()> {
4229 let channel_id = ChannelId::from_proto(request.channel_id);
4230 let message_id = MessageId::from_proto(request.message_id);
4231 let updated_at = OffsetDateTime::now_utc();
4232 let UpdatedChannelMessage {
4233 message_id,
4234 participant_connection_ids,
4235 notifications,
4236 reply_to_message_id,
4237 timestamp,
4238 deleted_mention_notification_ids,
4239 updated_mention_notifications,
4240 } = session
4241 .db()
4242 .await
4243 .update_channel_message(
4244 channel_id,
4245 message_id,
4246 session.user_id(),
4247 request.body.as_str(),
4248 &request.mentions,
4249 updated_at,
4250 )
4251 .await?;
4252
4253 let nonce = request
4254 .nonce
4255 .clone()
4256 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4257
4258 let message = proto::ChannelMessage {
4259 sender_id: session.user_id().to_proto(),
4260 id: message_id.to_proto(),
4261 body: request.body.clone(),
4262 mentions: request.mentions.clone(),
4263 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4264 nonce: Some(nonce),
4265 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4266 edited_at: Some(updated_at.unix_timestamp() as u64),
4267 };
4268
4269 response.send(proto::Ack {})?;
4270
4271 let pool = &*session.connection_pool().await;
4272 broadcast(
4273 Some(session.connection_id),
4274 participant_connection_ids,
4275 |connection| {
4276 session.peer.send(
4277 connection,
4278 proto::ChannelMessageUpdate {
4279 channel_id: channel_id.to_proto(),
4280 message: Some(message.clone()),
4281 },
4282 )?;
4283
4284 for notification_id in &deleted_mention_notification_ids {
4285 session.peer.send(
4286 connection,
4287 proto::DeleteNotification {
4288 notification_id: (*notification_id).to_proto(),
4289 },
4290 )?;
4291 }
4292
4293 for notification in &updated_mention_notifications {
4294 session.peer.send(
4295 connection,
4296 proto::UpdateNotification {
4297 notification: Some(notification.clone()),
4298 },
4299 )?;
4300 }
4301
4302 Ok(())
4303 },
4304 );
4305
4306 send_notifications(pool, &session.peer, notifications);
4307
4308 Ok(())
4309}
4310
4311/// Mark a channel message as read
4312async fn acknowledge_channel_message(
4313 request: proto::AckChannelMessage,
4314 session: UserSession,
4315) -> Result<()> {
4316 let channel_id = ChannelId::from_proto(request.channel_id);
4317 let message_id = MessageId::from_proto(request.message_id);
4318 let notifications = session
4319 .db()
4320 .await
4321 .observe_channel_message(channel_id, session.user_id(), message_id)
4322 .await?;
4323 send_notifications(
4324 &*session.connection_pool().await,
4325 &session.peer,
4326 notifications,
4327 );
4328 Ok(())
4329}
4330
4331/// Mark a buffer version as synced
4332async fn acknowledge_buffer_version(
4333 request: proto::AckBufferOperation,
4334 session: UserSession,
4335) -> Result<()> {
4336 let buffer_id = BufferId::from_proto(request.buffer_id);
4337 session
4338 .db()
4339 .await
4340 .observe_buffer_version(
4341 buffer_id,
4342 session.user_id(),
4343 request.epoch as i32,
4344 &request.version,
4345 )
4346 .await?;
4347 Ok(())
4348}
4349
4350struct CompleteWithLanguageModelRateLimit;
4351
4352impl RateLimit for CompleteWithLanguageModelRateLimit {
4353 fn capacity() -> usize {
4354 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4355 .ok()
4356 .and_then(|v| v.parse().ok())
4357 .unwrap_or(120) // Picked arbitrarily
4358 }
4359
4360 fn refill_duration() -> chrono::Duration {
4361 chrono::Duration::hours(1)
4362 }
4363
4364 fn db_name() -> &'static str {
4365 "complete-with-language-model"
4366 }
4367}
4368
4369async fn complete_with_language_model(
4370 request: proto::CompleteWithLanguageModel,
4371 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4372 session: Session,
4373 open_ai_api_key: Option<Arc<str>>,
4374 google_ai_api_key: Option<Arc<str>>,
4375 anthropic_api_key: Option<Arc<str>>,
4376) -> Result<()> {
4377 let Some(session) = session.for_user() else {
4378 return Err(anyhow!("user not found"))?;
4379 };
4380 authorize_access_to_language_models(&session).await?;
4381 session
4382 .rate_limiter
4383 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4384 .await?;
4385
4386 if request.model.starts_with("gpt") {
4387 let api_key =
4388 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4389 complete_with_open_ai(request, response, session, api_key).await?;
4390 } else if request.model.starts_with("gemini") {
4391 let api_key = google_ai_api_key
4392 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4393 complete_with_google_ai(request, response, session, api_key).await?;
4394 } else if request.model.starts_with("claude") {
4395 let api_key = anthropic_api_key
4396 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4397 complete_with_anthropic(request, response, session, api_key).await?;
4398 }
4399
4400 Ok(())
4401}
4402
4403async fn complete_with_open_ai(
4404 request: proto::CompleteWithLanguageModel,
4405 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4406 session: UserSession,
4407 api_key: Arc<str>,
4408) -> Result<()> {
4409 let mut completion_stream = open_ai::stream_completion(
4410 session.http_client.as_ref(),
4411 OPEN_AI_API_URL,
4412 &api_key,
4413 crate::ai::language_model_request_to_open_ai(request)?,
4414 None,
4415 )
4416 .await
4417 .context("open_ai::stream_completion request failed within collab")?;
4418
4419 while let Some(event) = completion_stream.next().await {
4420 let event = event?;
4421 response.send(proto::LanguageModelResponse {
4422 choices: event
4423 .choices
4424 .into_iter()
4425 .map(|choice| proto::LanguageModelChoiceDelta {
4426 index: choice.index,
4427 delta: Some(proto::LanguageModelResponseMessage {
4428 role: choice.delta.role.map(|role| match role {
4429 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4430 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4431 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4432 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4433 } as i32),
4434 content: choice.delta.content,
4435 tool_calls: choice
4436 .delta
4437 .tool_calls
4438 .into_iter()
4439 .map(|delta| proto::ToolCallDelta {
4440 index: delta.index as u32,
4441 id: delta.id,
4442 variant: match delta.function {
4443 Some(function) => {
4444 let name = function.name;
4445 let arguments = function.arguments;
4446
4447 Some(proto::tool_call_delta::Variant::Function(
4448 proto::tool_call_delta::FunctionCallDelta {
4449 name,
4450 arguments,
4451 },
4452 ))
4453 }
4454 None => None,
4455 },
4456 })
4457 .collect(),
4458 }),
4459 finish_reason: choice.finish_reason,
4460 })
4461 .collect(),
4462 })?;
4463 }
4464
4465 Ok(())
4466}
4467
4468async fn complete_with_google_ai(
4469 request: proto::CompleteWithLanguageModel,
4470 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4471 session: UserSession,
4472 api_key: Arc<str>,
4473) -> Result<()> {
4474 let mut stream = google_ai::stream_generate_content(
4475 session.http_client.clone(),
4476 google_ai::API_URL,
4477 api_key.as_ref(),
4478 crate::ai::language_model_request_to_google_ai(request)?,
4479 )
4480 .await
4481 .context("google_ai::stream_generate_content request failed")?;
4482
4483 while let Some(event) = stream.next().await {
4484 let event = event?;
4485 response.send(proto::LanguageModelResponse {
4486 choices: event
4487 .candidates
4488 .unwrap_or_default()
4489 .into_iter()
4490 .map(|candidate| proto::LanguageModelChoiceDelta {
4491 index: candidate.index as u32,
4492 delta: Some(proto::LanguageModelResponseMessage {
4493 role: Some(match candidate.content.role {
4494 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4495 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4496 } as i32),
4497 content: Some(
4498 candidate
4499 .content
4500 .parts
4501 .into_iter()
4502 .filter_map(|part| match part {
4503 google_ai::Part::TextPart(part) => Some(part.text),
4504 google_ai::Part::InlineDataPart(_) => None,
4505 })
4506 .collect(),
4507 ),
4508 // Tool calls are not supported for Google
4509 tool_calls: Vec::new(),
4510 }),
4511 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4512 })
4513 .collect(),
4514 })?;
4515 }
4516
4517 Ok(())
4518}
4519
4520async fn complete_with_anthropic(
4521 request: proto::CompleteWithLanguageModel,
4522 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4523 session: UserSession,
4524 api_key: Arc<str>,
4525) -> Result<()> {
4526 let model = anthropic::Model::from_id(&request.model)?;
4527
4528 let mut system_message = String::new();
4529 let messages = request
4530 .messages
4531 .into_iter()
4532 .filter_map(|message| {
4533 match message.role() {
4534 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4535 role: anthropic::Role::User,
4536 content: message.content,
4537 }),
4538 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4539 role: anthropic::Role::Assistant,
4540 content: message.content,
4541 }),
4542 // Anthropic's API breaks system instructions out as a separate field rather
4543 // than having a system message role.
4544 LanguageModelRole::LanguageModelSystem => {
4545 if !system_message.is_empty() {
4546 system_message.push_str("\n\n");
4547 }
4548 system_message.push_str(&message.content);
4549
4550 None
4551 }
4552 // We don't yet support tool calls for Anthropic
4553 LanguageModelRole::LanguageModelTool => None,
4554 }
4555 })
4556 .collect();
4557
4558 let mut stream = anthropic::stream_completion(
4559 session.http_client.as_ref(),
4560 anthropic::ANTHROPIC_API_URL,
4561 &api_key,
4562 anthropic::Request {
4563 model,
4564 messages,
4565 stream: true,
4566 system: system_message,
4567 max_tokens: 4092,
4568 },
4569 None,
4570 )
4571 .await?;
4572
4573 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4574
4575 while let Some(event) = stream.next().await {
4576 let event = event?;
4577
4578 match event {
4579 anthropic::ResponseEvent::MessageStart { message } => {
4580 if let Some(role) = message.role {
4581 if role == "assistant" {
4582 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4583 } else if role == "user" {
4584 current_role = proto::LanguageModelRole::LanguageModelUser;
4585 }
4586 }
4587 }
4588 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4589 match content_block {
4590 anthropic::ContentBlock::Text { text } => {
4591 if !text.is_empty() {
4592 response.send(proto::LanguageModelResponse {
4593 choices: vec![proto::LanguageModelChoiceDelta {
4594 index: 0,
4595 delta: Some(proto::LanguageModelResponseMessage {
4596 role: Some(current_role as i32),
4597 content: Some(text),
4598 tool_calls: Vec::new(),
4599 }),
4600 finish_reason: None,
4601 }],
4602 })?;
4603 }
4604 }
4605 }
4606 }
4607 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4608 anthropic::TextDelta::TextDelta { text } => {
4609 response.send(proto::LanguageModelResponse {
4610 choices: vec![proto::LanguageModelChoiceDelta {
4611 index: 0,
4612 delta: Some(proto::LanguageModelResponseMessage {
4613 role: Some(current_role as i32),
4614 content: Some(text),
4615 tool_calls: Vec::new(),
4616 }),
4617 finish_reason: None,
4618 }],
4619 })?;
4620 }
4621 },
4622 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4623 if let Some(stop_reason) = delta.stop_reason {
4624 response.send(proto::LanguageModelResponse {
4625 choices: vec![proto::LanguageModelChoiceDelta {
4626 index: 0,
4627 delta: None,
4628 finish_reason: Some(stop_reason),
4629 }],
4630 })?;
4631 }
4632 }
4633 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4634 anthropic::ResponseEvent::MessageStop {} => {}
4635 anthropic::ResponseEvent::Ping {} => {}
4636 }
4637 }
4638
4639 Ok(())
4640}
4641
4642struct CountTokensWithLanguageModelRateLimit;
4643
4644impl RateLimit for CountTokensWithLanguageModelRateLimit {
4645 fn capacity() -> usize {
4646 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4647 .ok()
4648 .and_then(|v| v.parse().ok())
4649 .unwrap_or(600) // Picked arbitrarily
4650 }
4651
4652 fn refill_duration() -> chrono::Duration {
4653 chrono::Duration::hours(1)
4654 }
4655
4656 fn db_name() -> &'static str {
4657 "count-tokens-with-language-model"
4658 }
4659}
4660
4661async fn count_tokens_with_language_model(
4662 request: proto::CountTokensWithLanguageModel,
4663 response: Response<proto::CountTokensWithLanguageModel>,
4664 session: UserSession,
4665 google_ai_api_key: Option<Arc<str>>,
4666) -> Result<()> {
4667 authorize_access_to_language_models(&session).await?;
4668
4669 if !request.model.starts_with("gemini") {
4670 return Err(anyhow!(
4671 "counting tokens for model: {:?} is not supported",
4672 request.model
4673 ))?;
4674 }
4675
4676 session
4677 .rate_limiter
4678 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4679 .await?;
4680
4681 let api_key = google_ai_api_key
4682 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4683 let tokens_response = google_ai::count_tokens(
4684 session.http_client.as_ref(),
4685 google_ai::API_URL,
4686 &api_key,
4687 crate::ai::count_tokens_request_to_google_ai(request)?,
4688 )
4689 .await?;
4690 response.send(proto::CountTokensResponse {
4691 token_count: tokens_response.total_tokens as u32,
4692 })?;
4693 Ok(())
4694}
4695
4696struct ComputeEmbeddingsRateLimit;
4697
4698impl RateLimit for ComputeEmbeddingsRateLimit {
4699 fn capacity() -> usize {
4700 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4701 .ok()
4702 .and_then(|v| v.parse().ok())
4703 .unwrap_or(5000) // Picked arbitrarily
4704 }
4705
4706 fn refill_duration() -> chrono::Duration {
4707 chrono::Duration::hours(1)
4708 }
4709
4710 fn db_name() -> &'static str {
4711 "compute-embeddings"
4712 }
4713}
4714
4715async fn compute_embeddings(
4716 request: proto::ComputeEmbeddings,
4717 response: Response<proto::ComputeEmbeddings>,
4718 session: UserSession,
4719 api_key: Option<Arc<str>>,
4720) -> Result<()> {
4721 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4722 authorize_access_to_language_models(&session).await?;
4723
4724 session
4725 .rate_limiter
4726 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4727 .await?;
4728
4729 let embeddings = match request.model.as_str() {
4730 "openai/text-embedding-3-small" => {
4731 open_ai::embed(
4732 session.http_client.as_ref(),
4733 OPEN_AI_API_URL,
4734 &api_key,
4735 OpenAiEmbeddingModel::TextEmbedding3Small,
4736 request.texts.iter().map(|text| text.as_str()),
4737 )
4738 .await?
4739 }
4740 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4741 };
4742
4743 let embeddings = request
4744 .texts
4745 .iter()
4746 .map(|text| {
4747 let mut hasher = sha2::Sha256::new();
4748 hasher.update(text.as_bytes());
4749 let result = hasher.finalize();
4750 result.to_vec()
4751 })
4752 .zip(
4753 embeddings
4754 .data
4755 .into_iter()
4756 .map(|embedding| embedding.embedding),
4757 )
4758 .collect::<HashMap<_, _>>();
4759
4760 let db = session.db().await;
4761 db.save_embeddings(&request.model, &embeddings)
4762 .await
4763 .context("failed to save embeddings")
4764 .trace_err();
4765
4766 response.send(proto::ComputeEmbeddingsResponse {
4767 embeddings: embeddings
4768 .into_iter()
4769 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4770 .collect(),
4771 })?;
4772 Ok(())
4773}
4774
4775async fn get_cached_embeddings(
4776 request: proto::GetCachedEmbeddings,
4777 response: Response<proto::GetCachedEmbeddings>,
4778 session: UserSession,
4779) -> Result<()> {
4780 authorize_access_to_language_models(&session).await?;
4781
4782 let db = session.db().await;
4783 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4784
4785 response.send(proto::GetCachedEmbeddingsResponse {
4786 embeddings: embeddings
4787 .into_iter()
4788 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4789 .collect(),
4790 })?;
4791 Ok(())
4792}
4793
4794async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4795 let db = session.db().await;
4796 let flags = db.get_user_flags(session.user_id()).await?;
4797 if flags.iter().any(|flag| flag == "language-models") {
4798 Ok(())
4799 } else {
4800 Err(anyhow!("permission denied"))?
4801 }
4802}
4803
4804/// Get a Supermaven API key for the user
4805async fn get_supermaven_api_key(
4806 _request: proto::GetSupermavenApiKey,
4807 response: Response<proto::GetSupermavenApiKey>,
4808 session: UserSession,
4809) -> Result<()> {
4810 let user_id: String = session.user_id().to_string();
4811 if !session.is_staff() {
4812 return Err(anyhow!("supermaven not enabled for this account"))?;
4813 }
4814
4815 let email = session
4816 .email()
4817 .ok_or_else(|| anyhow!("user must have an email"))?;
4818
4819 let supermaven_admin_api = session
4820 .supermaven_client
4821 .as_ref()
4822 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4823
4824 let result = supermaven_admin_api
4825 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4826 .await?;
4827
4828 response.send(proto::GetSupermavenApiKeyResponse {
4829 api_key: result.api_key,
4830 })?;
4831
4832 Ok(())
4833}
4834
4835/// Start receiving chat updates for a channel
4836async fn join_channel_chat(
4837 request: proto::JoinChannelChat,
4838 response: Response<proto::JoinChannelChat>,
4839 session: UserSession,
4840) -> Result<()> {
4841 let channel_id = ChannelId::from_proto(request.channel_id);
4842
4843 let db = session.db().await;
4844 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4845 .await?;
4846 let messages = db
4847 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4848 .await?;
4849 response.send(proto::JoinChannelChatResponse {
4850 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4851 messages,
4852 })?;
4853 Ok(())
4854}
4855
4856/// Stop receiving chat updates for a channel
4857async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4858 let channel_id = ChannelId::from_proto(request.channel_id);
4859 session
4860 .db()
4861 .await
4862 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4863 .await?;
4864 Ok(())
4865}
4866
4867/// Retrieve the chat history for a channel
4868async fn get_channel_messages(
4869 request: proto::GetChannelMessages,
4870 response: Response<proto::GetChannelMessages>,
4871 session: UserSession,
4872) -> Result<()> {
4873 let channel_id = ChannelId::from_proto(request.channel_id);
4874 let messages = session
4875 .db()
4876 .await
4877 .get_channel_messages(
4878 channel_id,
4879 session.user_id(),
4880 MESSAGE_COUNT_PER_PAGE,
4881 Some(MessageId::from_proto(request.before_message_id)),
4882 )
4883 .await?;
4884 response.send(proto::GetChannelMessagesResponse {
4885 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4886 messages,
4887 })?;
4888 Ok(())
4889}
4890
4891/// Retrieve specific chat messages
4892async fn get_channel_messages_by_id(
4893 request: proto::GetChannelMessagesById,
4894 response: Response<proto::GetChannelMessagesById>,
4895 session: UserSession,
4896) -> Result<()> {
4897 let message_ids = request
4898 .message_ids
4899 .iter()
4900 .map(|id| MessageId::from_proto(*id))
4901 .collect::<Vec<_>>();
4902 let messages = session
4903 .db()
4904 .await
4905 .get_channel_messages_by_id(session.user_id(), &message_ids)
4906 .await?;
4907 response.send(proto::GetChannelMessagesResponse {
4908 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4909 messages,
4910 })?;
4911 Ok(())
4912}
4913
4914/// Retrieve the current users notifications
4915async fn get_notifications(
4916 request: proto::GetNotifications,
4917 response: Response<proto::GetNotifications>,
4918 session: UserSession,
4919) -> Result<()> {
4920 let notifications = session
4921 .db()
4922 .await
4923 .get_notifications(
4924 session.user_id(),
4925 NOTIFICATION_COUNT_PER_PAGE,
4926 request
4927 .before_id
4928 .map(|id| db::NotificationId::from_proto(id)),
4929 )
4930 .await?;
4931 response.send(proto::GetNotificationsResponse {
4932 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4933 notifications,
4934 })?;
4935 Ok(())
4936}
4937
4938/// Mark notifications as read
4939async fn mark_notification_as_read(
4940 request: proto::MarkNotificationRead,
4941 response: Response<proto::MarkNotificationRead>,
4942 session: UserSession,
4943) -> Result<()> {
4944 let database = &session.db().await;
4945 let notifications = database
4946 .mark_notification_as_read_by_id(
4947 session.user_id(),
4948 NotificationId::from_proto(request.notification_id),
4949 )
4950 .await?;
4951 send_notifications(
4952 &*session.connection_pool().await,
4953 &session.peer,
4954 notifications,
4955 );
4956 response.send(proto::Ack {})?;
4957 Ok(())
4958}
4959
4960/// Get the current users information
4961async fn get_private_user_info(
4962 _request: proto::GetPrivateUserInfo,
4963 response: Response<proto::GetPrivateUserInfo>,
4964 session: UserSession,
4965) -> Result<()> {
4966 let db = session.db().await;
4967
4968 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4969 let user = db
4970 .get_user_by_id(session.user_id())
4971 .await?
4972 .ok_or_else(|| anyhow!("user not found"))?;
4973 let flags = db.get_user_flags(session.user_id()).await?;
4974
4975 response.send(proto::GetPrivateUserInfoResponse {
4976 metrics_id,
4977 staff: user.admin,
4978 flags,
4979 })?;
4980 Ok(())
4981}
4982
4983fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4984 match message {
4985 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4986 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4987 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4988 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4989 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4990 code: frame.code.into(),
4991 reason: frame.reason,
4992 })),
4993 }
4994}
4995
4996fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4997 match message {
4998 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4999 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5000 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5001 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5002 AxumMessage::Close(frame) => {
5003 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5004 code: frame.code.into(),
5005 reason: frame.reason,
5006 }))
5007 }
5008 }
5009}
5010
5011fn notify_membership_updated(
5012 connection_pool: &mut ConnectionPool,
5013 result: MembershipUpdated,
5014 user_id: UserId,
5015 peer: &Peer,
5016) {
5017 for membership in &result.new_channels.channel_memberships {
5018 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5019 }
5020 for channel_id in &result.removed_channels {
5021 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5022 }
5023
5024 let user_channels_update = proto::UpdateUserChannels {
5025 channel_memberships: result
5026 .new_channels
5027 .channel_memberships
5028 .iter()
5029 .map(|cm| proto::ChannelMembership {
5030 channel_id: cm.channel_id.to_proto(),
5031 role: cm.role.into(),
5032 })
5033 .collect(),
5034 ..Default::default()
5035 };
5036
5037 let mut update = build_channels_update(result.new_channels, vec![]);
5038 update.delete_channels = result
5039 .removed_channels
5040 .into_iter()
5041 .map(|id| id.to_proto())
5042 .collect();
5043 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5044
5045 for connection_id in connection_pool.user_connection_ids(user_id) {
5046 peer.send(connection_id, user_channels_update.clone())
5047 .trace_err();
5048 peer.send(connection_id, update.clone()).trace_err();
5049 }
5050}
5051
5052fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5053 proto::UpdateUserChannels {
5054 channel_memberships: channels
5055 .channel_memberships
5056 .iter()
5057 .map(|m| proto::ChannelMembership {
5058 channel_id: m.channel_id.to_proto(),
5059 role: m.role.into(),
5060 })
5061 .collect(),
5062 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5063 observed_channel_message_id: channels.observed_channel_messages.clone(),
5064 }
5065}
5066
5067fn build_channels_update(
5068 channels: ChannelsForUser,
5069 channel_invites: Vec<db::Channel>,
5070) -> proto::UpdateChannels {
5071 let mut update = proto::UpdateChannels::default();
5072
5073 for channel in channels.channels {
5074 update.channels.push(channel.to_proto());
5075 }
5076
5077 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5078 update.latest_channel_message_ids = channels.latest_channel_messages;
5079
5080 for (channel_id, participants) in channels.channel_participants {
5081 update
5082 .channel_participants
5083 .push(proto::ChannelParticipants {
5084 channel_id: channel_id.to_proto(),
5085 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5086 });
5087 }
5088
5089 for channel in channel_invites {
5090 update.channel_invitations.push(channel.to_proto());
5091 }
5092
5093 update.hosted_projects = channels.hosted_projects;
5094 update
5095}
5096
5097fn build_initial_contacts_update(
5098 contacts: Vec<db::Contact>,
5099 pool: &ConnectionPool,
5100) -> proto::UpdateContacts {
5101 let mut update = proto::UpdateContacts::default();
5102
5103 for contact in contacts {
5104 match contact {
5105 db::Contact::Accepted { user_id, busy } => {
5106 update.contacts.push(contact_for_user(user_id, busy, &pool));
5107 }
5108 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5109 db::Contact::Incoming { user_id } => {
5110 update
5111 .incoming_requests
5112 .push(proto::IncomingContactRequest {
5113 requester_id: user_id.to_proto(),
5114 })
5115 }
5116 }
5117 }
5118
5119 update
5120}
5121
5122fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5123 proto::Contact {
5124 user_id: user_id.to_proto(),
5125 online: pool.is_user_online(user_id),
5126 busy,
5127 }
5128}
5129
5130fn room_updated(room: &proto::Room, peer: &Peer) {
5131 broadcast(
5132 None,
5133 room.participants
5134 .iter()
5135 .filter_map(|participant| Some(participant.peer_id?.into())),
5136 |peer_id| {
5137 peer.send(
5138 peer_id,
5139 proto::RoomUpdated {
5140 room: Some(room.clone()),
5141 },
5142 )
5143 },
5144 );
5145}
5146
5147fn channel_updated(
5148 channel: &db::channel::Model,
5149 room: &proto::Room,
5150 peer: &Peer,
5151 pool: &ConnectionPool,
5152) {
5153 let participants = room
5154 .participants
5155 .iter()
5156 .map(|p| p.user_id)
5157 .collect::<Vec<_>>();
5158
5159 broadcast(
5160 None,
5161 pool.channel_connection_ids(channel.root_id())
5162 .filter_map(|(channel_id, role)| {
5163 role.can_see_channel(channel.visibility).then(|| channel_id)
5164 }),
5165 |peer_id| {
5166 peer.send(
5167 peer_id,
5168 proto::UpdateChannels {
5169 channel_participants: vec![proto::ChannelParticipants {
5170 channel_id: channel.id.to_proto(),
5171 participant_user_ids: participants.clone(),
5172 }],
5173 ..Default::default()
5174 },
5175 )
5176 },
5177 );
5178}
5179
5180async fn send_dev_server_projects_update(
5181 user_id: UserId,
5182 mut status: proto::DevServerProjectsUpdate,
5183 session: &Session,
5184) {
5185 let pool = session.connection_pool().await;
5186 for dev_server in &mut status.dev_servers {
5187 dev_server.status =
5188 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5189 }
5190 let connections = pool.user_connection_ids(user_id);
5191 for connection_id in connections {
5192 session.peer.send(connection_id, status.clone()).trace_err();
5193 }
5194}
5195
5196async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5197 let db = session.db().await;
5198
5199 let contacts = db.get_contacts(user_id).await?;
5200 let busy = db.is_user_busy(user_id).await?;
5201
5202 let pool = session.connection_pool().await;
5203 let updated_contact = contact_for_user(user_id, busy, &pool);
5204 for contact in contacts {
5205 if let db::Contact::Accepted {
5206 user_id: contact_user_id,
5207 ..
5208 } = contact
5209 {
5210 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5211 session
5212 .peer
5213 .send(
5214 contact_conn_id,
5215 proto::UpdateContacts {
5216 contacts: vec![updated_contact.clone()],
5217 remove_contacts: Default::default(),
5218 incoming_requests: Default::default(),
5219 remove_incoming_requests: Default::default(),
5220 outgoing_requests: Default::default(),
5221 remove_outgoing_requests: Default::default(),
5222 },
5223 )
5224 .trace_err();
5225 }
5226 }
5227 }
5228 Ok(())
5229}
5230
5231async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5232 log::info!("lost dev server connection, unsharing projects");
5233 let project_ids = session
5234 .db()
5235 .await
5236 .get_stale_dev_server_projects(session.connection_id)
5237 .await?;
5238
5239 for project_id in project_ids {
5240 // not unshare re-checks the connection ids match, so we get away with no transaction
5241 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5242 }
5243
5244 let user_id = session.dev_server().user_id;
5245 let update = session
5246 .db()
5247 .await
5248 .dev_server_projects_update(user_id)
5249 .await?;
5250
5251 send_dev_server_projects_update(user_id, update, session).await;
5252
5253 Ok(())
5254}
5255
5256async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5257 let mut contacts_to_update = HashSet::default();
5258
5259 let room_id;
5260 let canceled_calls_to_user_ids;
5261 let live_kit_room;
5262 let delete_live_kit_room;
5263 let room;
5264 let channel;
5265
5266 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5267 contacts_to_update.insert(session.user_id());
5268
5269 for project in left_room.left_projects.values() {
5270 project_left(project, session);
5271 }
5272
5273 room_id = RoomId::from_proto(left_room.room.id);
5274 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5275 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5276 delete_live_kit_room = left_room.deleted;
5277 room = mem::take(&mut left_room.room);
5278 channel = mem::take(&mut left_room.channel);
5279
5280 room_updated(&room, &session.peer);
5281 } else {
5282 return Ok(());
5283 }
5284
5285 if let Some(channel) = channel {
5286 channel_updated(
5287 &channel,
5288 &room,
5289 &session.peer,
5290 &*session.connection_pool().await,
5291 );
5292 }
5293
5294 {
5295 let pool = session.connection_pool().await;
5296 for canceled_user_id in canceled_calls_to_user_ids {
5297 for connection_id in pool.user_connection_ids(canceled_user_id) {
5298 session
5299 .peer
5300 .send(
5301 connection_id,
5302 proto::CallCanceled {
5303 room_id: room_id.to_proto(),
5304 },
5305 )
5306 .trace_err();
5307 }
5308 contacts_to_update.insert(canceled_user_id);
5309 }
5310 }
5311
5312 for contact_user_id in contacts_to_update {
5313 update_user_contacts(contact_user_id, &session).await?;
5314 }
5315
5316 if let Some(live_kit) = session.live_kit_client.as_ref() {
5317 live_kit
5318 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5319 .await
5320 .trace_err();
5321
5322 if delete_live_kit_room {
5323 live_kit.delete_room(live_kit_room).await.trace_err();
5324 }
5325 }
5326
5327 Ok(())
5328}
5329
5330async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5331 let left_channel_buffers = session
5332 .db()
5333 .await
5334 .leave_channel_buffers(session.connection_id)
5335 .await?;
5336
5337 for left_buffer in left_channel_buffers {
5338 channel_buffer_updated(
5339 session.connection_id,
5340 left_buffer.connections,
5341 &proto::UpdateChannelBufferCollaborators {
5342 channel_id: left_buffer.channel_id.to_proto(),
5343 collaborators: left_buffer.collaborators,
5344 },
5345 &session.peer,
5346 );
5347 }
5348
5349 Ok(())
5350}
5351
5352fn project_left(project: &db::LeftProject, session: &UserSession) {
5353 for connection_id in &project.connection_ids {
5354 if project.should_unshare {
5355 session
5356 .peer
5357 .send(
5358 *connection_id,
5359 proto::UnshareProject {
5360 project_id: project.id.to_proto(),
5361 },
5362 )
5363 .trace_err();
5364 } else {
5365 session
5366 .peer
5367 .send(
5368 *connection_id,
5369 proto::RemoveProjectCollaborator {
5370 project_id: project.id.to_proto(),
5371 peer_id: Some(session.connection_id.into()),
5372 },
5373 )
5374 .trace_err();
5375 }
5376 }
5377}
5378
5379pub trait ResultExt {
5380 type Ok;
5381
5382 fn trace_err(self) -> Option<Self::Ok>;
5383}
5384
5385impl<T, E> ResultExt for Result<T, E>
5386where
5387 E: std::fmt::Debug,
5388{
5389 type Ok = T;
5390
5391 #[track_caller]
5392 fn trace_err(self) -> Option<T> {
5393 match self {
5394 Ok(value) => Some(value),
5395 Err(error) => {
5396 tracing::error!("{:?}", error);
5397 None
5398 }
5399 }
5400 }
5401}