1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::stripe_client::StripeCustomerId;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{
12 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
13 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
14 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
15 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
16 },
17 executor::Executor,
18};
19use anyhow::{Context as _, anyhow, bail};
20use async_tungstenite::tungstenite::{
21 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
22};
23use axum::{
24 Extension, Router, TypedHeader,
25 body::Body,
26 extract::{
27 ConnectInfo, WebSocketUpgrade,
28 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
29 },
30 headers::{Header, HeaderName},
31 http::StatusCode,
32 middleware,
33 response::IntoResponse,
34 routing::get,
35};
36use chrono::Utc;
37use collections::{HashMap, HashSet};
38pub use connection_pool::{ConnectionPool, ZedVersion};
39use core::fmt::{self, Debug, Formatter};
40use reqwest_client::ReqwestClient;
41use rpc::proto::split_repository_update;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
46 stream::FuturesUnordered,
47};
48use prometheus::{IntGauge, register_int_gauge};
49use rpc::{
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 Arc, OnceLock,
68 atomic::{AtomicBool, Ordering::SeqCst},
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{Semaphore, watch};
74use tower::ServiceBuilder;
75use tracing::{
76 Instrument,
77 field::{self},
78 info_span, instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111}
112
113impl Principal {
114 fn user(&self) -> &User {
115 match self {
116 Principal::User(user) => user,
117 Principal::Impersonated { user, .. } => user,
118 }
119 }
120
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct Session {
138 principal: Principal,
139 connection_id: ConnectionId,
140 db: Arc<tokio::sync::Mutex<DbHandle>>,
141 peer: Arc<Peer>,
142 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 supermaven_client: Option<Arc<SupermavenAdminApi>>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 system_id: Option<String>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn is_staff(&self) -> bool {
173 match &self.principal {
174 Principal::User(user) => user.admin,
175 Principal::Impersonated { .. } => true,
176 }
177 }
178
179 fn user_id(&self) -> UserId {
180 match &self.principal {
181 Principal::User(user) => user.id,
182 Principal::Impersonated { user, .. } => user.id,
183 }
184 }
185
186 pub fn email(&self) -> Option<String> {
187 match &self.principal {
188 Principal::User(user) => user.email_address.clone(),
189 Principal::Impersonated { user, .. } => user.email_address.clone(),
190 }
191 }
192}
193
194impl Debug for Session {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let mut result = f.debug_struct("Session");
197 match &self.principal {
198 Principal::User(user) => {
199 result.field("user", &user.github_login);
200 }
201 Principal::Impersonated { user, admin } => {
202 result.field("user", &user.github_login);
203 result.field("impersonator", &admin.github_login);
204 }
205 }
206 result.field("connection_id", &self.connection_id).finish()
207 }
208}
209
210struct DbHandle(Arc<Database>);
211
212impl Deref for DbHandle {
213 type Target = Database;
214
215 fn deref(&self) -> &Self::Target {
216 self.0.as_ref()
217 }
218}
219
220pub struct Server {
221 id: parking_lot::Mutex<ServerId>,
222 peer: Arc<Peer>,
223 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
224 app_state: Arc<AppState>,
225 handlers: HashMap<TypeId, MessageHandler>,
226 teardown: watch::Sender<bool>,
227}
228
229pub(crate) struct ConnectionPoolGuard<'a> {
230 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
231 _not_send: PhantomData<Rc<()>>,
232}
233
234#[derive(Serialize)]
235pub struct ServerSnapshot<'a> {
236 peer: &'a Peer,
237 #[serde(serialize_with = "serialize_deref")]
238 connection_pool: ConnectionPoolGuard<'a>,
239}
240
241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
242where
243 S: Serializer,
244 T: Deref<Target = U>,
245 U: Serialize,
246{
247 Serialize::serialize(value.deref(), serializer)
248}
249
250impl Server {
251 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
252 let mut server = Self {
253 id: parking_lot::Mutex::new(id),
254 peer: Peer::new(id.0 as u32),
255 app_state: app_state.clone(),
256 connection_pool: Default::default(),
257 handlers: Default::default(),
258 teardown: watch::channel(false).0,
259 };
260
261 server
262 .add_request_handler(ping)
263 .add_request_handler(create_room)
264 .add_request_handler(join_room)
265 .add_request_handler(rejoin_room)
266 .add_request_handler(leave_room)
267 .add_request_handler(set_room_participant_role)
268 .add_request_handler(call)
269 .add_request_handler(cancel_call)
270 .add_message_handler(decline_call)
271 .add_request_handler(update_participant_location)
272 .add_request_handler(share_project)
273 .add_message_handler(unshare_project)
274 .add_request_handler(join_project)
275 .add_message_handler(leave_project)
276 .add_request_handler(update_project)
277 .add_request_handler(update_worktree)
278 .add_request_handler(update_repository)
279 .add_request_handler(remove_repository)
280 .add_message_handler(start_language_server)
281 .add_message_handler(update_language_server)
282 .add_message_handler(update_diagnostic_summary)
283 .add_message_handler(update_worktree_settings)
284 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
287 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
288 .add_request_handler(forward_find_search_candidates_request)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
294 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
295 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
296 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
297 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
298 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
299 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
304 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
305 .add_request_handler(
306 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
307 )
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
312 .add_request_handler(
313 forward_read_only_project_request::<proto::LanguageServerIdForName>,
314 )
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
348 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
349 .add_message_handler(create_buffer_for_peer)
350 .add_request_handler(update_buffer)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
357 .add_request_handler(get_users)
358 .add_request_handler(fuzzy_search_users)
359 .add_request_handler(request_contact)
360 .add_request_handler(remove_contact)
361 .add_request_handler(respond_to_contact_request)
362 .add_message_handler(subscribe_to_channels)
363 .add_request_handler(create_channel)
364 .add_request_handler(delete_channel)
365 .add_request_handler(invite_channel_member)
366 .add_request_handler(remove_channel_member)
367 .add_request_handler(set_channel_member_role)
368 .add_request_handler(set_channel_visibility)
369 .add_request_handler(rename_channel)
370 .add_request_handler(join_channel_buffer)
371 .add_request_handler(leave_channel_buffer)
372 .add_message_handler(update_channel_buffer)
373 .add_request_handler(rejoin_channel_buffers)
374 .add_request_handler(get_channel_members)
375 .add_request_handler(respond_to_channel_invite)
376 .add_request_handler(join_channel)
377 .add_request_handler(join_channel_chat)
378 .add_message_handler(leave_channel_chat)
379 .add_request_handler(send_channel_message)
380 .add_request_handler(remove_channel_message)
381 .add_request_handler(update_channel_message)
382 .add_request_handler(get_channel_messages)
383 .add_request_handler(get_channel_messages_by_id)
384 .add_request_handler(get_notifications)
385 .add_request_handler(mark_notification_as_read)
386 .add_request_handler(move_channel)
387 .add_request_handler(follow)
388 .add_message_handler(unfollow)
389 .add_message_handler(update_followers)
390 .add_request_handler(get_private_user_info)
391 .add_request_handler(get_llm_api_token)
392 .add_request_handler(accept_terms_of_service)
393 .add_message_handler(acknowledge_channel_message)
394 .add_message_handler(acknowledge_buffer_version)
395 .add_request_handler(get_supermaven_api_key)
396 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
398 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
399 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
401 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
402 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
403 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
405 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
406 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
408 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
409 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
411 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
412 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
415 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
416 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
417 .add_message_handler(update_context);
418
419 Arc::new(server)
420 }
421
422 pub async fn start(&self) -> Result<()> {
423 let server_id = *self.id.lock();
424 let app_state = self.app_state.clone();
425 let peer = self.peer.clone();
426 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
427 let pool = self.connection_pool.clone();
428 let livekit_client = self.app_state.livekit_client.clone();
429
430 let span = info_span!("start server");
431 self.app_state.executor.spawn_detached(
432 async move {
433 tracing::info!("waiting for cleanup timeout");
434 timeout.await;
435 tracing::info!("cleanup timeout expired, retrieving stale rooms");
436 if let Some((room_ids, channel_ids)) = app_state
437 .db
438 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
439 .await
440 .trace_err()
441 {
442 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
443 tracing::info!(
444 stale_channel_buffer_count = channel_ids.len(),
445 "retrieved stale channel buffers"
446 );
447
448 for channel_id in channel_ids {
449 if let Some(refreshed_channel_buffer) = app_state
450 .db
451 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
452 .await
453 .trace_err()
454 {
455 for connection_id in refreshed_channel_buffer.connection_ids {
456 peer.send(
457 connection_id,
458 proto::UpdateChannelBufferCollaborators {
459 channel_id: channel_id.to_proto(),
460 collaborators: refreshed_channel_buffer
461 .collaborators
462 .clone(),
463 },
464 )
465 .trace_err();
466 }
467 }
468 }
469
470 for room_id in room_ids {
471 let mut contacts_to_update = HashSet::default();
472 let mut canceled_calls_to_user_ids = Vec::new();
473 let mut livekit_room = String::new();
474 let mut delete_livekit_room = false;
475
476 if let Some(mut refreshed_room) = app_state
477 .db
478 .clear_stale_room_participants(room_id, server_id)
479 .await
480 .trace_err()
481 {
482 tracing::info!(
483 room_id = room_id.0,
484 new_participant_count = refreshed_room.room.participants.len(),
485 "refreshed room"
486 );
487 room_updated(&refreshed_room.room, &peer);
488 if let Some(channel) = refreshed_room.channel.as_ref() {
489 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
490 }
491 contacts_to_update
492 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
493 contacts_to_update
494 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
495 canceled_calls_to_user_ids =
496 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
497 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
498 delete_livekit_room = refreshed_room.room.participants.is_empty();
499 }
500
501 {
502 let pool = pool.lock();
503 for canceled_user_id in canceled_calls_to_user_ids {
504 for connection_id in pool.user_connection_ids(canceled_user_id) {
505 peer.send(
506 connection_id,
507 proto::CallCanceled {
508 room_id: room_id.to_proto(),
509 },
510 )
511 .trace_err();
512 }
513 }
514 }
515
516 for user_id in contacts_to_update {
517 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
518 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
519 if let Some((busy, contacts)) = busy.zip(contacts) {
520 let pool = pool.lock();
521 let updated_contact = contact_for_user(user_id, busy, &pool);
522 for contact in contacts {
523 if let db::Contact::Accepted {
524 user_id: contact_user_id,
525 ..
526 } = contact
527 {
528 for contact_conn_id in
529 pool.user_connection_ids(contact_user_id)
530 {
531 peer.send(
532 contact_conn_id,
533 proto::UpdateContacts {
534 contacts: vec![updated_contact.clone()],
535 remove_contacts: Default::default(),
536 incoming_requests: Default::default(),
537 remove_incoming_requests: Default::default(),
538 outgoing_requests: Default::default(),
539 remove_outgoing_requests: Default::default(),
540 },
541 )
542 .trace_err();
543 }
544 }
545 }
546 }
547 }
548
549 if let Some(live_kit) = livekit_client.as_ref() {
550 if delete_livekit_room {
551 live_kit.delete_room(livekit_room).await.trace_err();
552 }
553 }
554 }
555 }
556
557 app_state
558 .db
559 .delete_stale_servers(&app_state.config.zed_environment, server_id)
560 .await
561 .trace_err();
562 }
563 .instrument(span),
564 );
565 Ok(())
566 }
567
568 pub fn teardown(&self) {
569 self.peer.teardown();
570 self.connection_pool.lock().reset();
571 let _ = self.teardown.send(true);
572 }
573
574 #[cfg(test)]
575 pub fn reset(&self, id: ServerId) {
576 self.teardown();
577 *self.id.lock() = id;
578 self.peer.reset(id.0 as u32);
579 let _ = self.teardown.send(false);
580 }
581
582 #[cfg(test)]
583 pub fn id(&self) -> ServerId {
584 *self.id.lock()
585 }
586
587 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
588 where
589 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
590 Fut: 'static + Send + Future<Output = Result<()>>,
591 M: EnvelopedMessage,
592 {
593 let prev_handler = self.handlers.insert(
594 TypeId::of::<M>(),
595 Box::new(move |envelope, session| {
596 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
597 let received_at = envelope.received_at;
598 tracing::info!("message received");
599 let start_time = Instant::now();
600 let future = (handler)(*envelope, session);
601 async move {
602 let result = future.await;
603 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
604 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
605 let queue_duration_ms = total_duration_ms - processing_duration_ms;
606 let payload_type = M::NAME;
607
608 match result {
609 Err(error) => {
610 tracing::error!(
611 ?error,
612 total_duration_ms,
613 processing_duration_ms,
614 queue_duration_ms,
615 payload_type,
616 "error handling message"
617 )
618 }
619 Ok(()) => tracing::info!(
620 total_duration_ms,
621 processing_duration_ms,
622 queue_duration_ms,
623 "finished handling message"
624 ),
625 }
626 }
627 .boxed()
628 }),
629 );
630 if prev_handler.is_some() {
631 panic!("registered a handler for the same message twice");
632 }
633 self
634 }
635
636 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
637 where
638 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
639 Fut: 'static + Send + Future<Output = Result<()>>,
640 M: EnvelopedMessage,
641 {
642 self.add_handler(move |envelope, session| handler(envelope.payload, session));
643 self
644 }
645
646 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
649 Fut: Send + Future<Output = Result<()>>,
650 M: RequestMessage,
651 {
652 let handler = Arc::new(handler);
653 self.add_handler(move |envelope, session| {
654 let receipt = envelope.receipt();
655 let handler = handler.clone();
656 async move {
657 let peer = session.peer.clone();
658 let responded = Arc::new(AtomicBool::default());
659 let response = Response {
660 peer: peer.clone(),
661 responded: responded.clone(),
662 receipt,
663 };
664 match (handler)(envelope.payload, response, session).await {
665 Ok(()) => {
666 if responded.load(std::sync::atomic::Ordering::SeqCst) {
667 Ok(())
668 } else {
669 Err(anyhow!("handler did not send a response"))?
670 }
671 }
672 Err(error) => {
673 let proto_err = match &error {
674 Error::Internal(err) => err.to_proto(),
675 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
676 };
677 peer.respond_with_error(receipt, proto_err)?;
678 Err(error)
679 }
680 }
681 }
682 })
683 }
684
685 pub fn handle_connection(
686 self: &Arc<Self>,
687 connection: Connection,
688 address: String,
689 principal: Principal,
690 zed_version: ZedVersion,
691 geoip_country_code: Option<String>,
692 system_id: Option<String>,
693 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
694 executor: Executor,
695 ) -> impl Future<Output = ()> + use<> {
696 let this = self.clone();
697 let span = info_span!("handle connection", %address,
698 connection_id=field::Empty,
699 user_id=field::Empty,
700 login=field::Empty,
701 impersonator=field::Empty,
702 geoip_country_code=field::Empty
703 );
704 principal.update_span(&span);
705 if let Some(country_code) = geoip_country_code.as_ref() {
706 span.record("geoip_country_code", country_code);
707 }
708
709 let mut teardown = self.teardown.subscribe();
710 async move {
711 if *teardown.borrow() {
712 tracing::error!("server is tearing down");
713 return
714 }
715 let (connection_id, handle_io, mut incoming_rx) = this
716 .peer
717 .add_connection(connection, {
718 let executor = executor.clone();
719 move |duration| executor.sleep(duration)
720 });
721 tracing::Span::current().record("connection_id", format!("{}", connection_id));
722
723 tracing::info!("connection opened");
724
725 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
726 let http_client = match ReqwestClient::user_agent(&user_agent) {
727 Ok(http_client) => Arc::new(http_client),
728 Err(error) => {
729 tracing::error!(?error, "failed to create HTTP client");
730 return;
731 }
732 };
733
734 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
735 supermaven_admin_api_key.to_string(),
736 http_client.clone(),
737 )));
738
739 let session = Session {
740 principal: principal.clone(),
741 connection_id,
742 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
743 peer: this.peer.clone(),
744 connection_pool: this.connection_pool.clone(),
745 app_state: this.app_state.clone(),
746 geoip_country_code,
747 system_id,
748 _executor: executor.clone(),
749 supermaven_client,
750 };
751
752 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
753 tracing::error!(?error, "failed to send initial client update");
754 return;
755 }
756
757 let handle_io = handle_io.fuse();
758 futures::pin_mut!(handle_io);
759
760 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
761 // This prevents deadlocks when e.g., client A performs a request to client B and
762 // client B performs a request to client A. If both clients stop processing further
763 // messages until their respective request completes, they won't have a chance to
764 // respond to the other client's request and cause a deadlock.
765 //
766 // This arrangement ensures we will attempt to process earlier messages first, but fall
767 // back to processing messages arrived later in the spirit of making progress.
768 let mut foreground_message_handlers = FuturesUnordered::new();
769 let concurrent_handlers = Arc::new(Semaphore::new(256));
770 loop {
771 let next_message = async {
772 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
773 let message = incoming_rx.next().await;
774 (permit, message)
775 }.fuse();
776 futures::pin_mut!(next_message);
777 futures::select_biased! {
778 _ = teardown.changed().fuse() => return,
779 result = handle_io => {
780 if let Err(error) = result {
781 tracing::error!(?error, "error handling I/O");
782 }
783 break;
784 }
785 _ = foreground_message_handlers.next() => {}
786 next_message = next_message => {
787 let (permit, message) = next_message;
788 if let Some(message) = message {
789 let type_name = message.payload_type_name();
790 // note: we copy all the fields from the parent span so we can query them in the logs.
791 // (https://github.com/tokio-rs/tracing/issues/2670).
792 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
793 user_id=field::Empty,
794 login=field::Empty,
795 impersonator=field::Empty,
796 );
797 principal.update_span(&span);
798 let span_enter = span.enter();
799 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
800 let is_background = message.is_background();
801 let handle_message = (handler)(message, session.clone());
802 drop(span_enter);
803
804 let handle_message = async move {
805 handle_message.await;
806 drop(permit);
807 }.instrument(span);
808 if is_background {
809 executor.spawn_detached(handle_message);
810 } else {
811 foreground_message_handlers.push(handle_message);
812 }
813 } else {
814 tracing::error!("no message handler");
815 }
816 } else {
817 tracing::info!("connection closed");
818 break;
819 }
820 }
821 }
822 }
823
824 drop(foreground_message_handlers);
825 tracing::info!("signing out");
826 if let Err(error) = connection_lost(session, teardown, executor).await {
827 tracing::error!(?error, "error signing out");
828 }
829
830 }.instrument(span)
831 }
832
833 async fn send_initial_client_update(
834 &self,
835 connection_id: ConnectionId,
836 zed_version: ZedVersion,
837 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
838 session: &Session,
839 ) -> Result<()> {
840 self.peer.send(
841 connection_id,
842 proto::Hello {
843 peer_id: Some(connection_id.into()),
844 },
845 )?;
846 tracing::info!("sent hello message");
847 if let Some(send_connection_id) = send_connection_id.take() {
848 let _ = send_connection_id.send(connection_id);
849 }
850
851 match &session.principal {
852 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
853 if !user.connected_once {
854 self.peer.send(connection_id, proto::ShowContacts {})?;
855 self.app_state
856 .db
857 .set_user_connected_once(user.id, true)
858 .await?;
859 }
860
861 update_user_plan(session).await?;
862
863 let contacts = self.app_state.db.get_contacts(user.id).await?;
864
865 {
866 let mut pool = self.connection_pool.lock();
867 pool.add_connection(connection_id, user.id, user.admin, zed_version);
868 self.peer.send(
869 connection_id,
870 build_initial_contacts_update(contacts, &pool),
871 )?;
872 }
873
874 if should_auto_subscribe_to_channels(zed_version) {
875 subscribe_user_to_channels(user.id, session).await?;
876 }
877
878 if let Some(incoming_call) =
879 self.app_state.db.incoming_call_for_user(user.id).await?
880 {
881 self.peer.send(connection_id, incoming_call)?;
882 }
883
884 update_user_contacts(user.id, session).await?;
885 }
886 }
887
888 Ok(())
889 }
890
891 pub async fn invite_code_redeemed(
892 self: &Arc<Self>,
893 inviter_id: UserId,
894 invitee_id: UserId,
895 ) -> Result<()> {
896 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
897 if let Some(code) = &user.invite_code {
898 let pool = self.connection_pool.lock();
899 let invitee_contact = contact_for_user(invitee_id, false, &pool);
900 for connection_id in pool.user_connection_ids(inviter_id) {
901 self.peer.send(
902 connection_id,
903 proto::UpdateContacts {
904 contacts: vec![invitee_contact.clone()],
905 ..Default::default()
906 },
907 )?;
908 self.peer.send(
909 connection_id,
910 proto::UpdateInviteInfo {
911 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
912 count: user.invite_count as u32,
913 },
914 )?;
915 }
916 }
917 }
918 Ok(())
919 }
920
921 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
922 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
923 if let Some(invite_code) = &user.invite_code {
924 let pool = self.connection_pool.lock();
925 for connection_id in pool.user_connection_ids(user_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateInviteInfo {
929 url: format!(
930 "{}{}",
931 self.app_state.config.invite_link_prefix, invite_code
932 ),
933 count: user.invite_count as u32,
934 },
935 )?;
936 }
937 }
938 }
939 Ok(())
940 }
941
942 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
943 let user = self
944 .app_state
945 .db
946 .get_user_by_id(user_id)
947 .await?
948 .context("user not found")?;
949
950 let update_user_plan = make_update_user_plan_message(
951 &user,
952 user.admin,
953 &self.app_state.db,
954 self.app_state.llm_db.clone(),
955 )
956 .await?;
957
958 let pool = self.connection_pool.lock();
959 for connection_id in pool.user_connection_ids(user_id) {
960 self.peer
961 .send(connection_id, update_user_plan.clone())
962 .trace_err();
963 }
964
965 Ok(())
966 }
967
968 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
969 let pool = self.connection_pool.lock();
970 for connection_id in pool.user_connection_ids(user_id) {
971 self.peer
972 .send(connection_id, proto::RefreshLlmToken {})
973 .trace_err();
974 }
975 }
976
977 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
978 ServerSnapshot {
979 connection_pool: ConnectionPoolGuard {
980 guard: self.connection_pool.lock(),
981 _not_send: PhantomData,
982 },
983 peer: &self.peer,
984 }
985 }
986}
987
988impl Deref for ConnectionPoolGuard<'_> {
989 type Target = ConnectionPool;
990
991 fn deref(&self) -> &Self::Target {
992 &self.guard
993 }
994}
995
996impl DerefMut for ConnectionPoolGuard<'_> {
997 fn deref_mut(&mut self) -> &mut Self::Target {
998 &mut self.guard
999 }
1000}
1001
1002impl Drop for ConnectionPoolGuard<'_> {
1003 fn drop(&mut self) {
1004 #[cfg(test)]
1005 self.check_invariants();
1006 }
1007}
1008
1009fn broadcast<F>(
1010 sender_id: Option<ConnectionId>,
1011 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1012 mut f: F,
1013) where
1014 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1015{
1016 for receiver_id in receiver_ids {
1017 if Some(receiver_id) != sender_id {
1018 if let Err(error) = f(receiver_id) {
1019 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1020 }
1021 }
1022 }
1023}
1024
1025pub struct ProtocolVersion(u32);
1026
1027impl Header for ProtocolVersion {
1028 fn name() -> &'static HeaderName {
1029 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1030 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1031 }
1032
1033 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1034 where
1035 Self: Sized,
1036 I: Iterator<Item = &'i axum::http::HeaderValue>,
1037 {
1038 let version = values
1039 .next()
1040 .ok_or_else(axum::headers::Error::invalid)?
1041 .to_str()
1042 .map_err(|_| axum::headers::Error::invalid())?
1043 .parse()
1044 .map_err(|_| axum::headers::Error::invalid())?;
1045 Ok(Self(version))
1046 }
1047
1048 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1049 values.extend([self.0.to_string().parse().unwrap()]);
1050 }
1051}
1052
1053pub struct AppVersionHeader(SemanticVersion);
1054impl Header for AppVersionHeader {
1055 fn name() -> &'static HeaderName {
1056 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1057 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1058 }
1059
1060 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1061 where
1062 Self: Sized,
1063 I: Iterator<Item = &'i axum::http::HeaderValue>,
1064 {
1065 let version = values
1066 .next()
1067 .ok_or_else(axum::headers::Error::invalid)?
1068 .to_str()
1069 .map_err(|_| axum::headers::Error::invalid())?
1070 .parse()
1071 .map_err(|_| axum::headers::Error::invalid())?;
1072 Ok(Self(version))
1073 }
1074
1075 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1076 values.extend([self.0.to_string().parse().unwrap()]);
1077 }
1078}
1079
1080pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1081 Router::new()
1082 .route("/rpc", get(handle_websocket_request))
1083 .layer(
1084 ServiceBuilder::new()
1085 .layer(Extension(server.app_state.clone()))
1086 .layer(middleware::from_fn(auth::validate_header)),
1087 )
1088 .route("/metrics", get(handle_metrics))
1089 .layer(Extension(server))
1090}
1091
1092pub async fn handle_websocket_request(
1093 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1094 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1095 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1096 Extension(server): Extension<Arc<Server>>,
1097 Extension(principal): Extension<Principal>,
1098 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1099 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1100 ws: WebSocketUpgrade,
1101) -> axum::response::Response {
1102 if protocol_version != rpc::PROTOCOL_VERSION {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "client must be upgraded".to_string(),
1106 )
1107 .into_response();
1108 }
1109
1110 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "no version header found".to_string(),
1114 )
1115 .into_response();
1116 };
1117
1118 if !version.can_collaborate() {
1119 return (
1120 StatusCode::UPGRADE_REQUIRED,
1121 "client must be upgraded".to_string(),
1122 )
1123 .into_response();
1124 }
1125
1126 let socket_address = socket_address.to_string();
1127 ws.on_upgrade(move |socket| {
1128 let socket = socket
1129 .map_ok(to_tungstenite_message)
1130 .err_into()
1131 .with(|message| async move { to_axum_message(message) });
1132 let connection = Connection::new(Box::pin(socket));
1133 async move {
1134 server
1135 .handle_connection(
1136 connection,
1137 socket_address,
1138 principal,
1139 version,
1140 country_code_header.map(|header| header.to_string()),
1141 system_id_header.map(|header| header.to_string()),
1142 None,
1143 Executor::Production,
1144 )
1145 .await;
1146 }
1147 })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152 let connections_metric = CONNECTIONS_METRIC
1153 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155 let connections = server
1156 .connection_pool
1157 .lock()
1158 .connections()
1159 .filter(|connection| !connection.admin)
1160 .count();
1161 connections_metric.set(connections as _);
1162
1163 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165 register_int_gauge!(
1166 "shared_projects",
1167 "number of open projects with one or more guests"
1168 )
1169 .unwrap()
1170 });
1171
1172 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173 shared_projects_metric.set(shared_projects as _);
1174
1175 let encoder = prometheus::TextEncoder::new();
1176 let metric_families = prometheus::gather();
1177 let encoded_metrics = encoder
1178 .encode_to_string(&metric_families)
1179 .map_err(|err| anyhow!("{err}"))?;
1180 Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185 session: Session,
1186 mut teardown: watch::Receiver<bool>,
1187 executor: Executor,
1188) -> Result<()> {
1189 session.peer.disconnect(session.connection_id);
1190 session
1191 .connection_pool()
1192 .await
1193 .remove_connection(session.connection_id)?;
1194
1195 session
1196 .db()
1197 .await
1198 .connection_lost(session.connection_id)
1199 .await
1200 .trace_err();
1201
1202 futures::select_biased! {
1203 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204
1205 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206 leave_room_for_session(&session, session.connection_id).await.trace_err();
1207 leave_channel_buffers_for_session(&session)
1208 .await
1209 .trace_err();
1210
1211 if !session
1212 .connection_pool()
1213 .await
1214 .is_user_online(session.user_id())
1215 {
1216 let db = session.db().await;
1217 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218 room_updated(&room, &session.peer);
1219 }
1220 }
1221
1222 update_user_contacts(session.user_id(), &session).await?;
1223 },
1224 _ = teardown.changed().fuse() => {}
1225 }
1226
1227 Ok(())
1228}
1229
1230/// Acknowledges a ping from a client, used to keep the connection alive.
1231async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1232 response.send(proto::Ack {})?;
1233 Ok(())
1234}
1235
1236/// Creates a new room for calling (outside of channels)
1237async fn create_room(
1238 _request: proto::CreateRoom,
1239 response: Response<proto::CreateRoom>,
1240 session: Session,
1241) -> Result<()> {
1242 let livekit_room = nanoid::nanoid!(30);
1243
1244 let live_kit_connection_info = util::maybe!(async {
1245 let live_kit = session.app_state.livekit_client.as_ref();
1246 let live_kit = live_kit?;
1247 let user_id = session.user_id().to_string();
1248
1249 let token = live_kit
1250 .room_token(&livekit_room, &user_id.to_string())
1251 .trace_err()?;
1252
1253 Some(proto::LiveKitConnectionInfo {
1254 server_url: live_kit.url().into(),
1255 token,
1256 can_publish: true,
1257 })
1258 })
1259 .await;
1260
1261 let room = session
1262 .db()
1263 .await
1264 .create_room(session.user_id(), session.connection_id, &livekit_room)
1265 .await?;
1266
1267 response.send(proto::CreateRoomResponse {
1268 room: Some(room.clone()),
1269 live_kit_connection_info,
1270 })?;
1271
1272 update_user_contacts(session.user_id(), &session).await?;
1273 Ok(())
1274}
1275
1276/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1277async fn join_room(
1278 request: proto::JoinRoom,
1279 response: Response<proto::JoinRoom>,
1280 session: Session,
1281) -> Result<()> {
1282 let room_id = RoomId::from_proto(request.id);
1283
1284 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1285
1286 if let Some(channel_id) = channel_id {
1287 return join_channel_internal(channel_id, Box::new(response), session).await;
1288 }
1289
1290 let joined_room = {
1291 let room = session
1292 .db()
1293 .await
1294 .join_room(room_id, session.user_id(), session.connection_id)
1295 .await?;
1296 room_updated(&room.room, &session.peer);
1297 room.into_inner()
1298 };
1299
1300 for connection_id in session
1301 .connection_pool()
1302 .await
1303 .user_connection_ids(session.user_id())
1304 {
1305 session
1306 .peer
1307 .send(
1308 connection_id,
1309 proto::CallCanceled {
1310 room_id: room_id.to_proto(),
1311 },
1312 )
1313 .trace_err();
1314 }
1315
1316 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1317 {
1318 live_kit
1319 .room_token(
1320 &joined_room.room.livekit_room,
1321 &session.user_id().to_string(),
1322 )
1323 .trace_err()
1324 .map(|token| proto::LiveKitConnectionInfo {
1325 server_url: live_kit.url().into(),
1326 token,
1327 can_publish: true,
1328 })
1329 } else {
1330 None
1331 };
1332
1333 response.send(proto::JoinRoomResponse {
1334 room: Some(joined_room.room),
1335 channel_id: None,
1336 live_kit_connection_info,
1337 })?;
1338
1339 update_user_contacts(session.user_id(), &session).await?;
1340 Ok(())
1341}
1342
1343/// Rejoin room is used to reconnect to a room after connection errors.
1344async fn rejoin_room(
1345 request: proto::RejoinRoom,
1346 response: Response<proto::RejoinRoom>,
1347 session: Session,
1348) -> Result<()> {
1349 let room;
1350 let channel;
1351 {
1352 let mut rejoined_room = session
1353 .db()
1354 .await
1355 .rejoin_room(request, session.user_id(), session.connection_id)
1356 .await?;
1357
1358 response.send(proto::RejoinRoomResponse {
1359 room: Some(rejoined_room.room.clone()),
1360 reshared_projects: rejoined_room
1361 .reshared_projects
1362 .iter()
1363 .map(|project| proto::ResharedProject {
1364 id: project.id.to_proto(),
1365 collaborators: project
1366 .collaborators
1367 .iter()
1368 .map(|collaborator| collaborator.to_proto())
1369 .collect(),
1370 })
1371 .collect(),
1372 rejoined_projects: rejoined_room
1373 .rejoined_projects
1374 .iter()
1375 .map(|rejoined_project| rejoined_project.to_proto())
1376 .collect(),
1377 })?;
1378 room_updated(&rejoined_room.room, &session.peer);
1379
1380 for project in &rejoined_room.reshared_projects {
1381 for collaborator in &project.collaborators {
1382 session
1383 .peer
1384 .send(
1385 collaborator.connection_id,
1386 proto::UpdateProjectCollaborator {
1387 project_id: project.id.to_proto(),
1388 old_peer_id: Some(project.old_connection_id.into()),
1389 new_peer_id: Some(session.connection_id.into()),
1390 },
1391 )
1392 .trace_err();
1393 }
1394
1395 broadcast(
1396 Some(session.connection_id),
1397 project
1398 .collaborators
1399 .iter()
1400 .map(|collaborator| collaborator.connection_id),
1401 |connection_id| {
1402 session.peer.forward_send(
1403 session.connection_id,
1404 connection_id,
1405 proto::UpdateProject {
1406 project_id: project.id.to_proto(),
1407 worktrees: project.worktrees.clone(),
1408 },
1409 )
1410 },
1411 );
1412 }
1413
1414 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1415
1416 let rejoined_room = rejoined_room.into_inner();
1417
1418 room = rejoined_room.room;
1419 channel = rejoined_room.channel;
1420 }
1421
1422 if let Some(channel) = channel {
1423 channel_updated(
1424 &channel,
1425 &room,
1426 &session.peer,
1427 &*session.connection_pool().await,
1428 );
1429 }
1430
1431 update_user_contacts(session.user_id(), &session).await?;
1432 Ok(())
1433}
1434
1435fn notify_rejoined_projects(
1436 rejoined_projects: &mut Vec<RejoinedProject>,
1437 session: &Session,
1438) -> Result<()> {
1439 for project in rejoined_projects.iter() {
1440 for collaborator in &project.collaborators {
1441 session
1442 .peer
1443 .send(
1444 collaborator.connection_id,
1445 proto::UpdateProjectCollaborator {
1446 project_id: project.id.to_proto(),
1447 old_peer_id: Some(project.old_connection_id.into()),
1448 new_peer_id: Some(session.connection_id.into()),
1449 },
1450 )
1451 .trace_err();
1452 }
1453 }
1454
1455 for project in rejoined_projects {
1456 for worktree in mem::take(&mut project.worktrees) {
1457 // Stream this worktree's entries.
1458 let message = proto::UpdateWorktree {
1459 project_id: project.id.to_proto(),
1460 worktree_id: worktree.id,
1461 abs_path: worktree.abs_path.clone(),
1462 root_name: worktree.root_name,
1463 updated_entries: worktree.updated_entries,
1464 removed_entries: worktree.removed_entries,
1465 scan_id: worktree.scan_id,
1466 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1467 updated_repositories: worktree.updated_repositories,
1468 removed_repositories: worktree.removed_repositories,
1469 };
1470 for update in proto::split_worktree_update(message) {
1471 session.peer.send(session.connection_id, update)?;
1472 }
1473
1474 // Stream this worktree's diagnostics.
1475 for summary in worktree.diagnostic_summaries {
1476 session.peer.send(
1477 session.connection_id,
1478 proto::UpdateDiagnosticSummary {
1479 project_id: project.id.to_proto(),
1480 worktree_id: worktree.id,
1481 summary: Some(summary),
1482 },
1483 )?;
1484 }
1485
1486 for settings_file in worktree.settings_files {
1487 session.peer.send(
1488 session.connection_id,
1489 proto::UpdateWorktreeSettings {
1490 project_id: project.id.to_proto(),
1491 worktree_id: worktree.id,
1492 path: settings_file.path,
1493 content: Some(settings_file.content),
1494 kind: Some(settings_file.kind.to_proto().into()),
1495 },
1496 )?;
1497 }
1498 }
1499
1500 for repository in mem::take(&mut project.updated_repositories) {
1501 for update in split_repository_update(repository) {
1502 session.peer.send(session.connection_id, update)?;
1503 }
1504 }
1505
1506 for id in mem::take(&mut project.removed_repositories) {
1507 session.peer.send(
1508 session.connection_id,
1509 proto::RemoveRepository {
1510 project_id: project.id.to_proto(),
1511 id,
1512 },
1513 )?;
1514 }
1515 }
1516
1517 Ok(())
1518}
1519
1520/// leave room disconnects from the room.
1521async fn leave_room(
1522 _: proto::LeaveRoom,
1523 response: Response<proto::LeaveRoom>,
1524 session: Session,
1525) -> Result<()> {
1526 leave_room_for_session(&session, session.connection_id).await?;
1527 response.send(proto::Ack {})?;
1528 Ok(())
1529}
1530
1531/// Updates the permissions of someone else in the room.
1532async fn set_room_participant_role(
1533 request: proto::SetRoomParticipantRole,
1534 response: Response<proto::SetRoomParticipantRole>,
1535 session: Session,
1536) -> Result<()> {
1537 let user_id = UserId::from_proto(request.user_id);
1538 let role = ChannelRole::from(request.role());
1539
1540 let (livekit_room, can_publish) = {
1541 let room = session
1542 .db()
1543 .await
1544 .set_room_participant_role(
1545 session.user_id(),
1546 RoomId::from_proto(request.room_id),
1547 user_id,
1548 role,
1549 )
1550 .await?;
1551
1552 let livekit_room = room.livekit_room.clone();
1553 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1554 room_updated(&room, &session.peer);
1555 (livekit_room, can_publish)
1556 };
1557
1558 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1559 live_kit
1560 .update_participant(
1561 livekit_room.clone(),
1562 request.user_id.to_string(),
1563 livekit_api::proto::ParticipantPermission {
1564 can_subscribe: true,
1565 can_publish,
1566 can_publish_data: can_publish,
1567 hidden: false,
1568 recorder: false,
1569 },
1570 )
1571 .await
1572 .trace_err();
1573 }
1574
1575 response.send(proto::Ack {})?;
1576 Ok(())
1577}
1578
1579/// Call someone else into the current room
1580async fn call(
1581 request: proto::Call,
1582 response: Response<proto::Call>,
1583 session: Session,
1584) -> Result<()> {
1585 let room_id = RoomId::from_proto(request.room_id);
1586 let calling_user_id = session.user_id();
1587 let calling_connection_id = session.connection_id;
1588 let called_user_id = UserId::from_proto(request.called_user_id);
1589 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1590 if !session
1591 .db()
1592 .await
1593 .has_contact(calling_user_id, called_user_id)
1594 .await?
1595 {
1596 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1597 }
1598
1599 let incoming_call = {
1600 let (room, incoming_call) = &mut *session
1601 .db()
1602 .await
1603 .call(
1604 room_id,
1605 calling_user_id,
1606 calling_connection_id,
1607 called_user_id,
1608 initial_project_id,
1609 )
1610 .await?;
1611 room_updated(room, &session.peer);
1612 mem::take(incoming_call)
1613 };
1614 update_user_contacts(called_user_id, &session).await?;
1615
1616 let mut calls = session
1617 .connection_pool()
1618 .await
1619 .user_connection_ids(called_user_id)
1620 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1621 .collect::<FuturesUnordered<_>>();
1622
1623 while let Some(call_response) = calls.next().await {
1624 match call_response.as_ref() {
1625 Ok(_) => {
1626 response.send(proto::Ack {})?;
1627 return Ok(());
1628 }
1629 Err(_) => {
1630 call_response.trace_err();
1631 }
1632 }
1633 }
1634
1635 {
1636 let room = session
1637 .db()
1638 .await
1639 .call_failed(room_id, called_user_id)
1640 .await?;
1641 room_updated(&room, &session.peer);
1642 }
1643 update_user_contacts(called_user_id, &session).await?;
1644
1645 Err(anyhow!("failed to ring user"))?
1646}
1647
1648/// Cancel an outgoing call.
1649async fn cancel_call(
1650 request: proto::CancelCall,
1651 response: Response<proto::CancelCall>,
1652 session: Session,
1653) -> Result<()> {
1654 let called_user_id = UserId::from_proto(request.called_user_id);
1655 let room_id = RoomId::from_proto(request.room_id);
1656 {
1657 let room = session
1658 .db()
1659 .await
1660 .cancel_call(room_id, session.connection_id, called_user_id)
1661 .await?;
1662 room_updated(&room, &session.peer);
1663 }
1664
1665 for connection_id in session
1666 .connection_pool()
1667 .await
1668 .user_connection_ids(called_user_id)
1669 {
1670 session
1671 .peer
1672 .send(
1673 connection_id,
1674 proto::CallCanceled {
1675 room_id: room_id.to_proto(),
1676 },
1677 )
1678 .trace_err();
1679 }
1680 response.send(proto::Ack {})?;
1681
1682 update_user_contacts(called_user_id, &session).await?;
1683 Ok(())
1684}
1685
1686/// Decline an incoming call.
1687async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1688 let room_id = RoomId::from_proto(message.room_id);
1689 {
1690 let room = session
1691 .db()
1692 .await
1693 .decline_call(Some(room_id), session.user_id())
1694 .await?
1695 .context("declining call")?;
1696 room_updated(&room, &session.peer);
1697 }
1698
1699 for connection_id in session
1700 .connection_pool()
1701 .await
1702 .user_connection_ids(session.user_id())
1703 {
1704 session
1705 .peer
1706 .send(
1707 connection_id,
1708 proto::CallCanceled {
1709 room_id: room_id.to_proto(),
1710 },
1711 )
1712 .trace_err();
1713 }
1714 update_user_contacts(session.user_id(), &session).await?;
1715 Ok(())
1716}
1717
1718/// Updates other participants in the room with your current location.
1719async fn update_participant_location(
1720 request: proto::UpdateParticipantLocation,
1721 response: Response<proto::UpdateParticipantLocation>,
1722 session: Session,
1723) -> Result<()> {
1724 let room_id = RoomId::from_proto(request.room_id);
1725 let location = request.location.context("invalid location")?;
1726
1727 let db = session.db().await;
1728 let room = db
1729 .update_room_participant_location(room_id, session.connection_id, location)
1730 .await?;
1731
1732 room_updated(&room, &session.peer);
1733 response.send(proto::Ack {})?;
1734 Ok(())
1735}
1736
1737/// Share a project into the room.
1738async fn share_project(
1739 request: proto::ShareProject,
1740 response: Response<proto::ShareProject>,
1741 session: Session,
1742) -> Result<()> {
1743 let (project_id, room) = &*session
1744 .db()
1745 .await
1746 .share_project(
1747 RoomId::from_proto(request.room_id),
1748 session.connection_id,
1749 &request.worktrees,
1750 request.is_ssh_project,
1751 )
1752 .await?;
1753 response.send(proto::ShareProjectResponse {
1754 project_id: project_id.to_proto(),
1755 })?;
1756 room_updated(room, &session.peer);
1757
1758 Ok(())
1759}
1760
1761/// Unshare a project from the room.
1762async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1763 let project_id = ProjectId::from_proto(message.project_id);
1764 unshare_project_internal(project_id, session.connection_id, &session).await
1765}
1766
1767async fn unshare_project_internal(
1768 project_id: ProjectId,
1769 connection_id: ConnectionId,
1770 session: &Session,
1771) -> Result<()> {
1772 let delete = {
1773 let room_guard = session
1774 .db()
1775 .await
1776 .unshare_project(project_id, connection_id)
1777 .await?;
1778
1779 let (delete, room, guest_connection_ids) = &*room_guard;
1780
1781 let message = proto::UnshareProject {
1782 project_id: project_id.to_proto(),
1783 };
1784
1785 broadcast(
1786 Some(connection_id),
1787 guest_connection_ids.iter().copied(),
1788 |conn_id| session.peer.send(conn_id, message.clone()),
1789 );
1790 if let Some(room) = room {
1791 room_updated(room, &session.peer);
1792 }
1793
1794 *delete
1795 };
1796
1797 if delete {
1798 let db = session.db().await;
1799 db.delete_project(project_id).await?;
1800 }
1801
1802 Ok(())
1803}
1804
1805/// Join someone elses shared project.
1806async fn join_project(
1807 request: proto::JoinProject,
1808 response: Response<proto::JoinProject>,
1809 session: Session,
1810) -> Result<()> {
1811 let project_id = ProjectId::from_proto(request.project_id);
1812
1813 tracing::info!(%project_id, "join project");
1814
1815 let db = session.db().await;
1816 let (project, replica_id) = &mut *db
1817 .join_project(project_id, session.connection_id, session.user_id())
1818 .await?;
1819 drop(db);
1820 tracing::info!(%project_id, "join remote project");
1821 join_project_internal(response, session, project, replica_id)
1822}
1823
1824trait JoinProjectInternalResponse {
1825 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1826}
1827impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1828 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1829 Response::<proto::JoinProject>::send(self, result)
1830 }
1831}
1832
1833fn join_project_internal(
1834 response: impl JoinProjectInternalResponse,
1835 session: Session,
1836 project: &mut Project,
1837 replica_id: &ReplicaId,
1838) -> Result<()> {
1839 let collaborators = project
1840 .collaborators
1841 .iter()
1842 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1843 .map(|collaborator| collaborator.to_proto())
1844 .collect::<Vec<_>>();
1845 let project_id = project.id;
1846 let guest_user_id = session.user_id();
1847
1848 let worktrees = project
1849 .worktrees
1850 .iter()
1851 .map(|(id, worktree)| proto::WorktreeMetadata {
1852 id: *id,
1853 root_name: worktree.root_name.clone(),
1854 visible: worktree.visible,
1855 abs_path: worktree.abs_path.clone(),
1856 })
1857 .collect::<Vec<_>>();
1858
1859 let add_project_collaborator = proto::AddProjectCollaborator {
1860 project_id: project_id.to_proto(),
1861 collaborator: Some(proto::Collaborator {
1862 peer_id: Some(session.connection_id.into()),
1863 replica_id: replica_id.0 as u32,
1864 user_id: guest_user_id.to_proto(),
1865 is_host: false,
1866 }),
1867 };
1868
1869 for collaborator in &collaborators {
1870 session
1871 .peer
1872 .send(
1873 collaborator.peer_id.unwrap().into(),
1874 add_project_collaborator.clone(),
1875 )
1876 .trace_err();
1877 }
1878
1879 // First, we send the metadata associated with each worktree.
1880 response.send(proto::JoinProjectResponse {
1881 project_id: project.id.0 as u64,
1882 worktrees: worktrees.clone(),
1883 replica_id: replica_id.0 as u32,
1884 collaborators: collaborators.clone(),
1885 language_servers: project.language_servers.clone(),
1886 role: project.role.into(),
1887 })?;
1888
1889 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1890 // Stream this worktree's entries.
1891 let message = proto::UpdateWorktree {
1892 project_id: project_id.to_proto(),
1893 worktree_id,
1894 abs_path: worktree.abs_path.clone(),
1895 root_name: worktree.root_name,
1896 updated_entries: worktree.entries,
1897 removed_entries: Default::default(),
1898 scan_id: worktree.scan_id,
1899 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1900 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1901 removed_repositories: Default::default(),
1902 };
1903 for update in proto::split_worktree_update(message) {
1904 session.peer.send(session.connection_id, update.clone())?;
1905 }
1906
1907 // Stream this worktree's diagnostics.
1908 for summary in worktree.diagnostic_summaries {
1909 session.peer.send(
1910 session.connection_id,
1911 proto::UpdateDiagnosticSummary {
1912 project_id: project_id.to_proto(),
1913 worktree_id: worktree.id,
1914 summary: Some(summary),
1915 },
1916 )?;
1917 }
1918
1919 for settings_file in worktree.settings_files {
1920 session.peer.send(
1921 session.connection_id,
1922 proto::UpdateWorktreeSettings {
1923 project_id: project_id.to_proto(),
1924 worktree_id: worktree.id,
1925 path: settings_file.path,
1926 content: Some(settings_file.content),
1927 kind: Some(settings_file.kind.to_proto() as i32),
1928 },
1929 )?;
1930 }
1931 }
1932
1933 for repository in mem::take(&mut project.repositories) {
1934 for update in split_repository_update(repository) {
1935 session.peer.send(session.connection_id, update)?;
1936 }
1937 }
1938
1939 for language_server in &project.language_servers {
1940 session.peer.send(
1941 session.connection_id,
1942 proto::UpdateLanguageServer {
1943 project_id: project_id.to_proto(),
1944 language_server_id: language_server.id,
1945 variant: Some(
1946 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1947 proto::LspDiskBasedDiagnosticsUpdated {},
1948 ),
1949 ),
1950 },
1951 )?;
1952 }
1953
1954 Ok(())
1955}
1956
1957/// Leave someone elses shared project.
1958async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1959 let sender_id = session.connection_id;
1960 let project_id = ProjectId::from_proto(request.project_id);
1961 let db = session.db().await;
1962
1963 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1964 tracing::info!(
1965 %project_id,
1966 "leave project"
1967 );
1968
1969 project_left(project, &session);
1970 if let Some(room) = room {
1971 room_updated(room, &session.peer);
1972 }
1973
1974 Ok(())
1975}
1976
1977/// Updates other participants with changes to the project
1978async fn update_project(
1979 request: proto::UpdateProject,
1980 response: Response<proto::UpdateProject>,
1981 session: Session,
1982) -> Result<()> {
1983 let project_id = ProjectId::from_proto(request.project_id);
1984 let (room, guest_connection_ids) = &*session
1985 .db()
1986 .await
1987 .update_project(project_id, session.connection_id, &request.worktrees)
1988 .await?;
1989 broadcast(
1990 Some(session.connection_id),
1991 guest_connection_ids.iter().copied(),
1992 |connection_id| {
1993 session
1994 .peer
1995 .forward_send(session.connection_id, connection_id, request.clone())
1996 },
1997 );
1998 if let Some(room) = room {
1999 room_updated(room, &session.peer);
2000 }
2001 response.send(proto::Ack {})?;
2002
2003 Ok(())
2004}
2005
2006/// Updates other participants with changes to the worktree
2007async fn update_worktree(
2008 request: proto::UpdateWorktree,
2009 response: Response<proto::UpdateWorktree>,
2010 session: Session,
2011) -> Result<()> {
2012 let guest_connection_ids = session
2013 .db()
2014 .await
2015 .update_worktree(&request, session.connection_id)
2016 .await?;
2017
2018 broadcast(
2019 Some(session.connection_id),
2020 guest_connection_ids.iter().copied(),
2021 |connection_id| {
2022 session
2023 .peer
2024 .forward_send(session.connection_id, connection_id, request.clone())
2025 },
2026 );
2027 response.send(proto::Ack {})?;
2028 Ok(())
2029}
2030
2031async fn update_repository(
2032 request: proto::UpdateRepository,
2033 response: Response<proto::UpdateRepository>,
2034 session: Session,
2035) -> Result<()> {
2036 let guest_connection_ids = session
2037 .db()
2038 .await
2039 .update_repository(&request, session.connection_id)
2040 .await?;
2041
2042 broadcast(
2043 Some(session.connection_id),
2044 guest_connection_ids.iter().copied(),
2045 |connection_id| {
2046 session
2047 .peer
2048 .forward_send(session.connection_id, connection_id, request.clone())
2049 },
2050 );
2051 response.send(proto::Ack {})?;
2052 Ok(())
2053}
2054
2055async fn remove_repository(
2056 request: proto::RemoveRepository,
2057 response: Response<proto::RemoveRepository>,
2058 session: Session,
2059) -> Result<()> {
2060 let guest_connection_ids = session
2061 .db()
2062 .await
2063 .remove_repository(&request, session.connection_id)
2064 .await?;
2065
2066 broadcast(
2067 Some(session.connection_id),
2068 guest_connection_ids.iter().copied(),
2069 |connection_id| {
2070 session
2071 .peer
2072 .forward_send(session.connection_id, connection_id, request.clone())
2073 },
2074 );
2075 response.send(proto::Ack {})?;
2076 Ok(())
2077}
2078
2079/// Updates other participants with changes to the diagnostics
2080async fn update_diagnostic_summary(
2081 message: proto::UpdateDiagnosticSummary,
2082 session: Session,
2083) -> Result<()> {
2084 let guest_connection_ids = session
2085 .db()
2086 .await
2087 .update_diagnostic_summary(&message, session.connection_id)
2088 .await?;
2089
2090 broadcast(
2091 Some(session.connection_id),
2092 guest_connection_ids.iter().copied(),
2093 |connection_id| {
2094 session
2095 .peer
2096 .forward_send(session.connection_id, connection_id, message.clone())
2097 },
2098 );
2099
2100 Ok(())
2101}
2102
2103/// Updates other participants with changes to the worktree settings
2104async fn update_worktree_settings(
2105 message: proto::UpdateWorktreeSettings,
2106 session: Session,
2107) -> Result<()> {
2108 let guest_connection_ids = session
2109 .db()
2110 .await
2111 .update_worktree_settings(&message, session.connection_id)
2112 .await?;
2113
2114 broadcast(
2115 Some(session.connection_id),
2116 guest_connection_ids.iter().copied(),
2117 |connection_id| {
2118 session
2119 .peer
2120 .forward_send(session.connection_id, connection_id, message.clone())
2121 },
2122 );
2123
2124 Ok(())
2125}
2126
2127/// Notify other participants that a language server has started.
2128async fn start_language_server(
2129 request: proto::StartLanguageServer,
2130 session: Session,
2131) -> Result<()> {
2132 let guest_connection_ids = session
2133 .db()
2134 .await
2135 .start_language_server(&request, session.connection_id)
2136 .await?;
2137
2138 broadcast(
2139 Some(session.connection_id),
2140 guest_connection_ids.iter().copied(),
2141 |connection_id| {
2142 session
2143 .peer
2144 .forward_send(session.connection_id, connection_id, request.clone())
2145 },
2146 );
2147 Ok(())
2148}
2149
2150/// Notify other participants that a language server has changed.
2151async fn update_language_server(
2152 request: proto::UpdateLanguageServer,
2153 session: Session,
2154) -> Result<()> {
2155 let project_id = ProjectId::from_proto(request.project_id);
2156 let project_connection_ids = session
2157 .db()
2158 .await
2159 .project_connection_ids(project_id, session.connection_id, true)
2160 .await?;
2161 broadcast(
2162 Some(session.connection_id),
2163 project_connection_ids.iter().copied(),
2164 |connection_id| {
2165 session
2166 .peer
2167 .forward_send(session.connection_id, connection_id, request.clone())
2168 },
2169 );
2170 Ok(())
2171}
2172
2173/// forward a project request to the host. These requests should be read only
2174/// as guests are allowed to send them.
2175async fn forward_read_only_project_request<T>(
2176 request: T,
2177 response: Response<T>,
2178 session: Session,
2179) -> Result<()>
2180where
2181 T: EntityMessage + RequestMessage,
2182{
2183 let project_id = ProjectId::from_proto(request.remote_entity_id());
2184 let host_connection_id = session
2185 .db()
2186 .await
2187 .host_for_read_only_project_request(project_id, session.connection_id)
2188 .await?;
2189 let payload = session
2190 .peer
2191 .forward_request(session.connection_id, host_connection_id, request)
2192 .await?;
2193 response.send(payload)?;
2194 Ok(())
2195}
2196
2197async fn forward_find_search_candidates_request(
2198 request: proto::FindSearchCandidates,
2199 response: Response<proto::FindSearchCandidates>,
2200 session: Session,
2201) -> Result<()> {
2202 let project_id = ProjectId::from_proto(request.remote_entity_id());
2203 let host_connection_id = session
2204 .db()
2205 .await
2206 .host_for_read_only_project_request(project_id, session.connection_id)
2207 .await?;
2208 let payload = session
2209 .peer
2210 .forward_request(session.connection_id, host_connection_id, request)
2211 .await?;
2212 response.send(payload)?;
2213 Ok(())
2214}
2215
2216/// forward a project request to the host. These requests are disallowed
2217/// for guests.
2218async fn forward_mutating_project_request<T>(
2219 request: T,
2220 response: Response<T>,
2221 session: Session,
2222) -> Result<()>
2223where
2224 T: EntityMessage + RequestMessage,
2225{
2226 let project_id = ProjectId::from_proto(request.remote_entity_id());
2227
2228 let host_connection_id = session
2229 .db()
2230 .await
2231 .host_for_mutating_project_request(project_id, session.connection_id)
2232 .await?;
2233 let payload = session
2234 .peer
2235 .forward_request(session.connection_id, host_connection_id, request)
2236 .await?;
2237 response.send(payload)?;
2238 Ok(())
2239}
2240
2241/// Notify other participants that a new buffer has been created
2242async fn create_buffer_for_peer(
2243 request: proto::CreateBufferForPeer,
2244 session: Session,
2245) -> Result<()> {
2246 session
2247 .db()
2248 .await
2249 .check_user_is_project_host(
2250 ProjectId::from_proto(request.project_id),
2251 session.connection_id,
2252 )
2253 .await?;
2254 let peer_id = request.peer_id.context("invalid peer id")?;
2255 session
2256 .peer
2257 .forward_send(session.connection_id, peer_id.into(), request)?;
2258 Ok(())
2259}
2260
2261/// Notify other participants that a buffer has been updated. This is
2262/// allowed for guests as long as the update is limited to selections.
2263async fn update_buffer(
2264 request: proto::UpdateBuffer,
2265 response: Response<proto::UpdateBuffer>,
2266 session: Session,
2267) -> Result<()> {
2268 let project_id = ProjectId::from_proto(request.project_id);
2269 let mut capability = Capability::ReadOnly;
2270
2271 for op in request.operations.iter() {
2272 match op.variant {
2273 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2274 Some(_) => capability = Capability::ReadWrite,
2275 }
2276 }
2277
2278 let host = {
2279 let guard = session
2280 .db()
2281 .await
2282 .connections_for_buffer_update(project_id, session.connection_id, capability)
2283 .await?;
2284
2285 let (host, guests) = &*guard;
2286
2287 broadcast(
2288 Some(session.connection_id),
2289 guests.clone(),
2290 |connection_id| {
2291 session
2292 .peer
2293 .forward_send(session.connection_id, connection_id, request.clone())
2294 },
2295 );
2296
2297 *host
2298 };
2299
2300 if host != session.connection_id {
2301 session
2302 .peer
2303 .forward_request(session.connection_id, host, request.clone())
2304 .await?;
2305 }
2306
2307 response.send(proto::Ack {})?;
2308 Ok(())
2309}
2310
2311async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2312 let project_id = ProjectId::from_proto(message.project_id);
2313
2314 let operation = message.operation.as_ref().context("invalid operation")?;
2315 let capability = match operation.variant.as_ref() {
2316 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2317 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2318 match buffer_op.variant {
2319 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2320 Capability::ReadOnly
2321 }
2322 _ => Capability::ReadWrite,
2323 }
2324 } else {
2325 Capability::ReadWrite
2326 }
2327 }
2328 Some(_) => Capability::ReadWrite,
2329 None => Capability::ReadOnly,
2330 };
2331
2332 let guard = session
2333 .db()
2334 .await
2335 .connections_for_buffer_update(project_id, session.connection_id, capability)
2336 .await?;
2337
2338 let (host, guests) = &*guard;
2339
2340 broadcast(
2341 Some(session.connection_id),
2342 guests.iter().chain([host]).copied(),
2343 |connection_id| {
2344 session
2345 .peer
2346 .forward_send(session.connection_id, connection_id, message.clone())
2347 },
2348 );
2349
2350 Ok(())
2351}
2352
2353/// Notify other participants that a project has been updated.
2354async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2355 request: T,
2356 session: Session,
2357) -> Result<()> {
2358 let project_id = ProjectId::from_proto(request.remote_entity_id());
2359 let project_connection_ids = session
2360 .db()
2361 .await
2362 .project_connection_ids(project_id, session.connection_id, false)
2363 .await?;
2364
2365 broadcast(
2366 Some(session.connection_id),
2367 project_connection_ids.iter().copied(),
2368 |connection_id| {
2369 session
2370 .peer
2371 .forward_send(session.connection_id, connection_id, request.clone())
2372 },
2373 );
2374 Ok(())
2375}
2376
2377/// Start following another user in a call.
2378async fn follow(
2379 request: proto::Follow,
2380 response: Response<proto::Follow>,
2381 session: Session,
2382) -> Result<()> {
2383 let room_id = RoomId::from_proto(request.room_id);
2384 let project_id = request.project_id.map(ProjectId::from_proto);
2385 let leader_id = request.leader_id.context("invalid leader id")?.into();
2386 let follower_id = session.connection_id;
2387
2388 session
2389 .db()
2390 .await
2391 .check_room_participants(room_id, leader_id, session.connection_id)
2392 .await?;
2393
2394 let response_payload = session
2395 .peer
2396 .forward_request(session.connection_id, leader_id, request)
2397 .await?;
2398 response.send(response_payload)?;
2399
2400 if let Some(project_id) = project_id {
2401 let room = session
2402 .db()
2403 .await
2404 .follow(room_id, project_id, leader_id, follower_id)
2405 .await?;
2406 room_updated(&room, &session.peer);
2407 }
2408
2409 Ok(())
2410}
2411
2412/// Stop following another user in a call.
2413async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2414 let room_id = RoomId::from_proto(request.room_id);
2415 let project_id = request.project_id.map(ProjectId::from_proto);
2416 let leader_id = request.leader_id.context("invalid leader id")?.into();
2417 let follower_id = session.connection_id;
2418
2419 session
2420 .db()
2421 .await
2422 .check_room_participants(room_id, leader_id, session.connection_id)
2423 .await?;
2424
2425 session
2426 .peer
2427 .forward_send(session.connection_id, leader_id, request)?;
2428
2429 if let Some(project_id) = project_id {
2430 let room = session
2431 .db()
2432 .await
2433 .unfollow(room_id, project_id, leader_id, follower_id)
2434 .await?;
2435 room_updated(&room, &session.peer);
2436 }
2437
2438 Ok(())
2439}
2440
2441/// Notify everyone following you of your current location.
2442async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2443 let room_id = RoomId::from_proto(request.room_id);
2444 let database = session.db.lock().await;
2445
2446 let connection_ids = if let Some(project_id) = request.project_id {
2447 let project_id = ProjectId::from_proto(project_id);
2448 database
2449 .project_connection_ids(project_id, session.connection_id, true)
2450 .await?
2451 } else {
2452 database
2453 .room_connection_ids(room_id, session.connection_id)
2454 .await?
2455 };
2456
2457 // For now, don't send view update messages back to that view's current leader.
2458 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2459 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2460 _ => None,
2461 });
2462
2463 for connection_id in connection_ids.iter().cloned() {
2464 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2465 session
2466 .peer
2467 .forward_send(session.connection_id, connection_id, request.clone())?;
2468 }
2469 }
2470 Ok(())
2471}
2472
2473/// Get public data about users.
2474async fn get_users(
2475 request: proto::GetUsers,
2476 response: Response<proto::GetUsers>,
2477 session: Session,
2478) -> Result<()> {
2479 let user_ids = request
2480 .user_ids
2481 .into_iter()
2482 .map(UserId::from_proto)
2483 .collect();
2484 let users = session
2485 .db()
2486 .await
2487 .get_users_by_ids(user_ids)
2488 .await?
2489 .into_iter()
2490 .map(|user| proto::User {
2491 id: user.id.to_proto(),
2492 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2493 github_login: user.github_login,
2494 email: user.email_address,
2495 name: user.name,
2496 })
2497 .collect();
2498 response.send(proto::UsersResponse { users })?;
2499 Ok(())
2500}
2501
2502/// Search for users (to invite) buy Github login
2503async fn fuzzy_search_users(
2504 request: proto::FuzzySearchUsers,
2505 response: Response<proto::FuzzySearchUsers>,
2506 session: Session,
2507) -> Result<()> {
2508 let query = request.query;
2509 let users = match query.len() {
2510 0 => vec![],
2511 1 | 2 => session
2512 .db()
2513 .await
2514 .get_user_by_github_login(&query)
2515 .await?
2516 .into_iter()
2517 .collect(),
2518 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2519 };
2520 let users = users
2521 .into_iter()
2522 .filter(|user| user.id != session.user_id())
2523 .map(|user| proto::User {
2524 id: user.id.to_proto(),
2525 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526 github_login: user.github_login,
2527 name: user.name,
2528 email: user.email_address,
2529 })
2530 .collect();
2531 response.send(proto::UsersResponse { users })?;
2532 Ok(())
2533}
2534
2535/// Send a contact request to another user.
2536async fn request_contact(
2537 request: proto::RequestContact,
2538 response: Response<proto::RequestContact>,
2539 session: Session,
2540) -> Result<()> {
2541 let requester_id = session.user_id();
2542 let responder_id = UserId::from_proto(request.responder_id);
2543 if requester_id == responder_id {
2544 return Err(anyhow!("cannot add yourself as a contact"))?;
2545 }
2546
2547 let notifications = session
2548 .db()
2549 .await
2550 .send_contact_request(requester_id, responder_id)
2551 .await?;
2552
2553 // Update outgoing contact requests of requester
2554 let mut update = proto::UpdateContacts::default();
2555 update.outgoing_requests.push(responder_id.to_proto());
2556 for connection_id in session
2557 .connection_pool()
2558 .await
2559 .user_connection_ids(requester_id)
2560 {
2561 session.peer.send(connection_id, update.clone())?;
2562 }
2563
2564 // Update incoming contact requests of responder
2565 let mut update = proto::UpdateContacts::default();
2566 update
2567 .incoming_requests
2568 .push(proto::IncomingContactRequest {
2569 requester_id: requester_id.to_proto(),
2570 });
2571 let connection_pool = session.connection_pool().await;
2572 for connection_id in connection_pool.user_connection_ids(responder_id) {
2573 session.peer.send(connection_id, update.clone())?;
2574 }
2575
2576 send_notifications(&connection_pool, &session.peer, notifications);
2577
2578 response.send(proto::Ack {})?;
2579 Ok(())
2580}
2581
2582/// Accept or decline a contact request
2583async fn respond_to_contact_request(
2584 request: proto::RespondToContactRequest,
2585 response: Response<proto::RespondToContactRequest>,
2586 session: Session,
2587) -> Result<()> {
2588 let responder_id = session.user_id();
2589 let requester_id = UserId::from_proto(request.requester_id);
2590 let db = session.db().await;
2591 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2592 db.dismiss_contact_notification(responder_id, requester_id)
2593 .await?;
2594 } else {
2595 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2596
2597 let notifications = db
2598 .respond_to_contact_request(responder_id, requester_id, accept)
2599 .await?;
2600 let requester_busy = db.is_user_busy(requester_id).await?;
2601 let responder_busy = db.is_user_busy(responder_id).await?;
2602
2603 let pool = session.connection_pool().await;
2604 // Update responder with new contact
2605 let mut update = proto::UpdateContacts::default();
2606 if accept {
2607 update
2608 .contacts
2609 .push(contact_for_user(requester_id, requester_busy, &pool));
2610 }
2611 update
2612 .remove_incoming_requests
2613 .push(requester_id.to_proto());
2614 for connection_id in pool.user_connection_ids(responder_id) {
2615 session.peer.send(connection_id, update.clone())?;
2616 }
2617
2618 // Update requester with new contact
2619 let mut update = proto::UpdateContacts::default();
2620 if accept {
2621 update
2622 .contacts
2623 .push(contact_for_user(responder_id, responder_busy, &pool));
2624 }
2625 update
2626 .remove_outgoing_requests
2627 .push(responder_id.to_proto());
2628
2629 for connection_id in pool.user_connection_ids(requester_id) {
2630 session.peer.send(connection_id, update.clone())?;
2631 }
2632
2633 send_notifications(&pool, &session.peer, notifications);
2634 }
2635
2636 response.send(proto::Ack {})?;
2637 Ok(())
2638}
2639
2640/// Remove a contact.
2641async fn remove_contact(
2642 request: proto::RemoveContact,
2643 response: Response<proto::RemoveContact>,
2644 session: Session,
2645) -> Result<()> {
2646 let requester_id = session.user_id();
2647 let responder_id = UserId::from_proto(request.user_id);
2648 let db = session.db().await;
2649 let (contact_accepted, deleted_notification_id) =
2650 db.remove_contact(requester_id, responder_id).await?;
2651
2652 let pool = session.connection_pool().await;
2653 // Update outgoing contact requests of requester
2654 let mut update = proto::UpdateContacts::default();
2655 if contact_accepted {
2656 update.remove_contacts.push(responder_id.to_proto());
2657 } else {
2658 update
2659 .remove_outgoing_requests
2660 .push(responder_id.to_proto());
2661 }
2662 for connection_id in pool.user_connection_ids(requester_id) {
2663 session.peer.send(connection_id, update.clone())?;
2664 }
2665
2666 // Update incoming contact requests of responder
2667 let mut update = proto::UpdateContacts::default();
2668 if contact_accepted {
2669 update.remove_contacts.push(requester_id.to_proto());
2670 } else {
2671 update
2672 .remove_incoming_requests
2673 .push(requester_id.to_proto());
2674 }
2675 for connection_id in pool.user_connection_ids(responder_id) {
2676 session.peer.send(connection_id, update.clone())?;
2677 if let Some(notification_id) = deleted_notification_id {
2678 session.peer.send(
2679 connection_id,
2680 proto::DeleteNotification {
2681 notification_id: notification_id.to_proto(),
2682 },
2683 )?;
2684 }
2685 }
2686
2687 response.send(proto::Ack {})?;
2688 Ok(())
2689}
2690
2691fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2692 version.0.minor() < 139
2693}
2694
2695async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2696 if is_staff {
2697 return Ok(proto::Plan::ZedPro);
2698 }
2699
2700 let subscription = db.get_active_billing_subscription(user_id).await?;
2701 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2702
2703 let plan = if let Some(subscription_kind) = subscription_kind {
2704 match subscription_kind {
2705 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2706 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2707 SubscriptionKind::ZedFree => proto::Plan::Free,
2708 }
2709 } else {
2710 proto::Plan::Free
2711 };
2712
2713 Ok(plan)
2714}
2715
2716async fn make_update_user_plan_message(
2717 user: &User,
2718 is_staff: bool,
2719 db: &Arc<Database>,
2720 llm_db: Option<Arc<LlmDatabase>>,
2721) -> Result<proto::UpdateUserPlan> {
2722 let feature_flags = db.get_user_flags(user.id).await?;
2723 let plan = current_plan(db, user.id, is_staff).await?;
2724 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2725 let billing_preferences = db.get_billing_preferences(user.id).await?;
2726
2727 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2728 let subscription = db.get_active_billing_subscription(user.id).await?;
2729
2730 let subscription_period =
2731 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2732
2733 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2734 llm_db
2735 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2736 .await?
2737 } else {
2738 None
2739 };
2740
2741 (subscription_period, usage)
2742 } else {
2743 (None, None)
2744 };
2745
2746 let account_too_young =
2747 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2748
2749 Ok(proto::UpdateUserPlan {
2750 plan: plan.into(),
2751 trial_started_at: billing_customer
2752 .as_ref()
2753 .and_then(|billing_customer| billing_customer.trial_started_at)
2754 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2755 is_usage_based_billing_enabled: if is_staff {
2756 Some(true)
2757 } else {
2758 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2759 },
2760 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2761 proto::SubscriptionPeriod {
2762 started_at: started_at.timestamp() as u64,
2763 ended_at: ended_at.timestamp() as u64,
2764 }
2765 }),
2766 account_too_young: Some(account_too_young),
2767 has_overdue_invoices: billing_customer
2768 .map(|billing_customer| billing_customer.has_overdue_invoices),
2769 usage: usage.map(|usage| {
2770 let plan = match plan {
2771 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2772 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2773 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2774 };
2775
2776 let model_requests_limit = match plan.model_requests_limit() {
2777 zed_llm_client::UsageLimit::Limited(limit) => {
2778 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2779 && feature_flags
2780 .iter()
2781 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2782 {
2783 1_000
2784 } else {
2785 limit
2786 };
2787
2788 zed_llm_client::UsageLimit::Limited(limit)
2789 }
2790 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2791 };
2792
2793 proto::SubscriptionUsage {
2794 model_requests_usage_amount: usage.model_requests as u32,
2795 model_requests_usage_limit: Some(proto::UsageLimit {
2796 variant: Some(match model_requests_limit {
2797 zed_llm_client::UsageLimit::Limited(limit) => {
2798 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2799 limit: limit as u32,
2800 })
2801 }
2802 zed_llm_client::UsageLimit::Unlimited => {
2803 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2804 }
2805 }),
2806 }),
2807 edit_predictions_usage_amount: usage.edit_predictions as u32,
2808 edit_predictions_usage_limit: Some(proto::UsageLimit {
2809 variant: Some(match plan.edit_predictions_limit() {
2810 zed_llm_client::UsageLimit::Limited(limit) => {
2811 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2812 limit: limit as u32,
2813 })
2814 }
2815 zed_llm_client::UsageLimit::Unlimited => {
2816 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2817 }
2818 }),
2819 }),
2820 }
2821 }),
2822 })
2823}
2824
2825async fn update_user_plan(session: &Session) -> Result<()> {
2826 let db = session.db().await;
2827
2828 let update_user_plan = make_update_user_plan_message(
2829 session.principal.user(),
2830 session.is_staff(),
2831 &db.0,
2832 session.app_state.llm_db.clone(),
2833 )
2834 .await?;
2835
2836 session
2837 .peer
2838 .send(session.connection_id, update_user_plan)
2839 .trace_err();
2840
2841 Ok(())
2842}
2843
2844async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2845 subscribe_user_to_channels(session.user_id(), &session).await?;
2846 Ok(())
2847}
2848
2849async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2850 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2851 let mut pool = session.connection_pool().await;
2852 for membership in &channels_for_user.channel_memberships {
2853 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2854 }
2855 session.peer.send(
2856 session.connection_id,
2857 build_update_user_channels(&channels_for_user),
2858 )?;
2859 session.peer.send(
2860 session.connection_id,
2861 build_channels_update(channels_for_user),
2862 )?;
2863 Ok(())
2864}
2865
2866/// Creates a new channel.
2867async fn create_channel(
2868 request: proto::CreateChannel,
2869 response: Response<proto::CreateChannel>,
2870 session: Session,
2871) -> Result<()> {
2872 let db = session.db().await;
2873
2874 let parent_id = request.parent_id.map(ChannelId::from_proto);
2875 let (channel, membership) = db
2876 .create_channel(&request.name, parent_id, session.user_id())
2877 .await?;
2878
2879 let root_id = channel.root_id();
2880 let channel = Channel::from_model(channel);
2881
2882 response.send(proto::CreateChannelResponse {
2883 channel: Some(channel.to_proto()),
2884 parent_id: request.parent_id,
2885 })?;
2886
2887 let mut connection_pool = session.connection_pool().await;
2888 if let Some(membership) = membership {
2889 connection_pool.subscribe_to_channel(
2890 membership.user_id,
2891 membership.channel_id,
2892 membership.role,
2893 );
2894 let update = proto::UpdateUserChannels {
2895 channel_memberships: vec![proto::ChannelMembership {
2896 channel_id: membership.channel_id.to_proto(),
2897 role: membership.role.into(),
2898 }],
2899 ..Default::default()
2900 };
2901 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2902 session.peer.send(connection_id, update.clone())?;
2903 }
2904 }
2905
2906 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2907 if !role.can_see_channel(channel.visibility) {
2908 continue;
2909 }
2910
2911 let update = proto::UpdateChannels {
2912 channels: vec![channel.to_proto()],
2913 ..Default::default()
2914 };
2915 session.peer.send(connection_id, update.clone())?;
2916 }
2917
2918 Ok(())
2919}
2920
2921/// Delete a channel
2922async fn delete_channel(
2923 request: proto::DeleteChannel,
2924 response: Response<proto::DeleteChannel>,
2925 session: Session,
2926) -> Result<()> {
2927 let db = session.db().await;
2928
2929 let channel_id = request.channel_id;
2930 let (root_channel, removed_channels) = db
2931 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2932 .await?;
2933 response.send(proto::Ack {})?;
2934
2935 // Notify members of removed channels
2936 let mut update = proto::UpdateChannels::default();
2937 update
2938 .delete_channels
2939 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2940
2941 let connection_pool = session.connection_pool().await;
2942 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2943 session.peer.send(connection_id, update.clone())?;
2944 }
2945
2946 Ok(())
2947}
2948
2949/// Invite someone to join a channel.
2950async fn invite_channel_member(
2951 request: proto::InviteChannelMember,
2952 response: Response<proto::InviteChannelMember>,
2953 session: Session,
2954) -> Result<()> {
2955 let db = session.db().await;
2956 let channel_id = ChannelId::from_proto(request.channel_id);
2957 let invitee_id = UserId::from_proto(request.user_id);
2958 let InviteMemberResult {
2959 channel,
2960 notifications,
2961 } = db
2962 .invite_channel_member(
2963 channel_id,
2964 invitee_id,
2965 session.user_id(),
2966 request.role().into(),
2967 )
2968 .await?;
2969
2970 let update = proto::UpdateChannels {
2971 channel_invitations: vec![channel.to_proto()],
2972 ..Default::default()
2973 };
2974
2975 let connection_pool = session.connection_pool().await;
2976 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2977 session.peer.send(connection_id, update.clone())?;
2978 }
2979
2980 send_notifications(&connection_pool, &session.peer, notifications);
2981
2982 response.send(proto::Ack {})?;
2983 Ok(())
2984}
2985
2986/// remove someone from a channel
2987async fn remove_channel_member(
2988 request: proto::RemoveChannelMember,
2989 response: Response<proto::RemoveChannelMember>,
2990 session: Session,
2991) -> Result<()> {
2992 let db = session.db().await;
2993 let channel_id = ChannelId::from_proto(request.channel_id);
2994 let member_id = UserId::from_proto(request.user_id);
2995
2996 let RemoveChannelMemberResult {
2997 membership_update,
2998 notification_id,
2999 } = db
3000 .remove_channel_member(channel_id, member_id, session.user_id())
3001 .await?;
3002
3003 let mut connection_pool = session.connection_pool().await;
3004 notify_membership_updated(
3005 &mut connection_pool,
3006 membership_update,
3007 member_id,
3008 &session.peer,
3009 );
3010 for connection_id in connection_pool.user_connection_ids(member_id) {
3011 if let Some(notification_id) = notification_id {
3012 session
3013 .peer
3014 .send(
3015 connection_id,
3016 proto::DeleteNotification {
3017 notification_id: notification_id.to_proto(),
3018 },
3019 )
3020 .trace_err();
3021 }
3022 }
3023
3024 response.send(proto::Ack {})?;
3025 Ok(())
3026}
3027
3028/// Toggle the channel between public and private.
3029/// Care is taken to maintain the invariant that public channels only descend from public channels,
3030/// (though members-only channels can appear at any point in the hierarchy).
3031async fn set_channel_visibility(
3032 request: proto::SetChannelVisibility,
3033 response: Response<proto::SetChannelVisibility>,
3034 session: Session,
3035) -> Result<()> {
3036 let db = session.db().await;
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038 let visibility = request.visibility().into();
3039
3040 let channel_model = db
3041 .set_channel_visibility(channel_id, visibility, session.user_id())
3042 .await?;
3043 let root_id = channel_model.root_id();
3044 let channel = Channel::from_model(channel_model);
3045
3046 let mut connection_pool = session.connection_pool().await;
3047 for (user_id, role) in connection_pool
3048 .channel_user_ids(root_id)
3049 .collect::<Vec<_>>()
3050 .into_iter()
3051 {
3052 let update = if role.can_see_channel(channel.visibility) {
3053 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3054 proto::UpdateChannels {
3055 channels: vec![channel.to_proto()],
3056 ..Default::default()
3057 }
3058 } else {
3059 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3060 proto::UpdateChannels {
3061 delete_channels: vec![channel.id.to_proto()],
3062 ..Default::default()
3063 }
3064 };
3065
3066 for connection_id in connection_pool.user_connection_ids(user_id) {
3067 session.peer.send(connection_id, update.clone())?;
3068 }
3069 }
3070
3071 response.send(proto::Ack {})?;
3072 Ok(())
3073}
3074
3075/// Alter the role for a user in the channel.
3076async fn set_channel_member_role(
3077 request: proto::SetChannelMemberRole,
3078 response: Response<proto::SetChannelMemberRole>,
3079 session: Session,
3080) -> Result<()> {
3081 let db = session.db().await;
3082 let channel_id = ChannelId::from_proto(request.channel_id);
3083 let member_id = UserId::from_proto(request.user_id);
3084 let result = db
3085 .set_channel_member_role(
3086 channel_id,
3087 session.user_id(),
3088 member_id,
3089 request.role().into(),
3090 )
3091 .await?;
3092
3093 match result {
3094 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3095 let mut connection_pool = session.connection_pool().await;
3096 notify_membership_updated(
3097 &mut connection_pool,
3098 membership_update,
3099 member_id,
3100 &session.peer,
3101 )
3102 }
3103 db::SetMemberRoleResult::InviteUpdated(channel) => {
3104 let update = proto::UpdateChannels {
3105 channel_invitations: vec![channel.to_proto()],
3106 ..Default::default()
3107 };
3108
3109 for connection_id in session
3110 .connection_pool()
3111 .await
3112 .user_connection_ids(member_id)
3113 {
3114 session.peer.send(connection_id, update.clone())?;
3115 }
3116 }
3117 }
3118
3119 response.send(proto::Ack {})?;
3120 Ok(())
3121}
3122
3123/// Change the name of a channel
3124async fn rename_channel(
3125 request: proto::RenameChannel,
3126 response: Response<proto::RenameChannel>,
3127 session: Session,
3128) -> Result<()> {
3129 let db = session.db().await;
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let channel_model = db
3132 .rename_channel(channel_id, session.user_id(), &request.name)
3133 .await?;
3134 let root_id = channel_model.root_id();
3135 let channel = Channel::from_model(channel_model);
3136
3137 response.send(proto::RenameChannelResponse {
3138 channel: Some(channel.to_proto()),
3139 })?;
3140
3141 let connection_pool = session.connection_pool().await;
3142 let update = proto::UpdateChannels {
3143 channels: vec![channel.to_proto()],
3144 ..Default::default()
3145 };
3146 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147 if role.can_see_channel(channel.visibility) {
3148 session.peer.send(connection_id, update.clone())?;
3149 }
3150 }
3151
3152 Ok(())
3153}
3154
3155/// Move a channel to a new parent.
3156async fn move_channel(
3157 request: proto::MoveChannel,
3158 response: Response<proto::MoveChannel>,
3159 session: Session,
3160) -> Result<()> {
3161 let channel_id = ChannelId::from_proto(request.channel_id);
3162 let to = ChannelId::from_proto(request.to);
3163
3164 let (root_id, channels) = session
3165 .db()
3166 .await
3167 .move_channel(channel_id, to, session.user_id())
3168 .await?;
3169
3170 let connection_pool = session.connection_pool().await;
3171 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172 let channels = channels
3173 .iter()
3174 .filter_map(|channel| {
3175 if role.can_see_channel(channel.visibility) {
3176 Some(channel.to_proto())
3177 } else {
3178 None
3179 }
3180 })
3181 .collect::<Vec<_>>();
3182 if channels.is_empty() {
3183 continue;
3184 }
3185
3186 let update = proto::UpdateChannels {
3187 channels,
3188 ..Default::default()
3189 };
3190
3191 session.peer.send(connection_id, update.clone())?;
3192 }
3193
3194 response.send(Ack {})?;
3195 Ok(())
3196}
3197
3198/// Get the list of channel members
3199async fn get_channel_members(
3200 request: proto::GetChannelMembers,
3201 response: Response<proto::GetChannelMembers>,
3202 session: Session,
3203) -> Result<()> {
3204 let db = session.db().await;
3205 let channel_id = ChannelId::from_proto(request.channel_id);
3206 let limit = if request.limit == 0 {
3207 u16::MAX as u64
3208 } else {
3209 request.limit
3210 };
3211 let (members, users) = db
3212 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3213 .await?;
3214 response.send(proto::GetChannelMembersResponse { members, users })?;
3215 Ok(())
3216}
3217
3218/// Accept or decline a channel invitation.
3219async fn respond_to_channel_invite(
3220 request: proto::RespondToChannelInvite,
3221 response: Response<proto::RespondToChannelInvite>,
3222 session: Session,
3223) -> Result<()> {
3224 let db = session.db().await;
3225 let channel_id = ChannelId::from_proto(request.channel_id);
3226 let RespondToChannelInvite {
3227 membership_update,
3228 notifications,
3229 } = db
3230 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3231 .await?;
3232
3233 let mut connection_pool = session.connection_pool().await;
3234 if let Some(membership_update) = membership_update {
3235 notify_membership_updated(
3236 &mut connection_pool,
3237 membership_update,
3238 session.user_id(),
3239 &session.peer,
3240 );
3241 } else {
3242 let update = proto::UpdateChannels {
3243 remove_channel_invitations: vec![channel_id.to_proto()],
3244 ..Default::default()
3245 };
3246
3247 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3248 session.peer.send(connection_id, update.clone())?;
3249 }
3250 };
3251
3252 send_notifications(&connection_pool, &session.peer, notifications);
3253
3254 response.send(proto::Ack {})?;
3255
3256 Ok(())
3257}
3258
3259/// Join the channels' room
3260async fn join_channel(
3261 request: proto::JoinChannel,
3262 response: Response<proto::JoinChannel>,
3263 session: Session,
3264) -> Result<()> {
3265 let channel_id = ChannelId::from_proto(request.channel_id);
3266 join_channel_internal(channel_id, Box::new(response), session).await
3267}
3268
3269trait JoinChannelInternalResponse {
3270 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3271}
3272impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3273 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3274 Response::<proto::JoinChannel>::send(self, result)
3275 }
3276}
3277impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3278 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3279 Response::<proto::JoinRoom>::send(self, result)
3280 }
3281}
3282
3283async fn join_channel_internal(
3284 channel_id: ChannelId,
3285 response: Box<impl JoinChannelInternalResponse>,
3286 session: Session,
3287) -> Result<()> {
3288 let joined_room = {
3289 let mut db = session.db().await;
3290 // If zed quits without leaving the room, and the user re-opens zed before the
3291 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3292 // room they were in.
3293 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3294 tracing::info!(
3295 stale_connection_id = %connection,
3296 "cleaning up stale connection",
3297 );
3298 drop(db);
3299 leave_room_for_session(&session, connection).await?;
3300 db = session.db().await;
3301 }
3302
3303 let (joined_room, membership_updated, role) = db
3304 .join_channel(channel_id, session.user_id(), session.connection_id)
3305 .await?;
3306
3307 let live_kit_connection_info =
3308 session
3309 .app_state
3310 .livekit_client
3311 .as_ref()
3312 .and_then(|live_kit| {
3313 let (can_publish, token) = if role == ChannelRole::Guest {
3314 (
3315 false,
3316 live_kit
3317 .guest_token(
3318 &joined_room.room.livekit_room,
3319 &session.user_id().to_string(),
3320 )
3321 .trace_err()?,
3322 )
3323 } else {
3324 (
3325 true,
3326 live_kit
3327 .room_token(
3328 &joined_room.room.livekit_room,
3329 &session.user_id().to_string(),
3330 )
3331 .trace_err()?,
3332 )
3333 };
3334
3335 Some(LiveKitConnectionInfo {
3336 server_url: live_kit.url().into(),
3337 token,
3338 can_publish,
3339 })
3340 });
3341
3342 response.send(proto::JoinRoomResponse {
3343 room: Some(joined_room.room.clone()),
3344 channel_id: joined_room
3345 .channel
3346 .as_ref()
3347 .map(|channel| channel.id.to_proto()),
3348 live_kit_connection_info,
3349 })?;
3350
3351 let mut connection_pool = session.connection_pool().await;
3352 if let Some(membership_updated) = membership_updated {
3353 notify_membership_updated(
3354 &mut connection_pool,
3355 membership_updated,
3356 session.user_id(),
3357 &session.peer,
3358 );
3359 }
3360
3361 room_updated(&joined_room.room, &session.peer);
3362
3363 joined_room
3364 };
3365
3366 channel_updated(
3367 &joined_room.channel.context("channel not returned")?,
3368 &joined_room.room,
3369 &session.peer,
3370 &*session.connection_pool().await,
3371 );
3372
3373 update_user_contacts(session.user_id(), &session).await?;
3374 Ok(())
3375}
3376
3377/// Start editing the channel notes
3378async fn join_channel_buffer(
3379 request: proto::JoinChannelBuffer,
3380 response: Response<proto::JoinChannelBuffer>,
3381 session: Session,
3382) -> Result<()> {
3383 let db = session.db().await;
3384 let channel_id = ChannelId::from_proto(request.channel_id);
3385
3386 let open_response = db
3387 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3388 .await?;
3389
3390 let collaborators = open_response.collaborators.clone();
3391 response.send(open_response)?;
3392
3393 let update = UpdateChannelBufferCollaborators {
3394 channel_id: channel_id.to_proto(),
3395 collaborators: collaborators.clone(),
3396 };
3397 channel_buffer_updated(
3398 session.connection_id,
3399 collaborators
3400 .iter()
3401 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3402 &update,
3403 &session.peer,
3404 );
3405
3406 Ok(())
3407}
3408
3409/// Edit the channel notes
3410async fn update_channel_buffer(
3411 request: proto::UpdateChannelBuffer,
3412 session: Session,
3413) -> Result<()> {
3414 let db = session.db().await;
3415 let channel_id = ChannelId::from_proto(request.channel_id);
3416
3417 let (collaborators, epoch, version) = db
3418 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3419 .await?;
3420
3421 channel_buffer_updated(
3422 session.connection_id,
3423 collaborators.clone(),
3424 &proto::UpdateChannelBuffer {
3425 channel_id: channel_id.to_proto(),
3426 operations: request.operations,
3427 },
3428 &session.peer,
3429 );
3430
3431 let pool = &*session.connection_pool().await;
3432
3433 let non_collaborators =
3434 pool.channel_connection_ids(channel_id)
3435 .filter_map(|(connection_id, _)| {
3436 if collaborators.contains(&connection_id) {
3437 None
3438 } else {
3439 Some(connection_id)
3440 }
3441 });
3442
3443 broadcast(None, non_collaborators, |peer_id| {
3444 session.peer.send(
3445 peer_id,
3446 proto::UpdateChannels {
3447 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3448 channel_id: channel_id.to_proto(),
3449 epoch: epoch as u64,
3450 version: version.clone(),
3451 }],
3452 ..Default::default()
3453 },
3454 )
3455 });
3456
3457 Ok(())
3458}
3459
3460/// Rejoin the channel notes after a connection blip
3461async fn rejoin_channel_buffers(
3462 request: proto::RejoinChannelBuffers,
3463 response: Response<proto::RejoinChannelBuffers>,
3464 session: Session,
3465) -> Result<()> {
3466 let db = session.db().await;
3467 let buffers = db
3468 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3469 .await?;
3470
3471 for rejoined_buffer in &buffers {
3472 let collaborators_to_notify = rejoined_buffer
3473 .buffer
3474 .collaborators
3475 .iter()
3476 .filter_map(|c| Some(c.peer_id?.into()));
3477 channel_buffer_updated(
3478 session.connection_id,
3479 collaborators_to_notify,
3480 &proto::UpdateChannelBufferCollaborators {
3481 channel_id: rejoined_buffer.buffer.channel_id,
3482 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3483 },
3484 &session.peer,
3485 );
3486 }
3487
3488 response.send(proto::RejoinChannelBuffersResponse {
3489 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3490 })?;
3491
3492 Ok(())
3493}
3494
3495/// Stop editing the channel notes
3496async fn leave_channel_buffer(
3497 request: proto::LeaveChannelBuffer,
3498 response: Response<proto::LeaveChannelBuffer>,
3499 session: Session,
3500) -> Result<()> {
3501 let db = session.db().await;
3502 let channel_id = ChannelId::from_proto(request.channel_id);
3503
3504 let left_buffer = db
3505 .leave_channel_buffer(channel_id, session.connection_id)
3506 .await?;
3507
3508 response.send(Ack {})?;
3509
3510 channel_buffer_updated(
3511 session.connection_id,
3512 left_buffer.connections,
3513 &proto::UpdateChannelBufferCollaborators {
3514 channel_id: channel_id.to_proto(),
3515 collaborators: left_buffer.collaborators,
3516 },
3517 &session.peer,
3518 );
3519
3520 Ok(())
3521}
3522
3523fn channel_buffer_updated<T: EnvelopedMessage>(
3524 sender_id: ConnectionId,
3525 collaborators: impl IntoIterator<Item = ConnectionId>,
3526 message: &T,
3527 peer: &Peer,
3528) {
3529 broadcast(Some(sender_id), collaborators, |peer_id| {
3530 peer.send(peer_id, message.clone())
3531 });
3532}
3533
3534fn send_notifications(
3535 connection_pool: &ConnectionPool,
3536 peer: &Peer,
3537 notifications: db::NotificationBatch,
3538) {
3539 for (user_id, notification) in notifications {
3540 for connection_id in connection_pool.user_connection_ids(user_id) {
3541 if let Err(error) = peer.send(
3542 connection_id,
3543 proto::AddNotification {
3544 notification: Some(notification.clone()),
3545 },
3546 ) {
3547 tracing::error!(
3548 "failed to send notification to {:?} {}",
3549 connection_id,
3550 error
3551 );
3552 }
3553 }
3554 }
3555}
3556
3557/// Send a message to the channel
3558async fn send_channel_message(
3559 request: proto::SendChannelMessage,
3560 response: Response<proto::SendChannelMessage>,
3561 session: Session,
3562) -> Result<()> {
3563 // Validate the message body.
3564 let body = request.body.trim().to_string();
3565 if body.len() > MAX_MESSAGE_LEN {
3566 return Err(anyhow!("message is too long"))?;
3567 }
3568 if body.is_empty() {
3569 return Err(anyhow!("message can't be blank"))?;
3570 }
3571
3572 // TODO: adjust mentions if body is trimmed
3573
3574 let timestamp = OffsetDateTime::now_utc();
3575 let nonce = request.nonce.context("nonce can't be blank")?;
3576
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578 let CreatedChannelMessage {
3579 message_id,
3580 participant_connection_ids,
3581 notifications,
3582 } = session
3583 .db()
3584 .await
3585 .create_channel_message(
3586 channel_id,
3587 session.user_id(),
3588 &body,
3589 &request.mentions,
3590 timestamp,
3591 nonce.clone().into(),
3592 request.reply_to_message_id.map(MessageId::from_proto),
3593 )
3594 .await?;
3595
3596 let message = proto::ChannelMessage {
3597 sender_id: session.user_id().to_proto(),
3598 id: message_id.to_proto(),
3599 body,
3600 mentions: request.mentions,
3601 timestamp: timestamp.unix_timestamp() as u64,
3602 nonce: Some(nonce),
3603 reply_to_message_id: request.reply_to_message_id,
3604 edited_at: None,
3605 };
3606 broadcast(
3607 Some(session.connection_id),
3608 participant_connection_ids.clone(),
3609 |connection| {
3610 session.peer.send(
3611 connection,
3612 proto::ChannelMessageSent {
3613 channel_id: channel_id.to_proto(),
3614 message: Some(message.clone()),
3615 },
3616 )
3617 },
3618 );
3619 response.send(proto::SendChannelMessageResponse {
3620 message: Some(message),
3621 })?;
3622
3623 let pool = &*session.connection_pool().await;
3624 let non_participants =
3625 pool.channel_connection_ids(channel_id)
3626 .filter_map(|(connection_id, _)| {
3627 if participant_connection_ids.contains(&connection_id) {
3628 None
3629 } else {
3630 Some(connection_id)
3631 }
3632 });
3633 broadcast(None, non_participants, |peer_id| {
3634 session.peer.send(
3635 peer_id,
3636 proto::UpdateChannels {
3637 latest_channel_message_ids: vec![proto::ChannelMessageId {
3638 channel_id: channel_id.to_proto(),
3639 message_id: message_id.to_proto(),
3640 }],
3641 ..Default::default()
3642 },
3643 )
3644 });
3645 send_notifications(pool, &session.peer, notifications);
3646
3647 Ok(())
3648}
3649
3650/// Delete a channel message
3651async fn remove_channel_message(
3652 request: proto::RemoveChannelMessage,
3653 response: Response<proto::RemoveChannelMessage>,
3654 session: Session,
3655) -> Result<()> {
3656 let channel_id = ChannelId::from_proto(request.channel_id);
3657 let message_id = MessageId::from_proto(request.message_id);
3658 let (connection_ids, existing_notification_ids) = session
3659 .db()
3660 .await
3661 .remove_channel_message(channel_id, message_id, session.user_id())
3662 .await?;
3663
3664 broadcast(
3665 Some(session.connection_id),
3666 connection_ids,
3667 move |connection| {
3668 session.peer.send(connection, request.clone())?;
3669
3670 for notification_id in &existing_notification_ids {
3671 session.peer.send(
3672 connection,
3673 proto::DeleteNotification {
3674 notification_id: (*notification_id).to_proto(),
3675 },
3676 )?;
3677 }
3678
3679 Ok(())
3680 },
3681 );
3682 response.send(proto::Ack {})?;
3683 Ok(())
3684}
3685
3686async fn update_channel_message(
3687 request: proto::UpdateChannelMessage,
3688 response: Response<proto::UpdateChannelMessage>,
3689 session: Session,
3690) -> Result<()> {
3691 let channel_id = ChannelId::from_proto(request.channel_id);
3692 let message_id = MessageId::from_proto(request.message_id);
3693 let updated_at = OffsetDateTime::now_utc();
3694 let UpdatedChannelMessage {
3695 message_id,
3696 participant_connection_ids,
3697 notifications,
3698 reply_to_message_id,
3699 timestamp,
3700 deleted_mention_notification_ids,
3701 updated_mention_notifications,
3702 } = session
3703 .db()
3704 .await
3705 .update_channel_message(
3706 channel_id,
3707 message_id,
3708 session.user_id(),
3709 request.body.as_str(),
3710 &request.mentions,
3711 updated_at,
3712 )
3713 .await?;
3714
3715 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3716
3717 let message = proto::ChannelMessage {
3718 sender_id: session.user_id().to_proto(),
3719 id: message_id.to_proto(),
3720 body: request.body.clone(),
3721 mentions: request.mentions.clone(),
3722 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3723 nonce: Some(nonce),
3724 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3725 edited_at: Some(updated_at.unix_timestamp() as u64),
3726 };
3727
3728 response.send(proto::Ack {})?;
3729
3730 let pool = &*session.connection_pool().await;
3731 broadcast(
3732 Some(session.connection_id),
3733 participant_connection_ids,
3734 |connection| {
3735 session.peer.send(
3736 connection,
3737 proto::ChannelMessageUpdate {
3738 channel_id: channel_id.to_proto(),
3739 message: Some(message.clone()),
3740 },
3741 )?;
3742
3743 for notification_id in &deleted_mention_notification_ids {
3744 session.peer.send(
3745 connection,
3746 proto::DeleteNotification {
3747 notification_id: (*notification_id).to_proto(),
3748 },
3749 )?;
3750 }
3751
3752 for notification in &updated_mention_notifications {
3753 session.peer.send(
3754 connection,
3755 proto::UpdateNotification {
3756 notification: Some(notification.clone()),
3757 },
3758 )?;
3759 }
3760
3761 Ok(())
3762 },
3763 );
3764
3765 send_notifications(pool, &session.peer, notifications);
3766
3767 Ok(())
3768}
3769
3770/// Mark a channel message as read
3771async fn acknowledge_channel_message(
3772 request: proto::AckChannelMessage,
3773 session: Session,
3774) -> Result<()> {
3775 let channel_id = ChannelId::from_proto(request.channel_id);
3776 let message_id = MessageId::from_proto(request.message_id);
3777 let notifications = session
3778 .db()
3779 .await
3780 .observe_channel_message(channel_id, session.user_id(), message_id)
3781 .await?;
3782 send_notifications(
3783 &*session.connection_pool().await,
3784 &session.peer,
3785 notifications,
3786 );
3787 Ok(())
3788}
3789
3790/// Mark a buffer version as synced
3791async fn acknowledge_buffer_version(
3792 request: proto::AckBufferOperation,
3793 session: Session,
3794) -> Result<()> {
3795 let buffer_id = BufferId::from_proto(request.buffer_id);
3796 session
3797 .db()
3798 .await
3799 .observe_buffer_version(
3800 buffer_id,
3801 session.user_id(),
3802 request.epoch as i32,
3803 &request.version,
3804 )
3805 .await?;
3806 Ok(())
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811 _request: proto::GetSupermavenApiKey,
3812 response: Response<proto::GetSupermavenApiKey>,
3813 session: Session,
3814) -> Result<()> {
3815 let user_id: String = session.user_id().to_string();
3816 if !session.is_staff() {
3817 return Err(anyhow!("supermaven not enabled for this account"))?;
3818 }
3819
3820 let email = session.email().context("user must have an email")?;
3821
3822 let supermaven_admin_api = session
3823 .supermaven_client
3824 .as_ref()
3825 .context("supermaven not configured")?;
3826
3827 let result = supermaven_admin_api
3828 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3829 .await?;
3830
3831 response.send(proto::GetSupermavenApiKeyResponse {
3832 api_key: result.api_key,
3833 })?;
3834
3835 Ok(())
3836}
3837
3838/// Start receiving chat updates for a channel
3839async fn join_channel_chat(
3840 request: proto::JoinChannelChat,
3841 response: Response<proto::JoinChannelChat>,
3842 session: Session,
3843) -> Result<()> {
3844 let channel_id = ChannelId::from_proto(request.channel_id);
3845
3846 let db = session.db().await;
3847 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3848 .await?;
3849 let messages = db
3850 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3851 .await?;
3852 response.send(proto::JoinChannelChatResponse {
3853 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3854 messages,
3855 })?;
3856 Ok(())
3857}
3858
3859/// Stop receiving chat updates for a channel
3860async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3861 let channel_id = ChannelId::from_proto(request.channel_id);
3862 session
3863 .db()
3864 .await
3865 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3866 .await?;
3867 Ok(())
3868}
3869
3870/// Retrieve the chat history for a channel
3871async fn get_channel_messages(
3872 request: proto::GetChannelMessages,
3873 response: Response<proto::GetChannelMessages>,
3874 session: Session,
3875) -> Result<()> {
3876 let channel_id = ChannelId::from_proto(request.channel_id);
3877 let messages = session
3878 .db()
3879 .await
3880 .get_channel_messages(
3881 channel_id,
3882 session.user_id(),
3883 MESSAGE_COUNT_PER_PAGE,
3884 Some(MessageId::from_proto(request.before_message_id)),
3885 )
3886 .await?;
3887 response.send(proto::GetChannelMessagesResponse {
3888 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3889 messages,
3890 })?;
3891 Ok(())
3892}
3893
3894/// Retrieve specific chat messages
3895async fn get_channel_messages_by_id(
3896 request: proto::GetChannelMessagesById,
3897 response: Response<proto::GetChannelMessagesById>,
3898 session: Session,
3899) -> Result<()> {
3900 let message_ids = request
3901 .message_ids
3902 .iter()
3903 .map(|id| MessageId::from_proto(*id))
3904 .collect::<Vec<_>>();
3905 let messages = session
3906 .db()
3907 .await
3908 .get_channel_messages_by_id(session.user_id(), &message_ids)
3909 .await?;
3910 response.send(proto::GetChannelMessagesResponse {
3911 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3912 messages,
3913 })?;
3914 Ok(())
3915}
3916
3917/// Retrieve the current users notifications
3918async fn get_notifications(
3919 request: proto::GetNotifications,
3920 response: Response<proto::GetNotifications>,
3921 session: Session,
3922) -> Result<()> {
3923 let notifications = session
3924 .db()
3925 .await
3926 .get_notifications(
3927 session.user_id(),
3928 NOTIFICATION_COUNT_PER_PAGE,
3929 request.before_id.map(db::NotificationId::from_proto),
3930 )
3931 .await?;
3932 response.send(proto::GetNotificationsResponse {
3933 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3934 notifications,
3935 })?;
3936 Ok(())
3937}
3938
3939/// Mark notifications as read
3940async fn mark_notification_as_read(
3941 request: proto::MarkNotificationRead,
3942 response: Response<proto::MarkNotificationRead>,
3943 session: Session,
3944) -> Result<()> {
3945 let database = &session.db().await;
3946 let notifications = database
3947 .mark_notification_as_read_by_id(
3948 session.user_id(),
3949 NotificationId::from_proto(request.notification_id),
3950 )
3951 .await?;
3952 send_notifications(
3953 &*session.connection_pool().await,
3954 &session.peer,
3955 notifications,
3956 );
3957 response.send(proto::Ack {})?;
3958 Ok(())
3959}
3960
3961/// Get the current users information
3962async fn get_private_user_info(
3963 _request: proto::GetPrivateUserInfo,
3964 response: Response<proto::GetPrivateUserInfo>,
3965 session: Session,
3966) -> Result<()> {
3967 let db = session.db().await;
3968
3969 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3970 let user = db
3971 .get_user_by_id(session.user_id())
3972 .await?
3973 .context("user not found")?;
3974 let flags = db.get_user_flags(session.user_id()).await?;
3975
3976 response.send(proto::GetPrivateUserInfoResponse {
3977 metrics_id,
3978 staff: user.admin,
3979 flags,
3980 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3981 })?;
3982 Ok(())
3983}
3984
3985/// Accept the terms of service (tos) on behalf of the current user
3986async fn accept_terms_of_service(
3987 _request: proto::AcceptTermsOfService,
3988 response: Response<proto::AcceptTermsOfService>,
3989 session: Session,
3990) -> Result<()> {
3991 let db = session.db().await;
3992
3993 let accepted_tos_at = Utc::now();
3994 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3995 .await?;
3996
3997 response.send(proto::AcceptTermsOfServiceResponse {
3998 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3999 })?;
4000 Ok(())
4001}
4002
4003/// The minimum account age an account must have in order to use the LLM service.
4004pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4005
4006async fn get_llm_api_token(
4007 _request: proto::GetLlmToken,
4008 response: Response<proto::GetLlmToken>,
4009 session: Session,
4010) -> Result<()> {
4011 let db = session.db().await;
4012
4013 let flags = db.get_user_flags(session.user_id()).await?;
4014
4015 let user_id = session.user_id();
4016 let user = db
4017 .get_user_by_id(user_id)
4018 .await?
4019 .with_context(|| format!("user {user_id} not found"))?;
4020
4021 if user.accepted_tos_at.is_none() {
4022 Err(anyhow!("terms of service not accepted"))?
4023 }
4024
4025 let stripe_client = session
4026 .app_state
4027 .stripe_client
4028 .as_ref()
4029 .context("failed to retrieve Stripe client")?;
4030
4031 let stripe_billing = session
4032 .app_state
4033 .stripe_billing
4034 .as_ref()
4035 .context("failed to retrieve Stripe billing object")?;
4036
4037 let billing_customer = if let Some(billing_customer) =
4038 db.get_billing_customer_by_user_id(user.id).await?
4039 {
4040 billing_customer
4041 } else {
4042 let customer_id = stripe_billing
4043 .find_or_create_customer_by_email(user.email_address.as_deref())
4044 .await?;
4045
4046 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4047 .await?
4048 .context("billing customer not found")?
4049 };
4050
4051 let billing_subscription =
4052 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4053 billing_subscription
4054 } else {
4055 let stripe_customer_id =
4056 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4057
4058 let stripe_subscription = stripe_billing
4059 .subscribe_to_zed_free(stripe_customer_id)
4060 .await?;
4061
4062 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4063 billing_customer_id: billing_customer.id,
4064 kind: Some(SubscriptionKind::ZedFree),
4065 stripe_subscription_id: stripe_subscription.id.to_string(),
4066 stripe_subscription_status: stripe_subscription.status.into(),
4067 stripe_cancellation_reason: None,
4068 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4069 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4070 })
4071 .await?
4072 };
4073
4074 let billing_preferences = db.get_billing_preferences(user.id).await?;
4075
4076 let token = LlmTokenClaims::create(
4077 &user,
4078 session.is_staff(),
4079 billing_customer,
4080 billing_preferences,
4081 &flags,
4082 billing_subscription,
4083 session.system_id.clone(),
4084 &session.app_state.config,
4085 )?;
4086 response.send(proto::GetLlmTokenResponse { token })?;
4087 Ok(())
4088}
4089
4090fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4091 let message = match message {
4092 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4093 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4094 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4095 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4096 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4097 code: frame.code.into(),
4098 reason: frame.reason.as_str().to_owned().into(),
4099 })),
4100 // We should never receive a frame while reading the message, according
4101 // to the `tungstenite` maintainers:
4102 //
4103 // > It cannot occur when you read messages from the WebSocket, but it
4104 // > can be used when you want to send the raw frames (e.g. you want to
4105 // > send the frames to the WebSocket without composing the full message first).
4106 // >
4107 // > — https://github.com/snapview/tungstenite-rs/issues/268
4108 TungsteniteMessage::Frame(_) => {
4109 bail!("received an unexpected frame while reading the message")
4110 }
4111 };
4112
4113 Ok(message)
4114}
4115
4116fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4117 match message {
4118 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4119 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4120 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4121 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4122 AxumMessage::Close(frame) => {
4123 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4124 code: frame.code.into(),
4125 reason: frame.reason.as_ref().into(),
4126 }))
4127 }
4128 }
4129}
4130
4131fn notify_membership_updated(
4132 connection_pool: &mut ConnectionPool,
4133 result: MembershipUpdated,
4134 user_id: UserId,
4135 peer: &Peer,
4136) {
4137 for membership in &result.new_channels.channel_memberships {
4138 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4139 }
4140 for channel_id in &result.removed_channels {
4141 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4142 }
4143
4144 let user_channels_update = proto::UpdateUserChannels {
4145 channel_memberships: result
4146 .new_channels
4147 .channel_memberships
4148 .iter()
4149 .map(|cm| proto::ChannelMembership {
4150 channel_id: cm.channel_id.to_proto(),
4151 role: cm.role.into(),
4152 })
4153 .collect(),
4154 ..Default::default()
4155 };
4156
4157 let mut update = build_channels_update(result.new_channels);
4158 update.delete_channels = result
4159 .removed_channels
4160 .into_iter()
4161 .map(|id| id.to_proto())
4162 .collect();
4163 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4164
4165 for connection_id in connection_pool.user_connection_ids(user_id) {
4166 peer.send(connection_id, user_channels_update.clone())
4167 .trace_err();
4168 peer.send(connection_id, update.clone()).trace_err();
4169 }
4170}
4171
4172fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4173 proto::UpdateUserChannels {
4174 channel_memberships: channels
4175 .channel_memberships
4176 .iter()
4177 .map(|m| proto::ChannelMembership {
4178 channel_id: m.channel_id.to_proto(),
4179 role: m.role.into(),
4180 })
4181 .collect(),
4182 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4183 observed_channel_message_id: channels.observed_channel_messages.clone(),
4184 }
4185}
4186
4187fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4188 let mut update = proto::UpdateChannels::default();
4189
4190 for channel in channels.channels {
4191 update.channels.push(channel.to_proto());
4192 }
4193
4194 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4195 update.latest_channel_message_ids = channels.latest_channel_messages;
4196
4197 for (channel_id, participants) in channels.channel_participants {
4198 update
4199 .channel_participants
4200 .push(proto::ChannelParticipants {
4201 channel_id: channel_id.to_proto(),
4202 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4203 });
4204 }
4205
4206 for channel in channels.invited_channels {
4207 update.channel_invitations.push(channel.to_proto());
4208 }
4209
4210 update
4211}
4212
4213fn build_initial_contacts_update(
4214 contacts: Vec<db::Contact>,
4215 pool: &ConnectionPool,
4216) -> proto::UpdateContacts {
4217 let mut update = proto::UpdateContacts::default();
4218
4219 for contact in contacts {
4220 match contact {
4221 db::Contact::Accepted { user_id, busy } => {
4222 update.contacts.push(contact_for_user(user_id, busy, pool));
4223 }
4224 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4225 db::Contact::Incoming { user_id } => {
4226 update
4227 .incoming_requests
4228 .push(proto::IncomingContactRequest {
4229 requester_id: user_id.to_proto(),
4230 })
4231 }
4232 }
4233 }
4234
4235 update
4236}
4237
4238fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4239 proto::Contact {
4240 user_id: user_id.to_proto(),
4241 online: pool.is_user_online(user_id),
4242 busy,
4243 }
4244}
4245
4246fn room_updated(room: &proto::Room, peer: &Peer) {
4247 broadcast(
4248 None,
4249 room.participants
4250 .iter()
4251 .filter_map(|participant| Some(participant.peer_id?.into())),
4252 |peer_id| {
4253 peer.send(
4254 peer_id,
4255 proto::RoomUpdated {
4256 room: Some(room.clone()),
4257 },
4258 )
4259 },
4260 );
4261}
4262
4263fn channel_updated(
4264 channel: &db::channel::Model,
4265 room: &proto::Room,
4266 peer: &Peer,
4267 pool: &ConnectionPool,
4268) {
4269 let participants = room
4270 .participants
4271 .iter()
4272 .map(|p| p.user_id)
4273 .collect::<Vec<_>>();
4274
4275 broadcast(
4276 None,
4277 pool.channel_connection_ids(channel.root_id())
4278 .filter_map(|(channel_id, role)| {
4279 role.can_see_channel(channel.visibility)
4280 .then_some(channel_id)
4281 }),
4282 |peer_id| {
4283 peer.send(
4284 peer_id,
4285 proto::UpdateChannels {
4286 channel_participants: vec![proto::ChannelParticipants {
4287 channel_id: channel.id.to_proto(),
4288 participant_user_ids: participants.clone(),
4289 }],
4290 ..Default::default()
4291 },
4292 )
4293 },
4294 );
4295}
4296
4297async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4298 let db = session.db().await;
4299
4300 let contacts = db.get_contacts(user_id).await?;
4301 let busy = db.is_user_busy(user_id).await?;
4302
4303 let pool = session.connection_pool().await;
4304 let updated_contact = contact_for_user(user_id, busy, &pool);
4305 for contact in contacts {
4306 if let db::Contact::Accepted {
4307 user_id: contact_user_id,
4308 ..
4309 } = contact
4310 {
4311 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4312 session
4313 .peer
4314 .send(
4315 contact_conn_id,
4316 proto::UpdateContacts {
4317 contacts: vec![updated_contact.clone()],
4318 remove_contacts: Default::default(),
4319 incoming_requests: Default::default(),
4320 remove_incoming_requests: Default::default(),
4321 outgoing_requests: Default::default(),
4322 remove_outgoing_requests: Default::default(),
4323 },
4324 )
4325 .trace_err();
4326 }
4327 }
4328 }
4329 Ok(())
4330}
4331
4332async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4333 let mut contacts_to_update = HashSet::default();
4334
4335 let room_id;
4336 let canceled_calls_to_user_ids;
4337 let livekit_room;
4338 let delete_livekit_room;
4339 let room;
4340 let channel;
4341
4342 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4343 contacts_to_update.insert(session.user_id());
4344
4345 for project in left_room.left_projects.values() {
4346 project_left(project, session);
4347 }
4348
4349 room_id = RoomId::from_proto(left_room.room.id);
4350 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4351 livekit_room = mem::take(&mut left_room.room.livekit_room);
4352 delete_livekit_room = left_room.deleted;
4353 room = mem::take(&mut left_room.room);
4354 channel = mem::take(&mut left_room.channel);
4355
4356 room_updated(&room, &session.peer);
4357 } else {
4358 return Ok(());
4359 }
4360
4361 if let Some(channel) = channel {
4362 channel_updated(
4363 &channel,
4364 &room,
4365 &session.peer,
4366 &*session.connection_pool().await,
4367 );
4368 }
4369
4370 {
4371 let pool = session.connection_pool().await;
4372 for canceled_user_id in canceled_calls_to_user_ids {
4373 for connection_id in pool.user_connection_ids(canceled_user_id) {
4374 session
4375 .peer
4376 .send(
4377 connection_id,
4378 proto::CallCanceled {
4379 room_id: room_id.to_proto(),
4380 },
4381 )
4382 .trace_err();
4383 }
4384 contacts_to_update.insert(canceled_user_id);
4385 }
4386 }
4387
4388 for contact_user_id in contacts_to_update {
4389 update_user_contacts(contact_user_id, session).await?;
4390 }
4391
4392 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4393 live_kit
4394 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4395 .await
4396 .trace_err();
4397
4398 if delete_livekit_room {
4399 live_kit.delete_room(livekit_room).await.trace_err();
4400 }
4401 }
4402
4403 Ok(())
4404}
4405
4406async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4407 let left_channel_buffers = session
4408 .db()
4409 .await
4410 .leave_channel_buffers(session.connection_id)
4411 .await?;
4412
4413 for left_buffer in left_channel_buffers {
4414 channel_buffer_updated(
4415 session.connection_id,
4416 left_buffer.connections,
4417 &proto::UpdateChannelBufferCollaborators {
4418 channel_id: left_buffer.channel_id.to_proto(),
4419 collaborators: left_buffer.collaborators,
4420 },
4421 &session.peer,
4422 );
4423 }
4424
4425 Ok(())
4426}
4427
4428fn project_left(project: &db::LeftProject, session: &Session) {
4429 for connection_id in &project.connection_ids {
4430 if project.should_unshare {
4431 session
4432 .peer
4433 .send(
4434 *connection_id,
4435 proto::UnshareProject {
4436 project_id: project.id.to_proto(),
4437 },
4438 )
4439 .trace_err();
4440 } else {
4441 session
4442 .peer
4443 .send(
4444 *connection_id,
4445 proto::RemoveProjectCollaborator {
4446 project_id: project.id.to_proto(),
4447 peer_id: Some(session.connection_id.into()),
4448 },
4449 )
4450 .trace_err();
4451 }
4452 }
4453}
4454
4455pub trait ResultExt {
4456 type Ok;
4457
4458 fn trace_err(self) -> Option<Self::Ok>;
4459}
4460
4461impl<T, E> ResultExt for Result<T, E>
4462where
4463 E: std::fmt::Debug,
4464{
4465 type Ok = T;
4466
4467 #[track_caller]
4468 fn trace_err(self) -> Option<T> {
4469 match self {
4470 Ok(value) => Some(value),
4471 Err(error) => {
4472 tracing::error!("{:?}", error);
4473 None
4474 }
4475 }
4476 }
4477}