1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37
38use futures::{
39 channel::oneshot,
40 future::{self, BoxFuture},
41 stream::FuturesUnordered,
42 FutureExt, SinkExt, StreamExt, TryStreamExt,
43};
44use prometheus::{register_int_gauge, IntGauge};
45use rpc::{
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
48 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 atomic::{AtomicBool, Ordering::SeqCst},
64 Arc, OnceLock,
65 },
66 time::{Duration, Instant},
67};
68use time::OffsetDateTime;
69use tokio::sync::{watch, Semaphore};
70use tower::ServiceBuilder;
71use tracing::{
72 field::{self},
73 info_span, instrument, Instrument,
74};
75use util::http::IsahcHttpClient;
76
77use self::connection_pool::VersionedMessage;
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105struct StreamingResponse<R: RequestMessage> {
106 peer: Arc<Peer>,
107 receipt: Receipt<R>,
108}
109
110impl<R: RequestMessage> StreamingResponse<R> {
111 fn send(&self, payload: R::Response) -> Result<()> {
112 self.peer.respond(self.receipt, payload)?;
113 Ok(())
114 }
115}
116
117#[derive(Clone, Debug)]
118pub enum Principal {
119 User(User),
120 Impersonated { user: User, admin: User },
121 DevServer(dev_server::Model),
122}
123
124impl Principal {
125 fn update_span(&self, span: &tracing::Span) {
126 match &self {
127 Principal::User(user) => {
128 span.record("user_id", &user.id.0);
129 span.record("login", &user.github_login);
130 }
131 Principal::Impersonated { user, admin } => {
132 span.record("user_id", &user.id.0);
133 span.record("login", &user.github_login);
134 span.record("impersonator", &admin.github_login);
135 }
136 Principal::DevServer(dev_server) => {
137 span.record("dev_server_id", &dev_server.id.0);
138 }
139 }
140 }
141}
142
143#[derive(Clone)]
144struct Session {
145 principal: Principal,
146 connection_id: ConnectionId,
147 db: Arc<tokio::sync::Mutex<DbHandle>>,
148 peer: Arc<Peer>,
149 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
150 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
151 http_client: IsahcHttpClient,
152 rate_limiter: Arc<RateLimiter>,
153 _executor: Executor,
154}
155
156impl Session {
157 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
158 #[cfg(test)]
159 tokio::task::yield_now().await;
160 let guard = self.db.lock().await;
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 guard
164 }
165
166 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
167 #[cfg(test)]
168 tokio::task::yield_now().await;
169 let guard = self.connection_pool.lock();
170 ConnectionPoolGuard {
171 guard,
172 _not_send: PhantomData,
173 }
174 }
175
176 fn for_user(self) -> Option<UserSession> {
177 UserSession::new(self)
178 }
179
180 fn for_dev_server(self) -> Option<DevServerSession> {
181 DevServerSession::new(self)
182 }
183
184 fn user_id(&self) -> Option<UserId> {
185 match &self.principal {
186 Principal::User(user) => Some(user.id),
187 Principal::Impersonated { user, .. } => Some(user.id),
188 Principal::DevServer(_) => None,
189 }
190 }
191
192 fn dev_server_id(&self) -> Option<DevServerId> {
193 match &self.principal {
194 Principal::User(_) | Principal::Impersonated { .. } => None,
195 Principal::DevServer(dev_server) => Some(dev_server.id),
196 }
197 }
198
199 fn principal_id(&self) -> PrincipalId {
200 match &self.principal {
201 Principal::User(user) => PrincipalId::UserId(user.id),
202 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
203 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
204 }
205 }
206}
207
208impl Debug for Session {
209 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
210 let mut result = f.debug_struct("Session");
211 match &self.principal {
212 Principal::User(user) => {
213 result.field("user", &user.github_login);
214 }
215 Principal::Impersonated { user, admin } => {
216 result.field("user", &user.github_login);
217 result.field("impersonator", &admin.github_login);
218 }
219 Principal::DevServer(dev_server) => {
220 result.field("dev_server", &dev_server.id);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct UserSession(Session);
228
229impl UserSession {
230 pub fn new(s: Session) -> Option<Self> {
231 s.user_id().map(|_| UserSession(s))
232 }
233 pub fn user_id(&self) -> UserId {
234 self.0.user_id().unwrap()
235 }
236}
237
238impl Deref for UserSession {
239 type Target = Session;
240
241 fn deref(&self) -> &Self::Target {
242 &self.0
243 }
244}
245impl DerefMut for UserSession {
246 fn deref_mut(&mut self) -> &mut Self::Target {
247 &mut self.0
248 }
249}
250
251struct DevServerSession(Session);
252
253impl DevServerSession {
254 pub fn new(s: Session) -> Option<Self> {
255 s.dev_server_id().map(|_| DevServerSession(s))
256 }
257 pub fn dev_server_id(&self) -> DevServerId {
258 self.0.dev_server_id().unwrap()
259 }
260
261 fn dev_server(&self) -> &dev_server::Model {
262 match &self.0.principal {
263 Principal::DevServer(dev_server) => dev_server,
264 _ => unreachable!(),
265 }
266 }
267}
268
269impl Deref for DevServerSession {
270 type Target = Session;
271
272 fn deref(&self) -> &Self::Target {
273 &self.0
274 }
275}
276impl DerefMut for DevServerSession {
277 fn deref_mut(&mut self) -> &mut Self::Target {
278 &mut self.0
279 }
280}
281
282fn user_handler<M: RequestMessage, Fut>(
283 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
284) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
285where
286 Fut: Send + Future<Output = Result<()>>,
287{
288 let handler = Arc::new(handler);
289 move |message, response, session| {
290 let handler = handler.clone();
291 Box::pin(async move {
292 if let Some(user_session) = session.for_user() {
293 Ok(handler(message, response, user_session).await?)
294 } else {
295 Err(Error::Internal(anyhow!(
296 "must be a user to call {}",
297 M::NAME
298 )))
299 }
300 })
301 }
302}
303
304fn dev_server_handler<M: RequestMessage, Fut>(
305 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
306) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
307where
308 Fut: Send + Future<Output = Result<()>>,
309{
310 let handler = Arc::new(handler);
311 move |message, response, session| {
312 let handler = handler.clone();
313 Box::pin(async move {
314 if let Some(dev_server_session) = session.for_dev_server() {
315 Ok(handler(message, response, dev_server_session).await?)
316 } else {
317 Err(Error::Internal(anyhow!(
318 "must be a dev server to call {}",
319 M::NAME
320 )))
321 }
322 })
323 }
324}
325
326fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
327 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
328) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
329where
330 InnertRetFut: Send + Future<Output = Result<()>>,
331{
332 let handler = Arc::new(handler);
333 move |message, session| {
334 let handler = handler.clone();
335 Box::pin(async move {
336 if let Some(user_session) = session.for_user() {
337 Ok(handler(message, user_session).await?)
338 } else {
339 Err(Error::Internal(anyhow!(
340 "must be a user to call {}",
341 M::NAME
342 )))
343 }
344 })
345 }
346}
347
348struct DbHandle(Arc<Database>);
349
350impl Deref for DbHandle {
351 type Target = Database;
352
353 fn deref(&self) -> &Self::Target {
354 self.0.as_ref()
355 }
356}
357
358pub struct Server {
359 id: parking_lot::Mutex<ServerId>,
360 peer: Arc<Peer>,
361 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
362 app_state: Arc<AppState>,
363 handlers: HashMap<TypeId, MessageHandler>,
364 teardown: watch::Sender<bool>,
365}
366
367pub(crate) struct ConnectionPoolGuard<'a> {
368 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
369 _not_send: PhantomData<Rc<()>>,
370}
371
372#[derive(Serialize)]
373pub struct ServerSnapshot<'a> {
374 peer: &'a Peer,
375 #[serde(serialize_with = "serialize_deref")]
376 connection_pool: ConnectionPoolGuard<'a>,
377}
378
379pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
380where
381 S: Serializer,
382 T: Deref<Target = U>,
383 U: Serialize,
384{
385 Serialize::serialize(value.deref(), serializer)
386}
387
388impl Server {
389 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
390 let mut server = Self {
391 id: parking_lot::Mutex::new(id),
392 peer: Peer::new(id.0 as u32),
393 app_state: app_state.clone(),
394 connection_pool: Default::default(),
395 handlers: Default::default(),
396 teardown: watch::channel(false).0,
397 };
398
399 server
400 .add_request_handler(ping)
401 .add_request_handler(user_handler(create_room))
402 .add_request_handler(user_handler(join_room))
403 .add_request_handler(user_handler(rejoin_room))
404 .add_request_handler(user_handler(leave_room))
405 .add_request_handler(user_handler(set_room_participant_role))
406 .add_request_handler(user_handler(call))
407 .add_request_handler(user_handler(cancel_call))
408 .add_message_handler(user_message_handler(decline_call))
409 .add_request_handler(user_handler(update_participant_location))
410 .add_request_handler(user_handler(share_project))
411 .add_message_handler(unshare_project)
412 .add_request_handler(user_handler(join_project))
413 .add_request_handler(user_handler(join_hosted_project))
414 .add_request_handler(user_handler(rejoin_dev_server_projects))
415 .add_request_handler(user_handler(create_dev_server_project))
416 .add_request_handler(user_handler(delete_dev_server_project))
417 .add_request_handler(user_handler(create_dev_server))
418 .add_request_handler(user_handler(delete_dev_server))
419 .add_request_handler(dev_server_handler(share_dev_server_project))
420 .add_request_handler(dev_server_handler(shutdown_dev_server))
421 .add_request_handler(dev_server_handler(reconnect_dev_server))
422 .add_message_handler(user_message_handler(leave_project))
423 .add_request_handler(update_project)
424 .add_request_handler(update_worktree)
425 .add_message_handler(start_language_server)
426 .add_message_handler(update_language_server)
427 .add_message_handler(update_diagnostic_summary)
428 .add_message_handler(update_worktree_settings)
429 .add_request_handler(user_handler(
430 forward_read_only_project_request::<proto::GetHover>,
431 ))
432 .add_request_handler(user_handler(
433 forward_read_only_project_request::<proto::GetDefinition>,
434 ))
435 .add_request_handler(user_handler(
436 forward_read_only_project_request::<proto::GetTypeDefinition>,
437 ))
438 .add_request_handler(user_handler(
439 forward_read_only_project_request::<proto::GetReferences>,
440 ))
441 .add_request_handler(user_handler(
442 forward_read_only_project_request::<proto::SearchProject>,
443 ))
444 .add_request_handler(user_handler(
445 forward_read_only_project_request::<proto::GetDocumentHighlights>,
446 ))
447 .add_request_handler(user_handler(
448 forward_read_only_project_request::<proto::GetProjectSymbols>,
449 ))
450 .add_request_handler(user_handler(
451 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
452 ))
453 .add_request_handler(user_handler(
454 forward_read_only_project_request::<proto::OpenBufferById>,
455 ))
456 .add_request_handler(user_handler(
457 forward_read_only_project_request::<proto::SynchronizeBuffers>,
458 ))
459 .add_request_handler(user_handler(
460 forward_read_only_project_request::<proto::InlayHints>,
461 ))
462 .add_request_handler(user_handler(
463 forward_read_only_project_request::<proto::OpenBufferByPath>,
464 ))
465 .add_request_handler(user_handler(
466 forward_mutating_project_request::<proto::GetCompletions>,
467 ))
468 .add_request_handler(user_handler(
469 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
470 ))
471 .add_request_handler(user_handler(
472 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
473 ))
474 .add_request_handler(user_handler(
475 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
476 ))
477 .add_request_handler(user_handler(
478 forward_mutating_project_request::<proto::GetCodeActions>,
479 ))
480 .add_request_handler(user_handler(
481 forward_mutating_project_request::<proto::ApplyCodeAction>,
482 ))
483 .add_request_handler(user_handler(
484 forward_mutating_project_request::<proto::PrepareRename>,
485 ))
486 .add_request_handler(user_handler(
487 forward_mutating_project_request::<proto::PerformRename>,
488 ))
489 .add_request_handler(user_handler(
490 forward_mutating_project_request::<proto::ReloadBuffers>,
491 ))
492 .add_request_handler(user_handler(
493 forward_mutating_project_request::<proto::FormatBuffers>,
494 ))
495 .add_request_handler(user_handler(
496 forward_mutating_project_request::<proto::CreateProjectEntry>,
497 ))
498 .add_request_handler(user_handler(
499 forward_mutating_project_request::<proto::RenameProjectEntry>,
500 ))
501 .add_request_handler(user_handler(
502 forward_mutating_project_request::<proto::CopyProjectEntry>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::DeleteProjectEntry>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::ExpandProjectEntry>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::OnTypeFormatting>,
512 ))
513 .add_request_handler(user_handler(
514 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
515 ))
516 .add_request_handler(user_handler(
517 forward_mutating_project_request::<proto::BlameBuffer>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::MultiLspQuery>,
521 ))
522 .add_message_handler(create_buffer_for_peer)
523 .add_request_handler(update_buffer)
524 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
525 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
526 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
527 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
528 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
529 .add_request_handler(get_users)
530 .add_request_handler(user_handler(fuzzy_search_users))
531 .add_request_handler(user_handler(request_contact))
532 .add_request_handler(user_handler(remove_contact))
533 .add_request_handler(user_handler(respond_to_contact_request))
534 .add_request_handler(user_handler(create_channel))
535 .add_request_handler(user_handler(delete_channel))
536 .add_request_handler(user_handler(invite_channel_member))
537 .add_request_handler(user_handler(remove_channel_member))
538 .add_request_handler(user_handler(set_channel_member_role))
539 .add_request_handler(user_handler(set_channel_visibility))
540 .add_request_handler(user_handler(rename_channel))
541 .add_request_handler(user_handler(join_channel_buffer))
542 .add_request_handler(user_handler(leave_channel_buffer))
543 .add_message_handler(user_message_handler(update_channel_buffer))
544 .add_request_handler(user_handler(rejoin_channel_buffers))
545 .add_request_handler(user_handler(get_channel_members))
546 .add_request_handler(user_handler(respond_to_channel_invite))
547 .add_request_handler(user_handler(join_channel))
548 .add_request_handler(user_handler(join_channel_chat))
549 .add_message_handler(user_message_handler(leave_channel_chat))
550 .add_request_handler(user_handler(send_channel_message))
551 .add_request_handler(user_handler(remove_channel_message))
552 .add_request_handler(user_handler(update_channel_message))
553 .add_request_handler(user_handler(get_channel_messages))
554 .add_request_handler(user_handler(get_channel_messages_by_id))
555 .add_request_handler(user_handler(get_notifications))
556 .add_request_handler(user_handler(mark_notification_as_read))
557 .add_request_handler(user_handler(move_channel))
558 .add_request_handler(user_handler(follow))
559 .add_message_handler(user_message_handler(unfollow))
560 .add_message_handler(user_message_handler(update_followers))
561 .add_request_handler(user_handler(get_private_user_info))
562 .add_message_handler(user_message_handler(acknowledge_channel_message))
563 .add_message_handler(user_message_handler(acknowledge_buffer_version))
564 .add_streaming_request_handler({
565 let app_state = app_state.clone();
566 move |request, response, session| {
567 complete_with_language_model(
568 request,
569 response,
570 session,
571 app_state.config.openai_api_key.clone(),
572 app_state.config.google_ai_api_key.clone(),
573 app_state.config.anthropic_api_key.clone(),
574 )
575 }
576 })
577 .add_request_handler({
578 let app_state = app_state.clone();
579 user_handler(move |request, response, session| {
580 count_tokens_with_language_model(
581 request,
582 response,
583 session,
584 app_state.config.google_ai_api_key.clone(),
585 )
586 })
587 })
588 .add_request_handler({
589 user_handler(move |request, response, session| {
590 get_cached_embeddings(request, response, session)
591 })
592 })
593 .add_request_handler({
594 let app_state = app_state.clone();
595 user_handler(move |request, response, session| {
596 compute_embeddings(
597 request,
598 response,
599 session,
600 app_state.config.openai_api_key.clone(),
601 )
602 })
603 });
604
605 Arc::new(server)
606 }
607
608 pub async fn start(&self) -> Result<()> {
609 let server_id = *self.id.lock();
610 let app_state = self.app_state.clone();
611 let peer = self.peer.clone();
612 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
613 let pool = self.connection_pool.clone();
614 let live_kit_client = self.app_state.live_kit_client.clone();
615
616 let span = info_span!("start server");
617 self.app_state.executor.spawn_detached(
618 async move {
619 tracing::info!("waiting for cleanup timeout");
620 timeout.await;
621 tracing::info!("cleanup timeout expired, retrieving stale rooms");
622 if let Some((room_ids, channel_ids)) = app_state
623 .db
624 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
625 .await
626 .trace_err()
627 {
628 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
629 tracing::info!(
630 stale_channel_buffer_count = channel_ids.len(),
631 "retrieved stale channel buffers"
632 );
633
634 for channel_id in channel_ids {
635 if let Some(refreshed_channel_buffer) = app_state
636 .db
637 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
638 .await
639 .trace_err()
640 {
641 for connection_id in refreshed_channel_buffer.connection_ids {
642 peer.send(
643 connection_id,
644 proto::UpdateChannelBufferCollaborators {
645 channel_id: channel_id.to_proto(),
646 collaborators: refreshed_channel_buffer
647 .collaborators
648 .clone(),
649 },
650 )
651 .trace_err();
652 }
653 }
654 }
655
656 for room_id in room_ids {
657 let mut contacts_to_update = HashSet::default();
658 let mut canceled_calls_to_user_ids = Vec::new();
659 let mut live_kit_room = String::new();
660 let mut delete_live_kit_room = false;
661
662 if let Some(mut refreshed_room) = app_state
663 .db
664 .clear_stale_room_participants(room_id, server_id)
665 .await
666 .trace_err()
667 {
668 tracing::info!(
669 room_id = room_id.0,
670 new_participant_count = refreshed_room.room.participants.len(),
671 "refreshed room"
672 );
673 room_updated(&refreshed_room.room, &peer);
674 if let Some(channel) = refreshed_room.channel.as_ref() {
675 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
676 }
677 contacts_to_update
678 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
679 contacts_to_update
680 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
681 canceled_calls_to_user_ids =
682 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
683 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
684 delete_live_kit_room = refreshed_room.room.participants.is_empty();
685 }
686
687 {
688 let pool = pool.lock();
689 for canceled_user_id in canceled_calls_to_user_ids {
690 for connection_id in pool.user_connection_ids(canceled_user_id) {
691 peer.send(
692 connection_id,
693 proto::CallCanceled {
694 room_id: room_id.to_proto(),
695 },
696 )
697 .trace_err();
698 }
699 }
700 }
701
702 for user_id in contacts_to_update {
703 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
704 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
705 if let Some((busy, contacts)) = busy.zip(contacts) {
706 let pool = pool.lock();
707 let updated_contact = contact_for_user(user_id, busy, &pool);
708 for contact in contacts {
709 if let db::Contact::Accepted {
710 user_id: contact_user_id,
711 ..
712 } = contact
713 {
714 for contact_conn_id in
715 pool.user_connection_ids(contact_user_id)
716 {
717 peer.send(
718 contact_conn_id,
719 proto::UpdateContacts {
720 contacts: vec![updated_contact.clone()],
721 remove_contacts: Default::default(),
722 incoming_requests: Default::default(),
723 remove_incoming_requests: Default::default(),
724 outgoing_requests: Default::default(),
725 remove_outgoing_requests: Default::default(),
726 },
727 )
728 .trace_err();
729 }
730 }
731 }
732 }
733 }
734
735 if let Some(live_kit) = live_kit_client.as_ref() {
736 if delete_live_kit_room {
737 live_kit.delete_room(live_kit_room).await.trace_err();
738 }
739 }
740 }
741 }
742
743 app_state
744 .db
745 .delete_stale_servers(&app_state.config.zed_environment, server_id)
746 .await
747 .trace_err();
748 }
749 .instrument(span),
750 );
751 Ok(())
752 }
753
754 pub fn teardown(&self) {
755 self.peer.teardown();
756 self.connection_pool.lock().reset();
757 let _ = self.teardown.send(true);
758 }
759
760 #[cfg(test)]
761 pub fn reset(&self, id: ServerId) {
762 self.teardown();
763 *self.id.lock() = id;
764 self.peer.reset(id.0 as u32);
765 let _ = self.teardown.send(false);
766 }
767
768 #[cfg(test)]
769 pub fn id(&self) -> ServerId {
770 *self.id.lock()
771 }
772
773 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
774 where
775 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
776 Fut: 'static + Send + Future<Output = Result<()>>,
777 M: EnvelopedMessage,
778 {
779 let prev_handler = self.handlers.insert(
780 TypeId::of::<M>(),
781 Box::new(move |envelope, session| {
782 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
783 let received_at = envelope.received_at;
784 tracing::info!("message received");
785 let start_time = Instant::now();
786 let future = (handler)(*envelope, session);
787 async move {
788 let result = future.await;
789 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
790 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
791 let queue_duration_ms = total_duration_ms - processing_duration_ms;
792 let payload_type = M::NAME;
793
794 match result {
795 Err(error) => {
796 tracing::error!(
797 ?error,
798 total_duration_ms,
799 processing_duration_ms,
800 queue_duration_ms,
801 payload_type,
802 "error handling message"
803 )
804 }
805 Ok(()) => tracing::info!(
806 total_duration_ms,
807 processing_duration_ms,
808 queue_duration_ms,
809 "finished handling message"
810 ),
811 }
812 }
813 .boxed()
814 }),
815 );
816 if prev_handler.is_some() {
817 panic!("registered a handler for the same message twice");
818 }
819 self
820 }
821
822 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
823 where
824 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
825 Fut: 'static + Send + Future<Output = Result<()>>,
826 M: EnvelopedMessage,
827 {
828 self.add_handler(move |envelope, session| handler(envelope.payload, session));
829 self
830 }
831
832 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
833 where
834 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
835 Fut: Send + Future<Output = Result<()>>,
836 M: RequestMessage,
837 {
838 let handler = Arc::new(handler);
839 self.add_handler(move |envelope, session| {
840 let receipt = envelope.receipt();
841 let handler = handler.clone();
842 async move {
843 let peer = session.peer.clone();
844 let responded = Arc::new(AtomicBool::default());
845 let response = Response {
846 peer: peer.clone(),
847 responded: responded.clone(),
848 receipt,
849 };
850 match (handler)(envelope.payload, response, session).await {
851 Ok(()) => {
852 if responded.load(std::sync::atomic::Ordering::SeqCst) {
853 Ok(())
854 } else {
855 Err(anyhow!("handler did not send a response"))?
856 }
857 }
858 Err(error) => {
859 let proto_err = match &error {
860 Error::Internal(err) => err.to_proto(),
861 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
862 };
863 peer.respond_with_error(receipt, proto_err)?;
864 Err(error)
865 }
866 }
867 }
868 })
869 }
870
871 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
872 where
873 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
874 Fut: Send + Future<Output = Result<()>>,
875 M: RequestMessage,
876 {
877 let handler = Arc::new(handler);
878 self.add_handler(move |envelope, session| {
879 let receipt = envelope.receipt();
880 let handler = handler.clone();
881 async move {
882 let peer = session.peer.clone();
883 let response = StreamingResponse {
884 peer: peer.clone(),
885 receipt,
886 };
887 match (handler)(envelope.payload, response, session).await {
888 Ok(()) => {
889 peer.end_stream(receipt)?;
890 Ok(())
891 }
892 Err(error) => {
893 let proto_err = match &error {
894 Error::Internal(err) => err.to_proto(),
895 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
896 };
897 peer.respond_with_error(receipt, proto_err)?;
898 Err(error)
899 }
900 }
901 }
902 })
903 }
904
905 #[allow(clippy::too_many_arguments)]
906 pub fn handle_connection(
907 self: &Arc<Self>,
908 connection: Connection,
909 address: String,
910 principal: Principal,
911 zed_version: ZedVersion,
912 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
913 executor: Executor,
914 ) -> impl Future<Output = ()> {
915 let this = self.clone();
916 let span = info_span!("handle connection", %address,
917 connection_id=field::Empty,
918 user_id=field::Empty,
919 login=field::Empty,
920 impersonator=field::Empty,
921 dev_server_id=field::Empty
922 );
923 principal.update_span(&span);
924
925 let mut teardown = self.teardown.subscribe();
926 async move {
927 if *teardown.borrow() {
928 tracing::error!("server is tearing down");
929 return
930 }
931 let (connection_id, handle_io, mut incoming_rx) = this
932 .peer
933 .add_connection(connection, {
934 let executor = executor.clone();
935 move |duration| executor.sleep(duration)
936 });
937 tracing::Span::current().record("connection_id", format!("{}", connection_id));
938 tracing::info!("connection opened");
939
940 let http_client = match IsahcHttpClient::new() {
941 Ok(http_client) => http_client,
942 Err(error) => {
943 tracing::error!(?error, "failed to create HTTP client");
944 return;
945 }
946 };
947
948 let session = Session {
949 principal: principal.clone(),
950 connection_id,
951 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
952 peer: this.peer.clone(),
953 connection_pool: this.connection_pool.clone(),
954 live_kit_client: this.app_state.live_kit_client.clone(),
955 http_client,
956 rate_limiter: this.app_state.rate_limiter.clone(),
957 _executor: executor.clone(),
958 };
959
960 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
961 tracing::error!(?error, "failed to send initial client update");
962 return;
963 }
964
965 let handle_io = handle_io.fuse();
966 futures::pin_mut!(handle_io);
967
968 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
969 // This prevents deadlocks when e.g., client A performs a request to client B and
970 // client B performs a request to client A. If both clients stop processing further
971 // messages until their respective request completes, they won't have a chance to
972 // respond to the other client's request and cause a deadlock.
973 //
974 // This arrangement ensures we will attempt to process earlier messages first, but fall
975 // back to processing messages arrived later in the spirit of making progress.
976 let mut foreground_message_handlers = FuturesUnordered::new();
977 let concurrent_handlers = Arc::new(Semaphore::new(256));
978 loop {
979 let next_message = async {
980 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
981 let message = incoming_rx.next().await;
982 (permit, message)
983 }.fuse();
984 futures::pin_mut!(next_message);
985 futures::select_biased! {
986 _ = teardown.changed().fuse() => return,
987 result = handle_io => {
988 if let Err(error) = result {
989 tracing::error!(?error, "error handling I/O");
990 }
991 break;
992 }
993 _ = foreground_message_handlers.next() => {}
994 next_message = next_message => {
995 let (permit, message) = next_message;
996 if let Some(message) = message {
997 let type_name = message.payload_type_name();
998 // note: we copy all the fields from the parent span so we can query them in the logs.
999 // (https://github.com/tokio-rs/tracing/issues/2670).
1000 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1001 user_id=field::Empty,
1002 login=field::Empty,
1003 impersonator=field::Empty,
1004 dev_server_id=field::Empty
1005 );
1006 principal.update_span(&span);
1007 let span_enter = span.enter();
1008 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1009 let is_background = message.is_background();
1010 let handle_message = (handler)(message, session.clone());
1011 drop(span_enter);
1012
1013 let handle_message = async move {
1014 handle_message.await;
1015 drop(permit);
1016 }.instrument(span);
1017 if is_background {
1018 executor.spawn_detached(handle_message);
1019 } else {
1020 foreground_message_handlers.push(handle_message);
1021 }
1022 } else {
1023 tracing::error!("no message handler");
1024 }
1025 } else {
1026 tracing::info!("connection closed");
1027 break;
1028 }
1029 }
1030 }
1031 }
1032
1033 drop(foreground_message_handlers);
1034 tracing::info!("signing out");
1035 if let Err(error) = connection_lost(session, teardown, executor).await {
1036 tracing::error!(?error, "error signing out");
1037 }
1038
1039 }.instrument(span)
1040 }
1041
1042 async fn send_initial_client_update(
1043 &self,
1044 connection_id: ConnectionId,
1045 principal: &Principal,
1046 zed_version: ZedVersion,
1047 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1048 session: &Session,
1049 ) -> Result<()> {
1050 self.peer.send(
1051 connection_id,
1052 proto::Hello {
1053 peer_id: Some(connection_id.into()),
1054 },
1055 )?;
1056 tracing::info!("sent hello message");
1057 if let Some(send_connection_id) = send_connection_id.take() {
1058 let _ = send_connection_id.send(connection_id);
1059 }
1060
1061 match principal {
1062 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1063 if !user.connected_once {
1064 self.peer.send(connection_id, proto::ShowContacts {})?;
1065 self.app_state
1066 .db
1067 .set_user_connected_once(user.id, true)
1068 .await?;
1069 }
1070
1071 let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1072 future::try_join4(
1073 self.app_state.db.get_contacts(user.id),
1074 self.app_state.db.get_channels_for_user(user.id),
1075 self.app_state.db.get_channel_invites_for_user(user.id),
1076 self.app_state.db.dev_server_projects_update(user.id),
1077 )
1078 .await?;
1079
1080 {
1081 let mut pool = self.connection_pool.lock();
1082 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1083 for membership in &channels_for_user.channel_memberships {
1084 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1085 }
1086 self.peer.send(
1087 connection_id,
1088 build_initial_contacts_update(contacts, &pool),
1089 )?;
1090 self.peer.send(
1091 connection_id,
1092 build_update_user_channels(&channels_for_user),
1093 )?;
1094 self.peer.send(
1095 connection_id,
1096 build_channels_update(channels_for_user, channel_invites),
1097 )?;
1098 }
1099 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1100
1101 if let Some(incoming_call) =
1102 self.app_state.db.incoming_call_for_user(user.id).await?
1103 {
1104 self.peer.send(connection_id, incoming_call)?;
1105 }
1106
1107 update_user_contacts(user.id, &session).await?;
1108 }
1109 Principal::DevServer(dev_server) => {
1110 {
1111 let mut pool = self.connection_pool.lock();
1112 if pool.dev_server_connection_id(dev_server.id).is_some() {
1113 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1114 };
1115 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1116 }
1117
1118 let projects = self
1119 .app_state
1120 .db
1121 .get_projects_for_dev_server(dev_server.id)
1122 .await?;
1123 self.peer
1124 .send(connection_id, proto::DevServerInstructions { projects })?;
1125
1126 let status = self
1127 .app_state
1128 .db
1129 .dev_server_projects_update(dev_server.user_id)
1130 .await?;
1131 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1132 }
1133 }
1134
1135 Ok(())
1136 }
1137
1138 pub async fn invite_code_redeemed(
1139 self: &Arc<Self>,
1140 inviter_id: UserId,
1141 invitee_id: UserId,
1142 ) -> Result<()> {
1143 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1144 if let Some(code) = &user.invite_code {
1145 let pool = self.connection_pool.lock();
1146 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1147 for connection_id in pool.user_connection_ids(inviter_id) {
1148 self.peer.send(
1149 connection_id,
1150 proto::UpdateContacts {
1151 contacts: vec![invitee_contact.clone()],
1152 ..Default::default()
1153 },
1154 )?;
1155 self.peer.send(
1156 connection_id,
1157 proto::UpdateInviteInfo {
1158 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1159 count: user.invite_count as u32,
1160 },
1161 )?;
1162 }
1163 }
1164 }
1165 Ok(())
1166 }
1167
1168 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1169 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1170 if let Some(invite_code) = &user.invite_code {
1171 let pool = self.connection_pool.lock();
1172 for connection_id in pool.user_connection_ids(user_id) {
1173 self.peer.send(
1174 connection_id,
1175 proto::UpdateInviteInfo {
1176 url: format!(
1177 "{}{}",
1178 self.app_state.config.invite_link_prefix, invite_code
1179 ),
1180 count: user.invite_count as u32,
1181 },
1182 )?;
1183 }
1184 }
1185 }
1186 Ok(())
1187 }
1188
1189 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1190 ServerSnapshot {
1191 connection_pool: ConnectionPoolGuard {
1192 guard: self.connection_pool.lock(),
1193 _not_send: PhantomData,
1194 },
1195 peer: &self.peer,
1196 }
1197 }
1198}
1199
1200impl<'a> Deref for ConnectionPoolGuard<'a> {
1201 type Target = ConnectionPool;
1202
1203 fn deref(&self) -> &Self::Target {
1204 &self.guard
1205 }
1206}
1207
1208impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1209 fn deref_mut(&mut self) -> &mut Self::Target {
1210 &mut self.guard
1211 }
1212}
1213
1214impl<'a> Drop for ConnectionPoolGuard<'a> {
1215 fn drop(&mut self) {
1216 #[cfg(test)]
1217 self.check_invariants();
1218 }
1219}
1220
1221fn broadcast<F>(
1222 sender_id: Option<ConnectionId>,
1223 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1224 mut f: F,
1225) where
1226 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1227{
1228 for receiver_id in receiver_ids {
1229 if Some(receiver_id) != sender_id {
1230 if let Err(error) = f(receiver_id) {
1231 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1232 }
1233 }
1234 }
1235}
1236
1237pub struct ProtocolVersion(u32);
1238
1239impl Header for ProtocolVersion {
1240 fn name() -> &'static HeaderName {
1241 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1242 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1243 }
1244
1245 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1246 where
1247 Self: Sized,
1248 I: Iterator<Item = &'i axum::http::HeaderValue>,
1249 {
1250 let version = values
1251 .next()
1252 .ok_or_else(axum::headers::Error::invalid)?
1253 .to_str()
1254 .map_err(|_| axum::headers::Error::invalid())?
1255 .parse()
1256 .map_err(|_| axum::headers::Error::invalid())?;
1257 Ok(Self(version))
1258 }
1259
1260 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1261 values.extend([self.0.to_string().parse().unwrap()]);
1262 }
1263}
1264
1265pub struct AppVersionHeader(SemanticVersion);
1266impl Header for AppVersionHeader {
1267 fn name() -> &'static HeaderName {
1268 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1269 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1270 }
1271
1272 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1273 where
1274 Self: Sized,
1275 I: Iterator<Item = &'i axum::http::HeaderValue>,
1276 {
1277 let version = values
1278 .next()
1279 .ok_or_else(axum::headers::Error::invalid)?
1280 .to_str()
1281 .map_err(|_| axum::headers::Error::invalid())?
1282 .parse()
1283 .map_err(|_| axum::headers::Error::invalid())?;
1284 Ok(Self(version))
1285 }
1286
1287 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1288 values.extend([self.0.to_string().parse().unwrap()]);
1289 }
1290}
1291
1292pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1293 Router::new()
1294 .route("/rpc", get(handle_websocket_request))
1295 .layer(
1296 ServiceBuilder::new()
1297 .layer(Extension(server.app_state.clone()))
1298 .layer(middleware::from_fn(auth::validate_header)),
1299 )
1300 .route("/metrics", get(handle_metrics))
1301 .layer(Extension(server))
1302}
1303
1304pub async fn handle_websocket_request(
1305 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1306 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1307 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1308 Extension(server): Extension<Arc<Server>>,
1309 Extension(principal): Extension<Principal>,
1310 ws: WebSocketUpgrade,
1311) -> axum::response::Response {
1312 if protocol_version != rpc::PROTOCOL_VERSION {
1313 return (
1314 StatusCode::UPGRADE_REQUIRED,
1315 "client must be upgraded".to_string(),
1316 )
1317 .into_response();
1318 }
1319
1320 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1321 return (
1322 StatusCode::UPGRADE_REQUIRED,
1323 "no version header found".to_string(),
1324 )
1325 .into_response();
1326 };
1327
1328 if !version.can_collaborate() {
1329 return (
1330 StatusCode::UPGRADE_REQUIRED,
1331 "client must be upgraded".to_string(),
1332 )
1333 .into_response();
1334 }
1335
1336 let socket_address = socket_address.to_string();
1337 ws.on_upgrade(move |socket| {
1338 let socket = socket
1339 .map_ok(to_tungstenite_message)
1340 .err_into()
1341 .with(|message| async move { Ok(to_axum_message(message)) });
1342 let connection = Connection::new(Box::pin(socket));
1343 async move {
1344 server
1345 .handle_connection(
1346 connection,
1347 socket_address,
1348 principal,
1349 version,
1350 None,
1351 Executor::Production,
1352 )
1353 .await;
1354 }
1355 })
1356}
1357
1358pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1359 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1360 let connections_metric = CONNECTIONS_METRIC
1361 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1362
1363 let connections = server
1364 .connection_pool
1365 .lock()
1366 .connections()
1367 .filter(|connection| !connection.admin)
1368 .count();
1369 connections_metric.set(connections as _);
1370
1371 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1372 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1373 register_int_gauge!(
1374 "shared_projects",
1375 "number of open projects with one or more guests"
1376 )
1377 .unwrap()
1378 });
1379
1380 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1381 shared_projects_metric.set(shared_projects as _);
1382
1383 let encoder = prometheus::TextEncoder::new();
1384 let metric_families = prometheus::gather();
1385 let encoded_metrics = encoder
1386 .encode_to_string(&metric_families)
1387 .map_err(|err| anyhow!("{}", err))?;
1388 Ok(encoded_metrics)
1389}
1390
1391#[instrument(err, skip(executor))]
1392async fn connection_lost(
1393 session: Session,
1394 mut teardown: watch::Receiver<bool>,
1395 executor: Executor,
1396) -> Result<()> {
1397 session.peer.disconnect(session.connection_id);
1398 session
1399 .connection_pool()
1400 .await
1401 .remove_connection(session.connection_id)?;
1402
1403 session
1404 .db()
1405 .await
1406 .connection_lost(session.connection_id)
1407 .await
1408 .trace_err();
1409
1410 futures::select_biased! {
1411 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1412 match &session.principal {
1413 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1414 let session = session.for_user().unwrap();
1415
1416 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1417 leave_room_for_session(&session, session.connection_id).await.trace_err();
1418 leave_channel_buffers_for_session(&session)
1419 .await
1420 .trace_err();
1421
1422 if !session
1423 .connection_pool()
1424 .await
1425 .is_user_online(session.user_id())
1426 {
1427 let db = session.db().await;
1428 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1429 room_updated(&room, &session.peer);
1430 }
1431 }
1432
1433 update_user_contacts(session.user_id(), &session).await?;
1434 },
1435 Principal::DevServer(_) => {
1436 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1437 },
1438 }
1439 },
1440 _ = teardown.changed().fuse() => {}
1441 }
1442
1443 Ok(())
1444}
1445
1446/// Acknowledges a ping from a client, used to keep the connection alive.
1447async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1448 response.send(proto::Ack {})?;
1449 Ok(())
1450}
1451
1452/// Creates a new room for calling (outside of channels)
1453async fn create_room(
1454 _request: proto::CreateRoom,
1455 response: Response<proto::CreateRoom>,
1456 session: UserSession,
1457) -> Result<()> {
1458 let live_kit_room = nanoid::nanoid!(30);
1459
1460 let live_kit_connection_info = util::maybe!(async {
1461 let live_kit = session.live_kit_client.as_ref();
1462 let live_kit = live_kit?;
1463 let user_id = session.user_id().to_string();
1464
1465 let token = live_kit
1466 .room_token(&live_kit_room, &user_id.to_string())
1467 .trace_err()?;
1468
1469 Some(proto::LiveKitConnectionInfo {
1470 server_url: live_kit.url().into(),
1471 token,
1472 can_publish: true,
1473 })
1474 })
1475 .await;
1476
1477 let room = session
1478 .db()
1479 .await
1480 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1481 .await?;
1482
1483 response.send(proto::CreateRoomResponse {
1484 room: Some(room.clone()),
1485 live_kit_connection_info,
1486 })?;
1487
1488 update_user_contacts(session.user_id(), &session).await?;
1489 Ok(())
1490}
1491
1492/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1493async fn join_room(
1494 request: proto::JoinRoom,
1495 response: Response<proto::JoinRoom>,
1496 session: UserSession,
1497) -> Result<()> {
1498 let room_id = RoomId::from_proto(request.id);
1499
1500 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1501
1502 if let Some(channel_id) = channel_id {
1503 return join_channel_internal(channel_id, Box::new(response), session).await;
1504 }
1505
1506 let joined_room = {
1507 let room = session
1508 .db()
1509 .await
1510 .join_room(room_id, session.user_id(), session.connection_id)
1511 .await?;
1512 room_updated(&room.room, &session.peer);
1513 room.into_inner()
1514 };
1515
1516 for connection_id in session
1517 .connection_pool()
1518 .await
1519 .user_connection_ids(session.user_id())
1520 {
1521 session
1522 .peer
1523 .send(
1524 connection_id,
1525 proto::CallCanceled {
1526 room_id: room_id.to_proto(),
1527 },
1528 )
1529 .trace_err();
1530 }
1531
1532 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1533 if let Some(token) = live_kit
1534 .room_token(
1535 &joined_room.room.live_kit_room,
1536 &session.user_id().to_string(),
1537 )
1538 .trace_err()
1539 {
1540 Some(proto::LiveKitConnectionInfo {
1541 server_url: live_kit.url().into(),
1542 token,
1543 can_publish: true,
1544 })
1545 } else {
1546 None
1547 }
1548 } else {
1549 None
1550 };
1551
1552 response.send(proto::JoinRoomResponse {
1553 room: Some(joined_room.room),
1554 channel_id: None,
1555 live_kit_connection_info,
1556 })?;
1557
1558 update_user_contacts(session.user_id(), &session).await?;
1559 Ok(())
1560}
1561
1562/// Rejoin room is used to reconnect to a room after connection errors.
1563async fn rejoin_room(
1564 request: proto::RejoinRoom,
1565 response: Response<proto::RejoinRoom>,
1566 session: UserSession,
1567) -> Result<()> {
1568 let room;
1569 let channel;
1570 {
1571 let mut rejoined_room = session
1572 .db()
1573 .await
1574 .rejoin_room(request, session.user_id(), session.connection_id)
1575 .await?;
1576
1577 response.send(proto::RejoinRoomResponse {
1578 room: Some(rejoined_room.room.clone()),
1579 reshared_projects: rejoined_room
1580 .reshared_projects
1581 .iter()
1582 .map(|project| proto::ResharedProject {
1583 id: project.id.to_proto(),
1584 collaborators: project
1585 .collaborators
1586 .iter()
1587 .map(|collaborator| collaborator.to_proto())
1588 .collect(),
1589 })
1590 .collect(),
1591 rejoined_projects: rejoined_room
1592 .rejoined_projects
1593 .iter()
1594 .map(|rejoined_project| rejoined_project.to_proto())
1595 .collect(),
1596 })?;
1597 room_updated(&rejoined_room.room, &session.peer);
1598
1599 for project in &rejoined_room.reshared_projects {
1600 for collaborator in &project.collaborators {
1601 session
1602 .peer
1603 .send(
1604 collaborator.connection_id,
1605 proto::UpdateProjectCollaborator {
1606 project_id: project.id.to_proto(),
1607 old_peer_id: Some(project.old_connection_id.into()),
1608 new_peer_id: Some(session.connection_id.into()),
1609 },
1610 )
1611 .trace_err();
1612 }
1613
1614 broadcast(
1615 Some(session.connection_id),
1616 project
1617 .collaborators
1618 .iter()
1619 .map(|collaborator| collaborator.connection_id),
1620 |connection_id| {
1621 session.peer.forward_send(
1622 session.connection_id,
1623 connection_id,
1624 proto::UpdateProject {
1625 project_id: project.id.to_proto(),
1626 worktrees: project.worktrees.clone(),
1627 },
1628 )
1629 },
1630 );
1631 }
1632
1633 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1634
1635 let rejoined_room = rejoined_room.into_inner();
1636
1637 room = rejoined_room.room;
1638 channel = rejoined_room.channel;
1639 }
1640
1641 if let Some(channel) = channel {
1642 channel_updated(
1643 &channel,
1644 &room,
1645 &session.peer,
1646 &*session.connection_pool().await,
1647 );
1648 }
1649
1650 update_user_contacts(session.user_id(), &session).await?;
1651 Ok(())
1652}
1653
1654fn notify_rejoined_projects(
1655 rejoined_projects: &mut Vec<RejoinedProject>,
1656 session: &UserSession,
1657) -> Result<()> {
1658 for project in rejoined_projects.iter() {
1659 for collaborator in &project.collaborators {
1660 session
1661 .peer
1662 .send(
1663 collaborator.connection_id,
1664 proto::UpdateProjectCollaborator {
1665 project_id: project.id.to_proto(),
1666 old_peer_id: Some(project.old_connection_id.into()),
1667 new_peer_id: Some(session.connection_id.into()),
1668 },
1669 )
1670 .trace_err();
1671 }
1672 }
1673
1674 for project in rejoined_projects {
1675 for worktree in mem::take(&mut project.worktrees) {
1676 #[cfg(any(test, feature = "test-support"))]
1677 const MAX_CHUNK_SIZE: usize = 2;
1678 #[cfg(not(any(test, feature = "test-support")))]
1679 const MAX_CHUNK_SIZE: usize = 256;
1680
1681 // Stream this worktree's entries.
1682 let message = proto::UpdateWorktree {
1683 project_id: project.id.to_proto(),
1684 worktree_id: worktree.id,
1685 abs_path: worktree.abs_path.clone(),
1686 root_name: worktree.root_name,
1687 updated_entries: worktree.updated_entries,
1688 removed_entries: worktree.removed_entries,
1689 scan_id: worktree.scan_id,
1690 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1691 updated_repositories: worktree.updated_repositories,
1692 removed_repositories: worktree.removed_repositories,
1693 };
1694 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1695 session.peer.send(session.connection_id, update.clone())?;
1696 }
1697
1698 // Stream this worktree's diagnostics.
1699 for summary in worktree.diagnostic_summaries {
1700 session.peer.send(
1701 session.connection_id,
1702 proto::UpdateDiagnosticSummary {
1703 project_id: project.id.to_proto(),
1704 worktree_id: worktree.id,
1705 summary: Some(summary),
1706 },
1707 )?;
1708 }
1709
1710 for settings_file in worktree.settings_files {
1711 session.peer.send(
1712 session.connection_id,
1713 proto::UpdateWorktreeSettings {
1714 project_id: project.id.to_proto(),
1715 worktree_id: worktree.id,
1716 path: settings_file.path,
1717 content: Some(settings_file.content),
1718 },
1719 )?;
1720 }
1721 }
1722
1723 for language_server in &project.language_servers {
1724 session.peer.send(
1725 session.connection_id,
1726 proto::UpdateLanguageServer {
1727 project_id: project.id.to_proto(),
1728 language_server_id: language_server.id,
1729 variant: Some(
1730 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1731 proto::LspDiskBasedDiagnosticsUpdated {},
1732 ),
1733 ),
1734 },
1735 )?;
1736 }
1737 }
1738 Ok(())
1739}
1740
1741/// leave room disconnects from the room.
1742async fn leave_room(
1743 _: proto::LeaveRoom,
1744 response: Response<proto::LeaveRoom>,
1745 session: UserSession,
1746) -> Result<()> {
1747 leave_room_for_session(&session, session.connection_id).await?;
1748 response.send(proto::Ack {})?;
1749 Ok(())
1750}
1751
1752/// Updates the permissions of someone else in the room.
1753async fn set_room_participant_role(
1754 request: proto::SetRoomParticipantRole,
1755 response: Response<proto::SetRoomParticipantRole>,
1756 session: UserSession,
1757) -> Result<()> {
1758 let user_id = UserId::from_proto(request.user_id);
1759 let role = ChannelRole::from(request.role());
1760
1761 let (live_kit_room, can_publish) = {
1762 let room = session
1763 .db()
1764 .await
1765 .set_room_participant_role(
1766 session.user_id(),
1767 RoomId::from_proto(request.room_id),
1768 user_id,
1769 role,
1770 )
1771 .await?;
1772
1773 let live_kit_room = room.live_kit_room.clone();
1774 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1775 room_updated(&room, &session.peer);
1776 (live_kit_room, can_publish)
1777 };
1778
1779 if let Some(live_kit) = session.live_kit_client.as_ref() {
1780 live_kit
1781 .update_participant(
1782 live_kit_room.clone(),
1783 request.user_id.to_string(),
1784 live_kit_server::proto::ParticipantPermission {
1785 can_subscribe: true,
1786 can_publish,
1787 can_publish_data: can_publish,
1788 hidden: false,
1789 recorder: false,
1790 },
1791 )
1792 .await
1793 .trace_err();
1794 }
1795
1796 response.send(proto::Ack {})?;
1797 Ok(())
1798}
1799
1800/// Call someone else into the current room
1801async fn call(
1802 request: proto::Call,
1803 response: Response<proto::Call>,
1804 session: UserSession,
1805) -> Result<()> {
1806 let room_id = RoomId::from_proto(request.room_id);
1807 let calling_user_id = session.user_id();
1808 let calling_connection_id = session.connection_id;
1809 let called_user_id = UserId::from_proto(request.called_user_id);
1810 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1811 if !session
1812 .db()
1813 .await
1814 .has_contact(calling_user_id, called_user_id)
1815 .await?
1816 {
1817 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1818 }
1819
1820 let incoming_call = {
1821 let (room, incoming_call) = &mut *session
1822 .db()
1823 .await
1824 .call(
1825 room_id,
1826 calling_user_id,
1827 calling_connection_id,
1828 called_user_id,
1829 initial_project_id,
1830 )
1831 .await?;
1832 room_updated(&room, &session.peer);
1833 mem::take(incoming_call)
1834 };
1835 update_user_contacts(called_user_id, &session).await?;
1836
1837 let mut calls = session
1838 .connection_pool()
1839 .await
1840 .user_connection_ids(called_user_id)
1841 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1842 .collect::<FuturesUnordered<_>>();
1843
1844 while let Some(call_response) = calls.next().await {
1845 match call_response.as_ref() {
1846 Ok(_) => {
1847 response.send(proto::Ack {})?;
1848 return Ok(());
1849 }
1850 Err(_) => {
1851 call_response.trace_err();
1852 }
1853 }
1854 }
1855
1856 {
1857 let room = session
1858 .db()
1859 .await
1860 .call_failed(room_id, called_user_id)
1861 .await?;
1862 room_updated(&room, &session.peer);
1863 }
1864 update_user_contacts(called_user_id, &session).await?;
1865
1866 Err(anyhow!("failed to ring user"))?
1867}
1868
1869/// Cancel an outgoing call.
1870async fn cancel_call(
1871 request: proto::CancelCall,
1872 response: Response<proto::CancelCall>,
1873 session: UserSession,
1874) -> Result<()> {
1875 let called_user_id = UserId::from_proto(request.called_user_id);
1876 let room_id = RoomId::from_proto(request.room_id);
1877 {
1878 let room = session
1879 .db()
1880 .await
1881 .cancel_call(room_id, session.connection_id, called_user_id)
1882 .await?;
1883 room_updated(&room, &session.peer);
1884 }
1885
1886 for connection_id in session
1887 .connection_pool()
1888 .await
1889 .user_connection_ids(called_user_id)
1890 {
1891 session
1892 .peer
1893 .send(
1894 connection_id,
1895 proto::CallCanceled {
1896 room_id: room_id.to_proto(),
1897 },
1898 )
1899 .trace_err();
1900 }
1901 response.send(proto::Ack {})?;
1902
1903 update_user_contacts(called_user_id, &session).await?;
1904 Ok(())
1905}
1906
1907/// Decline an incoming call.
1908async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1909 let room_id = RoomId::from_proto(message.room_id);
1910 {
1911 let room = session
1912 .db()
1913 .await
1914 .decline_call(Some(room_id), session.user_id())
1915 .await?
1916 .ok_or_else(|| anyhow!("failed to decline call"))?;
1917 room_updated(&room, &session.peer);
1918 }
1919
1920 for connection_id in session
1921 .connection_pool()
1922 .await
1923 .user_connection_ids(session.user_id())
1924 {
1925 session
1926 .peer
1927 .send(
1928 connection_id,
1929 proto::CallCanceled {
1930 room_id: room_id.to_proto(),
1931 },
1932 )
1933 .trace_err();
1934 }
1935 update_user_contacts(session.user_id(), &session).await?;
1936 Ok(())
1937}
1938
1939/// Updates other participants in the room with your current location.
1940async fn update_participant_location(
1941 request: proto::UpdateParticipantLocation,
1942 response: Response<proto::UpdateParticipantLocation>,
1943 session: UserSession,
1944) -> Result<()> {
1945 let room_id = RoomId::from_proto(request.room_id);
1946 let location = request
1947 .location
1948 .ok_or_else(|| anyhow!("invalid location"))?;
1949
1950 let db = session.db().await;
1951 let room = db
1952 .update_room_participant_location(room_id, session.connection_id, location)
1953 .await?;
1954
1955 room_updated(&room, &session.peer);
1956 response.send(proto::Ack {})?;
1957 Ok(())
1958}
1959
1960/// Share a project into the room.
1961async fn share_project(
1962 request: proto::ShareProject,
1963 response: Response<proto::ShareProject>,
1964 session: UserSession,
1965) -> Result<()> {
1966 let (project_id, room) = &*session
1967 .db()
1968 .await
1969 .share_project(
1970 RoomId::from_proto(request.room_id),
1971 session.connection_id,
1972 &request.worktrees,
1973 request
1974 .dev_server_project_id
1975 .map(|id| DevServerProjectId::from_proto(id)),
1976 )
1977 .await?;
1978 response.send(proto::ShareProjectResponse {
1979 project_id: project_id.to_proto(),
1980 })?;
1981 room_updated(&room, &session.peer);
1982
1983 Ok(())
1984}
1985
1986/// Unshare a project from the room.
1987async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1988 let project_id = ProjectId::from_proto(message.project_id);
1989 unshare_project_internal(
1990 project_id,
1991 session.connection_id,
1992 session.user_id(),
1993 &session,
1994 )
1995 .await
1996}
1997
1998async fn unshare_project_internal(
1999 project_id: ProjectId,
2000 connection_id: ConnectionId,
2001 user_id: Option<UserId>,
2002 session: &Session,
2003) -> Result<()> {
2004 let (room, guest_connection_ids) = &*session
2005 .db()
2006 .await
2007 .unshare_project(project_id, connection_id, user_id)
2008 .await?;
2009
2010 let message = proto::UnshareProject {
2011 project_id: project_id.to_proto(),
2012 };
2013
2014 broadcast(
2015 Some(connection_id),
2016 guest_connection_ids.iter().copied(),
2017 |conn_id| session.peer.send(conn_id, message.clone()),
2018 );
2019 if let Some(room) = room {
2020 room_updated(room, &session.peer);
2021 }
2022
2023 Ok(())
2024}
2025
2026/// DevServer makes a project available online
2027async fn share_dev_server_project(
2028 request: proto::ShareDevServerProject,
2029 response: Response<proto::ShareDevServerProject>,
2030 session: DevServerSession,
2031) -> Result<()> {
2032 let (dev_server_project, user_id, status) = session
2033 .db()
2034 .await
2035 .share_dev_server_project(
2036 DevServerProjectId::from_proto(request.dev_server_project_id),
2037 session.dev_server_id(),
2038 session.connection_id,
2039 &request.worktrees,
2040 )
2041 .await?;
2042 let Some(project_id) = dev_server_project.project_id else {
2043 return Err(anyhow!("failed to share remote project"))?;
2044 };
2045
2046 send_dev_server_projects_update(user_id, status, &session).await;
2047
2048 response.send(proto::ShareProjectResponse { project_id })?;
2049
2050 Ok(())
2051}
2052
2053/// Join someone elses shared project.
2054async fn join_project(
2055 request: proto::JoinProject,
2056 response: Response<proto::JoinProject>,
2057 session: UserSession,
2058) -> Result<()> {
2059 let project_id = ProjectId::from_proto(request.project_id);
2060
2061 tracing::info!(%project_id, "join project");
2062
2063 let db = session.db().await;
2064 let (project, replica_id) = &mut *db
2065 .join_project(project_id, session.connection_id, session.user_id())
2066 .await?;
2067 drop(db);
2068 tracing::info!(%project_id, "join remote project");
2069 join_project_internal(response, session, project, replica_id)
2070}
2071
2072trait JoinProjectInternalResponse {
2073 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2074}
2075impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2076 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2077 Response::<proto::JoinProject>::send(self, result)
2078 }
2079}
2080impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2081 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2082 Response::<proto::JoinHostedProject>::send(self, result)
2083 }
2084}
2085
2086fn join_project_internal(
2087 response: impl JoinProjectInternalResponse,
2088 session: UserSession,
2089 project: &mut Project,
2090 replica_id: &ReplicaId,
2091) -> Result<()> {
2092 let collaborators = project
2093 .collaborators
2094 .iter()
2095 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2096 .map(|collaborator| collaborator.to_proto())
2097 .collect::<Vec<_>>();
2098 let project_id = project.id;
2099 let guest_user_id = session.user_id();
2100
2101 let worktrees = project
2102 .worktrees
2103 .iter()
2104 .map(|(id, worktree)| proto::WorktreeMetadata {
2105 id: *id,
2106 root_name: worktree.root_name.clone(),
2107 visible: worktree.visible,
2108 abs_path: worktree.abs_path.clone(),
2109 })
2110 .collect::<Vec<_>>();
2111
2112 let add_project_collaborator = proto::AddProjectCollaborator {
2113 project_id: project_id.to_proto(),
2114 collaborator: Some(proto::Collaborator {
2115 peer_id: Some(session.connection_id.into()),
2116 replica_id: replica_id.0 as u32,
2117 user_id: guest_user_id.to_proto(),
2118 }),
2119 };
2120
2121 for collaborator in &collaborators {
2122 session
2123 .peer
2124 .send(
2125 collaborator.peer_id.unwrap().into(),
2126 add_project_collaborator.clone(),
2127 )
2128 .trace_err();
2129 }
2130
2131 // First, we send the metadata associated with each worktree.
2132 response.send(proto::JoinProjectResponse {
2133 project_id: project.id.0 as u64,
2134 worktrees: worktrees.clone(),
2135 replica_id: replica_id.0 as u32,
2136 collaborators: collaborators.clone(),
2137 language_servers: project.language_servers.clone(),
2138 role: project.role.into(),
2139 dev_server_project_id: project
2140 .dev_server_project_id
2141 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2142 })?;
2143
2144 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2145 #[cfg(any(test, feature = "test-support"))]
2146 const MAX_CHUNK_SIZE: usize = 2;
2147 #[cfg(not(any(test, feature = "test-support")))]
2148 const MAX_CHUNK_SIZE: usize = 256;
2149
2150 // Stream this worktree's entries.
2151 let message = proto::UpdateWorktree {
2152 project_id: project_id.to_proto(),
2153 worktree_id,
2154 abs_path: worktree.abs_path.clone(),
2155 root_name: worktree.root_name,
2156 updated_entries: worktree.entries,
2157 removed_entries: Default::default(),
2158 scan_id: worktree.scan_id,
2159 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2160 updated_repositories: worktree.repository_entries.into_values().collect(),
2161 removed_repositories: Default::default(),
2162 };
2163 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2164 session.peer.send(session.connection_id, update.clone())?;
2165 }
2166
2167 // Stream this worktree's diagnostics.
2168 for summary in worktree.diagnostic_summaries {
2169 session.peer.send(
2170 session.connection_id,
2171 proto::UpdateDiagnosticSummary {
2172 project_id: project_id.to_proto(),
2173 worktree_id: worktree.id,
2174 summary: Some(summary),
2175 },
2176 )?;
2177 }
2178
2179 for settings_file in worktree.settings_files {
2180 session.peer.send(
2181 session.connection_id,
2182 proto::UpdateWorktreeSettings {
2183 project_id: project_id.to_proto(),
2184 worktree_id: worktree.id,
2185 path: settings_file.path,
2186 content: Some(settings_file.content),
2187 },
2188 )?;
2189 }
2190 }
2191
2192 for language_server in &project.language_servers {
2193 session.peer.send(
2194 session.connection_id,
2195 proto::UpdateLanguageServer {
2196 project_id: project_id.to_proto(),
2197 language_server_id: language_server.id,
2198 variant: Some(
2199 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2200 proto::LspDiskBasedDiagnosticsUpdated {},
2201 ),
2202 ),
2203 },
2204 )?;
2205 }
2206
2207 Ok(())
2208}
2209
2210/// Leave someone elses shared project.
2211async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2212 let sender_id = session.connection_id;
2213 let project_id = ProjectId::from_proto(request.project_id);
2214 let db = session.db().await;
2215 if db.is_hosted_project(project_id).await? {
2216 let project = db.leave_hosted_project(project_id, sender_id).await?;
2217 project_left(&project, &session);
2218 return Ok(());
2219 }
2220
2221 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2222 tracing::info!(
2223 %project_id,
2224 "leave project"
2225 );
2226
2227 project_left(&project, &session);
2228 if let Some(room) = room {
2229 room_updated(&room, &session.peer);
2230 }
2231
2232 Ok(())
2233}
2234
2235async fn join_hosted_project(
2236 request: proto::JoinHostedProject,
2237 response: Response<proto::JoinHostedProject>,
2238 session: UserSession,
2239) -> Result<()> {
2240 let (mut project, replica_id) = session
2241 .db()
2242 .await
2243 .join_hosted_project(
2244 ProjectId(request.project_id as i32),
2245 session.user_id(),
2246 session.connection_id,
2247 )
2248 .await?;
2249
2250 join_project_internal(response, session, &mut project, &replica_id)
2251}
2252
2253async fn create_dev_server_project(
2254 request: proto::CreateDevServerProject,
2255 response: Response<proto::CreateDevServerProject>,
2256 session: UserSession,
2257) -> Result<()> {
2258 let dev_server_id = DevServerId(request.dev_server_id as i32);
2259 let dev_server_connection_id = session
2260 .connection_pool()
2261 .await
2262 .dev_server_connection_id(dev_server_id);
2263 let Some(dev_server_connection_id) = dev_server_connection_id else {
2264 Err(ErrorCode::DevServerOffline
2265 .message("Cannot create a remote project when the dev server is offline".to_string())
2266 .anyhow())?
2267 };
2268
2269 let path = request.path.clone();
2270 //Check that the path exists on the dev server
2271 session
2272 .peer
2273 .forward_request(
2274 session.connection_id,
2275 dev_server_connection_id,
2276 proto::ValidateDevServerProjectRequest { path: path.clone() },
2277 )
2278 .await?;
2279
2280 let (dev_server_project, update) = session
2281 .db()
2282 .await
2283 .create_dev_server_project(
2284 DevServerId(request.dev_server_id as i32),
2285 &request.path,
2286 session.user_id(),
2287 )
2288 .await?;
2289
2290 let projects = session
2291 .db()
2292 .await
2293 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2294 .await?;
2295
2296 session.peer.send(
2297 dev_server_connection_id,
2298 proto::DevServerInstructions { projects },
2299 )?;
2300
2301 send_dev_server_projects_update(session.user_id(), update, &session).await;
2302
2303 response.send(proto::CreateDevServerProjectResponse {
2304 dev_server_project: Some(dev_server_project.to_proto(None)),
2305 })?;
2306 Ok(())
2307}
2308
2309async fn create_dev_server(
2310 request: proto::CreateDevServer,
2311 response: Response<proto::CreateDevServer>,
2312 session: UserSession,
2313) -> Result<()> {
2314 let access_token = auth::random_token();
2315 let hashed_access_token = auth::hash_access_token(&access_token);
2316
2317 let (dev_server, status) = session
2318 .db()
2319 .await
2320 .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2321 .await?;
2322
2323 send_dev_server_projects_update(session.user_id(), status, &session).await;
2324
2325 response.send(proto::CreateDevServerResponse {
2326 dev_server_id: dev_server.id.0 as u64,
2327 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2328 name: request.name.clone(),
2329 })?;
2330 Ok(())
2331}
2332
2333async fn delete_dev_server(
2334 request: proto::DeleteDevServer,
2335 response: Response<proto::DeleteDevServer>,
2336 session: UserSession,
2337) -> Result<()> {
2338 let dev_server_id = DevServerId(request.dev_server_id as i32);
2339 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2340 if dev_server.user_id != session.user_id() {
2341 return Err(anyhow!(ErrorCode::Forbidden))?;
2342 }
2343
2344 let connection_id = session
2345 .connection_pool()
2346 .await
2347 .dev_server_connection_id(dev_server_id);
2348 if let Some(connection_id) = connection_id {
2349 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2350 session
2351 .peer
2352 .send(connection_id, proto::ShutdownDevServer {})?;
2353 }
2354
2355 let status = session
2356 .db()
2357 .await
2358 .delete_dev_server(dev_server_id, session.user_id())
2359 .await?;
2360
2361 send_dev_server_projects_update(session.user_id(), status, &session).await;
2362
2363 response.send(proto::Ack {})?;
2364 Ok(())
2365}
2366
2367async fn delete_dev_server_project(
2368 request: proto::DeleteDevServerProject,
2369 response: Response<proto::DeleteDevServerProject>,
2370 session: UserSession,
2371) -> Result<()> {
2372 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2373 let dev_server_project = session
2374 .db()
2375 .await
2376 .get_dev_server_project(dev_server_project_id)
2377 .await?;
2378
2379 let dev_server = session
2380 .db()
2381 .await
2382 .get_dev_server(dev_server_project.dev_server_id)
2383 .await?;
2384 if dev_server.user_id != session.user_id() {
2385 return Err(anyhow!(ErrorCode::Forbidden))?;
2386 }
2387
2388 let dev_server_connection_id = session
2389 .connection_pool()
2390 .await
2391 .dev_server_connection_id(dev_server.id);
2392
2393 if let Some(dev_server_connection_id) = dev_server_connection_id {
2394 let project = session
2395 .db()
2396 .await
2397 .find_dev_server_project(dev_server_project_id)
2398 .await;
2399 if let Ok(project) = project {
2400 unshare_project_internal(
2401 project.id,
2402 dev_server_connection_id,
2403 Some(session.user_id()),
2404 &session,
2405 )
2406 .await?;
2407 }
2408 }
2409
2410 let (projects, status) = session
2411 .db()
2412 .await
2413 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2414 .await?;
2415
2416 if let Some(dev_server_connection_id) = dev_server_connection_id {
2417 session.peer.send(
2418 dev_server_connection_id,
2419 proto::DevServerInstructions { projects },
2420 )?;
2421 }
2422
2423 send_dev_server_projects_update(session.user_id(), status, &session).await;
2424
2425 response.send(proto::Ack {})?;
2426 Ok(())
2427}
2428
2429async fn rejoin_dev_server_projects(
2430 request: proto::RejoinRemoteProjects,
2431 response: Response<proto::RejoinRemoteProjects>,
2432 session: UserSession,
2433) -> Result<()> {
2434 let mut rejoined_projects = {
2435 let db = session.db().await;
2436 db.rejoin_dev_server_projects(
2437 &request.rejoined_projects,
2438 session.user_id(),
2439 session.0.connection_id,
2440 )
2441 .await?
2442 };
2443 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2444
2445 response.send(proto::RejoinRemoteProjectsResponse {
2446 rejoined_projects: rejoined_projects
2447 .into_iter()
2448 .map(|project| project.to_proto())
2449 .collect(),
2450 })
2451}
2452
2453async fn reconnect_dev_server(
2454 request: proto::ReconnectDevServer,
2455 response: Response<proto::ReconnectDevServer>,
2456 session: DevServerSession,
2457) -> Result<()> {
2458 let reshared_projects = {
2459 let db = session.db().await;
2460 db.reshare_dev_server_projects(
2461 &request.reshared_projects,
2462 session.dev_server_id(),
2463 session.0.connection_id,
2464 )
2465 .await?
2466 };
2467
2468 for project in &reshared_projects {
2469 for collaborator in &project.collaborators {
2470 session
2471 .peer
2472 .send(
2473 collaborator.connection_id,
2474 proto::UpdateProjectCollaborator {
2475 project_id: project.id.to_proto(),
2476 old_peer_id: Some(project.old_connection_id.into()),
2477 new_peer_id: Some(session.connection_id.into()),
2478 },
2479 )
2480 .trace_err();
2481 }
2482
2483 broadcast(
2484 Some(session.connection_id),
2485 project
2486 .collaborators
2487 .iter()
2488 .map(|collaborator| collaborator.connection_id),
2489 |connection_id| {
2490 session.peer.forward_send(
2491 session.connection_id,
2492 connection_id,
2493 proto::UpdateProject {
2494 project_id: project.id.to_proto(),
2495 worktrees: project.worktrees.clone(),
2496 },
2497 )
2498 },
2499 );
2500 }
2501
2502 response.send(proto::ReconnectDevServerResponse {
2503 reshared_projects: reshared_projects
2504 .iter()
2505 .map(|project| proto::ResharedProject {
2506 id: project.id.to_proto(),
2507 collaborators: project
2508 .collaborators
2509 .iter()
2510 .map(|collaborator| collaborator.to_proto())
2511 .collect(),
2512 })
2513 .collect(),
2514 })?;
2515
2516 Ok(())
2517}
2518
2519async fn shutdown_dev_server(
2520 _: proto::ShutdownDevServer,
2521 response: Response<proto::ShutdownDevServer>,
2522 session: DevServerSession,
2523) -> Result<()> {
2524 response.send(proto::Ack {})?;
2525 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2526}
2527
2528async fn shutdown_dev_server_internal(
2529 dev_server_id: DevServerId,
2530 connection_id: ConnectionId,
2531 session: &Session,
2532) -> Result<()> {
2533 let (dev_server_projects, dev_server) = {
2534 let db = session.db().await;
2535 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2536 let dev_server = db.get_dev_server(dev_server_id).await?;
2537 (dev_server_projects, dev_server)
2538 };
2539
2540 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2541 unshare_project_internal(
2542 ProjectId::from_proto(project_id),
2543 connection_id,
2544 None,
2545 session,
2546 )
2547 .await?;
2548 }
2549
2550 session
2551 .connection_pool()
2552 .await
2553 .set_dev_server_offline(dev_server_id);
2554
2555 let status = session
2556 .db()
2557 .await
2558 .dev_server_projects_update(dev_server.user_id)
2559 .await?;
2560 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2561
2562 Ok(())
2563}
2564
2565/// Updates other participants with changes to the project
2566async fn update_project(
2567 request: proto::UpdateProject,
2568 response: Response<proto::UpdateProject>,
2569 session: Session,
2570) -> Result<()> {
2571 let project_id = ProjectId::from_proto(request.project_id);
2572 let (room, guest_connection_ids) = &*session
2573 .db()
2574 .await
2575 .update_project(project_id, session.connection_id, &request.worktrees)
2576 .await?;
2577 broadcast(
2578 Some(session.connection_id),
2579 guest_connection_ids.iter().copied(),
2580 |connection_id| {
2581 session
2582 .peer
2583 .forward_send(session.connection_id, connection_id, request.clone())
2584 },
2585 );
2586 if let Some(room) = room {
2587 room_updated(&room, &session.peer);
2588 }
2589 response.send(proto::Ack {})?;
2590
2591 Ok(())
2592}
2593
2594/// Updates other participants with changes to the worktree
2595async fn update_worktree(
2596 request: proto::UpdateWorktree,
2597 response: Response<proto::UpdateWorktree>,
2598 session: Session,
2599) -> Result<()> {
2600 let guest_connection_ids = session
2601 .db()
2602 .await
2603 .update_worktree(&request, session.connection_id)
2604 .await?;
2605
2606 broadcast(
2607 Some(session.connection_id),
2608 guest_connection_ids.iter().copied(),
2609 |connection_id| {
2610 session
2611 .peer
2612 .forward_send(session.connection_id, connection_id, request.clone())
2613 },
2614 );
2615 response.send(proto::Ack {})?;
2616 Ok(())
2617}
2618
2619/// Updates other participants with changes to the diagnostics
2620async fn update_diagnostic_summary(
2621 message: proto::UpdateDiagnosticSummary,
2622 session: Session,
2623) -> Result<()> {
2624 let guest_connection_ids = session
2625 .db()
2626 .await
2627 .update_diagnostic_summary(&message, session.connection_id)
2628 .await?;
2629
2630 broadcast(
2631 Some(session.connection_id),
2632 guest_connection_ids.iter().copied(),
2633 |connection_id| {
2634 session
2635 .peer
2636 .forward_send(session.connection_id, connection_id, message.clone())
2637 },
2638 );
2639
2640 Ok(())
2641}
2642
2643/// Updates other participants with changes to the worktree settings
2644async fn update_worktree_settings(
2645 message: proto::UpdateWorktreeSettings,
2646 session: Session,
2647) -> Result<()> {
2648 let guest_connection_ids = session
2649 .db()
2650 .await
2651 .update_worktree_settings(&message, session.connection_id)
2652 .await?;
2653
2654 broadcast(
2655 Some(session.connection_id),
2656 guest_connection_ids.iter().copied(),
2657 |connection_id| {
2658 session
2659 .peer
2660 .forward_send(session.connection_id, connection_id, message.clone())
2661 },
2662 );
2663
2664 Ok(())
2665}
2666
2667/// Notify other participants that a language server has started.
2668async fn start_language_server(
2669 request: proto::StartLanguageServer,
2670 session: Session,
2671) -> Result<()> {
2672 let guest_connection_ids = session
2673 .db()
2674 .await
2675 .start_language_server(&request, session.connection_id)
2676 .await?;
2677
2678 broadcast(
2679 Some(session.connection_id),
2680 guest_connection_ids.iter().copied(),
2681 |connection_id| {
2682 session
2683 .peer
2684 .forward_send(session.connection_id, connection_id, request.clone())
2685 },
2686 );
2687 Ok(())
2688}
2689
2690/// Notify other participants that a language server has changed.
2691async fn update_language_server(
2692 request: proto::UpdateLanguageServer,
2693 session: Session,
2694) -> Result<()> {
2695 let project_id = ProjectId::from_proto(request.project_id);
2696 let project_connection_ids = session
2697 .db()
2698 .await
2699 .project_connection_ids(project_id, session.connection_id, true)
2700 .await?;
2701 broadcast(
2702 Some(session.connection_id),
2703 project_connection_ids.iter().copied(),
2704 |connection_id| {
2705 session
2706 .peer
2707 .forward_send(session.connection_id, connection_id, request.clone())
2708 },
2709 );
2710 Ok(())
2711}
2712
2713/// forward a project request to the host. These requests should be read only
2714/// as guests are allowed to send them.
2715async fn forward_read_only_project_request<T>(
2716 request: T,
2717 response: Response<T>,
2718 session: UserSession,
2719) -> Result<()>
2720where
2721 T: EntityMessage + RequestMessage,
2722{
2723 let project_id = ProjectId::from_proto(request.remote_entity_id());
2724 let host_connection_id = session
2725 .db()
2726 .await
2727 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2728 .await?;
2729 let payload = session
2730 .peer
2731 .forward_request(session.connection_id, host_connection_id, request)
2732 .await?;
2733 response.send(payload)?;
2734 Ok(())
2735}
2736
2737/// forward a project request to the host. These requests are disallowed
2738/// for guests.
2739async fn forward_mutating_project_request<T>(
2740 request: T,
2741 response: Response<T>,
2742 session: UserSession,
2743) -> Result<()>
2744where
2745 T: EntityMessage + RequestMessage,
2746{
2747 let project_id = ProjectId::from_proto(request.remote_entity_id());
2748
2749 let host_connection_id = session
2750 .db()
2751 .await
2752 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2753 .await?;
2754 let payload = session
2755 .peer
2756 .forward_request(session.connection_id, host_connection_id, request)
2757 .await?;
2758 response.send(payload)?;
2759 Ok(())
2760}
2761
2762/// forward a project request to the host. These requests are disallowed
2763/// for guests.
2764async fn forward_versioned_mutating_project_request<T>(
2765 request: T,
2766 response: Response<T>,
2767 session: UserSession,
2768) -> Result<()>
2769where
2770 T: EntityMessage + RequestMessage + VersionedMessage,
2771{
2772 let project_id = ProjectId::from_proto(request.remote_entity_id());
2773
2774 let host_connection_id = session
2775 .db()
2776 .await
2777 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2778 .await?;
2779 if let Some(host_version) = session
2780 .connection_pool()
2781 .await
2782 .connection(host_connection_id)
2783 .map(|c| c.zed_version)
2784 {
2785 if let Some(min_required_version) = request.required_host_version() {
2786 if min_required_version > host_version {
2787 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2788 .with_tag("required", &min_required_version.to_string())))?;
2789 }
2790 }
2791 }
2792
2793 let payload = session
2794 .peer
2795 .forward_request(session.connection_id, host_connection_id, request)
2796 .await?;
2797 response.send(payload)?;
2798 Ok(())
2799}
2800
2801/// Notify other participants that a new buffer has been created
2802async fn create_buffer_for_peer(
2803 request: proto::CreateBufferForPeer,
2804 session: Session,
2805) -> Result<()> {
2806 session
2807 .db()
2808 .await
2809 .check_user_is_project_host(
2810 ProjectId::from_proto(request.project_id),
2811 session.connection_id,
2812 )
2813 .await?;
2814 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2815 session
2816 .peer
2817 .forward_send(session.connection_id, peer_id.into(), request)?;
2818 Ok(())
2819}
2820
2821/// Notify other participants that a buffer has been updated. This is
2822/// allowed for guests as long as the update is limited to selections.
2823async fn update_buffer(
2824 request: proto::UpdateBuffer,
2825 response: Response<proto::UpdateBuffer>,
2826 session: Session,
2827) -> Result<()> {
2828 let project_id = ProjectId::from_proto(request.project_id);
2829 let mut capability = Capability::ReadOnly;
2830
2831 for op in request.operations.iter() {
2832 match op.variant {
2833 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2834 Some(_) => capability = Capability::ReadWrite,
2835 }
2836 }
2837
2838 let host = {
2839 let guard = session
2840 .db()
2841 .await
2842 .connections_for_buffer_update(
2843 project_id,
2844 session.principal_id(),
2845 session.connection_id,
2846 capability,
2847 )
2848 .await?;
2849
2850 let (host, guests) = &*guard;
2851
2852 broadcast(
2853 Some(session.connection_id),
2854 guests.clone(),
2855 |connection_id| {
2856 session
2857 .peer
2858 .forward_send(session.connection_id, connection_id, request.clone())
2859 },
2860 );
2861
2862 *host
2863 };
2864
2865 if host != session.connection_id {
2866 session
2867 .peer
2868 .forward_request(session.connection_id, host, request.clone())
2869 .await?;
2870 }
2871
2872 response.send(proto::Ack {})?;
2873 Ok(())
2874}
2875
2876/// Notify other participants that a project has been updated.
2877async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2878 request: T,
2879 session: Session,
2880) -> Result<()> {
2881 let project_id = ProjectId::from_proto(request.remote_entity_id());
2882 let project_connection_ids = session
2883 .db()
2884 .await
2885 .project_connection_ids(project_id, session.connection_id, false)
2886 .await?;
2887
2888 broadcast(
2889 Some(session.connection_id),
2890 project_connection_ids.iter().copied(),
2891 |connection_id| {
2892 session
2893 .peer
2894 .forward_send(session.connection_id, connection_id, request.clone())
2895 },
2896 );
2897 Ok(())
2898}
2899
2900/// Start following another user in a call.
2901async fn follow(
2902 request: proto::Follow,
2903 response: Response<proto::Follow>,
2904 session: UserSession,
2905) -> Result<()> {
2906 let room_id = RoomId::from_proto(request.room_id);
2907 let project_id = request.project_id.map(ProjectId::from_proto);
2908 let leader_id = request
2909 .leader_id
2910 .ok_or_else(|| anyhow!("invalid leader id"))?
2911 .into();
2912 let follower_id = session.connection_id;
2913
2914 session
2915 .db()
2916 .await
2917 .check_room_participants(room_id, leader_id, session.connection_id)
2918 .await?;
2919
2920 let response_payload = session
2921 .peer
2922 .forward_request(session.connection_id, leader_id, request)
2923 .await?;
2924 response.send(response_payload)?;
2925
2926 if let Some(project_id) = project_id {
2927 let room = session
2928 .db()
2929 .await
2930 .follow(room_id, project_id, leader_id, follower_id)
2931 .await?;
2932 room_updated(&room, &session.peer);
2933 }
2934
2935 Ok(())
2936}
2937
2938/// Stop following another user in a call.
2939async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2940 let room_id = RoomId::from_proto(request.room_id);
2941 let project_id = request.project_id.map(ProjectId::from_proto);
2942 let leader_id = request
2943 .leader_id
2944 .ok_or_else(|| anyhow!("invalid leader id"))?
2945 .into();
2946 let follower_id = session.connection_id;
2947
2948 session
2949 .db()
2950 .await
2951 .check_room_participants(room_id, leader_id, session.connection_id)
2952 .await?;
2953
2954 session
2955 .peer
2956 .forward_send(session.connection_id, leader_id, request)?;
2957
2958 if let Some(project_id) = project_id {
2959 let room = session
2960 .db()
2961 .await
2962 .unfollow(room_id, project_id, leader_id, follower_id)
2963 .await?;
2964 room_updated(&room, &session.peer);
2965 }
2966
2967 Ok(())
2968}
2969
2970/// Notify everyone following you of your current location.
2971async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2972 let room_id = RoomId::from_proto(request.room_id);
2973 let database = session.db.lock().await;
2974
2975 let connection_ids = if let Some(project_id) = request.project_id {
2976 let project_id = ProjectId::from_proto(project_id);
2977 database
2978 .project_connection_ids(project_id, session.connection_id, true)
2979 .await?
2980 } else {
2981 database
2982 .room_connection_ids(room_id, session.connection_id)
2983 .await?
2984 };
2985
2986 // For now, don't send view update messages back to that view's current leader.
2987 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2988 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2989 _ => None,
2990 });
2991
2992 for connection_id in connection_ids.iter().cloned() {
2993 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2994 session
2995 .peer
2996 .forward_send(session.connection_id, connection_id, request.clone())?;
2997 }
2998 }
2999 Ok(())
3000}
3001
3002/// Get public data about users.
3003async fn get_users(
3004 request: proto::GetUsers,
3005 response: Response<proto::GetUsers>,
3006 session: Session,
3007) -> Result<()> {
3008 let user_ids = request
3009 .user_ids
3010 .into_iter()
3011 .map(UserId::from_proto)
3012 .collect();
3013 let users = session
3014 .db()
3015 .await
3016 .get_users_by_ids(user_ids)
3017 .await?
3018 .into_iter()
3019 .map(|user| proto::User {
3020 id: user.id.to_proto(),
3021 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3022 github_login: user.github_login,
3023 })
3024 .collect();
3025 response.send(proto::UsersResponse { users })?;
3026 Ok(())
3027}
3028
3029/// Search for users (to invite) buy Github login
3030async fn fuzzy_search_users(
3031 request: proto::FuzzySearchUsers,
3032 response: Response<proto::FuzzySearchUsers>,
3033 session: UserSession,
3034) -> Result<()> {
3035 let query = request.query;
3036 let users = match query.len() {
3037 0 => vec![],
3038 1 | 2 => session
3039 .db()
3040 .await
3041 .get_user_by_github_login(&query)
3042 .await?
3043 .into_iter()
3044 .collect(),
3045 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3046 };
3047 let users = users
3048 .into_iter()
3049 .filter(|user| user.id != session.user_id())
3050 .map(|user| proto::User {
3051 id: user.id.to_proto(),
3052 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3053 github_login: user.github_login,
3054 })
3055 .collect();
3056 response.send(proto::UsersResponse { users })?;
3057 Ok(())
3058}
3059
3060/// Send a contact request to another user.
3061async fn request_contact(
3062 request: proto::RequestContact,
3063 response: Response<proto::RequestContact>,
3064 session: UserSession,
3065) -> Result<()> {
3066 let requester_id = session.user_id();
3067 let responder_id = UserId::from_proto(request.responder_id);
3068 if requester_id == responder_id {
3069 return Err(anyhow!("cannot add yourself as a contact"))?;
3070 }
3071
3072 let notifications = session
3073 .db()
3074 .await
3075 .send_contact_request(requester_id, responder_id)
3076 .await?;
3077
3078 // Update outgoing contact requests of requester
3079 let mut update = proto::UpdateContacts::default();
3080 update.outgoing_requests.push(responder_id.to_proto());
3081 for connection_id in session
3082 .connection_pool()
3083 .await
3084 .user_connection_ids(requester_id)
3085 {
3086 session.peer.send(connection_id, update.clone())?;
3087 }
3088
3089 // Update incoming contact requests of responder
3090 let mut update = proto::UpdateContacts::default();
3091 update
3092 .incoming_requests
3093 .push(proto::IncomingContactRequest {
3094 requester_id: requester_id.to_proto(),
3095 });
3096 let connection_pool = session.connection_pool().await;
3097 for connection_id in connection_pool.user_connection_ids(responder_id) {
3098 session.peer.send(connection_id, update.clone())?;
3099 }
3100
3101 send_notifications(&connection_pool, &session.peer, notifications);
3102
3103 response.send(proto::Ack {})?;
3104 Ok(())
3105}
3106
3107/// Accept or decline a contact request
3108async fn respond_to_contact_request(
3109 request: proto::RespondToContactRequest,
3110 response: Response<proto::RespondToContactRequest>,
3111 session: UserSession,
3112) -> Result<()> {
3113 let responder_id = session.user_id();
3114 let requester_id = UserId::from_proto(request.requester_id);
3115 let db = session.db().await;
3116 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3117 db.dismiss_contact_notification(responder_id, requester_id)
3118 .await?;
3119 } else {
3120 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3121
3122 let notifications = db
3123 .respond_to_contact_request(responder_id, requester_id, accept)
3124 .await?;
3125 let requester_busy = db.is_user_busy(requester_id).await?;
3126 let responder_busy = db.is_user_busy(responder_id).await?;
3127
3128 let pool = session.connection_pool().await;
3129 // Update responder with new contact
3130 let mut update = proto::UpdateContacts::default();
3131 if accept {
3132 update
3133 .contacts
3134 .push(contact_for_user(requester_id, requester_busy, &pool));
3135 }
3136 update
3137 .remove_incoming_requests
3138 .push(requester_id.to_proto());
3139 for connection_id in pool.user_connection_ids(responder_id) {
3140 session.peer.send(connection_id, update.clone())?;
3141 }
3142
3143 // Update requester with new contact
3144 let mut update = proto::UpdateContacts::default();
3145 if accept {
3146 update
3147 .contacts
3148 .push(contact_for_user(responder_id, responder_busy, &pool));
3149 }
3150 update
3151 .remove_outgoing_requests
3152 .push(responder_id.to_proto());
3153
3154 for connection_id in pool.user_connection_ids(requester_id) {
3155 session.peer.send(connection_id, update.clone())?;
3156 }
3157
3158 send_notifications(&pool, &session.peer, notifications);
3159 }
3160
3161 response.send(proto::Ack {})?;
3162 Ok(())
3163}
3164
3165/// Remove a contact.
3166async fn remove_contact(
3167 request: proto::RemoveContact,
3168 response: Response<proto::RemoveContact>,
3169 session: UserSession,
3170) -> Result<()> {
3171 let requester_id = session.user_id();
3172 let responder_id = UserId::from_proto(request.user_id);
3173 let db = session.db().await;
3174 let (contact_accepted, deleted_notification_id) =
3175 db.remove_contact(requester_id, responder_id).await?;
3176
3177 let pool = session.connection_pool().await;
3178 // Update outgoing contact requests of requester
3179 let mut update = proto::UpdateContacts::default();
3180 if contact_accepted {
3181 update.remove_contacts.push(responder_id.to_proto());
3182 } else {
3183 update
3184 .remove_outgoing_requests
3185 .push(responder_id.to_proto());
3186 }
3187 for connection_id in pool.user_connection_ids(requester_id) {
3188 session.peer.send(connection_id, update.clone())?;
3189 }
3190
3191 // Update incoming contact requests of responder
3192 let mut update = proto::UpdateContacts::default();
3193 if contact_accepted {
3194 update.remove_contacts.push(requester_id.to_proto());
3195 } else {
3196 update
3197 .remove_incoming_requests
3198 .push(requester_id.to_proto());
3199 }
3200 for connection_id in pool.user_connection_ids(responder_id) {
3201 session.peer.send(connection_id, update.clone())?;
3202 if let Some(notification_id) = deleted_notification_id {
3203 session.peer.send(
3204 connection_id,
3205 proto::DeleteNotification {
3206 notification_id: notification_id.to_proto(),
3207 },
3208 )?;
3209 }
3210 }
3211
3212 response.send(proto::Ack {})?;
3213 Ok(())
3214}
3215
3216/// Creates a new channel.
3217async fn create_channel(
3218 request: proto::CreateChannel,
3219 response: Response<proto::CreateChannel>,
3220 session: UserSession,
3221) -> Result<()> {
3222 let db = session.db().await;
3223
3224 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3225 let (channel, membership) = db
3226 .create_channel(&request.name, parent_id, session.user_id())
3227 .await?;
3228
3229 let root_id = channel.root_id();
3230 let channel = Channel::from_model(channel);
3231
3232 response.send(proto::CreateChannelResponse {
3233 channel: Some(channel.to_proto()),
3234 parent_id: request.parent_id,
3235 })?;
3236
3237 let mut connection_pool = session.connection_pool().await;
3238 if let Some(membership) = membership {
3239 connection_pool.subscribe_to_channel(
3240 membership.user_id,
3241 membership.channel_id,
3242 membership.role,
3243 );
3244 let update = proto::UpdateUserChannels {
3245 channel_memberships: vec![proto::ChannelMembership {
3246 channel_id: membership.channel_id.to_proto(),
3247 role: membership.role.into(),
3248 }],
3249 ..Default::default()
3250 };
3251 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3252 session.peer.send(connection_id, update.clone())?;
3253 }
3254 }
3255
3256 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3257 if !role.can_see_channel(channel.visibility) {
3258 continue;
3259 }
3260
3261 let update = proto::UpdateChannels {
3262 channels: vec![channel.to_proto()],
3263 ..Default::default()
3264 };
3265 session.peer.send(connection_id, update.clone())?;
3266 }
3267
3268 Ok(())
3269}
3270
3271/// Delete a channel
3272async fn delete_channel(
3273 request: proto::DeleteChannel,
3274 response: Response<proto::DeleteChannel>,
3275 session: UserSession,
3276) -> Result<()> {
3277 let db = session.db().await;
3278
3279 let channel_id = request.channel_id;
3280 let (root_channel, removed_channels) = db
3281 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3282 .await?;
3283 response.send(proto::Ack {})?;
3284
3285 // Notify members of removed channels
3286 let mut update = proto::UpdateChannels::default();
3287 update
3288 .delete_channels
3289 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3290
3291 let connection_pool = session.connection_pool().await;
3292 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3293 session.peer.send(connection_id, update.clone())?;
3294 }
3295
3296 Ok(())
3297}
3298
3299/// Invite someone to join a channel.
3300async fn invite_channel_member(
3301 request: proto::InviteChannelMember,
3302 response: Response<proto::InviteChannelMember>,
3303 session: UserSession,
3304) -> Result<()> {
3305 let db = session.db().await;
3306 let channel_id = ChannelId::from_proto(request.channel_id);
3307 let invitee_id = UserId::from_proto(request.user_id);
3308 let InviteMemberResult {
3309 channel,
3310 notifications,
3311 } = db
3312 .invite_channel_member(
3313 channel_id,
3314 invitee_id,
3315 session.user_id(),
3316 request.role().into(),
3317 )
3318 .await?;
3319
3320 let update = proto::UpdateChannels {
3321 channel_invitations: vec![channel.to_proto()],
3322 ..Default::default()
3323 };
3324
3325 let connection_pool = session.connection_pool().await;
3326 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3327 session.peer.send(connection_id, update.clone())?;
3328 }
3329
3330 send_notifications(&connection_pool, &session.peer, notifications);
3331
3332 response.send(proto::Ack {})?;
3333 Ok(())
3334}
3335
3336/// remove someone from a channel
3337async fn remove_channel_member(
3338 request: proto::RemoveChannelMember,
3339 response: Response<proto::RemoveChannelMember>,
3340 session: UserSession,
3341) -> Result<()> {
3342 let db = session.db().await;
3343 let channel_id = ChannelId::from_proto(request.channel_id);
3344 let member_id = UserId::from_proto(request.user_id);
3345
3346 let RemoveChannelMemberResult {
3347 membership_update,
3348 notification_id,
3349 } = db
3350 .remove_channel_member(channel_id, member_id, session.user_id())
3351 .await?;
3352
3353 let mut connection_pool = session.connection_pool().await;
3354 notify_membership_updated(
3355 &mut connection_pool,
3356 membership_update,
3357 member_id,
3358 &session.peer,
3359 );
3360 for connection_id in connection_pool.user_connection_ids(member_id) {
3361 if let Some(notification_id) = notification_id {
3362 session
3363 .peer
3364 .send(
3365 connection_id,
3366 proto::DeleteNotification {
3367 notification_id: notification_id.to_proto(),
3368 },
3369 )
3370 .trace_err();
3371 }
3372 }
3373
3374 response.send(proto::Ack {})?;
3375 Ok(())
3376}
3377
3378/// Toggle the channel between public and private.
3379/// Care is taken to maintain the invariant that public channels only descend from public channels,
3380/// (though members-only channels can appear at any point in the hierarchy).
3381async fn set_channel_visibility(
3382 request: proto::SetChannelVisibility,
3383 response: Response<proto::SetChannelVisibility>,
3384 session: UserSession,
3385) -> Result<()> {
3386 let db = session.db().await;
3387 let channel_id = ChannelId::from_proto(request.channel_id);
3388 let visibility = request.visibility().into();
3389
3390 let channel_model = db
3391 .set_channel_visibility(channel_id, visibility, session.user_id())
3392 .await?;
3393 let root_id = channel_model.root_id();
3394 let channel = Channel::from_model(channel_model);
3395
3396 let mut connection_pool = session.connection_pool().await;
3397 for (user_id, role) in connection_pool
3398 .channel_user_ids(root_id)
3399 .collect::<Vec<_>>()
3400 .into_iter()
3401 {
3402 let update = if role.can_see_channel(channel.visibility) {
3403 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3404 proto::UpdateChannels {
3405 channels: vec![channel.to_proto()],
3406 ..Default::default()
3407 }
3408 } else {
3409 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3410 proto::UpdateChannels {
3411 delete_channels: vec![channel.id.to_proto()],
3412 ..Default::default()
3413 }
3414 };
3415
3416 for connection_id in connection_pool.user_connection_ids(user_id) {
3417 session.peer.send(connection_id, update.clone())?;
3418 }
3419 }
3420
3421 response.send(proto::Ack {})?;
3422 Ok(())
3423}
3424
3425/// Alter the role for a user in the channel.
3426async fn set_channel_member_role(
3427 request: proto::SetChannelMemberRole,
3428 response: Response<proto::SetChannelMemberRole>,
3429 session: UserSession,
3430) -> Result<()> {
3431 let db = session.db().await;
3432 let channel_id = ChannelId::from_proto(request.channel_id);
3433 let member_id = UserId::from_proto(request.user_id);
3434 let result = db
3435 .set_channel_member_role(
3436 channel_id,
3437 session.user_id(),
3438 member_id,
3439 request.role().into(),
3440 )
3441 .await?;
3442
3443 match result {
3444 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3445 let mut connection_pool = session.connection_pool().await;
3446 notify_membership_updated(
3447 &mut connection_pool,
3448 membership_update,
3449 member_id,
3450 &session.peer,
3451 )
3452 }
3453 db::SetMemberRoleResult::InviteUpdated(channel) => {
3454 let update = proto::UpdateChannels {
3455 channel_invitations: vec![channel.to_proto()],
3456 ..Default::default()
3457 };
3458
3459 for connection_id in session
3460 .connection_pool()
3461 .await
3462 .user_connection_ids(member_id)
3463 {
3464 session.peer.send(connection_id, update.clone())?;
3465 }
3466 }
3467 }
3468
3469 response.send(proto::Ack {})?;
3470 Ok(())
3471}
3472
3473/// Change the name of a channel
3474async fn rename_channel(
3475 request: proto::RenameChannel,
3476 response: Response<proto::RenameChannel>,
3477 session: UserSession,
3478) -> Result<()> {
3479 let db = session.db().await;
3480 let channel_id = ChannelId::from_proto(request.channel_id);
3481 let channel_model = db
3482 .rename_channel(channel_id, session.user_id(), &request.name)
3483 .await?;
3484 let root_id = channel_model.root_id();
3485 let channel = Channel::from_model(channel_model);
3486
3487 response.send(proto::RenameChannelResponse {
3488 channel: Some(channel.to_proto()),
3489 })?;
3490
3491 let connection_pool = session.connection_pool().await;
3492 let update = proto::UpdateChannels {
3493 channels: vec![channel.to_proto()],
3494 ..Default::default()
3495 };
3496 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3497 if role.can_see_channel(channel.visibility) {
3498 session.peer.send(connection_id, update.clone())?;
3499 }
3500 }
3501
3502 Ok(())
3503}
3504
3505/// Move a channel to a new parent.
3506async fn move_channel(
3507 request: proto::MoveChannel,
3508 response: Response<proto::MoveChannel>,
3509 session: UserSession,
3510) -> Result<()> {
3511 let channel_id = ChannelId::from_proto(request.channel_id);
3512 let to = ChannelId::from_proto(request.to);
3513
3514 let (root_id, channels) = session
3515 .db()
3516 .await
3517 .move_channel(channel_id, to, session.user_id())
3518 .await?;
3519
3520 let connection_pool = session.connection_pool().await;
3521 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3522 let channels = channels
3523 .iter()
3524 .filter_map(|channel| {
3525 if role.can_see_channel(channel.visibility) {
3526 Some(channel.to_proto())
3527 } else {
3528 None
3529 }
3530 })
3531 .collect::<Vec<_>>();
3532 if channels.is_empty() {
3533 continue;
3534 }
3535
3536 let update = proto::UpdateChannels {
3537 channels,
3538 ..Default::default()
3539 };
3540
3541 session.peer.send(connection_id, update.clone())?;
3542 }
3543
3544 response.send(Ack {})?;
3545 Ok(())
3546}
3547
3548/// Get the list of channel members
3549async fn get_channel_members(
3550 request: proto::GetChannelMembers,
3551 response: Response<proto::GetChannelMembers>,
3552 session: UserSession,
3553) -> Result<()> {
3554 let db = session.db().await;
3555 let channel_id = ChannelId::from_proto(request.channel_id);
3556 let members = db
3557 .get_channel_participant_details(channel_id, session.user_id())
3558 .await?;
3559 response.send(proto::GetChannelMembersResponse { members })?;
3560 Ok(())
3561}
3562
3563/// Accept or decline a channel invitation.
3564async fn respond_to_channel_invite(
3565 request: proto::RespondToChannelInvite,
3566 response: Response<proto::RespondToChannelInvite>,
3567 session: UserSession,
3568) -> Result<()> {
3569 let db = session.db().await;
3570 let channel_id = ChannelId::from_proto(request.channel_id);
3571 let RespondToChannelInvite {
3572 membership_update,
3573 notifications,
3574 } = db
3575 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3576 .await?;
3577
3578 let mut connection_pool = session.connection_pool().await;
3579 if let Some(membership_update) = membership_update {
3580 notify_membership_updated(
3581 &mut connection_pool,
3582 membership_update,
3583 session.user_id(),
3584 &session.peer,
3585 );
3586 } else {
3587 let update = proto::UpdateChannels {
3588 remove_channel_invitations: vec![channel_id.to_proto()],
3589 ..Default::default()
3590 };
3591
3592 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3593 session.peer.send(connection_id, update.clone())?;
3594 }
3595 };
3596
3597 send_notifications(&connection_pool, &session.peer, notifications);
3598
3599 response.send(proto::Ack {})?;
3600
3601 Ok(())
3602}
3603
3604/// Join the channels' room
3605async fn join_channel(
3606 request: proto::JoinChannel,
3607 response: Response<proto::JoinChannel>,
3608 session: UserSession,
3609) -> Result<()> {
3610 let channel_id = ChannelId::from_proto(request.channel_id);
3611 join_channel_internal(channel_id, Box::new(response), session).await
3612}
3613
3614trait JoinChannelInternalResponse {
3615 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3616}
3617impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3618 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3619 Response::<proto::JoinChannel>::send(self, result)
3620 }
3621}
3622impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3623 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3624 Response::<proto::JoinRoom>::send(self, result)
3625 }
3626}
3627
3628async fn join_channel_internal(
3629 channel_id: ChannelId,
3630 response: Box<impl JoinChannelInternalResponse>,
3631 session: UserSession,
3632) -> Result<()> {
3633 let joined_room = {
3634 let mut db = session.db().await;
3635 // If zed quits without leaving the room, and the user re-opens zed before the
3636 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3637 // room they were in.
3638 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3639 tracing::info!(
3640 stale_connection_id = %connection,
3641 "cleaning up stale connection",
3642 );
3643 drop(db);
3644 leave_room_for_session(&session, connection).await?;
3645 db = session.db().await;
3646 }
3647
3648 let (joined_room, membership_updated, role) = db
3649 .join_channel(channel_id, session.user_id(), session.connection_id)
3650 .await?;
3651
3652 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3653 let (can_publish, token) = if role == ChannelRole::Guest {
3654 (
3655 false,
3656 live_kit
3657 .guest_token(
3658 &joined_room.room.live_kit_room,
3659 &session.user_id().to_string(),
3660 )
3661 .trace_err()?,
3662 )
3663 } else {
3664 (
3665 true,
3666 live_kit
3667 .room_token(
3668 &joined_room.room.live_kit_room,
3669 &session.user_id().to_string(),
3670 )
3671 .trace_err()?,
3672 )
3673 };
3674
3675 Some(LiveKitConnectionInfo {
3676 server_url: live_kit.url().into(),
3677 token,
3678 can_publish,
3679 })
3680 });
3681
3682 response.send(proto::JoinRoomResponse {
3683 room: Some(joined_room.room.clone()),
3684 channel_id: joined_room
3685 .channel
3686 .as_ref()
3687 .map(|channel| channel.id.to_proto()),
3688 live_kit_connection_info,
3689 })?;
3690
3691 let mut connection_pool = session.connection_pool().await;
3692 if let Some(membership_updated) = membership_updated {
3693 notify_membership_updated(
3694 &mut connection_pool,
3695 membership_updated,
3696 session.user_id(),
3697 &session.peer,
3698 );
3699 }
3700
3701 room_updated(&joined_room.room, &session.peer);
3702
3703 joined_room
3704 };
3705
3706 channel_updated(
3707 &joined_room
3708 .channel
3709 .ok_or_else(|| anyhow!("channel not returned"))?,
3710 &joined_room.room,
3711 &session.peer,
3712 &*session.connection_pool().await,
3713 );
3714
3715 update_user_contacts(session.user_id(), &session).await?;
3716 Ok(())
3717}
3718
3719/// Start editing the channel notes
3720async fn join_channel_buffer(
3721 request: proto::JoinChannelBuffer,
3722 response: Response<proto::JoinChannelBuffer>,
3723 session: UserSession,
3724) -> Result<()> {
3725 let db = session.db().await;
3726 let channel_id = ChannelId::from_proto(request.channel_id);
3727
3728 let open_response = db
3729 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3730 .await?;
3731
3732 let collaborators = open_response.collaborators.clone();
3733 response.send(open_response)?;
3734
3735 let update = UpdateChannelBufferCollaborators {
3736 channel_id: channel_id.to_proto(),
3737 collaborators: collaborators.clone(),
3738 };
3739 channel_buffer_updated(
3740 session.connection_id,
3741 collaborators
3742 .iter()
3743 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3744 &update,
3745 &session.peer,
3746 );
3747
3748 Ok(())
3749}
3750
3751/// Edit the channel notes
3752async fn update_channel_buffer(
3753 request: proto::UpdateChannelBuffer,
3754 session: UserSession,
3755) -> Result<()> {
3756 let db = session.db().await;
3757 let channel_id = ChannelId::from_proto(request.channel_id);
3758
3759 let (collaborators, non_collaborators, epoch, version) = db
3760 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3761 .await?;
3762
3763 channel_buffer_updated(
3764 session.connection_id,
3765 collaborators,
3766 &proto::UpdateChannelBuffer {
3767 channel_id: channel_id.to_proto(),
3768 operations: request.operations,
3769 },
3770 &session.peer,
3771 );
3772
3773 let pool = &*session.connection_pool().await;
3774
3775 broadcast(
3776 None,
3777 non_collaborators
3778 .iter()
3779 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3780 |peer_id| {
3781 session.peer.send(
3782 peer_id,
3783 proto::UpdateChannels {
3784 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3785 channel_id: channel_id.to_proto(),
3786 epoch: epoch as u64,
3787 version: version.clone(),
3788 }],
3789 ..Default::default()
3790 },
3791 )
3792 },
3793 );
3794
3795 Ok(())
3796}
3797
3798/// Rejoin the channel notes after a connection blip
3799async fn rejoin_channel_buffers(
3800 request: proto::RejoinChannelBuffers,
3801 response: Response<proto::RejoinChannelBuffers>,
3802 session: UserSession,
3803) -> Result<()> {
3804 let db = session.db().await;
3805 let buffers = db
3806 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3807 .await?;
3808
3809 for rejoined_buffer in &buffers {
3810 let collaborators_to_notify = rejoined_buffer
3811 .buffer
3812 .collaborators
3813 .iter()
3814 .filter_map(|c| Some(c.peer_id?.into()));
3815 channel_buffer_updated(
3816 session.connection_id,
3817 collaborators_to_notify,
3818 &proto::UpdateChannelBufferCollaborators {
3819 channel_id: rejoined_buffer.buffer.channel_id,
3820 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3821 },
3822 &session.peer,
3823 );
3824 }
3825
3826 response.send(proto::RejoinChannelBuffersResponse {
3827 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3828 })?;
3829
3830 Ok(())
3831}
3832
3833/// Stop editing the channel notes
3834async fn leave_channel_buffer(
3835 request: proto::LeaveChannelBuffer,
3836 response: Response<proto::LeaveChannelBuffer>,
3837 session: UserSession,
3838) -> Result<()> {
3839 let db = session.db().await;
3840 let channel_id = ChannelId::from_proto(request.channel_id);
3841
3842 let left_buffer = db
3843 .leave_channel_buffer(channel_id, session.connection_id)
3844 .await?;
3845
3846 response.send(Ack {})?;
3847
3848 channel_buffer_updated(
3849 session.connection_id,
3850 left_buffer.connections,
3851 &proto::UpdateChannelBufferCollaborators {
3852 channel_id: channel_id.to_proto(),
3853 collaborators: left_buffer.collaborators,
3854 },
3855 &session.peer,
3856 );
3857
3858 Ok(())
3859}
3860
3861fn channel_buffer_updated<T: EnvelopedMessage>(
3862 sender_id: ConnectionId,
3863 collaborators: impl IntoIterator<Item = ConnectionId>,
3864 message: &T,
3865 peer: &Peer,
3866) {
3867 broadcast(Some(sender_id), collaborators, |peer_id| {
3868 peer.send(peer_id, message.clone())
3869 });
3870}
3871
3872fn send_notifications(
3873 connection_pool: &ConnectionPool,
3874 peer: &Peer,
3875 notifications: db::NotificationBatch,
3876) {
3877 for (user_id, notification) in notifications {
3878 for connection_id in connection_pool.user_connection_ids(user_id) {
3879 if let Err(error) = peer.send(
3880 connection_id,
3881 proto::AddNotification {
3882 notification: Some(notification.clone()),
3883 },
3884 ) {
3885 tracing::error!(
3886 "failed to send notification to {:?} {}",
3887 connection_id,
3888 error
3889 );
3890 }
3891 }
3892 }
3893}
3894
3895/// Send a message to the channel
3896async fn send_channel_message(
3897 request: proto::SendChannelMessage,
3898 response: Response<proto::SendChannelMessage>,
3899 session: UserSession,
3900) -> Result<()> {
3901 // Validate the message body.
3902 let body = request.body.trim().to_string();
3903 if body.len() > MAX_MESSAGE_LEN {
3904 return Err(anyhow!("message is too long"))?;
3905 }
3906 if body.is_empty() {
3907 return Err(anyhow!("message can't be blank"))?;
3908 }
3909
3910 // TODO: adjust mentions if body is trimmed
3911
3912 let timestamp = OffsetDateTime::now_utc();
3913 let nonce = request
3914 .nonce
3915 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3916
3917 let channel_id = ChannelId::from_proto(request.channel_id);
3918 let CreatedChannelMessage {
3919 message_id,
3920 participant_connection_ids,
3921 channel_members,
3922 notifications,
3923 } = session
3924 .db()
3925 .await
3926 .create_channel_message(
3927 channel_id,
3928 session.user_id(),
3929 &body,
3930 &request.mentions,
3931 timestamp,
3932 nonce.clone().into(),
3933 match request.reply_to_message_id {
3934 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3935 None => None,
3936 },
3937 )
3938 .await?;
3939
3940 let message = proto::ChannelMessage {
3941 sender_id: session.user_id().to_proto(),
3942 id: message_id.to_proto(),
3943 body,
3944 mentions: request.mentions,
3945 timestamp: timestamp.unix_timestamp() as u64,
3946 nonce: Some(nonce),
3947 reply_to_message_id: request.reply_to_message_id,
3948 edited_at: None,
3949 };
3950 broadcast(
3951 Some(session.connection_id),
3952 participant_connection_ids,
3953 |connection| {
3954 session.peer.send(
3955 connection,
3956 proto::ChannelMessageSent {
3957 channel_id: channel_id.to_proto(),
3958 message: Some(message.clone()),
3959 },
3960 )
3961 },
3962 );
3963 response.send(proto::SendChannelMessageResponse {
3964 message: Some(message),
3965 })?;
3966
3967 let pool = &*session.connection_pool().await;
3968 broadcast(
3969 None,
3970 channel_members
3971 .iter()
3972 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3973 |peer_id| {
3974 session.peer.send(
3975 peer_id,
3976 proto::UpdateChannels {
3977 latest_channel_message_ids: vec![proto::ChannelMessageId {
3978 channel_id: channel_id.to_proto(),
3979 message_id: message_id.to_proto(),
3980 }],
3981 ..Default::default()
3982 },
3983 )
3984 },
3985 );
3986 send_notifications(pool, &session.peer, notifications);
3987
3988 Ok(())
3989}
3990
3991/// Delete a channel message
3992async fn remove_channel_message(
3993 request: proto::RemoveChannelMessage,
3994 response: Response<proto::RemoveChannelMessage>,
3995 session: UserSession,
3996) -> Result<()> {
3997 let channel_id = ChannelId::from_proto(request.channel_id);
3998 let message_id = MessageId::from_proto(request.message_id);
3999 let (connection_ids, existing_notification_ids) = session
4000 .db()
4001 .await
4002 .remove_channel_message(channel_id, message_id, session.user_id())
4003 .await?;
4004
4005 broadcast(
4006 Some(session.connection_id),
4007 connection_ids,
4008 move |connection| {
4009 session.peer.send(connection, request.clone())?;
4010
4011 for notification_id in &existing_notification_ids {
4012 session.peer.send(
4013 connection,
4014 proto::DeleteNotification {
4015 notification_id: (*notification_id).to_proto(),
4016 },
4017 )?;
4018 }
4019
4020 Ok(())
4021 },
4022 );
4023 response.send(proto::Ack {})?;
4024 Ok(())
4025}
4026
4027async fn update_channel_message(
4028 request: proto::UpdateChannelMessage,
4029 response: Response<proto::UpdateChannelMessage>,
4030 session: UserSession,
4031) -> Result<()> {
4032 let channel_id = ChannelId::from_proto(request.channel_id);
4033 let message_id = MessageId::from_proto(request.message_id);
4034 let updated_at = OffsetDateTime::now_utc();
4035 let UpdatedChannelMessage {
4036 message_id,
4037 participant_connection_ids,
4038 notifications,
4039 reply_to_message_id,
4040 timestamp,
4041 deleted_mention_notification_ids,
4042 updated_mention_notifications,
4043 } = session
4044 .db()
4045 .await
4046 .update_channel_message(
4047 channel_id,
4048 message_id,
4049 session.user_id(),
4050 request.body.as_str(),
4051 &request.mentions,
4052 updated_at,
4053 )
4054 .await?;
4055
4056 let nonce = request
4057 .nonce
4058 .clone()
4059 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4060
4061 let message = proto::ChannelMessage {
4062 sender_id: session.user_id().to_proto(),
4063 id: message_id.to_proto(),
4064 body: request.body.clone(),
4065 mentions: request.mentions.clone(),
4066 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4067 nonce: Some(nonce),
4068 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4069 edited_at: Some(updated_at.unix_timestamp() as u64),
4070 };
4071
4072 response.send(proto::Ack {})?;
4073
4074 let pool = &*session.connection_pool().await;
4075 broadcast(
4076 Some(session.connection_id),
4077 participant_connection_ids,
4078 |connection| {
4079 session.peer.send(
4080 connection,
4081 proto::ChannelMessageUpdate {
4082 channel_id: channel_id.to_proto(),
4083 message: Some(message.clone()),
4084 },
4085 )?;
4086
4087 for notification_id in &deleted_mention_notification_ids {
4088 session.peer.send(
4089 connection,
4090 proto::DeleteNotification {
4091 notification_id: (*notification_id).to_proto(),
4092 },
4093 )?;
4094 }
4095
4096 for notification in &updated_mention_notifications {
4097 session.peer.send(
4098 connection,
4099 proto::UpdateNotification {
4100 notification: Some(notification.clone()),
4101 },
4102 )?;
4103 }
4104
4105 Ok(())
4106 },
4107 );
4108
4109 send_notifications(pool, &session.peer, notifications);
4110
4111 Ok(())
4112}
4113
4114/// Mark a channel message as read
4115async fn acknowledge_channel_message(
4116 request: proto::AckChannelMessage,
4117 session: UserSession,
4118) -> Result<()> {
4119 let channel_id = ChannelId::from_proto(request.channel_id);
4120 let message_id = MessageId::from_proto(request.message_id);
4121 let notifications = session
4122 .db()
4123 .await
4124 .observe_channel_message(channel_id, session.user_id(), message_id)
4125 .await?;
4126 send_notifications(
4127 &*session.connection_pool().await,
4128 &session.peer,
4129 notifications,
4130 );
4131 Ok(())
4132}
4133
4134/// Mark a buffer version as synced
4135async fn acknowledge_buffer_version(
4136 request: proto::AckBufferOperation,
4137 session: UserSession,
4138) -> Result<()> {
4139 let buffer_id = BufferId::from_proto(request.buffer_id);
4140 session
4141 .db()
4142 .await
4143 .observe_buffer_version(
4144 buffer_id,
4145 session.user_id(),
4146 request.epoch as i32,
4147 &request.version,
4148 )
4149 .await?;
4150 Ok(())
4151}
4152
4153struct CompleteWithLanguageModelRateLimit;
4154
4155impl RateLimit for CompleteWithLanguageModelRateLimit {
4156 fn capacity() -> usize {
4157 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4158 .ok()
4159 .and_then(|v| v.parse().ok())
4160 .unwrap_or(120) // Picked arbitrarily
4161 }
4162
4163 fn refill_duration() -> chrono::Duration {
4164 chrono::Duration::hours(1)
4165 }
4166
4167 fn db_name() -> &'static str {
4168 "complete-with-language-model"
4169 }
4170}
4171
4172async fn complete_with_language_model(
4173 request: proto::CompleteWithLanguageModel,
4174 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4175 session: Session,
4176 open_ai_api_key: Option<Arc<str>>,
4177 google_ai_api_key: Option<Arc<str>>,
4178 anthropic_api_key: Option<Arc<str>>,
4179) -> Result<()> {
4180 let Some(session) = session.for_user() else {
4181 return Err(anyhow!("user not found"))?;
4182 };
4183 authorize_access_to_language_models(&session).await?;
4184 session
4185 .rate_limiter
4186 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4187 .await?;
4188
4189 if request.model.starts_with("gpt") {
4190 let api_key =
4191 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4192 complete_with_open_ai(request, response, session, api_key).await?;
4193 } else if request.model.starts_with("gemini") {
4194 let api_key = google_ai_api_key
4195 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4196 complete_with_google_ai(request, response, session, api_key).await?;
4197 } else if request.model.starts_with("claude") {
4198 let api_key = anthropic_api_key
4199 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4200 complete_with_anthropic(request, response, session, api_key).await?;
4201 }
4202
4203 Ok(())
4204}
4205
4206async fn complete_with_open_ai(
4207 request: proto::CompleteWithLanguageModel,
4208 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4209 session: UserSession,
4210 api_key: Arc<str>,
4211) -> Result<()> {
4212 let mut completion_stream = open_ai::stream_completion(
4213 &session.http_client,
4214 OPEN_AI_API_URL,
4215 &api_key,
4216 crate::ai::language_model_request_to_open_ai(request)?,
4217 )
4218 .await
4219 .context("open_ai::stream_completion request failed within collab")?;
4220
4221 while let Some(event) = completion_stream.next().await {
4222 let event = event?;
4223 response.send(proto::LanguageModelResponse {
4224 choices: event
4225 .choices
4226 .into_iter()
4227 .map(|choice| proto::LanguageModelChoiceDelta {
4228 index: choice.index,
4229 delta: Some(proto::LanguageModelResponseMessage {
4230 role: choice.delta.role.map(|role| match role {
4231 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4232 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4233 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4234 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4235 } as i32),
4236 content: choice.delta.content,
4237 tool_calls: choice
4238 .delta
4239 .tool_calls
4240 .into_iter()
4241 .map(|delta| proto::ToolCallDelta {
4242 index: delta.index as u32,
4243 id: delta.id,
4244 variant: match delta.function {
4245 Some(function) => {
4246 let name = function.name;
4247 let arguments = function.arguments;
4248
4249 Some(proto::tool_call_delta::Variant::Function(
4250 proto::tool_call_delta::FunctionCallDelta {
4251 name,
4252 arguments,
4253 },
4254 ))
4255 }
4256 None => None,
4257 },
4258 })
4259 .collect(),
4260 }),
4261 finish_reason: choice.finish_reason,
4262 })
4263 .collect(),
4264 })?;
4265 }
4266
4267 Ok(())
4268}
4269
4270async fn complete_with_google_ai(
4271 request: proto::CompleteWithLanguageModel,
4272 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4273 session: UserSession,
4274 api_key: Arc<str>,
4275) -> Result<()> {
4276 let mut stream = google_ai::stream_generate_content(
4277 &session.http_client,
4278 google_ai::API_URL,
4279 api_key.as_ref(),
4280 crate::ai::language_model_request_to_google_ai(request)?,
4281 )
4282 .await
4283 .context("google_ai::stream_generate_content request failed")?;
4284
4285 while let Some(event) = stream.next().await {
4286 let event = event?;
4287 response.send(proto::LanguageModelResponse {
4288 choices: event
4289 .candidates
4290 .unwrap_or_default()
4291 .into_iter()
4292 .map(|candidate| proto::LanguageModelChoiceDelta {
4293 index: candidate.index as u32,
4294 delta: Some(proto::LanguageModelResponseMessage {
4295 role: Some(match candidate.content.role {
4296 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4297 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4298 } as i32),
4299 content: Some(
4300 candidate
4301 .content
4302 .parts
4303 .into_iter()
4304 .filter_map(|part| match part {
4305 google_ai::Part::TextPart(part) => Some(part.text),
4306 google_ai::Part::InlineDataPart(_) => None,
4307 })
4308 .collect(),
4309 ),
4310 // Tool calls are not supported for Google
4311 tool_calls: Vec::new(),
4312 }),
4313 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4314 })
4315 .collect(),
4316 })?;
4317 }
4318
4319 Ok(())
4320}
4321
4322async fn complete_with_anthropic(
4323 request: proto::CompleteWithLanguageModel,
4324 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4325 session: UserSession,
4326 api_key: Arc<str>,
4327) -> Result<()> {
4328 let model = anthropic::Model::from_id(&request.model)?;
4329
4330 let mut system_message = String::new();
4331 let messages = request
4332 .messages
4333 .into_iter()
4334 .filter_map(|message| {
4335 match message.role() {
4336 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4337 role: anthropic::Role::User,
4338 content: message.content,
4339 }),
4340 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4341 role: anthropic::Role::Assistant,
4342 content: message.content,
4343 }),
4344 // Anthropic's API breaks system instructions out as a separate field rather
4345 // than having a system message role.
4346 LanguageModelRole::LanguageModelSystem => {
4347 if !system_message.is_empty() {
4348 system_message.push_str("\n\n");
4349 }
4350 system_message.push_str(&message.content);
4351
4352 None
4353 }
4354 // We don't yet support tool calls for Anthropic
4355 LanguageModelRole::LanguageModelTool => None,
4356 }
4357 })
4358 .collect();
4359
4360 let mut stream = anthropic::stream_completion(
4361 &session.http_client,
4362 "https://api.anthropic.com",
4363 &api_key,
4364 anthropic::Request {
4365 model,
4366 messages,
4367 stream: true,
4368 system: system_message,
4369 max_tokens: 4092,
4370 },
4371 )
4372 .await?;
4373
4374 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4375
4376 while let Some(event) = stream.next().await {
4377 let event = event?;
4378
4379 match event {
4380 anthropic::ResponseEvent::MessageStart { message } => {
4381 if let Some(role) = message.role {
4382 if role == "assistant" {
4383 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4384 } else if role == "user" {
4385 current_role = proto::LanguageModelRole::LanguageModelUser;
4386 }
4387 }
4388 }
4389 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4390 match content_block {
4391 anthropic::ContentBlock::Text { text } => {
4392 if !text.is_empty() {
4393 response.send(proto::LanguageModelResponse {
4394 choices: vec![proto::LanguageModelChoiceDelta {
4395 index: 0,
4396 delta: Some(proto::LanguageModelResponseMessage {
4397 role: Some(current_role as i32),
4398 content: Some(text),
4399 tool_calls: Vec::new(),
4400 }),
4401 finish_reason: None,
4402 }],
4403 })?;
4404 }
4405 }
4406 }
4407 }
4408 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4409 anthropic::TextDelta::TextDelta { text } => {
4410 response.send(proto::LanguageModelResponse {
4411 choices: vec![proto::LanguageModelChoiceDelta {
4412 index: 0,
4413 delta: Some(proto::LanguageModelResponseMessage {
4414 role: Some(current_role as i32),
4415 content: Some(text),
4416 tool_calls: Vec::new(),
4417 }),
4418 finish_reason: None,
4419 }],
4420 })?;
4421 }
4422 },
4423 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4424 if let Some(stop_reason) = delta.stop_reason {
4425 response.send(proto::LanguageModelResponse {
4426 choices: vec![proto::LanguageModelChoiceDelta {
4427 index: 0,
4428 delta: None,
4429 finish_reason: Some(stop_reason),
4430 }],
4431 })?;
4432 }
4433 }
4434 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4435 anthropic::ResponseEvent::MessageStop {} => {}
4436 anthropic::ResponseEvent::Ping {} => {}
4437 }
4438 }
4439
4440 Ok(())
4441}
4442
4443struct CountTokensWithLanguageModelRateLimit;
4444
4445impl RateLimit for CountTokensWithLanguageModelRateLimit {
4446 fn capacity() -> usize {
4447 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4448 .ok()
4449 .and_then(|v| v.parse().ok())
4450 .unwrap_or(600) // Picked arbitrarily
4451 }
4452
4453 fn refill_duration() -> chrono::Duration {
4454 chrono::Duration::hours(1)
4455 }
4456
4457 fn db_name() -> &'static str {
4458 "count-tokens-with-language-model"
4459 }
4460}
4461
4462async fn count_tokens_with_language_model(
4463 request: proto::CountTokensWithLanguageModel,
4464 response: Response<proto::CountTokensWithLanguageModel>,
4465 session: UserSession,
4466 google_ai_api_key: Option<Arc<str>>,
4467) -> Result<()> {
4468 authorize_access_to_language_models(&session).await?;
4469
4470 if !request.model.starts_with("gemini") {
4471 return Err(anyhow!(
4472 "counting tokens for model: {:?} is not supported",
4473 request.model
4474 ))?;
4475 }
4476
4477 session
4478 .rate_limiter
4479 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4480 .await?;
4481
4482 let api_key = google_ai_api_key
4483 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4484 let tokens_response = google_ai::count_tokens(
4485 &session.http_client,
4486 google_ai::API_URL,
4487 &api_key,
4488 crate::ai::count_tokens_request_to_google_ai(request)?,
4489 )
4490 .await?;
4491 response.send(proto::CountTokensResponse {
4492 token_count: tokens_response.total_tokens as u32,
4493 })?;
4494 Ok(())
4495}
4496
4497struct ComputeEmbeddingsRateLimit;
4498
4499impl RateLimit for ComputeEmbeddingsRateLimit {
4500 fn capacity() -> usize {
4501 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4502 .ok()
4503 .and_then(|v| v.parse().ok())
4504 .unwrap_or(5000) // Picked arbitrarily
4505 }
4506
4507 fn refill_duration() -> chrono::Duration {
4508 chrono::Duration::hours(1)
4509 }
4510
4511 fn db_name() -> &'static str {
4512 "compute-embeddings"
4513 }
4514}
4515
4516async fn compute_embeddings(
4517 request: proto::ComputeEmbeddings,
4518 response: Response<proto::ComputeEmbeddings>,
4519 session: UserSession,
4520 api_key: Option<Arc<str>>,
4521) -> Result<()> {
4522 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4523 authorize_access_to_language_models(&session).await?;
4524
4525 session
4526 .rate_limiter
4527 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4528 .await?;
4529
4530 let embeddings = match request.model.as_str() {
4531 "openai/text-embedding-3-small" => {
4532 open_ai::embed(
4533 &session.http_client,
4534 OPEN_AI_API_URL,
4535 &api_key,
4536 OpenAiEmbeddingModel::TextEmbedding3Small,
4537 request.texts.iter().map(|text| text.as_str()),
4538 )
4539 .await?
4540 }
4541 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4542 };
4543
4544 let embeddings = request
4545 .texts
4546 .iter()
4547 .map(|text| {
4548 let mut hasher = sha2::Sha256::new();
4549 hasher.update(text.as_bytes());
4550 let result = hasher.finalize();
4551 result.to_vec()
4552 })
4553 .zip(
4554 embeddings
4555 .data
4556 .into_iter()
4557 .map(|embedding| embedding.embedding),
4558 )
4559 .collect::<HashMap<_, _>>();
4560
4561 let db = session.db().await;
4562 db.save_embeddings(&request.model, &embeddings)
4563 .await
4564 .context("failed to save embeddings")
4565 .trace_err();
4566
4567 response.send(proto::ComputeEmbeddingsResponse {
4568 embeddings: embeddings
4569 .into_iter()
4570 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4571 .collect(),
4572 })?;
4573 Ok(())
4574}
4575
4576async fn get_cached_embeddings(
4577 request: proto::GetCachedEmbeddings,
4578 response: Response<proto::GetCachedEmbeddings>,
4579 session: UserSession,
4580) -> Result<()> {
4581 authorize_access_to_language_models(&session).await?;
4582
4583 let db = session.db().await;
4584 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4585
4586 response.send(proto::GetCachedEmbeddingsResponse {
4587 embeddings: embeddings
4588 .into_iter()
4589 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4590 .collect(),
4591 })?;
4592 Ok(())
4593}
4594
4595async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4596 let db = session.db().await;
4597 let flags = db.get_user_flags(session.user_id()).await?;
4598 if flags.iter().any(|flag| flag == "language-models") {
4599 Ok(())
4600 } else {
4601 Err(anyhow!("permission denied"))?
4602 }
4603}
4604
4605/// Start receiving chat updates for a channel
4606async fn join_channel_chat(
4607 request: proto::JoinChannelChat,
4608 response: Response<proto::JoinChannelChat>,
4609 session: UserSession,
4610) -> Result<()> {
4611 let channel_id = ChannelId::from_proto(request.channel_id);
4612
4613 let db = session.db().await;
4614 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4615 .await?;
4616 let messages = db
4617 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4618 .await?;
4619 response.send(proto::JoinChannelChatResponse {
4620 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4621 messages,
4622 })?;
4623 Ok(())
4624}
4625
4626/// Stop receiving chat updates for a channel
4627async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4628 let channel_id = ChannelId::from_proto(request.channel_id);
4629 session
4630 .db()
4631 .await
4632 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4633 .await?;
4634 Ok(())
4635}
4636
4637/// Retrieve the chat history for a channel
4638async fn get_channel_messages(
4639 request: proto::GetChannelMessages,
4640 response: Response<proto::GetChannelMessages>,
4641 session: UserSession,
4642) -> Result<()> {
4643 let channel_id = ChannelId::from_proto(request.channel_id);
4644 let messages = session
4645 .db()
4646 .await
4647 .get_channel_messages(
4648 channel_id,
4649 session.user_id(),
4650 MESSAGE_COUNT_PER_PAGE,
4651 Some(MessageId::from_proto(request.before_message_id)),
4652 )
4653 .await?;
4654 response.send(proto::GetChannelMessagesResponse {
4655 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4656 messages,
4657 })?;
4658 Ok(())
4659}
4660
4661/// Retrieve specific chat messages
4662async fn get_channel_messages_by_id(
4663 request: proto::GetChannelMessagesById,
4664 response: Response<proto::GetChannelMessagesById>,
4665 session: UserSession,
4666) -> Result<()> {
4667 let message_ids = request
4668 .message_ids
4669 .iter()
4670 .map(|id| MessageId::from_proto(*id))
4671 .collect::<Vec<_>>();
4672 let messages = session
4673 .db()
4674 .await
4675 .get_channel_messages_by_id(session.user_id(), &message_ids)
4676 .await?;
4677 response.send(proto::GetChannelMessagesResponse {
4678 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4679 messages,
4680 })?;
4681 Ok(())
4682}
4683
4684/// Retrieve the current users notifications
4685async fn get_notifications(
4686 request: proto::GetNotifications,
4687 response: Response<proto::GetNotifications>,
4688 session: UserSession,
4689) -> Result<()> {
4690 let notifications = session
4691 .db()
4692 .await
4693 .get_notifications(
4694 session.user_id(),
4695 NOTIFICATION_COUNT_PER_PAGE,
4696 request
4697 .before_id
4698 .map(|id| db::NotificationId::from_proto(id)),
4699 )
4700 .await?;
4701 response.send(proto::GetNotificationsResponse {
4702 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4703 notifications,
4704 })?;
4705 Ok(())
4706}
4707
4708/// Mark notifications as read
4709async fn mark_notification_as_read(
4710 request: proto::MarkNotificationRead,
4711 response: Response<proto::MarkNotificationRead>,
4712 session: UserSession,
4713) -> Result<()> {
4714 let database = &session.db().await;
4715 let notifications = database
4716 .mark_notification_as_read_by_id(
4717 session.user_id(),
4718 NotificationId::from_proto(request.notification_id),
4719 )
4720 .await?;
4721 send_notifications(
4722 &*session.connection_pool().await,
4723 &session.peer,
4724 notifications,
4725 );
4726 response.send(proto::Ack {})?;
4727 Ok(())
4728}
4729
4730/// Get the current users information
4731async fn get_private_user_info(
4732 _request: proto::GetPrivateUserInfo,
4733 response: Response<proto::GetPrivateUserInfo>,
4734 session: UserSession,
4735) -> Result<()> {
4736 let db = session.db().await;
4737
4738 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4739 let user = db
4740 .get_user_by_id(session.user_id())
4741 .await?
4742 .ok_or_else(|| anyhow!("user not found"))?;
4743 let flags = db.get_user_flags(session.user_id()).await?;
4744
4745 response.send(proto::GetPrivateUserInfoResponse {
4746 metrics_id,
4747 staff: user.admin,
4748 flags,
4749 })?;
4750 Ok(())
4751}
4752
4753fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4754 match message {
4755 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4756 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4757 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4758 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4759 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4760 code: frame.code.into(),
4761 reason: frame.reason,
4762 })),
4763 }
4764}
4765
4766fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4767 match message {
4768 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4769 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4770 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4771 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4772 AxumMessage::Close(frame) => {
4773 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4774 code: frame.code.into(),
4775 reason: frame.reason,
4776 }))
4777 }
4778 }
4779}
4780
4781fn notify_membership_updated(
4782 connection_pool: &mut ConnectionPool,
4783 result: MembershipUpdated,
4784 user_id: UserId,
4785 peer: &Peer,
4786) {
4787 for membership in &result.new_channels.channel_memberships {
4788 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4789 }
4790 for channel_id in &result.removed_channels {
4791 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4792 }
4793
4794 let user_channels_update = proto::UpdateUserChannels {
4795 channel_memberships: result
4796 .new_channels
4797 .channel_memberships
4798 .iter()
4799 .map(|cm| proto::ChannelMembership {
4800 channel_id: cm.channel_id.to_proto(),
4801 role: cm.role.into(),
4802 })
4803 .collect(),
4804 ..Default::default()
4805 };
4806
4807 let mut update = build_channels_update(result.new_channels, vec![]);
4808 update.delete_channels = result
4809 .removed_channels
4810 .into_iter()
4811 .map(|id| id.to_proto())
4812 .collect();
4813 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4814
4815 for connection_id in connection_pool.user_connection_ids(user_id) {
4816 peer.send(connection_id, user_channels_update.clone())
4817 .trace_err();
4818 peer.send(connection_id, update.clone()).trace_err();
4819 }
4820}
4821
4822fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4823 proto::UpdateUserChannels {
4824 channel_memberships: channels
4825 .channel_memberships
4826 .iter()
4827 .map(|m| proto::ChannelMembership {
4828 channel_id: m.channel_id.to_proto(),
4829 role: m.role.into(),
4830 })
4831 .collect(),
4832 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4833 observed_channel_message_id: channels.observed_channel_messages.clone(),
4834 }
4835}
4836
4837fn build_channels_update(
4838 channels: ChannelsForUser,
4839 channel_invites: Vec<db::Channel>,
4840) -> proto::UpdateChannels {
4841 let mut update = proto::UpdateChannels::default();
4842
4843 for channel in channels.channels {
4844 update.channels.push(channel.to_proto());
4845 }
4846
4847 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4848 update.latest_channel_message_ids = channels.latest_channel_messages;
4849
4850 for (channel_id, participants) in channels.channel_participants {
4851 update
4852 .channel_participants
4853 .push(proto::ChannelParticipants {
4854 channel_id: channel_id.to_proto(),
4855 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4856 });
4857 }
4858
4859 for channel in channel_invites {
4860 update.channel_invitations.push(channel.to_proto());
4861 }
4862
4863 update.hosted_projects = channels.hosted_projects;
4864 update
4865}
4866
4867fn build_initial_contacts_update(
4868 contacts: Vec<db::Contact>,
4869 pool: &ConnectionPool,
4870) -> proto::UpdateContacts {
4871 let mut update = proto::UpdateContacts::default();
4872
4873 for contact in contacts {
4874 match contact {
4875 db::Contact::Accepted { user_id, busy } => {
4876 update.contacts.push(contact_for_user(user_id, busy, &pool));
4877 }
4878 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4879 db::Contact::Incoming { user_id } => {
4880 update
4881 .incoming_requests
4882 .push(proto::IncomingContactRequest {
4883 requester_id: user_id.to_proto(),
4884 })
4885 }
4886 }
4887 }
4888
4889 update
4890}
4891
4892fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4893 proto::Contact {
4894 user_id: user_id.to_proto(),
4895 online: pool.is_user_online(user_id),
4896 busy,
4897 }
4898}
4899
4900fn room_updated(room: &proto::Room, peer: &Peer) {
4901 broadcast(
4902 None,
4903 room.participants
4904 .iter()
4905 .filter_map(|participant| Some(participant.peer_id?.into())),
4906 |peer_id| {
4907 peer.send(
4908 peer_id,
4909 proto::RoomUpdated {
4910 room: Some(room.clone()),
4911 },
4912 )
4913 },
4914 );
4915}
4916
4917fn channel_updated(
4918 channel: &db::channel::Model,
4919 room: &proto::Room,
4920 peer: &Peer,
4921 pool: &ConnectionPool,
4922) {
4923 let participants = room
4924 .participants
4925 .iter()
4926 .map(|p| p.user_id)
4927 .collect::<Vec<_>>();
4928
4929 broadcast(
4930 None,
4931 pool.channel_connection_ids(channel.root_id())
4932 .filter_map(|(channel_id, role)| {
4933 role.can_see_channel(channel.visibility).then(|| channel_id)
4934 }),
4935 |peer_id| {
4936 peer.send(
4937 peer_id,
4938 proto::UpdateChannels {
4939 channel_participants: vec![proto::ChannelParticipants {
4940 channel_id: channel.id.to_proto(),
4941 participant_user_ids: participants.clone(),
4942 }],
4943 ..Default::default()
4944 },
4945 )
4946 },
4947 );
4948}
4949
4950async fn send_dev_server_projects_update(
4951 user_id: UserId,
4952 mut status: proto::DevServerProjectsUpdate,
4953 session: &Session,
4954) {
4955 let pool = session.connection_pool().await;
4956 for dev_server in &mut status.dev_servers {
4957 dev_server.status =
4958 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4959 }
4960 let connections = pool.user_connection_ids(user_id);
4961 for connection_id in connections {
4962 session.peer.send(connection_id, status.clone()).trace_err();
4963 }
4964}
4965
4966async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4967 let db = session.db().await;
4968
4969 let contacts = db.get_contacts(user_id).await?;
4970 let busy = db.is_user_busy(user_id).await?;
4971
4972 let pool = session.connection_pool().await;
4973 let updated_contact = contact_for_user(user_id, busy, &pool);
4974 for contact in contacts {
4975 if let db::Contact::Accepted {
4976 user_id: contact_user_id,
4977 ..
4978 } = contact
4979 {
4980 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4981 session
4982 .peer
4983 .send(
4984 contact_conn_id,
4985 proto::UpdateContacts {
4986 contacts: vec![updated_contact.clone()],
4987 remove_contacts: Default::default(),
4988 incoming_requests: Default::default(),
4989 remove_incoming_requests: Default::default(),
4990 outgoing_requests: Default::default(),
4991 remove_outgoing_requests: Default::default(),
4992 },
4993 )
4994 .trace_err();
4995 }
4996 }
4997 }
4998 Ok(())
4999}
5000
5001async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5002 log::info!("lost dev server connection, unsharing projects");
5003 let project_ids = session
5004 .db()
5005 .await
5006 .get_stale_dev_server_projects(session.connection_id)
5007 .await?;
5008
5009 for project_id in project_ids {
5010 // not unshare re-checks the connection ids match, so we get away with no transaction
5011 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5012 }
5013
5014 let user_id = session.dev_server().user_id;
5015 let update = session
5016 .db()
5017 .await
5018 .dev_server_projects_update(user_id)
5019 .await?;
5020
5021 send_dev_server_projects_update(user_id, update, session).await;
5022
5023 Ok(())
5024}
5025
5026async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5027 let mut contacts_to_update = HashSet::default();
5028
5029 let room_id;
5030 let canceled_calls_to_user_ids;
5031 let live_kit_room;
5032 let delete_live_kit_room;
5033 let room;
5034 let channel;
5035
5036 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5037 contacts_to_update.insert(session.user_id());
5038
5039 for project in left_room.left_projects.values() {
5040 project_left(project, session);
5041 }
5042
5043 room_id = RoomId::from_proto(left_room.room.id);
5044 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5045 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5046 delete_live_kit_room = left_room.deleted;
5047 room = mem::take(&mut left_room.room);
5048 channel = mem::take(&mut left_room.channel);
5049
5050 room_updated(&room, &session.peer);
5051 } else {
5052 return Ok(());
5053 }
5054
5055 if let Some(channel) = channel {
5056 channel_updated(
5057 &channel,
5058 &room,
5059 &session.peer,
5060 &*session.connection_pool().await,
5061 );
5062 }
5063
5064 {
5065 let pool = session.connection_pool().await;
5066 for canceled_user_id in canceled_calls_to_user_ids {
5067 for connection_id in pool.user_connection_ids(canceled_user_id) {
5068 session
5069 .peer
5070 .send(
5071 connection_id,
5072 proto::CallCanceled {
5073 room_id: room_id.to_proto(),
5074 },
5075 )
5076 .trace_err();
5077 }
5078 contacts_to_update.insert(canceled_user_id);
5079 }
5080 }
5081
5082 for contact_user_id in contacts_to_update {
5083 update_user_contacts(contact_user_id, &session).await?;
5084 }
5085
5086 if let Some(live_kit) = session.live_kit_client.as_ref() {
5087 live_kit
5088 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5089 .await
5090 .trace_err();
5091
5092 if delete_live_kit_room {
5093 live_kit.delete_room(live_kit_room).await.trace_err();
5094 }
5095 }
5096
5097 Ok(())
5098}
5099
5100async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5101 let left_channel_buffers = session
5102 .db()
5103 .await
5104 .leave_channel_buffers(session.connection_id)
5105 .await?;
5106
5107 for left_buffer in left_channel_buffers {
5108 channel_buffer_updated(
5109 session.connection_id,
5110 left_buffer.connections,
5111 &proto::UpdateChannelBufferCollaborators {
5112 channel_id: left_buffer.channel_id.to_proto(),
5113 collaborators: left_buffer.collaborators,
5114 },
5115 &session.peer,
5116 );
5117 }
5118
5119 Ok(())
5120}
5121
5122fn project_left(project: &db::LeftProject, session: &UserSession) {
5123 for connection_id in &project.connection_ids {
5124 if project.should_unshare {
5125 session
5126 .peer
5127 .send(
5128 *connection_id,
5129 proto::UnshareProject {
5130 project_id: project.id.to_proto(),
5131 },
5132 )
5133 .trace_err();
5134 } else {
5135 session
5136 .peer
5137 .send(
5138 *connection_id,
5139 proto::RemoveProjectCollaborator {
5140 project_id: project.id.to_proto(),
5141 peer_id: Some(session.connection_id.into()),
5142 },
5143 )
5144 .trace_err();
5145 }
5146 }
5147}
5148
5149pub trait ResultExt {
5150 type Ok;
5151
5152 fn trace_err(self) -> Option<Self::Ok>;
5153}
5154
5155impl<T, E> ResultExt for Result<T, E>
5156where
5157 E: std::fmt::Debug,
5158{
5159 type Ok = T;
5160
5161 #[track_caller]
5162 fn trace_err(self) -> Option<T> {
5163 match self {
5164 Ok(value) => Some(value),
5165 Err(error) => {
5166 tracing::error!("{:?}", error);
5167 None
5168 }
5169 }
5170 }
5171}