1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
6use crate::{
7 AppState, Error, Result, auth,
8 db::{
9 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
10 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
11 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
12 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15};
16use anyhow::{Context as _, anyhow, bail};
17use async_tungstenite::tungstenite::{
18 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
19};
20use axum::{
21 Extension, Router, TypedHeader,
22 body::Body,
23 extract::{
24 ConnectInfo, WebSocketUpgrade,
25 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40use util::maybe;
41
42use futures::{
43 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
44 stream::FuturesUnordered,
45};
46use prometheus::{IntGauge, register_int_gauge};
47use rpc::{
48 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 Arc, OnceLock,
66 atomic::{AtomicBool, Ordering::SeqCst},
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{MutexGuard, Semaphore, watch};
72use tower::ServiceBuilder;
73use tracing::{
74 Instrument,
75 field::{self},
76 info_span, instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 /// The GeoIP country code for the user.
137 #[allow(unused)]
138 geoip_country_code: Option<String>,
139 system_id: Option<String>,
140 _executor: Executor,
141}
142
143impl Session {
144 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
145 #[cfg(test)]
146 tokio::task::yield_now().await;
147 let guard = self.db.lock().await;
148 #[cfg(test)]
149 tokio::task::yield_now().await;
150 guard
151 }
152
153 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.connection_pool.lock();
157 ConnectionPoolGuard {
158 guard,
159 _not_send: PhantomData,
160 }
161 }
162
163 fn is_staff(&self) -> bool {
164 match &self.principal {
165 Principal::User(user) => user.admin,
166 Principal::Impersonated { .. } => true,
167 }
168 }
169
170 pub async fn has_llm_subscription(
171 &self,
172 db: &MutexGuard<'_, DbHandle>,
173 ) -> anyhow::Result<bool> {
174 if self.is_staff() {
175 return Ok(true);
176 }
177
178 let user_id = self.user_id();
179
180 Ok(db.has_active_billing_subscription(user_id).await?)
181 }
182
183 pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
184 let user_id = self.user_id();
185
186 let subscription = db.get_active_billing_subscription(user_id).await?;
187 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
188
189 let plan = if let Some(subscription_kind) = subscription_kind {
190 match subscription_kind {
191 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
192 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
193 SubscriptionKind::ZedFree => proto::Plan::Free,
194 }
195 } else {
196 proto::Plan::Free
197 };
198
199 Ok(plan)
200 }
201
202 fn user_id(&self) -> UserId {
203 match &self.principal {
204 Principal::User(user) => user.id,
205 Principal::Impersonated { user, .. } => user.id,
206 }
207 }
208
209 pub fn email(&self) -> Option<String> {
210 match &self.principal {
211 Principal::User(user) => user.email_address.clone(),
212 Principal::Impersonated { user, .. } => user.email_address.clone(),
213 }
214 }
215}
216
217impl Debug for Session {
218 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219 let mut result = f.debug_struct("Session");
220 match &self.principal {
221 Principal::User(user) => {
222 result.field("user", &user.github_login);
223 }
224 Principal::Impersonated { user, admin } => {
225 result.field("user", &user.github_login);
226 result.field("impersonator", &admin.github_login);
227 }
228 }
229 result.field("connection_id", &self.connection_id).finish()
230 }
231}
232
233struct DbHandle(Arc<Database>);
234
235impl Deref for DbHandle {
236 type Target = Database;
237
238 fn deref(&self) -> &Self::Target {
239 self.0.as_ref()
240 }
241}
242
243pub struct Server {
244 id: parking_lot::Mutex<ServerId>,
245 peer: Arc<Peer>,
246 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
247 app_state: Arc<AppState>,
248 handlers: HashMap<TypeId, MessageHandler>,
249 teardown: watch::Sender<bool>,
250}
251
252pub(crate) struct ConnectionPoolGuard<'a> {
253 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
254 _not_send: PhantomData<Rc<()>>,
255}
256
257#[derive(Serialize)]
258pub struct ServerSnapshot<'a> {
259 peer: &'a Peer,
260 #[serde(serialize_with = "serialize_deref")]
261 connection_pool: ConnectionPoolGuard<'a>,
262}
263
264pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
265where
266 S: Serializer,
267 T: Deref<Target = U>,
268 U: Serialize,
269{
270 Serialize::serialize(value.deref(), serializer)
271}
272
273impl Server {
274 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
275 let mut server = Self {
276 id: parking_lot::Mutex::new(id),
277 peer: Peer::new(id.0 as u32),
278 app_state: app_state.clone(),
279 connection_pool: Default::default(),
280 handlers: Default::default(),
281 teardown: watch::channel(false).0,
282 };
283
284 server
285 .add_request_handler(ping)
286 .add_request_handler(create_room)
287 .add_request_handler(join_room)
288 .add_request_handler(rejoin_room)
289 .add_request_handler(leave_room)
290 .add_request_handler(set_room_participant_role)
291 .add_request_handler(call)
292 .add_request_handler(cancel_call)
293 .add_message_handler(decline_call)
294 .add_request_handler(update_participant_location)
295 .add_request_handler(share_project)
296 .add_message_handler(unshare_project)
297 .add_request_handler(join_project)
298 .add_message_handler(leave_project)
299 .add_request_handler(update_project)
300 .add_request_handler(update_worktree)
301 .add_request_handler(update_repository)
302 .add_request_handler(remove_repository)
303 .add_message_handler(start_language_server)
304 .add_message_handler(update_language_server)
305 .add_message_handler(update_diagnostic_summary)
306 .add_message_handler(update_worktree_settings)
307 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
309 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
311 .add_request_handler(forward_find_search_candidates_request)
312 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
313 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
317 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
318 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
319 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
320 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
322 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
324 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
327 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
328 .add_request_handler(
329 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
330 )
331 .add_request_handler(
332 forward_read_only_project_request::<proto::LanguageServerIdForName>,
333 )
334 .add_request_handler(
335 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
336 )
337 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
338 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
343 .add_request_handler(
344 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
345 )
346 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
347 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
348 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
349 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
350 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
351 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
352 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
353 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
354 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
355 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
356 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
357 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
358 .add_request_handler(
359 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
360 )
361 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
362 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
363 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
364 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
365 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
366 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
367 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
368 .add_message_handler(create_buffer_for_peer)
369 .add_request_handler(update_buffer)
370 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
376 .add_request_handler(get_users)
377 .add_request_handler(fuzzy_search_users)
378 .add_request_handler(request_contact)
379 .add_request_handler(remove_contact)
380 .add_request_handler(respond_to_contact_request)
381 .add_message_handler(subscribe_to_channels)
382 .add_request_handler(create_channel)
383 .add_request_handler(delete_channel)
384 .add_request_handler(invite_channel_member)
385 .add_request_handler(remove_channel_member)
386 .add_request_handler(set_channel_member_role)
387 .add_request_handler(set_channel_visibility)
388 .add_request_handler(rename_channel)
389 .add_request_handler(join_channel_buffer)
390 .add_request_handler(leave_channel_buffer)
391 .add_message_handler(update_channel_buffer)
392 .add_request_handler(rejoin_channel_buffers)
393 .add_request_handler(get_channel_members)
394 .add_request_handler(respond_to_channel_invite)
395 .add_request_handler(join_channel)
396 .add_request_handler(join_channel_chat)
397 .add_message_handler(leave_channel_chat)
398 .add_request_handler(send_channel_message)
399 .add_request_handler(remove_channel_message)
400 .add_request_handler(update_channel_message)
401 .add_request_handler(get_channel_messages)
402 .add_request_handler(get_channel_messages_by_id)
403 .add_request_handler(get_notifications)
404 .add_request_handler(mark_notification_as_read)
405 .add_request_handler(move_channel)
406 .add_request_handler(follow)
407 .add_message_handler(unfollow)
408 .add_message_handler(update_followers)
409 .add_request_handler(get_private_user_info)
410 .add_request_handler(get_llm_api_token)
411 .add_request_handler(accept_terms_of_service)
412 .add_message_handler(acknowledge_channel_message)
413 .add_message_handler(acknowledge_buffer_version)
414 .add_request_handler(get_supermaven_api_key)
415 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
416 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
417 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
418 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
419 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
420 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
421 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
422 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
423 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
424 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
426 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
427 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
428 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
429 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
430 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
431 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
435 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
436 .add_message_handler(update_context);
437
438 Arc::new(server)
439 }
440
441 pub async fn start(&self) -> Result<()> {
442 let server_id = *self.id.lock();
443 let app_state = self.app_state.clone();
444 let peer = self.peer.clone();
445 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
446 let pool = self.connection_pool.clone();
447 let livekit_client = self.app_state.livekit_client.clone();
448
449 let span = info_span!("start server");
450 self.app_state.executor.spawn_detached(
451 async move {
452 tracing::info!("waiting for cleanup timeout");
453 timeout.await;
454 tracing::info!("cleanup timeout expired, retrieving stale rooms");
455 if let Some((room_ids, channel_ids)) = app_state
456 .db
457 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
458 .await
459 .trace_err()
460 {
461 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
462 tracing::info!(
463 stale_channel_buffer_count = channel_ids.len(),
464 "retrieved stale channel buffers"
465 );
466
467 for channel_id in channel_ids {
468 if let Some(refreshed_channel_buffer) = app_state
469 .db
470 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
471 .await
472 .trace_err()
473 {
474 for connection_id in refreshed_channel_buffer.connection_ids {
475 peer.send(
476 connection_id,
477 proto::UpdateChannelBufferCollaborators {
478 channel_id: channel_id.to_proto(),
479 collaborators: refreshed_channel_buffer
480 .collaborators
481 .clone(),
482 },
483 )
484 .trace_err();
485 }
486 }
487 }
488
489 for room_id in room_ids {
490 let mut contacts_to_update = HashSet::default();
491 let mut canceled_calls_to_user_ids = Vec::new();
492 let mut livekit_room = String::new();
493 let mut delete_livekit_room = false;
494
495 if let Some(mut refreshed_room) = app_state
496 .db
497 .clear_stale_room_participants(room_id, server_id)
498 .await
499 .trace_err()
500 {
501 tracing::info!(
502 room_id = room_id.0,
503 new_participant_count = refreshed_room.room.participants.len(),
504 "refreshed room"
505 );
506 room_updated(&refreshed_room.room, &peer);
507 if let Some(channel) = refreshed_room.channel.as_ref() {
508 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
509 }
510 contacts_to_update
511 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
512 contacts_to_update
513 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
514 canceled_calls_to_user_ids =
515 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
516 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
517 delete_livekit_room = refreshed_room.room.participants.is_empty();
518 }
519
520 {
521 let pool = pool.lock();
522 for canceled_user_id in canceled_calls_to_user_ids {
523 for connection_id in pool.user_connection_ids(canceled_user_id) {
524 peer.send(
525 connection_id,
526 proto::CallCanceled {
527 room_id: room_id.to_proto(),
528 },
529 )
530 .trace_err();
531 }
532 }
533 }
534
535 for user_id in contacts_to_update {
536 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
537 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
538 if let Some((busy, contacts)) = busy.zip(contacts) {
539 let pool = pool.lock();
540 let updated_contact = contact_for_user(user_id, busy, &pool);
541 for contact in contacts {
542 if let db::Contact::Accepted {
543 user_id: contact_user_id,
544 ..
545 } = contact
546 {
547 for contact_conn_id in
548 pool.user_connection_ids(contact_user_id)
549 {
550 peer.send(
551 contact_conn_id,
552 proto::UpdateContacts {
553 contacts: vec![updated_contact.clone()],
554 remove_contacts: Default::default(),
555 incoming_requests: Default::default(),
556 remove_incoming_requests: Default::default(),
557 outgoing_requests: Default::default(),
558 remove_outgoing_requests: Default::default(),
559 },
560 )
561 .trace_err();
562 }
563 }
564 }
565 }
566 }
567
568 if let Some(live_kit) = livekit_client.as_ref() {
569 if delete_livekit_room {
570 live_kit.delete_room(livekit_room).await.trace_err();
571 }
572 }
573 }
574 }
575
576 app_state
577 .db
578 .delete_stale_servers(&app_state.config.zed_environment, server_id)
579 .await
580 .trace_err();
581 }
582 .instrument(span),
583 );
584 Ok(())
585 }
586
587 pub fn teardown(&self) {
588 self.peer.teardown();
589 self.connection_pool.lock().reset();
590 let _ = self.teardown.send(true);
591 }
592
593 #[cfg(test)]
594 pub fn reset(&self, id: ServerId) {
595 self.teardown();
596 *self.id.lock() = id;
597 self.peer.reset(id.0 as u32);
598 let _ = self.teardown.send(false);
599 }
600
601 #[cfg(test)]
602 pub fn id(&self) -> ServerId {
603 *self.id.lock()
604 }
605
606 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
607 where
608 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
609 Fut: 'static + Send + Future<Output = Result<()>>,
610 M: EnvelopedMessage,
611 {
612 let prev_handler = self.handlers.insert(
613 TypeId::of::<M>(),
614 Box::new(move |envelope, session| {
615 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
616 let received_at = envelope.received_at;
617 tracing::info!("message received");
618 let start_time = Instant::now();
619 let future = (handler)(*envelope, session);
620 async move {
621 let result = future.await;
622 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
623 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
624 let queue_duration_ms = total_duration_ms - processing_duration_ms;
625 let payload_type = M::NAME;
626
627 match result {
628 Err(error) => {
629 tracing::error!(
630 ?error,
631 total_duration_ms,
632 processing_duration_ms,
633 queue_duration_ms,
634 payload_type,
635 "error handling message"
636 )
637 }
638 Ok(()) => tracing::info!(
639 total_duration_ms,
640 processing_duration_ms,
641 queue_duration_ms,
642 "finished handling message"
643 ),
644 }
645 }
646 .boxed()
647 }),
648 );
649 if prev_handler.is_some() {
650 panic!("registered a handler for the same message twice");
651 }
652 self
653 }
654
655 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
656 where
657 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
658 Fut: 'static + Send + Future<Output = Result<()>>,
659 M: EnvelopedMessage,
660 {
661 self.add_handler(move |envelope, session| handler(envelope.payload, session));
662 self
663 }
664
665 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
666 where
667 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
668 Fut: Send + Future<Output = Result<()>>,
669 M: RequestMessage,
670 {
671 let handler = Arc::new(handler);
672 self.add_handler(move |envelope, session| {
673 let receipt = envelope.receipt();
674 let handler = handler.clone();
675 async move {
676 let peer = session.peer.clone();
677 let responded = Arc::new(AtomicBool::default());
678 let response = Response {
679 peer: peer.clone(),
680 responded: responded.clone(),
681 receipt,
682 };
683 match (handler)(envelope.payload, response, session).await {
684 Ok(()) => {
685 if responded.load(std::sync::atomic::Ordering::SeqCst) {
686 Ok(())
687 } else {
688 Err(anyhow!("handler did not send a response"))?
689 }
690 }
691 Err(error) => {
692 let proto_err = match &error {
693 Error::Internal(err) => err.to_proto(),
694 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
695 };
696 peer.respond_with_error(receipt, proto_err)?;
697 Err(error)
698 }
699 }
700 }
701 })
702 }
703
704 pub fn handle_connection(
705 self: &Arc<Self>,
706 connection: Connection,
707 address: String,
708 principal: Principal,
709 zed_version: ZedVersion,
710 geoip_country_code: Option<String>,
711 system_id: Option<String>,
712 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
713 executor: Executor,
714 ) -> impl Future<Output = ()> + use<> {
715 let this = self.clone();
716 let span = info_span!("handle connection", %address,
717 connection_id=field::Empty,
718 user_id=field::Empty,
719 login=field::Empty,
720 impersonator=field::Empty,
721 geoip_country_code=field::Empty
722 );
723 principal.update_span(&span);
724 if let Some(country_code) = geoip_country_code.as_ref() {
725 span.record("geoip_country_code", country_code);
726 }
727
728 let mut teardown = self.teardown.subscribe();
729 async move {
730 if *teardown.borrow() {
731 tracing::error!("server is tearing down");
732 return
733 }
734 let (connection_id, handle_io, mut incoming_rx) = this
735 .peer
736 .add_connection(connection, {
737 let executor = executor.clone();
738 move |duration| executor.sleep(duration)
739 });
740 tracing::Span::current().record("connection_id", format!("{}", connection_id));
741
742 tracing::info!("connection opened");
743
744 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
745 let http_client = match ReqwestClient::user_agent(&user_agent) {
746 Ok(http_client) => Arc::new(http_client),
747 Err(error) => {
748 tracing::error!(?error, "failed to create HTTP client");
749 return;
750 }
751 };
752
753 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
754 supermaven_admin_api_key.to_string(),
755 http_client.clone(),
756 )));
757
758 let session = Session {
759 principal: principal.clone(),
760 connection_id,
761 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
762 peer: this.peer.clone(),
763 connection_pool: this.connection_pool.clone(),
764 app_state: this.app_state.clone(),
765 geoip_country_code,
766 system_id,
767 _executor: executor.clone(),
768 supermaven_client,
769 };
770
771 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
772 tracing::error!(?error, "failed to send initial client update");
773 return;
774 }
775
776 let handle_io = handle_io.fuse();
777 futures::pin_mut!(handle_io);
778
779 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
780 // This prevents deadlocks when e.g., client A performs a request to client B and
781 // client B performs a request to client A. If both clients stop processing further
782 // messages until their respective request completes, they won't have a chance to
783 // respond to the other client's request and cause a deadlock.
784 //
785 // This arrangement ensures we will attempt to process earlier messages first, but fall
786 // back to processing messages arrived later in the spirit of making progress.
787 let mut foreground_message_handlers = FuturesUnordered::new();
788 let concurrent_handlers = Arc::new(Semaphore::new(256));
789 loop {
790 let next_message = async {
791 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
792 let message = incoming_rx.next().await;
793 (permit, message)
794 }.fuse();
795 futures::pin_mut!(next_message);
796 futures::select_biased! {
797 _ = teardown.changed().fuse() => return,
798 result = handle_io => {
799 if let Err(error) = result {
800 tracing::error!(?error, "error handling I/O");
801 }
802 break;
803 }
804 _ = foreground_message_handlers.next() => {}
805 next_message = next_message => {
806 let (permit, message) = next_message;
807 if let Some(message) = message {
808 let type_name = message.payload_type_name();
809 // note: we copy all the fields from the parent span so we can query them in the logs.
810 // (https://github.com/tokio-rs/tracing/issues/2670).
811 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
812 user_id=field::Empty,
813 login=field::Empty,
814 impersonator=field::Empty,
815 );
816 principal.update_span(&span);
817 let span_enter = span.enter();
818 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
819 let is_background = message.is_background();
820 let handle_message = (handler)(message, session.clone());
821 drop(span_enter);
822
823 let handle_message = async move {
824 handle_message.await;
825 drop(permit);
826 }.instrument(span);
827 if is_background {
828 executor.spawn_detached(handle_message);
829 } else {
830 foreground_message_handlers.push(handle_message);
831 }
832 } else {
833 tracing::error!("no message handler");
834 }
835 } else {
836 tracing::info!("connection closed");
837 break;
838 }
839 }
840 }
841 }
842
843 drop(foreground_message_handlers);
844 tracing::info!("signing out");
845 if let Err(error) = connection_lost(session, teardown, executor).await {
846 tracing::error!(?error, "error signing out");
847 }
848
849 }.instrument(span)
850 }
851
852 async fn send_initial_client_update(
853 &self,
854 connection_id: ConnectionId,
855 principal: &Principal,
856 zed_version: ZedVersion,
857 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
858 session: &Session,
859 ) -> Result<()> {
860 self.peer.send(
861 connection_id,
862 proto::Hello {
863 peer_id: Some(connection_id.into()),
864 },
865 )?;
866 tracing::info!("sent hello message");
867 if let Some(send_connection_id) = send_connection_id.take() {
868 let _ = send_connection_id.send(connection_id);
869 }
870
871 match principal {
872 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
873 if !user.connected_once {
874 self.peer.send(connection_id, proto::ShowContacts {})?;
875 self.app_state
876 .db
877 .set_user_connected_once(user.id, true)
878 .await?;
879 }
880
881 update_user_plan(user.id, session).await?;
882
883 let contacts = self.app_state.db.get_contacts(user.id).await?;
884
885 {
886 let mut pool = self.connection_pool.lock();
887 pool.add_connection(connection_id, user.id, user.admin, zed_version);
888 self.peer.send(
889 connection_id,
890 build_initial_contacts_update(contacts, &pool),
891 )?;
892 }
893
894 if should_auto_subscribe_to_channels(zed_version) {
895 subscribe_user_to_channels(user.id, session).await?;
896 }
897
898 if let Some(incoming_call) =
899 self.app_state.db.incoming_call_for_user(user.id).await?
900 {
901 self.peer.send(connection_id, incoming_call)?;
902 }
903
904 update_user_contacts(user.id, session).await?;
905 }
906 }
907
908 Ok(())
909 }
910
911 pub async fn invite_code_redeemed(
912 self: &Arc<Self>,
913 inviter_id: UserId,
914 invitee_id: UserId,
915 ) -> Result<()> {
916 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
917 if let Some(code) = &user.invite_code {
918 let pool = self.connection_pool.lock();
919 let invitee_contact = contact_for_user(invitee_id, false, &pool);
920 for connection_id in pool.user_connection_ids(inviter_id) {
921 self.peer.send(
922 connection_id,
923 proto::UpdateContacts {
924 contacts: vec![invitee_contact.clone()],
925 ..Default::default()
926 },
927 )?;
928 self.peer.send(
929 connection_id,
930 proto::UpdateInviteInfo {
931 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
932 count: user.invite_count as u32,
933 },
934 )?;
935 }
936 }
937 }
938 Ok(())
939 }
940
941 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
942 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
943 if let Some(invite_code) = &user.invite_code {
944 let pool = self.connection_pool.lock();
945 for connection_id in pool.user_connection_ids(user_id) {
946 self.peer.send(
947 connection_id,
948 proto::UpdateInviteInfo {
949 url: format!(
950 "{}{}",
951 self.app_state.config.invite_link_prefix, invite_code
952 ),
953 count: user.invite_count as u32,
954 },
955 )?;
956 }
957 }
958 }
959 Ok(())
960 }
961
962 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
963 let pool = self.connection_pool.lock();
964 for connection_id in pool.user_connection_ids(user_id) {
965 self.peer
966 .send(connection_id, proto::RefreshLlmToken {})
967 .trace_err();
968 }
969 }
970
971 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
972 ServerSnapshot {
973 connection_pool: ConnectionPoolGuard {
974 guard: self.connection_pool.lock(),
975 _not_send: PhantomData,
976 },
977 peer: &self.peer,
978 }
979 }
980}
981
982impl Deref for ConnectionPoolGuard<'_> {
983 type Target = ConnectionPool;
984
985 fn deref(&self) -> &Self::Target {
986 &self.guard
987 }
988}
989
990impl DerefMut for ConnectionPoolGuard<'_> {
991 fn deref_mut(&mut self) -> &mut Self::Target {
992 &mut self.guard
993 }
994}
995
996impl Drop for ConnectionPoolGuard<'_> {
997 fn drop(&mut self) {
998 #[cfg(test)]
999 self.check_invariants();
1000 }
1001}
1002
1003fn broadcast<F>(
1004 sender_id: Option<ConnectionId>,
1005 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1006 mut f: F,
1007) where
1008 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1009{
1010 for receiver_id in receiver_ids {
1011 if Some(receiver_id) != sender_id {
1012 if let Err(error) = f(receiver_id) {
1013 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1014 }
1015 }
1016 }
1017}
1018
1019pub struct ProtocolVersion(u32);
1020
1021impl Header for ProtocolVersion {
1022 fn name() -> &'static HeaderName {
1023 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1025 }
1026
1027 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028 where
1029 Self: Sized,
1030 I: Iterator<Item = &'i axum::http::HeaderValue>,
1031 {
1032 let version = values
1033 .next()
1034 .ok_or_else(axum::headers::Error::invalid)?
1035 .to_str()
1036 .map_err(|_| axum::headers::Error::invalid())?
1037 .parse()
1038 .map_err(|_| axum::headers::Error::invalid())?;
1039 Ok(Self(version))
1040 }
1041
1042 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043 values.extend([self.0.to_string().parse().unwrap()]);
1044 }
1045}
1046
1047pub struct AppVersionHeader(SemanticVersion);
1048impl Header for AppVersionHeader {
1049 fn name() -> &'static HeaderName {
1050 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1051 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1052 }
1053
1054 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1055 where
1056 Self: Sized,
1057 I: Iterator<Item = &'i axum::http::HeaderValue>,
1058 {
1059 let version = values
1060 .next()
1061 .ok_or_else(axum::headers::Error::invalid)?
1062 .to_str()
1063 .map_err(|_| axum::headers::Error::invalid())?
1064 .parse()
1065 .map_err(|_| axum::headers::Error::invalid())?;
1066 Ok(Self(version))
1067 }
1068
1069 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1070 values.extend([self.0.to_string().parse().unwrap()]);
1071 }
1072}
1073
1074pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1075 Router::new()
1076 .route("/rpc", get(handle_websocket_request))
1077 .layer(
1078 ServiceBuilder::new()
1079 .layer(Extension(server.app_state.clone()))
1080 .layer(middleware::from_fn(auth::validate_header)),
1081 )
1082 .route("/metrics", get(handle_metrics))
1083 .layer(Extension(server))
1084}
1085
1086pub async fn handle_websocket_request(
1087 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1088 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1089 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1090 Extension(server): Extension<Arc<Server>>,
1091 Extension(principal): Extension<Principal>,
1092 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1093 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1094 ws: WebSocketUpgrade,
1095) -> axum::response::Response {
1096 if protocol_version != rpc::PROTOCOL_VERSION {
1097 return (
1098 StatusCode::UPGRADE_REQUIRED,
1099 "client must be upgraded".to_string(),
1100 )
1101 .into_response();
1102 }
1103
1104 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1105 return (
1106 StatusCode::UPGRADE_REQUIRED,
1107 "no version header found".to_string(),
1108 )
1109 .into_response();
1110 };
1111
1112 if !version.can_collaborate() {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "client must be upgraded".to_string(),
1116 )
1117 .into_response();
1118 }
1119
1120 let socket_address = socket_address.to_string();
1121 ws.on_upgrade(move |socket| {
1122 let socket = socket
1123 .map_ok(to_tungstenite_message)
1124 .err_into()
1125 .with(|message| async move { to_axum_message(message) });
1126 let connection = Connection::new(Box::pin(socket));
1127 async move {
1128 server
1129 .handle_connection(
1130 connection,
1131 socket_address,
1132 principal,
1133 version,
1134 country_code_header.map(|header| header.to_string()),
1135 system_id_header.map(|header| header.to_string()),
1136 None,
1137 Executor::Production,
1138 )
1139 .await;
1140 }
1141 })
1142}
1143
1144pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1145 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1146 let connections_metric = CONNECTIONS_METRIC
1147 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1148
1149 let connections = server
1150 .connection_pool
1151 .lock()
1152 .connections()
1153 .filter(|connection| !connection.admin)
1154 .count();
1155 connections_metric.set(connections as _);
1156
1157 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1158 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1159 register_int_gauge!(
1160 "shared_projects",
1161 "number of open projects with one or more guests"
1162 )
1163 .unwrap()
1164 });
1165
1166 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1167 shared_projects_metric.set(shared_projects as _);
1168
1169 let encoder = prometheus::TextEncoder::new();
1170 let metric_families = prometheus::gather();
1171 let encoded_metrics = encoder
1172 .encode_to_string(&metric_families)
1173 .map_err(|err| anyhow!("{}", err))?;
1174 Ok(encoded_metrics)
1175}
1176
1177#[instrument(err, skip(executor))]
1178async fn connection_lost(
1179 session: Session,
1180 mut teardown: watch::Receiver<bool>,
1181 executor: Executor,
1182) -> Result<()> {
1183 session.peer.disconnect(session.connection_id);
1184 session
1185 .connection_pool()
1186 .await
1187 .remove_connection(session.connection_id)?;
1188
1189 session
1190 .db()
1191 .await
1192 .connection_lost(session.connection_id)
1193 .await
1194 .trace_err();
1195
1196 futures::select_biased! {
1197 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1198
1199 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1200 leave_room_for_session(&session, session.connection_id).await.trace_err();
1201 leave_channel_buffers_for_session(&session)
1202 .await
1203 .trace_err();
1204
1205 if !session
1206 .connection_pool()
1207 .await
1208 .is_user_online(session.user_id())
1209 {
1210 let db = session.db().await;
1211 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1212 room_updated(&room, &session.peer);
1213 }
1214 }
1215
1216 update_user_contacts(session.user_id(), &session).await?;
1217 },
1218 _ = teardown.changed().fuse() => {}
1219 }
1220
1221 Ok(())
1222}
1223
1224/// Acknowledges a ping from a client, used to keep the connection alive.
1225async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1226 response.send(proto::Ack {})?;
1227 Ok(())
1228}
1229
1230/// Creates a new room for calling (outside of channels)
1231async fn create_room(
1232 _request: proto::CreateRoom,
1233 response: Response<proto::CreateRoom>,
1234 session: Session,
1235) -> Result<()> {
1236 let livekit_room = nanoid::nanoid!(30);
1237
1238 let live_kit_connection_info = util::maybe!(async {
1239 let live_kit = session.app_state.livekit_client.as_ref();
1240 let live_kit = live_kit?;
1241 let user_id = session.user_id().to_string();
1242
1243 let token = live_kit
1244 .room_token(&livekit_room, &user_id.to_string())
1245 .trace_err()?;
1246
1247 Some(proto::LiveKitConnectionInfo {
1248 server_url: live_kit.url().into(),
1249 token,
1250 can_publish: true,
1251 })
1252 })
1253 .await;
1254
1255 let room = session
1256 .db()
1257 .await
1258 .create_room(session.user_id(), session.connection_id, &livekit_room)
1259 .await?;
1260
1261 response.send(proto::CreateRoomResponse {
1262 room: Some(room.clone()),
1263 live_kit_connection_info,
1264 })?;
1265
1266 update_user_contacts(session.user_id(), &session).await?;
1267 Ok(())
1268}
1269
1270/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1271async fn join_room(
1272 request: proto::JoinRoom,
1273 response: Response<proto::JoinRoom>,
1274 session: Session,
1275) -> Result<()> {
1276 let room_id = RoomId::from_proto(request.id);
1277
1278 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1279
1280 if let Some(channel_id) = channel_id {
1281 return join_channel_internal(channel_id, Box::new(response), session).await;
1282 }
1283
1284 let joined_room = {
1285 let room = session
1286 .db()
1287 .await
1288 .join_room(room_id, session.user_id(), session.connection_id)
1289 .await?;
1290 room_updated(&room.room, &session.peer);
1291 room.into_inner()
1292 };
1293
1294 for connection_id in session
1295 .connection_pool()
1296 .await
1297 .user_connection_ids(session.user_id())
1298 {
1299 session
1300 .peer
1301 .send(
1302 connection_id,
1303 proto::CallCanceled {
1304 room_id: room_id.to_proto(),
1305 },
1306 )
1307 .trace_err();
1308 }
1309
1310 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1311 {
1312 live_kit
1313 .room_token(
1314 &joined_room.room.livekit_room,
1315 &session.user_id().to_string(),
1316 )
1317 .trace_err()
1318 .map(|token| proto::LiveKitConnectionInfo {
1319 server_url: live_kit.url().into(),
1320 token,
1321 can_publish: true,
1322 })
1323 } else {
1324 None
1325 };
1326
1327 response.send(proto::JoinRoomResponse {
1328 room: Some(joined_room.room),
1329 channel_id: None,
1330 live_kit_connection_info,
1331 })?;
1332
1333 update_user_contacts(session.user_id(), &session).await?;
1334 Ok(())
1335}
1336
1337/// Rejoin room is used to reconnect to a room after connection errors.
1338async fn rejoin_room(
1339 request: proto::RejoinRoom,
1340 response: Response<proto::RejoinRoom>,
1341 session: Session,
1342) -> Result<()> {
1343 let room;
1344 let channel;
1345 {
1346 let mut rejoined_room = session
1347 .db()
1348 .await
1349 .rejoin_room(request, session.user_id(), session.connection_id)
1350 .await?;
1351
1352 response.send(proto::RejoinRoomResponse {
1353 room: Some(rejoined_room.room.clone()),
1354 reshared_projects: rejoined_room
1355 .reshared_projects
1356 .iter()
1357 .map(|project| proto::ResharedProject {
1358 id: project.id.to_proto(),
1359 collaborators: project
1360 .collaborators
1361 .iter()
1362 .map(|collaborator| collaborator.to_proto())
1363 .collect(),
1364 })
1365 .collect(),
1366 rejoined_projects: rejoined_room
1367 .rejoined_projects
1368 .iter()
1369 .map(|rejoined_project| rejoined_project.to_proto())
1370 .collect(),
1371 })?;
1372 room_updated(&rejoined_room.room, &session.peer);
1373
1374 for project in &rejoined_room.reshared_projects {
1375 for collaborator in &project.collaborators {
1376 session
1377 .peer
1378 .send(
1379 collaborator.connection_id,
1380 proto::UpdateProjectCollaborator {
1381 project_id: project.id.to_proto(),
1382 old_peer_id: Some(project.old_connection_id.into()),
1383 new_peer_id: Some(session.connection_id.into()),
1384 },
1385 )
1386 .trace_err();
1387 }
1388
1389 broadcast(
1390 Some(session.connection_id),
1391 project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.connection_id),
1395 |connection_id| {
1396 session.peer.forward_send(
1397 session.connection_id,
1398 connection_id,
1399 proto::UpdateProject {
1400 project_id: project.id.to_proto(),
1401 worktrees: project.worktrees.clone(),
1402 },
1403 )
1404 },
1405 );
1406 }
1407
1408 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1409
1410 let rejoined_room = rejoined_room.into_inner();
1411
1412 room = rejoined_room.room;
1413 channel = rejoined_room.channel;
1414 }
1415
1416 if let Some(channel) = channel {
1417 channel_updated(
1418 &channel,
1419 &room,
1420 &session.peer,
1421 &*session.connection_pool().await,
1422 );
1423 }
1424
1425 update_user_contacts(session.user_id(), &session).await?;
1426 Ok(())
1427}
1428
1429fn notify_rejoined_projects(
1430 rejoined_projects: &mut Vec<RejoinedProject>,
1431 session: &Session,
1432) -> Result<()> {
1433 for project in rejoined_projects.iter() {
1434 for collaborator in &project.collaborators {
1435 session
1436 .peer
1437 .send(
1438 collaborator.connection_id,
1439 proto::UpdateProjectCollaborator {
1440 project_id: project.id.to_proto(),
1441 old_peer_id: Some(project.old_connection_id.into()),
1442 new_peer_id: Some(session.connection_id.into()),
1443 },
1444 )
1445 .trace_err();
1446 }
1447 }
1448
1449 for project in rejoined_projects {
1450 for worktree in mem::take(&mut project.worktrees) {
1451 // Stream this worktree's entries.
1452 let message = proto::UpdateWorktree {
1453 project_id: project.id.to_proto(),
1454 worktree_id: worktree.id,
1455 abs_path: worktree.abs_path.clone(),
1456 root_name: worktree.root_name,
1457 updated_entries: worktree.updated_entries,
1458 removed_entries: worktree.removed_entries,
1459 scan_id: worktree.scan_id,
1460 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1461 updated_repositories: worktree.updated_repositories,
1462 removed_repositories: worktree.removed_repositories,
1463 };
1464 for update in proto::split_worktree_update(message) {
1465 session.peer.send(session.connection_id, update)?;
1466 }
1467
1468 // Stream this worktree's diagnostics.
1469 for summary in worktree.diagnostic_summaries {
1470 session.peer.send(
1471 session.connection_id,
1472 proto::UpdateDiagnosticSummary {
1473 project_id: project.id.to_proto(),
1474 worktree_id: worktree.id,
1475 summary: Some(summary),
1476 },
1477 )?;
1478 }
1479
1480 for settings_file in worktree.settings_files {
1481 session.peer.send(
1482 session.connection_id,
1483 proto::UpdateWorktreeSettings {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 path: settings_file.path,
1487 content: Some(settings_file.content),
1488 kind: Some(settings_file.kind.to_proto().into()),
1489 },
1490 )?;
1491 }
1492 }
1493
1494 for repository in mem::take(&mut project.updated_repositories) {
1495 for update in split_repository_update(repository) {
1496 session.peer.send(session.connection_id, update)?;
1497 }
1498 }
1499
1500 for id in mem::take(&mut project.removed_repositories) {
1501 session.peer.send(
1502 session.connection_id,
1503 proto::RemoveRepository {
1504 project_id: project.id.to_proto(),
1505 id,
1506 },
1507 )?;
1508 }
1509 }
1510
1511 Ok(())
1512}
1513
1514/// leave room disconnects from the room.
1515async fn leave_room(
1516 _: proto::LeaveRoom,
1517 response: Response<proto::LeaveRoom>,
1518 session: Session,
1519) -> Result<()> {
1520 leave_room_for_session(&session, session.connection_id).await?;
1521 response.send(proto::Ack {})?;
1522 Ok(())
1523}
1524
1525/// Updates the permissions of someone else in the room.
1526async fn set_room_participant_role(
1527 request: proto::SetRoomParticipantRole,
1528 response: Response<proto::SetRoomParticipantRole>,
1529 session: Session,
1530) -> Result<()> {
1531 let user_id = UserId::from_proto(request.user_id);
1532 let role = ChannelRole::from(request.role());
1533
1534 let (livekit_room, can_publish) = {
1535 let room = session
1536 .db()
1537 .await
1538 .set_room_participant_role(
1539 session.user_id(),
1540 RoomId::from_proto(request.room_id),
1541 user_id,
1542 role,
1543 )
1544 .await?;
1545
1546 let livekit_room = room.livekit_room.clone();
1547 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1548 room_updated(&room, &session.peer);
1549 (livekit_room, can_publish)
1550 };
1551
1552 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1553 live_kit
1554 .update_participant(
1555 livekit_room.clone(),
1556 request.user_id.to_string(),
1557 livekit_api::proto::ParticipantPermission {
1558 can_subscribe: true,
1559 can_publish,
1560 can_publish_data: can_publish,
1561 hidden: false,
1562 recorder: false,
1563 },
1564 )
1565 .await
1566 .trace_err();
1567 }
1568
1569 response.send(proto::Ack {})?;
1570 Ok(())
1571}
1572
1573/// Call someone else into the current room
1574async fn call(
1575 request: proto::Call,
1576 response: Response<proto::Call>,
1577 session: Session,
1578) -> Result<()> {
1579 let room_id = RoomId::from_proto(request.room_id);
1580 let calling_user_id = session.user_id();
1581 let calling_connection_id = session.connection_id;
1582 let called_user_id = UserId::from_proto(request.called_user_id);
1583 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1584 if !session
1585 .db()
1586 .await
1587 .has_contact(calling_user_id, called_user_id)
1588 .await?
1589 {
1590 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1591 }
1592
1593 let incoming_call = {
1594 let (room, incoming_call) = &mut *session
1595 .db()
1596 .await
1597 .call(
1598 room_id,
1599 calling_user_id,
1600 calling_connection_id,
1601 called_user_id,
1602 initial_project_id,
1603 )
1604 .await?;
1605 room_updated(room, &session.peer);
1606 mem::take(incoming_call)
1607 };
1608 update_user_contacts(called_user_id, &session).await?;
1609
1610 let mut calls = session
1611 .connection_pool()
1612 .await
1613 .user_connection_ids(called_user_id)
1614 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1615 .collect::<FuturesUnordered<_>>();
1616
1617 while let Some(call_response) = calls.next().await {
1618 match call_response.as_ref() {
1619 Ok(_) => {
1620 response.send(proto::Ack {})?;
1621 return Ok(());
1622 }
1623 Err(_) => {
1624 call_response.trace_err();
1625 }
1626 }
1627 }
1628
1629 {
1630 let room = session
1631 .db()
1632 .await
1633 .call_failed(room_id, called_user_id)
1634 .await?;
1635 room_updated(&room, &session.peer);
1636 }
1637 update_user_contacts(called_user_id, &session).await?;
1638
1639 Err(anyhow!("failed to ring user"))?
1640}
1641
1642/// Cancel an outgoing call.
1643async fn cancel_call(
1644 request: proto::CancelCall,
1645 response: Response<proto::CancelCall>,
1646 session: Session,
1647) -> Result<()> {
1648 let called_user_id = UserId::from_proto(request.called_user_id);
1649 let room_id = RoomId::from_proto(request.room_id);
1650 {
1651 let room = session
1652 .db()
1653 .await
1654 .cancel_call(room_id, session.connection_id, called_user_id)
1655 .await?;
1656 room_updated(&room, &session.peer);
1657 }
1658
1659 for connection_id in session
1660 .connection_pool()
1661 .await
1662 .user_connection_ids(called_user_id)
1663 {
1664 session
1665 .peer
1666 .send(
1667 connection_id,
1668 proto::CallCanceled {
1669 room_id: room_id.to_proto(),
1670 },
1671 )
1672 .trace_err();
1673 }
1674 response.send(proto::Ack {})?;
1675
1676 update_user_contacts(called_user_id, &session).await?;
1677 Ok(())
1678}
1679
1680/// Decline an incoming call.
1681async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1682 let room_id = RoomId::from_proto(message.room_id);
1683 {
1684 let room = session
1685 .db()
1686 .await
1687 .decline_call(Some(room_id), session.user_id())
1688 .await?
1689 .ok_or_else(|| anyhow!("failed to decline call"))?;
1690 room_updated(&room, &session.peer);
1691 }
1692
1693 for connection_id in session
1694 .connection_pool()
1695 .await
1696 .user_connection_ids(session.user_id())
1697 {
1698 session
1699 .peer
1700 .send(
1701 connection_id,
1702 proto::CallCanceled {
1703 room_id: room_id.to_proto(),
1704 },
1705 )
1706 .trace_err();
1707 }
1708 update_user_contacts(session.user_id(), &session).await?;
1709 Ok(())
1710}
1711
1712/// Updates other participants in the room with your current location.
1713async fn update_participant_location(
1714 request: proto::UpdateParticipantLocation,
1715 response: Response<proto::UpdateParticipantLocation>,
1716 session: Session,
1717) -> Result<()> {
1718 let room_id = RoomId::from_proto(request.room_id);
1719 let location = request
1720 .location
1721 .ok_or_else(|| anyhow!("invalid location"))?;
1722
1723 let db = session.db().await;
1724 let room = db
1725 .update_room_participant_location(room_id, session.connection_id, location)
1726 .await?;
1727
1728 room_updated(&room, &session.peer);
1729 response.send(proto::Ack {})?;
1730 Ok(())
1731}
1732
1733/// Share a project into the room.
1734async fn share_project(
1735 request: proto::ShareProject,
1736 response: Response<proto::ShareProject>,
1737 session: Session,
1738) -> Result<()> {
1739 let (project_id, room) = &*session
1740 .db()
1741 .await
1742 .share_project(
1743 RoomId::from_proto(request.room_id),
1744 session.connection_id,
1745 &request.worktrees,
1746 request.is_ssh_project,
1747 )
1748 .await?;
1749 response.send(proto::ShareProjectResponse {
1750 project_id: project_id.to_proto(),
1751 })?;
1752 room_updated(room, &session.peer);
1753
1754 Ok(())
1755}
1756
1757/// Unshare a project from the room.
1758async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1759 let project_id = ProjectId::from_proto(message.project_id);
1760 unshare_project_internal(project_id, session.connection_id, &session).await
1761}
1762
1763async fn unshare_project_internal(
1764 project_id: ProjectId,
1765 connection_id: ConnectionId,
1766 session: &Session,
1767) -> Result<()> {
1768 let delete = {
1769 let room_guard = session
1770 .db()
1771 .await
1772 .unshare_project(project_id, connection_id)
1773 .await?;
1774
1775 let (delete, room, guest_connection_ids) = &*room_guard;
1776
1777 let message = proto::UnshareProject {
1778 project_id: project_id.to_proto(),
1779 };
1780
1781 broadcast(
1782 Some(connection_id),
1783 guest_connection_ids.iter().copied(),
1784 |conn_id| session.peer.send(conn_id, message.clone()),
1785 );
1786 if let Some(room) = room {
1787 room_updated(room, &session.peer);
1788 }
1789
1790 *delete
1791 };
1792
1793 if delete {
1794 let db = session.db().await;
1795 db.delete_project(project_id).await?;
1796 }
1797
1798 Ok(())
1799}
1800
1801/// Join someone elses shared project.
1802async fn join_project(
1803 request: proto::JoinProject,
1804 response: Response<proto::JoinProject>,
1805 session: Session,
1806) -> Result<()> {
1807 let project_id = ProjectId::from_proto(request.project_id);
1808
1809 tracing::info!(%project_id, "join project");
1810
1811 let db = session.db().await;
1812 let (project, replica_id) = &mut *db
1813 .join_project(project_id, session.connection_id, session.user_id())
1814 .await?;
1815 drop(db);
1816 tracing::info!(%project_id, "join remote project");
1817 join_project_internal(response, session, project, replica_id)
1818}
1819
1820trait JoinProjectInternalResponse {
1821 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1822}
1823impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1824 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1825 Response::<proto::JoinProject>::send(self, result)
1826 }
1827}
1828
1829fn join_project_internal(
1830 response: impl JoinProjectInternalResponse,
1831 session: Session,
1832 project: &mut Project,
1833 replica_id: &ReplicaId,
1834) -> Result<()> {
1835 let collaborators = project
1836 .collaborators
1837 .iter()
1838 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1839 .map(|collaborator| collaborator.to_proto())
1840 .collect::<Vec<_>>();
1841 let project_id = project.id;
1842 let guest_user_id = session.user_id();
1843
1844 let worktrees = project
1845 .worktrees
1846 .iter()
1847 .map(|(id, worktree)| proto::WorktreeMetadata {
1848 id: *id,
1849 root_name: worktree.root_name.clone(),
1850 visible: worktree.visible,
1851 abs_path: worktree.abs_path.clone(),
1852 })
1853 .collect::<Vec<_>>();
1854
1855 let add_project_collaborator = proto::AddProjectCollaborator {
1856 project_id: project_id.to_proto(),
1857 collaborator: Some(proto::Collaborator {
1858 peer_id: Some(session.connection_id.into()),
1859 replica_id: replica_id.0 as u32,
1860 user_id: guest_user_id.to_proto(),
1861 is_host: false,
1862 }),
1863 };
1864
1865 for collaborator in &collaborators {
1866 session
1867 .peer
1868 .send(
1869 collaborator.peer_id.unwrap().into(),
1870 add_project_collaborator.clone(),
1871 )
1872 .trace_err();
1873 }
1874
1875 // First, we send the metadata associated with each worktree.
1876 response.send(proto::JoinProjectResponse {
1877 project_id: project.id.0 as u64,
1878 worktrees: worktrees.clone(),
1879 replica_id: replica_id.0 as u32,
1880 collaborators: collaborators.clone(),
1881 language_servers: project.language_servers.clone(),
1882 role: project.role.into(),
1883 })?;
1884
1885 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1886 // Stream this worktree's entries.
1887 let message = proto::UpdateWorktree {
1888 project_id: project_id.to_proto(),
1889 worktree_id,
1890 abs_path: worktree.abs_path.clone(),
1891 root_name: worktree.root_name,
1892 updated_entries: worktree.entries,
1893 removed_entries: Default::default(),
1894 scan_id: worktree.scan_id,
1895 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1896 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1897 removed_repositories: Default::default(),
1898 };
1899 for update in proto::split_worktree_update(message) {
1900 session.peer.send(session.connection_id, update.clone())?;
1901 }
1902
1903 // Stream this worktree's diagnostics.
1904 for summary in worktree.diagnostic_summaries {
1905 session.peer.send(
1906 session.connection_id,
1907 proto::UpdateDiagnosticSummary {
1908 project_id: project_id.to_proto(),
1909 worktree_id: worktree.id,
1910 summary: Some(summary),
1911 },
1912 )?;
1913 }
1914
1915 for settings_file in worktree.settings_files {
1916 session.peer.send(
1917 session.connection_id,
1918 proto::UpdateWorktreeSettings {
1919 project_id: project_id.to_proto(),
1920 worktree_id: worktree.id,
1921 path: settings_file.path,
1922 content: Some(settings_file.content),
1923 kind: Some(settings_file.kind.to_proto() as i32),
1924 },
1925 )?;
1926 }
1927 }
1928
1929 for repository in mem::take(&mut project.repositories) {
1930 for update in split_repository_update(repository) {
1931 session.peer.send(session.connection_id, update)?;
1932 }
1933 }
1934
1935 for language_server in &project.language_servers {
1936 session.peer.send(
1937 session.connection_id,
1938 proto::UpdateLanguageServer {
1939 project_id: project_id.to_proto(),
1940 language_server_id: language_server.id,
1941 variant: Some(
1942 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1943 proto::LspDiskBasedDiagnosticsUpdated {},
1944 ),
1945 ),
1946 },
1947 )?;
1948 }
1949
1950 Ok(())
1951}
1952
1953/// Leave someone elses shared project.
1954async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1955 let sender_id = session.connection_id;
1956 let project_id = ProjectId::from_proto(request.project_id);
1957 let db = session.db().await;
1958
1959 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1960 tracing::info!(
1961 %project_id,
1962 "leave project"
1963 );
1964
1965 project_left(project, &session);
1966 if let Some(room) = room {
1967 room_updated(room, &session.peer);
1968 }
1969
1970 Ok(())
1971}
1972
1973/// Updates other participants with changes to the project
1974async fn update_project(
1975 request: proto::UpdateProject,
1976 response: Response<proto::UpdateProject>,
1977 session: Session,
1978) -> Result<()> {
1979 let project_id = ProjectId::from_proto(request.project_id);
1980 let (room, guest_connection_ids) = &*session
1981 .db()
1982 .await
1983 .update_project(project_id, session.connection_id, &request.worktrees)
1984 .await?;
1985 broadcast(
1986 Some(session.connection_id),
1987 guest_connection_ids.iter().copied(),
1988 |connection_id| {
1989 session
1990 .peer
1991 .forward_send(session.connection_id, connection_id, request.clone())
1992 },
1993 );
1994 if let Some(room) = room {
1995 room_updated(room, &session.peer);
1996 }
1997 response.send(proto::Ack {})?;
1998
1999 Ok(())
2000}
2001
2002/// Updates other participants with changes to the worktree
2003async fn update_worktree(
2004 request: proto::UpdateWorktree,
2005 response: Response<proto::UpdateWorktree>,
2006 session: Session,
2007) -> Result<()> {
2008 let guest_connection_ids = session
2009 .db()
2010 .await
2011 .update_worktree(&request, session.connection_id)
2012 .await?;
2013
2014 broadcast(
2015 Some(session.connection_id),
2016 guest_connection_ids.iter().copied(),
2017 |connection_id| {
2018 session
2019 .peer
2020 .forward_send(session.connection_id, connection_id, request.clone())
2021 },
2022 );
2023 response.send(proto::Ack {})?;
2024 Ok(())
2025}
2026
2027async fn update_repository(
2028 request: proto::UpdateRepository,
2029 response: Response<proto::UpdateRepository>,
2030 session: Session,
2031) -> Result<()> {
2032 let guest_connection_ids = session
2033 .db()
2034 .await
2035 .update_repository(&request, session.connection_id)
2036 .await?;
2037
2038 broadcast(
2039 Some(session.connection_id),
2040 guest_connection_ids.iter().copied(),
2041 |connection_id| {
2042 session
2043 .peer
2044 .forward_send(session.connection_id, connection_id, request.clone())
2045 },
2046 );
2047 response.send(proto::Ack {})?;
2048 Ok(())
2049}
2050
2051async fn remove_repository(
2052 request: proto::RemoveRepository,
2053 response: Response<proto::RemoveRepository>,
2054 session: Session,
2055) -> Result<()> {
2056 let guest_connection_ids = session
2057 .db()
2058 .await
2059 .remove_repository(&request, session.connection_id)
2060 .await?;
2061
2062 broadcast(
2063 Some(session.connection_id),
2064 guest_connection_ids.iter().copied(),
2065 |connection_id| {
2066 session
2067 .peer
2068 .forward_send(session.connection_id, connection_id, request.clone())
2069 },
2070 );
2071 response.send(proto::Ack {})?;
2072 Ok(())
2073}
2074
2075/// Updates other participants with changes to the diagnostics
2076async fn update_diagnostic_summary(
2077 message: proto::UpdateDiagnosticSummary,
2078 session: Session,
2079) -> Result<()> {
2080 let guest_connection_ids = session
2081 .db()
2082 .await
2083 .update_diagnostic_summary(&message, session.connection_id)
2084 .await?;
2085
2086 broadcast(
2087 Some(session.connection_id),
2088 guest_connection_ids.iter().copied(),
2089 |connection_id| {
2090 session
2091 .peer
2092 .forward_send(session.connection_id, connection_id, message.clone())
2093 },
2094 );
2095
2096 Ok(())
2097}
2098
2099/// Updates other participants with changes to the worktree settings
2100async fn update_worktree_settings(
2101 message: proto::UpdateWorktreeSettings,
2102 session: Session,
2103) -> Result<()> {
2104 let guest_connection_ids = session
2105 .db()
2106 .await
2107 .update_worktree_settings(&message, session.connection_id)
2108 .await?;
2109
2110 broadcast(
2111 Some(session.connection_id),
2112 guest_connection_ids.iter().copied(),
2113 |connection_id| {
2114 session
2115 .peer
2116 .forward_send(session.connection_id, connection_id, message.clone())
2117 },
2118 );
2119
2120 Ok(())
2121}
2122
2123/// Notify other participants that a language server has started.
2124async fn start_language_server(
2125 request: proto::StartLanguageServer,
2126 session: Session,
2127) -> Result<()> {
2128 let guest_connection_ids = session
2129 .db()
2130 .await
2131 .start_language_server(&request, session.connection_id)
2132 .await?;
2133
2134 broadcast(
2135 Some(session.connection_id),
2136 guest_connection_ids.iter().copied(),
2137 |connection_id| {
2138 session
2139 .peer
2140 .forward_send(session.connection_id, connection_id, request.clone())
2141 },
2142 );
2143 Ok(())
2144}
2145
2146/// Notify other participants that a language server has changed.
2147async fn update_language_server(
2148 request: proto::UpdateLanguageServer,
2149 session: Session,
2150) -> Result<()> {
2151 let project_id = ProjectId::from_proto(request.project_id);
2152 let project_connection_ids = session
2153 .db()
2154 .await
2155 .project_connection_ids(project_id, session.connection_id, true)
2156 .await?;
2157 broadcast(
2158 Some(session.connection_id),
2159 project_connection_ids.iter().copied(),
2160 |connection_id| {
2161 session
2162 .peer
2163 .forward_send(session.connection_id, connection_id, request.clone())
2164 },
2165 );
2166 Ok(())
2167}
2168
2169/// forward a project request to the host. These requests should be read only
2170/// as guests are allowed to send them.
2171async fn forward_read_only_project_request<T>(
2172 request: T,
2173 response: Response<T>,
2174 session: Session,
2175) -> Result<()>
2176where
2177 T: EntityMessage + RequestMessage,
2178{
2179 let project_id = ProjectId::from_proto(request.remote_entity_id());
2180 let host_connection_id = session
2181 .db()
2182 .await
2183 .host_for_read_only_project_request(project_id, session.connection_id)
2184 .await?;
2185 let payload = session
2186 .peer
2187 .forward_request(session.connection_id, host_connection_id, request)
2188 .await?;
2189 response.send(payload)?;
2190 Ok(())
2191}
2192
2193async fn forward_find_search_candidates_request(
2194 request: proto::FindSearchCandidates,
2195 response: Response<proto::FindSearchCandidates>,
2196 session: Session,
2197) -> Result<()> {
2198 let project_id = ProjectId::from_proto(request.remote_entity_id());
2199 let host_connection_id = session
2200 .db()
2201 .await
2202 .host_for_read_only_project_request(project_id, session.connection_id)
2203 .await?;
2204 let payload = session
2205 .peer
2206 .forward_request(session.connection_id, host_connection_id, request)
2207 .await?;
2208 response.send(payload)?;
2209 Ok(())
2210}
2211
2212/// forward a project request to the host. These requests are disallowed
2213/// for guests.
2214async fn forward_mutating_project_request<T>(
2215 request: T,
2216 response: Response<T>,
2217 session: Session,
2218) -> Result<()>
2219where
2220 T: EntityMessage + RequestMessage,
2221{
2222 let project_id = ProjectId::from_proto(request.remote_entity_id());
2223
2224 let host_connection_id = session
2225 .db()
2226 .await
2227 .host_for_mutating_project_request(project_id, session.connection_id)
2228 .await?;
2229 let payload = session
2230 .peer
2231 .forward_request(session.connection_id, host_connection_id, request)
2232 .await?;
2233 response.send(payload)?;
2234 Ok(())
2235}
2236
2237/// Notify other participants that a new buffer has been created
2238async fn create_buffer_for_peer(
2239 request: proto::CreateBufferForPeer,
2240 session: Session,
2241) -> Result<()> {
2242 session
2243 .db()
2244 .await
2245 .check_user_is_project_host(
2246 ProjectId::from_proto(request.project_id),
2247 session.connection_id,
2248 )
2249 .await?;
2250 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2251 session
2252 .peer
2253 .forward_send(session.connection_id, peer_id.into(), request)?;
2254 Ok(())
2255}
2256
2257/// Notify other participants that a buffer has been updated. This is
2258/// allowed for guests as long as the update is limited to selections.
2259async fn update_buffer(
2260 request: proto::UpdateBuffer,
2261 response: Response<proto::UpdateBuffer>,
2262 session: Session,
2263) -> Result<()> {
2264 let project_id = ProjectId::from_proto(request.project_id);
2265 let mut capability = Capability::ReadOnly;
2266
2267 for op in request.operations.iter() {
2268 match op.variant {
2269 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2270 Some(_) => capability = Capability::ReadWrite,
2271 }
2272 }
2273
2274 let host = {
2275 let guard = session
2276 .db()
2277 .await
2278 .connections_for_buffer_update(project_id, session.connection_id, capability)
2279 .await?;
2280
2281 let (host, guests) = &*guard;
2282
2283 broadcast(
2284 Some(session.connection_id),
2285 guests.clone(),
2286 |connection_id| {
2287 session
2288 .peer
2289 .forward_send(session.connection_id, connection_id, request.clone())
2290 },
2291 );
2292
2293 *host
2294 };
2295
2296 if host != session.connection_id {
2297 session
2298 .peer
2299 .forward_request(session.connection_id, host, request.clone())
2300 .await?;
2301 }
2302
2303 response.send(proto::Ack {})?;
2304 Ok(())
2305}
2306
2307async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2308 let project_id = ProjectId::from_proto(message.project_id);
2309
2310 let operation = message.operation.as_ref().context("invalid operation")?;
2311 let capability = match operation.variant.as_ref() {
2312 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2313 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2314 match buffer_op.variant {
2315 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2316 Capability::ReadOnly
2317 }
2318 _ => Capability::ReadWrite,
2319 }
2320 } else {
2321 Capability::ReadWrite
2322 }
2323 }
2324 Some(_) => Capability::ReadWrite,
2325 None => Capability::ReadOnly,
2326 };
2327
2328 let guard = session
2329 .db()
2330 .await
2331 .connections_for_buffer_update(project_id, session.connection_id, capability)
2332 .await?;
2333
2334 let (host, guests) = &*guard;
2335
2336 broadcast(
2337 Some(session.connection_id),
2338 guests.iter().chain([host]).copied(),
2339 |connection_id| {
2340 session
2341 .peer
2342 .forward_send(session.connection_id, connection_id, message.clone())
2343 },
2344 );
2345
2346 Ok(())
2347}
2348
2349/// Notify other participants that a project has been updated.
2350async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2351 request: T,
2352 session: Session,
2353) -> Result<()> {
2354 let project_id = ProjectId::from_proto(request.remote_entity_id());
2355 let project_connection_ids = session
2356 .db()
2357 .await
2358 .project_connection_ids(project_id, session.connection_id, false)
2359 .await?;
2360
2361 broadcast(
2362 Some(session.connection_id),
2363 project_connection_ids.iter().copied(),
2364 |connection_id| {
2365 session
2366 .peer
2367 .forward_send(session.connection_id, connection_id, request.clone())
2368 },
2369 );
2370 Ok(())
2371}
2372
2373/// Start following another user in a call.
2374async fn follow(
2375 request: proto::Follow,
2376 response: Response<proto::Follow>,
2377 session: Session,
2378) -> Result<()> {
2379 let room_id = RoomId::from_proto(request.room_id);
2380 let project_id = request.project_id.map(ProjectId::from_proto);
2381 let leader_id = request
2382 .leader_id
2383 .ok_or_else(|| anyhow!("invalid leader id"))?
2384 .into();
2385 let follower_id = session.connection_id;
2386
2387 session
2388 .db()
2389 .await
2390 .check_room_participants(room_id, leader_id, session.connection_id)
2391 .await?;
2392
2393 let response_payload = session
2394 .peer
2395 .forward_request(session.connection_id, leader_id, request)
2396 .await?;
2397 response.send(response_payload)?;
2398
2399 if let Some(project_id) = project_id {
2400 let room = session
2401 .db()
2402 .await
2403 .follow(room_id, project_id, leader_id, follower_id)
2404 .await?;
2405 room_updated(&room, &session.peer);
2406 }
2407
2408 Ok(())
2409}
2410
2411/// Stop following another user in a call.
2412async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2413 let room_id = RoomId::from_proto(request.room_id);
2414 let project_id = request.project_id.map(ProjectId::from_proto);
2415 let leader_id = request
2416 .leader_id
2417 .ok_or_else(|| anyhow!("invalid leader id"))?
2418 .into();
2419 let follower_id = session.connection_id;
2420
2421 session
2422 .db()
2423 .await
2424 .check_room_participants(room_id, leader_id, session.connection_id)
2425 .await?;
2426
2427 session
2428 .peer
2429 .forward_send(session.connection_id, leader_id, request)?;
2430
2431 if let Some(project_id) = project_id {
2432 let room = session
2433 .db()
2434 .await
2435 .unfollow(room_id, project_id, leader_id, follower_id)
2436 .await?;
2437 room_updated(&room, &session.peer);
2438 }
2439
2440 Ok(())
2441}
2442
2443/// Notify everyone following you of your current location.
2444async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2445 let room_id = RoomId::from_proto(request.room_id);
2446 let database = session.db.lock().await;
2447
2448 let connection_ids = if let Some(project_id) = request.project_id {
2449 let project_id = ProjectId::from_proto(project_id);
2450 database
2451 .project_connection_ids(project_id, session.connection_id, true)
2452 .await?
2453 } else {
2454 database
2455 .room_connection_ids(room_id, session.connection_id)
2456 .await?
2457 };
2458
2459 // For now, don't send view update messages back to that view's current leader.
2460 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2461 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2462 _ => None,
2463 });
2464
2465 for connection_id in connection_ids.iter().cloned() {
2466 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2467 session
2468 .peer
2469 .forward_send(session.connection_id, connection_id, request.clone())?;
2470 }
2471 }
2472 Ok(())
2473}
2474
2475/// Get public data about users.
2476async fn get_users(
2477 request: proto::GetUsers,
2478 response: Response<proto::GetUsers>,
2479 session: Session,
2480) -> Result<()> {
2481 let user_ids = request
2482 .user_ids
2483 .into_iter()
2484 .map(UserId::from_proto)
2485 .collect();
2486 let users = session
2487 .db()
2488 .await
2489 .get_users_by_ids(user_ids)
2490 .await?
2491 .into_iter()
2492 .map(|user| proto::User {
2493 id: user.id.to_proto(),
2494 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2495 github_login: user.github_login,
2496 email: user.email_address,
2497 name: user.name,
2498 })
2499 .collect();
2500 response.send(proto::UsersResponse { users })?;
2501 Ok(())
2502}
2503
2504/// Search for users (to invite) buy Github login
2505async fn fuzzy_search_users(
2506 request: proto::FuzzySearchUsers,
2507 response: Response<proto::FuzzySearchUsers>,
2508 session: Session,
2509) -> Result<()> {
2510 let query = request.query;
2511 let users = match query.len() {
2512 0 => vec![],
2513 1 | 2 => session
2514 .db()
2515 .await
2516 .get_user_by_github_login(&query)
2517 .await?
2518 .into_iter()
2519 .collect(),
2520 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2521 };
2522 let users = users
2523 .into_iter()
2524 .filter(|user| user.id != session.user_id())
2525 .map(|user| proto::User {
2526 id: user.id.to_proto(),
2527 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2528 github_login: user.github_login,
2529 name: user.name,
2530 email: user.email_address,
2531 })
2532 .collect();
2533 response.send(proto::UsersResponse { users })?;
2534 Ok(())
2535}
2536
2537/// Send a contact request to another user.
2538async fn request_contact(
2539 request: proto::RequestContact,
2540 response: Response<proto::RequestContact>,
2541 session: Session,
2542) -> Result<()> {
2543 let requester_id = session.user_id();
2544 let responder_id = UserId::from_proto(request.responder_id);
2545 if requester_id == responder_id {
2546 return Err(anyhow!("cannot add yourself as a contact"))?;
2547 }
2548
2549 let notifications = session
2550 .db()
2551 .await
2552 .send_contact_request(requester_id, responder_id)
2553 .await?;
2554
2555 // Update outgoing contact requests of requester
2556 let mut update = proto::UpdateContacts::default();
2557 update.outgoing_requests.push(responder_id.to_proto());
2558 for connection_id in session
2559 .connection_pool()
2560 .await
2561 .user_connection_ids(requester_id)
2562 {
2563 session.peer.send(connection_id, update.clone())?;
2564 }
2565
2566 // Update incoming contact requests of responder
2567 let mut update = proto::UpdateContacts::default();
2568 update
2569 .incoming_requests
2570 .push(proto::IncomingContactRequest {
2571 requester_id: requester_id.to_proto(),
2572 });
2573 let connection_pool = session.connection_pool().await;
2574 for connection_id in connection_pool.user_connection_ids(responder_id) {
2575 session.peer.send(connection_id, update.clone())?;
2576 }
2577
2578 send_notifications(&connection_pool, &session.peer, notifications);
2579
2580 response.send(proto::Ack {})?;
2581 Ok(())
2582}
2583
2584/// Accept or decline a contact request
2585async fn respond_to_contact_request(
2586 request: proto::RespondToContactRequest,
2587 response: Response<proto::RespondToContactRequest>,
2588 session: Session,
2589) -> Result<()> {
2590 let responder_id = session.user_id();
2591 let requester_id = UserId::from_proto(request.requester_id);
2592 let db = session.db().await;
2593 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2594 db.dismiss_contact_notification(responder_id, requester_id)
2595 .await?;
2596 } else {
2597 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2598
2599 let notifications = db
2600 .respond_to_contact_request(responder_id, requester_id, accept)
2601 .await?;
2602 let requester_busy = db.is_user_busy(requester_id).await?;
2603 let responder_busy = db.is_user_busy(responder_id).await?;
2604
2605 let pool = session.connection_pool().await;
2606 // Update responder with new contact
2607 let mut update = proto::UpdateContacts::default();
2608 if accept {
2609 update
2610 .contacts
2611 .push(contact_for_user(requester_id, requester_busy, &pool));
2612 }
2613 update
2614 .remove_incoming_requests
2615 .push(requester_id.to_proto());
2616 for connection_id in pool.user_connection_ids(responder_id) {
2617 session.peer.send(connection_id, update.clone())?;
2618 }
2619
2620 // Update requester with new contact
2621 let mut update = proto::UpdateContacts::default();
2622 if accept {
2623 update
2624 .contacts
2625 .push(contact_for_user(responder_id, responder_busy, &pool));
2626 }
2627 update
2628 .remove_outgoing_requests
2629 .push(responder_id.to_proto());
2630
2631 for connection_id in pool.user_connection_ids(requester_id) {
2632 session.peer.send(connection_id, update.clone())?;
2633 }
2634
2635 send_notifications(&pool, &session.peer, notifications);
2636 }
2637
2638 response.send(proto::Ack {})?;
2639 Ok(())
2640}
2641
2642/// Remove a contact.
2643async fn remove_contact(
2644 request: proto::RemoveContact,
2645 response: Response<proto::RemoveContact>,
2646 session: Session,
2647) -> Result<()> {
2648 let requester_id = session.user_id();
2649 let responder_id = UserId::from_proto(request.user_id);
2650 let db = session.db().await;
2651 let (contact_accepted, deleted_notification_id) =
2652 db.remove_contact(requester_id, responder_id).await?;
2653
2654 let pool = session.connection_pool().await;
2655 // Update outgoing contact requests of requester
2656 let mut update = proto::UpdateContacts::default();
2657 if contact_accepted {
2658 update.remove_contacts.push(responder_id.to_proto());
2659 } else {
2660 update
2661 .remove_outgoing_requests
2662 .push(responder_id.to_proto());
2663 }
2664 for connection_id in pool.user_connection_ids(requester_id) {
2665 session.peer.send(connection_id, update.clone())?;
2666 }
2667
2668 // Update incoming contact requests of responder
2669 let mut update = proto::UpdateContacts::default();
2670 if contact_accepted {
2671 update.remove_contacts.push(requester_id.to_proto());
2672 } else {
2673 update
2674 .remove_incoming_requests
2675 .push(requester_id.to_proto());
2676 }
2677 for connection_id in pool.user_connection_ids(responder_id) {
2678 session.peer.send(connection_id, update.clone())?;
2679 if let Some(notification_id) = deleted_notification_id {
2680 session.peer.send(
2681 connection_id,
2682 proto::DeleteNotification {
2683 notification_id: notification_id.to_proto(),
2684 },
2685 )?;
2686 }
2687 }
2688
2689 response.send(proto::Ack {})?;
2690 Ok(())
2691}
2692
2693fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2694 version.0.minor() < 139
2695}
2696
2697async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2698 let db = session.db().await;
2699
2700 let feature_flags = db.get_user_flags(user_id).await?;
2701 let plan = session.current_plan(&db).await?;
2702 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2703 let billing_preferences = db.get_billing_preferences(user_id).await?;
2704
2705 let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
2706 let subscription = db.get_active_billing_subscription(user_id).await?;
2707
2708 let subscription_period = maybe!({
2709 let subscription = subscription?;
2710 let period_start_at = subscription.current_period_start_at()?;
2711 let period_end_at = subscription.current_period_end_at()?;
2712
2713 Some((period_start_at, period_end_at))
2714 });
2715
2716 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2717 llm_db
2718 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2719 .await?
2720 } else {
2721 None
2722 };
2723
2724 (subscription_period, usage)
2725 } else {
2726 (None, None)
2727 };
2728
2729 session
2730 .peer
2731 .send(
2732 session.connection_id,
2733 proto::UpdateUserPlan {
2734 plan: plan.into(),
2735 trial_started_at: billing_customer
2736 .and_then(|billing_customer| billing_customer.trial_started_at)
2737 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2738 is_usage_based_billing_enabled: if session.is_staff() {
2739 Some(true)
2740 } else {
2741 billing_preferences
2742 .map(|preferences| preferences.model_request_overages_enabled)
2743 },
2744 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2745 proto::SubscriptionPeriod {
2746 started_at: started_at.timestamp() as u64,
2747 ended_at: ended_at.timestamp() as u64,
2748 }
2749 }),
2750 usage: usage.map(|usage| {
2751 let plan = match plan {
2752 proto::Plan::Free => zed_llm_client::Plan::Free,
2753 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2754 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2755 };
2756
2757 let model_requests_limit = match plan.model_requests_limit() {
2758 zed_llm_client::UsageLimit::Limited(limit) => {
2759 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2760 && feature_flags
2761 .iter()
2762 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2763 {
2764 1_000
2765 } else {
2766 limit
2767 };
2768
2769 zed_llm_client::UsageLimit::Limited(limit)
2770 }
2771 zed_llm_client::UsageLimit::Unlimited => {
2772 zed_llm_client::UsageLimit::Unlimited
2773 }
2774 };
2775
2776 proto::SubscriptionUsage {
2777 model_requests_usage_amount: usage.model_requests as u32,
2778 model_requests_usage_limit: Some(proto::UsageLimit {
2779 variant: Some(match model_requests_limit {
2780 zed_llm_client::UsageLimit::Limited(limit) => {
2781 proto::usage_limit::Variant::Limited(
2782 proto::usage_limit::Limited {
2783 limit: limit as u32,
2784 },
2785 )
2786 }
2787 zed_llm_client::UsageLimit::Unlimited => {
2788 proto::usage_limit::Variant::Unlimited(
2789 proto::usage_limit::Unlimited {},
2790 )
2791 }
2792 }),
2793 }),
2794 edit_predictions_usage_amount: usage.edit_predictions as u32,
2795 edit_predictions_usage_limit: Some(proto::UsageLimit {
2796 variant: Some(match plan.edit_predictions_limit() {
2797 zed_llm_client::UsageLimit::Limited(limit) => {
2798 proto::usage_limit::Variant::Limited(
2799 proto::usage_limit::Limited {
2800 limit: limit as u32,
2801 },
2802 )
2803 }
2804 zed_llm_client::UsageLimit::Unlimited => {
2805 proto::usage_limit::Variant::Unlimited(
2806 proto::usage_limit::Unlimited {},
2807 )
2808 }
2809 }),
2810 }),
2811 }
2812 }),
2813 },
2814 )
2815 .trace_err();
2816
2817 Ok(())
2818}
2819
2820async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2821 subscribe_user_to_channels(session.user_id(), &session).await?;
2822 Ok(())
2823}
2824
2825async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2826 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2827 let mut pool = session.connection_pool().await;
2828 for membership in &channels_for_user.channel_memberships {
2829 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2830 }
2831 session.peer.send(
2832 session.connection_id,
2833 build_update_user_channels(&channels_for_user),
2834 )?;
2835 session.peer.send(
2836 session.connection_id,
2837 build_channels_update(channels_for_user),
2838 )?;
2839 Ok(())
2840}
2841
2842/// Creates a new channel.
2843async fn create_channel(
2844 request: proto::CreateChannel,
2845 response: Response<proto::CreateChannel>,
2846 session: Session,
2847) -> Result<()> {
2848 let db = session.db().await;
2849
2850 let parent_id = request.parent_id.map(ChannelId::from_proto);
2851 let (channel, membership) = db
2852 .create_channel(&request.name, parent_id, session.user_id())
2853 .await?;
2854
2855 let root_id = channel.root_id();
2856 let channel = Channel::from_model(channel);
2857
2858 response.send(proto::CreateChannelResponse {
2859 channel: Some(channel.to_proto()),
2860 parent_id: request.parent_id,
2861 })?;
2862
2863 let mut connection_pool = session.connection_pool().await;
2864 if let Some(membership) = membership {
2865 connection_pool.subscribe_to_channel(
2866 membership.user_id,
2867 membership.channel_id,
2868 membership.role,
2869 );
2870 let update = proto::UpdateUserChannels {
2871 channel_memberships: vec![proto::ChannelMembership {
2872 channel_id: membership.channel_id.to_proto(),
2873 role: membership.role.into(),
2874 }],
2875 ..Default::default()
2876 };
2877 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2878 session.peer.send(connection_id, update.clone())?;
2879 }
2880 }
2881
2882 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2883 if !role.can_see_channel(channel.visibility) {
2884 continue;
2885 }
2886
2887 let update = proto::UpdateChannels {
2888 channels: vec![channel.to_proto()],
2889 ..Default::default()
2890 };
2891 session.peer.send(connection_id, update.clone())?;
2892 }
2893
2894 Ok(())
2895}
2896
2897/// Delete a channel
2898async fn delete_channel(
2899 request: proto::DeleteChannel,
2900 response: Response<proto::DeleteChannel>,
2901 session: Session,
2902) -> Result<()> {
2903 let db = session.db().await;
2904
2905 let channel_id = request.channel_id;
2906 let (root_channel, removed_channels) = db
2907 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2908 .await?;
2909 response.send(proto::Ack {})?;
2910
2911 // Notify members of removed channels
2912 let mut update = proto::UpdateChannels::default();
2913 update
2914 .delete_channels
2915 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2916
2917 let connection_pool = session.connection_pool().await;
2918 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2919 session.peer.send(connection_id, update.clone())?;
2920 }
2921
2922 Ok(())
2923}
2924
2925/// Invite someone to join a channel.
2926async fn invite_channel_member(
2927 request: proto::InviteChannelMember,
2928 response: Response<proto::InviteChannelMember>,
2929 session: Session,
2930) -> Result<()> {
2931 let db = session.db().await;
2932 let channel_id = ChannelId::from_proto(request.channel_id);
2933 let invitee_id = UserId::from_proto(request.user_id);
2934 let InviteMemberResult {
2935 channel,
2936 notifications,
2937 } = db
2938 .invite_channel_member(
2939 channel_id,
2940 invitee_id,
2941 session.user_id(),
2942 request.role().into(),
2943 )
2944 .await?;
2945
2946 let update = proto::UpdateChannels {
2947 channel_invitations: vec![channel.to_proto()],
2948 ..Default::default()
2949 };
2950
2951 let connection_pool = session.connection_pool().await;
2952 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2953 session.peer.send(connection_id, update.clone())?;
2954 }
2955
2956 send_notifications(&connection_pool, &session.peer, notifications);
2957
2958 response.send(proto::Ack {})?;
2959 Ok(())
2960}
2961
2962/// remove someone from a channel
2963async fn remove_channel_member(
2964 request: proto::RemoveChannelMember,
2965 response: Response<proto::RemoveChannelMember>,
2966 session: Session,
2967) -> Result<()> {
2968 let db = session.db().await;
2969 let channel_id = ChannelId::from_proto(request.channel_id);
2970 let member_id = UserId::from_proto(request.user_id);
2971
2972 let RemoveChannelMemberResult {
2973 membership_update,
2974 notification_id,
2975 } = db
2976 .remove_channel_member(channel_id, member_id, session.user_id())
2977 .await?;
2978
2979 let mut connection_pool = session.connection_pool().await;
2980 notify_membership_updated(
2981 &mut connection_pool,
2982 membership_update,
2983 member_id,
2984 &session.peer,
2985 );
2986 for connection_id in connection_pool.user_connection_ids(member_id) {
2987 if let Some(notification_id) = notification_id {
2988 session
2989 .peer
2990 .send(
2991 connection_id,
2992 proto::DeleteNotification {
2993 notification_id: notification_id.to_proto(),
2994 },
2995 )
2996 .trace_err();
2997 }
2998 }
2999
3000 response.send(proto::Ack {})?;
3001 Ok(())
3002}
3003
3004/// Toggle the channel between public and private.
3005/// Care is taken to maintain the invariant that public channels only descend from public channels,
3006/// (though members-only channels can appear at any point in the hierarchy).
3007async fn set_channel_visibility(
3008 request: proto::SetChannelVisibility,
3009 response: Response<proto::SetChannelVisibility>,
3010 session: Session,
3011) -> Result<()> {
3012 let db = session.db().await;
3013 let channel_id = ChannelId::from_proto(request.channel_id);
3014 let visibility = request.visibility().into();
3015
3016 let channel_model = db
3017 .set_channel_visibility(channel_id, visibility, session.user_id())
3018 .await?;
3019 let root_id = channel_model.root_id();
3020 let channel = Channel::from_model(channel_model);
3021
3022 let mut connection_pool = session.connection_pool().await;
3023 for (user_id, role) in connection_pool
3024 .channel_user_ids(root_id)
3025 .collect::<Vec<_>>()
3026 .into_iter()
3027 {
3028 let update = if role.can_see_channel(channel.visibility) {
3029 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3030 proto::UpdateChannels {
3031 channels: vec![channel.to_proto()],
3032 ..Default::default()
3033 }
3034 } else {
3035 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3036 proto::UpdateChannels {
3037 delete_channels: vec![channel.id.to_proto()],
3038 ..Default::default()
3039 }
3040 };
3041
3042 for connection_id in connection_pool.user_connection_ids(user_id) {
3043 session.peer.send(connection_id, update.clone())?;
3044 }
3045 }
3046
3047 response.send(proto::Ack {})?;
3048 Ok(())
3049}
3050
3051/// Alter the role for a user in the channel.
3052async fn set_channel_member_role(
3053 request: proto::SetChannelMemberRole,
3054 response: Response<proto::SetChannelMemberRole>,
3055 session: Session,
3056) -> Result<()> {
3057 let db = session.db().await;
3058 let channel_id = ChannelId::from_proto(request.channel_id);
3059 let member_id = UserId::from_proto(request.user_id);
3060 let result = db
3061 .set_channel_member_role(
3062 channel_id,
3063 session.user_id(),
3064 member_id,
3065 request.role().into(),
3066 )
3067 .await?;
3068
3069 match result {
3070 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3071 let mut connection_pool = session.connection_pool().await;
3072 notify_membership_updated(
3073 &mut connection_pool,
3074 membership_update,
3075 member_id,
3076 &session.peer,
3077 )
3078 }
3079 db::SetMemberRoleResult::InviteUpdated(channel) => {
3080 let update = proto::UpdateChannels {
3081 channel_invitations: vec![channel.to_proto()],
3082 ..Default::default()
3083 };
3084
3085 for connection_id in session
3086 .connection_pool()
3087 .await
3088 .user_connection_ids(member_id)
3089 {
3090 session.peer.send(connection_id, update.clone())?;
3091 }
3092 }
3093 }
3094
3095 response.send(proto::Ack {})?;
3096 Ok(())
3097}
3098
3099/// Change the name of a channel
3100async fn rename_channel(
3101 request: proto::RenameChannel,
3102 response: Response<proto::RenameChannel>,
3103 session: Session,
3104) -> Result<()> {
3105 let db = session.db().await;
3106 let channel_id = ChannelId::from_proto(request.channel_id);
3107 let channel_model = db
3108 .rename_channel(channel_id, session.user_id(), &request.name)
3109 .await?;
3110 let root_id = channel_model.root_id();
3111 let channel = Channel::from_model(channel_model);
3112
3113 response.send(proto::RenameChannelResponse {
3114 channel: Some(channel.to_proto()),
3115 })?;
3116
3117 let connection_pool = session.connection_pool().await;
3118 let update = proto::UpdateChannels {
3119 channels: vec![channel.to_proto()],
3120 ..Default::default()
3121 };
3122 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3123 if role.can_see_channel(channel.visibility) {
3124 session.peer.send(connection_id, update.clone())?;
3125 }
3126 }
3127
3128 Ok(())
3129}
3130
3131/// Move a channel to a new parent.
3132async fn move_channel(
3133 request: proto::MoveChannel,
3134 response: Response<proto::MoveChannel>,
3135 session: Session,
3136) -> Result<()> {
3137 let channel_id = ChannelId::from_proto(request.channel_id);
3138 let to = ChannelId::from_proto(request.to);
3139
3140 let (root_id, channels) = session
3141 .db()
3142 .await
3143 .move_channel(channel_id, to, session.user_id())
3144 .await?;
3145
3146 let connection_pool = session.connection_pool().await;
3147 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3148 let channels = channels
3149 .iter()
3150 .filter_map(|channel| {
3151 if role.can_see_channel(channel.visibility) {
3152 Some(channel.to_proto())
3153 } else {
3154 None
3155 }
3156 })
3157 .collect::<Vec<_>>();
3158 if channels.is_empty() {
3159 continue;
3160 }
3161
3162 let update = proto::UpdateChannels {
3163 channels,
3164 ..Default::default()
3165 };
3166
3167 session.peer.send(connection_id, update.clone())?;
3168 }
3169
3170 response.send(Ack {})?;
3171 Ok(())
3172}
3173
3174/// Get the list of channel members
3175async fn get_channel_members(
3176 request: proto::GetChannelMembers,
3177 response: Response<proto::GetChannelMembers>,
3178 session: Session,
3179) -> Result<()> {
3180 let db = session.db().await;
3181 let channel_id = ChannelId::from_proto(request.channel_id);
3182 let limit = if request.limit == 0 {
3183 u16::MAX as u64
3184 } else {
3185 request.limit
3186 };
3187 let (members, users) = db
3188 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3189 .await?;
3190 response.send(proto::GetChannelMembersResponse { members, users })?;
3191 Ok(())
3192}
3193
3194/// Accept or decline a channel invitation.
3195async fn respond_to_channel_invite(
3196 request: proto::RespondToChannelInvite,
3197 response: Response<proto::RespondToChannelInvite>,
3198 session: Session,
3199) -> Result<()> {
3200 let db = session.db().await;
3201 let channel_id = ChannelId::from_proto(request.channel_id);
3202 let RespondToChannelInvite {
3203 membership_update,
3204 notifications,
3205 } = db
3206 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3207 .await?;
3208
3209 let mut connection_pool = session.connection_pool().await;
3210 if let Some(membership_update) = membership_update {
3211 notify_membership_updated(
3212 &mut connection_pool,
3213 membership_update,
3214 session.user_id(),
3215 &session.peer,
3216 );
3217 } else {
3218 let update = proto::UpdateChannels {
3219 remove_channel_invitations: vec![channel_id.to_proto()],
3220 ..Default::default()
3221 };
3222
3223 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3224 session.peer.send(connection_id, update.clone())?;
3225 }
3226 };
3227
3228 send_notifications(&connection_pool, &session.peer, notifications);
3229
3230 response.send(proto::Ack {})?;
3231
3232 Ok(())
3233}
3234
3235/// Join the channels' room
3236async fn join_channel(
3237 request: proto::JoinChannel,
3238 response: Response<proto::JoinChannel>,
3239 session: Session,
3240) -> Result<()> {
3241 let channel_id = ChannelId::from_proto(request.channel_id);
3242 join_channel_internal(channel_id, Box::new(response), session).await
3243}
3244
3245trait JoinChannelInternalResponse {
3246 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3247}
3248impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3249 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3250 Response::<proto::JoinChannel>::send(self, result)
3251 }
3252}
3253impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3254 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3255 Response::<proto::JoinRoom>::send(self, result)
3256 }
3257}
3258
3259async fn join_channel_internal(
3260 channel_id: ChannelId,
3261 response: Box<impl JoinChannelInternalResponse>,
3262 session: Session,
3263) -> Result<()> {
3264 let joined_room = {
3265 let mut db = session.db().await;
3266 // If zed quits without leaving the room, and the user re-opens zed before the
3267 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3268 // room they were in.
3269 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3270 tracing::info!(
3271 stale_connection_id = %connection,
3272 "cleaning up stale connection",
3273 );
3274 drop(db);
3275 leave_room_for_session(&session, connection).await?;
3276 db = session.db().await;
3277 }
3278
3279 let (joined_room, membership_updated, role) = db
3280 .join_channel(channel_id, session.user_id(), session.connection_id)
3281 .await?;
3282
3283 let live_kit_connection_info =
3284 session
3285 .app_state
3286 .livekit_client
3287 .as_ref()
3288 .and_then(|live_kit| {
3289 let (can_publish, token) = if role == ChannelRole::Guest {
3290 (
3291 false,
3292 live_kit
3293 .guest_token(
3294 &joined_room.room.livekit_room,
3295 &session.user_id().to_string(),
3296 )
3297 .trace_err()?,
3298 )
3299 } else {
3300 (
3301 true,
3302 live_kit
3303 .room_token(
3304 &joined_room.room.livekit_room,
3305 &session.user_id().to_string(),
3306 )
3307 .trace_err()?,
3308 )
3309 };
3310
3311 Some(LiveKitConnectionInfo {
3312 server_url: live_kit.url().into(),
3313 token,
3314 can_publish,
3315 })
3316 });
3317
3318 response.send(proto::JoinRoomResponse {
3319 room: Some(joined_room.room.clone()),
3320 channel_id: joined_room
3321 .channel
3322 .as_ref()
3323 .map(|channel| channel.id.to_proto()),
3324 live_kit_connection_info,
3325 })?;
3326
3327 let mut connection_pool = session.connection_pool().await;
3328 if let Some(membership_updated) = membership_updated {
3329 notify_membership_updated(
3330 &mut connection_pool,
3331 membership_updated,
3332 session.user_id(),
3333 &session.peer,
3334 );
3335 }
3336
3337 room_updated(&joined_room.room, &session.peer);
3338
3339 joined_room
3340 };
3341
3342 channel_updated(
3343 &joined_room
3344 .channel
3345 .ok_or_else(|| anyhow!("channel not returned"))?,
3346 &joined_room.room,
3347 &session.peer,
3348 &*session.connection_pool().await,
3349 );
3350
3351 update_user_contacts(session.user_id(), &session).await?;
3352 Ok(())
3353}
3354
3355/// Start editing the channel notes
3356async fn join_channel_buffer(
3357 request: proto::JoinChannelBuffer,
3358 response: Response<proto::JoinChannelBuffer>,
3359 session: Session,
3360) -> Result<()> {
3361 let db = session.db().await;
3362 let channel_id = ChannelId::from_proto(request.channel_id);
3363
3364 let open_response = db
3365 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3366 .await?;
3367
3368 let collaborators = open_response.collaborators.clone();
3369 response.send(open_response)?;
3370
3371 let update = UpdateChannelBufferCollaborators {
3372 channel_id: channel_id.to_proto(),
3373 collaborators: collaborators.clone(),
3374 };
3375 channel_buffer_updated(
3376 session.connection_id,
3377 collaborators
3378 .iter()
3379 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3380 &update,
3381 &session.peer,
3382 );
3383
3384 Ok(())
3385}
3386
3387/// Edit the channel notes
3388async fn update_channel_buffer(
3389 request: proto::UpdateChannelBuffer,
3390 session: Session,
3391) -> Result<()> {
3392 let db = session.db().await;
3393 let channel_id = ChannelId::from_proto(request.channel_id);
3394
3395 let (collaborators, epoch, version) = db
3396 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3397 .await?;
3398
3399 channel_buffer_updated(
3400 session.connection_id,
3401 collaborators.clone(),
3402 &proto::UpdateChannelBuffer {
3403 channel_id: channel_id.to_proto(),
3404 operations: request.operations,
3405 },
3406 &session.peer,
3407 );
3408
3409 let pool = &*session.connection_pool().await;
3410
3411 let non_collaborators =
3412 pool.channel_connection_ids(channel_id)
3413 .filter_map(|(connection_id, _)| {
3414 if collaborators.contains(&connection_id) {
3415 None
3416 } else {
3417 Some(connection_id)
3418 }
3419 });
3420
3421 broadcast(None, non_collaborators, |peer_id| {
3422 session.peer.send(
3423 peer_id,
3424 proto::UpdateChannels {
3425 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3426 channel_id: channel_id.to_proto(),
3427 epoch: epoch as u64,
3428 version: version.clone(),
3429 }],
3430 ..Default::default()
3431 },
3432 )
3433 });
3434
3435 Ok(())
3436}
3437
3438/// Rejoin the channel notes after a connection blip
3439async fn rejoin_channel_buffers(
3440 request: proto::RejoinChannelBuffers,
3441 response: Response<proto::RejoinChannelBuffers>,
3442 session: Session,
3443) -> Result<()> {
3444 let db = session.db().await;
3445 let buffers = db
3446 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3447 .await?;
3448
3449 for rejoined_buffer in &buffers {
3450 let collaborators_to_notify = rejoined_buffer
3451 .buffer
3452 .collaborators
3453 .iter()
3454 .filter_map(|c| Some(c.peer_id?.into()));
3455 channel_buffer_updated(
3456 session.connection_id,
3457 collaborators_to_notify,
3458 &proto::UpdateChannelBufferCollaborators {
3459 channel_id: rejoined_buffer.buffer.channel_id,
3460 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3461 },
3462 &session.peer,
3463 );
3464 }
3465
3466 response.send(proto::RejoinChannelBuffersResponse {
3467 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3468 })?;
3469
3470 Ok(())
3471}
3472
3473/// Stop editing the channel notes
3474async fn leave_channel_buffer(
3475 request: proto::LeaveChannelBuffer,
3476 response: Response<proto::LeaveChannelBuffer>,
3477 session: Session,
3478) -> Result<()> {
3479 let db = session.db().await;
3480 let channel_id = ChannelId::from_proto(request.channel_id);
3481
3482 let left_buffer = db
3483 .leave_channel_buffer(channel_id, session.connection_id)
3484 .await?;
3485
3486 response.send(Ack {})?;
3487
3488 channel_buffer_updated(
3489 session.connection_id,
3490 left_buffer.connections,
3491 &proto::UpdateChannelBufferCollaborators {
3492 channel_id: channel_id.to_proto(),
3493 collaborators: left_buffer.collaborators,
3494 },
3495 &session.peer,
3496 );
3497
3498 Ok(())
3499}
3500
3501fn channel_buffer_updated<T: EnvelopedMessage>(
3502 sender_id: ConnectionId,
3503 collaborators: impl IntoIterator<Item = ConnectionId>,
3504 message: &T,
3505 peer: &Peer,
3506) {
3507 broadcast(Some(sender_id), collaborators, |peer_id| {
3508 peer.send(peer_id, message.clone())
3509 });
3510}
3511
3512fn send_notifications(
3513 connection_pool: &ConnectionPool,
3514 peer: &Peer,
3515 notifications: db::NotificationBatch,
3516) {
3517 for (user_id, notification) in notifications {
3518 for connection_id in connection_pool.user_connection_ids(user_id) {
3519 if let Err(error) = peer.send(
3520 connection_id,
3521 proto::AddNotification {
3522 notification: Some(notification.clone()),
3523 },
3524 ) {
3525 tracing::error!(
3526 "failed to send notification to {:?} {}",
3527 connection_id,
3528 error
3529 );
3530 }
3531 }
3532 }
3533}
3534
3535/// Send a message to the channel
3536async fn send_channel_message(
3537 request: proto::SendChannelMessage,
3538 response: Response<proto::SendChannelMessage>,
3539 session: Session,
3540) -> Result<()> {
3541 // Validate the message body.
3542 let body = request.body.trim().to_string();
3543 if body.len() > MAX_MESSAGE_LEN {
3544 return Err(anyhow!("message is too long"))?;
3545 }
3546 if body.is_empty() {
3547 return Err(anyhow!("message can't be blank"))?;
3548 }
3549
3550 // TODO: adjust mentions if body is trimmed
3551
3552 let timestamp = OffsetDateTime::now_utc();
3553 let nonce = request
3554 .nonce
3555 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3556
3557 let channel_id = ChannelId::from_proto(request.channel_id);
3558 let CreatedChannelMessage {
3559 message_id,
3560 participant_connection_ids,
3561 notifications,
3562 } = session
3563 .db()
3564 .await
3565 .create_channel_message(
3566 channel_id,
3567 session.user_id(),
3568 &body,
3569 &request.mentions,
3570 timestamp,
3571 nonce.clone().into(),
3572 request.reply_to_message_id.map(MessageId::from_proto),
3573 )
3574 .await?;
3575
3576 let message = proto::ChannelMessage {
3577 sender_id: session.user_id().to_proto(),
3578 id: message_id.to_proto(),
3579 body,
3580 mentions: request.mentions,
3581 timestamp: timestamp.unix_timestamp() as u64,
3582 nonce: Some(nonce),
3583 reply_to_message_id: request.reply_to_message_id,
3584 edited_at: None,
3585 };
3586 broadcast(
3587 Some(session.connection_id),
3588 participant_connection_ids.clone(),
3589 |connection| {
3590 session.peer.send(
3591 connection,
3592 proto::ChannelMessageSent {
3593 channel_id: channel_id.to_proto(),
3594 message: Some(message.clone()),
3595 },
3596 )
3597 },
3598 );
3599 response.send(proto::SendChannelMessageResponse {
3600 message: Some(message),
3601 })?;
3602
3603 let pool = &*session.connection_pool().await;
3604 let non_participants =
3605 pool.channel_connection_ids(channel_id)
3606 .filter_map(|(connection_id, _)| {
3607 if participant_connection_ids.contains(&connection_id) {
3608 None
3609 } else {
3610 Some(connection_id)
3611 }
3612 });
3613 broadcast(None, non_participants, |peer_id| {
3614 session.peer.send(
3615 peer_id,
3616 proto::UpdateChannels {
3617 latest_channel_message_ids: vec![proto::ChannelMessageId {
3618 channel_id: channel_id.to_proto(),
3619 message_id: message_id.to_proto(),
3620 }],
3621 ..Default::default()
3622 },
3623 )
3624 });
3625 send_notifications(pool, &session.peer, notifications);
3626
3627 Ok(())
3628}
3629
3630/// Delete a channel message
3631async fn remove_channel_message(
3632 request: proto::RemoveChannelMessage,
3633 response: Response<proto::RemoveChannelMessage>,
3634 session: Session,
3635) -> Result<()> {
3636 let channel_id = ChannelId::from_proto(request.channel_id);
3637 let message_id = MessageId::from_proto(request.message_id);
3638 let (connection_ids, existing_notification_ids) = session
3639 .db()
3640 .await
3641 .remove_channel_message(channel_id, message_id, session.user_id())
3642 .await?;
3643
3644 broadcast(
3645 Some(session.connection_id),
3646 connection_ids,
3647 move |connection| {
3648 session.peer.send(connection, request.clone())?;
3649
3650 for notification_id in &existing_notification_ids {
3651 session.peer.send(
3652 connection,
3653 proto::DeleteNotification {
3654 notification_id: (*notification_id).to_proto(),
3655 },
3656 )?;
3657 }
3658
3659 Ok(())
3660 },
3661 );
3662 response.send(proto::Ack {})?;
3663 Ok(())
3664}
3665
3666async fn update_channel_message(
3667 request: proto::UpdateChannelMessage,
3668 response: Response<proto::UpdateChannelMessage>,
3669 session: Session,
3670) -> Result<()> {
3671 let channel_id = ChannelId::from_proto(request.channel_id);
3672 let message_id = MessageId::from_proto(request.message_id);
3673 let updated_at = OffsetDateTime::now_utc();
3674 let UpdatedChannelMessage {
3675 message_id,
3676 participant_connection_ids,
3677 notifications,
3678 reply_to_message_id,
3679 timestamp,
3680 deleted_mention_notification_ids,
3681 updated_mention_notifications,
3682 } = session
3683 .db()
3684 .await
3685 .update_channel_message(
3686 channel_id,
3687 message_id,
3688 session.user_id(),
3689 request.body.as_str(),
3690 &request.mentions,
3691 updated_at,
3692 )
3693 .await?;
3694
3695 let nonce = request
3696 .nonce
3697 .clone()
3698 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3699
3700 let message = proto::ChannelMessage {
3701 sender_id: session.user_id().to_proto(),
3702 id: message_id.to_proto(),
3703 body: request.body.clone(),
3704 mentions: request.mentions.clone(),
3705 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3706 nonce: Some(nonce),
3707 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3708 edited_at: Some(updated_at.unix_timestamp() as u64),
3709 };
3710
3711 response.send(proto::Ack {})?;
3712
3713 let pool = &*session.connection_pool().await;
3714 broadcast(
3715 Some(session.connection_id),
3716 participant_connection_ids,
3717 |connection| {
3718 session.peer.send(
3719 connection,
3720 proto::ChannelMessageUpdate {
3721 channel_id: channel_id.to_proto(),
3722 message: Some(message.clone()),
3723 },
3724 )?;
3725
3726 for notification_id in &deleted_mention_notification_ids {
3727 session.peer.send(
3728 connection,
3729 proto::DeleteNotification {
3730 notification_id: (*notification_id).to_proto(),
3731 },
3732 )?;
3733 }
3734
3735 for notification in &updated_mention_notifications {
3736 session.peer.send(
3737 connection,
3738 proto::UpdateNotification {
3739 notification: Some(notification.clone()),
3740 },
3741 )?;
3742 }
3743
3744 Ok(())
3745 },
3746 );
3747
3748 send_notifications(pool, &session.peer, notifications);
3749
3750 Ok(())
3751}
3752
3753/// Mark a channel message as read
3754async fn acknowledge_channel_message(
3755 request: proto::AckChannelMessage,
3756 session: Session,
3757) -> Result<()> {
3758 let channel_id = ChannelId::from_proto(request.channel_id);
3759 let message_id = MessageId::from_proto(request.message_id);
3760 let notifications = session
3761 .db()
3762 .await
3763 .observe_channel_message(channel_id, session.user_id(), message_id)
3764 .await?;
3765 send_notifications(
3766 &*session.connection_pool().await,
3767 &session.peer,
3768 notifications,
3769 );
3770 Ok(())
3771}
3772
3773/// Mark a buffer version as synced
3774async fn acknowledge_buffer_version(
3775 request: proto::AckBufferOperation,
3776 session: Session,
3777) -> Result<()> {
3778 let buffer_id = BufferId::from_proto(request.buffer_id);
3779 session
3780 .db()
3781 .await
3782 .observe_buffer_version(
3783 buffer_id,
3784 session.user_id(),
3785 request.epoch as i32,
3786 &request.version,
3787 )
3788 .await?;
3789 Ok(())
3790}
3791
3792/// Get a Supermaven API key for the user
3793async fn get_supermaven_api_key(
3794 _request: proto::GetSupermavenApiKey,
3795 response: Response<proto::GetSupermavenApiKey>,
3796 session: Session,
3797) -> Result<()> {
3798 let user_id: String = session.user_id().to_string();
3799 if !session.is_staff() {
3800 return Err(anyhow!("supermaven not enabled for this account"))?;
3801 }
3802
3803 let email = session
3804 .email()
3805 .ok_or_else(|| anyhow!("user must have an email"))?;
3806
3807 let supermaven_admin_api = session
3808 .supermaven_client
3809 .as_ref()
3810 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3811
3812 let result = supermaven_admin_api
3813 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3814 .await?;
3815
3816 response.send(proto::GetSupermavenApiKeyResponse {
3817 api_key: result.api_key,
3818 })?;
3819
3820 Ok(())
3821}
3822
3823/// Start receiving chat updates for a channel
3824async fn join_channel_chat(
3825 request: proto::JoinChannelChat,
3826 response: Response<proto::JoinChannelChat>,
3827 session: Session,
3828) -> Result<()> {
3829 let channel_id = ChannelId::from_proto(request.channel_id);
3830
3831 let db = session.db().await;
3832 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3833 .await?;
3834 let messages = db
3835 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3836 .await?;
3837 response.send(proto::JoinChannelChatResponse {
3838 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3839 messages,
3840 })?;
3841 Ok(())
3842}
3843
3844/// Stop receiving chat updates for a channel
3845async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847 session
3848 .db()
3849 .await
3850 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3851 .await?;
3852 Ok(())
3853}
3854
3855/// Retrieve the chat history for a channel
3856async fn get_channel_messages(
3857 request: proto::GetChannelMessages,
3858 response: Response<proto::GetChannelMessages>,
3859 session: Session,
3860) -> Result<()> {
3861 let channel_id = ChannelId::from_proto(request.channel_id);
3862 let messages = session
3863 .db()
3864 .await
3865 .get_channel_messages(
3866 channel_id,
3867 session.user_id(),
3868 MESSAGE_COUNT_PER_PAGE,
3869 Some(MessageId::from_proto(request.before_message_id)),
3870 )
3871 .await?;
3872 response.send(proto::GetChannelMessagesResponse {
3873 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3874 messages,
3875 })?;
3876 Ok(())
3877}
3878
3879/// Retrieve specific chat messages
3880async fn get_channel_messages_by_id(
3881 request: proto::GetChannelMessagesById,
3882 response: Response<proto::GetChannelMessagesById>,
3883 session: Session,
3884) -> Result<()> {
3885 let message_ids = request
3886 .message_ids
3887 .iter()
3888 .map(|id| MessageId::from_proto(*id))
3889 .collect::<Vec<_>>();
3890 let messages = session
3891 .db()
3892 .await
3893 .get_channel_messages_by_id(session.user_id(), &message_ids)
3894 .await?;
3895 response.send(proto::GetChannelMessagesResponse {
3896 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3897 messages,
3898 })?;
3899 Ok(())
3900}
3901
3902/// Retrieve the current users notifications
3903async fn get_notifications(
3904 request: proto::GetNotifications,
3905 response: Response<proto::GetNotifications>,
3906 session: Session,
3907) -> Result<()> {
3908 let notifications = session
3909 .db()
3910 .await
3911 .get_notifications(
3912 session.user_id(),
3913 NOTIFICATION_COUNT_PER_PAGE,
3914 request.before_id.map(db::NotificationId::from_proto),
3915 )
3916 .await?;
3917 response.send(proto::GetNotificationsResponse {
3918 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3919 notifications,
3920 })?;
3921 Ok(())
3922}
3923
3924/// Mark notifications as read
3925async fn mark_notification_as_read(
3926 request: proto::MarkNotificationRead,
3927 response: Response<proto::MarkNotificationRead>,
3928 session: Session,
3929) -> Result<()> {
3930 let database = &session.db().await;
3931 let notifications = database
3932 .mark_notification_as_read_by_id(
3933 session.user_id(),
3934 NotificationId::from_proto(request.notification_id),
3935 )
3936 .await?;
3937 send_notifications(
3938 &*session.connection_pool().await,
3939 &session.peer,
3940 notifications,
3941 );
3942 response.send(proto::Ack {})?;
3943 Ok(())
3944}
3945
3946/// Get the current users information
3947async fn get_private_user_info(
3948 _request: proto::GetPrivateUserInfo,
3949 response: Response<proto::GetPrivateUserInfo>,
3950 session: Session,
3951) -> Result<()> {
3952 let db = session.db().await;
3953
3954 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3955 let user = db
3956 .get_user_by_id(session.user_id())
3957 .await?
3958 .ok_or_else(|| anyhow!("user not found"))?;
3959 let flags = db.get_user_flags(session.user_id()).await?;
3960
3961 response.send(proto::GetPrivateUserInfoResponse {
3962 metrics_id,
3963 staff: user.admin,
3964 flags,
3965 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3966 })?;
3967 Ok(())
3968}
3969
3970/// Accept the terms of service (tos) on behalf of the current user
3971async fn accept_terms_of_service(
3972 _request: proto::AcceptTermsOfService,
3973 response: Response<proto::AcceptTermsOfService>,
3974 session: Session,
3975) -> Result<()> {
3976 let db = session.db().await;
3977
3978 let accepted_tos_at = Utc::now();
3979 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3980 .await?;
3981
3982 response.send(proto::AcceptTermsOfServiceResponse {
3983 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3984 })?;
3985 Ok(())
3986}
3987
3988/// The minimum account age an account must have in order to use the LLM service.
3989pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3990
3991async fn get_llm_api_token(
3992 _request: proto::GetLlmToken,
3993 response: Response<proto::GetLlmToken>,
3994 session: Session,
3995) -> Result<()> {
3996 let db = session.db().await;
3997
3998 let flags = db.get_user_flags(session.user_id()).await?;
3999 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4000
4001 if !session.is_staff() && !has_language_models_feature_flag {
4002 Err(anyhow!("permission denied"))?
4003 }
4004
4005 let user_id = session.user_id();
4006 let user = db
4007 .get_user_by_id(user_id)
4008 .await?
4009 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4010
4011 if user.accepted_tos_at.is_none() {
4012 Err(anyhow!("terms of service not accepted"))?
4013 }
4014
4015 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4016 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4017 let billing_preferences = db.get_billing_preferences(user.id).await?;
4018
4019 let token = LlmTokenClaims::create(
4020 &user,
4021 session.is_staff(),
4022 billing_preferences,
4023 &flags,
4024 has_legacy_llm_subscription,
4025 billing_subscription,
4026 session.system_id.clone(),
4027 &session.app_state.config,
4028 )?;
4029 response.send(proto::GetLlmTokenResponse { token })?;
4030 Ok(())
4031}
4032
4033fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4034 let message = match message {
4035 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4036 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4037 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4038 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4039 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4040 code: frame.code.into(),
4041 reason: frame.reason.as_str().to_owned().into(),
4042 })),
4043 // We should never receive a frame while reading the message, according
4044 // to the `tungstenite` maintainers:
4045 //
4046 // > It cannot occur when you read messages from the WebSocket, but it
4047 // > can be used when you want to send the raw frames (e.g. you want to
4048 // > send the frames to the WebSocket without composing the full message first).
4049 // >
4050 // > — https://github.com/snapview/tungstenite-rs/issues/268
4051 TungsteniteMessage::Frame(_) => {
4052 bail!("received an unexpected frame while reading the message")
4053 }
4054 };
4055
4056 Ok(message)
4057}
4058
4059fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4060 match message {
4061 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4062 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4063 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4064 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4065 AxumMessage::Close(frame) => {
4066 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4067 code: frame.code.into(),
4068 reason: frame.reason.as_ref().into(),
4069 }))
4070 }
4071 }
4072}
4073
4074fn notify_membership_updated(
4075 connection_pool: &mut ConnectionPool,
4076 result: MembershipUpdated,
4077 user_id: UserId,
4078 peer: &Peer,
4079) {
4080 for membership in &result.new_channels.channel_memberships {
4081 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4082 }
4083 for channel_id in &result.removed_channels {
4084 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4085 }
4086
4087 let user_channels_update = proto::UpdateUserChannels {
4088 channel_memberships: result
4089 .new_channels
4090 .channel_memberships
4091 .iter()
4092 .map(|cm| proto::ChannelMembership {
4093 channel_id: cm.channel_id.to_proto(),
4094 role: cm.role.into(),
4095 })
4096 .collect(),
4097 ..Default::default()
4098 };
4099
4100 let mut update = build_channels_update(result.new_channels);
4101 update.delete_channels = result
4102 .removed_channels
4103 .into_iter()
4104 .map(|id| id.to_proto())
4105 .collect();
4106 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4107
4108 for connection_id in connection_pool.user_connection_ids(user_id) {
4109 peer.send(connection_id, user_channels_update.clone())
4110 .trace_err();
4111 peer.send(connection_id, update.clone()).trace_err();
4112 }
4113}
4114
4115fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4116 proto::UpdateUserChannels {
4117 channel_memberships: channels
4118 .channel_memberships
4119 .iter()
4120 .map(|m| proto::ChannelMembership {
4121 channel_id: m.channel_id.to_proto(),
4122 role: m.role.into(),
4123 })
4124 .collect(),
4125 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4126 observed_channel_message_id: channels.observed_channel_messages.clone(),
4127 }
4128}
4129
4130fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4131 let mut update = proto::UpdateChannels::default();
4132
4133 for channel in channels.channels {
4134 update.channels.push(channel.to_proto());
4135 }
4136
4137 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4138 update.latest_channel_message_ids = channels.latest_channel_messages;
4139
4140 for (channel_id, participants) in channels.channel_participants {
4141 update
4142 .channel_participants
4143 .push(proto::ChannelParticipants {
4144 channel_id: channel_id.to_proto(),
4145 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4146 });
4147 }
4148
4149 for channel in channels.invited_channels {
4150 update.channel_invitations.push(channel.to_proto());
4151 }
4152
4153 update
4154}
4155
4156fn build_initial_contacts_update(
4157 contacts: Vec<db::Contact>,
4158 pool: &ConnectionPool,
4159) -> proto::UpdateContacts {
4160 let mut update = proto::UpdateContacts::default();
4161
4162 for contact in contacts {
4163 match contact {
4164 db::Contact::Accepted { user_id, busy } => {
4165 update.contacts.push(contact_for_user(user_id, busy, pool));
4166 }
4167 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4168 db::Contact::Incoming { user_id } => {
4169 update
4170 .incoming_requests
4171 .push(proto::IncomingContactRequest {
4172 requester_id: user_id.to_proto(),
4173 })
4174 }
4175 }
4176 }
4177
4178 update
4179}
4180
4181fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4182 proto::Contact {
4183 user_id: user_id.to_proto(),
4184 online: pool.is_user_online(user_id),
4185 busy,
4186 }
4187}
4188
4189fn room_updated(room: &proto::Room, peer: &Peer) {
4190 broadcast(
4191 None,
4192 room.participants
4193 .iter()
4194 .filter_map(|participant| Some(participant.peer_id?.into())),
4195 |peer_id| {
4196 peer.send(
4197 peer_id,
4198 proto::RoomUpdated {
4199 room: Some(room.clone()),
4200 },
4201 )
4202 },
4203 );
4204}
4205
4206fn channel_updated(
4207 channel: &db::channel::Model,
4208 room: &proto::Room,
4209 peer: &Peer,
4210 pool: &ConnectionPool,
4211) {
4212 let participants = room
4213 .participants
4214 .iter()
4215 .map(|p| p.user_id)
4216 .collect::<Vec<_>>();
4217
4218 broadcast(
4219 None,
4220 pool.channel_connection_ids(channel.root_id())
4221 .filter_map(|(channel_id, role)| {
4222 role.can_see_channel(channel.visibility)
4223 .then_some(channel_id)
4224 }),
4225 |peer_id| {
4226 peer.send(
4227 peer_id,
4228 proto::UpdateChannels {
4229 channel_participants: vec![proto::ChannelParticipants {
4230 channel_id: channel.id.to_proto(),
4231 participant_user_ids: participants.clone(),
4232 }],
4233 ..Default::default()
4234 },
4235 )
4236 },
4237 );
4238}
4239
4240async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4241 let db = session.db().await;
4242
4243 let contacts = db.get_contacts(user_id).await?;
4244 let busy = db.is_user_busy(user_id).await?;
4245
4246 let pool = session.connection_pool().await;
4247 let updated_contact = contact_for_user(user_id, busy, &pool);
4248 for contact in contacts {
4249 if let db::Contact::Accepted {
4250 user_id: contact_user_id,
4251 ..
4252 } = contact
4253 {
4254 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4255 session
4256 .peer
4257 .send(
4258 contact_conn_id,
4259 proto::UpdateContacts {
4260 contacts: vec![updated_contact.clone()],
4261 remove_contacts: Default::default(),
4262 incoming_requests: Default::default(),
4263 remove_incoming_requests: Default::default(),
4264 outgoing_requests: Default::default(),
4265 remove_outgoing_requests: Default::default(),
4266 },
4267 )
4268 .trace_err();
4269 }
4270 }
4271 }
4272 Ok(())
4273}
4274
4275async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4276 let mut contacts_to_update = HashSet::default();
4277
4278 let room_id;
4279 let canceled_calls_to_user_ids;
4280 let livekit_room;
4281 let delete_livekit_room;
4282 let room;
4283 let channel;
4284
4285 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4286 contacts_to_update.insert(session.user_id());
4287
4288 for project in left_room.left_projects.values() {
4289 project_left(project, session);
4290 }
4291
4292 room_id = RoomId::from_proto(left_room.room.id);
4293 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4294 livekit_room = mem::take(&mut left_room.room.livekit_room);
4295 delete_livekit_room = left_room.deleted;
4296 room = mem::take(&mut left_room.room);
4297 channel = mem::take(&mut left_room.channel);
4298
4299 room_updated(&room, &session.peer);
4300 } else {
4301 return Ok(());
4302 }
4303
4304 if let Some(channel) = channel {
4305 channel_updated(
4306 &channel,
4307 &room,
4308 &session.peer,
4309 &*session.connection_pool().await,
4310 );
4311 }
4312
4313 {
4314 let pool = session.connection_pool().await;
4315 for canceled_user_id in canceled_calls_to_user_ids {
4316 for connection_id in pool.user_connection_ids(canceled_user_id) {
4317 session
4318 .peer
4319 .send(
4320 connection_id,
4321 proto::CallCanceled {
4322 room_id: room_id.to_proto(),
4323 },
4324 )
4325 .trace_err();
4326 }
4327 contacts_to_update.insert(canceled_user_id);
4328 }
4329 }
4330
4331 for contact_user_id in contacts_to_update {
4332 update_user_contacts(contact_user_id, session).await?;
4333 }
4334
4335 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4336 live_kit
4337 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4338 .await
4339 .trace_err();
4340
4341 if delete_livekit_room {
4342 live_kit.delete_room(livekit_room).await.trace_err();
4343 }
4344 }
4345
4346 Ok(())
4347}
4348
4349async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4350 let left_channel_buffers = session
4351 .db()
4352 .await
4353 .leave_channel_buffers(session.connection_id)
4354 .await?;
4355
4356 for left_buffer in left_channel_buffers {
4357 channel_buffer_updated(
4358 session.connection_id,
4359 left_buffer.connections,
4360 &proto::UpdateChannelBufferCollaborators {
4361 channel_id: left_buffer.channel_id.to_proto(),
4362 collaborators: left_buffer.collaborators,
4363 },
4364 &session.peer,
4365 );
4366 }
4367
4368 Ok(())
4369}
4370
4371fn project_left(project: &db::LeftProject, session: &Session) {
4372 for connection_id in &project.connection_ids {
4373 if project.should_unshare {
4374 session
4375 .peer
4376 .send(
4377 *connection_id,
4378 proto::UnshareProject {
4379 project_id: project.id.to_proto(),
4380 },
4381 )
4382 .trace_err();
4383 } else {
4384 session
4385 .peer
4386 .send(
4387 *connection_id,
4388 proto::RemoveProjectCollaborator {
4389 project_id: project.id.to_proto(),
4390 peer_id: Some(session.connection_id.into()),
4391 },
4392 )
4393 .trace_err();
4394 }
4395 }
4396}
4397
4398pub trait ResultExt {
4399 type Ok;
4400
4401 fn trace_err(self) -> Option<Self::Ok>;
4402}
4403
4404impl<T, E> ResultExt for Result<T, E>
4405where
4406 E: std::fmt::Debug,
4407{
4408 type Ok = T;
4409
4410 #[track_caller]
4411 fn trace_err(self) -> Option<T> {
4412 match self {
4413 Ok(value) => Some(value),
4414 Err(error) => {
4415 tracing::error!("{:?}", error);
4416 None
4417 }
4418 }
4419 }
4420}