1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
313 .add_request_handler(
314 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
315 )
316 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
317 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
322 .add_request_handler(
323 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
324 )
325 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
326 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
327 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
328 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
330 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
340 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
341 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
343 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
344 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
345 .add_message_handler(create_buffer_for_peer)
346 .add_request_handler(update_buffer)
347 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
348 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
352 .add_request_handler(get_users)
353 .add_request_handler(fuzzy_search_users)
354 .add_request_handler(request_contact)
355 .add_request_handler(remove_contact)
356 .add_request_handler(respond_to_contact_request)
357 .add_message_handler(subscribe_to_channels)
358 .add_request_handler(create_channel)
359 .add_request_handler(delete_channel)
360 .add_request_handler(invite_channel_member)
361 .add_request_handler(remove_channel_member)
362 .add_request_handler(set_channel_member_role)
363 .add_request_handler(set_channel_visibility)
364 .add_request_handler(rename_channel)
365 .add_request_handler(join_channel_buffer)
366 .add_request_handler(leave_channel_buffer)
367 .add_message_handler(update_channel_buffer)
368 .add_request_handler(rejoin_channel_buffers)
369 .add_request_handler(get_channel_members)
370 .add_request_handler(respond_to_channel_invite)
371 .add_request_handler(join_channel)
372 .add_request_handler(join_channel_chat)
373 .add_message_handler(leave_channel_chat)
374 .add_request_handler(send_channel_message)
375 .add_request_handler(remove_channel_message)
376 .add_request_handler(update_channel_message)
377 .add_request_handler(get_channel_messages)
378 .add_request_handler(get_channel_messages_by_id)
379 .add_request_handler(get_notifications)
380 .add_request_handler(mark_notification_as_read)
381 .add_request_handler(move_channel)
382 .add_request_handler(follow)
383 .add_message_handler(unfollow)
384 .add_message_handler(update_followers)
385 .add_request_handler(get_private_user_info)
386 .add_request_handler(get_llm_api_token)
387 .add_request_handler(accept_terms_of_service)
388 .add_message_handler(acknowledge_channel_message)
389 .add_message_handler(acknowledge_buffer_version)
390 .add_request_handler(get_supermaven_api_key)
391 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
392 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
394 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
395 .add_message_handler(update_context)
396 .add_request_handler({
397 let app_state = app_state.clone();
398 move |request, response, session| {
399 let app_state = app_state.clone();
400 async move {
401 count_language_model_tokens(request, response, session, &app_state.config)
402 .await
403 }
404 }
405 })
406 .add_request_handler(get_cached_embeddings)
407 .add_request_handler({
408 let app_state = app_state.clone();
409 move |request, response, session| {
410 compute_embeddings(
411 request,
412 response,
413 session,
414 app_state.config.openai_api_key.clone(),
415 )
416 }
417 });
418
419 Arc::new(server)
420 }
421
422 pub async fn start(&self) -> Result<()> {
423 let server_id = *self.id.lock();
424 let app_state = self.app_state.clone();
425 let peer = self.peer.clone();
426 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
427 let pool = self.connection_pool.clone();
428 let livekit_client = self.app_state.livekit_client.clone();
429
430 let span = info_span!("start server");
431 self.app_state.executor.spawn_detached(
432 async move {
433 tracing::info!("waiting for cleanup timeout");
434 timeout.await;
435 tracing::info!("cleanup timeout expired, retrieving stale rooms");
436 if let Some((room_ids, channel_ids)) = app_state
437 .db
438 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
439 .await
440 .trace_err()
441 {
442 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
443 tracing::info!(
444 stale_channel_buffer_count = channel_ids.len(),
445 "retrieved stale channel buffers"
446 );
447
448 for channel_id in channel_ids {
449 if let Some(refreshed_channel_buffer) = app_state
450 .db
451 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
452 .await
453 .trace_err()
454 {
455 for connection_id in refreshed_channel_buffer.connection_ids {
456 peer.send(
457 connection_id,
458 proto::UpdateChannelBufferCollaborators {
459 channel_id: channel_id.to_proto(),
460 collaborators: refreshed_channel_buffer
461 .collaborators
462 .clone(),
463 },
464 )
465 .trace_err();
466 }
467 }
468 }
469
470 for room_id in room_ids {
471 let mut contacts_to_update = HashSet::default();
472 let mut canceled_calls_to_user_ids = Vec::new();
473 let mut livekit_room = String::new();
474 let mut delete_livekit_room = false;
475
476 if let Some(mut refreshed_room) = app_state
477 .db
478 .clear_stale_room_participants(room_id, server_id)
479 .await
480 .trace_err()
481 {
482 tracing::info!(
483 room_id = room_id.0,
484 new_participant_count = refreshed_room.room.participants.len(),
485 "refreshed room"
486 );
487 room_updated(&refreshed_room.room, &peer);
488 if let Some(channel) = refreshed_room.channel.as_ref() {
489 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
490 }
491 contacts_to_update
492 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
493 contacts_to_update
494 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
495 canceled_calls_to_user_ids =
496 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
497 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
498 delete_livekit_room = refreshed_room.room.participants.is_empty();
499 }
500
501 {
502 let pool = pool.lock();
503 for canceled_user_id in canceled_calls_to_user_ids {
504 for connection_id in pool.user_connection_ids(canceled_user_id) {
505 peer.send(
506 connection_id,
507 proto::CallCanceled {
508 room_id: room_id.to_proto(),
509 },
510 )
511 .trace_err();
512 }
513 }
514 }
515
516 for user_id in contacts_to_update {
517 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
518 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
519 if let Some((busy, contacts)) = busy.zip(contacts) {
520 let pool = pool.lock();
521 let updated_contact = contact_for_user(user_id, busy, &pool);
522 for contact in contacts {
523 if let db::Contact::Accepted {
524 user_id: contact_user_id,
525 ..
526 } = contact
527 {
528 for contact_conn_id in
529 pool.user_connection_ids(contact_user_id)
530 {
531 peer.send(
532 contact_conn_id,
533 proto::UpdateContacts {
534 contacts: vec![updated_contact.clone()],
535 remove_contacts: Default::default(),
536 incoming_requests: Default::default(),
537 remove_incoming_requests: Default::default(),
538 outgoing_requests: Default::default(),
539 remove_outgoing_requests: Default::default(),
540 },
541 )
542 .trace_err();
543 }
544 }
545 }
546 }
547 }
548
549 if let Some(live_kit) = livekit_client.as_ref() {
550 if delete_livekit_room {
551 live_kit.delete_room(livekit_room).await.trace_err();
552 }
553 }
554 }
555 }
556
557 app_state
558 .db
559 .delete_stale_servers(&app_state.config.zed_environment, server_id)
560 .await
561 .trace_err();
562 }
563 .instrument(span),
564 );
565 Ok(())
566 }
567
568 pub fn teardown(&self) {
569 self.peer.teardown();
570 self.connection_pool.lock().reset();
571 let _ = self.teardown.send(true);
572 }
573
574 #[cfg(test)]
575 pub fn reset(&self, id: ServerId) {
576 self.teardown();
577 *self.id.lock() = id;
578 self.peer.reset(id.0 as u32);
579 let _ = self.teardown.send(false);
580 }
581
582 #[cfg(test)]
583 pub fn id(&self) -> ServerId {
584 *self.id.lock()
585 }
586
587 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
588 where
589 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
590 Fut: 'static + Send + Future<Output = Result<()>>,
591 M: EnvelopedMessage,
592 {
593 let prev_handler = self.handlers.insert(
594 TypeId::of::<M>(),
595 Box::new(move |envelope, session| {
596 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
597 let received_at = envelope.received_at;
598 tracing::info!("message received");
599 let start_time = Instant::now();
600 let future = (handler)(*envelope, session);
601 async move {
602 let result = future.await;
603 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
604 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
605 let queue_duration_ms = total_duration_ms - processing_duration_ms;
606 let payload_type = M::NAME;
607
608 match result {
609 Err(error) => {
610 tracing::error!(
611 ?error,
612 total_duration_ms,
613 processing_duration_ms,
614 queue_duration_ms,
615 payload_type,
616 "error handling message"
617 )
618 }
619 Ok(()) => tracing::info!(
620 total_duration_ms,
621 processing_duration_ms,
622 queue_duration_ms,
623 "finished handling message"
624 ),
625 }
626 }
627 .boxed()
628 }),
629 );
630 if prev_handler.is_some() {
631 panic!("registered a handler for the same message twice");
632 }
633 self
634 }
635
636 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
637 where
638 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
639 Fut: 'static + Send + Future<Output = Result<()>>,
640 M: EnvelopedMessage,
641 {
642 self.add_handler(move |envelope, session| handler(envelope.payload, session));
643 self
644 }
645
646 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
649 Fut: Send + Future<Output = Result<()>>,
650 M: RequestMessage,
651 {
652 let handler = Arc::new(handler);
653 self.add_handler(move |envelope, session| {
654 let receipt = envelope.receipt();
655 let handler = handler.clone();
656 async move {
657 let peer = session.peer.clone();
658 let responded = Arc::new(AtomicBool::default());
659 let response = Response {
660 peer: peer.clone(),
661 responded: responded.clone(),
662 receipt,
663 };
664 match (handler)(envelope.payload, response, session).await {
665 Ok(()) => {
666 if responded.load(std::sync::atomic::Ordering::SeqCst) {
667 Ok(())
668 } else {
669 Err(anyhow!("handler did not send a response"))?
670 }
671 }
672 Err(error) => {
673 let proto_err = match &error {
674 Error::Internal(err) => err.to_proto(),
675 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
676 };
677 peer.respond_with_error(receipt, proto_err)?;
678 Err(error)
679 }
680 }
681 }
682 })
683 }
684
685 #[allow(clippy::too_many_arguments)]
686 pub fn handle_connection(
687 self: &Arc<Self>,
688 connection: Connection,
689 address: String,
690 principal: Principal,
691 zed_version: ZedVersion,
692 geoip_country_code: Option<String>,
693 system_id: Option<String>,
694 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
695 executor: Executor,
696 ) -> impl Future<Output = ()> {
697 let this = self.clone();
698 let span = info_span!("handle connection", %address,
699 connection_id=field::Empty,
700 user_id=field::Empty,
701 login=field::Empty,
702 impersonator=field::Empty,
703 geoip_country_code=field::Empty
704 );
705 principal.update_span(&span);
706 if let Some(country_code) = geoip_country_code.as_ref() {
707 span.record("geoip_country_code", country_code);
708 }
709
710 let mut teardown = self.teardown.subscribe();
711 async move {
712 if *teardown.borrow() {
713 tracing::error!("server is tearing down");
714 return
715 }
716 let (connection_id, handle_io, mut incoming_rx) = this
717 .peer
718 .add_connection(connection, {
719 let executor = executor.clone();
720 move |duration| executor.sleep(duration)
721 });
722 tracing::Span::current().record("connection_id", format!("{}", connection_id));
723
724 tracing::info!("connection opened");
725
726 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
727 let http_client = match ReqwestClient::user_agent(&user_agent) {
728 Ok(http_client) => Arc::new(http_client),
729 Err(error) => {
730 tracing::error!(?error, "failed to create HTTP client");
731 return;
732 }
733 };
734
735 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
736 supermaven_admin_api_key.to_string(),
737 http_client.clone(),
738 )));
739
740 let session = Session {
741 principal: principal.clone(),
742 connection_id,
743 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
744 peer: this.peer.clone(),
745 connection_pool: this.connection_pool.clone(),
746 app_state: this.app_state.clone(),
747 http_client,
748 geoip_country_code,
749 system_id,
750 _executor: executor.clone(),
751 supermaven_client,
752 };
753
754 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
755 tracing::error!(?error, "failed to send initial client update");
756 return;
757 }
758
759 let handle_io = handle_io.fuse();
760 futures::pin_mut!(handle_io);
761
762 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
763 // This prevents deadlocks when e.g., client A performs a request to client B and
764 // client B performs a request to client A. If both clients stop processing further
765 // messages until their respective request completes, they won't have a chance to
766 // respond to the other client's request and cause a deadlock.
767 //
768 // This arrangement ensures we will attempt to process earlier messages first, but fall
769 // back to processing messages arrived later in the spirit of making progress.
770 let mut foreground_message_handlers = FuturesUnordered::new();
771 let concurrent_handlers = Arc::new(Semaphore::new(256));
772 loop {
773 let next_message = async {
774 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
775 let message = incoming_rx.next().await;
776 (permit, message)
777 }.fuse();
778 futures::pin_mut!(next_message);
779 futures::select_biased! {
780 _ = teardown.changed().fuse() => return,
781 result = handle_io => {
782 if let Err(error) = result {
783 tracing::error!(?error, "error handling I/O");
784 }
785 break;
786 }
787 _ = foreground_message_handlers.next() => {}
788 next_message = next_message => {
789 let (permit, message) = next_message;
790 if let Some(message) = message {
791 let type_name = message.payload_type_name();
792 // note: we copy all the fields from the parent span so we can query them in the logs.
793 // (https://github.com/tokio-rs/tracing/issues/2670).
794 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
795 user_id=field::Empty,
796 login=field::Empty,
797 impersonator=field::Empty,
798 );
799 principal.update_span(&span);
800 let span_enter = span.enter();
801 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
802 let is_background = message.is_background();
803 let handle_message = (handler)(message, session.clone());
804 drop(span_enter);
805
806 let handle_message = async move {
807 handle_message.await;
808 drop(permit);
809 }.instrument(span);
810 if is_background {
811 executor.spawn_detached(handle_message);
812 } else {
813 foreground_message_handlers.push(handle_message);
814 }
815 } else {
816 tracing::error!("no message handler");
817 }
818 } else {
819 tracing::info!("connection closed");
820 break;
821 }
822 }
823 }
824 }
825
826 drop(foreground_message_handlers);
827 tracing::info!("signing out");
828 if let Err(error) = connection_lost(session, teardown, executor).await {
829 tracing::error!(?error, "error signing out");
830 }
831
832 }.instrument(span)
833 }
834
835 async fn send_initial_client_update(
836 &self,
837 connection_id: ConnectionId,
838 principal: &Principal,
839 zed_version: ZedVersion,
840 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
841 session: &Session,
842 ) -> Result<()> {
843 self.peer.send(
844 connection_id,
845 proto::Hello {
846 peer_id: Some(connection_id.into()),
847 },
848 )?;
849 tracing::info!("sent hello message");
850 if let Some(send_connection_id) = send_connection_id.take() {
851 let _ = send_connection_id.send(connection_id);
852 }
853
854 match principal {
855 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
856 if !user.connected_once {
857 self.peer.send(connection_id, proto::ShowContacts {})?;
858 self.app_state
859 .db
860 .set_user_connected_once(user.id, true)
861 .await?;
862 }
863
864 update_user_plan(user.id, session).await?;
865
866 let contacts = self.app_state.db.get_contacts(user.id).await?;
867
868 {
869 let mut pool = self.connection_pool.lock();
870 pool.add_connection(connection_id, user.id, user.admin, zed_version);
871 self.peer.send(
872 connection_id,
873 build_initial_contacts_update(contacts, &pool),
874 )?;
875 }
876
877 if should_auto_subscribe_to_channels(zed_version) {
878 subscribe_user_to_channels(user.id, session).await?;
879 }
880
881 if let Some(incoming_call) =
882 self.app_state.db.incoming_call_for_user(user.id).await?
883 {
884 self.peer.send(connection_id, incoming_call)?;
885 }
886
887 update_user_contacts(user.id, session).await?;
888 }
889 }
890
891 Ok(())
892 }
893
894 pub async fn invite_code_redeemed(
895 self: &Arc<Self>,
896 inviter_id: UserId,
897 invitee_id: UserId,
898 ) -> Result<()> {
899 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
900 if let Some(code) = &user.invite_code {
901 let pool = self.connection_pool.lock();
902 let invitee_contact = contact_for_user(invitee_id, false, &pool);
903 for connection_id in pool.user_connection_ids(inviter_id) {
904 self.peer.send(
905 connection_id,
906 proto::UpdateContacts {
907 contacts: vec![invitee_contact.clone()],
908 ..Default::default()
909 },
910 )?;
911 self.peer.send(
912 connection_id,
913 proto::UpdateInviteInfo {
914 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
915 count: user.invite_count as u32,
916 },
917 )?;
918 }
919 }
920 }
921 Ok(())
922 }
923
924 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
925 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
926 if let Some(invite_code) = &user.invite_code {
927 let pool = self.connection_pool.lock();
928 for connection_id in pool.user_connection_ids(user_id) {
929 self.peer.send(
930 connection_id,
931 proto::UpdateInviteInfo {
932 url: format!(
933 "{}{}",
934 self.app_state.config.invite_link_prefix, invite_code
935 ),
936 count: user.invite_count as u32,
937 },
938 )?;
939 }
940 }
941 }
942 Ok(())
943 }
944
945 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
946 let pool = self.connection_pool.lock();
947 for connection_id in pool.user_connection_ids(user_id) {
948 self.peer
949 .send(connection_id, proto::RefreshLlmToken {})
950 .trace_err();
951 }
952 }
953
954 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
955 ServerSnapshot {
956 connection_pool: ConnectionPoolGuard {
957 guard: self.connection_pool.lock(),
958 _not_send: PhantomData,
959 },
960 peer: &self.peer,
961 }
962 }
963}
964
965impl<'a> Deref for ConnectionPoolGuard<'a> {
966 type Target = ConnectionPool;
967
968 fn deref(&self) -> &Self::Target {
969 &self.guard
970 }
971}
972
973impl<'a> DerefMut for ConnectionPoolGuard<'a> {
974 fn deref_mut(&mut self) -> &mut Self::Target {
975 &mut self.guard
976 }
977}
978
979impl<'a> Drop for ConnectionPoolGuard<'a> {
980 fn drop(&mut self) {
981 #[cfg(test)]
982 self.check_invariants();
983 }
984}
985
986fn broadcast<F>(
987 sender_id: Option<ConnectionId>,
988 receiver_ids: impl IntoIterator<Item = ConnectionId>,
989 mut f: F,
990) where
991 F: FnMut(ConnectionId) -> anyhow::Result<()>,
992{
993 for receiver_id in receiver_ids {
994 if Some(receiver_id) != sender_id {
995 if let Err(error) = f(receiver_id) {
996 tracing::error!("failed to send to {:?} {}", receiver_id, error);
997 }
998 }
999 }
1000}
1001
1002pub struct ProtocolVersion(u32);
1003
1004impl Header for ProtocolVersion {
1005 fn name() -> &'static HeaderName {
1006 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1007 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1008 }
1009
1010 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1011 where
1012 Self: Sized,
1013 I: Iterator<Item = &'i axum::http::HeaderValue>,
1014 {
1015 let version = values
1016 .next()
1017 .ok_or_else(axum::headers::Error::invalid)?
1018 .to_str()
1019 .map_err(|_| axum::headers::Error::invalid())?
1020 .parse()
1021 .map_err(|_| axum::headers::Error::invalid())?;
1022 Ok(Self(version))
1023 }
1024
1025 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1026 values.extend([self.0.to_string().parse().unwrap()]);
1027 }
1028}
1029
1030pub struct AppVersionHeader(SemanticVersion);
1031impl Header for AppVersionHeader {
1032 fn name() -> &'static HeaderName {
1033 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1035 }
1036
1037 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038 where
1039 Self: Sized,
1040 I: Iterator<Item = &'i axum::http::HeaderValue>,
1041 {
1042 let version = values
1043 .next()
1044 .ok_or_else(axum::headers::Error::invalid)?
1045 .to_str()
1046 .map_err(|_| axum::headers::Error::invalid())?
1047 .parse()
1048 .map_err(|_| axum::headers::Error::invalid())?;
1049 Ok(Self(version))
1050 }
1051
1052 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053 values.extend([self.0.to_string().parse().unwrap()]);
1054 }
1055}
1056
1057pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1058 Router::new()
1059 .route("/rpc", get(handle_websocket_request))
1060 .layer(
1061 ServiceBuilder::new()
1062 .layer(Extension(server.app_state.clone()))
1063 .layer(middleware::from_fn(auth::validate_header)),
1064 )
1065 .route("/metrics", get(handle_metrics))
1066 .layer(Extension(server))
1067}
1068
1069#[allow(clippy::too_many_arguments)]
1070pub async fn handle_websocket_request(
1071 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1072 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1073 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1074 Extension(server): Extension<Arc<Server>>,
1075 Extension(principal): Extension<Principal>,
1076 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1077 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1078 ws: WebSocketUpgrade,
1079) -> axum::response::Response {
1080 if protocol_version != rpc::PROTOCOL_VERSION {
1081 return (
1082 StatusCode::UPGRADE_REQUIRED,
1083 "client must be upgraded".to_string(),
1084 )
1085 .into_response();
1086 }
1087
1088 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1089 return (
1090 StatusCode::UPGRADE_REQUIRED,
1091 "no version header found".to_string(),
1092 )
1093 .into_response();
1094 };
1095
1096 if !version.can_collaborate() {
1097 return (
1098 StatusCode::UPGRADE_REQUIRED,
1099 "client must be upgraded".to_string(),
1100 )
1101 .into_response();
1102 }
1103
1104 let socket_address = socket_address.to_string();
1105 ws.on_upgrade(move |socket| {
1106 let socket = socket
1107 .map_ok(to_tungstenite_message)
1108 .err_into()
1109 .with(|message| async move { to_axum_message(message) });
1110 let connection = Connection::new(Box::pin(socket));
1111 async move {
1112 server
1113 .handle_connection(
1114 connection,
1115 socket_address,
1116 principal,
1117 version,
1118 country_code_header.map(|header| header.to_string()),
1119 system_id_header.map(|header| header.to_string()),
1120 None,
1121 Executor::Production,
1122 )
1123 .await;
1124 }
1125 })
1126}
1127
1128pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1129 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1130 let connections_metric = CONNECTIONS_METRIC
1131 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1132
1133 let connections = server
1134 .connection_pool
1135 .lock()
1136 .connections()
1137 .filter(|connection| !connection.admin)
1138 .count();
1139 connections_metric.set(connections as _);
1140
1141 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1142 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1143 register_int_gauge!(
1144 "shared_projects",
1145 "number of open projects with one or more guests"
1146 )
1147 .unwrap()
1148 });
1149
1150 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1151 shared_projects_metric.set(shared_projects as _);
1152
1153 let encoder = prometheus::TextEncoder::new();
1154 let metric_families = prometheus::gather();
1155 let encoded_metrics = encoder
1156 .encode_to_string(&metric_families)
1157 .map_err(|err| anyhow!("{}", err))?;
1158 Ok(encoded_metrics)
1159}
1160
1161#[instrument(err, skip(executor))]
1162async fn connection_lost(
1163 session: Session,
1164 mut teardown: watch::Receiver<bool>,
1165 executor: Executor,
1166) -> Result<()> {
1167 session.peer.disconnect(session.connection_id);
1168 session
1169 .connection_pool()
1170 .await
1171 .remove_connection(session.connection_id)?;
1172
1173 session
1174 .db()
1175 .await
1176 .connection_lost(session.connection_id)
1177 .await
1178 .trace_err();
1179
1180 futures::select_biased! {
1181 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1182
1183 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1184 leave_room_for_session(&session, session.connection_id).await.trace_err();
1185 leave_channel_buffers_for_session(&session)
1186 .await
1187 .trace_err();
1188
1189 if !session
1190 .connection_pool()
1191 .await
1192 .is_user_online(session.user_id())
1193 {
1194 let db = session.db().await;
1195 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1196 room_updated(&room, &session.peer);
1197 }
1198 }
1199
1200 update_user_contacts(session.user_id(), &session).await?;
1201 },
1202 _ = teardown.changed().fuse() => {}
1203 }
1204
1205 Ok(())
1206}
1207
1208/// Acknowledges a ping from a client, used to keep the connection alive.
1209async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1210 response.send(proto::Ack {})?;
1211 Ok(())
1212}
1213
1214/// Creates a new room for calling (outside of channels)
1215async fn create_room(
1216 _request: proto::CreateRoom,
1217 response: Response<proto::CreateRoom>,
1218 session: Session,
1219) -> Result<()> {
1220 let livekit_room = nanoid::nanoid!(30);
1221
1222 let live_kit_connection_info = util::maybe!(async {
1223 let live_kit = session.app_state.livekit_client.as_ref();
1224 let live_kit = live_kit?;
1225 let user_id = session.user_id().to_string();
1226
1227 let token = live_kit
1228 .room_token(&livekit_room, &user_id.to_string())
1229 .trace_err()?;
1230
1231 Some(proto::LiveKitConnectionInfo {
1232 server_url: live_kit.url().into(),
1233 token,
1234 can_publish: true,
1235 })
1236 })
1237 .await;
1238
1239 let room = session
1240 .db()
1241 .await
1242 .create_room(session.user_id(), session.connection_id, &livekit_room)
1243 .await?;
1244
1245 response.send(proto::CreateRoomResponse {
1246 room: Some(room.clone()),
1247 live_kit_connection_info,
1248 })?;
1249
1250 update_user_contacts(session.user_id(), &session).await?;
1251 Ok(())
1252}
1253
1254/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1255async fn join_room(
1256 request: proto::JoinRoom,
1257 response: Response<proto::JoinRoom>,
1258 session: Session,
1259) -> Result<()> {
1260 let room_id = RoomId::from_proto(request.id);
1261
1262 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1263
1264 if let Some(channel_id) = channel_id {
1265 return join_channel_internal(channel_id, Box::new(response), session).await;
1266 }
1267
1268 let joined_room = {
1269 let room = session
1270 .db()
1271 .await
1272 .join_room(room_id, session.user_id(), session.connection_id)
1273 .await?;
1274 room_updated(&room.room, &session.peer);
1275 room.into_inner()
1276 };
1277
1278 for connection_id in session
1279 .connection_pool()
1280 .await
1281 .user_connection_ids(session.user_id())
1282 {
1283 session
1284 .peer
1285 .send(
1286 connection_id,
1287 proto::CallCanceled {
1288 room_id: room_id.to_proto(),
1289 },
1290 )
1291 .trace_err();
1292 }
1293
1294 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1295 {
1296 live_kit
1297 .room_token(
1298 &joined_room.room.livekit_room,
1299 &session.user_id().to_string(),
1300 )
1301 .trace_err()
1302 .map(|token| proto::LiveKitConnectionInfo {
1303 server_url: live_kit.url().into(),
1304 token,
1305 can_publish: true,
1306 })
1307 } else {
1308 None
1309 };
1310
1311 response.send(proto::JoinRoomResponse {
1312 room: Some(joined_room.room),
1313 channel_id: None,
1314 live_kit_connection_info,
1315 })?;
1316
1317 update_user_contacts(session.user_id(), &session).await?;
1318 Ok(())
1319}
1320
1321/// Rejoin room is used to reconnect to a room after connection errors.
1322async fn rejoin_room(
1323 request: proto::RejoinRoom,
1324 response: Response<proto::RejoinRoom>,
1325 session: Session,
1326) -> Result<()> {
1327 let room;
1328 let channel;
1329 {
1330 let mut rejoined_room = session
1331 .db()
1332 .await
1333 .rejoin_room(request, session.user_id(), session.connection_id)
1334 .await?;
1335
1336 response.send(proto::RejoinRoomResponse {
1337 room: Some(rejoined_room.room.clone()),
1338 reshared_projects: rejoined_room
1339 .reshared_projects
1340 .iter()
1341 .map(|project| proto::ResharedProject {
1342 id: project.id.to_proto(),
1343 collaborators: project
1344 .collaborators
1345 .iter()
1346 .map(|collaborator| collaborator.to_proto())
1347 .collect(),
1348 })
1349 .collect(),
1350 rejoined_projects: rejoined_room
1351 .rejoined_projects
1352 .iter()
1353 .map(|rejoined_project| rejoined_project.to_proto())
1354 .collect(),
1355 })?;
1356 room_updated(&rejoined_room.room, &session.peer);
1357
1358 for project in &rejoined_room.reshared_projects {
1359 for collaborator in &project.collaborators {
1360 session
1361 .peer
1362 .send(
1363 collaborator.connection_id,
1364 proto::UpdateProjectCollaborator {
1365 project_id: project.id.to_proto(),
1366 old_peer_id: Some(project.old_connection_id.into()),
1367 new_peer_id: Some(session.connection_id.into()),
1368 },
1369 )
1370 .trace_err();
1371 }
1372
1373 broadcast(
1374 Some(session.connection_id),
1375 project
1376 .collaborators
1377 .iter()
1378 .map(|collaborator| collaborator.connection_id),
1379 |connection_id| {
1380 session.peer.forward_send(
1381 session.connection_id,
1382 connection_id,
1383 proto::UpdateProject {
1384 project_id: project.id.to_proto(),
1385 worktrees: project.worktrees.clone(),
1386 },
1387 )
1388 },
1389 );
1390 }
1391
1392 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1393
1394 let rejoined_room = rejoined_room.into_inner();
1395
1396 room = rejoined_room.room;
1397 channel = rejoined_room.channel;
1398 }
1399
1400 if let Some(channel) = channel {
1401 channel_updated(
1402 &channel,
1403 &room,
1404 &session.peer,
1405 &*session.connection_pool().await,
1406 );
1407 }
1408
1409 update_user_contacts(session.user_id(), &session).await?;
1410 Ok(())
1411}
1412
1413fn notify_rejoined_projects(
1414 rejoined_projects: &mut Vec<RejoinedProject>,
1415 session: &Session,
1416) -> Result<()> {
1417 for project in rejoined_projects.iter() {
1418 for collaborator in &project.collaborators {
1419 session
1420 .peer
1421 .send(
1422 collaborator.connection_id,
1423 proto::UpdateProjectCollaborator {
1424 project_id: project.id.to_proto(),
1425 old_peer_id: Some(project.old_connection_id.into()),
1426 new_peer_id: Some(session.connection_id.into()),
1427 },
1428 )
1429 .trace_err();
1430 }
1431 }
1432
1433 for project in rejoined_projects {
1434 for worktree in mem::take(&mut project.worktrees) {
1435 // Stream this worktree's entries.
1436 let message = proto::UpdateWorktree {
1437 project_id: project.id.to_proto(),
1438 worktree_id: worktree.id,
1439 abs_path: worktree.abs_path.clone(),
1440 root_name: worktree.root_name,
1441 updated_entries: worktree.updated_entries,
1442 removed_entries: worktree.removed_entries,
1443 scan_id: worktree.scan_id,
1444 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1445 updated_repositories: worktree.updated_repositories,
1446 removed_repositories: worktree.removed_repositories,
1447 };
1448 for update in proto::split_worktree_update(message) {
1449 session.peer.send(session.connection_id, update.clone())?;
1450 }
1451
1452 // Stream this worktree's diagnostics.
1453 for summary in worktree.diagnostic_summaries {
1454 session.peer.send(
1455 session.connection_id,
1456 proto::UpdateDiagnosticSummary {
1457 project_id: project.id.to_proto(),
1458 worktree_id: worktree.id,
1459 summary: Some(summary),
1460 },
1461 )?;
1462 }
1463
1464 for settings_file in worktree.settings_files {
1465 session.peer.send(
1466 session.connection_id,
1467 proto::UpdateWorktreeSettings {
1468 project_id: project.id.to_proto(),
1469 worktree_id: worktree.id,
1470 path: settings_file.path,
1471 content: Some(settings_file.content),
1472 kind: Some(settings_file.kind.to_proto().into()),
1473 },
1474 )?;
1475 }
1476 }
1477
1478 for language_server in &project.language_servers {
1479 session.peer.send(
1480 session.connection_id,
1481 proto::UpdateLanguageServer {
1482 project_id: project.id.to_proto(),
1483 language_server_id: language_server.id,
1484 variant: Some(
1485 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1486 proto::LspDiskBasedDiagnosticsUpdated {},
1487 ),
1488 ),
1489 },
1490 )?;
1491 }
1492 }
1493 Ok(())
1494}
1495
1496/// leave room disconnects from the room.
1497async fn leave_room(
1498 _: proto::LeaveRoom,
1499 response: Response<proto::LeaveRoom>,
1500 session: Session,
1501) -> Result<()> {
1502 leave_room_for_session(&session, session.connection_id).await?;
1503 response.send(proto::Ack {})?;
1504 Ok(())
1505}
1506
1507/// Updates the permissions of someone else in the room.
1508async fn set_room_participant_role(
1509 request: proto::SetRoomParticipantRole,
1510 response: Response<proto::SetRoomParticipantRole>,
1511 session: Session,
1512) -> Result<()> {
1513 let user_id = UserId::from_proto(request.user_id);
1514 let role = ChannelRole::from(request.role());
1515
1516 let (livekit_room, can_publish) = {
1517 let room = session
1518 .db()
1519 .await
1520 .set_room_participant_role(
1521 session.user_id(),
1522 RoomId::from_proto(request.room_id),
1523 user_id,
1524 role,
1525 )
1526 .await?;
1527
1528 let livekit_room = room.livekit_room.clone();
1529 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1530 room_updated(&room, &session.peer);
1531 (livekit_room, can_publish)
1532 };
1533
1534 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1535 live_kit
1536 .update_participant(
1537 livekit_room.clone(),
1538 request.user_id.to_string(),
1539 livekit_server::proto::ParticipantPermission {
1540 can_subscribe: true,
1541 can_publish,
1542 can_publish_data: can_publish,
1543 hidden: false,
1544 recorder: false,
1545 },
1546 )
1547 .await
1548 .trace_err();
1549 }
1550
1551 response.send(proto::Ack {})?;
1552 Ok(())
1553}
1554
1555/// Call someone else into the current room
1556async fn call(
1557 request: proto::Call,
1558 response: Response<proto::Call>,
1559 session: Session,
1560) -> Result<()> {
1561 let room_id = RoomId::from_proto(request.room_id);
1562 let calling_user_id = session.user_id();
1563 let calling_connection_id = session.connection_id;
1564 let called_user_id = UserId::from_proto(request.called_user_id);
1565 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1566 if !session
1567 .db()
1568 .await
1569 .has_contact(calling_user_id, called_user_id)
1570 .await?
1571 {
1572 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1573 }
1574
1575 let incoming_call = {
1576 let (room, incoming_call) = &mut *session
1577 .db()
1578 .await
1579 .call(
1580 room_id,
1581 calling_user_id,
1582 calling_connection_id,
1583 called_user_id,
1584 initial_project_id,
1585 )
1586 .await?;
1587 room_updated(room, &session.peer);
1588 mem::take(incoming_call)
1589 };
1590 update_user_contacts(called_user_id, &session).await?;
1591
1592 let mut calls = session
1593 .connection_pool()
1594 .await
1595 .user_connection_ids(called_user_id)
1596 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1597 .collect::<FuturesUnordered<_>>();
1598
1599 while let Some(call_response) = calls.next().await {
1600 match call_response.as_ref() {
1601 Ok(_) => {
1602 response.send(proto::Ack {})?;
1603 return Ok(());
1604 }
1605 Err(_) => {
1606 call_response.trace_err();
1607 }
1608 }
1609 }
1610
1611 {
1612 let room = session
1613 .db()
1614 .await
1615 .call_failed(room_id, called_user_id)
1616 .await?;
1617 room_updated(&room, &session.peer);
1618 }
1619 update_user_contacts(called_user_id, &session).await?;
1620
1621 Err(anyhow!("failed to ring user"))?
1622}
1623
1624/// Cancel an outgoing call.
1625async fn cancel_call(
1626 request: proto::CancelCall,
1627 response: Response<proto::CancelCall>,
1628 session: Session,
1629) -> Result<()> {
1630 let called_user_id = UserId::from_proto(request.called_user_id);
1631 let room_id = RoomId::from_proto(request.room_id);
1632 {
1633 let room = session
1634 .db()
1635 .await
1636 .cancel_call(room_id, session.connection_id, called_user_id)
1637 .await?;
1638 room_updated(&room, &session.peer);
1639 }
1640
1641 for connection_id in session
1642 .connection_pool()
1643 .await
1644 .user_connection_ids(called_user_id)
1645 {
1646 session
1647 .peer
1648 .send(
1649 connection_id,
1650 proto::CallCanceled {
1651 room_id: room_id.to_proto(),
1652 },
1653 )
1654 .trace_err();
1655 }
1656 response.send(proto::Ack {})?;
1657
1658 update_user_contacts(called_user_id, &session).await?;
1659 Ok(())
1660}
1661
1662/// Decline an incoming call.
1663async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1664 let room_id = RoomId::from_proto(message.room_id);
1665 {
1666 let room = session
1667 .db()
1668 .await
1669 .decline_call(Some(room_id), session.user_id())
1670 .await?
1671 .ok_or_else(|| anyhow!("failed to decline call"))?;
1672 room_updated(&room, &session.peer);
1673 }
1674
1675 for connection_id in session
1676 .connection_pool()
1677 .await
1678 .user_connection_ids(session.user_id())
1679 {
1680 session
1681 .peer
1682 .send(
1683 connection_id,
1684 proto::CallCanceled {
1685 room_id: room_id.to_proto(),
1686 },
1687 )
1688 .trace_err();
1689 }
1690 update_user_contacts(session.user_id(), &session).await?;
1691 Ok(())
1692}
1693
1694/// Updates other participants in the room with your current location.
1695async fn update_participant_location(
1696 request: proto::UpdateParticipantLocation,
1697 response: Response<proto::UpdateParticipantLocation>,
1698 session: Session,
1699) -> Result<()> {
1700 let room_id = RoomId::from_proto(request.room_id);
1701 let location = request
1702 .location
1703 .ok_or_else(|| anyhow!("invalid location"))?;
1704
1705 let db = session.db().await;
1706 let room = db
1707 .update_room_participant_location(room_id, session.connection_id, location)
1708 .await?;
1709
1710 room_updated(&room, &session.peer);
1711 response.send(proto::Ack {})?;
1712 Ok(())
1713}
1714
1715/// Share a project into the room.
1716async fn share_project(
1717 request: proto::ShareProject,
1718 response: Response<proto::ShareProject>,
1719 session: Session,
1720) -> Result<()> {
1721 let (project_id, room) = &*session
1722 .db()
1723 .await
1724 .share_project(
1725 RoomId::from_proto(request.room_id),
1726 session.connection_id,
1727 &request.worktrees,
1728 request.is_ssh_project,
1729 )
1730 .await?;
1731 response.send(proto::ShareProjectResponse {
1732 project_id: project_id.to_proto(),
1733 })?;
1734 room_updated(room, &session.peer);
1735
1736 Ok(())
1737}
1738
1739/// Unshare a project from the room.
1740async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1741 let project_id = ProjectId::from_proto(message.project_id);
1742 unshare_project_internal(project_id, session.connection_id, &session).await
1743}
1744
1745async fn unshare_project_internal(
1746 project_id: ProjectId,
1747 connection_id: ConnectionId,
1748 session: &Session,
1749) -> Result<()> {
1750 let delete = {
1751 let room_guard = session
1752 .db()
1753 .await
1754 .unshare_project(project_id, connection_id)
1755 .await?;
1756
1757 let (delete, room, guest_connection_ids) = &*room_guard;
1758
1759 let message = proto::UnshareProject {
1760 project_id: project_id.to_proto(),
1761 };
1762
1763 broadcast(
1764 Some(connection_id),
1765 guest_connection_ids.iter().copied(),
1766 |conn_id| session.peer.send(conn_id, message.clone()),
1767 );
1768 if let Some(room) = room {
1769 room_updated(room, &session.peer);
1770 }
1771
1772 *delete
1773 };
1774
1775 if delete {
1776 let db = session.db().await;
1777 db.delete_project(project_id).await?;
1778 }
1779
1780 Ok(())
1781}
1782
1783/// Join someone elses shared project.
1784async fn join_project(
1785 request: proto::JoinProject,
1786 response: Response<proto::JoinProject>,
1787 session: Session,
1788) -> Result<()> {
1789 let project_id = ProjectId::from_proto(request.project_id);
1790
1791 tracing::info!(%project_id, "join project");
1792
1793 let db = session.db().await;
1794 let (project, replica_id) = &mut *db
1795 .join_project(project_id, session.connection_id, session.user_id())
1796 .await?;
1797 drop(db);
1798 tracing::info!(%project_id, "join remote project");
1799 join_project_internal(response, session, project, replica_id)
1800}
1801
1802trait JoinProjectInternalResponse {
1803 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1804}
1805impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1806 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1807 Response::<proto::JoinProject>::send(self, result)
1808 }
1809}
1810
1811fn join_project_internal(
1812 response: impl JoinProjectInternalResponse,
1813 session: Session,
1814 project: &mut Project,
1815 replica_id: &ReplicaId,
1816) -> Result<()> {
1817 let collaborators = project
1818 .collaborators
1819 .iter()
1820 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1821 .map(|collaborator| collaborator.to_proto())
1822 .collect::<Vec<_>>();
1823 let project_id = project.id;
1824 let guest_user_id = session.user_id();
1825
1826 let worktrees = project
1827 .worktrees
1828 .iter()
1829 .map(|(id, worktree)| proto::WorktreeMetadata {
1830 id: *id,
1831 root_name: worktree.root_name.clone(),
1832 visible: worktree.visible,
1833 abs_path: worktree.abs_path.clone(),
1834 })
1835 .collect::<Vec<_>>();
1836
1837 let add_project_collaborator = proto::AddProjectCollaborator {
1838 project_id: project_id.to_proto(),
1839 collaborator: Some(proto::Collaborator {
1840 peer_id: Some(session.connection_id.into()),
1841 replica_id: replica_id.0 as u32,
1842 user_id: guest_user_id.to_proto(),
1843 is_host: false,
1844 }),
1845 };
1846
1847 for collaborator in &collaborators {
1848 session
1849 .peer
1850 .send(
1851 collaborator.peer_id.unwrap().into(),
1852 add_project_collaborator.clone(),
1853 )
1854 .trace_err();
1855 }
1856
1857 // First, we send the metadata associated with each worktree.
1858 response.send(proto::JoinProjectResponse {
1859 project_id: project.id.0 as u64,
1860 worktrees: worktrees.clone(),
1861 replica_id: replica_id.0 as u32,
1862 collaborators: collaborators.clone(),
1863 language_servers: project.language_servers.clone(),
1864 role: project.role.into(),
1865 })?;
1866
1867 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1868 // Stream this worktree's entries.
1869 let message = proto::UpdateWorktree {
1870 project_id: project_id.to_proto(),
1871 worktree_id,
1872 abs_path: worktree.abs_path.clone(),
1873 root_name: worktree.root_name,
1874 updated_entries: worktree.entries,
1875 removed_entries: Default::default(),
1876 scan_id: worktree.scan_id,
1877 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1878 updated_repositories: worktree.repository_entries.into_values().collect(),
1879 removed_repositories: Default::default(),
1880 };
1881 for update in proto::split_worktree_update(message) {
1882 session.peer.send(session.connection_id, update.clone())?;
1883 }
1884
1885 // Stream this worktree's diagnostics.
1886 for summary in worktree.diagnostic_summaries {
1887 session.peer.send(
1888 session.connection_id,
1889 proto::UpdateDiagnosticSummary {
1890 project_id: project_id.to_proto(),
1891 worktree_id: worktree.id,
1892 summary: Some(summary),
1893 },
1894 )?;
1895 }
1896
1897 for settings_file in worktree.settings_files {
1898 session.peer.send(
1899 session.connection_id,
1900 proto::UpdateWorktreeSettings {
1901 project_id: project_id.to_proto(),
1902 worktree_id: worktree.id,
1903 path: settings_file.path,
1904 content: Some(settings_file.content),
1905 kind: Some(settings_file.kind.to_proto() as i32),
1906 },
1907 )?;
1908 }
1909 }
1910
1911 for language_server in &project.language_servers {
1912 session.peer.send(
1913 session.connection_id,
1914 proto::UpdateLanguageServer {
1915 project_id: project_id.to_proto(),
1916 language_server_id: language_server.id,
1917 variant: Some(
1918 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1919 proto::LspDiskBasedDiagnosticsUpdated {},
1920 ),
1921 ),
1922 },
1923 )?;
1924 }
1925
1926 Ok(())
1927}
1928
1929/// Leave someone elses shared project.
1930async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1931 let sender_id = session.connection_id;
1932 let project_id = ProjectId::from_proto(request.project_id);
1933 let db = session.db().await;
1934
1935 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1936 tracing::info!(
1937 %project_id,
1938 "leave project"
1939 );
1940
1941 project_left(project, &session);
1942 if let Some(room) = room {
1943 room_updated(room, &session.peer);
1944 }
1945
1946 Ok(())
1947}
1948
1949/// Updates other participants with changes to the project
1950async fn update_project(
1951 request: proto::UpdateProject,
1952 response: Response<proto::UpdateProject>,
1953 session: Session,
1954) -> Result<()> {
1955 let project_id = ProjectId::from_proto(request.project_id);
1956 let (room, guest_connection_ids) = &*session
1957 .db()
1958 .await
1959 .update_project(project_id, session.connection_id, &request.worktrees)
1960 .await?;
1961 broadcast(
1962 Some(session.connection_id),
1963 guest_connection_ids.iter().copied(),
1964 |connection_id| {
1965 session
1966 .peer
1967 .forward_send(session.connection_id, connection_id, request.clone())
1968 },
1969 );
1970 if let Some(room) = room {
1971 room_updated(room, &session.peer);
1972 }
1973 response.send(proto::Ack {})?;
1974
1975 Ok(())
1976}
1977
1978/// Updates other participants with changes to the worktree
1979async fn update_worktree(
1980 request: proto::UpdateWorktree,
1981 response: Response<proto::UpdateWorktree>,
1982 session: Session,
1983) -> Result<()> {
1984 let guest_connection_ids = session
1985 .db()
1986 .await
1987 .update_worktree(&request, session.connection_id)
1988 .await?;
1989
1990 broadcast(
1991 Some(session.connection_id),
1992 guest_connection_ids.iter().copied(),
1993 |connection_id| {
1994 session
1995 .peer
1996 .forward_send(session.connection_id, connection_id, request.clone())
1997 },
1998 );
1999 response.send(proto::Ack {})?;
2000 Ok(())
2001}
2002
2003/// Updates other participants with changes to the diagnostics
2004async fn update_diagnostic_summary(
2005 message: proto::UpdateDiagnosticSummary,
2006 session: Session,
2007) -> Result<()> {
2008 let guest_connection_ids = session
2009 .db()
2010 .await
2011 .update_diagnostic_summary(&message, session.connection_id)
2012 .await?;
2013
2014 broadcast(
2015 Some(session.connection_id),
2016 guest_connection_ids.iter().copied(),
2017 |connection_id| {
2018 session
2019 .peer
2020 .forward_send(session.connection_id, connection_id, message.clone())
2021 },
2022 );
2023
2024 Ok(())
2025}
2026
2027/// Updates other participants with changes to the worktree settings
2028async fn update_worktree_settings(
2029 message: proto::UpdateWorktreeSettings,
2030 session: Session,
2031) -> Result<()> {
2032 let guest_connection_ids = session
2033 .db()
2034 .await
2035 .update_worktree_settings(&message, session.connection_id)
2036 .await?;
2037
2038 broadcast(
2039 Some(session.connection_id),
2040 guest_connection_ids.iter().copied(),
2041 |connection_id| {
2042 session
2043 .peer
2044 .forward_send(session.connection_id, connection_id, message.clone())
2045 },
2046 );
2047
2048 Ok(())
2049}
2050
2051/// Notify other participants that a language server has started.
2052async fn start_language_server(
2053 request: proto::StartLanguageServer,
2054 session: Session,
2055) -> Result<()> {
2056 let guest_connection_ids = session
2057 .db()
2058 .await
2059 .start_language_server(&request, session.connection_id)
2060 .await?;
2061
2062 broadcast(
2063 Some(session.connection_id),
2064 guest_connection_ids.iter().copied(),
2065 |connection_id| {
2066 session
2067 .peer
2068 .forward_send(session.connection_id, connection_id, request.clone())
2069 },
2070 );
2071 Ok(())
2072}
2073
2074/// Notify other participants that a language server has changed.
2075async fn update_language_server(
2076 request: proto::UpdateLanguageServer,
2077 session: Session,
2078) -> Result<()> {
2079 let project_id = ProjectId::from_proto(request.project_id);
2080 let project_connection_ids = session
2081 .db()
2082 .await
2083 .project_connection_ids(project_id, session.connection_id, true)
2084 .await?;
2085 broadcast(
2086 Some(session.connection_id),
2087 project_connection_ids.iter().copied(),
2088 |connection_id| {
2089 session
2090 .peer
2091 .forward_send(session.connection_id, connection_id, request.clone())
2092 },
2093 );
2094 Ok(())
2095}
2096
2097/// forward a project request to the host. These requests should be read only
2098/// as guests are allowed to send them.
2099async fn forward_read_only_project_request<T>(
2100 request: T,
2101 response: Response<T>,
2102 session: Session,
2103) -> Result<()>
2104where
2105 T: EntityMessage + RequestMessage,
2106{
2107 let project_id = ProjectId::from_proto(request.remote_entity_id());
2108 let host_connection_id = session
2109 .db()
2110 .await
2111 .host_for_read_only_project_request(project_id, session.connection_id)
2112 .await?;
2113 let payload = session
2114 .peer
2115 .forward_request(session.connection_id, host_connection_id, request)
2116 .await?;
2117 response.send(payload)?;
2118 Ok(())
2119}
2120
2121async fn forward_find_search_candidates_request(
2122 request: proto::FindSearchCandidates,
2123 response: Response<proto::FindSearchCandidates>,
2124 session: Session,
2125) -> Result<()> {
2126 let project_id = ProjectId::from_proto(request.remote_entity_id());
2127 let host_connection_id = session
2128 .db()
2129 .await
2130 .host_for_read_only_project_request(project_id, session.connection_id)
2131 .await?;
2132 let payload = session
2133 .peer
2134 .forward_request(session.connection_id, host_connection_id, request)
2135 .await?;
2136 response.send(payload)?;
2137 Ok(())
2138}
2139
2140/// forward a project request to the host. These requests are disallowed
2141/// for guests.
2142async fn forward_mutating_project_request<T>(
2143 request: T,
2144 response: Response<T>,
2145 session: Session,
2146) -> Result<()>
2147where
2148 T: EntityMessage + RequestMessage,
2149{
2150 let project_id = ProjectId::from_proto(request.remote_entity_id());
2151
2152 let host_connection_id = session
2153 .db()
2154 .await
2155 .host_for_mutating_project_request(project_id, session.connection_id)
2156 .await?;
2157 let payload = session
2158 .peer
2159 .forward_request(session.connection_id, host_connection_id, request)
2160 .await?;
2161 response.send(payload)?;
2162 Ok(())
2163}
2164
2165/// Notify other participants that a new buffer has been created
2166async fn create_buffer_for_peer(
2167 request: proto::CreateBufferForPeer,
2168 session: Session,
2169) -> Result<()> {
2170 session
2171 .db()
2172 .await
2173 .check_user_is_project_host(
2174 ProjectId::from_proto(request.project_id),
2175 session.connection_id,
2176 )
2177 .await?;
2178 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2179 session
2180 .peer
2181 .forward_send(session.connection_id, peer_id.into(), request)?;
2182 Ok(())
2183}
2184
2185/// Notify other participants that a buffer has been updated. This is
2186/// allowed for guests as long as the update is limited to selections.
2187async fn update_buffer(
2188 request: proto::UpdateBuffer,
2189 response: Response<proto::UpdateBuffer>,
2190 session: Session,
2191) -> Result<()> {
2192 let project_id = ProjectId::from_proto(request.project_id);
2193 let mut capability = Capability::ReadOnly;
2194
2195 for op in request.operations.iter() {
2196 match op.variant {
2197 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2198 Some(_) => capability = Capability::ReadWrite,
2199 }
2200 }
2201
2202 let host = {
2203 let guard = session
2204 .db()
2205 .await
2206 .connections_for_buffer_update(project_id, session.connection_id, capability)
2207 .await?;
2208
2209 let (host, guests) = &*guard;
2210
2211 broadcast(
2212 Some(session.connection_id),
2213 guests.clone(),
2214 |connection_id| {
2215 session
2216 .peer
2217 .forward_send(session.connection_id, connection_id, request.clone())
2218 },
2219 );
2220
2221 *host
2222 };
2223
2224 if host != session.connection_id {
2225 session
2226 .peer
2227 .forward_request(session.connection_id, host, request.clone())
2228 .await?;
2229 }
2230
2231 response.send(proto::Ack {})?;
2232 Ok(())
2233}
2234
2235async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2236 let project_id = ProjectId::from_proto(message.project_id);
2237
2238 let operation = message.operation.as_ref().context("invalid operation")?;
2239 let capability = match operation.variant.as_ref() {
2240 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2241 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2242 match buffer_op.variant {
2243 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2244 Capability::ReadOnly
2245 }
2246 _ => Capability::ReadWrite,
2247 }
2248 } else {
2249 Capability::ReadWrite
2250 }
2251 }
2252 Some(_) => Capability::ReadWrite,
2253 None => Capability::ReadOnly,
2254 };
2255
2256 let guard = session
2257 .db()
2258 .await
2259 .connections_for_buffer_update(project_id, session.connection_id, capability)
2260 .await?;
2261
2262 let (host, guests) = &*guard;
2263
2264 broadcast(
2265 Some(session.connection_id),
2266 guests.iter().chain([host]).copied(),
2267 |connection_id| {
2268 session
2269 .peer
2270 .forward_send(session.connection_id, connection_id, message.clone())
2271 },
2272 );
2273
2274 Ok(())
2275}
2276
2277/// Notify other participants that a project has been updated.
2278async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2279 request: T,
2280 session: Session,
2281) -> Result<()> {
2282 let project_id = ProjectId::from_proto(request.remote_entity_id());
2283 let project_connection_ids = session
2284 .db()
2285 .await
2286 .project_connection_ids(project_id, session.connection_id, false)
2287 .await?;
2288
2289 broadcast(
2290 Some(session.connection_id),
2291 project_connection_ids.iter().copied(),
2292 |connection_id| {
2293 session
2294 .peer
2295 .forward_send(session.connection_id, connection_id, request.clone())
2296 },
2297 );
2298 Ok(())
2299}
2300
2301/// Start following another user in a call.
2302async fn follow(
2303 request: proto::Follow,
2304 response: Response<proto::Follow>,
2305 session: Session,
2306) -> Result<()> {
2307 let room_id = RoomId::from_proto(request.room_id);
2308 let project_id = request.project_id.map(ProjectId::from_proto);
2309 let leader_id = request
2310 .leader_id
2311 .ok_or_else(|| anyhow!("invalid leader id"))?
2312 .into();
2313 let follower_id = session.connection_id;
2314
2315 session
2316 .db()
2317 .await
2318 .check_room_participants(room_id, leader_id, session.connection_id)
2319 .await?;
2320
2321 let response_payload = session
2322 .peer
2323 .forward_request(session.connection_id, leader_id, request)
2324 .await?;
2325 response.send(response_payload)?;
2326
2327 if let Some(project_id) = project_id {
2328 let room = session
2329 .db()
2330 .await
2331 .follow(room_id, project_id, leader_id, follower_id)
2332 .await?;
2333 room_updated(&room, &session.peer);
2334 }
2335
2336 Ok(())
2337}
2338
2339/// Stop following another user in a call.
2340async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2341 let room_id = RoomId::from_proto(request.room_id);
2342 let project_id = request.project_id.map(ProjectId::from_proto);
2343 let leader_id = request
2344 .leader_id
2345 .ok_or_else(|| anyhow!("invalid leader id"))?
2346 .into();
2347 let follower_id = session.connection_id;
2348
2349 session
2350 .db()
2351 .await
2352 .check_room_participants(room_id, leader_id, session.connection_id)
2353 .await?;
2354
2355 session
2356 .peer
2357 .forward_send(session.connection_id, leader_id, request)?;
2358
2359 if let Some(project_id) = project_id {
2360 let room = session
2361 .db()
2362 .await
2363 .unfollow(room_id, project_id, leader_id, follower_id)
2364 .await?;
2365 room_updated(&room, &session.peer);
2366 }
2367
2368 Ok(())
2369}
2370
2371/// Notify everyone following you of your current location.
2372async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2373 let room_id = RoomId::from_proto(request.room_id);
2374 let database = session.db.lock().await;
2375
2376 let connection_ids = if let Some(project_id) = request.project_id {
2377 let project_id = ProjectId::from_proto(project_id);
2378 database
2379 .project_connection_ids(project_id, session.connection_id, true)
2380 .await?
2381 } else {
2382 database
2383 .room_connection_ids(room_id, session.connection_id)
2384 .await?
2385 };
2386
2387 // For now, don't send view update messages back to that view's current leader.
2388 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2389 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2390 _ => None,
2391 });
2392
2393 for connection_id in connection_ids.iter().cloned() {
2394 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2395 session
2396 .peer
2397 .forward_send(session.connection_id, connection_id, request.clone())?;
2398 }
2399 }
2400 Ok(())
2401}
2402
2403/// Get public data about users.
2404async fn get_users(
2405 request: proto::GetUsers,
2406 response: Response<proto::GetUsers>,
2407 session: Session,
2408) -> Result<()> {
2409 let user_ids = request
2410 .user_ids
2411 .into_iter()
2412 .map(UserId::from_proto)
2413 .collect();
2414 let users = session
2415 .db()
2416 .await
2417 .get_users_by_ids(user_ids)
2418 .await?
2419 .into_iter()
2420 .map(|user| proto::User {
2421 id: user.id.to_proto(),
2422 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2423 github_login: user.github_login,
2424 email: user.email_address,
2425 name: user.name,
2426 })
2427 .collect();
2428 response.send(proto::UsersResponse { users })?;
2429 Ok(())
2430}
2431
2432/// Search for users (to invite) buy Github login
2433async fn fuzzy_search_users(
2434 request: proto::FuzzySearchUsers,
2435 response: Response<proto::FuzzySearchUsers>,
2436 session: Session,
2437) -> Result<()> {
2438 let query = request.query;
2439 let users = match query.len() {
2440 0 => vec![],
2441 1 | 2 => session
2442 .db()
2443 .await
2444 .get_user_by_github_login(&query)
2445 .await?
2446 .into_iter()
2447 .collect(),
2448 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2449 };
2450 let users = users
2451 .into_iter()
2452 .filter(|user| user.id != session.user_id())
2453 .map(|user| proto::User {
2454 id: user.id.to_proto(),
2455 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2456 github_login: user.github_login,
2457 name: user.name,
2458 email: user.email_address,
2459 })
2460 .collect();
2461 response.send(proto::UsersResponse { users })?;
2462 Ok(())
2463}
2464
2465/// Send a contact request to another user.
2466async fn request_contact(
2467 request: proto::RequestContact,
2468 response: Response<proto::RequestContact>,
2469 session: Session,
2470) -> Result<()> {
2471 let requester_id = session.user_id();
2472 let responder_id = UserId::from_proto(request.responder_id);
2473 if requester_id == responder_id {
2474 return Err(anyhow!("cannot add yourself as a contact"))?;
2475 }
2476
2477 let notifications = session
2478 .db()
2479 .await
2480 .send_contact_request(requester_id, responder_id)
2481 .await?;
2482
2483 // Update outgoing contact requests of requester
2484 let mut update = proto::UpdateContacts::default();
2485 update.outgoing_requests.push(responder_id.to_proto());
2486 for connection_id in session
2487 .connection_pool()
2488 .await
2489 .user_connection_ids(requester_id)
2490 {
2491 session.peer.send(connection_id, update.clone())?;
2492 }
2493
2494 // Update incoming contact requests of responder
2495 let mut update = proto::UpdateContacts::default();
2496 update
2497 .incoming_requests
2498 .push(proto::IncomingContactRequest {
2499 requester_id: requester_id.to_proto(),
2500 });
2501 let connection_pool = session.connection_pool().await;
2502 for connection_id in connection_pool.user_connection_ids(responder_id) {
2503 session.peer.send(connection_id, update.clone())?;
2504 }
2505
2506 send_notifications(&connection_pool, &session.peer, notifications);
2507
2508 response.send(proto::Ack {})?;
2509 Ok(())
2510}
2511
2512/// Accept or decline a contact request
2513async fn respond_to_contact_request(
2514 request: proto::RespondToContactRequest,
2515 response: Response<proto::RespondToContactRequest>,
2516 session: Session,
2517) -> Result<()> {
2518 let responder_id = session.user_id();
2519 let requester_id = UserId::from_proto(request.requester_id);
2520 let db = session.db().await;
2521 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2522 db.dismiss_contact_notification(responder_id, requester_id)
2523 .await?;
2524 } else {
2525 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2526
2527 let notifications = db
2528 .respond_to_contact_request(responder_id, requester_id, accept)
2529 .await?;
2530 let requester_busy = db.is_user_busy(requester_id).await?;
2531 let responder_busy = db.is_user_busy(responder_id).await?;
2532
2533 let pool = session.connection_pool().await;
2534 // Update responder with new contact
2535 let mut update = proto::UpdateContacts::default();
2536 if accept {
2537 update
2538 .contacts
2539 .push(contact_for_user(requester_id, requester_busy, &pool));
2540 }
2541 update
2542 .remove_incoming_requests
2543 .push(requester_id.to_proto());
2544 for connection_id in pool.user_connection_ids(responder_id) {
2545 session.peer.send(connection_id, update.clone())?;
2546 }
2547
2548 // Update requester with new contact
2549 let mut update = proto::UpdateContacts::default();
2550 if accept {
2551 update
2552 .contacts
2553 .push(contact_for_user(responder_id, responder_busy, &pool));
2554 }
2555 update
2556 .remove_outgoing_requests
2557 .push(responder_id.to_proto());
2558
2559 for connection_id in pool.user_connection_ids(requester_id) {
2560 session.peer.send(connection_id, update.clone())?;
2561 }
2562
2563 send_notifications(&pool, &session.peer, notifications);
2564 }
2565
2566 response.send(proto::Ack {})?;
2567 Ok(())
2568}
2569
2570/// Remove a contact.
2571async fn remove_contact(
2572 request: proto::RemoveContact,
2573 response: Response<proto::RemoveContact>,
2574 session: Session,
2575) -> Result<()> {
2576 let requester_id = session.user_id();
2577 let responder_id = UserId::from_proto(request.user_id);
2578 let db = session.db().await;
2579 let (contact_accepted, deleted_notification_id) =
2580 db.remove_contact(requester_id, responder_id).await?;
2581
2582 let pool = session.connection_pool().await;
2583 // Update outgoing contact requests of requester
2584 let mut update = proto::UpdateContacts::default();
2585 if contact_accepted {
2586 update.remove_contacts.push(responder_id.to_proto());
2587 } else {
2588 update
2589 .remove_outgoing_requests
2590 .push(responder_id.to_proto());
2591 }
2592 for connection_id in pool.user_connection_ids(requester_id) {
2593 session.peer.send(connection_id, update.clone())?;
2594 }
2595
2596 // Update incoming contact requests of responder
2597 let mut update = proto::UpdateContacts::default();
2598 if contact_accepted {
2599 update.remove_contacts.push(requester_id.to_proto());
2600 } else {
2601 update
2602 .remove_incoming_requests
2603 .push(requester_id.to_proto());
2604 }
2605 for connection_id in pool.user_connection_ids(responder_id) {
2606 session.peer.send(connection_id, update.clone())?;
2607 if let Some(notification_id) = deleted_notification_id {
2608 session.peer.send(
2609 connection_id,
2610 proto::DeleteNotification {
2611 notification_id: notification_id.to_proto(),
2612 },
2613 )?;
2614 }
2615 }
2616
2617 response.send(proto::Ack {})?;
2618 Ok(())
2619}
2620
2621fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2622 version.0.minor() < 139
2623}
2624
2625async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2626 let plan = session.current_plan(&session.db().await).await?;
2627
2628 session
2629 .peer
2630 .send(
2631 session.connection_id,
2632 proto::UpdateUserPlan { plan: plan.into() },
2633 )
2634 .trace_err();
2635
2636 Ok(())
2637}
2638
2639async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2640 subscribe_user_to_channels(session.user_id(), &session).await?;
2641 Ok(())
2642}
2643
2644async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2645 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2646 let mut pool = session.connection_pool().await;
2647 for membership in &channels_for_user.channel_memberships {
2648 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2649 }
2650 session.peer.send(
2651 session.connection_id,
2652 build_update_user_channels(&channels_for_user),
2653 )?;
2654 session.peer.send(
2655 session.connection_id,
2656 build_channels_update(channels_for_user),
2657 )?;
2658 Ok(())
2659}
2660
2661/// Creates a new channel.
2662async fn create_channel(
2663 request: proto::CreateChannel,
2664 response: Response<proto::CreateChannel>,
2665 session: Session,
2666) -> Result<()> {
2667 let db = session.db().await;
2668
2669 let parent_id = request.parent_id.map(ChannelId::from_proto);
2670 let (channel, membership) = db
2671 .create_channel(&request.name, parent_id, session.user_id())
2672 .await?;
2673
2674 let root_id = channel.root_id();
2675 let channel = Channel::from_model(channel);
2676
2677 response.send(proto::CreateChannelResponse {
2678 channel: Some(channel.to_proto()),
2679 parent_id: request.parent_id,
2680 })?;
2681
2682 let mut connection_pool = session.connection_pool().await;
2683 if let Some(membership) = membership {
2684 connection_pool.subscribe_to_channel(
2685 membership.user_id,
2686 membership.channel_id,
2687 membership.role,
2688 );
2689 let update = proto::UpdateUserChannels {
2690 channel_memberships: vec![proto::ChannelMembership {
2691 channel_id: membership.channel_id.to_proto(),
2692 role: membership.role.into(),
2693 }],
2694 ..Default::default()
2695 };
2696 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2697 session.peer.send(connection_id, update.clone())?;
2698 }
2699 }
2700
2701 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2702 if !role.can_see_channel(channel.visibility) {
2703 continue;
2704 }
2705
2706 let update = proto::UpdateChannels {
2707 channels: vec![channel.to_proto()],
2708 ..Default::default()
2709 };
2710 session.peer.send(connection_id, update.clone())?;
2711 }
2712
2713 Ok(())
2714}
2715
2716/// Delete a channel
2717async fn delete_channel(
2718 request: proto::DeleteChannel,
2719 response: Response<proto::DeleteChannel>,
2720 session: Session,
2721) -> Result<()> {
2722 let db = session.db().await;
2723
2724 let channel_id = request.channel_id;
2725 let (root_channel, removed_channels) = db
2726 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2727 .await?;
2728 response.send(proto::Ack {})?;
2729
2730 // Notify members of removed channels
2731 let mut update = proto::UpdateChannels::default();
2732 update
2733 .delete_channels
2734 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2735
2736 let connection_pool = session.connection_pool().await;
2737 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2738 session.peer.send(connection_id, update.clone())?;
2739 }
2740
2741 Ok(())
2742}
2743
2744/// Invite someone to join a channel.
2745async fn invite_channel_member(
2746 request: proto::InviteChannelMember,
2747 response: Response<proto::InviteChannelMember>,
2748 session: Session,
2749) -> Result<()> {
2750 let db = session.db().await;
2751 let channel_id = ChannelId::from_proto(request.channel_id);
2752 let invitee_id = UserId::from_proto(request.user_id);
2753 let InviteMemberResult {
2754 channel,
2755 notifications,
2756 } = db
2757 .invite_channel_member(
2758 channel_id,
2759 invitee_id,
2760 session.user_id(),
2761 request.role().into(),
2762 )
2763 .await?;
2764
2765 let update = proto::UpdateChannels {
2766 channel_invitations: vec![channel.to_proto()],
2767 ..Default::default()
2768 };
2769
2770 let connection_pool = session.connection_pool().await;
2771 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2772 session.peer.send(connection_id, update.clone())?;
2773 }
2774
2775 send_notifications(&connection_pool, &session.peer, notifications);
2776
2777 response.send(proto::Ack {})?;
2778 Ok(())
2779}
2780
2781/// remove someone from a channel
2782async fn remove_channel_member(
2783 request: proto::RemoveChannelMember,
2784 response: Response<proto::RemoveChannelMember>,
2785 session: Session,
2786) -> Result<()> {
2787 let db = session.db().await;
2788 let channel_id = ChannelId::from_proto(request.channel_id);
2789 let member_id = UserId::from_proto(request.user_id);
2790
2791 let RemoveChannelMemberResult {
2792 membership_update,
2793 notification_id,
2794 } = db
2795 .remove_channel_member(channel_id, member_id, session.user_id())
2796 .await?;
2797
2798 let mut connection_pool = session.connection_pool().await;
2799 notify_membership_updated(
2800 &mut connection_pool,
2801 membership_update,
2802 member_id,
2803 &session.peer,
2804 );
2805 for connection_id in connection_pool.user_connection_ids(member_id) {
2806 if let Some(notification_id) = notification_id {
2807 session
2808 .peer
2809 .send(
2810 connection_id,
2811 proto::DeleteNotification {
2812 notification_id: notification_id.to_proto(),
2813 },
2814 )
2815 .trace_err();
2816 }
2817 }
2818
2819 response.send(proto::Ack {})?;
2820 Ok(())
2821}
2822
2823/// Toggle the channel between public and private.
2824/// Care is taken to maintain the invariant that public channels only descend from public channels,
2825/// (though members-only channels can appear at any point in the hierarchy).
2826async fn set_channel_visibility(
2827 request: proto::SetChannelVisibility,
2828 response: Response<proto::SetChannelVisibility>,
2829 session: Session,
2830) -> Result<()> {
2831 let db = session.db().await;
2832 let channel_id = ChannelId::from_proto(request.channel_id);
2833 let visibility = request.visibility().into();
2834
2835 let channel_model = db
2836 .set_channel_visibility(channel_id, visibility, session.user_id())
2837 .await?;
2838 let root_id = channel_model.root_id();
2839 let channel = Channel::from_model(channel_model);
2840
2841 let mut connection_pool = session.connection_pool().await;
2842 for (user_id, role) in connection_pool
2843 .channel_user_ids(root_id)
2844 .collect::<Vec<_>>()
2845 .into_iter()
2846 {
2847 let update = if role.can_see_channel(channel.visibility) {
2848 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2849 proto::UpdateChannels {
2850 channels: vec![channel.to_proto()],
2851 ..Default::default()
2852 }
2853 } else {
2854 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2855 proto::UpdateChannels {
2856 delete_channels: vec![channel.id.to_proto()],
2857 ..Default::default()
2858 }
2859 };
2860
2861 for connection_id in connection_pool.user_connection_ids(user_id) {
2862 session.peer.send(connection_id, update.clone())?;
2863 }
2864 }
2865
2866 response.send(proto::Ack {})?;
2867 Ok(())
2868}
2869
2870/// Alter the role for a user in the channel.
2871async fn set_channel_member_role(
2872 request: proto::SetChannelMemberRole,
2873 response: Response<proto::SetChannelMemberRole>,
2874 session: Session,
2875) -> Result<()> {
2876 let db = session.db().await;
2877 let channel_id = ChannelId::from_proto(request.channel_id);
2878 let member_id = UserId::from_proto(request.user_id);
2879 let result = db
2880 .set_channel_member_role(
2881 channel_id,
2882 session.user_id(),
2883 member_id,
2884 request.role().into(),
2885 )
2886 .await?;
2887
2888 match result {
2889 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2890 let mut connection_pool = session.connection_pool().await;
2891 notify_membership_updated(
2892 &mut connection_pool,
2893 membership_update,
2894 member_id,
2895 &session.peer,
2896 )
2897 }
2898 db::SetMemberRoleResult::InviteUpdated(channel) => {
2899 let update = proto::UpdateChannels {
2900 channel_invitations: vec![channel.to_proto()],
2901 ..Default::default()
2902 };
2903
2904 for connection_id in session
2905 .connection_pool()
2906 .await
2907 .user_connection_ids(member_id)
2908 {
2909 session.peer.send(connection_id, update.clone())?;
2910 }
2911 }
2912 }
2913
2914 response.send(proto::Ack {})?;
2915 Ok(())
2916}
2917
2918/// Change the name of a channel
2919async fn rename_channel(
2920 request: proto::RenameChannel,
2921 response: Response<proto::RenameChannel>,
2922 session: Session,
2923) -> Result<()> {
2924 let db = session.db().await;
2925 let channel_id = ChannelId::from_proto(request.channel_id);
2926 let channel_model = db
2927 .rename_channel(channel_id, session.user_id(), &request.name)
2928 .await?;
2929 let root_id = channel_model.root_id();
2930 let channel = Channel::from_model(channel_model);
2931
2932 response.send(proto::RenameChannelResponse {
2933 channel: Some(channel.to_proto()),
2934 })?;
2935
2936 let connection_pool = session.connection_pool().await;
2937 let update = proto::UpdateChannels {
2938 channels: vec![channel.to_proto()],
2939 ..Default::default()
2940 };
2941 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2942 if role.can_see_channel(channel.visibility) {
2943 session.peer.send(connection_id, update.clone())?;
2944 }
2945 }
2946
2947 Ok(())
2948}
2949
2950/// Move a channel to a new parent.
2951async fn move_channel(
2952 request: proto::MoveChannel,
2953 response: Response<proto::MoveChannel>,
2954 session: Session,
2955) -> Result<()> {
2956 let channel_id = ChannelId::from_proto(request.channel_id);
2957 let to = ChannelId::from_proto(request.to);
2958
2959 let (root_id, channels) = session
2960 .db()
2961 .await
2962 .move_channel(channel_id, to, session.user_id())
2963 .await?;
2964
2965 let connection_pool = session.connection_pool().await;
2966 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2967 let channels = channels
2968 .iter()
2969 .filter_map(|channel| {
2970 if role.can_see_channel(channel.visibility) {
2971 Some(channel.to_proto())
2972 } else {
2973 None
2974 }
2975 })
2976 .collect::<Vec<_>>();
2977 if channels.is_empty() {
2978 continue;
2979 }
2980
2981 let update = proto::UpdateChannels {
2982 channels,
2983 ..Default::default()
2984 };
2985
2986 session.peer.send(connection_id, update.clone())?;
2987 }
2988
2989 response.send(Ack {})?;
2990 Ok(())
2991}
2992
2993/// Get the list of channel members
2994async fn get_channel_members(
2995 request: proto::GetChannelMembers,
2996 response: Response<proto::GetChannelMembers>,
2997 session: Session,
2998) -> Result<()> {
2999 let db = session.db().await;
3000 let channel_id = ChannelId::from_proto(request.channel_id);
3001 let limit = if request.limit == 0 {
3002 u16::MAX as u64
3003 } else {
3004 request.limit
3005 };
3006 let (members, users) = db
3007 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3008 .await?;
3009 response.send(proto::GetChannelMembersResponse { members, users })?;
3010 Ok(())
3011}
3012
3013/// Accept or decline a channel invitation.
3014async fn respond_to_channel_invite(
3015 request: proto::RespondToChannelInvite,
3016 response: Response<proto::RespondToChannelInvite>,
3017 session: Session,
3018) -> Result<()> {
3019 let db = session.db().await;
3020 let channel_id = ChannelId::from_proto(request.channel_id);
3021 let RespondToChannelInvite {
3022 membership_update,
3023 notifications,
3024 } = db
3025 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3026 .await?;
3027
3028 let mut connection_pool = session.connection_pool().await;
3029 if let Some(membership_update) = membership_update {
3030 notify_membership_updated(
3031 &mut connection_pool,
3032 membership_update,
3033 session.user_id(),
3034 &session.peer,
3035 );
3036 } else {
3037 let update = proto::UpdateChannels {
3038 remove_channel_invitations: vec![channel_id.to_proto()],
3039 ..Default::default()
3040 };
3041
3042 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3043 session.peer.send(connection_id, update.clone())?;
3044 }
3045 };
3046
3047 send_notifications(&connection_pool, &session.peer, notifications);
3048
3049 response.send(proto::Ack {})?;
3050
3051 Ok(())
3052}
3053
3054/// Join the channels' room
3055async fn join_channel(
3056 request: proto::JoinChannel,
3057 response: Response<proto::JoinChannel>,
3058 session: Session,
3059) -> Result<()> {
3060 let channel_id = ChannelId::from_proto(request.channel_id);
3061 join_channel_internal(channel_id, Box::new(response), session).await
3062}
3063
3064trait JoinChannelInternalResponse {
3065 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3066}
3067impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3068 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3069 Response::<proto::JoinChannel>::send(self, result)
3070 }
3071}
3072impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3073 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3074 Response::<proto::JoinRoom>::send(self, result)
3075 }
3076}
3077
3078async fn join_channel_internal(
3079 channel_id: ChannelId,
3080 response: Box<impl JoinChannelInternalResponse>,
3081 session: Session,
3082) -> Result<()> {
3083 let joined_room = {
3084 let mut db = session.db().await;
3085 // If zed quits without leaving the room, and the user re-opens zed before the
3086 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3087 // room they were in.
3088 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3089 tracing::info!(
3090 stale_connection_id = %connection,
3091 "cleaning up stale connection",
3092 );
3093 drop(db);
3094 leave_room_for_session(&session, connection).await?;
3095 db = session.db().await;
3096 }
3097
3098 let (joined_room, membership_updated, role) = db
3099 .join_channel(channel_id, session.user_id(), session.connection_id)
3100 .await?;
3101
3102 let live_kit_connection_info =
3103 session
3104 .app_state
3105 .livekit_client
3106 .as_ref()
3107 .and_then(|live_kit| {
3108 let (can_publish, token) = if role == ChannelRole::Guest {
3109 (
3110 false,
3111 live_kit
3112 .guest_token(
3113 &joined_room.room.livekit_room,
3114 &session.user_id().to_string(),
3115 )
3116 .trace_err()?,
3117 )
3118 } else {
3119 (
3120 true,
3121 live_kit
3122 .room_token(
3123 &joined_room.room.livekit_room,
3124 &session.user_id().to_string(),
3125 )
3126 .trace_err()?,
3127 )
3128 };
3129
3130 Some(LiveKitConnectionInfo {
3131 server_url: live_kit.url().into(),
3132 token,
3133 can_publish,
3134 })
3135 });
3136
3137 response.send(proto::JoinRoomResponse {
3138 room: Some(joined_room.room.clone()),
3139 channel_id: joined_room
3140 .channel
3141 .as_ref()
3142 .map(|channel| channel.id.to_proto()),
3143 live_kit_connection_info,
3144 })?;
3145
3146 let mut connection_pool = session.connection_pool().await;
3147 if let Some(membership_updated) = membership_updated {
3148 notify_membership_updated(
3149 &mut connection_pool,
3150 membership_updated,
3151 session.user_id(),
3152 &session.peer,
3153 );
3154 }
3155
3156 room_updated(&joined_room.room, &session.peer);
3157
3158 joined_room
3159 };
3160
3161 channel_updated(
3162 &joined_room
3163 .channel
3164 .ok_or_else(|| anyhow!("channel not returned"))?,
3165 &joined_room.room,
3166 &session.peer,
3167 &*session.connection_pool().await,
3168 );
3169
3170 update_user_contacts(session.user_id(), &session).await?;
3171 Ok(())
3172}
3173
3174/// Start editing the channel notes
3175async fn join_channel_buffer(
3176 request: proto::JoinChannelBuffer,
3177 response: Response<proto::JoinChannelBuffer>,
3178 session: Session,
3179) -> Result<()> {
3180 let db = session.db().await;
3181 let channel_id = ChannelId::from_proto(request.channel_id);
3182
3183 let open_response = db
3184 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3185 .await?;
3186
3187 let collaborators = open_response.collaborators.clone();
3188 response.send(open_response)?;
3189
3190 let update = UpdateChannelBufferCollaborators {
3191 channel_id: channel_id.to_proto(),
3192 collaborators: collaborators.clone(),
3193 };
3194 channel_buffer_updated(
3195 session.connection_id,
3196 collaborators
3197 .iter()
3198 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3199 &update,
3200 &session.peer,
3201 );
3202
3203 Ok(())
3204}
3205
3206/// Edit the channel notes
3207async fn update_channel_buffer(
3208 request: proto::UpdateChannelBuffer,
3209 session: Session,
3210) -> Result<()> {
3211 let db = session.db().await;
3212 let channel_id = ChannelId::from_proto(request.channel_id);
3213
3214 let (collaborators, epoch, version) = db
3215 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3216 .await?;
3217
3218 channel_buffer_updated(
3219 session.connection_id,
3220 collaborators.clone(),
3221 &proto::UpdateChannelBuffer {
3222 channel_id: channel_id.to_proto(),
3223 operations: request.operations,
3224 },
3225 &session.peer,
3226 );
3227
3228 let pool = &*session.connection_pool().await;
3229
3230 let non_collaborators =
3231 pool.channel_connection_ids(channel_id)
3232 .filter_map(|(connection_id, _)| {
3233 if collaborators.contains(&connection_id) {
3234 None
3235 } else {
3236 Some(connection_id)
3237 }
3238 });
3239
3240 broadcast(None, non_collaborators, |peer_id| {
3241 session.peer.send(
3242 peer_id,
3243 proto::UpdateChannels {
3244 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3245 channel_id: channel_id.to_proto(),
3246 epoch: epoch as u64,
3247 version: version.clone(),
3248 }],
3249 ..Default::default()
3250 },
3251 )
3252 });
3253
3254 Ok(())
3255}
3256
3257/// Rejoin the channel notes after a connection blip
3258async fn rejoin_channel_buffers(
3259 request: proto::RejoinChannelBuffers,
3260 response: Response<proto::RejoinChannelBuffers>,
3261 session: Session,
3262) -> Result<()> {
3263 let db = session.db().await;
3264 let buffers = db
3265 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3266 .await?;
3267
3268 for rejoined_buffer in &buffers {
3269 let collaborators_to_notify = rejoined_buffer
3270 .buffer
3271 .collaborators
3272 .iter()
3273 .filter_map(|c| Some(c.peer_id?.into()));
3274 channel_buffer_updated(
3275 session.connection_id,
3276 collaborators_to_notify,
3277 &proto::UpdateChannelBufferCollaborators {
3278 channel_id: rejoined_buffer.buffer.channel_id,
3279 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3280 },
3281 &session.peer,
3282 );
3283 }
3284
3285 response.send(proto::RejoinChannelBuffersResponse {
3286 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3287 })?;
3288
3289 Ok(())
3290}
3291
3292/// Stop editing the channel notes
3293async fn leave_channel_buffer(
3294 request: proto::LeaveChannelBuffer,
3295 response: Response<proto::LeaveChannelBuffer>,
3296 session: Session,
3297) -> Result<()> {
3298 let db = session.db().await;
3299 let channel_id = ChannelId::from_proto(request.channel_id);
3300
3301 let left_buffer = db
3302 .leave_channel_buffer(channel_id, session.connection_id)
3303 .await?;
3304
3305 response.send(Ack {})?;
3306
3307 channel_buffer_updated(
3308 session.connection_id,
3309 left_buffer.connections,
3310 &proto::UpdateChannelBufferCollaborators {
3311 channel_id: channel_id.to_proto(),
3312 collaborators: left_buffer.collaborators,
3313 },
3314 &session.peer,
3315 );
3316
3317 Ok(())
3318}
3319
3320fn channel_buffer_updated<T: EnvelopedMessage>(
3321 sender_id: ConnectionId,
3322 collaborators: impl IntoIterator<Item = ConnectionId>,
3323 message: &T,
3324 peer: &Peer,
3325) {
3326 broadcast(Some(sender_id), collaborators, |peer_id| {
3327 peer.send(peer_id, message.clone())
3328 });
3329}
3330
3331fn send_notifications(
3332 connection_pool: &ConnectionPool,
3333 peer: &Peer,
3334 notifications: db::NotificationBatch,
3335) {
3336 for (user_id, notification) in notifications {
3337 for connection_id in connection_pool.user_connection_ids(user_id) {
3338 if let Err(error) = peer.send(
3339 connection_id,
3340 proto::AddNotification {
3341 notification: Some(notification.clone()),
3342 },
3343 ) {
3344 tracing::error!(
3345 "failed to send notification to {:?} {}",
3346 connection_id,
3347 error
3348 );
3349 }
3350 }
3351 }
3352}
3353
3354/// Send a message to the channel
3355async fn send_channel_message(
3356 request: proto::SendChannelMessage,
3357 response: Response<proto::SendChannelMessage>,
3358 session: Session,
3359) -> Result<()> {
3360 // Validate the message body.
3361 let body = request.body.trim().to_string();
3362 if body.len() > MAX_MESSAGE_LEN {
3363 return Err(anyhow!("message is too long"))?;
3364 }
3365 if body.is_empty() {
3366 return Err(anyhow!("message can't be blank"))?;
3367 }
3368
3369 // TODO: adjust mentions if body is trimmed
3370
3371 let timestamp = OffsetDateTime::now_utc();
3372 let nonce = request
3373 .nonce
3374 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3375
3376 let channel_id = ChannelId::from_proto(request.channel_id);
3377 let CreatedChannelMessage {
3378 message_id,
3379 participant_connection_ids,
3380 notifications,
3381 } = session
3382 .db()
3383 .await
3384 .create_channel_message(
3385 channel_id,
3386 session.user_id(),
3387 &body,
3388 &request.mentions,
3389 timestamp,
3390 nonce.clone().into(),
3391 request.reply_to_message_id.map(MessageId::from_proto),
3392 )
3393 .await?;
3394
3395 let message = proto::ChannelMessage {
3396 sender_id: session.user_id().to_proto(),
3397 id: message_id.to_proto(),
3398 body,
3399 mentions: request.mentions,
3400 timestamp: timestamp.unix_timestamp() as u64,
3401 nonce: Some(nonce),
3402 reply_to_message_id: request.reply_to_message_id,
3403 edited_at: None,
3404 };
3405 broadcast(
3406 Some(session.connection_id),
3407 participant_connection_ids.clone(),
3408 |connection| {
3409 session.peer.send(
3410 connection,
3411 proto::ChannelMessageSent {
3412 channel_id: channel_id.to_proto(),
3413 message: Some(message.clone()),
3414 },
3415 )
3416 },
3417 );
3418 response.send(proto::SendChannelMessageResponse {
3419 message: Some(message),
3420 })?;
3421
3422 let pool = &*session.connection_pool().await;
3423 let non_participants =
3424 pool.channel_connection_ids(channel_id)
3425 .filter_map(|(connection_id, _)| {
3426 if participant_connection_ids.contains(&connection_id) {
3427 None
3428 } else {
3429 Some(connection_id)
3430 }
3431 });
3432 broadcast(None, non_participants, |peer_id| {
3433 session.peer.send(
3434 peer_id,
3435 proto::UpdateChannels {
3436 latest_channel_message_ids: vec![proto::ChannelMessageId {
3437 channel_id: channel_id.to_proto(),
3438 message_id: message_id.to_proto(),
3439 }],
3440 ..Default::default()
3441 },
3442 )
3443 });
3444 send_notifications(pool, &session.peer, notifications);
3445
3446 Ok(())
3447}
3448
3449/// Delete a channel message
3450async fn remove_channel_message(
3451 request: proto::RemoveChannelMessage,
3452 response: Response<proto::RemoveChannelMessage>,
3453 session: Session,
3454) -> Result<()> {
3455 let channel_id = ChannelId::from_proto(request.channel_id);
3456 let message_id = MessageId::from_proto(request.message_id);
3457 let (connection_ids, existing_notification_ids) = session
3458 .db()
3459 .await
3460 .remove_channel_message(channel_id, message_id, session.user_id())
3461 .await?;
3462
3463 broadcast(
3464 Some(session.connection_id),
3465 connection_ids,
3466 move |connection| {
3467 session.peer.send(connection, request.clone())?;
3468
3469 for notification_id in &existing_notification_ids {
3470 session.peer.send(
3471 connection,
3472 proto::DeleteNotification {
3473 notification_id: (*notification_id).to_proto(),
3474 },
3475 )?;
3476 }
3477
3478 Ok(())
3479 },
3480 );
3481 response.send(proto::Ack {})?;
3482 Ok(())
3483}
3484
3485async fn update_channel_message(
3486 request: proto::UpdateChannelMessage,
3487 response: Response<proto::UpdateChannelMessage>,
3488 session: Session,
3489) -> Result<()> {
3490 let channel_id = ChannelId::from_proto(request.channel_id);
3491 let message_id = MessageId::from_proto(request.message_id);
3492 let updated_at = OffsetDateTime::now_utc();
3493 let UpdatedChannelMessage {
3494 message_id,
3495 participant_connection_ids,
3496 notifications,
3497 reply_to_message_id,
3498 timestamp,
3499 deleted_mention_notification_ids,
3500 updated_mention_notifications,
3501 } = session
3502 .db()
3503 .await
3504 .update_channel_message(
3505 channel_id,
3506 message_id,
3507 session.user_id(),
3508 request.body.as_str(),
3509 &request.mentions,
3510 updated_at,
3511 )
3512 .await?;
3513
3514 let nonce = request
3515 .nonce
3516 .clone()
3517 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3518
3519 let message = proto::ChannelMessage {
3520 sender_id: session.user_id().to_proto(),
3521 id: message_id.to_proto(),
3522 body: request.body.clone(),
3523 mentions: request.mentions.clone(),
3524 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3525 nonce: Some(nonce),
3526 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3527 edited_at: Some(updated_at.unix_timestamp() as u64),
3528 };
3529
3530 response.send(proto::Ack {})?;
3531
3532 let pool = &*session.connection_pool().await;
3533 broadcast(
3534 Some(session.connection_id),
3535 participant_connection_ids,
3536 |connection| {
3537 session.peer.send(
3538 connection,
3539 proto::ChannelMessageUpdate {
3540 channel_id: channel_id.to_proto(),
3541 message: Some(message.clone()),
3542 },
3543 )?;
3544
3545 for notification_id in &deleted_mention_notification_ids {
3546 session.peer.send(
3547 connection,
3548 proto::DeleteNotification {
3549 notification_id: (*notification_id).to_proto(),
3550 },
3551 )?;
3552 }
3553
3554 for notification in &updated_mention_notifications {
3555 session.peer.send(
3556 connection,
3557 proto::UpdateNotification {
3558 notification: Some(notification.clone()),
3559 },
3560 )?;
3561 }
3562
3563 Ok(())
3564 },
3565 );
3566
3567 send_notifications(pool, &session.peer, notifications);
3568
3569 Ok(())
3570}
3571
3572/// Mark a channel message as read
3573async fn acknowledge_channel_message(
3574 request: proto::AckChannelMessage,
3575 session: Session,
3576) -> Result<()> {
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578 let message_id = MessageId::from_proto(request.message_id);
3579 let notifications = session
3580 .db()
3581 .await
3582 .observe_channel_message(channel_id, session.user_id(), message_id)
3583 .await?;
3584 send_notifications(
3585 &*session.connection_pool().await,
3586 &session.peer,
3587 notifications,
3588 );
3589 Ok(())
3590}
3591
3592/// Mark a buffer version as synced
3593async fn acknowledge_buffer_version(
3594 request: proto::AckBufferOperation,
3595 session: Session,
3596) -> Result<()> {
3597 let buffer_id = BufferId::from_proto(request.buffer_id);
3598 session
3599 .db()
3600 .await
3601 .observe_buffer_version(
3602 buffer_id,
3603 session.user_id(),
3604 request.epoch as i32,
3605 &request.version,
3606 )
3607 .await?;
3608 Ok(())
3609}
3610
3611async fn count_language_model_tokens(
3612 request: proto::CountLanguageModelTokens,
3613 response: Response<proto::CountLanguageModelTokens>,
3614 session: Session,
3615 config: &Config,
3616) -> Result<()> {
3617 authorize_access_to_legacy_llm_endpoints(&session).await?;
3618
3619 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3620 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3621 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3622 };
3623
3624 session
3625 .app_state
3626 .rate_limiter
3627 .check(&*rate_limit, session.user_id())
3628 .await?;
3629
3630 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3631 Some(proto::LanguageModelProvider::Google) => {
3632 let api_key = config
3633 .google_ai_api_key
3634 .as_ref()
3635 .context("no Google AI API key configured on the server")?;
3636 google_ai::count_tokens(
3637 session.http_client.as_ref(),
3638 google_ai::API_URL,
3639 api_key,
3640 serde_json::from_str(&request.request)?,
3641 )
3642 .await?
3643 }
3644 _ => return Err(anyhow!("unsupported provider"))?,
3645 };
3646
3647 response.send(proto::CountLanguageModelTokensResponse {
3648 token_count: result.total_tokens as u32,
3649 })?;
3650
3651 Ok(())
3652}
3653
3654struct ZedProCountLanguageModelTokensRateLimit;
3655
3656impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3657 fn capacity(&self) -> usize {
3658 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3659 .ok()
3660 .and_then(|v| v.parse().ok())
3661 .unwrap_or(600) // Picked arbitrarily
3662 }
3663
3664 fn refill_duration(&self) -> chrono::Duration {
3665 chrono::Duration::hours(1)
3666 }
3667
3668 fn db_name(&self) -> &'static str {
3669 "zed-pro:count-language-model-tokens"
3670 }
3671}
3672
3673struct FreeCountLanguageModelTokensRateLimit;
3674
3675impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3676 fn capacity(&self) -> usize {
3677 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3678 .ok()
3679 .and_then(|v| v.parse().ok())
3680 .unwrap_or(600 / 10) // Picked arbitrarily
3681 }
3682
3683 fn refill_duration(&self) -> chrono::Duration {
3684 chrono::Duration::hours(1)
3685 }
3686
3687 fn db_name(&self) -> &'static str {
3688 "free:count-language-model-tokens"
3689 }
3690}
3691
3692struct ZedProComputeEmbeddingsRateLimit;
3693
3694impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3695 fn capacity(&self) -> usize {
3696 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3697 .ok()
3698 .and_then(|v| v.parse().ok())
3699 .unwrap_or(5000) // Picked arbitrarily
3700 }
3701
3702 fn refill_duration(&self) -> chrono::Duration {
3703 chrono::Duration::hours(1)
3704 }
3705
3706 fn db_name(&self) -> &'static str {
3707 "zed-pro:compute-embeddings"
3708 }
3709}
3710
3711struct FreeComputeEmbeddingsRateLimit;
3712
3713impl RateLimit for FreeComputeEmbeddingsRateLimit {
3714 fn capacity(&self) -> usize {
3715 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3716 .ok()
3717 .and_then(|v| v.parse().ok())
3718 .unwrap_or(5000 / 10) // Picked arbitrarily
3719 }
3720
3721 fn refill_duration(&self) -> chrono::Duration {
3722 chrono::Duration::hours(1)
3723 }
3724
3725 fn db_name(&self) -> &'static str {
3726 "free:compute-embeddings"
3727 }
3728}
3729
3730async fn compute_embeddings(
3731 request: proto::ComputeEmbeddings,
3732 response: Response<proto::ComputeEmbeddings>,
3733 session: Session,
3734 api_key: Option<Arc<str>>,
3735) -> Result<()> {
3736 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3737 authorize_access_to_legacy_llm_endpoints(&session).await?;
3738
3739 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3740 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3741 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3742 };
3743
3744 session
3745 .app_state
3746 .rate_limiter
3747 .check(&*rate_limit, session.user_id())
3748 .await?;
3749
3750 let embeddings = match request.model.as_str() {
3751 "openai/text-embedding-3-small" => {
3752 open_ai::embed(
3753 session.http_client.as_ref(),
3754 OPEN_AI_API_URL,
3755 &api_key,
3756 OpenAiEmbeddingModel::TextEmbedding3Small,
3757 request.texts.iter().map(|text| text.as_str()),
3758 )
3759 .await?
3760 }
3761 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3762 };
3763
3764 let embeddings = request
3765 .texts
3766 .iter()
3767 .map(|text| {
3768 let mut hasher = sha2::Sha256::new();
3769 hasher.update(text.as_bytes());
3770 let result = hasher.finalize();
3771 result.to_vec()
3772 })
3773 .zip(
3774 embeddings
3775 .data
3776 .into_iter()
3777 .map(|embedding| embedding.embedding),
3778 )
3779 .collect::<HashMap<_, _>>();
3780
3781 let db = session.db().await;
3782 db.save_embeddings(&request.model, &embeddings)
3783 .await
3784 .context("failed to save embeddings")
3785 .trace_err();
3786
3787 response.send(proto::ComputeEmbeddingsResponse {
3788 embeddings: embeddings
3789 .into_iter()
3790 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3791 .collect(),
3792 })?;
3793 Ok(())
3794}
3795
3796async fn get_cached_embeddings(
3797 request: proto::GetCachedEmbeddings,
3798 response: Response<proto::GetCachedEmbeddings>,
3799 session: Session,
3800) -> Result<()> {
3801 authorize_access_to_legacy_llm_endpoints(&session).await?;
3802
3803 let db = session.db().await;
3804 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3805
3806 response.send(proto::GetCachedEmbeddingsResponse {
3807 embeddings: embeddings
3808 .into_iter()
3809 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3810 .collect(),
3811 })?;
3812 Ok(())
3813}
3814
3815/// This is leftover from before the LLM service.
3816///
3817/// The endpoints protected by this check will be moved there eventually.
3818async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3819 if session.is_staff() {
3820 Ok(())
3821 } else {
3822 Err(anyhow!("permission denied"))?
3823 }
3824}
3825
3826/// Get a Supermaven API key for the user
3827async fn get_supermaven_api_key(
3828 _request: proto::GetSupermavenApiKey,
3829 response: Response<proto::GetSupermavenApiKey>,
3830 session: Session,
3831) -> Result<()> {
3832 let user_id: String = session.user_id().to_string();
3833 if !session.is_staff() {
3834 return Err(anyhow!("supermaven not enabled for this account"))?;
3835 }
3836
3837 let email = session
3838 .email()
3839 .ok_or_else(|| anyhow!("user must have an email"))?;
3840
3841 let supermaven_admin_api = session
3842 .supermaven_client
3843 .as_ref()
3844 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3845
3846 let result = supermaven_admin_api
3847 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3848 .await?;
3849
3850 response.send(proto::GetSupermavenApiKeyResponse {
3851 api_key: result.api_key,
3852 })?;
3853
3854 Ok(())
3855}
3856
3857/// Start receiving chat updates for a channel
3858async fn join_channel_chat(
3859 request: proto::JoinChannelChat,
3860 response: Response<proto::JoinChannelChat>,
3861 session: Session,
3862) -> Result<()> {
3863 let channel_id = ChannelId::from_proto(request.channel_id);
3864
3865 let db = session.db().await;
3866 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3867 .await?;
3868 let messages = db
3869 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3870 .await?;
3871 response.send(proto::JoinChannelChatResponse {
3872 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3873 messages,
3874 })?;
3875 Ok(())
3876}
3877
3878/// Stop receiving chat updates for a channel
3879async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3880 let channel_id = ChannelId::from_proto(request.channel_id);
3881 session
3882 .db()
3883 .await
3884 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3885 .await?;
3886 Ok(())
3887}
3888
3889/// Retrieve the chat history for a channel
3890async fn get_channel_messages(
3891 request: proto::GetChannelMessages,
3892 response: Response<proto::GetChannelMessages>,
3893 session: Session,
3894) -> Result<()> {
3895 let channel_id = ChannelId::from_proto(request.channel_id);
3896 let messages = session
3897 .db()
3898 .await
3899 .get_channel_messages(
3900 channel_id,
3901 session.user_id(),
3902 MESSAGE_COUNT_PER_PAGE,
3903 Some(MessageId::from_proto(request.before_message_id)),
3904 )
3905 .await?;
3906 response.send(proto::GetChannelMessagesResponse {
3907 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3908 messages,
3909 })?;
3910 Ok(())
3911}
3912
3913/// Retrieve specific chat messages
3914async fn get_channel_messages_by_id(
3915 request: proto::GetChannelMessagesById,
3916 response: Response<proto::GetChannelMessagesById>,
3917 session: Session,
3918) -> Result<()> {
3919 let message_ids = request
3920 .message_ids
3921 .iter()
3922 .map(|id| MessageId::from_proto(*id))
3923 .collect::<Vec<_>>();
3924 let messages = session
3925 .db()
3926 .await
3927 .get_channel_messages_by_id(session.user_id(), &message_ids)
3928 .await?;
3929 response.send(proto::GetChannelMessagesResponse {
3930 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3931 messages,
3932 })?;
3933 Ok(())
3934}
3935
3936/// Retrieve the current users notifications
3937async fn get_notifications(
3938 request: proto::GetNotifications,
3939 response: Response<proto::GetNotifications>,
3940 session: Session,
3941) -> Result<()> {
3942 let notifications = session
3943 .db()
3944 .await
3945 .get_notifications(
3946 session.user_id(),
3947 NOTIFICATION_COUNT_PER_PAGE,
3948 request.before_id.map(db::NotificationId::from_proto),
3949 )
3950 .await?;
3951 response.send(proto::GetNotificationsResponse {
3952 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3953 notifications,
3954 })?;
3955 Ok(())
3956}
3957
3958/// Mark notifications as read
3959async fn mark_notification_as_read(
3960 request: proto::MarkNotificationRead,
3961 response: Response<proto::MarkNotificationRead>,
3962 session: Session,
3963) -> Result<()> {
3964 let database = &session.db().await;
3965 let notifications = database
3966 .mark_notification_as_read_by_id(
3967 session.user_id(),
3968 NotificationId::from_proto(request.notification_id),
3969 )
3970 .await?;
3971 send_notifications(
3972 &*session.connection_pool().await,
3973 &session.peer,
3974 notifications,
3975 );
3976 response.send(proto::Ack {})?;
3977 Ok(())
3978}
3979
3980/// Get the current users information
3981async fn get_private_user_info(
3982 _request: proto::GetPrivateUserInfo,
3983 response: Response<proto::GetPrivateUserInfo>,
3984 session: Session,
3985) -> Result<()> {
3986 let db = session.db().await;
3987
3988 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3989 let user = db
3990 .get_user_by_id(session.user_id())
3991 .await?
3992 .ok_or_else(|| anyhow!("user not found"))?;
3993 let flags = db.get_user_flags(session.user_id()).await?;
3994
3995 response.send(proto::GetPrivateUserInfoResponse {
3996 metrics_id,
3997 staff: user.admin,
3998 flags,
3999 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4000 })?;
4001 Ok(())
4002}
4003
4004/// Accept the terms of service (tos) on behalf of the current user
4005async fn accept_terms_of_service(
4006 _request: proto::AcceptTermsOfService,
4007 response: Response<proto::AcceptTermsOfService>,
4008 session: Session,
4009) -> Result<()> {
4010 let db = session.db().await;
4011
4012 let accepted_tos_at = Utc::now();
4013 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4014 .await?;
4015
4016 response.send(proto::AcceptTermsOfServiceResponse {
4017 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4018 })?;
4019 Ok(())
4020}
4021
4022/// The minimum account age an account must have in order to use the LLM service.
4023const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4024
4025async fn get_llm_api_token(
4026 _request: proto::GetLlmToken,
4027 response: Response<proto::GetLlmToken>,
4028 session: Session,
4029) -> Result<()> {
4030 let db = session.db().await;
4031
4032 let flags = db.get_user_flags(session.user_id()).await?;
4033 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4034 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4035 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4036
4037 if !session.is_staff() && !has_language_models_feature_flag {
4038 Err(anyhow!("permission denied"))?
4039 }
4040
4041 let user_id = session.user_id();
4042 let user = db
4043 .get_user_by_id(user_id)
4044 .await?
4045 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4046
4047 if user.accepted_tos_at.is_none() {
4048 Err(anyhow!("terms of service not accepted"))?
4049 }
4050
4051 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4052
4053 let bypass_account_age_check =
4054 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4055 if !bypass_account_age_check {
4056 let mut account_created_at = user.created_at;
4057 if let Some(github_created_at) = user.github_user_created_at {
4058 account_created_at = account_created_at.min(github_created_at);
4059 }
4060 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4061 Err(anyhow!("account too young"))?
4062 }
4063 }
4064
4065 let billing_preferences = db.get_billing_preferences(user.id).await?;
4066
4067 let token = LlmTokenClaims::create(
4068 &user,
4069 session.is_staff(),
4070 billing_preferences,
4071 has_llm_closed_beta_feature_flag,
4072 has_predict_edits_feature_flag,
4073 has_llm_subscription,
4074 session.current_plan(&db).await?,
4075 session.system_id.clone(),
4076 &session.app_state.config,
4077 )?;
4078 response.send(proto::GetLlmTokenResponse { token })?;
4079 Ok(())
4080}
4081
4082fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4083 let message = match message {
4084 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4085 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4086 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4087 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4088 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4089 code: frame.code.into(),
4090 reason: frame.reason,
4091 })),
4092 // We should never receive a frame while reading the message, according
4093 // to the `tungstenite` maintainers:
4094 //
4095 // > It cannot occur when you read messages from the WebSocket, but it
4096 // > can be used when you want to send the raw frames (e.g. you want to
4097 // > send the frames to the WebSocket without composing the full message first).
4098 // >
4099 // > — https://github.com/snapview/tungstenite-rs/issues/268
4100 TungsteniteMessage::Frame(_) => {
4101 bail!("received an unexpected frame while reading the message")
4102 }
4103 };
4104
4105 Ok(message)
4106}
4107
4108fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4109 match message {
4110 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4111 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4112 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4113 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4114 AxumMessage::Close(frame) => {
4115 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4116 code: frame.code.into(),
4117 reason: frame.reason,
4118 }))
4119 }
4120 }
4121}
4122
4123fn notify_membership_updated(
4124 connection_pool: &mut ConnectionPool,
4125 result: MembershipUpdated,
4126 user_id: UserId,
4127 peer: &Peer,
4128) {
4129 for membership in &result.new_channels.channel_memberships {
4130 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4131 }
4132 for channel_id in &result.removed_channels {
4133 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4134 }
4135
4136 let user_channels_update = proto::UpdateUserChannels {
4137 channel_memberships: result
4138 .new_channels
4139 .channel_memberships
4140 .iter()
4141 .map(|cm| proto::ChannelMembership {
4142 channel_id: cm.channel_id.to_proto(),
4143 role: cm.role.into(),
4144 })
4145 .collect(),
4146 ..Default::default()
4147 };
4148
4149 let mut update = build_channels_update(result.new_channels);
4150 update.delete_channels = result
4151 .removed_channels
4152 .into_iter()
4153 .map(|id| id.to_proto())
4154 .collect();
4155 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4156
4157 for connection_id in connection_pool.user_connection_ids(user_id) {
4158 peer.send(connection_id, user_channels_update.clone())
4159 .trace_err();
4160 peer.send(connection_id, update.clone()).trace_err();
4161 }
4162}
4163
4164fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4165 proto::UpdateUserChannels {
4166 channel_memberships: channels
4167 .channel_memberships
4168 .iter()
4169 .map(|m| proto::ChannelMembership {
4170 channel_id: m.channel_id.to_proto(),
4171 role: m.role.into(),
4172 })
4173 .collect(),
4174 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4175 observed_channel_message_id: channels.observed_channel_messages.clone(),
4176 }
4177}
4178
4179fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4180 let mut update = proto::UpdateChannels::default();
4181
4182 for channel in channels.channels {
4183 update.channels.push(channel.to_proto());
4184 }
4185
4186 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4187 update.latest_channel_message_ids = channels.latest_channel_messages;
4188
4189 for (channel_id, participants) in channels.channel_participants {
4190 update
4191 .channel_participants
4192 .push(proto::ChannelParticipants {
4193 channel_id: channel_id.to_proto(),
4194 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4195 });
4196 }
4197
4198 for channel in channels.invited_channels {
4199 update.channel_invitations.push(channel.to_proto());
4200 }
4201
4202 update
4203}
4204
4205fn build_initial_contacts_update(
4206 contacts: Vec<db::Contact>,
4207 pool: &ConnectionPool,
4208) -> proto::UpdateContacts {
4209 let mut update = proto::UpdateContacts::default();
4210
4211 for contact in contacts {
4212 match contact {
4213 db::Contact::Accepted { user_id, busy } => {
4214 update.contacts.push(contact_for_user(user_id, busy, pool));
4215 }
4216 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4217 db::Contact::Incoming { user_id } => {
4218 update
4219 .incoming_requests
4220 .push(proto::IncomingContactRequest {
4221 requester_id: user_id.to_proto(),
4222 })
4223 }
4224 }
4225 }
4226
4227 update
4228}
4229
4230fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4231 proto::Contact {
4232 user_id: user_id.to_proto(),
4233 online: pool.is_user_online(user_id),
4234 busy,
4235 }
4236}
4237
4238fn room_updated(room: &proto::Room, peer: &Peer) {
4239 broadcast(
4240 None,
4241 room.participants
4242 .iter()
4243 .filter_map(|participant| Some(participant.peer_id?.into())),
4244 |peer_id| {
4245 peer.send(
4246 peer_id,
4247 proto::RoomUpdated {
4248 room: Some(room.clone()),
4249 },
4250 )
4251 },
4252 );
4253}
4254
4255fn channel_updated(
4256 channel: &db::channel::Model,
4257 room: &proto::Room,
4258 peer: &Peer,
4259 pool: &ConnectionPool,
4260) {
4261 let participants = room
4262 .participants
4263 .iter()
4264 .map(|p| p.user_id)
4265 .collect::<Vec<_>>();
4266
4267 broadcast(
4268 None,
4269 pool.channel_connection_ids(channel.root_id())
4270 .filter_map(|(channel_id, role)| {
4271 role.can_see_channel(channel.visibility)
4272 .then_some(channel_id)
4273 }),
4274 |peer_id| {
4275 peer.send(
4276 peer_id,
4277 proto::UpdateChannels {
4278 channel_participants: vec![proto::ChannelParticipants {
4279 channel_id: channel.id.to_proto(),
4280 participant_user_ids: participants.clone(),
4281 }],
4282 ..Default::default()
4283 },
4284 )
4285 },
4286 );
4287}
4288
4289async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4290 let db = session.db().await;
4291
4292 let contacts = db.get_contacts(user_id).await?;
4293 let busy = db.is_user_busy(user_id).await?;
4294
4295 let pool = session.connection_pool().await;
4296 let updated_contact = contact_for_user(user_id, busy, &pool);
4297 for contact in contacts {
4298 if let db::Contact::Accepted {
4299 user_id: contact_user_id,
4300 ..
4301 } = contact
4302 {
4303 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4304 session
4305 .peer
4306 .send(
4307 contact_conn_id,
4308 proto::UpdateContacts {
4309 contacts: vec![updated_contact.clone()],
4310 remove_contacts: Default::default(),
4311 incoming_requests: Default::default(),
4312 remove_incoming_requests: Default::default(),
4313 outgoing_requests: Default::default(),
4314 remove_outgoing_requests: Default::default(),
4315 },
4316 )
4317 .trace_err();
4318 }
4319 }
4320 }
4321 Ok(())
4322}
4323
4324async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4325 let mut contacts_to_update = HashSet::default();
4326
4327 let room_id;
4328 let canceled_calls_to_user_ids;
4329 let livekit_room;
4330 let delete_livekit_room;
4331 let room;
4332 let channel;
4333
4334 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4335 contacts_to_update.insert(session.user_id());
4336
4337 for project in left_room.left_projects.values() {
4338 project_left(project, session);
4339 }
4340
4341 room_id = RoomId::from_proto(left_room.room.id);
4342 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4343 livekit_room = mem::take(&mut left_room.room.livekit_room);
4344 delete_livekit_room = left_room.deleted;
4345 room = mem::take(&mut left_room.room);
4346 channel = mem::take(&mut left_room.channel);
4347
4348 room_updated(&room, &session.peer);
4349 } else {
4350 return Ok(());
4351 }
4352
4353 if let Some(channel) = channel {
4354 channel_updated(
4355 &channel,
4356 &room,
4357 &session.peer,
4358 &*session.connection_pool().await,
4359 );
4360 }
4361
4362 {
4363 let pool = session.connection_pool().await;
4364 for canceled_user_id in canceled_calls_to_user_ids {
4365 for connection_id in pool.user_connection_ids(canceled_user_id) {
4366 session
4367 .peer
4368 .send(
4369 connection_id,
4370 proto::CallCanceled {
4371 room_id: room_id.to_proto(),
4372 },
4373 )
4374 .trace_err();
4375 }
4376 contacts_to_update.insert(canceled_user_id);
4377 }
4378 }
4379
4380 for contact_user_id in contacts_to_update {
4381 update_user_contacts(contact_user_id, session).await?;
4382 }
4383
4384 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4385 live_kit
4386 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4387 .await
4388 .trace_err();
4389
4390 if delete_livekit_room {
4391 live_kit.delete_room(livekit_room).await.trace_err();
4392 }
4393 }
4394
4395 Ok(())
4396}
4397
4398async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4399 let left_channel_buffers = session
4400 .db()
4401 .await
4402 .leave_channel_buffers(session.connection_id)
4403 .await?;
4404
4405 for left_buffer in left_channel_buffers {
4406 channel_buffer_updated(
4407 session.connection_id,
4408 left_buffer.connections,
4409 &proto::UpdateChannelBufferCollaborators {
4410 channel_id: left_buffer.channel_id.to_proto(),
4411 collaborators: left_buffer.collaborators,
4412 },
4413 &session.peer,
4414 );
4415 }
4416
4417 Ok(())
4418}
4419
4420fn project_left(project: &db::LeftProject, session: &Session) {
4421 for connection_id in &project.connection_ids {
4422 if project.should_unshare {
4423 session
4424 .peer
4425 .send(
4426 *connection_id,
4427 proto::UnshareProject {
4428 project_id: project.id.to_proto(),
4429 },
4430 )
4431 .trace_err();
4432 } else {
4433 session
4434 .peer
4435 .send(
4436 *connection_id,
4437 proto::RemoveProjectCollaborator {
4438 project_id: project.id.to_proto(),
4439 peer_id: Some(session.connection_id.into()),
4440 },
4441 )
4442 .trace_err();
4443 }
4444 }
4445}
4446
4447pub trait ResultExt {
4448 type Ok;
4449
4450 fn trace_err(self) -> Option<Self::Ok>;
4451}
4452
4453impl<T, E> ResultExt for Result<T, E>
4454where
4455 E: std::fmt::Debug,
4456{
4457 type Ok = T;
4458
4459 #[track_caller]
4460 fn trace_err(self) -> Option<T> {
4461 match self {
4462 Ok(value) => Some(value),
4463 Err(error) => {
4464 tracing::error!("{:?}", error);
4465 None
4466 }
4467 }
4468 }
4469}