1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 AppState, Config, Error, RateLimit, Result, auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14};
15use anyhow::{Context as _, anyhow, bail};
16use async_tungstenite::tungstenite::{
17 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
18};
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use http_client::HttpClient;
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{MutexGuard, Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone, Debug)]
105pub enum Principal {
106 User(User),
107 Impersonated { user: User, admin: User },
108}
109
110impl Principal {
111 fn update_span(&self, span: &tracing::Span) {
112 match &self {
113 Principal::User(user) => {
114 span.record("user_id", user.id.0);
115 span.record("login", &user.github_login);
116 }
117 Principal::Impersonated { user, admin } => {
118 span.record("user_id", user.id.0);
119 span.record("login", &user.github_login);
120 span.record("impersonator", &admin.github_login);
121 }
122 }
123 }
124}
125
126#[derive(Clone)]
127struct Session {
128 principal: Principal,
129 connection_id: ConnectionId,
130 db: Arc<tokio::sync::Mutex<DbHandle>>,
131 peer: Arc<Peer>,
132 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
133 app_state: Arc<AppState>,
134 supermaven_client: Option<Arc<SupermavenAdminApi>>,
135 http_client: Arc<dyn HttpClient>,
136 /// The GeoIP country code for the user.
137 #[allow(unused)]
138 geoip_country_code: Option<String>,
139 system_id: Option<String>,
140 _executor: Executor,
141}
142
143impl Session {
144 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
145 #[cfg(test)]
146 tokio::task::yield_now().await;
147 let guard = self.db.lock().await;
148 #[cfg(test)]
149 tokio::task::yield_now().await;
150 guard
151 }
152
153 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.connection_pool.lock();
157 ConnectionPoolGuard {
158 guard,
159 _not_send: PhantomData,
160 }
161 }
162
163 fn is_staff(&self) -> bool {
164 match &self.principal {
165 Principal::User(user) => user.admin,
166 Principal::Impersonated { .. } => true,
167 }
168 }
169
170 pub async fn has_llm_subscription(
171 &self,
172 db: &MutexGuard<'_, DbHandle>,
173 ) -> anyhow::Result<bool> {
174 if self.is_staff() {
175 return Ok(true);
176 }
177
178 let user_id = self.user_id();
179
180 Ok(db.has_active_billing_subscription(user_id).await?)
181 }
182
183 pub async fn current_plan(
184 &self,
185 _db: &MutexGuard<'_, DbHandle>,
186 ) -> anyhow::Result<proto::Plan> {
187 if self.is_staff() {
188 Ok(proto::Plan::ZedPro)
189 } else {
190 Ok(proto::Plan::Free)
191 }
192 }
193
194 fn user_id(&self) -> UserId {
195 match &self.principal {
196 Principal::User(user) => user.id,
197 Principal::Impersonated { user, .. } => user.id,
198 }
199 }
200
201 pub fn email(&self) -> Option<String> {
202 match &self.principal {
203 Principal::User(user) => user.email_address.clone(),
204 Principal::Impersonated { user, .. } => user.email_address.clone(),
205 }
206 }
207}
208
209impl Debug for Session {
210 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211 let mut result = f.debug_struct("Session");
212 match &self.principal {
213 Principal::User(user) => {
214 result.field("user", &user.github_login);
215 }
216 Principal::Impersonated { user, admin } => {
217 result.field("user", &user.github_login);
218 result.field("impersonator", &admin.github_login);
219 }
220 }
221 result.field("connection_id", &self.connection_id).finish()
222 }
223}
224
225struct DbHandle(Arc<Database>);
226
227impl Deref for DbHandle {
228 type Target = Database;
229
230 fn deref(&self) -> &Self::Target {
231 self.0.as_ref()
232 }
233}
234
235pub struct Server {
236 id: parking_lot::Mutex<ServerId>,
237 peer: Arc<Peer>,
238 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
239 app_state: Arc<AppState>,
240 handlers: HashMap<TypeId, MessageHandler>,
241 teardown: watch::Sender<bool>,
242}
243
244pub(crate) struct ConnectionPoolGuard<'a> {
245 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
246 _not_send: PhantomData<Rc<()>>,
247}
248
249#[derive(Serialize)]
250pub struct ServerSnapshot<'a> {
251 peer: &'a Peer,
252 #[serde(serialize_with = "serialize_deref")]
253 connection_pool: ConnectionPoolGuard<'a>,
254}
255
256pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
257where
258 S: Serializer,
259 T: Deref<Target = U>,
260 U: Serialize,
261{
262 Serialize::serialize(value.deref(), serializer)
263}
264
265impl Server {
266 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
267 let mut server = Self {
268 id: parking_lot::Mutex::new(id),
269 peer: Peer::new(id.0 as u32),
270 app_state: app_state.clone(),
271 connection_pool: Default::default(),
272 handlers: Default::default(),
273 teardown: watch::channel(false).0,
274 };
275
276 server
277 .add_request_handler(ping)
278 .add_request_handler(create_room)
279 .add_request_handler(join_room)
280 .add_request_handler(rejoin_room)
281 .add_request_handler(leave_room)
282 .add_request_handler(set_room_participant_role)
283 .add_request_handler(call)
284 .add_request_handler(cancel_call)
285 .add_message_handler(decline_call)
286 .add_request_handler(update_participant_location)
287 .add_request_handler(share_project)
288 .add_message_handler(unshare_project)
289 .add_request_handler(join_project)
290 .add_message_handler(leave_project)
291 .add_request_handler(update_project)
292 .add_request_handler(update_worktree)
293 .add_request_handler(update_repository)
294 .add_request_handler(remove_repository)
295 .add_message_handler(start_language_server)
296 .add_message_handler(update_language_server)
297 .add_message_handler(update_diagnostic_summary)
298 .add_message_handler(update_worktree_settings)
299 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
303 .add_request_handler(forward_find_search_candidates_request)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
308 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
309 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
310 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
311 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
312 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
319 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
320 .add_request_handler(
321 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
322 )
323 .add_request_handler(
324 forward_read_only_project_request::<proto::LanguageServerIdForName>,
325 )
326 .add_request_handler(
327 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
328 )
329 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
330 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
331 .add_request_handler(
332 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
333 )
334 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
335 .add_request_handler(
336 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
337 )
338 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
339 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
340 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
341 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
343 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
344 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
345 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
346 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
352 )
353 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
354 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
355 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
357 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
358 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
359 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
360 .add_message_handler(create_buffer_for_peer)
361 .add_request_handler(update_buffer)
362 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
363 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
364 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
368 .add_request_handler(get_users)
369 .add_request_handler(fuzzy_search_users)
370 .add_request_handler(request_contact)
371 .add_request_handler(remove_contact)
372 .add_request_handler(respond_to_contact_request)
373 .add_message_handler(subscribe_to_channels)
374 .add_request_handler(create_channel)
375 .add_request_handler(delete_channel)
376 .add_request_handler(invite_channel_member)
377 .add_request_handler(remove_channel_member)
378 .add_request_handler(set_channel_member_role)
379 .add_request_handler(set_channel_visibility)
380 .add_request_handler(rename_channel)
381 .add_request_handler(join_channel_buffer)
382 .add_request_handler(leave_channel_buffer)
383 .add_message_handler(update_channel_buffer)
384 .add_request_handler(rejoin_channel_buffers)
385 .add_request_handler(get_channel_members)
386 .add_request_handler(respond_to_channel_invite)
387 .add_request_handler(join_channel)
388 .add_request_handler(join_channel_chat)
389 .add_message_handler(leave_channel_chat)
390 .add_request_handler(send_channel_message)
391 .add_request_handler(remove_channel_message)
392 .add_request_handler(update_channel_message)
393 .add_request_handler(get_channel_messages)
394 .add_request_handler(get_channel_messages_by_id)
395 .add_request_handler(get_notifications)
396 .add_request_handler(mark_notification_as_read)
397 .add_request_handler(move_channel)
398 .add_request_handler(follow)
399 .add_message_handler(unfollow)
400 .add_message_handler(update_followers)
401 .add_request_handler(get_private_user_info)
402 .add_request_handler(get_llm_api_token)
403 .add_request_handler(accept_terms_of_service)
404 .add_message_handler(acknowledge_channel_message)
405 .add_message_handler(acknowledge_buffer_version)
406 .add_request_handler(get_supermaven_api_key)
407 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
408 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
409 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
410 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
411 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
412 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
414 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
415 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
416 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
417 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
418 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
419 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
420 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
421 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
422 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
424 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
425 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
426 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
428 .add_message_handler(update_context)
429 .add_request_handler({
430 let app_state = app_state.clone();
431 move |request, response, session| {
432 let app_state = app_state.clone();
433 async move {
434 count_language_model_tokens(request, response, session, &app_state.config)
435 .await
436 }
437 }
438 });
439
440 Arc::new(server)
441 }
442
443 pub async fn start(&self) -> Result<()> {
444 let server_id = *self.id.lock();
445 let app_state = self.app_state.clone();
446 let peer = self.peer.clone();
447 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
448 let pool = self.connection_pool.clone();
449 let livekit_client = self.app_state.livekit_client.clone();
450
451 let span = info_span!("start server");
452 self.app_state.executor.spawn_detached(
453 async move {
454 tracing::info!("waiting for cleanup timeout");
455 timeout.await;
456 tracing::info!("cleanup timeout expired, retrieving stale rooms");
457 if let Some((room_ids, channel_ids)) = app_state
458 .db
459 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
460 .await
461 .trace_err()
462 {
463 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
464 tracing::info!(
465 stale_channel_buffer_count = channel_ids.len(),
466 "retrieved stale channel buffers"
467 );
468
469 for channel_id in channel_ids {
470 if let Some(refreshed_channel_buffer) = app_state
471 .db
472 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
473 .await
474 .trace_err()
475 {
476 for connection_id in refreshed_channel_buffer.connection_ids {
477 peer.send(
478 connection_id,
479 proto::UpdateChannelBufferCollaborators {
480 channel_id: channel_id.to_proto(),
481 collaborators: refreshed_channel_buffer
482 .collaborators
483 .clone(),
484 },
485 )
486 .trace_err();
487 }
488 }
489 }
490
491 for room_id in room_ids {
492 let mut contacts_to_update = HashSet::default();
493 let mut canceled_calls_to_user_ids = Vec::new();
494 let mut livekit_room = String::new();
495 let mut delete_livekit_room = false;
496
497 if let Some(mut refreshed_room) = app_state
498 .db
499 .clear_stale_room_participants(room_id, server_id)
500 .await
501 .trace_err()
502 {
503 tracing::info!(
504 room_id = room_id.0,
505 new_participant_count = refreshed_room.room.participants.len(),
506 "refreshed room"
507 );
508 room_updated(&refreshed_room.room, &peer);
509 if let Some(channel) = refreshed_room.channel.as_ref() {
510 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
511 }
512 contacts_to_update
513 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
514 contacts_to_update
515 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
516 canceled_calls_to_user_ids =
517 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
518 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
519 delete_livekit_room = refreshed_room.room.participants.is_empty();
520 }
521
522 {
523 let pool = pool.lock();
524 for canceled_user_id in canceled_calls_to_user_ids {
525 for connection_id in pool.user_connection_ids(canceled_user_id) {
526 peer.send(
527 connection_id,
528 proto::CallCanceled {
529 room_id: room_id.to_proto(),
530 },
531 )
532 .trace_err();
533 }
534 }
535 }
536
537 for user_id in contacts_to_update {
538 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
539 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
540 if let Some((busy, contacts)) = busy.zip(contacts) {
541 let pool = pool.lock();
542 let updated_contact = contact_for_user(user_id, busy, &pool);
543 for contact in contacts {
544 if let db::Contact::Accepted {
545 user_id: contact_user_id,
546 ..
547 } = contact
548 {
549 for contact_conn_id in
550 pool.user_connection_ids(contact_user_id)
551 {
552 peer.send(
553 contact_conn_id,
554 proto::UpdateContacts {
555 contacts: vec![updated_contact.clone()],
556 remove_contacts: Default::default(),
557 incoming_requests: Default::default(),
558 remove_incoming_requests: Default::default(),
559 outgoing_requests: Default::default(),
560 remove_outgoing_requests: Default::default(),
561 },
562 )
563 .trace_err();
564 }
565 }
566 }
567 }
568 }
569
570 if let Some(live_kit) = livekit_client.as_ref() {
571 if delete_livekit_room {
572 live_kit.delete_room(livekit_room).await.trace_err();
573 }
574 }
575 }
576 }
577
578 app_state
579 .db
580 .delete_stale_servers(&app_state.config.zed_environment, server_id)
581 .await
582 .trace_err();
583 }
584 .instrument(span),
585 );
586 Ok(())
587 }
588
589 pub fn teardown(&self) {
590 self.peer.teardown();
591 self.connection_pool.lock().reset();
592 let _ = self.teardown.send(true);
593 }
594
595 #[cfg(test)]
596 pub fn reset(&self, id: ServerId) {
597 self.teardown();
598 *self.id.lock() = id;
599 self.peer.reset(id.0 as u32);
600 let _ = self.teardown.send(false);
601 }
602
603 #[cfg(test)]
604 pub fn id(&self) -> ServerId {
605 *self.id.lock()
606 }
607
608 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
609 where
610 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
611 Fut: 'static + Send + Future<Output = Result<()>>,
612 M: EnvelopedMessage,
613 {
614 let prev_handler = self.handlers.insert(
615 TypeId::of::<M>(),
616 Box::new(move |envelope, session| {
617 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
618 let received_at = envelope.received_at;
619 tracing::info!("message received");
620 let start_time = Instant::now();
621 let future = (handler)(*envelope, session);
622 async move {
623 let result = future.await;
624 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
625 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
626 let queue_duration_ms = total_duration_ms - processing_duration_ms;
627 let payload_type = M::NAME;
628
629 match result {
630 Err(error) => {
631 tracing::error!(
632 ?error,
633 total_duration_ms,
634 processing_duration_ms,
635 queue_duration_ms,
636 payload_type,
637 "error handling message"
638 )
639 }
640 Ok(()) => tracing::info!(
641 total_duration_ms,
642 processing_duration_ms,
643 queue_duration_ms,
644 "finished handling message"
645 ),
646 }
647 }
648 .boxed()
649 }),
650 );
651 if prev_handler.is_some() {
652 panic!("registered a handler for the same message twice");
653 }
654 self
655 }
656
657 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
658 where
659 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
660 Fut: 'static + Send + Future<Output = Result<()>>,
661 M: EnvelopedMessage,
662 {
663 self.add_handler(move |envelope, session| handler(envelope.payload, session));
664 self
665 }
666
667 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
668 where
669 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
670 Fut: Send + Future<Output = Result<()>>,
671 M: RequestMessage,
672 {
673 let handler = Arc::new(handler);
674 self.add_handler(move |envelope, session| {
675 let receipt = envelope.receipt();
676 let handler = handler.clone();
677 async move {
678 let peer = session.peer.clone();
679 let responded = Arc::new(AtomicBool::default());
680 let response = Response {
681 peer: peer.clone(),
682 responded: responded.clone(),
683 receipt,
684 };
685 match (handler)(envelope.payload, response, session).await {
686 Ok(()) => {
687 if responded.load(std::sync::atomic::Ordering::SeqCst) {
688 Ok(())
689 } else {
690 Err(anyhow!("handler did not send a response"))?
691 }
692 }
693 Err(error) => {
694 let proto_err = match &error {
695 Error::Internal(err) => err.to_proto(),
696 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
697 };
698 peer.respond_with_error(receipt, proto_err)?;
699 Err(error)
700 }
701 }
702 }
703 })
704 }
705
706 pub fn handle_connection(
707 self: &Arc<Self>,
708 connection: Connection,
709 address: String,
710 principal: Principal,
711 zed_version: ZedVersion,
712 geoip_country_code: Option<String>,
713 system_id: Option<String>,
714 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
715 executor: Executor,
716 ) -> impl Future<Output = ()> + use<> {
717 let this = self.clone();
718 let span = info_span!("handle connection", %address,
719 connection_id=field::Empty,
720 user_id=field::Empty,
721 login=field::Empty,
722 impersonator=field::Empty,
723 geoip_country_code=field::Empty
724 );
725 principal.update_span(&span);
726 if let Some(country_code) = geoip_country_code.as_ref() {
727 span.record("geoip_country_code", country_code);
728 }
729
730 let mut teardown = self.teardown.subscribe();
731 async move {
732 if *teardown.borrow() {
733 tracing::error!("server is tearing down");
734 return
735 }
736 let (connection_id, handle_io, mut incoming_rx) = this
737 .peer
738 .add_connection(connection, {
739 let executor = executor.clone();
740 move |duration| executor.sleep(duration)
741 });
742 tracing::Span::current().record("connection_id", format!("{}", connection_id));
743
744 tracing::info!("connection opened");
745
746 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
747 let http_client = match ReqwestClient::user_agent(&user_agent) {
748 Ok(http_client) => Arc::new(http_client),
749 Err(error) => {
750 tracing::error!(?error, "failed to create HTTP client");
751 return;
752 }
753 };
754
755 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
756 supermaven_admin_api_key.to_string(),
757 http_client.clone(),
758 )));
759
760 let session = Session {
761 principal: principal.clone(),
762 connection_id,
763 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
764 peer: this.peer.clone(),
765 connection_pool: this.connection_pool.clone(),
766 app_state: this.app_state.clone(),
767 http_client,
768 geoip_country_code,
769 system_id,
770 _executor: executor.clone(),
771 supermaven_client,
772 };
773
774 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
775 tracing::error!(?error, "failed to send initial client update");
776 return;
777 }
778
779 let handle_io = handle_io.fuse();
780 futures::pin_mut!(handle_io);
781
782 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
783 // This prevents deadlocks when e.g., client A performs a request to client B and
784 // client B performs a request to client A. If both clients stop processing further
785 // messages until their respective request completes, they won't have a chance to
786 // respond to the other client's request and cause a deadlock.
787 //
788 // This arrangement ensures we will attempt to process earlier messages first, but fall
789 // back to processing messages arrived later in the spirit of making progress.
790 let mut foreground_message_handlers = FuturesUnordered::new();
791 let concurrent_handlers = Arc::new(Semaphore::new(256));
792 loop {
793 let next_message = async {
794 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
795 let message = incoming_rx.next().await;
796 (permit, message)
797 }.fuse();
798 futures::pin_mut!(next_message);
799 futures::select_biased! {
800 _ = teardown.changed().fuse() => return,
801 result = handle_io => {
802 if let Err(error) = result {
803 tracing::error!(?error, "error handling I/O");
804 }
805 break;
806 }
807 _ = foreground_message_handlers.next() => {}
808 next_message = next_message => {
809 let (permit, message) = next_message;
810 if let Some(message) = message {
811 let type_name = message.payload_type_name();
812 // note: we copy all the fields from the parent span so we can query them in the logs.
813 // (https://github.com/tokio-rs/tracing/issues/2670).
814 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
815 user_id=field::Empty,
816 login=field::Empty,
817 impersonator=field::Empty,
818 );
819 principal.update_span(&span);
820 let span_enter = span.enter();
821 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
822 let is_background = message.is_background();
823 let handle_message = (handler)(message, session.clone());
824 drop(span_enter);
825
826 let handle_message = async move {
827 handle_message.await;
828 drop(permit);
829 }.instrument(span);
830 if is_background {
831 executor.spawn_detached(handle_message);
832 } else {
833 foreground_message_handlers.push(handle_message);
834 }
835 } else {
836 tracing::error!("no message handler");
837 }
838 } else {
839 tracing::info!("connection closed");
840 break;
841 }
842 }
843 }
844 }
845
846 drop(foreground_message_handlers);
847 tracing::info!("signing out");
848 if let Err(error) = connection_lost(session, teardown, executor).await {
849 tracing::error!(?error, "error signing out");
850 }
851
852 }.instrument(span)
853 }
854
855 async fn send_initial_client_update(
856 &self,
857 connection_id: ConnectionId,
858 principal: &Principal,
859 zed_version: ZedVersion,
860 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
861 session: &Session,
862 ) -> Result<()> {
863 self.peer.send(
864 connection_id,
865 proto::Hello {
866 peer_id: Some(connection_id.into()),
867 },
868 )?;
869 tracing::info!("sent hello message");
870 if let Some(send_connection_id) = send_connection_id.take() {
871 let _ = send_connection_id.send(connection_id);
872 }
873
874 match principal {
875 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
876 if !user.connected_once {
877 self.peer.send(connection_id, proto::ShowContacts {})?;
878 self.app_state
879 .db
880 .set_user_connected_once(user.id, true)
881 .await?;
882 }
883
884 update_user_plan(user.id, session).await?;
885
886 let contacts = self.app_state.db.get_contacts(user.id).await?;
887
888 {
889 let mut pool = self.connection_pool.lock();
890 pool.add_connection(connection_id, user.id, user.admin, zed_version);
891 self.peer.send(
892 connection_id,
893 build_initial_contacts_update(contacts, &pool),
894 )?;
895 }
896
897 if should_auto_subscribe_to_channels(zed_version) {
898 subscribe_user_to_channels(user.id, session).await?;
899 }
900
901 if let Some(incoming_call) =
902 self.app_state.db.incoming_call_for_user(user.id).await?
903 {
904 self.peer.send(connection_id, incoming_call)?;
905 }
906
907 update_user_contacts(user.id, session).await?;
908 }
909 }
910
911 Ok(())
912 }
913
914 pub async fn invite_code_redeemed(
915 self: &Arc<Self>,
916 inviter_id: UserId,
917 invitee_id: UserId,
918 ) -> Result<()> {
919 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
920 if let Some(code) = &user.invite_code {
921 let pool = self.connection_pool.lock();
922 let invitee_contact = contact_for_user(invitee_id, false, &pool);
923 for connection_id in pool.user_connection_ids(inviter_id) {
924 self.peer.send(
925 connection_id,
926 proto::UpdateContacts {
927 contacts: vec![invitee_contact.clone()],
928 ..Default::default()
929 },
930 )?;
931 self.peer.send(
932 connection_id,
933 proto::UpdateInviteInfo {
934 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
935 count: user.invite_count as u32,
936 },
937 )?;
938 }
939 }
940 }
941 Ok(())
942 }
943
944 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
945 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
946 if let Some(invite_code) = &user.invite_code {
947 let pool = self.connection_pool.lock();
948 for connection_id in pool.user_connection_ids(user_id) {
949 self.peer.send(
950 connection_id,
951 proto::UpdateInviteInfo {
952 url: format!(
953 "{}{}",
954 self.app_state.config.invite_link_prefix, invite_code
955 ),
956 count: user.invite_count as u32,
957 },
958 )?;
959 }
960 }
961 }
962 Ok(())
963 }
964
965 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
966 let pool = self.connection_pool.lock();
967 for connection_id in pool.user_connection_ids(user_id) {
968 self.peer
969 .send(connection_id, proto::RefreshLlmToken {})
970 .trace_err();
971 }
972 }
973
974 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
975 ServerSnapshot {
976 connection_pool: ConnectionPoolGuard {
977 guard: self.connection_pool.lock(),
978 _not_send: PhantomData,
979 },
980 peer: &self.peer,
981 }
982 }
983}
984
985impl Deref for ConnectionPoolGuard<'_> {
986 type Target = ConnectionPool;
987
988 fn deref(&self) -> &Self::Target {
989 &self.guard
990 }
991}
992
993impl DerefMut for ConnectionPoolGuard<'_> {
994 fn deref_mut(&mut self) -> &mut Self::Target {
995 &mut self.guard
996 }
997}
998
999impl Drop for ConnectionPoolGuard<'_> {
1000 fn drop(&mut self) {
1001 #[cfg(test)]
1002 self.check_invariants();
1003 }
1004}
1005
1006fn broadcast<F>(
1007 sender_id: Option<ConnectionId>,
1008 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1009 mut f: F,
1010) where
1011 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1012{
1013 for receiver_id in receiver_ids {
1014 if Some(receiver_id) != sender_id {
1015 if let Err(error) = f(receiver_id) {
1016 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1017 }
1018 }
1019 }
1020}
1021
1022pub struct ProtocolVersion(u32);
1023
1024impl Header for ProtocolVersion {
1025 fn name() -> &'static HeaderName {
1026 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1027 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1028 }
1029
1030 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1031 where
1032 Self: Sized,
1033 I: Iterator<Item = &'i axum::http::HeaderValue>,
1034 {
1035 let version = values
1036 .next()
1037 .ok_or_else(axum::headers::Error::invalid)?
1038 .to_str()
1039 .map_err(|_| axum::headers::Error::invalid())?
1040 .parse()
1041 .map_err(|_| axum::headers::Error::invalid())?;
1042 Ok(Self(version))
1043 }
1044
1045 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1046 values.extend([self.0.to_string().parse().unwrap()]);
1047 }
1048}
1049
1050pub struct AppVersionHeader(SemanticVersion);
1051impl Header for AppVersionHeader {
1052 fn name() -> &'static HeaderName {
1053 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1054 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1055 }
1056
1057 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1058 where
1059 Self: Sized,
1060 I: Iterator<Item = &'i axum::http::HeaderValue>,
1061 {
1062 let version = values
1063 .next()
1064 .ok_or_else(axum::headers::Error::invalid)?
1065 .to_str()
1066 .map_err(|_| axum::headers::Error::invalid())?
1067 .parse()
1068 .map_err(|_| axum::headers::Error::invalid())?;
1069 Ok(Self(version))
1070 }
1071
1072 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1073 values.extend([self.0.to_string().parse().unwrap()]);
1074 }
1075}
1076
1077pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1078 Router::new()
1079 .route("/rpc", get(handle_websocket_request))
1080 .layer(
1081 ServiceBuilder::new()
1082 .layer(Extension(server.app_state.clone()))
1083 .layer(middleware::from_fn(auth::validate_header)),
1084 )
1085 .route("/metrics", get(handle_metrics))
1086 .layer(Extension(server))
1087}
1088
1089pub async fn handle_websocket_request(
1090 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1091 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1092 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1093 Extension(server): Extension<Arc<Server>>,
1094 Extension(principal): Extension<Principal>,
1095 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1096 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1097 ws: WebSocketUpgrade,
1098) -> axum::response::Response {
1099 if protocol_version != rpc::PROTOCOL_VERSION {
1100 return (
1101 StatusCode::UPGRADE_REQUIRED,
1102 "client must be upgraded".to_string(),
1103 )
1104 .into_response();
1105 }
1106
1107 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1108 return (
1109 StatusCode::UPGRADE_REQUIRED,
1110 "no version header found".to_string(),
1111 )
1112 .into_response();
1113 };
1114
1115 if !version.can_collaborate() {
1116 return (
1117 StatusCode::UPGRADE_REQUIRED,
1118 "client must be upgraded".to_string(),
1119 )
1120 .into_response();
1121 }
1122
1123 let socket_address = socket_address.to_string();
1124 ws.on_upgrade(move |socket| {
1125 let socket = socket
1126 .map_ok(to_tungstenite_message)
1127 .err_into()
1128 .with(|message| async move { to_axum_message(message) });
1129 let connection = Connection::new(Box::pin(socket));
1130 async move {
1131 server
1132 .handle_connection(
1133 connection,
1134 socket_address,
1135 principal,
1136 version,
1137 country_code_header.map(|header| header.to_string()),
1138 system_id_header.map(|header| header.to_string()),
1139 None,
1140 Executor::Production,
1141 )
1142 .await;
1143 }
1144 })
1145}
1146
1147pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1148 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1149 let connections_metric = CONNECTIONS_METRIC
1150 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1151
1152 let connections = server
1153 .connection_pool
1154 .lock()
1155 .connections()
1156 .filter(|connection| !connection.admin)
1157 .count();
1158 connections_metric.set(connections as _);
1159
1160 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1161 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1162 register_int_gauge!(
1163 "shared_projects",
1164 "number of open projects with one or more guests"
1165 )
1166 .unwrap()
1167 });
1168
1169 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1170 shared_projects_metric.set(shared_projects as _);
1171
1172 let encoder = prometheus::TextEncoder::new();
1173 let metric_families = prometheus::gather();
1174 let encoded_metrics = encoder
1175 .encode_to_string(&metric_families)
1176 .map_err(|err| anyhow!("{}", err))?;
1177 Ok(encoded_metrics)
1178}
1179
1180#[instrument(err, skip(executor))]
1181async fn connection_lost(
1182 session: Session,
1183 mut teardown: watch::Receiver<bool>,
1184 executor: Executor,
1185) -> Result<()> {
1186 session.peer.disconnect(session.connection_id);
1187 session
1188 .connection_pool()
1189 .await
1190 .remove_connection(session.connection_id)?;
1191
1192 session
1193 .db()
1194 .await
1195 .connection_lost(session.connection_id)
1196 .await
1197 .trace_err();
1198
1199 futures::select_biased! {
1200 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1201
1202 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1203 leave_room_for_session(&session, session.connection_id).await.trace_err();
1204 leave_channel_buffers_for_session(&session)
1205 .await
1206 .trace_err();
1207
1208 if !session
1209 .connection_pool()
1210 .await
1211 .is_user_online(session.user_id())
1212 {
1213 let db = session.db().await;
1214 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1215 room_updated(&room, &session.peer);
1216 }
1217 }
1218
1219 update_user_contacts(session.user_id(), &session).await?;
1220 },
1221 _ = teardown.changed().fuse() => {}
1222 }
1223
1224 Ok(())
1225}
1226
1227/// Acknowledges a ping from a client, used to keep the connection alive.
1228async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1229 response.send(proto::Ack {})?;
1230 Ok(())
1231}
1232
1233/// Creates a new room for calling (outside of channels)
1234async fn create_room(
1235 _request: proto::CreateRoom,
1236 response: Response<proto::CreateRoom>,
1237 session: Session,
1238) -> Result<()> {
1239 let livekit_room = nanoid::nanoid!(30);
1240
1241 let live_kit_connection_info = util::maybe!(async {
1242 let live_kit = session.app_state.livekit_client.as_ref();
1243 let live_kit = live_kit?;
1244 let user_id = session.user_id().to_string();
1245
1246 let token = live_kit
1247 .room_token(&livekit_room, &user_id.to_string())
1248 .trace_err()?;
1249
1250 Some(proto::LiveKitConnectionInfo {
1251 server_url: live_kit.url().into(),
1252 token,
1253 can_publish: true,
1254 })
1255 })
1256 .await;
1257
1258 let room = session
1259 .db()
1260 .await
1261 .create_room(session.user_id(), session.connection_id, &livekit_room)
1262 .await?;
1263
1264 response.send(proto::CreateRoomResponse {
1265 room: Some(room.clone()),
1266 live_kit_connection_info,
1267 })?;
1268
1269 update_user_contacts(session.user_id(), &session).await?;
1270 Ok(())
1271}
1272
1273/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1274async fn join_room(
1275 request: proto::JoinRoom,
1276 response: Response<proto::JoinRoom>,
1277 session: Session,
1278) -> Result<()> {
1279 let room_id = RoomId::from_proto(request.id);
1280
1281 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1282
1283 if let Some(channel_id) = channel_id {
1284 return join_channel_internal(channel_id, Box::new(response), session).await;
1285 }
1286
1287 let joined_room = {
1288 let room = session
1289 .db()
1290 .await
1291 .join_room(room_id, session.user_id(), session.connection_id)
1292 .await?;
1293 room_updated(&room.room, &session.peer);
1294 room.into_inner()
1295 };
1296
1297 for connection_id in session
1298 .connection_pool()
1299 .await
1300 .user_connection_ids(session.user_id())
1301 {
1302 session
1303 .peer
1304 .send(
1305 connection_id,
1306 proto::CallCanceled {
1307 room_id: room_id.to_proto(),
1308 },
1309 )
1310 .trace_err();
1311 }
1312
1313 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1314 {
1315 live_kit
1316 .room_token(
1317 &joined_room.room.livekit_room,
1318 &session.user_id().to_string(),
1319 )
1320 .trace_err()
1321 .map(|token| proto::LiveKitConnectionInfo {
1322 server_url: live_kit.url().into(),
1323 token,
1324 can_publish: true,
1325 })
1326 } else {
1327 None
1328 };
1329
1330 response.send(proto::JoinRoomResponse {
1331 room: Some(joined_room.room),
1332 channel_id: None,
1333 live_kit_connection_info,
1334 })?;
1335
1336 update_user_contacts(session.user_id(), &session).await?;
1337 Ok(())
1338}
1339
1340/// Rejoin room is used to reconnect to a room after connection errors.
1341async fn rejoin_room(
1342 request: proto::RejoinRoom,
1343 response: Response<proto::RejoinRoom>,
1344 session: Session,
1345) -> Result<()> {
1346 let room;
1347 let channel;
1348 {
1349 let mut rejoined_room = session
1350 .db()
1351 .await
1352 .rejoin_room(request, session.user_id(), session.connection_id)
1353 .await?;
1354
1355 response.send(proto::RejoinRoomResponse {
1356 room: Some(rejoined_room.room.clone()),
1357 reshared_projects: rejoined_room
1358 .reshared_projects
1359 .iter()
1360 .map(|project| proto::ResharedProject {
1361 id: project.id.to_proto(),
1362 collaborators: project
1363 .collaborators
1364 .iter()
1365 .map(|collaborator| collaborator.to_proto())
1366 .collect(),
1367 })
1368 .collect(),
1369 rejoined_projects: rejoined_room
1370 .rejoined_projects
1371 .iter()
1372 .map(|rejoined_project| rejoined_project.to_proto())
1373 .collect(),
1374 })?;
1375 room_updated(&rejoined_room.room, &session.peer);
1376
1377 for project in &rejoined_room.reshared_projects {
1378 for collaborator in &project.collaborators {
1379 session
1380 .peer
1381 .send(
1382 collaborator.connection_id,
1383 proto::UpdateProjectCollaborator {
1384 project_id: project.id.to_proto(),
1385 old_peer_id: Some(project.old_connection_id.into()),
1386 new_peer_id: Some(session.connection_id.into()),
1387 },
1388 )
1389 .trace_err();
1390 }
1391
1392 broadcast(
1393 Some(session.connection_id),
1394 project
1395 .collaborators
1396 .iter()
1397 .map(|collaborator| collaborator.connection_id),
1398 |connection_id| {
1399 session.peer.forward_send(
1400 session.connection_id,
1401 connection_id,
1402 proto::UpdateProject {
1403 project_id: project.id.to_proto(),
1404 worktrees: project.worktrees.clone(),
1405 },
1406 )
1407 },
1408 );
1409 }
1410
1411 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1412
1413 let rejoined_room = rejoined_room.into_inner();
1414
1415 room = rejoined_room.room;
1416 channel = rejoined_room.channel;
1417 }
1418
1419 if let Some(channel) = channel {
1420 channel_updated(
1421 &channel,
1422 &room,
1423 &session.peer,
1424 &*session.connection_pool().await,
1425 );
1426 }
1427
1428 update_user_contacts(session.user_id(), &session).await?;
1429 Ok(())
1430}
1431
1432fn notify_rejoined_projects(
1433 rejoined_projects: &mut Vec<RejoinedProject>,
1434 session: &Session,
1435) -> Result<()> {
1436 for project in rejoined_projects.iter() {
1437 for collaborator in &project.collaborators {
1438 session
1439 .peer
1440 .send(
1441 collaborator.connection_id,
1442 proto::UpdateProjectCollaborator {
1443 project_id: project.id.to_proto(),
1444 old_peer_id: Some(project.old_connection_id.into()),
1445 new_peer_id: Some(session.connection_id.into()),
1446 },
1447 )
1448 .trace_err();
1449 }
1450 }
1451
1452 for project in rejoined_projects {
1453 for worktree in mem::take(&mut project.worktrees) {
1454 // Stream this worktree's entries.
1455 let message = proto::UpdateWorktree {
1456 project_id: project.id.to_proto(),
1457 worktree_id: worktree.id,
1458 abs_path: worktree.abs_path.clone(),
1459 root_name: worktree.root_name,
1460 updated_entries: worktree.updated_entries,
1461 removed_entries: worktree.removed_entries,
1462 scan_id: worktree.scan_id,
1463 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1464 updated_repositories: worktree.updated_repositories,
1465 removed_repositories: worktree.removed_repositories,
1466 };
1467 for update in proto::split_worktree_update(message) {
1468 session.peer.send(session.connection_id, update)?;
1469 }
1470
1471 // Stream this worktree's diagnostics.
1472 for summary in worktree.diagnostic_summaries {
1473 session.peer.send(
1474 session.connection_id,
1475 proto::UpdateDiagnosticSummary {
1476 project_id: project.id.to_proto(),
1477 worktree_id: worktree.id,
1478 summary: Some(summary),
1479 },
1480 )?;
1481 }
1482
1483 for settings_file in worktree.settings_files {
1484 session.peer.send(
1485 session.connection_id,
1486 proto::UpdateWorktreeSettings {
1487 project_id: project.id.to_proto(),
1488 worktree_id: worktree.id,
1489 path: settings_file.path,
1490 content: Some(settings_file.content),
1491 kind: Some(settings_file.kind.to_proto().into()),
1492 },
1493 )?;
1494 }
1495 }
1496
1497 for repository in mem::take(&mut project.updated_repositories) {
1498 for update in split_repository_update(repository) {
1499 session.peer.send(session.connection_id, update)?;
1500 }
1501 }
1502
1503 for id in mem::take(&mut project.removed_repositories) {
1504 session.peer.send(
1505 session.connection_id,
1506 proto::RemoveRepository {
1507 project_id: project.id.to_proto(),
1508 id,
1509 },
1510 )?;
1511 }
1512 }
1513
1514 Ok(())
1515}
1516
1517/// leave room disconnects from the room.
1518async fn leave_room(
1519 _: proto::LeaveRoom,
1520 response: Response<proto::LeaveRoom>,
1521 session: Session,
1522) -> Result<()> {
1523 leave_room_for_session(&session, session.connection_id).await?;
1524 response.send(proto::Ack {})?;
1525 Ok(())
1526}
1527
1528/// Updates the permissions of someone else in the room.
1529async fn set_room_participant_role(
1530 request: proto::SetRoomParticipantRole,
1531 response: Response<proto::SetRoomParticipantRole>,
1532 session: Session,
1533) -> Result<()> {
1534 let user_id = UserId::from_proto(request.user_id);
1535 let role = ChannelRole::from(request.role());
1536
1537 let (livekit_room, can_publish) = {
1538 let room = session
1539 .db()
1540 .await
1541 .set_room_participant_role(
1542 session.user_id(),
1543 RoomId::from_proto(request.room_id),
1544 user_id,
1545 role,
1546 )
1547 .await?;
1548
1549 let livekit_room = room.livekit_room.clone();
1550 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1551 room_updated(&room, &session.peer);
1552 (livekit_room, can_publish)
1553 };
1554
1555 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1556 live_kit
1557 .update_participant(
1558 livekit_room.clone(),
1559 request.user_id.to_string(),
1560 livekit_api::proto::ParticipantPermission {
1561 can_subscribe: true,
1562 can_publish,
1563 can_publish_data: can_publish,
1564 hidden: false,
1565 recorder: false,
1566 },
1567 )
1568 .await
1569 .trace_err();
1570 }
1571
1572 response.send(proto::Ack {})?;
1573 Ok(())
1574}
1575
1576/// Call someone else into the current room
1577async fn call(
1578 request: proto::Call,
1579 response: Response<proto::Call>,
1580 session: Session,
1581) -> Result<()> {
1582 let room_id = RoomId::from_proto(request.room_id);
1583 let calling_user_id = session.user_id();
1584 let calling_connection_id = session.connection_id;
1585 let called_user_id = UserId::from_proto(request.called_user_id);
1586 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1587 if !session
1588 .db()
1589 .await
1590 .has_contact(calling_user_id, called_user_id)
1591 .await?
1592 {
1593 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1594 }
1595
1596 let incoming_call = {
1597 let (room, incoming_call) = &mut *session
1598 .db()
1599 .await
1600 .call(
1601 room_id,
1602 calling_user_id,
1603 calling_connection_id,
1604 called_user_id,
1605 initial_project_id,
1606 )
1607 .await?;
1608 room_updated(room, &session.peer);
1609 mem::take(incoming_call)
1610 };
1611 update_user_contacts(called_user_id, &session).await?;
1612
1613 let mut calls = session
1614 .connection_pool()
1615 .await
1616 .user_connection_ids(called_user_id)
1617 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1618 .collect::<FuturesUnordered<_>>();
1619
1620 while let Some(call_response) = calls.next().await {
1621 match call_response.as_ref() {
1622 Ok(_) => {
1623 response.send(proto::Ack {})?;
1624 return Ok(());
1625 }
1626 Err(_) => {
1627 call_response.trace_err();
1628 }
1629 }
1630 }
1631
1632 {
1633 let room = session
1634 .db()
1635 .await
1636 .call_failed(room_id, called_user_id)
1637 .await?;
1638 room_updated(&room, &session.peer);
1639 }
1640 update_user_contacts(called_user_id, &session).await?;
1641
1642 Err(anyhow!("failed to ring user"))?
1643}
1644
1645/// Cancel an outgoing call.
1646async fn cancel_call(
1647 request: proto::CancelCall,
1648 response: Response<proto::CancelCall>,
1649 session: Session,
1650) -> Result<()> {
1651 let called_user_id = UserId::from_proto(request.called_user_id);
1652 let room_id = RoomId::from_proto(request.room_id);
1653 {
1654 let room = session
1655 .db()
1656 .await
1657 .cancel_call(room_id, session.connection_id, called_user_id)
1658 .await?;
1659 room_updated(&room, &session.peer);
1660 }
1661
1662 for connection_id in session
1663 .connection_pool()
1664 .await
1665 .user_connection_ids(called_user_id)
1666 {
1667 session
1668 .peer
1669 .send(
1670 connection_id,
1671 proto::CallCanceled {
1672 room_id: room_id.to_proto(),
1673 },
1674 )
1675 .trace_err();
1676 }
1677 response.send(proto::Ack {})?;
1678
1679 update_user_contacts(called_user_id, &session).await?;
1680 Ok(())
1681}
1682
1683/// Decline an incoming call.
1684async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1685 let room_id = RoomId::from_proto(message.room_id);
1686 {
1687 let room = session
1688 .db()
1689 .await
1690 .decline_call(Some(room_id), session.user_id())
1691 .await?
1692 .ok_or_else(|| anyhow!("failed to decline call"))?;
1693 room_updated(&room, &session.peer);
1694 }
1695
1696 for connection_id in session
1697 .connection_pool()
1698 .await
1699 .user_connection_ids(session.user_id())
1700 {
1701 session
1702 .peer
1703 .send(
1704 connection_id,
1705 proto::CallCanceled {
1706 room_id: room_id.to_proto(),
1707 },
1708 )
1709 .trace_err();
1710 }
1711 update_user_contacts(session.user_id(), &session).await?;
1712 Ok(())
1713}
1714
1715/// Updates other participants in the room with your current location.
1716async fn update_participant_location(
1717 request: proto::UpdateParticipantLocation,
1718 response: Response<proto::UpdateParticipantLocation>,
1719 session: Session,
1720) -> Result<()> {
1721 let room_id = RoomId::from_proto(request.room_id);
1722 let location = request
1723 .location
1724 .ok_or_else(|| anyhow!("invalid location"))?;
1725
1726 let db = session.db().await;
1727 let room = db
1728 .update_room_participant_location(room_id, session.connection_id, location)
1729 .await?;
1730
1731 room_updated(&room, &session.peer);
1732 response.send(proto::Ack {})?;
1733 Ok(())
1734}
1735
1736/// Share a project into the room.
1737async fn share_project(
1738 request: proto::ShareProject,
1739 response: Response<proto::ShareProject>,
1740 session: Session,
1741) -> Result<()> {
1742 let (project_id, room) = &*session
1743 .db()
1744 .await
1745 .share_project(
1746 RoomId::from_proto(request.room_id),
1747 session.connection_id,
1748 &request.worktrees,
1749 request.is_ssh_project,
1750 )
1751 .await?;
1752 response.send(proto::ShareProjectResponse {
1753 project_id: project_id.to_proto(),
1754 })?;
1755 room_updated(room, &session.peer);
1756
1757 Ok(())
1758}
1759
1760/// Unshare a project from the room.
1761async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1762 let project_id = ProjectId::from_proto(message.project_id);
1763 unshare_project_internal(project_id, session.connection_id, &session).await
1764}
1765
1766async fn unshare_project_internal(
1767 project_id: ProjectId,
1768 connection_id: ConnectionId,
1769 session: &Session,
1770) -> Result<()> {
1771 let delete = {
1772 let room_guard = session
1773 .db()
1774 .await
1775 .unshare_project(project_id, connection_id)
1776 .await?;
1777
1778 let (delete, room, guest_connection_ids) = &*room_guard;
1779
1780 let message = proto::UnshareProject {
1781 project_id: project_id.to_proto(),
1782 };
1783
1784 broadcast(
1785 Some(connection_id),
1786 guest_connection_ids.iter().copied(),
1787 |conn_id| session.peer.send(conn_id, message.clone()),
1788 );
1789 if let Some(room) = room {
1790 room_updated(room, &session.peer);
1791 }
1792
1793 *delete
1794 };
1795
1796 if delete {
1797 let db = session.db().await;
1798 db.delete_project(project_id).await?;
1799 }
1800
1801 Ok(())
1802}
1803
1804/// Join someone elses shared project.
1805async fn join_project(
1806 request: proto::JoinProject,
1807 response: Response<proto::JoinProject>,
1808 session: Session,
1809) -> Result<()> {
1810 let project_id = ProjectId::from_proto(request.project_id);
1811
1812 tracing::info!(%project_id, "join project");
1813
1814 let db = session.db().await;
1815 let (project, replica_id) = &mut *db
1816 .join_project(project_id, session.connection_id, session.user_id())
1817 .await?;
1818 drop(db);
1819 tracing::info!(%project_id, "join remote project");
1820 join_project_internal(response, session, project, replica_id)
1821}
1822
1823trait JoinProjectInternalResponse {
1824 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1827 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828 Response::<proto::JoinProject>::send(self, result)
1829 }
1830}
1831
1832fn join_project_internal(
1833 response: impl JoinProjectInternalResponse,
1834 session: Session,
1835 project: &mut Project,
1836 replica_id: &ReplicaId,
1837) -> Result<()> {
1838 let collaborators = project
1839 .collaborators
1840 .iter()
1841 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842 .map(|collaborator| collaborator.to_proto())
1843 .collect::<Vec<_>>();
1844 let project_id = project.id;
1845 let guest_user_id = session.user_id();
1846
1847 let worktrees = project
1848 .worktrees
1849 .iter()
1850 .map(|(id, worktree)| proto::WorktreeMetadata {
1851 id: *id,
1852 root_name: worktree.root_name.clone(),
1853 visible: worktree.visible,
1854 abs_path: worktree.abs_path.clone(),
1855 })
1856 .collect::<Vec<_>>();
1857
1858 let add_project_collaborator = proto::AddProjectCollaborator {
1859 project_id: project_id.to_proto(),
1860 collaborator: Some(proto::Collaborator {
1861 peer_id: Some(session.connection_id.into()),
1862 replica_id: replica_id.0 as u32,
1863 user_id: guest_user_id.to_proto(),
1864 is_host: false,
1865 }),
1866 };
1867
1868 for collaborator in &collaborators {
1869 session
1870 .peer
1871 .send(
1872 collaborator.peer_id.unwrap().into(),
1873 add_project_collaborator.clone(),
1874 )
1875 .trace_err();
1876 }
1877
1878 // First, we send the metadata associated with each worktree.
1879 response.send(proto::JoinProjectResponse {
1880 project_id: project.id.0 as u64,
1881 worktrees: worktrees.clone(),
1882 replica_id: replica_id.0 as u32,
1883 collaborators: collaborators.clone(),
1884 language_servers: project.language_servers.clone(),
1885 role: project.role.into(),
1886 })?;
1887
1888 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1889 // Stream this worktree's entries.
1890 let message = proto::UpdateWorktree {
1891 project_id: project_id.to_proto(),
1892 worktree_id,
1893 abs_path: worktree.abs_path.clone(),
1894 root_name: worktree.root_name,
1895 updated_entries: worktree.entries,
1896 removed_entries: Default::default(),
1897 scan_id: worktree.scan_id,
1898 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1899 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1900 removed_repositories: Default::default(),
1901 };
1902 for update in proto::split_worktree_update(message) {
1903 session.peer.send(session.connection_id, update.clone())?;
1904 }
1905
1906 // Stream this worktree's diagnostics.
1907 for summary in worktree.diagnostic_summaries {
1908 session.peer.send(
1909 session.connection_id,
1910 proto::UpdateDiagnosticSummary {
1911 project_id: project_id.to_proto(),
1912 worktree_id: worktree.id,
1913 summary: Some(summary),
1914 },
1915 )?;
1916 }
1917
1918 for settings_file in worktree.settings_files {
1919 session.peer.send(
1920 session.connection_id,
1921 proto::UpdateWorktreeSettings {
1922 project_id: project_id.to_proto(),
1923 worktree_id: worktree.id,
1924 path: settings_file.path,
1925 content: Some(settings_file.content),
1926 kind: Some(settings_file.kind.to_proto() as i32),
1927 },
1928 )?;
1929 }
1930 }
1931
1932 for repository in mem::take(&mut project.repositories) {
1933 for update in split_repository_update(repository) {
1934 session.peer.send(session.connection_id, update)?;
1935 }
1936 }
1937
1938 for language_server in &project.language_servers {
1939 session.peer.send(
1940 session.connection_id,
1941 proto::UpdateLanguageServer {
1942 project_id: project_id.to_proto(),
1943 language_server_id: language_server.id,
1944 variant: Some(
1945 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1946 proto::LspDiskBasedDiagnosticsUpdated {},
1947 ),
1948 ),
1949 },
1950 )?;
1951 }
1952
1953 Ok(())
1954}
1955
1956/// Leave someone elses shared project.
1957async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1958 let sender_id = session.connection_id;
1959 let project_id = ProjectId::from_proto(request.project_id);
1960 let db = session.db().await;
1961
1962 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963 tracing::info!(
1964 %project_id,
1965 "leave project"
1966 );
1967
1968 project_left(project, &session);
1969 if let Some(room) = room {
1970 room_updated(room, &session.peer);
1971 }
1972
1973 Ok(())
1974}
1975
1976/// Updates other participants with changes to the project
1977async fn update_project(
1978 request: proto::UpdateProject,
1979 response: Response<proto::UpdateProject>,
1980 session: Session,
1981) -> Result<()> {
1982 let project_id = ProjectId::from_proto(request.project_id);
1983 let (room, guest_connection_ids) = &*session
1984 .db()
1985 .await
1986 .update_project(project_id, session.connection_id, &request.worktrees)
1987 .await?;
1988 broadcast(
1989 Some(session.connection_id),
1990 guest_connection_ids.iter().copied(),
1991 |connection_id| {
1992 session
1993 .peer
1994 .forward_send(session.connection_id, connection_id, request.clone())
1995 },
1996 );
1997 if let Some(room) = room {
1998 room_updated(room, &session.peer);
1999 }
2000 response.send(proto::Ack {})?;
2001
2002 Ok(())
2003}
2004
2005/// Updates other participants with changes to the worktree
2006async fn update_worktree(
2007 request: proto::UpdateWorktree,
2008 response: Response<proto::UpdateWorktree>,
2009 session: Session,
2010) -> Result<()> {
2011 let guest_connection_ids = session
2012 .db()
2013 .await
2014 .update_worktree(&request, session.connection_id)
2015 .await?;
2016
2017 broadcast(
2018 Some(session.connection_id),
2019 guest_connection_ids.iter().copied(),
2020 |connection_id| {
2021 session
2022 .peer
2023 .forward_send(session.connection_id, connection_id, request.clone())
2024 },
2025 );
2026 response.send(proto::Ack {})?;
2027 Ok(())
2028}
2029
2030async fn update_repository(
2031 request: proto::UpdateRepository,
2032 response: Response<proto::UpdateRepository>,
2033 session: Session,
2034) -> Result<()> {
2035 let guest_connection_ids = session
2036 .db()
2037 .await
2038 .update_repository(&request, session.connection_id)
2039 .await?;
2040
2041 broadcast(
2042 Some(session.connection_id),
2043 guest_connection_ids.iter().copied(),
2044 |connection_id| {
2045 session
2046 .peer
2047 .forward_send(session.connection_id, connection_id, request.clone())
2048 },
2049 );
2050 response.send(proto::Ack {})?;
2051 Ok(())
2052}
2053
2054async fn remove_repository(
2055 request: proto::RemoveRepository,
2056 response: Response<proto::RemoveRepository>,
2057 session: Session,
2058) -> Result<()> {
2059 let guest_connection_ids = session
2060 .db()
2061 .await
2062 .remove_repository(&request, session.connection_id)
2063 .await?;
2064
2065 broadcast(
2066 Some(session.connection_id),
2067 guest_connection_ids.iter().copied(),
2068 |connection_id| {
2069 session
2070 .peer
2071 .forward_send(session.connection_id, connection_id, request.clone())
2072 },
2073 );
2074 response.send(proto::Ack {})?;
2075 Ok(())
2076}
2077
2078/// Updates other participants with changes to the diagnostics
2079async fn update_diagnostic_summary(
2080 message: proto::UpdateDiagnosticSummary,
2081 session: Session,
2082) -> Result<()> {
2083 let guest_connection_ids = session
2084 .db()
2085 .await
2086 .update_diagnostic_summary(&message, session.connection_id)
2087 .await?;
2088
2089 broadcast(
2090 Some(session.connection_id),
2091 guest_connection_ids.iter().copied(),
2092 |connection_id| {
2093 session
2094 .peer
2095 .forward_send(session.connection_id, connection_id, message.clone())
2096 },
2097 );
2098
2099 Ok(())
2100}
2101
2102/// Updates other participants with changes to the worktree settings
2103async fn update_worktree_settings(
2104 message: proto::UpdateWorktreeSettings,
2105 session: Session,
2106) -> Result<()> {
2107 let guest_connection_ids = session
2108 .db()
2109 .await
2110 .update_worktree_settings(&message, session.connection_id)
2111 .await?;
2112
2113 broadcast(
2114 Some(session.connection_id),
2115 guest_connection_ids.iter().copied(),
2116 |connection_id| {
2117 session
2118 .peer
2119 .forward_send(session.connection_id, connection_id, message.clone())
2120 },
2121 );
2122
2123 Ok(())
2124}
2125
2126/// Notify other participants that a language server has started.
2127async fn start_language_server(
2128 request: proto::StartLanguageServer,
2129 session: Session,
2130) -> Result<()> {
2131 let guest_connection_ids = session
2132 .db()
2133 .await
2134 .start_language_server(&request, session.connection_id)
2135 .await?;
2136
2137 broadcast(
2138 Some(session.connection_id),
2139 guest_connection_ids.iter().copied(),
2140 |connection_id| {
2141 session
2142 .peer
2143 .forward_send(session.connection_id, connection_id, request.clone())
2144 },
2145 );
2146 Ok(())
2147}
2148
2149/// Notify other participants that a language server has changed.
2150async fn update_language_server(
2151 request: proto::UpdateLanguageServer,
2152 session: Session,
2153) -> Result<()> {
2154 let project_id = ProjectId::from_proto(request.project_id);
2155 let project_connection_ids = session
2156 .db()
2157 .await
2158 .project_connection_ids(project_id, session.connection_id, true)
2159 .await?;
2160 broadcast(
2161 Some(session.connection_id),
2162 project_connection_ids.iter().copied(),
2163 |connection_id| {
2164 session
2165 .peer
2166 .forward_send(session.connection_id, connection_id, request.clone())
2167 },
2168 );
2169 Ok(())
2170}
2171
2172/// forward a project request to the host. These requests should be read only
2173/// as guests are allowed to send them.
2174async fn forward_read_only_project_request<T>(
2175 request: T,
2176 response: Response<T>,
2177 session: Session,
2178) -> Result<()>
2179where
2180 T: EntityMessage + RequestMessage,
2181{
2182 let project_id = ProjectId::from_proto(request.remote_entity_id());
2183 let host_connection_id = session
2184 .db()
2185 .await
2186 .host_for_read_only_project_request(project_id, session.connection_id)
2187 .await?;
2188 let payload = session
2189 .peer
2190 .forward_request(session.connection_id, host_connection_id, request)
2191 .await?;
2192 response.send(payload)?;
2193 Ok(())
2194}
2195
2196async fn forward_find_search_candidates_request(
2197 request: proto::FindSearchCandidates,
2198 response: Response<proto::FindSearchCandidates>,
2199 session: Session,
2200) -> Result<()> {
2201 let project_id = ProjectId::from_proto(request.remote_entity_id());
2202 let host_connection_id = session
2203 .db()
2204 .await
2205 .host_for_read_only_project_request(project_id, session.connection_id)
2206 .await?;
2207 let payload = session
2208 .peer
2209 .forward_request(session.connection_id, host_connection_id, request)
2210 .await?;
2211 response.send(payload)?;
2212 Ok(())
2213}
2214
2215/// forward a project request to the host. These requests are disallowed
2216/// for guests.
2217async fn forward_mutating_project_request<T>(
2218 request: T,
2219 response: Response<T>,
2220 session: Session,
2221) -> Result<()>
2222where
2223 T: EntityMessage + RequestMessage,
2224{
2225 let project_id = ProjectId::from_proto(request.remote_entity_id());
2226
2227 let host_connection_id = session
2228 .db()
2229 .await
2230 .host_for_mutating_project_request(project_id, session.connection_id)
2231 .await?;
2232 let payload = session
2233 .peer
2234 .forward_request(session.connection_id, host_connection_id, request)
2235 .await?;
2236 response.send(payload)?;
2237 Ok(())
2238}
2239
2240/// Notify other participants that a new buffer has been created
2241async fn create_buffer_for_peer(
2242 request: proto::CreateBufferForPeer,
2243 session: Session,
2244) -> Result<()> {
2245 session
2246 .db()
2247 .await
2248 .check_user_is_project_host(
2249 ProjectId::from_proto(request.project_id),
2250 session.connection_id,
2251 )
2252 .await?;
2253 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2254 session
2255 .peer
2256 .forward_send(session.connection_id, peer_id.into(), request)?;
2257 Ok(())
2258}
2259
2260/// Notify other participants that a buffer has been updated. This is
2261/// allowed for guests as long as the update is limited to selections.
2262async fn update_buffer(
2263 request: proto::UpdateBuffer,
2264 response: Response<proto::UpdateBuffer>,
2265 session: Session,
2266) -> Result<()> {
2267 let project_id = ProjectId::from_proto(request.project_id);
2268 let mut capability = Capability::ReadOnly;
2269
2270 for op in request.operations.iter() {
2271 match op.variant {
2272 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2273 Some(_) => capability = Capability::ReadWrite,
2274 }
2275 }
2276
2277 let host = {
2278 let guard = session
2279 .db()
2280 .await
2281 .connections_for_buffer_update(project_id, session.connection_id, capability)
2282 .await?;
2283
2284 let (host, guests) = &*guard;
2285
2286 broadcast(
2287 Some(session.connection_id),
2288 guests.clone(),
2289 |connection_id| {
2290 session
2291 .peer
2292 .forward_send(session.connection_id, connection_id, request.clone())
2293 },
2294 );
2295
2296 *host
2297 };
2298
2299 if host != session.connection_id {
2300 session
2301 .peer
2302 .forward_request(session.connection_id, host, request.clone())
2303 .await?;
2304 }
2305
2306 response.send(proto::Ack {})?;
2307 Ok(())
2308}
2309
2310async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2311 let project_id = ProjectId::from_proto(message.project_id);
2312
2313 let operation = message.operation.as_ref().context("invalid operation")?;
2314 let capability = match operation.variant.as_ref() {
2315 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2316 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2317 match buffer_op.variant {
2318 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2319 Capability::ReadOnly
2320 }
2321 _ => Capability::ReadWrite,
2322 }
2323 } else {
2324 Capability::ReadWrite
2325 }
2326 }
2327 Some(_) => Capability::ReadWrite,
2328 None => Capability::ReadOnly,
2329 };
2330
2331 let guard = session
2332 .db()
2333 .await
2334 .connections_for_buffer_update(project_id, session.connection_id, capability)
2335 .await?;
2336
2337 let (host, guests) = &*guard;
2338
2339 broadcast(
2340 Some(session.connection_id),
2341 guests.iter().chain([host]).copied(),
2342 |connection_id| {
2343 session
2344 .peer
2345 .forward_send(session.connection_id, connection_id, message.clone())
2346 },
2347 );
2348
2349 Ok(())
2350}
2351
2352/// Notify other participants that a project has been updated.
2353async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2354 request: T,
2355 session: Session,
2356) -> Result<()> {
2357 let project_id = ProjectId::from_proto(request.remote_entity_id());
2358 let project_connection_ids = session
2359 .db()
2360 .await
2361 .project_connection_ids(project_id, session.connection_id, false)
2362 .await?;
2363
2364 broadcast(
2365 Some(session.connection_id),
2366 project_connection_ids.iter().copied(),
2367 |connection_id| {
2368 session
2369 .peer
2370 .forward_send(session.connection_id, connection_id, request.clone())
2371 },
2372 );
2373 Ok(())
2374}
2375
2376/// Start following another user in a call.
2377async fn follow(
2378 request: proto::Follow,
2379 response: Response<proto::Follow>,
2380 session: Session,
2381) -> Result<()> {
2382 let room_id = RoomId::from_proto(request.room_id);
2383 let project_id = request.project_id.map(ProjectId::from_proto);
2384 let leader_id = request
2385 .leader_id
2386 .ok_or_else(|| anyhow!("invalid leader id"))?
2387 .into();
2388 let follower_id = session.connection_id;
2389
2390 session
2391 .db()
2392 .await
2393 .check_room_participants(room_id, leader_id, session.connection_id)
2394 .await?;
2395
2396 let response_payload = session
2397 .peer
2398 .forward_request(session.connection_id, leader_id, request)
2399 .await?;
2400 response.send(response_payload)?;
2401
2402 if let Some(project_id) = project_id {
2403 let room = session
2404 .db()
2405 .await
2406 .follow(room_id, project_id, leader_id, follower_id)
2407 .await?;
2408 room_updated(&room, &session.peer);
2409 }
2410
2411 Ok(())
2412}
2413
2414/// Stop following another user in a call.
2415async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2416 let room_id = RoomId::from_proto(request.room_id);
2417 let project_id = request.project_id.map(ProjectId::from_proto);
2418 let leader_id = request
2419 .leader_id
2420 .ok_or_else(|| anyhow!("invalid leader id"))?
2421 .into();
2422 let follower_id = session.connection_id;
2423
2424 session
2425 .db()
2426 .await
2427 .check_room_participants(room_id, leader_id, session.connection_id)
2428 .await?;
2429
2430 session
2431 .peer
2432 .forward_send(session.connection_id, leader_id, request)?;
2433
2434 if let Some(project_id) = project_id {
2435 let room = session
2436 .db()
2437 .await
2438 .unfollow(room_id, project_id, leader_id, follower_id)
2439 .await?;
2440 room_updated(&room, &session.peer);
2441 }
2442
2443 Ok(())
2444}
2445
2446/// Notify everyone following you of your current location.
2447async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2448 let room_id = RoomId::from_proto(request.room_id);
2449 let database = session.db.lock().await;
2450
2451 let connection_ids = if let Some(project_id) = request.project_id {
2452 let project_id = ProjectId::from_proto(project_id);
2453 database
2454 .project_connection_ids(project_id, session.connection_id, true)
2455 .await?
2456 } else {
2457 database
2458 .room_connection_ids(room_id, session.connection_id)
2459 .await?
2460 };
2461
2462 // For now, don't send view update messages back to that view's current leader.
2463 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2464 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2465 _ => None,
2466 });
2467
2468 for connection_id in connection_ids.iter().cloned() {
2469 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2470 session
2471 .peer
2472 .forward_send(session.connection_id, connection_id, request.clone())?;
2473 }
2474 }
2475 Ok(())
2476}
2477
2478/// Get public data about users.
2479async fn get_users(
2480 request: proto::GetUsers,
2481 response: Response<proto::GetUsers>,
2482 session: Session,
2483) -> Result<()> {
2484 let user_ids = request
2485 .user_ids
2486 .into_iter()
2487 .map(UserId::from_proto)
2488 .collect();
2489 let users = session
2490 .db()
2491 .await
2492 .get_users_by_ids(user_ids)
2493 .await?
2494 .into_iter()
2495 .map(|user| proto::User {
2496 id: user.id.to_proto(),
2497 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2498 github_login: user.github_login,
2499 email: user.email_address,
2500 name: user.name,
2501 })
2502 .collect();
2503 response.send(proto::UsersResponse { users })?;
2504 Ok(())
2505}
2506
2507/// Search for users (to invite) buy Github login
2508async fn fuzzy_search_users(
2509 request: proto::FuzzySearchUsers,
2510 response: Response<proto::FuzzySearchUsers>,
2511 session: Session,
2512) -> Result<()> {
2513 let query = request.query;
2514 let users = match query.len() {
2515 0 => vec![],
2516 1 | 2 => session
2517 .db()
2518 .await
2519 .get_user_by_github_login(&query)
2520 .await?
2521 .into_iter()
2522 .collect(),
2523 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2524 };
2525 let users = users
2526 .into_iter()
2527 .filter(|user| user.id != session.user_id())
2528 .map(|user| proto::User {
2529 id: user.id.to_proto(),
2530 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2531 github_login: user.github_login,
2532 name: user.name,
2533 email: user.email_address,
2534 })
2535 .collect();
2536 response.send(proto::UsersResponse { users })?;
2537 Ok(())
2538}
2539
2540/// Send a contact request to another user.
2541async fn request_contact(
2542 request: proto::RequestContact,
2543 response: Response<proto::RequestContact>,
2544 session: Session,
2545) -> Result<()> {
2546 let requester_id = session.user_id();
2547 let responder_id = UserId::from_proto(request.responder_id);
2548 if requester_id == responder_id {
2549 return Err(anyhow!("cannot add yourself as a contact"))?;
2550 }
2551
2552 let notifications = session
2553 .db()
2554 .await
2555 .send_contact_request(requester_id, responder_id)
2556 .await?;
2557
2558 // Update outgoing contact requests of requester
2559 let mut update = proto::UpdateContacts::default();
2560 update.outgoing_requests.push(responder_id.to_proto());
2561 for connection_id in session
2562 .connection_pool()
2563 .await
2564 .user_connection_ids(requester_id)
2565 {
2566 session.peer.send(connection_id, update.clone())?;
2567 }
2568
2569 // Update incoming contact requests of responder
2570 let mut update = proto::UpdateContacts::default();
2571 update
2572 .incoming_requests
2573 .push(proto::IncomingContactRequest {
2574 requester_id: requester_id.to_proto(),
2575 });
2576 let connection_pool = session.connection_pool().await;
2577 for connection_id in connection_pool.user_connection_ids(responder_id) {
2578 session.peer.send(connection_id, update.clone())?;
2579 }
2580
2581 send_notifications(&connection_pool, &session.peer, notifications);
2582
2583 response.send(proto::Ack {})?;
2584 Ok(())
2585}
2586
2587/// Accept or decline a contact request
2588async fn respond_to_contact_request(
2589 request: proto::RespondToContactRequest,
2590 response: Response<proto::RespondToContactRequest>,
2591 session: Session,
2592) -> Result<()> {
2593 let responder_id = session.user_id();
2594 let requester_id = UserId::from_proto(request.requester_id);
2595 let db = session.db().await;
2596 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2597 db.dismiss_contact_notification(responder_id, requester_id)
2598 .await?;
2599 } else {
2600 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2601
2602 let notifications = db
2603 .respond_to_contact_request(responder_id, requester_id, accept)
2604 .await?;
2605 let requester_busy = db.is_user_busy(requester_id).await?;
2606 let responder_busy = db.is_user_busy(responder_id).await?;
2607
2608 let pool = session.connection_pool().await;
2609 // Update responder with new contact
2610 let mut update = proto::UpdateContacts::default();
2611 if accept {
2612 update
2613 .contacts
2614 .push(contact_for_user(requester_id, requester_busy, &pool));
2615 }
2616 update
2617 .remove_incoming_requests
2618 .push(requester_id.to_proto());
2619 for connection_id in pool.user_connection_ids(responder_id) {
2620 session.peer.send(connection_id, update.clone())?;
2621 }
2622
2623 // Update requester with new contact
2624 let mut update = proto::UpdateContacts::default();
2625 if accept {
2626 update
2627 .contacts
2628 .push(contact_for_user(responder_id, responder_busy, &pool));
2629 }
2630 update
2631 .remove_outgoing_requests
2632 .push(responder_id.to_proto());
2633
2634 for connection_id in pool.user_connection_ids(requester_id) {
2635 session.peer.send(connection_id, update.clone())?;
2636 }
2637
2638 send_notifications(&pool, &session.peer, notifications);
2639 }
2640
2641 response.send(proto::Ack {})?;
2642 Ok(())
2643}
2644
2645/// Remove a contact.
2646async fn remove_contact(
2647 request: proto::RemoveContact,
2648 response: Response<proto::RemoveContact>,
2649 session: Session,
2650) -> Result<()> {
2651 let requester_id = session.user_id();
2652 let responder_id = UserId::from_proto(request.user_id);
2653 let db = session.db().await;
2654 let (contact_accepted, deleted_notification_id) =
2655 db.remove_contact(requester_id, responder_id).await?;
2656
2657 let pool = session.connection_pool().await;
2658 // Update outgoing contact requests of requester
2659 let mut update = proto::UpdateContacts::default();
2660 if contact_accepted {
2661 update.remove_contacts.push(responder_id.to_proto());
2662 } else {
2663 update
2664 .remove_outgoing_requests
2665 .push(responder_id.to_proto());
2666 }
2667 for connection_id in pool.user_connection_ids(requester_id) {
2668 session.peer.send(connection_id, update.clone())?;
2669 }
2670
2671 // Update incoming contact requests of responder
2672 let mut update = proto::UpdateContacts::default();
2673 if contact_accepted {
2674 update.remove_contacts.push(requester_id.to_proto());
2675 } else {
2676 update
2677 .remove_incoming_requests
2678 .push(requester_id.to_proto());
2679 }
2680 for connection_id in pool.user_connection_ids(responder_id) {
2681 session.peer.send(connection_id, update.clone())?;
2682 if let Some(notification_id) = deleted_notification_id {
2683 session.peer.send(
2684 connection_id,
2685 proto::DeleteNotification {
2686 notification_id: notification_id.to_proto(),
2687 },
2688 )?;
2689 }
2690 }
2691
2692 response.send(proto::Ack {})?;
2693 Ok(())
2694}
2695
2696fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2697 version.0.minor() < 139
2698}
2699
2700async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2701 let plan = session.current_plan(&session.db().await).await?;
2702
2703 session
2704 .peer
2705 .send(
2706 session.connection_id,
2707 proto::UpdateUserPlan { plan: plan.into() },
2708 )
2709 .trace_err();
2710
2711 Ok(())
2712}
2713
2714async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2715 subscribe_user_to_channels(session.user_id(), &session).await?;
2716 Ok(())
2717}
2718
2719async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2720 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2721 let mut pool = session.connection_pool().await;
2722 for membership in &channels_for_user.channel_memberships {
2723 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2724 }
2725 session.peer.send(
2726 session.connection_id,
2727 build_update_user_channels(&channels_for_user),
2728 )?;
2729 session.peer.send(
2730 session.connection_id,
2731 build_channels_update(channels_for_user),
2732 )?;
2733 Ok(())
2734}
2735
2736/// Creates a new channel.
2737async fn create_channel(
2738 request: proto::CreateChannel,
2739 response: Response<proto::CreateChannel>,
2740 session: Session,
2741) -> Result<()> {
2742 let db = session.db().await;
2743
2744 let parent_id = request.parent_id.map(ChannelId::from_proto);
2745 let (channel, membership) = db
2746 .create_channel(&request.name, parent_id, session.user_id())
2747 .await?;
2748
2749 let root_id = channel.root_id();
2750 let channel = Channel::from_model(channel);
2751
2752 response.send(proto::CreateChannelResponse {
2753 channel: Some(channel.to_proto()),
2754 parent_id: request.parent_id,
2755 })?;
2756
2757 let mut connection_pool = session.connection_pool().await;
2758 if let Some(membership) = membership {
2759 connection_pool.subscribe_to_channel(
2760 membership.user_id,
2761 membership.channel_id,
2762 membership.role,
2763 );
2764 let update = proto::UpdateUserChannels {
2765 channel_memberships: vec![proto::ChannelMembership {
2766 channel_id: membership.channel_id.to_proto(),
2767 role: membership.role.into(),
2768 }],
2769 ..Default::default()
2770 };
2771 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2772 session.peer.send(connection_id, update.clone())?;
2773 }
2774 }
2775
2776 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2777 if !role.can_see_channel(channel.visibility) {
2778 continue;
2779 }
2780
2781 let update = proto::UpdateChannels {
2782 channels: vec![channel.to_proto()],
2783 ..Default::default()
2784 };
2785 session.peer.send(connection_id, update.clone())?;
2786 }
2787
2788 Ok(())
2789}
2790
2791/// Delete a channel
2792async fn delete_channel(
2793 request: proto::DeleteChannel,
2794 response: Response<proto::DeleteChannel>,
2795 session: Session,
2796) -> Result<()> {
2797 let db = session.db().await;
2798
2799 let channel_id = request.channel_id;
2800 let (root_channel, removed_channels) = db
2801 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2802 .await?;
2803 response.send(proto::Ack {})?;
2804
2805 // Notify members of removed channels
2806 let mut update = proto::UpdateChannels::default();
2807 update
2808 .delete_channels
2809 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2810
2811 let connection_pool = session.connection_pool().await;
2812 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2813 session.peer.send(connection_id, update.clone())?;
2814 }
2815
2816 Ok(())
2817}
2818
2819/// Invite someone to join a channel.
2820async fn invite_channel_member(
2821 request: proto::InviteChannelMember,
2822 response: Response<proto::InviteChannelMember>,
2823 session: Session,
2824) -> Result<()> {
2825 let db = session.db().await;
2826 let channel_id = ChannelId::from_proto(request.channel_id);
2827 let invitee_id = UserId::from_proto(request.user_id);
2828 let InviteMemberResult {
2829 channel,
2830 notifications,
2831 } = db
2832 .invite_channel_member(
2833 channel_id,
2834 invitee_id,
2835 session.user_id(),
2836 request.role().into(),
2837 )
2838 .await?;
2839
2840 let update = proto::UpdateChannels {
2841 channel_invitations: vec![channel.to_proto()],
2842 ..Default::default()
2843 };
2844
2845 let connection_pool = session.connection_pool().await;
2846 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2847 session.peer.send(connection_id, update.clone())?;
2848 }
2849
2850 send_notifications(&connection_pool, &session.peer, notifications);
2851
2852 response.send(proto::Ack {})?;
2853 Ok(())
2854}
2855
2856/// remove someone from a channel
2857async fn remove_channel_member(
2858 request: proto::RemoveChannelMember,
2859 response: Response<proto::RemoveChannelMember>,
2860 session: Session,
2861) -> Result<()> {
2862 let db = session.db().await;
2863 let channel_id = ChannelId::from_proto(request.channel_id);
2864 let member_id = UserId::from_proto(request.user_id);
2865
2866 let RemoveChannelMemberResult {
2867 membership_update,
2868 notification_id,
2869 } = db
2870 .remove_channel_member(channel_id, member_id, session.user_id())
2871 .await?;
2872
2873 let mut connection_pool = session.connection_pool().await;
2874 notify_membership_updated(
2875 &mut connection_pool,
2876 membership_update,
2877 member_id,
2878 &session.peer,
2879 );
2880 for connection_id in connection_pool.user_connection_ids(member_id) {
2881 if let Some(notification_id) = notification_id {
2882 session
2883 .peer
2884 .send(
2885 connection_id,
2886 proto::DeleteNotification {
2887 notification_id: notification_id.to_proto(),
2888 },
2889 )
2890 .trace_err();
2891 }
2892 }
2893
2894 response.send(proto::Ack {})?;
2895 Ok(())
2896}
2897
2898/// Toggle the channel between public and private.
2899/// Care is taken to maintain the invariant that public channels only descend from public channels,
2900/// (though members-only channels can appear at any point in the hierarchy).
2901async fn set_channel_visibility(
2902 request: proto::SetChannelVisibility,
2903 response: Response<proto::SetChannelVisibility>,
2904 session: Session,
2905) -> Result<()> {
2906 let db = session.db().await;
2907 let channel_id = ChannelId::from_proto(request.channel_id);
2908 let visibility = request.visibility().into();
2909
2910 let channel_model = db
2911 .set_channel_visibility(channel_id, visibility, session.user_id())
2912 .await?;
2913 let root_id = channel_model.root_id();
2914 let channel = Channel::from_model(channel_model);
2915
2916 let mut connection_pool = session.connection_pool().await;
2917 for (user_id, role) in connection_pool
2918 .channel_user_ids(root_id)
2919 .collect::<Vec<_>>()
2920 .into_iter()
2921 {
2922 let update = if role.can_see_channel(channel.visibility) {
2923 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2924 proto::UpdateChannels {
2925 channels: vec![channel.to_proto()],
2926 ..Default::default()
2927 }
2928 } else {
2929 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2930 proto::UpdateChannels {
2931 delete_channels: vec![channel.id.to_proto()],
2932 ..Default::default()
2933 }
2934 };
2935
2936 for connection_id in connection_pool.user_connection_ids(user_id) {
2937 session.peer.send(connection_id, update.clone())?;
2938 }
2939 }
2940
2941 response.send(proto::Ack {})?;
2942 Ok(())
2943}
2944
2945/// Alter the role for a user in the channel.
2946async fn set_channel_member_role(
2947 request: proto::SetChannelMemberRole,
2948 response: Response<proto::SetChannelMemberRole>,
2949 session: Session,
2950) -> Result<()> {
2951 let db = session.db().await;
2952 let channel_id = ChannelId::from_proto(request.channel_id);
2953 let member_id = UserId::from_proto(request.user_id);
2954 let result = db
2955 .set_channel_member_role(
2956 channel_id,
2957 session.user_id(),
2958 member_id,
2959 request.role().into(),
2960 )
2961 .await?;
2962
2963 match result {
2964 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2965 let mut connection_pool = session.connection_pool().await;
2966 notify_membership_updated(
2967 &mut connection_pool,
2968 membership_update,
2969 member_id,
2970 &session.peer,
2971 )
2972 }
2973 db::SetMemberRoleResult::InviteUpdated(channel) => {
2974 let update = proto::UpdateChannels {
2975 channel_invitations: vec![channel.to_proto()],
2976 ..Default::default()
2977 };
2978
2979 for connection_id in session
2980 .connection_pool()
2981 .await
2982 .user_connection_ids(member_id)
2983 {
2984 session.peer.send(connection_id, update.clone())?;
2985 }
2986 }
2987 }
2988
2989 response.send(proto::Ack {})?;
2990 Ok(())
2991}
2992
2993/// Change the name of a channel
2994async fn rename_channel(
2995 request: proto::RenameChannel,
2996 response: Response<proto::RenameChannel>,
2997 session: Session,
2998) -> Result<()> {
2999 let db = session.db().await;
3000 let channel_id = ChannelId::from_proto(request.channel_id);
3001 let channel_model = db
3002 .rename_channel(channel_id, session.user_id(), &request.name)
3003 .await?;
3004 let root_id = channel_model.root_id();
3005 let channel = Channel::from_model(channel_model);
3006
3007 response.send(proto::RenameChannelResponse {
3008 channel: Some(channel.to_proto()),
3009 })?;
3010
3011 let connection_pool = session.connection_pool().await;
3012 let update = proto::UpdateChannels {
3013 channels: vec![channel.to_proto()],
3014 ..Default::default()
3015 };
3016 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3017 if role.can_see_channel(channel.visibility) {
3018 session.peer.send(connection_id, update.clone())?;
3019 }
3020 }
3021
3022 Ok(())
3023}
3024
3025/// Move a channel to a new parent.
3026async fn move_channel(
3027 request: proto::MoveChannel,
3028 response: Response<proto::MoveChannel>,
3029 session: Session,
3030) -> Result<()> {
3031 let channel_id = ChannelId::from_proto(request.channel_id);
3032 let to = ChannelId::from_proto(request.to);
3033
3034 let (root_id, channels) = session
3035 .db()
3036 .await
3037 .move_channel(channel_id, to, session.user_id())
3038 .await?;
3039
3040 let connection_pool = session.connection_pool().await;
3041 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3042 let channels = channels
3043 .iter()
3044 .filter_map(|channel| {
3045 if role.can_see_channel(channel.visibility) {
3046 Some(channel.to_proto())
3047 } else {
3048 None
3049 }
3050 })
3051 .collect::<Vec<_>>();
3052 if channels.is_empty() {
3053 continue;
3054 }
3055
3056 let update = proto::UpdateChannels {
3057 channels,
3058 ..Default::default()
3059 };
3060
3061 session.peer.send(connection_id, update.clone())?;
3062 }
3063
3064 response.send(Ack {})?;
3065 Ok(())
3066}
3067
3068/// Get the list of channel members
3069async fn get_channel_members(
3070 request: proto::GetChannelMembers,
3071 response: Response<proto::GetChannelMembers>,
3072 session: Session,
3073) -> Result<()> {
3074 let db = session.db().await;
3075 let channel_id = ChannelId::from_proto(request.channel_id);
3076 let limit = if request.limit == 0 {
3077 u16::MAX as u64
3078 } else {
3079 request.limit
3080 };
3081 let (members, users) = db
3082 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3083 .await?;
3084 response.send(proto::GetChannelMembersResponse { members, users })?;
3085 Ok(())
3086}
3087
3088/// Accept or decline a channel invitation.
3089async fn respond_to_channel_invite(
3090 request: proto::RespondToChannelInvite,
3091 response: Response<proto::RespondToChannelInvite>,
3092 session: Session,
3093) -> Result<()> {
3094 let db = session.db().await;
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096 let RespondToChannelInvite {
3097 membership_update,
3098 notifications,
3099 } = db
3100 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3101 .await?;
3102
3103 let mut connection_pool = session.connection_pool().await;
3104 if let Some(membership_update) = membership_update {
3105 notify_membership_updated(
3106 &mut connection_pool,
3107 membership_update,
3108 session.user_id(),
3109 &session.peer,
3110 );
3111 } else {
3112 let update = proto::UpdateChannels {
3113 remove_channel_invitations: vec![channel_id.to_proto()],
3114 ..Default::default()
3115 };
3116
3117 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3118 session.peer.send(connection_id, update.clone())?;
3119 }
3120 };
3121
3122 send_notifications(&connection_pool, &session.peer, notifications);
3123
3124 response.send(proto::Ack {})?;
3125
3126 Ok(())
3127}
3128
3129/// Join the channels' room
3130async fn join_channel(
3131 request: proto::JoinChannel,
3132 response: Response<proto::JoinChannel>,
3133 session: Session,
3134) -> Result<()> {
3135 let channel_id = ChannelId::from_proto(request.channel_id);
3136 join_channel_internal(channel_id, Box::new(response), session).await
3137}
3138
3139trait JoinChannelInternalResponse {
3140 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3141}
3142impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3143 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3144 Response::<proto::JoinChannel>::send(self, result)
3145 }
3146}
3147impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3148 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3149 Response::<proto::JoinRoom>::send(self, result)
3150 }
3151}
3152
3153async fn join_channel_internal(
3154 channel_id: ChannelId,
3155 response: Box<impl JoinChannelInternalResponse>,
3156 session: Session,
3157) -> Result<()> {
3158 let joined_room = {
3159 let mut db = session.db().await;
3160 // If zed quits without leaving the room, and the user re-opens zed before the
3161 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3162 // room they were in.
3163 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3164 tracing::info!(
3165 stale_connection_id = %connection,
3166 "cleaning up stale connection",
3167 );
3168 drop(db);
3169 leave_room_for_session(&session, connection).await?;
3170 db = session.db().await;
3171 }
3172
3173 let (joined_room, membership_updated, role) = db
3174 .join_channel(channel_id, session.user_id(), session.connection_id)
3175 .await?;
3176
3177 let live_kit_connection_info =
3178 session
3179 .app_state
3180 .livekit_client
3181 .as_ref()
3182 .and_then(|live_kit| {
3183 let (can_publish, token) = if role == ChannelRole::Guest {
3184 (
3185 false,
3186 live_kit
3187 .guest_token(
3188 &joined_room.room.livekit_room,
3189 &session.user_id().to_string(),
3190 )
3191 .trace_err()?,
3192 )
3193 } else {
3194 (
3195 true,
3196 live_kit
3197 .room_token(
3198 &joined_room.room.livekit_room,
3199 &session.user_id().to_string(),
3200 )
3201 .trace_err()?,
3202 )
3203 };
3204
3205 Some(LiveKitConnectionInfo {
3206 server_url: live_kit.url().into(),
3207 token,
3208 can_publish,
3209 })
3210 });
3211
3212 response.send(proto::JoinRoomResponse {
3213 room: Some(joined_room.room.clone()),
3214 channel_id: joined_room
3215 .channel
3216 .as_ref()
3217 .map(|channel| channel.id.to_proto()),
3218 live_kit_connection_info,
3219 })?;
3220
3221 let mut connection_pool = session.connection_pool().await;
3222 if let Some(membership_updated) = membership_updated {
3223 notify_membership_updated(
3224 &mut connection_pool,
3225 membership_updated,
3226 session.user_id(),
3227 &session.peer,
3228 );
3229 }
3230
3231 room_updated(&joined_room.room, &session.peer);
3232
3233 joined_room
3234 };
3235
3236 channel_updated(
3237 &joined_room
3238 .channel
3239 .ok_or_else(|| anyhow!("channel not returned"))?,
3240 &joined_room.room,
3241 &session.peer,
3242 &*session.connection_pool().await,
3243 );
3244
3245 update_user_contacts(session.user_id(), &session).await?;
3246 Ok(())
3247}
3248
3249/// Start editing the channel notes
3250async fn join_channel_buffer(
3251 request: proto::JoinChannelBuffer,
3252 response: Response<proto::JoinChannelBuffer>,
3253 session: Session,
3254) -> Result<()> {
3255 let db = session.db().await;
3256 let channel_id = ChannelId::from_proto(request.channel_id);
3257
3258 let open_response = db
3259 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3260 .await?;
3261
3262 let collaborators = open_response.collaborators.clone();
3263 response.send(open_response)?;
3264
3265 let update = UpdateChannelBufferCollaborators {
3266 channel_id: channel_id.to_proto(),
3267 collaborators: collaborators.clone(),
3268 };
3269 channel_buffer_updated(
3270 session.connection_id,
3271 collaborators
3272 .iter()
3273 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3274 &update,
3275 &session.peer,
3276 );
3277
3278 Ok(())
3279}
3280
3281/// Edit the channel notes
3282async fn update_channel_buffer(
3283 request: proto::UpdateChannelBuffer,
3284 session: Session,
3285) -> Result<()> {
3286 let db = session.db().await;
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288
3289 let (collaborators, epoch, version) = db
3290 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3291 .await?;
3292
3293 channel_buffer_updated(
3294 session.connection_id,
3295 collaborators.clone(),
3296 &proto::UpdateChannelBuffer {
3297 channel_id: channel_id.to_proto(),
3298 operations: request.operations,
3299 },
3300 &session.peer,
3301 );
3302
3303 let pool = &*session.connection_pool().await;
3304
3305 let non_collaborators =
3306 pool.channel_connection_ids(channel_id)
3307 .filter_map(|(connection_id, _)| {
3308 if collaborators.contains(&connection_id) {
3309 None
3310 } else {
3311 Some(connection_id)
3312 }
3313 });
3314
3315 broadcast(None, non_collaborators, |peer_id| {
3316 session.peer.send(
3317 peer_id,
3318 proto::UpdateChannels {
3319 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3320 channel_id: channel_id.to_proto(),
3321 epoch: epoch as u64,
3322 version: version.clone(),
3323 }],
3324 ..Default::default()
3325 },
3326 )
3327 });
3328
3329 Ok(())
3330}
3331
3332/// Rejoin the channel notes after a connection blip
3333async fn rejoin_channel_buffers(
3334 request: proto::RejoinChannelBuffers,
3335 response: Response<proto::RejoinChannelBuffers>,
3336 session: Session,
3337) -> Result<()> {
3338 let db = session.db().await;
3339 let buffers = db
3340 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3341 .await?;
3342
3343 for rejoined_buffer in &buffers {
3344 let collaborators_to_notify = rejoined_buffer
3345 .buffer
3346 .collaborators
3347 .iter()
3348 .filter_map(|c| Some(c.peer_id?.into()));
3349 channel_buffer_updated(
3350 session.connection_id,
3351 collaborators_to_notify,
3352 &proto::UpdateChannelBufferCollaborators {
3353 channel_id: rejoined_buffer.buffer.channel_id,
3354 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3355 },
3356 &session.peer,
3357 );
3358 }
3359
3360 response.send(proto::RejoinChannelBuffersResponse {
3361 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3362 })?;
3363
3364 Ok(())
3365}
3366
3367/// Stop editing the channel notes
3368async fn leave_channel_buffer(
3369 request: proto::LeaveChannelBuffer,
3370 response: Response<proto::LeaveChannelBuffer>,
3371 session: Session,
3372) -> Result<()> {
3373 let db = session.db().await;
3374 let channel_id = ChannelId::from_proto(request.channel_id);
3375
3376 let left_buffer = db
3377 .leave_channel_buffer(channel_id, session.connection_id)
3378 .await?;
3379
3380 response.send(Ack {})?;
3381
3382 channel_buffer_updated(
3383 session.connection_id,
3384 left_buffer.connections,
3385 &proto::UpdateChannelBufferCollaborators {
3386 channel_id: channel_id.to_proto(),
3387 collaborators: left_buffer.collaborators,
3388 },
3389 &session.peer,
3390 );
3391
3392 Ok(())
3393}
3394
3395fn channel_buffer_updated<T: EnvelopedMessage>(
3396 sender_id: ConnectionId,
3397 collaborators: impl IntoIterator<Item = ConnectionId>,
3398 message: &T,
3399 peer: &Peer,
3400) {
3401 broadcast(Some(sender_id), collaborators, |peer_id| {
3402 peer.send(peer_id, message.clone())
3403 });
3404}
3405
3406fn send_notifications(
3407 connection_pool: &ConnectionPool,
3408 peer: &Peer,
3409 notifications: db::NotificationBatch,
3410) {
3411 for (user_id, notification) in notifications {
3412 for connection_id in connection_pool.user_connection_ids(user_id) {
3413 if let Err(error) = peer.send(
3414 connection_id,
3415 proto::AddNotification {
3416 notification: Some(notification.clone()),
3417 },
3418 ) {
3419 tracing::error!(
3420 "failed to send notification to {:?} {}",
3421 connection_id,
3422 error
3423 );
3424 }
3425 }
3426 }
3427}
3428
3429/// Send a message to the channel
3430async fn send_channel_message(
3431 request: proto::SendChannelMessage,
3432 response: Response<proto::SendChannelMessage>,
3433 session: Session,
3434) -> Result<()> {
3435 // Validate the message body.
3436 let body = request.body.trim().to_string();
3437 if body.len() > MAX_MESSAGE_LEN {
3438 return Err(anyhow!("message is too long"))?;
3439 }
3440 if body.is_empty() {
3441 return Err(anyhow!("message can't be blank"))?;
3442 }
3443
3444 // TODO: adjust mentions if body is trimmed
3445
3446 let timestamp = OffsetDateTime::now_utc();
3447 let nonce = request
3448 .nonce
3449 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3450
3451 let channel_id = ChannelId::from_proto(request.channel_id);
3452 let CreatedChannelMessage {
3453 message_id,
3454 participant_connection_ids,
3455 notifications,
3456 } = session
3457 .db()
3458 .await
3459 .create_channel_message(
3460 channel_id,
3461 session.user_id(),
3462 &body,
3463 &request.mentions,
3464 timestamp,
3465 nonce.clone().into(),
3466 request.reply_to_message_id.map(MessageId::from_proto),
3467 )
3468 .await?;
3469
3470 let message = proto::ChannelMessage {
3471 sender_id: session.user_id().to_proto(),
3472 id: message_id.to_proto(),
3473 body,
3474 mentions: request.mentions,
3475 timestamp: timestamp.unix_timestamp() as u64,
3476 nonce: Some(nonce),
3477 reply_to_message_id: request.reply_to_message_id,
3478 edited_at: None,
3479 };
3480 broadcast(
3481 Some(session.connection_id),
3482 participant_connection_ids.clone(),
3483 |connection| {
3484 session.peer.send(
3485 connection,
3486 proto::ChannelMessageSent {
3487 channel_id: channel_id.to_proto(),
3488 message: Some(message.clone()),
3489 },
3490 )
3491 },
3492 );
3493 response.send(proto::SendChannelMessageResponse {
3494 message: Some(message),
3495 })?;
3496
3497 let pool = &*session.connection_pool().await;
3498 let non_participants =
3499 pool.channel_connection_ids(channel_id)
3500 .filter_map(|(connection_id, _)| {
3501 if participant_connection_ids.contains(&connection_id) {
3502 None
3503 } else {
3504 Some(connection_id)
3505 }
3506 });
3507 broadcast(None, non_participants, |peer_id| {
3508 session.peer.send(
3509 peer_id,
3510 proto::UpdateChannels {
3511 latest_channel_message_ids: vec![proto::ChannelMessageId {
3512 channel_id: channel_id.to_proto(),
3513 message_id: message_id.to_proto(),
3514 }],
3515 ..Default::default()
3516 },
3517 )
3518 });
3519 send_notifications(pool, &session.peer, notifications);
3520
3521 Ok(())
3522}
3523
3524/// Delete a channel message
3525async fn remove_channel_message(
3526 request: proto::RemoveChannelMessage,
3527 response: Response<proto::RemoveChannelMessage>,
3528 session: Session,
3529) -> Result<()> {
3530 let channel_id = ChannelId::from_proto(request.channel_id);
3531 let message_id = MessageId::from_proto(request.message_id);
3532 let (connection_ids, existing_notification_ids) = session
3533 .db()
3534 .await
3535 .remove_channel_message(channel_id, message_id, session.user_id())
3536 .await?;
3537
3538 broadcast(
3539 Some(session.connection_id),
3540 connection_ids,
3541 move |connection| {
3542 session.peer.send(connection, request.clone())?;
3543
3544 for notification_id in &existing_notification_ids {
3545 session.peer.send(
3546 connection,
3547 proto::DeleteNotification {
3548 notification_id: (*notification_id).to_proto(),
3549 },
3550 )?;
3551 }
3552
3553 Ok(())
3554 },
3555 );
3556 response.send(proto::Ack {})?;
3557 Ok(())
3558}
3559
3560async fn update_channel_message(
3561 request: proto::UpdateChannelMessage,
3562 response: Response<proto::UpdateChannelMessage>,
3563 session: Session,
3564) -> Result<()> {
3565 let channel_id = ChannelId::from_proto(request.channel_id);
3566 let message_id = MessageId::from_proto(request.message_id);
3567 let updated_at = OffsetDateTime::now_utc();
3568 let UpdatedChannelMessage {
3569 message_id,
3570 participant_connection_ids,
3571 notifications,
3572 reply_to_message_id,
3573 timestamp,
3574 deleted_mention_notification_ids,
3575 updated_mention_notifications,
3576 } = session
3577 .db()
3578 .await
3579 .update_channel_message(
3580 channel_id,
3581 message_id,
3582 session.user_id(),
3583 request.body.as_str(),
3584 &request.mentions,
3585 updated_at,
3586 )
3587 .await?;
3588
3589 let nonce = request
3590 .nonce
3591 .clone()
3592 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3593
3594 let message = proto::ChannelMessage {
3595 sender_id: session.user_id().to_proto(),
3596 id: message_id.to_proto(),
3597 body: request.body.clone(),
3598 mentions: request.mentions.clone(),
3599 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3600 nonce: Some(nonce),
3601 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3602 edited_at: Some(updated_at.unix_timestamp() as u64),
3603 };
3604
3605 response.send(proto::Ack {})?;
3606
3607 let pool = &*session.connection_pool().await;
3608 broadcast(
3609 Some(session.connection_id),
3610 participant_connection_ids,
3611 |connection| {
3612 session.peer.send(
3613 connection,
3614 proto::ChannelMessageUpdate {
3615 channel_id: channel_id.to_proto(),
3616 message: Some(message.clone()),
3617 },
3618 )?;
3619
3620 for notification_id in &deleted_mention_notification_ids {
3621 session.peer.send(
3622 connection,
3623 proto::DeleteNotification {
3624 notification_id: (*notification_id).to_proto(),
3625 },
3626 )?;
3627 }
3628
3629 for notification in &updated_mention_notifications {
3630 session.peer.send(
3631 connection,
3632 proto::UpdateNotification {
3633 notification: Some(notification.clone()),
3634 },
3635 )?;
3636 }
3637
3638 Ok(())
3639 },
3640 );
3641
3642 send_notifications(pool, &session.peer, notifications);
3643
3644 Ok(())
3645}
3646
3647/// Mark a channel message as read
3648async fn acknowledge_channel_message(
3649 request: proto::AckChannelMessage,
3650 session: Session,
3651) -> Result<()> {
3652 let channel_id = ChannelId::from_proto(request.channel_id);
3653 let message_id = MessageId::from_proto(request.message_id);
3654 let notifications = session
3655 .db()
3656 .await
3657 .observe_channel_message(channel_id, session.user_id(), message_id)
3658 .await?;
3659 send_notifications(
3660 &*session.connection_pool().await,
3661 &session.peer,
3662 notifications,
3663 );
3664 Ok(())
3665}
3666
3667/// Mark a buffer version as synced
3668async fn acknowledge_buffer_version(
3669 request: proto::AckBufferOperation,
3670 session: Session,
3671) -> Result<()> {
3672 let buffer_id = BufferId::from_proto(request.buffer_id);
3673 session
3674 .db()
3675 .await
3676 .observe_buffer_version(
3677 buffer_id,
3678 session.user_id(),
3679 request.epoch as i32,
3680 &request.version,
3681 )
3682 .await?;
3683 Ok(())
3684}
3685
3686async fn count_language_model_tokens(
3687 request: proto::CountLanguageModelTokens,
3688 response: Response<proto::CountLanguageModelTokens>,
3689 session: Session,
3690 config: &Config,
3691) -> Result<()> {
3692 authorize_access_to_legacy_llm_endpoints(&session).await?;
3693
3694 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3695 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3696 proto::Plan::Free | proto::Plan::ZedProTrial => {
3697 Box::new(FreeCountLanguageModelTokensRateLimit)
3698 }
3699 };
3700
3701 session
3702 .app_state
3703 .rate_limiter
3704 .check(&*rate_limit, session.user_id())
3705 .await?;
3706
3707 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3708 Some(proto::LanguageModelProvider::Google) => {
3709 let api_key = config
3710 .google_ai_api_key
3711 .as_ref()
3712 .context("no Google AI API key configured on the server")?;
3713 google_ai::count_tokens(
3714 session.http_client.as_ref(),
3715 google_ai::API_URL,
3716 api_key,
3717 serde_json::from_str(&request.request)?,
3718 )
3719 .await?
3720 }
3721 _ => return Err(anyhow!("unsupported provider"))?,
3722 };
3723
3724 response.send(proto::CountLanguageModelTokensResponse {
3725 token_count: result.total_tokens as u32,
3726 })?;
3727
3728 Ok(())
3729}
3730
3731struct ZedProCountLanguageModelTokensRateLimit;
3732
3733impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3734 fn capacity(&self) -> usize {
3735 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3736 .ok()
3737 .and_then(|v| v.parse().ok())
3738 .unwrap_or(600) // Picked arbitrarily
3739 }
3740
3741 fn refill_duration(&self) -> chrono::Duration {
3742 chrono::Duration::hours(1)
3743 }
3744
3745 fn db_name(&self) -> &'static str {
3746 "zed-pro:count-language-model-tokens"
3747 }
3748}
3749
3750struct FreeCountLanguageModelTokensRateLimit;
3751
3752impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3753 fn capacity(&self) -> usize {
3754 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3755 .ok()
3756 .and_then(|v| v.parse().ok())
3757 .unwrap_or(600 / 10) // Picked arbitrarily
3758 }
3759
3760 fn refill_duration(&self) -> chrono::Duration {
3761 chrono::Duration::hours(1)
3762 }
3763
3764 fn db_name(&self) -> &'static str {
3765 "free:count-language-model-tokens"
3766 }
3767}
3768
3769/// This is leftover from before the LLM service.
3770///
3771/// The endpoints protected by this check will be moved there eventually.
3772async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3773 if session.is_staff() {
3774 Ok(())
3775 } else {
3776 Err(anyhow!("permission denied"))?
3777 }
3778}
3779
3780/// Get a Supermaven API key for the user
3781async fn get_supermaven_api_key(
3782 _request: proto::GetSupermavenApiKey,
3783 response: Response<proto::GetSupermavenApiKey>,
3784 session: Session,
3785) -> Result<()> {
3786 let user_id: String = session.user_id().to_string();
3787 if !session.is_staff() {
3788 return Err(anyhow!("supermaven not enabled for this account"))?;
3789 }
3790
3791 let email = session
3792 .email()
3793 .ok_or_else(|| anyhow!("user must have an email"))?;
3794
3795 let supermaven_admin_api = session
3796 .supermaven_client
3797 .as_ref()
3798 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3799
3800 let result = supermaven_admin_api
3801 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3802 .await?;
3803
3804 response.send(proto::GetSupermavenApiKeyResponse {
3805 api_key: result.api_key,
3806 })?;
3807
3808 Ok(())
3809}
3810
3811/// Start receiving chat updates for a channel
3812async fn join_channel_chat(
3813 request: proto::JoinChannelChat,
3814 response: Response<proto::JoinChannelChat>,
3815 session: Session,
3816) -> Result<()> {
3817 let channel_id = ChannelId::from_proto(request.channel_id);
3818
3819 let db = session.db().await;
3820 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3821 .await?;
3822 let messages = db
3823 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3824 .await?;
3825 response.send(proto::JoinChannelChatResponse {
3826 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3827 messages,
3828 })?;
3829 Ok(())
3830}
3831
3832/// Stop receiving chat updates for a channel
3833async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3834 let channel_id = ChannelId::from_proto(request.channel_id);
3835 session
3836 .db()
3837 .await
3838 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3839 .await?;
3840 Ok(())
3841}
3842
3843/// Retrieve the chat history for a channel
3844async fn get_channel_messages(
3845 request: proto::GetChannelMessages,
3846 response: Response<proto::GetChannelMessages>,
3847 session: Session,
3848) -> Result<()> {
3849 let channel_id = ChannelId::from_proto(request.channel_id);
3850 let messages = session
3851 .db()
3852 .await
3853 .get_channel_messages(
3854 channel_id,
3855 session.user_id(),
3856 MESSAGE_COUNT_PER_PAGE,
3857 Some(MessageId::from_proto(request.before_message_id)),
3858 )
3859 .await?;
3860 response.send(proto::GetChannelMessagesResponse {
3861 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3862 messages,
3863 })?;
3864 Ok(())
3865}
3866
3867/// Retrieve specific chat messages
3868async fn get_channel_messages_by_id(
3869 request: proto::GetChannelMessagesById,
3870 response: Response<proto::GetChannelMessagesById>,
3871 session: Session,
3872) -> Result<()> {
3873 let message_ids = request
3874 .message_ids
3875 .iter()
3876 .map(|id| MessageId::from_proto(*id))
3877 .collect::<Vec<_>>();
3878 let messages = session
3879 .db()
3880 .await
3881 .get_channel_messages_by_id(session.user_id(), &message_ids)
3882 .await?;
3883 response.send(proto::GetChannelMessagesResponse {
3884 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3885 messages,
3886 })?;
3887 Ok(())
3888}
3889
3890/// Retrieve the current users notifications
3891async fn get_notifications(
3892 request: proto::GetNotifications,
3893 response: Response<proto::GetNotifications>,
3894 session: Session,
3895) -> Result<()> {
3896 let notifications = session
3897 .db()
3898 .await
3899 .get_notifications(
3900 session.user_id(),
3901 NOTIFICATION_COUNT_PER_PAGE,
3902 request.before_id.map(db::NotificationId::from_proto),
3903 )
3904 .await?;
3905 response.send(proto::GetNotificationsResponse {
3906 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3907 notifications,
3908 })?;
3909 Ok(())
3910}
3911
3912/// Mark notifications as read
3913async fn mark_notification_as_read(
3914 request: proto::MarkNotificationRead,
3915 response: Response<proto::MarkNotificationRead>,
3916 session: Session,
3917) -> Result<()> {
3918 let database = &session.db().await;
3919 let notifications = database
3920 .mark_notification_as_read_by_id(
3921 session.user_id(),
3922 NotificationId::from_proto(request.notification_id),
3923 )
3924 .await?;
3925 send_notifications(
3926 &*session.connection_pool().await,
3927 &session.peer,
3928 notifications,
3929 );
3930 response.send(proto::Ack {})?;
3931 Ok(())
3932}
3933
3934/// Get the current users information
3935async fn get_private_user_info(
3936 _request: proto::GetPrivateUserInfo,
3937 response: Response<proto::GetPrivateUserInfo>,
3938 session: Session,
3939) -> Result<()> {
3940 let db = session.db().await;
3941
3942 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3943 let user = db
3944 .get_user_by_id(session.user_id())
3945 .await?
3946 .ok_or_else(|| anyhow!("user not found"))?;
3947 let flags = db.get_user_flags(session.user_id()).await?;
3948
3949 response.send(proto::GetPrivateUserInfoResponse {
3950 metrics_id,
3951 staff: user.admin,
3952 flags,
3953 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3954 })?;
3955 Ok(())
3956}
3957
3958/// Accept the terms of service (tos) on behalf of the current user
3959async fn accept_terms_of_service(
3960 _request: proto::AcceptTermsOfService,
3961 response: Response<proto::AcceptTermsOfService>,
3962 session: Session,
3963) -> Result<()> {
3964 let db = session.db().await;
3965
3966 let accepted_tos_at = Utc::now();
3967 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3968 .await?;
3969
3970 response.send(proto::AcceptTermsOfServiceResponse {
3971 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3972 })?;
3973 Ok(())
3974}
3975
3976/// The minimum account age an account must have in order to use the LLM service.
3977pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3978
3979async fn get_llm_api_token(
3980 _request: proto::GetLlmToken,
3981 response: Response<proto::GetLlmToken>,
3982 session: Session,
3983) -> Result<()> {
3984 let db = session.db().await;
3985
3986 let flags = db.get_user_flags(session.user_id()).await?;
3987 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3988
3989 if !session.is_staff() && !has_language_models_feature_flag {
3990 Err(anyhow!("permission denied"))?
3991 }
3992
3993 let user_id = session.user_id();
3994 let user = db
3995 .get_user_by_id(user_id)
3996 .await?
3997 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3998
3999 if user.accepted_tos_at.is_none() {
4000 Err(anyhow!("terms of service not accepted"))?
4001 }
4002
4003 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4004 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4005 let billing_preferences = db.get_billing_preferences(user.id).await?;
4006
4007 let token = LlmTokenClaims::create(
4008 &user,
4009 session.is_staff(),
4010 billing_preferences,
4011 &flags,
4012 has_legacy_llm_subscription,
4013 billing_subscription,
4014 session.system_id.clone(),
4015 &session.app_state.config,
4016 )?;
4017 response.send(proto::GetLlmTokenResponse { token })?;
4018 Ok(())
4019}
4020
4021fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4022 let message = match message {
4023 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4024 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4025 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4026 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4027 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4028 code: frame.code.into(),
4029 reason: frame.reason.as_str().to_owned().into(),
4030 })),
4031 // We should never receive a frame while reading the message, according
4032 // to the `tungstenite` maintainers:
4033 //
4034 // > It cannot occur when you read messages from the WebSocket, but it
4035 // > can be used when you want to send the raw frames (e.g. you want to
4036 // > send the frames to the WebSocket without composing the full message first).
4037 // >
4038 // > — https://github.com/snapview/tungstenite-rs/issues/268
4039 TungsteniteMessage::Frame(_) => {
4040 bail!("received an unexpected frame while reading the message")
4041 }
4042 };
4043
4044 Ok(message)
4045}
4046
4047fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4048 match message {
4049 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4050 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4051 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4052 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4053 AxumMessage::Close(frame) => {
4054 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4055 code: frame.code.into(),
4056 reason: frame.reason.as_ref().into(),
4057 }))
4058 }
4059 }
4060}
4061
4062fn notify_membership_updated(
4063 connection_pool: &mut ConnectionPool,
4064 result: MembershipUpdated,
4065 user_id: UserId,
4066 peer: &Peer,
4067) {
4068 for membership in &result.new_channels.channel_memberships {
4069 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4070 }
4071 for channel_id in &result.removed_channels {
4072 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4073 }
4074
4075 let user_channels_update = proto::UpdateUserChannels {
4076 channel_memberships: result
4077 .new_channels
4078 .channel_memberships
4079 .iter()
4080 .map(|cm| proto::ChannelMembership {
4081 channel_id: cm.channel_id.to_proto(),
4082 role: cm.role.into(),
4083 })
4084 .collect(),
4085 ..Default::default()
4086 };
4087
4088 let mut update = build_channels_update(result.new_channels);
4089 update.delete_channels = result
4090 .removed_channels
4091 .into_iter()
4092 .map(|id| id.to_proto())
4093 .collect();
4094 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4095
4096 for connection_id in connection_pool.user_connection_ids(user_id) {
4097 peer.send(connection_id, user_channels_update.clone())
4098 .trace_err();
4099 peer.send(connection_id, update.clone()).trace_err();
4100 }
4101}
4102
4103fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4104 proto::UpdateUserChannels {
4105 channel_memberships: channels
4106 .channel_memberships
4107 .iter()
4108 .map(|m| proto::ChannelMembership {
4109 channel_id: m.channel_id.to_proto(),
4110 role: m.role.into(),
4111 })
4112 .collect(),
4113 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4114 observed_channel_message_id: channels.observed_channel_messages.clone(),
4115 }
4116}
4117
4118fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4119 let mut update = proto::UpdateChannels::default();
4120
4121 for channel in channels.channels {
4122 update.channels.push(channel.to_proto());
4123 }
4124
4125 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4126 update.latest_channel_message_ids = channels.latest_channel_messages;
4127
4128 for (channel_id, participants) in channels.channel_participants {
4129 update
4130 .channel_participants
4131 .push(proto::ChannelParticipants {
4132 channel_id: channel_id.to_proto(),
4133 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4134 });
4135 }
4136
4137 for channel in channels.invited_channels {
4138 update.channel_invitations.push(channel.to_proto());
4139 }
4140
4141 update
4142}
4143
4144fn build_initial_contacts_update(
4145 contacts: Vec<db::Contact>,
4146 pool: &ConnectionPool,
4147) -> proto::UpdateContacts {
4148 let mut update = proto::UpdateContacts::default();
4149
4150 for contact in contacts {
4151 match contact {
4152 db::Contact::Accepted { user_id, busy } => {
4153 update.contacts.push(contact_for_user(user_id, busy, pool));
4154 }
4155 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4156 db::Contact::Incoming { user_id } => {
4157 update
4158 .incoming_requests
4159 .push(proto::IncomingContactRequest {
4160 requester_id: user_id.to_proto(),
4161 })
4162 }
4163 }
4164 }
4165
4166 update
4167}
4168
4169fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4170 proto::Contact {
4171 user_id: user_id.to_proto(),
4172 online: pool.is_user_online(user_id),
4173 busy,
4174 }
4175}
4176
4177fn room_updated(room: &proto::Room, peer: &Peer) {
4178 broadcast(
4179 None,
4180 room.participants
4181 .iter()
4182 .filter_map(|participant| Some(participant.peer_id?.into())),
4183 |peer_id| {
4184 peer.send(
4185 peer_id,
4186 proto::RoomUpdated {
4187 room: Some(room.clone()),
4188 },
4189 )
4190 },
4191 );
4192}
4193
4194fn channel_updated(
4195 channel: &db::channel::Model,
4196 room: &proto::Room,
4197 peer: &Peer,
4198 pool: &ConnectionPool,
4199) {
4200 let participants = room
4201 .participants
4202 .iter()
4203 .map(|p| p.user_id)
4204 .collect::<Vec<_>>();
4205
4206 broadcast(
4207 None,
4208 pool.channel_connection_ids(channel.root_id())
4209 .filter_map(|(channel_id, role)| {
4210 role.can_see_channel(channel.visibility)
4211 .then_some(channel_id)
4212 }),
4213 |peer_id| {
4214 peer.send(
4215 peer_id,
4216 proto::UpdateChannels {
4217 channel_participants: vec![proto::ChannelParticipants {
4218 channel_id: channel.id.to_proto(),
4219 participant_user_ids: participants.clone(),
4220 }],
4221 ..Default::default()
4222 },
4223 )
4224 },
4225 );
4226}
4227
4228async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4229 let db = session.db().await;
4230
4231 let contacts = db.get_contacts(user_id).await?;
4232 let busy = db.is_user_busy(user_id).await?;
4233
4234 let pool = session.connection_pool().await;
4235 let updated_contact = contact_for_user(user_id, busy, &pool);
4236 for contact in contacts {
4237 if let db::Contact::Accepted {
4238 user_id: contact_user_id,
4239 ..
4240 } = contact
4241 {
4242 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4243 session
4244 .peer
4245 .send(
4246 contact_conn_id,
4247 proto::UpdateContacts {
4248 contacts: vec![updated_contact.clone()],
4249 remove_contacts: Default::default(),
4250 incoming_requests: Default::default(),
4251 remove_incoming_requests: Default::default(),
4252 outgoing_requests: Default::default(),
4253 remove_outgoing_requests: Default::default(),
4254 },
4255 )
4256 .trace_err();
4257 }
4258 }
4259 }
4260 Ok(())
4261}
4262
4263async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4264 let mut contacts_to_update = HashSet::default();
4265
4266 let room_id;
4267 let canceled_calls_to_user_ids;
4268 let livekit_room;
4269 let delete_livekit_room;
4270 let room;
4271 let channel;
4272
4273 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4274 contacts_to_update.insert(session.user_id());
4275
4276 for project in left_room.left_projects.values() {
4277 project_left(project, session);
4278 }
4279
4280 room_id = RoomId::from_proto(left_room.room.id);
4281 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4282 livekit_room = mem::take(&mut left_room.room.livekit_room);
4283 delete_livekit_room = left_room.deleted;
4284 room = mem::take(&mut left_room.room);
4285 channel = mem::take(&mut left_room.channel);
4286
4287 room_updated(&room, &session.peer);
4288 } else {
4289 return Ok(());
4290 }
4291
4292 if let Some(channel) = channel {
4293 channel_updated(
4294 &channel,
4295 &room,
4296 &session.peer,
4297 &*session.connection_pool().await,
4298 );
4299 }
4300
4301 {
4302 let pool = session.connection_pool().await;
4303 for canceled_user_id in canceled_calls_to_user_ids {
4304 for connection_id in pool.user_connection_ids(canceled_user_id) {
4305 session
4306 .peer
4307 .send(
4308 connection_id,
4309 proto::CallCanceled {
4310 room_id: room_id.to_proto(),
4311 },
4312 )
4313 .trace_err();
4314 }
4315 contacts_to_update.insert(canceled_user_id);
4316 }
4317 }
4318
4319 for contact_user_id in contacts_to_update {
4320 update_user_contacts(contact_user_id, session).await?;
4321 }
4322
4323 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4324 live_kit
4325 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4326 .await
4327 .trace_err();
4328
4329 if delete_livekit_room {
4330 live_kit.delete_room(livekit_room).await.trace_err();
4331 }
4332 }
4333
4334 Ok(())
4335}
4336
4337async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4338 let left_channel_buffers = session
4339 .db()
4340 .await
4341 .leave_channel_buffers(session.connection_id)
4342 .await?;
4343
4344 for left_buffer in left_channel_buffers {
4345 channel_buffer_updated(
4346 session.connection_id,
4347 left_buffer.connections,
4348 &proto::UpdateChannelBufferCollaborators {
4349 channel_id: left_buffer.channel_id.to_proto(),
4350 collaborators: left_buffer.collaborators,
4351 },
4352 &session.peer,
4353 );
4354 }
4355
4356 Ok(())
4357}
4358
4359fn project_left(project: &db::LeftProject, session: &Session) {
4360 for connection_id in &project.connection_ids {
4361 if project.should_unshare {
4362 session
4363 .peer
4364 .send(
4365 *connection_id,
4366 proto::UnshareProject {
4367 project_id: project.id.to_proto(),
4368 },
4369 )
4370 .trace_err();
4371 } else {
4372 session
4373 .peer
4374 .send(
4375 *connection_id,
4376 proto::RemoveProjectCollaborator {
4377 project_id: project.id.to_proto(),
4378 peer_id: Some(session.connection_id.into()),
4379 },
4380 )
4381 .trace_err();
4382 }
4383 }
4384}
4385
4386pub trait ResultExt {
4387 type Ok;
4388
4389 fn trace_err(self) -> Option<Self::Ok>;
4390}
4391
4392impl<T, E> ResultExt for Result<T, E>
4393where
4394 E: std::fmt::Debug,
4395{
4396 type Ok = T;
4397
4398 #[track_caller]
4399 fn trace_err(self) -> Option<T> {
4400 match self {
4401 Ok(value) => Some(value),
4402 Err(error) => {
4403 tracing::error!("{:?}", error);
4404 None
4405 }
4406 }
4407 }
4408}