1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 _executor: Executor,
141}
142
143impl Session {
144 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
145 #[cfg(test)]
146 tokio::task::yield_now().await;
147 let guard = self.db.lock().await;
148 #[cfg(test)]
149 tokio::task::yield_now().await;
150 guard
151 }
152
153 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.connection_pool.lock();
157 ConnectionPoolGuard {
158 guard,
159 _not_send: PhantomData,
160 }
161 }
162
163 fn is_staff(&self) -> bool {
164 match &self.principal {
165 Principal::User(user) => user.admin,
166 Principal::Impersonated { .. } => true,
167 }
168 }
169
170 pub async fn has_llm_subscription(
171 &self,
172 db: &MutexGuard<'_, DbHandle>,
173 ) -> anyhow::Result<bool> {
174 if self.is_staff() {
175 return Ok(true);
176 }
177
178 let user_id = self.user_id();
179
180 Ok(db.has_active_billing_subscription(user_id).await?)
181 }
182
183 pub async fn current_plan(
184 &self,
185 _db: &MutexGuard<'_, DbHandle>,
186 ) -> anyhow::Result<proto::Plan> {
187 if self.is_staff() {
188 Ok(proto::Plan::ZedPro)
189 } else {
190 Ok(proto::Plan::Free)
191 }
192 }
193
194 fn user_id(&self) -> UserId {
195 match &self.principal {
196 Principal::User(user) => user.id,
197 Principal::Impersonated { user, .. } => user.id,
198 }
199 }
200
201 pub fn email(&self) -> Option<String> {
202 match &self.principal {
203 Principal::User(user) => user.email_address.clone(),
204 Principal::Impersonated { user, .. } => user.email_address.clone(),
205 }
206 }
207}
208
209impl Debug for Session {
210 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211 let mut result = f.debug_struct("Session");
212 match &self.principal {
213 Principal::User(user) => {
214 result.field("user", &user.github_login);
215 }
216 Principal::Impersonated { user, admin } => {
217 result.field("user", &user.github_login);
218 result.field("impersonator", &admin.github_login);
219 }
220 }
221 result.field("connection_id", &self.connection_id).finish()
222 }
223}
224
225struct DbHandle(Arc<Database>);
226
227impl Deref for DbHandle {
228 type Target = Database;
229
230 fn deref(&self) -> &Self::Target {
231 self.0.as_ref()
232 }
233}
234
235pub struct Server {
236 id: parking_lot::Mutex<ServerId>,
237 peer: Arc<Peer>,
238 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
239 app_state: Arc<AppState>,
240 handlers: HashMap<TypeId, MessageHandler>,
241 teardown: watch::Sender<bool>,
242}
243
244pub(crate) struct ConnectionPoolGuard<'a> {
245 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
246 _not_send: PhantomData<Rc<()>>,
247}
248
249#[derive(Serialize)]
250pub struct ServerSnapshot<'a> {
251 peer: &'a Peer,
252 #[serde(serialize_with = "serialize_deref")]
253 connection_pool: ConnectionPoolGuard<'a>,
254}
255
256pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
257where
258 S: Serializer,
259 T: Deref<Target = U>,
260 U: Serialize,
261{
262 Serialize::serialize(value.deref(), serializer)
263}
264
265impl Server {
266 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
267 let mut server = Self {
268 id: parking_lot::Mutex::new(id),
269 peer: Peer::new(id.0 as u32),
270 app_state: app_state.clone(),
271 connection_pool: Default::default(),
272 handlers: Default::default(),
273 teardown: watch::channel(false).0,
274 };
275
276 server
277 .add_request_handler(ping)
278 .add_request_handler(create_room)
279 .add_request_handler(join_room)
280 .add_request_handler(rejoin_room)
281 .add_request_handler(leave_room)
282 .add_request_handler(set_room_participant_role)
283 .add_request_handler(call)
284 .add_request_handler(cancel_call)
285 .add_message_handler(decline_call)
286 .add_request_handler(update_participant_location)
287 .add_request_handler(share_project)
288 .add_message_handler(unshare_project)
289 .add_request_handler(join_project)
290 .add_message_handler(leave_project)
291 .add_request_handler(update_project)
292 .add_request_handler(update_worktree)
293 .add_message_handler(start_language_server)
294 .add_message_handler(update_language_server)
295 .add_message_handler(update_diagnostic_summary)
296 .add_message_handler(update_worktree_settings)
297 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
298 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
301 .add_request_handler(forward_find_search_candidates_request)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
304 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
306 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
307 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
308 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
310 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
311 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
312 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
313 .add_request_handler(
314 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
315 )
316 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
317 .add_request_handler(
318 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
319 )
320 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
321 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
322 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
323 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
324 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
325 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
326 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
327 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
328 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
329 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
330 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
331 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
332 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
333 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
334 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
335 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
336 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
337 .add_message_handler(create_buffer_for_peer)
338 .add_request_handler(update_buffer)
339 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
340 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
341 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
342 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
343 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
344 .add_request_handler(get_users)
345 .add_request_handler(fuzzy_search_users)
346 .add_request_handler(request_contact)
347 .add_request_handler(remove_contact)
348 .add_request_handler(respond_to_contact_request)
349 .add_message_handler(subscribe_to_channels)
350 .add_request_handler(create_channel)
351 .add_request_handler(delete_channel)
352 .add_request_handler(invite_channel_member)
353 .add_request_handler(remove_channel_member)
354 .add_request_handler(set_channel_member_role)
355 .add_request_handler(set_channel_visibility)
356 .add_request_handler(rename_channel)
357 .add_request_handler(join_channel_buffer)
358 .add_request_handler(leave_channel_buffer)
359 .add_message_handler(update_channel_buffer)
360 .add_request_handler(rejoin_channel_buffers)
361 .add_request_handler(get_channel_members)
362 .add_request_handler(respond_to_channel_invite)
363 .add_request_handler(join_channel)
364 .add_request_handler(join_channel_chat)
365 .add_message_handler(leave_channel_chat)
366 .add_request_handler(send_channel_message)
367 .add_request_handler(remove_channel_message)
368 .add_request_handler(update_channel_message)
369 .add_request_handler(get_channel_messages)
370 .add_request_handler(get_channel_messages_by_id)
371 .add_request_handler(get_notifications)
372 .add_request_handler(mark_notification_as_read)
373 .add_request_handler(move_channel)
374 .add_request_handler(follow)
375 .add_message_handler(unfollow)
376 .add_message_handler(update_followers)
377 .add_request_handler(get_private_user_info)
378 .add_request_handler(get_llm_api_token)
379 .add_request_handler(accept_terms_of_service)
380 .add_message_handler(acknowledge_channel_message)
381 .add_message_handler(acknowledge_buffer_version)
382 .add_request_handler(get_supermaven_api_key)
383 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
384 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
385 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
387 .add_message_handler(update_context)
388 .add_request_handler({
389 let app_state = app_state.clone();
390 move |request, response, session| {
391 let app_state = app_state.clone();
392 async move {
393 count_language_model_tokens(request, response, session, &app_state.config)
394 .await
395 }
396 }
397 })
398 .add_request_handler(get_cached_embeddings)
399 .add_request_handler({
400 let app_state = app_state.clone();
401 move |request, response, session| {
402 compute_embeddings(
403 request,
404 response,
405 session,
406 app_state.config.openai_api_key.clone(),
407 )
408 }
409 });
410
411 Arc::new(server)
412 }
413
414 pub async fn start(&self) -> Result<()> {
415 let server_id = *self.id.lock();
416 let app_state = self.app_state.clone();
417 let peer = self.peer.clone();
418 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
419 let pool = self.connection_pool.clone();
420 let live_kit_client = self.app_state.live_kit_client.clone();
421
422 let span = info_span!("start server");
423 self.app_state.executor.spawn_detached(
424 async move {
425 tracing::info!("waiting for cleanup timeout");
426 timeout.await;
427 tracing::info!("cleanup timeout expired, retrieving stale rooms");
428 if let Some((room_ids, channel_ids)) = app_state
429 .db
430 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
431 .await
432 .trace_err()
433 {
434 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
435 tracing::info!(
436 stale_channel_buffer_count = channel_ids.len(),
437 "retrieved stale channel buffers"
438 );
439
440 for channel_id in channel_ids {
441 if let Some(refreshed_channel_buffer) = app_state
442 .db
443 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
444 .await
445 .trace_err()
446 {
447 for connection_id in refreshed_channel_buffer.connection_ids {
448 peer.send(
449 connection_id,
450 proto::UpdateChannelBufferCollaborators {
451 channel_id: channel_id.to_proto(),
452 collaborators: refreshed_channel_buffer
453 .collaborators
454 .clone(),
455 },
456 )
457 .trace_err();
458 }
459 }
460 }
461
462 for room_id in room_ids {
463 let mut contacts_to_update = HashSet::default();
464 let mut canceled_calls_to_user_ids = Vec::new();
465 let mut live_kit_room = String::new();
466 let mut delete_live_kit_room = false;
467
468 if let Some(mut refreshed_room) = app_state
469 .db
470 .clear_stale_room_participants(room_id, server_id)
471 .await
472 .trace_err()
473 {
474 tracing::info!(
475 room_id = room_id.0,
476 new_participant_count = refreshed_room.room.participants.len(),
477 "refreshed room"
478 );
479 room_updated(&refreshed_room.room, &peer);
480 if let Some(channel) = refreshed_room.channel.as_ref() {
481 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
482 }
483 contacts_to_update
484 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
485 contacts_to_update
486 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
487 canceled_calls_to_user_ids =
488 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
489 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
490 delete_live_kit_room = refreshed_room.room.participants.is_empty();
491 }
492
493 {
494 let pool = pool.lock();
495 for canceled_user_id in canceled_calls_to_user_ids {
496 for connection_id in pool.user_connection_ids(canceled_user_id) {
497 peer.send(
498 connection_id,
499 proto::CallCanceled {
500 room_id: room_id.to_proto(),
501 },
502 )
503 .trace_err();
504 }
505 }
506 }
507
508 for user_id in contacts_to_update {
509 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
510 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
511 if let Some((busy, contacts)) = busy.zip(contacts) {
512 let pool = pool.lock();
513 let updated_contact = contact_for_user(user_id, busy, &pool);
514 for contact in contacts {
515 if let db::Contact::Accepted {
516 user_id: contact_user_id,
517 ..
518 } = contact
519 {
520 for contact_conn_id in
521 pool.user_connection_ids(contact_user_id)
522 {
523 peer.send(
524 contact_conn_id,
525 proto::UpdateContacts {
526 contacts: vec![updated_contact.clone()],
527 remove_contacts: Default::default(),
528 incoming_requests: Default::default(),
529 remove_incoming_requests: Default::default(),
530 outgoing_requests: Default::default(),
531 remove_outgoing_requests: Default::default(),
532 },
533 )
534 .trace_err();
535 }
536 }
537 }
538 }
539 }
540
541 if let Some(live_kit) = live_kit_client.as_ref() {
542 if delete_live_kit_room {
543 live_kit.delete_room(live_kit_room).await.trace_err();
544 }
545 }
546 }
547 }
548
549 app_state
550 .db
551 .delete_stale_servers(&app_state.config.zed_environment, server_id)
552 .await
553 .trace_err();
554 }
555 .instrument(span),
556 );
557 Ok(())
558 }
559
560 pub fn teardown(&self) {
561 self.peer.teardown();
562 self.connection_pool.lock().reset();
563 let _ = self.teardown.send(true);
564 }
565
566 #[cfg(test)]
567 pub fn reset(&self, id: ServerId) {
568 self.teardown();
569 *self.id.lock() = id;
570 self.peer.reset(id.0 as u32);
571 let _ = self.teardown.send(false);
572 }
573
574 #[cfg(test)]
575 pub fn id(&self) -> ServerId {
576 *self.id.lock()
577 }
578
579 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
580 where
581 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
582 Fut: 'static + Send + Future<Output = Result<()>>,
583 M: EnvelopedMessage,
584 {
585 let prev_handler = self.handlers.insert(
586 TypeId::of::<M>(),
587 Box::new(move |envelope, session| {
588 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
589 let received_at = envelope.received_at;
590 tracing::info!("message received");
591 let start_time = Instant::now();
592 let future = (handler)(*envelope, session);
593 async move {
594 let result = future.await;
595 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
596 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
597 let queue_duration_ms = total_duration_ms - processing_duration_ms;
598 let payload_type = M::NAME;
599
600 match result {
601 Err(error) => {
602 tracing::error!(
603 ?error,
604 total_duration_ms,
605 processing_duration_ms,
606 queue_duration_ms,
607 payload_type,
608 "error handling message"
609 )
610 }
611 Ok(()) => tracing::info!(
612 total_duration_ms,
613 processing_duration_ms,
614 queue_duration_ms,
615 "finished handling message"
616 ),
617 }
618 }
619 .boxed()
620 }),
621 );
622 if prev_handler.is_some() {
623 panic!("registered a handler for the same message twice");
624 }
625 self
626 }
627
628 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
629 where
630 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
631 Fut: 'static + Send + Future<Output = Result<()>>,
632 M: EnvelopedMessage,
633 {
634 self.add_handler(move |envelope, session| handler(envelope.payload, session));
635 self
636 }
637
638 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
639 where
640 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
641 Fut: Send + Future<Output = Result<()>>,
642 M: RequestMessage,
643 {
644 let handler = Arc::new(handler);
645 self.add_handler(move |envelope, session| {
646 let receipt = envelope.receipt();
647 let handler = handler.clone();
648 async move {
649 let peer = session.peer.clone();
650 let responded = Arc::new(AtomicBool::default());
651 let response = Response {
652 peer: peer.clone(),
653 responded: responded.clone(),
654 receipt,
655 };
656 match (handler)(envelope.payload, response, session).await {
657 Ok(()) => {
658 if responded.load(std::sync::atomic::Ordering::SeqCst) {
659 Ok(())
660 } else {
661 Err(anyhow!("handler did not send a response"))?
662 }
663 }
664 Err(error) => {
665 let proto_err = match &error {
666 Error::Internal(err) => err.to_proto(),
667 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
668 };
669 peer.respond_with_error(receipt, proto_err)?;
670 Err(error)
671 }
672 }
673 }
674 })
675 }
676
677 #[allow(clippy::too_many_arguments)]
678 pub fn handle_connection(
679 self: &Arc<Self>,
680 connection: Connection,
681 address: String,
682 principal: Principal,
683 zed_version: ZedVersion,
684 geoip_country_code: Option<String>,
685 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
686 executor: Executor,
687 ) -> impl Future<Output = ()> {
688 let this = self.clone();
689 let span = info_span!("handle connection", %address,
690 connection_id=field::Empty,
691 user_id=field::Empty,
692 login=field::Empty,
693 impersonator=field::Empty,
694 geoip_country_code=field::Empty
695 );
696 principal.update_span(&span);
697 if let Some(country_code) = geoip_country_code.as_ref() {
698 span.record("geoip_country_code", country_code);
699 }
700
701 let mut teardown = self.teardown.subscribe();
702 async move {
703 if *teardown.borrow() {
704 tracing::error!("server is tearing down");
705 return
706 }
707 let (connection_id, handle_io, mut incoming_rx) = this
708 .peer
709 .add_connection(connection, {
710 let executor = executor.clone();
711 move |duration| executor.sleep(duration)
712 });
713 tracing::Span::current().record("connection_id", format!("{}", connection_id));
714
715 tracing::info!("connection opened");
716
717 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
718 let http_client = match ReqwestClient::user_agent(&user_agent) {
719 Ok(http_client) => Arc::new(http_client),
720 Err(error) => {
721 tracing::error!(?error, "failed to create HTTP client");
722 return;
723 }
724 };
725
726 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
727 supermaven_admin_api_key.to_string(),
728 http_client.clone(),
729 )));
730
731 let session = Session {
732 principal: principal.clone(),
733 connection_id,
734 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
735 peer: this.peer.clone(),
736 connection_pool: this.connection_pool.clone(),
737 app_state: this.app_state.clone(),
738 http_client,
739 geoip_country_code,
740 _executor: executor.clone(),
741 supermaven_client,
742 };
743
744 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
745 tracing::error!(?error, "failed to send initial client update");
746 return;
747 }
748
749 let handle_io = handle_io.fuse();
750 futures::pin_mut!(handle_io);
751
752 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
753 // This prevents deadlocks when e.g., client A performs a request to client B and
754 // client B performs a request to client A. If both clients stop processing further
755 // messages until their respective request completes, they won't have a chance to
756 // respond to the other client's request and cause a deadlock.
757 //
758 // This arrangement ensures we will attempt to process earlier messages first, but fall
759 // back to processing messages arrived later in the spirit of making progress.
760 let mut foreground_message_handlers = FuturesUnordered::new();
761 let concurrent_handlers = Arc::new(Semaphore::new(256));
762 loop {
763 let next_message = async {
764 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
765 let message = incoming_rx.next().await;
766 (permit, message)
767 }.fuse();
768 futures::pin_mut!(next_message);
769 futures::select_biased! {
770 _ = teardown.changed().fuse() => return,
771 result = handle_io => {
772 if let Err(error) = result {
773 tracing::error!(?error, "error handling I/O");
774 }
775 break;
776 }
777 _ = foreground_message_handlers.next() => {}
778 next_message = next_message => {
779 let (permit, message) = next_message;
780 if let Some(message) = message {
781 let type_name = message.payload_type_name();
782 // note: we copy all the fields from the parent span so we can query them in the logs.
783 // (https://github.com/tokio-rs/tracing/issues/2670).
784 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
785 user_id=field::Empty,
786 login=field::Empty,
787 impersonator=field::Empty,
788 );
789 principal.update_span(&span);
790 let span_enter = span.enter();
791 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
792 let is_background = message.is_background();
793 let handle_message = (handler)(message, session.clone());
794 drop(span_enter);
795
796 let handle_message = async move {
797 handle_message.await;
798 drop(permit);
799 }.instrument(span);
800 if is_background {
801 executor.spawn_detached(handle_message);
802 } else {
803 foreground_message_handlers.push(handle_message);
804 }
805 } else {
806 tracing::error!("no message handler");
807 }
808 } else {
809 tracing::info!("connection closed");
810 break;
811 }
812 }
813 }
814 }
815
816 drop(foreground_message_handlers);
817 tracing::info!("signing out");
818 if let Err(error) = connection_lost(session, teardown, executor).await {
819 tracing::error!(?error, "error signing out");
820 }
821
822 }.instrument(span)
823 }
824
825 async fn send_initial_client_update(
826 &self,
827 connection_id: ConnectionId,
828 principal: &Principal,
829 zed_version: ZedVersion,
830 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
831 session: &Session,
832 ) -> Result<()> {
833 self.peer.send(
834 connection_id,
835 proto::Hello {
836 peer_id: Some(connection_id.into()),
837 },
838 )?;
839 tracing::info!("sent hello message");
840 if let Some(send_connection_id) = send_connection_id.take() {
841 let _ = send_connection_id.send(connection_id);
842 }
843
844 match principal {
845 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
846 if !user.connected_once {
847 self.peer.send(connection_id, proto::ShowContacts {})?;
848 self.app_state
849 .db
850 .set_user_connected_once(user.id, true)
851 .await?;
852 }
853
854 update_user_plan(user.id, session).await?;
855
856 let contacts = self.app_state.db.get_contacts(user.id).await?;
857
858 {
859 let mut pool = self.connection_pool.lock();
860 pool.add_connection(connection_id, user.id, user.admin, zed_version);
861 self.peer.send(
862 connection_id,
863 build_initial_contacts_update(contacts, &pool),
864 )?;
865 }
866
867 if should_auto_subscribe_to_channels(zed_version) {
868 subscribe_user_to_channels(user.id, session).await?;
869 }
870
871 if let Some(incoming_call) =
872 self.app_state.db.incoming_call_for_user(user.id).await?
873 {
874 self.peer.send(connection_id, incoming_call)?;
875 }
876
877 update_user_contacts(user.id, session).await?;
878 }
879 }
880
881 Ok(())
882 }
883
884 pub async fn invite_code_redeemed(
885 self: &Arc<Self>,
886 inviter_id: UserId,
887 invitee_id: UserId,
888 ) -> Result<()> {
889 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
890 if let Some(code) = &user.invite_code {
891 let pool = self.connection_pool.lock();
892 let invitee_contact = contact_for_user(invitee_id, false, &pool);
893 for connection_id in pool.user_connection_ids(inviter_id) {
894 self.peer.send(
895 connection_id,
896 proto::UpdateContacts {
897 contacts: vec![invitee_contact.clone()],
898 ..Default::default()
899 },
900 )?;
901 self.peer.send(
902 connection_id,
903 proto::UpdateInviteInfo {
904 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
905 count: user.invite_count as u32,
906 },
907 )?;
908 }
909 }
910 }
911 Ok(())
912 }
913
914 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
915 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
916 if let Some(invite_code) = &user.invite_code {
917 let pool = self.connection_pool.lock();
918 for connection_id in pool.user_connection_ids(user_id) {
919 self.peer.send(
920 connection_id,
921 proto::UpdateInviteInfo {
922 url: format!(
923 "{}{}",
924 self.app_state.config.invite_link_prefix, invite_code
925 ),
926 count: user.invite_count as u32,
927 },
928 )?;
929 }
930 }
931 }
932 Ok(())
933 }
934
935 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
936 let pool = self.connection_pool.lock();
937 for connection_id in pool.user_connection_ids(user_id) {
938 self.peer
939 .send(connection_id, proto::RefreshLlmToken {})
940 .trace_err();
941 }
942 }
943
944 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
945 ServerSnapshot {
946 connection_pool: ConnectionPoolGuard {
947 guard: self.connection_pool.lock(),
948 _not_send: PhantomData,
949 },
950 peer: &self.peer,
951 }
952 }
953}
954
955impl<'a> Deref for ConnectionPoolGuard<'a> {
956 type Target = ConnectionPool;
957
958 fn deref(&self) -> &Self::Target {
959 &self.guard
960 }
961}
962
963impl<'a> DerefMut for ConnectionPoolGuard<'a> {
964 fn deref_mut(&mut self) -> &mut Self::Target {
965 &mut self.guard
966 }
967}
968
969impl<'a> Drop for ConnectionPoolGuard<'a> {
970 fn drop(&mut self) {
971 #[cfg(test)]
972 self.check_invariants();
973 }
974}
975
976fn broadcast<F>(
977 sender_id: Option<ConnectionId>,
978 receiver_ids: impl IntoIterator<Item = ConnectionId>,
979 mut f: F,
980) where
981 F: FnMut(ConnectionId) -> anyhow::Result<()>,
982{
983 for receiver_id in receiver_ids {
984 if Some(receiver_id) != sender_id {
985 if let Err(error) = f(receiver_id) {
986 tracing::error!("failed to send to {:?} {}", receiver_id, error);
987 }
988 }
989 }
990}
991
992pub struct ProtocolVersion(u32);
993
994impl Header for ProtocolVersion {
995 fn name() -> &'static HeaderName {
996 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
997 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
998 }
999
1000 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1001 where
1002 Self: Sized,
1003 I: Iterator<Item = &'i axum::http::HeaderValue>,
1004 {
1005 let version = values
1006 .next()
1007 .ok_or_else(axum::headers::Error::invalid)?
1008 .to_str()
1009 .map_err(|_| axum::headers::Error::invalid())?
1010 .parse()
1011 .map_err(|_| axum::headers::Error::invalid())?;
1012 Ok(Self(version))
1013 }
1014
1015 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1016 values.extend([self.0.to_string().parse().unwrap()]);
1017 }
1018}
1019
1020pub struct AppVersionHeader(SemanticVersion);
1021impl Header for AppVersionHeader {
1022 fn name() -> &'static HeaderName {
1023 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1025 }
1026
1027 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028 where
1029 Self: Sized,
1030 I: Iterator<Item = &'i axum::http::HeaderValue>,
1031 {
1032 let version = values
1033 .next()
1034 .ok_or_else(axum::headers::Error::invalid)?
1035 .to_str()
1036 .map_err(|_| axum::headers::Error::invalid())?
1037 .parse()
1038 .map_err(|_| axum::headers::Error::invalid())?;
1039 Ok(Self(version))
1040 }
1041
1042 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043 values.extend([self.0.to_string().parse().unwrap()]);
1044 }
1045}
1046
1047pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1048 Router::new()
1049 .route("/rpc", get(handle_websocket_request))
1050 .layer(
1051 ServiceBuilder::new()
1052 .layer(Extension(server.app_state.clone()))
1053 .layer(middleware::from_fn(auth::validate_header)),
1054 )
1055 .route("/metrics", get(handle_metrics))
1056 .layer(Extension(server))
1057}
1058
1059pub async fn handle_websocket_request(
1060 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1061 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1062 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1063 Extension(server): Extension<Arc<Server>>,
1064 Extension(principal): Extension<Principal>,
1065 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1066 ws: WebSocketUpgrade,
1067) -> axum::response::Response {
1068 if protocol_version != rpc::PROTOCOL_VERSION {
1069 return (
1070 StatusCode::UPGRADE_REQUIRED,
1071 "client must be upgraded".to_string(),
1072 )
1073 .into_response();
1074 }
1075
1076 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1077 return (
1078 StatusCode::UPGRADE_REQUIRED,
1079 "no version header found".to_string(),
1080 )
1081 .into_response();
1082 };
1083
1084 if !version.can_collaborate() {
1085 return (
1086 StatusCode::UPGRADE_REQUIRED,
1087 "client must be upgraded".to_string(),
1088 )
1089 .into_response();
1090 }
1091
1092 let socket_address = socket_address.to_string();
1093 ws.on_upgrade(move |socket| {
1094 let socket = socket
1095 .map_ok(to_tungstenite_message)
1096 .err_into()
1097 .with(|message| async move { to_axum_message(message) });
1098 let connection = Connection::new(Box::pin(socket));
1099 async move {
1100 server
1101 .handle_connection(
1102 connection,
1103 socket_address,
1104 principal,
1105 version,
1106 country_code_header.map(|header| header.to_string()),
1107 None,
1108 Executor::Production,
1109 )
1110 .await;
1111 }
1112 })
1113}
1114
1115pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1116 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1117 let connections_metric = CONNECTIONS_METRIC
1118 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1119
1120 let connections = server
1121 .connection_pool
1122 .lock()
1123 .connections()
1124 .filter(|connection| !connection.admin)
1125 .count();
1126 connections_metric.set(connections as _);
1127
1128 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1129 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1130 register_int_gauge!(
1131 "shared_projects",
1132 "number of open projects with one or more guests"
1133 )
1134 .unwrap()
1135 });
1136
1137 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1138 shared_projects_metric.set(shared_projects as _);
1139
1140 let encoder = prometheus::TextEncoder::new();
1141 let metric_families = prometheus::gather();
1142 let encoded_metrics = encoder
1143 .encode_to_string(&metric_families)
1144 .map_err(|err| anyhow!("{}", err))?;
1145 Ok(encoded_metrics)
1146}
1147
1148#[instrument(err, skip(executor))]
1149async fn connection_lost(
1150 session: Session,
1151 mut teardown: watch::Receiver<bool>,
1152 executor: Executor,
1153) -> Result<()> {
1154 session.peer.disconnect(session.connection_id);
1155 session
1156 .connection_pool()
1157 .await
1158 .remove_connection(session.connection_id)?;
1159
1160 session
1161 .db()
1162 .await
1163 .connection_lost(session.connection_id)
1164 .await
1165 .trace_err();
1166
1167 futures::select_biased! {
1168 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1169
1170 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1171 leave_room_for_session(&session, session.connection_id).await.trace_err();
1172 leave_channel_buffers_for_session(&session)
1173 .await
1174 .trace_err();
1175
1176 if !session
1177 .connection_pool()
1178 .await
1179 .is_user_online(session.user_id())
1180 {
1181 let db = session.db().await;
1182 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1183 room_updated(&room, &session.peer);
1184 }
1185 }
1186
1187 update_user_contacts(session.user_id(), &session).await?;
1188 },
1189 _ = teardown.changed().fuse() => {}
1190 }
1191
1192 Ok(())
1193}
1194
1195/// Acknowledges a ping from a client, used to keep the connection alive.
1196async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1197 response.send(proto::Ack {})?;
1198 Ok(())
1199}
1200
1201/// Creates a new room for calling (outside of channels)
1202async fn create_room(
1203 _request: proto::CreateRoom,
1204 response: Response<proto::CreateRoom>,
1205 session: Session,
1206) -> Result<()> {
1207 let live_kit_room = nanoid::nanoid!(30);
1208
1209 let live_kit_connection_info = util::maybe!(async {
1210 let live_kit = session.app_state.live_kit_client.as_ref();
1211 let live_kit = live_kit?;
1212 let user_id = session.user_id().to_string();
1213
1214 let token = live_kit
1215 .room_token(&live_kit_room, &user_id.to_string())
1216 .trace_err()?;
1217
1218 Some(proto::LiveKitConnectionInfo {
1219 server_url: live_kit.url().into(),
1220 token,
1221 can_publish: true,
1222 })
1223 })
1224 .await;
1225
1226 let room = session
1227 .db()
1228 .await
1229 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1230 .await?;
1231
1232 response.send(proto::CreateRoomResponse {
1233 room: Some(room.clone()),
1234 live_kit_connection_info,
1235 })?;
1236
1237 update_user_contacts(session.user_id(), &session).await?;
1238 Ok(())
1239}
1240
1241/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1242async fn join_room(
1243 request: proto::JoinRoom,
1244 response: Response<proto::JoinRoom>,
1245 session: Session,
1246) -> Result<()> {
1247 let room_id = RoomId::from_proto(request.id);
1248
1249 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1250
1251 if let Some(channel_id) = channel_id {
1252 return join_channel_internal(channel_id, Box::new(response), session).await;
1253 }
1254
1255 let joined_room = {
1256 let room = session
1257 .db()
1258 .await
1259 .join_room(room_id, session.user_id(), session.connection_id)
1260 .await?;
1261 room_updated(&room.room, &session.peer);
1262 room.into_inner()
1263 };
1264
1265 for connection_id in session
1266 .connection_pool()
1267 .await
1268 .user_connection_ids(session.user_id())
1269 {
1270 session
1271 .peer
1272 .send(
1273 connection_id,
1274 proto::CallCanceled {
1275 room_id: room_id.to_proto(),
1276 },
1277 )
1278 .trace_err();
1279 }
1280
1281 let live_kit_connection_info =
1282 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1283 live_kit
1284 .room_token(
1285 &joined_room.room.live_kit_room,
1286 &session.user_id().to_string(),
1287 )
1288 .trace_err()
1289 .map(|token| proto::LiveKitConnectionInfo {
1290 server_url: live_kit.url().into(),
1291 token,
1292 can_publish: true,
1293 })
1294 } else {
1295 None
1296 };
1297
1298 response.send(proto::JoinRoomResponse {
1299 room: Some(joined_room.room),
1300 channel_id: None,
1301 live_kit_connection_info,
1302 })?;
1303
1304 update_user_contacts(session.user_id(), &session).await?;
1305 Ok(())
1306}
1307
1308/// Rejoin room is used to reconnect to a room after connection errors.
1309async fn rejoin_room(
1310 request: proto::RejoinRoom,
1311 response: Response<proto::RejoinRoom>,
1312 session: Session,
1313) -> Result<()> {
1314 let room;
1315 let channel;
1316 {
1317 let mut rejoined_room = session
1318 .db()
1319 .await
1320 .rejoin_room(request, session.user_id(), session.connection_id)
1321 .await?;
1322
1323 response.send(proto::RejoinRoomResponse {
1324 room: Some(rejoined_room.room.clone()),
1325 reshared_projects: rejoined_room
1326 .reshared_projects
1327 .iter()
1328 .map(|project| proto::ResharedProject {
1329 id: project.id.to_proto(),
1330 collaborators: project
1331 .collaborators
1332 .iter()
1333 .map(|collaborator| collaborator.to_proto())
1334 .collect(),
1335 })
1336 .collect(),
1337 rejoined_projects: rejoined_room
1338 .rejoined_projects
1339 .iter()
1340 .map(|rejoined_project| rejoined_project.to_proto())
1341 .collect(),
1342 })?;
1343 room_updated(&rejoined_room.room, &session.peer);
1344
1345 for project in &rejoined_room.reshared_projects {
1346 for collaborator in &project.collaborators {
1347 session
1348 .peer
1349 .send(
1350 collaborator.connection_id,
1351 proto::UpdateProjectCollaborator {
1352 project_id: project.id.to_proto(),
1353 old_peer_id: Some(project.old_connection_id.into()),
1354 new_peer_id: Some(session.connection_id.into()),
1355 },
1356 )
1357 .trace_err();
1358 }
1359
1360 broadcast(
1361 Some(session.connection_id),
1362 project
1363 .collaborators
1364 .iter()
1365 .map(|collaborator| collaborator.connection_id),
1366 |connection_id| {
1367 session.peer.forward_send(
1368 session.connection_id,
1369 connection_id,
1370 proto::UpdateProject {
1371 project_id: project.id.to_proto(),
1372 worktrees: project.worktrees.clone(),
1373 },
1374 )
1375 },
1376 );
1377 }
1378
1379 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1380
1381 let rejoined_room = rejoined_room.into_inner();
1382
1383 room = rejoined_room.room;
1384 channel = rejoined_room.channel;
1385 }
1386
1387 if let Some(channel) = channel {
1388 channel_updated(
1389 &channel,
1390 &room,
1391 &session.peer,
1392 &*session.connection_pool().await,
1393 );
1394 }
1395
1396 update_user_contacts(session.user_id(), &session).await?;
1397 Ok(())
1398}
1399
1400fn notify_rejoined_projects(
1401 rejoined_projects: &mut Vec<RejoinedProject>,
1402 session: &Session,
1403) -> Result<()> {
1404 for project in rejoined_projects.iter() {
1405 for collaborator in &project.collaborators {
1406 session
1407 .peer
1408 .send(
1409 collaborator.connection_id,
1410 proto::UpdateProjectCollaborator {
1411 project_id: project.id.to_proto(),
1412 old_peer_id: Some(project.old_connection_id.into()),
1413 new_peer_id: Some(session.connection_id.into()),
1414 },
1415 )
1416 .trace_err();
1417 }
1418 }
1419
1420 for project in rejoined_projects {
1421 for worktree in mem::take(&mut project.worktrees) {
1422 // Stream this worktree's entries.
1423 let message = proto::UpdateWorktree {
1424 project_id: project.id.to_proto(),
1425 worktree_id: worktree.id,
1426 abs_path: worktree.abs_path.clone(),
1427 root_name: worktree.root_name,
1428 updated_entries: worktree.updated_entries,
1429 removed_entries: worktree.removed_entries,
1430 scan_id: worktree.scan_id,
1431 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1432 updated_repositories: worktree.updated_repositories,
1433 removed_repositories: worktree.removed_repositories,
1434 };
1435 for update in proto::split_worktree_update(message) {
1436 session.peer.send(session.connection_id, update.clone())?;
1437 }
1438
1439 // Stream this worktree's diagnostics.
1440 for summary in worktree.diagnostic_summaries {
1441 session.peer.send(
1442 session.connection_id,
1443 proto::UpdateDiagnosticSummary {
1444 project_id: project.id.to_proto(),
1445 worktree_id: worktree.id,
1446 summary: Some(summary),
1447 },
1448 )?;
1449 }
1450
1451 for settings_file in worktree.settings_files {
1452 session.peer.send(
1453 session.connection_id,
1454 proto::UpdateWorktreeSettings {
1455 project_id: project.id.to_proto(),
1456 worktree_id: worktree.id,
1457 path: settings_file.path,
1458 content: Some(settings_file.content),
1459 kind: Some(settings_file.kind.to_proto().into()),
1460 },
1461 )?;
1462 }
1463 }
1464
1465 for language_server in &project.language_servers {
1466 session.peer.send(
1467 session.connection_id,
1468 proto::UpdateLanguageServer {
1469 project_id: project.id.to_proto(),
1470 language_server_id: language_server.id,
1471 variant: Some(
1472 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1473 proto::LspDiskBasedDiagnosticsUpdated {},
1474 ),
1475 ),
1476 },
1477 )?;
1478 }
1479 }
1480 Ok(())
1481}
1482
1483/// leave room disconnects from the room.
1484async fn leave_room(
1485 _: proto::LeaveRoom,
1486 response: Response<proto::LeaveRoom>,
1487 session: Session,
1488) -> Result<()> {
1489 leave_room_for_session(&session, session.connection_id).await?;
1490 response.send(proto::Ack {})?;
1491 Ok(())
1492}
1493
1494/// Updates the permissions of someone else in the room.
1495async fn set_room_participant_role(
1496 request: proto::SetRoomParticipantRole,
1497 response: Response<proto::SetRoomParticipantRole>,
1498 session: Session,
1499) -> Result<()> {
1500 let user_id = UserId::from_proto(request.user_id);
1501 let role = ChannelRole::from(request.role());
1502
1503 let (live_kit_room, can_publish) = {
1504 let room = session
1505 .db()
1506 .await
1507 .set_room_participant_role(
1508 session.user_id(),
1509 RoomId::from_proto(request.room_id),
1510 user_id,
1511 role,
1512 )
1513 .await?;
1514
1515 let live_kit_room = room.live_kit_room.clone();
1516 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1517 room_updated(&room, &session.peer);
1518 (live_kit_room, can_publish)
1519 };
1520
1521 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1522 live_kit
1523 .update_participant(
1524 live_kit_room.clone(),
1525 request.user_id.to_string(),
1526 live_kit_server::proto::ParticipantPermission {
1527 can_subscribe: true,
1528 can_publish,
1529 can_publish_data: can_publish,
1530 hidden: false,
1531 recorder: false,
1532 },
1533 )
1534 .await
1535 .trace_err();
1536 }
1537
1538 response.send(proto::Ack {})?;
1539 Ok(())
1540}
1541
1542/// Call someone else into the current room
1543async fn call(
1544 request: proto::Call,
1545 response: Response<proto::Call>,
1546 session: Session,
1547) -> Result<()> {
1548 let room_id = RoomId::from_proto(request.room_id);
1549 let calling_user_id = session.user_id();
1550 let calling_connection_id = session.connection_id;
1551 let called_user_id = UserId::from_proto(request.called_user_id);
1552 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1553 if !session
1554 .db()
1555 .await
1556 .has_contact(calling_user_id, called_user_id)
1557 .await?
1558 {
1559 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1560 }
1561
1562 let incoming_call = {
1563 let (room, incoming_call) = &mut *session
1564 .db()
1565 .await
1566 .call(
1567 room_id,
1568 calling_user_id,
1569 calling_connection_id,
1570 called_user_id,
1571 initial_project_id,
1572 )
1573 .await?;
1574 room_updated(room, &session.peer);
1575 mem::take(incoming_call)
1576 };
1577 update_user_contacts(called_user_id, &session).await?;
1578
1579 let mut calls = session
1580 .connection_pool()
1581 .await
1582 .user_connection_ids(called_user_id)
1583 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1584 .collect::<FuturesUnordered<_>>();
1585
1586 while let Some(call_response) = calls.next().await {
1587 match call_response.as_ref() {
1588 Ok(_) => {
1589 response.send(proto::Ack {})?;
1590 return Ok(());
1591 }
1592 Err(_) => {
1593 call_response.trace_err();
1594 }
1595 }
1596 }
1597
1598 {
1599 let room = session
1600 .db()
1601 .await
1602 .call_failed(room_id, called_user_id)
1603 .await?;
1604 room_updated(&room, &session.peer);
1605 }
1606 update_user_contacts(called_user_id, &session).await?;
1607
1608 Err(anyhow!("failed to ring user"))?
1609}
1610
1611/// Cancel an outgoing call.
1612async fn cancel_call(
1613 request: proto::CancelCall,
1614 response: Response<proto::CancelCall>,
1615 session: Session,
1616) -> Result<()> {
1617 let called_user_id = UserId::from_proto(request.called_user_id);
1618 let room_id = RoomId::from_proto(request.room_id);
1619 {
1620 let room = session
1621 .db()
1622 .await
1623 .cancel_call(room_id, session.connection_id, called_user_id)
1624 .await?;
1625 room_updated(&room, &session.peer);
1626 }
1627
1628 for connection_id in session
1629 .connection_pool()
1630 .await
1631 .user_connection_ids(called_user_id)
1632 {
1633 session
1634 .peer
1635 .send(
1636 connection_id,
1637 proto::CallCanceled {
1638 room_id: room_id.to_proto(),
1639 },
1640 )
1641 .trace_err();
1642 }
1643 response.send(proto::Ack {})?;
1644
1645 update_user_contacts(called_user_id, &session).await?;
1646 Ok(())
1647}
1648
1649/// Decline an incoming call.
1650async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1651 let room_id = RoomId::from_proto(message.room_id);
1652 {
1653 let room = session
1654 .db()
1655 .await
1656 .decline_call(Some(room_id), session.user_id())
1657 .await?
1658 .ok_or_else(|| anyhow!("failed to decline call"))?;
1659 room_updated(&room, &session.peer);
1660 }
1661
1662 for connection_id in session
1663 .connection_pool()
1664 .await
1665 .user_connection_ids(session.user_id())
1666 {
1667 session
1668 .peer
1669 .send(
1670 connection_id,
1671 proto::CallCanceled {
1672 room_id: room_id.to_proto(),
1673 },
1674 )
1675 .trace_err();
1676 }
1677 update_user_contacts(session.user_id(), &session).await?;
1678 Ok(())
1679}
1680
1681/// Updates other participants in the room with your current location.
1682async fn update_participant_location(
1683 request: proto::UpdateParticipantLocation,
1684 response: Response<proto::UpdateParticipantLocation>,
1685 session: Session,
1686) -> Result<()> {
1687 let room_id = RoomId::from_proto(request.room_id);
1688 let location = request
1689 .location
1690 .ok_or_else(|| anyhow!("invalid location"))?;
1691
1692 let db = session.db().await;
1693 let room = db
1694 .update_room_participant_location(room_id, session.connection_id, location)
1695 .await?;
1696
1697 room_updated(&room, &session.peer);
1698 response.send(proto::Ack {})?;
1699 Ok(())
1700}
1701
1702/// Share a project into the room.
1703async fn share_project(
1704 request: proto::ShareProject,
1705 response: Response<proto::ShareProject>,
1706 session: Session,
1707) -> Result<()> {
1708 let (project_id, room) = &*session
1709 .db()
1710 .await
1711 .share_project(
1712 RoomId::from_proto(request.room_id),
1713 session.connection_id,
1714 &request.worktrees,
1715 request.is_ssh_project,
1716 )
1717 .await?;
1718 response.send(proto::ShareProjectResponse {
1719 project_id: project_id.to_proto(),
1720 })?;
1721 room_updated(room, &session.peer);
1722
1723 Ok(())
1724}
1725
1726/// Unshare a project from the room.
1727async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1728 let project_id = ProjectId::from_proto(message.project_id);
1729 unshare_project_internal(project_id, session.connection_id, &session).await
1730}
1731
1732async fn unshare_project_internal(
1733 project_id: ProjectId,
1734 connection_id: ConnectionId,
1735 session: &Session,
1736) -> Result<()> {
1737 let delete = {
1738 let room_guard = session
1739 .db()
1740 .await
1741 .unshare_project(project_id, connection_id)
1742 .await?;
1743
1744 let (delete, room, guest_connection_ids) = &*room_guard;
1745
1746 let message = proto::UnshareProject {
1747 project_id: project_id.to_proto(),
1748 };
1749
1750 broadcast(
1751 Some(connection_id),
1752 guest_connection_ids.iter().copied(),
1753 |conn_id| session.peer.send(conn_id, message.clone()),
1754 );
1755 if let Some(room) = room {
1756 room_updated(room, &session.peer);
1757 }
1758
1759 *delete
1760 };
1761
1762 if delete {
1763 let db = session.db().await;
1764 db.delete_project(project_id).await?;
1765 }
1766
1767 Ok(())
1768}
1769
1770/// Join someone elses shared project.
1771async fn join_project(
1772 request: proto::JoinProject,
1773 response: Response<proto::JoinProject>,
1774 session: Session,
1775) -> Result<()> {
1776 let project_id = ProjectId::from_proto(request.project_id);
1777
1778 tracing::info!(%project_id, "join project");
1779
1780 let db = session.db().await;
1781 let (project, replica_id) = &mut *db
1782 .join_project(project_id, session.connection_id, session.user_id())
1783 .await?;
1784 drop(db);
1785 tracing::info!(%project_id, "join remote project");
1786 join_project_internal(response, session, project, replica_id)
1787}
1788
1789trait JoinProjectInternalResponse {
1790 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1791}
1792impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1793 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1794 Response::<proto::JoinProject>::send(self, result)
1795 }
1796}
1797
1798fn join_project_internal(
1799 response: impl JoinProjectInternalResponse,
1800 session: Session,
1801 project: &mut Project,
1802 replica_id: &ReplicaId,
1803) -> Result<()> {
1804 let collaborators = project
1805 .collaborators
1806 .iter()
1807 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1808 .map(|collaborator| collaborator.to_proto())
1809 .collect::<Vec<_>>();
1810 let project_id = project.id;
1811 let guest_user_id = session.user_id();
1812
1813 let worktrees = project
1814 .worktrees
1815 .iter()
1816 .map(|(id, worktree)| proto::WorktreeMetadata {
1817 id: *id,
1818 root_name: worktree.root_name.clone(),
1819 visible: worktree.visible,
1820 abs_path: worktree.abs_path.clone(),
1821 })
1822 .collect::<Vec<_>>();
1823
1824 let add_project_collaborator = proto::AddProjectCollaborator {
1825 project_id: project_id.to_proto(),
1826 collaborator: Some(proto::Collaborator {
1827 peer_id: Some(session.connection_id.into()),
1828 replica_id: replica_id.0 as u32,
1829 user_id: guest_user_id.to_proto(),
1830 is_host: false,
1831 }),
1832 };
1833
1834 for collaborator in &collaborators {
1835 session
1836 .peer
1837 .send(
1838 collaborator.peer_id.unwrap().into(),
1839 add_project_collaborator.clone(),
1840 )
1841 .trace_err();
1842 }
1843
1844 // First, we send the metadata associated with each worktree.
1845 response.send(proto::JoinProjectResponse {
1846 project_id: project.id.0 as u64,
1847 worktrees: worktrees.clone(),
1848 replica_id: replica_id.0 as u32,
1849 collaborators: collaborators.clone(),
1850 language_servers: project.language_servers.clone(),
1851 role: project.role.into(),
1852 })?;
1853
1854 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1855 // Stream this worktree's entries.
1856 let message = proto::UpdateWorktree {
1857 project_id: project_id.to_proto(),
1858 worktree_id,
1859 abs_path: worktree.abs_path.clone(),
1860 root_name: worktree.root_name,
1861 updated_entries: worktree.entries,
1862 removed_entries: Default::default(),
1863 scan_id: worktree.scan_id,
1864 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1865 updated_repositories: worktree.repository_entries.into_values().collect(),
1866 removed_repositories: Default::default(),
1867 };
1868 for update in proto::split_worktree_update(message) {
1869 session.peer.send(session.connection_id, update.clone())?;
1870 }
1871
1872 // Stream this worktree's diagnostics.
1873 for summary in worktree.diagnostic_summaries {
1874 session.peer.send(
1875 session.connection_id,
1876 proto::UpdateDiagnosticSummary {
1877 project_id: project_id.to_proto(),
1878 worktree_id: worktree.id,
1879 summary: Some(summary),
1880 },
1881 )?;
1882 }
1883
1884 for settings_file in worktree.settings_files {
1885 session.peer.send(
1886 session.connection_id,
1887 proto::UpdateWorktreeSettings {
1888 project_id: project_id.to_proto(),
1889 worktree_id: worktree.id,
1890 path: settings_file.path,
1891 content: Some(settings_file.content),
1892 kind: Some(settings_file.kind.to_proto() as i32),
1893 },
1894 )?;
1895 }
1896 }
1897
1898 for language_server in &project.language_servers {
1899 session.peer.send(
1900 session.connection_id,
1901 proto::UpdateLanguageServer {
1902 project_id: project_id.to_proto(),
1903 language_server_id: language_server.id,
1904 variant: Some(
1905 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1906 proto::LspDiskBasedDiagnosticsUpdated {},
1907 ),
1908 ),
1909 },
1910 )?;
1911 }
1912
1913 Ok(())
1914}
1915
1916/// Leave someone elses shared project.
1917async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1918 let sender_id = session.connection_id;
1919 let project_id = ProjectId::from_proto(request.project_id);
1920 let db = session.db().await;
1921
1922 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1923 tracing::info!(
1924 %project_id,
1925 "leave project"
1926 );
1927
1928 project_left(project, &session);
1929 if let Some(room) = room {
1930 room_updated(room, &session.peer);
1931 }
1932
1933 Ok(())
1934}
1935
1936/// Updates other participants with changes to the project
1937async fn update_project(
1938 request: proto::UpdateProject,
1939 response: Response<proto::UpdateProject>,
1940 session: Session,
1941) -> Result<()> {
1942 let project_id = ProjectId::from_proto(request.project_id);
1943 let (room, guest_connection_ids) = &*session
1944 .db()
1945 .await
1946 .update_project(project_id, session.connection_id, &request.worktrees)
1947 .await?;
1948 broadcast(
1949 Some(session.connection_id),
1950 guest_connection_ids.iter().copied(),
1951 |connection_id| {
1952 session
1953 .peer
1954 .forward_send(session.connection_id, connection_id, request.clone())
1955 },
1956 );
1957 if let Some(room) = room {
1958 room_updated(room, &session.peer);
1959 }
1960 response.send(proto::Ack {})?;
1961
1962 Ok(())
1963}
1964
1965/// Updates other participants with changes to the worktree
1966async fn update_worktree(
1967 request: proto::UpdateWorktree,
1968 response: Response<proto::UpdateWorktree>,
1969 session: Session,
1970) -> Result<()> {
1971 let guest_connection_ids = session
1972 .db()
1973 .await
1974 .update_worktree(&request, session.connection_id)
1975 .await?;
1976
1977 broadcast(
1978 Some(session.connection_id),
1979 guest_connection_ids.iter().copied(),
1980 |connection_id| {
1981 session
1982 .peer
1983 .forward_send(session.connection_id, connection_id, request.clone())
1984 },
1985 );
1986 response.send(proto::Ack {})?;
1987 Ok(())
1988}
1989
1990/// Updates other participants with changes to the diagnostics
1991async fn update_diagnostic_summary(
1992 message: proto::UpdateDiagnosticSummary,
1993 session: Session,
1994) -> Result<()> {
1995 let guest_connection_ids = session
1996 .db()
1997 .await
1998 .update_diagnostic_summary(&message, session.connection_id)
1999 .await?;
2000
2001 broadcast(
2002 Some(session.connection_id),
2003 guest_connection_ids.iter().copied(),
2004 |connection_id| {
2005 session
2006 .peer
2007 .forward_send(session.connection_id, connection_id, message.clone())
2008 },
2009 );
2010
2011 Ok(())
2012}
2013
2014/// Updates other participants with changes to the worktree settings
2015async fn update_worktree_settings(
2016 message: proto::UpdateWorktreeSettings,
2017 session: Session,
2018) -> Result<()> {
2019 let guest_connection_ids = session
2020 .db()
2021 .await
2022 .update_worktree_settings(&message, session.connection_id)
2023 .await?;
2024
2025 broadcast(
2026 Some(session.connection_id),
2027 guest_connection_ids.iter().copied(),
2028 |connection_id| {
2029 session
2030 .peer
2031 .forward_send(session.connection_id, connection_id, message.clone())
2032 },
2033 );
2034
2035 Ok(())
2036}
2037
2038/// Notify other participants that a language server has started.
2039async fn start_language_server(
2040 request: proto::StartLanguageServer,
2041 session: Session,
2042) -> Result<()> {
2043 let guest_connection_ids = session
2044 .db()
2045 .await
2046 .start_language_server(&request, session.connection_id)
2047 .await?;
2048
2049 broadcast(
2050 Some(session.connection_id),
2051 guest_connection_ids.iter().copied(),
2052 |connection_id| {
2053 session
2054 .peer
2055 .forward_send(session.connection_id, connection_id, request.clone())
2056 },
2057 );
2058 Ok(())
2059}
2060
2061/// Notify other participants that a language server has changed.
2062async fn update_language_server(
2063 request: proto::UpdateLanguageServer,
2064 session: Session,
2065) -> Result<()> {
2066 let project_id = ProjectId::from_proto(request.project_id);
2067 let project_connection_ids = session
2068 .db()
2069 .await
2070 .project_connection_ids(project_id, session.connection_id, true)
2071 .await?;
2072 broadcast(
2073 Some(session.connection_id),
2074 project_connection_ids.iter().copied(),
2075 |connection_id| {
2076 session
2077 .peer
2078 .forward_send(session.connection_id, connection_id, request.clone())
2079 },
2080 );
2081 Ok(())
2082}
2083
2084/// forward a project request to the host. These requests should be read only
2085/// as guests are allowed to send them.
2086async fn forward_read_only_project_request<T>(
2087 request: T,
2088 response: Response<T>,
2089 session: Session,
2090) -> Result<()>
2091where
2092 T: EntityMessage + RequestMessage,
2093{
2094 let project_id = ProjectId::from_proto(request.remote_entity_id());
2095 let host_connection_id = session
2096 .db()
2097 .await
2098 .host_for_read_only_project_request(project_id, session.connection_id)
2099 .await?;
2100 let payload = session
2101 .peer
2102 .forward_request(session.connection_id, host_connection_id, request)
2103 .await?;
2104 response.send(payload)?;
2105 Ok(())
2106}
2107
2108async fn forward_find_search_candidates_request(
2109 request: proto::FindSearchCandidates,
2110 response: Response<proto::FindSearchCandidates>,
2111 session: Session,
2112) -> Result<()> {
2113 let project_id = ProjectId::from_proto(request.remote_entity_id());
2114 let host_connection_id = session
2115 .db()
2116 .await
2117 .host_for_read_only_project_request(project_id, session.connection_id)
2118 .await?;
2119 let payload = session
2120 .peer
2121 .forward_request(session.connection_id, host_connection_id, request)
2122 .await?;
2123 response.send(payload)?;
2124 Ok(())
2125}
2126
2127/// forward a project request to the host. These requests are disallowed
2128/// for guests.
2129async fn forward_mutating_project_request<T>(
2130 request: T,
2131 response: Response<T>,
2132 session: Session,
2133) -> Result<()>
2134where
2135 T: EntityMessage + RequestMessage,
2136{
2137 let project_id = ProjectId::from_proto(request.remote_entity_id());
2138
2139 let host_connection_id = session
2140 .db()
2141 .await
2142 .host_for_mutating_project_request(project_id, session.connection_id)
2143 .await?;
2144 let payload = session
2145 .peer
2146 .forward_request(session.connection_id, host_connection_id, request)
2147 .await?;
2148 response.send(payload)?;
2149 Ok(())
2150}
2151
2152/// Notify other participants that a new buffer has been created
2153async fn create_buffer_for_peer(
2154 request: proto::CreateBufferForPeer,
2155 session: Session,
2156) -> Result<()> {
2157 session
2158 .db()
2159 .await
2160 .check_user_is_project_host(
2161 ProjectId::from_proto(request.project_id),
2162 session.connection_id,
2163 )
2164 .await?;
2165 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2166 session
2167 .peer
2168 .forward_send(session.connection_id, peer_id.into(), request)?;
2169 Ok(())
2170}
2171
2172/// Notify other participants that a buffer has been updated. This is
2173/// allowed for guests as long as the update is limited to selections.
2174async fn update_buffer(
2175 request: proto::UpdateBuffer,
2176 response: Response<proto::UpdateBuffer>,
2177 session: Session,
2178) -> Result<()> {
2179 let project_id = ProjectId::from_proto(request.project_id);
2180 let mut capability = Capability::ReadOnly;
2181
2182 for op in request.operations.iter() {
2183 match op.variant {
2184 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2185 Some(_) => capability = Capability::ReadWrite,
2186 }
2187 }
2188
2189 let host = {
2190 let guard = session
2191 .db()
2192 .await
2193 .connections_for_buffer_update(project_id, session.connection_id, capability)
2194 .await?;
2195
2196 let (host, guests) = &*guard;
2197
2198 broadcast(
2199 Some(session.connection_id),
2200 guests.clone(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207
2208 *host
2209 };
2210
2211 if host != session.connection_id {
2212 session
2213 .peer
2214 .forward_request(session.connection_id, host, request.clone())
2215 .await?;
2216 }
2217
2218 response.send(proto::Ack {})?;
2219 Ok(())
2220}
2221
2222async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2223 let project_id = ProjectId::from_proto(message.project_id);
2224
2225 let operation = message.operation.as_ref().context("invalid operation")?;
2226 let capability = match operation.variant.as_ref() {
2227 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2228 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2229 match buffer_op.variant {
2230 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2231 Capability::ReadOnly
2232 }
2233 _ => Capability::ReadWrite,
2234 }
2235 } else {
2236 Capability::ReadWrite
2237 }
2238 }
2239 Some(_) => Capability::ReadWrite,
2240 None => Capability::ReadOnly,
2241 };
2242
2243 let guard = session
2244 .db()
2245 .await
2246 .connections_for_buffer_update(project_id, session.connection_id, capability)
2247 .await?;
2248
2249 let (host, guests) = &*guard;
2250
2251 broadcast(
2252 Some(session.connection_id),
2253 guests.iter().chain([host]).copied(),
2254 |connection_id| {
2255 session
2256 .peer
2257 .forward_send(session.connection_id, connection_id, message.clone())
2258 },
2259 );
2260
2261 Ok(())
2262}
2263
2264/// Notify other participants that a project has been updated.
2265async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2266 request: T,
2267 session: Session,
2268) -> Result<()> {
2269 let project_id = ProjectId::from_proto(request.remote_entity_id());
2270 let project_connection_ids = session
2271 .db()
2272 .await
2273 .project_connection_ids(project_id, session.connection_id, false)
2274 .await?;
2275
2276 broadcast(
2277 Some(session.connection_id),
2278 project_connection_ids.iter().copied(),
2279 |connection_id| {
2280 session
2281 .peer
2282 .forward_send(session.connection_id, connection_id, request.clone())
2283 },
2284 );
2285 Ok(())
2286}
2287
2288/// Start following another user in a call.
2289async fn follow(
2290 request: proto::Follow,
2291 response: Response<proto::Follow>,
2292 session: Session,
2293) -> Result<()> {
2294 let room_id = RoomId::from_proto(request.room_id);
2295 let project_id = request.project_id.map(ProjectId::from_proto);
2296 let leader_id = request
2297 .leader_id
2298 .ok_or_else(|| anyhow!("invalid leader id"))?
2299 .into();
2300 let follower_id = session.connection_id;
2301
2302 session
2303 .db()
2304 .await
2305 .check_room_participants(room_id, leader_id, session.connection_id)
2306 .await?;
2307
2308 let response_payload = session
2309 .peer
2310 .forward_request(session.connection_id, leader_id, request)
2311 .await?;
2312 response.send(response_payload)?;
2313
2314 if let Some(project_id) = project_id {
2315 let room = session
2316 .db()
2317 .await
2318 .follow(room_id, project_id, leader_id, follower_id)
2319 .await?;
2320 room_updated(&room, &session.peer);
2321 }
2322
2323 Ok(())
2324}
2325
2326/// Stop following another user in a call.
2327async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2328 let room_id = RoomId::from_proto(request.room_id);
2329 let project_id = request.project_id.map(ProjectId::from_proto);
2330 let leader_id = request
2331 .leader_id
2332 .ok_or_else(|| anyhow!("invalid leader id"))?
2333 .into();
2334 let follower_id = session.connection_id;
2335
2336 session
2337 .db()
2338 .await
2339 .check_room_participants(room_id, leader_id, session.connection_id)
2340 .await?;
2341
2342 session
2343 .peer
2344 .forward_send(session.connection_id, leader_id, request)?;
2345
2346 if let Some(project_id) = project_id {
2347 let room = session
2348 .db()
2349 .await
2350 .unfollow(room_id, project_id, leader_id, follower_id)
2351 .await?;
2352 room_updated(&room, &session.peer);
2353 }
2354
2355 Ok(())
2356}
2357
2358/// Notify everyone following you of your current location.
2359async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2360 let room_id = RoomId::from_proto(request.room_id);
2361 let database = session.db.lock().await;
2362
2363 let connection_ids = if let Some(project_id) = request.project_id {
2364 let project_id = ProjectId::from_proto(project_id);
2365 database
2366 .project_connection_ids(project_id, session.connection_id, true)
2367 .await?
2368 } else {
2369 database
2370 .room_connection_ids(room_id, session.connection_id)
2371 .await?
2372 };
2373
2374 // For now, don't send view update messages back to that view's current leader.
2375 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2376 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2377 _ => None,
2378 });
2379
2380 for connection_id in connection_ids.iter().cloned() {
2381 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2382 session
2383 .peer
2384 .forward_send(session.connection_id, connection_id, request.clone())?;
2385 }
2386 }
2387 Ok(())
2388}
2389
2390/// Get public data about users.
2391async fn get_users(
2392 request: proto::GetUsers,
2393 response: Response<proto::GetUsers>,
2394 session: Session,
2395) -> Result<()> {
2396 let user_ids = request
2397 .user_ids
2398 .into_iter()
2399 .map(UserId::from_proto)
2400 .collect();
2401 let users = session
2402 .db()
2403 .await
2404 .get_users_by_ids(user_ids)
2405 .await?
2406 .into_iter()
2407 .map(|user| proto::User {
2408 id: user.id.to_proto(),
2409 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2410 github_login: user.github_login,
2411 })
2412 .collect();
2413 response.send(proto::UsersResponse { users })?;
2414 Ok(())
2415}
2416
2417/// Search for users (to invite) buy Github login
2418async fn fuzzy_search_users(
2419 request: proto::FuzzySearchUsers,
2420 response: Response<proto::FuzzySearchUsers>,
2421 session: Session,
2422) -> Result<()> {
2423 let query = request.query;
2424 let users = match query.len() {
2425 0 => vec![],
2426 1 | 2 => session
2427 .db()
2428 .await
2429 .get_user_by_github_login(&query)
2430 .await?
2431 .into_iter()
2432 .collect(),
2433 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2434 };
2435 let users = users
2436 .into_iter()
2437 .filter(|user| user.id != session.user_id())
2438 .map(|user| proto::User {
2439 id: user.id.to_proto(),
2440 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2441 github_login: user.github_login,
2442 })
2443 .collect();
2444 response.send(proto::UsersResponse { users })?;
2445 Ok(())
2446}
2447
2448/// Send a contact request to another user.
2449async fn request_contact(
2450 request: proto::RequestContact,
2451 response: Response<proto::RequestContact>,
2452 session: Session,
2453) -> Result<()> {
2454 let requester_id = session.user_id();
2455 let responder_id = UserId::from_proto(request.responder_id);
2456 if requester_id == responder_id {
2457 return Err(anyhow!("cannot add yourself as a contact"))?;
2458 }
2459
2460 let notifications = session
2461 .db()
2462 .await
2463 .send_contact_request(requester_id, responder_id)
2464 .await?;
2465
2466 // Update outgoing contact requests of requester
2467 let mut update = proto::UpdateContacts::default();
2468 update.outgoing_requests.push(responder_id.to_proto());
2469 for connection_id in session
2470 .connection_pool()
2471 .await
2472 .user_connection_ids(requester_id)
2473 {
2474 session.peer.send(connection_id, update.clone())?;
2475 }
2476
2477 // Update incoming contact requests of responder
2478 let mut update = proto::UpdateContacts::default();
2479 update
2480 .incoming_requests
2481 .push(proto::IncomingContactRequest {
2482 requester_id: requester_id.to_proto(),
2483 });
2484 let connection_pool = session.connection_pool().await;
2485 for connection_id in connection_pool.user_connection_ids(responder_id) {
2486 session.peer.send(connection_id, update.clone())?;
2487 }
2488
2489 send_notifications(&connection_pool, &session.peer, notifications);
2490
2491 response.send(proto::Ack {})?;
2492 Ok(())
2493}
2494
2495/// Accept or decline a contact request
2496async fn respond_to_contact_request(
2497 request: proto::RespondToContactRequest,
2498 response: Response<proto::RespondToContactRequest>,
2499 session: Session,
2500) -> Result<()> {
2501 let responder_id = session.user_id();
2502 let requester_id = UserId::from_proto(request.requester_id);
2503 let db = session.db().await;
2504 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2505 db.dismiss_contact_notification(responder_id, requester_id)
2506 .await?;
2507 } else {
2508 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2509
2510 let notifications = db
2511 .respond_to_contact_request(responder_id, requester_id, accept)
2512 .await?;
2513 let requester_busy = db.is_user_busy(requester_id).await?;
2514 let responder_busy = db.is_user_busy(responder_id).await?;
2515
2516 let pool = session.connection_pool().await;
2517 // Update responder with new contact
2518 let mut update = proto::UpdateContacts::default();
2519 if accept {
2520 update
2521 .contacts
2522 .push(contact_for_user(requester_id, requester_busy, &pool));
2523 }
2524 update
2525 .remove_incoming_requests
2526 .push(requester_id.to_proto());
2527 for connection_id in pool.user_connection_ids(responder_id) {
2528 session.peer.send(connection_id, update.clone())?;
2529 }
2530
2531 // Update requester with new contact
2532 let mut update = proto::UpdateContacts::default();
2533 if accept {
2534 update
2535 .contacts
2536 .push(contact_for_user(responder_id, responder_busy, &pool));
2537 }
2538 update
2539 .remove_outgoing_requests
2540 .push(responder_id.to_proto());
2541
2542 for connection_id in pool.user_connection_ids(requester_id) {
2543 session.peer.send(connection_id, update.clone())?;
2544 }
2545
2546 send_notifications(&pool, &session.peer, notifications);
2547 }
2548
2549 response.send(proto::Ack {})?;
2550 Ok(())
2551}
2552
2553/// Remove a contact.
2554async fn remove_contact(
2555 request: proto::RemoveContact,
2556 response: Response<proto::RemoveContact>,
2557 session: Session,
2558) -> Result<()> {
2559 let requester_id = session.user_id();
2560 let responder_id = UserId::from_proto(request.user_id);
2561 let db = session.db().await;
2562 let (contact_accepted, deleted_notification_id) =
2563 db.remove_contact(requester_id, responder_id).await?;
2564
2565 let pool = session.connection_pool().await;
2566 // Update outgoing contact requests of requester
2567 let mut update = proto::UpdateContacts::default();
2568 if contact_accepted {
2569 update.remove_contacts.push(responder_id.to_proto());
2570 } else {
2571 update
2572 .remove_outgoing_requests
2573 .push(responder_id.to_proto());
2574 }
2575 for connection_id in pool.user_connection_ids(requester_id) {
2576 session.peer.send(connection_id, update.clone())?;
2577 }
2578
2579 // Update incoming contact requests of responder
2580 let mut update = proto::UpdateContacts::default();
2581 if contact_accepted {
2582 update.remove_contacts.push(requester_id.to_proto());
2583 } else {
2584 update
2585 .remove_incoming_requests
2586 .push(requester_id.to_proto());
2587 }
2588 for connection_id in pool.user_connection_ids(responder_id) {
2589 session.peer.send(connection_id, update.clone())?;
2590 if let Some(notification_id) = deleted_notification_id {
2591 session.peer.send(
2592 connection_id,
2593 proto::DeleteNotification {
2594 notification_id: notification_id.to_proto(),
2595 },
2596 )?;
2597 }
2598 }
2599
2600 response.send(proto::Ack {})?;
2601 Ok(())
2602}
2603
2604fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2605 version.0.minor() < 139
2606}
2607
2608async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2609 let plan = session.current_plan(&session.db().await).await?;
2610
2611 session
2612 .peer
2613 .send(
2614 session.connection_id,
2615 proto::UpdateUserPlan { plan: plan.into() },
2616 )
2617 .trace_err();
2618
2619 Ok(())
2620}
2621
2622async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2623 subscribe_user_to_channels(session.user_id(), &session).await?;
2624 Ok(())
2625}
2626
2627async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2628 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2629 let mut pool = session.connection_pool().await;
2630 for membership in &channels_for_user.channel_memberships {
2631 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2632 }
2633 session.peer.send(
2634 session.connection_id,
2635 build_update_user_channels(&channels_for_user),
2636 )?;
2637 session.peer.send(
2638 session.connection_id,
2639 build_channels_update(channels_for_user),
2640 )?;
2641 Ok(())
2642}
2643
2644/// Creates a new channel.
2645async fn create_channel(
2646 request: proto::CreateChannel,
2647 response: Response<proto::CreateChannel>,
2648 session: Session,
2649) -> Result<()> {
2650 let db = session.db().await;
2651
2652 let parent_id = request.parent_id.map(ChannelId::from_proto);
2653 let (channel, membership) = db
2654 .create_channel(&request.name, parent_id, session.user_id())
2655 .await?;
2656
2657 let root_id = channel.root_id();
2658 let channel = Channel::from_model(channel);
2659
2660 response.send(proto::CreateChannelResponse {
2661 channel: Some(channel.to_proto()),
2662 parent_id: request.parent_id,
2663 })?;
2664
2665 let mut connection_pool = session.connection_pool().await;
2666 if let Some(membership) = membership {
2667 connection_pool.subscribe_to_channel(
2668 membership.user_id,
2669 membership.channel_id,
2670 membership.role,
2671 );
2672 let update = proto::UpdateUserChannels {
2673 channel_memberships: vec![proto::ChannelMembership {
2674 channel_id: membership.channel_id.to_proto(),
2675 role: membership.role.into(),
2676 }],
2677 ..Default::default()
2678 };
2679 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2680 session.peer.send(connection_id, update.clone())?;
2681 }
2682 }
2683
2684 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2685 if !role.can_see_channel(channel.visibility) {
2686 continue;
2687 }
2688
2689 let update = proto::UpdateChannels {
2690 channels: vec![channel.to_proto()],
2691 ..Default::default()
2692 };
2693 session.peer.send(connection_id, update.clone())?;
2694 }
2695
2696 Ok(())
2697}
2698
2699/// Delete a channel
2700async fn delete_channel(
2701 request: proto::DeleteChannel,
2702 response: Response<proto::DeleteChannel>,
2703 session: Session,
2704) -> Result<()> {
2705 let db = session.db().await;
2706
2707 let channel_id = request.channel_id;
2708 let (root_channel, removed_channels) = db
2709 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2710 .await?;
2711 response.send(proto::Ack {})?;
2712
2713 // Notify members of removed channels
2714 let mut update = proto::UpdateChannels::default();
2715 update
2716 .delete_channels
2717 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2718
2719 let connection_pool = session.connection_pool().await;
2720 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2721 session.peer.send(connection_id, update.clone())?;
2722 }
2723
2724 Ok(())
2725}
2726
2727/// Invite someone to join a channel.
2728async fn invite_channel_member(
2729 request: proto::InviteChannelMember,
2730 response: Response<proto::InviteChannelMember>,
2731 session: Session,
2732) -> Result<()> {
2733 let db = session.db().await;
2734 let channel_id = ChannelId::from_proto(request.channel_id);
2735 let invitee_id = UserId::from_proto(request.user_id);
2736 let InviteMemberResult {
2737 channel,
2738 notifications,
2739 } = db
2740 .invite_channel_member(
2741 channel_id,
2742 invitee_id,
2743 session.user_id(),
2744 request.role().into(),
2745 )
2746 .await?;
2747
2748 let update = proto::UpdateChannels {
2749 channel_invitations: vec![channel.to_proto()],
2750 ..Default::default()
2751 };
2752
2753 let connection_pool = session.connection_pool().await;
2754 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 send_notifications(&connection_pool, &session.peer, notifications);
2759
2760 response.send(proto::Ack {})?;
2761 Ok(())
2762}
2763
2764/// remove someone from a channel
2765async fn remove_channel_member(
2766 request: proto::RemoveChannelMember,
2767 response: Response<proto::RemoveChannelMember>,
2768 session: Session,
2769) -> Result<()> {
2770 let db = session.db().await;
2771 let channel_id = ChannelId::from_proto(request.channel_id);
2772 let member_id = UserId::from_proto(request.user_id);
2773
2774 let RemoveChannelMemberResult {
2775 membership_update,
2776 notification_id,
2777 } = db
2778 .remove_channel_member(channel_id, member_id, session.user_id())
2779 .await?;
2780
2781 let mut connection_pool = session.connection_pool().await;
2782 notify_membership_updated(
2783 &mut connection_pool,
2784 membership_update,
2785 member_id,
2786 &session.peer,
2787 );
2788 for connection_id in connection_pool.user_connection_ids(member_id) {
2789 if let Some(notification_id) = notification_id {
2790 session
2791 .peer
2792 .send(
2793 connection_id,
2794 proto::DeleteNotification {
2795 notification_id: notification_id.to_proto(),
2796 },
2797 )
2798 .trace_err();
2799 }
2800 }
2801
2802 response.send(proto::Ack {})?;
2803 Ok(())
2804}
2805
2806/// Toggle the channel between public and private.
2807/// Care is taken to maintain the invariant that public channels only descend from public channels,
2808/// (though members-only channels can appear at any point in the hierarchy).
2809async fn set_channel_visibility(
2810 request: proto::SetChannelVisibility,
2811 response: Response<proto::SetChannelVisibility>,
2812 session: Session,
2813) -> Result<()> {
2814 let db = session.db().await;
2815 let channel_id = ChannelId::from_proto(request.channel_id);
2816 let visibility = request.visibility().into();
2817
2818 let channel_model = db
2819 .set_channel_visibility(channel_id, visibility, session.user_id())
2820 .await?;
2821 let root_id = channel_model.root_id();
2822 let channel = Channel::from_model(channel_model);
2823
2824 let mut connection_pool = session.connection_pool().await;
2825 for (user_id, role) in connection_pool
2826 .channel_user_ids(root_id)
2827 .collect::<Vec<_>>()
2828 .into_iter()
2829 {
2830 let update = if role.can_see_channel(channel.visibility) {
2831 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2832 proto::UpdateChannels {
2833 channels: vec![channel.to_proto()],
2834 ..Default::default()
2835 }
2836 } else {
2837 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2838 proto::UpdateChannels {
2839 delete_channels: vec![channel.id.to_proto()],
2840 ..Default::default()
2841 }
2842 };
2843
2844 for connection_id in connection_pool.user_connection_ids(user_id) {
2845 session.peer.send(connection_id, update.clone())?;
2846 }
2847 }
2848
2849 response.send(proto::Ack {})?;
2850 Ok(())
2851}
2852
2853/// Alter the role for a user in the channel.
2854async fn set_channel_member_role(
2855 request: proto::SetChannelMemberRole,
2856 response: Response<proto::SetChannelMemberRole>,
2857 session: Session,
2858) -> Result<()> {
2859 let db = session.db().await;
2860 let channel_id = ChannelId::from_proto(request.channel_id);
2861 let member_id = UserId::from_proto(request.user_id);
2862 let result = db
2863 .set_channel_member_role(
2864 channel_id,
2865 session.user_id(),
2866 member_id,
2867 request.role().into(),
2868 )
2869 .await?;
2870
2871 match result {
2872 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2873 let mut connection_pool = session.connection_pool().await;
2874 notify_membership_updated(
2875 &mut connection_pool,
2876 membership_update,
2877 member_id,
2878 &session.peer,
2879 )
2880 }
2881 db::SetMemberRoleResult::InviteUpdated(channel) => {
2882 let update = proto::UpdateChannels {
2883 channel_invitations: vec![channel.to_proto()],
2884 ..Default::default()
2885 };
2886
2887 for connection_id in session
2888 .connection_pool()
2889 .await
2890 .user_connection_ids(member_id)
2891 {
2892 session.peer.send(connection_id, update.clone())?;
2893 }
2894 }
2895 }
2896
2897 response.send(proto::Ack {})?;
2898 Ok(())
2899}
2900
2901/// Change the name of a channel
2902async fn rename_channel(
2903 request: proto::RenameChannel,
2904 response: Response<proto::RenameChannel>,
2905 session: Session,
2906) -> Result<()> {
2907 let db = session.db().await;
2908 let channel_id = ChannelId::from_proto(request.channel_id);
2909 let channel_model = db
2910 .rename_channel(channel_id, session.user_id(), &request.name)
2911 .await?;
2912 let root_id = channel_model.root_id();
2913 let channel = Channel::from_model(channel_model);
2914
2915 response.send(proto::RenameChannelResponse {
2916 channel: Some(channel.to_proto()),
2917 })?;
2918
2919 let connection_pool = session.connection_pool().await;
2920 let update = proto::UpdateChannels {
2921 channels: vec![channel.to_proto()],
2922 ..Default::default()
2923 };
2924 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2925 if role.can_see_channel(channel.visibility) {
2926 session.peer.send(connection_id, update.clone())?;
2927 }
2928 }
2929
2930 Ok(())
2931}
2932
2933/// Move a channel to a new parent.
2934async fn move_channel(
2935 request: proto::MoveChannel,
2936 response: Response<proto::MoveChannel>,
2937 session: Session,
2938) -> Result<()> {
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940 let to = ChannelId::from_proto(request.to);
2941
2942 let (root_id, channels) = session
2943 .db()
2944 .await
2945 .move_channel(channel_id, to, session.user_id())
2946 .await?;
2947
2948 let connection_pool = session.connection_pool().await;
2949 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2950 let channels = channels
2951 .iter()
2952 .filter_map(|channel| {
2953 if role.can_see_channel(channel.visibility) {
2954 Some(channel.to_proto())
2955 } else {
2956 None
2957 }
2958 })
2959 .collect::<Vec<_>>();
2960 if channels.is_empty() {
2961 continue;
2962 }
2963
2964 let update = proto::UpdateChannels {
2965 channels,
2966 ..Default::default()
2967 };
2968
2969 session.peer.send(connection_id, update.clone())?;
2970 }
2971
2972 response.send(Ack {})?;
2973 Ok(())
2974}
2975
2976/// Get the list of channel members
2977async fn get_channel_members(
2978 request: proto::GetChannelMembers,
2979 response: Response<proto::GetChannelMembers>,
2980 session: Session,
2981) -> Result<()> {
2982 let db = session.db().await;
2983 let channel_id = ChannelId::from_proto(request.channel_id);
2984 let limit = if request.limit == 0 {
2985 u16::MAX as u64
2986 } else {
2987 request.limit
2988 };
2989 let (members, users) = db
2990 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
2991 .await?;
2992 response.send(proto::GetChannelMembersResponse { members, users })?;
2993 Ok(())
2994}
2995
2996/// Accept or decline a channel invitation.
2997async fn respond_to_channel_invite(
2998 request: proto::RespondToChannelInvite,
2999 response: Response<proto::RespondToChannelInvite>,
3000 session: Session,
3001) -> Result<()> {
3002 let db = session.db().await;
3003 let channel_id = ChannelId::from_proto(request.channel_id);
3004 let RespondToChannelInvite {
3005 membership_update,
3006 notifications,
3007 } = db
3008 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3009 .await?;
3010
3011 let mut connection_pool = session.connection_pool().await;
3012 if let Some(membership_update) = membership_update {
3013 notify_membership_updated(
3014 &mut connection_pool,
3015 membership_update,
3016 session.user_id(),
3017 &session.peer,
3018 );
3019 } else {
3020 let update = proto::UpdateChannels {
3021 remove_channel_invitations: vec![channel_id.to_proto()],
3022 ..Default::default()
3023 };
3024
3025 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3026 session.peer.send(connection_id, update.clone())?;
3027 }
3028 };
3029
3030 send_notifications(&connection_pool, &session.peer, notifications);
3031
3032 response.send(proto::Ack {})?;
3033
3034 Ok(())
3035}
3036
3037/// Join the channels' room
3038async fn join_channel(
3039 request: proto::JoinChannel,
3040 response: Response<proto::JoinChannel>,
3041 session: Session,
3042) -> Result<()> {
3043 let channel_id = ChannelId::from_proto(request.channel_id);
3044 join_channel_internal(channel_id, Box::new(response), session).await
3045}
3046
3047trait JoinChannelInternalResponse {
3048 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3049}
3050impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3051 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3052 Response::<proto::JoinChannel>::send(self, result)
3053 }
3054}
3055impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3056 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3057 Response::<proto::JoinRoom>::send(self, result)
3058 }
3059}
3060
3061async fn join_channel_internal(
3062 channel_id: ChannelId,
3063 response: Box<impl JoinChannelInternalResponse>,
3064 session: Session,
3065) -> Result<()> {
3066 let joined_room = {
3067 let mut db = session.db().await;
3068 // If zed quits without leaving the room, and the user re-opens zed before the
3069 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3070 // room they were in.
3071 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3072 tracing::info!(
3073 stale_connection_id = %connection,
3074 "cleaning up stale connection",
3075 );
3076 drop(db);
3077 leave_room_for_session(&session, connection).await?;
3078 db = session.db().await;
3079 }
3080
3081 let (joined_room, membership_updated, role) = db
3082 .join_channel(channel_id, session.user_id(), session.connection_id)
3083 .await?;
3084
3085 let live_kit_connection_info =
3086 session
3087 .app_state
3088 .live_kit_client
3089 .as_ref()
3090 .and_then(|live_kit| {
3091 let (can_publish, token) = if role == ChannelRole::Guest {
3092 (
3093 false,
3094 live_kit
3095 .guest_token(
3096 &joined_room.room.live_kit_room,
3097 &session.user_id().to_string(),
3098 )
3099 .trace_err()?,
3100 )
3101 } else {
3102 (
3103 true,
3104 live_kit
3105 .room_token(
3106 &joined_room.room.live_kit_room,
3107 &session.user_id().to_string(),
3108 )
3109 .trace_err()?,
3110 )
3111 };
3112
3113 Some(LiveKitConnectionInfo {
3114 server_url: live_kit.url().into(),
3115 token,
3116 can_publish,
3117 })
3118 });
3119
3120 response.send(proto::JoinRoomResponse {
3121 room: Some(joined_room.room.clone()),
3122 channel_id: joined_room
3123 .channel
3124 .as_ref()
3125 .map(|channel| channel.id.to_proto()),
3126 live_kit_connection_info,
3127 })?;
3128
3129 let mut connection_pool = session.connection_pool().await;
3130 if let Some(membership_updated) = membership_updated {
3131 notify_membership_updated(
3132 &mut connection_pool,
3133 membership_updated,
3134 session.user_id(),
3135 &session.peer,
3136 );
3137 }
3138
3139 room_updated(&joined_room.room, &session.peer);
3140
3141 joined_room
3142 };
3143
3144 channel_updated(
3145 &joined_room
3146 .channel
3147 .ok_or_else(|| anyhow!("channel not returned"))?,
3148 &joined_room.room,
3149 &session.peer,
3150 &*session.connection_pool().await,
3151 );
3152
3153 update_user_contacts(session.user_id(), &session).await?;
3154 Ok(())
3155}
3156
3157/// Start editing the channel notes
3158async fn join_channel_buffer(
3159 request: proto::JoinChannelBuffer,
3160 response: Response<proto::JoinChannelBuffer>,
3161 session: Session,
3162) -> Result<()> {
3163 let db = session.db().await;
3164 let channel_id = ChannelId::from_proto(request.channel_id);
3165
3166 let open_response = db
3167 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3168 .await?;
3169
3170 let collaborators = open_response.collaborators.clone();
3171 response.send(open_response)?;
3172
3173 let update = UpdateChannelBufferCollaborators {
3174 channel_id: channel_id.to_proto(),
3175 collaborators: collaborators.clone(),
3176 };
3177 channel_buffer_updated(
3178 session.connection_id,
3179 collaborators
3180 .iter()
3181 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3182 &update,
3183 &session.peer,
3184 );
3185
3186 Ok(())
3187}
3188
3189/// Edit the channel notes
3190async fn update_channel_buffer(
3191 request: proto::UpdateChannelBuffer,
3192 session: Session,
3193) -> Result<()> {
3194 let db = session.db().await;
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196
3197 let (collaborators, epoch, version) = db
3198 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3199 .await?;
3200
3201 channel_buffer_updated(
3202 session.connection_id,
3203 collaborators.clone(),
3204 &proto::UpdateChannelBuffer {
3205 channel_id: channel_id.to_proto(),
3206 operations: request.operations,
3207 },
3208 &session.peer,
3209 );
3210
3211 let pool = &*session.connection_pool().await;
3212
3213 let non_collaborators =
3214 pool.channel_connection_ids(channel_id)
3215 .filter_map(|(connection_id, _)| {
3216 if collaborators.contains(&connection_id) {
3217 None
3218 } else {
3219 Some(connection_id)
3220 }
3221 });
3222
3223 broadcast(None, non_collaborators, |peer_id| {
3224 session.peer.send(
3225 peer_id,
3226 proto::UpdateChannels {
3227 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3228 channel_id: channel_id.to_proto(),
3229 epoch: epoch as u64,
3230 version: version.clone(),
3231 }],
3232 ..Default::default()
3233 },
3234 )
3235 });
3236
3237 Ok(())
3238}
3239
3240/// Rejoin the channel notes after a connection blip
3241async fn rejoin_channel_buffers(
3242 request: proto::RejoinChannelBuffers,
3243 response: Response<proto::RejoinChannelBuffers>,
3244 session: Session,
3245) -> Result<()> {
3246 let db = session.db().await;
3247 let buffers = db
3248 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3249 .await?;
3250
3251 for rejoined_buffer in &buffers {
3252 let collaborators_to_notify = rejoined_buffer
3253 .buffer
3254 .collaborators
3255 .iter()
3256 .filter_map(|c| Some(c.peer_id?.into()));
3257 channel_buffer_updated(
3258 session.connection_id,
3259 collaborators_to_notify,
3260 &proto::UpdateChannelBufferCollaborators {
3261 channel_id: rejoined_buffer.buffer.channel_id,
3262 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3263 },
3264 &session.peer,
3265 );
3266 }
3267
3268 response.send(proto::RejoinChannelBuffersResponse {
3269 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3270 })?;
3271
3272 Ok(())
3273}
3274
3275/// Stop editing the channel notes
3276async fn leave_channel_buffer(
3277 request: proto::LeaveChannelBuffer,
3278 response: Response<proto::LeaveChannelBuffer>,
3279 session: Session,
3280) -> Result<()> {
3281 let db = session.db().await;
3282 let channel_id = ChannelId::from_proto(request.channel_id);
3283
3284 let left_buffer = db
3285 .leave_channel_buffer(channel_id, session.connection_id)
3286 .await?;
3287
3288 response.send(Ack {})?;
3289
3290 channel_buffer_updated(
3291 session.connection_id,
3292 left_buffer.connections,
3293 &proto::UpdateChannelBufferCollaborators {
3294 channel_id: channel_id.to_proto(),
3295 collaborators: left_buffer.collaborators,
3296 },
3297 &session.peer,
3298 );
3299
3300 Ok(())
3301}
3302
3303fn channel_buffer_updated<T: EnvelopedMessage>(
3304 sender_id: ConnectionId,
3305 collaborators: impl IntoIterator<Item = ConnectionId>,
3306 message: &T,
3307 peer: &Peer,
3308) {
3309 broadcast(Some(sender_id), collaborators, |peer_id| {
3310 peer.send(peer_id, message.clone())
3311 });
3312}
3313
3314fn send_notifications(
3315 connection_pool: &ConnectionPool,
3316 peer: &Peer,
3317 notifications: db::NotificationBatch,
3318) {
3319 for (user_id, notification) in notifications {
3320 for connection_id in connection_pool.user_connection_ids(user_id) {
3321 if let Err(error) = peer.send(
3322 connection_id,
3323 proto::AddNotification {
3324 notification: Some(notification.clone()),
3325 },
3326 ) {
3327 tracing::error!(
3328 "failed to send notification to {:?} {}",
3329 connection_id,
3330 error
3331 );
3332 }
3333 }
3334 }
3335}
3336
3337/// Send a message to the channel
3338async fn send_channel_message(
3339 request: proto::SendChannelMessage,
3340 response: Response<proto::SendChannelMessage>,
3341 session: Session,
3342) -> Result<()> {
3343 // Validate the message body.
3344 let body = request.body.trim().to_string();
3345 if body.len() > MAX_MESSAGE_LEN {
3346 return Err(anyhow!("message is too long"))?;
3347 }
3348 if body.is_empty() {
3349 return Err(anyhow!("message can't be blank"))?;
3350 }
3351
3352 // TODO: adjust mentions if body is trimmed
3353
3354 let timestamp = OffsetDateTime::now_utc();
3355 let nonce = request
3356 .nonce
3357 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3358
3359 let channel_id = ChannelId::from_proto(request.channel_id);
3360 let CreatedChannelMessage {
3361 message_id,
3362 participant_connection_ids,
3363 notifications,
3364 } = session
3365 .db()
3366 .await
3367 .create_channel_message(
3368 channel_id,
3369 session.user_id(),
3370 &body,
3371 &request.mentions,
3372 timestamp,
3373 nonce.clone().into(),
3374 request.reply_to_message_id.map(MessageId::from_proto),
3375 )
3376 .await?;
3377
3378 let message = proto::ChannelMessage {
3379 sender_id: session.user_id().to_proto(),
3380 id: message_id.to_proto(),
3381 body,
3382 mentions: request.mentions,
3383 timestamp: timestamp.unix_timestamp() as u64,
3384 nonce: Some(nonce),
3385 reply_to_message_id: request.reply_to_message_id,
3386 edited_at: None,
3387 };
3388 broadcast(
3389 Some(session.connection_id),
3390 participant_connection_ids.clone(),
3391 |connection| {
3392 session.peer.send(
3393 connection,
3394 proto::ChannelMessageSent {
3395 channel_id: channel_id.to_proto(),
3396 message: Some(message.clone()),
3397 },
3398 )
3399 },
3400 );
3401 response.send(proto::SendChannelMessageResponse {
3402 message: Some(message),
3403 })?;
3404
3405 let pool = &*session.connection_pool().await;
3406 let non_participants =
3407 pool.channel_connection_ids(channel_id)
3408 .filter_map(|(connection_id, _)| {
3409 if participant_connection_ids.contains(&connection_id) {
3410 None
3411 } else {
3412 Some(connection_id)
3413 }
3414 });
3415 broadcast(None, non_participants, |peer_id| {
3416 session.peer.send(
3417 peer_id,
3418 proto::UpdateChannels {
3419 latest_channel_message_ids: vec![proto::ChannelMessageId {
3420 channel_id: channel_id.to_proto(),
3421 message_id: message_id.to_proto(),
3422 }],
3423 ..Default::default()
3424 },
3425 )
3426 });
3427 send_notifications(pool, &session.peer, notifications);
3428
3429 Ok(())
3430}
3431
3432/// Delete a channel message
3433async fn remove_channel_message(
3434 request: proto::RemoveChannelMessage,
3435 response: Response<proto::RemoveChannelMessage>,
3436 session: Session,
3437) -> Result<()> {
3438 let channel_id = ChannelId::from_proto(request.channel_id);
3439 let message_id = MessageId::from_proto(request.message_id);
3440 let (connection_ids, existing_notification_ids) = session
3441 .db()
3442 .await
3443 .remove_channel_message(channel_id, message_id, session.user_id())
3444 .await?;
3445
3446 broadcast(
3447 Some(session.connection_id),
3448 connection_ids,
3449 move |connection| {
3450 session.peer.send(connection, request.clone())?;
3451
3452 for notification_id in &existing_notification_ids {
3453 session.peer.send(
3454 connection,
3455 proto::DeleteNotification {
3456 notification_id: (*notification_id).to_proto(),
3457 },
3458 )?;
3459 }
3460
3461 Ok(())
3462 },
3463 );
3464 response.send(proto::Ack {})?;
3465 Ok(())
3466}
3467
3468async fn update_channel_message(
3469 request: proto::UpdateChannelMessage,
3470 response: Response<proto::UpdateChannelMessage>,
3471 session: Session,
3472) -> Result<()> {
3473 let channel_id = ChannelId::from_proto(request.channel_id);
3474 let message_id = MessageId::from_proto(request.message_id);
3475 let updated_at = OffsetDateTime::now_utc();
3476 let UpdatedChannelMessage {
3477 message_id,
3478 participant_connection_ids,
3479 notifications,
3480 reply_to_message_id,
3481 timestamp,
3482 deleted_mention_notification_ids,
3483 updated_mention_notifications,
3484 } = session
3485 .db()
3486 .await
3487 .update_channel_message(
3488 channel_id,
3489 message_id,
3490 session.user_id(),
3491 request.body.as_str(),
3492 &request.mentions,
3493 updated_at,
3494 )
3495 .await?;
3496
3497 let nonce = request
3498 .nonce
3499 .clone()
3500 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3501
3502 let message = proto::ChannelMessage {
3503 sender_id: session.user_id().to_proto(),
3504 id: message_id.to_proto(),
3505 body: request.body.clone(),
3506 mentions: request.mentions.clone(),
3507 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3508 nonce: Some(nonce),
3509 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3510 edited_at: Some(updated_at.unix_timestamp() as u64),
3511 };
3512
3513 response.send(proto::Ack {})?;
3514
3515 let pool = &*session.connection_pool().await;
3516 broadcast(
3517 Some(session.connection_id),
3518 participant_connection_ids,
3519 |connection| {
3520 session.peer.send(
3521 connection,
3522 proto::ChannelMessageUpdate {
3523 channel_id: channel_id.to_proto(),
3524 message: Some(message.clone()),
3525 },
3526 )?;
3527
3528 for notification_id in &deleted_mention_notification_ids {
3529 session.peer.send(
3530 connection,
3531 proto::DeleteNotification {
3532 notification_id: (*notification_id).to_proto(),
3533 },
3534 )?;
3535 }
3536
3537 for notification in &updated_mention_notifications {
3538 session.peer.send(
3539 connection,
3540 proto::UpdateNotification {
3541 notification: Some(notification.clone()),
3542 },
3543 )?;
3544 }
3545
3546 Ok(())
3547 },
3548 );
3549
3550 send_notifications(pool, &session.peer, notifications);
3551
3552 Ok(())
3553}
3554
3555/// Mark a channel message as read
3556async fn acknowledge_channel_message(
3557 request: proto::AckChannelMessage,
3558 session: Session,
3559) -> Result<()> {
3560 let channel_id = ChannelId::from_proto(request.channel_id);
3561 let message_id = MessageId::from_proto(request.message_id);
3562 let notifications = session
3563 .db()
3564 .await
3565 .observe_channel_message(channel_id, session.user_id(), message_id)
3566 .await?;
3567 send_notifications(
3568 &*session.connection_pool().await,
3569 &session.peer,
3570 notifications,
3571 );
3572 Ok(())
3573}
3574
3575/// Mark a buffer version as synced
3576async fn acknowledge_buffer_version(
3577 request: proto::AckBufferOperation,
3578 session: Session,
3579) -> Result<()> {
3580 let buffer_id = BufferId::from_proto(request.buffer_id);
3581 session
3582 .db()
3583 .await
3584 .observe_buffer_version(
3585 buffer_id,
3586 session.user_id(),
3587 request.epoch as i32,
3588 &request.version,
3589 )
3590 .await?;
3591 Ok(())
3592}
3593
3594async fn count_language_model_tokens(
3595 request: proto::CountLanguageModelTokens,
3596 response: Response<proto::CountLanguageModelTokens>,
3597 session: Session,
3598 config: &Config,
3599) -> Result<()> {
3600 authorize_access_to_legacy_llm_endpoints(&session).await?;
3601
3602 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3603 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3604 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3605 };
3606
3607 session
3608 .app_state
3609 .rate_limiter
3610 .check(&*rate_limit, session.user_id())
3611 .await?;
3612
3613 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3614 Some(proto::LanguageModelProvider::Google) => {
3615 let api_key = config
3616 .google_ai_api_key
3617 .as_ref()
3618 .context("no Google AI API key configured on the server")?;
3619 google_ai::count_tokens(
3620 session.http_client.as_ref(),
3621 google_ai::API_URL,
3622 api_key,
3623 serde_json::from_str(&request.request)?,
3624 )
3625 .await?
3626 }
3627 _ => return Err(anyhow!("unsupported provider"))?,
3628 };
3629
3630 response.send(proto::CountLanguageModelTokensResponse {
3631 token_count: result.total_tokens as u32,
3632 })?;
3633
3634 Ok(())
3635}
3636
3637struct ZedProCountLanguageModelTokensRateLimit;
3638
3639impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3640 fn capacity(&self) -> usize {
3641 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3642 .ok()
3643 .and_then(|v| v.parse().ok())
3644 .unwrap_or(600) // Picked arbitrarily
3645 }
3646
3647 fn refill_duration(&self) -> chrono::Duration {
3648 chrono::Duration::hours(1)
3649 }
3650
3651 fn db_name(&self) -> &'static str {
3652 "zed-pro:count-language-model-tokens"
3653 }
3654}
3655
3656struct FreeCountLanguageModelTokensRateLimit;
3657
3658impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3659 fn capacity(&self) -> usize {
3660 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3661 .ok()
3662 .and_then(|v| v.parse().ok())
3663 .unwrap_or(600 / 10) // Picked arbitrarily
3664 }
3665
3666 fn refill_duration(&self) -> chrono::Duration {
3667 chrono::Duration::hours(1)
3668 }
3669
3670 fn db_name(&self) -> &'static str {
3671 "free:count-language-model-tokens"
3672 }
3673}
3674
3675struct ZedProComputeEmbeddingsRateLimit;
3676
3677impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3678 fn capacity(&self) -> usize {
3679 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3680 .ok()
3681 .and_then(|v| v.parse().ok())
3682 .unwrap_or(5000) // Picked arbitrarily
3683 }
3684
3685 fn refill_duration(&self) -> chrono::Duration {
3686 chrono::Duration::hours(1)
3687 }
3688
3689 fn db_name(&self) -> &'static str {
3690 "zed-pro:compute-embeddings"
3691 }
3692}
3693
3694struct FreeComputeEmbeddingsRateLimit;
3695
3696impl RateLimit for FreeComputeEmbeddingsRateLimit {
3697 fn capacity(&self) -> usize {
3698 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3699 .ok()
3700 .and_then(|v| v.parse().ok())
3701 .unwrap_or(5000 / 10) // Picked arbitrarily
3702 }
3703
3704 fn refill_duration(&self) -> chrono::Duration {
3705 chrono::Duration::hours(1)
3706 }
3707
3708 fn db_name(&self) -> &'static str {
3709 "free:compute-embeddings"
3710 }
3711}
3712
3713async fn compute_embeddings(
3714 request: proto::ComputeEmbeddings,
3715 response: Response<proto::ComputeEmbeddings>,
3716 session: Session,
3717 api_key: Option<Arc<str>>,
3718) -> Result<()> {
3719 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3720 authorize_access_to_legacy_llm_endpoints(&session).await?;
3721
3722 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3723 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3724 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3725 };
3726
3727 session
3728 .app_state
3729 .rate_limiter
3730 .check(&*rate_limit, session.user_id())
3731 .await?;
3732
3733 let embeddings = match request.model.as_str() {
3734 "openai/text-embedding-3-small" => {
3735 open_ai::embed(
3736 session.http_client.as_ref(),
3737 OPEN_AI_API_URL,
3738 &api_key,
3739 OpenAiEmbeddingModel::TextEmbedding3Small,
3740 request.texts.iter().map(|text| text.as_str()),
3741 )
3742 .await?
3743 }
3744 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3745 };
3746
3747 let embeddings = request
3748 .texts
3749 .iter()
3750 .map(|text| {
3751 let mut hasher = sha2::Sha256::new();
3752 hasher.update(text.as_bytes());
3753 let result = hasher.finalize();
3754 result.to_vec()
3755 })
3756 .zip(
3757 embeddings
3758 .data
3759 .into_iter()
3760 .map(|embedding| embedding.embedding),
3761 )
3762 .collect::<HashMap<_, _>>();
3763
3764 let db = session.db().await;
3765 db.save_embeddings(&request.model, &embeddings)
3766 .await
3767 .context("failed to save embeddings")
3768 .trace_err();
3769
3770 response.send(proto::ComputeEmbeddingsResponse {
3771 embeddings: embeddings
3772 .into_iter()
3773 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3774 .collect(),
3775 })?;
3776 Ok(())
3777}
3778
3779async fn get_cached_embeddings(
3780 request: proto::GetCachedEmbeddings,
3781 response: Response<proto::GetCachedEmbeddings>,
3782 session: Session,
3783) -> Result<()> {
3784 authorize_access_to_legacy_llm_endpoints(&session).await?;
3785
3786 let db = session.db().await;
3787 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3788
3789 response.send(proto::GetCachedEmbeddingsResponse {
3790 embeddings: embeddings
3791 .into_iter()
3792 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3793 .collect(),
3794 })?;
3795 Ok(())
3796}
3797
3798/// This is leftover from before the LLM service.
3799///
3800/// The endpoints protected by this check will be moved there eventually.
3801async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3802 if session.is_staff() {
3803 Ok(())
3804 } else {
3805 Err(anyhow!("permission denied"))?
3806 }
3807}
3808
3809/// Get a Supermaven API key for the user
3810async fn get_supermaven_api_key(
3811 _request: proto::GetSupermavenApiKey,
3812 response: Response<proto::GetSupermavenApiKey>,
3813 session: Session,
3814) -> Result<()> {
3815 let user_id: String = session.user_id().to_string();
3816 if !session.is_staff() {
3817 return Err(anyhow!("supermaven not enabled for this account"))?;
3818 }
3819
3820 let email = session
3821 .email()
3822 .ok_or_else(|| anyhow!("user must have an email"))?;
3823
3824 let supermaven_admin_api = session
3825 .supermaven_client
3826 .as_ref()
3827 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3828
3829 let result = supermaven_admin_api
3830 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3831 .await?;
3832
3833 response.send(proto::GetSupermavenApiKeyResponse {
3834 api_key: result.api_key,
3835 })?;
3836
3837 Ok(())
3838}
3839
3840/// Start receiving chat updates for a channel
3841async fn join_channel_chat(
3842 request: proto::JoinChannelChat,
3843 response: Response<proto::JoinChannelChat>,
3844 session: Session,
3845) -> Result<()> {
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847
3848 let db = session.db().await;
3849 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3850 .await?;
3851 let messages = db
3852 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3853 .await?;
3854 response.send(proto::JoinChannelChatResponse {
3855 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3856 messages,
3857 })?;
3858 Ok(())
3859}
3860
3861/// Stop receiving chat updates for a channel
3862async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3863 let channel_id = ChannelId::from_proto(request.channel_id);
3864 session
3865 .db()
3866 .await
3867 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3868 .await?;
3869 Ok(())
3870}
3871
3872/// Retrieve the chat history for a channel
3873async fn get_channel_messages(
3874 request: proto::GetChannelMessages,
3875 response: Response<proto::GetChannelMessages>,
3876 session: Session,
3877) -> Result<()> {
3878 let channel_id = ChannelId::from_proto(request.channel_id);
3879 let messages = session
3880 .db()
3881 .await
3882 .get_channel_messages(
3883 channel_id,
3884 session.user_id(),
3885 MESSAGE_COUNT_PER_PAGE,
3886 Some(MessageId::from_proto(request.before_message_id)),
3887 )
3888 .await?;
3889 response.send(proto::GetChannelMessagesResponse {
3890 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891 messages,
3892 })?;
3893 Ok(())
3894}
3895
3896/// Retrieve specific chat messages
3897async fn get_channel_messages_by_id(
3898 request: proto::GetChannelMessagesById,
3899 response: Response<proto::GetChannelMessagesById>,
3900 session: Session,
3901) -> Result<()> {
3902 let message_ids = request
3903 .message_ids
3904 .iter()
3905 .map(|id| MessageId::from_proto(*id))
3906 .collect::<Vec<_>>();
3907 let messages = session
3908 .db()
3909 .await
3910 .get_channel_messages_by_id(session.user_id(), &message_ids)
3911 .await?;
3912 response.send(proto::GetChannelMessagesResponse {
3913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914 messages,
3915 })?;
3916 Ok(())
3917}
3918
3919/// Retrieve the current users notifications
3920async fn get_notifications(
3921 request: proto::GetNotifications,
3922 response: Response<proto::GetNotifications>,
3923 session: Session,
3924) -> Result<()> {
3925 let notifications = session
3926 .db()
3927 .await
3928 .get_notifications(
3929 session.user_id(),
3930 NOTIFICATION_COUNT_PER_PAGE,
3931 request.before_id.map(db::NotificationId::from_proto),
3932 )
3933 .await?;
3934 response.send(proto::GetNotificationsResponse {
3935 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3936 notifications,
3937 })?;
3938 Ok(())
3939}
3940
3941/// Mark notifications as read
3942async fn mark_notification_as_read(
3943 request: proto::MarkNotificationRead,
3944 response: Response<proto::MarkNotificationRead>,
3945 session: Session,
3946) -> Result<()> {
3947 let database = &session.db().await;
3948 let notifications = database
3949 .mark_notification_as_read_by_id(
3950 session.user_id(),
3951 NotificationId::from_proto(request.notification_id),
3952 )
3953 .await?;
3954 send_notifications(
3955 &*session.connection_pool().await,
3956 &session.peer,
3957 notifications,
3958 );
3959 response.send(proto::Ack {})?;
3960 Ok(())
3961}
3962
3963/// Get the current users information
3964async fn get_private_user_info(
3965 _request: proto::GetPrivateUserInfo,
3966 response: Response<proto::GetPrivateUserInfo>,
3967 session: Session,
3968) -> Result<()> {
3969 let db = session.db().await;
3970
3971 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3972 let user = db
3973 .get_user_by_id(session.user_id())
3974 .await?
3975 .ok_or_else(|| anyhow!("user not found"))?;
3976 let flags = db.get_user_flags(session.user_id()).await?;
3977
3978 response.send(proto::GetPrivateUserInfoResponse {
3979 metrics_id,
3980 staff: user.admin,
3981 flags,
3982 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3983 })?;
3984 Ok(())
3985}
3986
3987/// Accept the terms of service (tos) on behalf of the current user
3988async fn accept_terms_of_service(
3989 _request: proto::AcceptTermsOfService,
3990 response: Response<proto::AcceptTermsOfService>,
3991 session: Session,
3992) -> Result<()> {
3993 let db = session.db().await;
3994
3995 let accepted_tos_at = Utc::now();
3996 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3997 .await?;
3998
3999 response.send(proto::AcceptTermsOfServiceResponse {
4000 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4001 })?;
4002 Ok(())
4003}
4004
4005/// The minimum account age an account must have in order to use the LLM service.
4006const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4007
4008async fn get_llm_api_token(
4009 _request: proto::GetLlmToken,
4010 response: Response<proto::GetLlmToken>,
4011 session: Session,
4012) -> Result<()> {
4013 let db = session.db().await;
4014
4015 let flags = db.get_user_flags(session.user_id()).await?;
4016 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4017 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4018
4019 if !session.is_staff() && !has_language_models_feature_flag {
4020 Err(anyhow!("permission denied"))?
4021 }
4022
4023 let user_id = session.user_id();
4024 let user = db
4025 .get_user_by_id(user_id)
4026 .await?
4027 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4028
4029 if user.accepted_tos_at.is_none() {
4030 Err(anyhow!("terms of service not accepted"))?
4031 }
4032
4033 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4034
4035 let bypass_account_age_check =
4036 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4037 if !bypass_account_age_check {
4038 let mut account_created_at = user.created_at;
4039 if let Some(github_created_at) = user.github_user_created_at {
4040 account_created_at = account_created_at.min(github_created_at);
4041 }
4042 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4043 Err(anyhow!("account too young"))?
4044 }
4045 }
4046
4047 let billing_preferences = db.get_billing_preferences(user.id).await?;
4048
4049 let token = LlmTokenClaims::create(
4050 &user,
4051 session.is_staff(),
4052 billing_preferences,
4053 has_llm_closed_beta_feature_flag,
4054 has_llm_subscription,
4055 session.current_plan(&db).await?,
4056 &session.app_state.config,
4057 )?;
4058 response.send(proto::GetLlmTokenResponse { token })?;
4059 Ok(())
4060}
4061
4062fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4063 let message = match message {
4064 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4065 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4066 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4067 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4068 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4069 code: frame.code.into(),
4070 reason: frame.reason,
4071 })),
4072 // We should never receive a frame while reading the message, according
4073 // to the `tungstenite` maintainers:
4074 //
4075 // > It cannot occur when you read messages from the WebSocket, but it
4076 // > can be used when you want to send the raw frames (e.g. you want to
4077 // > send the frames to the WebSocket without composing the full message first).
4078 // >
4079 // > — https://github.com/snapview/tungstenite-rs/issues/268
4080 TungsteniteMessage::Frame(_) => {
4081 bail!("received an unexpected frame while reading the message")
4082 }
4083 };
4084
4085 Ok(message)
4086}
4087
4088fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4089 match message {
4090 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4091 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4092 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4093 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4094 AxumMessage::Close(frame) => {
4095 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4096 code: frame.code.into(),
4097 reason: frame.reason,
4098 }))
4099 }
4100 }
4101}
4102
4103fn notify_membership_updated(
4104 connection_pool: &mut ConnectionPool,
4105 result: MembershipUpdated,
4106 user_id: UserId,
4107 peer: &Peer,
4108) {
4109 for membership in &result.new_channels.channel_memberships {
4110 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4111 }
4112 for channel_id in &result.removed_channels {
4113 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4114 }
4115
4116 let user_channels_update = proto::UpdateUserChannels {
4117 channel_memberships: result
4118 .new_channels
4119 .channel_memberships
4120 .iter()
4121 .map(|cm| proto::ChannelMembership {
4122 channel_id: cm.channel_id.to_proto(),
4123 role: cm.role.into(),
4124 })
4125 .collect(),
4126 ..Default::default()
4127 };
4128
4129 let mut update = build_channels_update(result.new_channels);
4130 update.delete_channels = result
4131 .removed_channels
4132 .into_iter()
4133 .map(|id| id.to_proto())
4134 .collect();
4135 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4136
4137 for connection_id in connection_pool.user_connection_ids(user_id) {
4138 peer.send(connection_id, user_channels_update.clone())
4139 .trace_err();
4140 peer.send(connection_id, update.clone()).trace_err();
4141 }
4142}
4143
4144fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4145 proto::UpdateUserChannels {
4146 channel_memberships: channels
4147 .channel_memberships
4148 .iter()
4149 .map(|m| proto::ChannelMembership {
4150 channel_id: m.channel_id.to_proto(),
4151 role: m.role.into(),
4152 })
4153 .collect(),
4154 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4155 observed_channel_message_id: channels.observed_channel_messages.clone(),
4156 }
4157}
4158
4159fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4160 let mut update = proto::UpdateChannels::default();
4161
4162 for channel in channels.channels {
4163 update.channels.push(channel.to_proto());
4164 }
4165
4166 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4167 update.latest_channel_message_ids = channels.latest_channel_messages;
4168
4169 for (channel_id, participants) in channels.channel_participants {
4170 update
4171 .channel_participants
4172 .push(proto::ChannelParticipants {
4173 channel_id: channel_id.to_proto(),
4174 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4175 });
4176 }
4177
4178 for channel in channels.invited_channels {
4179 update.channel_invitations.push(channel.to_proto());
4180 }
4181
4182 update
4183}
4184
4185fn build_initial_contacts_update(
4186 contacts: Vec<db::Contact>,
4187 pool: &ConnectionPool,
4188) -> proto::UpdateContacts {
4189 let mut update = proto::UpdateContacts::default();
4190
4191 for contact in contacts {
4192 match contact {
4193 db::Contact::Accepted { user_id, busy } => {
4194 update.contacts.push(contact_for_user(user_id, busy, pool));
4195 }
4196 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4197 db::Contact::Incoming { user_id } => {
4198 update
4199 .incoming_requests
4200 .push(proto::IncomingContactRequest {
4201 requester_id: user_id.to_proto(),
4202 })
4203 }
4204 }
4205 }
4206
4207 update
4208}
4209
4210fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4211 proto::Contact {
4212 user_id: user_id.to_proto(),
4213 online: pool.is_user_online(user_id),
4214 busy,
4215 }
4216}
4217
4218fn room_updated(room: &proto::Room, peer: &Peer) {
4219 broadcast(
4220 None,
4221 room.participants
4222 .iter()
4223 .filter_map(|participant| Some(participant.peer_id?.into())),
4224 |peer_id| {
4225 peer.send(
4226 peer_id,
4227 proto::RoomUpdated {
4228 room: Some(room.clone()),
4229 },
4230 )
4231 },
4232 );
4233}
4234
4235fn channel_updated(
4236 channel: &db::channel::Model,
4237 room: &proto::Room,
4238 peer: &Peer,
4239 pool: &ConnectionPool,
4240) {
4241 let participants = room
4242 .participants
4243 .iter()
4244 .map(|p| p.user_id)
4245 .collect::<Vec<_>>();
4246
4247 broadcast(
4248 None,
4249 pool.channel_connection_ids(channel.root_id())
4250 .filter_map(|(channel_id, role)| {
4251 role.can_see_channel(channel.visibility)
4252 .then_some(channel_id)
4253 }),
4254 |peer_id| {
4255 peer.send(
4256 peer_id,
4257 proto::UpdateChannels {
4258 channel_participants: vec![proto::ChannelParticipants {
4259 channel_id: channel.id.to_proto(),
4260 participant_user_ids: participants.clone(),
4261 }],
4262 ..Default::default()
4263 },
4264 )
4265 },
4266 );
4267}
4268
4269async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4270 let db = session.db().await;
4271
4272 let contacts = db.get_contacts(user_id).await?;
4273 let busy = db.is_user_busy(user_id).await?;
4274
4275 let pool = session.connection_pool().await;
4276 let updated_contact = contact_for_user(user_id, busy, &pool);
4277 for contact in contacts {
4278 if let db::Contact::Accepted {
4279 user_id: contact_user_id,
4280 ..
4281 } = contact
4282 {
4283 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4284 session
4285 .peer
4286 .send(
4287 contact_conn_id,
4288 proto::UpdateContacts {
4289 contacts: vec![updated_contact.clone()],
4290 remove_contacts: Default::default(),
4291 incoming_requests: Default::default(),
4292 remove_incoming_requests: Default::default(),
4293 outgoing_requests: Default::default(),
4294 remove_outgoing_requests: Default::default(),
4295 },
4296 )
4297 .trace_err();
4298 }
4299 }
4300 }
4301 Ok(())
4302}
4303
4304async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4305 let mut contacts_to_update = HashSet::default();
4306
4307 let room_id;
4308 let canceled_calls_to_user_ids;
4309 let live_kit_room;
4310 let delete_live_kit_room;
4311 let room;
4312 let channel;
4313
4314 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4315 contacts_to_update.insert(session.user_id());
4316
4317 for project in left_room.left_projects.values() {
4318 project_left(project, session);
4319 }
4320
4321 room_id = RoomId::from_proto(left_room.room.id);
4322 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4323 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4324 delete_live_kit_room = left_room.deleted;
4325 room = mem::take(&mut left_room.room);
4326 channel = mem::take(&mut left_room.channel);
4327
4328 room_updated(&room, &session.peer);
4329 } else {
4330 return Ok(());
4331 }
4332
4333 if let Some(channel) = channel {
4334 channel_updated(
4335 &channel,
4336 &room,
4337 &session.peer,
4338 &*session.connection_pool().await,
4339 );
4340 }
4341
4342 {
4343 let pool = session.connection_pool().await;
4344 for canceled_user_id in canceled_calls_to_user_ids {
4345 for connection_id in pool.user_connection_ids(canceled_user_id) {
4346 session
4347 .peer
4348 .send(
4349 connection_id,
4350 proto::CallCanceled {
4351 room_id: room_id.to_proto(),
4352 },
4353 )
4354 .trace_err();
4355 }
4356 contacts_to_update.insert(canceled_user_id);
4357 }
4358 }
4359
4360 for contact_user_id in contacts_to_update {
4361 update_user_contacts(contact_user_id, session).await?;
4362 }
4363
4364 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
4365 live_kit
4366 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4367 .await
4368 .trace_err();
4369
4370 if delete_live_kit_room {
4371 live_kit.delete_room(live_kit_room).await.trace_err();
4372 }
4373 }
4374
4375 Ok(())
4376}
4377
4378async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4379 let left_channel_buffers = session
4380 .db()
4381 .await
4382 .leave_channel_buffers(session.connection_id)
4383 .await?;
4384
4385 for left_buffer in left_channel_buffers {
4386 channel_buffer_updated(
4387 session.connection_id,
4388 left_buffer.connections,
4389 &proto::UpdateChannelBufferCollaborators {
4390 channel_id: left_buffer.channel_id.to_proto(),
4391 collaborators: left_buffer.collaborators,
4392 },
4393 &session.peer,
4394 );
4395 }
4396
4397 Ok(())
4398}
4399
4400fn project_left(project: &db::LeftProject, session: &Session) {
4401 for connection_id in &project.connection_ids {
4402 if project.should_unshare {
4403 session
4404 .peer
4405 .send(
4406 *connection_id,
4407 proto::UnshareProject {
4408 project_id: project.id.to_proto(),
4409 },
4410 )
4411 .trace_err();
4412 } else {
4413 session
4414 .peer
4415 .send(
4416 *connection_id,
4417 proto::RemoveProjectCollaborator {
4418 project_id: project.id.to_proto(),
4419 peer_id: Some(session.connection_id.into()),
4420 },
4421 )
4422 .trace_err();
4423 }
4424 }
4425}
4426
4427pub trait ResultExt {
4428 type Ok;
4429
4430 fn trace_err(self) -> Option<Self::Ok>;
4431}
4432
4433impl<T, E> ResultExt for Result<T, E>
4434where
4435 E: std::fmt::Debug,
4436{
4437 type Ok = T;
4438
4439 #[track_caller]
4440 fn trace_err(self) -> Option<T> {
4441 match self {
4442 Ok(value) => Some(value),
4443 Err(error) => {
4444 tracing::error!("{:?}", error);
4445 None
4446 }
4447 }
4448 }
4449}