1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
313 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
322 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
323 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
324 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
325 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
326 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
327 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
328 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
329 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
330 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
331 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
333 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
334 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
335 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
336 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
337 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
338 .add_message_handler(create_buffer_for_peer)
339 .add_request_handler(update_buffer)
340 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
341 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
342 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
343 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
344 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
345 .add_request_handler(get_users)
346 .add_request_handler(fuzzy_search_users)
347 .add_request_handler(request_contact)
348 .add_request_handler(remove_contact)
349 .add_request_handler(respond_to_contact_request)
350 .add_message_handler(subscribe_to_channels)
351 .add_request_handler(create_channel)
352 .add_request_handler(delete_channel)
353 .add_request_handler(invite_channel_member)
354 .add_request_handler(remove_channel_member)
355 .add_request_handler(set_channel_member_role)
356 .add_request_handler(set_channel_visibility)
357 .add_request_handler(rename_channel)
358 .add_request_handler(join_channel_buffer)
359 .add_request_handler(leave_channel_buffer)
360 .add_message_handler(update_channel_buffer)
361 .add_request_handler(rejoin_channel_buffers)
362 .add_request_handler(get_channel_members)
363 .add_request_handler(respond_to_channel_invite)
364 .add_request_handler(join_channel)
365 .add_request_handler(join_channel_chat)
366 .add_message_handler(leave_channel_chat)
367 .add_request_handler(send_channel_message)
368 .add_request_handler(remove_channel_message)
369 .add_request_handler(update_channel_message)
370 .add_request_handler(get_channel_messages)
371 .add_request_handler(get_channel_messages_by_id)
372 .add_request_handler(get_notifications)
373 .add_request_handler(mark_notification_as_read)
374 .add_request_handler(move_channel)
375 .add_request_handler(follow)
376 .add_message_handler(unfollow)
377 .add_message_handler(update_followers)
378 .add_request_handler(get_private_user_info)
379 .add_request_handler(get_llm_api_token)
380 .add_request_handler(accept_terms_of_service)
381 .add_message_handler(acknowledge_channel_message)
382 .add_message_handler(acknowledge_buffer_version)
383 .add_request_handler(get_supermaven_api_key)
384 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
385 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
386 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
388 .add_message_handler(update_context)
389 .add_request_handler({
390 let app_state = app_state.clone();
391 move |request, response, session| {
392 let app_state = app_state.clone();
393 async move {
394 count_language_model_tokens(request, response, session, &app_state.config)
395 .await
396 }
397 }
398 })
399 .add_request_handler(get_cached_embeddings)
400 .add_request_handler({
401 let app_state = app_state.clone();
402 move |request, response, session| {
403 compute_embeddings(
404 request,
405 response,
406 session,
407 app_state.config.openai_api_key.clone(),
408 )
409 }
410 });
411
412 Arc::new(server)
413 }
414
415 pub async fn start(&self) -> Result<()> {
416 let server_id = *self.id.lock();
417 let app_state = self.app_state.clone();
418 let peer = self.peer.clone();
419 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
420 let pool = self.connection_pool.clone();
421 let live_kit_client = self.app_state.live_kit_client.clone();
422
423 let span = info_span!("start server");
424 self.app_state.executor.spawn_detached(
425 async move {
426 tracing::info!("waiting for cleanup timeout");
427 timeout.await;
428 tracing::info!("cleanup timeout expired, retrieving stale rooms");
429 if let Some((room_ids, channel_ids)) = app_state
430 .db
431 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
432 .await
433 .trace_err()
434 {
435 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
436 tracing::info!(
437 stale_channel_buffer_count = channel_ids.len(),
438 "retrieved stale channel buffers"
439 );
440
441 for channel_id in channel_ids {
442 if let Some(refreshed_channel_buffer) = app_state
443 .db
444 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
445 .await
446 .trace_err()
447 {
448 for connection_id in refreshed_channel_buffer.connection_ids {
449 peer.send(
450 connection_id,
451 proto::UpdateChannelBufferCollaborators {
452 channel_id: channel_id.to_proto(),
453 collaborators: refreshed_channel_buffer
454 .collaborators
455 .clone(),
456 },
457 )
458 .trace_err();
459 }
460 }
461 }
462
463 for room_id in room_ids {
464 let mut contacts_to_update = HashSet::default();
465 let mut canceled_calls_to_user_ids = Vec::new();
466 let mut live_kit_room = String::new();
467 let mut delete_live_kit_room = false;
468
469 if let Some(mut refreshed_room) = app_state
470 .db
471 .clear_stale_room_participants(room_id, server_id)
472 .await
473 .trace_err()
474 {
475 tracing::info!(
476 room_id = room_id.0,
477 new_participant_count = refreshed_room.room.participants.len(),
478 "refreshed room"
479 );
480 room_updated(&refreshed_room.room, &peer);
481 if let Some(channel) = refreshed_room.channel.as_ref() {
482 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
483 }
484 contacts_to_update
485 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
486 contacts_to_update
487 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
488 canceled_calls_to_user_ids =
489 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
490 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
491 delete_live_kit_room = refreshed_room.room.participants.is_empty();
492 }
493
494 {
495 let pool = pool.lock();
496 for canceled_user_id in canceled_calls_to_user_ids {
497 for connection_id in pool.user_connection_ids(canceled_user_id) {
498 peer.send(
499 connection_id,
500 proto::CallCanceled {
501 room_id: room_id.to_proto(),
502 },
503 )
504 .trace_err();
505 }
506 }
507 }
508
509 for user_id in contacts_to_update {
510 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
511 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
512 if let Some((busy, contacts)) = busy.zip(contacts) {
513 let pool = pool.lock();
514 let updated_contact = contact_for_user(user_id, busy, &pool);
515 for contact in contacts {
516 if let db::Contact::Accepted {
517 user_id: contact_user_id,
518 ..
519 } = contact
520 {
521 for contact_conn_id in
522 pool.user_connection_ids(contact_user_id)
523 {
524 peer.send(
525 contact_conn_id,
526 proto::UpdateContacts {
527 contacts: vec![updated_contact.clone()],
528 remove_contacts: Default::default(),
529 incoming_requests: Default::default(),
530 remove_incoming_requests: Default::default(),
531 outgoing_requests: Default::default(),
532 remove_outgoing_requests: Default::default(),
533 },
534 )
535 .trace_err();
536 }
537 }
538 }
539 }
540 }
541
542 if let Some(live_kit) = live_kit_client.as_ref() {
543 if delete_live_kit_room {
544 live_kit.delete_room(live_kit_room).await.trace_err();
545 }
546 }
547 }
548 }
549
550 app_state
551 .db
552 .delete_stale_servers(&app_state.config.zed_environment, server_id)
553 .await
554 .trace_err();
555 }
556 .instrument(span),
557 );
558 Ok(())
559 }
560
561 pub fn teardown(&self) {
562 self.peer.teardown();
563 self.connection_pool.lock().reset();
564 let _ = self.teardown.send(true);
565 }
566
567 #[cfg(test)]
568 pub fn reset(&self, id: ServerId) {
569 self.teardown();
570 *self.id.lock() = id;
571 self.peer.reset(id.0 as u32);
572 let _ = self.teardown.send(false);
573 }
574
575 #[cfg(test)]
576 pub fn id(&self) -> ServerId {
577 *self.id.lock()
578 }
579
580 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
581 where
582 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
583 Fut: 'static + Send + Future<Output = Result<()>>,
584 M: EnvelopedMessage,
585 {
586 let prev_handler = self.handlers.insert(
587 TypeId::of::<M>(),
588 Box::new(move |envelope, session| {
589 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
590 let received_at = envelope.received_at;
591 tracing::info!("message received");
592 let start_time = Instant::now();
593 let future = (handler)(*envelope, session);
594 async move {
595 let result = future.await;
596 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
597 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
598 let queue_duration_ms = total_duration_ms - processing_duration_ms;
599 let payload_type = M::NAME;
600
601 match result {
602 Err(error) => {
603 tracing::error!(
604 ?error,
605 total_duration_ms,
606 processing_duration_ms,
607 queue_duration_ms,
608 payload_type,
609 "error handling message"
610 )
611 }
612 Ok(()) => tracing::info!(
613 total_duration_ms,
614 processing_duration_ms,
615 queue_duration_ms,
616 "finished handling message"
617 ),
618 }
619 }
620 .boxed()
621 }),
622 );
623 if prev_handler.is_some() {
624 panic!("registered a handler for the same message twice");
625 }
626 self
627 }
628
629 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
630 where
631 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
632 Fut: 'static + Send + Future<Output = Result<()>>,
633 M: EnvelopedMessage,
634 {
635 self.add_handler(move |envelope, session| handler(envelope.payload, session));
636 self
637 }
638
639 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
640 where
641 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
642 Fut: Send + Future<Output = Result<()>>,
643 M: RequestMessage,
644 {
645 let handler = Arc::new(handler);
646 self.add_handler(move |envelope, session| {
647 let receipt = envelope.receipt();
648 let handler = handler.clone();
649 async move {
650 let peer = session.peer.clone();
651 let responded = Arc::new(AtomicBool::default());
652 let response = Response {
653 peer: peer.clone(),
654 responded: responded.clone(),
655 receipt,
656 };
657 match (handler)(envelope.payload, response, session).await {
658 Ok(()) => {
659 if responded.load(std::sync::atomic::Ordering::SeqCst) {
660 Ok(())
661 } else {
662 Err(anyhow!("handler did not send a response"))?
663 }
664 }
665 Err(error) => {
666 let proto_err = match &error {
667 Error::Internal(err) => err.to_proto(),
668 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
669 };
670 peer.respond_with_error(receipt, proto_err)?;
671 Err(error)
672 }
673 }
674 }
675 })
676 }
677
678 #[allow(clippy::too_many_arguments)]
679 pub fn handle_connection(
680 self: &Arc<Self>,
681 connection: Connection,
682 address: String,
683 principal: Principal,
684 zed_version: ZedVersion,
685 geoip_country_code: Option<String>,
686 system_id: Option<String>,
687 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
688 executor: Executor,
689 ) -> impl Future<Output = ()> {
690 let this = self.clone();
691 let span = info_span!("handle connection", %address,
692 connection_id=field::Empty,
693 user_id=field::Empty,
694 login=field::Empty,
695 impersonator=field::Empty,
696 geoip_country_code=field::Empty
697 );
698 principal.update_span(&span);
699 if let Some(country_code) = geoip_country_code.as_ref() {
700 span.record("geoip_country_code", country_code);
701 }
702
703 let mut teardown = self.teardown.subscribe();
704 async move {
705 if *teardown.borrow() {
706 tracing::error!("server is tearing down");
707 return
708 }
709 let (connection_id, handle_io, mut incoming_rx) = this
710 .peer
711 .add_connection(connection, {
712 let executor = executor.clone();
713 move |duration| executor.sleep(duration)
714 });
715 tracing::Span::current().record("connection_id", format!("{}", connection_id));
716
717 tracing::info!("connection opened");
718
719 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
720 let http_client = match ReqwestClient::user_agent(&user_agent) {
721 Ok(http_client) => Arc::new(http_client),
722 Err(error) => {
723 tracing::error!(?error, "failed to create HTTP client");
724 return;
725 }
726 };
727
728 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
729 supermaven_admin_api_key.to_string(),
730 http_client.clone(),
731 )));
732
733 let session = Session {
734 principal: principal.clone(),
735 connection_id,
736 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
737 peer: this.peer.clone(),
738 connection_pool: this.connection_pool.clone(),
739 app_state: this.app_state.clone(),
740 http_client,
741 geoip_country_code,
742 system_id,
743 _executor: executor.clone(),
744 supermaven_client,
745 };
746
747 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
748 tracing::error!(?error, "failed to send initial client update");
749 return;
750 }
751
752 let handle_io = handle_io.fuse();
753 futures::pin_mut!(handle_io);
754
755 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
756 // This prevents deadlocks when e.g., client A performs a request to client B and
757 // client B performs a request to client A. If both clients stop processing further
758 // messages until their respective request completes, they won't have a chance to
759 // respond to the other client's request and cause a deadlock.
760 //
761 // This arrangement ensures we will attempt to process earlier messages first, but fall
762 // back to processing messages arrived later in the spirit of making progress.
763 let mut foreground_message_handlers = FuturesUnordered::new();
764 let concurrent_handlers = Arc::new(Semaphore::new(256));
765 loop {
766 let next_message = async {
767 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
768 let message = incoming_rx.next().await;
769 (permit, message)
770 }.fuse();
771 futures::pin_mut!(next_message);
772 futures::select_biased! {
773 _ = teardown.changed().fuse() => return,
774 result = handle_io => {
775 if let Err(error) = result {
776 tracing::error!(?error, "error handling I/O");
777 }
778 break;
779 }
780 _ = foreground_message_handlers.next() => {}
781 next_message = next_message => {
782 let (permit, message) = next_message;
783 if let Some(message) = message {
784 let type_name = message.payload_type_name();
785 // note: we copy all the fields from the parent span so we can query them in the logs.
786 // (https://github.com/tokio-rs/tracing/issues/2670).
787 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
788 user_id=field::Empty,
789 login=field::Empty,
790 impersonator=field::Empty,
791 );
792 principal.update_span(&span);
793 let span_enter = span.enter();
794 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
795 let is_background = message.is_background();
796 let handle_message = (handler)(message, session.clone());
797 drop(span_enter);
798
799 let handle_message = async move {
800 handle_message.await;
801 drop(permit);
802 }.instrument(span);
803 if is_background {
804 executor.spawn_detached(handle_message);
805 } else {
806 foreground_message_handlers.push(handle_message);
807 }
808 } else {
809 tracing::error!("no message handler");
810 }
811 } else {
812 tracing::info!("connection closed");
813 break;
814 }
815 }
816 }
817 }
818
819 drop(foreground_message_handlers);
820 tracing::info!("signing out");
821 if let Err(error) = connection_lost(session, teardown, executor).await {
822 tracing::error!(?error, "error signing out");
823 }
824
825 }.instrument(span)
826 }
827
828 async fn send_initial_client_update(
829 &self,
830 connection_id: ConnectionId,
831 principal: &Principal,
832 zed_version: ZedVersion,
833 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
834 session: &Session,
835 ) -> Result<()> {
836 self.peer.send(
837 connection_id,
838 proto::Hello {
839 peer_id: Some(connection_id.into()),
840 },
841 )?;
842 tracing::info!("sent hello message");
843 if let Some(send_connection_id) = send_connection_id.take() {
844 let _ = send_connection_id.send(connection_id);
845 }
846
847 match principal {
848 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
849 if !user.connected_once {
850 self.peer.send(connection_id, proto::ShowContacts {})?;
851 self.app_state
852 .db
853 .set_user_connected_once(user.id, true)
854 .await?;
855 }
856
857 update_user_plan(user.id, session).await?;
858
859 let contacts = self.app_state.db.get_contacts(user.id).await?;
860
861 {
862 let mut pool = self.connection_pool.lock();
863 pool.add_connection(connection_id, user.id, user.admin, zed_version);
864 self.peer.send(
865 connection_id,
866 build_initial_contacts_update(contacts, &pool),
867 )?;
868 }
869
870 if should_auto_subscribe_to_channels(zed_version) {
871 subscribe_user_to_channels(user.id, session).await?;
872 }
873
874 if let Some(incoming_call) =
875 self.app_state.db.incoming_call_for_user(user.id).await?
876 {
877 self.peer.send(connection_id, incoming_call)?;
878 }
879
880 update_user_contacts(user.id, session).await?;
881 }
882 }
883
884 Ok(())
885 }
886
887 pub async fn invite_code_redeemed(
888 self: &Arc<Self>,
889 inviter_id: UserId,
890 invitee_id: UserId,
891 ) -> Result<()> {
892 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
893 if let Some(code) = &user.invite_code {
894 let pool = self.connection_pool.lock();
895 let invitee_contact = contact_for_user(invitee_id, false, &pool);
896 for connection_id in pool.user_connection_ids(inviter_id) {
897 self.peer.send(
898 connection_id,
899 proto::UpdateContacts {
900 contacts: vec![invitee_contact.clone()],
901 ..Default::default()
902 },
903 )?;
904 self.peer.send(
905 connection_id,
906 proto::UpdateInviteInfo {
907 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
908 count: user.invite_count as u32,
909 },
910 )?;
911 }
912 }
913 }
914 Ok(())
915 }
916
917 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
918 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
919 if let Some(invite_code) = &user.invite_code {
920 let pool = self.connection_pool.lock();
921 for connection_id in pool.user_connection_ids(user_id) {
922 self.peer.send(
923 connection_id,
924 proto::UpdateInviteInfo {
925 url: format!(
926 "{}{}",
927 self.app_state.config.invite_link_prefix, invite_code
928 ),
929 count: user.invite_count as u32,
930 },
931 )?;
932 }
933 }
934 }
935 Ok(())
936 }
937
938 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
939 let pool = self.connection_pool.lock();
940 for connection_id in pool.user_connection_ids(user_id) {
941 self.peer
942 .send(connection_id, proto::RefreshLlmToken {})
943 .trace_err();
944 }
945 }
946
947 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
948 ServerSnapshot {
949 connection_pool: ConnectionPoolGuard {
950 guard: self.connection_pool.lock(),
951 _not_send: PhantomData,
952 },
953 peer: &self.peer,
954 }
955 }
956}
957
958impl<'a> Deref for ConnectionPoolGuard<'a> {
959 type Target = ConnectionPool;
960
961 fn deref(&self) -> &Self::Target {
962 &self.guard
963 }
964}
965
966impl<'a> DerefMut for ConnectionPoolGuard<'a> {
967 fn deref_mut(&mut self) -> &mut Self::Target {
968 &mut self.guard
969 }
970}
971
972impl<'a> Drop for ConnectionPoolGuard<'a> {
973 fn drop(&mut self) {
974 #[cfg(test)]
975 self.check_invariants();
976 }
977}
978
979fn broadcast<F>(
980 sender_id: Option<ConnectionId>,
981 receiver_ids: impl IntoIterator<Item = ConnectionId>,
982 mut f: F,
983) where
984 F: FnMut(ConnectionId) -> anyhow::Result<()>,
985{
986 for receiver_id in receiver_ids {
987 if Some(receiver_id) != sender_id {
988 if let Err(error) = f(receiver_id) {
989 tracing::error!("failed to send to {:?} {}", receiver_id, error);
990 }
991 }
992 }
993}
994
995pub struct ProtocolVersion(u32);
996
997impl Header for ProtocolVersion {
998 fn name() -> &'static HeaderName {
999 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1000 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1001 }
1002
1003 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1004 where
1005 Self: Sized,
1006 I: Iterator<Item = &'i axum::http::HeaderValue>,
1007 {
1008 let version = values
1009 .next()
1010 .ok_or_else(axum::headers::Error::invalid)?
1011 .to_str()
1012 .map_err(|_| axum::headers::Error::invalid())?
1013 .parse()
1014 .map_err(|_| axum::headers::Error::invalid())?;
1015 Ok(Self(version))
1016 }
1017
1018 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1019 values.extend([self.0.to_string().parse().unwrap()]);
1020 }
1021}
1022
1023pub struct AppVersionHeader(SemanticVersion);
1024impl Header for AppVersionHeader {
1025 fn name() -> &'static HeaderName {
1026 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1027 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1028 }
1029
1030 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1031 where
1032 Self: Sized,
1033 I: Iterator<Item = &'i axum::http::HeaderValue>,
1034 {
1035 let version = values
1036 .next()
1037 .ok_or_else(axum::headers::Error::invalid)?
1038 .to_str()
1039 .map_err(|_| axum::headers::Error::invalid())?
1040 .parse()
1041 .map_err(|_| axum::headers::Error::invalid())?;
1042 Ok(Self(version))
1043 }
1044
1045 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1046 values.extend([self.0.to_string().parse().unwrap()]);
1047 }
1048}
1049
1050pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1051 Router::new()
1052 .route("/rpc", get(handle_websocket_request))
1053 .layer(
1054 ServiceBuilder::new()
1055 .layer(Extension(server.app_state.clone()))
1056 .layer(middleware::from_fn(auth::validate_header)),
1057 )
1058 .route("/metrics", get(handle_metrics))
1059 .layer(Extension(server))
1060}
1061
1062#[allow(clippy::too_many_arguments)]
1063pub async fn handle_websocket_request(
1064 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1065 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1066 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1067 Extension(server): Extension<Arc<Server>>,
1068 Extension(principal): Extension<Principal>,
1069 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1070 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1071 ws: WebSocketUpgrade,
1072) -> axum::response::Response {
1073 if protocol_version != rpc::PROTOCOL_VERSION {
1074 return (
1075 StatusCode::UPGRADE_REQUIRED,
1076 "client must be upgraded".to_string(),
1077 )
1078 .into_response();
1079 }
1080
1081 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1082 return (
1083 StatusCode::UPGRADE_REQUIRED,
1084 "no version header found".to_string(),
1085 )
1086 .into_response();
1087 };
1088
1089 if !version.can_collaborate() {
1090 return (
1091 StatusCode::UPGRADE_REQUIRED,
1092 "client must be upgraded".to_string(),
1093 )
1094 .into_response();
1095 }
1096
1097 let socket_address = socket_address.to_string();
1098 ws.on_upgrade(move |socket| {
1099 let socket = socket
1100 .map_ok(to_tungstenite_message)
1101 .err_into()
1102 .with(|message| async move { to_axum_message(message) });
1103 let connection = Connection::new(Box::pin(socket));
1104 async move {
1105 server
1106 .handle_connection(
1107 connection,
1108 socket_address,
1109 principal,
1110 version,
1111 country_code_header.map(|header| header.to_string()),
1112 system_id_header.map(|header| header.to_string()),
1113 None,
1114 Executor::Production,
1115 )
1116 .await;
1117 }
1118 })
1119}
1120
1121pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1122 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1123 let connections_metric = CONNECTIONS_METRIC
1124 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1125
1126 let connections = server
1127 .connection_pool
1128 .lock()
1129 .connections()
1130 .filter(|connection| !connection.admin)
1131 .count();
1132 connections_metric.set(connections as _);
1133
1134 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1135 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1136 register_int_gauge!(
1137 "shared_projects",
1138 "number of open projects with one or more guests"
1139 )
1140 .unwrap()
1141 });
1142
1143 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1144 shared_projects_metric.set(shared_projects as _);
1145
1146 let encoder = prometheus::TextEncoder::new();
1147 let metric_families = prometheus::gather();
1148 let encoded_metrics = encoder
1149 .encode_to_string(&metric_families)
1150 .map_err(|err| anyhow!("{}", err))?;
1151 Ok(encoded_metrics)
1152}
1153
1154#[instrument(err, skip(executor))]
1155async fn connection_lost(
1156 session: Session,
1157 mut teardown: watch::Receiver<bool>,
1158 executor: Executor,
1159) -> Result<()> {
1160 session.peer.disconnect(session.connection_id);
1161 session
1162 .connection_pool()
1163 .await
1164 .remove_connection(session.connection_id)?;
1165
1166 session
1167 .db()
1168 .await
1169 .connection_lost(session.connection_id)
1170 .await
1171 .trace_err();
1172
1173 futures::select_biased! {
1174 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1175
1176 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1177 leave_room_for_session(&session, session.connection_id).await.trace_err();
1178 leave_channel_buffers_for_session(&session)
1179 .await
1180 .trace_err();
1181
1182 if !session
1183 .connection_pool()
1184 .await
1185 .is_user_online(session.user_id())
1186 {
1187 let db = session.db().await;
1188 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1189 room_updated(&room, &session.peer);
1190 }
1191 }
1192
1193 update_user_contacts(session.user_id(), &session).await?;
1194 },
1195 _ = teardown.changed().fuse() => {}
1196 }
1197
1198 Ok(())
1199}
1200
1201/// Acknowledges a ping from a client, used to keep the connection alive.
1202async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1203 response.send(proto::Ack {})?;
1204 Ok(())
1205}
1206
1207/// Creates a new room for calling (outside of channels)
1208async fn create_room(
1209 _request: proto::CreateRoom,
1210 response: Response<proto::CreateRoom>,
1211 session: Session,
1212) -> Result<()> {
1213 let live_kit_room = nanoid::nanoid!(30);
1214
1215 let live_kit_connection_info = util::maybe!(async {
1216 let live_kit = session.app_state.live_kit_client.as_ref();
1217 let live_kit = live_kit?;
1218 let user_id = session.user_id().to_string();
1219
1220 let token = live_kit
1221 .room_token(&live_kit_room, &user_id.to_string())
1222 .trace_err()?;
1223
1224 Some(proto::LiveKitConnectionInfo {
1225 server_url: live_kit.url().into(),
1226 token,
1227 can_publish: true,
1228 })
1229 })
1230 .await;
1231
1232 let room = session
1233 .db()
1234 .await
1235 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1236 .await?;
1237
1238 response.send(proto::CreateRoomResponse {
1239 room: Some(room.clone()),
1240 live_kit_connection_info,
1241 })?;
1242
1243 update_user_contacts(session.user_id(), &session).await?;
1244 Ok(())
1245}
1246
1247/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1248async fn join_room(
1249 request: proto::JoinRoom,
1250 response: Response<proto::JoinRoom>,
1251 session: Session,
1252) -> Result<()> {
1253 let room_id = RoomId::from_proto(request.id);
1254
1255 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1256
1257 if let Some(channel_id) = channel_id {
1258 return join_channel_internal(channel_id, Box::new(response), session).await;
1259 }
1260
1261 let joined_room = {
1262 let room = session
1263 .db()
1264 .await
1265 .join_room(room_id, session.user_id(), session.connection_id)
1266 .await?;
1267 room_updated(&room.room, &session.peer);
1268 room.into_inner()
1269 };
1270
1271 for connection_id in session
1272 .connection_pool()
1273 .await
1274 .user_connection_ids(session.user_id())
1275 {
1276 session
1277 .peer
1278 .send(
1279 connection_id,
1280 proto::CallCanceled {
1281 room_id: room_id.to_proto(),
1282 },
1283 )
1284 .trace_err();
1285 }
1286
1287 let live_kit_connection_info =
1288 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1289 live_kit
1290 .room_token(
1291 &joined_room.room.live_kit_room,
1292 &session.user_id().to_string(),
1293 )
1294 .trace_err()
1295 .map(|token| proto::LiveKitConnectionInfo {
1296 server_url: live_kit.url().into(),
1297 token,
1298 can_publish: true,
1299 })
1300 } else {
1301 None
1302 };
1303
1304 response.send(proto::JoinRoomResponse {
1305 room: Some(joined_room.room),
1306 channel_id: None,
1307 live_kit_connection_info,
1308 })?;
1309
1310 update_user_contacts(session.user_id(), &session).await?;
1311 Ok(())
1312}
1313
1314/// Rejoin room is used to reconnect to a room after connection errors.
1315async fn rejoin_room(
1316 request: proto::RejoinRoom,
1317 response: Response<proto::RejoinRoom>,
1318 session: Session,
1319) -> Result<()> {
1320 let room;
1321 let channel;
1322 {
1323 let mut rejoined_room = session
1324 .db()
1325 .await
1326 .rejoin_room(request, session.user_id(), session.connection_id)
1327 .await?;
1328
1329 response.send(proto::RejoinRoomResponse {
1330 room: Some(rejoined_room.room.clone()),
1331 reshared_projects: rejoined_room
1332 .reshared_projects
1333 .iter()
1334 .map(|project| proto::ResharedProject {
1335 id: project.id.to_proto(),
1336 collaborators: project
1337 .collaborators
1338 .iter()
1339 .map(|collaborator| collaborator.to_proto())
1340 .collect(),
1341 })
1342 .collect(),
1343 rejoined_projects: rejoined_room
1344 .rejoined_projects
1345 .iter()
1346 .map(|rejoined_project| rejoined_project.to_proto())
1347 .collect(),
1348 })?;
1349 room_updated(&rejoined_room.room, &session.peer);
1350
1351 for project in &rejoined_room.reshared_projects {
1352 for collaborator in &project.collaborators {
1353 session
1354 .peer
1355 .send(
1356 collaborator.connection_id,
1357 proto::UpdateProjectCollaborator {
1358 project_id: project.id.to_proto(),
1359 old_peer_id: Some(project.old_connection_id.into()),
1360 new_peer_id: Some(session.connection_id.into()),
1361 },
1362 )
1363 .trace_err();
1364 }
1365
1366 broadcast(
1367 Some(session.connection_id),
1368 project
1369 .collaborators
1370 .iter()
1371 .map(|collaborator| collaborator.connection_id),
1372 |connection_id| {
1373 session.peer.forward_send(
1374 session.connection_id,
1375 connection_id,
1376 proto::UpdateProject {
1377 project_id: project.id.to_proto(),
1378 worktrees: project.worktrees.clone(),
1379 },
1380 )
1381 },
1382 );
1383 }
1384
1385 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1386
1387 let rejoined_room = rejoined_room.into_inner();
1388
1389 room = rejoined_room.room;
1390 channel = rejoined_room.channel;
1391 }
1392
1393 if let Some(channel) = channel {
1394 channel_updated(
1395 &channel,
1396 &room,
1397 &session.peer,
1398 &*session.connection_pool().await,
1399 );
1400 }
1401
1402 update_user_contacts(session.user_id(), &session).await?;
1403 Ok(())
1404}
1405
1406fn notify_rejoined_projects(
1407 rejoined_projects: &mut Vec<RejoinedProject>,
1408 session: &Session,
1409) -> Result<()> {
1410 for project in rejoined_projects.iter() {
1411 for collaborator in &project.collaborators {
1412 session
1413 .peer
1414 .send(
1415 collaborator.connection_id,
1416 proto::UpdateProjectCollaborator {
1417 project_id: project.id.to_proto(),
1418 old_peer_id: Some(project.old_connection_id.into()),
1419 new_peer_id: Some(session.connection_id.into()),
1420 },
1421 )
1422 .trace_err();
1423 }
1424 }
1425
1426 for project in rejoined_projects {
1427 for worktree in mem::take(&mut project.worktrees) {
1428 // Stream this worktree's entries.
1429 let message = proto::UpdateWorktree {
1430 project_id: project.id.to_proto(),
1431 worktree_id: worktree.id,
1432 abs_path: worktree.abs_path.clone(),
1433 root_name: worktree.root_name,
1434 updated_entries: worktree.updated_entries,
1435 removed_entries: worktree.removed_entries,
1436 scan_id: worktree.scan_id,
1437 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1438 updated_repositories: worktree.updated_repositories,
1439 removed_repositories: worktree.removed_repositories,
1440 };
1441 for update in proto::split_worktree_update(message) {
1442 session.peer.send(session.connection_id, update.clone())?;
1443 }
1444
1445 // Stream this worktree's diagnostics.
1446 for summary in worktree.diagnostic_summaries {
1447 session.peer.send(
1448 session.connection_id,
1449 proto::UpdateDiagnosticSummary {
1450 project_id: project.id.to_proto(),
1451 worktree_id: worktree.id,
1452 summary: Some(summary),
1453 },
1454 )?;
1455 }
1456
1457 for settings_file in worktree.settings_files {
1458 session.peer.send(
1459 session.connection_id,
1460 proto::UpdateWorktreeSettings {
1461 project_id: project.id.to_proto(),
1462 worktree_id: worktree.id,
1463 path: settings_file.path,
1464 content: Some(settings_file.content),
1465 kind: Some(settings_file.kind.to_proto().into()),
1466 },
1467 )?;
1468 }
1469 }
1470
1471 for language_server in &project.language_servers {
1472 session.peer.send(
1473 session.connection_id,
1474 proto::UpdateLanguageServer {
1475 project_id: project.id.to_proto(),
1476 language_server_id: language_server.id,
1477 variant: Some(
1478 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1479 proto::LspDiskBasedDiagnosticsUpdated {},
1480 ),
1481 ),
1482 },
1483 )?;
1484 }
1485 }
1486 Ok(())
1487}
1488
1489/// leave room disconnects from the room.
1490async fn leave_room(
1491 _: proto::LeaveRoom,
1492 response: Response<proto::LeaveRoom>,
1493 session: Session,
1494) -> Result<()> {
1495 leave_room_for_session(&session, session.connection_id).await?;
1496 response.send(proto::Ack {})?;
1497 Ok(())
1498}
1499
1500/// Updates the permissions of someone else in the room.
1501async fn set_room_participant_role(
1502 request: proto::SetRoomParticipantRole,
1503 response: Response<proto::SetRoomParticipantRole>,
1504 session: Session,
1505) -> Result<()> {
1506 let user_id = UserId::from_proto(request.user_id);
1507 let role = ChannelRole::from(request.role());
1508
1509 let (live_kit_room, can_publish) = {
1510 let room = session
1511 .db()
1512 .await
1513 .set_room_participant_role(
1514 session.user_id(),
1515 RoomId::from_proto(request.room_id),
1516 user_id,
1517 role,
1518 )
1519 .await?;
1520
1521 let live_kit_room = room.live_kit_room.clone();
1522 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1523 room_updated(&room, &session.peer);
1524 (live_kit_room, can_publish)
1525 };
1526
1527 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1528 live_kit
1529 .update_participant(
1530 live_kit_room.clone(),
1531 request.user_id.to_string(),
1532 live_kit_server::proto::ParticipantPermission {
1533 can_subscribe: true,
1534 can_publish,
1535 can_publish_data: can_publish,
1536 hidden: false,
1537 recorder: false,
1538 },
1539 )
1540 .await
1541 .trace_err();
1542 }
1543
1544 response.send(proto::Ack {})?;
1545 Ok(())
1546}
1547
1548/// Call someone else into the current room
1549async fn call(
1550 request: proto::Call,
1551 response: Response<proto::Call>,
1552 session: Session,
1553) -> Result<()> {
1554 let room_id = RoomId::from_proto(request.room_id);
1555 let calling_user_id = session.user_id();
1556 let calling_connection_id = session.connection_id;
1557 let called_user_id = UserId::from_proto(request.called_user_id);
1558 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1559 if !session
1560 .db()
1561 .await
1562 .has_contact(calling_user_id, called_user_id)
1563 .await?
1564 {
1565 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1566 }
1567
1568 let incoming_call = {
1569 let (room, incoming_call) = &mut *session
1570 .db()
1571 .await
1572 .call(
1573 room_id,
1574 calling_user_id,
1575 calling_connection_id,
1576 called_user_id,
1577 initial_project_id,
1578 )
1579 .await?;
1580 room_updated(room, &session.peer);
1581 mem::take(incoming_call)
1582 };
1583 update_user_contacts(called_user_id, &session).await?;
1584
1585 let mut calls = session
1586 .connection_pool()
1587 .await
1588 .user_connection_ids(called_user_id)
1589 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1590 .collect::<FuturesUnordered<_>>();
1591
1592 while let Some(call_response) = calls.next().await {
1593 match call_response.as_ref() {
1594 Ok(_) => {
1595 response.send(proto::Ack {})?;
1596 return Ok(());
1597 }
1598 Err(_) => {
1599 call_response.trace_err();
1600 }
1601 }
1602 }
1603
1604 {
1605 let room = session
1606 .db()
1607 .await
1608 .call_failed(room_id, called_user_id)
1609 .await?;
1610 room_updated(&room, &session.peer);
1611 }
1612 update_user_contacts(called_user_id, &session).await?;
1613
1614 Err(anyhow!("failed to ring user"))?
1615}
1616
1617/// Cancel an outgoing call.
1618async fn cancel_call(
1619 request: proto::CancelCall,
1620 response: Response<proto::CancelCall>,
1621 session: Session,
1622) -> Result<()> {
1623 let called_user_id = UserId::from_proto(request.called_user_id);
1624 let room_id = RoomId::from_proto(request.room_id);
1625 {
1626 let room = session
1627 .db()
1628 .await
1629 .cancel_call(room_id, session.connection_id, called_user_id)
1630 .await?;
1631 room_updated(&room, &session.peer);
1632 }
1633
1634 for connection_id in session
1635 .connection_pool()
1636 .await
1637 .user_connection_ids(called_user_id)
1638 {
1639 session
1640 .peer
1641 .send(
1642 connection_id,
1643 proto::CallCanceled {
1644 room_id: room_id.to_proto(),
1645 },
1646 )
1647 .trace_err();
1648 }
1649 response.send(proto::Ack {})?;
1650
1651 update_user_contacts(called_user_id, &session).await?;
1652 Ok(())
1653}
1654
1655/// Decline an incoming call.
1656async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1657 let room_id = RoomId::from_proto(message.room_id);
1658 {
1659 let room = session
1660 .db()
1661 .await
1662 .decline_call(Some(room_id), session.user_id())
1663 .await?
1664 .ok_or_else(|| anyhow!("failed to decline call"))?;
1665 room_updated(&room, &session.peer);
1666 }
1667
1668 for connection_id in session
1669 .connection_pool()
1670 .await
1671 .user_connection_ids(session.user_id())
1672 {
1673 session
1674 .peer
1675 .send(
1676 connection_id,
1677 proto::CallCanceled {
1678 room_id: room_id.to_proto(),
1679 },
1680 )
1681 .trace_err();
1682 }
1683 update_user_contacts(session.user_id(), &session).await?;
1684 Ok(())
1685}
1686
1687/// Updates other participants in the room with your current location.
1688async fn update_participant_location(
1689 request: proto::UpdateParticipantLocation,
1690 response: Response<proto::UpdateParticipantLocation>,
1691 session: Session,
1692) -> Result<()> {
1693 let room_id = RoomId::from_proto(request.room_id);
1694 let location = request
1695 .location
1696 .ok_or_else(|| anyhow!("invalid location"))?;
1697
1698 let db = session.db().await;
1699 let room = db
1700 .update_room_participant_location(room_id, session.connection_id, location)
1701 .await?;
1702
1703 room_updated(&room, &session.peer);
1704 response.send(proto::Ack {})?;
1705 Ok(())
1706}
1707
1708/// Share a project into the room.
1709async fn share_project(
1710 request: proto::ShareProject,
1711 response: Response<proto::ShareProject>,
1712 session: Session,
1713) -> Result<()> {
1714 let (project_id, room) = &*session
1715 .db()
1716 .await
1717 .share_project(
1718 RoomId::from_proto(request.room_id),
1719 session.connection_id,
1720 &request.worktrees,
1721 request.is_ssh_project,
1722 )
1723 .await?;
1724 response.send(proto::ShareProjectResponse {
1725 project_id: project_id.to_proto(),
1726 })?;
1727 room_updated(room, &session.peer);
1728
1729 Ok(())
1730}
1731
1732/// Unshare a project from the room.
1733async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1734 let project_id = ProjectId::from_proto(message.project_id);
1735 unshare_project_internal(project_id, session.connection_id, &session).await
1736}
1737
1738async fn unshare_project_internal(
1739 project_id: ProjectId,
1740 connection_id: ConnectionId,
1741 session: &Session,
1742) -> Result<()> {
1743 let delete = {
1744 let room_guard = session
1745 .db()
1746 .await
1747 .unshare_project(project_id, connection_id)
1748 .await?;
1749
1750 let (delete, room, guest_connection_ids) = &*room_guard;
1751
1752 let message = proto::UnshareProject {
1753 project_id: project_id.to_proto(),
1754 };
1755
1756 broadcast(
1757 Some(connection_id),
1758 guest_connection_ids.iter().copied(),
1759 |conn_id| session.peer.send(conn_id, message.clone()),
1760 );
1761 if let Some(room) = room {
1762 room_updated(room, &session.peer);
1763 }
1764
1765 *delete
1766 };
1767
1768 if delete {
1769 let db = session.db().await;
1770 db.delete_project(project_id).await?;
1771 }
1772
1773 Ok(())
1774}
1775
1776/// Join someone elses shared project.
1777async fn join_project(
1778 request: proto::JoinProject,
1779 response: Response<proto::JoinProject>,
1780 session: Session,
1781) -> Result<()> {
1782 let project_id = ProjectId::from_proto(request.project_id);
1783
1784 tracing::info!(%project_id, "join project");
1785
1786 let db = session.db().await;
1787 let (project, replica_id) = &mut *db
1788 .join_project(project_id, session.connection_id, session.user_id())
1789 .await?;
1790 drop(db);
1791 tracing::info!(%project_id, "join remote project");
1792 join_project_internal(response, session, project, replica_id)
1793}
1794
1795trait JoinProjectInternalResponse {
1796 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1797}
1798impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1799 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1800 Response::<proto::JoinProject>::send(self, result)
1801 }
1802}
1803
1804fn join_project_internal(
1805 response: impl JoinProjectInternalResponse,
1806 session: Session,
1807 project: &mut Project,
1808 replica_id: &ReplicaId,
1809) -> Result<()> {
1810 let collaborators = project
1811 .collaborators
1812 .iter()
1813 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1814 .map(|collaborator| collaborator.to_proto())
1815 .collect::<Vec<_>>();
1816 let project_id = project.id;
1817 let guest_user_id = session.user_id();
1818
1819 let worktrees = project
1820 .worktrees
1821 .iter()
1822 .map(|(id, worktree)| proto::WorktreeMetadata {
1823 id: *id,
1824 root_name: worktree.root_name.clone(),
1825 visible: worktree.visible,
1826 abs_path: worktree.abs_path.clone(),
1827 })
1828 .collect::<Vec<_>>();
1829
1830 let add_project_collaborator = proto::AddProjectCollaborator {
1831 project_id: project_id.to_proto(),
1832 collaborator: Some(proto::Collaborator {
1833 peer_id: Some(session.connection_id.into()),
1834 replica_id: replica_id.0 as u32,
1835 user_id: guest_user_id.to_proto(),
1836 is_host: false,
1837 }),
1838 };
1839
1840 for collaborator in &collaborators {
1841 session
1842 .peer
1843 .send(
1844 collaborator.peer_id.unwrap().into(),
1845 add_project_collaborator.clone(),
1846 )
1847 .trace_err();
1848 }
1849
1850 // First, we send the metadata associated with each worktree.
1851 response.send(proto::JoinProjectResponse {
1852 project_id: project.id.0 as u64,
1853 worktrees: worktrees.clone(),
1854 replica_id: replica_id.0 as u32,
1855 collaborators: collaborators.clone(),
1856 language_servers: project.language_servers.clone(),
1857 role: project.role.into(),
1858 })?;
1859
1860 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1861 // Stream this worktree's entries.
1862 let message = proto::UpdateWorktree {
1863 project_id: project_id.to_proto(),
1864 worktree_id,
1865 abs_path: worktree.abs_path.clone(),
1866 root_name: worktree.root_name,
1867 updated_entries: worktree.entries,
1868 removed_entries: Default::default(),
1869 scan_id: worktree.scan_id,
1870 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1871 updated_repositories: worktree.repository_entries.into_values().collect(),
1872 removed_repositories: Default::default(),
1873 };
1874 for update in proto::split_worktree_update(message) {
1875 session.peer.send(session.connection_id, update.clone())?;
1876 }
1877
1878 // Stream this worktree's diagnostics.
1879 for summary in worktree.diagnostic_summaries {
1880 session.peer.send(
1881 session.connection_id,
1882 proto::UpdateDiagnosticSummary {
1883 project_id: project_id.to_proto(),
1884 worktree_id: worktree.id,
1885 summary: Some(summary),
1886 },
1887 )?;
1888 }
1889
1890 for settings_file in worktree.settings_files {
1891 session.peer.send(
1892 session.connection_id,
1893 proto::UpdateWorktreeSettings {
1894 project_id: project_id.to_proto(),
1895 worktree_id: worktree.id,
1896 path: settings_file.path,
1897 content: Some(settings_file.content),
1898 kind: Some(settings_file.kind.to_proto() as i32),
1899 },
1900 )?;
1901 }
1902 }
1903
1904 for language_server in &project.language_servers {
1905 session.peer.send(
1906 session.connection_id,
1907 proto::UpdateLanguageServer {
1908 project_id: project_id.to_proto(),
1909 language_server_id: language_server.id,
1910 variant: Some(
1911 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1912 proto::LspDiskBasedDiagnosticsUpdated {},
1913 ),
1914 ),
1915 },
1916 )?;
1917 }
1918
1919 Ok(())
1920}
1921
1922/// Leave someone elses shared project.
1923async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1924 let sender_id = session.connection_id;
1925 let project_id = ProjectId::from_proto(request.project_id);
1926 let db = session.db().await;
1927
1928 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1929 tracing::info!(
1930 %project_id,
1931 "leave project"
1932 );
1933
1934 project_left(project, &session);
1935 if let Some(room) = room {
1936 room_updated(room, &session.peer);
1937 }
1938
1939 Ok(())
1940}
1941
1942/// Updates other participants with changes to the project
1943async fn update_project(
1944 request: proto::UpdateProject,
1945 response: Response<proto::UpdateProject>,
1946 session: Session,
1947) -> Result<()> {
1948 let project_id = ProjectId::from_proto(request.project_id);
1949 let (room, guest_connection_ids) = &*session
1950 .db()
1951 .await
1952 .update_project(project_id, session.connection_id, &request.worktrees)
1953 .await?;
1954 broadcast(
1955 Some(session.connection_id),
1956 guest_connection_ids.iter().copied(),
1957 |connection_id| {
1958 session
1959 .peer
1960 .forward_send(session.connection_id, connection_id, request.clone())
1961 },
1962 );
1963 if let Some(room) = room {
1964 room_updated(room, &session.peer);
1965 }
1966 response.send(proto::Ack {})?;
1967
1968 Ok(())
1969}
1970
1971/// Updates other participants with changes to the worktree
1972async fn update_worktree(
1973 request: proto::UpdateWorktree,
1974 response: Response<proto::UpdateWorktree>,
1975 session: Session,
1976) -> Result<()> {
1977 let guest_connection_ids = session
1978 .db()
1979 .await
1980 .update_worktree(&request, session.connection_id)
1981 .await?;
1982
1983 broadcast(
1984 Some(session.connection_id),
1985 guest_connection_ids.iter().copied(),
1986 |connection_id| {
1987 session
1988 .peer
1989 .forward_send(session.connection_id, connection_id, request.clone())
1990 },
1991 );
1992 response.send(proto::Ack {})?;
1993 Ok(())
1994}
1995
1996/// Updates other participants with changes to the diagnostics
1997async fn update_diagnostic_summary(
1998 message: proto::UpdateDiagnosticSummary,
1999 session: Session,
2000) -> Result<()> {
2001 let guest_connection_ids = session
2002 .db()
2003 .await
2004 .update_diagnostic_summary(&message, session.connection_id)
2005 .await?;
2006
2007 broadcast(
2008 Some(session.connection_id),
2009 guest_connection_ids.iter().copied(),
2010 |connection_id| {
2011 session
2012 .peer
2013 .forward_send(session.connection_id, connection_id, message.clone())
2014 },
2015 );
2016
2017 Ok(())
2018}
2019
2020/// Updates other participants with changes to the worktree settings
2021async fn update_worktree_settings(
2022 message: proto::UpdateWorktreeSettings,
2023 session: Session,
2024) -> Result<()> {
2025 let guest_connection_ids = session
2026 .db()
2027 .await
2028 .update_worktree_settings(&message, session.connection_id)
2029 .await?;
2030
2031 broadcast(
2032 Some(session.connection_id),
2033 guest_connection_ids.iter().copied(),
2034 |connection_id| {
2035 session
2036 .peer
2037 .forward_send(session.connection_id, connection_id, message.clone())
2038 },
2039 );
2040
2041 Ok(())
2042}
2043
2044/// Notify other participants that a language server has started.
2045async fn start_language_server(
2046 request: proto::StartLanguageServer,
2047 session: Session,
2048) -> Result<()> {
2049 let guest_connection_ids = session
2050 .db()
2051 .await
2052 .start_language_server(&request, session.connection_id)
2053 .await?;
2054
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, request.clone())
2062 },
2063 );
2064 Ok(())
2065}
2066
2067/// Notify other participants that a language server has changed.
2068async fn update_language_server(
2069 request: proto::UpdateLanguageServer,
2070 session: Session,
2071) -> Result<()> {
2072 let project_id = ProjectId::from_proto(request.project_id);
2073 let project_connection_ids = session
2074 .db()
2075 .await
2076 .project_connection_ids(project_id, session.connection_id, true)
2077 .await?;
2078 broadcast(
2079 Some(session.connection_id),
2080 project_connection_ids.iter().copied(),
2081 |connection_id| {
2082 session
2083 .peer
2084 .forward_send(session.connection_id, connection_id, request.clone())
2085 },
2086 );
2087 Ok(())
2088}
2089
2090/// forward a project request to the host. These requests should be read only
2091/// as guests are allowed to send them.
2092async fn forward_read_only_project_request<T>(
2093 request: T,
2094 response: Response<T>,
2095 session: Session,
2096) -> Result<()>
2097where
2098 T: EntityMessage + RequestMessage,
2099{
2100 let project_id = ProjectId::from_proto(request.remote_entity_id());
2101 let host_connection_id = session
2102 .db()
2103 .await
2104 .host_for_read_only_project_request(project_id, session.connection_id)
2105 .await?;
2106 let payload = session
2107 .peer
2108 .forward_request(session.connection_id, host_connection_id, request)
2109 .await?;
2110 response.send(payload)?;
2111 Ok(())
2112}
2113
2114async fn forward_find_search_candidates_request(
2115 request: proto::FindSearchCandidates,
2116 response: Response<proto::FindSearchCandidates>,
2117 session: Session,
2118) -> Result<()> {
2119 let project_id = ProjectId::from_proto(request.remote_entity_id());
2120 let host_connection_id = session
2121 .db()
2122 .await
2123 .host_for_read_only_project_request(project_id, session.connection_id)
2124 .await?;
2125 let payload = session
2126 .peer
2127 .forward_request(session.connection_id, host_connection_id, request)
2128 .await?;
2129 response.send(payload)?;
2130 Ok(())
2131}
2132
2133/// forward a project request to the host. These requests are disallowed
2134/// for guests.
2135async fn forward_mutating_project_request<T>(
2136 request: T,
2137 response: Response<T>,
2138 session: Session,
2139) -> Result<()>
2140where
2141 T: EntityMessage + RequestMessage,
2142{
2143 let project_id = ProjectId::from_proto(request.remote_entity_id());
2144
2145 let host_connection_id = session
2146 .db()
2147 .await
2148 .host_for_mutating_project_request(project_id, session.connection_id)
2149 .await?;
2150 let payload = session
2151 .peer
2152 .forward_request(session.connection_id, host_connection_id, request)
2153 .await?;
2154 response.send(payload)?;
2155 Ok(())
2156}
2157
2158/// Notify other participants that a new buffer has been created
2159async fn create_buffer_for_peer(
2160 request: proto::CreateBufferForPeer,
2161 session: Session,
2162) -> Result<()> {
2163 session
2164 .db()
2165 .await
2166 .check_user_is_project_host(
2167 ProjectId::from_proto(request.project_id),
2168 session.connection_id,
2169 )
2170 .await?;
2171 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2172 session
2173 .peer
2174 .forward_send(session.connection_id, peer_id.into(), request)?;
2175 Ok(())
2176}
2177
2178/// Notify other participants that a buffer has been updated. This is
2179/// allowed for guests as long as the update is limited to selections.
2180async fn update_buffer(
2181 request: proto::UpdateBuffer,
2182 response: Response<proto::UpdateBuffer>,
2183 session: Session,
2184) -> Result<()> {
2185 let project_id = ProjectId::from_proto(request.project_id);
2186 let mut capability = Capability::ReadOnly;
2187
2188 for op in request.operations.iter() {
2189 match op.variant {
2190 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2191 Some(_) => capability = Capability::ReadWrite,
2192 }
2193 }
2194
2195 let host = {
2196 let guard = session
2197 .db()
2198 .await
2199 .connections_for_buffer_update(project_id, session.connection_id, capability)
2200 .await?;
2201
2202 let (host, guests) = &*guard;
2203
2204 broadcast(
2205 Some(session.connection_id),
2206 guests.clone(),
2207 |connection_id| {
2208 session
2209 .peer
2210 .forward_send(session.connection_id, connection_id, request.clone())
2211 },
2212 );
2213
2214 *host
2215 };
2216
2217 if host != session.connection_id {
2218 session
2219 .peer
2220 .forward_request(session.connection_id, host, request.clone())
2221 .await?;
2222 }
2223
2224 response.send(proto::Ack {})?;
2225 Ok(())
2226}
2227
2228async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2229 let project_id = ProjectId::from_proto(message.project_id);
2230
2231 let operation = message.operation.as_ref().context("invalid operation")?;
2232 let capability = match operation.variant.as_ref() {
2233 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2234 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2235 match buffer_op.variant {
2236 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2237 Capability::ReadOnly
2238 }
2239 _ => Capability::ReadWrite,
2240 }
2241 } else {
2242 Capability::ReadWrite
2243 }
2244 }
2245 Some(_) => Capability::ReadWrite,
2246 None => Capability::ReadOnly,
2247 };
2248
2249 let guard = session
2250 .db()
2251 .await
2252 .connections_for_buffer_update(project_id, session.connection_id, capability)
2253 .await?;
2254
2255 let (host, guests) = &*guard;
2256
2257 broadcast(
2258 Some(session.connection_id),
2259 guests.iter().chain([host]).copied(),
2260 |connection_id| {
2261 session
2262 .peer
2263 .forward_send(session.connection_id, connection_id, message.clone())
2264 },
2265 );
2266
2267 Ok(())
2268}
2269
2270/// Notify other participants that a project has been updated.
2271async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2272 request: T,
2273 session: Session,
2274) -> Result<()> {
2275 let project_id = ProjectId::from_proto(request.remote_entity_id());
2276 let project_connection_ids = session
2277 .db()
2278 .await
2279 .project_connection_ids(project_id, session.connection_id, false)
2280 .await?;
2281
2282 broadcast(
2283 Some(session.connection_id),
2284 project_connection_ids.iter().copied(),
2285 |connection_id| {
2286 session
2287 .peer
2288 .forward_send(session.connection_id, connection_id, request.clone())
2289 },
2290 );
2291 Ok(())
2292}
2293
2294/// Start following another user in a call.
2295async fn follow(
2296 request: proto::Follow,
2297 response: Response<proto::Follow>,
2298 session: Session,
2299) -> Result<()> {
2300 let room_id = RoomId::from_proto(request.room_id);
2301 let project_id = request.project_id.map(ProjectId::from_proto);
2302 let leader_id = request
2303 .leader_id
2304 .ok_or_else(|| anyhow!("invalid leader id"))?
2305 .into();
2306 let follower_id = session.connection_id;
2307
2308 session
2309 .db()
2310 .await
2311 .check_room_participants(room_id, leader_id, session.connection_id)
2312 .await?;
2313
2314 let response_payload = session
2315 .peer
2316 .forward_request(session.connection_id, leader_id, request)
2317 .await?;
2318 response.send(response_payload)?;
2319
2320 if let Some(project_id) = project_id {
2321 let room = session
2322 .db()
2323 .await
2324 .follow(room_id, project_id, leader_id, follower_id)
2325 .await?;
2326 room_updated(&room, &session.peer);
2327 }
2328
2329 Ok(())
2330}
2331
2332/// Stop following another user in a call.
2333async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2334 let room_id = RoomId::from_proto(request.room_id);
2335 let project_id = request.project_id.map(ProjectId::from_proto);
2336 let leader_id = request
2337 .leader_id
2338 .ok_or_else(|| anyhow!("invalid leader id"))?
2339 .into();
2340 let follower_id = session.connection_id;
2341
2342 session
2343 .db()
2344 .await
2345 .check_room_participants(room_id, leader_id, session.connection_id)
2346 .await?;
2347
2348 session
2349 .peer
2350 .forward_send(session.connection_id, leader_id, request)?;
2351
2352 if let Some(project_id) = project_id {
2353 let room = session
2354 .db()
2355 .await
2356 .unfollow(room_id, project_id, leader_id, follower_id)
2357 .await?;
2358 room_updated(&room, &session.peer);
2359 }
2360
2361 Ok(())
2362}
2363
2364/// Notify everyone following you of your current location.
2365async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2366 let room_id = RoomId::from_proto(request.room_id);
2367 let database = session.db.lock().await;
2368
2369 let connection_ids = if let Some(project_id) = request.project_id {
2370 let project_id = ProjectId::from_proto(project_id);
2371 database
2372 .project_connection_ids(project_id, session.connection_id, true)
2373 .await?
2374 } else {
2375 database
2376 .room_connection_ids(room_id, session.connection_id)
2377 .await?
2378 };
2379
2380 // For now, don't send view update messages back to that view's current leader.
2381 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2382 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2383 _ => None,
2384 });
2385
2386 for connection_id in connection_ids.iter().cloned() {
2387 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2388 session
2389 .peer
2390 .forward_send(session.connection_id, connection_id, request.clone())?;
2391 }
2392 }
2393 Ok(())
2394}
2395
2396/// Get public data about users.
2397async fn get_users(
2398 request: proto::GetUsers,
2399 response: Response<proto::GetUsers>,
2400 session: Session,
2401) -> Result<()> {
2402 let user_ids = request
2403 .user_ids
2404 .into_iter()
2405 .map(UserId::from_proto)
2406 .collect();
2407 let users = session
2408 .db()
2409 .await
2410 .get_users_by_ids(user_ids)
2411 .await?
2412 .into_iter()
2413 .map(|user| proto::User {
2414 id: user.id.to_proto(),
2415 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2416 github_login: user.github_login,
2417 })
2418 .collect();
2419 response.send(proto::UsersResponse { users })?;
2420 Ok(())
2421}
2422
2423/// Search for users (to invite) buy Github login
2424async fn fuzzy_search_users(
2425 request: proto::FuzzySearchUsers,
2426 response: Response<proto::FuzzySearchUsers>,
2427 session: Session,
2428) -> Result<()> {
2429 let query = request.query;
2430 let users = match query.len() {
2431 0 => vec![],
2432 1 | 2 => session
2433 .db()
2434 .await
2435 .get_user_by_github_login(&query)
2436 .await?
2437 .into_iter()
2438 .collect(),
2439 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2440 };
2441 let users = users
2442 .into_iter()
2443 .filter(|user| user.id != session.user_id())
2444 .map(|user| proto::User {
2445 id: user.id.to_proto(),
2446 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2447 github_login: user.github_login,
2448 })
2449 .collect();
2450 response.send(proto::UsersResponse { users })?;
2451 Ok(())
2452}
2453
2454/// Send a contact request to another user.
2455async fn request_contact(
2456 request: proto::RequestContact,
2457 response: Response<proto::RequestContact>,
2458 session: Session,
2459) -> Result<()> {
2460 let requester_id = session.user_id();
2461 let responder_id = UserId::from_proto(request.responder_id);
2462 if requester_id == responder_id {
2463 return Err(anyhow!("cannot add yourself as a contact"))?;
2464 }
2465
2466 let notifications = session
2467 .db()
2468 .await
2469 .send_contact_request(requester_id, responder_id)
2470 .await?;
2471
2472 // Update outgoing contact requests of requester
2473 let mut update = proto::UpdateContacts::default();
2474 update.outgoing_requests.push(responder_id.to_proto());
2475 for connection_id in session
2476 .connection_pool()
2477 .await
2478 .user_connection_ids(requester_id)
2479 {
2480 session.peer.send(connection_id, update.clone())?;
2481 }
2482
2483 // Update incoming contact requests of responder
2484 let mut update = proto::UpdateContacts::default();
2485 update
2486 .incoming_requests
2487 .push(proto::IncomingContactRequest {
2488 requester_id: requester_id.to_proto(),
2489 });
2490 let connection_pool = session.connection_pool().await;
2491 for connection_id in connection_pool.user_connection_ids(responder_id) {
2492 session.peer.send(connection_id, update.clone())?;
2493 }
2494
2495 send_notifications(&connection_pool, &session.peer, notifications);
2496
2497 response.send(proto::Ack {})?;
2498 Ok(())
2499}
2500
2501/// Accept or decline a contact request
2502async fn respond_to_contact_request(
2503 request: proto::RespondToContactRequest,
2504 response: Response<proto::RespondToContactRequest>,
2505 session: Session,
2506) -> Result<()> {
2507 let responder_id = session.user_id();
2508 let requester_id = UserId::from_proto(request.requester_id);
2509 let db = session.db().await;
2510 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2511 db.dismiss_contact_notification(responder_id, requester_id)
2512 .await?;
2513 } else {
2514 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2515
2516 let notifications = db
2517 .respond_to_contact_request(responder_id, requester_id, accept)
2518 .await?;
2519 let requester_busy = db.is_user_busy(requester_id).await?;
2520 let responder_busy = db.is_user_busy(responder_id).await?;
2521
2522 let pool = session.connection_pool().await;
2523 // Update responder with new contact
2524 let mut update = proto::UpdateContacts::default();
2525 if accept {
2526 update
2527 .contacts
2528 .push(contact_for_user(requester_id, requester_busy, &pool));
2529 }
2530 update
2531 .remove_incoming_requests
2532 .push(requester_id.to_proto());
2533 for connection_id in pool.user_connection_ids(responder_id) {
2534 session.peer.send(connection_id, update.clone())?;
2535 }
2536
2537 // Update requester with new contact
2538 let mut update = proto::UpdateContacts::default();
2539 if accept {
2540 update
2541 .contacts
2542 .push(contact_for_user(responder_id, responder_busy, &pool));
2543 }
2544 update
2545 .remove_outgoing_requests
2546 .push(responder_id.to_proto());
2547
2548 for connection_id in pool.user_connection_ids(requester_id) {
2549 session.peer.send(connection_id, update.clone())?;
2550 }
2551
2552 send_notifications(&pool, &session.peer, notifications);
2553 }
2554
2555 response.send(proto::Ack {})?;
2556 Ok(())
2557}
2558
2559/// Remove a contact.
2560async fn remove_contact(
2561 request: proto::RemoveContact,
2562 response: Response<proto::RemoveContact>,
2563 session: Session,
2564) -> Result<()> {
2565 let requester_id = session.user_id();
2566 let responder_id = UserId::from_proto(request.user_id);
2567 let db = session.db().await;
2568 let (contact_accepted, deleted_notification_id) =
2569 db.remove_contact(requester_id, responder_id).await?;
2570
2571 let pool = session.connection_pool().await;
2572 // Update outgoing contact requests of requester
2573 let mut update = proto::UpdateContacts::default();
2574 if contact_accepted {
2575 update.remove_contacts.push(responder_id.to_proto());
2576 } else {
2577 update
2578 .remove_outgoing_requests
2579 .push(responder_id.to_proto());
2580 }
2581 for connection_id in pool.user_connection_ids(requester_id) {
2582 session.peer.send(connection_id, update.clone())?;
2583 }
2584
2585 // Update incoming contact requests of responder
2586 let mut update = proto::UpdateContacts::default();
2587 if contact_accepted {
2588 update.remove_contacts.push(requester_id.to_proto());
2589 } else {
2590 update
2591 .remove_incoming_requests
2592 .push(requester_id.to_proto());
2593 }
2594 for connection_id in pool.user_connection_ids(responder_id) {
2595 session.peer.send(connection_id, update.clone())?;
2596 if let Some(notification_id) = deleted_notification_id {
2597 session.peer.send(
2598 connection_id,
2599 proto::DeleteNotification {
2600 notification_id: notification_id.to_proto(),
2601 },
2602 )?;
2603 }
2604 }
2605
2606 response.send(proto::Ack {})?;
2607 Ok(())
2608}
2609
2610fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2611 version.0.minor() < 139
2612}
2613
2614async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2615 let plan = session.current_plan(&session.db().await).await?;
2616
2617 session
2618 .peer
2619 .send(
2620 session.connection_id,
2621 proto::UpdateUserPlan { plan: plan.into() },
2622 )
2623 .trace_err();
2624
2625 Ok(())
2626}
2627
2628async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2629 subscribe_user_to_channels(session.user_id(), &session).await?;
2630 Ok(())
2631}
2632
2633async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2634 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2635 let mut pool = session.connection_pool().await;
2636 for membership in &channels_for_user.channel_memberships {
2637 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2638 }
2639 session.peer.send(
2640 session.connection_id,
2641 build_update_user_channels(&channels_for_user),
2642 )?;
2643 session.peer.send(
2644 session.connection_id,
2645 build_channels_update(channels_for_user),
2646 )?;
2647 Ok(())
2648}
2649
2650/// Creates a new channel.
2651async fn create_channel(
2652 request: proto::CreateChannel,
2653 response: Response<proto::CreateChannel>,
2654 session: Session,
2655) -> Result<()> {
2656 let db = session.db().await;
2657
2658 let parent_id = request.parent_id.map(ChannelId::from_proto);
2659 let (channel, membership) = db
2660 .create_channel(&request.name, parent_id, session.user_id())
2661 .await?;
2662
2663 let root_id = channel.root_id();
2664 let channel = Channel::from_model(channel);
2665
2666 response.send(proto::CreateChannelResponse {
2667 channel: Some(channel.to_proto()),
2668 parent_id: request.parent_id,
2669 })?;
2670
2671 let mut connection_pool = session.connection_pool().await;
2672 if let Some(membership) = membership {
2673 connection_pool.subscribe_to_channel(
2674 membership.user_id,
2675 membership.channel_id,
2676 membership.role,
2677 );
2678 let update = proto::UpdateUserChannels {
2679 channel_memberships: vec![proto::ChannelMembership {
2680 channel_id: membership.channel_id.to_proto(),
2681 role: membership.role.into(),
2682 }],
2683 ..Default::default()
2684 };
2685 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2686 session.peer.send(connection_id, update.clone())?;
2687 }
2688 }
2689
2690 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2691 if !role.can_see_channel(channel.visibility) {
2692 continue;
2693 }
2694
2695 let update = proto::UpdateChannels {
2696 channels: vec![channel.to_proto()],
2697 ..Default::default()
2698 };
2699 session.peer.send(connection_id, update.clone())?;
2700 }
2701
2702 Ok(())
2703}
2704
2705/// Delete a channel
2706async fn delete_channel(
2707 request: proto::DeleteChannel,
2708 response: Response<proto::DeleteChannel>,
2709 session: Session,
2710) -> Result<()> {
2711 let db = session.db().await;
2712
2713 let channel_id = request.channel_id;
2714 let (root_channel, removed_channels) = db
2715 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2716 .await?;
2717 response.send(proto::Ack {})?;
2718
2719 // Notify members of removed channels
2720 let mut update = proto::UpdateChannels::default();
2721 update
2722 .delete_channels
2723 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2724
2725 let connection_pool = session.connection_pool().await;
2726 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2727 session.peer.send(connection_id, update.clone())?;
2728 }
2729
2730 Ok(())
2731}
2732
2733/// Invite someone to join a channel.
2734async fn invite_channel_member(
2735 request: proto::InviteChannelMember,
2736 response: Response<proto::InviteChannelMember>,
2737 session: Session,
2738) -> Result<()> {
2739 let db = session.db().await;
2740 let channel_id = ChannelId::from_proto(request.channel_id);
2741 let invitee_id = UserId::from_proto(request.user_id);
2742 let InviteMemberResult {
2743 channel,
2744 notifications,
2745 } = db
2746 .invite_channel_member(
2747 channel_id,
2748 invitee_id,
2749 session.user_id(),
2750 request.role().into(),
2751 )
2752 .await?;
2753
2754 let update = proto::UpdateChannels {
2755 channel_invitations: vec![channel.to_proto()],
2756 ..Default::default()
2757 };
2758
2759 let connection_pool = session.connection_pool().await;
2760 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2761 session.peer.send(connection_id, update.clone())?;
2762 }
2763
2764 send_notifications(&connection_pool, &session.peer, notifications);
2765
2766 response.send(proto::Ack {})?;
2767 Ok(())
2768}
2769
2770/// remove someone from a channel
2771async fn remove_channel_member(
2772 request: proto::RemoveChannelMember,
2773 response: Response<proto::RemoveChannelMember>,
2774 session: Session,
2775) -> Result<()> {
2776 let db = session.db().await;
2777 let channel_id = ChannelId::from_proto(request.channel_id);
2778 let member_id = UserId::from_proto(request.user_id);
2779
2780 let RemoveChannelMemberResult {
2781 membership_update,
2782 notification_id,
2783 } = db
2784 .remove_channel_member(channel_id, member_id, session.user_id())
2785 .await?;
2786
2787 let mut connection_pool = session.connection_pool().await;
2788 notify_membership_updated(
2789 &mut connection_pool,
2790 membership_update,
2791 member_id,
2792 &session.peer,
2793 );
2794 for connection_id in connection_pool.user_connection_ids(member_id) {
2795 if let Some(notification_id) = notification_id {
2796 session
2797 .peer
2798 .send(
2799 connection_id,
2800 proto::DeleteNotification {
2801 notification_id: notification_id.to_proto(),
2802 },
2803 )
2804 .trace_err();
2805 }
2806 }
2807
2808 response.send(proto::Ack {})?;
2809 Ok(())
2810}
2811
2812/// Toggle the channel between public and private.
2813/// Care is taken to maintain the invariant that public channels only descend from public channels,
2814/// (though members-only channels can appear at any point in the hierarchy).
2815async fn set_channel_visibility(
2816 request: proto::SetChannelVisibility,
2817 response: Response<proto::SetChannelVisibility>,
2818 session: Session,
2819) -> Result<()> {
2820 let db = session.db().await;
2821 let channel_id = ChannelId::from_proto(request.channel_id);
2822 let visibility = request.visibility().into();
2823
2824 let channel_model = db
2825 .set_channel_visibility(channel_id, visibility, session.user_id())
2826 .await?;
2827 let root_id = channel_model.root_id();
2828 let channel = Channel::from_model(channel_model);
2829
2830 let mut connection_pool = session.connection_pool().await;
2831 for (user_id, role) in connection_pool
2832 .channel_user_ids(root_id)
2833 .collect::<Vec<_>>()
2834 .into_iter()
2835 {
2836 let update = if role.can_see_channel(channel.visibility) {
2837 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2838 proto::UpdateChannels {
2839 channels: vec![channel.to_proto()],
2840 ..Default::default()
2841 }
2842 } else {
2843 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2844 proto::UpdateChannels {
2845 delete_channels: vec![channel.id.to_proto()],
2846 ..Default::default()
2847 }
2848 };
2849
2850 for connection_id in connection_pool.user_connection_ids(user_id) {
2851 session.peer.send(connection_id, update.clone())?;
2852 }
2853 }
2854
2855 response.send(proto::Ack {})?;
2856 Ok(())
2857}
2858
2859/// Alter the role for a user in the channel.
2860async fn set_channel_member_role(
2861 request: proto::SetChannelMemberRole,
2862 response: Response<proto::SetChannelMemberRole>,
2863 session: Session,
2864) -> Result<()> {
2865 let db = session.db().await;
2866 let channel_id = ChannelId::from_proto(request.channel_id);
2867 let member_id = UserId::from_proto(request.user_id);
2868 let result = db
2869 .set_channel_member_role(
2870 channel_id,
2871 session.user_id(),
2872 member_id,
2873 request.role().into(),
2874 )
2875 .await?;
2876
2877 match result {
2878 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2879 let mut connection_pool = session.connection_pool().await;
2880 notify_membership_updated(
2881 &mut connection_pool,
2882 membership_update,
2883 member_id,
2884 &session.peer,
2885 )
2886 }
2887 db::SetMemberRoleResult::InviteUpdated(channel) => {
2888 let update = proto::UpdateChannels {
2889 channel_invitations: vec![channel.to_proto()],
2890 ..Default::default()
2891 };
2892
2893 for connection_id in session
2894 .connection_pool()
2895 .await
2896 .user_connection_ids(member_id)
2897 {
2898 session.peer.send(connection_id, update.clone())?;
2899 }
2900 }
2901 }
2902
2903 response.send(proto::Ack {})?;
2904 Ok(())
2905}
2906
2907/// Change the name of a channel
2908async fn rename_channel(
2909 request: proto::RenameChannel,
2910 response: Response<proto::RenameChannel>,
2911 session: Session,
2912) -> Result<()> {
2913 let db = session.db().await;
2914 let channel_id = ChannelId::from_proto(request.channel_id);
2915 let channel_model = db
2916 .rename_channel(channel_id, session.user_id(), &request.name)
2917 .await?;
2918 let root_id = channel_model.root_id();
2919 let channel = Channel::from_model(channel_model);
2920
2921 response.send(proto::RenameChannelResponse {
2922 channel: Some(channel.to_proto()),
2923 })?;
2924
2925 let connection_pool = session.connection_pool().await;
2926 let update = proto::UpdateChannels {
2927 channels: vec![channel.to_proto()],
2928 ..Default::default()
2929 };
2930 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2931 if role.can_see_channel(channel.visibility) {
2932 session.peer.send(connection_id, update.clone())?;
2933 }
2934 }
2935
2936 Ok(())
2937}
2938
2939/// Move a channel to a new parent.
2940async fn move_channel(
2941 request: proto::MoveChannel,
2942 response: Response<proto::MoveChannel>,
2943 session: Session,
2944) -> Result<()> {
2945 let channel_id = ChannelId::from_proto(request.channel_id);
2946 let to = ChannelId::from_proto(request.to);
2947
2948 let (root_id, channels) = session
2949 .db()
2950 .await
2951 .move_channel(channel_id, to, session.user_id())
2952 .await?;
2953
2954 let connection_pool = session.connection_pool().await;
2955 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2956 let channels = channels
2957 .iter()
2958 .filter_map(|channel| {
2959 if role.can_see_channel(channel.visibility) {
2960 Some(channel.to_proto())
2961 } else {
2962 None
2963 }
2964 })
2965 .collect::<Vec<_>>();
2966 if channels.is_empty() {
2967 continue;
2968 }
2969
2970 let update = proto::UpdateChannels {
2971 channels,
2972 ..Default::default()
2973 };
2974
2975 session.peer.send(connection_id, update.clone())?;
2976 }
2977
2978 response.send(Ack {})?;
2979 Ok(())
2980}
2981
2982/// Get the list of channel members
2983async fn get_channel_members(
2984 request: proto::GetChannelMembers,
2985 response: Response<proto::GetChannelMembers>,
2986 session: Session,
2987) -> Result<()> {
2988 let db = session.db().await;
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 let limit = if request.limit == 0 {
2991 u16::MAX as u64
2992 } else {
2993 request.limit
2994 };
2995 let (members, users) = db
2996 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2997 .await?;
2998 response.send(proto::GetChannelMembersResponse { members, users })?;
2999 Ok(())
3000}
3001
3002/// Accept or decline a channel invitation.
3003async fn respond_to_channel_invite(
3004 request: proto::RespondToChannelInvite,
3005 response: Response<proto::RespondToChannelInvite>,
3006 session: Session,
3007) -> Result<()> {
3008 let db = session.db().await;
3009 let channel_id = ChannelId::from_proto(request.channel_id);
3010 let RespondToChannelInvite {
3011 membership_update,
3012 notifications,
3013 } = db
3014 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3015 .await?;
3016
3017 let mut connection_pool = session.connection_pool().await;
3018 if let Some(membership_update) = membership_update {
3019 notify_membership_updated(
3020 &mut connection_pool,
3021 membership_update,
3022 session.user_id(),
3023 &session.peer,
3024 );
3025 } else {
3026 let update = proto::UpdateChannels {
3027 remove_channel_invitations: vec![channel_id.to_proto()],
3028 ..Default::default()
3029 };
3030
3031 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3032 session.peer.send(connection_id, update.clone())?;
3033 }
3034 };
3035
3036 send_notifications(&connection_pool, &session.peer, notifications);
3037
3038 response.send(proto::Ack {})?;
3039
3040 Ok(())
3041}
3042
3043/// Join the channels' room
3044async fn join_channel(
3045 request: proto::JoinChannel,
3046 response: Response<proto::JoinChannel>,
3047 session: Session,
3048) -> Result<()> {
3049 let channel_id = ChannelId::from_proto(request.channel_id);
3050 join_channel_internal(channel_id, Box::new(response), session).await
3051}
3052
3053trait JoinChannelInternalResponse {
3054 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3055}
3056impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3057 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3058 Response::<proto::JoinChannel>::send(self, result)
3059 }
3060}
3061impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3062 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3063 Response::<proto::JoinRoom>::send(self, result)
3064 }
3065}
3066
3067async fn join_channel_internal(
3068 channel_id: ChannelId,
3069 response: Box<impl JoinChannelInternalResponse>,
3070 session: Session,
3071) -> Result<()> {
3072 let joined_room = {
3073 let mut db = session.db().await;
3074 // If zed quits without leaving the room, and the user re-opens zed before the
3075 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3076 // room they were in.
3077 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3078 tracing::info!(
3079 stale_connection_id = %connection,
3080 "cleaning up stale connection",
3081 );
3082 drop(db);
3083 leave_room_for_session(&session, connection).await?;
3084 db = session.db().await;
3085 }
3086
3087 let (joined_room, membership_updated, role) = db
3088 .join_channel(channel_id, session.user_id(), session.connection_id)
3089 .await?;
3090
3091 let live_kit_connection_info =
3092 session
3093 .app_state
3094 .live_kit_client
3095 .as_ref()
3096 .and_then(|live_kit| {
3097 let (can_publish, token) = if role == ChannelRole::Guest {
3098 (
3099 false,
3100 live_kit
3101 .guest_token(
3102 &joined_room.room.live_kit_room,
3103 &session.user_id().to_string(),
3104 )
3105 .trace_err()?,
3106 )
3107 } else {
3108 (
3109 true,
3110 live_kit
3111 .room_token(
3112 &joined_room.room.live_kit_room,
3113 &session.user_id().to_string(),
3114 )
3115 .trace_err()?,
3116 )
3117 };
3118
3119 Some(LiveKitConnectionInfo {
3120 server_url: live_kit.url().into(),
3121 token,
3122 can_publish,
3123 })
3124 });
3125
3126 response.send(proto::JoinRoomResponse {
3127 room: Some(joined_room.room.clone()),
3128 channel_id: joined_room
3129 .channel
3130 .as_ref()
3131 .map(|channel| channel.id.to_proto()),
3132 live_kit_connection_info,
3133 })?;
3134
3135 let mut connection_pool = session.connection_pool().await;
3136 if let Some(membership_updated) = membership_updated {
3137 notify_membership_updated(
3138 &mut connection_pool,
3139 membership_updated,
3140 session.user_id(),
3141 &session.peer,
3142 );
3143 }
3144
3145 room_updated(&joined_room.room, &session.peer);
3146
3147 joined_room
3148 };
3149
3150 channel_updated(
3151 &joined_room
3152 .channel
3153 .ok_or_else(|| anyhow!("channel not returned"))?,
3154 &joined_room.room,
3155 &session.peer,
3156 &*session.connection_pool().await,
3157 );
3158
3159 update_user_contacts(session.user_id(), &session).await?;
3160 Ok(())
3161}
3162
3163/// Start editing the channel notes
3164async fn join_channel_buffer(
3165 request: proto::JoinChannelBuffer,
3166 response: Response<proto::JoinChannelBuffer>,
3167 session: Session,
3168) -> Result<()> {
3169 let db = session.db().await;
3170 let channel_id = ChannelId::from_proto(request.channel_id);
3171
3172 let open_response = db
3173 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3174 .await?;
3175
3176 let collaborators = open_response.collaborators.clone();
3177 response.send(open_response)?;
3178
3179 let update = UpdateChannelBufferCollaborators {
3180 channel_id: channel_id.to_proto(),
3181 collaborators: collaborators.clone(),
3182 };
3183 channel_buffer_updated(
3184 session.connection_id,
3185 collaborators
3186 .iter()
3187 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3188 &update,
3189 &session.peer,
3190 );
3191
3192 Ok(())
3193}
3194
3195/// Edit the channel notes
3196async fn update_channel_buffer(
3197 request: proto::UpdateChannelBuffer,
3198 session: Session,
3199) -> Result<()> {
3200 let db = session.db().await;
3201 let channel_id = ChannelId::from_proto(request.channel_id);
3202
3203 let (collaborators, epoch, version) = db
3204 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3205 .await?;
3206
3207 channel_buffer_updated(
3208 session.connection_id,
3209 collaborators.clone(),
3210 &proto::UpdateChannelBuffer {
3211 channel_id: channel_id.to_proto(),
3212 operations: request.operations,
3213 },
3214 &session.peer,
3215 );
3216
3217 let pool = &*session.connection_pool().await;
3218
3219 let non_collaborators =
3220 pool.channel_connection_ids(channel_id)
3221 .filter_map(|(connection_id, _)| {
3222 if collaborators.contains(&connection_id) {
3223 None
3224 } else {
3225 Some(connection_id)
3226 }
3227 });
3228
3229 broadcast(None, non_collaborators, |peer_id| {
3230 session.peer.send(
3231 peer_id,
3232 proto::UpdateChannels {
3233 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3234 channel_id: channel_id.to_proto(),
3235 epoch: epoch as u64,
3236 version: version.clone(),
3237 }],
3238 ..Default::default()
3239 },
3240 )
3241 });
3242
3243 Ok(())
3244}
3245
3246/// Rejoin the channel notes after a connection blip
3247async fn rejoin_channel_buffers(
3248 request: proto::RejoinChannelBuffers,
3249 response: Response<proto::RejoinChannelBuffers>,
3250 session: Session,
3251) -> Result<()> {
3252 let db = session.db().await;
3253 let buffers = db
3254 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3255 .await?;
3256
3257 for rejoined_buffer in &buffers {
3258 let collaborators_to_notify = rejoined_buffer
3259 .buffer
3260 .collaborators
3261 .iter()
3262 .filter_map(|c| Some(c.peer_id?.into()));
3263 channel_buffer_updated(
3264 session.connection_id,
3265 collaborators_to_notify,
3266 &proto::UpdateChannelBufferCollaborators {
3267 channel_id: rejoined_buffer.buffer.channel_id,
3268 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3269 },
3270 &session.peer,
3271 );
3272 }
3273
3274 response.send(proto::RejoinChannelBuffersResponse {
3275 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3276 })?;
3277
3278 Ok(())
3279}
3280
3281/// Stop editing the channel notes
3282async fn leave_channel_buffer(
3283 request: proto::LeaveChannelBuffer,
3284 response: Response<proto::LeaveChannelBuffer>,
3285 session: Session,
3286) -> Result<()> {
3287 let db = session.db().await;
3288 let channel_id = ChannelId::from_proto(request.channel_id);
3289
3290 let left_buffer = db
3291 .leave_channel_buffer(channel_id, session.connection_id)
3292 .await?;
3293
3294 response.send(Ack {})?;
3295
3296 channel_buffer_updated(
3297 session.connection_id,
3298 left_buffer.connections,
3299 &proto::UpdateChannelBufferCollaborators {
3300 channel_id: channel_id.to_proto(),
3301 collaborators: left_buffer.collaborators,
3302 },
3303 &session.peer,
3304 );
3305
3306 Ok(())
3307}
3308
3309fn channel_buffer_updated<T: EnvelopedMessage>(
3310 sender_id: ConnectionId,
3311 collaborators: impl IntoIterator<Item = ConnectionId>,
3312 message: &T,
3313 peer: &Peer,
3314) {
3315 broadcast(Some(sender_id), collaborators, |peer_id| {
3316 peer.send(peer_id, message.clone())
3317 });
3318}
3319
3320fn send_notifications(
3321 connection_pool: &ConnectionPool,
3322 peer: &Peer,
3323 notifications: db::NotificationBatch,
3324) {
3325 for (user_id, notification) in notifications {
3326 for connection_id in connection_pool.user_connection_ids(user_id) {
3327 if let Err(error) = peer.send(
3328 connection_id,
3329 proto::AddNotification {
3330 notification: Some(notification.clone()),
3331 },
3332 ) {
3333 tracing::error!(
3334 "failed to send notification to {:?} {}",
3335 connection_id,
3336 error
3337 );
3338 }
3339 }
3340 }
3341}
3342
3343/// Send a message to the channel
3344async fn send_channel_message(
3345 request: proto::SendChannelMessage,
3346 response: Response<proto::SendChannelMessage>,
3347 session: Session,
3348) -> Result<()> {
3349 // Validate the message body.
3350 let body = request.body.trim().to_string();
3351 if body.len() > MAX_MESSAGE_LEN {
3352 return Err(anyhow!("message is too long"))?;
3353 }
3354 if body.is_empty() {
3355 return Err(anyhow!("message can't be blank"))?;
3356 }
3357
3358 // TODO: adjust mentions if body is trimmed
3359
3360 let timestamp = OffsetDateTime::now_utc();
3361 let nonce = request
3362 .nonce
3363 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3364
3365 let channel_id = ChannelId::from_proto(request.channel_id);
3366 let CreatedChannelMessage {
3367 message_id,
3368 participant_connection_ids,
3369 notifications,
3370 } = session
3371 .db()
3372 .await
3373 .create_channel_message(
3374 channel_id,
3375 session.user_id(),
3376 &body,
3377 &request.mentions,
3378 timestamp,
3379 nonce.clone().into(),
3380 request.reply_to_message_id.map(MessageId::from_proto),
3381 )
3382 .await?;
3383
3384 let message = proto::ChannelMessage {
3385 sender_id: session.user_id().to_proto(),
3386 id: message_id.to_proto(),
3387 body,
3388 mentions: request.mentions,
3389 timestamp: timestamp.unix_timestamp() as u64,
3390 nonce: Some(nonce),
3391 reply_to_message_id: request.reply_to_message_id,
3392 edited_at: None,
3393 };
3394 broadcast(
3395 Some(session.connection_id),
3396 participant_connection_ids.clone(),
3397 |connection| {
3398 session.peer.send(
3399 connection,
3400 proto::ChannelMessageSent {
3401 channel_id: channel_id.to_proto(),
3402 message: Some(message.clone()),
3403 },
3404 )
3405 },
3406 );
3407 response.send(proto::SendChannelMessageResponse {
3408 message: Some(message),
3409 })?;
3410
3411 let pool = &*session.connection_pool().await;
3412 let non_participants =
3413 pool.channel_connection_ids(channel_id)
3414 .filter_map(|(connection_id, _)| {
3415 if participant_connection_ids.contains(&connection_id) {
3416 None
3417 } else {
3418 Some(connection_id)
3419 }
3420 });
3421 broadcast(None, non_participants, |peer_id| {
3422 session.peer.send(
3423 peer_id,
3424 proto::UpdateChannels {
3425 latest_channel_message_ids: vec![proto::ChannelMessageId {
3426 channel_id: channel_id.to_proto(),
3427 message_id: message_id.to_proto(),
3428 }],
3429 ..Default::default()
3430 },
3431 )
3432 });
3433 send_notifications(pool, &session.peer, notifications);
3434
3435 Ok(())
3436}
3437
3438/// Delete a channel message
3439async fn remove_channel_message(
3440 request: proto::RemoveChannelMessage,
3441 response: Response<proto::RemoveChannelMessage>,
3442 session: Session,
3443) -> Result<()> {
3444 let channel_id = ChannelId::from_proto(request.channel_id);
3445 let message_id = MessageId::from_proto(request.message_id);
3446 let (connection_ids, existing_notification_ids) = session
3447 .db()
3448 .await
3449 .remove_channel_message(channel_id, message_id, session.user_id())
3450 .await?;
3451
3452 broadcast(
3453 Some(session.connection_id),
3454 connection_ids,
3455 move |connection| {
3456 session.peer.send(connection, request.clone())?;
3457
3458 for notification_id in &existing_notification_ids {
3459 session.peer.send(
3460 connection,
3461 proto::DeleteNotification {
3462 notification_id: (*notification_id).to_proto(),
3463 },
3464 )?;
3465 }
3466
3467 Ok(())
3468 },
3469 );
3470 response.send(proto::Ack {})?;
3471 Ok(())
3472}
3473
3474async fn update_channel_message(
3475 request: proto::UpdateChannelMessage,
3476 response: Response<proto::UpdateChannelMessage>,
3477 session: Session,
3478) -> Result<()> {
3479 let channel_id = ChannelId::from_proto(request.channel_id);
3480 let message_id = MessageId::from_proto(request.message_id);
3481 let updated_at = OffsetDateTime::now_utc();
3482 let UpdatedChannelMessage {
3483 message_id,
3484 participant_connection_ids,
3485 notifications,
3486 reply_to_message_id,
3487 timestamp,
3488 deleted_mention_notification_ids,
3489 updated_mention_notifications,
3490 } = session
3491 .db()
3492 .await
3493 .update_channel_message(
3494 channel_id,
3495 message_id,
3496 session.user_id(),
3497 request.body.as_str(),
3498 &request.mentions,
3499 updated_at,
3500 )
3501 .await?;
3502
3503 let nonce = request
3504 .nonce
3505 .clone()
3506 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3507
3508 let message = proto::ChannelMessage {
3509 sender_id: session.user_id().to_proto(),
3510 id: message_id.to_proto(),
3511 body: request.body.clone(),
3512 mentions: request.mentions.clone(),
3513 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3514 nonce: Some(nonce),
3515 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3516 edited_at: Some(updated_at.unix_timestamp() as u64),
3517 };
3518
3519 response.send(proto::Ack {})?;
3520
3521 let pool = &*session.connection_pool().await;
3522 broadcast(
3523 Some(session.connection_id),
3524 participant_connection_ids,
3525 |connection| {
3526 session.peer.send(
3527 connection,
3528 proto::ChannelMessageUpdate {
3529 channel_id: channel_id.to_proto(),
3530 message: Some(message.clone()),
3531 },
3532 )?;
3533
3534 for notification_id in &deleted_mention_notification_ids {
3535 session.peer.send(
3536 connection,
3537 proto::DeleteNotification {
3538 notification_id: (*notification_id).to_proto(),
3539 },
3540 )?;
3541 }
3542
3543 for notification in &updated_mention_notifications {
3544 session.peer.send(
3545 connection,
3546 proto::UpdateNotification {
3547 notification: Some(notification.clone()),
3548 },
3549 )?;
3550 }
3551
3552 Ok(())
3553 },
3554 );
3555
3556 send_notifications(pool, &session.peer, notifications);
3557
3558 Ok(())
3559}
3560
3561/// Mark a channel message as read
3562async fn acknowledge_channel_message(
3563 request: proto::AckChannelMessage,
3564 session: Session,
3565) -> Result<()> {
3566 let channel_id = ChannelId::from_proto(request.channel_id);
3567 let message_id = MessageId::from_proto(request.message_id);
3568 let notifications = session
3569 .db()
3570 .await
3571 .observe_channel_message(channel_id, session.user_id(), message_id)
3572 .await?;
3573 send_notifications(
3574 &*session.connection_pool().await,
3575 &session.peer,
3576 notifications,
3577 );
3578 Ok(())
3579}
3580
3581/// Mark a buffer version as synced
3582async fn acknowledge_buffer_version(
3583 request: proto::AckBufferOperation,
3584 session: Session,
3585) -> Result<()> {
3586 let buffer_id = BufferId::from_proto(request.buffer_id);
3587 session
3588 .db()
3589 .await
3590 .observe_buffer_version(
3591 buffer_id,
3592 session.user_id(),
3593 request.epoch as i32,
3594 &request.version,
3595 )
3596 .await?;
3597 Ok(())
3598}
3599
3600async fn count_language_model_tokens(
3601 request: proto::CountLanguageModelTokens,
3602 response: Response<proto::CountLanguageModelTokens>,
3603 session: Session,
3604 config: &Config,
3605) -> Result<()> {
3606 authorize_access_to_legacy_llm_endpoints(&session).await?;
3607
3608 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3609 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3610 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3611 };
3612
3613 session
3614 .app_state
3615 .rate_limiter
3616 .check(&*rate_limit, session.user_id())
3617 .await?;
3618
3619 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3620 Some(proto::LanguageModelProvider::Google) => {
3621 let api_key = config
3622 .google_ai_api_key
3623 .as_ref()
3624 .context("no Google AI API key configured on the server")?;
3625 google_ai::count_tokens(
3626 session.http_client.as_ref(),
3627 google_ai::API_URL,
3628 api_key,
3629 serde_json::from_str(&request.request)?,
3630 )
3631 .await?
3632 }
3633 _ => return Err(anyhow!("unsupported provider"))?,
3634 };
3635
3636 response.send(proto::CountLanguageModelTokensResponse {
3637 token_count: result.total_tokens as u32,
3638 })?;
3639
3640 Ok(())
3641}
3642
3643struct ZedProCountLanguageModelTokensRateLimit;
3644
3645impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3646 fn capacity(&self) -> usize {
3647 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3648 .ok()
3649 .and_then(|v| v.parse().ok())
3650 .unwrap_or(600) // Picked arbitrarily
3651 }
3652
3653 fn refill_duration(&self) -> chrono::Duration {
3654 chrono::Duration::hours(1)
3655 }
3656
3657 fn db_name(&self) -> &'static str {
3658 "zed-pro:count-language-model-tokens"
3659 }
3660}
3661
3662struct FreeCountLanguageModelTokensRateLimit;
3663
3664impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3665 fn capacity(&self) -> usize {
3666 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3667 .ok()
3668 .and_then(|v| v.parse().ok())
3669 .unwrap_or(600 / 10) // Picked arbitrarily
3670 }
3671
3672 fn refill_duration(&self) -> chrono::Duration {
3673 chrono::Duration::hours(1)
3674 }
3675
3676 fn db_name(&self) -> &'static str {
3677 "free:count-language-model-tokens"
3678 }
3679}
3680
3681struct ZedProComputeEmbeddingsRateLimit;
3682
3683impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3684 fn capacity(&self) -> usize {
3685 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3686 .ok()
3687 .and_then(|v| v.parse().ok())
3688 .unwrap_or(5000) // Picked arbitrarily
3689 }
3690
3691 fn refill_duration(&self) -> chrono::Duration {
3692 chrono::Duration::hours(1)
3693 }
3694
3695 fn db_name(&self) -> &'static str {
3696 "zed-pro:compute-embeddings"
3697 }
3698}
3699
3700struct FreeComputeEmbeddingsRateLimit;
3701
3702impl RateLimit for FreeComputeEmbeddingsRateLimit {
3703 fn capacity(&self) -> usize {
3704 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3705 .ok()
3706 .and_then(|v| v.parse().ok())
3707 .unwrap_or(5000 / 10) // Picked arbitrarily
3708 }
3709
3710 fn refill_duration(&self) -> chrono::Duration {
3711 chrono::Duration::hours(1)
3712 }
3713
3714 fn db_name(&self) -> &'static str {
3715 "free:compute-embeddings"
3716 }
3717}
3718
3719async fn compute_embeddings(
3720 request: proto::ComputeEmbeddings,
3721 response: Response<proto::ComputeEmbeddings>,
3722 session: Session,
3723 api_key: Option<Arc<str>>,
3724) -> Result<()> {
3725 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3726 authorize_access_to_legacy_llm_endpoints(&session).await?;
3727
3728 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3729 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3730 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3731 };
3732
3733 session
3734 .app_state
3735 .rate_limiter
3736 .check(&*rate_limit, session.user_id())
3737 .await?;
3738
3739 let embeddings = match request.model.as_str() {
3740 "openai/text-embedding-3-small" => {
3741 open_ai::embed(
3742 session.http_client.as_ref(),
3743 OPEN_AI_API_URL,
3744 &api_key,
3745 OpenAiEmbeddingModel::TextEmbedding3Small,
3746 request.texts.iter().map(|text| text.as_str()),
3747 )
3748 .await?
3749 }
3750 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3751 };
3752
3753 let embeddings = request
3754 .texts
3755 .iter()
3756 .map(|text| {
3757 let mut hasher = sha2::Sha256::new();
3758 hasher.update(text.as_bytes());
3759 let result = hasher.finalize();
3760 result.to_vec()
3761 })
3762 .zip(
3763 embeddings
3764 .data
3765 .into_iter()
3766 .map(|embedding| embedding.embedding),
3767 )
3768 .collect::<HashMap<_, _>>();
3769
3770 let db = session.db().await;
3771 db.save_embeddings(&request.model, &embeddings)
3772 .await
3773 .context("failed to save embeddings")
3774 .trace_err();
3775
3776 response.send(proto::ComputeEmbeddingsResponse {
3777 embeddings: embeddings
3778 .into_iter()
3779 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3780 .collect(),
3781 })?;
3782 Ok(())
3783}
3784
3785async fn get_cached_embeddings(
3786 request: proto::GetCachedEmbeddings,
3787 response: Response<proto::GetCachedEmbeddings>,
3788 session: Session,
3789) -> Result<()> {
3790 authorize_access_to_legacy_llm_endpoints(&session).await?;
3791
3792 let db = session.db().await;
3793 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3794
3795 response.send(proto::GetCachedEmbeddingsResponse {
3796 embeddings: embeddings
3797 .into_iter()
3798 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3799 .collect(),
3800 })?;
3801 Ok(())
3802}
3803
3804/// This is leftover from before the LLM service.
3805///
3806/// The endpoints protected by this check will be moved there eventually.
3807async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3808 if session.is_staff() {
3809 Ok(())
3810 } else {
3811 Err(anyhow!("permission denied"))?
3812 }
3813}
3814
3815/// Get a Supermaven API key for the user
3816async fn get_supermaven_api_key(
3817 _request: proto::GetSupermavenApiKey,
3818 response: Response<proto::GetSupermavenApiKey>,
3819 session: Session,
3820) -> Result<()> {
3821 let user_id: String = session.user_id().to_string();
3822 if !session.is_staff() {
3823 return Err(anyhow!("supermaven not enabled for this account"))?;
3824 }
3825
3826 let email = session
3827 .email()
3828 .ok_or_else(|| anyhow!("user must have an email"))?;
3829
3830 let supermaven_admin_api = session
3831 .supermaven_client
3832 .as_ref()
3833 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3834
3835 let result = supermaven_admin_api
3836 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3837 .await?;
3838
3839 response.send(proto::GetSupermavenApiKeyResponse {
3840 api_key: result.api_key,
3841 })?;
3842
3843 Ok(())
3844}
3845
3846/// Start receiving chat updates for a channel
3847async fn join_channel_chat(
3848 request: proto::JoinChannelChat,
3849 response: Response<proto::JoinChannelChat>,
3850 session: Session,
3851) -> Result<()> {
3852 let channel_id = ChannelId::from_proto(request.channel_id);
3853
3854 let db = session.db().await;
3855 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3856 .await?;
3857 let messages = db
3858 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3859 .await?;
3860 response.send(proto::JoinChannelChatResponse {
3861 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3862 messages,
3863 })?;
3864 Ok(())
3865}
3866
3867/// Stop receiving chat updates for a channel
3868async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3869 let channel_id = ChannelId::from_proto(request.channel_id);
3870 session
3871 .db()
3872 .await
3873 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3874 .await?;
3875 Ok(())
3876}
3877
3878/// Retrieve the chat history for a channel
3879async fn get_channel_messages(
3880 request: proto::GetChannelMessages,
3881 response: Response<proto::GetChannelMessages>,
3882 session: Session,
3883) -> Result<()> {
3884 let channel_id = ChannelId::from_proto(request.channel_id);
3885 let messages = session
3886 .db()
3887 .await
3888 .get_channel_messages(
3889 channel_id,
3890 session.user_id(),
3891 MESSAGE_COUNT_PER_PAGE,
3892 Some(MessageId::from_proto(request.before_message_id)),
3893 )
3894 .await?;
3895 response.send(proto::GetChannelMessagesResponse {
3896 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3897 messages,
3898 })?;
3899 Ok(())
3900}
3901
3902/// Retrieve specific chat messages
3903async fn get_channel_messages_by_id(
3904 request: proto::GetChannelMessagesById,
3905 response: Response<proto::GetChannelMessagesById>,
3906 session: Session,
3907) -> Result<()> {
3908 let message_ids = request
3909 .message_ids
3910 .iter()
3911 .map(|id| MessageId::from_proto(*id))
3912 .collect::<Vec<_>>();
3913 let messages = session
3914 .db()
3915 .await
3916 .get_channel_messages_by_id(session.user_id(), &message_ids)
3917 .await?;
3918 response.send(proto::GetChannelMessagesResponse {
3919 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3920 messages,
3921 })?;
3922 Ok(())
3923}
3924
3925/// Retrieve the current users notifications
3926async fn get_notifications(
3927 request: proto::GetNotifications,
3928 response: Response<proto::GetNotifications>,
3929 session: Session,
3930) -> Result<()> {
3931 let notifications = session
3932 .db()
3933 .await
3934 .get_notifications(
3935 session.user_id(),
3936 NOTIFICATION_COUNT_PER_PAGE,
3937 request.before_id.map(db::NotificationId::from_proto),
3938 )
3939 .await?;
3940 response.send(proto::GetNotificationsResponse {
3941 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3942 notifications,
3943 })?;
3944 Ok(())
3945}
3946
3947/// Mark notifications as read
3948async fn mark_notification_as_read(
3949 request: proto::MarkNotificationRead,
3950 response: Response<proto::MarkNotificationRead>,
3951 session: Session,
3952) -> Result<()> {
3953 let database = &session.db().await;
3954 let notifications = database
3955 .mark_notification_as_read_by_id(
3956 session.user_id(),
3957 NotificationId::from_proto(request.notification_id),
3958 )
3959 .await?;
3960 send_notifications(
3961 &*session.connection_pool().await,
3962 &session.peer,
3963 notifications,
3964 );
3965 response.send(proto::Ack {})?;
3966 Ok(())
3967}
3968
3969/// Get the current users information
3970async fn get_private_user_info(
3971 _request: proto::GetPrivateUserInfo,
3972 response: Response<proto::GetPrivateUserInfo>,
3973 session: Session,
3974) -> Result<()> {
3975 let db = session.db().await;
3976
3977 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3978 let user = db
3979 .get_user_by_id(session.user_id())
3980 .await?
3981 .ok_or_else(|| anyhow!("user not found"))?;
3982 let flags = db.get_user_flags(session.user_id()).await?;
3983
3984 response.send(proto::GetPrivateUserInfoResponse {
3985 metrics_id,
3986 staff: user.admin,
3987 flags,
3988 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3989 })?;
3990 Ok(())
3991}
3992
3993/// Accept the terms of service (tos) on behalf of the current user
3994async fn accept_terms_of_service(
3995 _request: proto::AcceptTermsOfService,
3996 response: Response<proto::AcceptTermsOfService>,
3997 session: Session,
3998) -> Result<()> {
3999 let db = session.db().await;
4000
4001 let accepted_tos_at = Utc::now();
4002 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4003 .await?;
4004
4005 response.send(proto::AcceptTermsOfServiceResponse {
4006 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4007 })?;
4008 Ok(())
4009}
4010
4011/// The minimum account age an account must have in order to use the LLM service.
4012const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4013
4014async fn get_llm_api_token(
4015 _request: proto::GetLlmToken,
4016 response: Response<proto::GetLlmToken>,
4017 session: Session,
4018) -> Result<()> {
4019 let db = session.db().await;
4020
4021 let flags = db.get_user_flags(session.user_id()).await?;
4022 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4023 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4024
4025 if !session.is_staff() && !has_language_models_feature_flag {
4026 Err(anyhow!("permission denied"))?
4027 }
4028
4029 let user_id = session.user_id();
4030 let user = db
4031 .get_user_by_id(user_id)
4032 .await?
4033 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4034
4035 if user.accepted_tos_at.is_none() {
4036 Err(anyhow!("terms of service not accepted"))?
4037 }
4038
4039 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4040
4041 let bypass_account_age_check =
4042 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4043 if !bypass_account_age_check {
4044 let mut account_created_at = user.created_at;
4045 if let Some(github_created_at) = user.github_user_created_at {
4046 account_created_at = account_created_at.min(github_created_at);
4047 }
4048 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4049 Err(anyhow!("account too young"))?
4050 }
4051 }
4052
4053 let billing_preferences = db.get_billing_preferences(user.id).await?;
4054
4055 let token = LlmTokenClaims::create(
4056 &user,
4057 session.is_staff(),
4058 billing_preferences,
4059 has_llm_closed_beta_feature_flag,
4060 has_llm_subscription,
4061 session.current_plan(&db).await?,
4062 session.system_id.clone(),
4063 &session.app_state.config,
4064 )?;
4065 response.send(proto::GetLlmTokenResponse { token })?;
4066 Ok(())
4067}
4068
4069fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4070 let message = match message {
4071 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4072 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4073 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4074 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4075 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4076 code: frame.code.into(),
4077 reason: frame.reason,
4078 })),
4079 // We should never receive a frame while reading the message, according
4080 // to the `tungstenite` maintainers:
4081 //
4082 // > It cannot occur when you read messages from the WebSocket, but it
4083 // > can be used when you want to send the raw frames (e.g. you want to
4084 // > send the frames to the WebSocket without composing the full message first).
4085 // >
4086 // > — https://github.com/snapview/tungstenite-rs/issues/268
4087 TungsteniteMessage::Frame(_) => {
4088 bail!("received an unexpected frame while reading the message")
4089 }
4090 };
4091
4092 Ok(message)
4093}
4094
4095fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4096 match message {
4097 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4098 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4099 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4100 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4101 AxumMessage::Close(frame) => {
4102 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4103 code: frame.code.into(),
4104 reason: frame.reason,
4105 }))
4106 }
4107 }
4108}
4109
4110fn notify_membership_updated(
4111 connection_pool: &mut ConnectionPool,
4112 result: MembershipUpdated,
4113 user_id: UserId,
4114 peer: &Peer,
4115) {
4116 for membership in &result.new_channels.channel_memberships {
4117 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4118 }
4119 for channel_id in &result.removed_channels {
4120 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4121 }
4122
4123 let user_channels_update = proto::UpdateUserChannels {
4124 channel_memberships: result
4125 .new_channels
4126 .channel_memberships
4127 .iter()
4128 .map(|cm| proto::ChannelMembership {
4129 channel_id: cm.channel_id.to_proto(),
4130 role: cm.role.into(),
4131 })
4132 .collect(),
4133 ..Default::default()
4134 };
4135
4136 let mut update = build_channels_update(result.new_channels);
4137 update.delete_channels = result
4138 .removed_channels
4139 .into_iter()
4140 .map(|id| id.to_proto())
4141 .collect();
4142 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4143
4144 for connection_id in connection_pool.user_connection_ids(user_id) {
4145 peer.send(connection_id, user_channels_update.clone())
4146 .trace_err();
4147 peer.send(connection_id, update.clone()).trace_err();
4148 }
4149}
4150
4151fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4152 proto::UpdateUserChannels {
4153 channel_memberships: channels
4154 .channel_memberships
4155 .iter()
4156 .map(|m| proto::ChannelMembership {
4157 channel_id: m.channel_id.to_proto(),
4158 role: m.role.into(),
4159 })
4160 .collect(),
4161 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4162 observed_channel_message_id: channels.observed_channel_messages.clone(),
4163 }
4164}
4165
4166fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4167 let mut update = proto::UpdateChannels::default();
4168
4169 for channel in channels.channels {
4170 update.channels.push(channel.to_proto());
4171 }
4172
4173 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4174 update.latest_channel_message_ids = channels.latest_channel_messages;
4175
4176 for (channel_id, participants) in channels.channel_participants {
4177 update
4178 .channel_participants
4179 .push(proto::ChannelParticipants {
4180 channel_id: channel_id.to_proto(),
4181 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4182 });
4183 }
4184
4185 for channel in channels.invited_channels {
4186 update.channel_invitations.push(channel.to_proto());
4187 }
4188
4189 update
4190}
4191
4192fn build_initial_contacts_update(
4193 contacts: Vec<db::Contact>,
4194 pool: &ConnectionPool,
4195) -> proto::UpdateContacts {
4196 let mut update = proto::UpdateContacts::default();
4197
4198 for contact in contacts {
4199 match contact {
4200 db::Contact::Accepted { user_id, busy } => {
4201 update.contacts.push(contact_for_user(user_id, busy, pool));
4202 }
4203 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4204 db::Contact::Incoming { user_id } => {
4205 update
4206 .incoming_requests
4207 .push(proto::IncomingContactRequest {
4208 requester_id: user_id.to_proto(),
4209 })
4210 }
4211 }
4212 }
4213
4214 update
4215}
4216
4217fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4218 proto::Contact {
4219 user_id: user_id.to_proto(),
4220 online: pool.is_user_online(user_id),
4221 busy,
4222 }
4223}
4224
4225fn room_updated(room: &proto::Room, peer: &Peer) {
4226 broadcast(
4227 None,
4228 room.participants
4229 .iter()
4230 .filter_map(|participant| Some(participant.peer_id?.into())),
4231 |peer_id| {
4232 peer.send(
4233 peer_id,
4234 proto::RoomUpdated {
4235 room: Some(room.clone()),
4236 },
4237 )
4238 },
4239 );
4240}
4241
4242fn channel_updated(
4243 channel: &db::channel::Model,
4244 room: &proto::Room,
4245 peer: &Peer,
4246 pool: &ConnectionPool,
4247) {
4248 let participants = room
4249 .participants
4250 .iter()
4251 .map(|p| p.user_id)
4252 .collect::<Vec<_>>();
4253
4254 broadcast(
4255 None,
4256 pool.channel_connection_ids(channel.root_id())
4257 .filter_map(|(channel_id, role)| {
4258 role.can_see_channel(channel.visibility)
4259 .then_some(channel_id)
4260 }),
4261 |peer_id| {
4262 peer.send(
4263 peer_id,
4264 proto::UpdateChannels {
4265 channel_participants: vec![proto::ChannelParticipants {
4266 channel_id: channel.id.to_proto(),
4267 participant_user_ids: participants.clone(),
4268 }],
4269 ..Default::default()
4270 },
4271 )
4272 },
4273 );
4274}
4275
4276async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4277 let db = session.db().await;
4278
4279 let contacts = db.get_contacts(user_id).await?;
4280 let busy = db.is_user_busy(user_id).await?;
4281
4282 let pool = session.connection_pool().await;
4283 let updated_contact = contact_for_user(user_id, busy, &pool);
4284 for contact in contacts {
4285 if let db::Contact::Accepted {
4286 user_id: contact_user_id,
4287 ..
4288 } = contact
4289 {
4290 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4291 session
4292 .peer
4293 .send(
4294 contact_conn_id,
4295 proto::UpdateContacts {
4296 contacts: vec![updated_contact.clone()],
4297 remove_contacts: Default::default(),
4298 incoming_requests: Default::default(),
4299 remove_incoming_requests: Default::default(),
4300 outgoing_requests: Default::default(),
4301 remove_outgoing_requests: Default::default(),
4302 },
4303 )
4304 .trace_err();
4305 }
4306 }
4307 }
4308 Ok(())
4309}
4310
4311async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4312 let mut contacts_to_update = HashSet::default();
4313
4314 let room_id;
4315 let canceled_calls_to_user_ids;
4316 let live_kit_room;
4317 let delete_live_kit_room;
4318 let room;
4319 let channel;
4320
4321 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4322 contacts_to_update.insert(session.user_id());
4323
4324 for project in left_room.left_projects.values() {
4325 project_left(project, session);
4326 }
4327
4328 room_id = RoomId::from_proto(left_room.room.id);
4329 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4330 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4331 delete_live_kit_room = left_room.deleted;
4332 room = mem::take(&mut left_room.room);
4333 channel = mem::take(&mut left_room.channel);
4334
4335 room_updated(&room, &session.peer);
4336 } else {
4337 return Ok(());
4338 }
4339
4340 if let Some(channel) = channel {
4341 channel_updated(
4342 &channel,
4343 &room,
4344 &session.peer,
4345 &*session.connection_pool().await,
4346 );
4347 }
4348
4349 {
4350 let pool = session.connection_pool().await;
4351 for canceled_user_id in canceled_calls_to_user_ids {
4352 for connection_id in pool.user_connection_ids(canceled_user_id) {
4353 session
4354 .peer
4355 .send(
4356 connection_id,
4357 proto::CallCanceled {
4358 room_id: room_id.to_proto(),
4359 },
4360 )
4361 .trace_err();
4362 }
4363 contacts_to_update.insert(canceled_user_id);
4364 }
4365 }
4366
4367 for contact_user_id in contacts_to_update {
4368 update_user_contacts(contact_user_id, session).await?;
4369 }
4370
4371 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
4372 live_kit
4373 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4374 .await
4375 .trace_err();
4376
4377 if delete_live_kit_room {
4378 live_kit.delete_room(live_kit_room).await.trace_err();
4379 }
4380 }
4381
4382 Ok(())
4383}
4384
4385async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4386 let left_channel_buffers = session
4387 .db()
4388 .await
4389 .leave_channel_buffers(session.connection_id)
4390 .await?;
4391
4392 for left_buffer in left_channel_buffers {
4393 channel_buffer_updated(
4394 session.connection_id,
4395 left_buffer.connections,
4396 &proto::UpdateChannelBufferCollaborators {
4397 channel_id: left_buffer.channel_id.to_proto(),
4398 collaborators: left_buffer.collaborators,
4399 },
4400 &session.peer,
4401 );
4402 }
4403
4404 Ok(())
4405}
4406
4407fn project_left(project: &db::LeftProject, session: &Session) {
4408 for connection_id in &project.connection_ids {
4409 if project.should_unshare {
4410 session
4411 .peer
4412 .send(
4413 *connection_id,
4414 proto::UnshareProject {
4415 project_id: project.id.to_proto(),
4416 },
4417 )
4418 .trace_err();
4419 } else {
4420 session
4421 .peer
4422 .send(
4423 *connection_id,
4424 proto::RemoveProjectCollaborator {
4425 project_id: project.id.to_proto(),
4426 peer_id: Some(session.connection_id.into()),
4427 },
4428 )
4429 .trace_err();
4430 }
4431 }
4432}
4433
4434pub trait ResultExt {
4435 type Ok;
4436
4437 fn trace_err(self) -> Option<Self::Ok>;
4438}
4439
4440impl<T, E> ResultExt for Result<T, E>
4441where
4442 E: std::fmt::Debug,
4443{
4444 type Ok = T;
4445
4446 #[track_caller]
4447 fn trace_err(self) -> Option<T> {
4448 match self {
4449 Ok(value) => Some(value),
4450 Err(error) => {
4451 tracing::error!("{:?}", error);
4452 None
4453 }
4454 }
4455 }
4456}