1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
313 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
315 .add_request_handler(
316 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
323 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
324 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
325 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
326 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
327 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
328 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
329 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
330 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
331 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
334 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
335 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
336 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
337 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
338 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
339 .add_message_handler(create_buffer_for_peer)
340 .add_request_handler(update_buffer)
341 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
342 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
343 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
344 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
345 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
346 .add_request_handler(get_users)
347 .add_request_handler(fuzzy_search_users)
348 .add_request_handler(request_contact)
349 .add_request_handler(remove_contact)
350 .add_request_handler(respond_to_contact_request)
351 .add_message_handler(subscribe_to_channels)
352 .add_request_handler(create_channel)
353 .add_request_handler(delete_channel)
354 .add_request_handler(invite_channel_member)
355 .add_request_handler(remove_channel_member)
356 .add_request_handler(set_channel_member_role)
357 .add_request_handler(set_channel_visibility)
358 .add_request_handler(rename_channel)
359 .add_request_handler(join_channel_buffer)
360 .add_request_handler(leave_channel_buffer)
361 .add_message_handler(update_channel_buffer)
362 .add_request_handler(rejoin_channel_buffers)
363 .add_request_handler(get_channel_members)
364 .add_request_handler(respond_to_channel_invite)
365 .add_request_handler(join_channel)
366 .add_request_handler(join_channel_chat)
367 .add_message_handler(leave_channel_chat)
368 .add_request_handler(send_channel_message)
369 .add_request_handler(remove_channel_message)
370 .add_request_handler(update_channel_message)
371 .add_request_handler(get_channel_messages)
372 .add_request_handler(get_channel_messages_by_id)
373 .add_request_handler(get_notifications)
374 .add_request_handler(mark_notification_as_read)
375 .add_request_handler(move_channel)
376 .add_request_handler(follow)
377 .add_message_handler(unfollow)
378 .add_message_handler(update_followers)
379 .add_request_handler(get_private_user_info)
380 .add_request_handler(get_llm_api_token)
381 .add_request_handler(accept_terms_of_service)
382 .add_message_handler(acknowledge_channel_message)
383 .add_message_handler(acknowledge_buffer_version)
384 .add_request_handler(get_supermaven_api_key)
385 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
386 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
387 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
388 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
389 .add_message_handler(update_context)
390 .add_request_handler({
391 let app_state = app_state.clone();
392 move |request, response, session| {
393 let app_state = app_state.clone();
394 async move {
395 count_language_model_tokens(request, response, session, &app_state.config)
396 .await
397 }
398 }
399 })
400 .add_request_handler(get_cached_embeddings)
401 .add_request_handler({
402 let app_state = app_state.clone();
403 move |request, response, session| {
404 compute_embeddings(
405 request,
406 response,
407 session,
408 app_state.config.openai_api_key.clone(),
409 )
410 }
411 });
412
413 Arc::new(server)
414 }
415
416 pub async fn start(&self) -> Result<()> {
417 let server_id = *self.id.lock();
418 let app_state = self.app_state.clone();
419 let peer = self.peer.clone();
420 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
421 let pool = self.connection_pool.clone();
422 let livekit_client = self.app_state.livekit_client.clone();
423
424 let span = info_span!("start server");
425 self.app_state.executor.spawn_detached(
426 async move {
427 tracing::info!("waiting for cleanup timeout");
428 timeout.await;
429 tracing::info!("cleanup timeout expired, retrieving stale rooms");
430 if let Some((room_ids, channel_ids)) = app_state
431 .db
432 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
433 .await
434 .trace_err()
435 {
436 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
437 tracing::info!(
438 stale_channel_buffer_count = channel_ids.len(),
439 "retrieved stale channel buffers"
440 );
441
442 for channel_id in channel_ids {
443 if let Some(refreshed_channel_buffer) = app_state
444 .db
445 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
446 .await
447 .trace_err()
448 {
449 for connection_id in refreshed_channel_buffer.connection_ids {
450 peer.send(
451 connection_id,
452 proto::UpdateChannelBufferCollaborators {
453 channel_id: channel_id.to_proto(),
454 collaborators: refreshed_channel_buffer
455 .collaborators
456 .clone(),
457 },
458 )
459 .trace_err();
460 }
461 }
462 }
463
464 for room_id in room_ids {
465 let mut contacts_to_update = HashSet::default();
466 let mut canceled_calls_to_user_ids = Vec::new();
467 let mut livekit_room = String::new();
468 let mut delete_livekit_room = false;
469
470 if let Some(mut refreshed_room) = app_state
471 .db
472 .clear_stale_room_participants(room_id, server_id)
473 .await
474 .trace_err()
475 {
476 tracing::info!(
477 room_id = room_id.0,
478 new_participant_count = refreshed_room.room.participants.len(),
479 "refreshed room"
480 );
481 room_updated(&refreshed_room.room, &peer);
482 if let Some(channel) = refreshed_room.channel.as_ref() {
483 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
484 }
485 contacts_to_update
486 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
487 contacts_to_update
488 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
489 canceled_calls_to_user_ids =
490 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
491 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
492 delete_livekit_room = refreshed_room.room.participants.is_empty();
493 }
494
495 {
496 let pool = pool.lock();
497 for canceled_user_id in canceled_calls_to_user_ids {
498 for connection_id in pool.user_connection_ids(canceled_user_id) {
499 peer.send(
500 connection_id,
501 proto::CallCanceled {
502 room_id: room_id.to_proto(),
503 },
504 )
505 .trace_err();
506 }
507 }
508 }
509
510 for user_id in contacts_to_update {
511 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
512 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
513 if let Some((busy, contacts)) = busy.zip(contacts) {
514 let pool = pool.lock();
515 let updated_contact = contact_for_user(user_id, busy, &pool);
516 for contact in contacts {
517 if let db::Contact::Accepted {
518 user_id: contact_user_id,
519 ..
520 } = contact
521 {
522 for contact_conn_id in
523 pool.user_connection_ids(contact_user_id)
524 {
525 peer.send(
526 contact_conn_id,
527 proto::UpdateContacts {
528 contacts: vec![updated_contact.clone()],
529 remove_contacts: Default::default(),
530 incoming_requests: Default::default(),
531 remove_incoming_requests: Default::default(),
532 outgoing_requests: Default::default(),
533 remove_outgoing_requests: Default::default(),
534 },
535 )
536 .trace_err();
537 }
538 }
539 }
540 }
541 }
542
543 if let Some(live_kit) = livekit_client.as_ref() {
544 if delete_livekit_room {
545 live_kit.delete_room(livekit_room).await.trace_err();
546 }
547 }
548 }
549 }
550
551 app_state
552 .db
553 .delete_stale_servers(&app_state.config.zed_environment, server_id)
554 .await
555 .trace_err();
556 }
557 .instrument(span),
558 );
559 Ok(())
560 }
561
562 pub fn teardown(&self) {
563 self.peer.teardown();
564 self.connection_pool.lock().reset();
565 let _ = self.teardown.send(true);
566 }
567
568 #[cfg(test)]
569 pub fn reset(&self, id: ServerId) {
570 self.teardown();
571 *self.id.lock() = id;
572 self.peer.reset(id.0 as u32);
573 let _ = self.teardown.send(false);
574 }
575
576 #[cfg(test)]
577 pub fn id(&self) -> ServerId {
578 *self.id.lock()
579 }
580
581 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
582 where
583 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
584 Fut: 'static + Send + Future<Output = Result<()>>,
585 M: EnvelopedMessage,
586 {
587 let prev_handler = self.handlers.insert(
588 TypeId::of::<M>(),
589 Box::new(move |envelope, session| {
590 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
591 let received_at = envelope.received_at;
592 tracing::info!("message received");
593 let start_time = Instant::now();
594 let future = (handler)(*envelope, session);
595 async move {
596 let result = future.await;
597 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
598 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
599 let queue_duration_ms = total_duration_ms - processing_duration_ms;
600 let payload_type = M::NAME;
601
602 match result {
603 Err(error) => {
604 tracing::error!(
605 ?error,
606 total_duration_ms,
607 processing_duration_ms,
608 queue_duration_ms,
609 payload_type,
610 "error handling message"
611 )
612 }
613 Ok(()) => tracing::info!(
614 total_duration_ms,
615 processing_duration_ms,
616 queue_duration_ms,
617 "finished handling message"
618 ),
619 }
620 }
621 .boxed()
622 }),
623 );
624 if prev_handler.is_some() {
625 panic!("registered a handler for the same message twice");
626 }
627 self
628 }
629
630 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
631 where
632 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
633 Fut: 'static + Send + Future<Output = Result<()>>,
634 M: EnvelopedMessage,
635 {
636 self.add_handler(move |envelope, session| handler(envelope.payload, session));
637 self
638 }
639
640 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
641 where
642 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
643 Fut: Send + Future<Output = Result<()>>,
644 M: RequestMessage,
645 {
646 let handler = Arc::new(handler);
647 self.add_handler(move |envelope, session| {
648 let receipt = envelope.receipt();
649 let handler = handler.clone();
650 async move {
651 let peer = session.peer.clone();
652 let responded = Arc::new(AtomicBool::default());
653 let response = Response {
654 peer: peer.clone(),
655 responded: responded.clone(),
656 receipt,
657 };
658 match (handler)(envelope.payload, response, session).await {
659 Ok(()) => {
660 if responded.load(std::sync::atomic::Ordering::SeqCst) {
661 Ok(())
662 } else {
663 Err(anyhow!("handler did not send a response"))?
664 }
665 }
666 Err(error) => {
667 let proto_err = match &error {
668 Error::Internal(err) => err.to_proto(),
669 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
670 };
671 peer.respond_with_error(receipt, proto_err)?;
672 Err(error)
673 }
674 }
675 }
676 })
677 }
678
679 #[allow(clippy::too_many_arguments)]
680 pub fn handle_connection(
681 self: &Arc<Self>,
682 connection: Connection,
683 address: String,
684 principal: Principal,
685 zed_version: ZedVersion,
686 geoip_country_code: Option<String>,
687 system_id: Option<String>,
688 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
689 executor: Executor,
690 ) -> impl Future<Output = ()> {
691 let this = self.clone();
692 let span = info_span!("handle connection", %address,
693 connection_id=field::Empty,
694 user_id=field::Empty,
695 login=field::Empty,
696 impersonator=field::Empty,
697 geoip_country_code=field::Empty
698 );
699 principal.update_span(&span);
700 if let Some(country_code) = geoip_country_code.as_ref() {
701 span.record("geoip_country_code", country_code);
702 }
703
704 let mut teardown = self.teardown.subscribe();
705 async move {
706 if *teardown.borrow() {
707 tracing::error!("server is tearing down");
708 return
709 }
710 let (connection_id, handle_io, mut incoming_rx) = this
711 .peer
712 .add_connection(connection, {
713 let executor = executor.clone();
714 move |duration| executor.sleep(duration)
715 });
716 tracing::Span::current().record("connection_id", format!("{}", connection_id));
717
718 tracing::info!("connection opened");
719
720 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
721 let http_client = match ReqwestClient::user_agent(&user_agent) {
722 Ok(http_client) => Arc::new(http_client),
723 Err(error) => {
724 tracing::error!(?error, "failed to create HTTP client");
725 return;
726 }
727 };
728
729 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
730 supermaven_admin_api_key.to_string(),
731 http_client.clone(),
732 )));
733
734 let session = Session {
735 principal: principal.clone(),
736 connection_id,
737 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
738 peer: this.peer.clone(),
739 connection_pool: this.connection_pool.clone(),
740 app_state: this.app_state.clone(),
741 http_client,
742 geoip_country_code,
743 system_id,
744 _executor: executor.clone(),
745 supermaven_client,
746 };
747
748 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
749 tracing::error!(?error, "failed to send initial client update");
750 return;
751 }
752
753 let handle_io = handle_io.fuse();
754 futures::pin_mut!(handle_io);
755
756 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
757 // This prevents deadlocks when e.g., client A performs a request to client B and
758 // client B performs a request to client A. If both clients stop processing further
759 // messages until their respective request completes, they won't have a chance to
760 // respond to the other client's request and cause a deadlock.
761 //
762 // This arrangement ensures we will attempt to process earlier messages first, but fall
763 // back to processing messages arrived later in the spirit of making progress.
764 let mut foreground_message_handlers = FuturesUnordered::new();
765 let concurrent_handlers = Arc::new(Semaphore::new(256));
766 loop {
767 let next_message = async {
768 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
769 let message = incoming_rx.next().await;
770 (permit, message)
771 }.fuse();
772 futures::pin_mut!(next_message);
773 futures::select_biased! {
774 _ = teardown.changed().fuse() => return,
775 result = handle_io => {
776 if let Err(error) = result {
777 tracing::error!(?error, "error handling I/O");
778 }
779 break;
780 }
781 _ = foreground_message_handlers.next() => {}
782 next_message = next_message => {
783 let (permit, message) = next_message;
784 if let Some(message) = message {
785 let type_name = message.payload_type_name();
786 // note: we copy all the fields from the parent span so we can query them in the logs.
787 // (https://github.com/tokio-rs/tracing/issues/2670).
788 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
789 user_id=field::Empty,
790 login=field::Empty,
791 impersonator=field::Empty,
792 );
793 principal.update_span(&span);
794 let span_enter = span.enter();
795 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
796 let is_background = message.is_background();
797 let handle_message = (handler)(message, session.clone());
798 drop(span_enter);
799
800 let handle_message = async move {
801 handle_message.await;
802 drop(permit);
803 }.instrument(span);
804 if is_background {
805 executor.spawn_detached(handle_message);
806 } else {
807 foreground_message_handlers.push(handle_message);
808 }
809 } else {
810 tracing::error!("no message handler");
811 }
812 } else {
813 tracing::info!("connection closed");
814 break;
815 }
816 }
817 }
818 }
819
820 drop(foreground_message_handlers);
821 tracing::info!("signing out");
822 if let Err(error) = connection_lost(session, teardown, executor).await {
823 tracing::error!(?error, "error signing out");
824 }
825
826 }.instrument(span)
827 }
828
829 async fn send_initial_client_update(
830 &self,
831 connection_id: ConnectionId,
832 principal: &Principal,
833 zed_version: ZedVersion,
834 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
835 session: &Session,
836 ) -> Result<()> {
837 self.peer.send(
838 connection_id,
839 proto::Hello {
840 peer_id: Some(connection_id.into()),
841 },
842 )?;
843 tracing::info!("sent hello message");
844 if let Some(send_connection_id) = send_connection_id.take() {
845 let _ = send_connection_id.send(connection_id);
846 }
847
848 match principal {
849 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
850 if !user.connected_once {
851 self.peer.send(connection_id, proto::ShowContacts {})?;
852 self.app_state
853 .db
854 .set_user_connected_once(user.id, true)
855 .await?;
856 }
857
858 update_user_plan(user.id, session).await?;
859
860 let contacts = self.app_state.db.get_contacts(user.id).await?;
861
862 {
863 let mut pool = self.connection_pool.lock();
864 pool.add_connection(connection_id, user.id, user.admin, zed_version);
865 self.peer.send(
866 connection_id,
867 build_initial_contacts_update(contacts, &pool),
868 )?;
869 }
870
871 if should_auto_subscribe_to_channels(zed_version) {
872 subscribe_user_to_channels(user.id, session).await?;
873 }
874
875 if let Some(incoming_call) =
876 self.app_state.db.incoming_call_for_user(user.id).await?
877 {
878 self.peer.send(connection_id, incoming_call)?;
879 }
880
881 update_user_contacts(user.id, session).await?;
882 }
883 }
884
885 Ok(())
886 }
887
888 pub async fn invite_code_redeemed(
889 self: &Arc<Self>,
890 inviter_id: UserId,
891 invitee_id: UserId,
892 ) -> Result<()> {
893 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
894 if let Some(code) = &user.invite_code {
895 let pool = self.connection_pool.lock();
896 let invitee_contact = contact_for_user(invitee_id, false, &pool);
897 for connection_id in pool.user_connection_ids(inviter_id) {
898 self.peer.send(
899 connection_id,
900 proto::UpdateContacts {
901 contacts: vec![invitee_contact.clone()],
902 ..Default::default()
903 },
904 )?;
905 self.peer.send(
906 connection_id,
907 proto::UpdateInviteInfo {
908 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
909 count: user.invite_count as u32,
910 },
911 )?;
912 }
913 }
914 }
915 Ok(())
916 }
917
918 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
919 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
920 if let Some(invite_code) = &user.invite_code {
921 let pool = self.connection_pool.lock();
922 for connection_id in pool.user_connection_ids(user_id) {
923 self.peer.send(
924 connection_id,
925 proto::UpdateInviteInfo {
926 url: format!(
927 "{}{}",
928 self.app_state.config.invite_link_prefix, invite_code
929 ),
930 count: user.invite_count as u32,
931 },
932 )?;
933 }
934 }
935 }
936 Ok(())
937 }
938
939 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
940 let pool = self.connection_pool.lock();
941 for connection_id in pool.user_connection_ids(user_id) {
942 self.peer
943 .send(connection_id, proto::RefreshLlmToken {})
944 .trace_err();
945 }
946 }
947
948 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
949 ServerSnapshot {
950 connection_pool: ConnectionPoolGuard {
951 guard: self.connection_pool.lock(),
952 _not_send: PhantomData,
953 },
954 peer: &self.peer,
955 }
956 }
957}
958
959impl<'a> Deref for ConnectionPoolGuard<'a> {
960 type Target = ConnectionPool;
961
962 fn deref(&self) -> &Self::Target {
963 &self.guard
964 }
965}
966
967impl<'a> DerefMut for ConnectionPoolGuard<'a> {
968 fn deref_mut(&mut self) -> &mut Self::Target {
969 &mut self.guard
970 }
971}
972
973impl<'a> Drop for ConnectionPoolGuard<'a> {
974 fn drop(&mut self) {
975 #[cfg(test)]
976 self.check_invariants();
977 }
978}
979
980fn broadcast<F>(
981 sender_id: Option<ConnectionId>,
982 receiver_ids: impl IntoIterator<Item = ConnectionId>,
983 mut f: F,
984) where
985 F: FnMut(ConnectionId) -> anyhow::Result<()>,
986{
987 for receiver_id in receiver_ids {
988 if Some(receiver_id) != sender_id {
989 if let Err(error) = f(receiver_id) {
990 tracing::error!("failed to send to {:?} {}", receiver_id, error);
991 }
992 }
993 }
994}
995
996pub struct ProtocolVersion(u32);
997
998impl Header for ProtocolVersion {
999 fn name() -> &'static HeaderName {
1000 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1001 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1002 }
1003
1004 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1005 where
1006 Self: Sized,
1007 I: Iterator<Item = &'i axum::http::HeaderValue>,
1008 {
1009 let version = values
1010 .next()
1011 .ok_or_else(axum::headers::Error::invalid)?
1012 .to_str()
1013 .map_err(|_| axum::headers::Error::invalid())?
1014 .parse()
1015 .map_err(|_| axum::headers::Error::invalid())?;
1016 Ok(Self(version))
1017 }
1018
1019 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1020 values.extend([self.0.to_string().parse().unwrap()]);
1021 }
1022}
1023
1024pub struct AppVersionHeader(SemanticVersion);
1025impl Header for AppVersionHeader {
1026 fn name() -> &'static HeaderName {
1027 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1028 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1029 }
1030
1031 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1032 where
1033 Self: Sized,
1034 I: Iterator<Item = &'i axum::http::HeaderValue>,
1035 {
1036 let version = values
1037 .next()
1038 .ok_or_else(axum::headers::Error::invalid)?
1039 .to_str()
1040 .map_err(|_| axum::headers::Error::invalid())?
1041 .parse()
1042 .map_err(|_| axum::headers::Error::invalid())?;
1043 Ok(Self(version))
1044 }
1045
1046 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1047 values.extend([self.0.to_string().parse().unwrap()]);
1048 }
1049}
1050
1051pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1052 Router::new()
1053 .route("/rpc", get(handle_websocket_request))
1054 .layer(
1055 ServiceBuilder::new()
1056 .layer(Extension(server.app_state.clone()))
1057 .layer(middleware::from_fn(auth::validate_header)),
1058 )
1059 .route("/metrics", get(handle_metrics))
1060 .layer(Extension(server))
1061}
1062
1063#[allow(clippy::too_many_arguments)]
1064pub async fn handle_websocket_request(
1065 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1066 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1067 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1068 Extension(server): Extension<Arc<Server>>,
1069 Extension(principal): Extension<Principal>,
1070 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1071 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1072 ws: WebSocketUpgrade,
1073) -> axum::response::Response {
1074 if protocol_version != rpc::PROTOCOL_VERSION {
1075 return (
1076 StatusCode::UPGRADE_REQUIRED,
1077 "client must be upgraded".to_string(),
1078 )
1079 .into_response();
1080 }
1081
1082 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1083 return (
1084 StatusCode::UPGRADE_REQUIRED,
1085 "no version header found".to_string(),
1086 )
1087 .into_response();
1088 };
1089
1090 if !version.can_collaborate() {
1091 return (
1092 StatusCode::UPGRADE_REQUIRED,
1093 "client must be upgraded".to_string(),
1094 )
1095 .into_response();
1096 }
1097
1098 let socket_address = socket_address.to_string();
1099 ws.on_upgrade(move |socket| {
1100 let socket = socket
1101 .map_ok(to_tungstenite_message)
1102 .err_into()
1103 .with(|message| async move { to_axum_message(message) });
1104 let connection = Connection::new(Box::pin(socket));
1105 async move {
1106 server
1107 .handle_connection(
1108 connection,
1109 socket_address,
1110 principal,
1111 version,
1112 country_code_header.map(|header| header.to_string()),
1113 system_id_header.map(|header| header.to_string()),
1114 None,
1115 Executor::Production,
1116 )
1117 .await;
1118 }
1119 })
1120}
1121
1122pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1123 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1124 let connections_metric = CONNECTIONS_METRIC
1125 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1126
1127 let connections = server
1128 .connection_pool
1129 .lock()
1130 .connections()
1131 .filter(|connection| !connection.admin)
1132 .count();
1133 connections_metric.set(connections as _);
1134
1135 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1137 register_int_gauge!(
1138 "shared_projects",
1139 "number of open projects with one or more guests"
1140 )
1141 .unwrap()
1142 });
1143
1144 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1145 shared_projects_metric.set(shared_projects as _);
1146
1147 let encoder = prometheus::TextEncoder::new();
1148 let metric_families = prometheus::gather();
1149 let encoded_metrics = encoder
1150 .encode_to_string(&metric_families)
1151 .map_err(|err| anyhow!("{}", err))?;
1152 Ok(encoded_metrics)
1153}
1154
1155#[instrument(err, skip(executor))]
1156async fn connection_lost(
1157 session: Session,
1158 mut teardown: watch::Receiver<bool>,
1159 executor: Executor,
1160) -> Result<()> {
1161 session.peer.disconnect(session.connection_id);
1162 session
1163 .connection_pool()
1164 .await
1165 .remove_connection(session.connection_id)?;
1166
1167 session
1168 .db()
1169 .await
1170 .connection_lost(session.connection_id)
1171 .await
1172 .trace_err();
1173
1174 futures::select_biased! {
1175 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1176
1177 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1178 leave_room_for_session(&session, session.connection_id).await.trace_err();
1179 leave_channel_buffers_for_session(&session)
1180 .await
1181 .trace_err();
1182
1183 if !session
1184 .connection_pool()
1185 .await
1186 .is_user_online(session.user_id())
1187 {
1188 let db = session.db().await;
1189 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1190 room_updated(&room, &session.peer);
1191 }
1192 }
1193
1194 update_user_contacts(session.user_id(), &session).await?;
1195 },
1196 _ = teardown.changed().fuse() => {}
1197 }
1198
1199 Ok(())
1200}
1201
1202/// Acknowledges a ping from a client, used to keep the connection alive.
1203async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1204 response.send(proto::Ack {})?;
1205 Ok(())
1206}
1207
1208/// Creates a new room for calling (outside of channels)
1209async fn create_room(
1210 _request: proto::CreateRoom,
1211 response: Response<proto::CreateRoom>,
1212 session: Session,
1213) -> Result<()> {
1214 let livekit_room = nanoid::nanoid!(30);
1215
1216 let live_kit_connection_info = util::maybe!(async {
1217 let live_kit = session.app_state.livekit_client.as_ref();
1218 let live_kit = live_kit?;
1219 let user_id = session.user_id().to_string();
1220
1221 let token = live_kit
1222 .room_token(&livekit_room, &user_id.to_string())
1223 .trace_err()?;
1224
1225 Some(proto::LiveKitConnectionInfo {
1226 server_url: live_kit.url().into(),
1227 token,
1228 can_publish: true,
1229 })
1230 })
1231 .await;
1232
1233 let room = session
1234 .db()
1235 .await
1236 .create_room(session.user_id(), session.connection_id, &livekit_room)
1237 .await?;
1238
1239 response.send(proto::CreateRoomResponse {
1240 room: Some(room.clone()),
1241 live_kit_connection_info,
1242 })?;
1243
1244 update_user_contacts(session.user_id(), &session).await?;
1245 Ok(())
1246}
1247
1248/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1249async fn join_room(
1250 request: proto::JoinRoom,
1251 response: Response<proto::JoinRoom>,
1252 session: Session,
1253) -> Result<()> {
1254 let room_id = RoomId::from_proto(request.id);
1255
1256 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1257
1258 if let Some(channel_id) = channel_id {
1259 return join_channel_internal(channel_id, Box::new(response), session).await;
1260 }
1261
1262 let joined_room = {
1263 let room = session
1264 .db()
1265 .await
1266 .join_room(room_id, session.user_id(), session.connection_id)
1267 .await?;
1268 room_updated(&room.room, &session.peer);
1269 room.into_inner()
1270 };
1271
1272 for connection_id in session
1273 .connection_pool()
1274 .await
1275 .user_connection_ids(session.user_id())
1276 {
1277 session
1278 .peer
1279 .send(
1280 connection_id,
1281 proto::CallCanceled {
1282 room_id: room_id.to_proto(),
1283 },
1284 )
1285 .trace_err();
1286 }
1287
1288 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1289 {
1290 live_kit
1291 .room_token(
1292 &joined_room.room.livekit_room,
1293 &session.user_id().to_string(),
1294 )
1295 .trace_err()
1296 .map(|token| proto::LiveKitConnectionInfo {
1297 server_url: live_kit.url().into(),
1298 token,
1299 can_publish: true,
1300 })
1301 } else {
1302 None
1303 };
1304
1305 response.send(proto::JoinRoomResponse {
1306 room: Some(joined_room.room),
1307 channel_id: None,
1308 live_kit_connection_info,
1309 })?;
1310
1311 update_user_contacts(session.user_id(), &session).await?;
1312 Ok(())
1313}
1314
1315/// Rejoin room is used to reconnect to a room after connection errors.
1316async fn rejoin_room(
1317 request: proto::RejoinRoom,
1318 response: Response<proto::RejoinRoom>,
1319 session: Session,
1320) -> Result<()> {
1321 let room;
1322 let channel;
1323 {
1324 let mut rejoined_room = session
1325 .db()
1326 .await
1327 .rejoin_room(request, session.user_id(), session.connection_id)
1328 .await?;
1329
1330 response.send(proto::RejoinRoomResponse {
1331 room: Some(rejoined_room.room.clone()),
1332 reshared_projects: rejoined_room
1333 .reshared_projects
1334 .iter()
1335 .map(|project| proto::ResharedProject {
1336 id: project.id.to_proto(),
1337 collaborators: project
1338 .collaborators
1339 .iter()
1340 .map(|collaborator| collaborator.to_proto())
1341 .collect(),
1342 })
1343 .collect(),
1344 rejoined_projects: rejoined_room
1345 .rejoined_projects
1346 .iter()
1347 .map(|rejoined_project| rejoined_project.to_proto())
1348 .collect(),
1349 })?;
1350 room_updated(&rejoined_room.room, &session.peer);
1351
1352 for project in &rejoined_room.reshared_projects {
1353 for collaborator in &project.collaborators {
1354 session
1355 .peer
1356 .send(
1357 collaborator.connection_id,
1358 proto::UpdateProjectCollaborator {
1359 project_id: project.id.to_proto(),
1360 old_peer_id: Some(project.old_connection_id.into()),
1361 new_peer_id: Some(session.connection_id.into()),
1362 },
1363 )
1364 .trace_err();
1365 }
1366
1367 broadcast(
1368 Some(session.connection_id),
1369 project
1370 .collaborators
1371 .iter()
1372 .map(|collaborator| collaborator.connection_id),
1373 |connection_id| {
1374 session.peer.forward_send(
1375 session.connection_id,
1376 connection_id,
1377 proto::UpdateProject {
1378 project_id: project.id.to_proto(),
1379 worktrees: project.worktrees.clone(),
1380 },
1381 )
1382 },
1383 );
1384 }
1385
1386 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1387
1388 let rejoined_room = rejoined_room.into_inner();
1389
1390 room = rejoined_room.room;
1391 channel = rejoined_room.channel;
1392 }
1393
1394 if let Some(channel) = channel {
1395 channel_updated(
1396 &channel,
1397 &room,
1398 &session.peer,
1399 &*session.connection_pool().await,
1400 );
1401 }
1402
1403 update_user_contacts(session.user_id(), &session).await?;
1404 Ok(())
1405}
1406
1407fn notify_rejoined_projects(
1408 rejoined_projects: &mut Vec<RejoinedProject>,
1409 session: &Session,
1410) -> Result<()> {
1411 for project in rejoined_projects.iter() {
1412 for collaborator in &project.collaborators {
1413 session
1414 .peer
1415 .send(
1416 collaborator.connection_id,
1417 proto::UpdateProjectCollaborator {
1418 project_id: project.id.to_proto(),
1419 old_peer_id: Some(project.old_connection_id.into()),
1420 new_peer_id: Some(session.connection_id.into()),
1421 },
1422 )
1423 .trace_err();
1424 }
1425 }
1426
1427 for project in rejoined_projects {
1428 for worktree in mem::take(&mut project.worktrees) {
1429 // Stream this worktree's entries.
1430 let message = proto::UpdateWorktree {
1431 project_id: project.id.to_proto(),
1432 worktree_id: worktree.id,
1433 abs_path: worktree.abs_path.clone(),
1434 root_name: worktree.root_name,
1435 updated_entries: worktree.updated_entries,
1436 removed_entries: worktree.removed_entries,
1437 scan_id: worktree.scan_id,
1438 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1439 updated_repositories: worktree.updated_repositories,
1440 removed_repositories: worktree.removed_repositories,
1441 };
1442 for update in proto::split_worktree_update(message) {
1443 session.peer.send(session.connection_id, update.clone())?;
1444 }
1445
1446 // Stream this worktree's diagnostics.
1447 for summary in worktree.diagnostic_summaries {
1448 session.peer.send(
1449 session.connection_id,
1450 proto::UpdateDiagnosticSummary {
1451 project_id: project.id.to_proto(),
1452 worktree_id: worktree.id,
1453 summary: Some(summary),
1454 },
1455 )?;
1456 }
1457
1458 for settings_file in worktree.settings_files {
1459 session.peer.send(
1460 session.connection_id,
1461 proto::UpdateWorktreeSettings {
1462 project_id: project.id.to_proto(),
1463 worktree_id: worktree.id,
1464 path: settings_file.path,
1465 content: Some(settings_file.content),
1466 kind: Some(settings_file.kind.to_proto().into()),
1467 },
1468 )?;
1469 }
1470 }
1471
1472 for language_server in &project.language_servers {
1473 session.peer.send(
1474 session.connection_id,
1475 proto::UpdateLanguageServer {
1476 project_id: project.id.to_proto(),
1477 language_server_id: language_server.id,
1478 variant: Some(
1479 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1480 proto::LspDiskBasedDiagnosticsUpdated {},
1481 ),
1482 ),
1483 },
1484 )?;
1485 }
1486 }
1487 Ok(())
1488}
1489
1490/// leave room disconnects from the room.
1491async fn leave_room(
1492 _: proto::LeaveRoom,
1493 response: Response<proto::LeaveRoom>,
1494 session: Session,
1495) -> Result<()> {
1496 leave_room_for_session(&session, session.connection_id).await?;
1497 response.send(proto::Ack {})?;
1498 Ok(())
1499}
1500
1501/// Updates the permissions of someone else in the room.
1502async fn set_room_participant_role(
1503 request: proto::SetRoomParticipantRole,
1504 response: Response<proto::SetRoomParticipantRole>,
1505 session: Session,
1506) -> Result<()> {
1507 let user_id = UserId::from_proto(request.user_id);
1508 let role = ChannelRole::from(request.role());
1509
1510 let (livekit_room, can_publish) = {
1511 let room = session
1512 .db()
1513 .await
1514 .set_room_participant_role(
1515 session.user_id(),
1516 RoomId::from_proto(request.room_id),
1517 user_id,
1518 role,
1519 )
1520 .await?;
1521
1522 let livekit_room = room.livekit_room.clone();
1523 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1524 room_updated(&room, &session.peer);
1525 (livekit_room, can_publish)
1526 };
1527
1528 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1529 live_kit
1530 .update_participant(
1531 livekit_room.clone(),
1532 request.user_id.to_string(),
1533 livekit_server::proto::ParticipantPermission {
1534 can_subscribe: true,
1535 can_publish,
1536 can_publish_data: can_publish,
1537 hidden: false,
1538 recorder: false,
1539 },
1540 )
1541 .await
1542 .trace_err();
1543 }
1544
1545 response.send(proto::Ack {})?;
1546 Ok(())
1547}
1548
1549/// Call someone else into the current room
1550async fn call(
1551 request: proto::Call,
1552 response: Response<proto::Call>,
1553 session: Session,
1554) -> Result<()> {
1555 let room_id = RoomId::from_proto(request.room_id);
1556 let calling_user_id = session.user_id();
1557 let calling_connection_id = session.connection_id;
1558 let called_user_id = UserId::from_proto(request.called_user_id);
1559 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1560 if !session
1561 .db()
1562 .await
1563 .has_contact(calling_user_id, called_user_id)
1564 .await?
1565 {
1566 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1567 }
1568
1569 let incoming_call = {
1570 let (room, incoming_call) = &mut *session
1571 .db()
1572 .await
1573 .call(
1574 room_id,
1575 calling_user_id,
1576 calling_connection_id,
1577 called_user_id,
1578 initial_project_id,
1579 )
1580 .await?;
1581 room_updated(room, &session.peer);
1582 mem::take(incoming_call)
1583 };
1584 update_user_contacts(called_user_id, &session).await?;
1585
1586 let mut calls = session
1587 .connection_pool()
1588 .await
1589 .user_connection_ids(called_user_id)
1590 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1591 .collect::<FuturesUnordered<_>>();
1592
1593 while let Some(call_response) = calls.next().await {
1594 match call_response.as_ref() {
1595 Ok(_) => {
1596 response.send(proto::Ack {})?;
1597 return Ok(());
1598 }
1599 Err(_) => {
1600 call_response.trace_err();
1601 }
1602 }
1603 }
1604
1605 {
1606 let room = session
1607 .db()
1608 .await
1609 .call_failed(room_id, called_user_id)
1610 .await?;
1611 room_updated(&room, &session.peer);
1612 }
1613 update_user_contacts(called_user_id, &session).await?;
1614
1615 Err(anyhow!("failed to ring user"))?
1616}
1617
1618/// Cancel an outgoing call.
1619async fn cancel_call(
1620 request: proto::CancelCall,
1621 response: Response<proto::CancelCall>,
1622 session: Session,
1623) -> Result<()> {
1624 let called_user_id = UserId::from_proto(request.called_user_id);
1625 let room_id = RoomId::from_proto(request.room_id);
1626 {
1627 let room = session
1628 .db()
1629 .await
1630 .cancel_call(room_id, session.connection_id, called_user_id)
1631 .await?;
1632 room_updated(&room, &session.peer);
1633 }
1634
1635 for connection_id in session
1636 .connection_pool()
1637 .await
1638 .user_connection_ids(called_user_id)
1639 {
1640 session
1641 .peer
1642 .send(
1643 connection_id,
1644 proto::CallCanceled {
1645 room_id: room_id.to_proto(),
1646 },
1647 )
1648 .trace_err();
1649 }
1650 response.send(proto::Ack {})?;
1651
1652 update_user_contacts(called_user_id, &session).await?;
1653 Ok(())
1654}
1655
1656/// Decline an incoming call.
1657async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1658 let room_id = RoomId::from_proto(message.room_id);
1659 {
1660 let room = session
1661 .db()
1662 .await
1663 .decline_call(Some(room_id), session.user_id())
1664 .await?
1665 .ok_or_else(|| anyhow!("failed to decline call"))?;
1666 room_updated(&room, &session.peer);
1667 }
1668
1669 for connection_id in session
1670 .connection_pool()
1671 .await
1672 .user_connection_ids(session.user_id())
1673 {
1674 session
1675 .peer
1676 .send(
1677 connection_id,
1678 proto::CallCanceled {
1679 room_id: room_id.to_proto(),
1680 },
1681 )
1682 .trace_err();
1683 }
1684 update_user_contacts(session.user_id(), &session).await?;
1685 Ok(())
1686}
1687
1688/// Updates other participants in the room with your current location.
1689async fn update_participant_location(
1690 request: proto::UpdateParticipantLocation,
1691 response: Response<proto::UpdateParticipantLocation>,
1692 session: Session,
1693) -> Result<()> {
1694 let room_id = RoomId::from_proto(request.room_id);
1695 let location = request
1696 .location
1697 .ok_or_else(|| anyhow!("invalid location"))?;
1698
1699 let db = session.db().await;
1700 let room = db
1701 .update_room_participant_location(room_id, session.connection_id, location)
1702 .await?;
1703
1704 room_updated(&room, &session.peer);
1705 response.send(proto::Ack {})?;
1706 Ok(())
1707}
1708
1709/// Share a project into the room.
1710async fn share_project(
1711 request: proto::ShareProject,
1712 response: Response<proto::ShareProject>,
1713 session: Session,
1714) -> Result<()> {
1715 let (project_id, room) = &*session
1716 .db()
1717 .await
1718 .share_project(
1719 RoomId::from_proto(request.room_id),
1720 session.connection_id,
1721 &request.worktrees,
1722 request.is_ssh_project,
1723 )
1724 .await?;
1725 response.send(proto::ShareProjectResponse {
1726 project_id: project_id.to_proto(),
1727 })?;
1728 room_updated(room, &session.peer);
1729
1730 Ok(())
1731}
1732
1733/// Unshare a project from the room.
1734async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1735 let project_id = ProjectId::from_proto(message.project_id);
1736 unshare_project_internal(project_id, session.connection_id, &session).await
1737}
1738
1739async fn unshare_project_internal(
1740 project_id: ProjectId,
1741 connection_id: ConnectionId,
1742 session: &Session,
1743) -> Result<()> {
1744 let delete = {
1745 let room_guard = session
1746 .db()
1747 .await
1748 .unshare_project(project_id, connection_id)
1749 .await?;
1750
1751 let (delete, room, guest_connection_ids) = &*room_guard;
1752
1753 let message = proto::UnshareProject {
1754 project_id: project_id.to_proto(),
1755 };
1756
1757 broadcast(
1758 Some(connection_id),
1759 guest_connection_ids.iter().copied(),
1760 |conn_id| session.peer.send(conn_id, message.clone()),
1761 );
1762 if let Some(room) = room {
1763 room_updated(room, &session.peer);
1764 }
1765
1766 *delete
1767 };
1768
1769 if delete {
1770 let db = session.db().await;
1771 db.delete_project(project_id).await?;
1772 }
1773
1774 Ok(())
1775}
1776
1777/// Join someone elses shared project.
1778async fn join_project(
1779 request: proto::JoinProject,
1780 response: Response<proto::JoinProject>,
1781 session: Session,
1782) -> Result<()> {
1783 let project_id = ProjectId::from_proto(request.project_id);
1784
1785 tracing::info!(%project_id, "join project");
1786
1787 let db = session.db().await;
1788 let (project, replica_id) = &mut *db
1789 .join_project(project_id, session.connection_id, session.user_id())
1790 .await?;
1791 drop(db);
1792 tracing::info!(%project_id, "join remote project");
1793 join_project_internal(response, session, project, replica_id)
1794}
1795
1796trait JoinProjectInternalResponse {
1797 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1798}
1799impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1800 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1801 Response::<proto::JoinProject>::send(self, result)
1802 }
1803}
1804
1805fn join_project_internal(
1806 response: impl JoinProjectInternalResponse,
1807 session: Session,
1808 project: &mut Project,
1809 replica_id: &ReplicaId,
1810) -> Result<()> {
1811 let collaborators = project
1812 .collaborators
1813 .iter()
1814 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1815 .map(|collaborator| collaborator.to_proto())
1816 .collect::<Vec<_>>();
1817 let project_id = project.id;
1818 let guest_user_id = session.user_id();
1819
1820 let worktrees = project
1821 .worktrees
1822 .iter()
1823 .map(|(id, worktree)| proto::WorktreeMetadata {
1824 id: *id,
1825 root_name: worktree.root_name.clone(),
1826 visible: worktree.visible,
1827 abs_path: worktree.abs_path.clone(),
1828 })
1829 .collect::<Vec<_>>();
1830
1831 let add_project_collaborator = proto::AddProjectCollaborator {
1832 project_id: project_id.to_proto(),
1833 collaborator: Some(proto::Collaborator {
1834 peer_id: Some(session.connection_id.into()),
1835 replica_id: replica_id.0 as u32,
1836 user_id: guest_user_id.to_proto(),
1837 is_host: false,
1838 }),
1839 };
1840
1841 for collaborator in &collaborators {
1842 session
1843 .peer
1844 .send(
1845 collaborator.peer_id.unwrap().into(),
1846 add_project_collaborator.clone(),
1847 )
1848 .trace_err();
1849 }
1850
1851 // First, we send the metadata associated with each worktree.
1852 response.send(proto::JoinProjectResponse {
1853 project_id: project.id.0 as u64,
1854 worktrees: worktrees.clone(),
1855 replica_id: replica_id.0 as u32,
1856 collaborators: collaborators.clone(),
1857 language_servers: project.language_servers.clone(),
1858 role: project.role.into(),
1859 })?;
1860
1861 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1862 // Stream this worktree's entries.
1863 let message = proto::UpdateWorktree {
1864 project_id: project_id.to_proto(),
1865 worktree_id,
1866 abs_path: worktree.abs_path.clone(),
1867 root_name: worktree.root_name,
1868 updated_entries: worktree.entries,
1869 removed_entries: Default::default(),
1870 scan_id: worktree.scan_id,
1871 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1872 updated_repositories: worktree.repository_entries.into_values().collect(),
1873 removed_repositories: Default::default(),
1874 };
1875 for update in proto::split_worktree_update(message) {
1876 session.peer.send(session.connection_id, update.clone())?;
1877 }
1878
1879 // Stream this worktree's diagnostics.
1880 for summary in worktree.diagnostic_summaries {
1881 session.peer.send(
1882 session.connection_id,
1883 proto::UpdateDiagnosticSummary {
1884 project_id: project_id.to_proto(),
1885 worktree_id: worktree.id,
1886 summary: Some(summary),
1887 },
1888 )?;
1889 }
1890
1891 for settings_file in worktree.settings_files {
1892 session.peer.send(
1893 session.connection_id,
1894 proto::UpdateWorktreeSettings {
1895 project_id: project_id.to_proto(),
1896 worktree_id: worktree.id,
1897 path: settings_file.path,
1898 content: Some(settings_file.content),
1899 kind: Some(settings_file.kind.to_proto() as i32),
1900 },
1901 )?;
1902 }
1903 }
1904
1905 for language_server in &project.language_servers {
1906 session.peer.send(
1907 session.connection_id,
1908 proto::UpdateLanguageServer {
1909 project_id: project_id.to_proto(),
1910 language_server_id: language_server.id,
1911 variant: Some(
1912 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1913 proto::LspDiskBasedDiagnosticsUpdated {},
1914 ),
1915 ),
1916 },
1917 )?;
1918 }
1919
1920 Ok(())
1921}
1922
1923/// Leave someone elses shared project.
1924async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1925 let sender_id = session.connection_id;
1926 let project_id = ProjectId::from_proto(request.project_id);
1927 let db = session.db().await;
1928
1929 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1930 tracing::info!(
1931 %project_id,
1932 "leave project"
1933 );
1934
1935 project_left(project, &session);
1936 if let Some(room) = room {
1937 room_updated(room, &session.peer);
1938 }
1939
1940 Ok(())
1941}
1942
1943/// Updates other participants with changes to the project
1944async fn update_project(
1945 request: proto::UpdateProject,
1946 response: Response<proto::UpdateProject>,
1947 session: Session,
1948) -> Result<()> {
1949 let project_id = ProjectId::from_proto(request.project_id);
1950 let (room, guest_connection_ids) = &*session
1951 .db()
1952 .await
1953 .update_project(project_id, session.connection_id, &request.worktrees)
1954 .await?;
1955 broadcast(
1956 Some(session.connection_id),
1957 guest_connection_ids.iter().copied(),
1958 |connection_id| {
1959 session
1960 .peer
1961 .forward_send(session.connection_id, connection_id, request.clone())
1962 },
1963 );
1964 if let Some(room) = room {
1965 room_updated(room, &session.peer);
1966 }
1967 response.send(proto::Ack {})?;
1968
1969 Ok(())
1970}
1971
1972/// Updates other participants with changes to the worktree
1973async fn update_worktree(
1974 request: proto::UpdateWorktree,
1975 response: Response<proto::UpdateWorktree>,
1976 session: Session,
1977) -> Result<()> {
1978 let guest_connection_ids = session
1979 .db()
1980 .await
1981 .update_worktree(&request, session.connection_id)
1982 .await?;
1983
1984 broadcast(
1985 Some(session.connection_id),
1986 guest_connection_ids.iter().copied(),
1987 |connection_id| {
1988 session
1989 .peer
1990 .forward_send(session.connection_id, connection_id, request.clone())
1991 },
1992 );
1993 response.send(proto::Ack {})?;
1994 Ok(())
1995}
1996
1997/// Updates other participants with changes to the diagnostics
1998async fn update_diagnostic_summary(
1999 message: proto::UpdateDiagnosticSummary,
2000 session: Session,
2001) -> Result<()> {
2002 let guest_connection_ids = session
2003 .db()
2004 .await
2005 .update_diagnostic_summary(&message, session.connection_id)
2006 .await?;
2007
2008 broadcast(
2009 Some(session.connection_id),
2010 guest_connection_ids.iter().copied(),
2011 |connection_id| {
2012 session
2013 .peer
2014 .forward_send(session.connection_id, connection_id, message.clone())
2015 },
2016 );
2017
2018 Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree settings
2022async fn update_worktree_settings(
2023 message: proto::UpdateWorktreeSettings,
2024 session: Session,
2025) -> Result<()> {
2026 let guest_connection_ids = session
2027 .db()
2028 .await
2029 .update_worktree_settings(&message, session.connection_id)
2030 .await?;
2031
2032 broadcast(
2033 Some(session.connection_id),
2034 guest_connection_ids.iter().copied(),
2035 |connection_id| {
2036 session
2037 .peer
2038 .forward_send(session.connection_id, connection_id, message.clone())
2039 },
2040 );
2041
2042 Ok(())
2043}
2044
2045/// Notify other participants that a language server has started.
2046async fn start_language_server(
2047 request: proto::StartLanguageServer,
2048 session: Session,
2049) -> Result<()> {
2050 let guest_connection_ids = session
2051 .db()
2052 .await
2053 .start_language_server(&request, session.connection_id)
2054 .await?;
2055
2056 broadcast(
2057 Some(session.connection_id),
2058 guest_connection_ids.iter().copied(),
2059 |connection_id| {
2060 session
2061 .peer
2062 .forward_send(session.connection_id, connection_id, request.clone())
2063 },
2064 );
2065 Ok(())
2066}
2067
2068/// Notify other participants that a language server has changed.
2069async fn update_language_server(
2070 request: proto::UpdateLanguageServer,
2071 session: Session,
2072) -> Result<()> {
2073 let project_id = ProjectId::from_proto(request.project_id);
2074 let project_connection_ids = session
2075 .db()
2076 .await
2077 .project_connection_ids(project_id, session.connection_id, true)
2078 .await?;
2079 broadcast(
2080 Some(session.connection_id),
2081 project_connection_ids.iter().copied(),
2082 |connection_id| {
2083 session
2084 .peer
2085 .forward_send(session.connection_id, connection_id, request.clone())
2086 },
2087 );
2088 Ok(())
2089}
2090
2091/// forward a project request to the host. These requests should be read only
2092/// as guests are allowed to send them.
2093async fn forward_read_only_project_request<T>(
2094 request: T,
2095 response: Response<T>,
2096 session: Session,
2097) -> Result<()>
2098where
2099 T: EntityMessage + RequestMessage,
2100{
2101 let project_id = ProjectId::from_proto(request.remote_entity_id());
2102 let host_connection_id = session
2103 .db()
2104 .await
2105 .host_for_read_only_project_request(project_id, session.connection_id)
2106 .await?;
2107 let payload = session
2108 .peer
2109 .forward_request(session.connection_id, host_connection_id, request)
2110 .await?;
2111 response.send(payload)?;
2112 Ok(())
2113}
2114
2115async fn forward_find_search_candidates_request(
2116 request: proto::FindSearchCandidates,
2117 response: Response<proto::FindSearchCandidates>,
2118 session: Session,
2119) -> Result<()> {
2120 let project_id = ProjectId::from_proto(request.remote_entity_id());
2121 let host_connection_id = session
2122 .db()
2123 .await
2124 .host_for_read_only_project_request(project_id, session.connection_id)
2125 .await?;
2126 let payload = session
2127 .peer
2128 .forward_request(session.connection_id, host_connection_id, request)
2129 .await?;
2130 response.send(payload)?;
2131 Ok(())
2132}
2133
2134/// forward a project request to the host. These requests are disallowed
2135/// for guests.
2136async fn forward_mutating_project_request<T>(
2137 request: T,
2138 response: Response<T>,
2139 session: Session,
2140) -> Result<()>
2141where
2142 T: EntityMessage + RequestMessage,
2143{
2144 let project_id = ProjectId::from_proto(request.remote_entity_id());
2145
2146 let host_connection_id = session
2147 .db()
2148 .await
2149 .host_for_mutating_project_request(project_id, session.connection_id)
2150 .await?;
2151 let payload = session
2152 .peer
2153 .forward_request(session.connection_id, host_connection_id, request)
2154 .await?;
2155 response.send(payload)?;
2156 Ok(())
2157}
2158
2159/// Notify other participants that a new buffer has been created
2160async fn create_buffer_for_peer(
2161 request: proto::CreateBufferForPeer,
2162 session: Session,
2163) -> Result<()> {
2164 session
2165 .db()
2166 .await
2167 .check_user_is_project_host(
2168 ProjectId::from_proto(request.project_id),
2169 session.connection_id,
2170 )
2171 .await?;
2172 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2173 session
2174 .peer
2175 .forward_send(session.connection_id, peer_id.into(), request)?;
2176 Ok(())
2177}
2178
2179/// Notify other participants that a buffer has been updated. This is
2180/// allowed for guests as long as the update is limited to selections.
2181async fn update_buffer(
2182 request: proto::UpdateBuffer,
2183 response: Response<proto::UpdateBuffer>,
2184 session: Session,
2185) -> Result<()> {
2186 let project_id = ProjectId::from_proto(request.project_id);
2187 let mut capability = Capability::ReadOnly;
2188
2189 for op in request.operations.iter() {
2190 match op.variant {
2191 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2192 Some(_) => capability = Capability::ReadWrite,
2193 }
2194 }
2195
2196 let host = {
2197 let guard = session
2198 .db()
2199 .await
2200 .connections_for_buffer_update(project_id, session.connection_id, capability)
2201 .await?;
2202
2203 let (host, guests) = &*guard;
2204
2205 broadcast(
2206 Some(session.connection_id),
2207 guests.clone(),
2208 |connection_id| {
2209 session
2210 .peer
2211 .forward_send(session.connection_id, connection_id, request.clone())
2212 },
2213 );
2214
2215 *host
2216 };
2217
2218 if host != session.connection_id {
2219 session
2220 .peer
2221 .forward_request(session.connection_id, host, request.clone())
2222 .await?;
2223 }
2224
2225 response.send(proto::Ack {})?;
2226 Ok(())
2227}
2228
2229async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2230 let project_id = ProjectId::from_proto(message.project_id);
2231
2232 let operation = message.operation.as_ref().context("invalid operation")?;
2233 let capability = match operation.variant.as_ref() {
2234 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2235 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2236 match buffer_op.variant {
2237 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2238 Capability::ReadOnly
2239 }
2240 _ => Capability::ReadWrite,
2241 }
2242 } else {
2243 Capability::ReadWrite
2244 }
2245 }
2246 Some(_) => Capability::ReadWrite,
2247 None => Capability::ReadOnly,
2248 };
2249
2250 let guard = session
2251 .db()
2252 .await
2253 .connections_for_buffer_update(project_id, session.connection_id, capability)
2254 .await?;
2255
2256 let (host, guests) = &*guard;
2257
2258 broadcast(
2259 Some(session.connection_id),
2260 guests.iter().chain([host]).copied(),
2261 |connection_id| {
2262 session
2263 .peer
2264 .forward_send(session.connection_id, connection_id, message.clone())
2265 },
2266 );
2267
2268 Ok(())
2269}
2270
2271/// Notify other participants that a project has been updated.
2272async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2273 request: T,
2274 session: Session,
2275) -> Result<()> {
2276 let project_id = ProjectId::from_proto(request.remote_entity_id());
2277 let project_connection_ids = session
2278 .db()
2279 .await
2280 .project_connection_ids(project_id, session.connection_id, false)
2281 .await?;
2282
2283 broadcast(
2284 Some(session.connection_id),
2285 project_connection_ids.iter().copied(),
2286 |connection_id| {
2287 session
2288 .peer
2289 .forward_send(session.connection_id, connection_id, request.clone())
2290 },
2291 );
2292 Ok(())
2293}
2294
2295/// Start following another user in a call.
2296async fn follow(
2297 request: proto::Follow,
2298 response: Response<proto::Follow>,
2299 session: Session,
2300) -> Result<()> {
2301 let room_id = RoomId::from_proto(request.room_id);
2302 let project_id = request.project_id.map(ProjectId::from_proto);
2303 let leader_id = request
2304 .leader_id
2305 .ok_or_else(|| anyhow!("invalid leader id"))?
2306 .into();
2307 let follower_id = session.connection_id;
2308
2309 session
2310 .db()
2311 .await
2312 .check_room_participants(room_id, leader_id, session.connection_id)
2313 .await?;
2314
2315 let response_payload = session
2316 .peer
2317 .forward_request(session.connection_id, leader_id, request)
2318 .await?;
2319 response.send(response_payload)?;
2320
2321 if let Some(project_id) = project_id {
2322 let room = session
2323 .db()
2324 .await
2325 .follow(room_id, project_id, leader_id, follower_id)
2326 .await?;
2327 room_updated(&room, &session.peer);
2328 }
2329
2330 Ok(())
2331}
2332
2333/// Stop following another user in a call.
2334async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2335 let room_id = RoomId::from_proto(request.room_id);
2336 let project_id = request.project_id.map(ProjectId::from_proto);
2337 let leader_id = request
2338 .leader_id
2339 .ok_or_else(|| anyhow!("invalid leader id"))?
2340 .into();
2341 let follower_id = session.connection_id;
2342
2343 session
2344 .db()
2345 .await
2346 .check_room_participants(room_id, leader_id, session.connection_id)
2347 .await?;
2348
2349 session
2350 .peer
2351 .forward_send(session.connection_id, leader_id, request)?;
2352
2353 if let Some(project_id) = project_id {
2354 let room = session
2355 .db()
2356 .await
2357 .unfollow(room_id, project_id, leader_id, follower_id)
2358 .await?;
2359 room_updated(&room, &session.peer);
2360 }
2361
2362 Ok(())
2363}
2364
2365/// Notify everyone following you of your current location.
2366async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2367 let room_id = RoomId::from_proto(request.room_id);
2368 let database = session.db.lock().await;
2369
2370 let connection_ids = if let Some(project_id) = request.project_id {
2371 let project_id = ProjectId::from_proto(project_id);
2372 database
2373 .project_connection_ids(project_id, session.connection_id, true)
2374 .await?
2375 } else {
2376 database
2377 .room_connection_ids(room_id, session.connection_id)
2378 .await?
2379 };
2380
2381 // For now, don't send view update messages back to that view's current leader.
2382 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2383 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2384 _ => None,
2385 });
2386
2387 for connection_id in connection_ids.iter().cloned() {
2388 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2389 session
2390 .peer
2391 .forward_send(session.connection_id, connection_id, request.clone())?;
2392 }
2393 }
2394 Ok(())
2395}
2396
2397/// Get public data about users.
2398async fn get_users(
2399 request: proto::GetUsers,
2400 response: Response<proto::GetUsers>,
2401 session: Session,
2402) -> Result<()> {
2403 let user_ids = request
2404 .user_ids
2405 .into_iter()
2406 .map(UserId::from_proto)
2407 .collect();
2408 let users = session
2409 .db()
2410 .await
2411 .get_users_by_ids(user_ids)
2412 .await?
2413 .into_iter()
2414 .map(|user| proto::User {
2415 id: user.id.to_proto(),
2416 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2417 github_login: user.github_login,
2418 })
2419 .collect();
2420 response.send(proto::UsersResponse { users })?;
2421 Ok(())
2422}
2423
2424/// Search for users (to invite) buy Github login
2425async fn fuzzy_search_users(
2426 request: proto::FuzzySearchUsers,
2427 response: Response<proto::FuzzySearchUsers>,
2428 session: Session,
2429) -> Result<()> {
2430 let query = request.query;
2431 let users = match query.len() {
2432 0 => vec![],
2433 1 | 2 => session
2434 .db()
2435 .await
2436 .get_user_by_github_login(&query)
2437 .await?
2438 .into_iter()
2439 .collect(),
2440 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2441 };
2442 let users = users
2443 .into_iter()
2444 .filter(|user| user.id != session.user_id())
2445 .map(|user| proto::User {
2446 id: user.id.to_proto(),
2447 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2448 github_login: user.github_login,
2449 })
2450 .collect();
2451 response.send(proto::UsersResponse { users })?;
2452 Ok(())
2453}
2454
2455/// Send a contact request to another user.
2456async fn request_contact(
2457 request: proto::RequestContact,
2458 response: Response<proto::RequestContact>,
2459 session: Session,
2460) -> Result<()> {
2461 let requester_id = session.user_id();
2462 let responder_id = UserId::from_proto(request.responder_id);
2463 if requester_id == responder_id {
2464 return Err(anyhow!("cannot add yourself as a contact"))?;
2465 }
2466
2467 let notifications = session
2468 .db()
2469 .await
2470 .send_contact_request(requester_id, responder_id)
2471 .await?;
2472
2473 // Update outgoing contact requests of requester
2474 let mut update = proto::UpdateContacts::default();
2475 update.outgoing_requests.push(responder_id.to_proto());
2476 for connection_id in session
2477 .connection_pool()
2478 .await
2479 .user_connection_ids(requester_id)
2480 {
2481 session.peer.send(connection_id, update.clone())?;
2482 }
2483
2484 // Update incoming contact requests of responder
2485 let mut update = proto::UpdateContacts::default();
2486 update
2487 .incoming_requests
2488 .push(proto::IncomingContactRequest {
2489 requester_id: requester_id.to_proto(),
2490 });
2491 let connection_pool = session.connection_pool().await;
2492 for connection_id in connection_pool.user_connection_ids(responder_id) {
2493 session.peer.send(connection_id, update.clone())?;
2494 }
2495
2496 send_notifications(&connection_pool, &session.peer, notifications);
2497
2498 response.send(proto::Ack {})?;
2499 Ok(())
2500}
2501
2502/// Accept or decline a contact request
2503async fn respond_to_contact_request(
2504 request: proto::RespondToContactRequest,
2505 response: Response<proto::RespondToContactRequest>,
2506 session: Session,
2507) -> Result<()> {
2508 let responder_id = session.user_id();
2509 let requester_id = UserId::from_proto(request.requester_id);
2510 let db = session.db().await;
2511 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2512 db.dismiss_contact_notification(responder_id, requester_id)
2513 .await?;
2514 } else {
2515 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2516
2517 let notifications = db
2518 .respond_to_contact_request(responder_id, requester_id, accept)
2519 .await?;
2520 let requester_busy = db.is_user_busy(requester_id).await?;
2521 let responder_busy = db.is_user_busy(responder_id).await?;
2522
2523 let pool = session.connection_pool().await;
2524 // Update responder with new contact
2525 let mut update = proto::UpdateContacts::default();
2526 if accept {
2527 update
2528 .contacts
2529 .push(contact_for_user(requester_id, requester_busy, &pool));
2530 }
2531 update
2532 .remove_incoming_requests
2533 .push(requester_id.to_proto());
2534 for connection_id in pool.user_connection_ids(responder_id) {
2535 session.peer.send(connection_id, update.clone())?;
2536 }
2537
2538 // Update requester with new contact
2539 let mut update = proto::UpdateContacts::default();
2540 if accept {
2541 update
2542 .contacts
2543 .push(contact_for_user(responder_id, responder_busy, &pool));
2544 }
2545 update
2546 .remove_outgoing_requests
2547 .push(responder_id.to_proto());
2548
2549 for connection_id in pool.user_connection_ids(requester_id) {
2550 session.peer.send(connection_id, update.clone())?;
2551 }
2552
2553 send_notifications(&pool, &session.peer, notifications);
2554 }
2555
2556 response.send(proto::Ack {})?;
2557 Ok(())
2558}
2559
2560/// Remove a contact.
2561async fn remove_contact(
2562 request: proto::RemoveContact,
2563 response: Response<proto::RemoveContact>,
2564 session: Session,
2565) -> Result<()> {
2566 let requester_id = session.user_id();
2567 let responder_id = UserId::from_proto(request.user_id);
2568 let db = session.db().await;
2569 let (contact_accepted, deleted_notification_id) =
2570 db.remove_contact(requester_id, responder_id).await?;
2571
2572 let pool = session.connection_pool().await;
2573 // Update outgoing contact requests of requester
2574 let mut update = proto::UpdateContacts::default();
2575 if contact_accepted {
2576 update.remove_contacts.push(responder_id.to_proto());
2577 } else {
2578 update
2579 .remove_outgoing_requests
2580 .push(responder_id.to_proto());
2581 }
2582 for connection_id in pool.user_connection_ids(requester_id) {
2583 session.peer.send(connection_id, update.clone())?;
2584 }
2585
2586 // Update incoming contact requests of responder
2587 let mut update = proto::UpdateContacts::default();
2588 if contact_accepted {
2589 update.remove_contacts.push(requester_id.to_proto());
2590 } else {
2591 update
2592 .remove_incoming_requests
2593 .push(requester_id.to_proto());
2594 }
2595 for connection_id in pool.user_connection_ids(responder_id) {
2596 session.peer.send(connection_id, update.clone())?;
2597 if let Some(notification_id) = deleted_notification_id {
2598 session.peer.send(
2599 connection_id,
2600 proto::DeleteNotification {
2601 notification_id: notification_id.to_proto(),
2602 },
2603 )?;
2604 }
2605 }
2606
2607 response.send(proto::Ack {})?;
2608 Ok(())
2609}
2610
2611fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2612 version.0.minor() < 139
2613}
2614
2615async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2616 let plan = session.current_plan(&session.db().await).await?;
2617
2618 session
2619 .peer
2620 .send(
2621 session.connection_id,
2622 proto::UpdateUserPlan { plan: plan.into() },
2623 )
2624 .trace_err();
2625
2626 Ok(())
2627}
2628
2629async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2630 subscribe_user_to_channels(session.user_id(), &session).await?;
2631 Ok(())
2632}
2633
2634async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2635 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2636 let mut pool = session.connection_pool().await;
2637 for membership in &channels_for_user.channel_memberships {
2638 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2639 }
2640 session.peer.send(
2641 session.connection_id,
2642 build_update_user_channels(&channels_for_user),
2643 )?;
2644 session.peer.send(
2645 session.connection_id,
2646 build_channels_update(channels_for_user),
2647 )?;
2648 Ok(())
2649}
2650
2651/// Creates a new channel.
2652async fn create_channel(
2653 request: proto::CreateChannel,
2654 response: Response<proto::CreateChannel>,
2655 session: Session,
2656) -> Result<()> {
2657 let db = session.db().await;
2658
2659 let parent_id = request.parent_id.map(ChannelId::from_proto);
2660 let (channel, membership) = db
2661 .create_channel(&request.name, parent_id, session.user_id())
2662 .await?;
2663
2664 let root_id = channel.root_id();
2665 let channel = Channel::from_model(channel);
2666
2667 response.send(proto::CreateChannelResponse {
2668 channel: Some(channel.to_proto()),
2669 parent_id: request.parent_id,
2670 })?;
2671
2672 let mut connection_pool = session.connection_pool().await;
2673 if let Some(membership) = membership {
2674 connection_pool.subscribe_to_channel(
2675 membership.user_id,
2676 membership.channel_id,
2677 membership.role,
2678 );
2679 let update = proto::UpdateUserChannels {
2680 channel_memberships: vec![proto::ChannelMembership {
2681 channel_id: membership.channel_id.to_proto(),
2682 role: membership.role.into(),
2683 }],
2684 ..Default::default()
2685 };
2686 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2687 session.peer.send(connection_id, update.clone())?;
2688 }
2689 }
2690
2691 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2692 if !role.can_see_channel(channel.visibility) {
2693 continue;
2694 }
2695
2696 let update = proto::UpdateChannels {
2697 channels: vec![channel.to_proto()],
2698 ..Default::default()
2699 };
2700 session.peer.send(connection_id, update.clone())?;
2701 }
2702
2703 Ok(())
2704}
2705
2706/// Delete a channel
2707async fn delete_channel(
2708 request: proto::DeleteChannel,
2709 response: Response<proto::DeleteChannel>,
2710 session: Session,
2711) -> Result<()> {
2712 let db = session.db().await;
2713
2714 let channel_id = request.channel_id;
2715 let (root_channel, removed_channels) = db
2716 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2717 .await?;
2718 response.send(proto::Ack {})?;
2719
2720 // Notify members of removed channels
2721 let mut update = proto::UpdateChannels::default();
2722 update
2723 .delete_channels
2724 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2725
2726 let connection_pool = session.connection_pool().await;
2727 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2728 session.peer.send(connection_id, update.clone())?;
2729 }
2730
2731 Ok(())
2732}
2733
2734/// Invite someone to join a channel.
2735async fn invite_channel_member(
2736 request: proto::InviteChannelMember,
2737 response: Response<proto::InviteChannelMember>,
2738 session: Session,
2739) -> Result<()> {
2740 let db = session.db().await;
2741 let channel_id = ChannelId::from_proto(request.channel_id);
2742 let invitee_id = UserId::from_proto(request.user_id);
2743 let InviteMemberResult {
2744 channel,
2745 notifications,
2746 } = db
2747 .invite_channel_member(
2748 channel_id,
2749 invitee_id,
2750 session.user_id(),
2751 request.role().into(),
2752 )
2753 .await?;
2754
2755 let update = proto::UpdateChannels {
2756 channel_invitations: vec![channel.to_proto()],
2757 ..Default::default()
2758 };
2759
2760 let connection_pool = session.connection_pool().await;
2761 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2762 session.peer.send(connection_id, update.clone())?;
2763 }
2764
2765 send_notifications(&connection_pool, &session.peer, notifications);
2766
2767 response.send(proto::Ack {})?;
2768 Ok(())
2769}
2770
2771/// remove someone from a channel
2772async fn remove_channel_member(
2773 request: proto::RemoveChannelMember,
2774 response: Response<proto::RemoveChannelMember>,
2775 session: Session,
2776) -> Result<()> {
2777 let db = session.db().await;
2778 let channel_id = ChannelId::from_proto(request.channel_id);
2779 let member_id = UserId::from_proto(request.user_id);
2780
2781 let RemoveChannelMemberResult {
2782 membership_update,
2783 notification_id,
2784 } = db
2785 .remove_channel_member(channel_id, member_id, session.user_id())
2786 .await?;
2787
2788 let mut connection_pool = session.connection_pool().await;
2789 notify_membership_updated(
2790 &mut connection_pool,
2791 membership_update,
2792 member_id,
2793 &session.peer,
2794 );
2795 for connection_id in connection_pool.user_connection_ids(member_id) {
2796 if let Some(notification_id) = notification_id {
2797 session
2798 .peer
2799 .send(
2800 connection_id,
2801 proto::DeleteNotification {
2802 notification_id: notification_id.to_proto(),
2803 },
2804 )
2805 .trace_err();
2806 }
2807 }
2808
2809 response.send(proto::Ack {})?;
2810 Ok(())
2811}
2812
2813/// Toggle the channel between public and private.
2814/// Care is taken to maintain the invariant that public channels only descend from public channels,
2815/// (though members-only channels can appear at any point in the hierarchy).
2816async fn set_channel_visibility(
2817 request: proto::SetChannelVisibility,
2818 response: Response<proto::SetChannelVisibility>,
2819 session: Session,
2820) -> Result<()> {
2821 let db = session.db().await;
2822 let channel_id = ChannelId::from_proto(request.channel_id);
2823 let visibility = request.visibility().into();
2824
2825 let channel_model = db
2826 .set_channel_visibility(channel_id, visibility, session.user_id())
2827 .await?;
2828 let root_id = channel_model.root_id();
2829 let channel = Channel::from_model(channel_model);
2830
2831 let mut connection_pool = session.connection_pool().await;
2832 for (user_id, role) in connection_pool
2833 .channel_user_ids(root_id)
2834 .collect::<Vec<_>>()
2835 .into_iter()
2836 {
2837 let update = if role.can_see_channel(channel.visibility) {
2838 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2839 proto::UpdateChannels {
2840 channels: vec![channel.to_proto()],
2841 ..Default::default()
2842 }
2843 } else {
2844 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2845 proto::UpdateChannels {
2846 delete_channels: vec![channel.id.to_proto()],
2847 ..Default::default()
2848 }
2849 };
2850
2851 for connection_id in connection_pool.user_connection_ids(user_id) {
2852 session.peer.send(connection_id, update.clone())?;
2853 }
2854 }
2855
2856 response.send(proto::Ack {})?;
2857 Ok(())
2858}
2859
2860/// Alter the role for a user in the channel.
2861async fn set_channel_member_role(
2862 request: proto::SetChannelMemberRole,
2863 response: Response<proto::SetChannelMemberRole>,
2864 session: Session,
2865) -> Result<()> {
2866 let db = session.db().await;
2867 let channel_id = ChannelId::from_proto(request.channel_id);
2868 let member_id = UserId::from_proto(request.user_id);
2869 let result = db
2870 .set_channel_member_role(
2871 channel_id,
2872 session.user_id(),
2873 member_id,
2874 request.role().into(),
2875 )
2876 .await?;
2877
2878 match result {
2879 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2880 let mut connection_pool = session.connection_pool().await;
2881 notify_membership_updated(
2882 &mut connection_pool,
2883 membership_update,
2884 member_id,
2885 &session.peer,
2886 )
2887 }
2888 db::SetMemberRoleResult::InviteUpdated(channel) => {
2889 let update = proto::UpdateChannels {
2890 channel_invitations: vec![channel.to_proto()],
2891 ..Default::default()
2892 };
2893
2894 for connection_id in session
2895 .connection_pool()
2896 .await
2897 .user_connection_ids(member_id)
2898 {
2899 session.peer.send(connection_id, update.clone())?;
2900 }
2901 }
2902 }
2903
2904 response.send(proto::Ack {})?;
2905 Ok(())
2906}
2907
2908/// Change the name of a channel
2909async fn rename_channel(
2910 request: proto::RenameChannel,
2911 response: Response<proto::RenameChannel>,
2912 session: Session,
2913) -> Result<()> {
2914 let db = session.db().await;
2915 let channel_id = ChannelId::from_proto(request.channel_id);
2916 let channel_model = db
2917 .rename_channel(channel_id, session.user_id(), &request.name)
2918 .await?;
2919 let root_id = channel_model.root_id();
2920 let channel = Channel::from_model(channel_model);
2921
2922 response.send(proto::RenameChannelResponse {
2923 channel: Some(channel.to_proto()),
2924 })?;
2925
2926 let connection_pool = session.connection_pool().await;
2927 let update = proto::UpdateChannels {
2928 channels: vec![channel.to_proto()],
2929 ..Default::default()
2930 };
2931 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2932 if role.can_see_channel(channel.visibility) {
2933 session.peer.send(connection_id, update.clone())?;
2934 }
2935 }
2936
2937 Ok(())
2938}
2939
2940/// Move a channel to a new parent.
2941async fn move_channel(
2942 request: proto::MoveChannel,
2943 response: Response<proto::MoveChannel>,
2944 session: Session,
2945) -> Result<()> {
2946 let channel_id = ChannelId::from_proto(request.channel_id);
2947 let to = ChannelId::from_proto(request.to);
2948
2949 let (root_id, channels) = session
2950 .db()
2951 .await
2952 .move_channel(channel_id, to, session.user_id())
2953 .await?;
2954
2955 let connection_pool = session.connection_pool().await;
2956 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2957 let channels = channels
2958 .iter()
2959 .filter_map(|channel| {
2960 if role.can_see_channel(channel.visibility) {
2961 Some(channel.to_proto())
2962 } else {
2963 None
2964 }
2965 })
2966 .collect::<Vec<_>>();
2967 if channels.is_empty() {
2968 continue;
2969 }
2970
2971 let update = proto::UpdateChannels {
2972 channels,
2973 ..Default::default()
2974 };
2975
2976 session.peer.send(connection_id, update.clone())?;
2977 }
2978
2979 response.send(Ack {})?;
2980 Ok(())
2981}
2982
2983/// Get the list of channel members
2984async fn get_channel_members(
2985 request: proto::GetChannelMembers,
2986 response: Response<proto::GetChannelMembers>,
2987 session: Session,
2988) -> Result<()> {
2989 let db = session.db().await;
2990 let channel_id = ChannelId::from_proto(request.channel_id);
2991 let limit = if request.limit == 0 {
2992 u16::MAX as u64
2993 } else {
2994 request.limit
2995 };
2996 let (members, users) = db
2997 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2998 .await?;
2999 response.send(proto::GetChannelMembersResponse { members, users })?;
3000 Ok(())
3001}
3002
3003/// Accept or decline a channel invitation.
3004async fn respond_to_channel_invite(
3005 request: proto::RespondToChannelInvite,
3006 response: Response<proto::RespondToChannelInvite>,
3007 session: Session,
3008) -> Result<()> {
3009 let db = session.db().await;
3010 let channel_id = ChannelId::from_proto(request.channel_id);
3011 let RespondToChannelInvite {
3012 membership_update,
3013 notifications,
3014 } = db
3015 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3016 .await?;
3017
3018 let mut connection_pool = session.connection_pool().await;
3019 if let Some(membership_update) = membership_update {
3020 notify_membership_updated(
3021 &mut connection_pool,
3022 membership_update,
3023 session.user_id(),
3024 &session.peer,
3025 );
3026 } else {
3027 let update = proto::UpdateChannels {
3028 remove_channel_invitations: vec![channel_id.to_proto()],
3029 ..Default::default()
3030 };
3031
3032 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3033 session.peer.send(connection_id, update.clone())?;
3034 }
3035 };
3036
3037 send_notifications(&connection_pool, &session.peer, notifications);
3038
3039 response.send(proto::Ack {})?;
3040
3041 Ok(())
3042}
3043
3044/// Join the channels' room
3045async fn join_channel(
3046 request: proto::JoinChannel,
3047 response: Response<proto::JoinChannel>,
3048 session: Session,
3049) -> Result<()> {
3050 let channel_id = ChannelId::from_proto(request.channel_id);
3051 join_channel_internal(channel_id, Box::new(response), session).await
3052}
3053
3054trait JoinChannelInternalResponse {
3055 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3056}
3057impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3058 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3059 Response::<proto::JoinChannel>::send(self, result)
3060 }
3061}
3062impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3063 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3064 Response::<proto::JoinRoom>::send(self, result)
3065 }
3066}
3067
3068async fn join_channel_internal(
3069 channel_id: ChannelId,
3070 response: Box<impl JoinChannelInternalResponse>,
3071 session: Session,
3072) -> Result<()> {
3073 let joined_room = {
3074 let mut db = session.db().await;
3075 // If zed quits without leaving the room, and the user re-opens zed before the
3076 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3077 // room they were in.
3078 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3079 tracing::info!(
3080 stale_connection_id = %connection,
3081 "cleaning up stale connection",
3082 );
3083 drop(db);
3084 leave_room_for_session(&session, connection).await?;
3085 db = session.db().await;
3086 }
3087
3088 let (joined_room, membership_updated, role) = db
3089 .join_channel(channel_id, session.user_id(), session.connection_id)
3090 .await?;
3091
3092 let live_kit_connection_info =
3093 session
3094 .app_state
3095 .livekit_client
3096 .as_ref()
3097 .and_then(|live_kit| {
3098 let (can_publish, token) = if role == ChannelRole::Guest {
3099 (
3100 false,
3101 live_kit
3102 .guest_token(
3103 &joined_room.room.livekit_room,
3104 &session.user_id().to_string(),
3105 )
3106 .trace_err()?,
3107 )
3108 } else {
3109 (
3110 true,
3111 live_kit
3112 .room_token(
3113 &joined_room.room.livekit_room,
3114 &session.user_id().to_string(),
3115 )
3116 .trace_err()?,
3117 )
3118 };
3119
3120 Some(LiveKitConnectionInfo {
3121 server_url: live_kit.url().into(),
3122 token,
3123 can_publish,
3124 })
3125 });
3126
3127 response.send(proto::JoinRoomResponse {
3128 room: Some(joined_room.room.clone()),
3129 channel_id: joined_room
3130 .channel
3131 .as_ref()
3132 .map(|channel| channel.id.to_proto()),
3133 live_kit_connection_info,
3134 })?;
3135
3136 let mut connection_pool = session.connection_pool().await;
3137 if let Some(membership_updated) = membership_updated {
3138 notify_membership_updated(
3139 &mut connection_pool,
3140 membership_updated,
3141 session.user_id(),
3142 &session.peer,
3143 );
3144 }
3145
3146 room_updated(&joined_room.room, &session.peer);
3147
3148 joined_room
3149 };
3150
3151 channel_updated(
3152 &joined_room
3153 .channel
3154 .ok_or_else(|| anyhow!("channel not returned"))?,
3155 &joined_room.room,
3156 &session.peer,
3157 &*session.connection_pool().await,
3158 );
3159
3160 update_user_contacts(session.user_id(), &session).await?;
3161 Ok(())
3162}
3163
3164/// Start editing the channel notes
3165async fn join_channel_buffer(
3166 request: proto::JoinChannelBuffer,
3167 response: Response<proto::JoinChannelBuffer>,
3168 session: Session,
3169) -> Result<()> {
3170 let db = session.db().await;
3171 let channel_id = ChannelId::from_proto(request.channel_id);
3172
3173 let open_response = db
3174 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3175 .await?;
3176
3177 let collaborators = open_response.collaborators.clone();
3178 response.send(open_response)?;
3179
3180 let update = UpdateChannelBufferCollaborators {
3181 channel_id: channel_id.to_proto(),
3182 collaborators: collaborators.clone(),
3183 };
3184 channel_buffer_updated(
3185 session.connection_id,
3186 collaborators
3187 .iter()
3188 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3189 &update,
3190 &session.peer,
3191 );
3192
3193 Ok(())
3194}
3195
3196/// Edit the channel notes
3197async fn update_channel_buffer(
3198 request: proto::UpdateChannelBuffer,
3199 session: Session,
3200) -> Result<()> {
3201 let db = session.db().await;
3202 let channel_id = ChannelId::from_proto(request.channel_id);
3203
3204 let (collaborators, epoch, version) = db
3205 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3206 .await?;
3207
3208 channel_buffer_updated(
3209 session.connection_id,
3210 collaborators.clone(),
3211 &proto::UpdateChannelBuffer {
3212 channel_id: channel_id.to_proto(),
3213 operations: request.operations,
3214 },
3215 &session.peer,
3216 );
3217
3218 let pool = &*session.connection_pool().await;
3219
3220 let non_collaborators =
3221 pool.channel_connection_ids(channel_id)
3222 .filter_map(|(connection_id, _)| {
3223 if collaborators.contains(&connection_id) {
3224 None
3225 } else {
3226 Some(connection_id)
3227 }
3228 });
3229
3230 broadcast(None, non_collaborators, |peer_id| {
3231 session.peer.send(
3232 peer_id,
3233 proto::UpdateChannels {
3234 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3235 channel_id: channel_id.to_proto(),
3236 epoch: epoch as u64,
3237 version: version.clone(),
3238 }],
3239 ..Default::default()
3240 },
3241 )
3242 });
3243
3244 Ok(())
3245}
3246
3247/// Rejoin the channel notes after a connection blip
3248async fn rejoin_channel_buffers(
3249 request: proto::RejoinChannelBuffers,
3250 response: Response<proto::RejoinChannelBuffers>,
3251 session: Session,
3252) -> Result<()> {
3253 let db = session.db().await;
3254 let buffers = db
3255 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3256 .await?;
3257
3258 for rejoined_buffer in &buffers {
3259 let collaborators_to_notify = rejoined_buffer
3260 .buffer
3261 .collaborators
3262 .iter()
3263 .filter_map(|c| Some(c.peer_id?.into()));
3264 channel_buffer_updated(
3265 session.connection_id,
3266 collaborators_to_notify,
3267 &proto::UpdateChannelBufferCollaborators {
3268 channel_id: rejoined_buffer.buffer.channel_id,
3269 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3270 },
3271 &session.peer,
3272 );
3273 }
3274
3275 response.send(proto::RejoinChannelBuffersResponse {
3276 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3277 })?;
3278
3279 Ok(())
3280}
3281
3282/// Stop editing the channel notes
3283async fn leave_channel_buffer(
3284 request: proto::LeaveChannelBuffer,
3285 response: Response<proto::LeaveChannelBuffer>,
3286 session: Session,
3287) -> Result<()> {
3288 let db = session.db().await;
3289 let channel_id = ChannelId::from_proto(request.channel_id);
3290
3291 let left_buffer = db
3292 .leave_channel_buffer(channel_id, session.connection_id)
3293 .await?;
3294
3295 response.send(Ack {})?;
3296
3297 channel_buffer_updated(
3298 session.connection_id,
3299 left_buffer.connections,
3300 &proto::UpdateChannelBufferCollaborators {
3301 channel_id: channel_id.to_proto(),
3302 collaborators: left_buffer.collaborators,
3303 },
3304 &session.peer,
3305 );
3306
3307 Ok(())
3308}
3309
3310fn channel_buffer_updated<T: EnvelopedMessage>(
3311 sender_id: ConnectionId,
3312 collaborators: impl IntoIterator<Item = ConnectionId>,
3313 message: &T,
3314 peer: &Peer,
3315) {
3316 broadcast(Some(sender_id), collaborators, |peer_id| {
3317 peer.send(peer_id, message.clone())
3318 });
3319}
3320
3321fn send_notifications(
3322 connection_pool: &ConnectionPool,
3323 peer: &Peer,
3324 notifications: db::NotificationBatch,
3325) {
3326 for (user_id, notification) in notifications {
3327 for connection_id in connection_pool.user_connection_ids(user_id) {
3328 if let Err(error) = peer.send(
3329 connection_id,
3330 proto::AddNotification {
3331 notification: Some(notification.clone()),
3332 },
3333 ) {
3334 tracing::error!(
3335 "failed to send notification to {:?} {}",
3336 connection_id,
3337 error
3338 );
3339 }
3340 }
3341 }
3342}
3343
3344/// Send a message to the channel
3345async fn send_channel_message(
3346 request: proto::SendChannelMessage,
3347 response: Response<proto::SendChannelMessage>,
3348 session: Session,
3349) -> Result<()> {
3350 // Validate the message body.
3351 let body = request.body.trim().to_string();
3352 if body.len() > MAX_MESSAGE_LEN {
3353 return Err(anyhow!("message is too long"))?;
3354 }
3355 if body.is_empty() {
3356 return Err(anyhow!("message can't be blank"))?;
3357 }
3358
3359 // TODO: adjust mentions if body is trimmed
3360
3361 let timestamp = OffsetDateTime::now_utc();
3362 let nonce = request
3363 .nonce
3364 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3365
3366 let channel_id = ChannelId::from_proto(request.channel_id);
3367 let CreatedChannelMessage {
3368 message_id,
3369 participant_connection_ids,
3370 notifications,
3371 } = session
3372 .db()
3373 .await
3374 .create_channel_message(
3375 channel_id,
3376 session.user_id(),
3377 &body,
3378 &request.mentions,
3379 timestamp,
3380 nonce.clone().into(),
3381 request.reply_to_message_id.map(MessageId::from_proto),
3382 )
3383 .await?;
3384
3385 let message = proto::ChannelMessage {
3386 sender_id: session.user_id().to_proto(),
3387 id: message_id.to_proto(),
3388 body,
3389 mentions: request.mentions,
3390 timestamp: timestamp.unix_timestamp() as u64,
3391 nonce: Some(nonce),
3392 reply_to_message_id: request.reply_to_message_id,
3393 edited_at: None,
3394 };
3395 broadcast(
3396 Some(session.connection_id),
3397 participant_connection_ids.clone(),
3398 |connection| {
3399 session.peer.send(
3400 connection,
3401 proto::ChannelMessageSent {
3402 channel_id: channel_id.to_proto(),
3403 message: Some(message.clone()),
3404 },
3405 )
3406 },
3407 );
3408 response.send(proto::SendChannelMessageResponse {
3409 message: Some(message),
3410 })?;
3411
3412 let pool = &*session.connection_pool().await;
3413 let non_participants =
3414 pool.channel_connection_ids(channel_id)
3415 .filter_map(|(connection_id, _)| {
3416 if participant_connection_ids.contains(&connection_id) {
3417 None
3418 } else {
3419 Some(connection_id)
3420 }
3421 });
3422 broadcast(None, non_participants, |peer_id| {
3423 session.peer.send(
3424 peer_id,
3425 proto::UpdateChannels {
3426 latest_channel_message_ids: vec![proto::ChannelMessageId {
3427 channel_id: channel_id.to_proto(),
3428 message_id: message_id.to_proto(),
3429 }],
3430 ..Default::default()
3431 },
3432 )
3433 });
3434 send_notifications(pool, &session.peer, notifications);
3435
3436 Ok(())
3437}
3438
3439/// Delete a channel message
3440async fn remove_channel_message(
3441 request: proto::RemoveChannelMessage,
3442 response: Response<proto::RemoveChannelMessage>,
3443 session: Session,
3444) -> Result<()> {
3445 let channel_id = ChannelId::from_proto(request.channel_id);
3446 let message_id = MessageId::from_proto(request.message_id);
3447 let (connection_ids, existing_notification_ids) = session
3448 .db()
3449 .await
3450 .remove_channel_message(channel_id, message_id, session.user_id())
3451 .await?;
3452
3453 broadcast(
3454 Some(session.connection_id),
3455 connection_ids,
3456 move |connection| {
3457 session.peer.send(connection, request.clone())?;
3458
3459 for notification_id in &existing_notification_ids {
3460 session.peer.send(
3461 connection,
3462 proto::DeleteNotification {
3463 notification_id: (*notification_id).to_proto(),
3464 },
3465 )?;
3466 }
3467
3468 Ok(())
3469 },
3470 );
3471 response.send(proto::Ack {})?;
3472 Ok(())
3473}
3474
3475async fn update_channel_message(
3476 request: proto::UpdateChannelMessage,
3477 response: Response<proto::UpdateChannelMessage>,
3478 session: Session,
3479) -> Result<()> {
3480 let channel_id = ChannelId::from_proto(request.channel_id);
3481 let message_id = MessageId::from_proto(request.message_id);
3482 let updated_at = OffsetDateTime::now_utc();
3483 let UpdatedChannelMessage {
3484 message_id,
3485 participant_connection_ids,
3486 notifications,
3487 reply_to_message_id,
3488 timestamp,
3489 deleted_mention_notification_ids,
3490 updated_mention_notifications,
3491 } = session
3492 .db()
3493 .await
3494 .update_channel_message(
3495 channel_id,
3496 message_id,
3497 session.user_id(),
3498 request.body.as_str(),
3499 &request.mentions,
3500 updated_at,
3501 )
3502 .await?;
3503
3504 let nonce = request
3505 .nonce
3506 .clone()
3507 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3508
3509 let message = proto::ChannelMessage {
3510 sender_id: session.user_id().to_proto(),
3511 id: message_id.to_proto(),
3512 body: request.body.clone(),
3513 mentions: request.mentions.clone(),
3514 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3515 nonce: Some(nonce),
3516 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3517 edited_at: Some(updated_at.unix_timestamp() as u64),
3518 };
3519
3520 response.send(proto::Ack {})?;
3521
3522 let pool = &*session.connection_pool().await;
3523 broadcast(
3524 Some(session.connection_id),
3525 participant_connection_ids,
3526 |connection| {
3527 session.peer.send(
3528 connection,
3529 proto::ChannelMessageUpdate {
3530 channel_id: channel_id.to_proto(),
3531 message: Some(message.clone()),
3532 },
3533 )?;
3534
3535 for notification_id in &deleted_mention_notification_ids {
3536 session.peer.send(
3537 connection,
3538 proto::DeleteNotification {
3539 notification_id: (*notification_id).to_proto(),
3540 },
3541 )?;
3542 }
3543
3544 for notification in &updated_mention_notifications {
3545 session.peer.send(
3546 connection,
3547 proto::UpdateNotification {
3548 notification: Some(notification.clone()),
3549 },
3550 )?;
3551 }
3552
3553 Ok(())
3554 },
3555 );
3556
3557 send_notifications(pool, &session.peer, notifications);
3558
3559 Ok(())
3560}
3561
3562/// Mark a channel message as read
3563async fn acknowledge_channel_message(
3564 request: proto::AckChannelMessage,
3565 session: Session,
3566) -> Result<()> {
3567 let channel_id = ChannelId::from_proto(request.channel_id);
3568 let message_id = MessageId::from_proto(request.message_id);
3569 let notifications = session
3570 .db()
3571 .await
3572 .observe_channel_message(channel_id, session.user_id(), message_id)
3573 .await?;
3574 send_notifications(
3575 &*session.connection_pool().await,
3576 &session.peer,
3577 notifications,
3578 );
3579 Ok(())
3580}
3581
3582/// Mark a buffer version as synced
3583async fn acknowledge_buffer_version(
3584 request: proto::AckBufferOperation,
3585 session: Session,
3586) -> Result<()> {
3587 let buffer_id = BufferId::from_proto(request.buffer_id);
3588 session
3589 .db()
3590 .await
3591 .observe_buffer_version(
3592 buffer_id,
3593 session.user_id(),
3594 request.epoch as i32,
3595 &request.version,
3596 )
3597 .await?;
3598 Ok(())
3599}
3600
3601async fn count_language_model_tokens(
3602 request: proto::CountLanguageModelTokens,
3603 response: Response<proto::CountLanguageModelTokens>,
3604 session: Session,
3605 config: &Config,
3606) -> Result<()> {
3607 authorize_access_to_legacy_llm_endpoints(&session).await?;
3608
3609 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3610 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3611 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3612 };
3613
3614 session
3615 .app_state
3616 .rate_limiter
3617 .check(&*rate_limit, session.user_id())
3618 .await?;
3619
3620 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3621 Some(proto::LanguageModelProvider::Google) => {
3622 let api_key = config
3623 .google_ai_api_key
3624 .as_ref()
3625 .context("no Google AI API key configured on the server")?;
3626 google_ai::count_tokens(
3627 session.http_client.as_ref(),
3628 google_ai::API_URL,
3629 api_key,
3630 serde_json::from_str(&request.request)?,
3631 )
3632 .await?
3633 }
3634 _ => return Err(anyhow!("unsupported provider"))?,
3635 };
3636
3637 response.send(proto::CountLanguageModelTokensResponse {
3638 token_count: result.total_tokens as u32,
3639 })?;
3640
3641 Ok(())
3642}
3643
3644struct ZedProCountLanguageModelTokensRateLimit;
3645
3646impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3647 fn capacity(&self) -> usize {
3648 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3649 .ok()
3650 .and_then(|v| v.parse().ok())
3651 .unwrap_or(600) // Picked arbitrarily
3652 }
3653
3654 fn refill_duration(&self) -> chrono::Duration {
3655 chrono::Duration::hours(1)
3656 }
3657
3658 fn db_name(&self) -> &'static str {
3659 "zed-pro:count-language-model-tokens"
3660 }
3661}
3662
3663struct FreeCountLanguageModelTokensRateLimit;
3664
3665impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3666 fn capacity(&self) -> usize {
3667 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3668 .ok()
3669 .and_then(|v| v.parse().ok())
3670 .unwrap_or(600 / 10) // Picked arbitrarily
3671 }
3672
3673 fn refill_duration(&self) -> chrono::Duration {
3674 chrono::Duration::hours(1)
3675 }
3676
3677 fn db_name(&self) -> &'static str {
3678 "free:count-language-model-tokens"
3679 }
3680}
3681
3682struct ZedProComputeEmbeddingsRateLimit;
3683
3684impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3685 fn capacity(&self) -> usize {
3686 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3687 .ok()
3688 .and_then(|v| v.parse().ok())
3689 .unwrap_or(5000) // Picked arbitrarily
3690 }
3691
3692 fn refill_duration(&self) -> chrono::Duration {
3693 chrono::Duration::hours(1)
3694 }
3695
3696 fn db_name(&self) -> &'static str {
3697 "zed-pro:compute-embeddings"
3698 }
3699}
3700
3701struct FreeComputeEmbeddingsRateLimit;
3702
3703impl RateLimit for FreeComputeEmbeddingsRateLimit {
3704 fn capacity(&self) -> usize {
3705 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3706 .ok()
3707 .and_then(|v| v.parse().ok())
3708 .unwrap_or(5000 / 10) // Picked arbitrarily
3709 }
3710
3711 fn refill_duration(&self) -> chrono::Duration {
3712 chrono::Duration::hours(1)
3713 }
3714
3715 fn db_name(&self) -> &'static str {
3716 "free:compute-embeddings"
3717 }
3718}
3719
3720async fn compute_embeddings(
3721 request: proto::ComputeEmbeddings,
3722 response: Response<proto::ComputeEmbeddings>,
3723 session: Session,
3724 api_key: Option<Arc<str>>,
3725) -> Result<()> {
3726 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3727 authorize_access_to_legacy_llm_endpoints(&session).await?;
3728
3729 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3730 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3731 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3732 };
3733
3734 session
3735 .app_state
3736 .rate_limiter
3737 .check(&*rate_limit, session.user_id())
3738 .await?;
3739
3740 let embeddings = match request.model.as_str() {
3741 "openai/text-embedding-3-small" => {
3742 open_ai::embed(
3743 session.http_client.as_ref(),
3744 OPEN_AI_API_URL,
3745 &api_key,
3746 OpenAiEmbeddingModel::TextEmbedding3Small,
3747 request.texts.iter().map(|text| text.as_str()),
3748 )
3749 .await?
3750 }
3751 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3752 };
3753
3754 let embeddings = request
3755 .texts
3756 .iter()
3757 .map(|text| {
3758 let mut hasher = sha2::Sha256::new();
3759 hasher.update(text.as_bytes());
3760 let result = hasher.finalize();
3761 result.to_vec()
3762 })
3763 .zip(
3764 embeddings
3765 .data
3766 .into_iter()
3767 .map(|embedding| embedding.embedding),
3768 )
3769 .collect::<HashMap<_, _>>();
3770
3771 let db = session.db().await;
3772 db.save_embeddings(&request.model, &embeddings)
3773 .await
3774 .context("failed to save embeddings")
3775 .trace_err();
3776
3777 response.send(proto::ComputeEmbeddingsResponse {
3778 embeddings: embeddings
3779 .into_iter()
3780 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3781 .collect(),
3782 })?;
3783 Ok(())
3784}
3785
3786async fn get_cached_embeddings(
3787 request: proto::GetCachedEmbeddings,
3788 response: Response<proto::GetCachedEmbeddings>,
3789 session: Session,
3790) -> Result<()> {
3791 authorize_access_to_legacy_llm_endpoints(&session).await?;
3792
3793 let db = session.db().await;
3794 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3795
3796 response.send(proto::GetCachedEmbeddingsResponse {
3797 embeddings: embeddings
3798 .into_iter()
3799 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3800 .collect(),
3801 })?;
3802 Ok(())
3803}
3804
3805/// This is leftover from before the LLM service.
3806///
3807/// The endpoints protected by this check will be moved there eventually.
3808async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3809 if session.is_staff() {
3810 Ok(())
3811 } else {
3812 Err(anyhow!("permission denied"))?
3813 }
3814}
3815
3816/// Get a Supermaven API key for the user
3817async fn get_supermaven_api_key(
3818 _request: proto::GetSupermavenApiKey,
3819 response: Response<proto::GetSupermavenApiKey>,
3820 session: Session,
3821) -> Result<()> {
3822 let user_id: String = session.user_id().to_string();
3823 if !session.is_staff() {
3824 return Err(anyhow!("supermaven not enabled for this account"))?;
3825 }
3826
3827 let email = session
3828 .email()
3829 .ok_or_else(|| anyhow!("user must have an email"))?;
3830
3831 let supermaven_admin_api = session
3832 .supermaven_client
3833 .as_ref()
3834 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3835
3836 let result = supermaven_admin_api
3837 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3838 .await?;
3839
3840 response.send(proto::GetSupermavenApiKeyResponse {
3841 api_key: result.api_key,
3842 })?;
3843
3844 Ok(())
3845}
3846
3847/// Start receiving chat updates for a channel
3848async fn join_channel_chat(
3849 request: proto::JoinChannelChat,
3850 response: Response<proto::JoinChannelChat>,
3851 session: Session,
3852) -> Result<()> {
3853 let channel_id = ChannelId::from_proto(request.channel_id);
3854
3855 let db = session.db().await;
3856 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3857 .await?;
3858 let messages = db
3859 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3860 .await?;
3861 response.send(proto::JoinChannelChatResponse {
3862 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3863 messages,
3864 })?;
3865 Ok(())
3866}
3867
3868/// Stop receiving chat updates for a channel
3869async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3870 let channel_id = ChannelId::from_proto(request.channel_id);
3871 session
3872 .db()
3873 .await
3874 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3875 .await?;
3876 Ok(())
3877}
3878
3879/// Retrieve the chat history for a channel
3880async fn get_channel_messages(
3881 request: proto::GetChannelMessages,
3882 response: Response<proto::GetChannelMessages>,
3883 session: Session,
3884) -> Result<()> {
3885 let channel_id = ChannelId::from_proto(request.channel_id);
3886 let messages = session
3887 .db()
3888 .await
3889 .get_channel_messages(
3890 channel_id,
3891 session.user_id(),
3892 MESSAGE_COUNT_PER_PAGE,
3893 Some(MessageId::from_proto(request.before_message_id)),
3894 )
3895 .await?;
3896 response.send(proto::GetChannelMessagesResponse {
3897 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3898 messages,
3899 })?;
3900 Ok(())
3901}
3902
3903/// Retrieve specific chat messages
3904async fn get_channel_messages_by_id(
3905 request: proto::GetChannelMessagesById,
3906 response: Response<proto::GetChannelMessagesById>,
3907 session: Session,
3908) -> Result<()> {
3909 let message_ids = request
3910 .message_ids
3911 .iter()
3912 .map(|id| MessageId::from_proto(*id))
3913 .collect::<Vec<_>>();
3914 let messages = session
3915 .db()
3916 .await
3917 .get_channel_messages_by_id(session.user_id(), &message_ids)
3918 .await?;
3919 response.send(proto::GetChannelMessagesResponse {
3920 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921 messages,
3922 })?;
3923 Ok(())
3924}
3925
3926/// Retrieve the current users notifications
3927async fn get_notifications(
3928 request: proto::GetNotifications,
3929 response: Response<proto::GetNotifications>,
3930 session: Session,
3931) -> Result<()> {
3932 let notifications = session
3933 .db()
3934 .await
3935 .get_notifications(
3936 session.user_id(),
3937 NOTIFICATION_COUNT_PER_PAGE,
3938 request.before_id.map(db::NotificationId::from_proto),
3939 )
3940 .await?;
3941 response.send(proto::GetNotificationsResponse {
3942 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3943 notifications,
3944 })?;
3945 Ok(())
3946}
3947
3948/// Mark notifications as read
3949async fn mark_notification_as_read(
3950 request: proto::MarkNotificationRead,
3951 response: Response<proto::MarkNotificationRead>,
3952 session: Session,
3953) -> Result<()> {
3954 let database = &session.db().await;
3955 let notifications = database
3956 .mark_notification_as_read_by_id(
3957 session.user_id(),
3958 NotificationId::from_proto(request.notification_id),
3959 )
3960 .await?;
3961 send_notifications(
3962 &*session.connection_pool().await,
3963 &session.peer,
3964 notifications,
3965 );
3966 response.send(proto::Ack {})?;
3967 Ok(())
3968}
3969
3970/// Get the current users information
3971async fn get_private_user_info(
3972 _request: proto::GetPrivateUserInfo,
3973 response: Response<proto::GetPrivateUserInfo>,
3974 session: Session,
3975) -> Result<()> {
3976 let db = session.db().await;
3977
3978 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3979 let user = db
3980 .get_user_by_id(session.user_id())
3981 .await?
3982 .ok_or_else(|| anyhow!("user not found"))?;
3983 let flags = db.get_user_flags(session.user_id()).await?;
3984
3985 response.send(proto::GetPrivateUserInfoResponse {
3986 metrics_id,
3987 staff: user.admin,
3988 flags,
3989 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3990 })?;
3991 Ok(())
3992}
3993
3994/// Accept the terms of service (tos) on behalf of the current user
3995async fn accept_terms_of_service(
3996 _request: proto::AcceptTermsOfService,
3997 response: Response<proto::AcceptTermsOfService>,
3998 session: Session,
3999) -> Result<()> {
4000 let db = session.db().await;
4001
4002 let accepted_tos_at = Utc::now();
4003 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4004 .await?;
4005
4006 response.send(proto::AcceptTermsOfServiceResponse {
4007 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4008 })?;
4009 Ok(())
4010}
4011
4012/// The minimum account age an account must have in order to use the LLM service.
4013const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4014
4015async fn get_llm_api_token(
4016 _request: proto::GetLlmToken,
4017 response: Response<proto::GetLlmToken>,
4018 session: Session,
4019) -> Result<()> {
4020 let db = session.db().await;
4021
4022 let flags = db.get_user_flags(session.user_id()).await?;
4023 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4024 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4025
4026 if !session.is_staff() && !has_language_models_feature_flag {
4027 Err(anyhow!("permission denied"))?
4028 }
4029
4030 let user_id = session.user_id();
4031 let user = db
4032 .get_user_by_id(user_id)
4033 .await?
4034 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4035
4036 if user.accepted_tos_at.is_none() {
4037 Err(anyhow!("terms of service not accepted"))?
4038 }
4039
4040 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4041
4042 let bypass_account_age_check =
4043 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4044 if !bypass_account_age_check {
4045 let mut account_created_at = user.created_at;
4046 if let Some(github_created_at) = user.github_user_created_at {
4047 account_created_at = account_created_at.min(github_created_at);
4048 }
4049 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4050 Err(anyhow!("account too young"))?
4051 }
4052 }
4053
4054 let billing_preferences = db.get_billing_preferences(user.id).await?;
4055
4056 let token = LlmTokenClaims::create(
4057 &user,
4058 session.is_staff(),
4059 billing_preferences,
4060 has_llm_closed_beta_feature_flag,
4061 has_llm_subscription,
4062 session.current_plan(&db).await?,
4063 session.system_id.clone(),
4064 &session.app_state.config,
4065 )?;
4066 response.send(proto::GetLlmTokenResponse { token })?;
4067 Ok(())
4068}
4069
4070fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4071 let message = match message {
4072 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4073 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4074 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4075 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4076 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4077 code: frame.code.into(),
4078 reason: frame.reason,
4079 })),
4080 // We should never receive a frame while reading the message, according
4081 // to the `tungstenite` maintainers:
4082 //
4083 // > It cannot occur when you read messages from the WebSocket, but it
4084 // > can be used when you want to send the raw frames (e.g. you want to
4085 // > send the frames to the WebSocket without composing the full message first).
4086 // >
4087 // > — https://github.com/snapview/tungstenite-rs/issues/268
4088 TungsteniteMessage::Frame(_) => {
4089 bail!("received an unexpected frame while reading the message")
4090 }
4091 };
4092
4093 Ok(message)
4094}
4095
4096fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4097 match message {
4098 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4099 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4100 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4101 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4102 AxumMessage::Close(frame) => {
4103 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4104 code: frame.code.into(),
4105 reason: frame.reason,
4106 }))
4107 }
4108 }
4109}
4110
4111fn notify_membership_updated(
4112 connection_pool: &mut ConnectionPool,
4113 result: MembershipUpdated,
4114 user_id: UserId,
4115 peer: &Peer,
4116) {
4117 for membership in &result.new_channels.channel_memberships {
4118 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4119 }
4120 for channel_id in &result.removed_channels {
4121 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4122 }
4123
4124 let user_channels_update = proto::UpdateUserChannels {
4125 channel_memberships: result
4126 .new_channels
4127 .channel_memberships
4128 .iter()
4129 .map(|cm| proto::ChannelMembership {
4130 channel_id: cm.channel_id.to_proto(),
4131 role: cm.role.into(),
4132 })
4133 .collect(),
4134 ..Default::default()
4135 };
4136
4137 let mut update = build_channels_update(result.new_channels);
4138 update.delete_channels = result
4139 .removed_channels
4140 .into_iter()
4141 .map(|id| id.to_proto())
4142 .collect();
4143 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4144
4145 for connection_id in connection_pool.user_connection_ids(user_id) {
4146 peer.send(connection_id, user_channels_update.clone())
4147 .trace_err();
4148 peer.send(connection_id, update.clone()).trace_err();
4149 }
4150}
4151
4152fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4153 proto::UpdateUserChannels {
4154 channel_memberships: channels
4155 .channel_memberships
4156 .iter()
4157 .map(|m| proto::ChannelMembership {
4158 channel_id: m.channel_id.to_proto(),
4159 role: m.role.into(),
4160 })
4161 .collect(),
4162 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4163 observed_channel_message_id: channels.observed_channel_messages.clone(),
4164 }
4165}
4166
4167fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4168 let mut update = proto::UpdateChannels::default();
4169
4170 for channel in channels.channels {
4171 update.channels.push(channel.to_proto());
4172 }
4173
4174 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4175 update.latest_channel_message_ids = channels.latest_channel_messages;
4176
4177 for (channel_id, participants) in channels.channel_participants {
4178 update
4179 .channel_participants
4180 .push(proto::ChannelParticipants {
4181 channel_id: channel_id.to_proto(),
4182 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4183 });
4184 }
4185
4186 for channel in channels.invited_channels {
4187 update.channel_invitations.push(channel.to_proto());
4188 }
4189
4190 update
4191}
4192
4193fn build_initial_contacts_update(
4194 contacts: Vec<db::Contact>,
4195 pool: &ConnectionPool,
4196) -> proto::UpdateContacts {
4197 let mut update = proto::UpdateContacts::default();
4198
4199 for contact in contacts {
4200 match contact {
4201 db::Contact::Accepted { user_id, busy } => {
4202 update.contacts.push(contact_for_user(user_id, busy, pool));
4203 }
4204 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4205 db::Contact::Incoming { user_id } => {
4206 update
4207 .incoming_requests
4208 .push(proto::IncomingContactRequest {
4209 requester_id: user_id.to_proto(),
4210 })
4211 }
4212 }
4213 }
4214
4215 update
4216}
4217
4218fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4219 proto::Contact {
4220 user_id: user_id.to_proto(),
4221 online: pool.is_user_online(user_id),
4222 busy,
4223 }
4224}
4225
4226fn room_updated(room: &proto::Room, peer: &Peer) {
4227 broadcast(
4228 None,
4229 room.participants
4230 .iter()
4231 .filter_map(|participant| Some(participant.peer_id?.into())),
4232 |peer_id| {
4233 peer.send(
4234 peer_id,
4235 proto::RoomUpdated {
4236 room: Some(room.clone()),
4237 },
4238 )
4239 },
4240 );
4241}
4242
4243fn channel_updated(
4244 channel: &db::channel::Model,
4245 room: &proto::Room,
4246 peer: &Peer,
4247 pool: &ConnectionPool,
4248) {
4249 let participants = room
4250 .participants
4251 .iter()
4252 .map(|p| p.user_id)
4253 .collect::<Vec<_>>();
4254
4255 broadcast(
4256 None,
4257 pool.channel_connection_ids(channel.root_id())
4258 .filter_map(|(channel_id, role)| {
4259 role.can_see_channel(channel.visibility)
4260 .then_some(channel_id)
4261 }),
4262 |peer_id| {
4263 peer.send(
4264 peer_id,
4265 proto::UpdateChannels {
4266 channel_participants: vec![proto::ChannelParticipants {
4267 channel_id: channel.id.to_proto(),
4268 participant_user_ids: participants.clone(),
4269 }],
4270 ..Default::default()
4271 },
4272 )
4273 },
4274 );
4275}
4276
4277async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4278 let db = session.db().await;
4279
4280 let contacts = db.get_contacts(user_id).await?;
4281 let busy = db.is_user_busy(user_id).await?;
4282
4283 let pool = session.connection_pool().await;
4284 let updated_contact = contact_for_user(user_id, busy, &pool);
4285 for contact in contacts {
4286 if let db::Contact::Accepted {
4287 user_id: contact_user_id,
4288 ..
4289 } = contact
4290 {
4291 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4292 session
4293 .peer
4294 .send(
4295 contact_conn_id,
4296 proto::UpdateContacts {
4297 contacts: vec![updated_contact.clone()],
4298 remove_contacts: Default::default(),
4299 incoming_requests: Default::default(),
4300 remove_incoming_requests: Default::default(),
4301 outgoing_requests: Default::default(),
4302 remove_outgoing_requests: Default::default(),
4303 },
4304 )
4305 .trace_err();
4306 }
4307 }
4308 }
4309 Ok(())
4310}
4311
4312async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4313 let mut contacts_to_update = HashSet::default();
4314
4315 let room_id;
4316 let canceled_calls_to_user_ids;
4317 let livekit_room;
4318 let delete_livekit_room;
4319 let room;
4320 let channel;
4321
4322 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4323 contacts_to_update.insert(session.user_id());
4324
4325 for project in left_room.left_projects.values() {
4326 project_left(project, session);
4327 }
4328
4329 room_id = RoomId::from_proto(left_room.room.id);
4330 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4331 livekit_room = mem::take(&mut left_room.room.livekit_room);
4332 delete_livekit_room = left_room.deleted;
4333 room = mem::take(&mut left_room.room);
4334 channel = mem::take(&mut left_room.channel);
4335
4336 room_updated(&room, &session.peer);
4337 } else {
4338 return Ok(());
4339 }
4340
4341 if let Some(channel) = channel {
4342 channel_updated(
4343 &channel,
4344 &room,
4345 &session.peer,
4346 &*session.connection_pool().await,
4347 );
4348 }
4349
4350 {
4351 let pool = session.connection_pool().await;
4352 for canceled_user_id in canceled_calls_to_user_ids {
4353 for connection_id in pool.user_connection_ids(canceled_user_id) {
4354 session
4355 .peer
4356 .send(
4357 connection_id,
4358 proto::CallCanceled {
4359 room_id: room_id.to_proto(),
4360 },
4361 )
4362 .trace_err();
4363 }
4364 contacts_to_update.insert(canceled_user_id);
4365 }
4366 }
4367
4368 for contact_user_id in contacts_to_update {
4369 update_user_contacts(contact_user_id, session).await?;
4370 }
4371
4372 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4373 live_kit
4374 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4375 .await
4376 .trace_err();
4377
4378 if delete_livekit_room {
4379 live_kit.delete_room(livekit_room).await.trace_err();
4380 }
4381 }
4382
4383 Ok(())
4384}
4385
4386async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4387 let left_channel_buffers = session
4388 .db()
4389 .await
4390 .leave_channel_buffers(session.connection_id)
4391 .await?;
4392
4393 for left_buffer in left_channel_buffers {
4394 channel_buffer_updated(
4395 session.connection_id,
4396 left_buffer.connections,
4397 &proto::UpdateChannelBufferCollaborators {
4398 channel_id: left_buffer.channel_id.to_proto(),
4399 collaborators: left_buffer.collaborators,
4400 },
4401 &session.peer,
4402 );
4403 }
4404
4405 Ok(())
4406}
4407
4408fn project_left(project: &db::LeftProject, session: &Session) {
4409 for connection_id in &project.connection_ids {
4410 if project.should_unshare {
4411 session
4412 .peer
4413 .send(
4414 *connection_id,
4415 proto::UnshareProject {
4416 project_id: project.id.to_proto(),
4417 },
4418 )
4419 .trace_err();
4420 } else {
4421 session
4422 .peer
4423 .send(
4424 *connection_id,
4425 proto::RemoveProjectCollaborator {
4426 project_id: project.id.to_proto(),
4427 peer_id: Some(session.connection_id.into()),
4428 },
4429 )
4430 .trace_err();
4431 }
4432 }
4433}
4434
4435pub trait ResultExt {
4436 type Ok;
4437
4438 fn trace_err(self) -> Option<Self::Ok>;
4439}
4440
4441impl<T, E> ResultExt for Result<T, E>
4442where
4443 E: std::fmt::Debug,
4444{
4445 type Ok = T;
4446
4447 #[track_caller]
4448 fn trace_err(self) -> Option<T> {
4449 match self {
4450 Ok(value) => Some(value),
4451 Err(error) => {
4452 tracing::error!("{:?}", error);
4453 None
4454 }
4455 }
4456 }
4457}