1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::stripe_client::StripeCustomerId;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{
12 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
13 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
14 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
15 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
16 },
17 executor::Executor,
18};
19use anyhow::{Context as _, anyhow, bail};
20use async_tungstenite::tungstenite::{
21 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
22};
23use axum::{
24 Extension, Router, TypedHeader,
25 body::Body,
26 extract::{
27 ConnectInfo, WebSocketUpgrade,
28 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
29 },
30 headers::{Header, HeaderName},
31 http::StatusCode,
32 middleware,
33 response::IntoResponse,
34 routing::get,
35};
36use chrono::Utc;
37use collections::{HashMap, HashSet};
38pub use connection_pool::{ConnectionPool, ZedVersion};
39use core::fmt::{self, Debug, Formatter};
40use reqwest_client::ReqwestClient;
41use rpc::proto::split_repository_update;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
46 stream::FuturesUnordered,
47};
48use prometheus::{IntGauge, register_int_gauge};
49use rpc::{
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 Arc, OnceLock,
68 atomic::{AtomicBool, Ordering::SeqCst},
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{Semaphore, watch};
74use tower::ServiceBuilder;
75use tracing::{
76 Instrument,
77 field::{self},
78 info_span, instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111}
112
113impl Principal {
114 fn user(&self) -> &User {
115 match self {
116 Principal::User(user) => user,
117 Principal::Impersonated { user, .. } => user,
118 }
119 }
120
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct Session {
138 principal: Principal,
139 connection_id: ConnectionId,
140 db: Arc<tokio::sync::Mutex<DbHandle>>,
141 peer: Arc<Peer>,
142 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 supermaven_client: Option<Arc<SupermavenAdminApi>>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 system_id: Option<String>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn is_staff(&self) -> bool {
173 match &self.principal {
174 Principal::User(user) => user.admin,
175 Principal::Impersonated { .. } => true,
176 }
177 }
178
179 fn user_id(&self) -> UserId {
180 match &self.principal {
181 Principal::User(user) => user.id,
182 Principal::Impersonated { user, .. } => user.id,
183 }
184 }
185
186 pub fn email(&self) -> Option<String> {
187 match &self.principal {
188 Principal::User(user) => user.email_address.clone(),
189 Principal::Impersonated { user, .. } => user.email_address.clone(),
190 }
191 }
192}
193
194impl Debug for Session {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let mut result = f.debug_struct("Session");
197 match &self.principal {
198 Principal::User(user) => {
199 result.field("user", &user.github_login);
200 }
201 Principal::Impersonated { user, admin } => {
202 result.field("user", &user.github_login);
203 result.field("impersonator", &admin.github_login);
204 }
205 }
206 result.field("connection_id", &self.connection_id).finish()
207 }
208}
209
210struct DbHandle(Arc<Database>);
211
212impl Deref for DbHandle {
213 type Target = Database;
214
215 fn deref(&self) -> &Self::Target {
216 self.0.as_ref()
217 }
218}
219
220pub struct Server {
221 id: parking_lot::Mutex<ServerId>,
222 peer: Arc<Peer>,
223 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
224 app_state: Arc<AppState>,
225 handlers: HashMap<TypeId, MessageHandler>,
226 teardown: watch::Sender<bool>,
227}
228
229pub(crate) struct ConnectionPoolGuard<'a> {
230 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
231 _not_send: PhantomData<Rc<()>>,
232}
233
234#[derive(Serialize)]
235pub struct ServerSnapshot<'a> {
236 peer: &'a Peer,
237 #[serde(serialize_with = "serialize_deref")]
238 connection_pool: ConnectionPoolGuard<'a>,
239}
240
241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
242where
243 S: Serializer,
244 T: Deref<Target = U>,
245 U: Serialize,
246{
247 Serialize::serialize(value.deref(), serializer)
248}
249
250impl Server {
251 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
252 let mut server = Self {
253 id: parking_lot::Mutex::new(id),
254 peer: Peer::new(id.0 as u32),
255 app_state: app_state.clone(),
256 connection_pool: Default::default(),
257 handlers: Default::default(),
258 teardown: watch::channel(false).0,
259 };
260
261 server
262 .add_request_handler(ping)
263 .add_request_handler(create_room)
264 .add_request_handler(join_room)
265 .add_request_handler(rejoin_room)
266 .add_request_handler(leave_room)
267 .add_request_handler(set_room_participant_role)
268 .add_request_handler(call)
269 .add_request_handler(cancel_call)
270 .add_message_handler(decline_call)
271 .add_request_handler(update_participant_location)
272 .add_request_handler(share_project)
273 .add_message_handler(unshare_project)
274 .add_request_handler(join_project)
275 .add_message_handler(leave_project)
276 .add_request_handler(update_project)
277 .add_request_handler(update_worktree)
278 .add_request_handler(update_repository)
279 .add_request_handler(remove_repository)
280 .add_message_handler(start_language_server)
281 .add_message_handler(update_language_server)
282 .add_message_handler(update_diagnostic_summary)
283 .add_message_handler(update_worktree_settings)
284 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
287 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
288 .add_request_handler(forward_find_search_candidates_request)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
294 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
295 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
296 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
297 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
298 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
299 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
304 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
305 .add_request_handler(
306 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
307 )
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
312 .add_request_handler(
313 forward_read_only_project_request::<proto::LanguageServerIdForName>,
314 )
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
348 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
349 .add_message_handler(create_buffer_for_peer)
350 .add_request_handler(update_buffer)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
357 .add_request_handler(get_users)
358 .add_request_handler(fuzzy_search_users)
359 .add_request_handler(request_contact)
360 .add_request_handler(remove_contact)
361 .add_request_handler(respond_to_contact_request)
362 .add_message_handler(subscribe_to_channels)
363 .add_request_handler(create_channel)
364 .add_request_handler(delete_channel)
365 .add_request_handler(invite_channel_member)
366 .add_request_handler(remove_channel_member)
367 .add_request_handler(set_channel_member_role)
368 .add_request_handler(set_channel_visibility)
369 .add_request_handler(rename_channel)
370 .add_request_handler(join_channel_buffer)
371 .add_request_handler(leave_channel_buffer)
372 .add_message_handler(update_channel_buffer)
373 .add_request_handler(rejoin_channel_buffers)
374 .add_request_handler(get_channel_members)
375 .add_request_handler(respond_to_channel_invite)
376 .add_request_handler(join_channel)
377 .add_request_handler(join_channel_chat)
378 .add_message_handler(leave_channel_chat)
379 .add_request_handler(send_channel_message)
380 .add_request_handler(remove_channel_message)
381 .add_request_handler(update_channel_message)
382 .add_request_handler(get_channel_messages)
383 .add_request_handler(get_channel_messages_by_id)
384 .add_request_handler(get_notifications)
385 .add_request_handler(mark_notification_as_read)
386 .add_request_handler(move_channel)
387 .add_request_handler(follow)
388 .add_message_handler(unfollow)
389 .add_message_handler(update_followers)
390 .add_request_handler(get_private_user_info)
391 .add_request_handler(get_llm_api_token)
392 .add_request_handler(accept_terms_of_service)
393 .add_message_handler(acknowledge_channel_message)
394 .add_message_handler(acknowledge_buffer_version)
395 .add_request_handler(get_supermaven_api_key)
396 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
398 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
399 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
401 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
402 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
403 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
405 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
406 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
408 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
409 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
411 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
412 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
415 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
416 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
417 .add_message_handler(update_context);
418
419 Arc::new(server)
420 }
421
422 pub async fn start(&self) -> Result<()> {
423 let server_id = *self.id.lock();
424 let app_state = self.app_state.clone();
425 let peer = self.peer.clone();
426 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
427 let pool = self.connection_pool.clone();
428 let livekit_client = self.app_state.livekit_client.clone();
429
430 let span = info_span!("start server");
431 self.app_state.executor.spawn_detached(
432 async move {
433 tracing::info!("waiting for cleanup timeout");
434 timeout.await;
435 tracing::info!("cleanup timeout expired, retrieving stale rooms");
436
437 app_state
438 .db
439 .delete_stale_channel_chat_participants(
440 &app_state.config.zed_environment,
441 server_id,
442 )
443 .await
444 .trace_err();
445
446 if let Some((room_ids, channel_ids)) = app_state
447 .db
448 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
449 .await
450 .trace_err()
451 {
452 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
453 tracing::info!(
454 stale_channel_buffer_count = channel_ids.len(),
455 "retrieved stale channel buffers"
456 );
457
458 for channel_id in channel_ids {
459 if let Some(refreshed_channel_buffer) = app_state
460 .db
461 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
462 .await
463 .trace_err()
464 {
465 for connection_id in refreshed_channel_buffer.connection_ids {
466 peer.send(
467 connection_id,
468 proto::UpdateChannelBufferCollaborators {
469 channel_id: channel_id.to_proto(),
470 collaborators: refreshed_channel_buffer
471 .collaborators
472 .clone(),
473 },
474 )
475 .trace_err();
476 }
477 }
478 }
479
480 for room_id in room_ids {
481 let mut contacts_to_update = HashSet::default();
482 let mut canceled_calls_to_user_ids = Vec::new();
483 let mut livekit_room = String::new();
484 let mut delete_livekit_room = false;
485
486 if let Some(mut refreshed_room) = app_state
487 .db
488 .clear_stale_room_participants(room_id, server_id)
489 .await
490 .trace_err()
491 {
492 tracing::info!(
493 room_id = room_id.0,
494 new_participant_count = refreshed_room.room.participants.len(),
495 "refreshed room"
496 );
497 room_updated(&refreshed_room.room, &peer);
498 if let Some(channel) = refreshed_room.channel.as_ref() {
499 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
500 }
501 contacts_to_update
502 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
503 contacts_to_update
504 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
505 canceled_calls_to_user_ids =
506 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
507 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
508 delete_livekit_room = refreshed_room.room.participants.is_empty();
509 }
510
511 {
512 let pool = pool.lock();
513 for canceled_user_id in canceled_calls_to_user_ids {
514 for connection_id in pool.user_connection_ids(canceled_user_id) {
515 peer.send(
516 connection_id,
517 proto::CallCanceled {
518 room_id: room_id.to_proto(),
519 },
520 )
521 .trace_err();
522 }
523 }
524 }
525
526 for user_id in contacts_to_update {
527 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
528 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
529 if let Some((busy, contacts)) = busy.zip(contacts) {
530 let pool = pool.lock();
531 let updated_contact = contact_for_user(user_id, busy, &pool);
532 for contact in contacts {
533 if let db::Contact::Accepted {
534 user_id: contact_user_id,
535 ..
536 } = contact
537 {
538 for contact_conn_id in
539 pool.user_connection_ids(contact_user_id)
540 {
541 peer.send(
542 contact_conn_id,
543 proto::UpdateContacts {
544 contacts: vec![updated_contact.clone()],
545 remove_contacts: Default::default(),
546 incoming_requests: Default::default(),
547 remove_incoming_requests: Default::default(),
548 outgoing_requests: Default::default(),
549 remove_outgoing_requests: Default::default(),
550 },
551 )
552 .trace_err();
553 }
554 }
555 }
556 }
557 }
558
559 if let Some(live_kit) = livekit_client.as_ref() {
560 if delete_livekit_room {
561 live_kit.delete_room(livekit_room).await.trace_err();
562 }
563 }
564 }
565 }
566
567 app_state
568 .db
569 .delete_stale_channel_chat_participants(
570 &app_state.config.zed_environment,
571 server_id,
572 )
573 .await
574 .trace_err();
575
576 app_state
577 .db
578 .clear_old_worktree_entries(server_id)
579 .await
580 .trace_err();
581
582 app_state
583 .db
584 .delete_stale_servers(&app_state.config.zed_environment, server_id)
585 .await
586 .trace_err();
587 }
588 .instrument(span),
589 );
590 Ok(())
591 }
592
593 pub fn teardown(&self) {
594 self.peer.teardown();
595 self.connection_pool.lock().reset();
596 let _ = self.teardown.send(true);
597 }
598
599 #[cfg(test)]
600 pub fn reset(&self, id: ServerId) {
601 self.teardown();
602 *self.id.lock() = id;
603 self.peer.reset(id.0 as u32);
604 let _ = self.teardown.send(false);
605 }
606
607 #[cfg(test)]
608 pub fn id(&self) -> ServerId {
609 *self.id.lock()
610 }
611
612 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
613 where
614 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
615 Fut: 'static + Send + Future<Output = Result<()>>,
616 M: EnvelopedMessage,
617 {
618 let prev_handler = self.handlers.insert(
619 TypeId::of::<M>(),
620 Box::new(move |envelope, session| {
621 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
622 let received_at = envelope.received_at;
623 tracing::info!("message received");
624 let start_time = Instant::now();
625 let future = (handler)(*envelope, session);
626 async move {
627 let result = future.await;
628 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
629 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
630 let queue_duration_ms = total_duration_ms - processing_duration_ms;
631 let payload_type = M::NAME;
632
633 match result {
634 Err(error) => {
635 tracing::error!(
636 ?error,
637 total_duration_ms,
638 processing_duration_ms,
639 queue_duration_ms,
640 payload_type,
641 "error handling message"
642 )
643 }
644 Ok(()) => tracing::info!(
645 total_duration_ms,
646 processing_duration_ms,
647 queue_duration_ms,
648 "finished handling message"
649 ),
650 }
651 }
652 .boxed()
653 }),
654 );
655 if prev_handler.is_some() {
656 panic!("registered a handler for the same message twice");
657 }
658 self
659 }
660
661 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
662 where
663 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
664 Fut: 'static + Send + Future<Output = Result<()>>,
665 M: EnvelopedMessage,
666 {
667 self.add_handler(move |envelope, session| handler(envelope.payload, session));
668 self
669 }
670
671 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
674 Fut: Send + Future<Output = Result<()>>,
675 M: RequestMessage,
676 {
677 let handler = Arc::new(handler);
678 self.add_handler(move |envelope, session| {
679 let receipt = envelope.receipt();
680 let handler = handler.clone();
681 async move {
682 let peer = session.peer.clone();
683 let responded = Arc::new(AtomicBool::default());
684 let response = Response {
685 peer: peer.clone(),
686 responded: responded.clone(),
687 receipt,
688 };
689 match (handler)(envelope.payload, response, session).await {
690 Ok(()) => {
691 if responded.load(std::sync::atomic::Ordering::SeqCst) {
692 Ok(())
693 } else {
694 Err(anyhow!("handler did not send a response"))?
695 }
696 }
697 Err(error) => {
698 let proto_err = match &error {
699 Error::Internal(err) => err.to_proto(),
700 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
701 };
702 peer.respond_with_error(receipt, proto_err)?;
703 Err(error)
704 }
705 }
706 }
707 })
708 }
709
710 pub fn handle_connection(
711 self: &Arc<Self>,
712 connection: Connection,
713 address: String,
714 principal: Principal,
715 zed_version: ZedVersion,
716 geoip_country_code: Option<String>,
717 system_id: Option<String>,
718 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
719 executor: Executor,
720 ) -> impl Future<Output = ()> + use<> {
721 let this = self.clone();
722 let span = info_span!("handle connection", %address,
723 connection_id=field::Empty,
724 user_id=field::Empty,
725 login=field::Empty,
726 impersonator=field::Empty,
727 geoip_country_code=field::Empty
728 );
729 principal.update_span(&span);
730 if let Some(country_code) = geoip_country_code.as_ref() {
731 span.record("geoip_country_code", country_code);
732 }
733
734 let mut teardown = self.teardown.subscribe();
735 async move {
736 if *teardown.borrow() {
737 tracing::error!("server is tearing down");
738 return
739 }
740 let (connection_id, handle_io, mut incoming_rx) = this
741 .peer
742 .add_connection(connection, {
743 let executor = executor.clone();
744 move |duration| executor.sleep(duration)
745 });
746 tracing::Span::current().record("connection_id", format!("{}", connection_id));
747
748 tracing::info!("connection opened");
749
750 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
751 let http_client = match ReqwestClient::user_agent(&user_agent) {
752 Ok(http_client) => Arc::new(http_client),
753 Err(error) => {
754 tracing::error!(?error, "failed to create HTTP client");
755 return;
756 }
757 };
758
759 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
760 supermaven_admin_api_key.to_string(),
761 http_client.clone(),
762 )));
763
764 let session = Session {
765 principal: principal.clone(),
766 connection_id,
767 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
768 peer: this.peer.clone(),
769 connection_pool: this.connection_pool.clone(),
770 app_state: this.app_state.clone(),
771 geoip_country_code,
772 system_id,
773 _executor: executor.clone(),
774 supermaven_client,
775 };
776
777 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
778 tracing::error!(?error, "failed to send initial client update");
779 return;
780 }
781
782 let handle_io = handle_io.fuse();
783 futures::pin_mut!(handle_io);
784
785 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
786 // This prevents deadlocks when e.g., client A performs a request to client B and
787 // client B performs a request to client A. If both clients stop processing further
788 // messages until their respective request completes, they won't have a chance to
789 // respond to the other client's request and cause a deadlock.
790 //
791 // This arrangement ensures we will attempt to process earlier messages first, but fall
792 // back to processing messages arrived later in the spirit of making progress.
793 let mut foreground_message_handlers = FuturesUnordered::new();
794 let concurrent_handlers = Arc::new(Semaphore::new(256));
795 loop {
796 let next_message = async {
797 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
798 let message = incoming_rx.next().await;
799 (permit, message)
800 }.fuse();
801 futures::pin_mut!(next_message);
802 futures::select_biased! {
803 _ = teardown.changed().fuse() => return,
804 result = handle_io => {
805 if let Err(error) = result {
806 tracing::error!(?error, "error handling I/O");
807 }
808 break;
809 }
810 _ = foreground_message_handlers.next() => {}
811 next_message = next_message => {
812 let (permit, message) = next_message;
813 if let Some(message) = message {
814 let type_name = message.payload_type_name();
815 // note: we copy all the fields from the parent span so we can query them in the logs.
816 // (https://github.com/tokio-rs/tracing/issues/2670).
817 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
818 user_id=field::Empty,
819 login=field::Empty,
820 impersonator=field::Empty,
821 );
822 principal.update_span(&span);
823 let span_enter = span.enter();
824 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
825 let is_background = message.is_background();
826 let handle_message = (handler)(message, session.clone());
827 drop(span_enter);
828
829 let handle_message = async move {
830 handle_message.await;
831 drop(permit);
832 }.instrument(span);
833 if is_background {
834 executor.spawn_detached(handle_message);
835 } else {
836 foreground_message_handlers.push(handle_message);
837 }
838 } else {
839 tracing::error!("no message handler");
840 }
841 } else {
842 tracing::info!("connection closed");
843 break;
844 }
845 }
846 }
847 }
848
849 drop(foreground_message_handlers);
850 tracing::info!("signing out");
851 if let Err(error) = connection_lost(session, teardown, executor).await {
852 tracing::error!(?error, "error signing out");
853 }
854
855 }.instrument(span)
856 }
857
858 async fn send_initial_client_update(
859 &self,
860 connection_id: ConnectionId,
861 zed_version: ZedVersion,
862 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
863 session: &Session,
864 ) -> Result<()> {
865 self.peer.send(
866 connection_id,
867 proto::Hello {
868 peer_id: Some(connection_id.into()),
869 },
870 )?;
871 tracing::info!("sent hello message");
872 if let Some(send_connection_id) = send_connection_id.take() {
873 let _ = send_connection_id.send(connection_id);
874 }
875
876 match &session.principal {
877 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
878 if !user.connected_once {
879 self.peer.send(connection_id, proto::ShowContacts {})?;
880 self.app_state
881 .db
882 .set_user_connected_once(user.id, true)
883 .await?;
884 }
885
886 update_user_plan(session).await?;
887
888 let contacts = self.app_state.db.get_contacts(user.id).await?;
889
890 {
891 let mut pool = self.connection_pool.lock();
892 pool.add_connection(connection_id, user.id, user.admin, zed_version);
893 self.peer.send(
894 connection_id,
895 build_initial_contacts_update(contacts, &pool),
896 )?;
897 }
898
899 if should_auto_subscribe_to_channels(zed_version) {
900 subscribe_user_to_channels(user.id, session).await?;
901 }
902
903 if let Some(incoming_call) =
904 self.app_state.db.incoming_call_for_user(user.id).await?
905 {
906 self.peer.send(connection_id, incoming_call)?;
907 }
908
909 update_user_contacts(user.id, session).await?;
910 }
911 }
912
913 Ok(())
914 }
915
916 pub async fn invite_code_redeemed(
917 self: &Arc<Self>,
918 inviter_id: UserId,
919 invitee_id: UserId,
920 ) -> Result<()> {
921 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
922 if let Some(code) = &user.invite_code {
923 let pool = self.connection_pool.lock();
924 let invitee_contact = contact_for_user(invitee_id, false, &pool);
925 for connection_id in pool.user_connection_ids(inviter_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateContacts {
929 contacts: vec![invitee_contact.clone()],
930 ..Default::default()
931 },
932 )?;
933 self.peer.send(
934 connection_id,
935 proto::UpdateInviteInfo {
936 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
937 count: user.invite_count as u32,
938 },
939 )?;
940 }
941 }
942 }
943 Ok(())
944 }
945
946 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
947 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
948 if let Some(invite_code) = &user.invite_code {
949 let pool = self.connection_pool.lock();
950 for connection_id in pool.user_connection_ids(user_id) {
951 self.peer.send(
952 connection_id,
953 proto::UpdateInviteInfo {
954 url: format!(
955 "{}{}",
956 self.app_state.config.invite_link_prefix, invite_code
957 ),
958 count: user.invite_count as u32,
959 },
960 )?;
961 }
962 }
963 }
964 Ok(())
965 }
966
967 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
968 let user = self
969 .app_state
970 .db
971 .get_user_by_id(user_id)
972 .await?
973 .context("user not found")?;
974
975 let update_user_plan = make_update_user_plan_message(
976 &user,
977 user.admin,
978 &self.app_state.db,
979 self.app_state.llm_db.clone(),
980 )
981 .await?;
982
983 let pool = self.connection_pool.lock();
984 for connection_id in pool.user_connection_ids(user_id) {
985 self.peer
986 .send(connection_id, update_user_plan.clone())
987 .trace_err();
988 }
989
990 Ok(())
991 }
992
993 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
994 let pool = self.connection_pool.lock();
995 for connection_id in pool.user_connection_ids(user_id) {
996 self.peer
997 .send(connection_id, proto::RefreshLlmToken {})
998 .trace_err();
999 }
1000 }
1001
1002 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1003 ServerSnapshot {
1004 connection_pool: ConnectionPoolGuard {
1005 guard: self.connection_pool.lock(),
1006 _not_send: PhantomData,
1007 },
1008 peer: &self.peer,
1009 }
1010 }
1011}
1012
1013impl Deref for ConnectionPoolGuard<'_> {
1014 type Target = ConnectionPool;
1015
1016 fn deref(&self) -> &Self::Target {
1017 &self.guard
1018 }
1019}
1020
1021impl DerefMut for ConnectionPoolGuard<'_> {
1022 fn deref_mut(&mut self) -> &mut Self::Target {
1023 &mut self.guard
1024 }
1025}
1026
1027impl Drop for ConnectionPoolGuard<'_> {
1028 fn drop(&mut self) {
1029 #[cfg(test)]
1030 self.check_invariants();
1031 }
1032}
1033
1034fn broadcast<F>(
1035 sender_id: Option<ConnectionId>,
1036 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1037 mut f: F,
1038) where
1039 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1040{
1041 for receiver_id in receiver_ids {
1042 if Some(receiver_id) != sender_id {
1043 if let Err(error) = f(receiver_id) {
1044 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1045 }
1046 }
1047 }
1048}
1049
1050pub struct ProtocolVersion(u32);
1051
1052impl Header for ProtocolVersion {
1053 fn name() -> &'static HeaderName {
1054 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1055 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1056 }
1057
1058 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1059 where
1060 Self: Sized,
1061 I: Iterator<Item = &'i axum::http::HeaderValue>,
1062 {
1063 let version = values
1064 .next()
1065 .ok_or_else(axum::headers::Error::invalid)?
1066 .to_str()
1067 .map_err(|_| axum::headers::Error::invalid())?
1068 .parse()
1069 .map_err(|_| axum::headers::Error::invalid())?;
1070 Ok(Self(version))
1071 }
1072
1073 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1074 values.extend([self.0.to_string().parse().unwrap()]);
1075 }
1076}
1077
1078pub struct AppVersionHeader(SemanticVersion);
1079impl Header for AppVersionHeader {
1080 fn name() -> &'static HeaderName {
1081 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1082 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1083 }
1084
1085 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1086 where
1087 Self: Sized,
1088 I: Iterator<Item = &'i axum::http::HeaderValue>,
1089 {
1090 let version = values
1091 .next()
1092 .ok_or_else(axum::headers::Error::invalid)?
1093 .to_str()
1094 .map_err(|_| axum::headers::Error::invalid())?
1095 .parse()
1096 .map_err(|_| axum::headers::Error::invalid())?;
1097 Ok(Self(version))
1098 }
1099
1100 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1101 values.extend([self.0.to_string().parse().unwrap()]);
1102 }
1103}
1104
1105pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1106 Router::new()
1107 .route("/rpc", get(handle_websocket_request))
1108 .layer(
1109 ServiceBuilder::new()
1110 .layer(Extension(server.app_state.clone()))
1111 .layer(middleware::from_fn(auth::validate_header)),
1112 )
1113 .route("/metrics", get(handle_metrics))
1114 .layer(Extension(server))
1115}
1116
1117pub async fn handle_websocket_request(
1118 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1119 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1120 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1121 Extension(server): Extension<Arc<Server>>,
1122 Extension(principal): Extension<Principal>,
1123 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1124 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1125 ws: WebSocketUpgrade,
1126) -> axum::response::Response {
1127 if protocol_version != rpc::PROTOCOL_VERSION {
1128 return (
1129 StatusCode::UPGRADE_REQUIRED,
1130 "client must be upgraded".to_string(),
1131 )
1132 .into_response();
1133 }
1134
1135 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1136 return (
1137 StatusCode::UPGRADE_REQUIRED,
1138 "no version header found".to_string(),
1139 )
1140 .into_response();
1141 };
1142
1143 if !version.can_collaborate() {
1144 return (
1145 StatusCode::UPGRADE_REQUIRED,
1146 "client must be upgraded".to_string(),
1147 )
1148 .into_response();
1149 }
1150
1151 let socket_address = socket_address.to_string();
1152 ws.on_upgrade(move |socket| {
1153 let socket = socket
1154 .map_ok(to_tungstenite_message)
1155 .err_into()
1156 .with(|message| async move { to_axum_message(message) });
1157 let connection = Connection::new(Box::pin(socket));
1158 async move {
1159 server
1160 .handle_connection(
1161 connection,
1162 socket_address,
1163 principal,
1164 version,
1165 country_code_header.map(|header| header.to_string()),
1166 system_id_header.map(|header| header.to_string()),
1167 None,
1168 Executor::Production,
1169 )
1170 .await;
1171 }
1172 })
1173}
1174
1175pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1176 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1177 let connections_metric = CONNECTIONS_METRIC
1178 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1179
1180 let connections = server
1181 .connection_pool
1182 .lock()
1183 .connections()
1184 .filter(|connection| !connection.admin)
1185 .count();
1186 connections_metric.set(connections as _);
1187
1188 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1189 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1190 register_int_gauge!(
1191 "shared_projects",
1192 "number of open projects with one or more guests"
1193 )
1194 .unwrap()
1195 });
1196
1197 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1198 shared_projects_metric.set(shared_projects as _);
1199
1200 let encoder = prometheus::TextEncoder::new();
1201 let metric_families = prometheus::gather();
1202 let encoded_metrics = encoder
1203 .encode_to_string(&metric_families)
1204 .map_err(|err| anyhow!("{err}"))?;
1205 Ok(encoded_metrics)
1206}
1207
1208#[instrument(err, skip(executor))]
1209async fn connection_lost(
1210 session: Session,
1211 mut teardown: watch::Receiver<bool>,
1212 executor: Executor,
1213) -> Result<()> {
1214 session.peer.disconnect(session.connection_id);
1215 session
1216 .connection_pool()
1217 .await
1218 .remove_connection(session.connection_id)?;
1219
1220 session
1221 .db()
1222 .await
1223 .connection_lost(session.connection_id)
1224 .await
1225 .trace_err();
1226
1227 futures::select_biased! {
1228 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1229
1230 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1231 leave_room_for_session(&session, session.connection_id).await.trace_err();
1232 leave_channel_buffers_for_session(&session)
1233 .await
1234 .trace_err();
1235
1236 if !session
1237 .connection_pool()
1238 .await
1239 .is_user_online(session.user_id())
1240 {
1241 let db = session.db().await;
1242 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1243 room_updated(&room, &session.peer);
1244 }
1245 }
1246
1247 update_user_contacts(session.user_id(), &session).await?;
1248 },
1249 _ = teardown.changed().fuse() => {}
1250 }
1251
1252 Ok(())
1253}
1254
1255/// Acknowledges a ping from a client, used to keep the connection alive.
1256async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1257 response.send(proto::Ack {})?;
1258 Ok(())
1259}
1260
1261/// Creates a new room for calling (outside of channels)
1262async fn create_room(
1263 _request: proto::CreateRoom,
1264 response: Response<proto::CreateRoom>,
1265 session: Session,
1266) -> Result<()> {
1267 let livekit_room = nanoid::nanoid!(30);
1268
1269 let live_kit_connection_info = util::maybe!(async {
1270 let live_kit = session.app_state.livekit_client.as_ref();
1271 let live_kit = live_kit?;
1272 let user_id = session.user_id().to_string();
1273
1274 let token = live_kit
1275 .room_token(&livekit_room, &user_id.to_string())
1276 .trace_err()?;
1277
1278 Some(proto::LiveKitConnectionInfo {
1279 server_url: live_kit.url().into(),
1280 token,
1281 can_publish: true,
1282 })
1283 })
1284 .await;
1285
1286 let room = session
1287 .db()
1288 .await
1289 .create_room(session.user_id(), session.connection_id, &livekit_room)
1290 .await?;
1291
1292 response.send(proto::CreateRoomResponse {
1293 room: Some(room.clone()),
1294 live_kit_connection_info,
1295 })?;
1296
1297 update_user_contacts(session.user_id(), &session).await?;
1298 Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303 request: proto::JoinRoom,
1304 response: Response<proto::JoinRoom>,
1305 session: Session,
1306) -> Result<()> {
1307 let room_id = RoomId::from_proto(request.id);
1308
1309 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311 if let Some(channel_id) = channel_id {
1312 return join_channel_internal(channel_id, Box::new(response), session).await;
1313 }
1314
1315 let joined_room = {
1316 let room = session
1317 .db()
1318 .await
1319 .join_room(room_id, session.user_id(), session.connection_id)
1320 .await?;
1321 room_updated(&room.room, &session.peer);
1322 room.into_inner()
1323 };
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id())
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342 {
1343 live_kit
1344 .room_token(
1345 &joined_room.room.livekit_room,
1346 &session.user_id().to_string(),
1347 )
1348 .trace_err()
1349 .map(|token| proto::LiveKitConnectionInfo {
1350 server_url: live_kit.url().into(),
1351 token,
1352 can_publish: true,
1353 })
1354 } else {
1355 None
1356 };
1357
1358 response.send(proto::JoinRoomResponse {
1359 room: Some(joined_room.room),
1360 channel_id: None,
1361 live_kit_connection_info,
1362 })?;
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370 request: proto::RejoinRoom,
1371 response: Response<proto::RejoinRoom>,
1372 session: Session,
1373) -> Result<()> {
1374 let room;
1375 let channel;
1376 {
1377 let mut rejoined_room = session
1378 .db()
1379 .await
1380 .rejoin_room(request, session.user_id(), session.connection_id)
1381 .await?;
1382
1383 response.send(proto::RejoinRoomResponse {
1384 room: Some(rejoined_room.room.clone()),
1385 reshared_projects: rejoined_room
1386 .reshared_projects
1387 .iter()
1388 .map(|project| proto::ResharedProject {
1389 id: project.id.to_proto(),
1390 collaborators: project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.to_proto())
1394 .collect(),
1395 })
1396 .collect(),
1397 rejoined_projects: rejoined_room
1398 .rejoined_projects
1399 .iter()
1400 .map(|rejoined_project| rejoined_project.to_proto())
1401 .collect(),
1402 })?;
1403 room_updated(&rejoined_room.room, &session.peer);
1404
1405 for project in &rejoined_room.reshared_projects {
1406 for collaborator in &project.collaborators {
1407 session
1408 .peer
1409 .send(
1410 collaborator.connection_id,
1411 proto::UpdateProjectCollaborator {
1412 project_id: project.id.to_proto(),
1413 old_peer_id: Some(project.old_connection_id.into()),
1414 new_peer_id: Some(session.connection_id.into()),
1415 },
1416 )
1417 .trace_err();
1418 }
1419
1420 broadcast(
1421 Some(session.connection_id),
1422 project
1423 .collaborators
1424 .iter()
1425 .map(|collaborator| collaborator.connection_id),
1426 |connection_id| {
1427 session.peer.forward_send(
1428 session.connection_id,
1429 connection_id,
1430 proto::UpdateProject {
1431 project_id: project.id.to_proto(),
1432 worktrees: project.worktrees.clone(),
1433 },
1434 )
1435 },
1436 );
1437 }
1438
1439 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441 let rejoined_room = rejoined_room.into_inner();
1442
1443 room = rejoined_room.room;
1444 channel = rejoined_room.channel;
1445 }
1446
1447 if let Some(channel) = channel {
1448 channel_updated(
1449 &channel,
1450 &room,
1451 &session.peer,
1452 &*session.connection_pool().await,
1453 );
1454 }
1455
1456 update_user_contacts(session.user_id(), &session).await?;
1457 Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461 rejoined_projects: &mut Vec<RejoinedProject>,
1462 session: &Session,
1463) -> Result<()> {
1464 for project in rejoined_projects.iter() {
1465 for collaborator in &project.collaborators {
1466 session
1467 .peer
1468 .send(
1469 collaborator.connection_id,
1470 proto::UpdateProjectCollaborator {
1471 project_id: project.id.to_proto(),
1472 old_peer_id: Some(project.old_connection_id.into()),
1473 new_peer_id: Some(session.connection_id.into()),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 }
1479
1480 for project in rejoined_projects {
1481 for worktree in mem::take(&mut project.worktrees) {
1482 // Stream this worktree's entries.
1483 let message = proto::UpdateWorktree {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 abs_path: worktree.abs_path.clone(),
1487 root_name: worktree.root_name,
1488 updated_entries: worktree.updated_entries,
1489 removed_entries: worktree.removed_entries,
1490 scan_id: worktree.scan_id,
1491 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492 updated_repositories: worktree.updated_repositories,
1493 removed_repositories: worktree.removed_repositories,
1494 };
1495 for update in proto::split_worktree_update(message) {
1496 session.peer.send(session.connection_id, update)?;
1497 }
1498
1499 // Stream this worktree's diagnostics.
1500 for summary in worktree.diagnostic_summaries {
1501 session.peer.send(
1502 session.connection_id,
1503 proto::UpdateDiagnosticSummary {
1504 project_id: project.id.to_proto(),
1505 worktree_id: worktree.id,
1506 summary: Some(summary),
1507 },
1508 )?;
1509 }
1510
1511 for settings_file in worktree.settings_files {
1512 session.peer.send(
1513 session.connection_id,
1514 proto::UpdateWorktreeSettings {
1515 project_id: project.id.to_proto(),
1516 worktree_id: worktree.id,
1517 path: settings_file.path,
1518 content: Some(settings_file.content),
1519 kind: Some(settings_file.kind.to_proto().into()),
1520 },
1521 )?;
1522 }
1523 }
1524
1525 for repository in mem::take(&mut project.updated_repositories) {
1526 for update in split_repository_update(repository) {
1527 session.peer.send(session.connection_id, update)?;
1528 }
1529 }
1530
1531 for id in mem::take(&mut project.removed_repositories) {
1532 session.peer.send(
1533 session.connection_id,
1534 proto::RemoveRepository {
1535 project_id: project.id.to_proto(),
1536 id,
1537 },
1538 )?;
1539 }
1540 }
1541
1542 Ok(())
1543}
1544
1545/// leave room disconnects from the room.
1546async fn leave_room(
1547 _: proto::LeaveRoom,
1548 response: Response<proto::LeaveRoom>,
1549 session: Session,
1550) -> Result<()> {
1551 leave_room_for_session(&session, session.connection_id).await?;
1552 response.send(proto::Ack {})?;
1553 Ok(())
1554}
1555
1556/// Updates the permissions of someone else in the room.
1557async fn set_room_participant_role(
1558 request: proto::SetRoomParticipantRole,
1559 response: Response<proto::SetRoomParticipantRole>,
1560 session: Session,
1561) -> Result<()> {
1562 let user_id = UserId::from_proto(request.user_id);
1563 let role = ChannelRole::from(request.role());
1564
1565 let (livekit_room, can_publish) = {
1566 let room = session
1567 .db()
1568 .await
1569 .set_room_participant_role(
1570 session.user_id(),
1571 RoomId::from_proto(request.room_id),
1572 user_id,
1573 role,
1574 )
1575 .await?;
1576
1577 let livekit_room = room.livekit_room.clone();
1578 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1579 room_updated(&room, &session.peer);
1580 (livekit_room, can_publish)
1581 };
1582
1583 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1584 live_kit
1585 .update_participant(
1586 livekit_room.clone(),
1587 request.user_id.to_string(),
1588 livekit_api::proto::ParticipantPermission {
1589 can_subscribe: true,
1590 can_publish,
1591 can_publish_data: can_publish,
1592 hidden: false,
1593 recorder: false,
1594 },
1595 )
1596 .await
1597 .trace_err();
1598 }
1599
1600 response.send(proto::Ack {})?;
1601 Ok(())
1602}
1603
1604/// Call someone else into the current room
1605async fn call(
1606 request: proto::Call,
1607 response: Response<proto::Call>,
1608 session: Session,
1609) -> Result<()> {
1610 let room_id = RoomId::from_proto(request.room_id);
1611 let calling_user_id = session.user_id();
1612 let calling_connection_id = session.connection_id;
1613 let called_user_id = UserId::from_proto(request.called_user_id);
1614 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1615 if !session
1616 .db()
1617 .await
1618 .has_contact(calling_user_id, called_user_id)
1619 .await?
1620 {
1621 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1622 }
1623
1624 let incoming_call = {
1625 let (room, incoming_call) = &mut *session
1626 .db()
1627 .await
1628 .call(
1629 room_id,
1630 calling_user_id,
1631 calling_connection_id,
1632 called_user_id,
1633 initial_project_id,
1634 )
1635 .await?;
1636 room_updated(room, &session.peer);
1637 mem::take(incoming_call)
1638 };
1639 update_user_contacts(called_user_id, &session).await?;
1640
1641 let mut calls = session
1642 .connection_pool()
1643 .await
1644 .user_connection_ids(called_user_id)
1645 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1646 .collect::<FuturesUnordered<_>>();
1647
1648 while let Some(call_response) = calls.next().await {
1649 match call_response.as_ref() {
1650 Ok(_) => {
1651 response.send(proto::Ack {})?;
1652 return Ok(());
1653 }
1654 Err(_) => {
1655 call_response.trace_err();
1656 }
1657 }
1658 }
1659
1660 {
1661 let room = session
1662 .db()
1663 .await
1664 .call_failed(room_id, called_user_id)
1665 .await?;
1666 room_updated(&room, &session.peer);
1667 }
1668 update_user_contacts(called_user_id, &session).await?;
1669
1670 Err(anyhow!("failed to ring user"))?
1671}
1672
1673/// Cancel an outgoing call.
1674async fn cancel_call(
1675 request: proto::CancelCall,
1676 response: Response<proto::CancelCall>,
1677 session: Session,
1678) -> Result<()> {
1679 let called_user_id = UserId::from_proto(request.called_user_id);
1680 let room_id = RoomId::from_proto(request.room_id);
1681 {
1682 let room = session
1683 .db()
1684 .await
1685 .cancel_call(room_id, session.connection_id, called_user_id)
1686 .await?;
1687 room_updated(&room, &session.peer);
1688 }
1689
1690 for connection_id in session
1691 .connection_pool()
1692 .await
1693 .user_connection_ids(called_user_id)
1694 {
1695 session
1696 .peer
1697 .send(
1698 connection_id,
1699 proto::CallCanceled {
1700 room_id: room_id.to_proto(),
1701 },
1702 )
1703 .trace_err();
1704 }
1705 response.send(proto::Ack {})?;
1706
1707 update_user_contacts(called_user_id, &session).await?;
1708 Ok(())
1709}
1710
1711/// Decline an incoming call.
1712async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1713 let room_id = RoomId::from_proto(message.room_id);
1714 {
1715 let room = session
1716 .db()
1717 .await
1718 .decline_call(Some(room_id), session.user_id())
1719 .await?
1720 .context("declining call")?;
1721 room_updated(&room, &session.peer);
1722 }
1723
1724 for connection_id in session
1725 .connection_pool()
1726 .await
1727 .user_connection_ids(session.user_id())
1728 {
1729 session
1730 .peer
1731 .send(
1732 connection_id,
1733 proto::CallCanceled {
1734 room_id: room_id.to_proto(),
1735 },
1736 )
1737 .trace_err();
1738 }
1739 update_user_contacts(session.user_id(), &session).await?;
1740 Ok(())
1741}
1742
1743/// Updates other participants in the room with your current location.
1744async fn update_participant_location(
1745 request: proto::UpdateParticipantLocation,
1746 response: Response<proto::UpdateParticipantLocation>,
1747 session: Session,
1748) -> Result<()> {
1749 let room_id = RoomId::from_proto(request.room_id);
1750 let location = request.location.context("invalid location")?;
1751
1752 let db = session.db().await;
1753 let room = db
1754 .update_room_participant_location(room_id, session.connection_id, location)
1755 .await?;
1756
1757 room_updated(&room, &session.peer);
1758 response.send(proto::Ack {})?;
1759 Ok(())
1760}
1761
1762/// Share a project into the room.
1763async fn share_project(
1764 request: proto::ShareProject,
1765 response: Response<proto::ShareProject>,
1766 session: Session,
1767) -> Result<()> {
1768 let (project_id, room) = &*session
1769 .db()
1770 .await
1771 .share_project(
1772 RoomId::from_proto(request.room_id),
1773 session.connection_id,
1774 &request.worktrees,
1775 request.is_ssh_project,
1776 )
1777 .await?;
1778 response.send(proto::ShareProjectResponse {
1779 project_id: project_id.to_proto(),
1780 })?;
1781 room_updated(room, &session.peer);
1782
1783 Ok(())
1784}
1785
1786/// Unshare a project from the room.
1787async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1788 let project_id = ProjectId::from_proto(message.project_id);
1789 unshare_project_internal(project_id, session.connection_id, &session).await
1790}
1791
1792async fn unshare_project_internal(
1793 project_id: ProjectId,
1794 connection_id: ConnectionId,
1795 session: &Session,
1796) -> Result<()> {
1797 let delete = {
1798 let room_guard = session
1799 .db()
1800 .await
1801 .unshare_project(project_id, connection_id)
1802 .await?;
1803
1804 let (delete, room, guest_connection_ids) = &*room_guard;
1805
1806 let message = proto::UnshareProject {
1807 project_id: project_id.to_proto(),
1808 };
1809
1810 broadcast(
1811 Some(connection_id),
1812 guest_connection_ids.iter().copied(),
1813 |conn_id| session.peer.send(conn_id, message.clone()),
1814 );
1815 if let Some(room) = room {
1816 room_updated(room, &session.peer);
1817 }
1818
1819 *delete
1820 };
1821
1822 if delete {
1823 let db = session.db().await;
1824 db.delete_project(project_id).await?;
1825 }
1826
1827 Ok(())
1828}
1829
1830/// Join someone elses shared project.
1831async fn join_project(
1832 request: proto::JoinProject,
1833 response: Response<proto::JoinProject>,
1834 session: Session,
1835) -> Result<()> {
1836 let project_id = ProjectId::from_proto(request.project_id);
1837
1838 tracing::info!(%project_id, "join project");
1839
1840 let db = session.db().await;
1841 let (project, replica_id) = &mut *db
1842 .join_project(project_id, session.connection_id, session.user_id())
1843 .await?;
1844 drop(db);
1845 tracing::info!(%project_id, "join remote project");
1846 join_project_internal(response, session, project, replica_id)
1847}
1848
1849trait JoinProjectInternalResponse {
1850 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1851}
1852impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1853 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1854 Response::<proto::JoinProject>::send(self, result)
1855 }
1856}
1857
1858fn join_project_internal(
1859 response: impl JoinProjectInternalResponse,
1860 session: Session,
1861 project: &mut Project,
1862 replica_id: &ReplicaId,
1863) -> Result<()> {
1864 let collaborators = project
1865 .collaborators
1866 .iter()
1867 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1868 .map(|collaborator| collaborator.to_proto())
1869 .collect::<Vec<_>>();
1870 let project_id = project.id;
1871 let guest_user_id = session.user_id();
1872
1873 let worktrees = project
1874 .worktrees
1875 .iter()
1876 .map(|(id, worktree)| proto::WorktreeMetadata {
1877 id: *id,
1878 root_name: worktree.root_name.clone(),
1879 visible: worktree.visible,
1880 abs_path: worktree.abs_path.clone(),
1881 })
1882 .collect::<Vec<_>>();
1883
1884 let add_project_collaborator = proto::AddProjectCollaborator {
1885 project_id: project_id.to_proto(),
1886 collaborator: Some(proto::Collaborator {
1887 peer_id: Some(session.connection_id.into()),
1888 replica_id: replica_id.0 as u32,
1889 user_id: guest_user_id.to_proto(),
1890 is_host: false,
1891 }),
1892 };
1893
1894 for collaborator in &collaborators {
1895 session
1896 .peer
1897 .send(
1898 collaborator.peer_id.unwrap().into(),
1899 add_project_collaborator.clone(),
1900 )
1901 .trace_err();
1902 }
1903
1904 // First, we send the metadata associated with each worktree.
1905 response.send(proto::JoinProjectResponse {
1906 project_id: project.id.0 as u64,
1907 worktrees: worktrees.clone(),
1908 replica_id: replica_id.0 as u32,
1909 collaborators: collaborators.clone(),
1910 language_servers: project.language_servers.clone(),
1911 role: project.role.into(),
1912 })?;
1913
1914 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1915 // Stream this worktree's entries.
1916 let message = proto::UpdateWorktree {
1917 project_id: project_id.to_proto(),
1918 worktree_id,
1919 abs_path: worktree.abs_path.clone(),
1920 root_name: worktree.root_name,
1921 updated_entries: worktree.entries,
1922 removed_entries: Default::default(),
1923 scan_id: worktree.scan_id,
1924 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1925 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1926 removed_repositories: Default::default(),
1927 };
1928 for update in proto::split_worktree_update(message) {
1929 session.peer.send(session.connection_id, update.clone())?;
1930 }
1931
1932 // Stream this worktree's diagnostics.
1933 for summary in worktree.diagnostic_summaries {
1934 session.peer.send(
1935 session.connection_id,
1936 proto::UpdateDiagnosticSummary {
1937 project_id: project_id.to_proto(),
1938 worktree_id: worktree.id,
1939 summary: Some(summary),
1940 },
1941 )?;
1942 }
1943
1944 for settings_file in worktree.settings_files {
1945 session.peer.send(
1946 session.connection_id,
1947 proto::UpdateWorktreeSettings {
1948 project_id: project_id.to_proto(),
1949 worktree_id: worktree.id,
1950 path: settings_file.path,
1951 content: Some(settings_file.content),
1952 kind: Some(settings_file.kind.to_proto() as i32),
1953 },
1954 )?;
1955 }
1956 }
1957
1958 for repository in mem::take(&mut project.repositories) {
1959 for update in split_repository_update(repository) {
1960 session.peer.send(session.connection_id, update)?;
1961 }
1962 }
1963
1964 for language_server in &project.language_servers {
1965 session.peer.send(
1966 session.connection_id,
1967 proto::UpdateLanguageServer {
1968 project_id: project_id.to_proto(),
1969 language_server_id: language_server.id,
1970 variant: Some(
1971 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1972 proto::LspDiskBasedDiagnosticsUpdated {},
1973 ),
1974 ),
1975 },
1976 )?;
1977 }
1978
1979 Ok(())
1980}
1981
1982/// Leave someone elses shared project.
1983async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1984 let sender_id = session.connection_id;
1985 let project_id = ProjectId::from_proto(request.project_id);
1986 let db = session.db().await;
1987
1988 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1989 tracing::info!(
1990 %project_id,
1991 "leave project"
1992 );
1993
1994 project_left(project, &session);
1995 if let Some(room) = room {
1996 room_updated(room, &session.peer);
1997 }
1998
1999 Ok(())
2000}
2001
2002/// Updates other participants with changes to the project
2003async fn update_project(
2004 request: proto::UpdateProject,
2005 response: Response<proto::UpdateProject>,
2006 session: Session,
2007) -> Result<()> {
2008 let project_id = ProjectId::from_proto(request.project_id);
2009 let (room, guest_connection_ids) = &*session
2010 .db()
2011 .await
2012 .update_project(project_id, session.connection_id, &request.worktrees)
2013 .await?;
2014 broadcast(
2015 Some(session.connection_id),
2016 guest_connection_ids.iter().copied(),
2017 |connection_id| {
2018 session
2019 .peer
2020 .forward_send(session.connection_id, connection_id, request.clone())
2021 },
2022 );
2023 if let Some(room) = room {
2024 room_updated(room, &session.peer);
2025 }
2026 response.send(proto::Ack {})?;
2027
2028 Ok(())
2029}
2030
2031/// Updates other participants with changes to the worktree
2032async fn update_worktree(
2033 request: proto::UpdateWorktree,
2034 response: Response<proto::UpdateWorktree>,
2035 session: Session,
2036) -> Result<()> {
2037 let guest_connection_ids = session
2038 .db()
2039 .await
2040 .update_worktree(&request, session.connection_id)
2041 .await?;
2042
2043 broadcast(
2044 Some(session.connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |connection_id| {
2047 session
2048 .peer
2049 .forward_send(session.connection_id, connection_id, request.clone())
2050 },
2051 );
2052 response.send(proto::Ack {})?;
2053 Ok(())
2054}
2055
2056async fn update_repository(
2057 request: proto::UpdateRepository,
2058 response: Response<proto::UpdateRepository>,
2059 session: Session,
2060) -> Result<()> {
2061 let guest_connection_ids = session
2062 .db()
2063 .await
2064 .update_repository(&request, session.connection_id)
2065 .await?;
2066
2067 broadcast(
2068 Some(session.connection_id),
2069 guest_connection_ids.iter().copied(),
2070 |connection_id| {
2071 session
2072 .peer
2073 .forward_send(session.connection_id, connection_id, request.clone())
2074 },
2075 );
2076 response.send(proto::Ack {})?;
2077 Ok(())
2078}
2079
2080async fn remove_repository(
2081 request: proto::RemoveRepository,
2082 response: Response<proto::RemoveRepository>,
2083 session: Session,
2084) -> Result<()> {
2085 let guest_connection_ids = session
2086 .db()
2087 .await
2088 .remove_repository(&request, session.connection_id)
2089 .await?;
2090
2091 broadcast(
2092 Some(session.connection_id),
2093 guest_connection_ids.iter().copied(),
2094 |connection_id| {
2095 session
2096 .peer
2097 .forward_send(session.connection_id, connection_id, request.clone())
2098 },
2099 );
2100 response.send(proto::Ack {})?;
2101 Ok(())
2102}
2103
2104/// Updates other participants with changes to the diagnostics
2105async fn update_diagnostic_summary(
2106 message: proto::UpdateDiagnosticSummary,
2107 session: Session,
2108) -> Result<()> {
2109 let guest_connection_ids = session
2110 .db()
2111 .await
2112 .update_diagnostic_summary(&message, session.connection_id)
2113 .await?;
2114
2115 broadcast(
2116 Some(session.connection_id),
2117 guest_connection_ids.iter().copied(),
2118 |connection_id| {
2119 session
2120 .peer
2121 .forward_send(session.connection_id, connection_id, message.clone())
2122 },
2123 );
2124
2125 Ok(())
2126}
2127
2128/// Updates other participants with changes to the worktree settings
2129async fn update_worktree_settings(
2130 message: proto::UpdateWorktreeSettings,
2131 session: Session,
2132) -> Result<()> {
2133 let guest_connection_ids = session
2134 .db()
2135 .await
2136 .update_worktree_settings(&message, session.connection_id)
2137 .await?;
2138
2139 broadcast(
2140 Some(session.connection_id),
2141 guest_connection_ids.iter().copied(),
2142 |connection_id| {
2143 session
2144 .peer
2145 .forward_send(session.connection_id, connection_id, message.clone())
2146 },
2147 );
2148
2149 Ok(())
2150}
2151
2152/// Notify other participants that a language server has started.
2153async fn start_language_server(
2154 request: proto::StartLanguageServer,
2155 session: Session,
2156) -> Result<()> {
2157 let guest_connection_ids = session
2158 .db()
2159 .await
2160 .start_language_server(&request, session.connection_id)
2161 .await?;
2162
2163 broadcast(
2164 Some(session.connection_id),
2165 guest_connection_ids.iter().copied(),
2166 |connection_id| {
2167 session
2168 .peer
2169 .forward_send(session.connection_id, connection_id, request.clone())
2170 },
2171 );
2172 Ok(())
2173}
2174
2175/// Notify other participants that a language server has changed.
2176async fn update_language_server(
2177 request: proto::UpdateLanguageServer,
2178 session: Session,
2179) -> Result<()> {
2180 let project_id = ProjectId::from_proto(request.project_id);
2181 let project_connection_ids = session
2182 .db()
2183 .await
2184 .project_connection_ids(project_id, session.connection_id, true)
2185 .await?;
2186 broadcast(
2187 Some(session.connection_id),
2188 project_connection_ids.iter().copied(),
2189 |connection_id| {
2190 session
2191 .peer
2192 .forward_send(session.connection_id, connection_id, request.clone())
2193 },
2194 );
2195 Ok(())
2196}
2197
2198/// forward a project request to the host. These requests should be read only
2199/// as guests are allowed to send them.
2200async fn forward_read_only_project_request<T>(
2201 request: T,
2202 response: Response<T>,
2203 session: Session,
2204) -> Result<()>
2205where
2206 T: EntityMessage + RequestMessage,
2207{
2208 let project_id = ProjectId::from_proto(request.remote_entity_id());
2209 let host_connection_id = session
2210 .db()
2211 .await
2212 .host_for_read_only_project_request(project_id, session.connection_id)
2213 .await?;
2214 let payload = session
2215 .peer
2216 .forward_request(session.connection_id, host_connection_id, request)
2217 .await?;
2218 response.send(payload)?;
2219 Ok(())
2220}
2221
2222async fn forward_find_search_candidates_request(
2223 request: proto::FindSearchCandidates,
2224 response: Response<proto::FindSearchCandidates>,
2225 session: Session,
2226) -> Result<()> {
2227 let project_id = ProjectId::from_proto(request.remote_entity_id());
2228 let host_connection_id = session
2229 .db()
2230 .await
2231 .host_for_read_only_project_request(project_id, session.connection_id)
2232 .await?;
2233 let payload = session
2234 .peer
2235 .forward_request(session.connection_id, host_connection_id, request)
2236 .await?;
2237 response.send(payload)?;
2238 Ok(())
2239}
2240
2241/// forward a project request to the host. These requests are disallowed
2242/// for guests.
2243async fn forward_mutating_project_request<T>(
2244 request: T,
2245 response: Response<T>,
2246 session: Session,
2247) -> Result<()>
2248where
2249 T: EntityMessage + RequestMessage,
2250{
2251 let project_id = ProjectId::from_proto(request.remote_entity_id());
2252
2253 let host_connection_id = session
2254 .db()
2255 .await
2256 .host_for_mutating_project_request(project_id, session.connection_id)
2257 .await?;
2258 let payload = session
2259 .peer
2260 .forward_request(session.connection_id, host_connection_id, request)
2261 .await?;
2262 response.send(payload)?;
2263 Ok(())
2264}
2265
2266/// Notify other participants that a new buffer has been created
2267async fn create_buffer_for_peer(
2268 request: proto::CreateBufferForPeer,
2269 session: Session,
2270) -> Result<()> {
2271 session
2272 .db()
2273 .await
2274 .check_user_is_project_host(
2275 ProjectId::from_proto(request.project_id),
2276 session.connection_id,
2277 )
2278 .await?;
2279 let peer_id = request.peer_id.context("invalid peer id")?;
2280 session
2281 .peer
2282 .forward_send(session.connection_id, peer_id.into(), request)?;
2283 Ok(())
2284}
2285
2286/// Notify other participants that a buffer has been updated. This is
2287/// allowed for guests as long as the update is limited to selections.
2288async fn update_buffer(
2289 request: proto::UpdateBuffer,
2290 response: Response<proto::UpdateBuffer>,
2291 session: Session,
2292) -> Result<()> {
2293 let project_id = ProjectId::from_proto(request.project_id);
2294 let mut capability = Capability::ReadOnly;
2295
2296 for op in request.operations.iter() {
2297 match op.variant {
2298 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2299 Some(_) => capability = Capability::ReadWrite,
2300 }
2301 }
2302
2303 let host = {
2304 let guard = session
2305 .db()
2306 .await
2307 .connections_for_buffer_update(project_id, session.connection_id, capability)
2308 .await?;
2309
2310 let (host, guests) = &*guard;
2311
2312 broadcast(
2313 Some(session.connection_id),
2314 guests.clone(),
2315 |connection_id| {
2316 session
2317 .peer
2318 .forward_send(session.connection_id, connection_id, request.clone())
2319 },
2320 );
2321
2322 *host
2323 };
2324
2325 if host != session.connection_id {
2326 session
2327 .peer
2328 .forward_request(session.connection_id, host, request.clone())
2329 .await?;
2330 }
2331
2332 response.send(proto::Ack {})?;
2333 Ok(())
2334}
2335
2336async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2337 let project_id = ProjectId::from_proto(message.project_id);
2338
2339 let operation = message.operation.as_ref().context("invalid operation")?;
2340 let capability = match operation.variant.as_ref() {
2341 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2342 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2343 match buffer_op.variant {
2344 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2345 Capability::ReadOnly
2346 }
2347 _ => Capability::ReadWrite,
2348 }
2349 } else {
2350 Capability::ReadWrite
2351 }
2352 }
2353 Some(_) => Capability::ReadWrite,
2354 None => Capability::ReadOnly,
2355 };
2356
2357 let guard = session
2358 .db()
2359 .await
2360 .connections_for_buffer_update(project_id, session.connection_id, capability)
2361 .await?;
2362
2363 let (host, guests) = &*guard;
2364
2365 broadcast(
2366 Some(session.connection_id),
2367 guests.iter().chain([host]).copied(),
2368 |connection_id| {
2369 session
2370 .peer
2371 .forward_send(session.connection_id, connection_id, message.clone())
2372 },
2373 );
2374
2375 Ok(())
2376}
2377
2378/// Notify other participants that a project has been updated.
2379async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2380 request: T,
2381 session: Session,
2382) -> Result<()> {
2383 let project_id = ProjectId::from_proto(request.remote_entity_id());
2384 let project_connection_ids = session
2385 .db()
2386 .await
2387 .project_connection_ids(project_id, session.connection_id, false)
2388 .await?;
2389
2390 broadcast(
2391 Some(session.connection_id),
2392 project_connection_ids.iter().copied(),
2393 |connection_id| {
2394 session
2395 .peer
2396 .forward_send(session.connection_id, connection_id, request.clone())
2397 },
2398 );
2399 Ok(())
2400}
2401
2402/// Start following another user in a call.
2403async fn follow(
2404 request: proto::Follow,
2405 response: Response<proto::Follow>,
2406 session: Session,
2407) -> Result<()> {
2408 let room_id = RoomId::from_proto(request.room_id);
2409 let project_id = request.project_id.map(ProjectId::from_proto);
2410 let leader_id = request.leader_id.context("invalid leader id")?.into();
2411 let follower_id = session.connection_id;
2412
2413 session
2414 .db()
2415 .await
2416 .check_room_participants(room_id, leader_id, session.connection_id)
2417 .await?;
2418
2419 let response_payload = session
2420 .peer
2421 .forward_request(session.connection_id, leader_id, request)
2422 .await?;
2423 response.send(response_payload)?;
2424
2425 if let Some(project_id) = project_id {
2426 let room = session
2427 .db()
2428 .await
2429 .follow(room_id, project_id, leader_id, follower_id)
2430 .await?;
2431 room_updated(&room, &session.peer);
2432 }
2433
2434 Ok(())
2435}
2436
2437/// Stop following another user in a call.
2438async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2439 let room_id = RoomId::from_proto(request.room_id);
2440 let project_id = request.project_id.map(ProjectId::from_proto);
2441 let leader_id = request.leader_id.context("invalid leader id")?.into();
2442 let follower_id = session.connection_id;
2443
2444 session
2445 .db()
2446 .await
2447 .check_room_participants(room_id, leader_id, session.connection_id)
2448 .await?;
2449
2450 session
2451 .peer
2452 .forward_send(session.connection_id, leader_id, request)?;
2453
2454 if let Some(project_id) = project_id {
2455 let room = session
2456 .db()
2457 .await
2458 .unfollow(room_id, project_id, leader_id, follower_id)
2459 .await?;
2460 room_updated(&room, &session.peer);
2461 }
2462
2463 Ok(())
2464}
2465
2466/// Notify everyone following you of your current location.
2467async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2468 let room_id = RoomId::from_proto(request.room_id);
2469 let database = session.db.lock().await;
2470
2471 let connection_ids = if let Some(project_id) = request.project_id {
2472 let project_id = ProjectId::from_proto(project_id);
2473 database
2474 .project_connection_ids(project_id, session.connection_id, true)
2475 .await?
2476 } else {
2477 database
2478 .room_connection_ids(room_id, session.connection_id)
2479 .await?
2480 };
2481
2482 // For now, don't send view update messages back to that view's current leader.
2483 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2484 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2485 _ => None,
2486 });
2487
2488 for connection_id in connection_ids.iter().cloned() {
2489 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2490 session
2491 .peer
2492 .forward_send(session.connection_id, connection_id, request.clone())?;
2493 }
2494 }
2495 Ok(())
2496}
2497
2498/// Get public data about users.
2499async fn get_users(
2500 request: proto::GetUsers,
2501 response: Response<proto::GetUsers>,
2502 session: Session,
2503) -> Result<()> {
2504 let user_ids = request
2505 .user_ids
2506 .into_iter()
2507 .map(UserId::from_proto)
2508 .collect();
2509 let users = session
2510 .db()
2511 .await
2512 .get_users_by_ids(user_ids)
2513 .await?
2514 .into_iter()
2515 .map(|user| proto::User {
2516 id: user.id.to_proto(),
2517 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2518 github_login: user.github_login,
2519 email: user.email_address,
2520 name: user.name,
2521 })
2522 .collect();
2523 response.send(proto::UsersResponse { users })?;
2524 Ok(())
2525}
2526
2527/// Search for users (to invite) buy Github login
2528async fn fuzzy_search_users(
2529 request: proto::FuzzySearchUsers,
2530 response: Response<proto::FuzzySearchUsers>,
2531 session: Session,
2532) -> Result<()> {
2533 let query = request.query;
2534 let users = match query.len() {
2535 0 => vec![],
2536 1 | 2 => session
2537 .db()
2538 .await
2539 .get_user_by_github_login(&query)
2540 .await?
2541 .into_iter()
2542 .collect(),
2543 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2544 };
2545 let users = users
2546 .into_iter()
2547 .filter(|user| user.id != session.user_id())
2548 .map(|user| proto::User {
2549 id: user.id.to_proto(),
2550 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2551 github_login: user.github_login,
2552 name: user.name,
2553 email: user.email_address,
2554 })
2555 .collect();
2556 response.send(proto::UsersResponse { users })?;
2557 Ok(())
2558}
2559
2560/// Send a contact request to another user.
2561async fn request_contact(
2562 request: proto::RequestContact,
2563 response: Response<proto::RequestContact>,
2564 session: Session,
2565) -> Result<()> {
2566 let requester_id = session.user_id();
2567 let responder_id = UserId::from_proto(request.responder_id);
2568 if requester_id == responder_id {
2569 return Err(anyhow!("cannot add yourself as a contact"))?;
2570 }
2571
2572 let notifications = session
2573 .db()
2574 .await
2575 .send_contact_request(requester_id, responder_id)
2576 .await?;
2577
2578 // Update outgoing contact requests of requester
2579 let mut update = proto::UpdateContacts::default();
2580 update.outgoing_requests.push(responder_id.to_proto());
2581 for connection_id in session
2582 .connection_pool()
2583 .await
2584 .user_connection_ids(requester_id)
2585 {
2586 session.peer.send(connection_id, update.clone())?;
2587 }
2588
2589 // Update incoming contact requests of responder
2590 let mut update = proto::UpdateContacts::default();
2591 update
2592 .incoming_requests
2593 .push(proto::IncomingContactRequest {
2594 requester_id: requester_id.to_proto(),
2595 });
2596 let connection_pool = session.connection_pool().await;
2597 for connection_id in connection_pool.user_connection_ids(responder_id) {
2598 session.peer.send(connection_id, update.clone())?;
2599 }
2600
2601 send_notifications(&connection_pool, &session.peer, notifications);
2602
2603 response.send(proto::Ack {})?;
2604 Ok(())
2605}
2606
2607/// Accept or decline a contact request
2608async fn respond_to_contact_request(
2609 request: proto::RespondToContactRequest,
2610 response: Response<proto::RespondToContactRequest>,
2611 session: Session,
2612) -> Result<()> {
2613 let responder_id = session.user_id();
2614 let requester_id = UserId::from_proto(request.requester_id);
2615 let db = session.db().await;
2616 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2617 db.dismiss_contact_notification(responder_id, requester_id)
2618 .await?;
2619 } else {
2620 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2621
2622 let notifications = db
2623 .respond_to_contact_request(responder_id, requester_id, accept)
2624 .await?;
2625 let requester_busy = db.is_user_busy(requester_id).await?;
2626 let responder_busy = db.is_user_busy(responder_id).await?;
2627
2628 let pool = session.connection_pool().await;
2629 // Update responder with new contact
2630 let mut update = proto::UpdateContacts::default();
2631 if accept {
2632 update
2633 .contacts
2634 .push(contact_for_user(requester_id, requester_busy, &pool));
2635 }
2636 update
2637 .remove_incoming_requests
2638 .push(requester_id.to_proto());
2639 for connection_id in pool.user_connection_ids(responder_id) {
2640 session.peer.send(connection_id, update.clone())?;
2641 }
2642
2643 // Update requester with new contact
2644 let mut update = proto::UpdateContacts::default();
2645 if accept {
2646 update
2647 .contacts
2648 .push(contact_for_user(responder_id, responder_busy, &pool));
2649 }
2650 update
2651 .remove_outgoing_requests
2652 .push(responder_id.to_proto());
2653
2654 for connection_id in pool.user_connection_ids(requester_id) {
2655 session.peer.send(connection_id, update.clone())?;
2656 }
2657
2658 send_notifications(&pool, &session.peer, notifications);
2659 }
2660
2661 response.send(proto::Ack {})?;
2662 Ok(())
2663}
2664
2665/// Remove a contact.
2666async fn remove_contact(
2667 request: proto::RemoveContact,
2668 response: Response<proto::RemoveContact>,
2669 session: Session,
2670) -> Result<()> {
2671 let requester_id = session.user_id();
2672 let responder_id = UserId::from_proto(request.user_id);
2673 let db = session.db().await;
2674 let (contact_accepted, deleted_notification_id) =
2675 db.remove_contact(requester_id, responder_id).await?;
2676
2677 let pool = session.connection_pool().await;
2678 // Update outgoing contact requests of requester
2679 let mut update = proto::UpdateContacts::default();
2680 if contact_accepted {
2681 update.remove_contacts.push(responder_id.to_proto());
2682 } else {
2683 update
2684 .remove_outgoing_requests
2685 .push(responder_id.to_proto());
2686 }
2687 for connection_id in pool.user_connection_ids(requester_id) {
2688 session.peer.send(connection_id, update.clone())?;
2689 }
2690
2691 // Update incoming contact requests of responder
2692 let mut update = proto::UpdateContacts::default();
2693 if contact_accepted {
2694 update.remove_contacts.push(requester_id.to_proto());
2695 } else {
2696 update
2697 .remove_incoming_requests
2698 .push(requester_id.to_proto());
2699 }
2700 for connection_id in pool.user_connection_ids(responder_id) {
2701 session.peer.send(connection_id, update.clone())?;
2702 if let Some(notification_id) = deleted_notification_id {
2703 session.peer.send(
2704 connection_id,
2705 proto::DeleteNotification {
2706 notification_id: notification_id.to_proto(),
2707 },
2708 )?;
2709 }
2710 }
2711
2712 response.send(proto::Ack {})?;
2713 Ok(())
2714}
2715
2716fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2717 version.0.minor() < 139
2718}
2719
2720async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2721 if is_staff {
2722 return Ok(proto::Plan::ZedPro);
2723 }
2724
2725 let subscription = db.get_active_billing_subscription(user_id).await?;
2726 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2727
2728 let plan = if let Some(subscription_kind) = subscription_kind {
2729 match subscription_kind {
2730 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2731 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2732 SubscriptionKind::ZedFree => proto::Plan::Free,
2733 }
2734 } else {
2735 proto::Plan::Free
2736 };
2737
2738 Ok(plan)
2739}
2740
2741async fn make_update_user_plan_message(
2742 user: &User,
2743 is_staff: bool,
2744 db: &Arc<Database>,
2745 llm_db: Option<Arc<LlmDatabase>>,
2746) -> Result<proto::UpdateUserPlan> {
2747 let feature_flags = db.get_user_flags(user.id).await?;
2748 let plan = current_plan(db, user.id, is_staff).await?;
2749 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2750 let billing_preferences = db.get_billing_preferences(user.id).await?;
2751
2752 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2753 let subscription = db.get_active_billing_subscription(user.id).await?;
2754
2755 let subscription_period =
2756 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2757
2758 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2759 llm_db
2760 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2761 .await?
2762 } else {
2763 None
2764 };
2765
2766 (subscription_period, usage)
2767 } else {
2768 (None, None)
2769 };
2770
2771 let account_too_young =
2772 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2773
2774 Ok(proto::UpdateUserPlan {
2775 plan: plan.into(),
2776 trial_started_at: billing_customer
2777 .as_ref()
2778 .and_then(|billing_customer| billing_customer.trial_started_at)
2779 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2780 is_usage_based_billing_enabled: if is_staff {
2781 Some(true)
2782 } else {
2783 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2784 },
2785 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2786 proto::SubscriptionPeriod {
2787 started_at: started_at.timestamp() as u64,
2788 ended_at: ended_at.timestamp() as u64,
2789 }
2790 }),
2791 account_too_young: Some(account_too_young),
2792 has_overdue_invoices: billing_customer
2793 .map(|billing_customer| billing_customer.has_overdue_invoices),
2794 usage: usage.map(|usage| {
2795 let plan = match plan {
2796 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2797 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2798 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2799 };
2800
2801 let model_requests_limit = match plan.model_requests_limit() {
2802 zed_llm_client::UsageLimit::Limited(limit) => {
2803 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2804 && feature_flags
2805 .iter()
2806 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2807 {
2808 1_000
2809 } else {
2810 limit
2811 };
2812
2813 zed_llm_client::UsageLimit::Limited(limit)
2814 }
2815 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2816 };
2817
2818 proto::SubscriptionUsage {
2819 model_requests_usage_amount: usage.model_requests as u32,
2820 model_requests_usage_limit: Some(proto::UsageLimit {
2821 variant: Some(match model_requests_limit {
2822 zed_llm_client::UsageLimit::Limited(limit) => {
2823 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2824 limit: limit as u32,
2825 })
2826 }
2827 zed_llm_client::UsageLimit::Unlimited => {
2828 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2829 }
2830 }),
2831 }),
2832 edit_predictions_usage_amount: usage.edit_predictions as u32,
2833 edit_predictions_usage_limit: Some(proto::UsageLimit {
2834 variant: Some(match plan.edit_predictions_limit() {
2835 zed_llm_client::UsageLimit::Limited(limit) => {
2836 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2837 limit: limit as u32,
2838 })
2839 }
2840 zed_llm_client::UsageLimit::Unlimited => {
2841 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2842 }
2843 }),
2844 }),
2845 }
2846 }),
2847 })
2848}
2849
2850async fn update_user_plan(session: &Session) -> Result<()> {
2851 let db = session.db().await;
2852
2853 let update_user_plan = make_update_user_plan_message(
2854 session.principal.user(),
2855 session.is_staff(),
2856 &db.0,
2857 session.app_state.llm_db.clone(),
2858 )
2859 .await?;
2860
2861 session
2862 .peer
2863 .send(session.connection_id, update_user_plan)
2864 .trace_err();
2865
2866 Ok(())
2867}
2868
2869async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2870 subscribe_user_to_channels(session.user_id(), &session).await?;
2871 Ok(())
2872}
2873
2874async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2875 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2876 let mut pool = session.connection_pool().await;
2877 for membership in &channels_for_user.channel_memberships {
2878 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2879 }
2880 session.peer.send(
2881 session.connection_id,
2882 build_update_user_channels(&channels_for_user),
2883 )?;
2884 session.peer.send(
2885 session.connection_id,
2886 build_channels_update(channels_for_user),
2887 )?;
2888 Ok(())
2889}
2890
2891/// Creates a new channel.
2892async fn create_channel(
2893 request: proto::CreateChannel,
2894 response: Response<proto::CreateChannel>,
2895 session: Session,
2896) -> Result<()> {
2897 let db = session.db().await;
2898
2899 let parent_id = request.parent_id.map(ChannelId::from_proto);
2900 let (channel, membership) = db
2901 .create_channel(&request.name, parent_id, session.user_id())
2902 .await?;
2903
2904 let root_id = channel.root_id();
2905 let channel = Channel::from_model(channel);
2906
2907 response.send(proto::CreateChannelResponse {
2908 channel: Some(channel.to_proto()),
2909 parent_id: request.parent_id,
2910 })?;
2911
2912 let mut connection_pool = session.connection_pool().await;
2913 if let Some(membership) = membership {
2914 connection_pool.subscribe_to_channel(
2915 membership.user_id,
2916 membership.channel_id,
2917 membership.role,
2918 );
2919 let update = proto::UpdateUserChannels {
2920 channel_memberships: vec![proto::ChannelMembership {
2921 channel_id: membership.channel_id.to_proto(),
2922 role: membership.role.into(),
2923 }],
2924 ..Default::default()
2925 };
2926 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2927 session.peer.send(connection_id, update.clone())?;
2928 }
2929 }
2930
2931 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2932 if !role.can_see_channel(channel.visibility) {
2933 continue;
2934 }
2935
2936 let update = proto::UpdateChannels {
2937 channels: vec![channel.to_proto()],
2938 ..Default::default()
2939 };
2940 session.peer.send(connection_id, update.clone())?;
2941 }
2942
2943 Ok(())
2944}
2945
2946/// Delete a channel
2947async fn delete_channel(
2948 request: proto::DeleteChannel,
2949 response: Response<proto::DeleteChannel>,
2950 session: Session,
2951) -> Result<()> {
2952 let db = session.db().await;
2953
2954 let channel_id = request.channel_id;
2955 let (root_channel, removed_channels) = db
2956 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2957 .await?;
2958 response.send(proto::Ack {})?;
2959
2960 // Notify members of removed channels
2961 let mut update = proto::UpdateChannels::default();
2962 update
2963 .delete_channels
2964 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2965
2966 let connection_pool = session.connection_pool().await;
2967 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2968 session.peer.send(connection_id, update.clone())?;
2969 }
2970
2971 Ok(())
2972}
2973
2974/// Invite someone to join a channel.
2975async fn invite_channel_member(
2976 request: proto::InviteChannelMember,
2977 response: Response<proto::InviteChannelMember>,
2978 session: Session,
2979) -> Result<()> {
2980 let db = session.db().await;
2981 let channel_id = ChannelId::from_proto(request.channel_id);
2982 let invitee_id = UserId::from_proto(request.user_id);
2983 let InviteMemberResult {
2984 channel,
2985 notifications,
2986 } = db
2987 .invite_channel_member(
2988 channel_id,
2989 invitee_id,
2990 session.user_id(),
2991 request.role().into(),
2992 )
2993 .await?;
2994
2995 let update = proto::UpdateChannels {
2996 channel_invitations: vec![channel.to_proto()],
2997 ..Default::default()
2998 };
2999
3000 let connection_pool = session.connection_pool().await;
3001 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3002 session.peer.send(connection_id, update.clone())?;
3003 }
3004
3005 send_notifications(&connection_pool, &session.peer, notifications);
3006
3007 response.send(proto::Ack {})?;
3008 Ok(())
3009}
3010
3011/// remove someone from a channel
3012async fn remove_channel_member(
3013 request: proto::RemoveChannelMember,
3014 response: Response<proto::RemoveChannelMember>,
3015 session: Session,
3016) -> Result<()> {
3017 let db = session.db().await;
3018 let channel_id = ChannelId::from_proto(request.channel_id);
3019 let member_id = UserId::from_proto(request.user_id);
3020
3021 let RemoveChannelMemberResult {
3022 membership_update,
3023 notification_id,
3024 } = db
3025 .remove_channel_member(channel_id, member_id, session.user_id())
3026 .await?;
3027
3028 let mut connection_pool = session.connection_pool().await;
3029 notify_membership_updated(
3030 &mut connection_pool,
3031 membership_update,
3032 member_id,
3033 &session.peer,
3034 );
3035 for connection_id in connection_pool.user_connection_ids(member_id) {
3036 if let Some(notification_id) = notification_id {
3037 session
3038 .peer
3039 .send(
3040 connection_id,
3041 proto::DeleteNotification {
3042 notification_id: notification_id.to_proto(),
3043 },
3044 )
3045 .trace_err();
3046 }
3047 }
3048
3049 response.send(proto::Ack {})?;
3050 Ok(())
3051}
3052
3053/// Toggle the channel between public and private.
3054/// Care is taken to maintain the invariant that public channels only descend from public channels,
3055/// (though members-only channels can appear at any point in the hierarchy).
3056async fn set_channel_visibility(
3057 request: proto::SetChannelVisibility,
3058 response: Response<proto::SetChannelVisibility>,
3059 session: Session,
3060) -> Result<()> {
3061 let db = session.db().await;
3062 let channel_id = ChannelId::from_proto(request.channel_id);
3063 let visibility = request.visibility().into();
3064
3065 let channel_model = db
3066 .set_channel_visibility(channel_id, visibility, session.user_id())
3067 .await?;
3068 let root_id = channel_model.root_id();
3069 let channel = Channel::from_model(channel_model);
3070
3071 let mut connection_pool = session.connection_pool().await;
3072 for (user_id, role) in connection_pool
3073 .channel_user_ids(root_id)
3074 .collect::<Vec<_>>()
3075 .into_iter()
3076 {
3077 let update = if role.can_see_channel(channel.visibility) {
3078 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3079 proto::UpdateChannels {
3080 channels: vec![channel.to_proto()],
3081 ..Default::default()
3082 }
3083 } else {
3084 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3085 proto::UpdateChannels {
3086 delete_channels: vec![channel.id.to_proto()],
3087 ..Default::default()
3088 }
3089 };
3090
3091 for connection_id in connection_pool.user_connection_ids(user_id) {
3092 session.peer.send(connection_id, update.clone())?;
3093 }
3094 }
3095
3096 response.send(proto::Ack {})?;
3097 Ok(())
3098}
3099
3100/// Alter the role for a user in the channel.
3101async fn set_channel_member_role(
3102 request: proto::SetChannelMemberRole,
3103 response: Response<proto::SetChannelMemberRole>,
3104 session: Session,
3105) -> Result<()> {
3106 let db = session.db().await;
3107 let channel_id = ChannelId::from_proto(request.channel_id);
3108 let member_id = UserId::from_proto(request.user_id);
3109 let result = db
3110 .set_channel_member_role(
3111 channel_id,
3112 session.user_id(),
3113 member_id,
3114 request.role().into(),
3115 )
3116 .await?;
3117
3118 match result {
3119 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3120 let mut connection_pool = session.connection_pool().await;
3121 notify_membership_updated(
3122 &mut connection_pool,
3123 membership_update,
3124 member_id,
3125 &session.peer,
3126 )
3127 }
3128 db::SetMemberRoleResult::InviteUpdated(channel) => {
3129 let update = proto::UpdateChannels {
3130 channel_invitations: vec![channel.to_proto()],
3131 ..Default::default()
3132 };
3133
3134 for connection_id in session
3135 .connection_pool()
3136 .await
3137 .user_connection_ids(member_id)
3138 {
3139 session.peer.send(connection_id, update.clone())?;
3140 }
3141 }
3142 }
3143
3144 response.send(proto::Ack {})?;
3145 Ok(())
3146}
3147
3148/// Change the name of a channel
3149async fn rename_channel(
3150 request: proto::RenameChannel,
3151 response: Response<proto::RenameChannel>,
3152 session: Session,
3153) -> Result<()> {
3154 let db = session.db().await;
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let channel_model = db
3157 .rename_channel(channel_id, session.user_id(), &request.name)
3158 .await?;
3159 let root_id = channel_model.root_id();
3160 let channel = Channel::from_model(channel_model);
3161
3162 response.send(proto::RenameChannelResponse {
3163 channel: Some(channel.to_proto()),
3164 })?;
3165
3166 let connection_pool = session.connection_pool().await;
3167 let update = proto::UpdateChannels {
3168 channels: vec![channel.to_proto()],
3169 ..Default::default()
3170 };
3171 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172 if role.can_see_channel(channel.visibility) {
3173 session.peer.send(connection_id, update.clone())?;
3174 }
3175 }
3176
3177 Ok(())
3178}
3179
3180/// Move a channel to a new parent.
3181async fn move_channel(
3182 request: proto::MoveChannel,
3183 response: Response<proto::MoveChannel>,
3184 session: Session,
3185) -> Result<()> {
3186 let channel_id = ChannelId::from_proto(request.channel_id);
3187 let to = ChannelId::from_proto(request.to);
3188
3189 let (root_id, channels) = session
3190 .db()
3191 .await
3192 .move_channel(channel_id, to, session.user_id())
3193 .await?;
3194
3195 let connection_pool = session.connection_pool().await;
3196 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3197 let channels = channels
3198 .iter()
3199 .filter_map(|channel| {
3200 if role.can_see_channel(channel.visibility) {
3201 Some(channel.to_proto())
3202 } else {
3203 None
3204 }
3205 })
3206 .collect::<Vec<_>>();
3207 if channels.is_empty() {
3208 continue;
3209 }
3210
3211 let update = proto::UpdateChannels {
3212 channels,
3213 ..Default::default()
3214 };
3215
3216 session.peer.send(connection_id, update.clone())?;
3217 }
3218
3219 response.send(Ack {})?;
3220 Ok(())
3221}
3222
3223/// Get the list of channel members
3224async fn get_channel_members(
3225 request: proto::GetChannelMembers,
3226 response: Response<proto::GetChannelMembers>,
3227 session: Session,
3228) -> Result<()> {
3229 let db = session.db().await;
3230 let channel_id = ChannelId::from_proto(request.channel_id);
3231 let limit = if request.limit == 0 {
3232 u16::MAX as u64
3233 } else {
3234 request.limit
3235 };
3236 let (members, users) = db
3237 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3238 .await?;
3239 response.send(proto::GetChannelMembersResponse { members, users })?;
3240 Ok(())
3241}
3242
3243/// Accept or decline a channel invitation.
3244async fn respond_to_channel_invite(
3245 request: proto::RespondToChannelInvite,
3246 response: Response<proto::RespondToChannelInvite>,
3247 session: Session,
3248) -> Result<()> {
3249 let db = session.db().await;
3250 let channel_id = ChannelId::from_proto(request.channel_id);
3251 let RespondToChannelInvite {
3252 membership_update,
3253 notifications,
3254 } = db
3255 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3256 .await?;
3257
3258 let mut connection_pool = session.connection_pool().await;
3259 if let Some(membership_update) = membership_update {
3260 notify_membership_updated(
3261 &mut connection_pool,
3262 membership_update,
3263 session.user_id(),
3264 &session.peer,
3265 );
3266 } else {
3267 let update = proto::UpdateChannels {
3268 remove_channel_invitations: vec![channel_id.to_proto()],
3269 ..Default::default()
3270 };
3271
3272 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3273 session.peer.send(connection_id, update.clone())?;
3274 }
3275 };
3276
3277 send_notifications(&connection_pool, &session.peer, notifications);
3278
3279 response.send(proto::Ack {})?;
3280
3281 Ok(())
3282}
3283
3284/// Join the channels' room
3285async fn join_channel(
3286 request: proto::JoinChannel,
3287 response: Response<proto::JoinChannel>,
3288 session: Session,
3289) -> Result<()> {
3290 let channel_id = ChannelId::from_proto(request.channel_id);
3291 join_channel_internal(channel_id, Box::new(response), session).await
3292}
3293
3294trait JoinChannelInternalResponse {
3295 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3296}
3297impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3298 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3299 Response::<proto::JoinChannel>::send(self, result)
3300 }
3301}
3302impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3303 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3304 Response::<proto::JoinRoom>::send(self, result)
3305 }
3306}
3307
3308async fn join_channel_internal(
3309 channel_id: ChannelId,
3310 response: Box<impl JoinChannelInternalResponse>,
3311 session: Session,
3312) -> Result<()> {
3313 let joined_room = {
3314 let mut db = session.db().await;
3315 // If zed quits without leaving the room, and the user re-opens zed before the
3316 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3317 // room they were in.
3318 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3319 tracing::info!(
3320 stale_connection_id = %connection,
3321 "cleaning up stale connection",
3322 );
3323 drop(db);
3324 leave_room_for_session(&session, connection).await?;
3325 db = session.db().await;
3326 }
3327
3328 let (joined_room, membership_updated, role) = db
3329 .join_channel(channel_id, session.user_id(), session.connection_id)
3330 .await?;
3331
3332 let live_kit_connection_info =
3333 session
3334 .app_state
3335 .livekit_client
3336 .as_ref()
3337 .and_then(|live_kit| {
3338 let (can_publish, token) = if role == ChannelRole::Guest {
3339 (
3340 false,
3341 live_kit
3342 .guest_token(
3343 &joined_room.room.livekit_room,
3344 &session.user_id().to_string(),
3345 )
3346 .trace_err()?,
3347 )
3348 } else {
3349 (
3350 true,
3351 live_kit
3352 .room_token(
3353 &joined_room.room.livekit_room,
3354 &session.user_id().to_string(),
3355 )
3356 .trace_err()?,
3357 )
3358 };
3359
3360 Some(LiveKitConnectionInfo {
3361 server_url: live_kit.url().into(),
3362 token,
3363 can_publish,
3364 })
3365 });
3366
3367 response.send(proto::JoinRoomResponse {
3368 room: Some(joined_room.room.clone()),
3369 channel_id: joined_room
3370 .channel
3371 .as_ref()
3372 .map(|channel| channel.id.to_proto()),
3373 live_kit_connection_info,
3374 })?;
3375
3376 let mut connection_pool = session.connection_pool().await;
3377 if let Some(membership_updated) = membership_updated {
3378 notify_membership_updated(
3379 &mut connection_pool,
3380 membership_updated,
3381 session.user_id(),
3382 &session.peer,
3383 );
3384 }
3385
3386 room_updated(&joined_room.room, &session.peer);
3387
3388 joined_room
3389 };
3390
3391 channel_updated(
3392 &joined_room.channel.context("channel not returned")?,
3393 &joined_room.room,
3394 &session.peer,
3395 &*session.connection_pool().await,
3396 );
3397
3398 update_user_contacts(session.user_id(), &session).await?;
3399 Ok(())
3400}
3401
3402/// Start editing the channel notes
3403async fn join_channel_buffer(
3404 request: proto::JoinChannelBuffer,
3405 response: Response<proto::JoinChannelBuffer>,
3406 session: Session,
3407) -> Result<()> {
3408 let db = session.db().await;
3409 let channel_id = ChannelId::from_proto(request.channel_id);
3410
3411 let open_response = db
3412 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3413 .await?;
3414
3415 let collaborators = open_response.collaborators.clone();
3416 response.send(open_response)?;
3417
3418 let update = UpdateChannelBufferCollaborators {
3419 channel_id: channel_id.to_proto(),
3420 collaborators: collaborators.clone(),
3421 };
3422 channel_buffer_updated(
3423 session.connection_id,
3424 collaborators
3425 .iter()
3426 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3427 &update,
3428 &session.peer,
3429 );
3430
3431 Ok(())
3432}
3433
3434/// Edit the channel notes
3435async fn update_channel_buffer(
3436 request: proto::UpdateChannelBuffer,
3437 session: Session,
3438) -> Result<()> {
3439 let db = session.db().await;
3440 let channel_id = ChannelId::from_proto(request.channel_id);
3441
3442 let (collaborators, epoch, version) = db
3443 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3444 .await?;
3445
3446 channel_buffer_updated(
3447 session.connection_id,
3448 collaborators.clone(),
3449 &proto::UpdateChannelBuffer {
3450 channel_id: channel_id.to_proto(),
3451 operations: request.operations,
3452 },
3453 &session.peer,
3454 );
3455
3456 let pool = &*session.connection_pool().await;
3457
3458 let non_collaborators =
3459 pool.channel_connection_ids(channel_id)
3460 .filter_map(|(connection_id, _)| {
3461 if collaborators.contains(&connection_id) {
3462 None
3463 } else {
3464 Some(connection_id)
3465 }
3466 });
3467
3468 broadcast(None, non_collaborators, |peer_id| {
3469 session.peer.send(
3470 peer_id,
3471 proto::UpdateChannels {
3472 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3473 channel_id: channel_id.to_proto(),
3474 epoch: epoch as u64,
3475 version: version.clone(),
3476 }],
3477 ..Default::default()
3478 },
3479 )
3480 });
3481
3482 Ok(())
3483}
3484
3485/// Rejoin the channel notes after a connection blip
3486async fn rejoin_channel_buffers(
3487 request: proto::RejoinChannelBuffers,
3488 response: Response<proto::RejoinChannelBuffers>,
3489 session: Session,
3490) -> Result<()> {
3491 let db = session.db().await;
3492 let buffers = db
3493 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3494 .await?;
3495
3496 for rejoined_buffer in &buffers {
3497 let collaborators_to_notify = rejoined_buffer
3498 .buffer
3499 .collaborators
3500 .iter()
3501 .filter_map(|c| Some(c.peer_id?.into()));
3502 channel_buffer_updated(
3503 session.connection_id,
3504 collaborators_to_notify,
3505 &proto::UpdateChannelBufferCollaborators {
3506 channel_id: rejoined_buffer.buffer.channel_id,
3507 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3508 },
3509 &session.peer,
3510 );
3511 }
3512
3513 response.send(proto::RejoinChannelBuffersResponse {
3514 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3515 })?;
3516
3517 Ok(())
3518}
3519
3520/// Stop editing the channel notes
3521async fn leave_channel_buffer(
3522 request: proto::LeaveChannelBuffer,
3523 response: Response<proto::LeaveChannelBuffer>,
3524 session: Session,
3525) -> Result<()> {
3526 let db = session.db().await;
3527 let channel_id = ChannelId::from_proto(request.channel_id);
3528
3529 let left_buffer = db
3530 .leave_channel_buffer(channel_id, session.connection_id)
3531 .await?;
3532
3533 response.send(Ack {})?;
3534
3535 channel_buffer_updated(
3536 session.connection_id,
3537 left_buffer.connections,
3538 &proto::UpdateChannelBufferCollaborators {
3539 channel_id: channel_id.to_proto(),
3540 collaborators: left_buffer.collaborators,
3541 },
3542 &session.peer,
3543 );
3544
3545 Ok(())
3546}
3547
3548fn channel_buffer_updated<T: EnvelopedMessage>(
3549 sender_id: ConnectionId,
3550 collaborators: impl IntoIterator<Item = ConnectionId>,
3551 message: &T,
3552 peer: &Peer,
3553) {
3554 broadcast(Some(sender_id), collaborators, |peer_id| {
3555 peer.send(peer_id, message.clone())
3556 });
3557}
3558
3559fn send_notifications(
3560 connection_pool: &ConnectionPool,
3561 peer: &Peer,
3562 notifications: db::NotificationBatch,
3563) {
3564 for (user_id, notification) in notifications {
3565 for connection_id in connection_pool.user_connection_ids(user_id) {
3566 if let Err(error) = peer.send(
3567 connection_id,
3568 proto::AddNotification {
3569 notification: Some(notification.clone()),
3570 },
3571 ) {
3572 tracing::error!(
3573 "failed to send notification to {:?} {}",
3574 connection_id,
3575 error
3576 );
3577 }
3578 }
3579 }
3580}
3581
3582/// Send a message to the channel
3583async fn send_channel_message(
3584 request: proto::SendChannelMessage,
3585 response: Response<proto::SendChannelMessage>,
3586 session: Session,
3587) -> Result<()> {
3588 // Validate the message body.
3589 let body = request.body.trim().to_string();
3590 if body.len() > MAX_MESSAGE_LEN {
3591 return Err(anyhow!("message is too long"))?;
3592 }
3593 if body.is_empty() {
3594 return Err(anyhow!("message can't be blank"))?;
3595 }
3596
3597 // TODO: adjust mentions if body is trimmed
3598
3599 let timestamp = OffsetDateTime::now_utc();
3600 let nonce = request.nonce.context("nonce can't be blank")?;
3601
3602 let channel_id = ChannelId::from_proto(request.channel_id);
3603 let CreatedChannelMessage {
3604 message_id,
3605 participant_connection_ids,
3606 notifications,
3607 } = session
3608 .db()
3609 .await
3610 .create_channel_message(
3611 channel_id,
3612 session.user_id(),
3613 &body,
3614 &request.mentions,
3615 timestamp,
3616 nonce.clone().into(),
3617 request.reply_to_message_id.map(MessageId::from_proto),
3618 )
3619 .await?;
3620
3621 let message = proto::ChannelMessage {
3622 sender_id: session.user_id().to_proto(),
3623 id: message_id.to_proto(),
3624 body,
3625 mentions: request.mentions,
3626 timestamp: timestamp.unix_timestamp() as u64,
3627 nonce: Some(nonce),
3628 reply_to_message_id: request.reply_to_message_id,
3629 edited_at: None,
3630 };
3631 broadcast(
3632 Some(session.connection_id),
3633 participant_connection_ids.clone(),
3634 |connection| {
3635 session.peer.send(
3636 connection,
3637 proto::ChannelMessageSent {
3638 channel_id: channel_id.to_proto(),
3639 message: Some(message.clone()),
3640 },
3641 )
3642 },
3643 );
3644 response.send(proto::SendChannelMessageResponse {
3645 message: Some(message),
3646 })?;
3647
3648 let pool = &*session.connection_pool().await;
3649 let non_participants =
3650 pool.channel_connection_ids(channel_id)
3651 .filter_map(|(connection_id, _)| {
3652 if participant_connection_ids.contains(&connection_id) {
3653 None
3654 } else {
3655 Some(connection_id)
3656 }
3657 });
3658 broadcast(None, non_participants, |peer_id| {
3659 session.peer.send(
3660 peer_id,
3661 proto::UpdateChannels {
3662 latest_channel_message_ids: vec![proto::ChannelMessageId {
3663 channel_id: channel_id.to_proto(),
3664 message_id: message_id.to_proto(),
3665 }],
3666 ..Default::default()
3667 },
3668 )
3669 });
3670 send_notifications(pool, &session.peer, notifications);
3671
3672 Ok(())
3673}
3674
3675/// Delete a channel message
3676async fn remove_channel_message(
3677 request: proto::RemoveChannelMessage,
3678 response: Response<proto::RemoveChannelMessage>,
3679 session: Session,
3680) -> Result<()> {
3681 let channel_id = ChannelId::from_proto(request.channel_id);
3682 let message_id = MessageId::from_proto(request.message_id);
3683 let (connection_ids, existing_notification_ids) = session
3684 .db()
3685 .await
3686 .remove_channel_message(channel_id, message_id, session.user_id())
3687 .await?;
3688
3689 broadcast(
3690 Some(session.connection_id),
3691 connection_ids,
3692 move |connection| {
3693 session.peer.send(connection, request.clone())?;
3694
3695 for notification_id in &existing_notification_ids {
3696 session.peer.send(
3697 connection,
3698 proto::DeleteNotification {
3699 notification_id: (*notification_id).to_proto(),
3700 },
3701 )?;
3702 }
3703
3704 Ok(())
3705 },
3706 );
3707 response.send(proto::Ack {})?;
3708 Ok(())
3709}
3710
3711async fn update_channel_message(
3712 request: proto::UpdateChannelMessage,
3713 response: Response<proto::UpdateChannelMessage>,
3714 session: Session,
3715) -> Result<()> {
3716 let channel_id = ChannelId::from_proto(request.channel_id);
3717 let message_id = MessageId::from_proto(request.message_id);
3718 let updated_at = OffsetDateTime::now_utc();
3719 let UpdatedChannelMessage {
3720 message_id,
3721 participant_connection_ids,
3722 notifications,
3723 reply_to_message_id,
3724 timestamp,
3725 deleted_mention_notification_ids,
3726 updated_mention_notifications,
3727 } = session
3728 .db()
3729 .await
3730 .update_channel_message(
3731 channel_id,
3732 message_id,
3733 session.user_id(),
3734 request.body.as_str(),
3735 &request.mentions,
3736 updated_at,
3737 )
3738 .await?;
3739
3740 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3741
3742 let message = proto::ChannelMessage {
3743 sender_id: session.user_id().to_proto(),
3744 id: message_id.to_proto(),
3745 body: request.body.clone(),
3746 mentions: request.mentions.clone(),
3747 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3748 nonce: Some(nonce),
3749 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3750 edited_at: Some(updated_at.unix_timestamp() as u64),
3751 };
3752
3753 response.send(proto::Ack {})?;
3754
3755 let pool = &*session.connection_pool().await;
3756 broadcast(
3757 Some(session.connection_id),
3758 participant_connection_ids,
3759 |connection| {
3760 session.peer.send(
3761 connection,
3762 proto::ChannelMessageUpdate {
3763 channel_id: channel_id.to_proto(),
3764 message: Some(message.clone()),
3765 },
3766 )?;
3767
3768 for notification_id in &deleted_mention_notification_ids {
3769 session.peer.send(
3770 connection,
3771 proto::DeleteNotification {
3772 notification_id: (*notification_id).to_proto(),
3773 },
3774 )?;
3775 }
3776
3777 for notification in &updated_mention_notifications {
3778 session.peer.send(
3779 connection,
3780 proto::UpdateNotification {
3781 notification: Some(notification.clone()),
3782 },
3783 )?;
3784 }
3785
3786 Ok(())
3787 },
3788 );
3789
3790 send_notifications(pool, &session.peer, notifications);
3791
3792 Ok(())
3793}
3794
3795/// Mark a channel message as read
3796async fn acknowledge_channel_message(
3797 request: proto::AckChannelMessage,
3798 session: Session,
3799) -> Result<()> {
3800 let channel_id = ChannelId::from_proto(request.channel_id);
3801 let message_id = MessageId::from_proto(request.message_id);
3802 let notifications = session
3803 .db()
3804 .await
3805 .observe_channel_message(channel_id, session.user_id(), message_id)
3806 .await?;
3807 send_notifications(
3808 &*session.connection_pool().await,
3809 &session.peer,
3810 notifications,
3811 );
3812 Ok(())
3813}
3814
3815/// Mark a buffer version as synced
3816async fn acknowledge_buffer_version(
3817 request: proto::AckBufferOperation,
3818 session: Session,
3819) -> Result<()> {
3820 let buffer_id = BufferId::from_proto(request.buffer_id);
3821 session
3822 .db()
3823 .await
3824 .observe_buffer_version(
3825 buffer_id,
3826 session.user_id(),
3827 request.epoch as i32,
3828 &request.version,
3829 )
3830 .await?;
3831 Ok(())
3832}
3833
3834/// Get a Supermaven API key for the user
3835async fn get_supermaven_api_key(
3836 _request: proto::GetSupermavenApiKey,
3837 response: Response<proto::GetSupermavenApiKey>,
3838 session: Session,
3839) -> Result<()> {
3840 let user_id: String = session.user_id().to_string();
3841 if !session.is_staff() {
3842 return Err(anyhow!("supermaven not enabled for this account"))?;
3843 }
3844
3845 let email = session.email().context("user must have an email")?;
3846
3847 let supermaven_admin_api = session
3848 .supermaven_client
3849 .as_ref()
3850 .context("supermaven not configured")?;
3851
3852 let result = supermaven_admin_api
3853 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3854 .await?;
3855
3856 response.send(proto::GetSupermavenApiKeyResponse {
3857 api_key: result.api_key,
3858 })?;
3859
3860 Ok(())
3861}
3862
3863/// Start receiving chat updates for a channel
3864async fn join_channel_chat(
3865 request: proto::JoinChannelChat,
3866 response: Response<proto::JoinChannelChat>,
3867 session: Session,
3868) -> Result<()> {
3869 let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871 let db = session.db().await;
3872 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3873 .await?;
3874 let messages = db
3875 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3876 .await?;
3877 response.send(proto::JoinChannelChatResponse {
3878 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3879 messages,
3880 })?;
3881 Ok(())
3882}
3883
3884/// Stop receiving chat updates for a channel
3885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3886 let channel_id = ChannelId::from_proto(request.channel_id);
3887 session
3888 .db()
3889 .await
3890 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891 .await?;
3892 Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897 request: proto::GetChannelMessages,
3898 response: Response<proto::GetChannelMessages>,
3899 session: Session,
3900) -> Result<()> {
3901 let channel_id = ChannelId::from_proto(request.channel_id);
3902 let messages = session
3903 .db()
3904 .await
3905 .get_channel_messages(
3906 channel_id,
3907 session.user_id(),
3908 MESSAGE_COUNT_PER_PAGE,
3909 Some(MessageId::from_proto(request.before_message_id)),
3910 )
3911 .await?;
3912 response.send(proto::GetChannelMessagesResponse {
3913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914 messages,
3915 })?;
3916 Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921 request: proto::GetChannelMessagesById,
3922 response: Response<proto::GetChannelMessagesById>,
3923 session: Session,
3924) -> Result<()> {
3925 let message_ids = request
3926 .message_ids
3927 .iter()
3928 .map(|id| MessageId::from_proto(*id))
3929 .collect::<Vec<_>>();
3930 let messages = session
3931 .db()
3932 .await
3933 .get_channel_messages_by_id(session.user_id(), &message_ids)
3934 .await?;
3935 response.send(proto::GetChannelMessagesResponse {
3936 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937 messages,
3938 })?;
3939 Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944 request: proto::GetNotifications,
3945 response: Response<proto::GetNotifications>,
3946 session: Session,
3947) -> Result<()> {
3948 let notifications = session
3949 .db()
3950 .await
3951 .get_notifications(
3952 session.user_id(),
3953 NOTIFICATION_COUNT_PER_PAGE,
3954 request.before_id.map(db::NotificationId::from_proto),
3955 )
3956 .await?;
3957 response.send(proto::GetNotificationsResponse {
3958 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959 notifications,
3960 })?;
3961 Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966 request: proto::MarkNotificationRead,
3967 response: Response<proto::MarkNotificationRead>,
3968 session: Session,
3969) -> Result<()> {
3970 let database = &session.db().await;
3971 let notifications = database
3972 .mark_notification_as_read_by_id(
3973 session.user_id(),
3974 NotificationId::from_proto(request.notification_id),
3975 )
3976 .await?;
3977 send_notifications(
3978 &*session.connection_pool().await,
3979 &session.peer,
3980 notifications,
3981 );
3982 response.send(proto::Ack {})?;
3983 Ok(())
3984}
3985
3986/// Get the current users information
3987async fn get_private_user_info(
3988 _request: proto::GetPrivateUserInfo,
3989 response: Response<proto::GetPrivateUserInfo>,
3990 session: Session,
3991) -> Result<()> {
3992 let db = session.db().await;
3993
3994 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3995 let user = db
3996 .get_user_by_id(session.user_id())
3997 .await?
3998 .context("user not found")?;
3999 let flags = db.get_user_flags(session.user_id()).await?;
4000
4001 response.send(proto::GetPrivateUserInfoResponse {
4002 metrics_id,
4003 staff: user.admin,
4004 flags,
4005 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4006 })?;
4007 Ok(())
4008}
4009
4010/// Accept the terms of service (tos) on behalf of the current user
4011async fn accept_terms_of_service(
4012 _request: proto::AcceptTermsOfService,
4013 response: Response<proto::AcceptTermsOfService>,
4014 session: Session,
4015) -> Result<()> {
4016 let db = session.db().await;
4017
4018 let accepted_tos_at = Utc::now();
4019 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4020 .await?;
4021
4022 response.send(proto::AcceptTermsOfServiceResponse {
4023 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4024 })?;
4025 Ok(())
4026}
4027
4028/// The minimum account age an account must have in order to use the LLM service.
4029pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4030
4031async fn get_llm_api_token(
4032 _request: proto::GetLlmToken,
4033 response: Response<proto::GetLlmToken>,
4034 session: Session,
4035) -> Result<()> {
4036 let db = session.db().await;
4037
4038 let flags = db.get_user_flags(session.user_id()).await?;
4039
4040 let user_id = session.user_id();
4041 let user = db
4042 .get_user_by_id(user_id)
4043 .await?
4044 .with_context(|| format!("user {user_id} not found"))?;
4045
4046 if user.accepted_tos_at.is_none() {
4047 Err(anyhow!("terms of service not accepted"))?
4048 }
4049
4050 let stripe_client = session
4051 .app_state
4052 .stripe_client
4053 .as_ref()
4054 .context("failed to retrieve Stripe client")?;
4055
4056 let stripe_billing = session
4057 .app_state
4058 .stripe_billing
4059 .as_ref()
4060 .context("failed to retrieve Stripe billing object")?;
4061
4062 let billing_customer = if let Some(billing_customer) =
4063 db.get_billing_customer_by_user_id(user.id).await?
4064 {
4065 billing_customer
4066 } else {
4067 let customer_id = stripe_billing
4068 .find_or_create_customer_by_email(user.email_address.as_deref())
4069 .await?;
4070
4071 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4072 .await?
4073 .context("billing customer not found")?
4074 };
4075
4076 let billing_subscription =
4077 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4078 billing_subscription
4079 } else {
4080 let stripe_customer_id =
4081 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4082
4083 let stripe_subscription = stripe_billing
4084 .subscribe_to_zed_free(stripe_customer_id)
4085 .await?;
4086
4087 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4088 billing_customer_id: billing_customer.id,
4089 kind: Some(SubscriptionKind::ZedFree),
4090 stripe_subscription_id: stripe_subscription.id.to_string(),
4091 stripe_subscription_status: stripe_subscription.status.into(),
4092 stripe_cancellation_reason: None,
4093 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4094 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4095 })
4096 .await?
4097 };
4098
4099 let billing_preferences = db.get_billing_preferences(user.id).await?;
4100
4101 let token = LlmTokenClaims::create(
4102 &user,
4103 session.is_staff(),
4104 billing_customer,
4105 billing_preferences,
4106 &flags,
4107 billing_subscription,
4108 session.system_id.clone(),
4109 &session.app_state.config,
4110 )?;
4111 response.send(proto::GetLlmTokenResponse { token })?;
4112 Ok(())
4113}
4114
4115fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4116 let message = match message {
4117 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4118 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4119 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4120 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4121 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4122 code: frame.code.into(),
4123 reason: frame.reason.as_str().to_owned().into(),
4124 })),
4125 // We should never receive a frame while reading the message, according
4126 // to the `tungstenite` maintainers:
4127 //
4128 // > It cannot occur when you read messages from the WebSocket, but it
4129 // > can be used when you want to send the raw frames (e.g. you want to
4130 // > send the frames to the WebSocket without composing the full message first).
4131 // >
4132 // > — https://github.com/snapview/tungstenite-rs/issues/268
4133 TungsteniteMessage::Frame(_) => {
4134 bail!("received an unexpected frame while reading the message")
4135 }
4136 };
4137
4138 Ok(message)
4139}
4140
4141fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4142 match message {
4143 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4144 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4145 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4146 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4147 AxumMessage::Close(frame) => {
4148 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4149 code: frame.code.into(),
4150 reason: frame.reason.as_ref().into(),
4151 }))
4152 }
4153 }
4154}
4155
4156fn notify_membership_updated(
4157 connection_pool: &mut ConnectionPool,
4158 result: MembershipUpdated,
4159 user_id: UserId,
4160 peer: &Peer,
4161) {
4162 for membership in &result.new_channels.channel_memberships {
4163 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4164 }
4165 for channel_id in &result.removed_channels {
4166 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4167 }
4168
4169 let user_channels_update = proto::UpdateUserChannels {
4170 channel_memberships: result
4171 .new_channels
4172 .channel_memberships
4173 .iter()
4174 .map(|cm| proto::ChannelMembership {
4175 channel_id: cm.channel_id.to_proto(),
4176 role: cm.role.into(),
4177 })
4178 .collect(),
4179 ..Default::default()
4180 };
4181
4182 let mut update = build_channels_update(result.new_channels);
4183 update.delete_channels = result
4184 .removed_channels
4185 .into_iter()
4186 .map(|id| id.to_proto())
4187 .collect();
4188 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4189
4190 for connection_id in connection_pool.user_connection_ids(user_id) {
4191 peer.send(connection_id, user_channels_update.clone())
4192 .trace_err();
4193 peer.send(connection_id, update.clone()).trace_err();
4194 }
4195}
4196
4197fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4198 proto::UpdateUserChannels {
4199 channel_memberships: channels
4200 .channel_memberships
4201 .iter()
4202 .map(|m| proto::ChannelMembership {
4203 channel_id: m.channel_id.to_proto(),
4204 role: m.role.into(),
4205 })
4206 .collect(),
4207 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4208 observed_channel_message_id: channels.observed_channel_messages.clone(),
4209 }
4210}
4211
4212fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4213 let mut update = proto::UpdateChannels::default();
4214
4215 for channel in channels.channels {
4216 update.channels.push(channel.to_proto());
4217 }
4218
4219 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4220 update.latest_channel_message_ids = channels.latest_channel_messages;
4221
4222 for (channel_id, participants) in channels.channel_participants {
4223 update
4224 .channel_participants
4225 .push(proto::ChannelParticipants {
4226 channel_id: channel_id.to_proto(),
4227 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4228 });
4229 }
4230
4231 for channel in channels.invited_channels {
4232 update.channel_invitations.push(channel.to_proto());
4233 }
4234
4235 update
4236}
4237
4238fn build_initial_contacts_update(
4239 contacts: Vec<db::Contact>,
4240 pool: &ConnectionPool,
4241) -> proto::UpdateContacts {
4242 let mut update = proto::UpdateContacts::default();
4243
4244 for contact in contacts {
4245 match contact {
4246 db::Contact::Accepted { user_id, busy } => {
4247 update.contacts.push(contact_for_user(user_id, busy, pool));
4248 }
4249 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4250 db::Contact::Incoming { user_id } => {
4251 update
4252 .incoming_requests
4253 .push(proto::IncomingContactRequest {
4254 requester_id: user_id.to_proto(),
4255 })
4256 }
4257 }
4258 }
4259
4260 update
4261}
4262
4263fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4264 proto::Contact {
4265 user_id: user_id.to_proto(),
4266 online: pool.is_user_online(user_id),
4267 busy,
4268 }
4269}
4270
4271fn room_updated(room: &proto::Room, peer: &Peer) {
4272 broadcast(
4273 None,
4274 room.participants
4275 .iter()
4276 .filter_map(|participant| Some(participant.peer_id?.into())),
4277 |peer_id| {
4278 peer.send(
4279 peer_id,
4280 proto::RoomUpdated {
4281 room: Some(room.clone()),
4282 },
4283 )
4284 },
4285 );
4286}
4287
4288fn channel_updated(
4289 channel: &db::channel::Model,
4290 room: &proto::Room,
4291 peer: &Peer,
4292 pool: &ConnectionPool,
4293) {
4294 let participants = room
4295 .participants
4296 .iter()
4297 .map(|p| p.user_id)
4298 .collect::<Vec<_>>();
4299
4300 broadcast(
4301 None,
4302 pool.channel_connection_ids(channel.root_id())
4303 .filter_map(|(channel_id, role)| {
4304 role.can_see_channel(channel.visibility)
4305 .then_some(channel_id)
4306 }),
4307 |peer_id| {
4308 peer.send(
4309 peer_id,
4310 proto::UpdateChannels {
4311 channel_participants: vec![proto::ChannelParticipants {
4312 channel_id: channel.id.to_proto(),
4313 participant_user_ids: participants.clone(),
4314 }],
4315 ..Default::default()
4316 },
4317 )
4318 },
4319 );
4320}
4321
4322async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4323 let db = session.db().await;
4324
4325 let contacts = db.get_contacts(user_id).await?;
4326 let busy = db.is_user_busy(user_id).await?;
4327
4328 let pool = session.connection_pool().await;
4329 let updated_contact = contact_for_user(user_id, busy, &pool);
4330 for contact in contacts {
4331 if let db::Contact::Accepted {
4332 user_id: contact_user_id,
4333 ..
4334 } = contact
4335 {
4336 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4337 session
4338 .peer
4339 .send(
4340 contact_conn_id,
4341 proto::UpdateContacts {
4342 contacts: vec![updated_contact.clone()],
4343 remove_contacts: Default::default(),
4344 incoming_requests: Default::default(),
4345 remove_incoming_requests: Default::default(),
4346 outgoing_requests: Default::default(),
4347 remove_outgoing_requests: Default::default(),
4348 },
4349 )
4350 .trace_err();
4351 }
4352 }
4353 }
4354 Ok(())
4355}
4356
4357async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4358 let mut contacts_to_update = HashSet::default();
4359
4360 let room_id;
4361 let canceled_calls_to_user_ids;
4362 let livekit_room;
4363 let delete_livekit_room;
4364 let room;
4365 let channel;
4366
4367 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4368 contacts_to_update.insert(session.user_id());
4369
4370 for project in left_room.left_projects.values() {
4371 project_left(project, session);
4372 }
4373
4374 room_id = RoomId::from_proto(left_room.room.id);
4375 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4376 livekit_room = mem::take(&mut left_room.room.livekit_room);
4377 delete_livekit_room = left_room.deleted;
4378 room = mem::take(&mut left_room.room);
4379 channel = mem::take(&mut left_room.channel);
4380
4381 room_updated(&room, &session.peer);
4382 } else {
4383 return Ok(());
4384 }
4385
4386 if let Some(channel) = channel {
4387 channel_updated(
4388 &channel,
4389 &room,
4390 &session.peer,
4391 &*session.connection_pool().await,
4392 );
4393 }
4394
4395 {
4396 let pool = session.connection_pool().await;
4397 for canceled_user_id in canceled_calls_to_user_ids {
4398 for connection_id in pool.user_connection_ids(canceled_user_id) {
4399 session
4400 .peer
4401 .send(
4402 connection_id,
4403 proto::CallCanceled {
4404 room_id: room_id.to_proto(),
4405 },
4406 )
4407 .trace_err();
4408 }
4409 contacts_to_update.insert(canceled_user_id);
4410 }
4411 }
4412
4413 for contact_user_id in contacts_to_update {
4414 update_user_contacts(contact_user_id, session).await?;
4415 }
4416
4417 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4418 live_kit
4419 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4420 .await
4421 .trace_err();
4422
4423 if delete_livekit_room {
4424 live_kit.delete_room(livekit_room).await.trace_err();
4425 }
4426 }
4427
4428 Ok(())
4429}
4430
4431async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4432 let left_channel_buffers = session
4433 .db()
4434 .await
4435 .leave_channel_buffers(session.connection_id)
4436 .await?;
4437
4438 for left_buffer in left_channel_buffers {
4439 channel_buffer_updated(
4440 session.connection_id,
4441 left_buffer.connections,
4442 &proto::UpdateChannelBufferCollaborators {
4443 channel_id: left_buffer.channel_id.to_proto(),
4444 collaborators: left_buffer.collaborators,
4445 },
4446 &session.peer,
4447 );
4448 }
4449
4450 Ok(())
4451}
4452
4453fn project_left(project: &db::LeftProject, session: &Session) {
4454 for connection_id in &project.connection_ids {
4455 if project.should_unshare {
4456 session
4457 .peer
4458 .send(
4459 *connection_id,
4460 proto::UnshareProject {
4461 project_id: project.id.to_proto(),
4462 },
4463 )
4464 .trace_err();
4465 } else {
4466 session
4467 .peer
4468 .send(
4469 *connection_id,
4470 proto::RemoveProjectCollaborator {
4471 project_id: project.id.to_proto(),
4472 peer_id: Some(session.connection_id.into()),
4473 },
4474 )
4475 .trace_err();
4476 }
4477 }
4478}
4479
4480pub trait ResultExt {
4481 type Ok;
4482
4483 fn trace_err(self) -> Option<Self::Ok>;
4484}
4485
4486impl<T, E> ResultExt for Result<T, E>
4487where
4488 E: std::fmt::Debug,
4489{
4490 type Ok = T;
4491
4492 #[track_caller]
4493 fn trace_err(self) -> Option<T> {
4494 match self {
4495 Ok(value) => Some(value),
4496 Err(error) => {
4497 tracing::error!("{:?}", error);
4498 None
4499 }
4500 }
4501 }
4502}