1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
313 .add_request_handler(
314 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
315 )
316 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
317 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
322 .add_request_handler(
323 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
324 )
325 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
326 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
327 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
328 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
330 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
340 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
341 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
343 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
344 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
345 .add_message_handler(create_buffer_for_peer)
346 .add_request_handler(update_buffer)
347 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
348 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
352 .add_request_handler(get_users)
353 .add_request_handler(fuzzy_search_users)
354 .add_request_handler(request_contact)
355 .add_request_handler(remove_contact)
356 .add_request_handler(respond_to_contact_request)
357 .add_message_handler(subscribe_to_channels)
358 .add_request_handler(create_channel)
359 .add_request_handler(delete_channel)
360 .add_request_handler(invite_channel_member)
361 .add_request_handler(remove_channel_member)
362 .add_request_handler(set_channel_member_role)
363 .add_request_handler(set_channel_visibility)
364 .add_request_handler(rename_channel)
365 .add_request_handler(join_channel_buffer)
366 .add_request_handler(leave_channel_buffer)
367 .add_message_handler(update_channel_buffer)
368 .add_request_handler(rejoin_channel_buffers)
369 .add_request_handler(get_channel_members)
370 .add_request_handler(respond_to_channel_invite)
371 .add_request_handler(join_channel)
372 .add_request_handler(join_channel_chat)
373 .add_message_handler(leave_channel_chat)
374 .add_request_handler(send_channel_message)
375 .add_request_handler(remove_channel_message)
376 .add_request_handler(update_channel_message)
377 .add_request_handler(get_channel_messages)
378 .add_request_handler(get_channel_messages_by_id)
379 .add_request_handler(get_notifications)
380 .add_request_handler(mark_notification_as_read)
381 .add_request_handler(move_channel)
382 .add_request_handler(follow)
383 .add_message_handler(unfollow)
384 .add_message_handler(update_followers)
385 .add_request_handler(get_private_user_info)
386 .add_request_handler(get_llm_api_token)
387 .add_request_handler(accept_terms_of_service)
388 .add_message_handler(acknowledge_channel_message)
389 .add_message_handler(acknowledge_buffer_version)
390 .add_request_handler(get_supermaven_api_key)
391 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
392 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
394 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
395 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
398 .add_message_handler(update_context)
399 .add_request_handler({
400 let app_state = app_state.clone();
401 move |request, response, session| {
402 let app_state = app_state.clone();
403 async move {
404 count_language_model_tokens(request, response, session, &app_state.config)
405 .await
406 }
407 }
408 })
409 .add_request_handler(get_cached_embeddings)
410 .add_request_handler({
411 let app_state = app_state.clone();
412 move |request, response, session| {
413 compute_embeddings(
414 request,
415 response,
416 session,
417 app_state.config.openai_api_key.clone(),
418 )
419 }
420 });
421
422 Arc::new(server)
423 }
424
425 pub async fn start(&self) -> Result<()> {
426 let server_id = *self.id.lock();
427 let app_state = self.app_state.clone();
428 let peer = self.peer.clone();
429 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
430 let pool = self.connection_pool.clone();
431 let livekit_client = self.app_state.livekit_client.clone();
432
433 let span = info_span!("start server");
434 self.app_state.executor.spawn_detached(
435 async move {
436 tracing::info!("waiting for cleanup timeout");
437 timeout.await;
438 tracing::info!("cleanup timeout expired, retrieving stale rooms");
439 if let Some((room_ids, channel_ids)) = app_state
440 .db
441 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
442 .await
443 .trace_err()
444 {
445 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
446 tracing::info!(
447 stale_channel_buffer_count = channel_ids.len(),
448 "retrieved stale channel buffers"
449 );
450
451 for channel_id in channel_ids {
452 if let Some(refreshed_channel_buffer) = app_state
453 .db
454 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
455 .await
456 .trace_err()
457 {
458 for connection_id in refreshed_channel_buffer.connection_ids {
459 peer.send(
460 connection_id,
461 proto::UpdateChannelBufferCollaborators {
462 channel_id: channel_id.to_proto(),
463 collaborators: refreshed_channel_buffer
464 .collaborators
465 .clone(),
466 },
467 )
468 .trace_err();
469 }
470 }
471 }
472
473 for room_id in room_ids {
474 let mut contacts_to_update = HashSet::default();
475 let mut canceled_calls_to_user_ids = Vec::new();
476 let mut livekit_room = String::new();
477 let mut delete_livekit_room = false;
478
479 if let Some(mut refreshed_room) = app_state
480 .db
481 .clear_stale_room_participants(room_id, server_id)
482 .await
483 .trace_err()
484 {
485 tracing::info!(
486 room_id = room_id.0,
487 new_participant_count = refreshed_room.room.participants.len(),
488 "refreshed room"
489 );
490 room_updated(&refreshed_room.room, &peer);
491 if let Some(channel) = refreshed_room.channel.as_ref() {
492 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
493 }
494 contacts_to_update
495 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
496 contacts_to_update
497 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
498 canceled_calls_to_user_ids =
499 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
500 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
501 delete_livekit_room = refreshed_room.room.participants.is_empty();
502 }
503
504 {
505 let pool = pool.lock();
506 for canceled_user_id in canceled_calls_to_user_ids {
507 for connection_id in pool.user_connection_ids(canceled_user_id) {
508 peer.send(
509 connection_id,
510 proto::CallCanceled {
511 room_id: room_id.to_proto(),
512 },
513 )
514 .trace_err();
515 }
516 }
517 }
518
519 for user_id in contacts_to_update {
520 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
521 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
522 if let Some((busy, contacts)) = busy.zip(contacts) {
523 let pool = pool.lock();
524 let updated_contact = contact_for_user(user_id, busy, &pool);
525 for contact in contacts {
526 if let db::Contact::Accepted {
527 user_id: contact_user_id,
528 ..
529 } = contact
530 {
531 for contact_conn_id in
532 pool.user_connection_ids(contact_user_id)
533 {
534 peer.send(
535 contact_conn_id,
536 proto::UpdateContacts {
537 contacts: vec![updated_contact.clone()],
538 remove_contacts: Default::default(),
539 incoming_requests: Default::default(),
540 remove_incoming_requests: Default::default(),
541 outgoing_requests: Default::default(),
542 remove_outgoing_requests: Default::default(),
543 },
544 )
545 .trace_err();
546 }
547 }
548 }
549 }
550 }
551
552 if let Some(live_kit) = livekit_client.as_ref() {
553 if delete_livekit_room {
554 live_kit.delete_room(livekit_room).await.trace_err();
555 }
556 }
557 }
558 }
559
560 app_state
561 .db
562 .delete_stale_servers(&app_state.config.zed_environment, server_id)
563 .await
564 .trace_err();
565 }
566 .instrument(span),
567 );
568 Ok(())
569 }
570
571 pub fn teardown(&self) {
572 self.peer.teardown();
573 self.connection_pool.lock().reset();
574 let _ = self.teardown.send(true);
575 }
576
577 #[cfg(test)]
578 pub fn reset(&self, id: ServerId) {
579 self.teardown();
580 *self.id.lock() = id;
581 self.peer.reset(id.0 as u32);
582 let _ = self.teardown.send(false);
583 }
584
585 #[cfg(test)]
586 pub fn id(&self) -> ServerId {
587 *self.id.lock()
588 }
589
590 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
591 where
592 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
593 Fut: 'static + Send + Future<Output = Result<()>>,
594 M: EnvelopedMessage,
595 {
596 let prev_handler = self.handlers.insert(
597 TypeId::of::<M>(),
598 Box::new(move |envelope, session| {
599 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
600 let received_at = envelope.received_at;
601 tracing::info!("message received");
602 let start_time = Instant::now();
603 let future = (handler)(*envelope, session);
604 async move {
605 let result = future.await;
606 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
607 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
608 let queue_duration_ms = total_duration_ms - processing_duration_ms;
609 let payload_type = M::NAME;
610
611 match result {
612 Err(error) => {
613 tracing::error!(
614 ?error,
615 total_duration_ms,
616 processing_duration_ms,
617 queue_duration_ms,
618 payload_type,
619 "error handling message"
620 )
621 }
622 Ok(()) => tracing::info!(
623 total_duration_ms,
624 processing_duration_ms,
625 queue_duration_ms,
626 "finished handling message"
627 ),
628 }
629 }
630 .boxed()
631 }),
632 );
633 if prev_handler.is_some() {
634 panic!("registered a handler for the same message twice");
635 }
636 self
637 }
638
639 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
640 where
641 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
642 Fut: 'static + Send + Future<Output = Result<()>>,
643 M: EnvelopedMessage,
644 {
645 self.add_handler(move |envelope, session| handler(envelope.payload, session));
646 self
647 }
648
649 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
650 where
651 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
652 Fut: Send + Future<Output = Result<()>>,
653 M: RequestMessage,
654 {
655 let handler = Arc::new(handler);
656 self.add_handler(move |envelope, session| {
657 let receipt = envelope.receipt();
658 let handler = handler.clone();
659 async move {
660 let peer = session.peer.clone();
661 let responded = Arc::new(AtomicBool::default());
662 let response = Response {
663 peer: peer.clone(),
664 responded: responded.clone(),
665 receipt,
666 };
667 match (handler)(envelope.payload, response, session).await {
668 Ok(()) => {
669 if responded.load(std::sync::atomic::Ordering::SeqCst) {
670 Ok(())
671 } else {
672 Err(anyhow!("handler did not send a response"))?
673 }
674 }
675 Err(error) => {
676 let proto_err = match &error {
677 Error::Internal(err) => err.to_proto(),
678 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
679 };
680 peer.respond_with_error(receipt, proto_err)?;
681 Err(error)
682 }
683 }
684 }
685 })
686 }
687
688 #[allow(clippy::too_many_arguments)]
689 pub fn handle_connection(
690 self: &Arc<Self>,
691 connection: Connection,
692 address: String,
693 principal: Principal,
694 zed_version: ZedVersion,
695 geoip_country_code: Option<String>,
696 system_id: Option<String>,
697 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
698 executor: Executor,
699 ) -> impl Future<Output = ()> {
700 let this = self.clone();
701 let span = info_span!("handle connection", %address,
702 connection_id=field::Empty,
703 user_id=field::Empty,
704 login=field::Empty,
705 impersonator=field::Empty,
706 geoip_country_code=field::Empty
707 );
708 principal.update_span(&span);
709 if let Some(country_code) = geoip_country_code.as_ref() {
710 span.record("geoip_country_code", country_code);
711 }
712
713 let mut teardown = self.teardown.subscribe();
714 async move {
715 if *teardown.borrow() {
716 tracing::error!("server is tearing down");
717 return
718 }
719 let (connection_id, handle_io, mut incoming_rx) = this
720 .peer
721 .add_connection(connection, {
722 let executor = executor.clone();
723 move |duration| executor.sleep(duration)
724 });
725 tracing::Span::current().record("connection_id", format!("{}", connection_id));
726
727 tracing::info!("connection opened");
728
729 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
730 let http_client = match ReqwestClient::user_agent(&user_agent) {
731 Ok(http_client) => Arc::new(http_client),
732 Err(error) => {
733 tracing::error!(?error, "failed to create HTTP client");
734 return;
735 }
736 };
737
738 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
739 supermaven_admin_api_key.to_string(),
740 http_client.clone(),
741 )));
742
743 let session = Session {
744 principal: principal.clone(),
745 connection_id,
746 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
747 peer: this.peer.clone(),
748 connection_pool: this.connection_pool.clone(),
749 app_state: this.app_state.clone(),
750 http_client,
751 geoip_country_code,
752 system_id,
753 _executor: executor.clone(),
754 supermaven_client,
755 };
756
757 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
758 tracing::error!(?error, "failed to send initial client update");
759 return;
760 }
761
762 let handle_io = handle_io.fuse();
763 futures::pin_mut!(handle_io);
764
765 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
766 // This prevents deadlocks when e.g., client A performs a request to client B and
767 // client B performs a request to client A. If both clients stop processing further
768 // messages until their respective request completes, they won't have a chance to
769 // respond to the other client's request and cause a deadlock.
770 //
771 // This arrangement ensures we will attempt to process earlier messages first, but fall
772 // back to processing messages arrived later in the spirit of making progress.
773 let mut foreground_message_handlers = FuturesUnordered::new();
774 let concurrent_handlers = Arc::new(Semaphore::new(256));
775 loop {
776 let next_message = async {
777 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
778 let message = incoming_rx.next().await;
779 (permit, message)
780 }.fuse();
781 futures::pin_mut!(next_message);
782 futures::select_biased! {
783 _ = teardown.changed().fuse() => return,
784 result = handle_io => {
785 if let Err(error) = result {
786 tracing::error!(?error, "error handling I/O");
787 }
788 break;
789 }
790 _ = foreground_message_handlers.next() => {}
791 next_message = next_message => {
792 let (permit, message) = next_message;
793 if let Some(message) = message {
794 let type_name = message.payload_type_name();
795 // note: we copy all the fields from the parent span so we can query them in the logs.
796 // (https://github.com/tokio-rs/tracing/issues/2670).
797 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
798 user_id=field::Empty,
799 login=field::Empty,
800 impersonator=field::Empty,
801 );
802 principal.update_span(&span);
803 let span_enter = span.enter();
804 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
805 let is_background = message.is_background();
806 let handle_message = (handler)(message, session.clone());
807 drop(span_enter);
808
809 let handle_message = async move {
810 handle_message.await;
811 drop(permit);
812 }.instrument(span);
813 if is_background {
814 executor.spawn_detached(handle_message);
815 } else {
816 foreground_message_handlers.push(handle_message);
817 }
818 } else {
819 tracing::error!("no message handler");
820 }
821 } else {
822 tracing::info!("connection closed");
823 break;
824 }
825 }
826 }
827 }
828
829 drop(foreground_message_handlers);
830 tracing::info!("signing out");
831 if let Err(error) = connection_lost(session, teardown, executor).await {
832 tracing::error!(?error, "error signing out");
833 }
834
835 }.instrument(span)
836 }
837
838 async fn send_initial_client_update(
839 &self,
840 connection_id: ConnectionId,
841 principal: &Principal,
842 zed_version: ZedVersion,
843 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
844 session: &Session,
845 ) -> Result<()> {
846 self.peer.send(
847 connection_id,
848 proto::Hello {
849 peer_id: Some(connection_id.into()),
850 },
851 )?;
852 tracing::info!("sent hello message");
853 if let Some(send_connection_id) = send_connection_id.take() {
854 let _ = send_connection_id.send(connection_id);
855 }
856
857 match principal {
858 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
859 if !user.connected_once {
860 self.peer.send(connection_id, proto::ShowContacts {})?;
861 self.app_state
862 .db
863 .set_user_connected_once(user.id, true)
864 .await?;
865 }
866
867 update_user_plan(user.id, session).await?;
868
869 let contacts = self.app_state.db.get_contacts(user.id).await?;
870
871 {
872 let mut pool = self.connection_pool.lock();
873 pool.add_connection(connection_id, user.id, user.admin, zed_version);
874 self.peer.send(
875 connection_id,
876 build_initial_contacts_update(contacts, &pool),
877 )?;
878 }
879
880 if should_auto_subscribe_to_channels(zed_version) {
881 subscribe_user_to_channels(user.id, session).await?;
882 }
883
884 if let Some(incoming_call) =
885 self.app_state.db.incoming_call_for_user(user.id).await?
886 {
887 self.peer.send(connection_id, incoming_call)?;
888 }
889
890 update_user_contacts(user.id, session).await?;
891 }
892 }
893
894 Ok(())
895 }
896
897 pub async fn invite_code_redeemed(
898 self: &Arc<Self>,
899 inviter_id: UserId,
900 invitee_id: UserId,
901 ) -> Result<()> {
902 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
903 if let Some(code) = &user.invite_code {
904 let pool = self.connection_pool.lock();
905 let invitee_contact = contact_for_user(invitee_id, false, &pool);
906 for connection_id in pool.user_connection_ids(inviter_id) {
907 self.peer.send(
908 connection_id,
909 proto::UpdateContacts {
910 contacts: vec![invitee_contact.clone()],
911 ..Default::default()
912 },
913 )?;
914 self.peer.send(
915 connection_id,
916 proto::UpdateInviteInfo {
917 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
918 count: user.invite_count as u32,
919 },
920 )?;
921 }
922 }
923 }
924 Ok(())
925 }
926
927 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
928 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
929 if let Some(invite_code) = &user.invite_code {
930 let pool = self.connection_pool.lock();
931 for connection_id in pool.user_connection_ids(user_id) {
932 self.peer.send(
933 connection_id,
934 proto::UpdateInviteInfo {
935 url: format!(
936 "{}{}",
937 self.app_state.config.invite_link_prefix, invite_code
938 ),
939 count: user.invite_count as u32,
940 },
941 )?;
942 }
943 }
944 }
945 Ok(())
946 }
947
948 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
949 let pool = self.connection_pool.lock();
950 for connection_id in pool.user_connection_ids(user_id) {
951 self.peer
952 .send(connection_id, proto::RefreshLlmToken {})
953 .trace_err();
954 }
955 }
956
957 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
958 ServerSnapshot {
959 connection_pool: ConnectionPoolGuard {
960 guard: self.connection_pool.lock(),
961 _not_send: PhantomData,
962 },
963 peer: &self.peer,
964 }
965 }
966}
967
968impl<'a> Deref for ConnectionPoolGuard<'a> {
969 type Target = ConnectionPool;
970
971 fn deref(&self) -> &Self::Target {
972 &self.guard
973 }
974}
975
976impl<'a> DerefMut for ConnectionPoolGuard<'a> {
977 fn deref_mut(&mut self) -> &mut Self::Target {
978 &mut self.guard
979 }
980}
981
982impl<'a> Drop for ConnectionPoolGuard<'a> {
983 fn drop(&mut self) {
984 #[cfg(test)]
985 self.check_invariants();
986 }
987}
988
989fn broadcast<F>(
990 sender_id: Option<ConnectionId>,
991 receiver_ids: impl IntoIterator<Item = ConnectionId>,
992 mut f: F,
993) where
994 F: FnMut(ConnectionId) -> anyhow::Result<()>,
995{
996 for receiver_id in receiver_ids {
997 if Some(receiver_id) != sender_id {
998 if let Err(error) = f(receiver_id) {
999 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000 }
1001 }
1002 }
1003}
1004
1005pub struct ProtocolVersion(u32);
1006
1007impl Header for ProtocolVersion {
1008 fn name() -> &'static HeaderName {
1009 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1010 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1011 }
1012
1013 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1014 where
1015 Self: Sized,
1016 I: Iterator<Item = &'i axum::http::HeaderValue>,
1017 {
1018 let version = values
1019 .next()
1020 .ok_or_else(axum::headers::Error::invalid)?
1021 .to_str()
1022 .map_err(|_| axum::headers::Error::invalid())?
1023 .parse()
1024 .map_err(|_| axum::headers::Error::invalid())?;
1025 Ok(Self(version))
1026 }
1027
1028 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1029 values.extend([self.0.to_string().parse().unwrap()]);
1030 }
1031}
1032
1033pub struct AppVersionHeader(SemanticVersion);
1034impl Header for AppVersionHeader {
1035 fn name() -> &'static HeaderName {
1036 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1037 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1038 }
1039
1040 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1041 where
1042 Self: Sized,
1043 I: Iterator<Item = &'i axum::http::HeaderValue>,
1044 {
1045 let version = values
1046 .next()
1047 .ok_or_else(axum::headers::Error::invalid)?
1048 .to_str()
1049 .map_err(|_| axum::headers::Error::invalid())?
1050 .parse()
1051 .map_err(|_| axum::headers::Error::invalid())?;
1052 Ok(Self(version))
1053 }
1054
1055 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1056 values.extend([self.0.to_string().parse().unwrap()]);
1057 }
1058}
1059
1060pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1061 Router::new()
1062 .route("/rpc", get(handle_websocket_request))
1063 .layer(
1064 ServiceBuilder::new()
1065 .layer(Extension(server.app_state.clone()))
1066 .layer(middleware::from_fn(auth::validate_header)),
1067 )
1068 .route("/metrics", get(handle_metrics))
1069 .layer(Extension(server))
1070}
1071
1072#[allow(clippy::too_many_arguments)]
1073pub async fn handle_websocket_request(
1074 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1075 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1076 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1077 Extension(server): Extension<Arc<Server>>,
1078 Extension(principal): Extension<Principal>,
1079 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1080 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1081 ws: WebSocketUpgrade,
1082) -> axum::response::Response {
1083 if protocol_version != rpc::PROTOCOL_VERSION {
1084 return (
1085 StatusCode::UPGRADE_REQUIRED,
1086 "client must be upgraded".to_string(),
1087 )
1088 .into_response();
1089 }
1090
1091 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1092 return (
1093 StatusCode::UPGRADE_REQUIRED,
1094 "no version header found".to_string(),
1095 )
1096 .into_response();
1097 };
1098
1099 if !version.can_collaborate() {
1100 return (
1101 StatusCode::UPGRADE_REQUIRED,
1102 "client must be upgraded".to_string(),
1103 )
1104 .into_response();
1105 }
1106
1107 let socket_address = socket_address.to_string();
1108 ws.on_upgrade(move |socket| {
1109 let socket = socket
1110 .map_ok(to_tungstenite_message)
1111 .err_into()
1112 .with(|message| async move { to_axum_message(message) });
1113 let connection = Connection::new(Box::pin(socket));
1114 async move {
1115 server
1116 .handle_connection(
1117 connection,
1118 socket_address,
1119 principal,
1120 version,
1121 country_code_header.map(|header| header.to_string()),
1122 system_id_header.map(|header| header.to_string()),
1123 None,
1124 Executor::Production,
1125 )
1126 .await;
1127 }
1128 })
1129}
1130
1131pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1132 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1133 let connections_metric = CONNECTIONS_METRIC
1134 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1135
1136 let connections = server
1137 .connection_pool
1138 .lock()
1139 .connections()
1140 .filter(|connection| !connection.admin)
1141 .count();
1142 connections_metric.set(connections as _);
1143
1144 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1146 register_int_gauge!(
1147 "shared_projects",
1148 "number of open projects with one or more guests"
1149 )
1150 .unwrap()
1151 });
1152
1153 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1154 shared_projects_metric.set(shared_projects as _);
1155
1156 let encoder = prometheus::TextEncoder::new();
1157 let metric_families = prometheus::gather();
1158 let encoded_metrics = encoder
1159 .encode_to_string(&metric_families)
1160 .map_err(|err| anyhow!("{}", err))?;
1161 Ok(encoded_metrics)
1162}
1163
1164#[instrument(err, skip(executor))]
1165async fn connection_lost(
1166 session: Session,
1167 mut teardown: watch::Receiver<bool>,
1168 executor: Executor,
1169) -> Result<()> {
1170 session.peer.disconnect(session.connection_id);
1171 session
1172 .connection_pool()
1173 .await
1174 .remove_connection(session.connection_id)?;
1175
1176 session
1177 .db()
1178 .await
1179 .connection_lost(session.connection_id)
1180 .await
1181 .trace_err();
1182
1183 futures::select_biased! {
1184 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1185
1186 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1187 leave_room_for_session(&session, session.connection_id).await.trace_err();
1188 leave_channel_buffers_for_session(&session)
1189 .await
1190 .trace_err();
1191
1192 if !session
1193 .connection_pool()
1194 .await
1195 .is_user_online(session.user_id())
1196 {
1197 let db = session.db().await;
1198 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1199 room_updated(&room, &session.peer);
1200 }
1201 }
1202
1203 update_user_contacts(session.user_id(), &session).await?;
1204 },
1205 _ = teardown.changed().fuse() => {}
1206 }
1207
1208 Ok(())
1209}
1210
1211/// Acknowledges a ping from a client, used to keep the connection alive.
1212async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1213 response.send(proto::Ack {})?;
1214 Ok(())
1215}
1216
1217/// Creates a new room for calling (outside of channels)
1218async fn create_room(
1219 _request: proto::CreateRoom,
1220 response: Response<proto::CreateRoom>,
1221 session: Session,
1222) -> Result<()> {
1223 let livekit_room = nanoid::nanoid!(30);
1224
1225 let live_kit_connection_info = util::maybe!(async {
1226 let live_kit = session.app_state.livekit_client.as_ref();
1227 let live_kit = live_kit?;
1228 let user_id = session.user_id().to_string();
1229
1230 let token = live_kit
1231 .room_token(&livekit_room, &user_id.to_string())
1232 .trace_err()?;
1233
1234 Some(proto::LiveKitConnectionInfo {
1235 server_url: live_kit.url().into(),
1236 token,
1237 can_publish: true,
1238 })
1239 })
1240 .await;
1241
1242 let room = session
1243 .db()
1244 .await
1245 .create_room(session.user_id(), session.connection_id, &livekit_room)
1246 .await?;
1247
1248 response.send(proto::CreateRoomResponse {
1249 room: Some(room.clone()),
1250 live_kit_connection_info,
1251 })?;
1252
1253 update_user_contacts(session.user_id(), &session).await?;
1254 Ok(())
1255}
1256
1257/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1258async fn join_room(
1259 request: proto::JoinRoom,
1260 response: Response<proto::JoinRoom>,
1261 session: Session,
1262) -> Result<()> {
1263 let room_id = RoomId::from_proto(request.id);
1264
1265 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1266
1267 if let Some(channel_id) = channel_id {
1268 return join_channel_internal(channel_id, Box::new(response), session).await;
1269 }
1270
1271 let joined_room = {
1272 let room = session
1273 .db()
1274 .await
1275 .join_room(room_id, session.user_id(), session.connection_id)
1276 .await?;
1277 room_updated(&room.room, &session.peer);
1278 room.into_inner()
1279 };
1280
1281 for connection_id in session
1282 .connection_pool()
1283 .await
1284 .user_connection_ids(session.user_id())
1285 {
1286 session
1287 .peer
1288 .send(
1289 connection_id,
1290 proto::CallCanceled {
1291 room_id: room_id.to_proto(),
1292 },
1293 )
1294 .trace_err();
1295 }
1296
1297 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1298 {
1299 live_kit
1300 .room_token(
1301 &joined_room.room.livekit_room,
1302 &session.user_id().to_string(),
1303 )
1304 .trace_err()
1305 .map(|token| proto::LiveKitConnectionInfo {
1306 server_url: live_kit.url().into(),
1307 token,
1308 can_publish: true,
1309 })
1310 } else {
1311 None
1312 };
1313
1314 response.send(proto::JoinRoomResponse {
1315 room: Some(joined_room.room),
1316 channel_id: None,
1317 live_kit_connection_info,
1318 })?;
1319
1320 update_user_contacts(session.user_id(), &session).await?;
1321 Ok(())
1322}
1323
1324/// Rejoin room is used to reconnect to a room after connection errors.
1325async fn rejoin_room(
1326 request: proto::RejoinRoom,
1327 response: Response<proto::RejoinRoom>,
1328 session: Session,
1329) -> Result<()> {
1330 let room;
1331 let channel;
1332 {
1333 let mut rejoined_room = session
1334 .db()
1335 .await
1336 .rejoin_room(request, session.user_id(), session.connection_id)
1337 .await?;
1338
1339 response.send(proto::RejoinRoomResponse {
1340 room: Some(rejoined_room.room.clone()),
1341 reshared_projects: rejoined_room
1342 .reshared_projects
1343 .iter()
1344 .map(|project| proto::ResharedProject {
1345 id: project.id.to_proto(),
1346 collaborators: project
1347 .collaborators
1348 .iter()
1349 .map(|collaborator| collaborator.to_proto())
1350 .collect(),
1351 })
1352 .collect(),
1353 rejoined_projects: rejoined_room
1354 .rejoined_projects
1355 .iter()
1356 .map(|rejoined_project| rejoined_project.to_proto())
1357 .collect(),
1358 })?;
1359 room_updated(&rejoined_room.room, &session.peer);
1360
1361 for project in &rejoined_room.reshared_projects {
1362 for collaborator in &project.collaborators {
1363 session
1364 .peer
1365 .send(
1366 collaborator.connection_id,
1367 proto::UpdateProjectCollaborator {
1368 project_id: project.id.to_proto(),
1369 old_peer_id: Some(project.old_connection_id.into()),
1370 new_peer_id: Some(session.connection_id.into()),
1371 },
1372 )
1373 .trace_err();
1374 }
1375
1376 broadcast(
1377 Some(session.connection_id),
1378 project
1379 .collaborators
1380 .iter()
1381 .map(|collaborator| collaborator.connection_id),
1382 |connection_id| {
1383 session.peer.forward_send(
1384 session.connection_id,
1385 connection_id,
1386 proto::UpdateProject {
1387 project_id: project.id.to_proto(),
1388 worktrees: project.worktrees.clone(),
1389 },
1390 )
1391 },
1392 );
1393 }
1394
1395 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1396
1397 let rejoined_room = rejoined_room.into_inner();
1398
1399 room = rejoined_room.room;
1400 channel = rejoined_room.channel;
1401 }
1402
1403 if let Some(channel) = channel {
1404 channel_updated(
1405 &channel,
1406 &room,
1407 &session.peer,
1408 &*session.connection_pool().await,
1409 );
1410 }
1411
1412 update_user_contacts(session.user_id(), &session).await?;
1413 Ok(())
1414}
1415
1416fn notify_rejoined_projects(
1417 rejoined_projects: &mut Vec<RejoinedProject>,
1418 session: &Session,
1419) -> Result<()> {
1420 for project in rejoined_projects.iter() {
1421 for collaborator in &project.collaborators {
1422 session
1423 .peer
1424 .send(
1425 collaborator.connection_id,
1426 proto::UpdateProjectCollaborator {
1427 project_id: project.id.to_proto(),
1428 old_peer_id: Some(project.old_connection_id.into()),
1429 new_peer_id: Some(session.connection_id.into()),
1430 },
1431 )
1432 .trace_err();
1433 }
1434 }
1435
1436 for project in rejoined_projects {
1437 for worktree in mem::take(&mut project.worktrees) {
1438 // Stream this worktree's entries.
1439 let message = proto::UpdateWorktree {
1440 project_id: project.id.to_proto(),
1441 worktree_id: worktree.id,
1442 abs_path: worktree.abs_path.clone(),
1443 root_name: worktree.root_name,
1444 updated_entries: worktree.updated_entries,
1445 removed_entries: worktree.removed_entries,
1446 scan_id: worktree.scan_id,
1447 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1448 updated_repositories: worktree.updated_repositories,
1449 removed_repositories: worktree.removed_repositories,
1450 };
1451 for update in proto::split_worktree_update(message) {
1452 session.peer.send(session.connection_id, update.clone())?;
1453 }
1454
1455 // Stream this worktree's diagnostics.
1456 for summary in worktree.diagnostic_summaries {
1457 session.peer.send(
1458 session.connection_id,
1459 proto::UpdateDiagnosticSummary {
1460 project_id: project.id.to_proto(),
1461 worktree_id: worktree.id,
1462 summary: Some(summary),
1463 },
1464 )?;
1465 }
1466
1467 for settings_file in worktree.settings_files {
1468 session.peer.send(
1469 session.connection_id,
1470 proto::UpdateWorktreeSettings {
1471 project_id: project.id.to_proto(),
1472 worktree_id: worktree.id,
1473 path: settings_file.path,
1474 content: Some(settings_file.content),
1475 kind: Some(settings_file.kind.to_proto().into()),
1476 },
1477 )?;
1478 }
1479 }
1480
1481 for language_server in &project.language_servers {
1482 session.peer.send(
1483 session.connection_id,
1484 proto::UpdateLanguageServer {
1485 project_id: project.id.to_proto(),
1486 language_server_id: language_server.id,
1487 variant: Some(
1488 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1489 proto::LspDiskBasedDiagnosticsUpdated {},
1490 ),
1491 ),
1492 },
1493 )?;
1494 }
1495 }
1496 Ok(())
1497}
1498
1499/// leave room disconnects from the room.
1500async fn leave_room(
1501 _: proto::LeaveRoom,
1502 response: Response<proto::LeaveRoom>,
1503 session: Session,
1504) -> Result<()> {
1505 leave_room_for_session(&session, session.connection_id).await?;
1506 response.send(proto::Ack {})?;
1507 Ok(())
1508}
1509
1510/// Updates the permissions of someone else in the room.
1511async fn set_room_participant_role(
1512 request: proto::SetRoomParticipantRole,
1513 response: Response<proto::SetRoomParticipantRole>,
1514 session: Session,
1515) -> Result<()> {
1516 let user_id = UserId::from_proto(request.user_id);
1517 let role = ChannelRole::from(request.role());
1518
1519 let (livekit_room, can_publish) = {
1520 let room = session
1521 .db()
1522 .await
1523 .set_room_participant_role(
1524 session.user_id(),
1525 RoomId::from_proto(request.room_id),
1526 user_id,
1527 role,
1528 )
1529 .await?;
1530
1531 let livekit_room = room.livekit_room.clone();
1532 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1533 room_updated(&room, &session.peer);
1534 (livekit_room, can_publish)
1535 };
1536
1537 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1538 live_kit
1539 .update_participant(
1540 livekit_room.clone(),
1541 request.user_id.to_string(),
1542 livekit_server::proto::ParticipantPermission {
1543 can_subscribe: true,
1544 can_publish,
1545 can_publish_data: can_publish,
1546 hidden: false,
1547 recorder: false,
1548 },
1549 )
1550 .await
1551 .trace_err();
1552 }
1553
1554 response.send(proto::Ack {})?;
1555 Ok(())
1556}
1557
1558/// Call someone else into the current room
1559async fn call(
1560 request: proto::Call,
1561 response: Response<proto::Call>,
1562 session: Session,
1563) -> Result<()> {
1564 let room_id = RoomId::from_proto(request.room_id);
1565 let calling_user_id = session.user_id();
1566 let calling_connection_id = session.connection_id;
1567 let called_user_id = UserId::from_proto(request.called_user_id);
1568 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1569 if !session
1570 .db()
1571 .await
1572 .has_contact(calling_user_id, called_user_id)
1573 .await?
1574 {
1575 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1576 }
1577
1578 let incoming_call = {
1579 let (room, incoming_call) = &mut *session
1580 .db()
1581 .await
1582 .call(
1583 room_id,
1584 calling_user_id,
1585 calling_connection_id,
1586 called_user_id,
1587 initial_project_id,
1588 )
1589 .await?;
1590 room_updated(room, &session.peer);
1591 mem::take(incoming_call)
1592 };
1593 update_user_contacts(called_user_id, &session).await?;
1594
1595 let mut calls = session
1596 .connection_pool()
1597 .await
1598 .user_connection_ids(called_user_id)
1599 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1600 .collect::<FuturesUnordered<_>>();
1601
1602 while let Some(call_response) = calls.next().await {
1603 match call_response.as_ref() {
1604 Ok(_) => {
1605 response.send(proto::Ack {})?;
1606 return Ok(());
1607 }
1608 Err(_) => {
1609 call_response.trace_err();
1610 }
1611 }
1612 }
1613
1614 {
1615 let room = session
1616 .db()
1617 .await
1618 .call_failed(room_id, called_user_id)
1619 .await?;
1620 room_updated(&room, &session.peer);
1621 }
1622 update_user_contacts(called_user_id, &session).await?;
1623
1624 Err(anyhow!("failed to ring user"))?
1625}
1626
1627/// Cancel an outgoing call.
1628async fn cancel_call(
1629 request: proto::CancelCall,
1630 response: Response<proto::CancelCall>,
1631 session: Session,
1632) -> Result<()> {
1633 let called_user_id = UserId::from_proto(request.called_user_id);
1634 let room_id = RoomId::from_proto(request.room_id);
1635 {
1636 let room = session
1637 .db()
1638 .await
1639 .cancel_call(room_id, session.connection_id, called_user_id)
1640 .await?;
1641 room_updated(&room, &session.peer);
1642 }
1643
1644 for connection_id in session
1645 .connection_pool()
1646 .await
1647 .user_connection_ids(called_user_id)
1648 {
1649 session
1650 .peer
1651 .send(
1652 connection_id,
1653 proto::CallCanceled {
1654 room_id: room_id.to_proto(),
1655 },
1656 )
1657 .trace_err();
1658 }
1659 response.send(proto::Ack {})?;
1660
1661 update_user_contacts(called_user_id, &session).await?;
1662 Ok(())
1663}
1664
1665/// Decline an incoming call.
1666async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1667 let room_id = RoomId::from_proto(message.room_id);
1668 {
1669 let room = session
1670 .db()
1671 .await
1672 .decline_call(Some(room_id), session.user_id())
1673 .await?
1674 .ok_or_else(|| anyhow!("failed to decline call"))?;
1675 room_updated(&room, &session.peer);
1676 }
1677
1678 for connection_id in session
1679 .connection_pool()
1680 .await
1681 .user_connection_ids(session.user_id())
1682 {
1683 session
1684 .peer
1685 .send(
1686 connection_id,
1687 proto::CallCanceled {
1688 room_id: room_id.to_proto(),
1689 },
1690 )
1691 .trace_err();
1692 }
1693 update_user_contacts(session.user_id(), &session).await?;
1694 Ok(())
1695}
1696
1697/// Updates other participants in the room with your current location.
1698async fn update_participant_location(
1699 request: proto::UpdateParticipantLocation,
1700 response: Response<proto::UpdateParticipantLocation>,
1701 session: Session,
1702) -> Result<()> {
1703 let room_id = RoomId::from_proto(request.room_id);
1704 let location = request
1705 .location
1706 .ok_or_else(|| anyhow!("invalid location"))?;
1707
1708 let db = session.db().await;
1709 let room = db
1710 .update_room_participant_location(room_id, session.connection_id, location)
1711 .await?;
1712
1713 room_updated(&room, &session.peer);
1714 response.send(proto::Ack {})?;
1715 Ok(())
1716}
1717
1718/// Share a project into the room.
1719async fn share_project(
1720 request: proto::ShareProject,
1721 response: Response<proto::ShareProject>,
1722 session: Session,
1723) -> Result<()> {
1724 let (project_id, room) = &*session
1725 .db()
1726 .await
1727 .share_project(
1728 RoomId::from_proto(request.room_id),
1729 session.connection_id,
1730 &request.worktrees,
1731 request.is_ssh_project,
1732 )
1733 .await?;
1734 response.send(proto::ShareProjectResponse {
1735 project_id: project_id.to_proto(),
1736 })?;
1737 room_updated(room, &session.peer);
1738
1739 Ok(())
1740}
1741
1742/// Unshare a project from the room.
1743async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1744 let project_id = ProjectId::from_proto(message.project_id);
1745 unshare_project_internal(project_id, session.connection_id, &session).await
1746}
1747
1748async fn unshare_project_internal(
1749 project_id: ProjectId,
1750 connection_id: ConnectionId,
1751 session: &Session,
1752) -> Result<()> {
1753 let delete = {
1754 let room_guard = session
1755 .db()
1756 .await
1757 .unshare_project(project_id, connection_id)
1758 .await?;
1759
1760 let (delete, room, guest_connection_ids) = &*room_guard;
1761
1762 let message = proto::UnshareProject {
1763 project_id: project_id.to_proto(),
1764 };
1765
1766 broadcast(
1767 Some(connection_id),
1768 guest_connection_ids.iter().copied(),
1769 |conn_id| session.peer.send(conn_id, message.clone()),
1770 );
1771 if let Some(room) = room {
1772 room_updated(room, &session.peer);
1773 }
1774
1775 *delete
1776 };
1777
1778 if delete {
1779 let db = session.db().await;
1780 db.delete_project(project_id).await?;
1781 }
1782
1783 Ok(())
1784}
1785
1786/// Join someone elses shared project.
1787async fn join_project(
1788 request: proto::JoinProject,
1789 response: Response<proto::JoinProject>,
1790 session: Session,
1791) -> Result<()> {
1792 let project_id = ProjectId::from_proto(request.project_id);
1793
1794 tracing::info!(%project_id, "join project");
1795
1796 let db = session.db().await;
1797 let (project, replica_id) = &mut *db
1798 .join_project(project_id, session.connection_id, session.user_id())
1799 .await?;
1800 drop(db);
1801 tracing::info!(%project_id, "join remote project");
1802 join_project_internal(response, session, project, replica_id)
1803}
1804
1805trait JoinProjectInternalResponse {
1806 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1807}
1808impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1809 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1810 Response::<proto::JoinProject>::send(self, result)
1811 }
1812}
1813
1814fn join_project_internal(
1815 response: impl JoinProjectInternalResponse,
1816 session: Session,
1817 project: &mut Project,
1818 replica_id: &ReplicaId,
1819) -> Result<()> {
1820 let collaborators = project
1821 .collaborators
1822 .iter()
1823 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1824 .map(|collaborator| collaborator.to_proto())
1825 .collect::<Vec<_>>();
1826 let project_id = project.id;
1827 let guest_user_id = session.user_id();
1828
1829 let worktrees = project
1830 .worktrees
1831 .iter()
1832 .map(|(id, worktree)| proto::WorktreeMetadata {
1833 id: *id,
1834 root_name: worktree.root_name.clone(),
1835 visible: worktree.visible,
1836 abs_path: worktree.abs_path.clone(),
1837 })
1838 .collect::<Vec<_>>();
1839
1840 let add_project_collaborator = proto::AddProjectCollaborator {
1841 project_id: project_id.to_proto(),
1842 collaborator: Some(proto::Collaborator {
1843 peer_id: Some(session.connection_id.into()),
1844 replica_id: replica_id.0 as u32,
1845 user_id: guest_user_id.to_proto(),
1846 is_host: false,
1847 }),
1848 };
1849
1850 for collaborator in &collaborators {
1851 session
1852 .peer
1853 .send(
1854 collaborator.peer_id.unwrap().into(),
1855 add_project_collaborator.clone(),
1856 )
1857 .trace_err();
1858 }
1859
1860 // First, we send the metadata associated with each worktree.
1861 response.send(proto::JoinProjectResponse {
1862 project_id: project.id.0 as u64,
1863 worktrees: worktrees.clone(),
1864 replica_id: replica_id.0 as u32,
1865 collaborators: collaborators.clone(),
1866 language_servers: project.language_servers.clone(),
1867 role: project.role.into(),
1868 })?;
1869
1870 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1871 // Stream this worktree's entries.
1872 let message = proto::UpdateWorktree {
1873 project_id: project_id.to_proto(),
1874 worktree_id,
1875 abs_path: worktree.abs_path.clone(),
1876 root_name: worktree.root_name,
1877 updated_entries: worktree.entries,
1878 removed_entries: Default::default(),
1879 scan_id: worktree.scan_id,
1880 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1881 updated_repositories: worktree.repository_entries.into_values().collect(),
1882 removed_repositories: Default::default(),
1883 };
1884 for update in proto::split_worktree_update(message) {
1885 session.peer.send(session.connection_id, update.clone())?;
1886 }
1887
1888 // Stream this worktree's diagnostics.
1889 for summary in worktree.diagnostic_summaries {
1890 session.peer.send(
1891 session.connection_id,
1892 proto::UpdateDiagnosticSummary {
1893 project_id: project_id.to_proto(),
1894 worktree_id: worktree.id,
1895 summary: Some(summary),
1896 },
1897 )?;
1898 }
1899
1900 for settings_file in worktree.settings_files {
1901 session.peer.send(
1902 session.connection_id,
1903 proto::UpdateWorktreeSettings {
1904 project_id: project_id.to_proto(),
1905 worktree_id: worktree.id,
1906 path: settings_file.path,
1907 content: Some(settings_file.content),
1908 kind: Some(settings_file.kind.to_proto() as i32),
1909 },
1910 )?;
1911 }
1912 }
1913
1914 for language_server in &project.language_servers {
1915 session.peer.send(
1916 session.connection_id,
1917 proto::UpdateLanguageServer {
1918 project_id: project_id.to_proto(),
1919 language_server_id: language_server.id,
1920 variant: Some(
1921 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1922 proto::LspDiskBasedDiagnosticsUpdated {},
1923 ),
1924 ),
1925 },
1926 )?;
1927 }
1928
1929 Ok(())
1930}
1931
1932/// Leave someone elses shared project.
1933async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1934 let sender_id = session.connection_id;
1935 let project_id = ProjectId::from_proto(request.project_id);
1936 let db = session.db().await;
1937
1938 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1939 tracing::info!(
1940 %project_id,
1941 "leave project"
1942 );
1943
1944 project_left(project, &session);
1945 if let Some(room) = room {
1946 room_updated(room, &session.peer);
1947 }
1948
1949 Ok(())
1950}
1951
1952/// Updates other participants with changes to the project
1953async fn update_project(
1954 request: proto::UpdateProject,
1955 response: Response<proto::UpdateProject>,
1956 session: Session,
1957) -> Result<()> {
1958 let project_id = ProjectId::from_proto(request.project_id);
1959 let (room, guest_connection_ids) = &*session
1960 .db()
1961 .await
1962 .update_project(project_id, session.connection_id, &request.worktrees)
1963 .await?;
1964 broadcast(
1965 Some(session.connection_id),
1966 guest_connection_ids.iter().copied(),
1967 |connection_id| {
1968 session
1969 .peer
1970 .forward_send(session.connection_id, connection_id, request.clone())
1971 },
1972 );
1973 if let Some(room) = room {
1974 room_updated(room, &session.peer);
1975 }
1976 response.send(proto::Ack {})?;
1977
1978 Ok(())
1979}
1980
1981/// Updates other participants with changes to the worktree
1982async fn update_worktree(
1983 request: proto::UpdateWorktree,
1984 response: Response<proto::UpdateWorktree>,
1985 session: Session,
1986) -> Result<()> {
1987 let guest_connection_ids = session
1988 .db()
1989 .await
1990 .update_worktree(&request, session.connection_id)
1991 .await?;
1992
1993 broadcast(
1994 Some(session.connection_id),
1995 guest_connection_ids.iter().copied(),
1996 |connection_id| {
1997 session
1998 .peer
1999 .forward_send(session.connection_id, connection_id, request.clone())
2000 },
2001 );
2002 response.send(proto::Ack {})?;
2003 Ok(())
2004}
2005
2006/// Updates other participants with changes to the diagnostics
2007async fn update_diagnostic_summary(
2008 message: proto::UpdateDiagnosticSummary,
2009 session: Session,
2010) -> Result<()> {
2011 let guest_connection_ids = session
2012 .db()
2013 .await
2014 .update_diagnostic_summary(&message, session.connection_id)
2015 .await?;
2016
2017 broadcast(
2018 Some(session.connection_id),
2019 guest_connection_ids.iter().copied(),
2020 |connection_id| {
2021 session
2022 .peer
2023 .forward_send(session.connection_id, connection_id, message.clone())
2024 },
2025 );
2026
2027 Ok(())
2028}
2029
2030/// Updates other participants with changes to the worktree settings
2031async fn update_worktree_settings(
2032 message: proto::UpdateWorktreeSettings,
2033 session: Session,
2034) -> Result<()> {
2035 let guest_connection_ids = session
2036 .db()
2037 .await
2038 .update_worktree_settings(&message, session.connection_id)
2039 .await?;
2040
2041 broadcast(
2042 Some(session.connection_id),
2043 guest_connection_ids.iter().copied(),
2044 |connection_id| {
2045 session
2046 .peer
2047 .forward_send(session.connection_id, connection_id, message.clone())
2048 },
2049 );
2050
2051 Ok(())
2052}
2053
2054/// Notify other participants that a language server has started.
2055async fn start_language_server(
2056 request: proto::StartLanguageServer,
2057 session: Session,
2058) -> Result<()> {
2059 let guest_connection_ids = session
2060 .db()
2061 .await
2062 .start_language_server(&request, session.connection_id)
2063 .await?;
2064
2065 broadcast(
2066 Some(session.connection_id),
2067 guest_connection_ids.iter().copied(),
2068 |connection_id| {
2069 session
2070 .peer
2071 .forward_send(session.connection_id, connection_id, request.clone())
2072 },
2073 );
2074 Ok(())
2075}
2076
2077/// Notify other participants that a language server has changed.
2078async fn update_language_server(
2079 request: proto::UpdateLanguageServer,
2080 session: Session,
2081) -> Result<()> {
2082 let project_id = ProjectId::from_proto(request.project_id);
2083 let project_connection_ids = session
2084 .db()
2085 .await
2086 .project_connection_ids(project_id, session.connection_id, true)
2087 .await?;
2088 broadcast(
2089 Some(session.connection_id),
2090 project_connection_ids.iter().copied(),
2091 |connection_id| {
2092 session
2093 .peer
2094 .forward_send(session.connection_id, connection_id, request.clone())
2095 },
2096 );
2097 Ok(())
2098}
2099
2100/// forward a project request to the host. These requests should be read only
2101/// as guests are allowed to send them.
2102async fn forward_read_only_project_request<T>(
2103 request: T,
2104 response: Response<T>,
2105 session: Session,
2106) -> Result<()>
2107where
2108 T: EntityMessage + RequestMessage,
2109{
2110 let project_id = ProjectId::from_proto(request.remote_entity_id());
2111 let host_connection_id = session
2112 .db()
2113 .await
2114 .host_for_read_only_project_request(project_id, session.connection_id)
2115 .await?;
2116 let payload = session
2117 .peer
2118 .forward_request(session.connection_id, host_connection_id, request)
2119 .await?;
2120 response.send(payload)?;
2121 Ok(())
2122}
2123
2124async fn forward_find_search_candidates_request(
2125 request: proto::FindSearchCandidates,
2126 response: Response<proto::FindSearchCandidates>,
2127 session: Session,
2128) -> Result<()> {
2129 let project_id = ProjectId::from_proto(request.remote_entity_id());
2130 let host_connection_id = session
2131 .db()
2132 .await
2133 .host_for_read_only_project_request(project_id, session.connection_id)
2134 .await?;
2135 let payload = session
2136 .peer
2137 .forward_request(session.connection_id, host_connection_id, request)
2138 .await?;
2139 response.send(payload)?;
2140 Ok(())
2141}
2142
2143/// forward a project request to the host. These requests are disallowed
2144/// for guests.
2145async fn forward_mutating_project_request<T>(
2146 request: T,
2147 response: Response<T>,
2148 session: Session,
2149) -> Result<()>
2150where
2151 T: EntityMessage + RequestMessage,
2152{
2153 let project_id = ProjectId::from_proto(request.remote_entity_id());
2154
2155 let host_connection_id = session
2156 .db()
2157 .await
2158 .host_for_mutating_project_request(project_id, session.connection_id)
2159 .await?;
2160 let payload = session
2161 .peer
2162 .forward_request(session.connection_id, host_connection_id, request)
2163 .await?;
2164 response.send(payload)?;
2165 Ok(())
2166}
2167
2168/// Notify other participants that a new buffer has been created
2169async fn create_buffer_for_peer(
2170 request: proto::CreateBufferForPeer,
2171 session: Session,
2172) -> Result<()> {
2173 session
2174 .db()
2175 .await
2176 .check_user_is_project_host(
2177 ProjectId::from_proto(request.project_id),
2178 session.connection_id,
2179 )
2180 .await?;
2181 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2182 session
2183 .peer
2184 .forward_send(session.connection_id, peer_id.into(), request)?;
2185 Ok(())
2186}
2187
2188/// Notify other participants that a buffer has been updated. This is
2189/// allowed for guests as long as the update is limited to selections.
2190async fn update_buffer(
2191 request: proto::UpdateBuffer,
2192 response: Response<proto::UpdateBuffer>,
2193 session: Session,
2194) -> Result<()> {
2195 let project_id = ProjectId::from_proto(request.project_id);
2196 let mut capability = Capability::ReadOnly;
2197
2198 for op in request.operations.iter() {
2199 match op.variant {
2200 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2201 Some(_) => capability = Capability::ReadWrite,
2202 }
2203 }
2204
2205 let host = {
2206 let guard = session
2207 .db()
2208 .await
2209 .connections_for_buffer_update(project_id, session.connection_id, capability)
2210 .await?;
2211
2212 let (host, guests) = &*guard;
2213
2214 broadcast(
2215 Some(session.connection_id),
2216 guests.clone(),
2217 |connection_id| {
2218 session
2219 .peer
2220 .forward_send(session.connection_id, connection_id, request.clone())
2221 },
2222 );
2223
2224 *host
2225 };
2226
2227 if host != session.connection_id {
2228 session
2229 .peer
2230 .forward_request(session.connection_id, host, request.clone())
2231 .await?;
2232 }
2233
2234 response.send(proto::Ack {})?;
2235 Ok(())
2236}
2237
2238async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2239 let project_id = ProjectId::from_proto(message.project_id);
2240
2241 let operation = message.operation.as_ref().context("invalid operation")?;
2242 let capability = match operation.variant.as_ref() {
2243 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2244 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2245 match buffer_op.variant {
2246 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2247 Capability::ReadOnly
2248 }
2249 _ => Capability::ReadWrite,
2250 }
2251 } else {
2252 Capability::ReadWrite
2253 }
2254 }
2255 Some(_) => Capability::ReadWrite,
2256 None => Capability::ReadOnly,
2257 };
2258
2259 let guard = session
2260 .db()
2261 .await
2262 .connections_for_buffer_update(project_id, session.connection_id, capability)
2263 .await?;
2264
2265 let (host, guests) = &*guard;
2266
2267 broadcast(
2268 Some(session.connection_id),
2269 guests.iter().chain([host]).copied(),
2270 |connection_id| {
2271 session
2272 .peer
2273 .forward_send(session.connection_id, connection_id, message.clone())
2274 },
2275 );
2276
2277 Ok(())
2278}
2279
2280/// Notify other participants that a project has been updated.
2281async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2282 request: T,
2283 session: Session,
2284) -> Result<()> {
2285 let project_id = ProjectId::from_proto(request.remote_entity_id());
2286 let project_connection_ids = session
2287 .db()
2288 .await
2289 .project_connection_ids(project_id, session.connection_id, false)
2290 .await?;
2291
2292 broadcast(
2293 Some(session.connection_id),
2294 project_connection_ids.iter().copied(),
2295 |connection_id| {
2296 session
2297 .peer
2298 .forward_send(session.connection_id, connection_id, request.clone())
2299 },
2300 );
2301 Ok(())
2302}
2303
2304/// Start following another user in a call.
2305async fn follow(
2306 request: proto::Follow,
2307 response: Response<proto::Follow>,
2308 session: Session,
2309) -> Result<()> {
2310 let room_id = RoomId::from_proto(request.room_id);
2311 let project_id = request.project_id.map(ProjectId::from_proto);
2312 let leader_id = request
2313 .leader_id
2314 .ok_or_else(|| anyhow!("invalid leader id"))?
2315 .into();
2316 let follower_id = session.connection_id;
2317
2318 session
2319 .db()
2320 .await
2321 .check_room_participants(room_id, leader_id, session.connection_id)
2322 .await?;
2323
2324 let response_payload = session
2325 .peer
2326 .forward_request(session.connection_id, leader_id, request)
2327 .await?;
2328 response.send(response_payload)?;
2329
2330 if let Some(project_id) = project_id {
2331 let room = session
2332 .db()
2333 .await
2334 .follow(room_id, project_id, leader_id, follower_id)
2335 .await?;
2336 room_updated(&room, &session.peer);
2337 }
2338
2339 Ok(())
2340}
2341
2342/// Stop following another user in a call.
2343async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2344 let room_id = RoomId::from_proto(request.room_id);
2345 let project_id = request.project_id.map(ProjectId::from_proto);
2346 let leader_id = request
2347 .leader_id
2348 .ok_or_else(|| anyhow!("invalid leader id"))?
2349 .into();
2350 let follower_id = session.connection_id;
2351
2352 session
2353 .db()
2354 .await
2355 .check_room_participants(room_id, leader_id, session.connection_id)
2356 .await?;
2357
2358 session
2359 .peer
2360 .forward_send(session.connection_id, leader_id, request)?;
2361
2362 if let Some(project_id) = project_id {
2363 let room = session
2364 .db()
2365 .await
2366 .unfollow(room_id, project_id, leader_id, follower_id)
2367 .await?;
2368 room_updated(&room, &session.peer);
2369 }
2370
2371 Ok(())
2372}
2373
2374/// Notify everyone following you of your current location.
2375async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2376 let room_id = RoomId::from_proto(request.room_id);
2377 let database = session.db.lock().await;
2378
2379 let connection_ids = if let Some(project_id) = request.project_id {
2380 let project_id = ProjectId::from_proto(project_id);
2381 database
2382 .project_connection_ids(project_id, session.connection_id, true)
2383 .await?
2384 } else {
2385 database
2386 .room_connection_ids(room_id, session.connection_id)
2387 .await?
2388 };
2389
2390 // For now, don't send view update messages back to that view's current leader.
2391 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2392 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2393 _ => None,
2394 });
2395
2396 for connection_id in connection_ids.iter().cloned() {
2397 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2398 session
2399 .peer
2400 .forward_send(session.connection_id, connection_id, request.clone())?;
2401 }
2402 }
2403 Ok(())
2404}
2405
2406/// Get public data about users.
2407async fn get_users(
2408 request: proto::GetUsers,
2409 response: Response<proto::GetUsers>,
2410 session: Session,
2411) -> Result<()> {
2412 let user_ids = request
2413 .user_ids
2414 .into_iter()
2415 .map(UserId::from_proto)
2416 .collect();
2417 let users = session
2418 .db()
2419 .await
2420 .get_users_by_ids(user_ids)
2421 .await?
2422 .into_iter()
2423 .map(|user| proto::User {
2424 id: user.id.to_proto(),
2425 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2426 github_login: user.github_login,
2427 email: user.email_address,
2428 name: user.name,
2429 })
2430 .collect();
2431 response.send(proto::UsersResponse { users })?;
2432 Ok(())
2433}
2434
2435/// Search for users (to invite) buy Github login
2436async fn fuzzy_search_users(
2437 request: proto::FuzzySearchUsers,
2438 response: Response<proto::FuzzySearchUsers>,
2439 session: Session,
2440) -> Result<()> {
2441 let query = request.query;
2442 let users = match query.len() {
2443 0 => vec![],
2444 1 | 2 => session
2445 .db()
2446 .await
2447 .get_user_by_github_login(&query)
2448 .await?
2449 .into_iter()
2450 .collect(),
2451 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2452 };
2453 let users = users
2454 .into_iter()
2455 .filter(|user| user.id != session.user_id())
2456 .map(|user| proto::User {
2457 id: user.id.to_proto(),
2458 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2459 github_login: user.github_login,
2460 name: user.name,
2461 email: user.email_address,
2462 })
2463 .collect();
2464 response.send(proto::UsersResponse { users })?;
2465 Ok(())
2466}
2467
2468/// Send a contact request to another user.
2469async fn request_contact(
2470 request: proto::RequestContact,
2471 response: Response<proto::RequestContact>,
2472 session: Session,
2473) -> Result<()> {
2474 let requester_id = session.user_id();
2475 let responder_id = UserId::from_proto(request.responder_id);
2476 if requester_id == responder_id {
2477 return Err(anyhow!("cannot add yourself as a contact"))?;
2478 }
2479
2480 let notifications = session
2481 .db()
2482 .await
2483 .send_contact_request(requester_id, responder_id)
2484 .await?;
2485
2486 // Update outgoing contact requests of requester
2487 let mut update = proto::UpdateContacts::default();
2488 update.outgoing_requests.push(responder_id.to_proto());
2489 for connection_id in session
2490 .connection_pool()
2491 .await
2492 .user_connection_ids(requester_id)
2493 {
2494 session.peer.send(connection_id, update.clone())?;
2495 }
2496
2497 // Update incoming contact requests of responder
2498 let mut update = proto::UpdateContacts::default();
2499 update
2500 .incoming_requests
2501 .push(proto::IncomingContactRequest {
2502 requester_id: requester_id.to_proto(),
2503 });
2504 let connection_pool = session.connection_pool().await;
2505 for connection_id in connection_pool.user_connection_ids(responder_id) {
2506 session.peer.send(connection_id, update.clone())?;
2507 }
2508
2509 send_notifications(&connection_pool, &session.peer, notifications);
2510
2511 response.send(proto::Ack {})?;
2512 Ok(())
2513}
2514
2515/// Accept or decline a contact request
2516async fn respond_to_contact_request(
2517 request: proto::RespondToContactRequest,
2518 response: Response<proto::RespondToContactRequest>,
2519 session: Session,
2520) -> Result<()> {
2521 let responder_id = session.user_id();
2522 let requester_id = UserId::from_proto(request.requester_id);
2523 let db = session.db().await;
2524 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2525 db.dismiss_contact_notification(responder_id, requester_id)
2526 .await?;
2527 } else {
2528 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2529
2530 let notifications = db
2531 .respond_to_contact_request(responder_id, requester_id, accept)
2532 .await?;
2533 let requester_busy = db.is_user_busy(requester_id).await?;
2534 let responder_busy = db.is_user_busy(responder_id).await?;
2535
2536 let pool = session.connection_pool().await;
2537 // Update responder with new contact
2538 let mut update = proto::UpdateContacts::default();
2539 if accept {
2540 update
2541 .contacts
2542 .push(contact_for_user(requester_id, requester_busy, &pool));
2543 }
2544 update
2545 .remove_incoming_requests
2546 .push(requester_id.to_proto());
2547 for connection_id in pool.user_connection_ids(responder_id) {
2548 session.peer.send(connection_id, update.clone())?;
2549 }
2550
2551 // Update requester with new contact
2552 let mut update = proto::UpdateContacts::default();
2553 if accept {
2554 update
2555 .contacts
2556 .push(contact_for_user(responder_id, responder_busy, &pool));
2557 }
2558 update
2559 .remove_outgoing_requests
2560 .push(responder_id.to_proto());
2561
2562 for connection_id in pool.user_connection_ids(requester_id) {
2563 session.peer.send(connection_id, update.clone())?;
2564 }
2565
2566 send_notifications(&pool, &session.peer, notifications);
2567 }
2568
2569 response.send(proto::Ack {})?;
2570 Ok(())
2571}
2572
2573/// Remove a contact.
2574async fn remove_contact(
2575 request: proto::RemoveContact,
2576 response: Response<proto::RemoveContact>,
2577 session: Session,
2578) -> Result<()> {
2579 let requester_id = session.user_id();
2580 let responder_id = UserId::from_proto(request.user_id);
2581 let db = session.db().await;
2582 let (contact_accepted, deleted_notification_id) =
2583 db.remove_contact(requester_id, responder_id).await?;
2584
2585 let pool = session.connection_pool().await;
2586 // Update outgoing contact requests of requester
2587 let mut update = proto::UpdateContacts::default();
2588 if contact_accepted {
2589 update.remove_contacts.push(responder_id.to_proto());
2590 } else {
2591 update
2592 .remove_outgoing_requests
2593 .push(responder_id.to_proto());
2594 }
2595 for connection_id in pool.user_connection_ids(requester_id) {
2596 session.peer.send(connection_id, update.clone())?;
2597 }
2598
2599 // Update incoming contact requests of responder
2600 let mut update = proto::UpdateContacts::default();
2601 if contact_accepted {
2602 update.remove_contacts.push(requester_id.to_proto());
2603 } else {
2604 update
2605 .remove_incoming_requests
2606 .push(requester_id.to_proto());
2607 }
2608 for connection_id in pool.user_connection_ids(responder_id) {
2609 session.peer.send(connection_id, update.clone())?;
2610 if let Some(notification_id) = deleted_notification_id {
2611 session.peer.send(
2612 connection_id,
2613 proto::DeleteNotification {
2614 notification_id: notification_id.to_proto(),
2615 },
2616 )?;
2617 }
2618 }
2619
2620 response.send(proto::Ack {})?;
2621 Ok(())
2622}
2623
2624fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2625 version.0.minor() < 139
2626}
2627
2628async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2629 let plan = session.current_plan(&session.db().await).await?;
2630
2631 session
2632 .peer
2633 .send(
2634 session.connection_id,
2635 proto::UpdateUserPlan { plan: plan.into() },
2636 )
2637 .trace_err();
2638
2639 Ok(())
2640}
2641
2642async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2643 subscribe_user_to_channels(session.user_id(), &session).await?;
2644 Ok(())
2645}
2646
2647async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2648 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2649 let mut pool = session.connection_pool().await;
2650 for membership in &channels_for_user.channel_memberships {
2651 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2652 }
2653 session.peer.send(
2654 session.connection_id,
2655 build_update_user_channels(&channels_for_user),
2656 )?;
2657 session.peer.send(
2658 session.connection_id,
2659 build_channels_update(channels_for_user),
2660 )?;
2661 Ok(())
2662}
2663
2664/// Creates a new channel.
2665async fn create_channel(
2666 request: proto::CreateChannel,
2667 response: Response<proto::CreateChannel>,
2668 session: Session,
2669) -> Result<()> {
2670 let db = session.db().await;
2671
2672 let parent_id = request.parent_id.map(ChannelId::from_proto);
2673 let (channel, membership) = db
2674 .create_channel(&request.name, parent_id, session.user_id())
2675 .await?;
2676
2677 let root_id = channel.root_id();
2678 let channel = Channel::from_model(channel);
2679
2680 response.send(proto::CreateChannelResponse {
2681 channel: Some(channel.to_proto()),
2682 parent_id: request.parent_id,
2683 })?;
2684
2685 let mut connection_pool = session.connection_pool().await;
2686 if let Some(membership) = membership {
2687 connection_pool.subscribe_to_channel(
2688 membership.user_id,
2689 membership.channel_id,
2690 membership.role,
2691 );
2692 let update = proto::UpdateUserChannels {
2693 channel_memberships: vec![proto::ChannelMembership {
2694 channel_id: membership.channel_id.to_proto(),
2695 role: membership.role.into(),
2696 }],
2697 ..Default::default()
2698 };
2699 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2700 session.peer.send(connection_id, update.clone())?;
2701 }
2702 }
2703
2704 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2705 if !role.can_see_channel(channel.visibility) {
2706 continue;
2707 }
2708
2709 let update = proto::UpdateChannels {
2710 channels: vec![channel.to_proto()],
2711 ..Default::default()
2712 };
2713 session.peer.send(connection_id, update.clone())?;
2714 }
2715
2716 Ok(())
2717}
2718
2719/// Delete a channel
2720async fn delete_channel(
2721 request: proto::DeleteChannel,
2722 response: Response<proto::DeleteChannel>,
2723 session: Session,
2724) -> Result<()> {
2725 let db = session.db().await;
2726
2727 let channel_id = request.channel_id;
2728 let (root_channel, removed_channels) = db
2729 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2730 .await?;
2731 response.send(proto::Ack {})?;
2732
2733 // Notify members of removed channels
2734 let mut update = proto::UpdateChannels::default();
2735 update
2736 .delete_channels
2737 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2738
2739 let connection_pool = session.connection_pool().await;
2740 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2741 session.peer.send(connection_id, update.clone())?;
2742 }
2743
2744 Ok(())
2745}
2746
2747/// Invite someone to join a channel.
2748async fn invite_channel_member(
2749 request: proto::InviteChannelMember,
2750 response: Response<proto::InviteChannelMember>,
2751 session: Session,
2752) -> Result<()> {
2753 let db = session.db().await;
2754 let channel_id = ChannelId::from_proto(request.channel_id);
2755 let invitee_id = UserId::from_proto(request.user_id);
2756 let InviteMemberResult {
2757 channel,
2758 notifications,
2759 } = db
2760 .invite_channel_member(
2761 channel_id,
2762 invitee_id,
2763 session.user_id(),
2764 request.role().into(),
2765 )
2766 .await?;
2767
2768 let update = proto::UpdateChannels {
2769 channel_invitations: vec![channel.to_proto()],
2770 ..Default::default()
2771 };
2772
2773 let connection_pool = session.connection_pool().await;
2774 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2775 session.peer.send(connection_id, update.clone())?;
2776 }
2777
2778 send_notifications(&connection_pool, &session.peer, notifications);
2779
2780 response.send(proto::Ack {})?;
2781 Ok(())
2782}
2783
2784/// remove someone from a channel
2785async fn remove_channel_member(
2786 request: proto::RemoveChannelMember,
2787 response: Response<proto::RemoveChannelMember>,
2788 session: Session,
2789) -> Result<()> {
2790 let db = session.db().await;
2791 let channel_id = ChannelId::from_proto(request.channel_id);
2792 let member_id = UserId::from_proto(request.user_id);
2793
2794 let RemoveChannelMemberResult {
2795 membership_update,
2796 notification_id,
2797 } = db
2798 .remove_channel_member(channel_id, member_id, session.user_id())
2799 .await?;
2800
2801 let mut connection_pool = session.connection_pool().await;
2802 notify_membership_updated(
2803 &mut connection_pool,
2804 membership_update,
2805 member_id,
2806 &session.peer,
2807 );
2808 for connection_id in connection_pool.user_connection_ids(member_id) {
2809 if let Some(notification_id) = notification_id {
2810 session
2811 .peer
2812 .send(
2813 connection_id,
2814 proto::DeleteNotification {
2815 notification_id: notification_id.to_proto(),
2816 },
2817 )
2818 .trace_err();
2819 }
2820 }
2821
2822 response.send(proto::Ack {})?;
2823 Ok(())
2824}
2825
2826/// Toggle the channel between public and private.
2827/// Care is taken to maintain the invariant that public channels only descend from public channels,
2828/// (though members-only channels can appear at any point in the hierarchy).
2829async fn set_channel_visibility(
2830 request: proto::SetChannelVisibility,
2831 response: Response<proto::SetChannelVisibility>,
2832 session: Session,
2833) -> Result<()> {
2834 let db = session.db().await;
2835 let channel_id = ChannelId::from_proto(request.channel_id);
2836 let visibility = request.visibility().into();
2837
2838 let channel_model = db
2839 .set_channel_visibility(channel_id, visibility, session.user_id())
2840 .await?;
2841 let root_id = channel_model.root_id();
2842 let channel = Channel::from_model(channel_model);
2843
2844 let mut connection_pool = session.connection_pool().await;
2845 for (user_id, role) in connection_pool
2846 .channel_user_ids(root_id)
2847 .collect::<Vec<_>>()
2848 .into_iter()
2849 {
2850 let update = if role.can_see_channel(channel.visibility) {
2851 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2852 proto::UpdateChannels {
2853 channels: vec![channel.to_proto()],
2854 ..Default::default()
2855 }
2856 } else {
2857 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2858 proto::UpdateChannels {
2859 delete_channels: vec![channel.id.to_proto()],
2860 ..Default::default()
2861 }
2862 };
2863
2864 for connection_id in connection_pool.user_connection_ids(user_id) {
2865 session.peer.send(connection_id, update.clone())?;
2866 }
2867 }
2868
2869 response.send(proto::Ack {})?;
2870 Ok(())
2871}
2872
2873/// Alter the role for a user in the channel.
2874async fn set_channel_member_role(
2875 request: proto::SetChannelMemberRole,
2876 response: Response<proto::SetChannelMemberRole>,
2877 session: Session,
2878) -> Result<()> {
2879 let db = session.db().await;
2880 let channel_id = ChannelId::from_proto(request.channel_id);
2881 let member_id = UserId::from_proto(request.user_id);
2882 let result = db
2883 .set_channel_member_role(
2884 channel_id,
2885 session.user_id(),
2886 member_id,
2887 request.role().into(),
2888 )
2889 .await?;
2890
2891 match result {
2892 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2893 let mut connection_pool = session.connection_pool().await;
2894 notify_membership_updated(
2895 &mut connection_pool,
2896 membership_update,
2897 member_id,
2898 &session.peer,
2899 )
2900 }
2901 db::SetMemberRoleResult::InviteUpdated(channel) => {
2902 let update = proto::UpdateChannels {
2903 channel_invitations: vec![channel.to_proto()],
2904 ..Default::default()
2905 };
2906
2907 for connection_id in session
2908 .connection_pool()
2909 .await
2910 .user_connection_ids(member_id)
2911 {
2912 session.peer.send(connection_id, update.clone())?;
2913 }
2914 }
2915 }
2916
2917 response.send(proto::Ack {})?;
2918 Ok(())
2919}
2920
2921/// Change the name of a channel
2922async fn rename_channel(
2923 request: proto::RenameChannel,
2924 response: Response<proto::RenameChannel>,
2925 session: Session,
2926) -> Result<()> {
2927 let db = session.db().await;
2928 let channel_id = ChannelId::from_proto(request.channel_id);
2929 let channel_model = db
2930 .rename_channel(channel_id, session.user_id(), &request.name)
2931 .await?;
2932 let root_id = channel_model.root_id();
2933 let channel = Channel::from_model(channel_model);
2934
2935 response.send(proto::RenameChannelResponse {
2936 channel: Some(channel.to_proto()),
2937 })?;
2938
2939 let connection_pool = session.connection_pool().await;
2940 let update = proto::UpdateChannels {
2941 channels: vec![channel.to_proto()],
2942 ..Default::default()
2943 };
2944 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2945 if role.can_see_channel(channel.visibility) {
2946 session.peer.send(connection_id, update.clone())?;
2947 }
2948 }
2949
2950 Ok(())
2951}
2952
2953/// Move a channel to a new parent.
2954async fn move_channel(
2955 request: proto::MoveChannel,
2956 response: Response<proto::MoveChannel>,
2957 session: Session,
2958) -> Result<()> {
2959 let channel_id = ChannelId::from_proto(request.channel_id);
2960 let to = ChannelId::from_proto(request.to);
2961
2962 let (root_id, channels) = session
2963 .db()
2964 .await
2965 .move_channel(channel_id, to, session.user_id())
2966 .await?;
2967
2968 let connection_pool = session.connection_pool().await;
2969 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2970 let channels = channels
2971 .iter()
2972 .filter_map(|channel| {
2973 if role.can_see_channel(channel.visibility) {
2974 Some(channel.to_proto())
2975 } else {
2976 None
2977 }
2978 })
2979 .collect::<Vec<_>>();
2980 if channels.is_empty() {
2981 continue;
2982 }
2983
2984 let update = proto::UpdateChannels {
2985 channels,
2986 ..Default::default()
2987 };
2988
2989 session.peer.send(connection_id, update.clone())?;
2990 }
2991
2992 response.send(Ack {})?;
2993 Ok(())
2994}
2995
2996/// Get the list of channel members
2997async fn get_channel_members(
2998 request: proto::GetChannelMembers,
2999 response: Response<proto::GetChannelMembers>,
3000 session: Session,
3001) -> Result<()> {
3002 let db = session.db().await;
3003 let channel_id = ChannelId::from_proto(request.channel_id);
3004 let limit = if request.limit == 0 {
3005 u16::MAX as u64
3006 } else {
3007 request.limit
3008 };
3009 let (members, users) = db
3010 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3011 .await?;
3012 response.send(proto::GetChannelMembersResponse { members, users })?;
3013 Ok(())
3014}
3015
3016/// Accept or decline a channel invitation.
3017async fn respond_to_channel_invite(
3018 request: proto::RespondToChannelInvite,
3019 response: Response<proto::RespondToChannelInvite>,
3020 session: Session,
3021) -> Result<()> {
3022 let db = session.db().await;
3023 let channel_id = ChannelId::from_proto(request.channel_id);
3024 let RespondToChannelInvite {
3025 membership_update,
3026 notifications,
3027 } = db
3028 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3029 .await?;
3030
3031 let mut connection_pool = session.connection_pool().await;
3032 if let Some(membership_update) = membership_update {
3033 notify_membership_updated(
3034 &mut connection_pool,
3035 membership_update,
3036 session.user_id(),
3037 &session.peer,
3038 );
3039 } else {
3040 let update = proto::UpdateChannels {
3041 remove_channel_invitations: vec![channel_id.to_proto()],
3042 ..Default::default()
3043 };
3044
3045 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3046 session.peer.send(connection_id, update.clone())?;
3047 }
3048 };
3049
3050 send_notifications(&connection_pool, &session.peer, notifications);
3051
3052 response.send(proto::Ack {})?;
3053
3054 Ok(())
3055}
3056
3057/// Join the channels' room
3058async fn join_channel(
3059 request: proto::JoinChannel,
3060 response: Response<proto::JoinChannel>,
3061 session: Session,
3062) -> Result<()> {
3063 let channel_id = ChannelId::from_proto(request.channel_id);
3064 join_channel_internal(channel_id, Box::new(response), session).await
3065}
3066
3067trait JoinChannelInternalResponse {
3068 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3069}
3070impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3071 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3072 Response::<proto::JoinChannel>::send(self, result)
3073 }
3074}
3075impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3076 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3077 Response::<proto::JoinRoom>::send(self, result)
3078 }
3079}
3080
3081async fn join_channel_internal(
3082 channel_id: ChannelId,
3083 response: Box<impl JoinChannelInternalResponse>,
3084 session: Session,
3085) -> Result<()> {
3086 let joined_room = {
3087 let mut db = session.db().await;
3088 // If zed quits without leaving the room, and the user re-opens zed before the
3089 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3090 // room they were in.
3091 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3092 tracing::info!(
3093 stale_connection_id = %connection,
3094 "cleaning up stale connection",
3095 );
3096 drop(db);
3097 leave_room_for_session(&session, connection).await?;
3098 db = session.db().await;
3099 }
3100
3101 let (joined_room, membership_updated, role) = db
3102 .join_channel(channel_id, session.user_id(), session.connection_id)
3103 .await?;
3104
3105 let live_kit_connection_info =
3106 session
3107 .app_state
3108 .livekit_client
3109 .as_ref()
3110 .and_then(|live_kit| {
3111 let (can_publish, token) = if role == ChannelRole::Guest {
3112 (
3113 false,
3114 live_kit
3115 .guest_token(
3116 &joined_room.room.livekit_room,
3117 &session.user_id().to_string(),
3118 )
3119 .trace_err()?,
3120 )
3121 } else {
3122 (
3123 true,
3124 live_kit
3125 .room_token(
3126 &joined_room.room.livekit_room,
3127 &session.user_id().to_string(),
3128 )
3129 .trace_err()?,
3130 )
3131 };
3132
3133 Some(LiveKitConnectionInfo {
3134 server_url: live_kit.url().into(),
3135 token,
3136 can_publish,
3137 })
3138 });
3139
3140 response.send(proto::JoinRoomResponse {
3141 room: Some(joined_room.room.clone()),
3142 channel_id: joined_room
3143 .channel
3144 .as_ref()
3145 .map(|channel| channel.id.to_proto()),
3146 live_kit_connection_info,
3147 })?;
3148
3149 let mut connection_pool = session.connection_pool().await;
3150 if let Some(membership_updated) = membership_updated {
3151 notify_membership_updated(
3152 &mut connection_pool,
3153 membership_updated,
3154 session.user_id(),
3155 &session.peer,
3156 );
3157 }
3158
3159 room_updated(&joined_room.room, &session.peer);
3160
3161 joined_room
3162 };
3163
3164 channel_updated(
3165 &joined_room
3166 .channel
3167 .ok_or_else(|| anyhow!("channel not returned"))?,
3168 &joined_room.room,
3169 &session.peer,
3170 &*session.connection_pool().await,
3171 );
3172
3173 update_user_contacts(session.user_id(), &session).await?;
3174 Ok(())
3175}
3176
3177/// Start editing the channel notes
3178async fn join_channel_buffer(
3179 request: proto::JoinChannelBuffer,
3180 response: Response<proto::JoinChannelBuffer>,
3181 session: Session,
3182) -> Result<()> {
3183 let db = session.db().await;
3184 let channel_id = ChannelId::from_proto(request.channel_id);
3185
3186 let open_response = db
3187 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3188 .await?;
3189
3190 let collaborators = open_response.collaborators.clone();
3191 response.send(open_response)?;
3192
3193 let update = UpdateChannelBufferCollaborators {
3194 channel_id: channel_id.to_proto(),
3195 collaborators: collaborators.clone(),
3196 };
3197 channel_buffer_updated(
3198 session.connection_id,
3199 collaborators
3200 .iter()
3201 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3202 &update,
3203 &session.peer,
3204 );
3205
3206 Ok(())
3207}
3208
3209/// Edit the channel notes
3210async fn update_channel_buffer(
3211 request: proto::UpdateChannelBuffer,
3212 session: Session,
3213) -> Result<()> {
3214 let db = session.db().await;
3215 let channel_id = ChannelId::from_proto(request.channel_id);
3216
3217 let (collaborators, epoch, version) = db
3218 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3219 .await?;
3220
3221 channel_buffer_updated(
3222 session.connection_id,
3223 collaborators.clone(),
3224 &proto::UpdateChannelBuffer {
3225 channel_id: channel_id.to_proto(),
3226 operations: request.operations,
3227 },
3228 &session.peer,
3229 );
3230
3231 let pool = &*session.connection_pool().await;
3232
3233 let non_collaborators =
3234 pool.channel_connection_ids(channel_id)
3235 .filter_map(|(connection_id, _)| {
3236 if collaborators.contains(&connection_id) {
3237 None
3238 } else {
3239 Some(connection_id)
3240 }
3241 });
3242
3243 broadcast(None, non_collaborators, |peer_id| {
3244 session.peer.send(
3245 peer_id,
3246 proto::UpdateChannels {
3247 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3248 channel_id: channel_id.to_proto(),
3249 epoch: epoch as u64,
3250 version: version.clone(),
3251 }],
3252 ..Default::default()
3253 },
3254 )
3255 });
3256
3257 Ok(())
3258}
3259
3260/// Rejoin the channel notes after a connection blip
3261async fn rejoin_channel_buffers(
3262 request: proto::RejoinChannelBuffers,
3263 response: Response<proto::RejoinChannelBuffers>,
3264 session: Session,
3265) -> Result<()> {
3266 let db = session.db().await;
3267 let buffers = db
3268 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3269 .await?;
3270
3271 for rejoined_buffer in &buffers {
3272 let collaborators_to_notify = rejoined_buffer
3273 .buffer
3274 .collaborators
3275 .iter()
3276 .filter_map(|c| Some(c.peer_id?.into()));
3277 channel_buffer_updated(
3278 session.connection_id,
3279 collaborators_to_notify,
3280 &proto::UpdateChannelBufferCollaborators {
3281 channel_id: rejoined_buffer.buffer.channel_id,
3282 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3283 },
3284 &session.peer,
3285 );
3286 }
3287
3288 response.send(proto::RejoinChannelBuffersResponse {
3289 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3290 })?;
3291
3292 Ok(())
3293}
3294
3295/// Stop editing the channel notes
3296async fn leave_channel_buffer(
3297 request: proto::LeaveChannelBuffer,
3298 response: Response<proto::LeaveChannelBuffer>,
3299 session: Session,
3300) -> Result<()> {
3301 let db = session.db().await;
3302 let channel_id = ChannelId::from_proto(request.channel_id);
3303
3304 let left_buffer = db
3305 .leave_channel_buffer(channel_id, session.connection_id)
3306 .await?;
3307
3308 response.send(Ack {})?;
3309
3310 channel_buffer_updated(
3311 session.connection_id,
3312 left_buffer.connections,
3313 &proto::UpdateChannelBufferCollaborators {
3314 channel_id: channel_id.to_proto(),
3315 collaborators: left_buffer.collaborators,
3316 },
3317 &session.peer,
3318 );
3319
3320 Ok(())
3321}
3322
3323fn channel_buffer_updated<T: EnvelopedMessage>(
3324 sender_id: ConnectionId,
3325 collaborators: impl IntoIterator<Item = ConnectionId>,
3326 message: &T,
3327 peer: &Peer,
3328) {
3329 broadcast(Some(sender_id), collaborators, |peer_id| {
3330 peer.send(peer_id, message.clone())
3331 });
3332}
3333
3334fn send_notifications(
3335 connection_pool: &ConnectionPool,
3336 peer: &Peer,
3337 notifications: db::NotificationBatch,
3338) {
3339 for (user_id, notification) in notifications {
3340 for connection_id in connection_pool.user_connection_ids(user_id) {
3341 if let Err(error) = peer.send(
3342 connection_id,
3343 proto::AddNotification {
3344 notification: Some(notification.clone()),
3345 },
3346 ) {
3347 tracing::error!(
3348 "failed to send notification to {:?} {}",
3349 connection_id,
3350 error
3351 );
3352 }
3353 }
3354 }
3355}
3356
3357/// Send a message to the channel
3358async fn send_channel_message(
3359 request: proto::SendChannelMessage,
3360 response: Response<proto::SendChannelMessage>,
3361 session: Session,
3362) -> Result<()> {
3363 // Validate the message body.
3364 let body = request.body.trim().to_string();
3365 if body.len() > MAX_MESSAGE_LEN {
3366 return Err(anyhow!("message is too long"))?;
3367 }
3368 if body.is_empty() {
3369 return Err(anyhow!("message can't be blank"))?;
3370 }
3371
3372 // TODO: adjust mentions if body is trimmed
3373
3374 let timestamp = OffsetDateTime::now_utc();
3375 let nonce = request
3376 .nonce
3377 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3378
3379 let channel_id = ChannelId::from_proto(request.channel_id);
3380 let CreatedChannelMessage {
3381 message_id,
3382 participant_connection_ids,
3383 notifications,
3384 } = session
3385 .db()
3386 .await
3387 .create_channel_message(
3388 channel_id,
3389 session.user_id(),
3390 &body,
3391 &request.mentions,
3392 timestamp,
3393 nonce.clone().into(),
3394 request.reply_to_message_id.map(MessageId::from_proto),
3395 )
3396 .await?;
3397
3398 let message = proto::ChannelMessage {
3399 sender_id: session.user_id().to_proto(),
3400 id: message_id.to_proto(),
3401 body,
3402 mentions: request.mentions,
3403 timestamp: timestamp.unix_timestamp() as u64,
3404 nonce: Some(nonce),
3405 reply_to_message_id: request.reply_to_message_id,
3406 edited_at: None,
3407 };
3408 broadcast(
3409 Some(session.connection_id),
3410 participant_connection_ids.clone(),
3411 |connection| {
3412 session.peer.send(
3413 connection,
3414 proto::ChannelMessageSent {
3415 channel_id: channel_id.to_proto(),
3416 message: Some(message.clone()),
3417 },
3418 )
3419 },
3420 );
3421 response.send(proto::SendChannelMessageResponse {
3422 message: Some(message),
3423 })?;
3424
3425 let pool = &*session.connection_pool().await;
3426 let non_participants =
3427 pool.channel_connection_ids(channel_id)
3428 .filter_map(|(connection_id, _)| {
3429 if participant_connection_ids.contains(&connection_id) {
3430 None
3431 } else {
3432 Some(connection_id)
3433 }
3434 });
3435 broadcast(None, non_participants, |peer_id| {
3436 session.peer.send(
3437 peer_id,
3438 proto::UpdateChannels {
3439 latest_channel_message_ids: vec![proto::ChannelMessageId {
3440 channel_id: channel_id.to_proto(),
3441 message_id: message_id.to_proto(),
3442 }],
3443 ..Default::default()
3444 },
3445 )
3446 });
3447 send_notifications(pool, &session.peer, notifications);
3448
3449 Ok(())
3450}
3451
3452/// Delete a channel message
3453async fn remove_channel_message(
3454 request: proto::RemoveChannelMessage,
3455 response: Response<proto::RemoveChannelMessage>,
3456 session: Session,
3457) -> Result<()> {
3458 let channel_id = ChannelId::from_proto(request.channel_id);
3459 let message_id = MessageId::from_proto(request.message_id);
3460 let (connection_ids, existing_notification_ids) = session
3461 .db()
3462 .await
3463 .remove_channel_message(channel_id, message_id, session.user_id())
3464 .await?;
3465
3466 broadcast(
3467 Some(session.connection_id),
3468 connection_ids,
3469 move |connection| {
3470 session.peer.send(connection, request.clone())?;
3471
3472 for notification_id in &existing_notification_ids {
3473 session.peer.send(
3474 connection,
3475 proto::DeleteNotification {
3476 notification_id: (*notification_id).to_proto(),
3477 },
3478 )?;
3479 }
3480
3481 Ok(())
3482 },
3483 );
3484 response.send(proto::Ack {})?;
3485 Ok(())
3486}
3487
3488async fn update_channel_message(
3489 request: proto::UpdateChannelMessage,
3490 response: Response<proto::UpdateChannelMessage>,
3491 session: Session,
3492) -> Result<()> {
3493 let channel_id = ChannelId::from_proto(request.channel_id);
3494 let message_id = MessageId::from_proto(request.message_id);
3495 let updated_at = OffsetDateTime::now_utc();
3496 let UpdatedChannelMessage {
3497 message_id,
3498 participant_connection_ids,
3499 notifications,
3500 reply_to_message_id,
3501 timestamp,
3502 deleted_mention_notification_ids,
3503 updated_mention_notifications,
3504 } = session
3505 .db()
3506 .await
3507 .update_channel_message(
3508 channel_id,
3509 message_id,
3510 session.user_id(),
3511 request.body.as_str(),
3512 &request.mentions,
3513 updated_at,
3514 )
3515 .await?;
3516
3517 let nonce = request
3518 .nonce
3519 .clone()
3520 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3521
3522 let message = proto::ChannelMessage {
3523 sender_id: session.user_id().to_proto(),
3524 id: message_id.to_proto(),
3525 body: request.body.clone(),
3526 mentions: request.mentions.clone(),
3527 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3528 nonce: Some(nonce),
3529 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3530 edited_at: Some(updated_at.unix_timestamp() as u64),
3531 };
3532
3533 response.send(proto::Ack {})?;
3534
3535 let pool = &*session.connection_pool().await;
3536 broadcast(
3537 Some(session.connection_id),
3538 participant_connection_ids,
3539 |connection| {
3540 session.peer.send(
3541 connection,
3542 proto::ChannelMessageUpdate {
3543 channel_id: channel_id.to_proto(),
3544 message: Some(message.clone()),
3545 },
3546 )?;
3547
3548 for notification_id in &deleted_mention_notification_ids {
3549 session.peer.send(
3550 connection,
3551 proto::DeleteNotification {
3552 notification_id: (*notification_id).to_proto(),
3553 },
3554 )?;
3555 }
3556
3557 for notification in &updated_mention_notifications {
3558 session.peer.send(
3559 connection,
3560 proto::UpdateNotification {
3561 notification: Some(notification.clone()),
3562 },
3563 )?;
3564 }
3565
3566 Ok(())
3567 },
3568 );
3569
3570 send_notifications(pool, &session.peer, notifications);
3571
3572 Ok(())
3573}
3574
3575/// Mark a channel message as read
3576async fn acknowledge_channel_message(
3577 request: proto::AckChannelMessage,
3578 session: Session,
3579) -> Result<()> {
3580 let channel_id = ChannelId::from_proto(request.channel_id);
3581 let message_id = MessageId::from_proto(request.message_id);
3582 let notifications = session
3583 .db()
3584 .await
3585 .observe_channel_message(channel_id, session.user_id(), message_id)
3586 .await?;
3587 send_notifications(
3588 &*session.connection_pool().await,
3589 &session.peer,
3590 notifications,
3591 );
3592 Ok(())
3593}
3594
3595/// Mark a buffer version as synced
3596async fn acknowledge_buffer_version(
3597 request: proto::AckBufferOperation,
3598 session: Session,
3599) -> Result<()> {
3600 let buffer_id = BufferId::from_proto(request.buffer_id);
3601 session
3602 .db()
3603 .await
3604 .observe_buffer_version(
3605 buffer_id,
3606 session.user_id(),
3607 request.epoch as i32,
3608 &request.version,
3609 )
3610 .await?;
3611 Ok(())
3612}
3613
3614async fn count_language_model_tokens(
3615 request: proto::CountLanguageModelTokens,
3616 response: Response<proto::CountLanguageModelTokens>,
3617 session: Session,
3618 config: &Config,
3619) -> Result<()> {
3620 authorize_access_to_legacy_llm_endpoints(&session).await?;
3621
3622 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3623 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3624 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3625 };
3626
3627 session
3628 .app_state
3629 .rate_limiter
3630 .check(&*rate_limit, session.user_id())
3631 .await?;
3632
3633 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3634 Some(proto::LanguageModelProvider::Google) => {
3635 let api_key = config
3636 .google_ai_api_key
3637 .as_ref()
3638 .context("no Google AI API key configured on the server")?;
3639 google_ai::count_tokens(
3640 session.http_client.as_ref(),
3641 google_ai::API_URL,
3642 api_key,
3643 serde_json::from_str(&request.request)?,
3644 )
3645 .await?
3646 }
3647 _ => return Err(anyhow!("unsupported provider"))?,
3648 };
3649
3650 response.send(proto::CountLanguageModelTokensResponse {
3651 token_count: result.total_tokens as u32,
3652 })?;
3653
3654 Ok(())
3655}
3656
3657struct ZedProCountLanguageModelTokensRateLimit;
3658
3659impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3660 fn capacity(&self) -> usize {
3661 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3662 .ok()
3663 .and_then(|v| v.parse().ok())
3664 .unwrap_or(600) // Picked arbitrarily
3665 }
3666
3667 fn refill_duration(&self) -> chrono::Duration {
3668 chrono::Duration::hours(1)
3669 }
3670
3671 fn db_name(&self) -> &'static str {
3672 "zed-pro:count-language-model-tokens"
3673 }
3674}
3675
3676struct FreeCountLanguageModelTokensRateLimit;
3677
3678impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3679 fn capacity(&self) -> usize {
3680 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3681 .ok()
3682 .and_then(|v| v.parse().ok())
3683 .unwrap_or(600 / 10) // Picked arbitrarily
3684 }
3685
3686 fn refill_duration(&self) -> chrono::Duration {
3687 chrono::Duration::hours(1)
3688 }
3689
3690 fn db_name(&self) -> &'static str {
3691 "free:count-language-model-tokens"
3692 }
3693}
3694
3695struct ZedProComputeEmbeddingsRateLimit;
3696
3697impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3698 fn capacity(&self) -> usize {
3699 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3700 .ok()
3701 .and_then(|v| v.parse().ok())
3702 .unwrap_or(5000) // Picked arbitrarily
3703 }
3704
3705 fn refill_duration(&self) -> chrono::Duration {
3706 chrono::Duration::hours(1)
3707 }
3708
3709 fn db_name(&self) -> &'static str {
3710 "zed-pro:compute-embeddings"
3711 }
3712}
3713
3714struct FreeComputeEmbeddingsRateLimit;
3715
3716impl RateLimit for FreeComputeEmbeddingsRateLimit {
3717 fn capacity(&self) -> usize {
3718 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3719 .ok()
3720 .and_then(|v| v.parse().ok())
3721 .unwrap_or(5000 / 10) // Picked arbitrarily
3722 }
3723
3724 fn refill_duration(&self) -> chrono::Duration {
3725 chrono::Duration::hours(1)
3726 }
3727
3728 fn db_name(&self) -> &'static str {
3729 "free:compute-embeddings"
3730 }
3731}
3732
3733async fn compute_embeddings(
3734 request: proto::ComputeEmbeddings,
3735 response: Response<proto::ComputeEmbeddings>,
3736 session: Session,
3737 api_key: Option<Arc<str>>,
3738) -> Result<()> {
3739 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3740 authorize_access_to_legacy_llm_endpoints(&session).await?;
3741
3742 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3743 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3744 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3745 };
3746
3747 session
3748 .app_state
3749 .rate_limiter
3750 .check(&*rate_limit, session.user_id())
3751 .await?;
3752
3753 let embeddings = match request.model.as_str() {
3754 "openai/text-embedding-3-small" => {
3755 open_ai::embed(
3756 session.http_client.as_ref(),
3757 OPEN_AI_API_URL,
3758 &api_key,
3759 OpenAiEmbeddingModel::TextEmbedding3Small,
3760 request.texts.iter().map(|text| text.as_str()),
3761 )
3762 .await?
3763 }
3764 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3765 };
3766
3767 let embeddings = request
3768 .texts
3769 .iter()
3770 .map(|text| {
3771 let mut hasher = sha2::Sha256::new();
3772 hasher.update(text.as_bytes());
3773 let result = hasher.finalize();
3774 result.to_vec()
3775 })
3776 .zip(
3777 embeddings
3778 .data
3779 .into_iter()
3780 .map(|embedding| embedding.embedding),
3781 )
3782 .collect::<HashMap<_, _>>();
3783
3784 let db = session.db().await;
3785 db.save_embeddings(&request.model, &embeddings)
3786 .await
3787 .context("failed to save embeddings")
3788 .trace_err();
3789
3790 response.send(proto::ComputeEmbeddingsResponse {
3791 embeddings: embeddings
3792 .into_iter()
3793 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3794 .collect(),
3795 })?;
3796 Ok(())
3797}
3798
3799async fn get_cached_embeddings(
3800 request: proto::GetCachedEmbeddings,
3801 response: Response<proto::GetCachedEmbeddings>,
3802 session: Session,
3803) -> Result<()> {
3804 authorize_access_to_legacy_llm_endpoints(&session).await?;
3805
3806 let db = session.db().await;
3807 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3808
3809 response.send(proto::GetCachedEmbeddingsResponse {
3810 embeddings: embeddings
3811 .into_iter()
3812 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3813 .collect(),
3814 })?;
3815 Ok(())
3816}
3817
3818/// This is leftover from before the LLM service.
3819///
3820/// The endpoints protected by this check will be moved there eventually.
3821async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3822 if session.is_staff() {
3823 Ok(())
3824 } else {
3825 Err(anyhow!("permission denied"))?
3826 }
3827}
3828
3829/// Get a Supermaven API key for the user
3830async fn get_supermaven_api_key(
3831 _request: proto::GetSupermavenApiKey,
3832 response: Response<proto::GetSupermavenApiKey>,
3833 session: Session,
3834) -> Result<()> {
3835 let user_id: String = session.user_id().to_string();
3836 if !session.is_staff() {
3837 return Err(anyhow!("supermaven not enabled for this account"))?;
3838 }
3839
3840 let email = session
3841 .email()
3842 .ok_or_else(|| anyhow!("user must have an email"))?;
3843
3844 let supermaven_admin_api = session
3845 .supermaven_client
3846 .as_ref()
3847 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3848
3849 let result = supermaven_admin_api
3850 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3851 .await?;
3852
3853 response.send(proto::GetSupermavenApiKeyResponse {
3854 api_key: result.api_key,
3855 })?;
3856
3857 Ok(())
3858}
3859
3860/// Start receiving chat updates for a channel
3861async fn join_channel_chat(
3862 request: proto::JoinChannelChat,
3863 response: Response<proto::JoinChannelChat>,
3864 session: Session,
3865) -> Result<()> {
3866 let channel_id = ChannelId::from_proto(request.channel_id);
3867
3868 let db = session.db().await;
3869 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3870 .await?;
3871 let messages = db
3872 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3873 .await?;
3874 response.send(proto::JoinChannelChatResponse {
3875 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3876 messages,
3877 })?;
3878 Ok(())
3879}
3880
3881/// Stop receiving chat updates for a channel
3882async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3883 let channel_id = ChannelId::from_proto(request.channel_id);
3884 session
3885 .db()
3886 .await
3887 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3888 .await?;
3889 Ok(())
3890}
3891
3892/// Retrieve the chat history for a channel
3893async fn get_channel_messages(
3894 request: proto::GetChannelMessages,
3895 response: Response<proto::GetChannelMessages>,
3896 session: Session,
3897) -> Result<()> {
3898 let channel_id = ChannelId::from_proto(request.channel_id);
3899 let messages = session
3900 .db()
3901 .await
3902 .get_channel_messages(
3903 channel_id,
3904 session.user_id(),
3905 MESSAGE_COUNT_PER_PAGE,
3906 Some(MessageId::from_proto(request.before_message_id)),
3907 )
3908 .await?;
3909 response.send(proto::GetChannelMessagesResponse {
3910 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3911 messages,
3912 })?;
3913 Ok(())
3914}
3915
3916/// Retrieve specific chat messages
3917async fn get_channel_messages_by_id(
3918 request: proto::GetChannelMessagesById,
3919 response: Response<proto::GetChannelMessagesById>,
3920 session: Session,
3921) -> Result<()> {
3922 let message_ids = request
3923 .message_ids
3924 .iter()
3925 .map(|id| MessageId::from_proto(*id))
3926 .collect::<Vec<_>>();
3927 let messages = session
3928 .db()
3929 .await
3930 .get_channel_messages_by_id(session.user_id(), &message_ids)
3931 .await?;
3932 response.send(proto::GetChannelMessagesResponse {
3933 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3934 messages,
3935 })?;
3936 Ok(())
3937}
3938
3939/// Retrieve the current users notifications
3940async fn get_notifications(
3941 request: proto::GetNotifications,
3942 response: Response<proto::GetNotifications>,
3943 session: Session,
3944) -> Result<()> {
3945 let notifications = session
3946 .db()
3947 .await
3948 .get_notifications(
3949 session.user_id(),
3950 NOTIFICATION_COUNT_PER_PAGE,
3951 request.before_id.map(db::NotificationId::from_proto),
3952 )
3953 .await?;
3954 response.send(proto::GetNotificationsResponse {
3955 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3956 notifications,
3957 })?;
3958 Ok(())
3959}
3960
3961/// Mark notifications as read
3962async fn mark_notification_as_read(
3963 request: proto::MarkNotificationRead,
3964 response: Response<proto::MarkNotificationRead>,
3965 session: Session,
3966) -> Result<()> {
3967 let database = &session.db().await;
3968 let notifications = database
3969 .mark_notification_as_read_by_id(
3970 session.user_id(),
3971 NotificationId::from_proto(request.notification_id),
3972 )
3973 .await?;
3974 send_notifications(
3975 &*session.connection_pool().await,
3976 &session.peer,
3977 notifications,
3978 );
3979 response.send(proto::Ack {})?;
3980 Ok(())
3981}
3982
3983/// Get the current users information
3984async fn get_private_user_info(
3985 _request: proto::GetPrivateUserInfo,
3986 response: Response<proto::GetPrivateUserInfo>,
3987 session: Session,
3988) -> Result<()> {
3989 let db = session.db().await;
3990
3991 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3992 let user = db
3993 .get_user_by_id(session.user_id())
3994 .await?
3995 .ok_or_else(|| anyhow!("user not found"))?;
3996 let flags = db.get_user_flags(session.user_id()).await?;
3997
3998 response.send(proto::GetPrivateUserInfoResponse {
3999 metrics_id,
4000 staff: user.admin,
4001 flags,
4002 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4003 })?;
4004 Ok(())
4005}
4006
4007/// Accept the terms of service (tos) on behalf of the current user
4008async fn accept_terms_of_service(
4009 _request: proto::AcceptTermsOfService,
4010 response: Response<proto::AcceptTermsOfService>,
4011 session: Session,
4012) -> Result<()> {
4013 let db = session.db().await;
4014
4015 let accepted_tos_at = Utc::now();
4016 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4017 .await?;
4018
4019 response.send(proto::AcceptTermsOfServiceResponse {
4020 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4021 })?;
4022 Ok(())
4023}
4024
4025/// The minimum account age an account must have in order to use the LLM service.
4026const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4027
4028async fn get_llm_api_token(
4029 _request: proto::GetLlmToken,
4030 response: Response<proto::GetLlmToken>,
4031 session: Session,
4032) -> Result<()> {
4033 let db = session.db().await;
4034
4035 let flags = db.get_user_flags(session.user_id()).await?;
4036 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4037 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4038 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4039
4040 if !session.is_staff() && !has_language_models_feature_flag {
4041 Err(anyhow!("permission denied"))?
4042 }
4043
4044 let user_id = session.user_id();
4045 let user = db
4046 .get_user_by_id(user_id)
4047 .await?
4048 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4049
4050 if user.accepted_tos_at.is_none() {
4051 Err(anyhow!("terms of service not accepted"))?
4052 }
4053
4054 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4055
4056 let bypass_account_age_check =
4057 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4058 if !bypass_account_age_check {
4059 let mut account_created_at = user.created_at;
4060 if let Some(github_created_at) = user.github_user_created_at {
4061 account_created_at = account_created_at.min(github_created_at);
4062 }
4063 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4064 Err(anyhow!("account too young"))?
4065 }
4066 }
4067
4068 let billing_preferences = db.get_billing_preferences(user.id).await?;
4069
4070 let token = LlmTokenClaims::create(
4071 &user,
4072 session.is_staff(),
4073 billing_preferences,
4074 has_llm_closed_beta_feature_flag,
4075 has_predict_edits_feature_flag,
4076 has_llm_subscription,
4077 session.current_plan(&db).await?,
4078 session.system_id.clone(),
4079 &session.app_state.config,
4080 )?;
4081 response.send(proto::GetLlmTokenResponse { token })?;
4082 Ok(())
4083}
4084
4085fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4086 let message = match message {
4087 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4088 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4089 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4090 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4091 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4092 code: frame.code.into(),
4093 reason: frame.reason,
4094 })),
4095 // We should never receive a frame while reading the message, according
4096 // to the `tungstenite` maintainers:
4097 //
4098 // > It cannot occur when you read messages from the WebSocket, but it
4099 // > can be used when you want to send the raw frames (e.g. you want to
4100 // > send the frames to the WebSocket without composing the full message first).
4101 // >
4102 // > — https://github.com/snapview/tungstenite-rs/issues/268
4103 TungsteniteMessage::Frame(_) => {
4104 bail!("received an unexpected frame while reading the message")
4105 }
4106 };
4107
4108 Ok(message)
4109}
4110
4111fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4112 match message {
4113 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4114 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4115 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4116 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4117 AxumMessage::Close(frame) => {
4118 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4119 code: frame.code.into(),
4120 reason: frame.reason,
4121 }))
4122 }
4123 }
4124}
4125
4126fn notify_membership_updated(
4127 connection_pool: &mut ConnectionPool,
4128 result: MembershipUpdated,
4129 user_id: UserId,
4130 peer: &Peer,
4131) {
4132 for membership in &result.new_channels.channel_memberships {
4133 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4134 }
4135 for channel_id in &result.removed_channels {
4136 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4137 }
4138
4139 let user_channels_update = proto::UpdateUserChannels {
4140 channel_memberships: result
4141 .new_channels
4142 .channel_memberships
4143 .iter()
4144 .map(|cm| proto::ChannelMembership {
4145 channel_id: cm.channel_id.to_proto(),
4146 role: cm.role.into(),
4147 })
4148 .collect(),
4149 ..Default::default()
4150 };
4151
4152 let mut update = build_channels_update(result.new_channels);
4153 update.delete_channels = result
4154 .removed_channels
4155 .into_iter()
4156 .map(|id| id.to_proto())
4157 .collect();
4158 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4159
4160 for connection_id in connection_pool.user_connection_ids(user_id) {
4161 peer.send(connection_id, user_channels_update.clone())
4162 .trace_err();
4163 peer.send(connection_id, update.clone()).trace_err();
4164 }
4165}
4166
4167fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4168 proto::UpdateUserChannels {
4169 channel_memberships: channels
4170 .channel_memberships
4171 .iter()
4172 .map(|m| proto::ChannelMembership {
4173 channel_id: m.channel_id.to_proto(),
4174 role: m.role.into(),
4175 })
4176 .collect(),
4177 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4178 observed_channel_message_id: channels.observed_channel_messages.clone(),
4179 }
4180}
4181
4182fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4183 let mut update = proto::UpdateChannels::default();
4184
4185 for channel in channels.channels {
4186 update.channels.push(channel.to_proto());
4187 }
4188
4189 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4190 update.latest_channel_message_ids = channels.latest_channel_messages;
4191
4192 for (channel_id, participants) in channels.channel_participants {
4193 update
4194 .channel_participants
4195 .push(proto::ChannelParticipants {
4196 channel_id: channel_id.to_proto(),
4197 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4198 });
4199 }
4200
4201 for channel in channels.invited_channels {
4202 update.channel_invitations.push(channel.to_proto());
4203 }
4204
4205 update
4206}
4207
4208fn build_initial_contacts_update(
4209 contacts: Vec<db::Contact>,
4210 pool: &ConnectionPool,
4211) -> proto::UpdateContacts {
4212 let mut update = proto::UpdateContacts::default();
4213
4214 for contact in contacts {
4215 match contact {
4216 db::Contact::Accepted { user_id, busy } => {
4217 update.contacts.push(contact_for_user(user_id, busy, pool));
4218 }
4219 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4220 db::Contact::Incoming { user_id } => {
4221 update
4222 .incoming_requests
4223 .push(proto::IncomingContactRequest {
4224 requester_id: user_id.to_proto(),
4225 })
4226 }
4227 }
4228 }
4229
4230 update
4231}
4232
4233fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4234 proto::Contact {
4235 user_id: user_id.to_proto(),
4236 online: pool.is_user_online(user_id),
4237 busy,
4238 }
4239}
4240
4241fn room_updated(room: &proto::Room, peer: &Peer) {
4242 broadcast(
4243 None,
4244 room.participants
4245 .iter()
4246 .filter_map(|participant| Some(participant.peer_id?.into())),
4247 |peer_id| {
4248 peer.send(
4249 peer_id,
4250 proto::RoomUpdated {
4251 room: Some(room.clone()),
4252 },
4253 )
4254 },
4255 );
4256}
4257
4258fn channel_updated(
4259 channel: &db::channel::Model,
4260 room: &proto::Room,
4261 peer: &Peer,
4262 pool: &ConnectionPool,
4263) {
4264 let participants = room
4265 .participants
4266 .iter()
4267 .map(|p| p.user_id)
4268 .collect::<Vec<_>>();
4269
4270 broadcast(
4271 None,
4272 pool.channel_connection_ids(channel.root_id())
4273 .filter_map(|(channel_id, role)| {
4274 role.can_see_channel(channel.visibility)
4275 .then_some(channel_id)
4276 }),
4277 |peer_id| {
4278 peer.send(
4279 peer_id,
4280 proto::UpdateChannels {
4281 channel_participants: vec![proto::ChannelParticipants {
4282 channel_id: channel.id.to_proto(),
4283 participant_user_ids: participants.clone(),
4284 }],
4285 ..Default::default()
4286 },
4287 )
4288 },
4289 );
4290}
4291
4292async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4293 let db = session.db().await;
4294
4295 let contacts = db.get_contacts(user_id).await?;
4296 let busy = db.is_user_busy(user_id).await?;
4297
4298 let pool = session.connection_pool().await;
4299 let updated_contact = contact_for_user(user_id, busy, &pool);
4300 for contact in contacts {
4301 if let db::Contact::Accepted {
4302 user_id: contact_user_id,
4303 ..
4304 } = contact
4305 {
4306 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4307 session
4308 .peer
4309 .send(
4310 contact_conn_id,
4311 proto::UpdateContacts {
4312 contacts: vec![updated_contact.clone()],
4313 remove_contacts: Default::default(),
4314 incoming_requests: Default::default(),
4315 remove_incoming_requests: Default::default(),
4316 outgoing_requests: Default::default(),
4317 remove_outgoing_requests: Default::default(),
4318 },
4319 )
4320 .trace_err();
4321 }
4322 }
4323 }
4324 Ok(())
4325}
4326
4327async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4328 let mut contacts_to_update = HashSet::default();
4329
4330 let room_id;
4331 let canceled_calls_to_user_ids;
4332 let livekit_room;
4333 let delete_livekit_room;
4334 let room;
4335 let channel;
4336
4337 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4338 contacts_to_update.insert(session.user_id());
4339
4340 for project in left_room.left_projects.values() {
4341 project_left(project, session);
4342 }
4343
4344 room_id = RoomId::from_proto(left_room.room.id);
4345 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4346 livekit_room = mem::take(&mut left_room.room.livekit_room);
4347 delete_livekit_room = left_room.deleted;
4348 room = mem::take(&mut left_room.room);
4349 channel = mem::take(&mut left_room.channel);
4350
4351 room_updated(&room, &session.peer);
4352 } else {
4353 return Ok(());
4354 }
4355
4356 if let Some(channel) = channel {
4357 channel_updated(
4358 &channel,
4359 &room,
4360 &session.peer,
4361 &*session.connection_pool().await,
4362 );
4363 }
4364
4365 {
4366 let pool = session.connection_pool().await;
4367 for canceled_user_id in canceled_calls_to_user_ids {
4368 for connection_id in pool.user_connection_ids(canceled_user_id) {
4369 session
4370 .peer
4371 .send(
4372 connection_id,
4373 proto::CallCanceled {
4374 room_id: room_id.to_proto(),
4375 },
4376 )
4377 .trace_err();
4378 }
4379 contacts_to_update.insert(canceled_user_id);
4380 }
4381 }
4382
4383 for contact_user_id in contacts_to_update {
4384 update_user_contacts(contact_user_id, session).await?;
4385 }
4386
4387 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4388 live_kit
4389 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4390 .await
4391 .trace_err();
4392
4393 if delete_livekit_room {
4394 live_kit.delete_room(livekit_room).await.trace_err();
4395 }
4396 }
4397
4398 Ok(())
4399}
4400
4401async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4402 let left_channel_buffers = session
4403 .db()
4404 .await
4405 .leave_channel_buffers(session.connection_id)
4406 .await?;
4407
4408 for left_buffer in left_channel_buffers {
4409 channel_buffer_updated(
4410 session.connection_id,
4411 left_buffer.connections,
4412 &proto::UpdateChannelBufferCollaborators {
4413 channel_id: left_buffer.channel_id.to_proto(),
4414 collaborators: left_buffer.collaborators,
4415 },
4416 &session.peer,
4417 );
4418 }
4419
4420 Ok(())
4421}
4422
4423fn project_left(project: &db::LeftProject, session: &Session) {
4424 for connection_id in &project.connection_ids {
4425 if project.should_unshare {
4426 session
4427 .peer
4428 .send(
4429 *connection_id,
4430 proto::UnshareProject {
4431 project_id: project.id.to_proto(),
4432 },
4433 )
4434 .trace_err();
4435 } else {
4436 session
4437 .peer
4438 .send(
4439 *connection_id,
4440 proto::RemoveProjectCollaborator {
4441 project_id: project.id.to_proto(),
4442 peer_id: Some(session.connection_id.into()),
4443 },
4444 )
4445 .trace_err();
4446 }
4447 }
4448}
4449
4450pub trait ResultExt {
4451 type Ok;
4452
4453 fn trace_err(self) -> Option<Self::Ok>;
4454}
4455
4456impl<T, E> ResultExt for Result<T, E>
4457where
4458 E: std::fmt::Debug,
4459{
4460 type Ok = T;
4461
4462 #[track_caller]
4463 fn trace_err(self) -> Option<T> {
4464 match self {
4465 Ok(value) => Some(value),
4466 Err(error) => {
4467 tracing::error!("{:?}", error);
4468 None
4469 }
4470 }
4471 }
4472}