1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
313 .add_request_handler(
314 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
315 )
316 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
317 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
322 .add_request_handler(
323 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
324 )
325 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
326 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
327 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
328 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
330 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
337 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
338 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
339 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
340 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
341 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
342 .add_message_handler(create_buffer_for_peer)
343 .add_request_handler(update_buffer)
344 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
345 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
346 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
347 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
348 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
349 .add_request_handler(get_users)
350 .add_request_handler(fuzzy_search_users)
351 .add_request_handler(request_contact)
352 .add_request_handler(remove_contact)
353 .add_request_handler(respond_to_contact_request)
354 .add_message_handler(subscribe_to_channels)
355 .add_request_handler(create_channel)
356 .add_request_handler(delete_channel)
357 .add_request_handler(invite_channel_member)
358 .add_request_handler(remove_channel_member)
359 .add_request_handler(set_channel_member_role)
360 .add_request_handler(set_channel_visibility)
361 .add_request_handler(rename_channel)
362 .add_request_handler(join_channel_buffer)
363 .add_request_handler(leave_channel_buffer)
364 .add_message_handler(update_channel_buffer)
365 .add_request_handler(rejoin_channel_buffers)
366 .add_request_handler(get_channel_members)
367 .add_request_handler(respond_to_channel_invite)
368 .add_request_handler(join_channel)
369 .add_request_handler(join_channel_chat)
370 .add_message_handler(leave_channel_chat)
371 .add_request_handler(send_channel_message)
372 .add_request_handler(remove_channel_message)
373 .add_request_handler(update_channel_message)
374 .add_request_handler(get_channel_messages)
375 .add_request_handler(get_channel_messages_by_id)
376 .add_request_handler(get_notifications)
377 .add_request_handler(mark_notification_as_read)
378 .add_request_handler(move_channel)
379 .add_request_handler(follow)
380 .add_message_handler(unfollow)
381 .add_message_handler(update_followers)
382 .add_request_handler(get_private_user_info)
383 .add_request_handler(get_llm_api_token)
384 .add_request_handler(accept_terms_of_service)
385 .add_message_handler(acknowledge_channel_message)
386 .add_message_handler(acknowledge_buffer_version)
387 .add_request_handler(get_supermaven_api_key)
388 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
389 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
390 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
391 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
392 .add_message_handler(update_context)
393 .add_request_handler({
394 let app_state = app_state.clone();
395 move |request, response, session| {
396 let app_state = app_state.clone();
397 async move {
398 count_language_model_tokens(request, response, session, &app_state.config)
399 .await
400 }
401 }
402 })
403 .add_request_handler(get_cached_embeddings)
404 .add_request_handler({
405 let app_state = app_state.clone();
406 move |request, response, session| {
407 compute_embeddings(
408 request,
409 response,
410 session,
411 app_state.config.openai_api_key.clone(),
412 )
413 }
414 });
415
416 Arc::new(server)
417 }
418
419 pub async fn start(&self) -> Result<()> {
420 let server_id = *self.id.lock();
421 let app_state = self.app_state.clone();
422 let peer = self.peer.clone();
423 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
424 let pool = self.connection_pool.clone();
425 let livekit_client = self.app_state.livekit_client.clone();
426
427 let span = info_span!("start server");
428 self.app_state.executor.spawn_detached(
429 async move {
430 tracing::info!("waiting for cleanup timeout");
431 timeout.await;
432 tracing::info!("cleanup timeout expired, retrieving stale rooms");
433 if let Some((room_ids, channel_ids)) = app_state
434 .db
435 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
436 .await
437 .trace_err()
438 {
439 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
440 tracing::info!(
441 stale_channel_buffer_count = channel_ids.len(),
442 "retrieved stale channel buffers"
443 );
444
445 for channel_id in channel_ids {
446 if let Some(refreshed_channel_buffer) = app_state
447 .db
448 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
449 .await
450 .trace_err()
451 {
452 for connection_id in refreshed_channel_buffer.connection_ids {
453 peer.send(
454 connection_id,
455 proto::UpdateChannelBufferCollaborators {
456 channel_id: channel_id.to_proto(),
457 collaborators: refreshed_channel_buffer
458 .collaborators
459 .clone(),
460 },
461 )
462 .trace_err();
463 }
464 }
465 }
466
467 for room_id in room_ids {
468 let mut contacts_to_update = HashSet::default();
469 let mut canceled_calls_to_user_ids = Vec::new();
470 let mut livekit_room = String::new();
471 let mut delete_livekit_room = false;
472
473 if let Some(mut refreshed_room) = app_state
474 .db
475 .clear_stale_room_participants(room_id, server_id)
476 .await
477 .trace_err()
478 {
479 tracing::info!(
480 room_id = room_id.0,
481 new_participant_count = refreshed_room.room.participants.len(),
482 "refreshed room"
483 );
484 room_updated(&refreshed_room.room, &peer);
485 if let Some(channel) = refreshed_room.channel.as_ref() {
486 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
487 }
488 contacts_to_update
489 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
490 contacts_to_update
491 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
492 canceled_calls_to_user_ids =
493 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
494 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
495 delete_livekit_room = refreshed_room.room.participants.is_empty();
496 }
497
498 {
499 let pool = pool.lock();
500 for canceled_user_id in canceled_calls_to_user_ids {
501 for connection_id in pool.user_connection_ids(canceled_user_id) {
502 peer.send(
503 connection_id,
504 proto::CallCanceled {
505 room_id: room_id.to_proto(),
506 },
507 )
508 .trace_err();
509 }
510 }
511 }
512
513 for user_id in contacts_to_update {
514 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
515 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
516 if let Some((busy, contacts)) = busy.zip(contacts) {
517 let pool = pool.lock();
518 let updated_contact = contact_for_user(user_id, busy, &pool);
519 for contact in contacts {
520 if let db::Contact::Accepted {
521 user_id: contact_user_id,
522 ..
523 } = contact
524 {
525 for contact_conn_id in
526 pool.user_connection_ids(contact_user_id)
527 {
528 peer.send(
529 contact_conn_id,
530 proto::UpdateContacts {
531 contacts: vec![updated_contact.clone()],
532 remove_contacts: Default::default(),
533 incoming_requests: Default::default(),
534 remove_incoming_requests: Default::default(),
535 outgoing_requests: Default::default(),
536 remove_outgoing_requests: Default::default(),
537 },
538 )
539 .trace_err();
540 }
541 }
542 }
543 }
544 }
545
546 if let Some(live_kit) = livekit_client.as_ref() {
547 if delete_livekit_room {
548 live_kit.delete_room(livekit_room).await.trace_err();
549 }
550 }
551 }
552 }
553
554 app_state
555 .db
556 .delete_stale_servers(&app_state.config.zed_environment, server_id)
557 .await
558 .trace_err();
559 }
560 .instrument(span),
561 );
562 Ok(())
563 }
564
565 pub fn teardown(&self) {
566 self.peer.teardown();
567 self.connection_pool.lock().reset();
568 let _ = self.teardown.send(true);
569 }
570
571 #[cfg(test)]
572 pub fn reset(&self, id: ServerId) {
573 self.teardown();
574 *self.id.lock() = id;
575 self.peer.reset(id.0 as u32);
576 let _ = self.teardown.send(false);
577 }
578
579 #[cfg(test)]
580 pub fn id(&self) -> ServerId {
581 *self.id.lock()
582 }
583
584 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
585 where
586 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
587 Fut: 'static + Send + Future<Output = Result<()>>,
588 M: EnvelopedMessage,
589 {
590 let prev_handler = self.handlers.insert(
591 TypeId::of::<M>(),
592 Box::new(move |envelope, session| {
593 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
594 let received_at = envelope.received_at;
595 tracing::info!("message received");
596 let start_time = Instant::now();
597 let future = (handler)(*envelope, session);
598 async move {
599 let result = future.await;
600 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
601 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
602 let queue_duration_ms = total_duration_ms - processing_duration_ms;
603 let payload_type = M::NAME;
604
605 match result {
606 Err(error) => {
607 tracing::error!(
608 ?error,
609 total_duration_ms,
610 processing_duration_ms,
611 queue_duration_ms,
612 payload_type,
613 "error handling message"
614 )
615 }
616 Ok(()) => tracing::info!(
617 total_duration_ms,
618 processing_duration_ms,
619 queue_duration_ms,
620 "finished handling message"
621 ),
622 }
623 }
624 .boxed()
625 }),
626 );
627 if prev_handler.is_some() {
628 panic!("registered a handler for the same message twice");
629 }
630 self
631 }
632
633 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
634 where
635 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
636 Fut: 'static + Send + Future<Output = Result<()>>,
637 M: EnvelopedMessage,
638 {
639 self.add_handler(move |envelope, session| handler(envelope.payload, session));
640 self
641 }
642
643 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
644 where
645 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
646 Fut: Send + Future<Output = Result<()>>,
647 M: RequestMessage,
648 {
649 let handler = Arc::new(handler);
650 self.add_handler(move |envelope, session| {
651 let receipt = envelope.receipt();
652 let handler = handler.clone();
653 async move {
654 let peer = session.peer.clone();
655 let responded = Arc::new(AtomicBool::default());
656 let response = Response {
657 peer: peer.clone(),
658 responded: responded.clone(),
659 receipt,
660 };
661 match (handler)(envelope.payload, response, session).await {
662 Ok(()) => {
663 if responded.load(std::sync::atomic::Ordering::SeqCst) {
664 Ok(())
665 } else {
666 Err(anyhow!("handler did not send a response"))?
667 }
668 }
669 Err(error) => {
670 let proto_err = match &error {
671 Error::Internal(err) => err.to_proto(),
672 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
673 };
674 peer.respond_with_error(receipt, proto_err)?;
675 Err(error)
676 }
677 }
678 }
679 })
680 }
681
682 #[allow(clippy::too_many_arguments)]
683 pub fn handle_connection(
684 self: &Arc<Self>,
685 connection: Connection,
686 address: String,
687 principal: Principal,
688 zed_version: ZedVersion,
689 geoip_country_code: Option<String>,
690 system_id: Option<String>,
691 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
692 executor: Executor,
693 ) -> impl Future<Output = ()> {
694 let this = self.clone();
695 let span = info_span!("handle connection", %address,
696 connection_id=field::Empty,
697 user_id=field::Empty,
698 login=field::Empty,
699 impersonator=field::Empty,
700 geoip_country_code=field::Empty
701 );
702 principal.update_span(&span);
703 if let Some(country_code) = geoip_country_code.as_ref() {
704 span.record("geoip_country_code", country_code);
705 }
706
707 let mut teardown = self.teardown.subscribe();
708 async move {
709 if *teardown.borrow() {
710 tracing::error!("server is tearing down");
711 return
712 }
713 let (connection_id, handle_io, mut incoming_rx) = this
714 .peer
715 .add_connection(connection, {
716 let executor = executor.clone();
717 move |duration| executor.sleep(duration)
718 });
719 tracing::Span::current().record("connection_id", format!("{}", connection_id));
720
721 tracing::info!("connection opened");
722
723 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
724 let http_client = match ReqwestClient::user_agent(&user_agent) {
725 Ok(http_client) => Arc::new(http_client),
726 Err(error) => {
727 tracing::error!(?error, "failed to create HTTP client");
728 return;
729 }
730 };
731
732 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
733 supermaven_admin_api_key.to_string(),
734 http_client.clone(),
735 )));
736
737 let session = Session {
738 principal: principal.clone(),
739 connection_id,
740 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
741 peer: this.peer.clone(),
742 connection_pool: this.connection_pool.clone(),
743 app_state: this.app_state.clone(),
744 http_client,
745 geoip_country_code,
746 system_id,
747 _executor: executor.clone(),
748 supermaven_client,
749 };
750
751 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
752 tracing::error!(?error, "failed to send initial client update");
753 return;
754 }
755
756 let handle_io = handle_io.fuse();
757 futures::pin_mut!(handle_io);
758
759 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
760 // This prevents deadlocks when e.g., client A performs a request to client B and
761 // client B performs a request to client A. If both clients stop processing further
762 // messages until their respective request completes, they won't have a chance to
763 // respond to the other client's request and cause a deadlock.
764 //
765 // This arrangement ensures we will attempt to process earlier messages first, but fall
766 // back to processing messages arrived later in the spirit of making progress.
767 let mut foreground_message_handlers = FuturesUnordered::new();
768 let concurrent_handlers = Arc::new(Semaphore::new(256));
769 loop {
770 let next_message = async {
771 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
772 let message = incoming_rx.next().await;
773 (permit, message)
774 }.fuse();
775 futures::pin_mut!(next_message);
776 futures::select_biased! {
777 _ = teardown.changed().fuse() => return,
778 result = handle_io => {
779 if let Err(error) = result {
780 tracing::error!(?error, "error handling I/O");
781 }
782 break;
783 }
784 _ = foreground_message_handlers.next() => {}
785 next_message = next_message => {
786 let (permit, message) = next_message;
787 if let Some(message) = message {
788 let type_name = message.payload_type_name();
789 // note: we copy all the fields from the parent span so we can query them in the logs.
790 // (https://github.com/tokio-rs/tracing/issues/2670).
791 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
792 user_id=field::Empty,
793 login=field::Empty,
794 impersonator=field::Empty,
795 );
796 principal.update_span(&span);
797 let span_enter = span.enter();
798 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
799 let is_background = message.is_background();
800 let handle_message = (handler)(message, session.clone());
801 drop(span_enter);
802
803 let handle_message = async move {
804 handle_message.await;
805 drop(permit);
806 }.instrument(span);
807 if is_background {
808 executor.spawn_detached(handle_message);
809 } else {
810 foreground_message_handlers.push(handle_message);
811 }
812 } else {
813 tracing::error!("no message handler");
814 }
815 } else {
816 tracing::info!("connection closed");
817 break;
818 }
819 }
820 }
821 }
822
823 drop(foreground_message_handlers);
824 tracing::info!("signing out");
825 if let Err(error) = connection_lost(session, teardown, executor).await {
826 tracing::error!(?error, "error signing out");
827 }
828
829 }.instrument(span)
830 }
831
832 async fn send_initial_client_update(
833 &self,
834 connection_id: ConnectionId,
835 principal: &Principal,
836 zed_version: ZedVersion,
837 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
838 session: &Session,
839 ) -> Result<()> {
840 self.peer.send(
841 connection_id,
842 proto::Hello {
843 peer_id: Some(connection_id.into()),
844 },
845 )?;
846 tracing::info!("sent hello message");
847 if let Some(send_connection_id) = send_connection_id.take() {
848 let _ = send_connection_id.send(connection_id);
849 }
850
851 match principal {
852 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
853 if !user.connected_once {
854 self.peer.send(connection_id, proto::ShowContacts {})?;
855 self.app_state
856 .db
857 .set_user_connected_once(user.id, true)
858 .await?;
859 }
860
861 update_user_plan(user.id, session).await?;
862
863 let contacts = self.app_state.db.get_contacts(user.id).await?;
864
865 {
866 let mut pool = self.connection_pool.lock();
867 pool.add_connection(connection_id, user.id, user.admin, zed_version);
868 self.peer.send(
869 connection_id,
870 build_initial_contacts_update(contacts, &pool),
871 )?;
872 }
873
874 if should_auto_subscribe_to_channels(zed_version) {
875 subscribe_user_to_channels(user.id, session).await?;
876 }
877
878 if let Some(incoming_call) =
879 self.app_state.db.incoming_call_for_user(user.id).await?
880 {
881 self.peer.send(connection_id, incoming_call)?;
882 }
883
884 update_user_contacts(user.id, session).await?;
885 }
886 }
887
888 Ok(())
889 }
890
891 pub async fn invite_code_redeemed(
892 self: &Arc<Self>,
893 inviter_id: UserId,
894 invitee_id: UserId,
895 ) -> Result<()> {
896 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
897 if let Some(code) = &user.invite_code {
898 let pool = self.connection_pool.lock();
899 let invitee_contact = contact_for_user(invitee_id, false, &pool);
900 for connection_id in pool.user_connection_ids(inviter_id) {
901 self.peer.send(
902 connection_id,
903 proto::UpdateContacts {
904 contacts: vec![invitee_contact.clone()],
905 ..Default::default()
906 },
907 )?;
908 self.peer.send(
909 connection_id,
910 proto::UpdateInviteInfo {
911 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
912 count: user.invite_count as u32,
913 },
914 )?;
915 }
916 }
917 }
918 Ok(())
919 }
920
921 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
922 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
923 if let Some(invite_code) = &user.invite_code {
924 let pool = self.connection_pool.lock();
925 for connection_id in pool.user_connection_ids(user_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateInviteInfo {
929 url: format!(
930 "{}{}",
931 self.app_state.config.invite_link_prefix, invite_code
932 ),
933 count: user.invite_count as u32,
934 },
935 )?;
936 }
937 }
938 }
939 Ok(())
940 }
941
942 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
943 let pool = self.connection_pool.lock();
944 for connection_id in pool.user_connection_ids(user_id) {
945 self.peer
946 .send(connection_id, proto::RefreshLlmToken {})
947 .trace_err();
948 }
949 }
950
951 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
952 ServerSnapshot {
953 connection_pool: ConnectionPoolGuard {
954 guard: self.connection_pool.lock(),
955 _not_send: PhantomData,
956 },
957 peer: &self.peer,
958 }
959 }
960}
961
962impl<'a> Deref for ConnectionPoolGuard<'a> {
963 type Target = ConnectionPool;
964
965 fn deref(&self) -> &Self::Target {
966 &self.guard
967 }
968}
969
970impl<'a> DerefMut for ConnectionPoolGuard<'a> {
971 fn deref_mut(&mut self) -> &mut Self::Target {
972 &mut self.guard
973 }
974}
975
976impl<'a> Drop for ConnectionPoolGuard<'a> {
977 fn drop(&mut self) {
978 #[cfg(test)]
979 self.check_invariants();
980 }
981}
982
983fn broadcast<F>(
984 sender_id: Option<ConnectionId>,
985 receiver_ids: impl IntoIterator<Item = ConnectionId>,
986 mut f: F,
987) where
988 F: FnMut(ConnectionId) -> anyhow::Result<()>,
989{
990 for receiver_id in receiver_ids {
991 if Some(receiver_id) != sender_id {
992 if let Err(error) = f(receiver_id) {
993 tracing::error!("failed to send to {:?} {}", receiver_id, error);
994 }
995 }
996 }
997}
998
999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002 fn name() -> &'static HeaderName {
1003 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005 }
1006
1007 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008 where
1009 Self: Sized,
1010 I: Iterator<Item = &'i axum::http::HeaderValue>,
1011 {
1012 let version = values
1013 .next()
1014 .ok_or_else(axum::headers::Error::invalid)?
1015 .to_str()
1016 .map_err(|_| axum::headers::Error::invalid())?
1017 .parse()
1018 .map_err(|_| axum::headers::Error::invalid())?;
1019 Ok(Self(version))
1020 }
1021
1022 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023 values.extend([self.0.to_string().parse().unwrap()]);
1024 }
1025}
1026
1027pub struct AppVersionHeader(SemanticVersion);
1028impl Header for AppVersionHeader {
1029 fn name() -> &'static HeaderName {
1030 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032 }
1033
1034 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035 where
1036 Self: Sized,
1037 I: Iterator<Item = &'i axum::http::HeaderValue>,
1038 {
1039 let version = values
1040 .next()
1041 .ok_or_else(axum::headers::Error::invalid)?
1042 .to_str()
1043 .map_err(|_| axum::headers::Error::invalid())?
1044 .parse()
1045 .map_err(|_| axum::headers::Error::invalid())?;
1046 Ok(Self(version))
1047 }
1048
1049 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050 values.extend([self.0.to_string().parse().unwrap()]);
1051 }
1052}
1053
1054pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1055 Router::new()
1056 .route("/rpc", get(handle_websocket_request))
1057 .layer(
1058 ServiceBuilder::new()
1059 .layer(Extension(server.app_state.clone()))
1060 .layer(middleware::from_fn(auth::validate_header)),
1061 )
1062 .route("/metrics", get(handle_metrics))
1063 .layer(Extension(server))
1064}
1065
1066#[allow(clippy::too_many_arguments)]
1067pub async fn handle_websocket_request(
1068 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1069 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1070 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1071 Extension(server): Extension<Arc<Server>>,
1072 Extension(principal): Extension<Principal>,
1073 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1074 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1075 ws: WebSocketUpgrade,
1076) -> axum::response::Response {
1077 if protocol_version != rpc::PROTOCOL_VERSION {
1078 return (
1079 StatusCode::UPGRADE_REQUIRED,
1080 "client must be upgraded".to_string(),
1081 )
1082 .into_response();
1083 }
1084
1085 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1086 return (
1087 StatusCode::UPGRADE_REQUIRED,
1088 "no version header found".to_string(),
1089 )
1090 .into_response();
1091 };
1092
1093 if !version.can_collaborate() {
1094 return (
1095 StatusCode::UPGRADE_REQUIRED,
1096 "client must be upgraded".to_string(),
1097 )
1098 .into_response();
1099 }
1100
1101 let socket_address = socket_address.to_string();
1102 ws.on_upgrade(move |socket| {
1103 let socket = socket
1104 .map_ok(to_tungstenite_message)
1105 .err_into()
1106 .with(|message| async move { to_axum_message(message) });
1107 let connection = Connection::new(Box::pin(socket));
1108 async move {
1109 server
1110 .handle_connection(
1111 connection,
1112 socket_address,
1113 principal,
1114 version,
1115 country_code_header.map(|header| header.to_string()),
1116 system_id_header.map(|header| header.to_string()),
1117 None,
1118 Executor::Production,
1119 )
1120 .await;
1121 }
1122 })
1123}
1124
1125pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1126 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1127 let connections_metric = CONNECTIONS_METRIC
1128 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1129
1130 let connections = server
1131 .connection_pool
1132 .lock()
1133 .connections()
1134 .filter(|connection| !connection.admin)
1135 .count();
1136 connections_metric.set(connections as _);
1137
1138 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1140 register_int_gauge!(
1141 "shared_projects",
1142 "number of open projects with one or more guests"
1143 )
1144 .unwrap()
1145 });
1146
1147 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1148 shared_projects_metric.set(shared_projects as _);
1149
1150 let encoder = prometheus::TextEncoder::new();
1151 let metric_families = prometheus::gather();
1152 let encoded_metrics = encoder
1153 .encode_to_string(&metric_families)
1154 .map_err(|err| anyhow!("{}", err))?;
1155 Ok(encoded_metrics)
1156}
1157
1158#[instrument(err, skip(executor))]
1159async fn connection_lost(
1160 session: Session,
1161 mut teardown: watch::Receiver<bool>,
1162 executor: Executor,
1163) -> Result<()> {
1164 session.peer.disconnect(session.connection_id);
1165 session
1166 .connection_pool()
1167 .await
1168 .remove_connection(session.connection_id)?;
1169
1170 session
1171 .db()
1172 .await
1173 .connection_lost(session.connection_id)
1174 .await
1175 .trace_err();
1176
1177 futures::select_biased! {
1178 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1179
1180 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1181 leave_room_for_session(&session, session.connection_id).await.trace_err();
1182 leave_channel_buffers_for_session(&session)
1183 .await
1184 .trace_err();
1185
1186 if !session
1187 .connection_pool()
1188 .await
1189 .is_user_online(session.user_id())
1190 {
1191 let db = session.db().await;
1192 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1193 room_updated(&room, &session.peer);
1194 }
1195 }
1196
1197 update_user_contacts(session.user_id(), &session).await?;
1198 },
1199 _ = teardown.changed().fuse() => {}
1200 }
1201
1202 Ok(())
1203}
1204
1205/// Acknowledges a ping from a client, used to keep the connection alive.
1206async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1207 response.send(proto::Ack {})?;
1208 Ok(())
1209}
1210
1211/// Creates a new room for calling (outside of channels)
1212async fn create_room(
1213 _request: proto::CreateRoom,
1214 response: Response<proto::CreateRoom>,
1215 session: Session,
1216) -> Result<()> {
1217 let livekit_room = nanoid::nanoid!(30);
1218
1219 let live_kit_connection_info = util::maybe!(async {
1220 let live_kit = session.app_state.livekit_client.as_ref();
1221 let live_kit = live_kit?;
1222 let user_id = session.user_id().to_string();
1223
1224 let token = live_kit
1225 .room_token(&livekit_room, &user_id.to_string())
1226 .trace_err()?;
1227
1228 Some(proto::LiveKitConnectionInfo {
1229 server_url: live_kit.url().into(),
1230 token,
1231 can_publish: true,
1232 })
1233 })
1234 .await;
1235
1236 let room = session
1237 .db()
1238 .await
1239 .create_room(session.user_id(), session.connection_id, &livekit_room)
1240 .await?;
1241
1242 response.send(proto::CreateRoomResponse {
1243 room: Some(room.clone()),
1244 live_kit_connection_info,
1245 })?;
1246
1247 update_user_contacts(session.user_id(), &session).await?;
1248 Ok(())
1249}
1250
1251/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1252async fn join_room(
1253 request: proto::JoinRoom,
1254 response: Response<proto::JoinRoom>,
1255 session: Session,
1256) -> Result<()> {
1257 let room_id = RoomId::from_proto(request.id);
1258
1259 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1260
1261 if let Some(channel_id) = channel_id {
1262 return join_channel_internal(channel_id, Box::new(response), session).await;
1263 }
1264
1265 let joined_room = {
1266 let room = session
1267 .db()
1268 .await
1269 .join_room(room_id, session.user_id(), session.connection_id)
1270 .await?;
1271 room_updated(&room.room, &session.peer);
1272 room.into_inner()
1273 };
1274
1275 for connection_id in session
1276 .connection_pool()
1277 .await
1278 .user_connection_ids(session.user_id())
1279 {
1280 session
1281 .peer
1282 .send(
1283 connection_id,
1284 proto::CallCanceled {
1285 room_id: room_id.to_proto(),
1286 },
1287 )
1288 .trace_err();
1289 }
1290
1291 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1292 {
1293 live_kit
1294 .room_token(
1295 &joined_room.room.livekit_room,
1296 &session.user_id().to_string(),
1297 )
1298 .trace_err()
1299 .map(|token| proto::LiveKitConnectionInfo {
1300 server_url: live_kit.url().into(),
1301 token,
1302 can_publish: true,
1303 })
1304 } else {
1305 None
1306 };
1307
1308 response.send(proto::JoinRoomResponse {
1309 room: Some(joined_room.room),
1310 channel_id: None,
1311 live_kit_connection_info,
1312 })?;
1313
1314 update_user_contacts(session.user_id(), &session).await?;
1315 Ok(())
1316}
1317
1318/// Rejoin room is used to reconnect to a room after connection errors.
1319async fn rejoin_room(
1320 request: proto::RejoinRoom,
1321 response: Response<proto::RejoinRoom>,
1322 session: Session,
1323) -> Result<()> {
1324 let room;
1325 let channel;
1326 {
1327 let mut rejoined_room = session
1328 .db()
1329 .await
1330 .rejoin_room(request, session.user_id(), session.connection_id)
1331 .await?;
1332
1333 response.send(proto::RejoinRoomResponse {
1334 room: Some(rejoined_room.room.clone()),
1335 reshared_projects: rejoined_room
1336 .reshared_projects
1337 .iter()
1338 .map(|project| proto::ResharedProject {
1339 id: project.id.to_proto(),
1340 collaborators: project
1341 .collaborators
1342 .iter()
1343 .map(|collaborator| collaborator.to_proto())
1344 .collect(),
1345 })
1346 .collect(),
1347 rejoined_projects: rejoined_room
1348 .rejoined_projects
1349 .iter()
1350 .map(|rejoined_project| rejoined_project.to_proto())
1351 .collect(),
1352 })?;
1353 room_updated(&rejoined_room.room, &session.peer);
1354
1355 for project in &rejoined_room.reshared_projects {
1356 for collaborator in &project.collaborators {
1357 session
1358 .peer
1359 .send(
1360 collaborator.connection_id,
1361 proto::UpdateProjectCollaborator {
1362 project_id: project.id.to_proto(),
1363 old_peer_id: Some(project.old_connection_id.into()),
1364 new_peer_id: Some(session.connection_id.into()),
1365 },
1366 )
1367 .trace_err();
1368 }
1369
1370 broadcast(
1371 Some(session.connection_id),
1372 project
1373 .collaborators
1374 .iter()
1375 .map(|collaborator| collaborator.connection_id),
1376 |connection_id| {
1377 session.peer.forward_send(
1378 session.connection_id,
1379 connection_id,
1380 proto::UpdateProject {
1381 project_id: project.id.to_proto(),
1382 worktrees: project.worktrees.clone(),
1383 },
1384 )
1385 },
1386 );
1387 }
1388
1389 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1390
1391 let rejoined_room = rejoined_room.into_inner();
1392
1393 room = rejoined_room.room;
1394 channel = rejoined_room.channel;
1395 }
1396
1397 if let Some(channel) = channel {
1398 channel_updated(
1399 &channel,
1400 &room,
1401 &session.peer,
1402 &*session.connection_pool().await,
1403 );
1404 }
1405
1406 update_user_contacts(session.user_id(), &session).await?;
1407 Ok(())
1408}
1409
1410fn notify_rejoined_projects(
1411 rejoined_projects: &mut Vec<RejoinedProject>,
1412 session: &Session,
1413) -> Result<()> {
1414 for project in rejoined_projects.iter() {
1415 for collaborator in &project.collaborators {
1416 session
1417 .peer
1418 .send(
1419 collaborator.connection_id,
1420 proto::UpdateProjectCollaborator {
1421 project_id: project.id.to_proto(),
1422 old_peer_id: Some(project.old_connection_id.into()),
1423 new_peer_id: Some(session.connection_id.into()),
1424 },
1425 )
1426 .trace_err();
1427 }
1428 }
1429
1430 for project in rejoined_projects {
1431 for worktree in mem::take(&mut project.worktrees) {
1432 // Stream this worktree's entries.
1433 let message = proto::UpdateWorktree {
1434 project_id: project.id.to_proto(),
1435 worktree_id: worktree.id,
1436 abs_path: worktree.abs_path.clone(),
1437 root_name: worktree.root_name,
1438 updated_entries: worktree.updated_entries,
1439 removed_entries: worktree.removed_entries,
1440 scan_id: worktree.scan_id,
1441 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1442 updated_repositories: worktree.updated_repositories,
1443 removed_repositories: worktree.removed_repositories,
1444 };
1445 for update in proto::split_worktree_update(message) {
1446 session.peer.send(session.connection_id, update.clone())?;
1447 }
1448
1449 // Stream this worktree's diagnostics.
1450 for summary in worktree.diagnostic_summaries {
1451 session.peer.send(
1452 session.connection_id,
1453 proto::UpdateDiagnosticSummary {
1454 project_id: project.id.to_proto(),
1455 worktree_id: worktree.id,
1456 summary: Some(summary),
1457 },
1458 )?;
1459 }
1460
1461 for settings_file in worktree.settings_files {
1462 session.peer.send(
1463 session.connection_id,
1464 proto::UpdateWorktreeSettings {
1465 project_id: project.id.to_proto(),
1466 worktree_id: worktree.id,
1467 path: settings_file.path,
1468 content: Some(settings_file.content),
1469 kind: Some(settings_file.kind.to_proto().into()),
1470 },
1471 )?;
1472 }
1473 }
1474
1475 for language_server in &project.language_servers {
1476 session.peer.send(
1477 session.connection_id,
1478 proto::UpdateLanguageServer {
1479 project_id: project.id.to_proto(),
1480 language_server_id: language_server.id,
1481 variant: Some(
1482 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1483 proto::LspDiskBasedDiagnosticsUpdated {},
1484 ),
1485 ),
1486 },
1487 )?;
1488 }
1489 }
1490 Ok(())
1491}
1492
1493/// leave room disconnects from the room.
1494async fn leave_room(
1495 _: proto::LeaveRoom,
1496 response: Response<proto::LeaveRoom>,
1497 session: Session,
1498) -> Result<()> {
1499 leave_room_for_session(&session, session.connection_id).await?;
1500 response.send(proto::Ack {})?;
1501 Ok(())
1502}
1503
1504/// Updates the permissions of someone else in the room.
1505async fn set_room_participant_role(
1506 request: proto::SetRoomParticipantRole,
1507 response: Response<proto::SetRoomParticipantRole>,
1508 session: Session,
1509) -> Result<()> {
1510 let user_id = UserId::from_proto(request.user_id);
1511 let role = ChannelRole::from(request.role());
1512
1513 let (livekit_room, can_publish) = {
1514 let room = session
1515 .db()
1516 .await
1517 .set_room_participant_role(
1518 session.user_id(),
1519 RoomId::from_proto(request.room_id),
1520 user_id,
1521 role,
1522 )
1523 .await?;
1524
1525 let livekit_room = room.livekit_room.clone();
1526 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1527 room_updated(&room, &session.peer);
1528 (livekit_room, can_publish)
1529 };
1530
1531 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1532 live_kit
1533 .update_participant(
1534 livekit_room.clone(),
1535 request.user_id.to_string(),
1536 livekit_server::proto::ParticipantPermission {
1537 can_subscribe: true,
1538 can_publish,
1539 can_publish_data: can_publish,
1540 hidden: false,
1541 recorder: false,
1542 },
1543 )
1544 .await
1545 .trace_err();
1546 }
1547
1548 response.send(proto::Ack {})?;
1549 Ok(())
1550}
1551
1552/// Call someone else into the current room
1553async fn call(
1554 request: proto::Call,
1555 response: Response<proto::Call>,
1556 session: Session,
1557) -> Result<()> {
1558 let room_id = RoomId::from_proto(request.room_id);
1559 let calling_user_id = session.user_id();
1560 let calling_connection_id = session.connection_id;
1561 let called_user_id = UserId::from_proto(request.called_user_id);
1562 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1563 if !session
1564 .db()
1565 .await
1566 .has_contact(calling_user_id, called_user_id)
1567 .await?
1568 {
1569 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1570 }
1571
1572 let incoming_call = {
1573 let (room, incoming_call) = &mut *session
1574 .db()
1575 .await
1576 .call(
1577 room_id,
1578 calling_user_id,
1579 calling_connection_id,
1580 called_user_id,
1581 initial_project_id,
1582 )
1583 .await?;
1584 room_updated(room, &session.peer);
1585 mem::take(incoming_call)
1586 };
1587 update_user_contacts(called_user_id, &session).await?;
1588
1589 let mut calls = session
1590 .connection_pool()
1591 .await
1592 .user_connection_ids(called_user_id)
1593 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1594 .collect::<FuturesUnordered<_>>();
1595
1596 while let Some(call_response) = calls.next().await {
1597 match call_response.as_ref() {
1598 Ok(_) => {
1599 response.send(proto::Ack {})?;
1600 return Ok(());
1601 }
1602 Err(_) => {
1603 call_response.trace_err();
1604 }
1605 }
1606 }
1607
1608 {
1609 let room = session
1610 .db()
1611 .await
1612 .call_failed(room_id, called_user_id)
1613 .await?;
1614 room_updated(&room, &session.peer);
1615 }
1616 update_user_contacts(called_user_id, &session).await?;
1617
1618 Err(anyhow!("failed to ring user"))?
1619}
1620
1621/// Cancel an outgoing call.
1622async fn cancel_call(
1623 request: proto::CancelCall,
1624 response: Response<proto::CancelCall>,
1625 session: Session,
1626) -> Result<()> {
1627 let called_user_id = UserId::from_proto(request.called_user_id);
1628 let room_id = RoomId::from_proto(request.room_id);
1629 {
1630 let room = session
1631 .db()
1632 .await
1633 .cancel_call(room_id, session.connection_id, called_user_id)
1634 .await?;
1635 room_updated(&room, &session.peer);
1636 }
1637
1638 for connection_id in session
1639 .connection_pool()
1640 .await
1641 .user_connection_ids(called_user_id)
1642 {
1643 session
1644 .peer
1645 .send(
1646 connection_id,
1647 proto::CallCanceled {
1648 room_id: room_id.to_proto(),
1649 },
1650 )
1651 .trace_err();
1652 }
1653 response.send(proto::Ack {})?;
1654
1655 update_user_contacts(called_user_id, &session).await?;
1656 Ok(())
1657}
1658
1659/// Decline an incoming call.
1660async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1661 let room_id = RoomId::from_proto(message.room_id);
1662 {
1663 let room = session
1664 .db()
1665 .await
1666 .decline_call(Some(room_id), session.user_id())
1667 .await?
1668 .ok_or_else(|| anyhow!("failed to decline call"))?;
1669 room_updated(&room, &session.peer);
1670 }
1671
1672 for connection_id in session
1673 .connection_pool()
1674 .await
1675 .user_connection_ids(session.user_id())
1676 {
1677 session
1678 .peer
1679 .send(
1680 connection_id,
1681 proto::CallCanceled {
1682 room_id: room_id.to_proto(),
1683 },
1684 )
1685 .trace_err();
1686 }
1687 update_user_contacts(session.user_id(), &session).await?;
1688 Ok(())
1689}
1690
1691/// Updates other participants in the room with your current location.
1692async fn update_participant_location(
1693 request: proto::UpdateParticipantLocation,
1694 response: Response<proto::UpdateParticipantLocation>,
1695 session: Session,
1696) -> Result<()> {
1697 let room_id = RoomId::from_proto(request.room_id);
1698 let location = request
1699 .location
1700 .ok_or_else(|| anyhow!("invalid location"))?;
1701
1702 let db = session.db().await;
1703 let room = db
1704 .update_room_participant_location(room_id, session.connection_id, location)
1705 .await?;
1706
1707 room_updated(&room, &session.peer);
1708 response.send(proto::Ack {})?;
1709 Ok(())
1710}
1711
1712/// Share a project into the room.
1713async fn share_project(
1714 request: proto::ShareProject,
1715 response: Response<proto::ShareProject>,
1716 session: Session,
1717) -> Result<()> {
1718 let (project_id, room) = &*session
1719 .db()
1720 .await
1721 .share_project(
1722 RoomId::from_proto(request.room_id),
1723 session.connection_id,
1724 &request.worktrees,
1725 request.is_ssh_project,
1726 )
1727 .await?;
1728 response.send(proto::ShareProjectResponse {
1729 project_id: project_id.to_proto(),
1730 })?;
1731 room_updated(room, &session.peer);
1732
1733 Ok(())
1734}
1735
1736/// Unshare a project from the room.
1737async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1738 let project_id = ProjectId::from_proto(message.project_id);
1739 unshare_project_internal(project_id, session.connection_id, &session).await
1740}
1741
1742async fn unshare_project_internal(
1743 project_id: ProjectId,
1744 connection_id: ConnectionId,
1745 session: &Session,
1746) -> Result<()> {
1747 let delete = {
1748 let room_guard = session
1749 .db()
1750 .await
1751 .unshare_project(project_id, connection_id)
1752 .await?;
1753
1754 let (delete, room, guest_connection_ids) = &*room_guard;
1755
1756 let message = proto::UnshareProject {
1757 project_id: project_id.to_proto(),
1758 };
1759
1760 broadcast(
1761 Some(connection_id),
1762 guest_connection_ids.iter().copied(),
1763 |conn_id| session.peer.send(conn_id, message.clone()),
1764 );
1765 if let Some(room) = room {
1766 room_updated(room, &session.peer);
1767 }
1768
1769 *delete
1770 };
1771
1772 if delete {
1773 let db = session.db().await;
1774 db.delete_project(project_id).await?;
1775 }
1776
1777 Ok(())
1778}
1779
1780/// Join someone elses shared project.
1781async fn join_project(
1782 request: proto::JoinProject,
1783 response: Response<proto::JoinProject>,
1784 session: Session,
1785) -> Result<()> {
1786 let project_id = ProjectId::from_proto(request.project_id);
1787
1788 tracing::info!(%project_id, "join project");
1789
1790 let db = session.db().await;
1791 let (project, replica_id) = &mut *db
1792 .join_project(project_id, session.connection_id, session.user_id())
1793 .await?;
1794 drop(db);
1795 tracing::info!(%project_id, "join remote project");
1796 join_project_internal(response, session, project, replica_id)
1797}
1798
1799trait JoinProjectInternalResponse {
1800 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1801}
1802impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1803 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1804 Response::<proto::JoinProject>::send(self, result)
1805 }
1806}
1807
1808fn join_project_internal(
1809 response: impl JoinProjectInternalResponse,
1810 session: Session,
1811 project: &mut Project,
1812 replica_id: &ReplicaId,
1813) -> Result<()> {
1814 let collaborators = project
1815 .collaborators
1816 .iter()
1817 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1818 .map(|collaborator| collaborator.to_proto())
1819 .collect::<Vec<_>>();
1820 let project_id = project.id;
1821 let guest_user_id = session.user_id();
1822
1823 let worktrees = project
1824 .worktrees
1825 .iter()
1826 .map(|(id, worktree)| proto::WorktreeMetadata {
1827 id: *id,
1828 root_name: worktree.root_name.clone(),
1829 visible: worktree.visible,
1830 abs_path: worktree.abs_path.clone(),
1831 })
1832 .collect::<Vec<_>>();
1833
1834 let add_project_collaborator = proto::AddProjectCollaborator {
1835 project_id: project_id.to_proto(),
1836 collaborator: Some(proto::Collaborator {
1837 peer_id: Some(session.connection_id.into()),
1838 replica_id: replica_id.0 as u32,
1839 user_id: guest_user_id.to_proto(),
1840 is_host: false,
1841 }),
1842 };
1843
1844 for collaborator in &collaborators {
1845 session
1846 .peer
1847 .send(
1848 collaborator.peer_id.unwrap().into(),
1849 add_project_collaborator.clone(),
1850 )
1851 .trace_err();
1852 }
1853
1854 // First, we send the metadata associated with each worktree.
1855 response.send(proto::JoinProjectResponse {
1856 project_id: project.id.0 as u64,
1857 worktrees: worktrees.clone(),
1858 replica_id: replica_id.0 as u32,
1859 collaborators: collaborators.clone(),
1860 language_servers: project.language_servers.clone(),
1861 role: project.role.into(),
1862 })?;
1863
1864 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1865 // Stream this worktree's entries.
1866 let message = proto::UpdateWorktree {
1867 project_id: project_id.to_proto(),
1868 worktree_id,
1869 abs_path: worktree.abs_path.clone(),
1870 root_name: worktree.root_name,
1871 updated_entries: worktree.entries,
1872 removed_entries: Default::default(),
1873 scan_id: worktree.scan_id,
1874 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1875 updated_repositories: worktree.repository_entries.into_values().collect(),
1876 removed_repositories: Default::default(),
1877 };
1878 for update in proto::split_worktree_update(message) {
1879 session.peer.send(session.connection_id, update.clone())?;
1880 }
1881
1882 // Stream this worktree's diagnostics.
1883 for summary in worktree.diagnostic_summaries {
1884 session.peer.send(
1885 session.connection_id,
1886 proto::UpdateDiagnosticSummary {
1887 project_id: project_id.to_proto(),
1888 worktree_id: worktree.id,
1889 summary: Some(summary),
1890 },
1891 )?;
1892 }
1893
1894 for settings_file in worktree.settings_files {
1895 session.peer.send(
1896 session.connection_id,
1897 proto::UpdateWorktreeSettings {
1898 project_id: project_id.to_proto(),
1899 worktree_id: worktree.id,
1900 path: settings_file.path,
1901 content: Some(settings_file.content),
1902 kind: Some(settings_file.kind.to_proto() as i32),
1903 },
1904 )?;
1905 }
1906 }
1907
1908 for language_server in &project.language_servers {
1909 session.peer.send(
1910 session.connection_id,
1911 proto::UpdateLanguageServer {
1912 project_id: project_id.to_proto(),
1913 language_server_id: language_server.id,
1914 variant: Some(
1915 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1916 proto::LspDiskBasedDiagnosticsUpdated {},
1917 ),
1918 ),
1919 },
1920 )?;
1921 }
1922
1923 Ok(())
1924}
1925
1926/// Leave someone elses shared project.
1927async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1928 let sender_id = session.connection_id;
1929 let project_id = ProjectId::from_proto(request.project_id);
1930 let db = session.db().await;
1931
1932 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1933 tracing::info!(
1934 %project_id,
1935 "leave project"
1936 );
1937
1938 project_left(project, &session);
1939 if let Some(room) = room {
1940 room_updated(room, &session.peer);
1941 }
1942
1943 Ok(())
1944}
1945
1946/// Updates other participants with changes to the project
1947async fn update_project(
1948 request: proto::UpdateProject,
1949 response: Response<proto::UpdateProject>,
1950 session: Session,
1951) -> Result<()> {
1952 let project_id = ProjectId::from_proto(request.project_id);
1953 let (room, guest_connection_ids) = &*session
1954 .db()
1955 .await
1956 .update_project(project_id, session.connection_id, &request.worktrees)
1957 .await?;
1958 broadcast(
1959 Some(session.connection_id),
1960 guest_connection_ids.iter().copied(),
1961 |connection_id| {
1962 session
1963 .peer
1964 .forward_send(session.connection_id, connection_id, request.clone())
1965 },
1966 );
1967 if let Some(room) = room {
1968 room_updated(room, &session.peer);
1969 }
1970 response.send(proto::Ack {})?;
1971
1972 Ok(())
1973}
1974
1975/// Updates other participants with changes to the worktree
1976async fn update_worktree(
1977 request: proto::UpdateWorktree,
1978 response: Response<proto::UpdateWorktree>,
1979 session: Session,
1980) -> Result<()> {
1981 let guest_connection_ids = session
1982 .db()
1983 .await
1984 .update_worktree(&request, session.connection_id)
1985 .await?;
1986
1987 broadcast(
1988 Some(session.connection_id),
1989 guest_connection_ids.iter().copied(),
1990 |connection_id| {
1991 session
1992 .peer
1993 .forward_send(session.connection_id, connection_id, request.clone())
1994 },
1995 );
1996 response.send(proto::Ack {})?;
1997 Ok(())
1998}
1999
2000/// Updates other participants with changes to the diagnostics
2001async fn update_diagnostic_summary(
2002 message: proto::UpdateDiagnosticSummary,
2003 session: Session,
2004) -> Result<()> {
2005 let guest_connection_ids = session
2006 .db()
2007 .await
2008 .update_diagnostic_summary(&message, session.connection_id)
2009 .await?;
2010
2011 broadcast(
2012 Some(session.connection_id),
2013 guest_connection_ids.iter().copied(),
2014 |connection_id| {
2015 session
2016 .peer
2017 .forward_send(session.connection_id, connection_id, message.clone())
2018 },
2019 );
2020
2021 Ok(())
2022}
2023
2024/// Updates other participants with changes to the worktree settings
2025async fn update_worktree_settings(
2026 message: proto::UpdateWorktreeSettings,
2027 session: Session,
2028) -> Result<()> {
2029 let guest_connection_ids = session
2030 .db()
2031 .await
2032 .update_worktree_settings(&message, session.connection_id)
2033 .await?;
2034
2035 broadcast(
2036 Some(session.connection_id),
2037 guest_connection_ids.iter().copied(),
2038 |connection_id| {
2039 session
2040 .peer
2041 .forward_send(session.connection_id, connection_id, message.clone())
2042 },
2043 );
2044
2045 Ok(())
2046}
2047
2048/// Notify other participants that a language server has started.
2049async fn start_language_server(
2050 request: proto::StartLanguageServer,
2051 session: Session,
2052) -> Result<()> {
2053 let guest_connection_ids = session
2054 .db()
2055 .await
2056 .start_language_server(&request, session.connection_id)
2057 .await?;
2058
2059 broadcast(
2060 Some(session.connection_id),
2061 guest_connection_ids.iter().copied(),
2062 |connection_id| {
2063 session
2064 .peer
2065 .forward_send(session.connection_id, connection_id, request.clone())
2066 },
2067 );
2068 Ok(())
2069}
2070
2071/// Notify other participants that a language server has changed.
2072async fn update_language_server(
2073 request: proto::UpdateLanguageServer,
2074 session: Session,
2075) -> Result<()> {
2076 let project_id = ProjectId::from_proto(request.project_id);
2077 let project_connection_ids = session
2078 .db()
2079 .await
2080 .project_connection_ids(project_id, session.connection_id, true)
2081 .await?;
2082 broadcast(
2083 Some(session.connection_id),
2084 project_connection_ids.iter().copied(),
2085 |connection_id| {
2086 session
2087 .peer
2088 .forward_send(session.connection_id, connection_id, request.clone())
2089 },
2090 );
2091 Ok(())
2092}
2093
2094/// forward a project request to the host. These requests should be read only
2095/// as guests are allowed to send them.
2096async fn forward_read_only_project_request<T>(
2097 request: T,
2098 response: Response<T>,
2099 session: Session,
2100) -> Result<()>
2101where
2102 T: EntityMessage + RequestMessage,
2103{
2104 let project_id = ProjectId::from_proto(request.remote_entity_id());
2105 let host_connection_id = session
2106 .db()
2107 .await
2108 .host_for_read_only_project_request(project_id, session.connection_id)
2109 .await?;
2110 let payload = session
2111 .peer
2112 .forward_request(session.connection_id, host_connection_id, request)
2113 .await?;
2114 response.send(payload)?;
2115 Ok(())
2116}
2117
2118async fn forward_find_search_candidates_request(
2119 request: proto::FindSearchCandidates,
2120 response: Response<proto::FindSearchCandidates>,
2121 session: Session,
2122) -> Result<()> {
2123 let project_id = ProjectId::from_proto(request.remote_entity_id());
2124 let host_connection_id = session
2125 .db()
2126 .await
2127 .host_for_read_only_project_request(project_id, session.connection_id)
2128 .await?;
2129 let payload = session
2130 .peer
2131 .forward_request(session.connection_id, host_connection_id, request)
2132 .await?;
2133 response.send(payload)?;
2134 Ok(())
2135}
2136
2137/// forward a project request to the host. These requests are disallowed
2138/// for guests.
2139async fn forward_mutating_project_request<T>(
2140 request: T,
2141 response: Response<T>,
2142 session: Session,
2143) -> Result<()>
2144where
2145 T: EntityMessage + RequestMessage,
2146{
2147 let project_id = ProjectId::from_proto(request.remote_entity_id());
2148
2149 let host_connection_id = session
2150 .db()
2151 .await
2152 .host_for_mutating_project_request(project_id, session.connection_id)
2153 .await?;
2154 let payload = session
2155 .peer
2156 .forward_request(session.connection_id, host_connection_id, request)
2157 .await?;
2158 response.send(payload)?;
2159 Ok(())
2160}
2161
2162/// Notify other participants that a new buffer has been created
2163async fn create_buffer_for_peer(
2164 request: proto::CreateBufferForPeer,
2165 session: Session,
2166) -> Result<()> {
2167 session
2168 .db()
2169 .await
2170 .check_user_is_project_host(
2171 ProjectId::from_proto(request.project_id),
2172 session.connection_id,
2173 )
2174 .await?;
2175 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2176 session
2177 .peer
2178 .forward_send(session.connection_id, peer_id.into(), request)?;
2179 Ok(())
2180}
2181
2182/// Notify other participants that a buffer has been updated. This is
2183/// allowed for guests as long as the update is limited to selections.
2184async fn update_buffer(
2185 request: proto::UpdateBuffer,
2186 response: Response<proto::UpdateBuffer>,
2187 session: Session,
2188) -> Result<()> {
2189 let project_id = ProjectId::from_proto(request.project_id);
2190 let mut capability = Capability::ReadOnly;
2191
2192 for op in request.operations.iter() {
2193 match op.variant {
2194 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2195 Some(_) => capability = Capability::ReadWrite,
2196 }
2197 }
2198
2199 let host = {
2200 let guard = session
2201 .db()
2202 .await
2203 .connections_for_buffer_update(project_id, session.connection_id, capability)
2204 .await?;
2205
2206 let (host, guests) = &*guard;
2207
2208 broadcast(
2209 Some(session.connection_id),
2210 guests.clone(),
2211 |connection_id| {
2212 session
2213 .peer
2214 .forward_send(session.connection_id, connection_id, request.clone())
2215 },
2216 );
2217
2218 *host
2219 };
2220
2221 if host != session.connection_id {
2222 session
2223 .peer
2224 .forward_request(session.connection_id, host, request.clone())
2225 .await?;
2226 }
2227
2228 response.send(proto::Ack {})?;
2229 Ok(())
2230}
2231
2232async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2233 let project_id = ProjectId::from_proto(message.project_id);
2234
2235 let operation = message.operation.as_ref().context("invalid operation")?;
2236 let capability = match operation.variant.as_ref() {
2237 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2238 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2239 match buffer_op.variant {
2240 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2241 Capability::ReadOnly
2242 }
2243 _ => Capability::ReadWrite,
2244 }
2245 } else {
2246 Capability::ReadWrite
2247 }
2248 }
2249 Some(_) => Capability::ReadWrite,
2250 None => Capability::ReadOnly,
2251 };
2252
2253 let guard = session
2254 .db()
2255 .await
2256 .connections_for_buffer_update(project_id, session.connection_id, capability)
2257 .await?;
2258
2259 let (host, guests) = &*guard;
2260
2261 broadcast(
2262 Some(session.connection_id),
2263 guests.iter().chain([host]).copied(),
2264 |connection_id| {
2265 session
2266 .peer
2267 .forward_send(session.connection_id, connection_id, message.clone())
2268 },
2269 );
2270
2271 Ok(())
2272}
2273
2274/// Notify other participants that a project has been updated.
2275async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2276 request: T,
2277 session: Session,
2278) -> Result<()> {
2279 let project_id = ProjectId::from_proto(request.remote_entity_id());
2280 let project_connection_ids = session
2281 .db()
2282 .await
2283 .project_connection_ids(project_id, session.connection_id, false)
2284 .await?;
2285
2286 broadcast(
2287 Some(session.connection_id),
2288 project_connection_ids.iter().copied(),
2289 |connection_id| {
2290 session
2291 .peer
2292 .forward_send(session.connection_id, connection_id, request.clone())
2293 },
2294 );
2295 Ok(())
2296}
2297
2298/// Start following another user in a call.
2299async fn follow(
2300 request: proto::Follow,
2301 response: Response<proto::Follow>,
2302 session: Session,
2303) -> Result<()> {
2304 let room_id = RoomId::from_proto(request.room_id);
2305 let project_id = request.project_id.map(ProjectId::from_proto);
2306 let leader_id = request
2307 .leader_id
2308 .ok_or_else(|| anyhow!("invalid leader id"))?
2309 .into();
2310 let follower_id = session.connection_id;
2311
2312 session
2313 .db()
2314 .await
2315 .check_room_participants(room_id, leader_id, session.connection_id)
2316 .await?;
2317
2318 let response_payload = session
2319 .peer
2320 .forward_request(session.connection_id, leader_id, request)
2321 .await?;
2322 response.send(response_payload)?;
2323
2324 if let Some(project_id) = project_id {
2325 let room = session
2326 .db()
2327 .await
2328 .follow(room_id, project_id, leader_id, follower_id)
2329 .await?;
2330 room_updated(&room, &session.peer);
2331 }
2332
2333 Ok(())
2334}
2335
2336/// Stop following another user in a call.
2337async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2338 let room_id = RoomId::from_proto(request.room_id);
2339 let project_id = request.project_id.map(ProjectId::from_proto);
2340 let leader_id = request
2341 .leader_id
2342 .ok_or_else(|| anyhow!("invalid leader id"))?
2343 .into();
2344 let follower_id = session.connection_id;
2345
2346 session
2347 .db()
2348 .await
2349 .check_room_participants(room_id, leader_id, session.connection_id)
2350 .await?;
2351
2352 session
2353 .peer
2354 .forward_send(session.connection_id, leader_id, request)?;
2355
2356 if let Some(project_id) = project_id {
2357 let room = session
2358 .db()
2359 .await
2360 .unfollow(room_id, project_id, leader_id, follower_id)
2361 .await?;
2362 room_updated(&room, &session.peer);
2363 }
2364
2365 Ok(())
2366}
2367
2368/// Notify everyone following you of your current location.
2369async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2370 let room_id = RoomId::from_proto(request.room_id);
2371 let database = session.db.lock().await;
2372
2373 let connection_ids = if let Some(project_id) = request.project_id {
2374 let project_id = ProjectId::from_proto(project_id);
2375 database
2376 .project_connection_ids(project_id, session.connection_id, true)
2377 .await?
2378 } else {
2379 database
2380 .room_connection_ids(room_id, session.connection_id)
2381 .await?
2382 };
2383
2384 // For now, don't send view update messages back to that view's current leader.
2385 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2386 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2387 _ => None,
2388 });
2389
2390 for connection_id in connection_ids.iter().cloned() {
2391 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2392 session
2393 .peer
2394 .forward_send(session.connection_id, connection_id, request.clone())?;
2395 }
2396 }
2397 Ok(())
2398}
2399
2400/// Get public data about users.
2401async fn get_users(
2402 request: proto::GetUsers,
2403 response: Response<proto::GetUsers>,
2404 session: Session,
2405) -> Result<()> {
2406 let user_ids = request
2407 .user_ids
2408 .into_iter()
2409 .map(UserId::from_proto)
2410 .collect();
2411 let users = session
2412 .db()
2413 .await
2414 .get_users_by_ids(user_ids)
2415 .await?
2416 .into_iter()
2417 .map(|user| proto::User {
2418 id: user.id.to_proto(),
2419 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2420 github_login: user.github_login,
2421 email: user.email_address,
2422 name: user.name,
2423 })
2424 .collect();
2425 response.send(proto::UsersResponse { users })?;
2426 Ok(())
2427}
2428
2429/// Search for users (to invite) buy Github login
2430async fn fuzzy_search_users(
2431 request: proto::FuzzySearchUsers,
2432 response: Response<proto::FuzzySearchUsers>,
2433 session: Session,
2434) -> Result<()> {
2435 let query = request.query;
2436 let users = match query.len() {
2437 0 => vec![],
2438 1 | 2 => session
2439 .db()
2440 .await
2441 .get_user_by_github_login(&query)
2442 .await?
2443 .into_iter()
2444 .collect(),
2445 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2446 };
2447 let users = users
2448 .into_iter()
2449 .filter(|user| user.id != session.user_id())
2450 .map(|user| proto::User {
2451 id: user.id.to_proto(),
2452 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2453 github_login: user.github_login,
2454 name: user.name,
2455 email: user.email_address,
2456 })
2457 .collect();
2458 response.send(proto::UsersResponse { users })?;
2459 Ok(())
2460}
2461
2462/// Send a contact request to another user.
2463async fn request_contact(
2464 request: proto::RequestContact,
2465 response: Response<proto::RequestContact>,
2466 session: Session,
2467) -> Result<()> {
2468 let requester_id = session.user_id();
2469 let responder_id = UserId::from_proto(request.responder_id);
2470 if requester_id == responder_id {
2471 return Err(anyhow!("cannot add yourself as a contact"))?;
2472 }
2473
2474 let notifications = session
2475 .db()
2476 .await
2477 .send_contact_request(requester_id, responder_id)
2478 .await?;
2479
2480 // Update outgoing contact requests of requester
2481 let mut update = proto::UpdateContacts::default();
2482 update.outgoing_requests.push(responder_id.to_proto());
2483 for connection_id in session
2484 .connection_pool()
2485 .await
2486 .user_connection_ids(requester_id)
2487 {
2488 session.peer.send(connection_id, update.clone())?;
2489 }
2490
2491 // Update incoming contact requests of responder
2492 let mut update = proto::UpdateContacts::default();
2493 update
2494 .incoming_requests
2495 .push(proto::IncomingContactRequest {
2496 requester_id: requester_id.to_proto(),
2497 });
2498 let connection_pool = session.connection_pool().await;
2499 for connection_id in connection_pool.user_connection_ids(responder_id) {
2500 session.peer.send(connection_id, update.clone())?;
2501 }
2502
2503 send_notifications(&connection_pool, &session.peer, notifications);
2504
2505 response.send(proto::Ack {})?;
2506 Ok(())
2507}
2508
2509/// Accept or decline a contact request
2510async fn respond_to_contact_request(
2511 request: proto::RespondToContactRequest,
2512 response: Response<proto::RespondToContactRequest>,
2513 session: Session,
2514) -> Result<()> {
2515 let responder_id = session.user_id();
2516 let requester_id = UserId::from_proto(request.requester_id);
2517 let db = session.db().await;
2518 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2519 db.dismiss_contact_notification(responder_id, requester_id)
2520 .await?;
2521 } else {
2522 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2523
2524 let notifications = db
2525 .respond_to_contact_request(responder_id, requester_id, accept)
2526 .await?;
2527 let requester_busy = db.is_user_busy(requester_id).await?;
2528 let responder_busy = db.is_user_busy(responder_id).await?;
2529
2530 let pool = session.connection_pool().await;
2531 // Update responder with new contact
2532 let mut update = proto::UpdateContacts::default();
2533 if accept {
2534 update
2535 .contacts
2536 .push(contact_for_user(requester_id, requester_busy, &pool));
2537 }
2538 update
2539 .remove_incoming_requests
2540 .push(requester_id.to_proto());
2541 for connection_id in pool.user_connection_ids(responder_id) {
2542 session.peer.send(connection_id, update.clone())?;
2543 }
2544
2545 // Update requester with new contact
2546 let mut update = proto::UpdateContacts::default();
2547 if accept {
2548 update
2549 .contacts
2550 .push(contact_for_user(responder_id, responder_busy, &pool));
2551 }
2552 update
2553 .remove_outgoing_requests
2554 .push(responder_id.to_proto());
2555
2556 for connection_id in pool.user_connection_ids(requester_id) {
2557 session.peer.send(connection_id, update.clone())?;
2558 }
2559
2560 send_notifications(&pool, &session.peer, notifications);
2561 }
2562
2563 response.send(proto::Ack {})?;
2564 Ok(())
2565}
2566
2567/// Remove a contact.
2568async fn remove_contact(
2569 request: proto::RemoveContact,
2570 response: Response<proto::RemoveContact>,
2571 session: Session,
2572) -> Result<()> {
2573 let requester_id = session.user_id();
2574 let responder_id = UserId::from_proto(request.user_id);
2575 let db = session.db().await;
2576 let (contact_accepted, deleted_notification_id) =
2577 db.remove_contact(requester_id, responder_id).await?;
2578
2579 let pool = session.connection_pool().await;
2580 // Update outgoing contact requests of requester
2581 let mut update = proto::UpdateContacts::default();
2582 if contact_accepted {
2583 update.remove_contacts.push(responder_id.to_proto());
2584 } else {
2585 update
2586 .remove_outgoing_requests
2587 .push(responder_id.to_proto());
2588 }
2589 for connection_id in pool.user_connection_ids(requester_id) {
2590 session.peer.send(connection_id, update.clone())?;
2591 }
2592
2593 // Update incoming contact requests of responder
2594 let mut update = proto::UpdateContacts::default();
2595 if contact_accepted {
2596 update.remove_contacts.push(requester_id.to_proto());
2597 } else {
2598 update
2599 .remove_incoming_requests
2600 .push(requester_id.to_proto());
2601 }
2602 for connection_id in pool.user_connection_ids(responder_id) {
2603 session.peer.send(connection_id, update.clone())?;
2604 if let Some(notification_id) = deleted_notification_id {
2605 session.peer.send(
2606 connection_id,
2607 proto::DeleteNotification {
2608 notification_id: notification_id.to_proto(),
2609 },
2610 )?;
2611 }
2612 }
2613
2614 response.send(proto::Ack {})?;
2615 Ok(())
2616}
2617
2618fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2619 version.0.minor() < 139
2620}
2621
2622async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2623 let plan = session.current_plan(&session.db().await).await?;
2624
2625 session
2626 .peer
2627 .send(
2628 session.connection_id,
2629 proto::UpdateUserPlan { plan: plan.into() },
2630 )
2631 .trace_err();
2632
2633 Ok(())
2634}
2635
2636async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2637 subscribe_user_to_channels(session.user_id(), &session).await?;
2638 Ok(())
2639}
2640
2641async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2642 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2643 let mut pool = session.connection_pool().await;
2644 for membership in &channels_for_user.channel_memberships {
2645 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2646 }
2647 session.peer.send(
2648 session.connection_id,
2649 build_update_user_channels(&channels_for_user),
2650 )?;
2651 session.peer.send(
2652 session.connection_id,
2653 build_channels_update(channels_for_user),
2654 )?;
2655 Ok(())
2656}
2657
2658/// Creates a new channel.
2659async fn create_channel(
2660 request: proto::CreateChannel,
2661 response: Response<proto::CreateChannel>,
2662 session: Session,
2663) -> Result<()> {
2664 let db = session.db().await;
2665
2666 let parent_id = request.parent_id.map(ChannelId::from_proto);
2667 let (channel, membership) = db
2668 .create_channel(&request.name, parent_id, session.user_id())
2669 .await?;
2670
2671 let root_id = channel.root_id();
2672 let channel = Channel::from_model(channel);
2673
2674 response.send(proto::CreateChannelResponse {
2675 channel: Some(channel.to_proto()),
2676 parent_id: request.parent_id,
2677 })?;
2678
2679 let mut connection_pool = session.connection_pool().await;
2680 if let Some(membership) = membership {
2681 connection_pool.subscribe_to_channel(
2682 membership.user_id,
2683 membership.channel_id,
2684 membership.role,
2685 );
2686 let update = proto::UpdateUserChannels {
2687 channel_memberships: vec![proto::ChannelMembership {
2688 channel_id: membership.channel_id.to_proto(),
2689 role: membership.role.into(),
2690 }],
2691 ..Default::default()
2692 };
2693 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2694 session.peer.send(connection_id, update.clone())?;
2695 }
2696 }
2697
2698 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2699 if !role.can_see_channel(channel.visibility) {
2700 continue;
2701 }
2702
2703 let update = proto::UpdateChannels {
2704 channels: vec![channel.to_proto()],
2705 ..Default::default()
2706 };
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709
2710 Ok(())
2711}
2712
2713/// Delete a channel
2714async fn delete_channel(
2715 request: proto::DeleteChannel,
2716 response: Response<proto::DeleteChannel>,
2717 session: Session,
2718) -> Result<()> {
2719 let db = session.db().await;
2720
2721 let channel_id = request.channel_id;
2722 let (root_channel, removed_channels) = db
2723 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2724 .await?;
2725 response.send(proto::Ack {})?;
2726
2727 // Notify members of removed channels
2728 let mut update = proto::UpdateChannels::default();
2729 update
2730 .delete_channels
2731 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2732
2733 let connection_pool = session.connection_pool().await;
2734 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2735 session.peer.send(connection_id, update.clone())?;
2736 }
2737
2738 Ok(())
2739}
2740
2741/// Invite someone to join a channel.
2742async fn invite_channel_member(
2743 request: proto::InviteChannelMember,
2744 response: Response<proto::InviteChannelMember>,
2745 session: Session,
2746) -> Result<()> {
2747 let db = session.db().await;
2748 let channel_id = ChannelId::from_proto(request.channel_id);
2749 let invitee_id = UserId::from_proto(request.user_id);
2750 let InviteMemberResult {
2751 channel,
2752 notifications,
2753 } = db
2754 .invite_channel_member(
2755 channel_id,
2756 invitee_id,
2757 session.user_id(),
2758 request.role().into(),
2759 )
2760 .await?;
2761
2762 let update = proto::UpdateChannels {
2763 channel_invitations: vec![channel.to_proto()],
2764 ..Default::default()
2765 };
2766
2767 let connection_pool = session.connection_pool().await;
2768 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2769 session.peer.send(connection_id, update.clone())?;
2770 }
2771
2772 send_notifications(&connection_pool, &session.peer, notifications);
2773
2774 response.send(proto::Ack {})?;
2775 Ok(())
2776}
2777
2778/// remove someone from a channel
2779async fn remove_channel_member(
2780 request: proto::RemoveChannelMember,
2781 response: Response<proto::RemoveChannelMember>,
2782 session: Session,
2783) -> Result<()> {
2784 let db = session.db().await;
2785 let channel_id = ChannelId::from_proto(request.channel_id);
2786 let member_id = UserId::from_proto(request.user_id);
2787
2788 let RemoveChannelMemberResult {
2789 membership_update,
2790 notification_id,
2791 } = db
2792 .remove_channel_member(channel_id, member_id, session.user_id())
2793 .await?;
2794
2795 let mut connection_pool = session.connection_pool().await;
2796 notify_membership_updated(
2797 &mut connection_pool,
2798 membership_update,
2799 member_id,
2800 &session.peer,
2801 );
2802 for connection_id in connection_pool.user_connection_ids(member_id) {
2803 if let Some(notification_id) = notification_id {
2804 session
2805 .peer
2806 .send(
2807 connection_id,
2808 proto::DeleteNotification {
2809 notification_id: notification_id.to_proto(),
2810 },
2811 )
2812 .trace_err();
2813 }
2814 }
2815
2816 response.send(proto::Ack {})?;
2817 Ok(())
2818}
2819
2820/// Toggle the channel between public and private.
2821/// Care is taken to maintain the invariant that public channels only descend from public channels,
2822/// (though members-only channels can appear at any point in the hierarchy).
2823async fn set_channel_visibility(
2824 request: proto::SetChannelVisibility,
2825 response: Response<proto::SetChannelVisibility>,
2826 session: Session,
2827) -> Result<()> {
2828 let db = session.db().await;
2829 let channel_id = ChannelId::from_proto(request.channel_id);
2830 let visibility = request.visibility().into();
2831
2832 let channel_model = db
2833 .set_channel_visibility(channel_id, visibility, session.user_id())
2834 .await?;
2835 let root_id = channel_model.root_id();
2836 let channel = Channel::from_model(channel_model);
2837
2838 let mut connection_pool = session.connection_pool().await;
2839 for (user_id, role) in connection_pool
2840 .channel_user_ids(root_id)
2841 .collect::<Vec<_>>()
2842 .into_iter()
2843 {
2844 let update = if role.can_see_channel(channel.visibility) {
2845 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2846 proto::UpdateChannels {
2847 channels: vec![channel.to_proto()],
2848 ..Default::default()
2849 }
2850 } else {
2851 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2852 proto::UpdateChannels {
2853 delete_channels: vec![channel.id.to_proto()],
2854 ..Default::default()
2855 }
2856 };
2857
2858 for connection_id in connection_pool.user_connection_ids(user_id) {
2859 session.peer.send(connection_id, update.clone())?;
2860 }
2861 }
2862
2863 response.send(proto::Ack {})?;
2864 Ok(())
2865}
2866
2867/// Alter the role for a user in the channel.
2868async fn set_channel_member_role(
2869 request: proto::SetChannelMemberRole,
2870 response: Response<proto::SetChannelMemberRole>,
2871 session: Session,
2872) -> Result<()> {
2873 let db = session.db().await;
2874 let channel_id = ChannelId::from_proto(request.channel_id);
2875 let member_id = UserId::from_proto(request.user_id);
2876 let result = db
2877 .set_channel_member_role(
2878 channel_id,
2879 session.user_id(),
2880 member_id,
2881 request.role().into(),
2882 )
2883 .await?;
2884
2885 match result {
2886 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2887 let mut connection_pool = session.connection_pool().await;
2888 notify_membership_updated(
2889 &mut connection_pool,
2890 membership_update,
2891 member_id,
2892 &session.peer,
2893 )
2894 }
2895 db::SetMemberRoleResult::InviteUpdated(channel) => {
2896 let update = proto::UpdateChannels {
2897 channel_invitations: vec![channel.to_proto()],
2898 ..Default::default()
2899 };
2900
2901 for connection_id in session
2902 .connection_pool()
2903 .await
2904 .user_connection_ids(member_id)
2905 {
2906 session.peer.send(connection_id, update.clone())?;
2907 }
2908 }
2909 }
2910
2911 response.send(proto::Ack {})?;
2912 Ok(())
2913}
2914
2915/// Change the name of a channel
2916async fn rename_channel(
2917 request: proto::RenameChannel,
2918 response: Response<proto::RenameChannel>,
2919 session: Session,
2920) -> Result<()> {
2921 let db = session.db().await;
2922 let channel_id = ChannelId::from_proto(request.channel_id);
2923 let channel_model = db
2924 .rename_channel(channel_id, session.user_id(), &request.name)
2925 .await?;
2926 let root_id = channel_model.root_id();
2927 let channel = Channel::from_model(channel_model);
2928
2929 response.send(proto::RenameChannelResponse {
2930 channel: Some(channel.to_proto()),
2931 })?;
2932
2933 let connection_pool = session.connection_pool().await;
2934 let update = proto::UpdateChannels {
2935 channels: vec![channel.to_proto()],
2936 ..Default::default()
2937 };
2938 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2939 if role.can_see_channel(channel.visibility) {
2940 session.peer.send(connection_id, update.clone())?;
2941 }
2942 }
2943
2944 Ok(())
2945}
2946
2947/// Move a channel to a new parent.
2948async fn move_channel(
2949 request: proto::MoveChannel,
2950 response: Response<proto::MoveChannel>,
2951 session: Session,
2952) -> Result<()> {
2953 let channel_id = ChannelId::from_proto(request.channel_id);
2954 let to = ChannelId::from_proto(request.to);
2955
2956 let (root_id, channels) = session
2957 .db()
2958 .await
2959 .move_channel(channel_id, to, session.user_id())
2960 .await?;
2961
2962 let connection_pool = session.connection_pool().await;
2963 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2964 let channels = channels
2965 .iter()
2966 .filter_map(|channel| {
2967 if role.can_see_channel(channel.visibility) {
2968 Some(channel.to_proto())
2969 } else {
2970 None
2971 }
2972 })
2973 .collect::<Vec<_>>();
2974 if channels.is_empty() {
2975 continue;
2976 }
2977
2978 let update = proto::UpdateChannels {
2979 channels,
2980 ..Default::default()
2981 };
2982
2983 session.peer.send(connection_id, update.clone())?;
2984 }
2985
2986 response.send(Ack {})?;
2987 Ok(())
2988}
2989
2990/// Get the list of channel members
2991async fn get_channel_members(
2992 request: proto::GetChannelMembers,
2993 response: Response<proto::GetChannelMembers>,
2994 session: Session,
2995) -> Result<()> {
2996 let db = session.db().await;
2997 let channel_id = ChannelId::from_proto(request.channel_id);
2998 let limit = if request.limit == 0 {
2999 u16::MAX as u64
3000 } else {
3001 request.limit
3002 };
3003 let (members, users) = db
3004 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3005 .await?;
3006 response.send(proto::GetChannelMembersResponse { members, users })?;
3007 Ok(())
3008}
3009
3010/// Accept or decline a channel invitation.
3011async fn respond_to_channel_invite(
3012 request: proto::RespondToChannelInvite,
3013 response: Response<proto::RespondToChannelInvite>,
3014 session: Session,
3015) -> Result<()> {
3016 let db = session.db().await;
3017 let channel_id = ChannelId::from_proto(request.channel_id);
3018 let RespondToChannelInvite {
3019 membership_update,
3020 notifications,
3021 } = db
3022 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3023 .await?;
3024
3025 let mut connection_pool = session.connection_pool().await;
3026 if let Some(membership_update) = membership_update {
3027 notify_membership_updated(
3028 &mut connection_pool,
3029 membership_update,
3030 session.user_id(),
3031 &session.peer,
3032 );
3033 } else {
3034 let update = proto::UpdateChannels {
3035 remove_channel_invitations: vec![channel_id.to_proto()],
3036 ..Default::default()
3037 };
3038
3039 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3040 session.peer.send(connection_id, update.clone())?;
3041 }
3042 };
3043
3044 send_notifications(&connection_pool, &session.peer, notifications);
3045
3046 response.send(proto::Ack {})?;
3047
3048 Ok(())
3049}
3050
3051/// Join the channels' room
3052async fn join_channel(
3053 request: proto::JoinChannel,
3054 response: Response<proto::JoinChannel>,
3055 session: Session,
3056) -> Result<()> {
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 join_channel_internal(channel_id, Box::new(response), session).await
3059}
3060
3061trait JoinChannelInternalResponse {
3062 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3063}
3064impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3065 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3066 Response::<proto::JoinChannel>::send(self, result)
3067 }
3068}
3069impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3070 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3071 Response::<proto::JoinRoom>::send(self, result)
3072 }
3073}
3074
3075async fn join_channel_internal(
3076 channel_id: ChannelId,
3077 response: Box<impl JoinChannelInternalResponse>,
3078 session: Session,
3079) -> Result<()> {
3080 let joined_room = {
3081 let mut db = session.db().await;
3082 // If zed quits without leaving the room, and the user re-opens zed before the
3083 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3084 // room they were in.
3085 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3086 tracing::info!(
3087 stale_connection_id = %connection,
3088 "cleaning up stale connection",
3089 );
3090 drop(db);
3091 leave_room_for_session(&session, connection).await?;
3092 db = session.db().await;
3093 }
3094
3095 let (joined_room, membership_updated, role) = db
3096 .join_channel(channel_id, session.user_id(), session.connection_id)
3097 .await?;
3098
3099 let live_kit_connection_info =
3100 session
3101 .app_state
3102 .livekit_client
3103 .as_ref()
3104 .and_then(|live_kit| {
3105 let (can_publish, token) = if role == ChannelRole::Guest {
3106 (
3107 false,
3108 live_kit
3109 .guest_token(
3110 &joined_room.room.livekit_room,
3111 &session.user_id().to_string(),
3112 )
3113 .trace_err()?,
3114 )
3115 } else {
3116 (
3117 true,
3118 live_kit
3119 .room_token(
3120 &joined_room.room.livekit_room,
3121 &session.user_id().to_string(),
3122 )
3123 .trace_err()?,
3124 )
3125 };
3126
3127 Some(LiveKitConnectionInfo {
3128 server_url: live_kit.url().into(),
3129 token,
3130 can_publish,
3131 })
3132 });
3133
3134 response.send(proto::JoinRoomResponse {
3135 room: Some(joined_room.room.clone()),
3136 channel_id: joined_room
3137 .channel
3138 .as_ref()
3139 .map(|channel| channel.id.to_proto()),
3140 live_kit_connection_info,
3141 })?;
3142
3143 let mut connection_pool = session.connection_pool().await;
3144 if let Some(membership_updated) = membership_updated {
3145 notify_membership_updated(
3146 &mut connection_pool,
3147 membership_updated,
3148 session.user_id(),
3149 &session.peer,
3150 );
3151 }
3152
3153 room_updated(&joined_room.room, &session.peer);
3154
3155 joined_room
3156 };
3157
3158 channel_updated(
3159 &joined_room
3160 .channel
3161 .ok_or_else(|| anyhow!("channel not returned"))?,
3162 &joined_room.room,
3163 &session.peer,
3164 &*session.connection_pool().await,
3165 );
3166
3167 update_user_contacts(session.user_id(), &session).await?;
3168 Ok(())
3169}
3170
3171/// Start editing the channel notes
3172async fn join_channel_buffer(
3173 request: proto::JoinChannelBuffer,
3174 response: Response<proto::JoinChannelBuffer>,
3175 session: Session,
3176) -> Result<()> {
3177 let db = session.db().await;
3178 let channel_id = ChannelId::from_proto(request.channel_id);
3179
3180 let open_response = db
3181 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3182 .await?;
3183
3184 let collaborators = open_response.collaborators.clone();
3185 response.send(open_response)?;
3186
3187 let update = UpdateChannelBufferCollaborators {
3188 channel_id: channel_id.to_proto(),
3189 collaborators: collaborators.clone(),
3190 };
3191 channel_buffer_updated(
3192 session.connection_id,
3193 collaborators
3194 .iter()
3195 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3196 &update,
3197 &session.peer,
3198 );
3199
3200 Ok(())
3201}
3202
3203/// Edit the channel notes
3204async fn update_channel_buffer(
3205 request: proto::UpdateChannelBuffer,
3206 session: Session,
3207) -> Result<()> {
3208 let db = session.db().await;
3209 let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211 let (collaborators, epoch, version) = db
3212 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3213 .await?;
3214
3215 channel_buffer_updated(
3216 session.connection_id,
3217 collaborators.clone(),
3218 &proto::UpdateChannelBuffer {
3219 channel_id: channel_id.to_proto(),
3220 operations: request.operations,
3221 },
3222 &session.peer,
3223 );
3224
3225 let pool = &*session.connection_pool().await;
3226
3227 let non_collaborators =
3228 pool.channel_connection_ids(channel_id)
3229 .filter_map(|(connection_id, _)| {
3230 if collaborators.contains(&connection_id) {
3231 None
3232 } else {
3233 Some(connection_id)
3234 }
3235 });
3236
3237 broadcast(None, non_collaborators, |peer_id| {
3238 session.peer.send(
3239 peer_id,
3240 proto::UpdateChannels {
3241 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3242 channel_id: channel_id.to_proto(),
3243 epoch: epoch as u64,
3244 version: version.clone(),
3245 }],
3246 ..Default::default()
3247 },
3248 )
3249 });
3250
3251 Ok(())
3252}
3253
3254/// Rejoin the channel notes after a connection blip
3255async fn rejoin_channel_buffers(
3256 request: proto::RejoinChannelBuffers,
3257 response: Response<proto::RejoinChannelBuffers>,
3258 session: Session,
3259) -> Result<()> {
3260 let db = session.db().await;
3261 let buffers = db
3262 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3263 .await?;
3264
3265 for rejoined_buffer in &buffers {
3266 let collaborators_to_notify = rejoined_buffer
3267 .buffer
3268 .collaborators
3269 .iter()
3270 .filter_map(|c| Some(c.peer_id?.into()));
3271 channel_buffer_updated(
3272 session.connection_id,
3273 collaborators_to_notify,
3274 &proto::UpdateChannelBufferCollaborators {
3275 channel_id: rejoined_buffer.buffer.channel_id,
3276 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3277 },
3278 &session.peer,
3279 );
3280 }
3281
3282 response.send(proto::RejoinChannelBuffersResponse {
3283 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3284 })?;
3285
3286 Ok(())
3287}
3288
3289/// Stop editing the channel notes
3290async fn leave_channel_buffer(
3291 request: proto::LeaveChannelBuffer,
3292 response: Response<proto::LeaveChannelBuffer>,
3293 session: Session,
3294) -> Result<()> {
3295 let db = session.db().await;
3296 let channel_id = ChannelId::from_proto(request.channel_id);
3297
3298 let left_buffer = db
3299 .leave_channel_buffer(channel_id, session.connection_id)
3300 .await?;
3301
3302 response.send(Ack {})?;
3303
3304 channel_buffer_updated(
3305 session.connection_id,
3306 left_buffer.connections,
3307 &proto::UpdateChannelBufferCollaborators {
3308 channel_id: channel_id.to_proto(),
3309 collaborators: left_buffer.collaborators,
3310 },
3311 &session.peer,
3312 );
3313
3314 Ok(())
3315}
3316
3317fn channel_buffer_updated<T: EnvelopedMessage>(
3318 sender_id: ConnectionId,
3319 collaborators: impl IntoIterator<Item = ConnectionId>,
3320 message: &T,
3321 peer: &Peer,
3322) {
3323 broadcast(Some(sender_id), collaborators, |peer_id| {
3324 peer.send(peer_id, message.clone())
3325 });
3326}
3327
3328fn send_notifications(
3329 connection_pool: &ConnectionPool,
3330 peer: &Peer,
3331 notifications: db::NotificationBatch,
3332) {
3333 for (user_id, notification) in notifications {
3334 for connection_id in connection_pool.user_connection_ids(user_id) {
3335 if let Err(error) = peer.send(
3336 connection_id,
3337 proto::AddNotification {
3338 notification: Some(notification.clone()),
3339 },
3340 ) {
3341 tracing::error!(
3342 "failed to send notification to {:?} {}",
3343 connection_id,
3344 error
3345 );
3346 }
3347 }
3348 }
3349}
3350
3351/// Send a message to the channel
3352async fn send_channel_message(
3353 request: proto::SendChannelMessage,
3354 response: Response<proto::SendChannelMessage>,
3355 session: Session,
3356) -> Result<()> {
3357 // Validate the message body.
3358 let body = request.body.trim().to_string();
3359 if body.len() > MAX_MESSAGE_LEN {
3360 return Err(anyhow!("message is too long"))?;
3361 }
3362 if body.is_empty() {
3363 return Err(anyhow!("message can't be blank"))?;
3364 }
3365
3366 // TODO: adjust mentions if body is trimmed
3367
3368 let timestamp = OffsetDateTime::now_utc();
3369 let nonce = request
3370 .nonce
3371 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3372
3373 let channel_id = ChannelId::from_proto(request.channel_id);
3374 let CreatedChannelMessage {
3375 message_id,
3376 participant_connection_ids,
3377 notifications,
3378 } = session
3379 .db()
3380 .await
3381 .create_channel_message(
3382 channel_id,
3383 session.user_id(),
3384 &body,
3385 &request.mentions,
3386 timestamp,
3387 nonce.clone().into(),
3388 request.reply_to_message_id.map(MessageId::from_proto),
3389 )
3390 .await?;
3391
3392 let message = proto::ChannelMessage {
3393 sender_id: session.user_id().to_proto(),
3394 id: message_id.to_proto(),
3395 body,
3396 mentions: request.mentions,
3397 timestamp: timestamp.unix_timestamp() as u64,
3398 nonce: Some(nonce),
3399 reply_to_message_id: request.reply_to_message_id,
3400 edited_at: None,
3401 };
3402 broadcast(
3403 Some(session.connection_id),
3404 participant_connection_ids.clone(),
3405 |connection| {
3406 session.peer.send(
3407 connection,
3408 proto::ChannelMessageSent {
3409 channel_id: channel_id.to_proto(),
3410 message: Some(message.clone()),
3411 },
3412 )
3413 },
3414 );
3415 response.send(proto::SendChannelMessageResponse {
3416 message: Some(message),
3417 })?;
3418
3419 let pool = &*session.connection_pool().await;
3420 let non_participants =
3421 pool.channel_connection_ids(channel_id)
3422 .filter_map(|(connection_id, _)| {
3423 if participant_connection_ids.contains(&connection_id) {
3424 None
3425 } else {
3426 Some(connection_id)
3427 }
3428 });
3429 broadcast(None, non_participants, |peer_id| {
3430 session.peer.send(
3431 peer_id,
3432 proto::UpdateChannels {
3433 latest_channel_message_ids: vec![proto::ChannelMessageId {
3434 channel_id: channel_id.to_proto(),
3435 message_id: message_id.to_proto(),
3436 }],
3437 ..Default::default()
3438 },
3439 )
3440 });
3441 send_notifications(pool, &session.peer, notifications);
3442
3443 Ok(())
3444}
3445
3446/// Delete a channel message
3447async fn remove_channel_message(
3448 request: proto::RemoveChannelMessage,
3449 response: Response<proto::RemoveChannelMessage>,
3450 session: Session,
3451) -> Result<()> {
3452 let channel_id = ChannelId::from_proto(request.channel_id);
3453 let message_id = MessageId::from_proto(request.message_id);
3454 let (connection_ids, existing_notification_ids) = session
3455 .db()
3456 .await
3457 .remove_channel_message(channel_id, message_id, session.user_id())
3458 .await?;
3459
3460 broadcast(
3461 Some(session.connection_id),
3462 connection_ids,
3463 move |connection| {
3464 session.peer.send(connection, request.clone())?;
3465
3466 for notification_id in &existing_notification_ids {
3467 session.peer.send(
3468 connection,
3469 proto::DeleteNotification {
3470 notification_id: (*notification_id).to_proto(),
3471 },
3472 )?;
3473 }
3474
3475 Ok(())
3476 },
3477 );
3478 response.send(proto::Ack {})?;
3479 Ok(())
3480}
3481
3482async fn update_channel_message(
3483 request: proto::UpdateChannelMessage,
3484 response: Response<proto::UpdateChannelMessage>,
3485 session: Session,
3486) -> Result<()> {
3487 let channel_id = ChannelId::from_proto(request.channel_id);
3488 let message_id = MessageId::from_proto(request.message_id);
3489 let updated_at = OffsetDateTime::now_utc();
3490 let UpdatedChannelMessage {
3491 message_id,
3492 participant_connection_ids,
3493 notifications,
3494 reply_to_message_id,
3495 timestamp,
3496 deleted_mention_notification_ids,
3497 updated_mention_notifications,
3498 } = session
3499 .db()
3500 .await
3501 .update_channel_message(
3502 channel_id,
3503 message_id,
3504 session.user_id(),
3505 request.body.as_str(),
3506 &request.mentions,
3507 updated_at,
3508 )
3509 .await?;
3510
3511 let nonce = request
3512 .nonce
3513 .clone()
3514 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3515
3516 let message = proto::ChannelMessage {
3517 sender_id: session.user_id().to_proto(),
3518 id: message_id.to_proto(),
3519 body: request.body.clone(),
3520 mentions: request.mentions.clone(),
3521 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3522 nonce: Some(nonce),
3523 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3524 edited_at: Some(updated_at.unix_timestamp() as u64),
3525 };
3526
3527 response.send(proto::Ack {})?;
3528
3529 let pool = &*session.connection_pool().await;
3530 broadcast(
3531 Some(session.connection_id),
3532 participant_connection_ids,
3533 |connection| {
3534 session.peer.send(
3535 connection,
3536 proto::ChannelMessageUpdate {
3537 channel_id: channel_id.to_proto(),
3538 message: Some(message.clone()),
3539 },
3540 )?;
3541
3542 for notification_id in &deleted_mention_notification_ids {
3543 session.peer.send(
3544 connection,
3545 proto::DeleteNotification {
3546 notification_id: (*notification_id).to_proto(),
3547 },
3548 )?;
3549 }
3550
3551 for notification in &updated_mention_notifications {
3552 session.peer.send(
3553 connection,
3554 proto::UpdateNotification {
3555 notification: Some(notification.clone()),
3556 },
3557 )?;
3558 }
3559
3560 Ok(())
3561 },
3562 );
3563
3564 send_notifications(pool, &session.peer, notifications);
3565
3566 Ok(())
3567}
3568
3569/// Mark a channel message as read
3570async fn acknowledge_channel_message(
3571 request: proto::AckChannelMessage,
3572 session: Session,
3573) -> Result<()> {
3574 let channel_id = ChannelId::from_proto(request.channel_id);
3575 let message_id = MessageId::from_proto(request.message_id);
3576 let notifications = session
3577 .db()
3578 .await
3579 .observe_channel_message(channel_id, session.user_id(), message_id)
3580 .await?;
3581 send_notifications(
3582 &*session.connection_pool().await,
3583 &session.peer,
3584 notifications,
3585 );
3586 Ok(())
3587}
3588
3589/// Mark a buffer version as synced
3590async fn acknowledge_buffer_version(
3591 request: proto::AckBufferOperation,
3592 session: Session,
3593) -> Result<()> {
3594 let buffer_id = BufferId::from_proto(request.buffer_id);
3595 session
3596 .db()
3597 .await
3598 .observe_buffer_version(
3599 buffer_id,
3600 session.user_id(),
3601 request.epoch as i32,
3602 &request.version,
3603 )
3604 .await?;
3605 Ok(())
3606}
3607
3608async fn count_language_model_tokens(
3609 request: proto::CountLanguageModelTokens,
3610 response: Response<proto::CountLanguageModelTokens>,
3611 session: Session,
3612 config: &Config,
3613) -> Result<()> {
3614 authorize_access_to_legacy_llm_endpoints(&session).await?;
3615
3616 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3617 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3618 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3619 };
3620
3621 session
3622 .app_state
3623 .rate_limiter
3624 .check(&*rate_limit, session.user_id())
3625 .await?;
3626
3627 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3628 Some(proto::LanguageModelProvider::Google) => {
3629 let api_key = config
3630 .google_ai_api_key
3631 .as_ref()
3632 .context("no Google AI API key configured on the server")?;
3633 google_ai::count_tokens(
3634 session.http_client.as_ref(),
3635 google_ai::API_URL,
3636 api_key,
3637 serde_json::from_str(&request.request)?,
3638 )
3639 .await?
3640 }
3641 _ => return Err(anyhow!("unsupported provider"))?,
3642 };
3643
3644 response.send(proto::CountLanguageModelTokensResponse {
3645 token_count: result.total_tokens as u32,
3646 })?;
3647
3648 Ok(())
3649}
3650
3651struct ZedProCountLanguageModelTokensRateLimit;
3652
3653impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3654 fn capacity(&self) -> usize {
3655 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3656 .ok()
3657 .and_then(|v| v.parse().ok())
3658 .unwrap_or(600) // Picked arbitrarily
3659 }
3660
3661 fn refill_duration(&self) -> chrono::Duration {
3662 chrono::Duration::hours(1)
3663 }
3664
3665 fn db_name(&self) -> &'static str {
3666 "zed-pro:count-language-model-tokens"
3667 }
3668}
3669
3670struct FreeCountLanguageModelTokensRateLimit;
3671
3672impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3673 fn capacity(&self) -> usize {
3674 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3675 .ok()
3676 .and_then(|v| v.parse().ok())
3677 .unwrap_or(600 / 10) // Picked arbitrarily
3678 }
3679
3680 fn refill_duration(&self) -> chrono::Duration {
3681 chrono::Duration::hours(1)
3682 }
3683
3684 fn db_name(&self) -> &'static str {
3685 "free:count-language-model-tokens"
3686 }
3687}
3688
3689struct ZedProComputeEmbeddingsRateLimit;
3690
3691impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3692 fn capacity(&self) -> usize {
3693 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3694 .ok()
3695 .and_then(|v| v.parse().ok())
3696 .unwrap_or(5000) // Picked arbitrarily
3697 }
3698
3699 fn refill_duration(&self) -> chrono::Duration {
3700 chrono::Duration::hours(1)
3701 }
3702
3703 fn db_name(&self) -> &'static str {
3704 "zed-pro:compute-embeddings"
3705 }
3706}
3707
3708struct FreeComputeEmbeddingsRateLimit;
3709
3710impl RateLimit for FreeComputeEmbeddingsRateLimit {
3711 fn capacity(&self) -> usize {
3712 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3713 .ok()
3714 .and_then(|v| v.parse().ok())
3715 .unwrap_or(5000 / 10) // Picked arbitrarily
3716 }
3717
3718 fn refill_duration(&self) -> chrono::Duration {
3719 chrono::Duration::hours(1)
3720 }
3721
3722 fn db_name(&self) -> &'static str {
3723 "free:compute-embeddings"
3724 }
3725}
3726
3727async fn compute_embeddings(
3728 request: proto::ComputeEmbeddings,
3729 response: Response<proto::ComputeEmbeddings>,
3730 session: Session,
3731 api_key: Option<Arc<str>>,
3732) -> Result<()> {
3733 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3734 authorize_access_to_legacy_llm_endpoints(&session).await?;
3735
3736 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3737 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3738 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3739 };
3740
3741 session
3742 .app_state
3743 .rate_limiter
3744 .check(&*rate_limit, session.user_id())
3745 .await?;
3746
3747 let embeddings = match request.model.as_str() {
3748 "openai/text-embedding-3-small" => {
3749 open_ai::embed(
3750 session.http_client.as_ref(),
3751 OPEN_AI_API_URL,
3752 &api_key,
3753 OpenAiEmbeddingModel::TextEmbedding3Small,
3754 request.texts.iter().map(|text| text.as_str()),
3755 )
3756 .await?
3757 }
3758 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3759 };
3760
3761 let embeddings = request
3762 .texts
3763 .iter()
3764 .map(|text| {
3765 let mut hasher = sha2::Sha256::new();
3766 hasher.update(text.as_bytes());
3767 let result = hasher.finalize();
3768 result.to_vec()
3769 })
3770 .zip(
3771 embeddings
3772 .data
3773 .into_iter()
3774 .map(|embedding| embedding.embedding),
3775 )
3776 .collect::<HashMap<_, _>>();
3777
3778 let db = session.db().await;
3779 db.save_embeddings(&request.model, &embeddings)
3780 .await
3781 .context("failed to save embeddings")
3782 .trace_err();
3783
3784 response.send(proto::ComputeEmbeddingsResponse {
3785 embeddings: embeddings
3786 .into_iter()
3787 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3788 .collect(),
3789 })?;
3790 Ok(())
3791}
3792
3793async fn get_cached_embeddings(
3794 request: proto::GetCachedEmbeddings,
3795 response: Response<proto::GetCachedEmbeddings>,
3796 session: Session,
3797) -> Result<()> {
3798 authorize_access_to_legacy_llm_endpoints(&session).await?;
3799
3800 let db = session.db().await;
3801 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3802
3803 response.send(proto::GetCachedEmbeddingsResponse {
3804 embeddings: embeddings
3805 .into_iter()
3806 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3807 .collect(),
3808 })?;
3809 Ok(())
3810}
3811
3812/// This is leftover from before the LLM service.
3813///
3814/// The endpoints protected by this check will be moved there eventually.
3815async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3816 if session.is_staff() {
3817 Ok(())
3818 } else {
3819 Err(anyhow!("permission denied"))?
3820 }
3821}
3822
3823/// Get a Supermaven API key for the user
3824async fn get_supermaven_api_key(
3825 _request: proto::GetSupermavenApiKey,
3826 response: Response<proto::GetSupermavenApiKey>,
3827 session: Session,
3828) -> Result<()> {
3829 let user_id: String = session.user_id().to_string();
3830 if !session.is_staff() {
3831 return Err(anyhow!("supermaven not enabled for this account"))?;
3832 }
3833
3834 let email = session
3835 .email()
3836 .ok_or_else(|| anyhow!("user must have an email"))?;
3837
3838 let supermaven_admin_api = session
3839 .supermaven_client
3840 .as_ref()
3841 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3842
3843 let result = supermaven_admin_api
3844 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3845 .await?;
3846
3847 response.send(proto::GetSupermavenApiKeyResponse {
3848 api_key: result.api_key,
3849 })?;
3850
3851 Ok(())
3852}
3853
3854/// Start receiving chat updates for a channel
3855async fn join_channel_chat(
3856 request: proto::JoinChannelChat,
3857 response: Response<proto::JoinChannelChat>,
3858 session: Session,
3859) -> Result<()> {
3860 let channel_id = ChannelId::from_proto(request.channel_id);
3861
3862 let db = session.db().await;
3863 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3864 .await?;
3865 let messages = db
3866 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3867 .await?;
3868 response.send(proto::JoinChannelChatResponse {
3869 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870 messages,
3871 })?;
3872 Ok(())
3873}
3874
3875/// Stop receiving chat updates for a channel
3876async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3877 let channel_id = ChannelId::from_proto(request.channel_id);
3878 session
3879 .db()
3880 .await
3881 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3882 .await?;
3883 Ok(())
3884}
3885
3886/// Retrieve the chat history for a channel
3887async fn get_channel_messages(
3888 request: proto::GetChannelMessages,
3889 response: Response<proto::GetChannelMessages>,
3890 session: Session,
3891) -> Result<()> {
3892 let channel_id = ChannelId::from_proto(request.channel_id);
3893 let messages = session
3894 .db()
3895 .await
3896 .get_channel_messages(
3897 channel_id,
3898 session.user_id(),
3899 MESSAGE_COUNT_PER_PAGE,
3900 Some(MessageId::from_proto(request.before_message_id)),
3901 )
3902 .await?;
3903 response.send(proto::GetChannelMessagesResponse {
3904 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3905 messages,
3906 })?;
3907 Ok(())
3908}
3909
3910/// Retrieve specific chat messages
3911async fn get_channel_messages_by_id(
3912 request: proto::GetChannelMessagesById,
3913 response: Response<proto::GetChannelMessagesById>,
3914 session: Session,
3915) -> Result<()> {
3916 let message_ids = request
3917 .message_ids
3918 .iter()
3919 .map(|id| MessageId::from_proto(*id))
3920 .collect::<Vec<_>>();
3921 let messages = session
3922 .db()
3923 .await
3924 .get_channel_messages_by_id(session.user_id(), &message_ids)
3925 .await?;
3926 response.send(proto::GetChannelMessagesResponse {
3927 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3928 messages,
3929 })?;
3930 Ok(())
3931}
3932
3933/// Retrieve the current users notifications
3934async fn get_notifications(
3935 request: proto::GetNotifications,
3936 response: Response<proto::GetNotifications>,
3937 session: Session,
3938) -> Result<()> {
3939 let notifications = session
3940 .db()
3941 .await
3942 .get_notifications(
3943 session.user_id(),
3944 NOTIFICATION_COUNT_PER_PAGE,
3945 request.before_id.map(db::NotificationId::from_proto),
3946 )
3947 .await?;
3948 response.send(proto::GetNotificationsResponse {
3949 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3950 notifications,
3951 })?;
3952 Ok(())
3953}
3954
3955/// Mark notifications as read
3956async fn mark_notification_as_read(
3957 request: proto::MarkNotificationRead,
3958 response: Response<proto::MarkNotificationRead>,
3959 session: Session,
3960) -> Result<()> {
3961 let database = &session.db().await;
3962 let notifications = database
3963 .mark_notification_as_read_by_id(
3964 session.user_id(),
3965 NotificationId::from_proto(request.notification_id),
3966 )
3967 .await?;
3968 send_notifications(
3969 &*session.connection_pool().await,
3970 &session.peer,
3971 notifications,
3972 );
3973 response.send(proto::Ack {})?;
3974 Ok(())
3975}
3976
3977/// Get the current users information
3978async fn get_private_user_info(
3979 _request: proto::GetPrivateUserInfo,
3980 response: Response<proto::GetPrivateUserInfo>,
3981 session: Session,
3982) -> Result<()> {
3983 let db = session.db().await;
3984
3985 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3986 let user = db
3987 .get_user_by_id(session.user_id())
3988 .await?
3989 .ok_or_else(|| anyhow!("user not found"))?;
3990 let flags = db.get_user_flags(session.user_id()).await?;
3991
3992 response.send(proto::GetPrivateUserInfoResponse {
3993 metrics_id,
3994 staff: user.admin,
3995 flags,
3996 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3997 })?;
3998 Ok(())
3999}
4000
4001/// Accept the terms of service (tos) on behalf of the current user
4002async fn accept_terms_of_service(
4003 _request: proto::AcceptTermsOfService,
4004 response: Response<proto::AcceptTermsOfService>,
4005 session: Session,
4006) -> Result<()> {
4007 let db = session.db().await;
4008
4009 let accepted_tos_at = Utc::now();
4010 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4011 .await?;
4012
4013 response.send(proto::AcceptTermsOfServiceResponse {
4014 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4015 })?;
4016 Ok(())
4017}
4018
4019/// The minimum account age an account must have in order to use the LLM service.
4020const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4021
4022async fn get_llm_api_token(
4023 _request: proto::GetLlmToken,
4024 response: Response<proto::GetLlmToken>,
4025 session: Session,
4026) -> Result<()> {
4027 let db = session.db().await;
4028
4029 let flags = db.get_user_flags(session.user_id()).await?;
4030 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4031 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4032 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4033
4034 if !session.is_staff() && !has_language_models_feature_flag {
4035 Err(anyhow!("permission denied"))?
4036 }
4037
4038 let user_id = session.user_id();
4039 let user = db
4040 .get_user_by_id(user_id)
4041 .await?
4042 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4043
4044 if user.accepted_tos_at.is_none() {
4045 Err(anyhow!("terms of service not accepted"))?
4046 }
4047
4048 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4049
4050 let bypass_account_age_check =
4051 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4052 if !bypass_account_age_check {
4053 let mut account_created_at = user.created_at;
4054 if let Some(github_created_at) = user.github_user_created_at {
4055 account_created_at = account_created_at.min(github_created_at);
4056 }
4057 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4058 Err(anyhow!("account too young"))?
4059 }
4060 }
4061
4062 let billing_preferences = db.get_billing_preferences(user.id).await?;
4063
4064 let token = LlmTokenClaims::create(
4065 &user,
4066 session.is_staff(),
4067 billing_preferences,
4068 has_llm_closed_beta_feature_flag,
4069 has_predict_edits_feature_flag,
4070 has_llm_subscription,
4071 session.current_plan(&db).await?,
4072 session.system_id.clone(),
4073 &session.app_state.config,
4074 )?;
4075 response.send(proto::GetLlmTokenResponse { token })?;
4076 Ok(())
4077}
4078
4079fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4080 let message = match message {
4081 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4082 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4083 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4084 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4085 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4086 code: frame.code.into(),
4087 reason: frame.reason,
4088 })),
4089 // We should never receive a frame while reading the message, according
4090 // to the `tungstenite` maintainers:
4091 //
4092 // > It cannot occur when you read messages from the WebSocket, but it
4093 // > can be used when you want to send the raw frames (e.g. you want to
4094 // > send the frames to the WebSocket without composing the full message first).
4095 // >
4096 // > — https://github.com/snapview/tungstenite-rs/issues/268
4097 TungsteniteMessage::Frame(_) => {
4098 bail!("received an unexpected frame while reading the message")
4099 }
4100 };
4101
4102 Ok(message)
4103}
4104
4105fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4106 match message {
4107 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4108 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4109 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4110 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4111 AxumMessage::Close(frame) => {
4112 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4113 code: frame.code.into(),
4114 reason: frame.reason,
4115 }))
4116 }
4117 }
4118}
4119
4120fn notify_membership_updated(
4121 connection_pool: &mut ConnectionPool,
4122 result: MembershipUpdated,
4123 user_id: UserId,
4124 peer: &Peer,
4125) {
4126 for membership in &result.new_channels.channel_memberships {
4127 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4128 }
4129 for channel_id in &result.removed_channels {
4130 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4131 }
4132
4133 let user_channels_update = proto::UpdateUserChannels {
4134 channel_memberships: result
4135 .new_channels
4136 .channel_memberships
4137 .iter()
4138 .map(|cm| proto::ChannelMembership {
4139 channel_id: cm.channel_id.to_proto(),
4140 role: cm.role.into(),
4141 })
4142 .collect(),
4143 ..Default::default()
4144 };
4145
4146 let mut update = build_channels_update(result.new_channels);
4147 update.delete_channels = result
4148 .removed_channels
4149 .into_iter()
4150 .map(|id| id.to_proto())
4151 .collect();
4152 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4153
4154 for connection_id in connection_pool.user_connection_ids(user_id) {
4155 peer.send(connection_id, user_channels_update.clone())
4156 .trace_err();
4157 peer.send(connection_id, update.clone()).trace_err();
4158 }
4159}
4160
4161fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4162 proto::UpdateUserChannels {
4163 channel_memberships: channels
4164 .channel_memberships
4165 .iter()
4166 .map(|m| proto::ChannelMembership {
4167 channel_id: m.channel_id.to_proto(),
4168 role: m.role.into(),
4169 })
4170 .collect(),
4171 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4172 observed_channel_message_id: channels.observed_channel_messages.clone(),
4173 }
4174}
4175
4176fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4177 let mut update = proto::UpdateChannels::default();
4178
4179 for channel in channels.channels {
4180 update.channels.push(channel.to_proto());
4181 }
4182
4183 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4184 update.latest_channel_message_ids = channels.latest_channel_messages;
4185
4186 for (channel_id, participants) in channels.channel_participants {
4187 update
4188 .channel_participants
4189 .push(proto::ChannelParticipants {
4190 channel_id: channel_id.to_proto(),
4191 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4192 });
4193 }
4194
4195 for channel in channels.invited_channels {
4196 update.channel_invitations.push(channel.to_proto());
4197 }
4198
4199 update
4200}
4201
4202fn build_initial_contacts_update(
4203 contacts: Vec<db::Contact>,
4204 pool: &ConnectionPool,
4205) -> proto::UpdateContacts {
4206 let mut update = proto::UpdateContacts::default();
4207
4208 for contact in contacts {
4209 match contact {
4210 db::Contact::Accepted { user_id, busy } => {
4211 update.contacts.push(contact_for_user(user_id, busy, pool));
4212 }
4213 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4214 db::Contact::Incoming { user_id } => {
4215 update
4216 .incoming_requests
4217 .push(proto::IncomingContactRequest {
4218 requester_id: user_id.to_proto(),
4219 })
4220 }
4221 }
4222 }
4223
4224 update
4225}
4226
4227fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4228 proto::Contact {
4229 user_id: user_id.to_proto(),
4230 online: pool.is_user_online(user_id),
4231 busy,
4232 }
4233}
4234
4235fn room_updated(room: &proto::Room, peer: &Peer) {
4236 broadcast(
4237 None,
4238 room.participants
4239 .iter()
4240 .filter_map(|participant| Some(participant.peer_id?.into())),
4241 |peer_id| {
4242 peer.send(
4243 peer_id,
4244 proto::RoomUpdated {
4245 room: Some(room.clone()),
4246 },
4247 )
4248 },
4249 );
4250}
4251
4252fn channel_updated(
4253 channel: &db::channel::Model,
4254 room: &proto::Room,
4255 peer: &Peer,
4256 pool: &ConnectionPool,
4257) {
4258 let participants = room
4259 .participants
4260 .iter()
4261 .map(|p| p.user_id)
4262 .collect::<Vec<_>>();
4263
4264 broadcast(
4265 None,
4266 pool.channel_connection_ids(channel.root_id())
4267 .filter_map(|(channel_id, role)| {
4268 role.can_see_channel(channel.visibility)
4269 .then_some(channel_id)
4270 }),
4271 |peer_id| {
4272 peer.send(
4273 peer_id,
4274 proto::UpdateChannels {
4275 channel_participants: vec![proto::ChannelParticipants {
4276 channel_id: channel.id.to_proto(),
4277 participant_user_ids: participants.clone(),
4278 }],
4279 ..Default::default()
4280 },
4281 )
4282 },
4283 );
4284}
4285
4286async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4287 let db = session.db().await;
4288
4289 let contacts = db.get_contacts(user_id).await?;
4290 let busy = db.is_user_busy(user_id).await?;
4291
4292 let pool = session.connection_pool().await;
4293 let updated_contact = contact_for_user(user_id, busy, &pool);
4294 for contact in contacts {
4295 if let db::Contact::Accepted {
4296 user_id: contact_user_id,
4297 ..
4298 } = contact
4299 {
4300 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4301 session
4302 .peer
4303 .send(
4304 contact_conn_id,
4305 proto::UpdateContacts {
4306 contacts: vec![updated_contact.clone()],
4307 remove_contacts: Default::default(),
4308 incoming_requests: Default::default(),
4309 remove_incoming_requests: Default::default(),
4310 outgoing_requests: Default::default(),
4311 remove_outgoing_requests: Default::default(),
4312 },
4313 )
4314 .trace_err();
4315 }
4316 }
4317 }
4318 Ok(())
4319}
4320
4321async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4322 let mut contacts_to_update = HashSet::default();
4323
4324 let room_id;
4325 let canceled_calls_to_user_ids;
4326 let livekit_room;
4327 let delete_livekit_room;
4328 let room;
4329 let channel;
4330
4331 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4332 contacts_to_update.insert(session.user_id());
4333
4334 for project in left_room.left_projects.values() {
4335 project_left(project, session);
4336 }
4337
4338 room_id = RoomId::from_proto(left_room.room.id);
4339 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4340 livekit_room = mem::take(&mut left_room.room.livekit_room);
4341 delete_livekit_room = left_room.deleted;
4342 room = mem::take(&mut left_room.room);
4343 channel = mem::take(&mut left_room.channel);
4344
4345 room_updated(&room, &session.peer);
4346 } else {
4347 return Ok(());
4348 }
4349
4350 if let Some(channel) = channel {
4351 channel_updated(
4352 &channel,
4353 &room,
4354 &session.peer,
4355 &*session.connection_pool().await,
4356 );
4357 }
4358
4359 {
4360 let pool = session.connection_pool().await;
4361 for canceled_user_id in canceled_calls_to_user_ids {
4362 for connection_id in pool.user_connection_ids(canceled_user_id) {
4363 session
4364 .peer
4365 .send(
4366 connection_id,
4367 proto::CallCanceled {
4368 room_id: room_id.to_proto(),
4369 },
4370 )
4371 .trace_err();
4372 }
4373 contacts_to_update.insert(canceled_user_id);
4374 }
4375 }
4376
4377 for contact_user_id in contacts_to_update {
4378 update_user_contacts(contact_user_id, session).await?;
4379 }
4380
4381 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4382 live_kit
4383 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4384 .await
4385 .trace_err();
4386
4387 if delete_livekit_room {
4388 live_kit.delete_room(livekit_room).await.trace_err();
4389 }
4390 }
4391
4392 Ok(())
4393}
4394
4395async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4396 let left_channel_buffers = session
4397 .db()
4398 .await
4399 .leave_channel_buffers(session.connection_id)
4400 .await?;
4401
4402 for left_buffer in left_channel_buffers {
4403 channel_buffer_updated(
4404 session.connection_id,
4405 left_buffer.connections,
4406 &proto::UpdateChannelBufferCollaborators {
4407 channel_id: left_buffer.channel_id.to_proto(),
4408 collaborators: left_buffer.collaborators,
4409 },
4410 &session.peer,
4411 );
4412 }
4413
4414 Ok(())
4415}
4416
4417fn project_left(project: &db::LeftProject, session: &Session) {
4418 for connection_id in &project.connection_ids {
4419 if project.should_unshare {
4420 session
4421 .peer
4422 .send(
4423 *connection_id,
4424 proto::UnshareProject {
4425 project_id: project.id.to_proto(),
4426 },
4427 )
4428 .trace_err();
4429 } else {
4430 session
4431 .peer
4432 .send(
4433 *connection_id,
4434 proto::RemoveProjectCollaborator {
4435 project_id: project.id.to_proto(),
4436 peer_id: Some(session.connection_id.into()),
4437 },
4438 )
4439 .trace_err();
4440 }
4441 }
4442}
4443
4444pub trait ResultExt {
4445 type Ok;
4446
4447 fn trace_err(self) -> Option<Self::Ok>;
4448}
4449
4450impl<T, E> ResultExt for Result<T, E>
4451where
4452 E: std::fmt::Debug,
4453{
4454 type Ok = T;
4455
4456 #[track_caller]
4457 fn trace_err(self) -> Option<T> {
4458 match self {
4459 Ok(value) => Some(value),
4460 Err(error) => {
4461 tracing::error!("{:?}", error);
4462 None
4463 }
4464 }
4465 }
4466}