1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::{
9 AppState, Error, Result, auth,
10 db::{
11 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
12 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
13 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
14 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
15 },
16 executor::Executor,
17};
18use anyhow::{Context as _, anyhow, bail};
19use async_tungstenite::tungstenite::{
20 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
21};
22use axum::{
23 Extension, Router, TypedHeader,
24 body::Body,
25 extract::{
26 ConnectInfo, WebSocketUpgrade,
27 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
28 },
29 headers::{Header, HeaderName},
30 http::StatusCode,
31 middleware,
32 response::IntoResponse,
33 routing::get,
34};
35use chrono::Utc;
36use collections::{HashMap, HashSet};
37pub use connection_pool::{ConnectionPool, ZedVersion};
38use core::fmt::{self, Debug, Formatter};
39use reqwest_client::ReqwestClient;
40use rpc::proto::split_repository_update;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
45 stream::FuturesUnordered,
46};
47use prometheus::{IntGauge, register_int_gauge};
48use rpc::{
49 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 Arc, OnceLock,
67 atomic::{AtomicBool, Ordering::SeqCst},
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{Semaphore, watch};
73use tower::ServiceBuilder;
74use tracing::{
75 Instrument,
76 field::{self},
77 info_span, instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 fn user_id(&self) -> UserId {
172 match &self.principal {
173 Principal::User(user) => user.id,
174 Principal::Impersonated { user, .. } => user.id,
175 }
176 }
177
178 pub fn email(&self) -> Option<String> {
179 match &self.principal {
180 Principal::User(user) => user.email_address.clone(),
181 Principal::Impersonated { user, .. } => user.email_address.clone(),
182 }
183 }
184}
185
186impl Debug for Session {
187 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
188 let mut result = f.debug_struct("Session");
189 match &self.principal {
190 Principal::User(user) => {
191 result.field("user", &user.github_login);
192 }
193 Principal::Impersonated { user, admin } => {
194 result.field("user", &user.github_login);
195 result.field("impersonator", &admin.github_login);
196 }
197 }
198 result.field("connection_id", &self.connection_id).finish()
199 }
200}
201
202struct DbHandle(Arc<Database>);
203
204impl Deref for DbHandle {
205 type Target = Database;
206
207 fn deref(&self) -> &Self::Target {
208 self.0.as_ref()
209 }
210}
211
212pub struct Server {
213 id: parking_lot::Mutex<ServerId>,
214 peer: Arc<Peer>,
215 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
216 app_state: Arc<AppState>,
217 handlers: HashMap<TypeId, MessageHandler>,
218 teardown: watch::Sender<bool>,
219}
220
221pub(crate) struct ConnectionPoolGuard<'a> {
222 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
223 _not_send: PhantomData<Rc<()>>,
224}
225
226#[derive(Serialize)]
227pub struct ServerSnapshot<'a> {
228 peer: &'a Peer,
229 #[serde(serialize_with = "serialize_deref")]
230 connection_pool: ConnectionPoolGuard<'a>,
231}
232
233pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
234where
235 S: Serializer,
236 T: Deref<Target = U>,
237 U: Serialize,
238{
239 Serialize::serialize(value.deref(), serializer)
240}
241
242impl Server {
243 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
244 let mut server = Self {
245 id: parking_lot::Mutex::new(id),
246 peer: Peer::new(id.0 as u32),
247 app_state: app_state.clone(),
248 connection_pool: Default::default(),
249 handlers: Default::default(),
250 teardown: watch::channel(false).0,
251 };
252
253 server
254 .add_request_handler(ping)
255 .add_request_handler(create_room)
256 .add_request_handler(join_room)
257 .add_request_handler(rejoin_room)
258 .add_request_handler(leave_room)
259 .add_request_handler(set_room_participant_role)
260 .add_request_handler(call)
261 .add_request_handler(cancel_call)
262 .add_message_handler(decline_call)
263 .add_request_handler(update_participant_location)
264 .add_request_handler(share_project)
265 .add_message_handler(unshare_project)
266 .add_request_handler(join_project)
267 .add_message_handler(leave_project)
268 .add_request_handler(update_project)
269 .add_request_handler(update_worktree)
270 .add_request_handler(update_repository)
271 .add_request_handler(remove_repository)
272 .add_message_handler(start_language_server)
273 .add_message_handler(update_language_server)
274 .add_message_handler(update_diagnostic_summary)
275 .add_message_handler(update_worktree_settings)
276 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
277 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
278 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
279 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
280 .add_request_handler(forward_find_search_candidates_request)
281 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
282 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
283 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
284 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
285 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
286 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
287 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
288 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
289 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
290 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
291 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
294 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
295 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
296 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
297 .add_request_handler(
298 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
299 )
300 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
301 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
304 .add_request_handler(
305 forward_read_only_project_request::<proto::LanguageServerIdForName>,
306 )
307 .add_request_handler(
308 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
309 )
310 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
311 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
312 .add_request_handler(
313 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
314 )
315 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
316 .add_request_handler(
317 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
318 )
319 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
320 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
321 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
322 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
323 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
324 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
325 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
326 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
327 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
328 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
329 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
330 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
331 .add_request_handler(
332 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
333 )
334 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
335 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
336 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
337 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
338 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
339 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
340 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
341 .add_message_handler(create_buffer_for_peer)
342 .add_request_handler(update_buffer)
343 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
344 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
345 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
346 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
347 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
348 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
349 .add_request_handler(get_users)
350 .add_request_handler(fuzzy_search_users)
351 .add_request_handler(request_contact)
352 .add_request_handler(remove_contact)
353 .add_request_handler(respond_to_contact_request)
354 .add_message_handler(subscribe_to_channels)
355 .add_request_handler(create_channel)
356 .add_request_handler(delete_channel)
357 .add_request_handler(invite_channel_member)
358 .add_request_handler(remove_channel_member)
359 .add_request_handler(set_channel_member_role)
360 .add_request_handler(set_channel_visibility)
361 .add_request_handler(rename_channel)
362 .add_request_handler(join_channel_buffer)
363 .add_request_handler(leave_channel_buffer)
364 .add_message_handler(update_channel_buffer)
365 .add_request_handler(rejoin_channel_buffers)
366 .add_request_handler(get_channel_members)
367 .add_request_handler(respond_to_channel_invite)
368 .add_request_handler(join_channel)
369 .add_request_handler(join_channel_chat)
370 .add_message_handler(leave_channel_chat)
371 .add_request_handler(send_channel_message)
372 .add_request_handler(remove_channel_message)
373 .add_request_handler(update_channel_message)
374 .add_request_handler(get_channel_messages)
375 .add_request_handler(get_channel_messages_by_id)
376 .add_request_handler(get_notifications)
377 .add_request_handler(mark_notification_as_read)
378 .add_request_handler(move_channel)
379 .add_request_handler(follow)
380 .add_message_handler(unfollow)
381 .add_message_handler(update_followers)
382 .add_request_handler(get_private_user_info)
383 .add_request_handler(get_llm_api_token)
384 .add_request_handler(accept_terms_of_service)
385 .add_message_handler(acknowledge_channel_message)
386 .add_message_handler(acknowledge_buffer_version)
387 .add_request_handler(get_supermaven_api_key)
388 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
389 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
390 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
391 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
392 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
393 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
394 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
395 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
396 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
397 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
398 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
399 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
400 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
401 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
403 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
404 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
405 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
406 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
407 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
408 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
409 .add_message_handler(update_context);
410
411 Arc::new(server)
412 }
413
414 pub async fn start(&self) -> Result<()> {
415 let server_id = *self.id.lock();
416 let app_state = self.app_state.clone();
417 let peer = self.peer.clone();
418 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
419 let pool = self.connection_pool.clone();
420 let livekit_client = self.app_state.livekit_client.clone();
421
422 let span = info_span!("start server");
423 self.app_state.executor.spawn_detached(
424 async move {
425 tracing::info!("waiting for cleanup timeout");
426 timeout.await;
427 tracing::info!("cleanup timeout expired, retrieving stale rooms");
428 if let Some((room_ids, channel_ids)) = app_state
429 .db
430 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
431 .await
432 .trace_err()
433 {
434 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
435 tracing::info!(
436 stale_channel_buffer_count = channel_ids.len(),
437 "retrieved stale channel buffers"
438 );
439
440 for channel_id in channel_ids {
441 if let Some(refreshed_channel_buffer) = app_state
442 .db
443 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
444 .await
445 .trace_err()
446 {
447 for connection_id in refreshed_channel_buffer.connection_ids {
448 peer.send(
449 connection_id,
450 proto::UpdateChannelBufferCollaborators {
451 channel_id: channel_id.to_proto(),
452 collaborators: refreshed_channel_buffer
453 .collaborators
454 .clone(),
455 },
456 )
457 .trace_err();
458 }
459 }
460 }
461
462 for room_id in room_ids {
463 let mut contacts_to_update = HashSet::default();
464 let mut canceled_calls_to_user_ids = Vec::new();
465 let mut livekit_room = String::new();
466 let mut delete_livekit_room = false;
467
468 if let Some(mut refreshed_room) = app_state
469 .db
470 .clear_stale_room_participants(room_id, server_id)
471 .await
472 .trace_err()
473 {
474 tracing::info!(
475 room_id = room_id.0,
476 new_participant_count = refreshed_room.room.participants.len(),
477 "refreshed room"
478 );
479 room_updated(&refreshed_room.room, &peer);
480 if let Some(channel) = refreshed_room.channel.as_ref() {
481 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
482 }
483 contacts_to_update
484 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
485 contacts_to_update
486 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
487 canceled_calls_to_user_ids =
488 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
489 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
490 delete_livekit_room = refreshed_room.room.participants.is_empty();
491 }
492
493 {
494 let pool = pool.lock();
495 for canceled_user_id in canceled_calls_to_user_ids {
496 for connection_id in pool.user_connection_ids(canceled_user_id) {
497 peer.send(
498 connection_id,
499 proto::CallCanceled {
500 room_id: room_id.to_proto(),
501 },
502 )
503 .trace_err();
504 }
505 }
506 }
507
508 for user_id in contacts_to_update {
509 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
510 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
511 if let Some((busy, contacts)) = busy.zip(contacts) {
512 let pool = pool.lock();
513 let updated_contact = contact_for_user(user_id, busy, &pool);
514 for contact in contacts {
515 if let db::Contact::Accepted {
516 user_id: contact_user_id,
517 ..
518 } = contact
519 {
520 for contact_conn_id in
521 pool.user_connection_ids(contact_user_id)
522 {
523 peer.send(
524 contact_conn_id,
525 proto::UpdateContacts {
526 contacts: vec![updated_contact.clone()],
527 remove_contacts: Default::default(),
528 incoming_requests: Default::default(),
529 remove_incoming_requests: Default::default(),
530 outgoing_requests: Default::default(),
531 remove_outgoing_requests: Default::default(),
532 },
533 )
534 .trace_err();
535 }
536 }
537 }
538 }
539 }
540
541 if let Some(live_kit) = livekit_client.as_ref() {
542 if delete_livekit_room {
543 live_kit.delete_room(livekit_room).await.trace_err();
544 }
545 }
546 }
547 }
548
549 app_state
550 .db
551 .delete_stale_servers(&app_state.config.zed_environment, server_id)
552 .await
553 .trace_err();
554 }
555 .instrument(span),
556 );
557 Ok(())
558 }
559
560 pub fn teardown(&self) {
561 self.peer.teardown();
562 self.connection_pool.lock().reset();
563 let _ = self.teardown.send(true);
564 }
565
566 #[cfg(test)]
567 pub fn reset(&self, id: ServerId) {
568 self.teardown();
569 *self.id.lock() = id;
570 self.peer.reset(id.0 as u32);
571 let _ = self.teardown.send(false);
572 }
573
574 #[cfg(test)]
575 pub fn id(&self) -> ServerId {
576 *self.id.lock()
577 }
578
579 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
580 where
581 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
582 Fut: 'static + Send + Future<Output = Result<()>>,
583 M: EnvelopedMessage,
584 {
585 let prev_handler = self.handlers.insert(
586 TypeId::of::<M>(),
587 Box::new(move |envelope, session| {
588 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
589 let received_at = envelope.received_at;
590 tracing::info!("message received");
591 let start_time = Instant::now();
592 let future = (handler)(*envelope, session);
593 async move {
594 let result = future.await;
595 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
596 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
597 let queue_duration_ms = total_duration_ms - processing_duration_ms;
598 let payload_type = M::NAME;
599
600 match result {
601 Err(error) => {
602 tracing::error!(
603 ?error,
604 total_duration_ms,
605 processing_duration_ms,
606 queue_duration_ms,
607 payload_type,
608 "error handling message"
609 )
610 }
611 Ok(()) => tracing::info!(
612 total_duration_ms,
613 processing_duration_ms,
614 queue_duration_ms,
615 "finished handling message"
616 ),
617 }
618 }
619 .boxed()
620 }),
621 );
622 if prev_handler.is_some() {
623 panic!("registered a handler for the same message twice");
624 }
625 self
626 }
627
628 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
629 where
630 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
631 Fut: 'static + Send + Future<Output = Result<()>>,
632 M: EnvelopedMessage,
633 {
634 self.add_handler(move |envelope, session| handler(envelope.payload, session));
635 self
636 }
637
638 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
639 where
640 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
641 Fut: Send + Future<Output = Result<()>>,
642 M: RequestMessage,
643 {
644 let handler = Arc::new(handler);
645 self.add_handler(move |envelope, session| {
646 let receipt = envelope.receipt();
647 let handler = handler.clone();
648 async move {
649 let peer = session.peer.clone();
650 let responded = Arc::new(AtomicBool::default());
651 let response = Response {
652 peer: peer.clone(),
653 responded: responded.clone(),
654 receipt,
655 };
656 match (handler)(envelope.payload, response, session).await {
657 Ok(()) => {
658 if responded.load(std::sync::atomic::Ordering::SeqCst) {
659 Ok(())
660 } else {
661 Err(anyhow!("handler did not send a response"))?
662 }
663 }
664 Err(error) => {
665 let proto_err = match &error {
666 Error::Internal(err) => err.to_proto(),
667 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
668 };
669 peer.respond_with_error(receipt, proto_err)?;
670 Err(error)
671 }
672 }
673 }
674 })
675 }
676
677 pub fn handle_connection(
678 self: &Arc<Self>,
679 connection: Connection,
680 address: String,
681 principal: Principal,
682 zed_version: ZedVersion,
683 geoip_country_code: Option<String>,
684 system_id: Option<String>,
685 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
686 executor: Executor,
687 ) -> impl Future<Output = ()> + use<> {
688 let this = self.clone();
689 let span = info_span!("handle connection", %address,
690 connection_id=field::Empty,
691 user_id=field::Empty,
692 login=field::Empty,
693 impersonator=field::Empty,
694 geoip_country_code=field::Empty
695 );
696 principal.update_span(&span);
697 if let Some(country_code) = geoip_country_code.as_ref() {
698 span.record("geoip_country_code", country_code);
699 }
700
701 let mut teardown = self.teardown.subscribe();
702 async move {
703 if *teardown.borrow() {
704 tracing::error!("server is tearing down");
705 return
706 }
707 let (connection_id, handle_io, mut incoming_rx) = this
708 .peer
709 .add_connection(connection, {
710 let executor = executor.clone();
711 move |duration| executor.sleep(duration)
712 });
713 tracing::Span::current().record("connection_id", format!("{}", connection_id));
714
715 tracing::info!("connection opened");
716
717 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
718 let http_client = match ReqwestClient::user_agent(&user_agent) {
719 Ok(http_client) => Arc::new(http_client),
720 Err(error) => {
721 tracing::error!(?error, "failed to create HTTP client");
722 return;
723 }
724 };
725
726 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
727 supermaven_admin_api_key.to_string(),
728 http_client.clone(),
729 )));
730
731 let session = Session {
732 principal: principal.clone(),
733 connection_id,
734 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
735 peer: this.peer.clone(),
736 connection_pool: this.connection_pool.clone(),
737 app_state: this.app_state.clone(),
738 geoip_country_code,
739 system_id,
740 _executor: executor.clone(),
741 supermaven_client,
742 };
743
744 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
745 tracing::error!(?error, "failed to send initial client update");
746 return;
747 }
748
749 let handle_io = handle_io.fuse();
750 futures::pin_mut!(handle_io);
751
752 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
753 // This prevents deadlocks when e.g., client A performs a request to client B and
754 // client B performs a request to client A. If both clients stop processing further
755 // messages until their respective request completes, they won't have a chance to
756 // respond to the other client's request and cause a deadlock.
757 //
758 // This arrangement ensures we will attempt to process earlier messages first, but fall
759 // back to processing messages arrived later in the spirit of making progress.
760 let mut foreground_message_handlers = FuturesUnordered::new();
761 let concurrent_handlers = Arc::new(Semaphore::new(256));
762 loop {
763 let next_message = async {
764 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
765 let message = incoming_rx.next().await;
766 (permit, message)
767 }.fuse();
768 futures::pin_mut!(next_message);
769 futures::select_biased! {
770 _ = teardown.changed().fuse() => return,
771 result = handle_io => {
772 if let Err(error) = result {
773 tracing::error!(?error, "error handling I/O");
774 }
775 break;
776 }
777 _ = foreground_message_handlers.next() => {}
778 next_message = next_message => {
779 let (permit, message) = next_message;
780 if let Some(message) = message {
781 let type_name = message.payload_type_name();
782 // note: we copy all the fields from the parent span so we can query them in the logs.
783 // (https://github.com/tokio-rs/tracing/issues/2670).
784 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
785 user_id=field::Empty,
786 login=field::Empty,
787 impersonator=field::Empty,
788 );
789 principal.update_span(&span);
790 let span_enter = span.enter();
791 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
792 let is_background = message.is_background();
793 let handle_message = (handler)(message, session.clone());
794 drop(span_enter);
795
796 let handle_message = async move {
797 handle_message.await;
798 drop(permit);
799 }.instrument(span);
800 if is_background {
801 executor.spawn_detached(handle_message);
802 } else {
803 foreground_message_handlers.push(handle_message);
804 }
805 } else {
806 tracing::error!("no message handler");
807 }
808 } else {
809 tracing::info!("connection closed");
810 break;
811 }
812 }
813 }
814 }
815
816 drop(foreground_message_handlers);
817 tracing::info!("signing out");
818 if let Err(error) = connection_lost(session, teardown, executor).await {
819 tracing::error!(?error, "error signing out");
820 }
821
822 }.instrument(span)
823 }
824
825 async fn send_initial_client_update(
826 &self,
827 connection_id: ConnectionId,
828 principal: &Principal,
829 zed_version: ZedVersion,
830 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
831 session: &Session,
832 ) -> Result<()> {
833 self.peer.send(
834 connection_id,
835 proto::Hello {
836 peer_id: Some(connection_id.into()),
837 },
838 )?;
839 tracing::info!("sent hello message");
840 if let Some(send_connection_id) = send_connection_id.take() {
841 let _ = send_connection_id.send(connection_id);
842 }
843
844 match principal {
845 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
846 if !user.connected_once {
847 self.peer.send(connection_id, proto::ShowContacts {})?;
848 self.app_state
849 .db
850 .set_user_connected_once(user.id, true)
851 .await?;
852 }
853
854 update_user_plan(user.id, session).await?;
855
856 let contacts = self.app_state.db.get_contacts(user.id).await?;
857
858 {
859 let mut pool = self.connection_pool.lock();
860 pool.add_connection(connection_id, user.id, user.admin, zed_version);
861 self.peer.send(
862 connection_id,
863 build_initial_contacts_update(contacts, &pool),
864 )?;
865 }
866
867 if should_auto_subscribe_to_channels(zed_version) {
868 subscribe_user_to_channels(user.id, session).await?;
869 }
870
871 if let Some(incoming_call) =
872 self.app_state.db.incoming_call_for_user(user.id).await?
873 {
874 self.peer.send(connection_id, incoming_call)?;
875 }
876
877 update_user_contacts(user.id, session).await?;
878 }
879 }
880
881 Ok(())
882 }
883
884 pub async fn invite_code_redeemed(
885 self: &Arc<Self>,
886 inviter_id: UserId,
887 invitee_id: UserId,
888 ) -> Result<()> {
889 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
890 if let Some(code) = &user.invite_code {
891 let pool = self.connection_pool.lock();
892 let invitee_contact = contact_for_user(invitee_id, false, &pool);
893 for connection_id in pool.user_connection_ids(inviter_id) {
894 self.peer.send(
895 connection_id,
896 proto::UpdateContacts {
897 contacts: vec![invitee_contact.clone()],
898 ..Default::default()
899 },
900 )?;
901 self.peer.send(
902 connection_id,
903 proto::UpdateInviteInfo {
904 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
905 count: user.invite_count as u32,
906 },
907 )?;
908 }
909 }
910 }
911 Ok(())
912 }
913
914 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
915 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
916 if let Some(invite_code) = &user.invite_code {
917 let pool = self.connection_pool.lock();
918 for connection_id in pool.user_connection_ids(user_id) {
919 self.peer.send(
920 connection_id,
921 proto::UpdateInviteInfo {
922 url: format!(
923 "{}{}",
924 self.app_state.config.invite_link_prefix, invite_code
925 ),
926 count: user.invite_count as u32,
927 },
928 )?;
929 }
930 }
931 }
932 Ok(())
933 }
934
935 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
936 let user = self
937 .app_state
938 .db
939 .get_user_by_id(user_id)
940 .await?
941 .ok_or_else(|| anyhow!("user not found"))?;
942
943 let update_user_plan = make_update_user_plan_message(
944 &self.app_state.db,
945 self.app_state.llm_db.clone(),
946 user_id,
947 user.admin,
948 )
949 .await?;
950
951 let pool = self.connection_pool.lock();
952 for connection_id in pool.user_connection_ids(user_id) {
953 self.peer
954 .send(connection_id, update_user_plan.clone())
955 .trace_err();
956 }
957
958 Ok(())
959 }
960
961 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
962 let pool = self.connection_pool.lock();
963 for connection_id in pool.user_connection_ids(user_id) {
964 self.peer
965 .send(connection_id, proto::RefreshLlmToken {})
966 .trace_err();
967 }
968 }
969
970 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
971 ServerSnapshot {
972 connection_pool: ConnectionPoolGuard {
973 guard: self.connection_pool.lock(),
974 _not_send: PhantomData,
975 },
976 peer: &self.peer,
977 }
978 }
979}
980
981impl Deref for ConnectionPoolGuard<'_> {
982 type Target = ConnectionPool;
983
984 fn deref(&self) -> &Self::Target {
985 &self.guard
986 }
987}
988
989impl DerefMut for ConnectionPoolGuard<'_> {
990 fn deref_mut(&mut self) -> &mut Self::Target {
991 &mut self.guard
992 }
993}
994
995impl Drop for ConnectionPoolGuard<'_> {
996 fn drop(&mut self) {
997 #[cfg(test)]
998 self.check_invariants();
999 }
1000}
1001
1002fn broadcast<F>(
1003 sender_id: Option<ConnectionId>,
1004 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005 mut f: F,
1006) where
1007 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009 for receiver_id in receiver_ids {
1010 if Some(receiver_id) != sender_id {
1011 if let Err(error) = f(receiver_id) {
1012 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013 }
1014 }
1015 }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021 fn name() -> &'static HeaderName {
1022 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024 }
1025
1026 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027 where
1028 Self: Sized,
1029 I: Iterator<Item = &'i axum::http::HeaderValue>,
1030 {
1031 let version = values
1032 .next()
1033 .ok_or_else(axum::headers::Error::invalid)?
1034 .to_str()
1035 .map_err(|_| axum::headers::Error::invalid())?
1036 .parse()
1037 .map_err(|_| axum::headers::Error::invalid())?;
1038 Ok(Self(version))
1039 }
1040
1041 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042 values.extend([self.0.to_string().parse().unwrap()]);
1043 }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048 fn name() -> &'static HeaderName {
1049 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051 }
1052
1053 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054 where
1055 Self: Sized,
1056 I: Iterator<Item = &'i axum::http::HeaderValue>,
1057 {
1058 let version = values
1059 .next()
1060 .ok_or_else(axum::headers::Error::invalid)?
1061 .to_str()
1062 .map_err(|_| axum::headers::Error::invalid())?
1063 .parse()
1064 .map_err(|_| axum::headers::Error::invalid())?;
1065 Ok(Self(version))
1066 }
1067
1068 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069 values.extend([self.0.to_string().parse().unwrap()]);
1070 }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074 Router::new()
1075 .route("/rpc", get(handle_websocket_request))
1076 .layer(
1077 ServiceBuilder::new()
1078 .layer(Extension(server.app_state.clone()))
1079 .layer(middleware::from_fn(auth::validate_header)),
1080 )
1081 .route("/metrics", get(handle_metrics))
1082 .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089 Extension(server): Extension<Arc<Server>>,
1090 Extension(principal): Extension<Principal>,
1091 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093 ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095 if protocol_version != rpc::PROTOCOL_VERSION {
1096 return (
1097 StatusCode::UPGRADE_REQUIRED,
1098 "client must be upgraded".to_string(),
1099 )
1100 .into_response();
1101 }
1102
1103 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1104 return (
1105 StatusCode::UPGRADE_REQUIRED,
1106 "no version header found".to_string(),
1107 )
1108 .into_response();
1109 };
1110
1111 if !version.can_collaborate() {
1112 return (
1113 StatusCode::UPGRADE_REQUIRED,
1114 "client must be upgraded".to_string(),
1115 )
1116 .into_response();
1117 }
1118
1119 let socket_address = socket_address.to_string();
1120 ws.on_upgrade(move |socket| {
1121 let socket = socket
1122 .map_ok(to_tungstenite_message)
1123 .err_into()
1124 .with(|message| async move { to_axum_message(message) });
1125 let connection = Connection::new(Box::pin(socket));
1126 async move {
1127 server
1128 .handle_connection(
1129 connection,
1130 socket_address,
1131 principal,
1132 version,
1133 country_code_header.map(|header| header.to_string()),
1134 system_id_header.map(|header| header.to_string()),
1135 None,
1136 Executor::Production,
1137 )
1138 .await;
1139 }
1140 })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145 let connections_metric = CONNECTIONS_METRIC
1146 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148 let connections = server
1149 .connection_pool
1150 .lock()
1151 .connections()
1152 .filter(|connection| !connection.admin)
1153 .count();
1154 connections_metric.set(connections as _);
1155
1156 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158 register_int_gauge!(
1159 "shared_projects",
1160 "number of open projects with one or more guests"
1161 )
1162 .unwrap()
1163 });
1164
1165 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166 shared_projects_metric.set(shared_projects as _);
1167
1168 let encoder = prometheus::TextEncoder::new();
1169 let metric_families = prometheus::gather();
1170 let encoded_metrics = encoder
1171 .encode_to_string(&metric_families)
1172 .map_err(|err| anyhow!("{}", err))?;
1173 Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178 session: Session,
1179 mut teardown: watch::Receiver<bool>,
1180 executor: Executor,
1181) -> Result<()> {
1182 session.peer.disconnect(session.connection_id);
1183 session
1184 .connection_pool()
1185 .await
1186 .remove_connection(session.connection_id)?;
1187
1188 session
1189 .db()
1190 .await
1191 .connection_lost(session.connection_id)
1192 .await
1193 .trace_err();
1194
1195 futures::select_biased! {
1196 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199 leave_room_for_session(&session, session.connection_id).await.trace_err();
1200 leave_channel_buffers_for_session(&session)
1201 .await
1202 .trace_err();
1203
1204 if !session
1205 .connection_pool()
1206 .await
1207 .is_user_online(session.user_id())
1208 {
1209 let db = session.db().await;
1210 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211 room_updated(&room, &session.peer);
1212 }
1213 }
1214
1215 update_user_contacts(session.user_id(), &session).await?;
1216 },
1217 _ = teardown.changed().fuse() => {}
1218 }
1219
1220 Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225 response.send(proto::Ack {})?;
1226 Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231 _request: proto::CreateRoom,
1232 response: Response<proto::CreateRoom>,
1233 session: Session,
1234) -> Result<()> {
1235 let livekit_room = nanoid::nanoid!(30);
1236
1237 let live_kit_connection_info = util::maybe!(async {
1238 let live_kit = session.app_state.livekit_client.as_ref();
1239 let live_kit = live_kit?;
1240 let user_id = session.user_id().to_string();
1241
1242 let token = live_kit
1243 .room_token(&livekit_room, &user_id.to_string())
1244 .trace_err()?;
1245
1246 Some(proto::LiveKitConnectionInfo {
1247 server_url: live_kit.url().into(),
1248 token,
1249 can_publish: true,
1250 })
1251 })
1252 .await;
1253
1254 let room = session
1255 .db()
1256 .await
1257 .create_room(session.user_id(), session.connection_id, &livekit_room)
1258 .await?;
1259
1260 response.send(proto::CreateRoomResponse {
1261 room: Some(room.clone()),
1262 live_kit_connection_info,
1263 })?;
1264
1265 update_user_contacts(session.user_id(), &session).await?;
1266 Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271 request: proto::JoinRoom,
1272 response: Response<proto::JoinRoom>,
1273 session: Session,
1274) -> Result<()> {
1275 let room_id = RoomId::from_proto(request.id);
1276
1277 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279 if let Some(channel_id) = channel_id {
1280 return join_channel_internal(channel_id, Box::new(response), session).await;
1281 }
1282
1283 let joined_room = {
1284 let room = session
1285 .db()
1286 .await
1287 .join_room(room_id, session.user_id(), session.connection_id)
1288 .await?;
1289 room_updated(&room.room, &session.peer);
1290 room.into_inner()
1291 };
1292
1293 for connection_id in session
1294 .connection_pool()
1295 .await
1296 .user_connection_ids(session.user_id())
1297 {
1298 session
1299 .peer
1300 .send(
1301 connection_id,
1302 proto::CallCanceled {
1303 room_id: room_id.to_proto(),
1304 },
1305 )
1306 .trace_err();
1307 }
1308
1309 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310 {
1311 live_kit
1312 .room_token(
1313 &joined_room.room.livekit_room,
1314 &session.user_id().to_string(),
1315 )
1316 .trace_err()
1317 .map(|token| proto::LiveKitConnectionInfo {
1318 server_url: live_kit.url().into(),
1319 token,
1320 can_publish: true,
1321 })
1322 } else {
1323 None
1324 };
1325
1326 response.send(proto::JoinRoomResponse {
1327 room: Some(joined_room.room),
1328 channel_id: None,
1329 live_kit_connection_info,
1330 })?;
1331
1332 update_user_contacts(session.user_id(), &session).await?;
1333 Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338 request: proto::RejoinRoom,
1339 response: Response<proto::RejoinRoom>,
1340 session: Session,
1341) -> Result<()> {
1342 let room;
1343 let channel;
1344 {
1345 let mut rejoined_room = session
1346 .db()
1347 .await
1348 .rejoin_room(request, session.user_id(), session.connection_id)
1349 .await?;
1350
1351 response.send(proto::RejoinRoomResponse {
1352 room: Some(rejoined_room.room.clone()),
1353 reshared_projects: rejoined_room
1354 .reshared_projects
1355 .iter()
1356 .map(|project| proto::ResharedProject {
1357 id: project.id.to_proto(),
1358 collaborators: project
1359 .collaborators
1360 .iter()
1361 .map(|collaborator| collaborator.to_proto())
1362 .collect(),
1363 })
1364 .collect(),
1365 rejoined_projects: rejoined_room
1366 .rejoined_projects
1367 .iter()
1368 .map(|rejoined_project| rejoined_project.to_proto())
1369 .collect(),
1370 })?;
1371 room_updated(&rejoined_room.room, &session.peer);
1372
1373 for project in &rejoined_room.reshared_projects {
1374 for collaborator in &project.collaborators {
1375 session
1376 .peer
1377 .send(
1378 collaborator.connection_id,
1379 proto::UpdateProjectCollaborator {
1380 project_id: project.id.to_proto(),
1381 old_peer_id: Some(project.old_connection_id.into()),
1382 new_peer_id: Some(session.connection_id.into()),
1383 },
1384 )
1385 .trace_err();
1386 }
1387
1388 broadcast(
1389 Some(session.connection_id),
1390 project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.connection_id),
1394 |connection_id| {
1395 session.peer.forward_send(
1396 session.connection_id,
1397 connection_id,
1398 proto::UpdateProject {
1399 project_id: project.id.to_proto(),
1400 worktrees: project.worktrees.clone(),
1401 },
1402 )
1403 },
1404 );
1405 }
1406
1407 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409 let rejoined_room = rejoined_room.into_inner();
1410
1411 room = rejoined_room.room;
1412 channel = rejoined_room.channel;
1413 }
1414
1415 if let Some(channel) = channel {
1416 channel_updated(
1417 &channel,
1418 &room,
1419 &session.peer,
1420 &*session.connection_pool().await,
1421 );
1422 }
1423
1424 update_user_contacts(session.user_id(), &session).await?;
1425 Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429 rejoined_projects: &mut Vec<RejoinedProject>,
1430 session: &Session,
1431) -> Result<()> {
1432 for project in rejoined_projects.iter() {
1433 for collaborator in &project.collaborators {
1434 session
1435 .peer
1436 .send(
1437 collaborator.connection_id,
1438 proto::UpdateProjectCollaborator {
1439 project_id: project.id.to_proto(),
1440 old_peer_id: Some(project.old_connection_id.into()),
1441 new_peer_id: Some(session.connection_id.into()),
1442 },
1443 )
1444 .trace_err();
1445 }
1446 }
1447
1448 for project in rejoined_projects {
1449 for worktree in mem::take(&mut project.worktrees) {
1450 // Stream this worktree's entries.
1451 let message = proto::UpdateWorktree {
1452 project_id: project.id.to_proto(),
1453 worktree_id: worktree.id,
1454 abs_path: worktree.abs_path.clone(),
1455 root_name: worktree.root_name,
1456 updated_entries: worktree.updated_entries,
1457 removed_entries: worktree.removed_entries,
1458 scan_id: worktree.scan_id,
1459 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460 updated_repositories: worktree.updated_repositories,
1461 removed_repositories: worktree.removed_repositories,
1462 };
1463 for update in proto::split_worktree_update(message) {
1464 session.peer.send(session.connection_id, update)?;
1465 }
1466
1467 // Stream this worktree's diagnostics.
1468 for summary in worktree.diagnostic_summaries {
1469 session.peer.send(
1470 session.connection_id,
1471 proto::UpdateDiagnosticSummary {
1472 project_id: project.id.to_proto(),
1473 worktree_id: worktree.id,
1474 summary: Some(summary),
1475 },
1476 )?;
1477 }
1478
1479 for settings_file in worktree.settings_files {
1480 session.peer.send(
1481 session.connection_id,
1482 proto::UpdateWorktreeSettings {
1483 project_id: project.id.to_proto(),
1484 worktree_id: worktree.id,
1485 path: settings_file.path,
1486 content: Some(settings_file.content),
1487 kind: Some(settings_file.kind.to_proto().into()),
1488 },
1489 )?;
1490 }
1491 }
1492
1493 for repository in mem::take(&mut project.updated_repositories) {
1494 for update in split_repository_update(repository) {
1495 session.peer.send(session.connection_id, update)?;
1496 }
1497 }
1498
1499 for id in mem::take(&mut project.removed_repositories) {
1500 session.peer.send(
1501 session.connection_id,
1502 proto::RemoveRepository {
1503 project_id: project.id.to_proto(),
1504 id,
1505 },
1506 )?;
1507 }
1508 }
1509
1510 Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515 _: proto::LeaveRoom,
1516 response: Response<proto::LeaveRoom>,
1517 session: Session,
1518) -> Result<()> {
1519 leave_room_for_session(&session, session.connection_id).await?;
1520 response.send(proto::Ack {})?;
1521 Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526 request: proto::SetRoomParticipantRole,
1527 response: Response<proto::SetRoomParticipantRole>,
1528 session: Session,
1529) -> Result<()> {
1530 let user_id = UserId::from_proto(request.user_id);
1531 let role = ChannelRole::from(request.role());
1532
1533 let (livekit_room, can_publish) = {
1534 let room = session
1535 .db()
1536 .await
1537 .set_room_participant_role(
1538 session.user_id(),
1539 RoomId::from_proto(request.room_id),
1540 user_id,
1541 role,
1542 )
1543 .await?;
1544
1545 let livekit_room = room.livekit_room.clone();
1546 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547 room_updated(&room, &session.peer);
1548 (livekit_room, can_publish)
1549 };
1550
1551 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552 live_kit
1553 .update_participant(
1554 livekit_room.clone(),
1555 request.user_id.to_string(),
1556 livekit_api::proto::ParticipantPermission {
1557 can_subscribe: true,
1558 can_publish,
1559 can_publish_data: can_publish,
1560 hidden: false,
1561 recorder: false,
1562 },
1563 )
1564 .await
1565 .trace_err();
1566 }
1567
1568 response.send(proto::Ack {})?;
1569 Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574 request: proto::Call,
1575 response: Response<proto::Call>,
1576 session: Session,
1577) -> Result<()> {
1578 let room_id = RoomId::from_proto(request.room_id);
1579 let calling_user_id = session.user_id();
1580 let calling_connection_id = session.connection_id;
1581 let called_user_id = UserId::from_proto(request.called_user_id);
1582 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583 if !session
1584 .db()
1585 .await
1586 .has_contact(calling_user_id, called_user_id)
1587 .await?
1588 {
1589 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590 }
1591
1592 let incoming_call = {
1593 let (room, incoming_call) = &mut *session
1594 .db()
1595 .await
1596 .call(
1597 room_id,
1598 calling_user_id,
1599 calling_connection_id,
1600 called_user_id,
1601 initial_project_id,
1602 )
1603 .await?;
1604 room_updated(room, &session.peer);
1605 mem::take(incoming_call)
1606 };
1607 update_user_contacts(called_user_id, &session).await?;
1608
1609 let mut calls = session
1610 .connection_pool()
1611 .await
1612 .user_connection_ids(called_user_id)
1613 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614 .collect::<FuturesUnordered<_>>();
1615
1616 while let Some(call_response) = calls.next().await {
1617 match call_response.as_ref() {
1618 Ok(_) => {
1619 response.send(proto::Ack {})?;
1620 return Ok(());
1621 }
1622 Err(_) => {
1623 call_response.trace_err();
1624 }
1625 }
1626 }
1627
1628 {
1629 let room = session
1630 .db()
1631 .await
1632 .call_failed(room_id, called_user_id)
1633 .await?;
1634 room_updated(&room, &session.peer);
1635 }
1636 update_user_contacts(called_user_id, &session).await?;
1637
1638 Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643 request: proto::CancelCall,
1644 response: Response<proto::CancelCall>,
1645 session: Session,
1646) -> Result<()> {
1647 let called_user_id = UserId::from_proto(request.called_user_id);
1648 let room_id = RoomId::from_proto(request.room_id);
1649 {
1650 let room = session
1651 .db()
1652 .await
1653 .cancel_call(room_id, session.connection_id, called_user_id)
1654 .await?;
1655 room_updated(&room, &session.peer);
1656 }
1657
1658 for connection_id in session
1659 .connection_pool()
1660 .await
1661 .user_connection_ids(called_user_id)
1662 {
1663 session
1664 .peer
1665 .send(
1666 connection_id,
1667 proto::CallCanceled {
1668 room_id: room_id.to_proto(),
1669 },
1670 )
1671 .trace_err();
1672 }
1673 response.send(proto::Ack {})?;
1674
1675 update_user_contacts(called_user_id, &session).await?;
1676 Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681 let room_id = RoomId::from_proto(message.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .decline_call(Some(room_id), session.user_id())
1687 .await?
1688 .ok_or_else(|| anyhow!("failed to decline call"))?;
1689 room_updated(&room, &session.peer);
1690 }
1691
1692 for connection_id in session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(session.user_id())
1696 {
1697 session
1698 .peer
1699 .send(
1700 connection_id,
1701 proto::CallCanceled {
1702 room_id: room_id.to_proto(),
1703 },
1704 )
1705 .trace_err();
1706 }
1707 update_user_contacts(session.user_id(), &session).await?;
1708 Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713 request: proto::UpdateParticipantLocation,
1714 response: Response<proto::UpdateParticipantLocation>,
1715 session: Session,
1716) -> Result<()> {
1717 let room_id = RoomId::from_proto(request.room_id);
1718 let location = request
1719 .location
1720 .ok_or_else(|| anyhow!("invalid location"))?;
1721
1722 let db = session.db().await;
1723 let room = db
1724 .update_room_participant_location(room_id, session.connection_id, location)
1725 .await?;
1726
1727 room_updated(&room, &session.peer);
1728 response.send(proto::Ack {})?;
1729 Ok(())
1730}
1731
1732/// Share a project into the room.
1733async fn share_project(
1734 request: proto::ShareProject,
1735 response: Response<proto::ShareProject>,
1736 session: Session,
1737) -> Result<()> {
1738 let (project_id, room) = &*session
1739 .db()
1740 .await
1741 .share_project(
1742 RoomId::from_proto(request.room_id),
1743 session.connection_id,
1744 &request.worktrees,
1745 request.is_ssh_project,
1746 )
1747 .await?;
1748 response.send(proto::ShareProjectResponse {
1749 project_id: project_id.to_proto(),
1750 })?;
1751 room_updated(room, &session.peer);
1752
1753 Ok(())
1754}
1755
1756/// Unshare a project from the room.
1757async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1758 let project_id = ProjectId::from_proto(message.project_id);
1759 unshare_project_internal(project_id, session.connection_id, &session).await
1760}
1761
1762async fn unshare_project_internal(
1763 project_id: ProjectId,
1764 connection_id: ConnectionId,
1765 session: &Session,
1766) -> Result<()> {
1767 let delete = {
1768 let room_guard = session
1769 .db()
1770 .await
1771 .unshare_project(project_id, connection_id)
1772 .await?;
1773
1774 let (delete, room, guest_connection_ids) = &*room_guard;
1775
1776 let message = proto::UnshareProject {
1777 project_id: project_id.to_proto(),
1778 };
1779
1780 broadcast(
1781 Some(connection_id),
1782 guest_connection_ids.iter().copied(),
1783 |conn_id| session.peer.send(conn_id, message.clone()),
1784 );
1785 if let Some(room) = room {
1786 room_updated(room, &session.peer);
1787 }
1788
1789 *delete
1790 };
1791
1792 if delete {
1793 let db = session.db().await;
1794 db.delete_project(project_id).await?;
1795 }
1796
1797 Ok(())
1798}
1799
1800/// Join someone elses shared project.
1801async fn join_project(
1802 request: proto::JoinProject,
1803 response: Response<proto::JoinProject>,
1804 session: Session,
1805) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807
1808 tracing::info!(%project_id, "join project");
1809
1810 let db = session.db().await;
1811 let (project, replica_id) = &mut *db
1812 .join_project(project_id, session.connection_id, session.user_id())
1813 .await?;
1814 drop(db);
1815 tracing::info!(%project_id, "join remote project");
1816 join_project_internal(response, session, project, replica_id)
1817}
1818
1819trait JoinProjectInternalResponse {
1820 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1821}
1822impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1823 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1824 Response::<proto::JoinProject>::send(self, result)
1825 }
1826}
1827
1828fn join_project_internal(
1829 response: impl JoinProjectInternalResponse,
1830 session: Session,
1831 project: &mut Project,
1832 replica_id: &ReplicaId,
1833) -> Result<()> {
1834 let collaborators = project
1835 .collaborators
1836 .iter()
1837 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1838 .map(|collaborator| collaborator.to_proto())
1839 .collect::<Vec<_>>();
1840 let project_id = project.id;
1841 let guest_user_id = session.user_id();
1842
1843 let worktrees = project
1844 .worktrees
1845 .iter()
1846 .map(|(id, worktree)| proto::WorktreeMetadata {
1847 id: *id,
1848 root_name: worktree.root_name.clone(),
1849 visible: worktree.visible,
1850 abs_path: worktree.abs_path.clone(),
1851 })
1852 .collect::<Vec<_>>();
1853
1854 let add_project_collaborator = proto::AddProjectCollaborator {
1855 project_id: project_id.to_proto(),
1856 collaborator: Some(proto::Collaborator {
1857 peer_id: Some(session.connection_id.into()),
1858 replica_id: replica_id.0 as u32,
1859 user_id: guest_user_id.to_proto(),
1860 is_host: false,
1861 }),
1862 };
1863
1864 for collaborator in &collaborators {
1865 session
1866 .peer
1867 .send(
1868 collaborator.peer_id.unwrap().into(),
1869 add_project_collaborator.clone(),
1870 )
1871 .trace_err();
1872 }
1873
1874 // First, we send the metadata associated with each worktree.
1875 response.send(proto::JoinProjectResponse {
1876 project_id: project.id.0 as u64,
1877 worktrees: worktrees.clone(),
1878 replica_id: replica_id.0 as u32,
1879 collaborators: collaborators.clone(),
1880 language_servers: project.language_servers.clone(),
1881 role: project.role.into(),
1882 })?;
1883
1884 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885 // Stream this worktree's entries.
1886 let message = proto::UpdateWorktree {
1887 project_id: project_id.to_proto(),
1888 worktree_id,
1889 abs_path: worktree.abs_path.clone(),
1890 root_name: worktree.root_name,
1891 updated_entries: worktree.entries,
1892 removed_entries: Default::default(),
1893 scan_id: worktree.scan_id,
1894 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1895 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1896 removed_repositories: Default::default(),
1897 };
1898 for update in proto::split_worktree_update(message) {
1899 session.peer.send(session.connection_id, update.clone())?;
1900 }
1901
1902 // Stream this worktree's diagnostics.
1903 for summary in worktree.diagnostic_summaries {
1904 session.peer.send(
1905 session.connection_id,
1906 proto::UpdateDiagnosticSummary {
1907 project_id: project_id.to_proto(),
1908 worktree_id: worktree.id,
1909 summary: Some(summary),
1910 },
1911 )?;
1912 }
1913
1914 for settings_file in worktree.settings_files {
1915 session.peer.send(
1916 session.connection_id,
1917 proto::UpdateWorktreeSettings {
1918 project_id: project_id.to_proto(),
1919 worktree_id: worktree.id,
1920 path: settings_file.path,
1921 content: Some(settings_file.content),
1922 kind: Some(settings_file.kind.to_proto() as i32),
1923 },
1924 )?;
1925 }
1926 }
1927
1928 for repository in mem::take(&mut project.repositories) {
1929 for update in split_repository_update(repository) {
1930 session.peer.send(session.connection_id, update)?;
1931 }
1932 }
1933
1934 for language_server in &project.language_servers {
1935 session.peer.send(
1936 session.connection_id,
1937 proto::UpdateLanguageServer {
1938 project_id: project_id.to_proto(),
1939 language_server_id: language_server.id,
1940 variant: Some(
1941 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1942 proto::LspDiskBasedDiagnosticsUpdated {},
1943 ),
1944 ),
1945 },
1946 )?;
1947 }
1948
1949 Ok(())
1950}
1951
1952/// Leave someone elses shared project.
1953async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1954 let sender_id = session.connection_id;
1955 let project_id = ProjectId::from_proto(request.project_id);
1956 let db = session.db().await;
1957
1958 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1959 tracing::info!(
1960 %project_id,
1961 "leave project"
1962 );
1963
1964 project_left(project, &session);
1965 if let Some(room) = room {
1966 room_updated(room, &session.peer);
1967 }
1968
1969 Ok(())
1970}
1971
1972/// Updates other participants with changes to the project
1973async fn update_project(
1974 request: proto::UpdateProject,
1975 response: Response<proto::UpdateProject>,
1976 session: Session,
1977) -> Result<()> {
1978 let project_id = ProjectId::from_proto(request.project_id);
1979 let (room, guest_connection_ids) = &*session
1980 .db()
1981 .await
1982 .update_project(project_id, session.connection_id, &request.worktrees)
1983 .await?;
1984 broadcast(
1985 Some(session.connection_id),
1986 guest_connection_ids.iter().copied(),
1987 |connection_id| {
1988 session
1989 .peer
1990 .forward_send(session.connection_id, connection_id, request.clone())
1991 },
1992 );
1993 if let Some(room) = room {
1994 room_updated(room, &session.peer);
1995 }
1996 response.send(proto::Ack {})?;
1997
1998 Ok(())
1999}
2000
2001/// Updates other participants with changes to the worktree
2002async fn update_worktree(
2003 request: proto::UpdateWorktree,
2004 response: Response<proto::UpdateWorktree>,
2005 session: Session,
2006) -> Result<()> {
2007 let guest_connection_ids = session
2008 .db()
2009 .await
2010 .update_worktree(&request, session.connection_id)
2011 .await?;
2012
2013 broadcast(
2014 Some(session.connection_id),
2015 guest_connection_ids.iter().copied(),
2016 |connection_id| {
2017 session
2018 .peer
2019 .forward_send(session.connection_id, connection_id, request.clone())
2020 },
2021 );
2022 response.send(proto::Ack {})?;
2023 Ok(())
2024}
2025
2026async fn update_repository(
2027 request: proto::UpdateRepository,
2028 response: Response<proto::UpdateRepository>,
2029 session: Session,
2030) -> Result<()> {
2031 let guest_connection_ids = session
2032 .db()
2033 .await
2034 .update_repository(&request, session.connection_id)
2035 .await?;
2036
2037 broadcast(
2038 Some(session.connection_id),
2039 guest_connection_ids.iter().copied(),
2040 |connection_id| {
2041 session
2042 .peer
2043 .forward_send(session.connection_id, connection_id, request.clone())
2044 },
2045 );
2046 response.send(proto::Ack {})?;
2047 Ok(())
2048}
2049
2050async fn remove_repository(
2051 request: proto::RemoveRepository,
2052 response: Response<proto::RemoveRepository>,
2053 session: Session,
2054) -> Result<()> {
2055 let guest_connection_ids = session
2056 .db()
2057 .await
2058 .remove_repository(&request, session.connection_id)
2059 .await?;
2060
2061 broadcast(
2062 Some(session.connection_id),
2063 guest_connection_ids.iter().copied(),
2064 |connection_id| {
2065 session
2066 .peer
2067 .forward_send(session.connection_id, connection_id, request.clone())
2068 },
2069 );
2070 response.send(proto::Ack {})?;
2071 Ok(())
2072}
2073
2074/// Updates other participants with changes to the diagnostics
2075async fn update_diagnostic_summary(
2076 message: proto::UpdateDiagnosticSummary,
2077 session: Session,
2078) -> Result<()> {
2079 let guest_connection_ids = session
2080 .db()
2081 .await
2082 .update_diagnostic_summary(&message, session.connection_id)
2083 .await?;
2084
2085 broadcast(
2086 Some(session.connection_id),
2087 guest_connection_ids.iter().copied(),
2088 |connection_id| {
2089 session
2090 .peer
2091 .forward_send(session.connection_id, connection_id, message.clone())
2092 },
2093 );
2094
2095 Ok(())
2096}
2097
2098/// Updates other participants with changes to the worktree settings
2099async fn update_worktree_settings(
2100 message: proto::UpdateWorktreeSettings,
2101 session: Session,
2102) -> Result<()> {
2103 let guest_connection_ids = session
2104 .db()
2105 .await
2106 .update_worktree_settings(&message, session.connection_id)
2107 .await?;
2108
2109 broadcast(
2110 Some(session.connection_id),
2111 guest_connection_ids.iter().copied(),
2112 |connection_id| {
2113 session
2114 .peer
2115 .forward_send(session.connection_id, connection_id, message.clone())
2116 },
2117 );
2118
2119 Ok(())
2120}
2121
2122/// Notify other participants that a language server has started.
2123async fn start_language_server(
2124 request: proto::StartLanguageServer,
2125 session: Session,
2126) -> Result<()> {
2127 let guest_connection_ids = session
2128 .db()
2129 .await
2130 .start_language_server(&request, session.connection_id)
2131 .await?;
2132
2133 broadcast(
2134 Some(session.connection_id),
2135 guest_connection_ids.iter().copied(),
2136 |connection_id| {
2137 session
2138 .peer
2139 .forward_send(session.connection_id, connection_id, request.clone())
2140 },
2141 );
2142 Ok(())
2143}
2144
2145/// Notify other participants that a language server has changed.
2146async fn update_language_server(
2147 request: proto::UpdateLanguageServer,
2148 session: Session,
2149) -> Result<()> {
2150 let project_id = ProjectId::from_proto(request.project_id);
2151 let project_connection_ids = session
2152 .db()
2153 .await
2154 .project_connection_ids(project_id, session.connection_id, true)
2155 .await?;
2156 broadcast(
2157 Some(session.connection_id),
2158 project_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, request.clone())
2163 },
2164 );
2165 Ok(())
2166}
2167
2168/// forward a project request to the host. These requests should be read only
2169/// as guests are allowed to send them.
2170async fn forward_read_only_project_request<T>(
2171 request: T,
2172 response: Response<T>,
2173 session: Session,
2174) -> Result<()>
2175where
2176 T: EntityMessage + RequestMessage,
2177{
2178 let project_id = ProjectId::from_proto(request.remote_entity_id());
2179 let host_connection_id = session
2180 .db()
2181 .await
2182 .host_for_read_only_project_request(project_id, session.connection_id)
2183 .await?;
2184 let payload = session
2185 .peer
2186 .forward_request(session.connection_id, host_connection_id, request)
2187 .await?;
2188 response.send(payload)?;
2189 Ok(())
2190}
2191
2192async fn forward_find_search_candidates_request(
2193 request: proto::FindSearchCandidates,
2194 response: Response<proto::FindSearchCandidates>,
2195 session: Session,
2196) -> Result<()> {
2197 let project_id = ProjectId::from_proto(request.remote_entity_id());
2198 let host_connection_id = session
2199 .db()
2200 .await
2201 .host_for_read_only_project_request(project_id, session.connection_id)
2202 .await?;
2203 let payload = session
2204 .peer
2205 .forward_request(session.connection_id, host_connection_id, request)
2206 .await?;
2207 response.send(payload)?;
2208 Ok(())
2209}
2210
2211/// forward a project request to the host. These requests are disallowed
2212/// for guests.
2213async fn forward_mutating_project_request<T>(
2214 request: T,
2215 response: Response<T>,
2216 session: Session,
2217) -> Result<()>
2218where
2219 T: EntityMessage + RequestMessage,
2220{
2221 let project_id = ProjectId::from_proto(request.remote_entity_id());
2222
2223 let host_connection_id = session
2224 .db()
2225 .await
2226 .host_for_mutating_project_request(project_id, session.connection_id)
2227 .await?;
2228 let payload = session
2229 .peer
2230 .forward_request(session.connection_id, host_connection_id, request)
2231 .await?;
2232 response.send(payload)?;
2233 Ok(())
2234}
2235
2236/// Notify other participants that a new buffer has been created
2237async fn create_buffer_for_peer(
2238 request: proto::CreateBufferForPeer,
2239 session: Session,
2240) -> Result<()> {
2241 session
2242 .db()
2243 .await
2244 .check_user_is_project_host(
2245 ProjectId::from_proto(request.project_id),
2246 session.connection_id,
2247 )
2248 .await?;
2249 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2250 session
2251 .peer
2252 .forward_send(session.connection_id, peer_id.into(), request)?;
2253 Ok(())
2254}
2255
2256/// Notify other participants that a buffer has been updated. This is
2257/// allowed for guests as long as the update is limited to selections.
2258async fn update_buffer(
2259 request: proto::UpdateBuffer,
2260 response: Response<proto::UpdateBuffer>,
2261 session: Session,
2262) -> Result<()> {
2263 let project_id = ProjectId::from_proto(request.project_id);
2264 let mut capability = Capability::ReadOnly;
2265
2266 for op in request.operations.iter() {
2267 match op.variant {
2268 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2269 Some(_) => capability = Capability::ReadWrite,
2270 }
2271 }
2272
2273 let host = {
2274 let guard = session
2275 .db()
2276 .await
2277 .connections_for_buffer_update(project_id, session.connection_id, capability)
2278 .await?;
2279
2280 let (host, guests) = &*guard;
2281
2282 broadcast(
2283 Some(session.connection_id),
2284 guests.clone(),
2285 |connection_id| {
2286 session
2287 .peer
2288 .forward_send(session.connection_id, connection_id, request.clone())
2289 },
2290 );
2291
2292 *host
2293 };
2294
2295 if host != session.connection_id {
2296 session
2297 .peer
2298 .forward_request(session.connection_id, host, request.clone())
2299 .await?;
2300 }
2301
2302 response.send(proto::Ack {})?;
2303 Ok(())
2304}
2305
2306async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2307 let project_id = ProjectId::from_proto(message.project_id);
2308
2309 let operation = message.operation.as_ref().context("invalid operation")?;
2310 let capability = match operation.variant.as_ref() {
2311 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2312 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2313 match buffer_op.variant {
2314 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2315 Capability::ReadOnly
2316 }
2317 _ => Capability::ReadWrite,
2318 }
2319 } else {
2320 Capability::ReadWrite
2321 }
2322 }
2323 Some(_) => Capability::ReadWrite,
2324 None => Capability::ReadOnly,
2325 };
2326
2327 let guard = session
2328 .db()
2329 .await
2330 .connections_for_buffer_update(project_id, session.connection_id, capability)
2331 .await?;
2332
2333 let (host, guests) = &*guard;
2334
2335 broadcast(
2336 Some(session.connection_id),
2337 guests.iter().chain([host]).copied(),
2338 |connection_id| {
2339 session
2340 .peer
2341 .forward_send(session.connection_id, connection_id, message.clone())
2342 },
2343 );
2344
2345 Ok(())
2346}
2347
2348/// Notify other participants that a project has been updated.
2349async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2350 request: T,
2351 session: Session,
2352) -> Result<()> {
2353 let project_id = ProjectId::from_proto(request.remote_entity_id());
2354 let project_connection_ids = session
2355 .db()
2356 .await
2357 .project_connection_ids(project_id, session.connection_id, false)
2358 .await?;
2359
2360 broadcast(
2361 Some(session.connection_id),
2362 project_connection_ids.iter().copied(),
2363 |connection_id| {
2364 session
2365 .peer
2366 .forward_send(session.connection_id, connection_id, request.clone())
2367 },
2368 );
2369 Ok(())
2370}
2371
2372/// Start following another user in a call.
2373async fn follow(
2374 request: proto::Follow,
2375 response: Response<proto::Follow>,
2376 session: Session,
2377) -> Result<()> {
2378 let room_id = RoomId::from_proto(request.room_id);
2379 let project_id = request.project_id.map(ProjectId::from_proto);
2380 let leader_id = request
2381 .leader_id
2382 .ok_or_else(|| anyhow!("invalid leader id"))?
2383 .into();
2384 let follower_id = session.connection_id;
2385
2386 session
2387 .db()
2388 .await
2389 .check_room_participants(room_id, leader_id, session.connection_id)
2390 .await?;
2391
2392 let response_payload = session
2393 .peer
2394 .forward_request(session.connection_id, leader_id, request)
2395 .await?;
2396 response.send(response_payload)?;
2397
2398 if let Some(project_id) = project_id {
2399 let room = session
2400 .db()
2401 .await
2402 .follow(room_id, project_id, leader_id, follower_id)
2403 .await?;
2404 room_updated(&room, &session.peer);
2405 }
2406
2407 Ok(())
2408}
2409
2410/// Stop following another user in a call.
2411async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2412 let room_id = RoomId::from_proto(request.room_id);
2413 let project_id = request.project_id.map(ProjectId::from_proto);
2414 let leader_id = request
2415 .leader_id
2416 .ok_or_else(|| anyhow!("invalid leader id"))?
2417 .into();
2418 let follower_id = session.connection_id;
2419
2420 session
2421 .db()
2422 .await
2423 .check_room_participants(room_id, leader_id, session.connection_id)
2424 .await?;
2425
2426 session
2427 .peer
2428 .forward_send(session.connection_id, leader_id, request)?;
2429
2430 if let Some(project_id) = project_id {
2431 let room = session
2432 .db()
2433 .await
2434 .unfollow(room_id, project_id, leader_id, follower_id)
2435 .await?;
2436 room_updated(&room, &session.peer);
2437 }
2438
2439 Ok(())
2440}
2441
2442/// Notify everyone following you of your current location.
2443async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2444 let room_id = RoomId::from_proto(request.room_id);
2445 let database = session.db.lock().await;
2446
2447 let connection_ids = if let Some(project_id) = request.project_id {
2448 let project_id = ProjectId::from_proto(project_id);
2449 database
2450 .project_connection_ids(project_id, session.connection_id, true)
2451 .await?
2452 } else {
2453 database
2454 .room_connection_ids(room_id, session.connection_id)
2455 .await?
2456 };
2457
2458 // For now, don't send view update messages back to that view's current leader.
2459 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2460 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2461 _ => None,
2462 });
2463
2464 for connection_id in connection_ids.iter().cloned() {
2465 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2466 session
2467 .peer
2468 .forward_send(session.connection_id, connection_id, request.clone())?;
2469 }
2470 }
2471 Ok(())
2472}
2473
2474/// Get public data about users.
2475async fn get_users(
2476 request: proto::GetUsers,
2477 response: Response<proto::GetUsers>,
2478 session: Session,
2479) -> Result<()> {
2480 let user_ids = request
2481 .user_ids
2482 .into_iter()
2483 .map(UserId::from_proto)
2484 .collect();
2485 let users = session
2486 .db()
2487 .await
2488 .get_users_by_ids(user_ids)
2489 .await?
2490 .into_iter()
2491 .map(|user| proto::User {
2492 id: user.id.to_proto(),
2493 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2494 github_login: user.github_login,
2495 email: user.email_address,
2496 name: user.name,
2497 })
2498 .collect();
2499 response.send(proto::UsersResponse { users })?;
2500 Ok(())
2501}
2502
2503/// Search for users (to invite) buy Github login
2504async fn fuzzy_search_users(
2505 request: proto::FuzzySearchUsers,
2506 response: Response<proto::FuzzySearchUsers>,
2507 session: Session,
2508) -> Result<()> {
2509 let query = request.query;
2510 let users = match query.len() {
2511 0 => vec![],
2512 1 | 2 => session
2513 .db()
2514 .await
2515 .get_user_by_github_login(&query)
2516 .await?
2517 .into_iter()
2518 .collect(),
2519 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2520 };
2521 let users = users
2522 .into_iter()
2523 .filter(|user| user.id != session.user_id())
2524 .map(|user| proto::User {
2525 id: user.id.to_proto(),
2526 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2527 github_login: user.github_login,
2528 name: user.name,
2529 email: user.email_address,
2530 })
2531 .collect();
2532 response.send(proto::UsersResponse { users })?;
2533 Ok(())
2534}
2535
2536/// Send a contact request to another user.
2537async fn request_contact(
2538 request: proto::RequestContact,
2539 response: Response<proto::RequestContact>,
2540 session: Session,
2541) -> Result<()> {
2542 let requester_id = session.user_id();
2543 let responder_id = UserId::from_proto(request.responder_id);
2544 if requester_id == responder_id {
2545 return Err(anyhow!("cannot add yourself as a contact"))?;
2546 }
2547
2548 let notifications = session
2549 .db()
2550 .await
2551 .send_contact_request(requester_id, responder_id)
2552 .await?;
2553
2554 // Update outgoing contact requests of requester
2555 let mut update = proto::UpdateContacts::default();
2556 update.outgoing_requests.push(responder_id.to_proto());
2557 for connection_id in session
2558 .connection_pool()
2559 .await
2560 .user_connection_ids(requester_id)
2561 {
2562 session.peer.send(connection_id, update.clone())?;
2563 }
2564
2565 // Update incoming contact requests of responder
2566 let mut update = proto::UpdateContacts::default();
2567 update
2568 .incoming_requests
2569 .push(proto::IncomingContactRequest {
2570 requester_id: requester_id.to_proto(),
2571 });
2572 let connection_pool = session.connection_pool().await;
2573 for connection_id in connection_pool.user_connection_ids(responder_id) {
2574 session.peer.send(connection_id, update.clone())?;
2575 }
2576
2577 send_notifications(&connection_pool, &session.peer, notifications);
2578
2579 response.send(proto::Ack {})?;
2580 Ok(())
2581}
2582
2583/// Accept or decline a contact request
2584async fn respond_to_contact_request(
2585 request: proto::RespondToContactRequest,
2586 response: Response<proto::RespondToContactRequest>,
2587 session: Session,
2588) -> Result<()> {
2589 let responder_id = session.user_id();
2590 let requester_id = UserId::from_proto(request.requester_id);
2591 let db = session.db().await;
2592 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2593 db.dismiss_contact_notification(responder_id, requester_id)
2594 .await?;
2595 } else {
2596 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2597
2598 let notifications = db
2599 .respond_to_contact_request(responder_id, requester_id, accept)
2600 .await?;
2601 let requester_busy = db.is_user_busy(requester_id).await?;
2602 let responder_busy = db.is_user_busy(responder_id).await?;
2603
2604 let pool = session.connection_pool().await;
2605 // Update responder with new contact
2606 let mut update = proto::UpdateContacts::default();
2607 if accept {
2608 update
2609 .contacts
2610 .push(contact_for_user(requester_id, requester_busy, &pool));
2611 }
2612 update
2613 .remove_incoming_requests
2614 .push(requester_id.to_proto());
2615 for connection_id in pool.user_connection_ids(responder_id) {
2616 session.peer.send(connection_id, update.clone())?;
2617 }
2618
2619 // Update requester with new contact
2620 let mut update = proto::UpdateContacts::default();
2621 if accept {
2622 update
2623 .contacts
2624 .push(contact_for_user(responder_id, responder_busy, &pool));
2625 }
2626 update
2627 .remove_outgoing_requests
2628 .push(responder_id.to_proto());
2629
2630 for connection_id in pool.user_connection_ids(requester_id) {
2631 session.peer.send(connection_id, update.clone())?;
2632 }
2633
2634 send_notifications(&pool, &session.peer, notifications);
2635 }
2636
2637 response.send(proto::Ack {})?;
2638 Ok(())
2639}
2640
2641/// Remove a contact.
2642async fn remove_contact(
2643 request: proto::RemoveContact,
2644 response: Response<proto::RemoveContact>,
2645 session: Session,
2646) -> Result<()> {
2647 let requester_id = session.user_id();
2648 let responder_id = UserId::from_proto(request.user_id);
2649 let db = session.db().await;
2650 let (contact_accepted, deleted_notification_id) =
2651 db.remove_contact(requester_id, responder_id).await?;
2652
2653 let pool = session.connection_pool().await;
2654 // Update outgoing contact requests of requester
2655 let mut update = proto::UpdateContacts::default();
2656 if contact_accepted {
2657 update.remove_contacts.push(responder_id.to_proto());
2658 } else {
2659 update
2660 .remove_outgoing_requests
2661 .push(responder_id.to_proto());
2662 }
2663 for connection_id in pool.user_connection_ids(requester_id) {
2664 session.peer.send(connection_id, update.clone())?;
2665 }
2666
2667 // Update incoming contact requests of responder
2668 let mut update = proto::UpdateContacts::default();
2669 if contact_accepted {
2670 update.remove_contacts.push(requester_id.to_proto());
2671 } else {
2672 update
2673 .remove_incoming_requests
2674 .push(requester_id.to_proto());
2675 }
2676 for connection_id in pool.user_connection_ids(responder_id) {
2677 session.peer.send(connection_id, update.clone())?;
2678 if let Some(notification_id) = deleted_notification_id {
2679 session.peer.send(
2680 connection_id,
2681 proto::DeleteNotification {
2682 notification_id: notification_id.to_proto(),
2683 },
2684 )?;
2685 }
2686 }
2687
2688 response.send(proto::Ack {})?;
2689 Ok(())
2690}
2691
2692fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2693 version.0.minor() < 139
2694}
2695
2696async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2697 if is_staff {
2698 return Ok(proto::Plan::ZedPro);
2699 }
2700
2701 let subscription = db.get_active_billing_subscription(user_id).await?;
2702 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2703
2704 let plan = if let Some(subscription_kind) = subscription_kind {
2705 match subscription_kind {
2706 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2707 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2708 SubscriptionKind::ZedFree => proto::Plan::Free,
2709 }
2710 } else {
2711 proto::Plan::Free
2712 };
2713
2714 Ok(plan)
2715}
2716
2717async fn make_update_user_plan_message(
2718 db: &Arc<Database>,
2719 llm_db: Option<Arc<LlmDatabase>>,
2720 user_id: UserId,
2721 is_staff: bool,
2722) -> Result<proto::UpdateUserPlan> {
2723 let feature_flags = db.get_user_flags(user_id).await?;
2724 let plan = current_plan(db, user_id, is_staff).await?;
2725 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2726 let billing_preferences = db.get_billing_preferences(user_id).await?;
2727
2728 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2729 let subscription = db.get_active_billing_subscription(user_id).await?;
2730
2731 let subscription_period =
2732 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2733
2734 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2735 llm_db
2736 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2737 .await?
2738 } else {
2739 None
2740 };
2741
2742 (subscription_period, usage)
2743 } else {
2744 (None, None)
2745 };
2746
2747 Ok(proto::UpdateUserPlan {
2748 plan: plan.into(),
2749 trial_started_at: billing_customer
2750 .and_then(|billing_customer| billing_customer.trial_started_at)
2751 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2752 is_usage_based_billing_enabled: if is_staff {
2753 Some(true)
2754 } else {
2755 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2756 },
2757 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2758 proto::SubscriptionPeriod {
2759 started_at: started_at.timestamp() as u64,
2760 ended_at: ended_at.timestamp() as u64,
2761 }
2762 }),
2763 usage: usage.map(|usage| {
2764 let plan = match plan {
2765 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2766 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2767 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2768 };
2769
2770 let model_requests_limit = match plan.model_requests_limit() {
2771 zed_llm_client::UsageLimit::Limited(limit) => {
2772 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2773 && feature_flags
2774 .iter()
2775 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2776 {
2777 1_000
2778 } else {
2779 limit
2780 };
2781
2782 zed_llm_client::UsageLimit::Limited(limit)
2783 }
2784 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2785 };
2786
2787 proto::SubscriptionUsage {
2788 model_requests_usage_amount: usage.model_requests as u32,
2789 model_requests_usage_limit: Some(proto::UsageLimit {
2790 variant: Some(match model_requests_limit {
2791 zed_llm_client::UsageLimit::Limited(limit) => {
2792 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2793 limit: limit as u32,
2794 })
2795 }
2796 zed_llm_client::UsageLimit::Unlimited => {
2797 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2798 }
2799 }),
2800 }),
2801 edit_predictions_usage_amount: usage.edit_predictions as u32,
2802 edit_predictions_usage_limit: Some(proto::UsageLimit {
2803 variant: Some(match plan.edit_predictions_limit() {
2804 zed_llm_client::UsageLimit::Limited(limit) => {
2805 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2806 limit: limit as u32,
2807 })
2808 }
2809 zed_llm_client::UsageLimit::Unlimited => {
2810 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2811 }
2812 }),
2813 }),
2814 }
2815 }),
2816 })
2817}
2818
2819async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2820 let db = session.db().await;
2821
2822 let update_user_plan = make_update_user_plan_message(
2823 &db.0,
2824 session.app_state.llm_db.clone(),
2825 user_id,
2826 session.is_staff(),
2827 )
2828 .await?;
2829
2830 session
2831 .peer
2832 .send(session.connection_id, update_user_plan)
2833 .trace_err();
2834
2835 Ok(())
2836}
2837
2838async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2839 subscribe_user_to_channels(session.user_id(), &session).await?;
2840 Ok(())
2841}
2842
2843async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2844 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2845 let mut pool = session.connection_pool().await;
2846 for membership in &channels_for_user.channel_memberships {
2847 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2848 }
2849 session.peer.send(
2850 session.connection_id,
2851 build_update_user_channels(&channels_for_user),
2852 )?;
2853 session.peer.send(
2854 session.connection_id,
2855 build_channels_update(channels_for_user),
2856 )?;
2857 Ok(())
2858}
2859
2860/// Creates a new channel.
2861async fn create_channel(
2862 request: proto::CreateChannel,
2863 response: Response<proto::CreateChannel>,
2864 session: Session,
2865) -> Result<()> {
2866 let db = session.db().await;
2867
2868 let parent_id = request.parent_id.map(ChannelId::from_proto);
2869 let (channel, membership) = db
2870 .create_channel(&request.name, parent_id, session.user_id())
2871 .await?;
2872
2873 let root_id = channel.root_id();
2874 let channel = Channel::from_model(channel);
2875
2876 response.send(proto::CreateChannelResponse {
2877 channel: Some(channel.to_proto()),
2878 parent_id: request.parent_id,
2879 })?;
2880
2881 let mut connection_pool = session.connection_pool().await;
2882 if let Some(membership) = membership {
2883 connection_pool.subscribe_to_channel(
2884 membership.user_id,
2885 membership.channel_id,
2886 membership.role,
2887 );
2888 let update = proto::UpdateUserChannels {
2889 channel_memberships: vec![proto::ChannelMembership {
2890 channel_id: membership.channel_id.to_proto(),
2891 role: membership.role.into(),
2892 }],
2893 ..Default::default()
2894 };
2895 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2896 session.peer.send(connection_id, update.clone())?;
2897 }
2898 }
2899
2900 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2901 if !role.can_see_channel(channel.visibility) {
2902 continue;
2903 }
2904
2905 let update = proto::UpdateChannels {
2906 channels: vec![channel.to_proto()],
2907 ..Default::default()
2908 };
2909 session.peer.send(connection_id, update.clone())?;
2910 }
2911
2912 Ok(())
2913}
2914
2915/// Delete a channel
2916async fn delete_channel(
2917 request: proto::DeleteChannel,
2918 response: Response<proto::DeleteChannel>,
2919 session: Session,
2920) -> Result<()> {
2921 let db = session.db().await;
2922
2923 let channel_id = request.channel_id;
2924 let (root_channel, removed_channels) = db
2925 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2926 .await?;
2927 response.send(proto::Ack {})?;
2928
2929 // Notify members of removed channels
2930 let mut update = proto::UpdateChannels::default();
2931 update
2932 .delete_channels
2933 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2934
2935 let connection_pool = session.connection_pool().await;
2936 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2937 session.peer.send(connection_id, update.clone())?;
2938 }
2939
2940 Ok(())
2941}
2942
2943/// Invite someone to join a channel.
2944async fn invite_channel_member(
2945 request: proto::InviteChannelMember,
2946 response: Response<proto::InviteChannelMember>,
2947 session: Session,
2948) -> Result<()> {
2949 let db = session.db().await;
2950 let channel_id = ChannelId::from_proto(request.channel_id);
2951 let invitee_id = UserId::from_proto(request.user_id);
2952 let InviteMemberResult {
2953 channel,
2954 notifications,
2955 } = db
2956 .invite_channel_member(
2957 channel_id,
2958 invitee_id,
2959 session.user_id(),
2960 request.role().into(),
2961 )
2962 .await?;
2963
2964 let update = proto::UpdateChannels {
2965 channel_invitations: vec![channel.to_proto()],
2966 ..Default::default()
2967 };
2968
2969 let connection_pool = session.connection_pool().await;
2970 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2971 session.peer.send(connection_id, update.clone())?;
2972 }
2973
2974 send_notifications(&connection_pool, &session.peer, notifications);
2975
2976 response.send(proto::Ack {})?;
2977 Ok(())
2978}
2979
2980/// remove someone from a channel
2981async fn remove_channel_member(
2982 request: proto::RemoveChannelMember,
2983 response: Response<proto::RemoveChannelMember>,
2984 session: Session,
2985) -> Result<()> {
2986 let db = session.db().await;
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let member_id = UserId::from_proto(request.user_id);
2989
2990 let RemoveChannelMemberResult {
2991 membership_update,
2992 notification_id,
2993 } = db
2994 .remove_channel_member(channel_id, member_id, session.user_id())
2995 .await?;
2996
2997 let mut connection_pool = session.connection_pool().await;
2998 notify_membership_updated(
2999 &mut connection_pool,
3000 membership_update,
3001 member_id,
3002 &session.peer,
3003 );
3004 for connection_id in connection_pool.user_connection_ids(member_id) {
3005 if let Some(notification_id) = notification_id {
3006 session
3007 .peer
3008 .send(
3009 connection_id,
3010 proto::DeleteNotification {
3011 notification_id: notification_id.to_proto(),
3012 },
3013 )
3014 .trace_err();
3015 }
3016 }
3017
3018 response.send(proto::Ack {})?;
3019 Ok(())
3020}
3021
3022/// Toggle the channel between public and private.
3023/// Care is taken to maintain the invariant that public channels only descend from public channels,
3024/// (though members-only channels can appear at any point in the hierarchy).
3025async fn set_channel_visibility(
3026 request: proto::SetChannelVisibility,
3027 response: Response<proto::SetChannelVisibility>,
3028 session: Session,
3029) -> Result<()> {
3030 let db = session.db().await;
3031 let channel_id = ChannelId::from_proto(request.channel_id);
3032 let visibility = request.visibility().into();
3033
3034 let channel_model = db
3035 .set_channel_visibility(channel_id, visibility, session.user_id())
3036 .await?;
3037 let root_id = channel_model.root_id();
3038 let channel = Channel::from_model(channel_model);
3039
3040 let mut connection_pool = session.connection_pool().await;
3041 for (user_id, role) in connection_pool
3042 .channel_user_ids(root_id)
3043 .collect::<Vec<_>>()
3044 .into_iter()
3045 {
3046 let update = if role.can_see_channel(channel.visibility) {
3047 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3048 proto::UpdateChannels {
3049 channels: vec![channel.to_proto()],
3050 ..Default::default()
3051 }
3052 } else {
3053 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3054 proto::UpdateChannels {
3055 delete_channels: vec![channel.id.to_proto()],
3056 ..Default::default()
3057 }
3058 };
3059
3060 for connection_id in connection_pool.user_connection_ids(user_id) {
3061 session.peer.send(connection_id, update.clone())?;
3062 }
3063 }
3064
3065 response.send(proto::Ack {})?;
3066 Ok(())
3067}
3068
3069/// Alter the role for a user in the channel.
3070async fn set_channel_member_role(
3071 request: proto::SetChannelMemberRole,
3072 response: Response<proto::SetChannelMemberRole>,
3073 session: Session,
3074) -> Result<()> {
3075 let db = session.db().await;
3076 let channel_id = ChannelId::from_proto(request.channel_id);
3077 let member_id = UserId::from_proto(request.user_id);
3078 let result = db
3079 .set_channel_member_role(
3080 channel_id,
3081 session.user_id(),
3082 member_id,
3083 request.role().into(),
3084 )
3085 .await?;
3086
3087 match result {
3088 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3089 let mut connection_pool = session.connection_pool().await;
3090 notify_membership_updated(
3091 &mut connection_pool,
3092 membership_update,
3093 member_id,
3094 &session.peer,
3095 )
3096 }
3097 db::SetMemberRoleResult::InviteUpdated(channel) => {
3098 let update = proto::UpdateChannels {
3099 channel_invitations: vec![channel.to_proto()],
3100 ..Default::default()
3101 };
3102
3103 for connection_id in session
3104 .connection_pool()
3105 .await
3106 .user_connection_ids(member_id)
3107 {
3108 session.peer.send(connection_id, update.clone())?;
3109 }
3110 }
3111 }
3112
3113 response.send(proto::Ack {})?;
3114 Ok(())
3115}
3116
3117/// Change the name of a channel
3118async fn rename_channel(
3119 request: proto::RenameChannel,
3120 response: Response<proto::RenameChannel>,
3121 session: Session,
3122) -> Result<()> {
3123 let db = session.db().await;
3124 let channel_id = ChannelId::from_proto(request.channel_id);
3125 let channel_model = db
3126 .rename_channel(channel_id, session.user_id(), &request.name)
3127 .await?;
3128 let root_id = channel_model.root_id();
3129 let channel = Channel::from_model(channel_model);
3130
3131 response.send(proto::RenameChannelResponse {
3132 channel: Some(channel.to_proto()),
3133 })?;
3134
3135 let connection_pool = session.connection_pool().await;
3136 let update = proto::UpdateChannels {
3137 channels: vec![channel.to_proto()],
3138 ..Default::default()
3139 };
3140 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3141 if role.can_see_channel(channel.visibility) {
3142 session.peer.send(connection_id, update.clone())?;
3143 }
3144 }
3145
3146 Ok(())
3147}
3148
3149/// Move a channel to a new parent.
3150async fn move_channel(
3151 request: proto::MoveChannel,
3152 response: Response<proto::MoveChannel>,
3153 session: Session,
3154) -> Result<()> {
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let to = ChannelId::from_proto(request.to);
3157
3158 let (root_id, channels) = session
3159 .db()
3160 .await
3161 .move_channel(channel_id, to, session.user_id())
3162 .await?;
3163
3164 let connection_pool = session.connection_pool().await;
3165 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3166 let channels = channels
3167 .iter()
3168 .filter_map(|channel| {
3169 if role.can_see_channel(channel.visibility) {
3170 Some(channel.to_proto())
3171 } else {
3172 None
3173 }
3174 })
3175 .collect::<Vec<_>>();
3176 if channels.is_empty() {
3177 continue;
3178 }
3179
3180 let update = proto::UpdateChannels {
3181 channels,
3182 ..Default::default()
3183 };
3184
3185 session.peer.send(connection_id, update.clone())?;
3186 }
3187
3188 response.send(Ack {})?;
3189 Ok(())
3190}
3191
3192/// Get the list of channel members
3193async fn get_channel_members(
3194 request: proto::GetChannelMembers,
3195 response: Response<proto::GetChannelMembers>,
3196 session: Session,
3197) -> Result<()> {
3198 let db = session.db().await;
3199 let channel_id = ChannelId::from_proto(request.channel_id);
3200 let limit = if request.limit == 0 {
3201 u16::MAX as u64
3202 } else {
3203 request.limit
3204 };
3205 let (members, users) = db
3206 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3207 .await?;
3208 response.send(proto::GetChannelMembersResponse { members, users })?;
3209 Ok(())
3210}
3211
3212/// Accept or decline a channel invitation.
3213async fn respond_to_channel_invite(
3214 request: proto::RespondToChannelInvite,
3215 response: Response<proto::RespondToChannelInvite>,
3216 session: Session,
3217) -> Result<()> {
3218 let db = session.db().await;
3219 let channel_id = ChannelId::from_proto(request.channel_id);
3220 let RespondToChannelInvite {
3221 membership_update,
3222 notifications,
3223 } = db
3224 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3225 .await?;
3226
3227 let mut connection_pool = session.connection_pool().await;
3228 if let Some(membership_update) = membership_update {
3229 notify_membership_updated(
3230 &mut connection_pool,
3231 membership_update,
3232 session.user_id(),
3233 &session.peer,
3234 );
3235 } else {
3236 let update = proto::UpdateChannels {
3237 remove_channel_invitations: vec![channel_id.to_proto()],
3238 ..Default::default()
3239 };
3240
3241 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3242 session.peer.send(connection_id, update.clone())?;
3243 }
3244 };
3245
3246 send_notifications(&connection_pool, &session.peer, notifications);
3247
3248 response.send(proto::Ack {})?;
3249
3250 Ok(())
3251}
3252
3253/// Join the channels' room
3254async fn join_channel(
3255 request: proto::JoinChannel,
3256 response: Response<proto::JoinChannel>,
3257 session: Session,
3258) -> Result<()> {
3259 let channel_id = ChannelId::from_proto(request.channel_id);
3260 join_channel_internal(channel_id, Box::new(response), session).await
3261}
3262
3263trait JoinChannelInternalResponse {
3264 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3265}
3266impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3267 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3268 Response::<proto::JoinChannel>::send(self, result)
3269 }
3270}
3271impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3272 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3273 Response::<proto::JoinRoom>::send(self, result)
3274 }
3275}
3276
3277async fn join_channel_internal(
3278 channel_id: ChannelId,
3279 response: Box<impl JoinChannelInternalResponse>,
3280 session: Session,
3281) -> Result<()> {
3282 let joined_room = {
3283 let mut db = session.db().await;
3284 // If zed quits without leaving the room, and the user re-opens zed before the
3285 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3286 // room they were in.
3287 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3288 tracing::info!(
3289 stale_connection_id = %connection,
3290 "cleaning up stale connection",
3291 );
3292 drop(db);
3293 leave_room_for_session(&session, connection).await?;
3294 db = session.db().await;
3295 }
3296
3297 let (joined_room, membership_updated, role) = db
3298 .join_channel(channel_id, session.user_id(), session.connection_id)
3299 .await?;
3300
3301 let live_kit_connection_info =
3302 session
3303 .app_state
3304 .livekit_client
3305 .as_ref()
3306 .and_then(|live_kit| {
3307 let (can_publish, token) = if role == ChannelRole::Guest {
3308 (
3309 false,
3310 live_kit
3311 .guest_token(
3312 &joined_room.room.livekit_room,
3313 &session.user_id().to_string(),
3314 )
3315 .trace_err()?,
3316 )
3317 } else {
3318 (
3319 true,
3320 live_kit
3321 .room_token(
3322 &joined_room.room.livekit_room,
3323 &session.user_id().to_string(),
3324 )
3325 .trace_err()?,
3326 )
3327 };
3328
3329 Some(LiveKitConnectionInfo {
3330 server_url: live_kit.url().into(),
3331 token,
3332 can_publish,
3333 })
3334 });
3335
3336 response.send(proto::JoinRoomResponse {
3337 room: Some(joined_room.room.clone()),
3338 channel_id: joined_room
3339 .channel
3340 .as_ref()
3341 .map(|channel| channel.id.to_proto()),
3342 live_kit_connection_info,
3343 })?;
3344
3345 let mut connection_pool = session.connection_pool().await;
3346 if let Some(membership_updated) = membership_updated {
3347 notify_membership_updated(
3348 &mut connection_pool,
3349 membership_updated,
3350 session.user_id(),
3351 &session.peer,
3352 );
3353 }
3354
3355 room_updated(&joined_room.room, &session.peer);
3356
3357 joined_room
3358 };
3359
3360 channel_updated(
3361 &joined_room
3362 .channel
3363 .ok_or_else(|| anyhow!("channel not returned"))?,
3364 &joined_room.room,
3365 &session.peer,
3366 &*session.connection_pool().await,
3367 );
3368
3369 update_user_contacts(session.user_id(), &session).await?;
3370 Ok(())
3371}
3372
3373/// Start editing the channel notes
3374async fn join_channel_buffer(
3375 request: proto::JoinChannelBuffer,
3376 response: Response<proto::JoinChannelBuffer>,
3377 session: Session,
3378) -> Result<()> {
3379 let db = session.db().await;
3380 let channel_id = ChannelId::from_proto(request.channel_id);
3381
3382 let open_response = db
3383 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3384 .await?;
3385
3386 let collaborators = open_response.collaborators.clone();
3387 response.send(open_response)?;
3388
3389 let update = UpdateChannelBufferCollaborators {
3390 channel_id: channel_id.to_proto(),
3391 collaborators: collaborators.clone(),
3392 };
3393 channel_buffer_updated(
3394 session.connection_id,
3395 collaborators
3396 .iter()
3397 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3398 &update,
3399 &session.peer,
3400 );
3401
3402 Ok(())
3403}
3404
3405/// Edit the channel notes
3406async fn update_channel_buffer(
3407 request: proto::UpdateChannelBuffer,
3408 session: Session,
3409) -> Result<()> {
3410 let db = session.db().await;
3411 let channel_id = ChannelId::from_proto(request.channel_id);
3412
3413 let (collaborators, epoch, version) = db
3414 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3415 .await?;
3416
3417 channel_buffer_updated(
3418 session.connection_id,
3419 collaborators.clone(),
3420 &proto::UpdateChannelBuffer {
3421 channel_id: channel_id.to_proto(),
3422 operations: request.operations,
3423 },
3424 &session.peer,
3425 );
3426
3427 let pool = &*session.connection_pool().await;
3428
3429 let non_collaborators =
3430 pool.channel_connection_ids(channel_id)
3431 .filter_map(|(connection_id, _)| {
3432 if collaborators.contains(&connection_id) {
3433 None
3434 } else {
3435 Some(connection_id)
3436 }
3437 });
3438
3439 broadcast(None, non_collaborators, |peer_id| {
3440 session.peer.send(
3441 peer_id,
3442 proto::UpdateChannels {
3443 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3444 channel_id: channel_id.to_proto(),
3445 epoch: epoch as u64,
3446 version: version.clone(),
3447 }],
3448 ..Default::default()
3449 },
3450 )
3451 });
3452
3453 Ok(())
3454}
3455
3456/// Rejoin the channel notes after a connection blip
3457async fn rejoin_channel_buffers(
3458 request: proto::RejoinChannelBuffers,
3459 response: Response<proto::RejoinChannelBuffers>,
3460 session: Session,
3461) -> Result<()> {
3462 let db = session.db().await;
3463 let buffers = db
3464 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3465 .await?;
3466
3467 for rejoined_buffer in &buffers {
3468 let collaborators_to_notify = rejoined_buffer
3469 .buffer
3470 .collaborators
3471 .iter()
3472 .filter_map(|c| Some(c.peer_id?.into()));
3473 channel_buffer_updated(
3474 session.connection_id,
3475 collaborators_to_notify,
3476 &proto::UpdateChannelBufferCollaborators {
3477 channel_id: rejoined_buffer.buffer.channel_id,
3478 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3479 },
3480 &session.peer,
3481 );
3482 }
3483
3484 response.send(proto::RejoinChannelBuffersResponse {
3485 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3486 })?;
3487
3488 Ok(())
3489}
3490
3491/// Stop editing the channel notes
3492async fn leave_channel_buffer(
3493 request: proto::LeaveChannelBuffer,
3494 response: Response<proto::LeaveChannelBuffer>,
3495 session: Session,
3496) -> Result<()> {
3497 let db = session.db().await;
3498 let channel_id = ChannelId::from_proto(request.channel_id);
3499
3500 let left_buffer = db
3501 .leave_channel_buffer(channel_id, session.connection_id)
3502 .await?;
3503
3504 response.send(Ack {})?;
3505
3506 channel_buffer_updated(
3507 session.connection_id,
3508 left_buffer.connections,
3509 &proto::UpdateChannelBufferCollaborators {
3510 channel_id: channel_id.to_proto(),
3511 collaborators: left_buffer.collaborators,
3512 },
3513 &session.peer,
3514 );
3515
3516 Ok(())
3517}
3518
3519fn channel_buffer_updated<T: EnvelopedMessage>(
3520 sender_id: ConnectionId,
3521 collaborators: impl IntoIterator<Item = ConnectionId>,
3522 message: &T,
3523 peer: &Peer,
3524) {
3525 broadcast(Some(sender_id), collaborators, |peer_id| {
3526 peer.send(peer_id, message.clone())
3527 });
3528}
3529
3530fn send_notifications(
3531 connection_pool: &ConnectionPool,
3532 peer: &Peer,
3533 notifications: db::NotificationBatch,
3534) {
3535 for (user_id, notification) in notifications {
3536 for connection_id in connection_pool.user_connection_ids(user_id) {
3537 if let Err(error) = peer.send(
3538 connection_id,
3539 proto::AddNotification {
3540 notification: Some(notification.clone()),
3541 },
3542 ) {
3543 tracing::error!(
3544 "failed to send notification to {:?} {}",
3545 connection_id,
3546 error
3547 );
3548 }
3549 }
3550 }
3551}
3552
3553/// Send a message to the channel
3554async fn send_channel_message(
3555 request: proto::SendChannelMessage,
3556 response: Response<proto::SendChannelMessage>,
3557 session: Session,
3558) -> Result<()> {
3559 // Validate the message body.
3560 let body = request.body.trim().to_string();
3561 if body.len() > MAX_MESSAGE_LEN {
3562 return Err(anyhow!("message is too long"))?;
3563 }
3564 if body.is_empty() {
3565 return Err(anyhow!("message can't be blank"))?;
3566 }
3567
3568 // TODO: adjust mentions if body is trimmed
3569
3570 let timestamp = OffsetDateTime::now_utc();
3571 let nonce = request
3572 .nonce
3573 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3574
3575 let channel_id = ChannelId::from_proto(request.channel_id);
3576 let CreatedChannelMessage {
3577 message_id,
3578 participant_connection_ids,
3579 notifications,
3580 } = session
3581 .db()
3582 .await
3583 .create_channel_message(
3584 channel_id,
3585 session.user_id(),
3586 &body,
3587 &request.mentions,
3588 timestamp,
3589 nonce.clone().into(),
3590 request.reply_to_message_id.map(MessageId::from_proto),
3591 )
3592 .await?;
3593
3594 let message = proto::ChannelMessage {
3595 sender_id: session.user_id().to_proto(),
3596 id: message_id.to_proto(),
3597 body,
3598 mentions: request.mentions,
3599 timestamp: timestamp.unix_timestamp() as u64,
3600 nonce: Some(nonce),
3601 reply_to_message_id: request.reply_to_message_id,
3602 edited_at: None,
3603 };
3604 broadcast(
3605 Some(session.connection_id),
3606 participant_connection_ids.clone(),
3607 |connection| {
3608 session.peer.send(
3609 connection,
3610 proto::ChannelMessageSent {
3611 channel_id: channel_id.to_proto(),
3612 message: Some(message.clone()),
3613 },
3614 )
3615 },
3616 );
3617 response.send(proto::SendChannelMessageResponse {
3618 message: Some(message),
3619 })?;
3620
3621 let pool = &*session.connection_pool().await;
3622 let non_participants =
3623 pool.channel_connection_ids(channel_id)
3624 .filter_map(|(connection_id, _)| {
3625 if participant_connection_ids.contains(&connection_id) {
3626 None
3627 } else {
3628 Some(connection_id)
3629 }
3630 });
3631 broadcast(None, non_participants, |peer_id| {
3632 session.peer.send(
3633 peer_id,
3634 proto::UpdateChannels {
3635 latest_channel_message_ids: vec![proto::ChannelMessageId {
3636 channel_id: channel_id.to_proto(),
3637 message_id: message_id.to_proto(),
3638 }],
3639 ..Default::default()
3640 },
3641 )
3642 });
3643 send_notifications(pool, &session.peer, notifications);
3644
3645 Ok(())
3646}
3647
3648/// Delete a channel message
3649async fn remove_channel_message(
3650 request: proto::RemoveChannelMessage,
3651 response: Response<proto::RemoveChannelMessage>,
3652 session: Session,
3653) -> Result<()> {
3654 let channel_id = ChannelId::from_proto(request.channel_id);
3655 let message_id = MessageId::from_proto(request.message_id);
3656 let (connection_ids, existing_notification_ids) = session
3657 .db()
3658 .await
3659 .remove_channel_message(channel_id, message_id, session.user_id())
3660 .await?;
3661
3662 broadcast(
3663 Some(session.connection_id),
3664 connection_ids,
3665 move |connection| {
3666 session.peer.send(connection, request.clone())?;
3667
3668 for notification_id in &existing_notification_ids {
3669 session.peer.send(
3670 connection,
3671 proto::DeleteNotification {
3672 notification_id: (*notification_id).to_proto(),
3673 },
3674 )?;
3675 }
3676
3677 Ok(())
3678 },
3679 );
3680 response.send(proto::Ack {})?;
3681 Ok(())
3682}
3683
3684async fn update_channel_message(
3685 request: proto::UpdateChannelMessage,
3686 response: Response<proto::UpdateChannelMessage>,
3687 session: Session,
3688) -> Result<()> {
3689 let channel_id = ChannelId::from_proto(request.channel_id);
3690 let message_id = MessageId::from_proto(request.message_id);
3691 let updated_at = OffsetDateTime::now_utc();
3692 let UpdatedChannelMessage {
3693 message_id,
3694 participant_connection_ids,
3695 notifications,
3696 reply_to_message_id,
3697 timestamp,
3698 deleted_mention_notification_ids,
3699 updated_mention_notifications,
3700 } = session
3701 .db()
3702 .await
3703 .update_channel_message(
3704 channel_id,
3705 message_id,
3706 session.user_id(),
3707 request.body.as_str(),
3708 &request.mentions,
3709 updated_at,
3710 )
3711 .await?;
3712
3713 let nonce = request
3714 .nonce
3715 .clone()
3716 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3717
3718 let message = proto::ChannelMessage {
3719 sender_id: session.user_id().to_proto(),
3720 id: message_id.to_proto(),
3721 body: request.body.clone(),
3722 mentions: request.mentions.clone(),
3723 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3724 nonce: Some(nonce),
3725 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3726 edited_at: Some(updated_at.unix_timestamp() as u64),
3727 };
3728
3729 response.send(proto::Ack {})?;
3730
3731 let pool = &*session.connection_pool().await;
3732 broadcast(
3733 Some(session.connection_id),
3734 participant_connection_ids,
3735 |connection| {
3736 session.peer.send(
3737 connection,
3738 proto::ChannelMessageUpdate {
3739 channel_id: channel_id.to_proto(),
3740 message: Some(message.clone()),
3741 },
3742 )?;
3743
3744 for notification_id in &deleted_mention_notification_ids {
3745 session.peer.send(
3746 connection,
3747 proto::DeleteNotification {
3748 notification_id: (*notification_id).to_proto(),
3749 },
3750 )?;
3751 }
3752
3753 for notification in &updated_mention_notifications {
3754 session.peer.send(
3755 connection,
3756 proto::UpdateNotification {
3757 notification: Some(notification.clone()),
3758 },
3759 )?;
3760 }
3761
3762 Ok(())
3763 },
3764 );
3765
3766 send_notifications(pool, &session.peer, notifications);
3767
3768 Ok(())
3769}
3770
3771/// Mark a channel message as read
3772async fn acknowledge_channel_message(
3773 request: proto::AckChannelMessage,
3774 session: Session,
3775) -> Result<()> {
3776 let channel_id = ChannelId::from_proto(request.channel_id);
3777 let message_id = MessageId::from_proto(request.message_id);
3778 let notifications = session
3779 .db()
3780 .await
3781 .observe_channel_message(channel_id, session.user_id(), message_id)
3782 .await?;
3783 send_notifications(
3784 &*session.connection_pool().await,
3785 &session.peer,
3786 notifications,
3787 );
3788 Ok(())
3789}
3790
3791/// Mark a buffer version as synced
3792async fn acknowledge_buffer_version(
3793 request: proto::AckBufferOperation,
3794 session: Session,
3795) -> Result<()> {
3796 let buffer_id = BufferId::from_proto(request.buffer_id);
3797 session
3798 .db()
3799 .await
3800 .observe_buffer_version(
3801 buffer_id,
3802 session.user_id(),
3803 request.epoch as i32,
3804 &request.version,
3805 )
3806 .await?;
3807 Ok(())
3808}
3809
3810/// Get a Supermaven API key for the user
3811async fn get_supermaven_api_key(
3812 _request: proto::GetSupermavenApiKey,
3813 response: Response<proto::GetSupermavenApiKey>,
3814 session: Session,
3815) -> Result<()> {
3816 let user_id: String = session.user_id().to_string();
3817 if !session.is_staff() {
3818 return Err(anyhow!("supermaven not enabled for this account"))?;
3819 }
3820
3821 let email = session
3822 .email()
3823 .ok_or_else(|| anyhow!("user must have an email"))?;
3824
3825 let supermaven_admin_api = session
3826 .supermaven_client
3827 .as_ref()
3828 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3829
3830 let result = supermaven_admin_api
3831 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3832 .await?;
3833
3834 response.send(proto::GetSupermavenApiKeyResponse {
3835 api_key: result.api_key,
3836 })?;
3837
3838 Ok(())
3839}
3840
3841/// Start receiving chat updates for a channel
3842async fn join_channel_chat(
3843 request: proto::JoinChannelChat,
3844 response: Response<proto::JoinChannelChat>,
3845 session: Session,
3846) -> Result<()> {
3847 let channel_id = ChannelId::from_proto(request.channel_id);
3848
3849 let db = session.db().await;
3850 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3851 .await?;
3852 let messages = db
3853 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3854 .await?;
3855 response.send(proto::JoinChannelChatResponse {
3856 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3857 messages,
3858 })?;
3859 Ok(())
3860}
3861
3862/// Stop receiving chat updates for a channel
3863async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3864 let channel_id = ChannelId::from_proto(request.channel_id);
3865 session
3866 .db()
3867 .await
3868 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3869 .await?;
3870 Ok(())
3871}
3872
3873/// Retrieve the chat history for a channel
3874async fn get_channel_messages(
3875 request: proto::GetChannelMessages,
3876 response: Response<proto::GetChannelMessages>,
3877 session: Session,
3878) -> Result<()> {
3879 let channel_id = ChannelId::from_proto(request.channel_id);
3880 let messages = session
3881 .db()
3882 .await
3883 .get_channel_messages(
3884 channel_id,
3885 session.user_id(),
3886 MESSAGE_COUNT_PER_PAGE,
3887 Some(MessageId::from_proto(request.before_message_id)),
3888 )
3889 .await?;
3890 response.send(proto::GetChannelMessagesResponse {
3891 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3892 messages,
3893 })?;
3894 Ok(())
3895}
3896
3897/// Retrieve specific chat messages
3898async fn get_channel_messages_by_id(
3899 request: proto::GetChannelMessagesById,
3900 response: Response<proto::GetChannelMessagesById>,
3901 session: Session,
3902) -> Result<()> {
3903 let message_ids = request
3904 .message_ids
3905 .iter()
3906 .map(|id| MessageId::from_proto(*id))
3907 .collect::<Vec<_>>();
3908 let messages = session
3909 .db()
3910 .await
3911 .get_channel_messages_by_id(session.user_id(), &message_ids)
3912 .await?;
3913 response.send(proto::GetChannelMessagesResponse {
3914 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3915 messages,
3916 })?;
3917 Ok(())
3918}
3919
3920/// Retrieve the current users notifications
3921async fn get_notifications(
3922 request: proto::GetNotifications,
3923 response: Response<proto::GetNotifications>,
3924 session: Session,
3925) -> Result<()> {
3926 let notifications = session
3927 .db()
3928 .await
3929 .get_notifications(
3930 session.user_id(),
3931 NOTIFICATION_COUNT_PER_PAGE,
3932 request.before_id.map(db::NotificationId::from_proto),
3933 )
3934 .await?;
3935 response.send(proto::GetNotificationsResponse {
3936 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3937 notifications,
3938 })?;
3939 Ok(())
3940}
3941
3942/// Mark notifications as read
3943async fn mark_notification_as_read(
3944 request: proto::MarkNotificationRead,
3945 response: Response<proto::MarkNotificationRead>,
3946 session: Session,
3947) -> Result<()> {
3948 let database = &session.db().await;
3949 let notifications = database
3950 .mark_notification_as_read_by_id(
3951 session.user_id(),
3952 NotificationId::from_proto(request.notification_id),
3953 )
3954 .await?;
3955 send_notifications(
3956 &*session.connection_pool().await,
3957 &session.peer,
3958 notifications,
3959 );
3960 response.send(proto::Ack {})?;
3961 Ok(())
3962}
3963
3964/// Get the current users information
3965async fn get_private_user_info(
3966 _request: proto::GetPrivateUserInfo,
3967 response: Response<proto::GetPrivateUserInfo>,
3968 session: Session,
3969) -> Result<()> {
3970 let db = session.db().await;
3971
3972 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3973 let user = db
3974 .get_user_by_id(session.user_id())
3975 .await?
3976 .ok_or_else(|| anyhow!("user not found"))?;
3977 let flags = db.get_user_flags(session.user_id()).await?;
3978
3979 response.send(proto::GetPrivateUserInfoResponse {
3980 metrics_id,
3981 staff: user.admin,
3982 flags,
3983 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3984 })?;
3985 Ok(())
3986}
3987
3988/// Accept the terms of service (tos) on behalf of the current user
3989async fn accept_terms_of_service(
3990 _request: proto::AcceptTermsOfService,
3991 response: Response<proto::AcceptTermsOfService>,
3992 session: Session,
3993) -> Result<()> {
3994 let db = session.db().await;
3995
3996 let accepted_tos_at = Utc::now();
3997 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3998 .await?;
3999
4000 response.send(proto::AcceptTermsOfServiceResponse {
4001 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4002 })?;
4003 Ok(())
4004}
4005
4006/// The minimum account age an account must have in order to use the LLM service.
4007pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4008
4009async fn get_llm_api_token(
4010 _request: proto::GetLlmToken,
4011 response: Response<proto::GetLlmToken>,
4012 session: Session,
4013) -> Result<()> {
4014 let db = session.db().await;
4015
4016 let flags = db.get_user_flags(session.user_id()).await?;
4017
4018 let user_id = session.user_id();
4019 let user = db
4020 .get_user_by_id(user_id)
4021 .await?
4022 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4023
4024 if user.accepted_tos_at.is_none() {
4025 Err(anyhow!("terms of service not accepted"))?
4026 }
4027
4028 let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
4029 Err(anyhow!("failed to retrieve Stripe client"))?
4030 };
4031
4032 let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
4033 Err(anyhow!("failed to retrieve Stripe billing object"))?
4034 };
4035
4036 let billing_customer =
4037 if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
4038 billing_customer
4039 } else {
4040 let customer_id = stripe_billing
4041 .find_or_create_customer_by_email(user.email_address.as_deref())
4042 .await?;
4043
4044 find_or_create_billing_customer(
4045 &session.app_state,
4046 &stripe_client,
4047 stripe::Expandable::Id(customer_id),
4048 )
4049 .await?
4050 .ok_or_else(|| anyhow!("billing customer not found"))?
4051 };
4052
4053 let billing_subscription =
4054 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4055 billing_subscription
4056 } else {
4057 let stripe_customer_id = billing_customer
4058 .stripe_customer_id
4059 .parse::<stripe::CustomerId>()
4060 .context("failed to parse Stripe customer ID from database")?;
4061
4062 let stripe_subscription = stripe_billing
4063 .subscribe_to_zed_free(stripe_customer_id)
4064 .await?;
4065
4066 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4067 billing_customer_id: billing_customer.id,
4068 kind: Some(SubscriptionKind::ZedFree),
4069 stripe_subscription_id: stripe_subscription.id.to_string(),
4070 stripe_subscription_status: stripe_subscription.status.into(),
4071 stripe_cancellation_reason: None,
4072 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4073 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4074 })
4075 .await?
4076 };
4077
4078 let billing_preferences = db.get_billing_preferences(user.id).await?;
4079
4080 let token = LlmTokenClaims::create(
4081 &user,
4082 session.is_staff(),
4083 billing_preferences,
4084 &flags,
4085 billing_subscription,
4086 session.system_id.clone(),
4087 &session.app_state.config,
4088 )?;
4089 response.send(proto::GetLlmTokenResponse { token })?;
4090 Ok(())
4091}
4092
4093fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4094 let message = match message {
4095 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4096 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4097 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4098 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4099 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4100 code: frame.code.into(),
4101 reason: frame.reason.as_str().to_owned().into(),
4102 })),
4103 // We should never receive a frame while reading the message, according
4104 // to the `tungstenite` maintainers:
4105 //
4106 // > It cannot occur when you read messages from the WebSocket, but it
4107 // > can be used when you want to send the raw frames (e.g. you want to
4108 // > send the frames to the WebSocket without composing the full message first).
4109 // >
4110 // > — https://github.com/snapview/tungstenite-rs/issues/268
4111 TungsteniteMessage::Frame(_) => {
4112 bail!("received an unexpected frame while reading the message")
4113 }
4114 };
4115
4116 Ok(message)
4117}
4118
4119fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4120 match message {
4121 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4122 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4123 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4124 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4125 AxumMessage::Close(frame) => {
4126 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4127 code: frame.code.into(),
4128 reason: frame.reason.as_ref().into(),
4129 }))
4130 }
4131 }
4132}
4133
4134fn notify_membership_updated(
4135 connection_pool: &mut ConnectionPool,
4136 result: MembershipUpdated,
4137 user_id: UserId,
4138 peer: &Peer,
4139) {
4140 for membership in &result.new_channels.channel_memberships {
4141 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4142 }
4143 for channel_id in &result.removed_channels {
4144 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4145 }
4146
4147 let user_channels_update = proto::UpdateUserChannels {
4148 channel_memberships: result
4149 .new_channels
4150 .channel_memberships
4151 .iter()
4152 .map(|cm| proto::ChannelMembership {
4153 channel_id: cm.channel_id.to_proto(),
4154 role: cm.role.into(),
4155 })
4156 .collect(),
4157 ..Default::default()
4158 };
4159
4160 let mut update = build_channels_update(result.new_channels);
4161 update.delete_channels = result
4162 .removed_channels
4163 .into_iter()
4164 .map(|id| id.to_proto())
4165 .collect();
4166 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4167
4168 for connection_id in connection_pool.user_connection_ids(user_id) {
4169 peer.send(connection_id, user_channels_update.clone())
4170 .trace_err();
4171 peer.send(connection_id, update.clone()).trace_err();
4172 }
4173}
4174
4175fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4176 proto::UpdateUserChannels {
4177 channel_memberships: channels
4178 .channel_memberships
4179 .iter()
4180 .map(|m| proto::ChannelMembership {
4181 channel_id: m.channel_id.to_proto(),
4182 role: m.role.into(),
4183 })
4184 .collect(),
4185 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4186 observed_channel_message_id: channels.observed_channel_messages.clone(),
4187 }
4188}
4189
4190fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4191 let mut update = proto::UpdateChannels::default();
4192
4193 for channel in channels.channels {
4194 update.channels.push(channel.to_proto());
4195 }
4196
4197 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4198 update.latest_channel_message_ids = channels.latest_channel_messages;
4199
4200 for (channel_id, participants) in channels.channel_participants {
4201 update
4202 .channel_participants
4203 .push(proto::ChannelParticipants {
4204 channel_id: channel_id.to_proto(),
4205 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4206 });
4207 }
4208
4209 for channel in channels.invited_channels {
4210 update.channel_invitations.push(channel.to_proto());
4211 }
4212
4213 update
4214}
4215
4216fn build_initial_contacts_update(
4217 contacts: Vec<db::Contact>,
4218 pool: &ConnectionPool,
4219) -> proto::UpdateContacts {
4220 let mut update = proto::UpdateContacts::default();
4221
4222 for contact in contacts {
4223 match contact {
4224 db::Contact::Accepted { user_id, busy } => {
4225 update.contacts.push(contact_for_user(user_id, busy, pool));
4226 }
4227 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4228 db::Contact::Incoming { user_id } => {
4229 update
4230 .incoming_requests
4231 .push(proto::IncomingContactRequest {
4232 requester_id: user_id.to_proto(),
4233 })
4234 }
4235 }
4236 }
4237
4238 update
4239}
4240
4241fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4242 proto::Contact {
4243 user_id: user_id.to_proto(),
4244 online: pool.is_user_online(user_id),
4245 busy,
4246 }
4247}
4248
4249fn room_updated(room: &proto::Room, peer: &Peer) {
4250 broadcast(
4251 None,
4252 room.participants
4253 .iter()
4254 .filter_map(|participant| Some(participant.peer_id?.into())),
4255 |peer_id| {
4256 peer.send(
4257 peer_id,
4258 proto::RoomUpdated {
4259 room: Some(room.clone()),
4260 },
4261 )
4262 },
4263 );
4264}
4265
4266fn channel_updated(
4267 channel: &db::channel::Model,
4268 room: &proto::Room,
4269 peer: &Peer,
4270 pool: &ConnectionPool,
4271) {
4272 let participants = room
4273 .participants
4274 .iter()
4275 .map(|p| p.user_id)
4276 .collect::<Vec<_>>();
4277
4278 broadcast(
4279 None,
4280 pool.channel_connection_ids(channel.root_id())
4281 .filter_map(|(channel_id, role)| {
4282 role.can_see_channel(channel.visibility)
4283 .then_some(channel_id)
4284 }),
4285 |peer_id| {
4286 peer.send(
4287 peer_id,
4288 proto::UpdateChannels {
4289 channel_participants: vec![proto::ChannelParticipants {
4290 channel_id: channel.id.to_proto(),
4291 participant_user_ids: participants.clone(),
4292 }],
4293 ..Default::default()
4294 },
4295 )
4296 },
4297 );
4298}
4299
4300async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4301 let db = session.db().await;
4302
4303 let contacts = db.get_contacts(user_id).await?;
4304 let busy = db.is_user_busy(user_id).await?;
4305
4306 let pool = session.connection_pool().await;
4307 let updated_contact = contact_for_user(user_id, busy, &pool);
4308 for contact in contacts {
4309 if let db::Contact::Accepted {
4310 user_id: contact_user_id,
4311 ..
4312 } = contact
4313 {
4314 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4315 session
4316 .peer
4317 .send(
4318 contact_conn_id,
4319 proto::UpdateContacts {
4320 contacts: vec![updated_contact.clone()],
4321 remove_contacts: Default::default(),
4322 incoming_requests: Default::default(),
4323 remove_incoming_requests: Default::default(),
4324 outgoing_requests: Default::default(),
4325 remove_outgoing_requests: Default::default(),
4326 },
4327 )
4328 .trace_err();
4329 }
4330 }
4331 }
4332 Ok(())
4333}
4334
4335async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4336 let mut contacts_to_update = HashSet::default();
4337
4338 let room_id;
4339 let canceled_calls_to_user_ids;
4340 let livekit_room;
4341 let delete_livekit_room;
4342 let room;
4343 let channel;
4344
4345 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4346 contacts_to_update.insert(session.user_id());
4347
4348 for project in left_room.left_projects.values() {
4349 project_left(project, session);
4350 }
4351
4352 room_id = RoomId::from_proto(left_room.room.id);
4353 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4354 livekit_room = mem::take(&mut left_room.room.livekit_room);
4355 delete_livekit_room = left_room.deleted;
4356 room = mem::take(&mut left_room.room);
4357 channel = mem::take(&mut left_room.channel);
4358
4359 room_updated(&room, &session.peer);
4360 } else {
4361 return Ok(());
4362 }
4363
4364 if let Some(channel) = channel {
4365 channel_updated(
4366 &channel,
4367 &room,
4368 &session.peer,
4369 &*session.connection_pool().await,
4370 );
4371 }
4372
4373 {
4374 let pool = session.connection_pool().await;
4375 for canceled_user_id in canceled_calls_to_user_ids {
4376 for connection_id in pool.user_connection_ids(canceled_user_id) {
4377 session
4378 .peer
4379 .send(
4380 connection_id,
4381 proto::CallCanceled {
4382 room_id: room_id.to_proto(),
4383 },
4384 )
4385 .trace_err();
4386 }
4387 contacts_to_update.insert(canceled_user_id);
4388 }
4389 }
4390
4391 for contact_user_id in contacts_to_update {
4392 update_user_contacts(contact_user_id, session).await?;
4393 }
4394
4395 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4396 live_kit
4397 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4398 .await
4399 .trace_err();
4400
4401 if delete_livekit_room {
4402 live_kit.delete_room(livekit_room).await.trace_err();
4403 }
4404 }
4405
4406 Ok(())
4407}
4408
4409async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4410 let left_channel_buffers = session
4411 .db()
4412 .await
4413 .leave_channel_buffers(session.connection_id)
4414 .await?;
4415
4416 for left_buffer in left_channel_buffers {
4417 channel_buffer_updated(
4418 session.connection_id,
4419 left_buffer.connections,
4420 &proto::UpdateChannelBufferCollaborators {
4421 channel_id: left_buffer.channel_id.to_proto(),
4422 collaborators: left_buffer.collaborators,
4423 },
4424 &session.peer,
4425 );
4426 }
4427
4428 Ok(())
4429}
4430
4431fn project_left(project: &db::LeftProject, session: &Session) {
4432 for connection_id in &project.connection_ids {
4433 if project.should_unshare {
4434 session
4435 .peer
4436 .send(
4437 *connection_id,
4438 proto::UnshareProject {
4439 project_id: project.id.to_proto(),
4440 },
4441 )
4442 .trace_err();
4443 } else {
4444 session
4445 .peer
4446 .send(
4447 *connection_id,
4448 proto::RemoveProjectCollaborator {
4449 project_id: project.id.to_proto(),
4450 peer_id: Some(session.connection_id.into()),
4451 },
4452 )
4453 .trace_err();
4454 }
4455 }
4456}
4457
4458pub trait ResultExt {
4459 type Ok;
4460
4461 fn trace_err(self) -> Option<Self::Ok>;
4462}
4463
4464impl<T, E> ResultExt for Result<T, E>
4465where
4466 E: std::fmt::Debug,
4467{
4468 type Ok = T;
4469
4470 #[track_caller]
4471 fn trace_err(self) -> Option<T> {
4472 match self {
4473 Ok(value) => Some(value),
4474 Err(error) => {
4475 tracing::error!("{:?}", error);
4476 None
4477 }
4478 }
4479 }
4480}