1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
341 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
344 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
345 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
346 .add_message_handler(create_buffer_for_peer)
347 .add_request_handler(update_buffer)
348 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
353 .add_request_handler(get_users)
354 .add_request_handler(fuzzy_search_users)
355 .add_request_handler(request_contact)
356 .add_request_handler(remove_contact)
357 .add_request_handler(respond_to_contact_request)
358 .add_message_handler(subscribe_to_channels)
359 .add_request_handler(create_channel)
360 .add_request_handler(delete_channel)
361 .add_request_handler(invite_channel_member)
362 .add_request_handler(remove_channel_member)
363 .add_request_handler(set_channel_member_role)
364 .add_request_handler(set_channel_visibility)
365 .add_request_handler(rename_channel)
366 .add_request_handler(join_channel_buffer)
367 .add_request_handler(leave_channel_buffer)
368 .add_message_handler(update_channel_buffer)
369 .add_request_handler(rejoin_channel_buffers)
370 .add_request_handler(get_channel_members)
371 .add_request_handler(respond_to_channel_invite)
372 .add_request_handler(join_channel)
373 .add_request_handler(join_channel_chat)
374 .add_message_handler(leave_channel_chat)
375 .add_request_handler(send_channel_message)
376 .add_request_handler(remove_channel_message)
377 .add_request_handler(update_channel_message)
378 .add_request_handler(get_channel_messages)
379 .add_request_handler(get_channel_messages_by_id)
380 .add_request_handler(get_notifications)
381 .add_request_handler(mark_notification_as_read)
382 .add_request_handler(move_channel)
383 .add_request_handler(follow)
384 .add_message_handler(unfollow)
385 .add_message_handler(update_followers)
386 .add_request_handler(get_private_user_info)
387 .add_request_handler(get_llm_api_token)
388 .add_request_handler(accept_terms_of_service)
389 .add_message_handler(acknowledge_channel_message)
390 .add_message_handler(acknowledge_buffer_version)
391 .add_request_handler(get_supermaven_api_key)
392 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
395 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
398 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
399 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
400 .add_message_handler(update_context)
401 .add_request_handler({
402 let app_state = app_state.clone();
403 move |request, response, session| {
404 let app_state = app_state.clone();
405 async move {
406 count_language_model_tokens(request, response, session, &app_state.config)
407 .await
408 }
409 }
410 })
411 .add_request_handler(get_cached_embeddings)
412 .add_request_handler({
413 let app_state = app_state.clone();
414 move |request, response, session| {
415 compute_embeddings(
416 request,
417 response,
418 session,
419 app_state.config.openai_api_key.clone(),
420 )
421 }
422 });
423
424 Arc::new(server)
425 }
426
427 pub async fn start(&self) -> Result<()> {
428 let server_id = *self.id.lock();
429 let app_state = self.app_state.clone();
430 let peer = self.peer.clone();
431 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
432 let pool = self.connection_pool.clone();
433 let livekit_client = self.app_state.livekit_client.clone();
434
435 let span = info_span!("start server");
436 self.app_state.executor.spawn_detached(
437 async move {
438 tracing::info!("waiting for cleanup timeout");
439 timeout.await;
440 tracing::info!("cleanup timeout expired, retrieving stale rooms");
441 if let Some((room_ids, channel_ids)) = app_state
442 .db
443 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
444 .await
445 .trace_err()
446 {
447 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
448 tracing::info!(
449 stale_channel_buffer_count = channel_ids.len(),
450 "retrieved stale channel buffers"
451 );
452
453 for channel_id in channel_ids {
454 if let Some(refreshed_channel_buffer) = app_state
455 .db
456 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
457 .await
458 .trace_err()
459 {
460 for connection_id in refreshed_channel_buffer.connection_ids {
461 peer.send(
462 connection_id,
463 proto::UpdateChannelBufferCollaborators {
464 channel_id: channel_id.to_proto(),
465 collaborators: refreshed_channel_buffer
466 .collaborators
467 .clone(),
468 },
469 )
470 .trace_err();
471 }
472 }
473 }
474
475 for room_id in room_ids {
476 let mut contacts_to_update = HashSet::default();
477 let mut canceled_calls_to_user_ids = Vec::new();
478 let mut livekit_room = String::new();
479 let mut delete_livekit_room = false;
480
481 if let Some(mut refreshed_room) = app_state
482 .db
483 .clear_stale_room_participants(room_id, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(
488 room_id = room_id.0,
489 new_participant_count = refreshed_room.room.participants.len(),
490 "refreshed room"
491 );
492 room_updated(&refreshed_room.room, &peer);
493 if let Some(channel) = refreshed_room.channel.as_ref() {
494 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
495 }
496 contacts_to_update
497 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
498 contacts_to_update
499 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
500 canceled_calls_to_user_ids =
501 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
502 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
503 delete_livekit_room = refreshed_room.room.participants.is_empty();
504 }
505
506 {
507 let pool = pool.lock();
508 for canceled_user_id in canceled_calls_to_user_ids {
509 for connection_id in pool.user_connection_ids(canceled_user_id) {
510 peer.send(
511 connection_id,
512 proto::CallCanceled {
513 room_id: room_id.to_proto(),
514 },
515 )
516 .trace_err();
517 }
518 }
519 }
520
521 for user_id in contacts_to_update {
522 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
523 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
524 if let Some((busy, contacts)) = busy.zip(contacts) {
525 let pool = pool.lock();
526 let updated_contact = contact_for_user(user_id, busy, &pool);
527 for contact in contacts {
528 if let db::Contact::Accepted {
529 user_id: contact_user_id,
530 ..
531 } = contact
532 {
533 for contact_conn_id in
534 pool.user_connection_ids(contact_user_id)
535 {
536 peer.send(
537 contact_conn_id,
538 proto::UpdateContacts {
539 contacts: vec![updated_contact.clone()],
540 remove_contacts: Default::default(),
541 incoming_requests: Default::default(),
542 remove_incoming_requests: Default::default(),
543 outgoing_requests: Default::default(),
544 remove_outgoing_requests: Default::default(),
545 },
546 )
547 .trace_err();
548 }
549 }
550 }
551 }
552 }
553
554 if let Some(live_kit) = livekit_client.as_ref() {
555 if delete_livekit_room {
556 live_kit.delete_room(livekit_room).await.trace_err();
557 }
558 }
559 }
560 }
561
562 app_state
563 .db
564 .delete_stale_servers(&app_state.config.zed_environment, server_id)
565 .await
566 .trace_err();
567 }
568 .instrument(span),
569 );
570 Ok(())
571 }
572
573 pub fn teardown(&self) {
574 self.peer.teardown();
575 self.connection_pool.lock().reset();
576 let _ = self.teardown.send(true);
577 }
578
579 #[cfg(test)]
580 pub fn reset(&self, id: ServerId) {
581 self.teardown();
582 *self.id.lock() = id;
583 self.peer.reset(id.0 as u32);
584 let _ = self.teardown.send(false);
585 }
586
587 #[cfg(test)]
588 pub fn id(&self) -> ServerId {
589 *self.id.lock()
590 }
591
592 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
593 where
594 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
595 Fut: 'static + Send + Future<Output = Result<()>>,
596 M: EnvelopedMessage,
597 {
598 let prev_handler = self.handlers.insert(
599 TypeId::of::<M>(),
600 Box::new(move |envelope, session| {
601 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
602 let received_at = envelope.received_at;
603 tracing::info!("message received");
604 let start_time = Instant::now();
605 let future = (handler)(*envelope, session);
606 async move {
607 let result = future.await;
608 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
609 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
610 let queue_duration_ms = total_duration_ms - processing_duration_ms;
611 let payload_type = M::NAME;
612
613 match result {
614 Err(error) => {
615 tracing::error!(
616 ?error,
617 total_duration_ms,
618 processing_duration_ms,
619 queue_duration_ms,
620 payload_type,
621 "error handling message"
622 )
623 }
624 Ok(()) => tracing::info!(
625 total_duration_ms,
626 processing_duration_ms,
627 queue_duration_ms,
628 "finished handling message"
629 ),
630 }
631 }
632 .boxed()
633 }),
634 );
635 if prev_handler.is_some() {
636 panic!("registered a handler for the same message twice");
637 }
638 self
639 }
640
641 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
642 where
643 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
644 Fut: 'static + Send + Future<Output = Result<()>>,
645 M: EnvelopedMessage,
646 {
647 self.add_handler(move |envelope, session| handler(envelope.payload, session));
648 self
649 }
650
651 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
652 where
653 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
654 Fut: Send + Future<Output = Result<()>>,
655 M: RequestMessage,
656 {
657 let handler = Arc::new(handler);
658 self.add_handler(move |envelope, session| {
659 let receipt = envelope.receipt();
660 let handler = handler.clone();
661 async move {
662 let peer = session.peer.clone();
663 let responded = Arc::new(AtomicBool::default());
664 let response = Response {
665 peer: peer.clone(),
666 responded: responded.clone(),
667 receipt,
668 };
669 match (handler)(envelope.payload, response, session).await {
670 Ok(()) => {
671 if responded.load(std::sync::atomic::Ordering::SeqCst) {
672 Ok(())
673 } else {
674 Err(anyhow!("handler did not send a response"))?
675 }
676 }
677 Err(error) => {
678 let proto_err = match &error {
679 Error::Internal(err) => err.to_proto(),
680 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
681 };
682 peer.respond_with_error(receipt, proto_err)?;
683 Err(error)
684 }
685 }
686 }
687 })
688 }
689
690 #[allow(clippy::too_many_arguments)]
691 pub fn handle_connection(
692 self: &Arc<Self>,
693 connection: Connection,
694 address: String,
695 principal: Principal,
696 zed_version: ZedVersion,
697 geoip_country_code: Option<String>,
698 system_id: Option<String>,
699 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
700 executor: Executor,
701 ) -> impl Future<Output = ()> {
702 let this = self.clone();
703 let span = info_span!("handle connection", %address,
704 connection_id=field::Empty,
705 user_id=field::Empty,
706 login=field::Empty,
707 impersonator=field::Empty,
708 geoip_country_code=field::Empty
709 );
710 principal.update_span(&span);
711 if let Some(country_code) = geoip_country_code.as_ref() {
712 span.record("geoip_country_code", country_code);
713 }
714
715 let mut teardown = self.teardown.subscribe();
716 async move {
717 if *teardown.borrow() {
718 tracing::error!("server is tearing down");
719 return
720 }
721 let (connection_id, handle_io, mut incoming_rx) = this
722 .peer
723 .add_connection(connection, {
724 let executor = executor.clone();
725 move |duration| executor.sleep(duration)
726 });
727 tracing::Span::current().record("connection_id", format!("{}", connection_id));
728
729 tracing::info!("connection opened");
730
731 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
732 let http_client = match ReqwestClient::user_agent(&user_agent) {
733 Ok(http_client) => Arc::new(http_client),
734 Err(error) => {
735 tracing::error!(?error, "failed to create HTTP client");
736 return;
737 }
738 };
739
740 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
741 supermaven_admin_api_key.to_string(),
742 http_client.clone(),
743 )));
744
745 let session = Session {
746 principal: principal.clone(),
747 connection_id,
748 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
749 peer: this.peer.clone(),
750 connection_pool: this.connection_pool.clone(),
751 app_state: this.app_state.clone(),
752 http_client,
753 geoip_country_code,
754 system_id,
755 _executor: executor.clone(),
756 supermaven_client,
757 };
758
759 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
760 tracing::error!(?error, "failed to send initial client update");
761 return;
762 }
763
764 let handle_io = handle_io.fuse();
765 futures::pin_mut!(handle_io);
766
767 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
768 // This prevents deadlocks when e.g., client A performs a request to client B and
769 // client B performs a request to client A. If both clients stop processing further
770 // messages until their respective request completes, they won't have a chance to
771 // respond to the other client's request and cause a deadlock.
772 //
773 // This arrangement ensures we will attempt to process earlier messages first, but fall
774 // back to processing messages arrived later in the spirit of making progress.
775 let mut foreground_message_handlers = FuturesUnordered::new();
776 let concurrent_handlers = Arc::new(Semaphore::new(256));
777 loop {
778 let next_message = async {
779 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
780 let message = incoming_rx.next().await;
781 (permit, message)
782 }.fuse();
783 futures::pin_mut!(next_message);
784 futures::select_biased! {
785 _ = teardown.changed().fuse() => return,
786 result = handle_io => {
787 if let Err(error) = result {
788 tracing::error!(?error, "error handling I/O");
789 }
790 break;
791 }
792 _ = foreground_message_handlers.next() => {}
793 next_message = next_message => {
794 let (permit, message) = next_message;
795 if let Some(message) = message {
796 let type_name = message.payload_type_name();
797 // note: we copy all the fields from the parent span so we can query them in the logs.
798 // (https://github.com/tokio-rs/tracing/issues/2670).
799 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
800 user_id=field::Empty,
801 login=field::Empty,
802 impersonator=field::Empty,
803 );
804 principal.update_span(&span);
805 let span_enter = span.enter();
806 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
807 let is_background = message.is_background();
808 let handle_message = (handler)(message, session.clone());
809 drop(span_enter);
810
811 let handle_message = async move {
812 handle_message.await;
813 drop(permit);
814 }.instrument(span);
815 if is_background {
816 executor.spawn_detached(handle_message);
817 } else {
818 foreground_message_handlers.push(handle_message);
819 }
820 } else {
821 tracing::error!("no message handler");
822 }
823 } else {
824 tracing::info!("connection closed");
825 break;
826 }
827 }
828 }
829 }
830
831 drop(foreground_message_handlers);
832 tracing::info!("signing out");
833 if let Err(error) = connection_lost(session, teardown, executor).await {
834 tracing::error!(?error, "error signing out");
835 }
836
837 }.instrument(span)
838 }
839
840 async fn send_initial_client_update(
841 &self,
842 connection_id: ConnectionId,
843 principal: &Principal,
844 zed_version: ZedVersion,
845 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
846 session: &Session,
847 ) -> Result<()> {
848 self.peer.send(
849 connection_id,
850 proto::Hello {
851 peer_id: Some(connection_id.into()),
852 },
853 )?;
854 tracing::info!("sent hello message");
855 if let Some(send_connection_id) = send_connection_id.take() {
856 let _ = send_connection_id.send(connection_id);
857 }
858
859 match principal {
860 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
861 if !user.connected_once {
862 self.peer.send(connection_id, proto::ShowContacts {})?;
863 self.app_state
864 .db
865 .set_user_connected_once(user.id, true)
866 .await?;
867 }
868
869 update_user_plan(user.id, session).await?;
870
871 let contacts = self.app_state.db.get_contacts(user.id).await?;
872
873 {
874 let mut pool = self.connection_pool.lock();
875 pool.add_connection(connection_id, user.id, user.admin, zed_version);
876 self.peer.send(
877 connection_id,
878 build_initial_contacts_update(contacts, &pool),
879 )?;
880 }
881
882 if should_auto_subscribe_to_channels(zed_version) {
883 subscribe_user_to_channels(user.id, session).await?;
884 }
885
886 if let Some(incoming_call) =
887 self.app_state.db.incoming_call_for_user(user.id).await?
888 {
889 self.peer.send(connection_id, incoming_call)?;
890 }
891
892 update_user_contacts(user.id, session).await?;
893 }
894 }
895
896 Ok(())
897 }
898
899 pub async fn invite_code_redeemed(
900 self: &Arc<Self>,
901 inviter_id: UserId,
902 invitee_id: UserId,
903 ) -> Result<()> {
904 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
905 if let Some(code) = &user.invite_code {
906 let pool = self.connection_pool.lock();
907 let invitee_contact = contact_for_user(invitee_id, false, &pool);
908 for connection_id in pool.user_connection_ids(inviter_id) {
909 self.peer.send(
910 connection_id,
911 proto::UpdateContacts {
912 contacts: vec![invitee_contact.clone()],
913 ..Default::default()
914 },
915 )?;
916 self.peer.send(
917 connection_id,
918 proto::UpdateInviteInfo {
919 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
920 count: user.invite_count as u32,
921 },
922 )?;
923 }
924 }
925 }
926 Ok(())
927 }
928
929 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
930 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
931 if let Some(invite_code) = &user.invite_code {
932 let pool = self.connection_pool.lock();
933 for connection_id in pool.user_connection_ids(user_id) {
934 self.peer.send(
935 connection_id,
936 proto::UpdateInviteInfo {
937 url: format!(
938 "{}{}",
939 self.app_state.config.invite_link_prefix, invite_code
940 ),
941 count: user.invite_count as u32,
942 },
943 )?;
944 }
945 }
946 }
947 Ok(())
948 }
949
950 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
951 let pool = self.connection_pool.lock();
952 for connection_id in pool.user_connection_ids(user_id) {
953 self.peer
954 .send(connection_id, proto::RefreshLlmToken {})
955 .trace_err();
956 }
957 }
958
959 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
960 ServerSnapshot {
961 connection_pool: ConnectionPoolGuard {
962 guard: self.connection_pool.lock(),
963 _not_send: PhantomData,
964 },
965 peer: &self.peer,
966 }
967 }
968}
969
970impl<'a> Deref for ConnectionPoolGuard<'a> {
971 type Target = ConnectionPool;
972
973 fn deref(&self) -> &Self::Target {
974 &self.guard
975 }
976}
977
978impl<'a> DerefMut for ConnectionPoolGuard<'a> {
979 fn deref_mut(&mut self) -> &mut Self::Target {
980 &mut self.guard
981 }
982}
983
984impl<'a> Drop for ConnectionPoolGuard<'a> {
985 fn drop(&mut self) {
986 #[cfg(test)]
987 self.check_invariants();
988 }
989}
990
991fn broadcast<F>(
992 sender_id: Option<ConnectionId>,
993 receiver_ids: impl IntoIterator<Item = ConnectionId>,
994 mut f: F,
995) where
996 F: FnMut(ConnectionId) -> anyhow::Result<()>,
997{
998 for receiver_id in receiver_ids {
999 if Some(receiver_id) != sender_id {
1000 if let Err(error) = f(receiver_id) {
1001 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1002 }
1003 }
1004 }
1005}
1006
1007pub struct ProtocolVersion(u32);
1008
1009impl Header for ProtocolVersion {
1010 fn name() -> &'static HeaderName {
1011 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1012 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1013 }
1014
1015 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1016 where
1017 Self: Sized,
1018 I: Iterator<Item = &'i axum::http::HeaderValue>,
1019 {
1020 let version = values
1021 .next()
1022 .ok_or_else(axum::headers::Error::invalid)?
1023 .to_str()
1024 .map_err(|_| axum::headers::Error::invalid())?
1025 .parse()
1026 .map_err(|_| axum::headers::Error::invalid())?;
1027 Ok(Self(version))
1028 }
1029
1030 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1031 values.extend([self.0.to_string().parse().unwrap()]);
1032 }
1033}
1034
1035pub struct AppVersionHeader(SemanticVersion);
1036impl Header for AppVersionHeader {
1037 fn name() -> &'static HeaderName {
1038 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1039 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1040 }
1041
1042 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1043 where
1044 Self: Sized,
1045 I: Iterator<Item = &'i axum::http::HeaderValue>,
1046 {
1047 let version = values
1048 .next()
1049 .ok_or_else(axum::headers::Error::invalid)?
1050 .to_str()
1051 .map_err(|_| axum::headers::Error::invalid())?
1052 .parse()
1053 .map_err(|_| axum::headers::Error::invalid())?;
1054 Ok(Self(version))
1055 }
1056
1057 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1058 values.extend([self.0.to_string().parse().unwrap()]);
1059 }
1060}
1061
1062pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1063 Router::new()
1064 .route("/rpc", get(handle_websocket_request))
1065 .layer(
1066 ServiceBuilder::new()
1067 .layer(Extension(server.app_state.clone()))
1068 .layer(middleware::from_fn(auth::validate_header)),
1069 )
1070 .route("/metrics", get(handle_metrics))
1071 .layer(Extension(server))
1072}
1073
1074#[allow(clippy::too_many_arguments)]
1075pub async fn handle_websocket_request(
1076 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1077 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1078 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1079 Extension(server): Extension<Arc<Server>>,
1080 Extension(principal): Extension<Principal>,
1081 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1082 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1083 ws: WebSocketUpgrade,
1084) -> axum::response::Response {
1085 if protocol_version != rpc::PROTOCOL_VERSION {
1086 return (
1087 StatusCode::UPGRADE_REQUIRED,
1088 "client must be upgraded".to_string(),
1089 )
1090 .into_response();
1091 }
1092
1093 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1094 return (
1095 StatusCode::UPGRADE_REQUIRED,
1096 "no version header found".to_string(),
1097 )
1098 .into_response();
1099 };
1100
1101 if !version.can_collaborate() {
1102 return (
1103 StatusCode::UPGRADE_REQUIRED,
1104 "client must be upgraded".to_string(),
1105 )
1106 .into_response();
1107 }
1108
1109 let socket_address = socket_address.to_string();
1110 ws.on_upgrade(move |socket| {
1111 let socket = socket
1112 .map_ok(to_tungstenite_message)
1113 .err_into()
1114 .with(|message| async move { to_axum_message(message) });
1115 let connection = Connection::new(Box::pin(socket));
1116 async move {
1117 server
1118 .handle_connection(
1119 connection,
1120 socket_address,
1121 principal,
1122 version,
1123 country_code_header.map(|header| header.to_string()),
1124 system_id_header.map(|header| header.to_string()),
1125 None,
1126 Executor::Production,
1127 )
1128 .await;
1129 }
1130 })
1131}
1132
1133pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1134 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1135 let connections_metric = CONNECTIONS_METRIC
1136 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1137
1138 let connections = server
1139 .connection_pool
1140 .lock()
1141 .connections()
1142 .filter(|connection| !connection.admin)
1143 .count();
1144 connections_metric.set(connections as _);
1145
1146 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1147 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1148 register_int_gauge!(
1149 "shared_projects",
1150 "number of open projects with one or more guests"
1151 )
1152 .unwrap()
1153 });
1154
1155 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1156 shared_projects_metric.set(shared_projects as _);
1157
1158 let encoder = prometheus::TextEncoder::new();
1159 let metric_families = prometheus::gather();
1160 let encoded_metrics = encoder
1161 .encode_to_string(&metric_families)
1162 .map_err(|err| anyhow!("{}", err))?;
1163 Ok(encoded_metrics)
1164}
1165
1166#[instrument(err, skip(executor))]
1167async fn connection_lost(
1168 session: Session,
1169 mut teardown: watch::Receiver<bool>,
1170 executor: Executor,
1171) -> Result<()> {
1172 session.peer.disconnect(session.connection_id);
1173 session
1174 .connection_pool()
1175 .await
1176 .remove_connection(session.connection_id)?;
1177
1178 session
1179 .db()
1180 .await
1181 .connection_lost(session.connection_id)
1182 .await
1183 .trace_err();
1184
1185 futures::select_biased! {
1186 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1187
1188 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1189 leave_room_for_session(&session, session.connection_id).await.trace_err();
1190 leave_channel_buffers_for_session(&session)
1191 .await
1192 .trace_err();
1193
1194 if !session
1195 .connection_pool()
1196 .await
1197 .is_user_online(session.user_id())
1198 {
1199 let db = session.db().await;
1200 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1201 room_updated(&room, &session.peer);
1202 }
1203 }
1204
1205 update_user_contacts(session.user_id(), &session).await?;
1206 },
1207 _ = teardown.changed().fuse() => {}
1208 }
1209
1210 Ok(())
1211}
1212
1213/// Acknowledges a ping from a client, used to keep the connection alive.
1214async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1215 response.send(proto::Ack {})?;
1216 Ok(())
1217}
1218
1219/// Creates a new room for calling (outside of channels)
1220async fn create_room(
1221 _request: proto::CreateRoom,
1222 response: Response<proto::CreateRoom>,
1223 session: Session,
1224) -> Result<()> {
1225 let livekit_room = nanoid::nanoid!(30);
1226
1227 let live_kit_connection_info = util::maybe!(async {
1228 let live_kit = session.app_state.livekit_client.as_ref();
1229 let live_kit = live_kit?;
1230 let user_id = session.user_id().to_string();
1231
1232 let token = live_kit
1233 .room_token(&livekit_room, &user_id.to_string())
1234 .trace_err()?;
1235
1236 Some(proto::LiveKitConnectionInfo {
1237 server_url: live_kit.url().into(),
1238 token,
1239 can_publish: true,
1240 })
1241 })
1242 .await;
1243
1244 let room = session
1245 .db()
1246 .await
1247 .create_room(session.user_id(), session.connection_id, &livekit_room)
1248 .await?;
1249
1250 response.send(proto::CreateRoomResponse {
1251 room: Some(room.clone()),
1252 live_kit_connection_info,
1253 })?;
1254
1255 update_user_contacts(session.user_id(), &session).await?;
1256 Ok(())
1257}
1258
1259/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1260async fn join_room(
1261 request: proto::JoinRoom,
1262 response: Response<proto::JoinRoom>,
1263 session: Session,
1264) -> Result<()> {
1265 let room_id = RoomId::from_proto(request.id);
1266
1267 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1268
1269 if let Some(channel_id) = channel_id {
1270 return join_channel_internal(channel_id, Box::new(response), session).await;
1271 }
1272
1273 let joined_room = {
1274 let room = session
1275 .db()
1276 .await
1277 .join_room(room_id, session.user_id(), session.connection_id)
1278 .await?;
1279 room_updated(&room.room, &session.peer);
1280 room.into_inner()
1281 };
1282
1283 for connection_id in session
1284 .connection_pool()
1285 .await
1286 .user_connection_ids(session.user_id())
1287 {
1288 session
1289 .peer
1290 .send(
1291 connection_id,
1292 proto::CallCanceled {
1293 room_id: room_id.to_proto(),
1294 },
1295 )
1296 .trace_err();
1297 }
1298
1299 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1300 {
1301 live_kit
1302 .room_token(
1303 &joined_room.room.livekit_room,
1304 &session.user_id().to_string(),
1305 )
1306 .trace_err()
1307 .map(|token| proto::LiveKitConnectionInfo {
1308 server_url: live_kit.url().into(),
1309 token,
1310 can_publish: true,
1311 })
1312 } else {
1313 None
1314 };
1315
1316 response.send(proto::JoinRoomResponse {
1317 room: Some(joined_room.room),
1318 channel_id: None,
1319 live_kit_connection_info,
1320 })?;
1321
1322 update_user_contacts(session.user_id(), &session).await?;
1323 Ok(())
1324}
1325
1326/// Rejoin room is used to reconnect to a room after connection errors.
1327async fn rejoin_room(
1328 request: proto::RejoinRoom,
1329 response: Response<proto::RejoinRoom>,
1330 session: Session,
1331) -> Result<()> {
1332 let room;
1333 let channel;
1334 {
1335 let mut rejoined_room = session
1336 .db()
1337 .await
1338 .rejoin_room(request, session.user_id(), session.connection_id)
1339 .await?;
1340
1341 response.send(proto::RejoinRoomResponse {
1342 room: Some(rejoined_room.room.clone()),
1343 reshared_projects: rejoined_room
1344 .reshared_projects
1345 .iter()
1346 .map(|project| proto::ResharedProject {
1347 id: project.id.to_proto(),
1348 collaborators: project
1349 .collaborators
1350 .iter()
1351 .map(|collaborator| collaborator.to_proto())
1352 .collect(),
1353 })
1354 .collect(),
1355 rejoined_projects: rejoined_room
1356 .rejoined_projects
1357 .iter()
1358 .map(|rejoined_project| rejoined_project.to_proto())
1359 .collect(),
1360 })?;
1361 room_updated(&rejoined_room.room, &session.peer);
1362
1363 for project in &rejoined_room.reshared_projects {
1364 for collaborator in &project.collaborators {
1365 session
1366 .peer
1367 .send(
1368 collaborator.connection_id,
1369 proto::UpdateProjectCollaborator {
1370 project_id: project.id.to_proto(),
1371 old_peer_id: Some(project.old_connection_id.into()),
1372 new_peer_id: Some(session.connection_id.into()),
1373 },
1374 )
1375 .trace_err();
1376 }
1377
1378 broadcast(
1379 Some(session.connection_id),
1380 project
1381 .collaborators
1382 .iter()
1383 .map(|collaborator| collaborator.connection_id),
1384 |connection_id| {
1385 session.peer.forward_send(
1386 session.connection_id,
1387 connection_id,
1388 proto::UpdateProject {
1389 project_id: project.id.to_proto(),
1390 worktrees: project.worktrees.clone(),
1391 },
1392 )
1393 },
1394 );
1395 }
1396
1397 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1398
1399 let rejoined_room = rejoined_room.into_inner();
1400
1401 room = rejoined_room.room;
1402 channel = rejoined_room.channel;
1403 }
1404
1405 if let Some(channel) = channel {
1406 channel_updated(
1407 &channel,
1408 &room,
1409 &session.peer,
1410 &*session.connection_pool().await,
1411 );
1412 }
1413
1414 update_user_contacts(session.user_id(), &session).await?;
1415 Ok(())
1416}
1417
1418fn notify_rejoined_projects(
1419 rejoined_projects: &mut Vec<RejoinedProject>,
1420 session: &Session,
1421) -> Result<()> {
1422 for project in rejoined_projects.iter() {
1423 for collaborator in &project.collaborators {
1424 session
1425 .peer
1426 .send(
1427 collaborator.connection_id,
1428 proto::UpdateProjectCollaborator {
1429 project_id: project.id.to_proto(),
1430 old_peer_id: Some(project.old_connection_id.into()),
1431 new_peer_id: Some(session.connection_id.into()),
1432 },
1433 )
1434 .trace_err();
1435 }
1436 }
1437
1438 for project in rejoined_projects {
1439 for worktree in mem::take(&mut project.worktrees) {
1440 // Stream this worktree's entries.
1441 let message = proto::UpdateWorktree {
1442 project_id: project.id.to_proto(),
1443 worktree_id: worktree.id,
1444 abs_path: worktree.abs_path.clone(),
1445 root_name: worktree.root_name,
1446 updated_entries: worktree.updated_entries,
1447 removed_entries: worktree.removed_entries,
1448 scan_id: worktree.scan_id,
1449 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1450 updated_repositories: worktree.updated_repositories,
1451 removed_repositories: worktree.removed_repositories,
1452 };
1453 for update in proto::split_worktree_update(message) {
1454 session.peer.send(session.connection_id, update.clone())?;
1455 }
1456
1457 // Stream this worktree's diagnostics.
1458 for summary in worktree.diagnostic_summaries {
1459 session.peer.send(
1460 session.connection_id,
1461 proto::UpdateDiagnosticSummary {
1462 project_id: project.id.to_proto(),
1463 worktree_id: worktree.id,
1464 summary: Some(summary),
1465 },
1466 )?;
1467 }
1468
1469 for settings_file in worktree.settings_files {
1470 session.peer.send(
1471 session.connection_id,
1472 proto::UpdateWorktreeSettings {
1473 project_id: project.id.to_proto(),
1474 worktree_id: worktree.id,
1475 path: settings_file.path,
1476 content: Some(settings_file.content),
1477 kind: Some(settings_file.kind.to_proto().into()),
1478 },
1479 )?;
1480 }
1481 }
1482
1483 for language_server in &project.language_servers {
1484 session.peer.send(
1485 session.connection_id,
1486 proto::UpdateLanguageServer {
1487 project_id: project.id.to_proto(),
1488 language_server_id: language_server.id,
1489 variant: Some(
1490 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1491 proto::LspDiskBasedDiagnosticsUpdated {},
1492 ),
1493 ),
1494 },
1495 )?;
1496 }
1497 }
1498 Ok(())
1499}
1500
1501/// leave room disconnects from the room.
1502async fn leave_room(
1503 _: proto::LeaveRoom,
1504 response: Response<proto::LeaveRoom>,
1505 session: Session,
1506) -> Result<()> {
1507 leave_room_for_session(&session, session.connection_id).await?;
1508 response.send(proto::Ack {})?;
1509 Ok(())
1510}
1511
1512/// Updates the permissions of someone else in the room.
1513async fn set_room_participant_role(
1514 request: proto::SetRoomParticipantRole,
1515 response: Response<proto::SetRoomParticipantRole>,
1516 session: Session,
1517) -> Result<()> {
1518 let user_id = UserId::from_proto(request.user_id);
1519 let role = ChannelRole::from(request.role());
1520
1521 let (livekit_room, can_publish) = {
1522 let room = session
1523 .db()
1524 .await
1525 .set_room_participant_role(
1526 session.user_id(),
1527 RoomId::from_proto(request.room_id),
1528 user_id,
1529 role,
1530 )
1531 .await?;
1532
1533 let livekit_room = room.livekit_room.clone();
1534 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1535 room_updated(&room, &session.peer);
1536 (livekit_room, can_publish)
1537 };
1538
1539 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1540 live_kit
1541 .update_participant(
1542 livekit_room.clone(),
1543 request.user_id.to_string(),
1544 livekit_server::proto::ParticipantPermission {
1545 can_subscribe: true,
1546 can_publish,
1547 can_publish_data: can_publish,
1548 hidden: false,
1549 recorder: false,
1550 },
1551 )
1552 .await
1553 .trace_err();
1554 }
1555
1556 response.send(proto::Ack {})?;
1557 Ok(())
1558}
1559
1560/// Call someone else into the current room
1561async fn call(
1562 request: proto::Call,
1563 response: Response<proto::Call>,
1564 session: Session,
1565) -> Result<()> {
1566 let room_id = RoomId::from_proto(request.room_id);
1567 let calling_user_id = session.user_id();
1568 let calling_connection_id = session.connection_id;
1569 let called_user_id = UserId::from_proto(request.called_user_id);
1570 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1571 if !session
1572 .db()
1573 .await
1574 .has_contact(calling_user_id, called_user_id)
1575 .await?
1576 {
1577 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1578 }
1579
1580 let incoming_call = {
1581 let (room, incoming_call) = &mut *session
1582 .db()
1583 .await
1584 .call(
1585 room_id,
1586 calling_user_id,
1587 calling_connection_id,
1588 called_user_id,
1589 initial_project_id,
1590 )
1591 .await?;
1592 room_updated(room, &session.peer);
1593 mem::take(incoming_call)
1594 };
1595 update_user_contacts(called_user_id, &session).await?;
1596
1597 let mut calls = session
1598 .connection_pool()
1599 .await
1600 .user_connection_ids(called_user_id)
1601 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1602 .collect::<FuturesUnordered<_>>();
1603
1604 while let Some(call_response) = calls.next().await {
1605 match call_response.as_ref() {
1606 Ok(_) => {
1607 response.send(proto::Ack {})?;
1608 return Ok(());
1609 }
1610 Err(_) => {
1611 call_response.trace_err();
1612 }
1613 }
1614 }
1615
1616 {
1617 let room = session
1618 .db()
1619 .await
1620 .call_failed(room_id, called_user_id)
1621 .await?;
1622 room_updated(&room, &session.peer);
1623 }
1624 update_user_contacts(called_user_id, &session).await?;
1625
1626 Err(anyhow!("failed to ring user"))?
1627}
1628
1629/// Cancel an outgoing call.
1630async fn cancel_call(
1631 request: proto::CancelCall,
1632 response: Response<proto::CancelCall>,
1633 session: Session,
1634) -> Result<()> {
1635 let called_user_id = UserId::from_proto(request.called_user_id);
1636 let room_id = RoomId::from_proto(request.room_id);
1637 {
1638 let room = session
1639 .db()
1640 .await
1641 .cancel_call(room_id, session.connection_id, called_user_id)
1642 .await?;
1643 room_updated(&room, &session.peer);
1644 }
1645
1646 for connection_id in session
1647 .connection_pool()
1648 .await
1649 .user_connection_ids(called_user_id)
1650 {
1651 session
1652 .peer
1653 .send(
1654 connection_id,
1655 proto::CallCanceled {
1656 room_id: room_id.to_proto(),
1657 },
1658 )
1659 .trace_err();
1660 }
1661 response.send(proto::Ack {})?;
1662
1663 update_user_contacts(called_user_id, &session).await?;
1664 Ok(())
1665}
1666
1667/// Decline an incoming call.
1668async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1669 let room_id = RoomId::from_proto(message.room_id);
1670 {
1671 let room = session
1672 .db()
1673 .await
1674 .decline_call(Some(room_id), session.user_id())
1675 .await?
1676 .ok_or_else(|| anyhow!("failed to decline call"))?;
1677 room_updated(&room, &session.peer);
1678 }
1679
1680 for connection_id in session
1681 .connection_pool()
1682 .await
1683 .user_connection_ids(session.user_id())
1684 {
1685 session
1686 .peer
1687 .send(
1688 connection_id,
1689 proto::CallCanceled {
1690 room_id: room_id.to_proto(),
1691 },
1692 )
1693 .trace_err();
1694 }
1695 update_user_contacts(session.user_id(), &session).await?;
1696 Ok(())
1697}
1698
1699/// Updates other participants in the room with your current location.
1700async fn update_participant_location(
1701 request: proto::UpdateParticipantLocation,
1702 response: Response<proto::UpdateParticipantLocation>,
1703 session: Session,
1704) -> Result<()> {
1705 let room_id = RoomId::from_proto(request.room_id);
1706 let location = request
1707 .location
1708 .ok_or_else(|| anyhow!("invalid location"))?;
1709
1710 let db = session.db().await;
1711 let room = db
1712 .update_room_participant_location(room_id, session.connection_id, location)
1713 .await?;
1714
1715 room_updated(&room, &session.peer);
1716 response.send(proto::Ack {})?;
1717 Ok(())
1718}
1719
1720/// Share a project into the room.
1721async fn share_project(
1722 request: proto::ShareProject,
1723 response: Response<proto::ShareProject>,
1724 session: Session,
1725) -> Result<()> {
1726 let (project_id, room) = &*session
1727 .db()
1728 .await
1729 .share_project(
1730 RoomId::from_proto(request.room_id),
1731 session.connection_id,
1732 &request.worktrees,
1733 request.is_ssh_project,
1734 )
1735 .await?;
1736 response.send(proto::ShareProjectResponse {
1737 project_id: project_id.to_proto(),
1738 })?;
1739 room_updated(room, &session.peer);
1740
1741 Ok(())
1742}
1743
1744/// Unshare a project from the room.
1745async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1746 let project_id = ProjectId::from_proto(message.project_id);
1747 unshare_project_internal(project_id, session.connection_id, &session).await
1748}
1749
1750async fn unshare_project_internal(
1751 project_id: ProjectId,
1752 connection_id: ConnectionId,
1753 session: &Session,
1754) -> Result<()> {
1755 let delete = {
1756 let room_guard = session
1757 .db()
1758 .await
1759 .unshare_project(project_id, connection_id)
1760 .await?;
1761
1762 let (delete, room, guest_connection_ids) = &*room_guard;
1763
1764 let message = proto::UnshareProject {
1765 project_id: project_id.to_proto(),
1766 };
1767
1768 broadcast(
1769 Some(connection_id),
1770 guest_connection_ids.iter().copied(),
1771 |conn_id| session.peer.send(conn_id, message.clone()),
1772 );
1773 if let Some(room) = room {
1774 room_updated(room, &session.peer);
1775 }
1776
1777 *delete
1778 };
1779
1780 if delete {
1781 let db = session.db().await;
1782 db.delete_project(project_id).await?;
1783 }
1784
1785 Ok(())
1786}
1787
1788/// Join someone elses shared project.
1789async fn join_project(
1790 request: proto::JoinProject,
1791 response: Response<proto::JoinProject>,
1792 session: Session,
1793) -> Result<()> {
1794 let project_id = ProjectId::from_proto(request.project_id);
1795
1796 tracing::info!(%project_id, "join project");
1797
1798 let db = session.db().await;
1799 let (project, replica_id) = &mut *db
1800 .join_project(project_id, session.connection_id, session.user_id())
1801 .await?;
1802 drop(db);
1803 tracing::info!(%project_id, "join remote project");
1804 join_project_internal(response, session, project, replica_id)
1805}
1806
1807trait JoinProjectInternalResponse {
1808 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1809}
1810impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1811 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1812 Response::<proto::JoinProject>::send(self, result)
1813 }
1814}
1815
1816fn join_project_internal(
1817 response: impl JoinProjectInternalResponse,
1818 session: Session,
1819 project: &mut Project,
1820 replica_id: &ReplicaId,
1821) -> Result<()> {
1822 let collaborators = project
1823 .collaborators
1824 .iter()
1825 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1826 .map(|collaborator| collaborator.to_proto())
1827 .collect::<Vec<_>>();
1828 let project_id = project.id;
1829 let guest_user_id = session.user_id();
1830
1831 let worktrees = project
1832 .worktrees
1833 .iter()
1834 .map(|(id, worktree)| proto::WorktreeMetadata {
1835 id: *id,
1836 root_name: worktree.root_name.clone(),
1837 visible: worktree.visible,
1838 abs_path: worktree.abs_path.clone(),
1839 })
1840 .collect::<Vec<_>>();
1841
1842 let add_project_collaborator = proto::AddProjectCollaborator {
1843 project_id: project_id.to_proto(),
1844 collaborator: Some(proto::Collaborator {
1845 peer_id: Some(session.connection_id.into()),
1846 replica_id: replica_id.0 as u32,
1847 user_id: guest_user_id.to_proto(),
1848 is_host: false,
1849 }),
1850 };
1851
1852 for collaborator in &collaborators {
1853 session
1854 .peer
1855 .send(
1856 collaborator.peer_id.unwrap().into(),
1857 add_project_collaborator.clone(),
1858 )
1859 .trace_err();
1860 }
1861
1862 // First, we send the metadata associated with each worktree.
1863 response.send(proto::JoinProjectResponse {
1864 project_id: project.id.0 as u64,
1865 worktrees: worktrees.clone(),
1866 replica_id: replica_id.0 as u32,
1867 collaborators: collaborators.clone(),
1868 language_servers: project.language_servers.clone(),
1869 role: project.role.into(),
1870 })?;
1871
1872 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1873 // Stream this worktree's entries.
1874 let message = proto::UpdateWorktree {
1875 project_id: project_id.to_proto(),
1876 worktree_id,
1877 abs_path: worktree.abs_path.clone(),
1878 root_name: worktree.root_name,
1879 updated_entries: worktree.entries,
1880 removed_entries: Default::default(),
1881 scan_id: worktree.scan_id,
1882 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1883 updated_repositories: worktree.repository_entries.into_values().collect(),
1884 removed_repositories: Default::default(),
1885 };
1886 for update in proto::split_worktree_update(message) {
1887 session.peer.send(session.connection_id, update.clone())?;
1888 }
1889
1890 // Stream this worktree's diagnostics.
1891 for summary in worktree.diagnostic_summaries {
1892 session.peer.send(
1893 session.connection_id,
1894 proto::UpdateDiagnosticSummary {
1895 project_id: project_id.to_proto(),
1896 worktree_id: worktree.id,
1897 summary: Some(summary),
1898 },
1899 )?;
1900 }
1901
1902 for settings_file in worktree.settings_files {
1903 session.peer.send(
1904 session.connection_id,
1905 proto::UpdateWorktreeSettings {
1906 project_id: project_id.to_proto(),
1907 worktree_id: worktree.id,
1908 path: settings_file.path,
1909 content: Some(settings_file.content),
1910 kind: Some(settings_file.kind.to_proto() as i32),
1911 },
1912 )?;
1913 }
1914 }
1915
1916 for language_server in &project.language_servers {
1917 session.peer.send(
1918 session.connection_id,
1919 proto::UpdateLanguageServer {
1920 project_id: project_id.to_proto(),
1921 language_server_id: language_server.id,
1922 variant: Some(
1923 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1924 proto::LspDiskBasedDiagnosticsUpdated {},
1925 ),
1926 ),
1927 },
1928 )?;
1929 }
1930
1931 Ok(())
1932}
1933
1934/// Leave someone elses shared project.
1935async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1936 let sender_id = session.connection_id;
1937 let project_id = ProjectId::from_proto(request.project_id);
1938 let db = session.db().await;
1939
1940 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1941 tracing::info!(
1942 %project_id,
1943 "leave project"
1944 );
1945
1946 project_left(project, &session);
1947 if let Some(room) = room {
1948 room_updated(room, &session.peer);
1949 }
1950
1951 Ok(())
1952}
1953
1954/// Updates other participants with changes to the project
1955async fn update_project(
1956 request: proto::UpdateProject,
1957 response: Response<proto::UpdateProject>,
1958 session: Session,
1959) -> Result<()> {
1960 let project_id = ProjectId::from_proto(request.project_id);
1961 let (room, guest_connection_ids) = &*session
1962 .db()
1963 .await
1964 .update_project(project_id, session.connection_id, &request.worktrees)
1965 .await?;
1966 broadcast(
1967 Some(session.connection_id),
1968 guest_connection_ids.iter().copied(),
1969 |connection_id| {
1970 session
1971 .peer
1972 .forward_send(session.connection_id, connection_id, request.clone())
1973 },
1974 );
1975 if let Some(room) = room {
1976 room_updated(room, &session.peer);
1977 }
1978 response.send(proto::Ack {})?;
1979
1980 Ok(())
1981}
1982
1983/// Updates other participants with changes to the worktree
1984async fn update_worktree(
1985 request: proto::UpdateWorktree,
1986 response: Response<proto::UpdateWorktree>,
1987 session: Session,
1988) -> Result<()> {
1989 let guest_connection_ids = session
1990 .db()
1991 .await
1992 .update_worktree(&request, session.connection_id)
1993 .await?;
1994
1995 broadcast(
1996 Some(session.connection_id),
1997 guest_connection_ids.iter().copied(),
1998 |connection_id| {
1999 session
2000 .peer
2001 .forward_send(session.connection_id, connection_id, request.clone())
2002 },
2003 );
2004 response.send(proto::Ack {})?;
2005 Ok(())
2006}
2007
2008/// Updates other participants with changes to the diagnostics
2009async fn update_diagnostic_summary(
2010 message: proto::UpdateDiagnosticSummary,
2011 session: Session,
2012) -> Result<()> {
2013 let guest_connection_ids = session
2014 .db()
2015 .await
2016 .update_diagnostic_summary(&message, session.connection_id)
2017 .await?;
2018
2019 broadcast(
2020 Some(session.connection_id),
2021 guest_connection_ids.iter().copied(),
2022 |connection_id| {
2023 session
2024 .peer
2025 .forward_send(session.connection_id, connection_id, message.clone())
2026 },
2027 );
2028
2029 Ok(())
2030}
2031
2032/// Updates other participants with changes to the worktree settings
2033async fn update_worktree_settings(
2034 message: proto::UpdateWorktreeSettings,
2035 session: Session,
2036) -> Result<()> {
2037 let guest_connection_ids = session
2038 .db()
2039 .await
2040 .update_worktree_settings(&message, session.connection_id)
2041 .await?;
2042
2043 broadcast(
2044 Some(session.connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |connection_id| {
2047 session
2048 .peer
2049 .forward_send(session.connection_id, connection_id, message.clone())
2050 },
2051 );
2052
2053 Ok(())
2054}
2055
2056/// Notify other participants that a language server has started.
2057async fn start_language_server(
2058 request: proto::StartLanguageServer,
2059 session: Session,
2060) -> Result<()> {
2061 let guest_connection_ids = session
2062 .db()
2063 .await
2064 .start_language_server(&request, session.connection_id)
2065 .await?;
2066
2067 broadcast(
2068 Some(session.connection_id),
2069 guest_connection_ids.iter().copied(),
2070 |connection_id| {
2071 session
2072 .peer
2073 .forward_send(session.connection_id, connection_id, request.clone())
2074 },
2075 );
2076 Ok(())
2077}
2078
2079/// Notify other participants that a language server has changed.
2080async fn update_language_server(
2081 request: proto::UpdateLanguageServer,
2082 session: Session,
2083) -> Result<()> {
2084 let project_id = ProjectId::from_proto(request.project_id);
2085 let project_connection_ids = session
2086 .db()
2087 .await
2088 .project_connection_ids(project_id, session.connection_id, true)
2089 .await?;
2090 broadcast(
2091 Some(session.connection_id),
2092 project_connection_ids.iter().copied(),
2093 |connection_id| {
2094 session
2095 .peer
2096 .forward_send(session.connection_id, connection_id, request.clone())
2097 },
2098 );
2099 Ok(())
2100}
2101
2102/// forward a project request to the host. These requests should be read only
2103/// as guests are allowed to send them.
2104async fn forward_read_only_project_request<T>(
2105 request: T,
2106 response: Response<T>,
2107 session: Session,
2108) -> Result<()>
2109where
2110 T: EntityMessage + RequestMessage,
2111{
2112 let project_id = ProjectId::from_proto(request.remote_entity_id());
2113 let host_connection_id = session
2114 .db()
2115 .await
2116 .host_for_read_only_project_request(project_id, session.connection_id)
2117 .await?;
2118 let payload = session
2119 .peer
2120 .forward_request(session.connection_id, host_connection_id, request)
2121 .await?;
2122 response.send(payload)?;
2123 Ok(())
2124}
2125
2126async fn forward_find_search_candidates_request(
2127 request: proto::FindSearchCandidates,
2128 response: Response<proto::FindSearchCandidates>,
2129 session: Session,
2130) -> Result<()> {
2131 let project_id = ProjectId::from_proto(request.remote_entity_id());
2132 let host_connection_id = session
2133 .db()
2134 .await
2135 .host_for_read_only_project_request(project_id, session.connection_id)
2136 .await?;
2137 let payload = session
2138 .peer
2139 .forward_request(session.connection_id, host_connection_id, request)
2140 .await?;
2141 response.send(payload)?;
2142 Ok(())
2143}
2144
2145/// forward a project request to the host. These requests are disallowed
2146/// for guests.
2147async fn forward_mutating_project_request<T>(
2148 request: T,
2149 response: Response<T>,
2150 session: Session,
2151) -> Result<()>
2152where
2153 T: EntityMessage + RequestMessage,
2154{
2155 let project_id = ProjectId::from_proto(request.remote_entity_id());
2156
2157 let host_connection_id = session
2158 .db()
2159 .await
2160 .host_for_mutating_project_request(project_id, session.connection_id)
2161 .await?;
2162 let payload = session
2163 .peer
2164 .forward_request(session.connection_id, host_connection_id, request)
2165 .await?;
2166 response.send(payload)?;
2167 Ok(())
2168}
2169
2170/// Notify other participants that a new buffer has been created
2171async fn create_buffer_for_peer(
2172 request: proto::CreateBufferForPeer,
2173 session: Session,
2174) -> Result<()> {
2175 session
2176 .db()
2177 .await
2178 .check_user_is_project_host(
2179 ProjectId::from_proto(request.project_id),
2180 session.connection_id,
2181 )
2182 .await?;
2183 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2184 session
2185 .peer
2186 .forward_send(session.connection_id, peer_id.into(), request)?;
2187 Ok(())
2188}
2189
2190/// Notify other participants that a buffer has been updated. This is
2191/// allowed for guests as long as the update is limited to selections.
2192async fn update_buffer(
2193 request: proto::UpdateBuffer,
2194 response: Response<proto::UpdateBuffer>,
2195 session: Session,
2196) -> Result<()> {
2197 let project_id = ProjectId::from_proto(request.project_id);
2198 let mut capability = Capability::ReadOnly;
2199
2200 for op in request.operations.iter() {
2201 match op.variant {
2202 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2203 Some(_) => capability = Capability::ReadWrite,
2204 }
2205 }
2206
2207 let host = {
2208 let guard = session
2209 .db()
2210 .await
2211 .connections_for_buffer_update(project_id, session.connection_id, capability)
2212 .await?;
2213
2214 let (host, guests) = &*guard;
2215
2216 broadcast(
2217 Some(session.connection_id),
2218 guests.clone(),
2219 |connection_id| {
2220 session
2221 .peer
2222 .forward_send(session.connection_id, connection_id, request.clone())
2223 },
2224 );
2225
2226 *host
2227 };
2228
2229 if host != session.connection_id {
2230 session
2231 .peer
2232 .forward_request(session.connection_id, host, request.clone())
2233 .await?;
2234 }
2235
2236 response.send(proto::Ack {})?;
2237 Ok(())
2238}
2239
2240async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2241 let project_id = ProjectId::from_proto(message.project_id);
2242
2243 let operation = message.operation.as_ref().context("invalid operation")?;
2244 let capability = match operation.variant.as_ref() {
2245 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2246 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2247 match buffer_op.variant {
2248 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2249 Capability::ReadOnly
2250 }
2251 _ => Capability::ReadWrite,
2252 }
2253 } else {
2254 Capability::ReadWrite
2255 }
2256 }
2257 Some(_) => Capability::ReadWrite,
2258 None => Capability::ReadOnly,
2259 };
2260
2261 let guard = session
2262 .db()
2263 .await
2264 .connections_for_buffer_update(project_id, session.connection_id, capability)
2265 .await?;
2266
2267 let (host, guests) = &*guard;
2268
2269 broadcast(
2270 Some(session.connection_id),
2271 guests.iter().chain([host]).copied(),
2272 |connection_id| {
2273 session
2274 .peer
2275 .forward_send(session.connection_id, connection_id, message.clone())
2276 },
2277 );
2278
2279 Ok(())
2280}
2281
2282/// Notify other participants that a project has been updated.
2283async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2284 request: T,
2285 session: Session,
2286) -> Result<()> {
2287 let project_id = ProjectId::from_proto(request.remote_entity_id());
2288 let project_connection_ids = session
2289 .db()
2290 .await
2291 .project_connection_ids(project_id, session.connection_id, false)
2292 .await?;
2293
2294 broadcast(
2295 Some(session.connection_id),
2296 project_connection_ids.iter().copied(),
2297 |connection_id| {
2298 session
2299 .peer
2300 .forward_send(session.connection_id, connection_id, request.clone())
2301 },
2302 );
2303 Ok(())
2304}
2305
2306/// Start following another user in a call.
2307async fn follow(
2308 request: proto::Follow,
2309 response: Response<proto::Follow>,
2310 session: Session,
2311) -> Result<()> {
2312 let room_id = RoomId::from_proto(request.room_id);
2313 let project_id = request.project_id.map(ProjectId::from_proto);
2314 let leader_id = request
2315 .leader_id
2316 .ok_or_else(|| anyhow!("invalid leader id"))?
2317 .into();
2318 let follower_id = session.connection_id;
2319
2320 session
2321 .db()
2322 .await
2323 .check_room_participants(room_id, leader_id, session.connection_id)
2324 .await?;
2325
2326 let response_payload = session
2327 .peer
2328 .forward_request(session.connection_id, leader_id, request)
2329 .await?;
2330 response.send(response_payload)?;
2331
2332 if let Some(project_id) = project_id {
2333 let room = session
2334 .db()
2335 .await
2336 .follow(room_id, project_id, leader_id, follower_id)
2337 .await?;
2338 room_updated(&room, &session.peer);
2339 }
2340
2341 Ok(())
2342}
2343
2344/// Stop following another user in a call.
2345async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2346 let room_id = RoomId::from_proto(request.room_id);
2347 let project_id = request.project_id.map(ProjectId::from_proto);
2348 let leader_id = request
2349 .leader_id
2350 .ok_or_else(|| anyhow!("invalid leader id"))?
2351 .into();
2352 let follower_id = session.connection_id;
2353
2354 session
2355 .db()
2356 .await
2357 .check_room_participants(room_id, leader_id, session.connection_id)
2358 .await?;
2359
2360 session
2361 .peer
2362 .forward_send(session.connection_id, leader_id, request)?;
2363
2364 if let Some(project_id) = project_id {
2365 let room = session
2366 .db()
2367 .await
2368 .unfollow(room_id, project_id, leader_id, follower_id)
2369 .await?;
2370 room_updated(&room, &session.peer);
2371 }
2372
2373 Ok(())
2374}
2375
2376/// Notify everyone following you of your current location.
2377async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2378 let room_id = RoomId::from_proto(request.room_id);
2379 let database = session.db.lock().await;
2380
2381 let connection_ids = if let Some(project_id) = request.project_id {
2382 let project_id = ProjectId::from_proto(project_id);
2383 database
2384 .project_connection_ids(project_id, session.connection_id, true)
2385 .await?
2386 } else {
2387 database
2388 .room_connection_ids(room_id, session.connection_id)
2389 .await?
2390 };
2391
2392 // For now, don't send view update messages back to that view's current leader.
2393 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2394 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2395 _ => None,
2396 });
2397
2398 for connection_id in connection_ids.iter().cloned() {
2399 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2400 session
2401 .peer
2402 .forward_send(session.connection_id, connection_id, request.clone())?;
2403 }
2404 }
2405 Ok(())
2406}
2407
2408/// Get public data about users.
2409async fn get_users(
2410 request: proto::GetUsers,
2411 response: Response<proto::GetUsers>,
2412 session: Session,
2413) -> Result<()> {
2414 let user_ids = request
2415 .user_ids
2416 .into_iter()
2417 .map(UserId::from_proto)
2418 .collect();
2419 let users = session
2420 .db()
2421 .await
2422 .get_users_by_ids(user_ids)
2423 .await?
2424 .into_iter()
2425 .map(|user| proto::User {
2426 id: user.id.to_proto(),
2427 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2428 github_login: user.github_login,
2429 email: user.email_address,
2430 name: user.name,
2431 })
2432 .collect();
2433 response.send(proto::UsersResponse { users })?;
2434 Ok(())
2435}
2436
2437/// Search for users (to invite) buy Github login
2438async fn fuzzy_search_users(
2439 request: proto::FuzzySearchUsers,
2440 response: Response<proto::FuzzySearchUsers>,
2441 session: Session,
2442) -> Result<()> {
2443 let query = request.query;
2444 let users = match query.len() {
2445 0 => vec![],
2446 1 | 2 => session
2447 .db()
2448 .await
2449 .get_user_by_github_login(&query)
2450 .await?
2451 .into_iter()
2452 .collect(),
2453 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2454 };
2455 let users = users
2456 .into_iter()
2457 .filter(|user| user.id != session.user_id())
2458 .map(|user| proto::User {
2459 id: user.id.to_proto(),
2460 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2461 github_login: user.github_login,
2462 name: user.name,
2463 email: user.email_address,
2464 })
2465 .collect();
2466 response.send(proto::UsersResponse { users })?;
2467 Ok(())
2468}
2469
2470/// Send a contact request to another user.
2471async fn request_contact(
2472 request: proto::RequestContact,
2473 response: Response<proto::RequestContact>,
2474 session: Session,
2475) -> Result<()> {
2476 let requester_id = session.user_id();
2477 let responder_id = UserId::from_proto(request.responder_id);
2478 if requester_id == responder_id {
2479 return Err(anyhow!("cannot add yourself as a contact"))?;
2480 }
2481
2482 let notifications = session
2483 .db()
2484 .await
2485 .send_contact_request(requester_id, responder_id)
2486 .await?;
2487
2488 // Update outgoing contact requests of requester
2489 let mut update = proto::UpdateContacts::default();
2490 update.outgoing_requests.push(responder_id.to_proto());
2491 for connection_id in session
2492 .connection_pool()
2493 .await
2494 .user_connection_ids(requester_id)
2495 {
2496 session.peer.send(connection_id, update.clone())?;
2497 }
2498
2499 // Update incoming contact requests of responder
2500 let mut update = proto::UpdateContacts::default();
2501 update
2502 .incoming_requests
2503 .push(proto::IncomingContactRequest {
2504 requester_id: requester_id.to_proto(),
2505 });
2506 let connection_pool = session.connection_pool().await;
2507 for connection_id in connection_pool.user_connection_ids(responder_id) {
2508 session.peer.send(connection_id, update.clone())?;
2509 }
2510
2511 send_notifications(&connection_pool, &session.peer, notifications);
2512
2513 response.send(proto::Ack {})?;
2514 Ok(())
2515}
2516
2517/// Accept or decline a contact request
2518async fn respond_to_contact_request(
2519 request: proto::RespondToContactRequest,
2520 response: Response<proto::RespondToContactRequest>,
2521 session: Session,
2522) -> Result<()> {
2523 let responder_id = session.user_id();
2524 let requester_id = UserId::from_proto(request.requester_id);
2525 let db = session.db().await;
2526 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2527 db.dismiss_contact_notification(responder_id, requester_id)
2528 .await?;
2529 } else {
2530 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2531
2532 let notifications = db
2533 .respond_to_contact_request(responder_id, requester_id, accept)
2534 .await?;
2535 let requester_busy = db.is_user_busy(requester_id).await?;
2536 let responder_busy = db.is_user_busy(responder_id).await?;
2537
2538 let pool = session.connection_pool().await;
2539 // Update responder with new contact
2540 let mut update = proto::UpdateContacts::default();
2541 if accept {
2542 update
2543 .contacts
2544 .push(contact_for_user(requester_id, requester_busy, &pool));
2545 }
2546 update
2547 .remove_incoming_requests
2548 .push(requester_id.to_proto());
2549 for connection_id in pool.user_connection_ids(responder_id) {
2550 session.peer.send(connection_id, update.clone())?;
2551 }
2552
2553 // Update requester with new contact
2554 let mut update = proto::UpdateContacts::default();
2555 if accept {
2556 update
2557 .contacts
2558 .push(contact_for_user(responder_id, responder_busy, &pool));
2559 }
2560 update
2561 .remove_outgoing_requests
2562 .push(responder_id.to_proto());
2563
2564 for connection_id in pool.user_connection_ids(requester_id) {
2565 session.peer.send(connection_id, update.clone())?;
2566 }
2567
2568 send_notifications(&pool, &session.peer, notifications);
2569 }
2570
2571 response.send(proto::Ack {})?;
2572 Ok(())
2573}
2574
2575/// Remove a contact.
2576async fn remove_contact(
2577 request: proto::RemoveContact,
2578 response: Response<proto::RemoveContact>,
2579 session: Session,
2580) -> Result<()> {
2581 let requester_id = session.user_id();
2582 let responder_id = UserId::from_proto(request.user_id);
2583 let db = session.db().await;
2584 let (contact_accepted, deleted_notification_id) =
2585 db.remove_contact(requester_id, responder_id).await?;
2586
2587 let pool = session.connection_pool().await;
2588 // Update outgoing contact requests of requester
2589 let mut update = proto::UpdateContacts::default();
2590 if contact_accepted {
2591 update.remove_contacts.push(responder_id.to_proto());
2592 } else {
2593 update
2594 .remove_outgoing_requests
2595 .push(responder_id.to_proto());
2596 }
2597 for connection_id in pool.user_connection_ids(requester_id) {
2598 session.peer.send(connection_id, update.clone())?;
2599 }
2600
2601 // Update incoming contact requests of responder
2602 let mut update = proto::UpdateContacts::default();
2603 if contact_accepted {
2604 update.remove_contacts.push(requester_id.to_proto());
2605 } else {
2606 update
2607 .remove_incoming_requests
2608 .push(requester_id.to_proto());
2609 }
2610 for connection_id in pool.user_connection_ids(responder_id) {
2611 session.peer.send(connection_id, update.clone())?;
2612 if let Some(notification_id) = deleted_notification_id {
2613 session.peer.send(
2614 connection_id,
2615 proto::DeleteNotification {
2616 notification_id: notification_id.to_proto(),
2617 },
2618 )?;
2619 }
2620 }
2621
2622 response.send(proto::Ack {})?;
2623 Ok(())
2624}
2625
2626fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2627 version.0.minor() < 139
2628}
2629
2630async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2631 let plan = session.current_plan(&session.db().await).await?;
2632
2633 session
2634 .peer
2635 .send(
2636 session.connection_id,
2637 proto::UpdateUserPlan { plan: plan.into() },
2638 )
2639 .trace_err();
2640
2641 Ok(())
2642}
2643
2644async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2645 subscribe_user_to_channels(session.user_id(), &session).await?;
2646 Ok(())
2647}
2648
2649async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2650 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2651 let mut pool = session.connection_pool().await;
2652 for membership in &channels_for_user.channel_memberships {
2653 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2654 }
2655 session.peer.send(
2656 session.connection_id,
2657 build_update_user_channels(&channels_for_user),
2658 )?;
2659 session.peer.send(
2660 session.connection_id,
2661 build_channels_update(channels_for_user),
2662 )?;
2663 Ok(())
2664}
2665
2666/// Creates a new channel.
2667async fn create_channel(
2668 request: proto::CreateChannel,
2669 response: Response<proto::CreateChannel>,
2670 session: Session,
2671) -> Result<()> {
2672 let db = session.db().await;
2673
2674 let parent_id = request.parent_id.map(ChannelId::from_proto);
2675 let (channel, membership) = db
2676 .create_channel(&request.name, parent_id, session.user_id())
2677 .await?;
2678
2679 let root_id = channel.root_id();
2680 let channel = Channel::from_model(channel);
2681
2682 response.send(proto::CreateChannelResponse {
2683 channel: Some(channel.to_proto()),
2684 parent_id: request.parent_id,
2685 })?;
2686
2687 let mut connection_pool = session.connection_pool().await;
2688 if let Some(membership) = membership {
2689 connection_pool.subscribe_to_channel(
2690 membership.user_id,
2691 membership.channel_id,
2692 membership.role,
2693 );
2694 let update = proto::UpdateUserChannels {
2695 channel_memberships: vec![proto::ChannelMembership {
2696 channel_id: membership.channel_id.to_proto(),
2697 role: membership.role.into(),
2698 }],
2699 ..Default::default()
2700 };
2701 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2702 session.peer.send(connection_id, update.clone())?;
2703 }
2704 }
2705
2706 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2707 if !role.can_see_channel(channel.visibility) {
2708 continue;
2709 }
2710
2711 let update = proto::UpdateChannels {
2712 channels: vec![channel.to_proto()],
2713 ..Default::default()
2714 };
2715 session.peer.send(connection_id, update.clone())?;
2716 }
2717
2718 Ok(())
2719}
2720
2721/// Delete a channel
2722async fn delete_channel(
2723 request: proto::DeleteChannel,
2724 response: Response<proto::DeleteChannel>,
2725 session: Session,
2726) -> Result<()> {
2727 let db = session.db().await;
2728
2729 let channel_id = request.channel_id;
2730 let (root_channel, removed_channels) = db
2731 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2732 .await?;
2733 response.send(proto::Ack {})?;
2734
2735 // Notify members of removed channels
2736 let mut update = proto::UpdateChannels::default();
2737 update
2738 .delete_channels
2739 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2740
2741 let connection_pool = session.connection_pool().await;
2742 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2743 session.peer.send(connection_id, update.clone())?;
2744 }
2745
2746 Ok(())
2747}
2748
2749/// Invite someone to join a channel.
2750async fn invite_channel_member(
2751 request: proto::InviteChannelMember,
2752 response: Response<proto::InviteChannelMember>,
2753 session: Session,
2754) -> Result<()> {
2755 let db = session.db().await;
2756 let channel_id = ChannelId::from_proto(request.channel_id);
2757 let invitee_id = UserId::from_proto(request.user_id);
2758 let InviteMemberResult {
2759 channel,
2760 notifications,
2761 } = db
2762 .invite_channel_member(
2763 channel_id,
2764 invitee_id,
2765 session.user_id(),
2766 request.role().into(),
2767 )
2768 .await?;
2769
2770 let update = proto::UpdateChannels {
2771 channel_invitations: vec![channel.to_proto()],
2772 ..Default::default()
2773 };
2774
2775 let connection_pool = session.connection_pool().await;
2776 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2777 session.peer.send(connection_id, update.clone())?;
2778 }
2779
2780 send_notifications(&connection_pool, &session.peer, notifications);
2781
2782 response.send(proto::Ack {})?;
2783 Ok(())
2784}
2785
2786/// remove someone from a channel
2787async fn remove_channel_member(
2788 request: proto::RemoveChannelMember,
2789 response: Response<proto::RemoveChannelMember>,
2790 session: Session,
2791) -> Result<()> {
2792 let db = session.db().await;
2793 let channel_id = ChannelId::from_proto(request.channel_id);
2794 let member_id = UserId::from_proto(request.user_id);
2795
2796 let RemoveChannelMemberResult {
2797 membership_update,
2798 notification_id,
2799 } = db
2800 .remove_channel_member(channel_id, member_id, session.user_id())
2801 .await?;
2802
2803 let mut connection_pool = session.connection_pool().await;
2804 notify_membership_updated(
2805 &mut connection_pool,
2806 membership_update,
2807 member_id,
2808 &session.peer,
2809 );
2810 for connection_id in connection_pool.user_connection_ids(member_id) {
2811 if let Some(notification_id) = notification_id {
2812 session
2813 .peer
2814 .send(
2815 connection_id,
2816 proto::DeleteNotification {
2817 notification_id: notification_id.to_proto(),
2818 },
2819 )
2820 .trace_err();
2821 }
2822 }
2823
2824 response.send(proto::Ack {})?;
2825 Ok(())
2826}
2827
2828/// Toggle the channel between public and private.
2829/// Care is taken to maintain the invariant that public channels only descend from public channels,
2830/// (though members-only channels can appear at any point in the hierarchy).
2831async fn set_channel_visibility(
2832 request: proto::SetChannelVisibility,
2833 response: Response<proto::SetChannelVisibility>,
2834 session: Session,
2835) -> Result<()> {
2836 let db = session.db().await;
2837 let channel_id = ChannelId::from_proto(request.channel_id);
2838 let visibility = request.visibility().into();
2839
2840 let channel_model = db
2841 .set_channel_visibility(channel_id, visibility, session.user_id())
2842 .await?;
2843 let root_id = channel_model.root_id();
2844 let channel = Channel::from_model(channel_model);
2845
2846 let mut connection_pool = session.connection_pool().await;
2847 for (user_id, role) in connection_pool
2848 .channel_user_ids(root_id)
2849 .collect::<Vec<_>>()
2850 .into_iter()
2851 {
2852 let update = if role.can_see_channel(channel.visibility) {
2853 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2854 proto::UpdateChannels {
2855 channels: vec![channel.to_proto()],
2856 ..Default::default()
2857 }
2858 } else {
2859 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2860 proto::UpdateChannels {
2861 delete_channels: vec![channel.id.to_proto()],
2862 ..Default::default()
2863 }
2864 };
2865
2866 for connection_id in connection_pool.user_connection_ids(user_id) {
2867 session.peer.send(connection_id, update.clone())?;
2868 }
2869 }
2870
2871 response.send(proto::Ack {})?;
2872 Ok(())
2873}
2874
2875/// Alter the role for a user in the channel.
2876async fn set_channel_member_role(
2877 request: proto::SetChannelMemberRole,
2878 response: Response<proto::SetChannelMemberRole>,
2879 session: Session,
2880) -> Result<()> {
2881 let db = session.db().await;
2882 let channel_id = ChannelId::from_proto(request.channel_id);
2883 let member_id = UserId::from_proto(request.user_id);
2884 let result = db
2885 .set_channel_member_role(
2886 channel_id,
2887 session.user_id(),
2888 member_id,
2889 request.role().into(),
2890 )
2891 .await?;
2892
2893 match result {
2894 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2895 let mut connection_pool = session.connection_pool().await;
2896 notify_membership_updated(
2897 &mut connection_pool,
2898 membership_update,
2899 member_id,
2900 &session.peer,
2901 )
2902 }
2903 db::SetMemberRoleResult::InviteUpdated(channel) => {
2904 let update = proto::UpdateChannels {
2905 channel_invitations: vec![channel.to_proto()],
2906 ..Default::default()
2907 };
2908
2909 for connection_id in session
2910 .connection_pool()
2911 .await
2912 .user_connection_ids(member_id)
2913 {
2914 session.peer.send(connection_id, update.clone())?;
2915 }
2916 }
2917 }
2918
2919 response.send(proto::Ack {})?;
2920 Ok(())
2921}
2922
2923/// Change the name of a channel
2924async fn rename_channel(
2925 request: proto::RenameChannel,
2926 response: Response<proto::RenameChannel>,
2927 session: Session,
2928) -> Result<()> {
2929 let db = session.db().await;
2930 let channel_id = ChannelId::from_proto(request.channel_id);
2931 let channel_model = db
2932 .rename_channel(channel_id, session.user_id(), &request.name)
2933 .await?;
2934 let root_id = channel_model.root_id();
2935 let channel = Channel::from_model(channel_model);
2936
2937 response.send(proto::RenameChannelResponse {
2938 channel: Some(channel.to_proto()),
2939 })?;
2940
2941 let connection_pool = session.connection_pool().await;
2942 let update = proto::UpdateChannels {
2943 channels: vec![channel.to_proto()],
2944 ..Default::default()
2945 };
2946 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2947 if role.can_see_channel(channel.visibility) {
2948 session.peer.send(connection_id, update.clone())?;
2949 }
2950 }
2951
2952 Ok(())
2953}
2954
2955/// Move a channel to a new parent.
2956async fn move_channel(
2957 request: proto::MoveChannel,
2958 response: Response<proto::MoveChannel>,
2959 session: Session,
2960) -> Result<()> {
2961 let channel_id = ChannelId::from_proto(request.channel_id);
2962 let to = ChannelId::from_proto(request.to);
2963
2964 let (root_id, channels) = session
2965 .db()
2966 .await
2967 .move_channel(channel_id, to, session.user_id())
2968 .await?;
2969
2970 let connection_pool = session.connection_pool().await;
2971 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2972 let channels = channels
2973 .iter()
2974 .filter_map(|channel| {
2975 if role.can_see_channel(channel.visibility) {
2976 Some(channel.to_proto())
2977 } else {
2978 None
2979 }
2980 })
2981 .collect::<Vec<_>>();
2982 if channels.is_empty() {
2983 continue;
2984 }
2985
2986 let update = proto::UpdateChannels {
2987 channels,
2988 ..Default::default()
2989 };
2990
2991 session.peer.send(connection_id, update.clone())?;
2992 }
2993
2994 response.send(Ack {})?;
2995 Ok(())
2996}
2997
2998/// Get the list of channel members
2999async fn get_channel_members(
3000 request: proto::GetChannelMembers,
3001 response: Response<proto::GetChannelMembers>,
3002 session: Session,
3003) -> Result<()> {
3004 let db = session.db().await;
3005 let channel_id = ChannelId::from_proto(request.channel_id);
3006 let limit = if request.limit == 0 {
3007 u16::MAX as u64
3008 } else {
3009 request.limit
3010 };
3011 let (members, users) = db
3012 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3013 .await?;
3014 response.send(proto::GetChannelMembersResponse { members, users })?;
3015 Ok(())
3016}
3017
3018/// Accept or decline a channel invitation.
3019async fn respond_to_channel_invite(
3020 request: proto::RespondToChannelInvite,
3021 response: Response<proto::RespondToChannelInvite>,
3022 session: Session,
3023) -> Result<()> {
3024 let db = session.db().await;
3025 let channel_id = ChannelId::from_proto(request.channel_id);
3026 let RespondToChannelInvite {
3027 membership_update,
3028 notifications,
3029 } = db
3030 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3031 .await?;
3032
3033 let mut connection_pool = session.connection_pool().await;
3034 if let Some(membership_update) = membership_update {
3035 notify_membership_updated(
3036 &mut connection_pool,
3037 membership_update,
3038 session.user_id(),
3039 &session.peer,
3040 );
3041 } else {
3042 let update = proto::UpdateChannels {
3043 remove_channel_invitations: vec![channel_id.to_proto()],
3044 ..Default::default()
3045 };
3046
3047 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3048 session.peer.send(connection_id, update.clone())?;
3049 }
3050 };
3051
3052 send_notifications(&connection_pool, &session.peer, notifications);
3053
3054 response.send(proto::Ack {})?;
3055
3056 Ok(())
3057}
3058
3059/// Join the channels' room
3060async fn join_channel(
3061 request: proto::JoinChannel,
3062 response: Response<proto::JoinChannel>,
3063 session: Session,
3064) -> Result<()> {
3065 let channel_id = ChannelId::from_proto(request.channel_id);
3066 join_channel_internal(channel_id, Box::new(response), session).await
3067}
3068
3069trait JoinChannelInternalResponse {
3070 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3071}
3072impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3073 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3074 Response::<proto::JoinChannel>::send(self, result)
3075 }
3076}
3077impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3078 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3079 Response::<proto::JoinRoom>::send(self, result)
3080 }
3081}
3082
3083async fn join_channel_internal(
3084 channel_id: ChannelId,
3085 response: Box<impl JoinChannelInternalResponse>,
3086 session: Session,
3087) -> Result<()> {
3088 let joined_room = {
3089 let mut db = session.db().await;
3090 // If zed quits without leaving the room, and the user re-opens zed before the
3091 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3092 // room they were in.
3093 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3094 tracing::info!(
3095 stale_connection_id = %connection,
3096 "cleaning up stale connection",
3097 );
3098 drop(db);
3099 leave_room_for_session(&session, connection).await?;
3100 db = session.db().await;
3101 }
3102
3103 let (joined_room, membership_updated, role) = db
3104 .join_channel(channel_id, session.user_id(), session.connection_id)
3105 .await?;
3106
3107 let live_kit_connection_info =
3108 session
3109 .app_state
3110 .livekit_client
3111 .as_ref()
3112 .and_then(|live_kit| {
3113 let (can_publish, token) = if role == ChannelRole::Guest {
3114 (
3115 false,
3116 live_kit
3117 .guest_token(
3118 &joined_room.room.livekit_room,
3119 &session.user_id().to_string(),
3120 )
3121 .trace_err()?,
3122 )
3123 } else {
3124 (
3125 true,
3126 live_kit
3127 .room_token(
3128 &joined_room.room.livekit_room,
3129 &session.user_id().to_string(),
3130 )
3131 .trace_err()?,
3132 )
3133 };
3134
3135 Some(LiveKitConnectionInfo {
3136 server_url: live_kit.url().into(),
3137 token,
3138 can_publish,
3139 })
3140 });
3141
3142 response.send(proto::JoinRoomResponse {
3143 room: Some(joined_room.room.clone()),
3144 channel_id: joined_room
3145 .channel
3146 .as_ref()
3147 .map(|channel| channel.id.to_proto()),
3148 live_kit_connection_info,
3149 })?;
3150
3151 let mut connection_pool = session.connection_pool().await;
3152 if let Some(membership_updated) = membership_updated {
3153 notify_membership_updated(
3154 &mut connection_pool,
3155 membership_updated,
3156 session.user_id(),
3157 &session.peer,
3158 );
3159 }
3160
3161 room_updated(&joined_room.room, &session.peer);
3162
3163 joined_room
3164 };
3165
3166 channel_updated(
3167 &joined_room
3168 .channel
3169 .ok_or_else(|| anyhow!("channel not returned"))?,
3170 &joined_room.room,
3171 &session.peer,
3172 &*session.connection_pool().await,
3173 );
3174
3175 update_user_contacts(session.user_id(), &session).await?;
3176 Ok(())
3177}
3178
3179/// Start editing the channel notes
3180async fn join_channel_buffer(
3181 request: proto::JoinChannelBuffer,
3182 response: Response<proto::JoinChannelBuffer>,
3183 session: Session,
3184) -> Result<()> {
3185 let db = session.db().await;
3186 let channel_id = ChannelId::from_proto(request.channel_id);
3187
3188 let open_response = db
3189 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3190 .await?;
3191
3192 let collaborators = open_response.collaborators.clone();
3193 response.send(open_response)?;
3194
3195 let update = UpdateChannelBufferCollaborators {
3196 channel_id: channel_id.to_proto(),
3197 collaborators: collaborators.clone(),
3198 };
3199 channel_buffer_updated(
3200 session.connection_id,
3201 collaborators
3202 .iter()
3203 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3204 &update,
3205 &session.peer,
3206 );
3207
3208 Ok(())
3209}
3210
3211/// Edit the channel notes
3212async fn update_channel_buffer(
3213 request: proto::UpdateChannelBuffer,
3214 session: Session,
3215) -> Result<()> {
3216 let db = session.db().await;
3217 let channel_id = ChannelId::from_proto(request.channel_id);
3218
3219 let (collaborators, epoch, version) = db
3220 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3221 .await?;
3222
3223 channel_buffer_updated(
3224 session.connection_id,
3225 collaborators.clone(),
3226 &proto::UpdateChannelBuffer {
3227 channel_id: channel_id.to_proto(),
3228 operations: request.operations,
3229 },
3230 &session.peer,
3231 );
3232
3233 let pool = &*session.connection_pool().await;
3234
3235 let non_collaborators =
3236 pool.channel_connection_ids(channel_id)
3237 .filter_map(|(connection_id, _)| {
3238 if collaborators.contains(&connection_id) {
3239 None
3240 } else {
3241 Some(connection_id)
3242 }
3243 });
3244
3245 broadcast(None, non_collaborators, |peer_id| {
3246 session.peer.send(
3247 peer_id,
3248 proto::UpdateChannels {
3249 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3250 channel_id: channel_id.to_proto(),
3251 epoch: epoch as u64,
3252 version: version.clone(),
3253 }],
3254 ..Default::default()
3255 },
3256 )
3257 });
3258
3259 Ok(())
3260}
3261
3262/// Rejoin the channel notes after a connection blip
3263async fn rejoin_channel_buffers(
3264 request: proto::RejoinChannelBuffers,
3265 response: Response<proto::RejoinChannelBuffers>,
3266 session: Session,
3267) -> Result<()> {
3268 let db = session.db().await;
3269 let buffers = db
3270 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3271 .await?;
3272
3273 for rejoined_buffer in &buffers {
3274 let collaborators_to_notify = rejoined_buffer
3275 .buffer
3276 .collaborators
3277 .iter()
3278 .filter_map(|c| Some(c.peer_id?.into()));
3279 channel_buffer_updated(
3280 session.connection_id,
3281 collaborators_to_notify,
3282 &proto::UpdateChannelBufferCollaborators {
3283 channel_id: rejoined_buffer.buffer.channel_id,
3284 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3285 },
3286 &session.peer,
3287 );
3288 }
3289
3290 response.send(proto::RejoinChannelBuffersResponse {
3291 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3292 })?;
3293
3294 Ok(())
3295}
3296
3297/// Stop editing the channel notes
3298async fn leave_channel_buffer(
3299 request: proto::LeaveChannelBuffer,
3300 response: Response<proto::LeaveChannelBuffer>,
3301 session: Session,
3302) -> Result<()> {
3303 let db = session.db().await;
3304 let channel_id = ChannelId::from_proto(request.channel_id);
3305
3306 let left_buffer = db
3307 .leave_channel_buffer(channel_id, session.connection_id)
3308 .await?;
3309
3310 response.send(Ack {})?;
3311
3312 channel_buffer_updated(
3313 session.connection_id,
3314 left_buffer.connections,
3315 &proto::UpdateChannelBufferCollaborators {
3316 channel_id: channel_id.to_proto(),
3317 collaborators: left_buffer.collaborators,
3318 },
3319 &session.peer,
3320 );
3321
3322 Ok(())
3323}
3324
3325fn channel_buffer_updated<T: EnvelopedMessage>(
3326 sender_id: ConnectionId,
3327 collaborators: impl IntoIterator<Item = ConnectionId>,
3328 message: &T,
3329 peer: &Peer,
3330) {
3331 broadcast(Some(sender_id), collaborators, |peer_id| {
3332 peer.send(peer_id, message.clone())
3333 });
3334}
3335
3336fn send_notifications(
3337 connection_pool: &ConnectionPool,
3338 peer: &Peer,
3339 notifications: db::NotificationBatch,
3340) {
3341 for (user_id, notification) in notifications {
3342 for connection_id in connection_pool.user_connection_ids(user_id) {
3343 if let Err(error) = peer.send(
3344 connection_id,
3345 proto::AddNotification {
3346 notification: Some(notification.clone()),
3347 },
3348 ) {
3349 tracing::error!(
3350 "failed to send notification to {:?} {}",
3351 connection_id,
3352 error
3353 );
3354 }
3355 }
3356 }
3357}
3358
3359/// Send a message to the channel
3360async fn send_channel_message(
3361 request: proto::SendChannelMessage,
3362 response: Response<proto::SendChannelMessage>,
3363 session: Session,
3364) -> Result<()> {
3365 // Validate the message body.
3366 let body = request.body.trim().to_string();
3367 if body.len() > MAX_MESSAGE_LEN {
3368 return Err(anyhow!("message is too long"))?;
3369 }
3370 if body.is_empty() {
3371 return Err(anyhow!("message can't be blank"))?;
3372 }
3373
3374 // TODO: adjust mentions if body is trimmed
3375
3376 let timestamp = OffsetDateTime::now_utc();
3377 let nonce = request
3378 .nonce
3379 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3380
3381 let channel_id = ChannelId::from_proto(request.channel_id);
3382 let CreatedChannelMessage {
3383 message_id,
3384 participant_connection_ids,
3385 notifications,
3386 } = session
3387 .db()
3388 .await
3389 .create_channel_message(
3390 channel_id,
3391 session.user_id(),
3392 &body,
3393 &request.mentions,
3394 timestamp,
3395 nonce.clone().into(),
3396 request.reply_to_message_id.map(MessageId::from_proto),
3397 )
3398 .await?;
3399
3400 let message = proto::ChannelMessage {
3401 sender_id: session.user_id().to_proto(),
3402 id: message_id.to_proto(),
3403 body,
3404 mentions: request.mentions,
3405 timestamp: timestamp.unix_timestamp() as u64,
3406 nonce: Some(nonce),
3407 reply_to_message_id: request.reply_to_message_id,
3408 edited_at: None,
3409 };
3410 broadcast(
3411 Some(session.connection_id),
3412 participant_connection_ids.clone(),
3413 |connection| {
3414 session.peer.send(
3415 connection,
3416 proto::ChannelMessageSent {
3417 channel_id: channel_id.to_proto(),
3418 message: Some(message.clone()),
3419 },
3420 )
3421 },
3422 );
3423 response.send(proto::SendChannelMessageResponse {
3424 message: Some(message),
3425 })?;
3426
3427 let pool = &*session.connection_pool().await;
3428 let non_participants =
3429 pool.channel_connection_ids(channel_id)
3430 .filter_map(|(connection_id, _)| {
3431 if participant_connection_ids.contains(&connection_id) {
3432 None
3433 } else {
3434 Some(connection_id)
3435 }
3436 });
3437 broadcast(None, non_participants, |peer_id| {
3438 session.peer.send(
3439 peer_id,
3440 proto::UpdateChannels {
3441 latest_channel_message_ids: vec![proto::ChannelMessageId {
3442 channel_id: channel_id.to_proto(),
3443 message_id: message_id.to_proto(),
3444 }],
3445 ..Default::default()
3446 },
3447 )
3448 });
3449 send_notifications(pool, &session.peer, notifications);
3450
3451 Ok(())
3452}
3453
3454/// Delete a channel message
3455async fn remove_channel_message(
3456 request: proto::RemoveChannelMessage,
3457 response: Response<proto::RemoveChannelMessage>,
3458 session: Session,
3459) -> Result<()> {
3460 let channel_id = ChannelId::from_proto(request.channel_id);
3461 let message_id = MessageId::from_proto(request.message_id);
3462 let (connection_ids, existing_notification_ids) = session
3463 .db()
3464 .await
3465 .remove_channel_message(channel_id, message_id, session.user_id())
3466 .await?;
3467
3468 broadcast(
3469 Some(session.connection_id),
3470 connection_ids,
3471 move |connection| {
3472 session.peer.send(connection, request.clone())?;
3473
3474 for notification_id in &existing_notification_ids {
3475 session.peer.send(
3476 connection,
3477 proto::DeleteNotification {
3478 notification_id: (*notification_id).to_proto(),
3479 },
3480 )?;
3481 }
3482
3483 Ok(())
3484 },
3485 );
3486 response.send(proto::Ack {})?;
3487 Ok(())
3488}
3489
3490async fn update_channel_message(
3491 request: proto::UpdateChannelMessage,
3492 response: Response<proto::UpdateChannelMessage>,
3493 session: Session,
3494) -> Result<()> {
3495 let channel_id = ChannelId::from_proto(request.channel_id);
3496 let message_id = MessageId::from_proto(request.message_id);
3497 let updated_at = OffsetDateTime::now_utc();
3498 let UpdatedChannelMessage {
3499 message_id,
3500 participant_connection_ids,
3501 notifications,
3502 reply_to_message_id,
3503 timestamp,
3504 deleted_mention_notification_ids,
3505 updated_mention_notifications,
3506 } = session
3507 .db()
3508 .await
3509 .update_channel_message(
3510 channel_id,
3511 message_id,
3512 session.user_id(),
3513 request.body.as_str(),
3514 &request.mentions,
3515 updated_at,
3516 )
3517 .await?;
3518
3519 let nonce = request
3520 .nonce
3521 .clone()
3522 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3523
3524 let message = proto::ChannelMessage {
3525 sender_id: session.user_id().to_proto(),
3526 id: message_id.to_proto(),
3527 body: request.body.clone(),
3528 mentions: request.mentions.clone(),
3529 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3530 nonce: Some(nonce),
3531 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3532 edited_at: Some(updated_at.unix_timestamp() as u64),
3533 };
3534
3535 response.send(proto::Ack {})?;
3536
3537 let pool = &*session.connection_pool().await;
3538 broadcast(
3539 Some(session.connection_id),
3540 participant_connection_ids,
3541 |connection| {
3542 session.peer.send(
3543 connection,
3544 proto::ChannelMessageUpdate {
3545 channel_id: channel_id.to_proto(),
3546 message: Some(message.clone()),
3547 },
3548 )?;
3549
3550 for notification_id in &deleted_mention_notification_ids {
3551 session.peer.send(
3552 connection,
3553 proto::DeleteNotification {
3554 notification_id: (*notification_id).to_proto(),
3555 },
3556 )?;
3557 }
3558
3559 for notification in &updated_mention_notifications {
3560 session.peer.send(
3561 connection,
3562 proto::UpdateNotification {
3563 notification: Some(notification.clone()),
3564 },
3565 )?;
3566 }
3567
3568 Ok(())
3569 },
3570 );
3571
3572 send_notifications(pool, &session.peer, notifications);
3573
3574 Ok(())
3575}
3576
3577/// Mark a channel message as read
3578async fn acknowledge_channel_message(
3579 request: proto::AckChannelMessage,
3580 session: Session,
3581) -> Result<()> {
3582 let channel_id = ChannelId::from_proto(request.channel_id);
3583 let message_id = MessageId::from_proto(request.message_id);
3584 let notifications = session
3585 .db()
3586 .await
3587 .observe_channel_message(channel_id, session.user_id(), message_id)
3588 .await?;
3589 send_notifications(
3590 &*session.connection_pool().await,
3591 &session.peer,
3592 notifications,
3593 );
3594 Ok(())
3595}
3596
3597/// Mark a buffer version as synced
3598async fn acknowledge_buffer_version(
3599 request: proto::AckBufferOperation,
3600 session: Session,
3601) -> Result<()> {
3602 let buffer_id = BufferId::from_proto(request.buffer_id);
3603 session
3604 .db()
3605 .await
3606 .observe_buffer_version(
3607 buffer_id,
3608 session.user_id(),
3609 request.epoch as i32,
3610 &request.version,
3611 )
3612 .await?;
3613 Ok(())
3614}
3615
3616async fn count_language_model_tokens(
3617 request: proto::CountLanguageModelTokens,
3618 response: Response<proto::CountLanguageModelTokens>,
3619 session: Session,
3620 config: &Config,
3621) -> Result<()> {
3622 authorize_access_to_legacy_llm_endpoints(&session).await?;
3623
3624 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3625 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3626 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3627 };
3628
3629 session
3630 .app_state
3631 .rate_limiter
3632 .check(&*rate_limit, session.user_id())
3633 .await?;
3634
3635 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3636 Some(proto::LanguageModelProvider::Google) => {
3637 let api_key = config
3638 .google_ai_api_key
3639 .as_ref()
3640 .context("no Google AI API key configured on the server")?;
3641 google_ai::count_tokens(
3642 session.http_client.as_ref(),
3643 google_ai::API_URL,
3644 api_key,
3645 serde_json::from_str(&request.request)?,
3646 )
3647 .await?
3648 }
3649 _ => return Err(anyhow!("unsupported provider"))?,
3650 };
3651
3652 response.send(proto::CountLanguageModelTokensResponse {
3653 token_count: result.total_tokens as u32,
3654 })?;
3655
3656 Ok(())
3657}
3658
3659struct ZedProCountLanguageModelTokensRateLimit;
3660
3661impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3662 fn capacity(&self) -> usize {
3663 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3664 .ok()
3665 .and_then(|v| v.parse().ok())
3666 .unwrap_or(600) // Picked arbitrarily
3667 }
3668
3669 fn refill_duration(&self) -> chrono::Duration {
3670 chrono::Duration::hours(1)
3671 }
3672
3673 fn db_name(&self) -> &'static str {
3674 "zed-pro:count-language-model-tokens"
3675 }
3676}
3677
3678struct FreeCountLanguageModelTokensRateLimit;
3679
3680impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3681 fn capacity(&self) -> usize {
3682 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3683 .ok()
3684 .and_then(|v| v.parse().ok())
3685 .unwrap_or(600 / 10) // Picked arbitrarily
3686 }
3687
3688 fn refill_duration(&self) -> chrono::Duration {
3689 chrono::Duration::hours(1)
3690 }
3691
3692 fn db_name(&self) -> &'static str {
3693 "free:count-language-model-tokens"
3694 }
3695}
3696
3697struct ZedProComputeEmbeddingsRateLimit;
3698
3699impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3700 fn capacity(&self) -> usize {
3701 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3702 .ok()
3703 .and_then(|v| v.parse().ok())
3704 .unwrap_or(5000) // Picked arbitrarily
3705 }
3706
3707 fn refill_duration(&self) -> chrono::Duration {
3708 chrono::Duration::hours(1)
3709 }
3710
3711 fn db_name(&self) -> &'static str {
3712 "zed-pro:compute-embeddings"
3713 }
3714}
3715
3716struct FreeComputeEmbeddingsRateLimit;
3717
3718impl RateLimit for FreeComputeEmbeddingsRateLimit {
3719 fn capacity(&self) -> usize {
3720 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3721 .ok()
3722 .and_then(|v| v.parse().ok())
3723 .unwrap_or(5000 / 10) // Picked arbitrarily
3724 }
3725
3726 fn refill_duration(&self) -> chrono::Duration {
3727 chrono::Duration::hours(1)
3728 }
3729
3730 fn db_name(&self) -> &'static str {
3731 "free:compute-embeddings"
3732 }
3733}
3734
3735async fn compute_embeddings(
3736 request: proto::ComputeEmbeddings,
3737 response: Response<proto::ComputeEmbeddings>,
3738 session: Session,
3739 api_key: Option<Arc<str>>,
3740) -> Result<()> {
3741 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3742 authorize_access_to_legacy_llm_endpoints(&session).await?;
3743
3744 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3745 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3746 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3747 };
3748
3749 session
3750 .app_state
3751 .rate_limiter
3752 .check(&*rate_limit, session.user_id())
3753 .await?;
3754
3755 let embeddings = match request.model.as_str() {
3756 "openai/text-embedding-3-small" => {
3757 open_ai::embed(
3758 session.http_client.as_ref(),
3759 OPEN_AI_API_URL,
3760 &api_key,
3761 OpenAiEmbeddingModel::TextEmbedding3Small,
3762 request.texts.iter().map(|text| text.as_str()),
3763 )
3764 .await?
3765 }
3766 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3767 };
3768
3769 let embeddings = request
3770 .texts
3771 .iter()
3772 .map(|text| {
3773 let mut hasher = sha2::Sha256::new();
3774 hasher.update(text.as_bytes());
3775 let result = hasher.finalize();
3776 result.to_vec()
3777 })
3778 .zip(
3779 embeddings
3780 .data
3781 .into_iter()
3782 .map(|embedding| embedding.embedding),
3783 )
3784 .collect::<HashMap<_, _>>();
3785
3786 let db = session.db().await;
3787 db.save_embeddings(&request.model, &embeddings)
3788 .await
3789 .context("failed to save embeddings")
3790 .trace_err();
3791
3792 response.send(proto::ComputeEmbeddingsResponse {
3793 embeddings: embeddings
3794 .into_iter()
3795 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3796 .collect(),
3797 })?;
3798 Ok(())
3799}
3800
3801async fn get_cached_embeddings(
3802 request: proto::GetCachedEmbeddings,
3803 response: Response<proto::GetCachedEmbeddings>,
3804 session: Session,
3805) -> Result<()> {
3806 authorize_access_to_legacy_llm_endpoints(&session).await?;
3807
3808 let db = session.db().await;
3809 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3810
3811 response.send(proto::GetCachedEmbeddingsResponse {
3812 embeddings: embeddings
3813 .into_iter()
3814 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3815 .collect(),
3816 })?;
3817 Ok(())
3818}
3819
3820/// This is leftover from before the LLM service.
3821///
3822/// The endpoints protected by this check will be moved there eventually.
3823async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3824 if session.is_staff() {
3825 Ok(())
3826 } else {
3827 Err(anyhow!("permission denied"))?
3828 }
3829}
3830
3831/// Get a Supermaven API key for the user
3832async fn get_supermaven_api_key(
3833 _request: proto::GetSupermavenApiKey,
3834 response: Response<proto::GetSupermavenApiKey>,
3835 session: Session,
3836) -> Result<()> {
3837 let user_id: String = session.user_id().to_string();
3838 if !session.is_staff() {
3839 return Err(anyhow!("supermaven not enabled for this account"))?;
3840 }
3841
3842 let email = session
3843 .email()
3844 .ok_or_else(|| anyhow!("user must have an email"))?;
3845
3846 let supermaven_admin_api = session
3847 .supermaven_client
3848 .as_ref()
3849 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3850
3851 let result = supermaven_admin_api
3852 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3853 .await?;
3854
3855 response.send(proto::GetSupermavenApiKeyResponse {
3856 api_key: result.api_key,
3857 })?;
3858
3859 Ok(())
3860}
3861
3862/// Start receiving chat updates for a channel
3863async fn join_channel_chat(
3864 request: proto::JoinChannelChat,
3865 response: Response<proto::JoinChannelChat>,
3866 session: Session,
3867) -> Result<()> {
3868 let channel_id = ChannelId::from_proto(request.channel_id);
3869
3870 let db = session.db().await;
3871 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3872 .await?;
3873 let messages = db
3874 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3875 .await?;
3876 response.send(proto::JoinChannelChatResponse {
3877 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878 messages,
3879 })?;
3880 Ok(())
3881}
3882
3883/// Stop receiving chat updates for a channel
3884async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3885 let channel_id = ChannelId::from_proto(request.channel_id);
3886 session
3887 .db()
3888 .await
3889 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3890 .await?;
3891 Ok(())
3892}
3893
3894/// Retrieve the chat history for a channel
3895async fn get_channel_messages(
3896 request: proto::GetChannelMessages,
3897 response: Response<proto::GetChannelMessages>,
3898 session: Session,
3899) -> Result<()> {
3900 let channel_id = ChannelId::from_proto(request.channel_id);
3901 let messages = session
3902 .db()
3903 .await
3904 .get_channel_messages(
3905 channel_id,
3906 session.user_id(),
3907 MESSAGE_COUNT_PER_PAGE,
3908 Some(MessageId::from_proto(request.before_message_id)),
3909 )
3910 .await?;
3911 response.send(proto::GetChannelMessagesResponse {
3912 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3913 messages,
3914 })?;
3915 Ok(())
3916}
3917
3918/// Retrieve specific chat messages
3919async fn get_channel_messages_by_id(
3920 request: proto::GetChannelMessagesById,
3921 response: Response<proto::GetChannelMessagesById>,
3922 session: Session,
3923) -> Result<()> {
3924 let message_ids = request
3925 .message_ids
3926 .iter()
3927 .map(|id| MessageId::from_proto(*id))
3928 .collect::<Vec<_>>();
3929 let messages = session
3930 .db()
3931 .await
3932 .get_channel_messages_by_id(session.user_id(), &message_ids)
3933 .await?;
3934 response.send(proto::GetChannelMessagesResponse {
3935 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3936 messages,
3937 })?;
3938 Ok(())
3939}
3940
3941/// Retrieve the current users notifications
3942async fn get_notifications(
3943 request: proto::GetNotifications,
3944 response: Response<proto::GetNotifications>,
3945 session: Session,
3946) -> Result<()> {
3947 let notifications = session
3948 .db()
3949 .await
3950 .get_notifications(
3951 session.user_id(),
3952 NOTIFICATION_COUNT_PER_PAGE,
3953 request.before_id.map(db::NotificationId::from_proto),
3954 )
3955 .await?;
3956 response.send(proto::GetNotificationsResponse {
3957 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3958 notifications,
3959 })?;
3960 Ok(())
3961}
3962
3963/// Mark notifications as read
3964async fn mark_notification_as_read(
3965 request: proto::MarkNotificationRead,
3966 response: Response<proto::MarkNotificationRead>,
3967 session: Session,
3968) -> Result<()> {
3969 let database = &session.db().await;
3970 let notifications = database
3971 .mark_notification_as_read_by_id(
3972 session.user_id(),
3973 NotificationId::from_proto(request.notification_id),
3974 )
3975 .await?;
3976 send_notifications(
3977 &*session.connection_pool().await,
3978 &session.peer,
3979 notifications,
3980 );
3981 response.send(proto::Ack {})?;
3982 Ok(())
3983}
3984
3985/// Get the current users information
3986async fn get_private_user_info(
3987 _request: proto::GetPrivateUserInfo,
3988 response: Response<proto::GetPrivateUserInfo>,
3989 session: Session,
3990) -> Result<()> {
3991 let db = session.db().await;
3992
3993 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3994 let user = db
3995 .get_user_by_id(session.user_id())
3996 .await?
3997 .ok_or_else(|| anyhow!("user not found"))?;
3998 let flags = db.get_user_flags(session.user_id()).await?;
3999
4000 response.send(proto::GetPrivateUserInfoResponse {
4001 metrics_id,
4002 staff: user.admin,
4003 flags,
4004 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4005 })?;
4006 Ok(())
4007}
4008
4009/// Accept the terms of service (tos) on behalf of the current user
4010async fn accept_terms_of_service(
4011 _request: proto::AcceptTermsOfService,
4012 response: Response<proto::AcceptTermsOfService>,
4013 session: Session,
4014) -> Result<()> {
4015 let db = session.db().await;
4016
4017 let accepted_tos_at = Utc::now();
4018 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4019 .await?;
4020
4021 response.send(proto::AcceptTermsOfServiceResponse {
4022 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4023 })?;
4024 Ok(())
4025}
4026
4027/// The minimum account age an account must have in order to use the LLM service.
4028const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4029
4030async fn get_llm_api_token(
4031 _request: proto::GetLlmToken,
4032 response: Response<proto::GetLlmToken>,
4033 session: Session,
4034) -> Result<()> {
4035 let db = session.db().await;
4036
4037 let flags = db.get_user_flags(session.user_id()).await?;
4038 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4039 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4040 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4041
4042 if !session.is_staff() && !has_language_models_feature_flag {
4043 Err(anyhow!("permission denied"))?
4044 }
4045
4046 let user_id = session.user_id();
4047 let user = db
4048 .get_user_by_id(user_id)
4049 .await?
4050 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4051
4052 if user.accepted_tos_at.is_none() {
4053 Err(anyhow!("terms of service not accepted"))?
4054 }
4055
4056 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4057
4058 let bypass_account_age_check =
4059 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4060 if !bypass_account_age_check {
4061 let mut account_created_at = user.created_at;
4062 if let Some(github_created_at) = user.github_user_created_at {
4063 account_created_at = account_created_at.min(github_created_at);
4064 }
4065 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4066 Err(anyhow!("account too young"))?
4067 }
4068 }
4069
4070 let billing_preferences = db.get_billing_preferences(user.id).await?;
4071
4072 let token = LlmTokenClaims::create(
4073 &user,
4074 session.is_staff(),
4075 billing_preferences,
4076 has_llm_closed_beta_feature_flag,
4077 has_predict_edits_feature_flag,
4078 has_llm_subscription,
4079 session.current_plan(&db).await?,
4080 session.system_id.clone(),
4081 &session.app_state.config,
4082 )?;
4083 response.send(proto::GetLlmTokenResponse { token })?;
4084 Ok(())
4085}
4086
4087fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4088 let message = match message {
4089 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4090 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4091 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4092 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4093 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4094 code: frame.code.into(),
4095 reason: frame.reason,
4096 })),
4097 // We should never receive a frame while reading the message, according
4098 // to the `tungstenite` maintainers:
4099 //
4100 // > It cannot occur when you read messages from the WebSocket, but it
4101 // > can be used when you want to send the raw frames (e.g. you want to
4102 // > send the frames to the WebSocket without composing the full message first).
4103 // >
4104 // > — https://github.com/snapview/tungstenite-rs/issues/268
4105 TungsteniteMessage::Frame(_) => {
4106 bail!("received an unexpected frame while reading the message")
4107 }
4108 };
4109
4110 Ok(message)
4111}
4112
4113fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4114 match message {
4115 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4116 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4117 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4118 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4119 AxumMessage::Close(frame) => {
4120 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4121 code: frame.code.into(),
4122 reason: frame.reason,
4123 }))
4124 }
4125 }
4126}
4127
4128fn notify_membership_updated(
4129 connection_pool: &mut ConnectionPool,
4130 result: MembershipUpdated,
4131 user_id: UserId,
4132 peer: &Peer,
4133) {
4134 for membership in &result.new_channels.channel_memberships {
4135 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4136 }
4137 for channel_id in &result.removed_channels {
4138 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4139 }
4140
4141 let user_channels_update = proto::UpdateUserChannels {
4142 channel_memberships: result
4143 .new_channels
4144 .channel_memberships
4145 .iter()
4146 .map(|cm| proto::ChannelMembership {
4147 channel_id: cm.channel_id.to_proto(),
4148 role: cm.role.into(),
4149 })
4150 .collect(),
4151 ..Default::default()
4152 };
4153
4154 let mut update = build_channels_update(result.new_channels);
4155 update.delete_channels = result
4156 .removed_channels
4157 .into_iter()
4158 .map(|id| id.to_proto())
4159 .collect();
4160 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4161
4162 for connection_id in connection_pool.user_connection_ids(user_id) {
4163 peer.send(connection_id, user_channels_update.clone())
4164 .trace_err();
4165 peer.send(connection_id, update.clone()).trace_err();
4166 }
4167}
4168
4169fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4170 proto::UpdateUserChannels {
4171 channel_memberships: channels
4172 .channel_memberships
4173 .iter()
4174 .map(|m| proto::ChannelMembership {
4175 channel_id: m.channel_id.to_proto(),
4176 role: m.role.into(),
4177 })
4178 .collect(),
4179 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4180 observed_channel_message_id: channels.observed_channel_messages.clone(),
4181 }
4182}
4183
4184fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4185 let mut update = proto::UpdateChannels::default();
4186
4187 for channel in channels.channels {
4188 update.channels.push(channel.to_proto());
4189 }
4190
4191 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4192 update.latest_channel_message_ids = channels.latest_channel_messages;
4193
4194 for (channel_id, participants) in channels.channel_participants {
4195 update
4196 .channel_participants
4197 .push(proto::ChannelParticipants {
4198 channel_id: channel_id.to_proto(),
4199 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4200 });
4201 }
4202
4203 for channel in channels.invited_channels {
4204 update.channel_invitations.push(channel.to_proto());
4205 }
4206
4207 update
4208}
4209
4210fn build_initial_contacts_update(
4211 contacts: Vec<db::Contact>,
4212 pool: &ConnectionPool,
4213) -> proto::UpdateContacts {
4214 let mut update = proto::UpdateContacts::default();
4215
4216 for contact in contacts {
4217 match contact {
4218 db::Contact::Accepted { user_id, busy } => {
4219 update.contacts.push(contact_for_user(user_id, busy, pool));
4220 }
4221 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4222 db::Contact::Incoming { user_id } => {
4223 update
4224 .incoming_requests
4225 .push(proto::IncomingContactRequest {
4226 requester_id: user_id.to_proto(),
4227 })
4228 }
4229 }
4230 }
4231
4232 update
4233}
4234
4235fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4236 proto::Contact {
4237 user_id: user_id.to_proto(),
4238 online: pool.is_user_online(user_id),
4239 busy,
4240 }
4241}
4242
4243fn room_updated(room: &proto::Room, peer: &Peer) {
4244 broadcast(
4245 None,
4246 room.participants
4247 .iter()
4248 .filter_map(|participant| Some(participant.peer_id?.into())),
4249 |peer_id| {
4250 peer.send(
4251 peer_id,
4252 proto::RoomUpdated {
4253 room: Some(room.clone()),
4254 },
4255 )
4256 },
4257 );
4258}
4259
4260fn channel_updated(
4261 channel: &db::channel::Model,
4262 room: &proto::Room,
4263 peer: &Peer,
4264 pool: &ConnectionPool,
4265) {
4266 let participants = room
4267 .participants
4268 .iter()
4269 .map(|p| p.user_id)
4270 .collect::<Vec<_>>();
4271
4272 broadcast(
4273 None,
4274 pool.channel_connection_ids(channel.root_id())
4275 .filter_map(|(channel_id, role)| {
4276 role.can_see_channel(channel.visibility)
4277 .then_some(channel_id)
4278 }),
4279 |peer_id| {
4280 peer.send(
4281 peer_id,
4282 proto::UpdateChannels {
4283 channel_participants: vec![proto::ChannelParticipants {
4284 channel_id: channel.id.to_proto(),
4285 participant_user_ids: participants.clone(),
4286 }],
4287 ..Default::default()
4288 },
4289 )
4290 },
4291 );
4292}
4293
4294async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4295 let db = session.db().await;
4296
4297 let contacts = db.get_contacts(user_id).await?;
4298 let busy = db.is_user_busy(user_id).await?;
4299
4300 let pool = session.connection_pool().await;
4301 let updated_contact = contact_for_user(user_id, busy, &pool);
4302 for contact in contacts {
4303 if let db::Contact::Accepted {
4304 user_id: contact_user_id,
4305 ..
4306 } = contact
4307 {
4308 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4309 session
4310 .peer
4311 .send(
4312 contact_conn_id,
4313 proto::UpdateContacts {
4314 contacts: vec![updated_contact.clone()],
4315 remove_contacts: Default::default(),
4316 incoming_requests: Default::default(),
4317 remove_incoming_requests: Default::default(),
4318 outgoing_requests: Default::default(),
4319 remove_outgoing_requests: Default::default(),
4320 },
4321 )
4322 .trace_err();
4323 }
4324 }
4325 }
4326 Ok(())
4327}
4328
4329async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4330 let mut contacts_to_update = HashSet::default();
4331
4332 let room_id;
4333 let canceled_calls_to_user_ids;
4334 let livekit_room;
4335 let delete_livekit_room;
4336 let room;
4337 let channel;
4338
4339 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4340 contacts_to_update.insert(session.user_id());
4341
4342 for project in left_room.left_projects.values() {
4343 project_left(project, session);
4344 }
4345
4346 room_id = RoomId::from_proto(left_room.room.id);
4347 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4348 livekit_room = mem::take(&mut left_room.room.livekit_room);
4349 delete_livekit_room = left_room.deleted;
4350 room = mem::take(&mut left_room.room);
4351 channel = mem::take(&mut left_room.channel);
4352
4353 room_updated(&room, &session.peer);
4354 } else {
4355 return Ok(());
4356 }
4357
4358 if let Some(channel) = channel {
4359 channel_updated(
4360 &channel,
4361 &room,
4362 &session.peer,
4363 &*session.connection_pool().await,
4364 );
4365 }
4366
4367 {
4368 let pool = session.connection_pool().await;
4369 for canceled_user_id in canceled_calls_to_user_ids {
4370 for connection_id in pool.user_connection_ids(canceled_user_id) {
4371 session
4372 .peer
4373 .send(
4374 connection_id,
4375 proto::CallCanceled {
4376 room_id: room_id.to_proto(),
4377 },
4378 )
4379 .trace_err();
4380 }
4381 contacts_to_update.insert(canceled_user_id);
4382 }
4383 }
4384
4385 for contact_user_id in contacts_to_update {
4386 update_user_contacts(contact_user_id, session).await?;
4387 }
4388
4389 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4390 live_kit
4391 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4392 .await
4393 .trace_err();
4394
4395 if delete_livekit_room {
4396 live_kit.delete_room(livekit_room).await.trace_err();
4397 }
4398 }
4399
4400 Ok(())
4401}
4402
4403async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4404 let left_channel_buffers = session
4405 .db()
4406 .await
4407 .leave_channel_buffers(session.connection_id)
4408 .await?;
4409
4410 for left_buffer in left_channel_buffers {
4411 channel_buffer_updated(
4412 session.connection_id,
4413 left_buffer.connections,
4414 &proto::UpdateChannelBufferCollaborators {
4415 channel_id: left_buffer.channel_id.to_proto(),
4416 collaborators: left_buffer.collaborators,
4417 },
4418 &session.peer,
4419 );
4420 }
4421
4422 Ok(())
4423}
4424
4425fn project_left(project: &db::LeftProject, session: &Session) {
4426 for connection_id in &project.connection_ids {
4427 if project.should_unshare {
4428 session
4429 .peer
4430 .send(
4431 *connection_id,
4432 proto::UnshareProject {
4433 project_id: project.id.to_proto(),
4434 },
4435 )
4436 .trace_err();
4437 } else {
4438 session
4439 .peer
4440 .send(
4441 *connection_id,
4442 proto::RemoveProjectCollaborator {
4443 project_id: project.id.to_proto(),
4444 peer_id: Some(session.connection_id.into()),
4445 },
4446 )
4447 .trace_err();
4448 }
4449 }
4450}
4451
4452pub trait ResultExt {
4453 type Ok;
4454
4455 fn trace_err(self) -> Option<Self::Ok>;
4456}
4457
4458impl<T, E> ResultExt for Result<T, E>
4459where
4460 E: std::fmt::Debug,
4461{
4462 type Ok = T;
4463
4464 #[track_caller]
4465 fn trace_err(self) -> Option<T> {
4466 match self {
4467 Ok(value) => Some(value),
4468 Err(error) => {
4469 tracing::error!("{:?}", error);
4470 None
4471 }
4472 }
4473 }
4474}