1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::db::LlmDatabase;
6use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
7use crate::{
8 AppState, Error, Result, auth,
9 db::{
10 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
11 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
12 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
13 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
14 },
15 executor::Executor,
16};
17use anyhow::{Context as _, anyhow, bail};
18use async_tungstenite::tungstenite::{
19 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
20};
21use axum::{
22 Extension, Router, TypedHeader,
23 body::Body,
24 extract::{
25 ConnectInfo, WebSocketUpgrade,
26 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
27 },
28 headers::{Header, HeaderName},
29 http::StatusCode,
30 middleware,
31 response::IntoResponse,
32 routing::get,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use reqwest_client::ReqwestClient;
39use rpc::proto::split_repository_update;
40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
41
42use futures::{
43 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
44 stream::FuturesUnordered,
45};
46use prometheus::{IntGauge, register_int_gauge};
47use rpc::{
48 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 Arc, OnceLock,
66 atomic::{AtomicBool, Ordering::SeqCst},
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{Semaphore, watch};
72use tower::ServiceBuilder;
73use tracing::{
74 Instrument,
75 field::{self},
76 info_span, instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 /// The GeoIP country code for the user.
137 #[allow(unused)]
138 geoip_country_code: Option<String>,
139 system_id: Option<String>,
140 _executor: Executor,
141}
142
143impl Session {
144 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
145 #[cfg(test)]
146 tokio::task::yield_now().await;
147 let guard = self.db.lock().await;
148 #[cfg(test)]
149 tokio::task::yield_now().await;
150 guard
151 }
152
153 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.connection_pool.lock();
157 ConnectionPoolGuard {
158 guard,
159 _not_send: PhantomData,
160 }
161 }
162
163 fn is_staff(&self) -> bool {
164 match &self.principal {
165 Principal::User(user) => user.admin,
166 Principal::Impersonated { .. } => true,
167 }
168 }
169
170 fn user_id(&self) -> UserId {
171 match &self.principal {
172 Principal::User(user) => user.id,
173 Principal::Impersonated { user, .. } => user.id,
174 }
175 }
176
177 pub fn email(&self) -> Option<String> {
178 match &self.principal {
179 Principal::User(user) => user.email_address.clone(),
180 Principal::Impersonated { user, .. } => user.email_address.clone(),
181 }
182 }
183}
184
185impl Debug for Session {
186 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
187 let mut result = f.debug_struct("Session");
188 match &self.principal {
189 Principal::User(user) => {
190 result.field("user", &user.github_login);
191 }
192 Principal::Impersonated { user, admin } => {
193 result.field("user", &user.github_login);
194 result.field("impersonator", &admin.github_login);
195 }
196 }
197 result.field("connection_id", &self.connection_id).finish()
198 }
199}
200
201struct DbHandle(Arc<Database>);
202
203impl Deref for DbHandle {
204 type Target = Database;
205
206 fn deref(&self) -> &Self::Target {
207 self.0.as_ref()
208 }
209}
210
211pub struct Server {
212 id: parking_lot::Mutex<ServerId>,
213 peer: Arc<Peer>,
214 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
215 app_state: Arc<AppState>,
216 handlers: HashMap<TypeId, MessageHandler>,
217 teardown: watch::Sender<bool>,
218}
219
220pub(crate) struct ConnectionPoolGuard<'a> {
221 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
222 _not_send: PhantomData<Rc<()>>,
223}
224
225#[derive(Serialize)]
226pub struct ServerSnapshot<'a> {
227 peer: &'a Peer,
228 #[serde(serialize_with = "serialize_deref")]
229 connection_pool: ConnectionPoolGuard<'a>,
230}
231
232pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
233where
234 S: Serializer,
235 T: Deref<Target = U>,
236 U: Serialize,
237{
238 Serialize::serialize(value.deref(), serializer)
239}
240
241impl Server {
242 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
243 let mut server = Self {
244 id: parking_lot::Mutex::new(id),
245 peer: Peer::new(id.0 as u32),
246 app_state: app_state.clone(),
247 connection_pool: Default::default(),
248 handlers: Default::default(),
249 teardown: watch::channel(false).0,
250 };
251
252 server
253 .add_request_handler(ping)
254 .add_request_handler(create_room)
255 .add_request_handler(join_room)
256 .add_request_handler(rejoin_room)
257 .add_request_handler(leave_room)
258 .add_request_handler(set_room_participant_role)
259 .add_request_handler(call)
260 .add_request_handler(cancel_call)
261 .add_message_handler(decline_call)
262 .add_request_handler(update_participant_location)
263 .add_request_handler(share_project)
264 .add_message_handler(unshare_project)
265 .add_request_handler(join_project)
266 .add_message_handler(leave_project)
267 .add_request_handler(update_project)
268 .add_request_handler(update_worktree)
269 .add_request_handler(update_repository)
270 .add_request_handler(remove_repository)
271 .add_message_handler(start_language_server)
272 .add_message_handler(update_language_server)
273 .add_message_handler(update_diagnostic_summary)
274 .add_message_handler(update_worktree_settings)
275 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
276 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
277 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
278 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
279 .add_request_handler(forward_find_search_candidates_request)
280 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
281 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
282 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
283 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
284 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
285 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
286 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
287 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
288 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
289 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
290 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
291 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
293 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
294 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
295 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
296 .add_request_handler(
297 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
298 )
299 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
300 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
301 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
303 .add_request_handler(
304 forward_read_only_project_request::<proto::LanguageServerIdForName>,
305 )
306 .add_request_handler(
307 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
308 )
309 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
310 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
311 .add_request_handler(
312 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
313 )
314 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
315 .add_request_handler(
316 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
319 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
320 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
321 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
322 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
323 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
324 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
325 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
326 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
327 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
328 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
329 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
330 .add_request_handler(
331 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
332 )
333 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
334 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
335 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
336 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
337 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
338 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
339 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
340 .add_message_handler(create_buffer_for_peer)
341 .add_request_handler(update_buffer)
342 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
343 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
344 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
345 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
346 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
347 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
348 .add_request_handler(get_users)
349 .add_request_handler(fuzzy_search_users)
350 .add_request_handler(request_contact)
351 .add_request_handler(remove_contact)
352 .add_request_handler(respond_to_contact_request)
353 .add_message_handler(subscribe_to_channels)
354 .add_request_handler(create_channel)
355 .add_request_handler(delete_channel)
356 .add_request_handler(invite_channel_member)
357 .add_request_handler(remove_channel_member)
358 .add_request_handler(set_channel_member_role)
359 .add_request_handler(set_channel_visibility)
360 .add_request_handler(rename_channel)
361 .add_request_handler(join_channel_buffer)
362 .add_request_handler(leave_channel_buffer)
363 .add_message_handler(update_channel_buffer)
364 .add_request_handler(rejoin_channel_buffers)
365 .add_request_handler(get_channel_members)
366 .add_request_handler(respond_to_channel_invite)
367 .add_request_handler(join_channel)
368 .add_request_handler(join_channel_chat)
369 .add_message_handler(leave_channel_chat)
370 .add_request_handler(send_channel_message)
371 .add_request_handler(remove_channel_message)
372 .add_request_handler(update_channel_message)
373 .add_request_handler(get_channel_messages)
374 .add_request_handler(get_channel_messages_by_id)
375 .add_request_handler(get_notifications)
376 .add_request_handler(mark_notification_as_read)
377 .add_request_handler(move_channel)
378 .add_request_handler(follow)
379 .add_message_handler(unfollow)
380 .add_message_handler(update_followers)
381 .add_request_handler(get_private_user_info)
382 .add_request_handler(get_llm_api_token)
383 .add_request_handler(accept_terms_of_service)
384 .add_message_handler(acknowledge_channel_message)
385 .add_message_handler(acknowledge_buffer_version)
386 .add_request_handler(get_supermaven_api_key)
387 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
388 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
389 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
390 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
391 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
392 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
393 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
394 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
395 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
396 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
397 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
398 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
399 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
400 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
402 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
403 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
404 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
405 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
406 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
407 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
408 .add_message_handler(update_context);
409
410 Arc::new(server)
411 }
412
413 pub async fn start(&self) -> Result<()> {
414 let server_id = *self.id.lock();
415 let app_state = self.app_state.clone();
416 let peer = self.peer.clone();
417 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
418 let pool = self.connection_pool.clone();
419 let livekit_client = self.app_state.livekit_client.clone();
420
421 let span = info_span!("start server");
422 self.app_state.executor.spawn_detached(
423 async move {
424 tracing::info!("waiting for cleanup timeout");
425 timeout.await;
426 tracing::info!("cleanup timeout expired, retrieving stale rooms");
427 if let Some((room_ids, channel_ids)) = app_state
428 .db
429 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
430 .await
431 .trace_err()
432 {
433 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
434 tracing::info!(
435 stale_channel_buffer_count = channel_ids.len(),
436 "retrieved stale channel buffers"
437 );
438
439 for channel_id in channel_ids {
440 if let Some(refreshed_channel_buffer) = app_state
441 .db
442 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
443 .await
444 .trace_err()
445 {
446 for connection_id in refreshed_channel_buffer.connection_ids {
447 peer.send(
448 connection_id,
449 proto::UpdateChannelBufferCollaborators {
450 channel_id: channel_id.to_proto(),
451 collaborators: refreshed_channel_buffer
452 .collaborators
453 .clone(),
454 },
455 )
456 .trace_err();
457 }
458 }
459 }
460
461 for room_id in room_ids {
462 let mut contacts_to_update = HashSet::default();
463 let mut canceled_calls_to_user_ids = Vec::new();
464 let mut livekit_room = String::new();
465 let mut delete_livekit_room = false;
466
467 if let Some(mut refreshed_room) = app_state
468 .db
469 .clear_stale_room_participants(room_id, server_id)
470 .await
471 .trace_err()
472 {
473 tracing::info!(
474 room_id = room_id.0,
475 new_participant_count = refreshed_room.room.participants.len(),
476 "refreshed room"
477 );
478 room_updated(&refreshed_room.room, &peer);
479 if let Some(channel) = refreshed_room.channel.as_ref() {
480 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
481 }
482 contacts_to_update
483 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
484 contacts_to_update
485 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
486 canceled_calls_to_user_ids =
487 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
488 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
489 delete_livekit_room = refreshed_room.room.participants.is_empty();
490 }
491
492 {
493 let pool = pool.lock();
494 for canceled_user_id in canceled_calls_to_user_ids {
495 for connection_id in pool.user_connection_ids(canceled_user_id) {
496 peer.send(
497 connection_id,
498 proto::CallCanceled {
499 room_id: room_id.to_proto(),
500 },
501 )
502 .trace_err();
503 }
504 }
505 }
506
507 for user_id in contacts_to_update {
508 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
509 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
510 if let Some((busy, contacts)) = busy.zip(contacts) {
511 let pool = pool.lock();
512 let updated_contact = contact_for_user(user_id, busy, &pool);
513 for contact in contacts {
514 if let db::Contact::Accepted {
515 user_id: contact_user_id,
516 ..
517 } = contact
518 {
519 for contact_conn_id in
520 pool.user_connection_ids(contact_user_id)
521 {
522 peer.send(
523 contact_conn_id,
524 proto::UpdateContacts {
525 contacts: vec![updated_contact.clone()],
526 remove_contacts: Default::default(),
527 incoming_requests: Default::default(),
528 remove_incoming_requests: Default::default(),
529 outgoing_requests: Default::default(),
530 remove_outgoing_requests: Default::default(),
531 },
532 )
533 .trace_err();
534 }
535 }
536 }
537 }
538 }
539
540 if let Some(live_kit) = livekit_client.as_ref() {
541 if delete_livekit_room {
542 live_kit.delete_room(livekit_room).await.trace_err();
543 }
544 }
545 }
546 }
547
548 app_state
549 .db
550 .delete_stale_servers(&app_state.config.zed_environment, server_id)
551 .await
552 .trace_err();
553 }
554 .instrument(span),
555 );
556 Ok(())
557 }
558
559 pub fn teardown(&self) {
560 self.peer.teardown();
561 self.connection_pool.lock().reset();
562 let _ = self.teardown.send(true);
563 }
564
565 #[cfg(test)]
566 pub fn reset(&self, id: ServerId) {
567 self.teardown();
568 *self.id.lock() = id;
569 self.peer.reset(id.0 as u32);
570 let _ = self.teardown.send(false);
571 }
572
573 #[cfg(test)]
574 pub fn id(&self) -> ServerId {
575 *self.id.lock()
576 }
577
578 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
579 where
580 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
581 Fut: 'static + Send + Future<Output = Result<()>>,
582 M: EnvelopedMessage,
583 {
584 let prev_handler = self.handlers.insert(
585 TypeId::of::<M>(),
586 Box::new(move |envelope, session| {
587 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
588 let received_at = envelope.received_at;
589 tracing::info!("message received");
590 let start_time = Instant::now();
591 let future = (handler)(*envelope, session);
592 async move {
593 let result = future.await;
594 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
595 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
596 let queue_duration_ms = total_duration_ms - processing_duration_ms;
597 let payload_type = M::NAME;
598
599 match result {
600 Err(error) => {
601 tracing::error!(
602 ?error,
603 total_duration_ms,
604 processing_duration_ms,
605 queue_duration_ms,
606 payload_type,
607 "error handling message"
608 )
609 }
610 Ok(()) => tracing::info!(
611 total_duration_ms,
612 processing_duration_ms,
613 queue_duration_ms,
614 "finished handling message"
615 ),
616 }
617 }
618 .boxed()
619 }),
620 );
621 if prev_handler.is_some() {
622 panic!("registered a handler for the same message twice");
623 }
624 self
625 }
626
627 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
628 where
629 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
630 Fut: 'static + Send + Future<Output = Result<()>>,
631 M: EnvelopedMessage,
632 {
633 self.add_handler(move |envelope, session| handler(envelope.payload, session));
634 self
635 }
636
637 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
638 where
639 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
640 Fut: Send + Future<Output = Result<()>>,
641 M: RequestMessage,
642 {
643 let handler = Arc::new(handler);
644 self.add_handler(move |envelope, session| {
645 let receipt = envelope.receipt();
646 let handler = handler.clone();
647 async move {
648 let peer = session.peer.clone();
649 let responded = Arc::new(AtomicBool::default());
650 let response = Response {
651 peer: peer.clone(),
652 responded: responded.clone(),
653 receipt,
654 };
655 match (handler)(envelope.payload, response, session).await {
656 Ok(()) => {
657 if responded.load(std::sync::atomic::Ordering::SeqCst) {
658 Ok(())
659 } else {
660 Err(anyhow!("handler did not send a response"))?
661 }
662 }
663 Err(error) => {
664 let proto_err = match &error {
665 Error::Internal(err) => err.to_proto(),
666 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
667 };
668 peer.respond_with_error(receipt, proto_err)?;
669 Err(error)
670 }
671 }
672 }
673 })
674 }
675
676 pub fn handle_connection(
677 self: &Arc<Self>,
678 connection: Connection,
679 address: String,
680 principal: Principal,
681 zed_version: ZedVersion,
682 geoip_country_code: Option<String>,
683 system_id: Option<String>,
684 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
685 executor: Executor,
686 ) -> impl Future<Output = ()> + use<> {
687 let this = self.clone();
688 let span = info_span!("handle connection", %address,
689 connection_id=field::Empty,
690 user_id=field::Empty,
691 login=field::Empty,
692 impersonator=field::Empty,
693 geoip_country_code=field::Empty
694 );
695 principal.update_span(&span);
696 if let Some(country_code) = geoip_country_code.as_ref() {
697 span.record("geoip_country_code", country_code);
698 }
699
700 let mut teardown = self.teardown.subscribe();
701 async move {
702 if *teardown.borrow() {
703 tracing::error!("server is tearing down");
704 return
705 }
706 let (connection_id, handle_io, mut incoming_rx) = this
707 .peer
708 .add_connection(connection, {
709 let executor = executor.clone();
710 move |duration| executor.sleep(duration)
711 });
712 tracing::Span::current().record("connection_id", format!("{}", connection_id));
713
714 tracing::info!("connection opened");
715
716 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
717 let http_client = match ReqwestClient::user_agent(&user_agent) {
718 Ok(http_client) => Arc::new(http_client),
719 Err(error) => {
720 tracing::error!(?error, "failed to create HTTP client");
721 return;
722 }
723 };
724
725 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
726 supermaven_admin_api_key.to_string(),
727 http_client.clone(),
728 )));
729
730 let session = Session {
731 principal: principal.clone(),
732 connection_id,
733 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
734 peer: this.peer.clone(),
735 connection_pool: this.connection_pool.clone(),
736 app_state: this.app_state.clone(),
737 geoip_country_code,
738 system_id,
739 _executor: executor.clone(),
740 supermaven_client,
741 };
742
743 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
744 tracing::error!(?error, "failed to send initial client update");
745 return;
746 }
747
748 let handle_io = handle_io.fuse();
749 futures::pin_mut!(handle_io);
750
751 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
752 // This prevents deadlocks when e.g., client A performs a request to client B and
753 // client B performs a request to client A. If both clients stop processing further
754 // messages until their respective request completes, they won't have a chance to
755 // respond to the other client's request and cause a deadlock.
756 //
757 // This arrangement ensures we will attempt to process earlier messages first, but fall
758 // back to processing messages arrived later in the spirit of making progress.
759 let mut foreground_message_handlers = FuturesUnordered::new();
760 let concurrent_handlers = Arc::new(Semaphore::new(256));
761 loop {
762 let next_message = async {
763 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
764 let message = incoming_rx.next().await;
765 (permit, message)
766 }.fuse();
767 futures::pin_mut!(next_message);
768 futures::select_biased! {
769 _ = teardown.changed().fuse() => return,
770 result = handle_io => {
771 if let Err(error) = result {
772 tracing::error!(?error, "error handling I/O");
773 }
774 break;
775 }
776 _ = foreground_message_handlers.next() => {}
777 next_message = next_message => {
778 let (permit, message) = next_message;
779 if let Some(message) = message {
780 let type_name = message.payload_type_name();
781 // note: we copy all the fields from the parent span so we can query them in the logs.
782 // (https://github.com/tokio-rs/tracing/issues/2670).
783 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
784 user_id=field::Empty,
785 login=field::Empty,
786 impersonator=field::Empty,
787 );
788 principal.update_span(&span);
789 let span_enter = span.enter();
790 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
791 let is_background = message.is_background();
792 let handle_message = (handler)(message, session.clone());
793 drop(span_enter);
794
795 let handle_message = async move {
796 handle_message.await;
797 drop(permit);
798 }.instrument(span);
799 if is_background {
800 executor.spawn_detached(handle_message);
801 } else {
802 foreground_message_handlers.push(handle_message);
803 }
804 } else {
805 tracing::error!("no message handler");
806 }
807 } else {
808 tracing::info!("connection closed");
809 break;
810 }
811 }
812 }
813 }
814
815 drop(foreground_message_handlers);
816 tracing::info!("signing out");
817 if let Err(error) = connection_lost(session, teardown, executor).await {
818 tracing::error!(?error, "error signing out");
819 }
820
821 }.instrument(span)
822 }
823
824 async fn send_initial_client_update(
825 &self,
826 connection_id: ConnectionId,
827 principal: &Principal,
828 zed_version: ZedVersion,
829 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
830 session: &Session,
831 ) -> Result<()> {
832 self.peer.send(
833 connection_id,
834 proto::Hello {
835 peer_id: Some(connection_id.into()),
836 },
837 )?;
838 tracing::info!("sent hello message");
839 if let Some(send_connection_id) = send_connection_id.take() {
840 let _ = send_connection_id.send(connection_id);
841 }
842
843 match principal {
844 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
845 if !user.connected_once {
846 self.peer.send(connection_id, proto::ShowContacts {})?;
847 self.app_state
848 .db
849 .set_user_connected_once(user.id, true)
850 .await?;
851 }
852
853 update_user_plan(user.id, session).await?;
854
855 let contacts = self.app_state.db.get_contacts(user.id).await?;
856
857 {
858 let mut pool = self.connection_pool.lock();
859 pool.add_connection(connection_id, user.id, user.admin, zed_version);
860 self.peer.send(
861 connection_id,
862 build_initial_contacts_update(contacts, &pool),
863 )?;
864 }
865
866 if should_auto_subscribe_to_channels(zed_version) {
867 subscribe_user_to_channels(user.id, session).await?;
868 }
869
870 if let Some(incoming_call) =
871 self.app_state.db.incoming_call_for_user(user.id).await?
872 {
873 self.peer.send(connection_id, incoming_call)?;
874 }
875
876 update_user_contacts(user.id, session).await?;
877 }
878 }
879
880 Ok(())
881 }
882
883 pub async fn invite_code_redeemed(
884 self: &Arc<Self>,
885 inviter_id: UserId,
886 invitee_id: UserId,
887 ) -> Result<()> {
888 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
889 if let Some(code) = &user.invite_code {
890 let pool = self.connection_pool.lock();
891 let invitee_contact = contact_for_user(invitee_id, false, &pool);
892 for connection_id in pool.user_connection_ids(inviter_id) {
893 self.peer.send(
894 connection_id,
895 proto::UpdateContacts {
896 contacts: vec![invitee_contact.clone()],
897 ..Default::default()
898 },
899 )?;
900 self.peer.send(
901 connection_id,
902 proto::UpdateInviteInfo {
903 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
904 count: user.invite_count as u32,
905 },
906 )?;
907 }
908 }
909 }
910 Ok(())
911 }
912
913 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
914 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
915 if let Some(invite_code) = &user.invite_code {
916 let pool = self.connection_pool.lock();
917 for connection_id in pool.user_connection_ids(user_id) {
918 self.peer.send(
919 connection_id,
920 proto::UpdateInviteInfo {
921 url: format!(
922 "{}{}",
923 self.app_state.config.invite_link_prefix, invite_code
924 ),
925 count: user.invite_count as u32,
926 },
927 )?;
928 }
929 }
930 }
931 Ok(())
932 }
933
934 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
935 let user = self
936 .app_state
937 .db
938 .get_user_by_id(user_id)
939 .await?
940 .ok_or_else(|| anyhow!("user not found"))?;
941
942 let update_user_plan = make_update_user_plan_message(
943 &self.app_state.db,
944 self.app_state.llm_db.clone(),
945 user_id,
946 user.admin,
947 )
948 .await?;
949
950 let pool = self.connection_pool.lock();
951 for connection_id in pool.user_connection_ids(user_id) {
952 self.peer
953 .send(connection_id, update_user_plan.clone())
954 .trace_err();
955 }
956
957 Ok(())
958 }
959
960 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
961 let pool = self.connection_pool.lock();
962 for connection_id in pool.user_connection_ids(user_id) {
963 self.peer
964 .send(connection_id, proto::RefreshLlmToken {})
965 .trace_err();
966 }
967 }
968
969 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
970 ServerSnapshot {
971 connection_pool: ConnectionPoolGuard {
972 guard: self.connection_pool.lock(),
973 _not_send: PhantomData,
974 },
975 peer: &self.peer,
976 }
977 }
978}
979
980impl Deref for ConnectionPoolGuard<'_> {
981 type Target = ConnectionPool;
982
983 fn deref(&self) -> &Self::Target {
984 &self.guard
985 }
986}
987
988impl DerefMut for ConnectionPoolGuard<'_> {
989 fn deref_mut(&mut self) -> &mut Self::Target {
990 &mut self.guard
991 }
992}
993
994impl Drop for ConnectionPoolGuard<'_> {
995 fn drop(&mut self) {
996 #[cfg(test)]
997 self.check_invariants();
998 }
999}
1000
1001fn broadcast<F>(
1002 sender_id: Option<ConnectionId>,
1003 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004 mut f: F,
1005) where
1006 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008 for receiver_id in receiver_ids {
1009 if Some(receiver_id) != sender_id {
1010 if let Err(error) = f(receiver_id) {
1011 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012 }
1013 }
1014 }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020 fn name() -> &'static HeaderName {
1021 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023 }
1024
1025 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026 where
1027 Self: Sized,
1028 I: Iterator<Item = &'i axum::http::HeaderValue>,
1029 {
1030 let version = values
1031 .next()
1032 .ok_or_else(axum::headers::Error::invalid)?
1033 .to_str()
1034 .map_err(|_| axum::headers::Error::invalid())?
1035 .parse()
1036 .map_err(|_| axum::headers::Error::invalid())?;
1037 Ok(Self(version))
1038 }
1039
1040 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041 values.extend([self.0.to_string().parse().unwrap()]);
1042 }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047 fn name() -> &'static HeaderName {
1048 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050 }
1051
1052 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053 where
1054 Self: Sized,
1055 I: Iterator<Item = &'i axum::http::HeaderValue>,
1056 {
1057 let version = values
1058 .next()
1059 .ok_or_else(axum::headers::Error::invalid)?
1060 .to_str()
1061 .map_err(|_| axum::headers::Error::invalid())?
1062 .parse()
1063 .map_err(|_| axum::headers::Error::invalid())?;
1064 Ok(Self(version))
1065 }
1066
1067 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068 values.extend([self.0.to_string().parse().unwrap()]);
1069 }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073 Router::new()
1074 .route("/rpc", get(handle_websocket_request))
1075 .layer(
1076 ServiceBuilder::new()
1077 .layer(Extension(server.app_state.clone()))
1078 .layer(middleware::from_fn(auth::validate_header)),
1079 )
1080 .route("/metrics", get(handle_metrics))
1081 .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088 Extension(server): Extension<Arc<Server>>,
1089 Extension(principal): Extension<Principal>,
1090 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1091 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1092 ws: WebSocketUpgrade,
1093) -> axum::response::Response {
1094 if protocol_version != rpc::PROTOCOL_VERSION {
1095 return (
1096 StatusCode::UPGRADE_REQUIRED,
1097 "client must be upgraded".to_string(),
1098 )
1099 .into_response();
1100 }
1101
1102 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "no version header found".to_string(),
1106 )
1107 .into_response();
1108 };
1109
1110 if !version.can_collaborate() {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "client must be upgraded".to_string(),
1114 )
1115 .into_response();
1116 }
1117
1118 let socket_address = socket_address.to_string();
1119 ws.on_upgrade(move |socket| {
1120 let socket = socket
1121 .map_ok(to_tungstenite_message)
1122 .err_into()
1123 .with(|message| async move { to_axum_message(message) });
1124 let connection = Connection::new(Box::pin(socket));
1125 async move {
1126 server
1127 .handle_connection(
1128 connection,
1129 socket_address,
1130 principal,
1131 version,
1132 country_code_header.map(|header| header.to_string()),
1133 system_id_header.map(|header| header.to_string()),
1134 None,
1135 Executor::Production,
1136 )
1137 .await;
1138 }
1139 })
1140}
1141
1142pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1143 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1144 let connections_metric = CONNECTIONS_METRIC
1145 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1146
1147 let connections = server
1148 .connection_pool
1149 .lock()
1150 .connections()
1151 .filter(|connection| !connection.admin)
1152 .count();
1153 connections_metric.set(connections as _);
1154
1155 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1156 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1157 register_int_gauge!(
1158 "shared_projects",
1159 "number of open projects with one or more guests"
1160 )
1161 .unwrap()
1162 });
1163
1164 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1165 shared_projects_metric.set(shared_projects as _);
1166
1167 let encoder = prometheus::TextEncoder::new();
1168 let metric_families = prometheus::gather();
1169 let encoded_metrics = encoder
1170 .encode_to_string(&metric_families)
1171 .map_err(|err| anyhow!("{}", err))?;
1172 Ok(encoded_metrics)
1173}
1174
1175#[instrument(err, skip(executor))]
1176async fn connection_lost(
1177 session: Session,
1178 mut teardown: watch::Receiver<bool>,
1179 executor: Executor,
1180) -> Result<()> {
1181 session.peer.disconnect(session.connection_id);
1182 session
1183 .connection_pool()
1184 .await
1185 .remove_connection(session.connection_id)?;
1186
1187 session
1188 .db()
1189 .await
1190 .connection_lost(session.connection_id)
1191 .await
1192 .trace_err();
1193
1194 futures::select_biased! {
1195 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1196
1197 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1198 leave_room_for_session(&session, session.connection_id).await.trace_err();
1199 leave_channel_buffers_for_session(&session)
1200 .await
1201 .trace_err();
1202
1203 if !session
1204 .connection_pool()
1205 .await
1206 .is_user_online(session.user_id())
1207 {
1208 let db = session.db().await;
1209 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1210 room_updated(&room, &session.peer);
1211 }
1212 }
1213
1214 update_user_contacts(session.user_id(), &session).await?;
1215 },
1216 _ = teardown.changed().fuse() => {}
1217 }
1218
1219 Ok(())
1220}
1221
1222/// Acknowledges a ping from a client, used to keep the connection alive.
1223async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1224 response.send(proto::Ack {})?;
1225 Ok(())
1226}
1227
1228/// Creates a new room for calling (outside of channels)
1229async fn create_room(
1230 _request: proto::CreateRoom,
1231 response: Response<proto::CreateRoom>,
1232 session: Session,
1233) -> Result<()> {
1234 let livekit_room = nanoid::nanoid!(30);
1235
1236 let live_kit_connection_info = util::maybe!(async {
1237 let live_kit = session.app_state.livekit_client.as_ref();
1238 let live_kit = live_kit?;
1239 let user_id = session.user_id().to_string();
1240
1241 let token = live_kit
1242 .room_token(&livekit_room, &user_id.to_string())
1243 .trace_err()?;
1244
1245 Some(proto::LiveKitConnectionInfo {
1246 server_url: live_kit.url().into(),
1247 token,
1248 can_publish: true,
1249 })
1250 })
1251 .await;
1252
1253 let room = session
1254 .db()
1255 .await
1256 .create_room(session.user_id(), session.connection_id, &livekit_room)
1257 .await?;
1258
1259 response.send(proto::CreateRoomResponse {
1260 room: Some(room.clone()),
1261 live_kit_connection_info,
1262 })?;
1263
1264 update_user_contacts(session.user_id(), &session).await?;
1265 Ok(())
1266}
1267
1268/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1269async fn join_room(
1270 request: proto::JoinRoom,
1271 response: Response<proto::JoinRoom>,
1272 session: Session,
1273) -> Result<()> {
1274 let room_id = RoomId::from_proto(request.id);
1275
1276 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1277
1278 if let Some(channel_id) = channel_id {
1279 return join_channel_internal(channel_id, Box::new(response), session).await;
1280 }
1281
1282 let joined_room = {
1283 let room = session
1284 .db()
1285 .await
1286 .join_room(room_id, session.user_id(), session.connection_id)
1287 .await?;
1288 room_updated(&room.room, &session.peer);
1289 room.into_inner()
1290 };
1291
1292 for connection_id in session
1293 .connection_pool()
1294 .await
1295 .user_connection_ids(session.user_id())
1296 {
1297 session
1298 .peer
1299 .send(
1300 connection_id,
1301 proto::CallCanceled {
1302 room_id: room_id.to_proto(),
1303 },
1304 )
1305 .trace_err();
1306 }
1307
1308 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1309 {
1310 live_kit
1311 .room_token(
1312 &joined_room.room.livekit_room,
1313 &session.user_id().to_string(),
1314 )
1315 .trace_err()
1316 .map(|token| proto::LiveKitConnectionInfo {
1317 server_url: live_kit.url().into(),
1318 token,
1319 can_publish: true,
1320 })
1321 } else {
1322 None
1323 };
1324
1325 response.send(proto::JoinRoomResponse {
1326 room: Some(joined_room.room),
1327 channel_id: None,
1328 live_kit_connection_info,
1329 })?;
1330
1331 update_user_contacts(session.user_id(), &session).await?;
1332 Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337 request: proto::RejoinRoom,
1338 response: Response<proto::RejoinRoom>,
1339 session: Session,
1340) -> Result<()> {
1341 let room;
1342 let channel;
1343 {
1344 let mut rejoined_room = session
1345 .db()
1346 .await
1347 .rejoin_room(request, session.user_id(), session.connection_id)
1348 .await?;
1349
1350 response.send(proto::RejoinRoomResponse {
1351 room: Some(rejoined_room.room.clone()),
1352 reshared_projects: rejoined_room
1353 .reshared_projects
1354 .iter()
1355 .map(|project| proto::ResharedProject {
1356 id: project.id.to_proto(),
1357 collaborators: project
1358 .collaborators
1359 .iter()
1360 .map(|collaborator| collaborator.to_proto())
1361 .collect(),
1362 })
1363 .collect(),
1364 rejoined_projects: rejoined_room
1365 .rejoined_projects
1366 .iter()
1367 .map(|rejoined_project| rejoined_project.to_proto())
1368 .collect(),
1369 })?;
1370 room_updated(&rejoined_room.room, &session.peer);
1371
1372 for project in &rejoined_room.reshared_projects {
1373 for collaborator in &project.collaborators {
1374 session
1375 .peer
1376 .send(
1377 collaborator.connection_id,
1378 proto::UpdateProjectCollaborator {
1379 project_id: project.id.to_proto(),
1380 old_peer_id: Some(project.old_connection_id.into()),
1381 new_peer_id: Some(session.connection_id.into()),
1382 },
1383 )
1384 .trace_err();
1385 }
1386
1387 broadcast(
1388 Some(session.connection_id),
1389 project
1390 .collaborators
1391 .iter()
1392 .map(|collaborator| collaborator.connection_id),
1393 |connection_id| {
1394 session.peer.forward_send(
1395 session.connection_id,
1396 connection_id,
1397 proto::UpdateProject {
1398 project_id: project.id.to_proto(),
1399 worktrees: project.worktrees.clone(),
1400 },
1401 )
1402 },
1403 );
1404 }
1405
1406 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1407
1408 let rejoined_room = rejoined_room.into_inner();
1409
1410 room = rejoined_room.room;
1411 channel = rejoined_room.channel;
1412 }
1413
1414 if let Some(channel) = channel {
1415 channel_updated(
1416 &channel,
1417 &room,
1418 &session.peer,
1419 &*session.connection_pool().await,
1420 );
1421 }
1422
1423 update_user_contacts(session.user_id(), &session).await?;
1424 Ok(())
1425}
1426
1427fn notify_rejoined_projects(
1428 rejoined_projects: &mut Vec<RejoinedProject>,
1429 session: &Session,
1430) -> Result<()> {
1431 for project in rejoined_projects.iter() {
1432 for collaborator in &project.collaborators {
1433 session
1434 .peer
1435 .send(
1436 collaborator.connection_id,
1437 proto::UpdateProjectCollaborator {
1438 project_id: project.id.to_proto(),
1439 old_peer_id: Some(project.old_connection_id.into()),
1440 new_peer_id: Some(session.connection_id.into()),
1441 },
1442 )
1443 .trace_err();
1444 }
1445 }
1446
1447 for project in rejoined_projects {
1448 for worktree in mem::take(&mut project.worktrees) {
1449 // Stream this worktree's entries.
1450 let message = proto::UpdateWorktree {
1451 project_id: project.id.to_proto(),
1452 worktree_id: worktree.id,
1453 abs_path: worktree.abs_path.clone(),
1454 root_name: worktree.root_name,
1455 updated_entries: worktree.updated_entries,
1456 removed_entries: worktree.removed_entries,
1457 scan_id: worktree.scan_id,
1458 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1459 updated_repositories: worktree.updated_repositories,
1460 removed_repositories: worktree.removed_repositories,
1461 };
1462 for update in proto::split_worktree_update(message) {
1463 session.peer.send(session.connection_id, update)?;
1464 }
1465
1466 // Stream this worktree's diagnostics.
1467 for summary in worktree.diagnostic_summaries {
1468 session.peer.send(
1469 session.connection_id,
1470 proto::UpdateDiagnosticSummary {
1471 project_id: project.id.to_proto(),
1472 worktree_id: worktree.id,
1473 summary: Some(summary),
1474 },
1475 )?;
1476 }
1477
1478 for settings_file in worktree.settings_files {
1479 session.peer.send(
1480 session.connection_id,
1481 proto::UpdateWorktreeSettings {
1482 project_id: project.id.to_proto(),
1483 worktree_id: worktree.id,
1484 path: settings_file.path,
1485 content: Some(settings_file.content),
1486 kind: Some(settings_file.kind.to_proto().into()),
1487 },
1488 )?;
1489 }
1490 }
1491
1492 for repository in mem::take(&mut project.updated_repositories) {
1493 for update in split_repository_update(repository) {
1494 session.peer.send(session.connection_id, update)?;
1495 }
1496 }
1497
1498 for id in mem::take(&mut project.removed_repositories) {
1499 session.peer.send(
1500 session.connection_id,
1501 proto::RemoveRepository {
1502 project_id: project.id.to_proto(),
1503 id,
1504 },
1505 )?;
1506 }
1507 }
1508
1509 Ok(())
1510}
1511
1512/// leave room disconnects from the room.
1513async fn leave_room(
1514 _: proto::LeaveRoom,
1515 response: Response<proto::LeaveRoom>,
1516 session: Session,
1517) -> Result<()> {
1518 leave_room_for_session(&session, session.connection_id).await?;
1519 response.send(proto::Ack {})?;
1520 Ok(())
1521}
1522
1523/// Updates the permissions of someone else in the room.
1524async fn set_room_participant_role(
1525 request: proto::SetRoomParticipantRole,
1526 response: Response<proto::SetRoomParticipantRole>,
1527 session: Session,
1528) -> Result<()> {
1529 let user_id = UserId::from_proto(request.user_id);
1530 let role = ChannelRole::from(request.role());
1531
1532 let (livekit_room, can_publish) = {
1533 let room = session
1534 .db()
1535 .await
1536 .set_room_participant_role(
1537 session.user_id(),
1538 RoomId::from_proto(request.room_id),
1539 user_id,
1540 role,
1541 )
1542 .await?;
1543
1544 let livekit_room = room.livekit_room.clone();
1545 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1546 room_updated(&room, &session.peer);
1547 (livekit_room, can_publish)
1548 };
1549
1550 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1551 live_kit
1552 .update_participant(
1553 livekit_room.clone(),
1554 request.user_id.to_string(),
1555 livekit_api::proto::ParticipantPermission {
1556 can_subscribe: true,
1557 can_publish,
1558 can_publish_data: can_publish,
1559 hidden: false,
1560 recorder: false,
1561 },
1562 )
1563 .await
1564 .trace_err();
1565 }
1566
1567 response.send(proto::Ack {})?;
1568 Ok(())
1569}
1570
1571/// Call someone else into the current room
1572async fn call(
1573 request: proto::Call,
1574 response: Response<proto::Call>,
1575 session: Session,
1576) -> Result<()> {
1577 let room_id = RoomId::from_proto(request.room_id);
1578 let calling_user_id = session.user_id();
1579 let calling_connection_id = session.connection_id;
1580 let called_user_id = UserId::from_proto(request.called_user_id);
1581 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1582 if !session
1583 .db()
1584 .await
1585 .has_contact(calling_user_id, called_user_id)
1586 .await?
1587 {
1588 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1589 }
1590
1591 let incoming_call = {
1592 let (room, incoming_call) = &mut *session
1593 .db()
1594 .await
1595 .call(
1596 room_id,
1597 calling_user_id,
1598 calling_connection_id,
1599 called_user_id,
1600 initial_project_id,
1601 )
1602 .await?;
1603 room_updated(room, &session.peer);
1604 mem::take(incoming_call)
1605 };
1606 update_user_contacts(called_user_id, &session).await?;
1607
1608 let mut calls = session
1609 .connection_pool()
1610 .await
1611 .user_connection_ids(called_user_id)
1612 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1613 .collect::<FuturesUnordered<_>>();
1614
1615 while let Some(call_response) = calls.next().await {
1616 match call_response.as_ref() {
1617 Ok(_) => {
1618 response.send(proto::Ack {})?;
1619 return Ok(());
1620 }
1621 Err(_) => {
1622 call_response.trace_err();
1623 }
1624 }
1625 }
1626
1627 {
1628 let room = session
1629 .db()
1630 .await
1631 .call_failed(room_id, called_user_id)
1632 .await?;
1633 room_updated(&room, &session.peer);
1634 }
1635 update_user_contacts(called_user_id, &session).await?;
1636
1637 Err(anyhow!("failed to ring user"))?
1638}
1639
1640/// Cancel an outgoing call.
1641async fn cancel_call(
1642 request: proto::CancelCall,
1643 response: Response<proto::CancelCall>,
1644 session: Session,
1645) -> Result<()> {
1646 let called_user_id = UserId::from_proto(request.called_user_id);
1647 let room_id = RoomId::from_proto(request.room_id);
1648 {
1649 let room = session
1650 .db()
1651 .await
1652 .cancel_call(room_id, session.connection_id, called_user_id)
1653 .await?;
1654 room_updated(&room, &session.peer);
1655 }
1656
1657 for connection_id in session
1658 .connection_pool()
1659 .await
1660 .user_connection_ids(called_user_id)
1661 {
1662 session
1663 .peer
1664 .send(
1665 connection_id,
1666 proto::CallCanceled {
1667 room_id: room_id.to_proto(),
1668 },
1669 )
1670 .trace_err();
1671 }
1672 response.send(proto::Ack {})?;
1673
1674 update_user_contacts(called_user_id, &session).await?;
1675 Ok(())
1676}
1677
1678/// Decline an incoming call.
1679async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1680 let room_id = RoomId::from_proto(message.room_id);
1681 {
1682 let room = session
1683 .db()
1684 .await
1685 .decline_call(Some(room_id), session.user_id())
1686 .await?
1687 .ok_or_else(|| anyhow!("failed to decline call"))?;
1688 room_updated(&room, &session.peer);
1689 }
1690
1691 for connection_id in session
1692 .connection_pool()
1693 .await
1694 .user_connection_ids(session.user_id())
1695 {
1696 session
1697 .peer
1698 .send(
1699 connection_id,
1700 proto::CallCanceled {
1701 room_id: room_id.to_proto(),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 update_user_contacts(session.user_id(), &session).await?;
1707 Ok(())
1708}
1709
1710/// Updates other participants in the room with your current location.
1711async fn update_participant_location(
1712 request: proto::UpdateParticipantLocation,
1713 response: Response<proto::UpdateParticipantLocation>,
1714 session: Session,
1715) -> Result<()> {
1716 let room_id = RoomId::from_proto(request.room_id);
1717 let location = request
1718 .location
1719 .ok_or_else(|| anyhow!("invalid location"))?;
1720
1721 let db = session.db().await;
1722 let room = db
1723 .update_room_participant_location(room_id, session.connection_id, location)
1724 .await?;
1725
1726 room_updated(&room, &session.peer);
1727 response.send(proto::Ack {})?;
1728 Ok(())
1729}
1730
1731/// Share a project into the room.
1732async fn share_project(
1733 request: proto::ShareProject,
1734 response: Response<proto::ShareProject>,
1735 session: Session,
1736) -> Result<()> {
1737 let (project_id, room) = &*session
1738 .db()
1739 .await
1740 .share_project(
1741 RoomId::from_proto(request.room_id),
1742 session.connection_id,
1743 &request.worktrees,
1744 request.is_ssh_project,
1745 )
1746 .await?;
1747 response.send(proto::ShareProjectResponse {
1748 project_id: project_id.to_proto(),
1749 })?;
1750 room_updated(room, &session.peer);
1751
1752 Ok(())
1753}
1754
1755/// Unshare a project from the room.
1756async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1757 let project_id = ProjectId::from_proto(message.project_id);
1758 unshare_project_internal(project_id, session.connection_id, &session).await
1759}
1760
1761async fn unshare_project_internal(
1762 project_id: ProjectId,
1763 connection_id: ConnectionId,
1764 session: &Session,
1765) -> Result<()> {
1766 let delete = {
1767 let room_guard = session
1768 .db()
1769 .await
1770 .unshare_project(project_id, connection_id)
1771 .await?;
1772
1773 let (delete, room, guest_connection_ids) = &*room_guard;
1774
1775 let message = proto::UnshareProject {
1776 project_id: project_id.to_proto(),
1777 };
1778
1779 broadcast(
1780 Some(connection_id),
1781 guest_connection_ids.iter().copied(),
1782 |conn_id| session.peer.send(conn_id, message.clone()),
1783 );
1784 if let Some(room) = room {
1785 room_updated(room, &session.peer);
1786 }
1787
1788 *delete
1789 };
1790
1791 if delete {
1792 let db = session.db().await;
1793 db.delete_project(project_id).await?;
1794 }
1795
1796 Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801 request: proto::JoinProject,
1802 response: Response<proto::JoinProject>,
1803 session: Session,
1804) -> Result<()> {
1805 let project_id = ProjectId::from_proto(request.project_id);
1806
1807 tracing::info!(%project_id, "join project");
1808
1809 let db = session.db().await;
1810 let (project, replica_id) = &mut *db
1811 .join_project(project_id, session.connection_id, session.user_id())
1812 .await?;
1813 drop(db);
1814 tracing::info!(%project_id, "join remote project");
1815 join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823 Response::<proto::JoinProject>::send(self, result)
1824 }
1825}
1826
1827fn join_project_internal(
1828 response: impl JoinProjectInternalResponse,
1829 session: Session,
1830 project: &mut Project,
1831 replica_id: &ReplicaId,
1832) -> Result<()> {
1833 let collaborators = project
1834 .collaborators
1835 .iter()
1836 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1837 .map(|collaborator| collaborator.to_proto())
1838 .collect::<Vec<_>>();
1839 let project_id = project.id;
1840 let guest_user_id = session.user_id();
1841
1842 let worktrees = project
1843 .worktrees
1844 .iter()
1845 .map(|(id, worktree)| proto::WorktreeMetadata {
1846 id: *id,
1847 root_name: worktree.root_name.clone(),
1848 visible: worktree.visible,
1849 abs_path: worktree.abs_path.clone(),
1850 })
1851 .collect::<Vec<_>>();
1852
1853 let add_project_collaborator = proto::AddProjectCollaborator {
1854 project_id: project_id.to_proto(),
1855 collaborator: Some(proto::Collaborator {
1856 peer_id: Some(session.connection_id.into()),
1857 replica_id: replica_id.0 as u32,
1858 user_id: guest_user_id.to_proto(),
1859 is_host: false,
1860 }),
1861 };
1862
1863 for collaborator in &collaborators {
1864 session
1865 .peer
1866 .send(
1867 collaborator.peer_id.unwrap().into(),
1868 add_project_collaborator.clone(),
1869 )
1870 .trace_err();
1871 }
1872
1873 // First, we send the metadata associated with each worktree.
1874 response.send(proto::JoinProjectResponse {
1875 project_id: project.id.0 as u64,
1876 worktrees: worktrees.clone(),
1877 replica_id: replica_id.0 as u32,
1878 collaborators: collaborators.clone(),
1879 language_servers: project.language_servers.clone(),
1880 role: project.role.into(),
1881 })?;
1882
1883 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1884 // Stream this worktree's entries.
1885 let message = proto::UpdateWorktree {
1886 project_id: project_id.to_proto(),
1887 worktree_id,
1888 abs_path: worktree.abs_path.clone(),
1889 root_name: worktree.root_name,
1890 updated_entries: worktree.entries,
1891 removed_entries: Default::default(),
1892 scan_id: worktree.scan_id,
1893 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1894 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1895 removed_repositories: Default::default(),
1896 };
1897 for update in proto::split_worktree_update(message) {
1898 session.peer.send(session.connection_id, update.clone())?;
1899 }
1900
1901 // Stream this worktree's diagnostics.
1902 for summary in worktree.diagnostic_summaries {
1903 session.peer.send(
1904 session.connection_id,
1905 proto::UpdateDiagnosticSummary {
1906 project_id: project_id.to_proto(),
1907 worktree_id: worktree.id,
1908 summary: Some(summary),
1909 },
1910 )?;
1911 }
1912
1913 for settings_file in worktree.settings_files {
1914 session.peer.send(
1915 session.connection_id,
1916 proto::UpdateWorktreeSettings {
1917 project_id: project_id.to_proto(),
1918 worktree_id: worktree.id,
1919 path: settings_file.path,
1920 content: Some(settings_file.content),
1921 kind: Some(settings_file.kind.to_proto() as i32),
1922 },
1923 )?;
1924 }
1925 }
1926
1927 for repository in mem::take(&mut project.repositories) {
1928 for update in split_repository_update(repository) {
1929 session.peer.send(session.connection_id, update)?;
1930 }
1931 }
1932
1933 for language_server in &project.language_servers {
1934 session.peer.send(
1935 session.connection_id,
1936 proto::UpdateLanguageServer {
1937 project_id: project_id.to_proto(),
1938 language_server_id: language_server.id,
1939 variant: Some(
1940 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1941 proto::LspDiskBasedDiagnosticsUpdated {},
1942 ),
1943 ),
1944 },
1945 )?;
1946 }
1947
1948 Ok(())
1949}
1950
1951/// Leave someone elses shared project.
1952async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1953 let sender_id = session.connection_id;
1954 let project_id = ProjectId::from_proto(request.project_id);
1955 let db = session.db().await;
1956
1957 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1958 tracing::info!(
1959 %project_id,
1960 "leave project"
1961 );
1962
1963 project_left(project, &session);
1964 if let Some(room) = room {
1965 room_updated(room, &session.peer);
1966 }
1967
1968 Ok(())
1969}
1970
1971/// Updates other participants with changes to the project
1972async fn update_project(
1973 request: proto::UpdateProject,
1974 response: Response<proto::UpdateProject>,
1975 session: Session,
1976) -> Result<()> {
1977 let project_id = ProjectId::from_proto(request.project_id);
1978 let (room, guest_connection_ids) = &*session
1979 .db()
1980 .await
1981 .update_project(project_id, session.connection_id, &request.worktrees)
1982 .await?;
1983 broadcast(
1984 Some(session.connection_id),
1985 guest_connection_ids.iter().copied(),
1986 |connection_id| {
1987 session
1988 .peer
1989 .forward_send(session.connection_id, connection_id, request.clone())
1990 },
1991 );
1992 if let Some(room) = room {
1993 room_updated(room, &session.peer);
1994 }
1995 response.send(proto::Ack {})?;
1996
1997 Ok(())
1998}
1999
2000/// Updates other participants with changes to the worktree
2001async fn update_worktree(
2002 request: proto::UpdateWorktree,
2003 response: Response<proto::UpdateWorktree>,
2004 session: Session,
2005) -> Result<()> {
2006 let guest_connection_ids = session
2007 .db()
2008 .await
2009 .update_worktree(&request, session.connection_id)
2010 .await?;
2011
2012 broadcast(
2013 Some(session.connection_id),
2014 guest_connection_ids.iter().copied(),
2015 |connection_id| {
2016 session
2017 .peer
2018 .forward_send(session.connection_id, connection_id, request.clone())
2019 },
2020 );
2021 response.send(proto::Ack {})?;
2022 Ok(())
2023}
2024
2025async fn update_repository(
2026 request: proto::UpdateRepository,
2027 response: Response<proto::UpdateRepository>,
2028 session: Session,
2029) -> Result<()> {
2030 let guest_connection_ids = session
2031 .db()
2032 .await
2033 .update_repository(&request, session.connection_id)
2034 .await?;
2035
2036 broadcast(
2037 Some(session.connection_id),
2038 guest_connection_ids.iter().copied(),
2039 |connection_id| {
2040 session
2041 .peer
2042 .forward_send(session.connection_id, connection_id, request.clone())
2043 },
2044 );
2045 response.send(proto::Ack {})?;
2046 Ok(())
2047}
2048
2049async fn remove_repository(
2050 request: proto::RemoveRepository,
2051 response: Response<proto::RemoveRepository>,
2052 session: Session,
2053) -> Result<()> {
2054 let guest_connection_ids = session
2055 .db()
2056 .await
2057 .remove_repository(&request, session.connection_id)
2058 .await?;
2059
2060 broadcast(
2061 Some(session.connection_id),
2062 guest_connection_ids.iter().copied(),
2063 |connection_id| {
2064 session
2065 .peer
2066 .forward_send(session.connection_id, connection_id, request.clone())
2067 },
2068 );
2069 response.send(proto::Ack {})?;
2070 Ok(())
2071}
2072
2073/// Updates other participants with changes to the diagnostics
2074async fn update_diagnostic_summary(
2075 message: proto::UpdateDiagnosticSummary,
2076 session: Session,
2077) -> Result<()> {
2078 let guest_connection_ids = session
2079 .db()
2080 .await
2081 .update_diagnostic_summary(&message, session.connection_id)
2082 .await?;
2083
2084 broadcast(
2085 Some(session.connection_id),
2086 guest_connection_ids.iter().copied(),
2087 |connection_id| {
2088 session
2089 .peer
2090 .forward_send(session.connection_id, connection_id, message.clone())
2091 },
2092 );
2093
2094 Ok(())
2095}
2096
2097/// Updates other participants with changes to the worktree settings
2098async fn update_worktree_settings(
2099 message: proto::UpdateWorktreeSettings,
2100 session: Session,
2101) -> Result<()> {
2102 let guest_connection_ids = session
2103 .db()
2104 .await
2105 .update_worktree_settings(&message, session.connection_id)
2106 .await?;
2107
2108 broadcast(
2109 Some(session.connection_id),
2110 guest_connection_ids.iter().copied(),
2111 |connection_id| {
2112 session
2113 .peer
2114 .forward_send(session.connection_id, connection_id, message.clone())
2115 },
2116 );
2117
2118 Ok(())
2119}
2120
2121/// Notify other participants that a language server has started.
2122async fn start_language_server(
2123 request: proto::StartLanguageServer,
2124 session: Session,
2125) -> Result<()> {
2126 let guest_connection_ids = session
2127 .db()
2128 .await
2129 .start_language_server(&request, session.connection_id)
2130 .await?;
2131
2132 broadcast(
2133 Some(session.connection_id),
2134 guest_connection_ids.iter().copied(),
2135 |connection_id| {
2136 session
2137 .peer
2138 .forward_send(session.connection_id, connection_id, request.clone())
2139 },
2140 );
2141 Ok(())
2142}
2143
2144/// Notify other participants that a language server has changed.
2145async fn update_language_server(
2146 request: proto::UpdateLanguageServer,
2147 session: Session,
2148) -> Result<()> {
2149 let project_id = ProjectId::from_proto(request.project_id);
2150 let project_connection_ids = session
2151 .db()
2152 .await
2153 .project_connection_ids(project_id, session.connection_id, true)
2154 .await?;
2155 broadcast(
2156 Some(session.connection_id),
2157 project_connection_ids.iter().copied(),
2158 |connection_id| {
2159 session
2160 .peer
2161 .forward_send(session.connection_id, connection_id, request.clone())
2162 },
2163 );
2164 Ok(())
2165}
2166
2167/// forward a project request to the host. These requests should be read only
2168/// as guests are allowed to send them.
2169async fn forward_read_only_project_request<T>(
2170 request: T,
2171 response: Response<T>,
2172 session: Session,
2173) -> Result<()>
2174where
2175 T: EntityMessage + RequestMessage,
2176{
2177 let project_id = ProjectId::from_proto(request.remote_entity_id());
2178 let host_connection_id = session
2179 .db()
2180 .await
2181 .host_for_read_only_project_request(project_id, session.connection_id)
2182 .await?;
2183 let payload = session
2184 .peer
2185 .forward_request(session.connection_id, host_connection_id, request)
2186 .await?;
2187 response.send(payload)?;
2188 Ok(())
2189}
2190
2191async fn forward_find_search_candidates_request(
2192 request: proto::FindSearchCandidates,
2193 response: Response<proto::FindSearchCandidates>,
2194 session: Session,
2195) -> Result<()> {
2196 let project_id = ProjectId::from_proto(request.remote_entity_id());
2197 let host_connection_id = session
2198 .db()
2199 .await
2200 .host_for_read_only_project_request(project_id, session.connection_id)
2201 .await?;
2202 let payload = session
2203 .peer
2204 .forward_request(session.connection_id, host_connection_id, request)
2205 .await?;
2206 response.send(payload)?;
2207 Ok(())
2208}
2209
2210/// forward a project request to the host. These requests are disallowed
2211/// for guests.
2212async fn forward_mutating_project_request<T>(
2213 request: T,
2214 response: Response<T>,
2215 session: Session,
2216) -> Result<()>
2217where
2218 T: EntityMessage + RequestMessage,
2219{
2220 let project_id = ProjectId::from_proto(request.remote_entity_id());
2221
2222 let host_connection_id = session
2223 .db()
2224 .await
2225 .host_for_mutating_project_request(project_id, session.connection_id)
2226 .await?;
2227 let payload = session
2228 .peer
2229 .forward_request(session.connection_id, host_connection_id, request)
2230 .await?;
2231 response.send(payload)?;
2232 Ok(())
2233}
2234
2235/// Notify other participants that a new buffer has been created
2236async fn create_buffer_for_peer(
2237 request: proto::CreateBufferForPeer,
2238 session: Session,
2239) -> Result<()> {
2240 session
2241 .db()
2242 .await
2243 .check_user_is_project_host(
2244 ProjectId::from_proto(request.project_id),
2245 session.connection_id,
2246 )
2247 .await?;
2248 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2249 session
2250 .peer
2251 .forward_send(session.connection_id, peer_id.into(), request)?;
2252 Ok(())
2253}
2254
2255/// Notify other participants that a buffer has been updated. This is
2256/// allowed for guests as long as the update is limited to selections.
2257async fn update_buffer(
2258 request: proto::UpdateBuffer,
2259 response: Response<proto::UpdateBuffer>,
2260 session: Session,
2261) -> Result<()> {
2262 let project_id = ProjectId::from_proto(request.project_id);
2263 let mut capability = Capability::ReadOnly;
2264
2265 for op in request.operations.iter() {
2266 match op.variant {
2267 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2268 Some(_) => capability = Capability::ReadWrite,
2269 }
2270 }
2271
2272 let host = {
2273 let guard = session
2274 .db()
2275 .await
2276 .connections_for_buffer_update(project_id, session.connection_id, capability)
2277 .await?;
2278
2279 let (host, guests) = &*guard;
2280
2281 broadcast(
2282 Some(session.connection_id),
2283 guests.clone(),
2284 |connection_id| {
2285 session
2286 .peer
2287 .forward_send(session.connection_id, connection_id, request.clone())
2288 },
2289 );
2290
2291 *host
2292 };
2293
2294 if host != session.connection_id {
2295 session
2296 .peer
2297 .forward_request(session.connection_id, host, request.clone())
2298 .await?;
2299 }
2300
2301 response.send(proto::Ack {})?;
2302 Ok(())
2303}
2304
2305async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2306 let project_id = ProjectId::from_proto(message.project_id);
2307
2308 let operation = message.operation.as_ref().context("invalid operation")?;
2309 let capability = match operation.variant.as_ref() {
2310 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2311 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2312 match buffer_op.variant {
2313 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2314 Capability::ReadOnly
2315 }
2316 _ => Capability::ReadWrite,
2317 }
2318 } else {
2319 Capability::ReadWrite
2320 }
2321 }
2322 Some(_) => Capability::ReadWrite,
2323 None => Capability::ReadOnly,
2324 };
2325
2326 let guard = session
2327 .db()
2328 .await
2329 .connections_for_buffer_update(project_id, session.connection_id, capability)
2330 .await?;
2331
2332 let (host, guests) = &*guard;
2333
2334 broadcast(
2335 Some(session.connection_id),
2336 guests.iter().chain([host]).copied(),
2337 |connection_id| {
2338 session
2339 .peer
2340 .forward_send(session.connection_id, connection_id, message.clone())
2341 },
2342 );
2343
2344 Ok(())
2345}
2346
2347/// Notify other participants that a project has been updated.
2348async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2349 request: T,
2350 session: Session,
2351) -> Result<()> {
2352 let project_id = ProjectId::from_proto(request.remote_entity_id());
2353 let project_connection_ids = session
2354 .db()
2355 .await
2356 .project_connection_ids(project_id, session.connection_id, false)
2357 .await?;
2358
2359 broadcast(
2360 Some(session.connection_id),
2361 project_connection_ids.iter().copied(),
2362 |connection_id| {
2363 session
2364 .peer
2365 .forward_send(session.connection_id, connection_id, request.clone())
2366 },
2367 );
2368 Ok(())
2369}
2370
2371/// Start following another user in a call.
2372async fn follow(
2373 request: proto::Follow,
2374 response: Response<proto::Follow>,
2375 session: Session,
2376) -> Result<()> {
2377 let room_id = RoomId::from_proto(request.room_id);
2378 let project_id = request.project_id.map(ProjectId::from_proto);
2379 let leader_id = request
2380 .leader_id
2381 .ok_or_else(|| anyhow!("invalid leader id"))?
2382 .into();
2383 let follower_id = session.connection_id;
2384
2385 session
2386 .db()
2387 .await
2388 .check_room_participants(room_id, leader_id, session.connection_id)
2389 .await?;
2390
2391 let response_payload = session
2392 .peer
2393 .forward_request(session.connection_id, leader_id, request)
2394 .await?;
2395 response.send(response_payload)?;
2396
2397 if let Some(project_id) = project_id {
2398 let room = session
2399 .db()
2400 .await
2401 .follow(room_id, project_id, leader_id, follower_id)
2402 .await?;
2403 room_updated(&room, &session.peer);
2404 }
2405
2406 Ok(())
2407}
2408
2409/// Stop following another user in a call.
2410async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2411 let room_id = RoomId::from_proto(request.room_id);
2412 let project_id = request.project_id.map(ProjectId::from_proto);
2413 let leader_id = request
2414 .leader_id
2415 .ok_or_else(|| anyhow!("invalid leader id"))?
2416 .into();
2417 let follower_id = session.connection_id;
2418
2419 session
2420 .db()
2421 .await
2422 .check_room_participants(room_id, leader_id, session.connection_id)
2423 .await?;
2424
2425 session
2426 .peer
2427 .forward_send(session.connection_id, leader_id, request)?;
2428
2429 if let Some(project_id) = project_id {
2430 let room = session
2431 .db()
2432 .await
2433 .unfollow(room_id, project_id, leader_id, follower_id)
2434 .await?;
2435 room_updated(&room, &session.peer);
2436 }
2437
2438 Ok(())
2439}
2440
2441/// Notify everyone following you of your current location.
2442async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2443 let room_id = RoomId::from_proto(request.room_id);
2444 let database = session.db.lock().await;
2445
2446 let connection_ids = if let Some(project_id) = request.project_id {
2447 let project_id = ProjectId::from_proto(project_id);
2448 database
2449 .project_connection_ids(project_id, session.connection_id, true)
2450 .await?
2451 } else {
2452 database
2453 .room_connection_ids(room_id, session.connection_id)
2454 .await?
2455 };
2456
2457 // For now, don't send view update messages back to that view's current leader.
2458 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2459 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2460 _ => None,
2461 });
2462
2463 for connection_id in connection_ids.iter().cloned() {
2464 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2465 session
2466 .peer
2467 .forward_send(session.connection_id, connection_id, request.clone())?;
2468 }
2469 }
2470 Ok(())
2471}
2472
2473/// Get public data about users.
2474async fn get_users(
2475 request: proto::GetUsers,
2476 response: Response<proto::GetUsers>,
2477 session: Session,
2478) -> Result<()> {
2479 let user_ids = request
2480 .user_ids
2481 .into_iter()
2482 .map(UserId::from_proto)
2483 .collect();
2484 let users = session
2485 .db()
2486 .await
2487 .get_users_by_ids(user_ids)
2488 .await?
2489 .into_iter()
2490 .map(|user| proto::User {
2491 id: user.id.to_proto(),
2492 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2493 github_login: user.github_login,
2494 email: user.email_address,
2495 name: user.name,
2496 })
2497 .collect();
2498 response.send(proto::UsersResponse { users })?;
2499 Ok(())
2500}
2501
2502/// Search for users (to invite) buy Github login
2503async fn fuzzy_search_users(
2504 request: proto::FuzzySearchUsers,
2505 response: Response<proto::FuzzySearchUsers>,
2506 session: Session,
2507) -> Result<()> {
2508 let query = request.query;
2509 let users = match query.len() {
2510 0 => vec![],
2511 1 | 2 => session
2512 .db()
2513 .await
2514 .get_user_by_github_login(&query)
2515 .await?
2516 .into_iter()
2517 .collect(),
2518 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2519 };
2520 let users = users
2521 .into_iter()
2522 .filter(|user| user.id != session.user_id())
2523 .map(|user| proto::User {
2524 id: user.id.to_proto(),
2525 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526 github_login: user.github_login,
2527 name: user.name,
2528 email: user.email_address,
2529 })
2530 .collect();
2531 response.send(proto::UsersResponse { users })?;
2532 Ok(())
2533}
2534
2535/// Send a contact request to another user.
2536async fn request_contact(
2537 request: proto::RequestContact,
2538 response: Response<proto::RequestContact>,
2539 session: Session,
2540) -> Result<()> {
2541 let requester_id = session.user_id();
2542 let responder_id = UserId::from_proto(request.responder_id);
2543 if requester_id == responder_id {
2544 return Err(anyhow!("cannot add yourself as a contact"))?;
2545 }
2546
2547 let notifications = session
2548 .db()
2549 .await
2550 .send_contact_request(requester_id, responder_id)
2551 .await?;
2552
2553 // Update outgoing contact requests of requester
2554 let mut update = proto::UpdateContacts::default();
2555 update.outgoing_requests.push(responder_id.to_proto());
2556 for connection_id in session
2557 .connection_pool()
2558 .await
2559 .user_connection_ids(requester_id)
2560 {
2561 session.peer.send(connection_id, update.clone())?;
2562 }
2563
2564 // Update incoming contact requests of responder
2565 let mut update = proto::UpdateContacts::default();
2566 update
2567 .incoming_requests
2568 .push(proto::IncomingContactRequest {
2569 requester_id: requester_id.to_proto(),
2570 });
2571 let connection_pool = session.connection_pool().await;
2572 for connection_id in connection_pool.user_connection_ids(responder_id) {
2573 session.peer.send(connection_id, update.clone())?;
2574 }
2575
2576 send_notifications(&connection_pool, &session.peer, notifications);
2577
2578 response.send(proto::Ack {})?;
2579 Ok(())
2580}
2581
2582/// Accept or decline a contact request
2583async fn respond_to_contact_request(
2584 request: proto::RespondToContactRequest,
2585 response: Response<proto::RespondToContactRequest>,
2586 session: Session,
2587) -> Result<()> {
2588 let responder_id = session.user_id();
2589 let requester_id = UserId::from_proto(request.requester_id);
2590 let db = session.db().await;
2591 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2592 db.dismiss_contact_notification(responder_id, requester_id)
2593 .await?;
2594 } else {
2595 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2596
2597 let notifications = db
2598 .respond_to_contact_request(responder_id, requester_id, accept)
2599 .await?;
2600 let requester_busy = db.is_user_busy(requester_id).await?;
2601 let responder_busy = db.is_user_busy(responder_id).await?;
2602
2603 let pool = session.connection_pool().await;
2604 // Update responder with new contact
2605 let mut update = proto::UpdateContacts::default();
2606 if accept {
2607 update
2608 .contacts
2609 .push(contact_for_user(requester_id, requester_busy, &pool));
2610 }
2611 update
2612 .remove_incoming_requests
2613 .push(requester_id.to_proto());
2614 for connection_id in pool.user_connection_ids(responder_id) {
2615 session.peer.send(connection_id, update.clone())?;
2616 }
2617
2618 // Update requester with new contact
2619 let mut update = proto::UpdateContacts::default();
2620 if accept {
2621 update
2622 .contacts
2623 .push(contact_for_user(responder_id, responder_busy, &pool));
2624 }
2625 update
2626 .remove_outgoing_requests
2627 .push(responder_id.to_proto());
2628
2629 for connection_id in pool.user_connection_ids(requester_id) {
2630 session.peer.send(connection_id, update.clone())?;
2631 }
2632
2633 send_notifications(&pool, &session.peer, notifications);
2634 }
2635
2636 response.send(proto::Ack {})?;
2637 Ok(())
2638}
2639
2640/// Remove a contact.
2641async fn remove_contact(
2642 request: proto::RemoveContact,
2643 response: Response<proto::RemoveContact>,
2644 session: Session,
2645) -> Result<()> {
2646 let requester_id = session.user_id();
2647 let responder_id = UserId::from_proto(request.user_id);
2648 let db = session.db().await;
2649 let (contact_accepted, deleted_notification_id) =
2650 db.remove_contact(requester_id, responder_id).await?;
2651
2652 let pool = session.connection_pool().await;
2653 // Update outgoing contact requests of requester
2654 let mut update = proto::UpdateContacts::default();
2655 if contact_accepted {
2656 update.remove_contacts.push(responder_id.to_proto());
2657 } else {
2658 update
2659 .remove_outgoing_requests
2660 .push(responder_id.to_proto());
2661 }
2662 for connection_id in pool.user_connection_ids(requester_id) {
2663 session.peer.send(connection_id, update.clone())?;
2664 }
2665
2666 // Update incoming contact requests of responder
2667 let mut update = proto::UpdateContacts::default();
2668 if contact_accepted {
2669 update.remove_contacts.push(requester_id.to_proto());
2670 } else {
2671 update
2672 .remove_incoming_requests
2673 .push(requester_id.to_proto());
2674 }
2675 for connection_id in pool.user_connection_ids(responder_id) {
2676 session.peer.send(connection_id, update.clone())?;
2677 if let Some(notification_id) = deleted_notification_id {
2678 session.peer.send(
2679 connection_id,
2680 proto::DeleteNotification {
2681 notification_id: notification_id.to_proto(),
2682 },
2683 )?;
2684 }
2685 }
2686
2687 response.send(proto::Ack {})?;
2688 Ok(())
2689}
2690
2691fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2692 version.0.minor() < 139
2693}
2694
2695async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2696 if is_staff {
2697 return Ok(proto::Plan::ZedPro);
2698 }
2699
2700 let subscription = db.get_active_billing_subscription(user_id).await?;
2701 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2702
2703 let plan = if let Some(subscription_kind) = subscription_kind {
2704 match subscription_kind {
2705 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2706 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2707 SubscriptionKind::ZedFree => proto::Plan::Free,
2708 }
2709 } else {
2710 proto::Plan::Free
2711 };
2712
2713 Ok(plan)
2714}
2715
2716async fn make_update_user_plan_message(
2717 db: &Arc<Database>,
2718 llm_db: Option<Arc<LlmDatabase>>,
2719 user_id: UserId,
2720 is_staff: bool,
2721) -> Result<proto::UpdateUserPlan> {
2722 let feature_flags = db.get_user_flags(user_id).await?;
2723 let plan = current_plan(db, user_id, is_staff).await?;
2724 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2725 let billing_preferences = db.get_billing_preferences(user_id).await?;
2726
2727 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2728 let subscription = db.get_active_billing_subscription(user_id).await?;
2729
2730 let subscription_period =
2731 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2732
2733 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2734 llm_db
2735 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2736 .await?
2737 } else {
2738 None
2739 };
2740
2741 (subscription_period, usage)
2742 } else {
2743 (None, None)
2744 };
2745
2746 Ok(proto::UpdateUserPlan {
2747 plan: plan.into(),
2748 trial_started_at: billing_customer
2749 .and_then(|billing_customer| billing_customer.trial_started_at)
2750 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2751 is_usage_based_billing_enabled: if is_staff {
2752 Some(true)
2753 } else {
2754 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2755 },
2756 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2757 proto::SubscriptionPeriod {
2758 started_at: started_at.timestamp() as u64,
2759 ended_at: ended_at.timestamp() as u64,
2760 }
2761 }),
2762 usage: usage.map(|usage| {
2763 let plan = match plan {
2764 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2765 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2766 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2767 };
2768
2769 let model_requests_limit = match plan.model_requests_limit() {
2770 zed_llm_client::UsageLimit::Limited(limit) => {
2771 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2772 && feature_flags
2773 .iter()
2774 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2775 {
2776 1_000
2777 } else {
2778 limit
2779 };
2780
2781 zed_llm_client::UsageLimit::Limited(limit)
2782 }
2783 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2784 };
2785
2786 proto::SubscriptionUsage {
2787 model_requests_usage_amount: usage.model_requests as u32,
2788 model_requests_usage_limit: Some(proto::UsageLimit {
2789 variant: Some(match model_requests_limit {
2790 zed_llm_client::UsageLimit::Limited(limit) => {
2791 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2792 limit: limit as u32,
2793 })
2794 }
2795 zed_llm_client::UsageLimit::Unlimited => {
2796 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2797 }
2798 }),
2799 }),
2800 edit_predictions_usage_amount: usage.edit_predictions as u32,
2801 edit_predictions_usage_limit: Some(proto::UsageLimit {
2802 variant: Some(match plan.edit_predictions_limit() {
2803 zed_llm_client::UsageLimit::Limited(limit) => {
2804 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2805 limit: limit as u32,
2806 })
2807 }
2808 zed_llm_client::UsageLimit::Unlimited => {
2809 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2810 }
2811 }),
2812 }),
2813 }
2814 }),
2815 })
2816}
2817
2818async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2819 let db = session.db().await;
2820
2821 let update_user_plan = make_update_user_plan_message(
2822 &db.0,
2823 session.app_state.llm_db.clone(),
2824 user_id,
2825 session.is_staff(),
2826 )
2827 .await?;
2828
2829 session
2830 .peer
2831 .send(session.connection_id, update_user_plan)
2832 .trace_err();
2833
2834 Ok(())
2835}
2836
2837async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2838 subscribe_user_to_channels(session.user_id(), &session).await?;
2839 Ok(())
2840}
2841
2842async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2843 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2844 let mut pool = session.connection_pool().await;
2845 for membership in &channels_for_user.channel_memberships {
2846 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2847 }
2848 session.peer.send(
2849 session.connection_id,
2850 build_update_user_channels(&channels_for_user),
2851 )?;
2852 session.peer.send(
2853 session.connection_id,
2854 build_channels_update(channels_for_user),
2855 )?;
2856 Ok(())
2857}
2858
2859/// Creates a new channel.
2860async fn create_channel(
2861 request: proto::CreateChannel,
2862 response: Response<proto::CreateChannel>,
2863 session: Session,
2864) -> Result<()> {
2865 let db = session.db().await;
2866
2867 let parent_id = request.parent_id.map(ChannelId::from_proto);
2868 let (channel, membership) = db
2869 .create_channel(&request.name, parent_id, session.user_id())
2870 .await?;
2871
2872 let root_id = channel.root_id();
2873 let channel = Channel::from_model(channel);
2874
2875 response.send(proto::CreateChannelResponse {
2876 channel: Some(channel.to_proto()),
2877 parent_id: request.parent_id,
2878 })?;
2879
2880 let mut connection_pool = session.connection_pool().await;
2881 if let Some(membership) = membership {
2882 connection_pool.subscribe_to_channel(
2883 membership.user_id,
2884 membership.channel_id,
2885 membership.role,
2886 );
2887 let update = proto::UpdateUserChannels {
2888 channel_memberships: vec![proto::ChannelMembership {
2889 channel_id: membership.channel_id.to_proto(),
2890 role: membership.role.into(),
2891 }],
2892 ..Default::default()
2893 };
2894 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2895 session.peer.send(connection_id, update.clone())?;
2896 }
2897 }
2898
2899 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2900 if !role.can_see_channel(channel.visibility) {
2901 continue;
2902 }
2903
2904 let update = proto::UpdateChannels {
2905 channels: vec![channel.to_proto()],
2906 ..Default::default()
2907 };
2908 session.peer.send(connection_id, update.clone())?;
2909 }
2910
2911 Ok(())
2912}
2913
2914/// Delete a channel
2915async fn delete_channel(
2916 request: proto::DeleteChannel,
2917 response: Response<proto::DeleteChannel>,
2918 session: Session,
2919) -> Result<()> {
2920 let db = session.db().await;
2921
2922 let channel_id = request.channel_id;
2923 let (root_channel, removed_channels) = db
2924 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2925 .await?;
2926 response.send(proto::Ack {})?;
2927
2928 // Notify members of removed channels
2929 let mut update = proto::UpdateChannels::default();
2930 update
2931 .delete_channels
2932 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2933
2934 let connection_pool = session.connection_pool().await;
2935 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2936 session.peer.send(connection_id, update.clone())?;
2937 }
2938
2939 Ok(())
2940}
2941
2942/// Invite someone to join a channel.
2943async fn invite_channel_member(
2944 request: proto::InviteChannelMember,
2945 response: Response<proto::InviteChannelMember>,
2946 session: Session,
2947) -> Result<()> {
2948 let db = session.db().await;
2949 let channel_id = ChannelId::from_proto(request.channel_id);
2950 let invitee_id = UserId::from_proto(request.user_id);
2951 let InviteMemberResult {
2952 channel,
2953 notifications,
2954 } = db
2955 .invite_channel_member(
2956 channel_id,
2957 invitee_id,
2958 session.user_id(),
2959 request.role().into(),
2960 )
2961 .await?;
2962
2963 let update = proto::UpdateChannels {
2964 channel_invitations: vec![channel.to_proto()],
2965 ..Default::default()
2966 };
2967
2968 let connection_pool = session.connection_pool().await;
2969 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2970 session.peer.send(connection_id, update.clone())?;
2971 }
2972
2973 send_notifications(&connection_pool, &session.peer, notifications);
2974
2975 response.send(proto::Ack {})?;
2976 Ok(())
2977}
2978
2979/// remove someone from a channel
2980async fn remove_channel_member(
2981 request: proto::RemoveChannelMember,
2982 response: Response<proto::RemoveChannelMember>,
2983 session: Session,
2984) -> Result<()> {
2985 let db = session.db().await;
2986 let channel_id = ChannelId::from_proto(request.channel_id);
2987 let member_id = UserId::from_proto(request.user_id);
2988
2989 let RemoveChannelMemberResult {
2990 membership_update,
2991 notification_id,
2992 } = db
2993 .remove_channel_member(channel_id, member_id, session.user_id())
2994 .await?;
2995
2996 let mut connection_pool = session.connection_pool().await;
2997 notify_membership_updated(
2998 &mut connection_pool,
2999 membership_update,
3000 member_id,
3001 &session.peer,
3002 );
3003 for connection_id in connection_pool.user_connection_ids(member_id) {
3004 if let Some(notification_id) = notification_id {
3005 session
3006 .peer
3007 .send(
3008 connection_id,
3009 proto::DeleteNotification {
3010 notification_id: notification_id.to_proto(),
3011 },
3012 )
3013 .trace_err();
3014 }
3015 }
3016
3017 response.send(proto::Ack {})?;
3018 Ok(())
3019}
3020
3021/// Toggle the channel between public and private.
3022/// Care is taken to maintain the invariant that public channels only descend from public channels,
3023/// (though members-only channels can appear at any point in the hierarchy).
3024async fn set_channel_visibility(
3025 request: proto::SetChannelVisibility,
3026 response: Response<proto::SetChannelVisibility>,
3027 session: Session,
3028) -> Result<()> {
3029 let db = session.db().await;
3030 let channel_id = ChannelId::from_proto(request.channel_id);
3031 let visibility = request.visibility().into();
3032
3033 let channel_model = db
3034 .set_channel_visibility(channel_id, visibility, session.user_id())
3035 .await?;
3036 let root_id = channel_model.root_id();
3037 let channel = Channel::from_model(channel_model);
3038
3039 let mut connection_pool = session.connection_pool().await;
3040 for (user_id, role) in connection_pool
3041 .channel_user_ids(root_id)
3042 .collect::<Vec<_>>()
3043 .into_iter()
3044 {
3045 let update = if role.can_see_channel(channel.visibility) {
3046 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3047 proto::UpdateChannels {
3048 channels: vec![channel.to_proto()],
3049 ..Default::default()
3050 }
3051 } else {
3052 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3053 proto::UpdateChannels {
3054 delete_channels: vec![channel.id.to_proto()],
3055 ..Default::default()
3056 }
3057 };
3058
3059 for connection_id in connection_pool.user_connection_ids(user_id) {
3060 session.peer.send(connection_id, update.clone())?;
3061 }
3062 }
3063
3064 response.send(proto::Ack {})?;
3065 Ok(())
3066}
3067
3068/// Alter the role for a user in the channel.
3069async fn set_channel_member_role(
3070 request: proto::SetChannelMemberRole,
3071 response: Response<proto::SetChannelMemberRole>,
3072 session: Session,
3073) -> Result<()> {
3074 let db = session.db().await;
3075 let channel_id = ChannelId::from_proto(request.channel_id);
3076 let member_id = UserId::from_proto(request.user_id);
3077 let result = db
3078 .set_channel_member_role(
3079 channel_id,
3080 session.user_id(),
3081 member_id,
3082 request.role().into(),
3083 )
3084 .await?;
3085
3086 match result {
3087 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3088 let mut connection_pool = session.connection_pool().await;
3089 notify_membership_updated(
3090 &mut connection_pool,
3091 membership_update,
3092 member_id,
3093 &session.peer,
3094 )
3095 }
3096 db::SetMemberRoleResult::InviteUpdated(channel) => {
3097 let update = proto::UpdateChannels {
3098 channel_invitations: vec![channel.to_proto()],
3099 ..Default::default()
3100 };
3101
3102 for connection_id in session
3103 .connection_pool()
3104 .await
3105 .user_connection_ids(member_id)
3106 {
3107 session.peer.send(connection_id, update.clone())?;
3108 }
3109 }
3110 }
3111
3112 response.send(proto::Ack {})?;
3113 Ok(())
3114}
3115
3116/// Change the name of a channel
3117async fn rename_channel(
3118 request: proto::RenameChannel,
3119 response: Response<proto::RenameChannel>,
3120 session: Session,
3121) -> Result<()> {
3122 let db = session.db().await;
3123 let channel_id = ChannelId::from_proto(request.channel_id);
3124 let channel_model = db
3125 .rename_channel(channel_id, session.user_id(), &request.name)
3126 .await?;
3127 let root_id = channel_model.root_id();
3128 let channel = Channel::from_model(channel_model);
3129
3130 response.send(proto::RenameChannelResponse {
3131 channel: Some(channel.to_proto()),
3132 })?;
3133
3134 let connection_pool = session.connection_pool().await;
3135 let update = proto::UpdateChannels {
3136 channels: vec![channel.to_proto()],
3137 ..Default::default()
3138 };
3139 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3140 if role.can_see_channel(channel.visibility) {
3141 session.peer.send(connection_id, update.clone())?;
3142 }
3143 }
3144
3145 Ok(())
3146}
3147
3148/// Move a channel to a new parent.
3149async fn move_channel(
3150 request: proto::MoveChannel,
3151 response: Response<proto::MoveChannel>,
3152 session: Session,
3153) -> Result<()> {
3154 let channel_id = ChannelId::from_proto(request.channel_id);
3155 let to = ChannelId::from_proto(request.to);
3156
3157 let (root_id, channels) = session
3158 .db()
3159 .await
3160 .move_channel(channel_id, to, session.user_id())
3161 .await?;
3162
3163 let connection_pool = session.connection_pool().await;
3164 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3165 let channels = channels
3166 .iter()
3167 .filter_map(|channel| {
3168 if role.can_see_channel(channel.visibility) {
3169 Some(channel.to_proto())
3170 } else {
3171 None
3172 }
3173 })
3174 .collect::<Vec<_>>();
3175 if channels.is_empty() {
3176 continue;
3177 }
3178
3179 let update = proto::UpdateChannels {
3180 channels,
3181 ..Default::default()
3182 };
3183
3184 session.peer.send(connection_id, update.clone())?;
3185 }
3186
3187 response.send(Ack {})?;
3188 Ok(())
3189}
3190
3191/// Get the list of channel members
3192async fn get_channel_members(
3193 request: proto::GetChannelMembers,
3194 response: Response<proto::GetChannelMembers>,
3195 session: Session,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199 let limit = if request.limit == 0 {
3200 u16::MAX as u64
3201 } else {
3202 request.limit
3203 };
3204 let (members, users) = db
3205 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3206 .await?;
3207 response.send(proto::GetChannelMembersResponse { members, users })?;
3208 Ok(())
3209}
3210
3211/// Accept or decline a channel invitation.
3212async fn respond_to_channel_invite(
3213 request: proto::RespondToChannelInvite,
3214 response: Response<proto::RespondToChannelInvite>,
3215 session: Session,
3216) -> Result<()> {
3217 let db = session.db().await;
3218 let channel_id = ChannelId::from_proto(request.channel_id);
3219 let RespondToChannelInvite {
3220 membership_update,
3221 notifications,
3222 } = db
3223 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3224 .await?;
3225
3226 let mut connection_pool = session.connection_pool().await;
3227 if let Some(membership_update) = membership_update {
3228 notify_membership_updated(
3229 &mut connection_pool,
3230 membership_update,
3231 session.user_id(),
3232 &session.peer,
3233 );
3234 } else {
3235 let update = proto::UpdateChannels {
3236 remove_channel_invitations: vec![channel_id.to_proto()],
3237 ..Default::default()
3238 };
3239
3240 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3241 session.peer.send(connection_id, update.clone())?;
3242 }
3243 };
3244
3245 send_notifications(&connection_pool, &session.peer, notifications);
3246
3247 response.send(proto::Ack {})?;
3248
3249 Ok(())
3250}
3251
3252/// Join the channels' room
3253async fn join_channel(
3254 request: proto::JoinChannel,
3255 response: Response<proto::JoinChannel>,
3256 session: Session,
3257) -> Result<()> {
3258 let channel_id = ChannelId::from_proto(request.channel_id);
3259 join_channel_internal(channel_id, Box::new(response), session).await
3260}
3261
3262trait JoinChannelInternalResponse {
3263 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3264}
3265impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3266 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3267 Response::<proto::JoinChannel>::send(self, result)
3268 }
3269}
3270impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3271 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3272 Response::<proto::JoinRoom>::send(self, result)
3273 }
3274}
3275
3276async fn join_channel_internal(
3277 channel_id: ChannelId,
3278 response: Box<impl JoinChannelInternalResponse>,
3279 session: Session,
3280) -> Result<()> {
3281 let joined_room = {
3282 let mut db = session.db().await;
3283 // If zed quits without leaving the room, and the user re-opens zed before the
3284 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3285 // room they were in.
3286 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3287 tracing::info!(
3288 stale_connection_id = %connection,
3289 "cleaning up stale connection",
3290 );
3291 drop(db);
3292 leave_room_for_session(&session, connection).await?;
3293 db = session.db().await;
3294 }
3295
3296 let (joined_room, membership_updated, role) = db
3297 .join_channel(channel_id, session.user_id(), session.connection_id)
3298 .await?;
3299
3300 let live_kit_connection_info =
3301 session
3302 .app_state
3303 .livekit_client
3304 .as_ref()
3305 .and_then(|live_kit| {
3306 let (can_publish, token) = if role == ChannelRole::Guest {
3307 (
3308 false,
3309 live_kit
3310 .guest_token(
3311 &joined_room.room.livekit_room,
3312 &session.user_id().to_string(),
3313 )
3314 .trace_err()?,
3315 )
3316 } else {
3317 (
3318 true,
3319 live_kit
3320 .room_token(
3321 &joined_room.room.livekit_room,
3322 &session.user_id().to_string(),
3323 )
3324 .trace_err()?,
3325 )
3326 };
3327
3328 Some(LiveKitConnectionInfo {
3329 server_url: live_kit.url().into(),
3330 token,
3331 can_publish,
3332 })
3333 });
3334
3335 response.send(proto::JoinRoomResponse {
3336 room: Some(joined_room.room.clone()),
3337 channel_id: joined_room
3338 .channel
3339 .as_ref()
3340 .map(|channel| channel.id.to_proto()),
3341 live_kit_connection_info,
3342 })?;
3343
3344 let mut connection_pool = session.connection_pool().await;
3345 if let Some(membership_updated) = membership_updated {
3346 notify_membership_updated(
3347 &mut connection_pool,
3348 membership_updated,
3349 session.user_id(),
3350 &session.peer,
3351 );
3352 }
3353
3354 room_updated(&joined_room.room, &session.peer);
3355
3356 joined_room
3357 };
3358
3359 channel_updated(
3360 &joined_room
3361 .channel
3362 .ok_or_else(|| anyhow!("channel not returned"))?,
3363 &joined_room.room,
3364 &session.peer,
3365 &*session.connection_pool().await,
3366 );
3367
3368 update_user_contacts(session.user_id(), &session).await?;
3369 Ok(())
3370}
3371
3372/// Start editing the channel notes
3373async fn join_channel_buffer(
3374 request: proto::JoinChannelBuffer,
3375 response: Response<proto::JoinChannelBuffer>,
3376 session: Session,
3377) -> Result<()> {
3378 let db = session.db().await;
3379 let channel_id = ChannelId::from_proto(request.channel_id);
3380
3381 let open_response = db
3382 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3383 .await?;
3384
3385 let collaborators = open_response.collaborators.clone();
3386 response.send(open_response)?;
3387
3388 let update = UpdateChannelBufferCollaborators {
3389 channel_id: channel_id.to_proto(),
3390 collaborators: collaborators.clone(),
3391 };
3392 channel_buffer_updated(
3393 session.connection_id,
3394 collaborators
3395 .iter()
3396 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3397 &update,
3398 &session.peer,
3399 );
3400
3401 Ok(())
3402}
3403
3404/// Edit the channel notes
3405async fn update_channel_buffer(
3406 request: proto::UpdateChannelBuffer,
3407 session: Session,
3408) -> Result<()> {
3409 let db = session.db().await;
3410 let channel_id = ChannelId::from_proto(request.channel_id);
3411
3412 let (collaborators, epoch, version) = db
3413 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3414 .await?;
3415
3416 channel_buffer_updated(
3417 session.connection_id,
3418 collaborators.clone(),
3419 &proto::UpdateChannelBuffer {
3420 channel_id: channel_id.to_proto(),
3421 operations: request.operations,
3422 },
3423 &session.peer,
3424 );
3425
3426 let pool = &*session.connection_pool().await;
3427
3428 let non_collaborators =
3429 pool.channel_connection_ids(channel_id)
3430 .filter_map(|(connection_id, _)| {
3431 if collaborators.contains(&connection_id) {
3432 None
3433 } else {
3434 Some(connection_id)
3435 }
3436 });
3437
3438 broadcast(None, non_collaborators, |peer_id| {
3439 session.peer.send(
3440 peer_id,
3441 proto::UpdateChannels {
3442 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3443 channel_id: channel_id.to_proto(),
3444 epoch: epoch as u64,
3445 version: version.clone(),
3446 }],
3447 ..Default::default()
3448 },
3449 )
3450 });
3451
3452 Ok(())
3453}
3454
3455/// Rejoin the channel notes after a connection blip
3456async fn rejoin_channel_buffers(
3457 request: proto::RejoinChannelBuffers,
3458 response: Response<proto::RejoinChannelBuffers>,
3459 session: Session,
3460) -> Result<()> {
3461 let db = session.db().await;
3462 let buffers = db
3463 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3464 .await?;
3465
3466 for rejoined_buffer in &buffers {
3467 let collaborators_to_notify = rejoined_buffer
3468 .buffer
3469 .collaborators
3470 .iter()
3471 .filter_map(|c| Some(c.peer_id?.into()));
3472 channel_buffer_updated(
3473 session.connection_id,
3474 collaborators_to_notify,
3475 &proto::UpdateChannelBufferCollaborators {
3476 channel_id: rejoined_buffer.buffer.channel_id,
3477 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3478 },
3479 &session.peer,
3480 );
3481 }
3482
3483 response.send(proto::RejoinChannelBuffersResponse {
3484 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3485 })?;
3486
3487 Ok(())
3488}
3489
3490/// Stop editing the channel notes
3491async fn leave_channel_buffer(
3492 request: proto::LeaveChannelBuffer,
3493 response: Response<proto::LeaveChannelBuffer>,
3494 session: Session,
3495) -> Result<()> {
3496 let db = session.db().await;
3497 let channel_id = ChannelId::from_proto(request.channel_id);
3498
3499 let left_buffer = db
3500 .leave_channel_buffer(channel_id, session.connection_id)
3501 .await?;
3502
3503 response.send(Ack {})?;
3504
3505 channel_buffer_updated(
3506 session.connection_id,
3507 left_buffer.connections,
3508 &proto::UpdateChannelBufferCollaborators {
3509 channel_id: channel_id.to_proto(),
3510 collaborators: left_buffer.collaborators,
3511 },
3512 &session.peer,
3513 );
3514
3515 Ok(())
3516}
3517
3518fn channel_buffer_updated<T: EnvelopedMessage>(
3519 sender_id: ConnectionId,
3520 collaborators: impl IntoIterator<Item = ConnectionId>,
3521 message: &T,
3522 peer: &Peer,
3523) {
3524 broadcast(Some(sender_id), collaborators, |peer_id| {
3525 peer.send(peer_id, message.clone())
3526 });
3527}
3528
3529fn send_notifications(
3530 connection_pool: &ConnectionPool,
3531 peer: &Peer,
3532 notifications: db::NotificationBatch,
3533) {
3534 for (user_id, notification) in notifications {
3535 for connection_id in connection_pool.user_connection_ids(user_id) {
3536 if let Err(error) = peer.send(
3537 connection_id,
3538 proto::AddNotification {
3539 notification: Some(notification.clone()),
3540 },
3541 ) {
3542 tracing::error!(
3543 "failed to send notification to {:?} {}",
3544 connection_id,
3545 error
3546 );
3547 }
3548 }
3549 }
3550}
3551
3552/// Send a message to the channel
3553async fn send_channel_message(
3554 request: proto::SendChannelMessage,
3555 response: Response<proto::SendChannelMessage>,
3556 session: Session,
3557) -> Result<()> {
3558 // Validate the message body.
3559 let body = request.body.trim().to_string();
3560 if body.len() > MAX_MESSAGE_LEN {
3561 return Err(anyhow!("message is too long"))?;
3562 }
3563 if body.is_empty() {
3564 return Err(anyhow!("message can't be blank"))?;
3565 }
3566
3567 // TODO: adjust mentions if body is trimmed
3568
3569 let timestamp = OffsetDateTime::now_utc();
3570 let nonce = request
3571 .nonce
3572 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3573
3574 let channel_id = ChannelId::from_proto(request.channel_id);
3575 let CreatedChannelMessage {
3576 message_id,
3577 participant_connection_ids,
3578 notifications,
3579 } = session
3580 .db()
3581 .await
3582 .create_channel_message(
3583 channel_id,
3584 session.user_id(),
3585 &body,
3586 &request.mentions,
3587 timestamp,
3588 nonce.clone().into(),
3589 request.reply_to_message_id.map(MessageId::from_proto),
3590 )
3591 .await?;
3592
3593 let message = proto::ChannelMessage {
3594 sender_id: session.user_id().to_proto(),
3595 id: message_id.to_proto(),
3596 body,
3597 mentions: request.mentions,
3598 timestamp: timestamp.unix_timestamp() as u64,
3599 nonce: Some(nonce),
3600 reply_to_message_id: request.reply_to_message_id,
3601 edited_at: None,
3602 };
3603 broadcast(
3604 Some(session.connection_id),
3605 participant_connection_ids.clone(),
3606 |connection| {
3607 session.peer.send(
3608 connection,
3609 proto::ChannelMessageSent {
3610 channel_id: channel_id.to_proto(),
3611 message: Some(message.clone()),
3612 },
3613 )
3614 },
3615 );
3616 response.send(proto::SendChannelMessageResponse {
3617 message: Some(message),
3618 })?;
3619
3620 let pool = &*session.connection_pool().await;
3621 let non_participants =
3622 pool.channel_connection_ids(channel_id)
3623 .filter_map(|(connection_id, _)| {
3624 if participant_connection_ids.contains(&connection_id) {
3625 None
3626 } else {
3627 Some(connection_id)
3628 }
3629 });
3630 broadcast(None, non_participants, |peer_id| {
3631 session.peer.send(
3632 peer_id,
3633 proto::UpdateChannels {
3634 latest_channel_message_ids: vec![proto::ChannelMessageId {
3635 channel_id: channel_id.to_proto(),
3636 message_id: message_id.to_proto(),
3637 }],
3638 ..Default::default()
3639 },
3640 )
3641 });
3642 send_notifications(pool, &session.peer, notifications);
3643
3644 Ok(())
3645}
3646
3647/// Delete a channel message
3648async fn remove_channel_message(
3649 request: proto::RemoveChannelMessage,
3650 response: Response<proto::RemoveChannelMessage>,
3651 session: Session,
3652) -> Result<()> {
3653 let channel_id = ChannelId::from_proto(request.channel_id);
3654 let message_id = MessageId::from_proto(request.message_id);
3655 let (connection_ids, existing_notification_ids) = session
3656 .db()
3657 .await
3658 .remove_channel_message(channel_id, message_id, session.user_id())
3659 .await?;
3660
3661 broadcast(
3662 Some(session.connection_id),
3663 connection_ids,
3664 move |connection| {
3665 session.peer.send(connection, request.clone())?;
3666
3667 for notification_id in &existing_notification_ids {
3668 session.peer.send(
3669 connection,
3670 proto::DeleteNotification {
3671 notification_id: (*notification_id).to_proto(),
3672 },
3673 )?;
3674 }
3675
3676 Ok(())
3677 },
3678 );
3679 response.send(proto::Ack {})?;
3680 Ok(())
3681}
3682
3683async fn update_channel_message(
3684 request: proto::UpdateChannelMessage,
3685 response: Response<proto::UpdateChannelMessage>,
3686 session: Session,
3687) -> Result<()> {
3688 let channel_id = ChannelId::from_proto(request.channel_id);
3689 let message_id = MessageId::from_proto(request.message_id);
3690 let updated_at = OffsetDateTime::now_utc();
3691 let UpdatedChannelMessage {
3692 message_id,
3693 participant_connection_ids,
3694 notifications,
3695 reply_to_message_id,
3696 timestamp,
3697 deleted_mention_notification_ids,
3698 updated_mention_notifications,
3699 } = session
3700 .db()
3701 .await
3702 .update_channel_message(
3703 channel_id,
3704 message_id,
3705 session.user_id(),
3706 request.body.as_str(),
3707 &request.mentions,
3708 updated_at,
3709 )
3710 .await?;
3711
3712 let nonce = request
3713 .nonce
3714 .clone()
3715 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3716
3717 let message = proto::ChannelMessage {
3718 sender_id: session.user_id().to_proto(),
3719 id: message_id.to_proto(),
3720 body: request.body.clone(),
3721 mentions: request.mentions.clone(),
3722 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3723 nonce: Some(nonce),
3724 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3725 edited_at: Some(updated_at.unix_timestamp() as u64),
3726 };
3727
3728 response.send(proto::Ack {})?;
3729
3730 let pool = &*session.connection_pool().await;
3731 broadcast(
3732 Some(session.connection_id),
3733 participant_connection_ids,
3734 |connection| {
3735 session.peer.send(
3736 connection,
3737 proto::ChannelMessageUpdate {
3738 channel_id: channel_id.to_proto(),
3739 message: Some(message.clone()),
3740 },
3741 )?;
3742
3743 for notification_id in &deleted_mention_notification_ids {
3744 session.peer.send(
3745 connection,
3746 proto::DeleteNotification {
3747 notification_id: (*notification_id).to_proto(),
3748 },
3749 )?;
3750 }
3751
3752 for notification in &updated_mention_notifications {
3753 session.peer.send(
3754 connection,
3755 proto::UpdateNotification {
3756 notification: Some(notification.clone()),
3757 },
3758 )?;
3759 }
3760
3761 Ok(())
3762 },
3763 );
3764
3765 send_notifications(pool, &session.peer, notifications);
3766
3767 Ok(())
3768}
3769
3770/// Mark a channel message as read
3771async fn acknowledge_channel_message(
3772 request: proto::AckChannelMessage,
3773 session: Session,
3774) -> Result<()> {
3775 let channel_id = ChannelId::from_proto(request.channel_id);
3776 let message_id = MessageId::from_proto(request.message_id);
3777 let notifications = session
3778 .db()
3779 .await
3780 .observe_channel_message(channel_id, session.user_id(), message_id)
3781 .await?;
3782 send_notifications(
3783 &*session.connection_pool().await,
3784 &session.peer,
3785 notifications,
3786 );
3787 Ok(())
3788}
3789
3790/// Mark a buffer version as synced
3791async fn acknowledge_buffer_version(
3792 request: proto::AckBufferOperation,
3793 session: Session,
3794) -> Result<()> {
3795 let buffer_id = BufferId::from_proto(request.buffer_id);
3796 session
3797 .db()
3798 .await
3799 .observe_buffer_version(
3800 buffer_id,
3801 session.user_id(),
3802 request.epoch as i32,
3803 &request.version,
3804 )
3805 .await?;
3806 Ok(())
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811 _request: proto::GetSupermavenApiKey,
3812 response: Response<proto::GetSupermavenApiKey>,
3813 session: Session,
3814) -> Result<()> {
3815 let user_id: String = session.user_id().to_string();
3816 if !session.is_staff() {
3817 return Err(anyhow!("supermaven not enabled for this account"))?;
3818 }
3819
3820 let email = session
3821 .email()
3822 .ok_or_else(|| anyhow!("user must have an email"))?;
3823
3824 let supermaven_admin_api = session
3825 .supermaven_client
3826 .as_ref()
3827 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3828
3829 let result = supermaven_admin_api
3830 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3831 .await?;
3832
3833 response.send(proto::GetSupermavenApiKeyResponse {
3834 api_key: result.api_key,
3835 })?;
3836
3837 Ok(())
3838}
3839
3840/// Start receiving chat updates for a channel
3841async fn join_channel_chat(
3842 request: proto::JoinChannelChat,
3843 response: Response<proto::JoinChannelChat>,
3844 session: Session,
3845) -> Result<()> {
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847
3848 let db = session.db().await;
3849 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3850 .await?;
3851 let messages = db
3852 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3853 .await?;
3854 response.send(proto::JoinChannelChatResponse {
3855 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3856 messages,
3857 })?;
3858 Ok(())
3859}
3860
3861/// Stop receiving chat updates for a channel
3862async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3863 let channel_id = ChannelId::from_proto(request.channel_id);
3864 session
3865 .db()
3866 .await
3867 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3868 .await?;
3869 Ok(())
3870}
3871
3872/// Retrieve the chat history for a channel
3873async fn get_channel_messages(
3874 request: proto::GetChannelMessages,
3875 response: Response<proto::GetChannelMessages>,
3876 session: Session,
3877) -> Result<()> {
3878 let channel_id = ChannelId::from_proto(request.channel_id);
3879 let messages = session
3880 .db()
3881 .await
3882 .get_channel_messages(
3883 channel_id,
3884 session.user_id(),
3885 MESSAGE_COUNT_PER_PAGE,
3886 Some(MessageId::from_proto(request.before_message_id)),
3887 )
3888 .await?;
3889 response.send(proto::GetChannelMessagesResponse {
3890 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891 messages,
3892 })?;
3893 Ok(())
3894}
3895
3896/// Retrieve specific chat messages
3897async fn get_channel_messages_by_id(
3898 request: proto::GetChannelMessagesById,
3899 response: Response<proto::GetChannelMessagesById>,
3900 session: Session,
3901) -> Result<()> {
3902 let message_ids = request
3903 .message_ids
3904 .iter()
3905 .map(|id| MessageId::from_proto(*id))
3906 .collect::<Vec<_>>();
3907 let messages = session
3908 .db()
3909 .await
3910 .get_channel_messages_by_id(session.user_id(), &message_ids)
3911 .await?;
3912 response.send(proto::GetChannelMessagesResponse {
3913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914 messages,
3915 })?;
3916 Ok(())
3917}
3918
3919/// Retrieve the current users notifications
3920async fn get_notifications(
3921 request: proto::GetNotifications,
3922 response: Response<proto::GetNotifications>,
3923 session: Session,
3924) -> Result<()> {
3925 let notifications = session
3926 .db()
3927 .await
3928 .get_notifications(
3929 session.user_id(),
3930 NOTIFICATION_COUNT_PER_PAGE,
3931 request.before_id.map(db::NotificationId::from_proto),
3932 )
3933 .await?;
3934 response.send(proto::GetNotificationsResponse {
3935 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3936 notifications,
3937 })?;
3938 Ok(())
3939}
3940
3941/// Mark notifications as read
3942async fn mark_notification_as_read(
3943 request: proto::MarkNotificationRead,
3944 response: Response<proto::MarkNotificationRead>,
3945 session: Session,
3946) -> Result<()> {
3947 let database = &session.db().await;
3948 let notifications = database
3949 .mark_notification_as_read_by_id(
3950 session.user_id(),
3951 NotificationId::from_proto(request.notification_id),
3952 )
3953 .await?;
3954 send_notifications(
3955 &*session.connection_pool().await,
3956 &session.peer,
3957 notifications,
3958 );
3959 response.send(proto::Ack {})?;
3960 Ok(())
3961}
3962
3963/// Get the current users information
3964async fn get_private_user_info(
3965 _request: proto::GetPrivateUserInfo,
3966 response: Response<proto::GetPrivateUserInfo>,
3967 session: Session,
3968) -> Result<()> {
3969 let db = session.db().await;
3970
3971 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3972 let user = db
3973 .get_user_by_id(session.user_id())
3974 .await?
3975 .ok_or_else(|| anyhow!("user not found"))?;
3976 let flags = db.get_user_flags(session.user_id()).await?;
3977
3978 response.send(proto::GetPrivateUserInfoResponse {
3979 metrics_id,
3980 staff: user.admin,
3981 flags,
3982 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3983 })?;
3984 Ok(())
3985}
3986
3987/// Accept the terms of service (tos) on behalf of the current user
3988async fn accept_terms_of_service(
3989 _request: proto::AcceptTermsOfService,
3990 response: Response<proto::AcceptTermsOfService>,
3991 session: Session,
3992) -> Result<()> {
3993 let db = session.db().await;
3994
3995 let accepted_tos_at = Utc::now();
3996 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3997 .await?;
3998
3999 response.send(proto::AcceptTermsOfServiceResponse {
4000 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4001 })?;
4002 Ok(())
4003}
4004
4005/// The minimum account age an account must have in order to use the LLM service.
4006pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4007
4008async fn get_llm_api_token(
4009 _request: proto::GetLlmToken,
4010 response: Response<proto::GetLlmToken>,
4011 session: Session,
4012) -> Result<()> {
4013 let db = session.db().await;
4014
4015 let flags = db.get_user_flags(session.user_id()).await?;
4016
4017 let user_id = session.user_id();
4018 let user = db
4019 .get_user_by_id(user_id)
4020 .await?
4021 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4022
4023 if user.accepted_tos_at.is_none() {
4024 Err(anyhow!("terms of service not accepted"))?
4025 }
4026
4027 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4028 let billing_preferences = db.get_billing_preferences(user.id).await?;
4029
4030 let token = LlmTokenClaims::create(
4031 &user,
4032 session.is_staff(),
4033 billing_preferences,
4034 &flags,
4035 billing_subscription,
4036 session.system_id.clone(),
4037 &session.app_state.config,
4038 )?;
4039 response.send(proto::GetLlmTokenResponse { token })?;
4040 Ok(())
4041}
4042
4043fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4044 let message = match message {
4045 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4046 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4047 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4048 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4049 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4050 code: frame.code.into(),
4051 reason: frame.reason.as_str().to_owned().into(),
4052 })),
4053 // We should never receive a frame while reading the message, according
4054 // to the `tungstenite` maintainers:
4055 //
4056 // > It cannot occur when you read messages from the WebSocket, but it
4057 // > can be used when you want to send the raw frames (e.g. you want to
4058 // > send the frames to the WebSocket without composing the full message first).
4059 // >
4060 // > — https://github.com/snapview/tungstenite-rs/issues/268
4061 TungsteniteMessage::Frame(_) => {
4062 bail!("received an unexpected frame while reading the message")
4063 }
4064 };
4065
4066 Ok(message)
4067}
4068
4069fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4070 match message {
4071 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4072 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4073 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4074 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4075 AxumMessage::Close(frame) => {
4076 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4077 code: frame.code.into(),
4078 reason: frame.reason.as_ref().into(),
4079 }))
4080 }
4081 }
4082}
4083
4084fn notify_membership_updated(
4085 connection_pool: &mut ConnectionPool,
4086 result: MembershipUpdated,
4087 user_id: UserId,
4088 peer: &Peer,
4089) {
4090 for membership in &result.new_channels.channel_memberships {
4091 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4092 }
4093 for channel_id in &result.removed_channels {
4094 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4095 }
4096
4097 let user_channels_update = proto::UpdateUserChannels {
4098 channel_memberships: result
4099 .new_channels
4100 .channel_memberships
4101 .iter()
4102 .map(|cm| proto::ChannelMembership {
4103 channel_id: cm.channel_id.to_proto(),
4104 role: cm.role.into(),
4105 })
4106 .collect(),
4107 ..Default::default()
4108 };
4109
4110 let mut update = build_channels_update(result.new_channels);
4111 update.delete_channels = result
4112 .removed_channels
4113 .into_iter()
4114 .map(|id| id.to_proto())
4115 .collect();
4116 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4117
4118 for connection_id in connection_pool.user_connection_ids(user_id) {
4119 peer.send(connection_id, user_channels_update.clone())
4120 .trace_err();
4121 peer.send(connection_id, update.clone()).trace_err();
4122 }
4123}
4124
4125fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4126 proto::UpdateUserChannels {
4127 channel_memberships: channels
4128 .channel_memberships
4129 .iter()
4130 .map(|m| proto::ChannelMembership {
4131 channel_id: m.channel_id.to_proto(),
4132 role: m.role.into(),
4133 })
4134 .collect(),
4135 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4136 observed_channel_message_id: channels.observed_channel_messages.clone(),
4137 }
4138}
4139
4140fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4141 let mut update = proto::UpdateChannels::default();
4142
4143 for channel in channels.channels {
4144 update.channels.push(channel.to_proto());
4145 }
4146
4147 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4148 update.latest_channel_message_ids = channels.latest_channel_messages;
4149
4150 for (channel_id, participants) in channels.channel_participants {
4151 update
4152 .channel_participants
4153 .push(proto::ChannelParticipants {
4154 channel_id: channel_id.to_proto(),
4155 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4156 });
4157 }
4158
4159 for channel in channels.invited_channels {
4160 update.channel_invitations.push(channel.to_proto());
4161 }
4162
4163 update
4164}
4165
4166fn build_initial_contacts_update(
4167 contacts: Vec<db::Contact>,
4168 pool: &ConnectionPool,
4169) -> proto::UpdateContacts {
4170 let mut update = proto::UpdateContacts::default();
4171
4172 for contact in contacts {
4173 match contact {
4174 db::Contact::Accepted { user_id, busy } => {
4175 update.contacts.push(contact_for_user(user_id, busy, pool));
4176 }
4177 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4178 db::Contact::Incoming { user_id } => {
4179 update
4180 .incoming_requests
4181 .push(proto::IncomingContactRequest {
4182 requester_id: user_id.to_proto(),
4183 })
4184 }
4185 }
4186 }
4187
4188 update
4189}
4190
4191fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4192 proto::Contact {
4193 user_id: user_id.to_proto(),
4194 online: pool.is_user_online(user_id),
4195 busy,
4196 }
4197}
4198
4199fn room_updated(room: &proto::Room, peer: &Peer) {
4200 broadcast(
4201 None,
4202 room.participants
4203 .iter()
4204 .filter_map(|participant| Some(participant.peer_id?.into())),
4205 |peer_id| {
4206 peer.send(
4207 peer_id,
4208 proto::RoomUpdated {
4209 room: Some(room.clone()),
4210 },
4211 )
4212 },
4213 );
4214}
4215
4216fn channel_updated(
4217 channel: &db::channel::Model,
4218 room: &proto::Room,
4219 peer: &Peer,
4220 pool: &ConnectionPool,
4221) {
4222 let participants = room
4223 .participants
4224 .iter()
4225 .map(|p| p.user_id)
4226 .collect::<Vec<_>>();
4227
4228 broadcast(
4229 None,
4230 pool.channel_connection_ids(channel.root_id())
4231 .filter_map(|(channel_id, role)| {
4232 role.can_see_channel(channel.visibility)
4233 .then_some(channel_id)
4234 }),
4235 |peer_id| {
4236 peer.send(
4237 peer_id,
4238 proto::UpdateChannels {
4239 channel_participants: vec![proto::ChannelParticipants {
4240 channel_id: channel.id.to_proto(),
4241 participant_user_ids: participants.clone(),
4242 }],
4243 ..Default::default()
4244 },
4245 )
4246 },
4247 );
4248}
4249
4250async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4251 let db = session.db().await;
4252
4253 let contacts = db.get_contacts(user_id).await?;
4254 let busy = db.is_user_busy(user_id).await?;
4255
4256 let pool = session.connection_pool().await;
4257 let updated_contact = contact_for_user(user_id, busy, &pool);
4258 for contact in contacts {
4259 if let db::Contact::Accepted {
4260 user_id: contact_user_id,
4261 ..
4262 } = contact
4263 {
4264 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4265 session
4266 .peer
4267 .send(
4268 contact_conn_id,
4269 proto::UpdateContacts {
4270 contacts: vec![updated_contact.clone()],
4271 remove_contacts: Default::default(),
4272 incoming_requests: Default::default(),
4273 remove_incoming_requests: Default::default(),
4274 outgoing_requests: Default::default(),
4275 remove_outgoing_requests: Default::default(),
4276 },
4277 )
4278 .trace_err();
4279 }
4280 }
4281 }
4282 Ok(())
4283}
4284
4285async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4286 let mut contacts_to_update = HashSet::default();
4287
4288 let room_id;
4289 let canceled_calls_to_user_ids;
4290 let livekit_room;
4291 let delete_livekit_room;
4292 let room;
4293 let channel;
4294
4295 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4296 contacts_to_update.insert(session.user_id());
4297
4298 for project in left_room.left_projects.values() {
4299 project_left(project, session);
4300 }
4301
4302 room_id = RoomId::from_proto(left_room.room.id);
4303 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4304 livekit_room = mem::take(&mut left_room.room.livekit_room);
4305 delete_livekit_room = left_room.deleted;
4306 room = mem::take(&mut left_room.room);
4307 channel = mem::take(&mut left_room.channel);
4308
4309 room_updated(&room, &session.peer);
4310 } else {
4311 return Ok(());
4312 }
4313
4314 if let Some(channel) = channel {
4315 channel_updated(
4316 &channel,
4317 &room,
4318 &session.peer,
4319 &*session.connection_pool().await,
4320 );
4321 }
4322
4323 {
4324 let pool = session.connection_pool().await;
4325 for canceled_user_id in canceled_calls_to_user_ids {
4326 for connection_id in pool.user_connection_ids(canceled_user_id) {
4327 session
4328 .peer
4329 .send(
4330 connection_id,
4331 proto::CallCanceled {
4332 room_id: room_id.to_proto(),
4333 },
4334 )
4335 .trace_err();
4336 }
4337 contacts_to_update.insert(canceled_user_id);
4338 }
4339 }
4340
4341 for contact_user_id in contacts_to_update {
4342 update_user_contacts(contact_user_id, session).await?;
4343 }
4344
4345 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4346 live_kit
4347 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4348 .await
4349 .trace_err();
4350
4351 if delete_livekit_room {
4352 live_kit.delete_room(livekit_room).await.trace_err();
4353 }
4354 }
4355
4356 Ok(())
4357}
4358
4359async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4360 let left_channel_buffers = session
4361 .db()
4362 .await
4363 .leave_channel_buffers(session.connection_id)
4364 .await?;
4365
4366 for left_buffer in left_channel_buffers {
4367 channel_buffer_updated(
4368 session.connection_id,
4369 left_buffer.connections,
4370 &proto::UpdateChannelBufferCollaborators {
4371 channel_id: left_buffer.channel_id.to_proto(),
4372 collaborators: left_buffer.collaborators,
4373 },
4374 &session.peer,
4375 );
4376 }
4377
4378 Ok(())
4379}
4380
4381fn project_left(project: &db::LeftProject, session: &Session) {
4382 for connection_id in &project.connection_ids {
4383 if project.should_unshare {
4384 session
4385 .peer
4386 .send(
4387 *connection_id,
4388 proto::UnshareProject {
4389 project_id: project.id.to_proto(),
4390 },
4391 )
4392 .trace_err();
4393 } else {
4394 session
4395 .peer
4396 .send(
4397 *connection_id,
4398 proto::RemoveProjectCollaborator {
4399 project_id: project.id.to_proto(),
4400 peer_id: Some(session.connection_id.into()),
4401 },
4402 )
4403 .trace_err();
4404 }
4405 }
4406}
4407
4408pub trait ResultExt {
4409 type Ok;
4410
4411 fn trace_err(self) -> Option<Self::Ok>;
4412}
4413
4414impl<T, E> ResultExt for Result<T, E>
4415where
4416 E: std::fmt::Debug,
4417{
4418 type Ok = T;
4419
4420 #[track_caller]
4421 fn trace_err(self) -> Option<T> {
4422 match self {
4423 Ok(value) => Some(value),
4424 Err(error) => {
4425 tracing::error!("{:?}", error);
4426 None
4427 }
4428 }
4429 }
4430}