1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
332 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
342 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
345 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
346 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
347 .add_message_handler(create_buffer_for_peer)
348 .add_request_handler(update_buffer)
349 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
354 .add_request_handler(get_users)
355 .add_request_handler(fuzzy_search_users)
356 .add_request_handler(request_contact)
357 .add_request_handler(remove_contact)
358 .add_request_handler(respond_to_contact_request)
359 .add_message_handler(subscribe_to_channels)
360 .add_request_handler(create_channel)
361 .add_request_handler(delete_channel)
362 .add_request_handler(invite_channel_member)
363 .add_request_handler(remove_channel_member)
364 .add_request_handler(set_channel_member_role)
365 .add_request_handler(set_channel_visibility)
366 .add_request_handler(rename_channel)
367 .add_request_handler(join_channel_buffer)
368 .add_request_handler(leave_channel_buffer)
369 .add_message_handler(update_channel_buffer)
370 .add_request_handler(rejoin_channel_buffers)
371 .add_request_handler(get_channel_members)
372 .add_request_handler(respond_to_channel_invite)
373 .add_request_handler(join_channel)
374 .add_request_handler(join_channel_chat)
375 .add_message_handler(leave_channel_chat)
376 .add_request_handler(send_channel_message)
377 .add_request_handler(remove_channel_message)
378 .add_request_handler(update_channel_message)
379 .add_request_handler(get_channel_messages)
380 .add_request_handler(get_channel_messages_by_id)
381 .add_request_handler(get_notifications)
382 .add_request_handler(mark_notification_as_read)
383 .add_request_handler(move_channel)
384 .add_request_handler(follow)
385 .add_message_handler(unfollow)
386 .add_message_handler(update_followers)
387 .add_request_handler(get_private_user_info)
388 .add_request_handler(get_llm_api_token)
389 .add_request_handler(accept_terms_of_service)
390 .add_message_handler(acknowledge_channel_message)
391 .add_message_handler(acknowledge_buffer_version)
392 .add_request_handler(get_supermaven_api_key)
393 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
395 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
396 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
398 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
399 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
400 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
401 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
402 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
403 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
404 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
405 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
406 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
407 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
408 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
409 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
411 .add_message_handler(update_context)
412 .add_request_handler({
413 let app_state = app_state.clone();
414 move |request, response, session| {
415 let app_state = app_state.clone();
416 async move {
417 count_language_model_tokens(request, response, session, &app_state.config)
418 .await
419 }
420 }
421 })
422 .add_request_handler(get_cached_embeddings)
423 .add_request_handler({
424 let app_state = app_state.clone();
425 move |request, response, session| {
426 compute_embeddings(
427 request,
428 response,
429 session,
430 app_state.config.openai_api_key.clone(),
431 )
432 }
433 });
434
435 Arc::new(server)
436 }
437
438 pub async fn start(&self) -> Result<()> {
439 let server_id = *self.id.lock();
440 let app_state = self.app_state.clone();
441 let peer = self.peer.clone();
442 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
443 let pool = self.connection_pool.clone();
444 let livekit_client = self.app_state.livekit_client.clone();
445
446 let span = info_span!("start server");
447 self.app_state.executor.spawn_detached(
448 async move {
449 tracing::info!("waiting for cleanup timeout");
450 timeout.await;
451 tracing::info!("cleanup timeout expired, retrieving stale rooms");
452 if let Some((room_ids, channel_ids)) = app_state
453 .db
454 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
455 .await
456 .trace_err()
457 {
458 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
459 tracing::info!(
460 stale_channel_buffer_count = channel_ids.len(),
461 "retrieved stale channel buffers"
462 );
463
464 for channel_id in channel_ids {
465 if let Some(refreshed_channel_buffer) = app_state
466 .db
467 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
468 .await
469 .trace_err()
470 {
471 for connection_id in refreshed_channel_buffer.connection_ids {
472 peer.send(
473 connection_id,
474 proto::UpdateChannelBufferCollaborators {
475 channel_id: channel_id.to_proto(),
476 collaborators: refreshed_channel_buffer
477 .collaborators
478 .clone(),
479 },
480 )
481 .trace_err();
482 }
483 }
484 }
485
486 for room_id in room_ids {
487 let mut contacts_to_update = HashSet::default();
488 let mut canceled_calls_to_user_ids = Vec::new();
489 let mut livekit_room = String::new();
490 let mut delete_livekit_room = false;
491
492 if let Some(mut refreshed_room) = app_state
493 .db
494 .clear_stale_room_participants(room_id, server_id)
495 .await
496 .trace_err()
497 {
498 tracing::info!(
499 room_id = room_id.0,
500 new_participant_count = refreshed_room.room.participants.len(),
501 "refreshed room"
502 );
503 room_updated(&refreshed_room.room, &peer);
504 if let Some(channel) = refreshed_room.channel.as_ref() {
505 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
506 }
507 contacts_to_update
508 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
509 contacts_to_update
510 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
511 canceled_calls_to_user_ids =
512 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
513 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
514 delete_livekit_room = refreshed_room.room.participants.is_empty();
515 }
516
517 {
518 let pool = pool.lock();
519 for canceled_user_id in canceled_calls_to_user_ids {
520 for connection_id in pool.user_connection_ids(canceled_user_id) {
521 peer.send(
522 connection_id,
523 proto::CallCanceled {
524 room_id: room_id.to_proto(),
525 },
526 )
527 .trace_err();
528 }
529 }
530 }
531
532 for user_id in contacts_to_update {
533 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
534 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
535 if let Some((busy, contacts)) = busy.zip(contacts) {
536 let pool = pool.lock();
537 let updated_contact = contact_for_user(user_id, busy, &pool);
538 for contact in contacts {
539 if let db::Contact::Accepted {
540 user_id: contact_user_id,
541 ..
542 } = contact
543 {
544 for contact_conn_id in
545 pool.user_connection_ids(contact_user_id)
546 {
547 peer.send(
548 contact_conn_id,
549 proto::UpdateContacts {
550 contacts: vec![updated_contact.clone()],
551 remove_contacts: Default::default(),
552 incoming_requests: Default::default(),
553 remove_incoming_requests: Default::default(),
554 outgoing_requests: Default::default(),
555 remove_outgoing_requests: Default::default(),
556 },
557 )
558 .trace_err();
559 }
560 }
561 }
562 }
563 }
564
565 if let Some(live_kit) = livekit_client.as_ref() {
566 if delete_livekit_room {
567 live_kit.delete_room(livekit_room).await.trace_err();
568 }
569 }
570 }
571 }
572
573 app_state
574 .db
575 .delete_stale_servers(&app_state.config.zed_environment, server_id)
576 .await
577 .trace_err();
578 }
579 .instrument(span),
580 );
581 Ok(())
582 }
583
584 pub fn teardown(&self) {
585 self.peer.teardown();
586 self.connection_pool.lock().reset();
587 let _ = self.teardown.send(true);
588 }
589
590 #[cfg(test)]
591 pub fn reset(&self, id: ServerId) {
592 self.teardown();
593 *self.id.lock() = id;
594 self.peer.reset(id.0 as u32);
595 let _ = self.teardown.send(false);
596 }
597
598 #[cfg(test)]
599 pub fn id(&self) -> ServerId {
600 *self.id.lock()
601 }
602
603 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
604 where
605 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
606 Fut: 'static + Send + Future<Output = Result<()>>,
607 M: EnvelopedMessage,
608 {
609 let prev_handler = self.handlers.insert(
610 TypeId::of::<M>(),
611 Box::new(move |envelope, session| {
612 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
613 let received_at = envelope.received_at;
614 tracing::info!("message received");
615 let start_time = Instant::now();
616 let future = (handler)(*envelope, session);
617 async move {
618 let result = future.await;
619 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
620 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
621 let queue_duration_ms = total_duration_ms - processing_duration_ms;
622 let payload_type = M::NAME;
623
624 match result {
625 Err(error) => {
626 tracing::error!(
627 ?error,
628 total_duration_ms,
629 processing_duration_ms,
630 queue_duration_ms,
631 payload_type,
632 "error handling message"
633 )
634 }
635 Ok(()) => tracing::info!(
636 total_duration_ms,
637 processing_duration_ms,
638 queue_duration_ms,
639 "finished handling message"
640 ),
641 }
642 }
643 .boxed()
644 }),
645 );
646 if prev_handler.is_some() {
647 panic!("registered a handler for the same message twice");
648 }
649 self
650 }
651
652 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
653 where
654 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
655 Fut: 'static + Send + Future<Output = Result<()>>,
656 M: EnvelopedMessage,
657 {
658 self.add_handler(move |envelope, session| handler(envelope.payload, session));
659 self
660 }
661
662 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
663 where
664 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
665 Fut: Send + Future<Output = Result<()>>,
666 M: RequestMessage,
667 {
668 let handler = Arc::new(handler);
669 self.add_handler(move |envelope, session| {
670 let receipt = envelope.receipt();
671 let handler = handler.clone();
672 async move {
673 let peer = session.peer.clone();
674 let responded = Arc::new(AtomicBool::default());
675 let response = Response {
676 peer: peer.clone(),
677 responded: responded.clone(),
678 receipt,
679 };
680 match (handler)(envelope.payload, response, session).await {
681 Ok(()) => {
682 if responded.load(std::sync::atomic::Ordering::SeqCst) {
683 Ok(())
684 } else {
685 Err(anyhow!("handler did not send a response"))?
686 }
687 }
688 Err(error) => {
689 let proto_err = match &error {
690 Error::Internal(err) => err.to_proto(),
691 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
692 };
693 peer.respond_with_error(receipt, proto_err)?;
694 Err(error)
695 }
696 }
697 }
698 })
699 }
700
701 pub fn handle_connection(
702 self: &Arc<Self>,
703 connection: Connection,
704 address: String,
705 principal: Principal,
706 zed_version: ZedVersion,
707 geoip_country_code: Option<String>,
708 system_id: Option<String>,
709 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
710 executor: Executor,
711 ) -> impl Future<Output = ()> {
712 let this = self.clone();
713 let span = info_span!("handle connection", %address,
714 connection_id=field::Empty,
715 user_id=field::Empty,
716 login=field::Empty,
717 impersonator=field::Empty,
718 geoip_country_code=field::Empty
719 );
720 principal.update_span(&span);
721 if let Some(country_code) = geoip_country_code.as_ref() {
722 span.record("geoip_country_code", country_code);
723 }
724
725 let mut teardown = self.teardown.subscribe();
726 async move {
727 if *teardown.borrow() {
728 tracing::error!("server is tearing down");
729 return
730 }
731 let (connection_id, handle_io, mut incoming_rx) = this
732 .peer
733 .add_connection(connection, {
734 let executor = executor.clone();
735 move |duration| executor.sleep(duration)
736 });
737 tracing::Span::current().record("connection_id", format!("{}", connection_id));
738
739 tracing::info!("connection opened");
740
741 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
742 let http_client = match ReqwestClient::user_agent(&user_agent) {
743 Ok(http_client) => Arc::new(http_client),
744 Err(error) => {
745 tracing::error!(?error, "failed to create HTTP client");
746 return;
747 }
748 };
749
750 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
751 supermaven_admin_api_key.to_string(),
752 http_client.clone(),
753 )));
754
755 let session = Session {
756 principal: principal.clone(),
757 connection_id,
758 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
759 peer: this.peer.clone(),
760 connection_pool: this.connection_pool.clone(),
761 app_state: this.app_state.clone(),
762 http_client,
763 geoip_country_code,
764 system_id,
765 _executor: executor.clone(),
766 supermaven_client,
767 };
768
769 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
770 tracing::error!(?error, "failed to send initial client update");
771 return;
772 }
773
774 let handle_io = handle_io.fuse();
775 futures::pin_mut!(handle_io);
776
777 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
778 // This prevents deadlocks when e.g., client A performs a request to client B and
779 // client B performs a request to client A. If both clients stop processing further
780 // messages until their respective request completes, they won't have a chance to
781 // respond to the other client's request and cause a deadlock.
782 //
783 // This arrangement ensures we will attempt to process earlier messages first, but fall
784 // back to processing messages arrived later in the spirit of making progress.
785 let mut foreground_message_handlers = FuturesUnordered::new();
786 let concurrent_handlers = Arc::new(Semaphore::new(256));
787 loop {
788 let next_message = async {
789 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
790 let message = incoming_rx.next().await;
791 (permit, message)
792 }.fuse();
793 futures::pin_mut!(next_message);
794 futures::select_biased! {
795 _ = teardown.changed().fuse() => return,
796 result = handle_io => {
797 if let Err(error) = result {
798 tracing::error!(?error, "error handling I/O");
799 }
800 break;
801 }
802 _ = foreground_message_handlers.next() => {}
803 next_message = next_message => {
804 let (permit, message) = next_message;
805 if let Some(message) = message {
806 let type_name = message.payload_type_name();
807 // note: we copy all the fields from the parent span so we can query them in the logs.
808 // (https://github.com/tokio-rs/tracing/issues/2670).
809 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
810 user_id=field::Empty,
811 login=field::Empty,
812 impersonator=field::Empty,
813 );
814 principal.update_span(&span);
815 let span_enter = span.enter();
816 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
817 let is_background = message.is_background();
818 let handle_message = (handler)(message, session.clone());
819 drop(span_enter);
820
821 let handle_message = async move {
822 handle_message.await;
823 drop(permit);
824 }.instrument(span);
825 if is_background {
826 executor.spawn_detached(handle_message);
827 } else {
828 foreground_message_handlers.push(handle_message);
829 }
830 } else {
831 tracing::error!("no message handler");
832 }
833 } else {
834 tracing::info!("connection closed");
835 break;
836 }
837 }
838 }
839 }
840
841 drop(foreground_message_handlers);
842 tracing::info!("signing out");
843 if let Err(error) = connection_lost(session, teardown, executor).await {
844 tracing::error!(?error, "error signing out");
845 }
846
847 }.instrument(span)
848 }
849
850 async fn send_initial_client_update(
851 &self,
852 connection_id: ConnectionId,
853 principal: &Principal,
854 zed_version: ZedVersion,
855 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
856 session: &Session,
857 ) -> Result<()> {
858 self.peer.send(
859 connection_id,
860 proto::Hello {
861 peer_id: Some(connection_id.into()),
862 },
863 )?;
864 tracing::info!("sent hello message");
865 if let Some(send_connection_id) = send_connection_id.take() {
866 let _ = send_connection_id.send(connection_id);
867 }
868
869 match principal {
870 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
871 if !user.connected_once {
872 self.peer.send(connection_id, proto::ShowContacts {})?;
873 self.app_state
874 .db
875 .set_user_connected_once(user.id, true)
876 .await?;
877 }
878
879 update_user_plan(user.id, session).await?;
880
881 let contacts = self.app_state.db.get_contacts(user.id).await?;
882
883 {
884 let mut pool = self.connection_pool.lock();
885 pool.add_connection(connection_id, user.id, user.admin, zed_version);
886 self.peer.send(
887 connection_id,
888 build_initial_contacts_update(contacts, &pool),
889 )?;
890 }
891
892 if should_auto_subscribe_to_channels(zed_version) {
893 subscribe_user_to_channels(user.id, session).await?;
894 }
895
896 if let Some(incoming_call) =
897 self.app_state.db.incoming_call_for_user(user.id).await?
898 {
899 self.peer.send(connection_id, incoming_call)?;
900 }
901
902 update_user_contacts(user.id, session).await?;
903 }
904 }
905
906 Ok(())
907 }
908
909 pub async fn invite_code_redeemed(
910 self: &Arc<Self>,
911 inviter_id: UserId,
912 invitee_id: UserId,
913 ) -> Result<()> {
914 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
915 if let Some(code) = &user.invite_code {
916 let pool = self.connection_pool.lock();
917 let invitee_contact = contact_for_user(invitee_id, false, &pool);
918 for connection_id in pool.user_connection_ids(inviter_id) {
919 self.peer.send(
920 connection_id,
921 proto::UpdateContacts {
922 contacts: vec![invitee_contact.clone()],
923 ..Default::default()
924 },
925 )?;
926 self.peer.send(
927 connection_id,
928 proto::UpdateInviteInfo {
929 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
930 count: user.invite_count as u32,
931 },
932 )?;
933 }
934 }
935 }
936 Ok(())
937 }
938
939 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
940 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
941 if let Some(invite_code) = &user.invite_code {
942 let pool = self.connection_pool.lock();
943 for connection_id in pool.user_connection_ids(user_id) {
944 self.peer.send(
945 connection_id,
946 proto::UpdateInviteInfo {
947 url: format!(
948 "{}{}",
949 self.app_state.config.invite_link_prefix, invite_code
950 ),
951 count: user.invite_count as u32,
952 },
953 )?;
954 }
955 }
956 }
957 Ok(())
958 }
959
960 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
961 let pool = self.connection_pool.lock();
962 for connection_id in pool.user_connection_ids(user_id) {
963 self.peer
964 .send(connection_id, proto::RefreshLlmToken {})
965 .trace_err();
966 }
967 }
968
969 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
970 ServerSnapshot {
971 connection_pool: ConnectionPoolGuard {
972 guard: self.connection_pool.lock(),
973 _not_send: PhantomData,
974 },
975 peer: &self.peer,
976 }
977 }
978}
979
980impl Deref for ConnectionPoolGuard<'_> {
981 type Target = ConnectionPool;
982
983 fn deref(&self) -> &Self::Target {
984 &self.guard
985 }
986}
987
988impl DerefMut for ConnectionPoolGuard<'_> {
989 fn deref_mut(&mut self) -> &mut Self::Target {
990 &mut self.guard
991 }
992}
993
994impl Drop for ConnectionPoolGuard<'_> {
995 fn drop(&mut self) {
996 #[cfg(test)]
997 self.check_invariants();
998 }
999}
1000
1001fn broadcast<F>(
1002 sender_id: Option<ConnectionId>,
1003 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004 mut f: F,
1005) where
1006 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008 for receiver_id in receiver_ids {
1009 if Some(receiver_id) != sender_id {
1010 if let Err(error) = f(receiver_id) {
1011 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012 }
1013 }
1014 }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020 fn name() -> &'static HeaderName {
1021 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023 }
1024
1025 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026 where
1027 Self: Sized,
1028 I: Iterator<Item = &'i axum::http::HeaderValue>,
1029 {
1030 let version = values
1031 .next()
1032 .ok_or_else(axum::headers::Error::invalid)?
1033 .to_str()
1034 .map_err(|_| axum::headers::Error::invalid())?
1035 .parse()
1036 .map_err(|_| axum::headers::Error::invalid())?;
1037 Ok(Self(version))
1038 }
1039
1040 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041 values.extend([self.0.to_string().parse().unwrap()]);
1042 }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047 fn name() -> &'static HeaderName {
1048 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050 }
1051
1052 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053 where
1054 Self: Sized,
1055 I: Iterator<Item = &'i axum::http::HeaderValue>,
1056 {
1057 let version = values
1058 .next()
1059 .ok_or_else(axum::headers::Error::invalid)?
1060 .to_str()
1061 .map_err(|_| axum::headers::Error::invalid())?
1062 .parse()
1063 .map_err(|_| axum::headers::Error::invalid())?;
1064 Ok(Self(version))
1065 }
1066
1067 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068 values.extend([self.0.to_string().parse().unwrap()]);
1069 }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073 Router::new()
1074 .route("/rpc", get(handle_websocket_request))
1075 .layer(
1076 ServiceBuilder::new()
1077 .layer(Extension(server.app_state.clone()))
1078 .layer(middleware::from_fn(auth::validate_header)),
1079 )
1080 .route("/metrics", get(handle_metrics))
1081 .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088 Extension(server): Extension<Arc<Server>>,
1089 Extension(principal): Extension<Principal>,
1090 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1091 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1092 ws: WebSocketUpgrade,
1093) -> axum::response::Response {
1094 if protocol_version != rpc::PROTOCOL_VERSION {
1095 return (
1096 StatusCode::UPGRADE_REQUIRED,
1097 "client must be upgraded".to_string(),
1098 )
1099 .into_response();
1100 }
1101
1102 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "no version header found".to_string(),
1106 )
1107 .into_response();
1108 };
1109
1110 if !version.can_collaborate() {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "client must be upgraded".to_string(),
1114 )
1115 .into_response();
1116 }
1117
1118 let socket_address = socket_address.to_string();
1119 ws.on_upgrade(move |socket| {
1120 let socket = socket
1121 .map_ok(to_tungstenite_message)
1122 .err_into()
1123 .with(|message| async move { to_axum_message(message) });
1124 let connection = Connection::new(Box::pin(socket));
1125 async move {
1126 server
1127 .handle_connection(
1128 connection,
1129 socket_address,
1130 principal,
1131 version,
1132 country_code_header.map(|header| header.to_string()),
1133 system_id_header.map(|header| header.to_string()),
1134 None,
1135 Executor::Production,
1136 )
1137 .await;
1138 }
1139 })
1140}
1141
1142pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1143 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1144 let connections_metric = CONNECTIONS_METRIC
1145 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1146
1147 let connections = server
1148 .connection_pool
1149 .lock()
1150 .connections()
1151 .filter(|connection| !connection.admin)
1152 .count();
1153 connections_metric.set(connections as _);
1154
1155 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1156 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1157 register_int_gauge!(
1158 "shared_projects",
1159 "number of open projects with one or more guests"
1160 )
1161 .unwrap()
1162 });
1163
1164 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1165 shared_projects_metric.set(shared_projects as _);
1166
1167 let encoder = prometheus::TextEncoder::new();
1168 let metric_families = prometheus::gather();
1169 let encoded_metrics = encoder
1170 .encode_to_string(&metric_families)
1171 .map_err(|err| anyhow!("{}", err))?;
1172 Ok(encoded_metrics)
1173}
1174
1175#[instrument(err, skip(executor))]
1176async fn connection_lost(
1177 session: Session,
1178 mut teardown: watch::Receiver<bool>,
1179 executor: Executor,
1180) -> Result<()> {
1181 session.peer.disconnect(session.connection_id);
1182 session
1183 .connection_pool()
1184 .await
1185 .remove_connection(session.connection_id)?;
1186
1187 session
1188 .db()
1189 .await
1190 .connection_lost(session.connection_id)
1191 .await
1192 .trace_err();
1193
1194 futures::select_biased! {
1195 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1196
1197 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1198 leave_room_for_session(&session, session.connection_id).await.trace_err();
1199 leave_channel_buffers_for_session(&session)
1200 .await
1201 .trace_err();
1202
1203 if !session
1204 .connection_pool()
1205 .await
1206 .is_user_online(session.user_id())
1207 {
1208 let db = session.db().await;
1209 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1210 room_updated(&room, &session.peer);
1211 }
1212 }
1213
1214 update_user_contacts(session.user_id(), &session).await?;
1215 },
1216 _ = teardown.changed().fuse() => {}
1217 }
1218
1219 Ok(())
1220}
1221
1222/// Acknowledges a ping from a client, used to keep the connection alive.
1223async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1224 response.send(proto::Ack {})?;
1225 Ok(())
1226}
1227
1228/// Creates a new room for calling (outside of channels)
1229async fn create_room(
1230 _request: proto::CreateRoom,
1231 response: Response<proto::CreateRoom>,
1232 session: Session,
1233) -> Result<()> {
1234 let livekit_room = nanoid::nanoid!(30);
1235
1236 let live_kit_connection_info = util::maybe!(async {
1237 let live_kit = session.app_state.livekit_client.as_ref();
1238 let live_kit = live_kit?;
1239 let user_id = session.user_id().to_string();
1240
1241 let token = live_kit
1242 .room_token(&livekit_room, &user_id.to_string())
1243 .trace_err()?;
1244
1245 Some(proto::LiveKitConnectionInfo {
1246 server_url: live_kit.url().into(),
1247 token,
1248 can_publish: true,
1249 })
1250 })
1251 .await;
1252
1253 let room = session
1254 .db()
1255 .await
1256 .create_room(session.user_id(), session.connection_id, &livekit_room)
1257 .await?;
1258
1259 response.send(proto::CreateRoomResponse {
1260 room: Some(room.clone()),
1261 live_kit_connection_info,
1262 })?;
1263
1264 update_user_contacts(session.user_id(), &session).await?;
1265 Ok(())
1266}
1267
1268/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1269async fn join_room(
1270 request: proto::JoinRoom,
1271 response: Response<proto::JoinRoom>,
1272 session: Session,
1273) -> Result<()> {
1274 let room_id = RoomId::from_proto(request.id);
1275
1276 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1277
1278 if let Some(channel_id) = channel_id {
1279 return join_channel_internal(channel_id, Box::new(response), session).await;
1280 }
1281
1282 let joined_room = {
1283 let room = session
1284 .db()
1285 .await
1286 .join_room(room_id, session.user_id(), session.connection_id)
1287 .await?;
1288 room_updated(&room.room, &session.peer);
1289 room.into_inner()
1290 };
1291
1292 for connection_id in session
1293 .connection_pool()
1294 .await
1295 .user_connection_ids(session.user_id())
1296 {
1297 session
1298 .peer
1299 .send(
1300 connection_id,
1301 proto::CallCanceled {
1302 room_id: room_id.to_proto(),
1303 },
1304 )
1305 .trace_err();
1306 }
1307
1308 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1309 {
1310 live_kit
1311 .room_token(
1312 &joined_room.room.livekit_room,
1313 &session.user_id().to_string(),
1314 )
1315 .trace_err()
1316 .map(|token| proto::LiveKitConnectionInfo {
1317 server_url: live_kit.url().into(),
1318 token,
1319 can_publish: true,
1320 })
1321 } else {
1322 None
1323 };
1324
1325 response.send(proto::JoinRoomResponse {
1326 room: Some(joined_room.room),
1327 channel_id: None,
1328 live_kit_connection_info,
1329 })?;
1330
1331 update_user_contacts(session.user_id(), &session).await?;
1332 Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337 request: proto::RejoinRoom,
1338 response: Response<proto::RejoinRoom>,
1339 session: Session,
1340) -> Result<()> {
1341 let room;
1342 let channel;
1343 {
1344 let mut rejoined_room = session
1345 .db()
1346 .await
1347 .rejoin_room(request, session.user_id(), session.connection_id)
1348 .await?;
1349
1350 response.send(proto::RejoinRoomResponse {
1351 room: Some(rejoined_room.room.clone()),
1352 reshared_projects: rejoined_room
1353 .reshared_projects
1354 .iter()
1355 .map(|project| proto::ResharedProject {
1356 id: project.id.to_proto(),
1357 collaborators: project
1358 .collaborators
1359 .iter()
1360 .map(|collaborator| collaborator.to_proto())
1361 .collect(),
1362 })
1363 .collect(),
1364 rejoined_projects: rejoined_room
1365 .rejoined_projects
1366 .iter()
1367 .map(|rejoined_project| rejoined_project.to_proto())
1368 .collect(),
1369 })?;
1370 room_updated(&rejoined_room.room, &session.peer);
1371
1372 for project in &rejoined_room.reshared_projects {
1373 for collaborator in &project.collaborators {
1374 session
1375 .peer
1376 .send(
1377 collaborator.connection_id,
1378 proto::UpdateProjectCollaborator {
1379 project_id: project.id.to_proto(),
1380 old_peer_id: Some(project.old_connection_id.into()),
1381 new_peer_id: Some(session.connection_id.into()),
1382 },
1383 )
1384 .trace_err();
1385 }
1386
1387 broadcast(
1388 Some(session.connection_id),
1389 project
1390 .collaborators
1391 .iter()
1392 .map(|collaborator| collaborator.connection_id),
1393 |connection_id| {
1394 session.peer.forward_send(
1395 session.connection_id,
1396 connection_id,
1397 proto::UpdateProject {
1398 project_id: project.id.to_proto(),
1399 worktrees: project.worktrees.clone(),
1400 },
1401 )
1402 },
1403 );
1404 }
1405
1406 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1407
1408 let rejoined_room = rejoined_room.into_inner();
1409
1410 room = rejoined_room.room;
1411 channel = rejoined_room.channel;
1412 }
1413
1414 if let Some(channel) = channel {
1415 channel_updated(
1416 &channel,
1417 &room,
1418 &session.peer,
1419 &*session.connection_pool().await,
1420 );
1421 }
1422
1423 update_user_contacts(session.user_id(), &session).await?;
1424 Ok(())
1425}
1426
1427fn notify_rejoined_projects(
1428 rejoined_projects: &mut Vec<RejoinedProject>,
1429 session: &Session,
1430) -> Result<()> {
1431 for project in rejoined_projects.iter() {
1432 for collaborator in &project.collaborators {
1433 session
1434 .peer
1435 .send(
1436 collaborator.connection_id,
1437 proto::UpdateProjectCollaborator {
1438 project_id: project.id.to_proto(),
1439 old_peer_id: Some(project.old_connection_id.into()),
1440 new_peer_id: Some(session.connection_id.into()),
1441 },
1442 )
1443 .trace_err();
1444 }
1445 }
1446
1447 for project in rejoined_projects {
1448 for worktree in mem::take(&mut project.worktrees) {
1449 // Stream this worktree's entries.
1450 let message = proto::UpdateWorktree {
1451 project_id: project.id.to_proto(),
1452 worktree_id: worktree.id,
1453 abs_path: worktree.abs_path.clone(),
1454 root_name: worktree.root_name,
1455 updated_entries: worktree.updated_entries,
1456 removed_entries: worktree.removed_entries,
1457 scan_id: worktree.scan_id,
1458 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1459 updated_repositories: worktree.updated_repositories,
1460 removed_repositories: worktree.removed_repositories,
1461 };
1462 for update in proto::split_worktree_update(message) {
1463 session.peer.send(session.connection_id, update.clone())?;
1464 }
1465
1466 // Stream this worktree's diagnostics.
1467 for summary in worktree.diagnostic_summaries {
1468 session.peer.send(
1469 session.connection_id,
1470 proto::UpdateDiagnosticSummary {
1471 project_id: project.id.to_proto(),
1472 worktree_id: worktree.id,
1473 summary: Some(summary),
1474 },
1475 )?;
1476 }
1477
1478 for settings_file in worktree.settings_files {
1479 session.peer.send(
1480 session.connection_id,
1481 proto::UpdateWorktreeSettings {
1482 project_id: project.id.to_proto(),
1483 worktree_id: worktree.id,
1484 path: settings_file.path,
1485 content: Some(settings_file.content),
1486 kind: Some(settings_file.kind.to_proto().into()),
1487 },
1488 )?;
1489 }
1490 }
1491
1492 for language_server in &project.language_servers {
1493 session.peer.send(
1494 session.connection_id,
1495 proto::UpdateLanguageServer {
1496 project_id: project.id.to_proto(),
1497 language_server_id: language_server.id,
1498 variant: Some(
1499 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1500 proto::LspDiskBasedDiagnosticsUpdated {},
1501 ),
1502 ),
1503 },
1504 )?;
1505 }
1506 }
1507 Ok(())
1508}
1509
1510/// leave room disconnects from the room.
1511async fn leave_room(
1512 _: proto::LeaveRoom,
1513 response: Response<proto::LeaveRoom>,
1514 session: Session,
1515) -> Result<()> {
1516 leave_room_for_session(&session, session.connection_id).await?;
1517 response.send(proto::Ack {})?;
1518 Ok(())
1519}
1520
1521/// Updates the permissions of someone else in the room.
1522async fn set_room_participant_role(
1523 request: proto::SetRoomParticipantRole,
1524 response: Response<proto::SetRoomParticipantRole>,
1525 session: Session,
1526) -> Result<()> {
1527 let user_id = UserId::from_proto(request.user_id);
1528 let role = ChannelRole::from(request.role());
1529
1530 let (livekit_room, can_publish) = {
1531 let room = session
1532 .db()
1533 .await
1534 .set_room_participant_role(
1535 session.user_id(),
1536 RoomId::from_proto(request.room_id),
1537 user_id,
1538 role,
1539 )
1540 .await?;
1541
1542 let livekit_room = room.livekit_room.clone();
1543 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1544 room_updated(&room, &session.peer);
1545 (livekit_room, can_publish)
1546 };
1547
1548 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1549 live_kit
1550 .update_participant(
1551 livekit_room.clone(),
1552 request.user_id.to_string(),
1553 livekit_api::proto::ParticipantPermission {
1554 can_subscribe: true,
1555 can_publish,
1556 can_publish_data: can_publish,
1557 hidden: false,
1558 recorder: false,
1559 },
1560 )
1561 .await
1562 .trace_err();
1563 }
1564
1565 response.send(proto::Ack {})?;
1566 Ok(())
1567}
1568
1569/// Call someone else into the current room
1570async fn call(
1571 request: proto::Call,
1572 response: Response<proto::Call>,
1573 session: Session,
1574) -> Result<()> {
1575 let room_id = RoomId::from_proto(request.room_id);
1576 let calling_user_id = session.user_id();
1577 let calling_connection_id = session.connection_id;
1578 let called_user_id = UserId::from_proto(request.called_user_id);
1579 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1580 if !session
1581 .db()
1582 .await
1583 .has_contact(calling_user_id, called_user_id)
1584 .await?
1585 {
1586 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1587 }
1588
1589 let incoming_call = {
1590 let (room, incoming_call) = &mut *session
1591 .db()
1592 .await
1593 .call(
1594 room_id,
1595 calling_user_id,
1596 calling_connection_id,
1597 called_user_id,
1598 initial_project_id,
1599 )
1600 .await?;
1601 room_updated(room, &session.peer);
1602 mem::take(incoming_call)
1603 };
1604 update_user_contacts(called_user_id, &session).await?;
1605
1606 let mut calls = session
1607 .connection_pool()
1608 .await
1609 .user_connection_ids(called_user_id)
1610 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1611 .collect::<FuturesUnordered<_>>();
1612
1613 while let Some(call_response) = calls.next().await {
1614 match call_response.as_ref() {
1615 Ok(_) => {
1616 response.send(proto::Ack {})?;
1617 return Ok(());
1618 }
1619 Err(_) => {
1620 call_response.trace_err();
1621 }
1622 }
1623 }
1624
1625 {
1626 let room = session
1627 .db()
1628 .await
1629 .call_failed(room_id, called_user_id)
1630 .await?;
1631 room_updated(&room, &session.peer);
1632 }
1633 update_user_contacts(called_user_id, &session).await?;
1634
1635 Err(anyhow!("failed to ring user"))?
1636}
1637
1638/// Cancel an outgoing call.
1639async fn cancel_call(
1640 request: proto::CancelCall,
1641 response: Response<proto::CancelCall>,
1642 session: Session,
1643) -> Result<()> {
1644 let called_user_id = UserId::from_proto(request.called_user_id);
1645 let room_id = RoomId::from_proto(request.room_id);
1646 {
1647 let room = session
1648 .db()
1649 .await
1650 .cancel_call(room_id, session.connection_id, called_user_id)
1651 .await?;
1652 room_updated(&room, &session.peer);
1653 }
1654
1655 for connection_id in session
1656 .connection_pool()
1657 .await
1658 .user_connection_ids(called_user_id)
1659 {
1660 session
1661 .peer
1662 .send(
1663 connection_id,
1664 proto::CallCanceled {
1665 room_id: room_id.to_proto(),
1666 },
1667 )
1668 .trace_err();
1669 }
1670 response.send(proto::Ack {})?;
1671
1672 update_user_contacts(called_user_id, &session).await?;
1673 Ok(())
1674}
1675
1676/// Decline an incoming call.
1677async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1678 let room_id = RoomId::from_proto(message.room_id);
1679 {
1680 let room = session
1681 .db()
1682 .await
1683 .decline_call(Some(room_id), session.user_id())
1684 .await?
1685 .ok_or_else(|| anyhow!("failed to decline call"))?;
1686 room_updated(&room, &session.peer);
1687 }
1688
1689 for connection_id in session
1690 .connection_pool()
1691 .await
1692 .user_connection_ids(session.user_id())
1693 {
1694 session
1695 .peer
1696 .send(
1697 connection_id,
1698 proto::CallCanceled {
1699 room_id: room_id.to_proto(),
1700 },
1701 )
1702 .trace_err();
1703 }
1704 update_user_contacts(session.user_id(), &session).await?;
1705 Ok(())
1706}
1707
1708/// Updates other participants in the room with your current location.
1709async fn update_participant_location(
1710 request: proto::UpdateParticipantLocation,
1711 response: Response<proto::UpdateParticipantLocation>,
1712 session: Session,
1713) -> Result<()> {
1714 let room_id = RoomId::from_proto(request.room_id);
1715 let location = request
1716 .location
1717 .ok_or_else(|| anyhow!("invalid location"))?;
1718
1719 let db = session.db().await;
1720 let room = db
1721 .update_room_participant_location(room_id, session.connection_id, location)
1722 .await?;
1723
1724 room_updated(&room, &session.peer);
1725 response.send(proto::Ack {})?;
1726 Ok(())
1727}
1728
1729/// Share a project into the room.
1730async fn share_project(
1731 request: proto::ShareProject,
1732 response: Response<proto::ShareProject>,
1733 session: Session,
1734) -> Result<()> {
1735 let (project_id, room) = &*session
1736 .db()
1737 .await
1738 .share_project(
1739 RoomId::from_proto(request.room_id),
1740 session.connection_id,
1741 &request.worktrees,
1742 request.is_ssh_project,
1743 )
1744 .await?;
1745 response.send(proto::ShareProjectResponse {
1746 project_id: project_id.to_proto(),
1747 })?;
1748 room_updated(room, &session.peer);
1749
1750 Ok(())
1751}
1752
1753/// Unshare a project from the room.
1754async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1755 let project_id = ProjectId::from_proto(message.project_id);
1756 unshare_project_internal(project_id, session.connection_id, &session).await
1757}
1758
1759async fn unshare_project_internal(
1760 project_id: ProjectId,
1761 connection_id: ConnectionId,
1762 session: &Session,
1763) -> Result<()> {
1764 let delete = {
1765 let room_guard = session
1766 .db()
1767 .await
1768 .unshare_project(project_id, connection_id)
1769 .await?;
1770
1771 let (delete, room, guest_connection_ids) = &*room_guard;
1772
1773 let message = proto::UnshareProject {
1774 project_id: project_id.to_proto(),
1775 };
1776
1777 broadcast(
1778 Some(connection_id),
1779 guest_connection_ids.iter().copied(),
1780 |conn_id| session.peer.send(conn_id, message.clone()),
1781 );
1782 if let Some(room) = room {
1783 room_updated(room, &session.peer);
1784 }
1785
1786 *delete
1787 };
1788
1789 if delete {
1790 let db = session.db().await;
1791 db.delete_project(project_id).await?;
1792 }
1793
1794 Ok(())
1795}
1796
1797/// Join someone elses shared project.
1798async fn join_project(
1799 request: proto::JoinProject,
1800 response: Response<proto::JoinProject>,
1801 session: Session,
1802) -> Result<()> {
1803 let project_id = ProjectId::from_proto(request.project_id);
1804
1805 tracing::info!(%project_id, "join project");
1806
1807 let db = session.db().await;
1808 let (project, replica_id) = &mut *db
1809 .join_project(project_id, session.connection_id, session.user_id())
1810 .await?;
1811 drop(db);
1812 tracing::info!(%project_id, "join remote project");
1813 join_project_internal(response, session, project, replica_id)
1814}
1815
1816trait JoinProjectInternalResponse {
1817 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1818}
1819impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1820 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1821 Response::<proto::JoinProject>::send(self, result)
1822 }
1823}
1824
1825fn join_project_internal(
1826 response: impl JoinProjectInternalResponse,
1827 session: Session,
1828 project: &mut Project,
1829 replica_id: &ReplicaId,
1830) -> Result<()> {
1831 let collaborators = project
1832 .collaborators
1833 .iter()
1834 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1835 .map(|collaborator| collaborator.to_proto())
1836 .collect::<Vec<_>>();
1837 let project_id = project.id;
1838 let guest_user_id = session.user_id();
1839
1840 let worktrees = project
1841 .worktrees
1842 .iter()
1843 .map(|(id, worktree)| proto::WorktreeMetadata {
1844 id: *id,
1845 root_name: worktree.root_name.clone(),
1846 visible: worktree.visible,
1847 abs_path: worktree.abs_path.clone(),
1848 })
1849 .collect::<Vec<_>>();
1850
1851 let add_project_collaborator = proto::AddProjectCollaborator {
1852 project_id: project_id.to_proto(),
1853 collaborator: Some(proto::Collaborator {
1854 peer_id: Some(session.connection_id.into()),
1855 replica_id: replica_id.0 as u32,
1856 user_id: guest_user_id.to_proto(),
1857 is_host: false,
1858 }),
1859 };
1860
1861 for collaborator in &collaborators {
1862 session
1863 .peer
1864 .send(
1865 collaborator.peer_id.unwrap().into(),
1866 add_project_collaborator.clone(),
1867 )
1868 .trace_err();
1869 }
1870
1871 // First, we send the metadata associated with each worktree.
1872 response.send(proto::JoinProjectResponse {
1873 project_id: project.id.0 as u64,
1874 worktrees: worktrees.clone(),
1875 replica_id: replica_id.0 as u32,
1876 collaborators: collaborators.clone(),
1877 language_servers: project.language_servers.clone(),
1878 role: project.role.into(),
1879 })?;
1880
1881 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1882 // Stream this worktree's entries.
1883 let message = proto::UpdateWorktree {
1884 project_id: project_id.to_proto(),
1885 worktree_id,
1886 abs_path: worktree.abs_path.clone(),
1887 root_name: worktree.root_name,
1888 updated_entries: worktree.entries,
1889 removed_entries: Default::default(),
1890 scan_id: worktree.scan_id,
1891 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1892 updated_repositories: worktree.repository_entries.into_values().collect(),
1893 removed_repositories: Default::default(),
1894 };
1895 for update in proto::split_worktree_update(message) {
1896 session.peer.send(session.connection_id, update.clone())?;
1897 }
1898
1899 // Stream this worktree's diagnostics.
1900 for summary in worktree.diagnostic_summaries {
1901 session.peer.send(
1902 session.connection_id,
1903 proto::UpdateDiagnosticSummary {
1904 project_id: project_id.to_proto(),
1905 worktree_id: worktree.id,
1906 summary: Some(summary),
1907 },
1908 )?;
1909 }
1910
1911 for settings_file in worktree.settings_files {
1912 session.peer.send(
1913 session.connection_id,
1914 proto::UpdateWorktreeSettings {
1915 project_id: project_id.to_proto(),
1916 worktree_id: worktree.id,
1917 path: settings_file.path,
1918 content: Some(settings_file.content),
1919 kind: Some(settings_file.kind.to_proto() as i32),
1920 },
1921 )?;
1922 }
1923 }
1924
1925 for language_server in &project.language_servers {
1926 session.peer.send(
1927 session.connection_id,
1928 proto::UpdateLanguageServer {
1929 project_id: project_id.to_proto(),
1930 language_server_id: language_server.id,
1931 variant: Some(
1932 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1933 proto::LspDiskBasedDiagnosticsUpdated {},
1934 ),
1935 ),
1936 },
1937 )?;
1938 }
1939
1940 Ok(())
1941}
1942
1943/// Leave someone elses shared project.
1944async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1945 let sender_id = session.connection_id;
1946 let project_id = ProjectId::from_proto(request.project_id);
1947 let db = session.db().await;
1948
1949 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950 tracing::info!(
1951 %project_id,
1952 "leave project"
1953 );
1954
1955 project_left(project, &session);
1956 if let Some(room) = room {
1957 room_updated(room, &session.peer);
1958 }
1959
1960 Ok(())
1961}
1962
1963/// Updates other participants with changes to the project
1964async fn update_project(
1965 request: proto::UpdateProject,
1966 response: Response<proto::UpdateProject>,
1967 session: Session,
1968) -> Result<()> {
1969 let project_id = ProjectId::from_proto(request.project_id);
1970 let (room, guest_connection_ids) = &*session
1971 .db()
1972 .await
1973 .update_project(project_id, session.connection_id, &request.worktrees)
1974 .await?;
1975 broadcast(
1976 Some(session.connection_id),
1977 guest_connection_ids.iter().copied(),
1978 |connection_id| {
1979 session
1980 .peer
1981 .forward_send(session.connection_id, connection_id, request.clone())
1982 },
1983 );
1984 if let Some(room) = room {
1985 room_updated(room, &session.peer);
1986 }
1987 response.send(proto::Ack {})?;
1988
1989 Ok(())
1990}
1991
1992/// Updates other participants with changes to the worktree
1993async fn update_worktree(
1994 request: proto::UpdateWorktree,
1995 response: Response<proto::UpdateWorktree>,
1996 session: Session,
1997) -> Result<()> {
1998 let guest_connection_ids = session
1999 .db()
2000 .await
2001 .update_worktree(&request, session.connection_id)
2002 .await?;
2003
2004 broadcast(
2005 Some(session.connection_id),
2006 guest_connection_ids.iter().copied(),
2007 |connection_id| {
2008 session
2009 .peer
2010 .forward_send(session.connection_id, connection_id, request.clone())
2011 },
2012 );
2013 response.send(proto::Ack {})?;
2014 Ok(())
2015}
2016
2017/// Updates other participants with changes to the diagnostics
2018async fn update_diagnostic_summary(
2019 message: proto::UpdateDiagnosticSummary,
2020 session: Session,
2021) -> Result<()> {
2022 let guest_connection_ids = session
2023 .db()
2024 .await
2025 .update_diagnostic_summary(&message, session.connection_id)
2026 .await?;
2027
2028 broadcast(
2029 Some(session.connection_id),
2030 guest_connection_ids.iter().copied(),
2031 |connection_id| {
2032 session
2033 .peer
2034 .forward_send(session.connection_id, connection_id, message.clone())
2035 },
2036 );
2037
2038 Ok(())
2039}
2040
2041/// Updates other participants with changes to the worktree settings
2042async fn update_worktree_settings(
2043 message: proto::UpdateWorktreeSettings,
2044 session: Session,
2045) -> Result<()> {
2046 let guest_connection_ids = session
2047 .db()
2048 .await
2049 .update_worktree_settings(&message, session.connection_id)
2050 .await?;
2051
2052 broadcast(
2053 Some(session.connection_id),
2054 guest_connection_ids.iter().copied(),
2055 |connection_id| {
2056 session
2057 .peer
2058 .forward_send(session.connection_id, connection_id, message.clone())
2059 },
2060 );
2061
2062 Ok(())
2063}
2064
2065/// Notify other participants that a language server has started.
2066async fn start_language_server(
2067 request: proto::StartLanguageServer,
2068 session: Session,
2069) -> Result<()> {
2070 let guest_connection_ids = session
2071 .db()
2072 .await
2073 .start_language_server(&request, session.connection_id)
2074 .await?;
2075
2076 broadcast(
2077 Some(session.connection_id),
2078 guest_connection_ids.iter().copied(),
2079 |connection_id| {
2080 session
2081 .peer
2082 .forward_send(session.connection_id, connection_id, request.clone())
2083 },
2084 );
2085 Ok(())
2086}
2087
2088/// Notify other participants that a language server has changed.
2089async fn update_language_server(
2090 request: proto::UpdateLanguageServer,
2091 session: Session,
2092) -> Result<()> {
2093 let project_id = ProjectId::from_proto(request.project_id);
2094 let project_connection_ids = session
2095 .db()
2096 .await
2097 .project_connection_ids(project_id, session.connection_id, true)
2098 .await?;
2099 broadcast(
2100 Some(session.connection_id),
2101 project_connection_ids.iter().copied(),
2102 |connection_id| {
2103 session
2104 .peer
2105 .forward_send(session.connection_id, connection_id, request.clone())
2106 },
2107 );
2108 Ok(())
2109}
2110
2111/// forward a project request to the host. These requests should be read only
2112/// as guests are allowed to send them.
2113async fn forward_read_only_project_request<T>(
2114 request: T,
2115 response: Response<T>,
2116 session: Session,
2117) -> Result<()>
2118where
2119 T: EntityMessage + RequestMessage,
2120{
2121 let project_id = ProjectId::from_proto(request.remote_entity_id());
2122 let host_connection_id = session
2123 .db()
2124 .await
2125 .host_for_read_only_project_request(project_id, session.connection_id)
2126 .await?;
2127 let payload = session
2128 .peer
2129 .forward_request(session.connection_id, host_connection_id, request)
2130 .await?;
2131 response.send(payload)?;
2132 Ok(())
2133}
2134
2135async fn forward_find_search_candidates_request(
2136 request: proto::FindSearchCandidates,
2137 response: Response<proto::FindSearchCandidates>,
2138 session: Session,
2139) -> Result<()> {
2140 let project_id = ProjectId::from_proto(request.remote_entity_id());
2141 let host_connection_id = session
2142 .db()
2143 .await
2144 .host_for_read_only_project_request(project_id, session.connection_id)
2145 .await?;
2146 let payload = session
2147 .peer
2148 .forward_request(session.connection_id, host_connection_id, request)
2149 .await?;
2150 response.send(payload)?;
2151 Ok(())
2152}
2153
2154/// forward a project request to the host. These requests are disallowed
2155/// for guests.
2156async fn forward_mutating_project_request<T>(
2157 request: T,
2158 response: Response<T>,
2159 session: Session,
2160) -> Result<()>
2161where
2162 T: EntityMessage + RequestMessage,
2163{
2164 let project_id = ProjectId::from_proto(request.remote_entity_id());
2165
2166 let host_connection_id = session
2167 .db()
2168 .await
2169 .host_for_mutating_project_request(project_id, session.connection_id)
2170 .await?;
2171 let payload = session
2172 .peer
2173 .forward_request(session.connection_id, host_connection_id, request)
2174 .await?;
2175 response.send(payload)?;
2176 Ok(())
2177}
2178
2179/// Notify other participants that a new buffer has been created
2180async fn create_buffer_for_peer(
2181 request: proto::CreateBufferForPeer,
2182 session: Session,
2183) -> Result<()> {
2184 session
2185 .db()
2186 .await
2187 .check_user_is_project_host(
2188 ProjectId::from_proto(request.project_id),
2189 session.connection_id,
2190 )
2191 .await?;
2192 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2193 session
2194 .peer
2195 .forward_send(session.connection_id, peer_id.into(), request)?;
2196 Ok(())
2197}
2198
2199/// Notify other participants that a buffer has been updated. This is
2200/// allowed for guests as long as the update is limited to selections.
2201async fn update_buffer(
2202 request: proto::UpdateBuffer,
2203 response: Response<proto::UpdateBuffer>,
2204 session: Session,
2205) -> Result<()> {
2206 let project_id = ProjectId::from_proto(request.project_id);
2207 let mut capability = Capability::ReadOnly;
2208
2209 for op in request.operations.iter() {
2210 match op.variant {
2211 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2212 Some(_) => capability = Capability::ReadWrite,
2213 }
2214 }
2215
2216 let host = {
2217 let guard = session
2218 .db()
2219 .await
2220 .connections_for_buffer_update(project_id, session.connection_id, capability)
2221 .await?;
2222
2223 let (host, guests) = &*guard;
2224
2225 broadcast(
2226 Some(session.connection_id),
2227 guests.clone(),
2228 |connection_id| {
2229 session
2230 .peer
2231 .forward_send(session.connection_id, connection_id, request.clone())
2232 },
2233 );
2234
2235 *host
2236 };
2237
2238 if host != session.connection_id {
2239 session
2240 .peer
2241 .forward_request(session.connection_id, host, request.clone())
2242 .await?;
2243 }
2244
2245 response.send(proto::Ack {})?;
2246 Ok(())
2247}
2248
2249async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2250 let project_id = ProjectId::from_proto(message.project_id);
2251
2252 let operation = message.operation.as_ref().context("invalid operation")?;
2253 let capability = match operation.variant.as_ref() {
2254 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2255 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2256 match buffer_op.variant {
2257 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2258 Capability::ReadOnly
2259 }
2260 _ => Capability::ReadWrite,
2261 }
2262 } else {
2263 Capability::ReadWrite
2264 }
2265 }
2266 Some(_) => Capability::ReadWrite,
2267 None => Capability::ReadOnly,
2268 };
2269
2270 let guard = session
2271 .db()
2272 .await
2273 .connections_for_buffer_update(project_id, session.connection_id, capability)
2274 .await?;
2275
2276 let (host, guests) = &*guard;
2277
2278 broadcast(
2279 Some(session.connection_id),
2280 guests.iter().chain([host]).copied(),
2281 |connection_id| {
2282 session
2283 .peer
2284 .forward_send(session.connection_id, connection_id, message.clone())
2285 },
2286 );
2287
2288 Ok(())
2289}
2290
2291/// Notify other participants that a project has been updated.
2292async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2293 request: T,
2294 session: Session,
2295) -> Result<()> {
2296 let project_id = ProjectId::from_proto(request.remote_entity_id());
2297 let project_connection_ids = session
2298 .db()
2299 .await
2300 .project_connection_ids(project_id, session.connection_id, false)
2301 .await?;
2302
2303 broadcast(
2304 Some(session.connection_id),
2305 project_connection_ids.iter().copied(),
2306 |connection_id| {
2307 session
2308 .peer
2309 .forward_send(session.connection_id, connection_id, request.clone())
2310 },
2311 );
2312 Ok(())
2313}
2314
2315/// Start following another user in a call.
2316async fn follow(
2317 request: proto::Follow,
2318 response: Response<proto::Follow>,
2319 session: Session,
2320) -> Result<()> {
2321 let room_id = RoomId::from_proto(request.room_id);
2322 let project_id = request.project_id.map(ProjectId::from_proto);
2323 let leader_id = request
2324 .leader_id
2325 .ok_or_else(|| anyhow!("invalid leader id"))?
2326 .into();
2327 let follower_id = session.connection_id;
2328
2329 session
2330 .db()
2331 .await
2332 .check_room_participants(room_id, leader_id, session.connection_id)
2333 .await?;
2334
2335 let response_payload = session
2336 .peer
2337 .forward_request(session.connection_id, leader_id, request)
2338 .await?;
2339 response.send(response_payload)?;
2340
2341 if let Some(project_id) = project_id {
2342 let room = session
2343 .db()
2344 .await
2345 .follow(room_id, project_id, leader_id, follower_id)
2346 .await?;
2347 room_updated(&room, &session.peer);
2348 }
2349
2350 Ok(())
2351}
2352
2353/// Stop following another user in a call.
2354async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2355 let room_id = RoomId::from_proto(request.room_id);
2356 let project_id = request.project_id.map(ProjectId::from_proto);
2357 let leader_id = request
2358 .leader_id
2359 .ok_or_else(|| anyhow!("invalid leader id"))?
2360 .into();
2361 let follower_id = session.connection_id;
2362
2363 session
2364 .db()
2365 .await
2366 .check_room_participants(room_id, leader_id, session.connection_id)
2367 .await?;
2368
2369 session
2370 .peer
2371 .forward_send(session.connection_id, leader_id, request)?;
2372
2373 if let Some(project_id) = project_id {
2374 let room = session
2375 .db()
2376 .await
2377 .unfollow(room_id, project_id, leader_id, follower_id)
2378 .await?;
2379 room_updated(&room, &session.peer);
2380 }
2381
2382 Ok(())
2383}
2384
2385/// Notify everyone following you of your current location.
2386async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2387 let room_id = RoomId::from_proto(request.room_id);
2388 let database = session.db.lock().await;
2389
2390 let connection_ids = if let Some(project_id) = request.project_id {
2391 let project_id = ProjectId::from_proto(project_id);
2392 database
2393 .project_connection_ids(project_id, session.connection_id, true)
2394 .await?
2395 } else {
2396 database
2397 .room_connection_ids(room_id, session.connection_id)
2398 .await?
2399 };
2400
2401 // For now, don't send view update messages back to that view's current leader.
2402 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2403 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2404 _ => None,
2405 });
2406
2407 for connection_id in connection_ids.iter().cloned() {
2408 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2409 session
2410 .peer
2411 .forward_send(session.connection_id, connection_id, request.clone())?;
2412 }
2413 }
2414 Ok(())
2415}
2416
2417/// Get public data about users.
2418async fn get_users(
2419 request: proto::GetUsers,
2420 response: Response<proto::GetUsers>,
2421 session: Session,
2422) -> Result<()> {
2423 let user_ids = request
2424 .user_ids
2425 .into_iter()
2426 .map(UserId::from_proto)
2427 .collect();
2428 let users = session
2429 .db()
2430 .await
2431 .get_users_by_ids(user_ids)
2432 .await?
2433 .into_iter()
2434 .map(|user| proto::User {
2435 id: user.id.to_proto(),
2436 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2437 github_login: user.github_login,
2438 email: user.email_address,
2439 name: user.name,
2440 })
2441 .collect();
2442 response.send(proto::UsersResponse { users })?;
2443 Ok(())
2444}
2445
2446/// Search for users (to invite) buy Github login
2447async fn fuzzy_search_users(
2448 request: proto::FuzzySearchUsers,
2449 response: Response<proto::FuzzySearchUsers>,
2450 session: Session,
2451) -> Result<()> {
2452 let query = request.query;
2453 let users = match query.len() {
2454 0 => vec![],
2455 1 | 2 => session
2456 .db()
2457 .await
2458 .get_user_by_github_login(&query)
2459 .await?
2460 .into_iter()
2461 .collect(),
2462 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2463 };
2464 let users = users
2465 .into_iter()
2466 .filter(|user| user.id != session.user_id())
2467 .map(|user| proto::User {
2468 id: user.id.to_proto(),
2469 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2470 github_login: user.github_login,
2471 name: user.name,
2472 email: user.email_address,
2473 })
2474 .collect();
2475 response.send(proto::UsersResponse { users })?;
2476 Ok(())
2477}
2478
2479/// Send a contact request to another user.
2480async fn request_contact(
2481 request: proto::RequestContact,
2482 response: Response<proto::RequestContact>,
2483 session: Session,
2484) -> Result<()> {
2485 let requester_id = session.user_id();
2486 let responder_id = UserId::from_proto(request.responder_id);
2487 if requester_id == responder_id {
2488 return Err(anyhow!("cannot add yourself as a contact"))?;
2489 }
2490
2491 let notifications = session
2492 .db()
2493 .await
2494 .send_contact_request(requester_id, responder_id)
2495 .await?;
2496
2497 // Update outgoing contact requests of requester
2498 let mut update = proto::UpdateContacts::default();
2499 update.outgoing_requests.push(responder_id.to_proto());
2500 for connection_id in session
2501 .connection_pool()
2502 .await
2503 .user_connection_ids(requester_id)
2504 {
2505 session.peer.send(connection_id, update.clone())?;
2506 }
2507
2508 // Update incoming contact requests of responder
2509 let mut update = proto::UpdateContacts::default();
2510 update
2511 .incoming_requests
2512 .push(proto::IncomingContactRequest {
2513 requester_id: requester_id.to_proto(),
2514 });
2515 let connection_pool = session.connection_pool().await;
2516 for connection_id in connection_pool.user_connection_ids(responder_id) {
2517 session.peer.send(connection_id, update.clone())?;
2518 }
2519
2520 send_notifications(&connection_pool, &session.peer, notifications);
2521
2522 response.send(proto::Ack {})?;
2523 Ok(())
2524}
2525
2526/// Accept or decline a contact request
2527async fn respond_to_contact_request(
2528 request: proto::RespondToContactRequest,
2529 response: Response<proto::RespondToContactRequest>,
2530 session: Session,
2531) -> Result<()> {
2532 let responder_id = session.user_id();
2533 let requester_id = UserId::from_proto(request.requester_id);
2534 let db = session.db().await;
2535 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2536 db.dismiss_contact_notification(responder_id, requester_id)
2537 .await?;
2538 } else {
2539 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2540
2541 let notifications = db
2542 .respond_to_contact_request(responder_id, requester_id, accept)
2543 .await?;
2544 let requester_busy = db.is_user_busy(requester_id).await?;
2545 let responder_busy = db.is_user_busy(responder_id).await?;
2546
2547 let pool = session.connection_pool().await;
2548 // Update responder with new contact
2549 let mut update = proto::UpdateContacts::default();
2550 if accept {
2551 update
2552 .contacts
2553 .push(contact_for_user(requester_id, requester_busy, &pool));
2554 }
2555 update
2556 .remove_incoming_requests
2557 .push(requester_id.to_proto());
2558 for connection_id in pool.user_connection_ids(responder_id) {
2559 session.peer.send(connection_id, update.clone())?;
2560 }
2561
2562 // Update requester with new contact
2563 let mut update = proto::UpdateContacts::default();
2564 if accept {
2565 update
2566 .contacts
2567 .push(contact_for_user(responder_id, responder_busy, &pool));
2568 }
2569 update
2570 .remove_outgoing_requests
2571 .push(responder_id.to_proto());
2572
2573 for connection_id in pool.user_connection_ids(requester_id) {
2574 session.peer.send(connection_id, update.clone())?;
2575 }
2576
2577 send_notifications(&pool, &session.peer, notifications);
2578 }
2579
2580 response.send(proto::Ack {})?;
2581 Ok(())
2582}
2583
2584/// Remove a contact.
2585async fn remove_contact(
2586 request: proto::RemoveContact,
2587 response: Response<proto::RemoveContact>,
2588 session: Session,
2589) -> Result<()> {
2590 let requester_id = session.user_id();
2591 let responder_id = UserId::from_proto(request.user_id);
2592 let db = session.db().await;
2593 let (contact_accepted, deleted_notification_id) =
2594 db.remove_contact(requester_id, responder_id).await?;
2595
2596 let pool = session.connection_pool().await;
2597 // Update outgoing contact requests of requester
2598 let mut update = proto::UpdateContacts::default();
2599 if contact_accepted {
2600 update.remove_contacts.push(responder_id.to_proto());
2601 } else {
2602 update
2603 .remove_outgoing_requests
2604 .push(responder_id.to_proto());
2605 }
2606 for connection_id in pool.user_connection_ids(requester_id) {
2607 session.peer.send(connection_id, update.clone())?;
2608 }
2609
2610 // Update incoming contact requests of responder
2611 let mut update = proto::UpdateContacts::default();
2612 if contact_accepted {
2613 update.remove_contacts.push(requester_id.to_proto());
2614 } else {
2615 update
2616 .remove_incoming_requests
2617 .push(requester_id.to_proto());
2618 }
2619 for connection_id in pool.user_connection_ids(responder_id) {
2620 session.peer.send(connection_id, update.clone())?;
2621 if let Some(notification_id) = deleted_notification_id {
2622 session.peer.send(
2623 connection_id,
2624 proto::DeleteNotification {
2625 notification_id: notification_id.to_proto(),
2626 },
2627 )?;
2628 }
2629 }
2630
2631 response.send(proto::Ack {})?;
2632 Ok(())
2633}
2634
2635fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2636 version.0.minor() < 139
2637}
2638
2639async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2640 let plan = session.current_plan(&session.db().await).await?;
2641
2642 session
2643 .peer
2644 .send(
2645 session.connection_id,
2646 proto::UpdateUserPlan { plan: plan.into() },
2647 )
2648 .trace_err();
2649
2650 Ok(())
2651}
2652
2653async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2654 subscribe_user_to_channels(session.user_id(), &session).await?;
2655 Ok(())
2656}
2657
2658async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2659 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2660 let mut pool = session.connection_pool().await;
2661 for membership in &channels_for_user.channel_memberships {
2662 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2663 }
2664 session.peer.send(
2665 session.connection_id,
2666 build_update_user_channels(&channels_for_user),
2667 )?;
2668 session.peer.send(
2669 session.connection_id,
2670 build_channels_update(channels_for_user),
2671 )?;
2672 Ok(())
2673}
2674
2675/// Creates a new channel.
2676async fn create_channel(
2677 request: proto::CreateChannel,
2678 response: Response<proto::CreateChannel>,
2679 session: Session,
2680) -> Result<()> {
2681 let db = session.db().await;
2682
2683 let parent_id = request.parent_id.map(ChannelId::from_proto);
2684 let (channel, membership) = db
2685 .create_channel(&request.name, parent_id, session.user_id())
2686 .await?;
2687
2688 let root_id = channel.root_id();
2689 let channel = Channel::from_model(channel);
2690
2691 response.send(proto::CreateChannelResponse {
2692 channel: Some(channel.to_proto()),
2693 parent_id: request.parent_id,
2694 })?;
2695
2696 let mut connection_pool = session.connection_pool().await;
2697 if let Some(membership) = membership {
2698 connection_pool.subscribe_to_channel(
2699 membership.user_id,
2700 membership.channel_id,
2701 membership.role,
2702 );
2703 let update = proto::UpdateUserChannels {
2704 channel_memberships: vec![proto::ChannelMembership {
2705 channel_id: membership.channel_id.to_proto(),
2706 role: membership.role.into(),
2707 }],
2708 ..Default::default()
2709 };
2710 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2711 session.peer.send(connection_id, update.clone())?;
2712 }
2713 }
2714
2715 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2716 if !role.can_see_channel(channel.visibility) {
2717 continue;
2718 }
2719
2720 let update = proto::UpdateChannels {
2721 channels: vec![channel.to_proto()],
2722 ..Default::default()
2723 };
2724 session.peer.send(connection_id, update.clone())?;
2725 }
2726
2727 Ok(())
2728}
2729
2730/// Delete a channel
2731async fn delete_channel(
2732 request: proto::DeleteChannel,
2733 response: Response<proto::DeleteChannel>,
2734 session: Session,
2735) -> Result<()> {
2736 let db = session.db().await;
2737
2738 let channel_id = request.channel_id;
2739 let (root_channel, removed_channels) = db
2740 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2741 .await?;
2742 response.send(proto::Ack {})?;
2743
2744 // Notify members of removed channels
2745 let mut update = proto::UpdateChannels::default();
2746 update
2747 .delete_channels
2748 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2749
2750 let connection_pool = session.connection_pool().await;
2751 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2752 session.peer.send(connection_id, update.clone())?;
2753 }
2754
2755 Ok(())
2756}
2757
2758/// Invite someone to join a channel.
2759async fn invite_channel_member(
2760 request: proto::InviteChannelMember,
2761 response: Response<proto::InviteChannelMember>,
2762 session: Session,
2763) -> Result<()> {
2764 let db = session.db().await;
2765 let channel_id = ChannelId::from_proto(request.channel_id);
2766 let invitee_id = UserId::from_proto(request.user_id);
2767 let InviteMemberResult {
2768 channel,
2769 notifications,
2770 } = db
2771 .invite_channel_member(
2772 channel_id,
2773 invitee_id,
2774 session.user_id(),
2775 request.role().into(),
2776 )
2777 .await?;
2778
2779 let update = proto::UpdateChannels {
2780 channel_invitations: vec![channel.to_proto()],
2781 ..Default::default()
2782 };
2783
2784 let connection_pool = session.connection_pool().await;
2785 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2786 session.peer.send(connection_id, update.clone())?;
2787 }
2788
2789 send_notifications(&connection_pool, &session.peer, notifications);
2790
2791 response.send(proto::Ack {})?;
2792 Ok(())
2793}
2794
2795/// remove someone from a channel
2796async fn remove_channel_member(
2797 request: proto::RemoveChannelMember,
2798 response: Response<proto::RemoveChannelMember>,
2799 session: Session,
2800) -> Result<()> {
2801 let db = session.db().await;
2802 let channel_id = ChannelId::from_proto(request.channel_id);
2803 let member_id = UserId::from_proto(request.user_id);
2804
2805 let RemoveChannelMemberResult {
2806 membership_update,
2807 notification_id,
2808 } = db
2809 .remove_channel_member(channel_id, member_id, session.user_id())
2810 .await?;
2811
2812 let mut connection_pool = session.connection_pool().await;
2813 notify_membership_updated(
2814 &mut connection_pool,
2815 membership_update,
2816 member_id,
2817 &session.peer,
2818 );
2819 for connection_id in connection_pool.user_connection_ids(member_id) {
2820 if let Some(notification_id) = notification_id {
2821 session
2822 .peer
2823 .send(
2824 connection_id,
2825 proto::DeleteNotification {
2826 notification_id: notification_id.to_proto(),
2827 },
2828 )
2829 .trace_err();
2830 }
2831 }
2832
2833 response.send(proto::Ack {})?;
2834 Ok(())
2835}
2836
2837/// Toggle the channel between public and private.
2838/// Care is taken to maintain the invariant that public channels only descend from public channels,
2839/// (though members-only channels can appear at any point in the hierarchy).
2840async fn set_channel_visibility(
2841 request: proto::SetChannelVisibility,
2842 response: Response<proto::SetChannelVisibility>,
2843 session: Session,
2844) -> Result<()> {
2845 let db = session.db().await;
2846 let channel_id = ChannelId::from_proto(request.channel_id);
2847 let visibility = request.visibility().into();
2848
2849 let channel_model = db
2850 .set_channel_visibility(channel_id, visibility, session.user_id())
2851 .await?;
2852 let root_id = channel_model.root_id();
2853 let channel = Channel::from_model(channel_model);
2854
2855 let mut connection_pool = session.connection_pool().await;
2856 for (user_id, role) in connection_pool
2857 .channel_user_ids(root_id)
2858 .collect::<Vec<_>>()
2859 .into_iter()
2860 {
2861 let update = if role.can_see_channel(channel.visibility) {
2862 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2863 proto::UpdateChannels {
2864 channels: vec![channel.to_proto()],
2865 ..Default::default()
2866 }
2867 } else {
2868 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2869 proto::UpdateChannels {
2870 delete_channels: vec![channel.id.to_proto()],
2871 ..Default::default()
2872 }
2873 };
2874
2875 for connection_id in connection_pool.user_connection_ids(user_id) {
2876 session.peer.send(connection_id, update.clone())?;
2877 }
2878 }
2879
2880 response.send(proto::Ack {})?;
2881 Ok(())
2882}
2883
2884/// Alter the role for a user in the channel.
2885async fn set_channel_member_role(
2886 request: proto::SetChannelMemberRole,
2887 response: Response<proto::SetChannelMemberRole>,
2888 session: Session,
2889) -> Result<()> {
2890 let db = session.db().await;
2891 let channel_id = ChannelId::from_proto(request.channel_id);
2892 let member_id = UserId::from_proto(request.user_id);
2893 let result = db
2894 .set_channel_member_role(
2895 channel_id,
2896 session.user_id(),
2897 member_id,
2898 request.role().into(),
2899 )
2900 .await?;
2901
2902 match result {
2903 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2904 let mut connection_pool = session.connection_pool().await;
2905 notify_membership_updated(
2906 &mut connection_pool,
2907 membership_update,
2908 member_id,
2909 &session.peer,
2910 )
2911 }
2912 db::SetMemberRoleResult::InviteUpdated(channel) => {
2913 let update = proto::UpdateChannels {
2914 channel_invitations: vec![channel.to_proto()],
2915 ..Default::default()
2916 };
2917
2918 for connection_id in session
2919 .connection_pool()
2920 .await
2921 .user_connection_ids(member_id)
2922 {
2923 session.peer.send(connection_id, update.clone())?;
2924 }
2925 }
2926 }
2927
2928 response.send(proto::Ack {})?;
2929 Ok(())
2930}
2931
2932/// Change the name of a channel
2933async fn rename_channel(
2934 request: proto::RenameChannel,
2935 response: Response<proto::RenameChannel>,
2936 session: Session,
2937) -> Result<()> {
2938 let db = session.db().await;
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940 let channel_model = db
2941 .rename_channel(channel_id, session.user_id(), &request.name)
2942 .await?;
2943 let root_id = channel_model.root_id();
2944 let channel = Channel::from_model(channel_model);
2945
2946 response.send(proto::RenameChannelResponse {
2947 channel: Some(channel.to_proto()),
2948 })?;
2949
2950 let connection_pool = session.connection_pool().await;
2951 let update = proto::UpdateChannels {
2952 channels: vec![channel.to_proto()],
2953 ..Default::default()
2954 };
2955 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2956 if role.can_see_channel(channel.visibility) {
2957 session.peer.send(connection_id, update.clone())?;
2958 }
2959 }
2960
2961 Ok(())
2962}
2963
2964/// Move a channel to a new parent.
2965async fn move_channel(
2966 request: proto::MoveChannel,
2967 response: Response<proto::MoveChannel>,
2968 session: Session,
2969) -> Result<()> {
2970 let channel_id = ChannelId::from_proto(request.channel_id);
2971 let to = ChannelId::from_proto(request.to);
2972
2973 let (root_id, channels) = session
2974 .db()
2975 .await
2976 .move_channel(channel_id, to, session.user_id())
2977 .await?;
2978
2979 let connection_pool = session.connection_pool().await;
2980 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2981 let channels = channels
2982 .iter()
2983 .filter_map(|channel| {
2984 if role.can_see_channel(channel.visibility) {
2985 Some(channel.to_proto())
2986 } else {
2987 None
2988 }
2989 })
2990 .collect::<Vec<_>>();
2991 if channels.is_empty() {
2992 continue;
2993 }
2994
2995 let update = proto::UpdateChannels {
2996 channels,
2997 ..Default::default()
2998 };
2999
3000 session.peer.send(connection_id, update.clone())?;
3001 }
3002
3003 response.send(Ack {})?;
3004 Ok(())
3005}
3006
3007/// Get the list of channel members
3008async fn get_channel_members(
3009 request: proto::GetChannelMembers,
3010 response: Response<proto::GetChannelMembers>,
3011 session: Session,
3012) -> Result<()> {
3013 let db = session.db().await;
3014 let channel_id = ChannelId::from_proto(request.channel_id);
3015 let limit = if request.limit == 0 {
3016 u16::MAX as u64
3017 } else {
3018 request.limit
3019 };
3020 let (members, users) = db
3021 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3022 .await?;
3023 response.send(proto::GetChannelMembersResponse { members, users })?;
3024 Ok(())
3025}
3026
3027/// Accept or decline a channel invitation.
3028async fn respond_to_channel_invite(
3029 request: proto::RespondToChannelInvite,
3030 response: Response<proto::RespondToChannelInvite>,
3031 session: Session,
3032) -> Result<()> {
3033 let db = session.db().await;
3034 let channel_id = ChannelId::from_proto(request.channel_id);
3035 let RespondToChannelInvite {
3036 membership_update,
3037 notifications,
3038 } = db
3039 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3040 .await?;
3041
3042 let mut connection_pool = session.connection_pool().await;
3043 if let Some(membership_update) = membership_update {
3044 notify_membership_updated(
3045 &mut connection_pool,
3046 membership_update,
3047 session.user_id(),
3048 &session.peer,
3049 );
3050 } else {
3051 let update = proto::UpdateChannels {
3052 remove_channel_invitations: vec![channel_id.to_proto()],
3053 ..Default::default()
3054 };
3055
3056 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3057 session.peer.send(connection_id, update.clone())?;
3058 }
3059 };
3060
3061 send_notifications(&connection_pool, &session.peer, notifications);
3062
3063 response.send(proto::Ack {})?;
3064
3065 Ok(())
3066}
3067
3068/// Join the channels' room
3069async fn join_channel(
3070 request: proto::JoinChannel,
3071 response: Response<proto::JoinChannel>,
3072 session: Session,
3073) -> Result<()> {
3074 let channel_id = ChannelId::from_proto(request.channel_id);
3075 join_channel_internal(channel_id, Box::new(response), session).await
3076}
3077
3078trait JoinChannelInternalResponse {
3079 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3080}
3081impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3082 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3083 Response::<proto::JoinChannel>::send(self, result)
3084 }
3085}
3086impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3087 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3088 Response::<proto::JoinRoom>::send(self, result)
3089 }
3090}
3091
3092async fn join_channel_internal(
3093 channel_id: ChannelId,
3094 response: Box<impl JoinChannelInternalResponse>,
3095 session: Session,
3096) -> Result<()> {
3097 let joined_room = {
3098 let mut db = session.db().await;
3099 // If zed quits without leaving the room, and the user re-opens zed before the
3100 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3101 // room they were in.
3102 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3103 tracing::info!(
3104 stale_connection_id = %connection,
3105 "cleaning up stale connection",
3106 );
3107 drop(db);
3108 leave_room_for_session(&session, connection).await?;
3109 db = session.db().await;
3110 }
3111
3112 let (joined_room, membership_updated, role) = db
3113 .join_channel(channel_id, session.user_id(), session.connection_id)
3114 .await?;
3115
3116 let live_kit_connection_info =
3117 session
3118 .app_state
3119 .livekit_client
3120 .as_ref()
3121 .and_then(|live_kit| {
3122 let (can_publish, token) = if role == ChannelRole::Guest {
3123 (
3124 false,
3125 live_kit
3126 .guest_token(
3127 &joined_room.room.livekit_room,
3128 &session.user_id().to_string(),
3129 )
3130 .trace_err()?,
3131 )
3132 } else {
3133 (
3134 true,
3135 live_kit
3136 .room_token(
3137 &joined_room.room.livekit_room,
3138 &session.user_id().to_string(),
3139 )
3140 .trace_err()?,
3141 )
3142 };
3143
3144 Some(LiveKitConnectionInfo {
3145 server_url: live_kit.url().into(),
3146 token,
3147 can_publish,
3148 })
3149 });
3150
3151 response.send(proto::JoinRoomResponse {
3152 room: Some(joined_room.room.clone()),
3153 channel_id: joined_room
3154 .channel
3155 .as_ref()
3156 .map(|channel| channel.id.to_proto()),
3157 live_kit_connection_info,
3158 })?;
3159
3160 let mut connection_pool = session.connection_pool().await;
3161 if let Some(membership_updated) = membership_updated {
3162 notify_membership_updated(
3163 &mut connection_pool,
3164 membership_updated,
3165 session.user_id(),
3166 &session.peer,
3167 );
3168 }
3169
3170 room_updated(&joined_room.room, &session.peer);
3171
3172 joined_room
3173 };
3174
3175 channel_updated(
3176 &joined_room
3177 .channel
3178 .ok_or_else(|| anyhow!("channel not returned"))?,
3179 &joined_room.room,
3180 &session.peer,
3181 &*session.connection_pool().await,
3182 );
3183
3184 update_user_contacts(session.user_id(), &session).await?;
3185 Ok(())
3186}
3187
3188/// Start editing the channel notes
3189async fn join_channel_buffer(
3190 request: proto::JoinChannelBuffer,
3191 response: Response<proto::JoinChannelBuffer>,
3192 session: Session,
3193) -> Result<()> {
3194 let db = session.db().await;
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196
3197 let open_response = db
3198 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3199 .await?;
3200
3201 let collaborators = open_response.collaborators.clone();
3202 response.send(open_response)?;
3203
3204 let update = UpdateChannelBufferCollaborators {
3205 channel_id: channel_id.to_proto(),
3206 collaborators: collaborators.clone(),
3207 };
3208 channel_buffer_updated(
3209 session.connection_id,
3210 collaborators
3211 .iter()
3212 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3213 &update,
3214 &session.peer,
3215 );
3216
3217 Ok(())
3218}
3219
3220/// Edit the channel notes
3221async fn update_channel_buffer(
3222 request: proto::UpdateChannelBuffer,
3223 session: Session,
3224) -> Result<()> {
3225 let db = session.db().await;
3226 let channel_id = ChannelId::from_proto(request.channel_id);
3227
3228 let (collaborators, epoch, version) = db
3229 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3230 .await?;
3231
3232 channel_buffer_updated(
3233 session.connection_id,
3234 collaborators.clone(),
3235 &proto::UpdateChannelBuffer {
3236 channel_id: channel_id.to_proto(),
3237 operations: request.operations,
3238 },
3239 &session.peer,
3240 );
3241
3242 let pool = &*session.connection_pool().await;
3243
3244 let non_collaborators =
3245 pool.channel_connection_ids(channel_id)
3246 .filter_map(|(connection_id, _)| {
3247 if collaborators.contains(&connection_id) {
3248 None
3249 } else {
3250 Some(connection_id)
3251 }
3252 });
3253
3254 broadcast(None, non_collaborators, |peer_id| {
3255 session.peer.send(
3256 peer_id,
3257 proto::UpdateChannels {
3258 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3259 channel_id: channel_id.to_proto(),
3260 epoch: epoch as u64,
3261 version: version.clone(),
3262 }],
3263 ..Default::default()
3264 },
3265 )
3266 });
3267
3268 Ok(())
3269}
3270
3271/// Rejoin the channel notes after a connection blip
3272async fn rejoin_channel_buffers(
3273 request: proto::RejoinChannelBuffers,
3274 response: Response<proto::RejoinChannelBuffers>,
3275 session: Session,
3276) -> Result<()> {
3277 let db = session.db().await;
3278 let buffers = db
3279 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3280 .await?;
3281
3282 for rejoined_buffer in &buffers {
3283 let collaborators_to_notify = rejoined_buffer
3284 .buffer
3285 .collaborators
3286 .iter()
3287 .filter_map(|c| Some(c.peer_id?.into()));
3288 channel_buffer_updated(
3289 session.connection_id,
3290 collaborators_to_notify,
3291 &proto::UpdateChannelBufferCollaborators {
3292 channel_id: rejoined_buffer.buffer.channel_id,
3293 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3294 },
3295 &session.peer,
3296 );
3297 }
3298
3299 response.send(proto::RejoinChannelBuffersResponse {
3300 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3301 })?;
3302
3303 Ok(())
3304}
3305
3306/// Stop editing the channel notes
3307async fn leave_channel_buffer(
3308 request: proto::LeaveChannelBuffer,
3309 response: Response<proto::LeaveChannelBuffer>,
3310 session: Session,
3311) -> Result<()> {
3312 let db = session.db().await;
3313 let channel_id = ChannelId::from_proto(request.channel_id);
3314
3315 let left_buffer = db
3316 .leave_channel_buffer(channel_id, session.connection_id)
3317 .await?;
3318
3319 response.send(Ack {})?;
3320
3321 channel_buffer_updated(
3322 session.connection_id,
3323 left_buffer.connections,
3324 &proto::UpdateChannelBufferCollaborators {
3325 channel_id: channel_id.to_proto(),
3326 collaborators: left_buffer.collaborators,
3327 },
3328 &session.peer,
3329 );
3330
3331 Ok(())
3332}
3333
3334fn channel_buffer_updated<T: EnvelopedMessage>(
3335 sender_id: ConnectionId,
3336 collaborators: impl IntoIterator<Item = ConnectionId>,
3337 message: &T,
3338 peer: &Peer,
3339) {
3340 broadcast(Some(sender_id), collaborators, |peer_id| {
3341 peer.send(peer_id, message.clone())
3342 });
3343}
3344
3345fn send_notifications(
3346 connection_pool: &ConnectionPool,
3347 peer: &Peer,
3348 notifications: db::NotificationBatch,
3349) {
3350 for (user_id, notification) in notifications {
3351 for connection_id in connection_pool.user_connection_ids(user_id) {
3352 if let Err(error) = peer.send(
3353 connection_id,
3354 proto::AddNotification {
3355 notification: Some(notification.clone()),
3356 },
3357 ) {
3358 tracing::error!(
3359 "failed to send notification to {:?} {}",
3360 connection_id,
3361 error
3362 );
3363 }
3364 }
3365 }
3366}
3367
3368/// Send a message to the channel
3369async fn send_channel_message(
3370 request: proto::SendChannelMessage,
3371 response: Response<proto::SendChannelMessage>,
3372 session: Session,
3373) -> Result<()> {
3374 // Validate the message body.
3375 let body = request.body.trim().to_string();
3376 if body.len() > MAX_MESSAGE_LEN {
3377 return Err(anyhow!("message is too long"))?;
3378 }
3379 if body.is_empty() {
3380 return Err(anyhow!("message can't be blank"))?;
3381 }
3382
3383 // TODO: adjust mentions if body is trimmed
3384
3385 let timestamp = OffsetDateTime::now_utc();
3386 let nonce = request
3387 .nonce
3388 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3389
3390 let channel_id = ChannelId::from_proto(request.channel_id);
3391 let CreatedChannelMessage {
3392 message_id,
3393 participant_connection_ids,
3394 notifications,
3395 } = session
3396 .db()
3397 .await
3398 .create_channel_message(
3399 channel_id,
3400 session.user_id(),
3401 &body,
3402 &request.mentions,
3403 timestamp,
3404 nonce.clone().into(),
3405 request.reply_to_message_id.map(MessageId::from_proto),
3406 )
3407 .await?;
3408
3409 let message = proto::ChannelMessage {
3410 sender_id: session.user_id().to_proto(),
3411 id: message_id.to_proto(),
3412 body,
3413 mentions: request.mentions,
3414 timestamp: timestamp.unix_timestamp() as u64,
3415 nonce: Some(nonce),
3416 reply_to_message_id: request.reply_to_message_id,
3417 edited_at: None,
3418 };
3419 broadcast(
3420 Some(session.connection_id),
3421 participant_connection_ids.clone(),
3422 |connection| {
3423 session.peer.send(
3424 connection,
3425 proto::ChannelMessageSent {
3426 channel_id: channel_id.to_proto(),
3427 message: Some(message.clone()),
3428 },
3429 )
3430 },
3431 );
3432 response.send(proto::SendChannelMessageResponse {
3433 message: Some(message),
3434 })?;
3435
3436 let pool = &*session.connection_pool().await;
3437 let non_participants =
3438 pool.channel_connection_ids(channel_id)
3439 .filter_map(|(connection_id, _)| {
3440 if participant_connection_ids.contains(&connection_id) {
3441 None
3442 } else {
3443 Some(connection_id)
3444 }
3445 });
3446 broadcast(None, non_participants, |peer_id| {
3447 session.peer.send(
3448 peer_id,
3449 proto::UpdateChannels {
3450 latest_channel_message_ids: vec![proto::ChannelMessageId {
3451 channel_id: channel_id.to_proto(),
3452 message_id: message_id.to_proto(),
3453 }],
3454 ..Default::default()
3455 },
3456 )
3457 });
3458 send_notifications(pool, &session.peer, notifications);
3459
3460 Ok(())
3461}
3462
3463/// Delete a channel message
3464async fn remove_channel_message(
3465 request: proto::RemoveChannelMessage,
3466 response: Response<proto::RemoveChannelMessage>,
3467 session: Session,
3468) -> Result<()> {
3469 let channel_id = ChannelId::from_proto(request.channel_id);
3470 let message_id = MessageId::from_proto(request.message_id);
3471 let (connection_ids, existing_notification_ids) = session
3472 .db()
3473 .await
3474 .remove_channel_message(channel_id, message_id, session.user_id())
3475 .await?;
3476
3477 broadcast(
3478 Some(session.connection_id),
3479 connection_ids,
3480 move |connection| {
3481 session.peer.send(connection, request.clone())?;
3482
3483 for notification_id in &existing_notification_ids {
3484 session.peer.send(
3485 connection,
3486 proto::DeleteNotification {
3487 notification_id: (*notification_id).to_proto(),
3488 },
3489 )?;
3490 }
3491
3492 Ok(())
3493 },
3494 );
3495 response.send(proto::Ack {})?;
3496 Ok(())
3497}
3498
3499async fn update_channel_message(
3500 request: proto::UpdateChannelMessage,
3501 response: Response<proto::UpdateChannelMessage>,
3502 session: Session,
3503) -> Result<()> {
3504 let channel_id = ChannelId::from_proto(request.channel_id);
3505 let message_id = MessageId::from_proto(request.message_id);
3506 let updated_at = OffsetDateTime::now_utc();
3507 let UpdatedChannelMessage {
3508 message_id,
3509 participant_connection_ids,
3510 notifications,
3511 reply_to_message_id,
3512 timestamp,
3513 deleted_mention_notification_ids,
3514 updated_mention_notifications,
3515 } = session
3516 .db()
3517 .await
3518 .update_channel_message(
3519 channel_id,
3520 message_id,
3521 session.user_id(),
3522 request.body.as_str(),
3523 &request.mentions,
3524 updated_at,
3525 )
3526 .await?;
3527
3528 let nonce = request
3529 .nonce
3530 .clone()
3531 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3532
3533 let message = proto::ChannelMessage {
3534 sender_id: session.user_id().to_proto(),
3535 id: message_id.to_proto(),
3536 body: request.body.clone(),
3537 mentions: request.mentions.clone(),
3538 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3539 nonce: Some(nonce),
3540 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3541 edited_at: Some(updated_at.unix_timestamp() as u64),
3542 };
3543
3544 response.send(proto::Ack {})?;
3545
3546 let pool = &*session.connection_pool().await;
3547 broadcast(
3548 Some(session.connection_id),
3549 participant_connection_ids,
3550 |connection| {
3551 session.peer.send(
3552 connection,
3553 proto::ChannelMessageUpdate {
3554 channel_id: channel_id.to_proto(),
3555 message: Some(message.clone()),
3556 },
3557 )?;
3558
3559 for notification_id in &deleted_mention_notification_ids {
3560 session.peer.send(
3561 connection,
3562 proto::DeleteNotification {
3563 notification_id: (*notification_id).to_proto(),
3564 },
3565 )?;
3566 }
3567
3568 for notification in &updated_mention_notifications {
3569 session.peer.send(
3570 connection,
3571 proto::UpdateNotification {
3572 notification: Some(notification.clone()),
3573 },
3574 )?;
3575 }
3576
3577 Ok(())
3578 },
3579 );
3580
3581 send_notifications(pool, &session.peer, notifications);
3582
3583 Ok(())
3584}
3585
3586/// Mark a channel message as read
3587async fn acknowledge_channel_message(
3588 request: proto::AckChannelMessage,
3589 session: Session,
3590) -> Result<()> {
3591 let channel_id = ChannelId::from_proto(request.channel_id);
3592 let message_id = MessageId::from_proto(request.message_id);
3593 let notifications = session
3594 .db()
3595 .await
3596 .observe_channel_message(channel_id, session.user_id(), message_id)
3597 .await?;
3598 send_notifications(
3599 &*session.connection_pool().await,
3600 &session.peer,
3601 notifications,
3602 );
3603 Ok(())
3604}
3605
3606/// Mark a buffer version as synced
3607async fn acknowledge_buffer_version(
3608 request: proto::AckBufferOperation,
3609 session: Session,
3610) -> Result<()> {
3611 let buffer_id = BufferId::from_proto(request.buffer_id);
3612 session
3613 .db()
3614 .await
3615 .observe_buffer_version(
3616 buffer_id,
3617 session.user_id(),
3618 request.epoch as i32,
3619 &request.version,
3620 )
3621 .await?;
3622 Ok(())
3623}
3624
3625async fn count_language_model_tokens(
3626 request: proto::CountLanguageModelTokens,
3627 response: Response<proto::CountLanguageModelTokens>,
3628 session: Session,
3629 config: &Config,
3630) -> Result<()> {
3631 authorize_access_to_legacy_llm_endpoints(&session).await?;
3632
3633 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3634 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3635 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3636 };
3637
3638 session
3639 .app_state
3640 .rate_limiter
3641 .check(&*rate_limit, session.user_id())
3642 .await?;
3643
3644 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3645 Some(proto::LanguageModelProvider::Google) => {
3646 let api_key = config
3647 .google_ai_api_key
3648 .as_ref()
3649 .context("no Google AI API key configured on the server")?;
3650 google_ai::count_tokens(
3651 session.http_client.as_ref(),
3652 google_ai::API_URL,
3653 api_key,
3654 serde_json::from_str(&request.request)?,
3655 )
3656 .await?
3657 }
3658 _ => return Err(anyhow!("unsupported provider"))?,
3659 };
3660
3661 response.send(proto::CountLanguageModelTokensResponse {
3662 token_count: result.total_tokens as u32,
3663 })?;
3664
3665 Ok(())
3666}
3667
3668struct ZedProCountLanguageModelTokensRateLimit;
3669
3670impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3671 fn capacity(&self) -> usize {
3672 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3673 .ok()
3674 .and_then(|v| v.parse().ok())
3675 .unwrap_or(600) // Picked arbitrarily
3676 }
3677
3678 fn refill_duration(&self) -> chrono::Duration {
3679 chrono::Duration::hours(1)
3680 }
3681
3682 fn db_name(&self) -> &'static str {
3683 "zed-pro:count-language-model-tokens"
3684 }
3685}
3686
3687struct FreeCountLanguageModelTokensRateLimit;
3688
3689impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3690 fn capacity(&self) -> usize {
3691 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3692 .ok()
3693 .and_then(|v| v.parse().ok())
3694 .unwrap_or(600 / 10) // Picked arbitrarily
3695 }
3696
3697 fn refill_duration(&self) -> chrono::Duration {
3698 chrono::Duration::hours(1)
3699 }
3700
3701 fn db_name(&self) -> &'static str {
3702 "free:count-language-model-tokens"
3703 }
3704}
3705
3706struct ZedProComputeEmbeddingsRateLimit;
3707
3708impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3709 fn capacity(&self) -> usize {
3710 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3711 .ok()
3712 .and_then(|v| v.parse().ok())
3713 .unwrap_or(5000) // Picked arbitrarily
3714 }
3715
3716 fn refill_duration(&self) -> chrono::Duration {
3717 chrono::Duration::hours(1)
3718 }
3719
3720 fn db_name(&self) -> &'static str {
3721 "zed-pro:compute-embeddings"
3722 }
3723}
3724
3725struct FreeComputeEmbeddingsRateLimit;
3726
3727impl RateLimit for FreeComputeEmbeddingsRateLimit {
3728 fn capacity(&self) -> usize {
3729 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3730 .ok()
3731 .and_then(|v| v.parse().ok())
3732 .unwrap_or(5000 / 10) // Picked arbitrarily
3733 }
3734
3735 fn refill_duration(&self) -> chrono::Duration {
3736 chrono::Duration::hours(1)
3737 }
3738
3739 fn db_name(&self) -> &'static str {
3740 "free:compute-embeddings"
3741 }
3742}
3743
3744async fn compute_embeddings(
3745 request: proto::ComputeEmbeddings,
3746 response: Response<proto::ComputeEmbeddings>,
3747 session: Session,
3748 api_key: Option<Arc<str>>,
3749) -> Result<()> {
3750 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3751 authorize_access_to_legacy_llm_endpoints(&session).await?;
3752
3753 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3754 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3755 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3756 };
3757
3758 session
3759 .app_state
3760 .rate_limiter
3761 .check(&*rate_limit, session.user_id())
3762 .await?;
3763
3764 let embeddings = match request.model.as_str() {
3765 "openai/text-embedding-3-small" => {
3766 open_ai::embed(
3767 session.http_client.as_ref(),
3768 OPEN_AI_API_URL,
3769 &api_key,
3770 OpenAiEmbeddingModel::TextEmbedding3Small,
3771 request.texts.iter().map(|text| text.as_str()),
3772 )
3773 .await?
3774 }
3775 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3776 };
3777
3778 let embeddings = request
3779 .texts
3780 .iter()
3781 .map(|text| {
3782 let mut hasher = sha2::Sha256::new();
3783 hasher.update(text.as_bytes());
3784 let result = hasher.finalize();
3785 result.to_vec()
3786 })
3787 .zip(
3788 embeddings
3789 .data
3790 .into_iter()
3791 .map(|embedding| embedding.embedding),
3792 )
3793 .collect::<HashMap<_, _>>();
3794
3795 let db = session.db().await;
3796 db.save_embeddings(&request.model, &embeddings)
3797 .await
3798 .context("failed to save embeddings")
3799 .trace_err();
3800
3801 response.send(proto::ComputeEmbeddingsResponse {
3802 embeddings: embeddings
3803 .into_iter()
3804 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3805 .collect(),
3806 })?;
3807 Ok(())
3808}
3809
3810async fn get_cached_embeddings(
3811 request: proto::GetCachedEmbeddings,
3812 response: Response<proto::GetCachedEmbeddings>,
3813 session: Session,
3814) -> Result<()> {
3815 authorize_access_to_legacy_llm_endpoints(&session).await?;
3816
3817 let db = session.db().await;
3818 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3819
3820 response.send(proto::GetCachedEmbeddingsResponse {
3821 embeddings: embeddings
3822 .into_iter()
3823 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3824 .collect(),
3825 })?;
3826 Ok(())
3827}
3828
3829/// This is leftover from before the LLM service.
3830///
3831/// The endpoints protected by this check will be moved there eventually.
3832async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3833 if session.is_staff() {
3834 Ok(())
3835 } else {
3836 Err(anyhow!("permission denied"))?
3837 }
3838}
3839
3840/// Get a Supermaven API key for the user
3841async fn get_supermaven_api_key(
3842 _request: proto::GetSupermavenApiKey,
3843 response: Response<proto::GetSupermavenApiKey>,
3844 session: Session,
3845) -> Result<()> {
3846 let user_id: String = session.user_id().to_string();
3847 if !session.is_staff() {
3848 return Err(anyhow!("supermaven not enabled for this account"))?;
3849 }
3850
3851 let email = session
3852 .email()
3853 .ok_or_else(|| anyhow!("user must have an email"))?;
3854
3855 let supermaven_admin_api = session
3856 .supermaven_client
3857 .as_ref()
3858 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3859
3860 let result = supermaven_admin_api
3861 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3862 .await?;
3863
3864 response.send(proto::GetSupermavenApiKeyResponse {
3865 api_key: result.api_key,
3866 })?;
3867
3868 Ok(())
3869}
3870
3871/// Start receiving chat updates for a channel
3872async fn join_channel_chat(
3873 request: proto::JoinChannelChat,
3874 response: Response<proto::JoinChannelChat>,
3875 session: Session,
3876) -> Result<()> {
3877 let channel_id = ChannelId::from_proto(request.channel_id);
3878
3879 let db = session.db().await;
3880 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3881 .await?;
3882 let messages = db
3883 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3884 .await?;
3885 response.send(proto::JoinChannelChatResponse {
3886 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3887 messages,
3888 })?;
3889 Ok(())
3890}
3891
3892/// Stop receiving chat updates for a channel
3893async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3894 let channel_id = ChannelId::from_proto(request.channel_id);
3895 session
3896 .db()
3897 .await
3898 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3899 .await?;
3900 Ok(())
3901}
3902
3903/// Retrieve the chat history for a channel
3904async fn get_channel_messages(
3905 request: proto::GetChannelMessages,
3906 response: Response<proto::GetChannelMessages>,
3907 session: Session,
3908) -> Result<()> {
3909 let channel_id = ChannelId::from_proto(request.channel_id);
3910 let messages = session
3911 .db()
3912 .await
3913 .get_channel_messages(
3914 channel_id,
3915 session.user_id(),
3916 MESSAGE_COUNT_PER_PAGE,
3917 Some(MessageId::from_proto(request.before_message_id)),
3918 )
3919 .await?;
3920 response.send(proto::GetChannelMessagesResponse {
3921 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3922 messages,
3923 })?;
3924 Ok(())
3925}
3926
3927/// Retrieve specific chat messages
3928async fn get_channel_messages_by_id(
3929 request: proto::GetChannelMessagesById,
3930 response: Response<proto::GetChannelMessagesById>,
3931 session: Session,
3932) -> Result<()> {
3933 let message_ids = request
3934 .message_ids
3935 .iter()
3936 .map(|id| MessageId::from_proto(*id))
3937 .collect::<Vec<_>>();
3938 let messages = session
3939 .db()
3940 .await
3941 .get_channel_messages_by_id(session.user_id(), &message_ids)
3942 .await?;
3943 response.send(proto::GetChannelMessagesResponse {
3944 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3945 messages,
3946 })?;
3947 Ok(())
3948}
3949
3950/// Retrieve the current users notifications
3951async fn get_notifications(
3952 request: proto::GetNotifications,
3953 response: Response<proto::GetNotifications>,
3954 session: Session,
3955) -> Result<()> {
3956 let notifications = session
3957 .db()
3958 .await
3959 .get_notifications(
3960 session.user_id(),
3961 NOTIFICATION_COUNT_PER_PAGE,
3962 request.before_id.map(db::NotificationId::from_proto),
3963 )
3964 .await?;
3965 response.send(proto::GetNotificationsResponse {
3966 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3967 notifications,
3968 })?;
3969 Ok(())
3970}
3971
3972/// Mark notifications as read
3973async fn mark_notification_as_read(
3974 request: proto::MarkNotificationRead,
3975 response: Response<proto::MarkNotificationRead>,
3976 session: Session,
3977) -> Result<()> {
3978 let database = &session.db().await;
3979 let notifications = database
3980 .mark_notification_as_read_by_id(
3981 session.user_id(),
3982 NotificationId::from_proto(request.notification_id),
3983 )
3984 .await?;
3985 send_notifications(
3986 &*session.connection_pool().await,
3987 &session.peer,
3988 notifications,
3989 );
3990 response.send(proto::Ack {})?;
3991 Ok(())
3992}
3993
3994/// Get the current users information
3995async fn get_private_user_info(
3996 _request: proto::GetPrivateUserInfo,
3997 response: Response<proto::GetPrivateUserInfo>,
3998 session: Session,
3999) -> Result<()> {
4000 let db = session.db().await;
4001
4002 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4003 let user = db
4004 .get_user_by_id(session.user_id())
4005 .await?
4006 .ok_or_else(|| anyhow!("user not found"))?;
4007 let flags = db.get_user_flags(session.user_id()).await?;
4008
4009 response.send(proto::GetPrivateUserInfoResponse {
4010 metrics_id,
4011 staff: user.admin,
4012 flags,
4013 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4014 })?;
4015 Ok(())
4016}
4017
4018/// Accept the terms of service (tos) on behalf of the current user
4019async fn accept_terms_of_service(
4020 _request: proto::AcceptTermsOfService,
4021 response: Response<proto::AcceptTermsOfService>,
4022 session: Session,
4023) -> Result<()> {
4024 let db = session.db().await;
4025
4026 let accepted_tos_at = Utc::now();
4027 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4028 .await?;
4029
4030 response.send(proto::AcceptTermsOfServiceResponse {
4031 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4032 })?;
4033 Ok(())
4034}
4035
4036/// The minimum account age an account must have in order to use the LLM service.
4037const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4038
4039async fn get_llm_api_token(
4040 _request: proto::GetLlmToken,
4041 response: Response<proto::GetLlmToken>,
4042 session: Session,
4043) -> Result<()> {
4044 let db = session.db().await;
4045
4046 let flags = db.get_user_flags(session.user_id()).await?;
4047 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4048 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4049 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4050
4051 if !session.is_staff() && !has_language_models_feature_flag {
4052 Err(anyhow!("permission denied"))?
4053 }
4054
4055 let user_id = session.user_id();
4056 let user = db
4057 .get_user_by_id(user_id)
4058 .await?
4059 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4060
4061 if user.accepted_tos_at.is_none() {
4062 Err(anyhow!("terms of service not accepted"))?
4063 }
4064
4065 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4066
4067 let bypass_account_age_check =
4068 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4069 if !bypass_account_age_check {
4070 let mut account_created_at = user.created_at;
4071 if let Some(github_created_at) = user.github_user_created_at {
4072 account_created_at = account_created_at.min(github_created_at);
4073 }
4074 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4075 Err(anyhow!("account too young"))?
4076 }
4077 }
4078
4079 let billing_preferences = db.get_billing_preferences(user.id).await?;
4080
4081 let token = LlmTokenClaims::create(
4082 &user,
4083 session.is_staff(),
4084 billing_preferences,
4085 has_llm_closed_beta_feature_flag,
4086 has_predict_edits_feature_flag,
4087 has_llm_subscription,
4088 session.current_plan(&db).await?,
4089 session.system_id.clone(),
4090 &session.app_state.config,
4091 )?;
4092 response.send(proto::GetLlmTokenResponse { token })?;
4093 Ok(())
4094}
4095
4096fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4097 let message = match message {
4098 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4099 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4100 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4101 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4102 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4103 code: frame.code.into(),
4104 reason: frame.reason,
4105 })),
4106 // We should never receive a frame while reading the message, according
4107 // to the `tungstenite` maintainers:
4108 //
4109 // > It cannot occur when you read messages from the WebSocket, but it
4110 // > can be used when you want to send the raw frames (e.g. you want to
4111 // > send the frames to the WebSocket without composing the full message first).
4112 // >
4113 // > — https://github.com/snapview/tungstenite-rs/issues/268
4114 TungsteniteMessage::Frame(_) => {
4115 bail!("received an unexpected frame while reading the message")
4116 }
4117 };
4118
4119 Ok(message)
4120}
4121
4122fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4123 match message {
4124 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4125 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4126 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4127 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4128 AxumMessage::Close(frame) => {
4129 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4130 code: frame.code.into(),
4131 reason: frame.reason,
4132 }))
4133 }
4134 }
4135}
4136
4137fn notify_membership_updated(
4138 connection_pool: &mut ConnectionPool,
4139 result: MembershipUpdated,
4140 user_id: UserId,
4141 peer: &Peer,
4142) {
4143 for membership in &result.new_channels.channel_memberships {
4144 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4145 }
4146 for channel_id in &result.removed_channels {
4147 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4148 }
4149
4150 let user_channels_update = proto::UpdateUserChannels {
4151 channel_memberships: result
4152 .new_channels
4153 .channel_memberships
4154 .iter()
4155 .map(|cm| proto::ChannelMembership {
4156 channel_id: cm.channel_id.to_proto(),
4157 role: cm.role.into(),
4158 })
4159 .collect(),
4160 ..Default::default()
4161 };
4162
4163 let mut update = build_channels_update(result.new_channels);
4164 update.delete_channels = result
4165 .removed_channels
4166 .into_iter()
4167 .map(|id| id.to_proto())
4168 .collect();
4169 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4170
4171 for connection_id in connection_pool.user_connection_ids(user_id) {
4172 peer.send(connection_id, user_channels_update.clone())
4173 .trace_err();
4174 peer.send(connection_id, update.clone()).trace_err();
4175 }
4176}
4177
4178fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4179 proto::UpdateUserChannels {
4180 channel_memberships: channels
4181 .channel_memberships
4182 .iter()
4183 .map(|m| proto::ChannelMembership {
4184 channel_id: m.channel_id.to_proto(),
4185 role: m.role.into(),
4186 })
4187 .collect(),
4188 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4189 observed_channel_message_id: channels.observed_channel_messages.clone(),
4190 }
4191}
4192
4193fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4194 let mut update = proto::UpdateChannels::default();
4195
4196 for channel in channels.channels {
4197 update.channels.push(channel.to_proto());
4198 }
4199
4200 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4201 update.latest_channel_message_ids = channels.latest_channel_messages;
4202
4203 for (channel_id, participants) in channels.channel_participants {
4204 update
4205 .channel_participants
4206 .push(proto::ChannelParticipants {
4207 channel_id: channel_id.to_proto(),
4208 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4209 });
4210 }
4211
4212 for channel in channels.invited_channels {
4213 update.channel_invitations.push(channel.to_proto());
4214 }
4215
4216 update
4217}
4218
4219fn build_initial_contacts_update(
4220 contacts: Vec<db::Contact>,
4221 pool: &ConnectionPool,
4222) -> proto::UpdateContacts {
4223 let mut update = proto::UpdateContacts::default();
4224
4225 for contact in contacts {
4226 match contact {
4227 db::Contact::Accepted { user_id, busy } => {
4228 update.contacts.push(contact_for_user(user_id, busy, pool));
4229 }
4230 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4231 db::Contact::Incoming { user_id } => {
4232 update
4233 .incoming_requests
4234 .push(proto::IncomingContactRequest {
4235 requester_id: user_id.to_proto(),
4236 })
4237 }
4238 }
4239 }
4240
4241 update
4242}
4243
4244fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4245 proto::Contact {
4246 user_id: user_id.to_proto(),
4247 online: pool.is_user_online(user_id),
4248 busy,
4249 }
4250}
4251
4252fn room_updated(room: &proto::Room, peer: &Peer) {
4253 broadcast(
4254 None,
4255 room.participants
4256 .iter()
4257 .filter_map(|participant| Some(participant.peer_id?.into())),
4258 |peer_id| {
4259 peer.send(
4260 peer_id,
4261 proto::RoomUpdated {
4262 room: Some(room.clone()),
4263 },
4264 )
4265 },
4266 );
4267}
4268
4269fn channel_updated(
4270 channel: &db::channel::Model,
4271 room: &proto::Room,
4272 peer: &Peer,
4273 pool: &ConnectionPool,
4274) {
4275 let participants = room
4276 .participants
4277 .iter()
4278 .map(|p| p.user_id)
4279 .collect::<Vec<_>>();
4280
4281 broadcast(
4282 None,
4283 pool.channel_connection_ids(channel.root_id())
4284 .filter_map(|(channel_id, role)| {
4285 role.can_see_channel(channel.visibility)
4286 .then_some(channel_id)
4287 }),
4288 |peer_id| {
4289 peer.send(
4290 peer_id,
4291 proto::UpdateChannels {
4292 channel_participants: vec![proto::ChannelParticipants {
4293 channel_id: channel.id.to_proto(),
4294 participant_user_ids: participants.clone(),
4295 }],
4296 ..Default::default()
4297 },
4298 )
4299 },
4300 );
4301}
4302
4303async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4304 let db = session.db().await;
4305
4306 let contacts = db.get_contacts(user_id).await?;
4307 let busy = db.is_user_busy(user_id).await?;
4308
4309 let pool = session.connection_pool().await;
4310 let updated_contact = contact_for_user(user_id, busy, &pool);
4311 for contact in contacts {
4312 if let db::Contact::Accepted {
4313 user_id: contact_user_id,
4314 ..
4315 } = contact
4316 {
4317 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4318 session
4319 .peer
4320 .send(
4321 contact_conn_id,
4322 proto::UpdateContacts {
4323 contacts: vec![updated_contact.clone()],
4324 remove_contacts: Default::default(),
4325 incoming_requests: Default::default(),
4326 remove_incoming_requests: Default::default(),
4327 outgoing_requests: Default::default(),
4328 remove_outgoing_requests: Default::default(),
4329 },
4330 )
4331 .trace_err();
4332 }
4333 }
4334 }
4335 Ok(())
4336}
4337
4338async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4339 let mut contacts_to_update = HashSet::default();
4340
4341 let room_id;
4342 let canceled_calls_to_user_ids;
4343 let livekit_room;
4344 let delete_livekit_room;
4345 let room;
4346 let channel;
4347
4348 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4349 contacts_to_update.insert(session.user_id());
4350
4351 for project in left_room.left_projects.values() {
4352 project_left(project, session);
4353 }
4354
4355 room_id = RoomId::from_proto(left_room.room.id);
4356 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4357 livekit_room = mem::take(&mut left_room.room.livekit_room);
4358 delete_livekit_room = left_room.deleted;
4359 room = mem::take(&mut left_room.room);
4360 channel = mem::take(&mut left_room.channel);
4361
4362 room_updated(&room, &session.peer);
4363 } else {
4364 return Ok(());
4365 }
4366
4367 if let Some(channel) = channel {
4368 channel_updated(
4369 &channel,
4370 &room,
4371 &session.peer,
4372 &*session.connection_pool().await,
4373 );
4374 }
4375
4376 {
4377 let pool = session.connection_pool().await;
4378 for canceled_user_id in canceled_calls_to_user_ids {
4379 for connection_id in pool.user_connection_ids(canceled_user_id) {
4380 session
4381 .peer
4382 .send(
4383 connection_id,
4384 proto::CallCanceled {
4385 room_id: room_id.to_proto(),
4386 },
4387 )
4388 .trace_err();
4389 }
4390 contacts_to_update.insert(canceled_user_id);
4391 }
4392 }
4393
4394 for contact_user_id in contacts_to_update {
4395 update_user_contacts(contact_user_id, session).await?;
4396 }
4397
4398 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4399 live_kit
4400 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4401 .await
4402 .trace_err();
4403
4404 if delete_livekit_room {
4405 live_kit.delete_room(livekit_room).await.trace_err();
4406 }
4407 }
4408
4409 Ok(())
4410}
4411
4412async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4413 let left_channel_buffers = session
4414 .db()
4415 .await
4416 .leave_channel_buffers(session.connection_id)
4417 .await?;
4418
4419 for left_buffer in left_channel_buffers {
4420 channel_buffer_updated(
4421 session.connection_id,
4422 left_buffer.connections,
4423 &proto::UpdateChannelBufferCollaborators {
4424 channel_id: left_buffer.channel_id.to_proto(),
4425 collaborators: left_buffer.collaborators,
4426 },
4427 &session.peer,
4428 );
4429 }
4430
4431 Ok(())
4432}
4433
4434fn project_left(project: &db::LeftProject, session: &Session) {
4435 for connection_id in &project.connection_ids {
4436 if project.should_unshare {
4437 session
4438 .peer
4439 .send(
4440 *connection_id,
4441 proto::UnshareProject {
4442 project_id: project.id.to_proto(),
4443 },
4444 )
4445 .trace_err();
4446 } else {
4447 session
4448 .peer
4449 .send(
4450 *connection_id,
4451 proto::RemoveProjectCollaborator {
4452 project_id: project.id.to_proto(),
4453 peer_id: Some(session.connection_id.into()),
4454 },
4455 )
4456 .trace_err();
4457 }
4458 }
4459}
4460
4461pub trait ResultExt {
4462 type Ok;
4463
4464 fn trace_err(self) -> Option<Self::Ok>;
4465}
4466
4467impl<T, E> ResultExt for Result<T, E>
4468where
4469 E: std::fmt::Debug,
4470{
4471 type Ok = T;
4472
4473 #[track_caller]
4474 fn trace_err(self) -> Option<T> {
4475 match self {
4476 Ok(value) => Some(value),
4477 Err(error) => {
4478 tracing::error!("{:?}", error);
4479 None
4480 }
4481 }
4482 }
4483}