1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
332 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
342 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
345 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
346 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
347 .add_message_handler(create_buffer_for_peer)
348 .add_request_handler(update_buffer)
349 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
354 .add_request_handler(get_users)
355 .add_request_handler(fuzzy_search_users)
356 .add_request_handler(request_contact)
357 .add_request_handler(remove_contact)
358 .add_request_handler(respond_to_contact_request)
359 .add_message_handler(subscribe_to_channels)
360 .add_request_handler(create_channel)
361 .add_request_handler(delete_channel)
362 .add_request_handler(invite_channel_member)
363 .add_request_handler(remove_channel_member)
364 .add_request_handler(set_channel_member_role)
365 .add_request_handler(set_channel_visibility)
366 .add_request_handler(rename_channel)
367 .add_request_handler(join_channel_buffer)
368 .add_request_handler(leave_channel_buffer)
369 .add_message_handler(update_channel_buffer)
370 .add_request_handler(rejoin_channel_buffers)
371 .add_request_handler(get_channel_members)
372 .add_request_handler(respond_to_channel_invite)
373 .add_request_handler(join_channel)
374 .add_request_handler(join_channel_chat)
375 .add_message_handler(leave_channel_chat)
376 .add_request_handler(send_channel_message)
377 .add_request_handler(remove_channel_message)
378 .add_request_handler(update_channel_message)
379 .add_request_handler(get_channel_messages)
380 .add_request_handler(get_channel_messages_by_id)
381 .add_request_handler(get_notifications)
382 .add_request_handler(mark_notification_as_read)
383 .add_request_handler(move_channel)
384 .add_request_handler(follow)
385 .add_message_handler(unfollow)
386 .add_message_handler(update_followers)
387 .add_request_handler(get_private_user_info)
388 .add_request_handler(get_llm_api_token)
389 .add_request_handler(accept_terms_of_service)
390 .add_message_handler(acknowledge_channel_message)
391 .add_message_handler(acknowledge_buffer_version)
392 .add_request_handler(get_supermaven_api_key)
393 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
395 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
396 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
398 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
399 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
400 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
401 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
402 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
403 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
404 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
405 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
406 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
407 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
408 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
410 .add_message_handler(update_context)
411 .add_request_handler({
412 let app_state = app_state.clone();
413 move |request, response, session| {
414 let app_state = app_state.clone();
415 async move {
416 count_language_model_tokens(request, response, session, &app_state.config)
417 .await
418 }
419 }
420 })
421 .add_request_handler(get_cached_embeddings)
422 .add_request_handler({
423 let app_state = app_state.clone();
424 move |request, response, session| {
425 compute_embeddings(
426 request,
427 response,
428 session,
429 app_state.config.openai_api_key.clone(),
430 )
431 }
432 });
433
434 Arc::new(server)
435 }
436
437 pub async fn start(&self) -> Result<()> {
438 let server_id = *self.id.lock();
439 let app_state = self.app_state.clone();
440 let peer = self.peer.clone();
441 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
442 let pool = self.connection_pool.clone();
443 let livekit_client = self.app_state.livekit_client.clone();
444
445 let span = info_span!("start server");
446 self.app_state.executor.spawn_detached(
447 async move {
448 tracing::info!("waiting for cleanup timeout");
449 timeout.await;
450 tracing::info!("cleanup timeout expired, retrieving stale rooms");
451 if let Some((room_ids, channel_ids)) = app_state
452 .db
453 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
454 .await
455 .trace_err()
456 {
457 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
458 tracing::info!(
459 stale_channel_buffer_count = channel_ids.len(),
460 "retrieved stale channel buffers"
461 );
462
463 for channel_id in channel_ids {
464 if let Some(refreshed_channel_buffer) = app_state
465 .db
466 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
467 .await
468 .trace_err()
469 {
470 for connection_id in refreshed_channel_buffer.connection_ids {
471 peer.send(
472 connection_id,
473 proto::UpdateChannelBufferCollaborators {
474 channel_id: channel_id.to_proto(),
475 collaborators: refreshed_channel_buffer
476 .collaborators
477 .clone(),
478 },
479 )
480 .trace_err();
481 }
482 }
483 }
484
485 for room_id in room_ids {
486 let mut contacts_to_update = HashSet::default();
487 let mut canceled_calls_to_user_ids = Vec::new();
488 let mut livekit_room = String::new();
489 let mut delete_livekit_room = false;
490
491 if let Some(mut refreshed_room) = app_state
492 .db
493 .clear_stale_room_participants(room_id, server_id)
494 .await
495 .trace_err()
496 {
497 tracing::info!(
498 room_id = room_id.0,
499 new_participant_count = refreshed_room.room.participants.len(),
500 "refreshed room"
501 );
502 room_updated(&refreshed_room.room, &peer);
503 if let Some(channel) = refreshed_room.channel.as_ref() {
504 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
505 }
506 contacts_to_update
507 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
508 contacts_to_update
509 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
510 canceled_calls_to_user_ids =
511 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
512 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
513 delete_livekit_room = refreshed_room.room.participants.is_empty();
514 }
515
516 {
517 let pool = pool.lock();
518 for canceled_user_id in canceled_calls_to_user_ids {
519 for connection_id in pool.user_connection_ids(canceled_user_id) {
520 peer.send(
521 connection_id,
522 proto::CallCanceled {
523 room_id: room_id.to_proto(),
524 },
525 )
526 .trace_err();
527 }
528 }
529 }
530
531 for user_id in contacts_to_update {
532 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
533 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
534 if let Some((busy, contacts)) = busy.zip(contacts) {
535 let pool = pool.lock();
536 let updated_contact = contact_for_user(user_id, busy, &pool);
537 for contact in contacts {
538 if let db::Contact::Accepted {
539 user_id: contact_user_id,
540 ..
541 } = contact
542 {
543 for contact_conn_id in
544 pool.user_connection_ids(contact_user_id)
545 {
546 peer.send(
547 contact_conn_id,
548 proto::UpdateContacts {
549 contacts: vec![updated_contact.clone()],
550 remove_contacts: Default::default(),
551 incoming_requests: Default::default(),
552 remove_incoming_requests: Default::default(),
553 outgoing_requests: Default::default(),
554 remove_outgoing_requests: Default::default(),
555 },
556 )
557 .trace_err();
558 }
559 }
560 }
561 }
562 }
563
564 if let Some(live_kit) = livekit_client.as_ref() {
565 if delete_livekit_room {
566 live_kit.delete_room(livekit_room).await.trace_err();
567 }
568 }
569 }
570 }
571
572 app_state
573 .db
574 .delete_stale_servers(&app_state.config.zed_environment, server_id)
575 .await
576 .trace_err();
577 }
578 .instrument(span),
579 );
580 Ok(())
581 }
582
583 pub fn teardown(&self) {
584 self.peer.teardown();
585 self.connection_pool.lock().reset();
586 let _ = self.teardown.send(true);
587 }
588
589 #[cfg(test)]
590 pub fn reset(&self, id: ServerId) {
591 self.teardown();
592 *self.id.lock() = id;
593 self.peer.reset(id.0 as u32);
594 let _ = self.teardown.send(false);
595 }
596
597 #[cfg(test)]
598 pub fn id(&self) -> ServerId {
599 *self.id.lock()
600 }
601
602 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
603 where
604 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
605 Fut: 'static + Send + Future<Output = Result<()>>,
606 M: EnvelopedMessage,
607 {
608 let prev_handler = self.handlers.insert(
609 TypeId::of::<M>(),
610 Box::new(move |envelope, session| {
611 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
612 let received_at = envelope.received_at;
613 tracing::info!("message received");
614 let start_time = Instant::now();
615 let future = (handler)(*envelope, session);
616 async move {
617 let result = future.await;
618 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
619 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
620 let queue_duration_ms = total_duration_ms - processing_duration_ms;
621 let payload_type = M::NAME;
622
623 match result {
624 Err(error) => {
625 tracing::error!(
626 ?error,
627 total_duration_ms,
628 processing_duration_ms,
629 queue_duration_ms,
630 payload_type,
631 "error handling message"
632 )
633 }
634 Ok(()) => tracing::info!(
635 total_duration_ms,
636 processing_duration_ms,
637 queue_duration_ms,
638 "finished handling message"
639 ),
640 }
641 }
642 .boxed()
643 }),
644 );
645 if prev_handler.is_some() {
646 panic!("registered a handler for the same message twice");
647 }
648 self
649 }
650
651 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
652 where
653 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
654 Fut: 'static + Send + Future<Output = Result<()>>,
655 M: EnvelopedMessage,
656 {
657 self.add_handler(move |envelope, session| handler(envelope.payload, session));
658 self
659 }
660
661 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
662 where
663 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
664 Fut: Send + Future<Output = Result<()>>,
665 M: RequestMessage,
666 {
667 let handler = Arc::new(handler);
668 self.add_handler(move |envelope, session| {
669 let receipt = envelope.receipt();
670 let handler = handler.clone();
671 async move {
672 let peer = session.peer.clone();
673 let responded = Arc::new(AtomicBool::default());
674 let response = Response {
675 peer: peer.clone(),
676 responded: responded.clone(),
677 receipt,
678 };
679 match (handler)(envelope.payload, response, session).await {
680 Ok(()) => {
681 if responded.load(std::sync::atomic::Ordering::SeqCst) {
682 Ok(())
683 } else {
684 Err(anyhow!("handler did not send a response"))?
685 }
686 }
687 Err(error) => {
688 let proto_err = match &error {
689 Error::Internal(err) => err.to_proto(),
690 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
691 };
692 peer.respond_with_error(receipt, proto_err)?;
693 Err(error)
694 }
695 }
696 }
697 })
698 }
699
700 pub fn handle_connection(
701 self: &Arc<Self>,
702 connection: Connection,
703 address: String,
704 principal: Principal,
705 zed_version: ZedVersion,
706 geoip_country_code: Option<String>,
707 system_id: Option<String>,
708 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
709 executor: Executor,
710 ) -> impl Future<Output = ()> {
711 let this = self.clone();
712 let span = info_span!("handle connection", %address,
713 connection_id=field::Empty,
714 user_id=field::Empty,
715 login=field::Empty,
716 impersonator=field::Empty,
717 geoip_country_code=field::Empty
718 );
719 principal.update_span(&span);
720 if let Some(country_code) = geoip_country_code.as_ref() {
721 span.record("geoip_country_code", country_code);
722 }
723
724 let mut teardown = self.teardown.subscribe();
725 async move {
726 if *teardown.borrow() {
727 tracing::error!("server is tearing down");
728 return
729 }
730 let (connection_id, handle_io, mut incoming_rx) = this
731 .peer
732 .add_connection(connection, {
733 let executor = executor.clone();
734 move |duration| executor.sleep(duration)
735 });
736 tracing::Span::current().record("connection_id", format!("{}", connection_id));
737
738 tracing::info!("connection opened");
739
740 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
741 let http_client = match ReqwestClient::user_agent(&user_agent) {
742 Ok(http_client) => Arc::new(http_client),
743 Err(error) => {
744 tracing::error!(?error, "failed to create HTTP client");
745 return;
746 }
747 };
748
749 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
750 supermaven_admin_api_key.to_string(),
751 http_client.clone(),
752 )));
753
754 let session = Session {
755 principal: principal.clone(),
756 connection_id,
757 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
758 peer: this.peer.clone(),
759 connection_pool: this.connection_pool.clone(),
760 app_state: this.app_state.clone(),
761 http_client,
762 geoip_country_code,
763 system_id,
764 _executor: executor.clone(),
765 supermaven_client,
766 };
767
768 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
769 tracing::error!(?error, "failed to send initial client update");
770 return;
771 }
772
773 let handle_io = handle_io.fuse();
774 futures::pin_mut!(handle_io);
775
776 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
777 // This prevents deadlocks when e.g., client A performs a request to client B and
778 // client B performs a request to client A. If both clients stop processing further
779 // messages until their respective request completes, they won't have a chance to
780 // respond to the other client's request and cause a deadlock.
781 //
782 // This arrangement ensures we will attempt to process earlier messages first, but fall
783 // back to processing messages arrived later in the spirit of making progress.
784 let mut foreground_message_handlers = FuturesUnordered::new();
785 let concurrent_handlers = Arc::new(Semaphore::new(256));
786 loop {
787 let next_message = async {
788 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
789 let message = incoming_rx.next().await;
790 (permit, message)
791 }.fuse();
792 futures::pin_mut!(next_message);
793 futures::select_biased! {
794 _ = teardown.changed().fuse() => return,
795 result = handle_io => {
796 if let Err(error) = result {
797 tracing::error!(?error, "error handling I/O");
798 }
799 break;
800 }
801 _ = foreground_message_handlers.next() => {}
802 next_message = next_message => {
803 let (permit, message) = next_message;
804 if let Some(message) = message {
805 let type_name = message.payload_type_name();
806 // note: we copy all the fields from the parent span so we can query them in the logs.
807 // (https://github.com/tokio-rs/tracing/issues/2670).
808 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
809 user_id=field::Empty,
810 login=field::Empty,
811 impersonator=field::Empty,
812 );
813 principal.update_span(&span);
814 let span_enter = span.enter();
815 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
816 let is_background = message.is_background();
817 let handle_message = (handler)(message, session.clone());
818 drop(span_enter);
819
820 let handle_message = async move {
821 handle_message.await;
822 drop(permit);
823 }.instrument(span);
824 if is_background {
825 executor.spawn_detached(handle_message);
826 } else {
827 foreground_message_handlers.push(handle_message);
828 }
829 } else {
830 tracing::error!("no message handler");
831 }
832 } else {
833 tracing::info!("connection closed");
834 break;
835 }
836 }
837 }
838 }
839
840 drop(foreground_message_handlers);
841 tracing::info!("signing out");
842 if let Err(error) = connection_lost(session, teardown, executor).await {
843 tracing::error!(?error, "error signing out");
844 }
845
846 }.instrument(span)
847 }
848
849 async fn send_initial_client_update(
850 &self,
851 connection_id: ConnectionId,
852 principal: &Principal,
853 zed_version: ZedVersion,
854 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
855 session: &Session,
856 ) -> Result<()> {
857 self.peer.send(
858 connection_id,
859 proto::Hello {
860 peer_id: Some(connection_id.into()),
861 },
862 )?;
863 tracing::info!("sent hello message");
864 if let Some(send_connection_id) = send_connection_id.take() {
865 let _ = send_connection_id.send(connection_id);
866 }
867
868 match principal {
869 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
870 if !user.connected_once {
871 self.peer.send(connection_id, proto::ShowContacts {})?;
872 self.app_state
873 .db
874 .set_user_connected_once(user.id, true)
875 .await?;
876 }
877
878 update_user_plan(user.id, session).await?;
879
880 let contacts = self.app_state.db.get_contacts(user.id).await?;
881
882 {
883 let mut pool = self.connection_pool.lock();
884 pool.add_connection(connection_id, user.id, user.admin, zed_version);
885 self.peer.send(
886 connection_id,
887 build_initial_contacts_update(contacts, &pool),
888 )?;
889 }
890
891 if should_auto_subscribe_to_channels(zed_version) {
892 subscribe_user_to_channels(user.id, session).await?;
893 }
894
895 if let Some(incoming_call) =
896 self.app_state.db.incoming_call_for_user(user.id).await?
897 {
898 self.peer.send(connection_id, incoming_call)?;
899 }
900
901 update_user_contacts(user.id, session).await?;
902 }
903 }
904
905 Ok(())
906 }
907
908 pub async fn invite_code_redeemed(
909 self: &Arc<Self>,
910 inviter_id: UserId,
911 invitee_id: UserId,
912 ) -> Result<()> {
913 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
914 if let Some(code) = &user.invite_code {
915 let pool = self.connection_pool.lock();
916 let invitee_contact = contact_for_user(invitee_id, false, &pool);
917 for connection_id in pool.user_connection_ids(inviter_id) {
918 self.peer.send(
919 connection_id,
920 proto::UpdateContacts {
921 contacts: vec![invitee_contact.clone()],
922 ..Default::default()
923 },
924 )?;
925 self.peer.send(
926 connection_id,
927 proto::UpdateInviteInfo {
928 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
929 count: user.invite_count as u32,
930 },
931 )?;
932 }
933 }
934 }
935 Ok(())
936 }
937
938 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
939 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
940 if let Some(invite_code) = &user.invite_code {
941 let pool = self.connection_pool.lock();
942 for connection_id in pool.user_connection_ids(user_id) {
943 self.peer.send(
944 connection_id,
945 proto::UpdateInviteInfo {
946 url: format!(
947 "{}{}",
948 self.app_state.config.invite_link_prefix, invite_code
949 ),
950 count: user.invite_count as u32,
951 },
952 )?;
953 }
954 }
955 }
956 Ok(())
957 }
958
959 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
960 let pool = self.connection_pool.lock();
961 for connection_id in pool.user_connection_ids(user_id) {
962 self.peer
963 .send(connection_id, proto::RefreshLlmToken {})
964 .trace_err();
965 }
966 }
967
968 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
969 ServerSnapshot {
970 connection_pool: ConnectionPoolGuard {
971 guard: self.connection_pool.lock(),
972 _not_send: PhantomData,
973 },
974 peer: &self.peer,
975 }
976 }
977}
978
979impl Deref for ConnectionPoolGuard<'_> {
980 type Target = ConnectionPool;
981
982 fn deref(&self) -> &Self::Target {
983 &self.guard
984 }
985}
986
987impl DerefMut for ConnectionPoolGuard<'_> {
988 fn deref_mut(&mut self) -> &mut Self::Target {
989 &mut self.guard
990 }
991}
992
993impl Drop for ConnectionPoolGuard<'_> {
994 fn drop(&mut self) {
995 #[cfg(test)]
996 self.check_invariants();
997 }
998}
999
1000fn broadcast<F>(
1001 sender_id: Option<ConnectionId>,
1002 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1003 mut f: F,
1004) where
1005 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1006{
1007 for receiver_id in receiver_ids {
1008 if Some(receiver_id) != sender_id {
1009 if let Err(error) = f(receiver_id) {
1010 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1011 }
1012 }
1013 }
1014}
1015
1016pub struct ProtocolVersion(u32);
1017
1018impl Header for ProtocolVersion {
1019 fn name() -> &'static HeaderName {
1020 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1021 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1022 }
1023
1024 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1025 where
1026 Self: Sized,
1027 I: Iterator<Item = &'i axum::http::HeaderValue>,
1028 {
1029 let version = values
1030 .next()
1031 .ok_or_else(axum::headers::Error::invalid)?
1032 .to_str()
1033 .map_err(|_| axum::headers::Error::invalid())?
1034 .parse()
1035 .map_err(|_| axum::headers::Error::invalid())?;
1036 Ok(Self(version))
1037 }
1038
1039 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1040 values.extend([self.0.to_string().parse().unwrap()]);
1041 }
1042}
1043
1044pub struct AppVersionHeader(SemanticVersion);
1045impl Header for AppVersionHeader {
1046 fn name() -> &'static HeaderName {
1047 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1048 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1049 }
1050
1051 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1052 where
1053 Self: Sized,
1054 I: Iterator<Item = &'i axum::http::HeaderValue>,
1055 {
1056 let version = values
1057 .next()
1058 .ok_or_else(axum::headers::Error::invalid)?
1059 .to_str()
1060 .map_err(|_| axum::headers::Error::invalid())?
1061 .parse()
1062 .map_err(|_| axum::headers::Error::invalid())?;
1063 Ok(Self(version))
1064 }
1065
1066 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1067 values.extend([self.0.to_string().parse().unwrap()]);
1068 }
1069}
1070
1071pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1072 Router::new()
1073 .route("/rpc", get(handle_websocket_request))
1074 .layer(
1075 ServiceBuilder::new()
1076 .layer(Extension(server.app_state.clone()))
1077 .layer(middleware::from_fn(auth::validate_header)),
1078 )
1079 .route("/metrics", get(handle_metrics))
1080 .layer(Extension(server))
1081}
1082
1083pub async fn handle_websocket_request(
1084 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1085 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1086 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1087 Extension(server): Extension<Arc<Server>>,
1088 Extension(principal): Extension<Principal>,
1089 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1090 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1091 ws: WebSocketUpgrade,
1092) -> axum::response::Response {
1093 if protocol_version != rpc::PROTOCOL_VERSION {
1094 return (
1095 StatusCode::UPGRADE_REQUIRED,
1096 "client must be upgraded".to_string(),
1097 )
1098 .into_response();
1099 }
1100
1101 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1102 return (
1103 StatusCode::UPGRADE_REQUIRED,
1104 "no version header found".to_string(),
1105 )
1106 .into_response();
1107 };
1108
1109 if !version.can_collaborate() {
1110 return (
1111 StatusCode::UPGRADE_REQUIRED,
1112 "client must be upgraded".to_string(),
1113 )
1114 .into_response();
1115 }
1116
1117 let socket_address = socket_address.to_string();
1118 ws.on_upgrade(move |socket| {
1119 let socket = socket
1120 .map_ok(to_tungstenite_message)
1121 .err_into()
1122 .with(|message| async move { to_axum_message(message) });
1123 let connection = Connection::new(Box::pin(socket));
1124 async move {
1125 server
1126 .handle_connection(
1127 connection,
1128 socket_address,
1129 principal,
1130 version,
1131 country_code_header.map(|header| header.to_string()),
1132 system_id_header.map(|header| header.to_string()),
1133 None,
1134 Executor::Production,
1135 )
1136 .await;
1137 }
1138 })
1139}
1140
1141pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1142 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1143 let connections_metric = CONNECTIONS_METRIC
1144 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1145
1146 let connections = server
1147 .connection_pool
1148 .lock()
1149 .connections()
1150 .filter(|connection| !connection.admin)
1151 .count();
1152 connections_metric.set(connections as _);
1153
1154 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1155 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1156 register_int_gauge!(
1157 "shared_projects",
1158 "number of open projects with one or more guests"
1159 )
1160 .unwrap()
1161 });
1162
1163 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1164 shared_projects_metric.set(shared_projects as _);
1165
1166 let encoder = prometheus::TextEncoder::new();
1167 let metric_families = prometheus::gather();
1168 let encoded_metrics = encoder
1169 .encode_to_string(&metric_families)
1170 .map_err(|err| anyhow!("{}", err))?;
1171 Ok(encoded_metrics)
1172}
1173
1174#[instrument(err, skip(executor))]
1175async fn connection_lost(
1176 session: Session,
1177 mut teardown: watch::Receiver<bool>,
1178 executor: Executor,
1179) -> Result<()> {
1180 session.peer.disconnect(session.connection_id);
1181 session
1182 .connection_pool()
1183 .await
1184 .remove_connection(session.connection_id)?;
1185
1186 session
1187 .db()
1188 .await
1189 .connection_lost(session.connection_id)
1190 .await
1191 .trace_err();
1192
1193 futures::select_biased! {
1194 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1195
1196 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1197 leave_room_for_session(&session, session.connection_id).await.trace_err();
1198 leave_channel_buffers_for_session(&session)
1199 .await
1200 .trace_err();
1201
1202 if !session
1203 .connection_pool()
1204 .await
1205 .is_user_online(session.user_id())
1206 {
1207 let db = session.db().await;
1208 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1209 room_updated(&room, &session.peer);
1210 }
1211 }
1212
1213 update_user_contacts(session.user_id(), &session).await?;
1214 },
1215 _ = teardown.changed().fuse() => {}
1216 }
1217
1218 Ok(())
1219}
1220
1221/// Acknowledges a ping from a client, used to keep the connection alive.
1222async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1223 response.send(proto::Ack {})?;
1224 Ok(())
1225}
1226
1227/// Creates a new room for calling (outside of channels)
1228async fn create_room(
1229 _request: proto::CreateRoom,
1230 response: Response<proto::CreateRoom>,
1231 session: Session,
1232) -> Result<()> {
1233 let livekit_room = nanoid::nanoid!(30);
1234
1235 let live_kit_connection_info = util::maybe!(async {
1236 let live_kit = session.app_state.livekit_client.as_ref();
1237 let live_kit = live_kit?;
1238 let user_id = session.user_id().to_string();
1239
1240 let token = live_kit
1241 .room_token(&livekit_room, &user_id.to_string())
1242 .trace_err()?;
1243
1244 Some(proto::LiveKitConnectionInfo {
1245 server_url: live_kit.url().into(),
1246 token,
1247 can_publish: true,
1248 })
1249 })
1250 .await;
1251
1252 let room = session
1253 .db()
1254 .await
1255 .create_room(session.user_id(), session.connection_id, &livekit_room)
1256 .await?;
1257
1258 response.send(proto::CreateRoomResponse {
1259 room: Some(room.clone()),
1260 live_kit_connection_info,
1261 })?;
1262
1263 update_user_contacts(session.user_id(), &session).await?;
1264 Ok(())
1265}
1266
1267/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1268async fn join_room(
1269 request: proto::JoinRoom,
1270 response: Response<proto::JoinRoom>,
1271 session: Session,
1272) -> Result<()> {
1273 let room_id = RoomId::from_proto(request.id);
1274
1275 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1276
1277 if let Some(channel_id) = channel_id {
1278 return join_channel_internal(channel_id, Box::new(response), session).await;
1279 }
1280
1281 let joined_room = {
1282 let room = session
1283 .db()
1284 .await
1285 .join_room(room_id, session.user_id(), session.connection_id)
1286 .await?;
1287 room_updated(&room.room, &session.peer);
1288 room.into_inner()
1289 };
1290
1291 for connection_id in session
1292 .connection_pool()
1293 .await
1294 .user_connection_ids(session.user_id())
1295 {
1296 session
1297 .peer
1298 .send(
1299 connection_id,
1300 proto::CallCanceled {
1301 room_id: room_id.to_proto(),
1302 },
1303 )
1304 .trace_err();
1305 }
1306
1307 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1308 {
1309 live_kit
1310 .room_token(
1311 &joined_room.room.livekit_room,
1312 &session.user_id().to_string(),
1313 )
1314 .trace_err()
1315 .map(|token| proto::LiveKitConnectionInfo {
1316 server_url: live_kit.url().into(),
1317 token,
1318 can_publish: true,
1319 })
1320 } else {
1321 None
1322 };
1323
1324 response.send(proto::JoinRoomResponse {
1325 room: Some(joined_room.room),
1326 channel_id: None,
1327 live_kit_connection_info,
1328 })?;
1329
1330 update_user_contacts(session.user_id(), &session).await?;
1331 Ok(())
1332}
1333
1334/// Rejoin room is used to reconnect to a room after connection errors.
1335async fn rejoin_room(
1336 request: proto::RejoinRoom,
1337 response: Response<proto::RejoinRoom>,
1338 session: Session,
1339) -> Result<()> {
1340 let room;
1341 let channel;
1342 {
1343 let mut rejoined_room = session
1344 .db()
1345 .await
1346 .rejoin_room(request, session.user_id(), session.connection_id)
1347 .await?;
1348
1349 response.send(proto::RejoinRoomResponse {
1350 room: Some(rejoined_room.room.clone()),
1351 reshared_projects: rejoined_room
1352 .reshared_projects
1353 .iter()
1354 .map(|project| proto::ResharedProject {
1355 id: project.id.to_proto(),
1356 collaborators: project
1357 .collaborators
1358 .iter()
1359 .map(|collaborator| collaborator.to_proto())
1360 .collect(),
1361 })
1362 .collect(),
1363 rejoined_projects: rejoined_room
1364 .rejoined_projects
1365 .iter()
1366 .map(|rejoined_project| rejoined_project.to_proto())
1367 .collect(),
1368 })?;
1369 room_updated(&rejoined_room.room, &session.peer);
1370
1371 for project in &rejoined_room.reshared_projects {
1372 for collaborator in &project.collaborators {
1373 session
1374 .peer
1375 .send(
1376 collaborator.connection_id,
1377 proto::UpdateProjectCollaborator {
1378 project_id: project.id.to_proto(),
1379 old_peer_id: Some(project.old_connection_id.into()),
1380 new_peer_id: Some(session.connection_id.into()),
1381 },
1382 )
1383 .trace_err();
1384 }
1385
1386 broadcast(
1387 Some(session.connection_id),
1388 project
1389 .collaborators
1390 .iter()
1391 .map(|collaborator| collaborator.connection_id),
1392 |connection_id| {
1393 session.peer.forward_send(
1394 session.connection_id,
1395 connection_id,
1396 proto::UpdateProject {
1397 project_id: project.id.to_proto(),
1398 worktrees: project.worktrees.clone(),
1399 },
1400 )
1401 },
1402 );
1403 }
1404
1405 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1406
1407 let rejoined_room = rejoined_room.into_inner();
1408
1409 room = rejoined_room.room;
1410 channel = rejoined_room.channel;
1411 }
1412
1413 if let Some(channel) = channel {
1414 channel_updated(
1415 &channel,
1416 &room,
1417 &session.peer,
1418 &*session.connection_pool().await,
1419 );
1420 }
1421
1422 update_user_contacts(session.user_id(), &session).await?;
1423 Ok(())
1424}
1425
1426fn notify_rejoined_projects(
1427 rejoined_projects: &mut Vec<RejoinedProject>,
1428 session: &Session,
1429) -> Result<()> {
1430 for project in rejoined_projects.iter() {
1431 for collaborator in &project.collaborators {
1432 session
1433 .peer
1434 .send(
1435 collaborator.connection_id,
1436 proto::UpdateProjectCollaborator {
1437 project_id: project.id.to_proto(),
1438 old_peer_id: Some(project.old_connection_id.into()),
1439 new_peer_id: Some(session.connection_id.into()),
1440 },
1441 )
1442 .trace_err();
1443 }
1444 }
1445
1446 for project in rejoined_projects {
1447 for worktree in mem::take(&mut project.worktrees) {
1448 // Stream this worktree's entries.
1449 let message = proto::UpdateWorktree {
1450 project_id: project.id.to_proto(),
1451 worktree_id: worktree.id,
1452 abs_path: worktree.abs_path.clone(),
1453 root_name: worktree.root_name,
1454 updated_entries: worktree.updated_entries,
1455 removed_entries: worktree.removed_entries,
1456 scan_id: worktree.scan_id,
1457 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1458 updated_repositories: worktree.updated_repositories,
1459 removed_repositories: worktree.removed_repositories,
1460 };
1461 for update in proto::split_worktree_update(message) {
1462 session.peer.send(session.connection_id, update.clone())?;
1463 }
1464
1465 // Stream this worktree's diagnostics.
1466 for summary in worktree.diagnostic_summaries {
1467 session.peer.send(
1468 session.connection_id,
1469 proto::UpdateDiagnosticSummary {
1470 project_id: project.id.to_proto(),
1471 worktree_id: worktree.id,
1472 summary: Some(summary),
1473 },
1474 )?;
1475 }
1476
1477 for settings_file in worktree.settings_files {
1478 session.peer.send(
1479 session.connection_id,
1480 proto::UpdateWorktreeSettings {
1481 project_id: project.id.to_proto(),
1482 worktree_id: worktree.id,
1483 path: settings_file.path,
1484 content: Some(settings_file.content),
1485 kind: Some(settings_file.kind.to_proto().into()),
1486 },
1487 )?;
1488 }
1489 }
1490
1491 for language_server in &project.language_servers {
1492 session.peer.send(
1493 session.connection_id,
1494 proto::UpdateLanguageServer {
1495 project_id: project.id.to_proto(),
1496 language_server_id: language_server.id,
1497 variant: Some(
1498 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1499 proto::LspDiskBasedDiagnosticsUpdated {},
1500 ),
1501 ),
1502 },
1503 )?;
1504 }
1505 }
1506 Ok(())
1507}
1508
1509/// leave room disconnects from the room.
1510async fn leave_room(
1511 _: proto::LeaveRoom,
1512 response: Response<proto::LeaveRoom>,
1513 session: Session,
1514) -> Result<()> {
1515 leave_room_for_session(&session, session.connection_id).await?;
1516 response.send(proto::Ack {})?;
1517 Ok(())
1518}
1519
1520/// Updates the permissions of someone else in the room.
1521async fn set_room_participant_role(
1522 request: proto::SetRoomParticipantRole,
1523 response: Response<proto::SetRoomParticipantRole>,
1524 session: Session,
1525) -> Result<()> {
1526 let user_id = UserId::from_proto(request.user_id);
1527 let role = ChannelRole::from(request.role());
1528
1529 let (livekit_room, can_publish) = {
1530 let room = session
1531 .db()
1532 .await
1533 .set_room_participant_role(
1534 session.user_id(),
1535 RoomId::from_proto(request.room_id),
1536 user_id,
1537 role,
1538 )
1539 .await?;
1540
1541 let livekit_room = room.livekit_room.clone();
1542 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1543 room_updated(&room, &session.peer);
1544 (livekit_room, can_publish)
1545 };
1546
1547 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1548 live_kit
1549 .update_participant(
1550 livekit_room.clone(),
1551 request.user_id.to_string(),
1552 livekit_api::proto::ParticipantPermission {
1553 can_subscribe: true,
1554 can_publish,
1555 can_publish_data: can_publish,
1556 hidden: false,
1557 recorder: false,
1558 },
1559 )
1560 .await
1561 .trace_err();
1562 }
1563
1564 response.send(proto::Ack {})?;
1565 Ok(())
1566}
1567
1568/// Call someone else into the current room
1569async fn call(
1570 request: proto::Call,
1571 response: Response<proto::Call>,
1572 session: Session,
1573) -> Result<()> {
1574 let room_id = RoomId::from_proto(request.room_id);
1575 let calling_user_id = session.user_id();
1576 let calling_connection_id = session.connection_id;
1577 let called_user_id = UserId::from_proto(request.called_user_id);
1578 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1579 if !session
1580 .db()
1581 .await
1582 .has_contact(calling_user_id, called_user_id)
1583 .await?
1584 {
1585 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1586 }
1587
1588 let incoming_call = {
1589 let (room, incoming_call) = &mut *session
1590 .db()
1591 .await
1592 .call(
1593 room_id,
1594 calling_user_id,
1595 calling_connection_id,
1596 called_user_id,
1597 initial_project_id,
1598 )
1599 .await?;
1600 room_updated(room, &session.peer);
1601 mem::take(incoming_call)
1602 };
1603 update_user_contacts(called_user_id, &session).await?;
1604
1605 let mut calls = session
1606 .connection_pool()
1607 .await
1608 .user_connection_ids(called_user_id)
1609 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1610 .collect::<FuturesUnordered<_>>();
1611
1612 while let Some(call_response) = calls.next().await {
1613 match call_response.as_ref() {
1614 Ok(_) => {
1615 response.send(proto::Ack {})?;
1616 return Ok(());
1617 }
1618 Err(_) => {
1619 call_response.trace_err();
1620 }
1621 }
1622 }
1623
1624 {
1625 let room = session
1626 .db()
1627 .await
1628 .call_failed(room_id, called_user_id)
1629 .await?;
1630 room_updated(&room, &session.peer);
1631 }
1632 update_user_contacts(called_user_id, &session).await?;
1633
1634 Err(anyhow!("failed to ring user"))?
1635}
1636
1637/// Cancel an outgoing call.
1638async fn cancel_call(
1639 request: proto::CancelCall,
1640 response: Response<proto::CancelCall>,
1641 session: Session,
1642) -> Result<()> {
1643 let called_user_id = UserId::from_proto(request.called_user_id);
1644 let room_id = RoomId::from_proto(request.room_id);
1645 {
1646 let room = session
1647 .db()
1648 .await
1649 .cancel_call(room_id, session.connection_id, called_user_id)
1650 .await?;
1651 room_updated(&room, &session.peer);
1652 }
1653
1654 for connection_id in session
1655 .connection_pool()
1656 .await
1657 .user_connection_ids(called_user_id)
1658 {
1659 session
1660 .peer
1661 .send(
1662 connection_id,
1663 proto::CallCanceled {
1664 room_id: room_id.to_proto(),
1665 },
1666 )
1667 .trace_err();
1668 }
1669 response.send(proto::Ack {})?;
1670
1671 update_user_contacts(called_user_id, &session).await?;
1672 Ok(())
1673}
1674
1675/// Decline an incoming call.
1676async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1677 let room_id = RoomId::from_proto(message.room_id);
1678 {
1679 let room = session
1680 .db()
1681 .await
1682 .decline_call(Some(room_id), session.user_id())
1683 .await?
1684 .ok_or_else(|| anyhow!("failed to decline call"))?;
1685 room_updated(&room, &session.peer);
1686 }
1687
1688 for connection_id in session
1689 .connection_pool()
1690 .await
1691 .user_connection_ids(session.user_id())
1692 {
1693 session
1694 .peer
1695 .send(
1696 connection_id,
1697 proto::CallCanceled {
1698 room_id: room_id.to_proto(),
1699 },
1700 )
1701 .trace_err();
1702 }
1703 update_user_contacts(session.user_id(), &session).await?;
1704 Ok(())
1705}
1706
1707/// Updates other participants in the room with your current location.
1708async fn update_participant_location(
1709 request: proto::UpdateParticipantLocation,
1710 response: Response<proto::UpdateParticipantLocation>,
1711 session: Session,
1712) -> Result<()> {
1713 let room_id = RoomId::from_proto(request.room_id);
1714 let location = request
1715 .location
1716 .ok_or_else(|| anyhow!("invalid location"))?;
1717
1718 let db = session.db().await;
1719 let room = db
1720 .update_room_participant_location(room_id, session.connection_id, location)
1721 .await?;
1722
1723 room_updated(&room, &session.peer);
1724 response.send(proto::Ack {})?;
1725 Ok(())
1726}
1727
1728/// Share a project into the room.
1729async fn share_project(
1730 request: proto::ShareProject,
1731 response: Response<proto::ShareProject>,
1732 session: Session,
1733) -> Result<()> {
1734 let (project_id, room) = &*session
1735 .db()
1736 .await
1737 .share_project(
1738 RoomId::from_proto(request.room_id),
1739 session.connection_id,
1740 &request.worktrees,
1741 request.is_ssh_project,
1742 )
1743 .await?;
1744 response.send(proto::ShareProjectResponse {
1745 project_id: project_id.to_proto(),
1746 })?;
1747 room_updated(room, &session.peer);
1748
1749 Ok(())
1750}
1751
1752/// Unshare a project from the room.
1753async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1754 let project_id = ProjectId::from_proto(message.project_id);
1755 unshare_project_internal(project_id, session.connection_id, &session).await
1756}
1757
1758async fn unshare_project_internal(
1759 project_id: ProjectId,
1760 connection_id: ConnectionId,
1761 session: &Session,
1762) -> Result<()> {
1763 let delete = {
1764 let room_guard = session
1765 .db()
1766 .await
1767 .unshare_project(project_id, connection_id)
1768 .await?;
1769
1770 let (delete, room, guest_connection_ids) = &*room_guard;
1771
1772 let message = proto::UnshareProject {
1773 project_id: project_id.to_proto(),
1774 };
1775
1776 broadcast(
1777 Some(connection_id),
1778 guest_connection_ids.iter().copied(),
1779 |conn_id| session.peer.send(conn_id, message.clone()),
1780 );
1781 if let Some(room) = room {
1782 room_updated(room, &session.peer);
1783 }
1784
1785 *delete
1786 };
1787
1788 if delete {
1789 let db = session.db().await;
1790 db.delete_project(project_id).await?;
1791 }
1792
1793 Ok(())
1794}
1795
1796/// Join someone elses shared project.
1797async fn join_project(
1798 request: proto::JoinProject,
1799 response: Response<proto::JoinProject>,
1800 session: Session,
1801) -> Result<()> {
1802 let project_id = ProjectId::from_proto(request.project_id);
1803
1804 tracing::info!(%project_id, "join project");
1805
1806 let db = session.db().await;
1807 let (project, replica_id) = &mut *db
1808 .join_project(project_id, session.connection_id, session.user_id())
1809 .await?;
1810 drop(db);
1811 tracing::info!(%project_id, "join remote project");
1812 join_project_internal(response, session, project, replica_id)
1813}
1814
1815trait JoinProjectInternalResponse {
1816 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1817}
1818impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1819 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1820 Response::<proto::JoinProject>::send(self, result)
1821 }
1822}
1823
1824fn join_project_internal(
1825 response: impl JoinProjectInternalResponse,
1826 session: Session,
1827 project: &mut Project,
1828 replica_id: &ReplicaId,
1829) -> Result<()> {
1830 let collaborators = project
1831 .collaborators
1832 .iter()
1833 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1834 .map(|collaborator| collaborator.to_proto())
1835 .collect::<Vec<_>>();
1836 let project_id = project.id;
1837 let guest_user_id = session.user_id();
1838
1839 let worktrees = project
1840 .worktrees
1841 .iter()
1842 .map(|(id, worktree)| proto::WorktreeMetadata {
1843 id: *id,
1844 root_name: worktree.root_name.clone(),
1845 visible: worktree.visible,
1846 abs_path: worktree.abs_path.clone(),
1847 })
1848 .collect::<Vec<_>>();
1849
1850 let add_project_collaborator = proto::AddProjectCollaborator {
1851 project_id: project_id.to_proto(),
1852 collaborator: Some(proto::Collaborator {
1853 peer_id: Some(session.connection_id.into()),
1854 replica_id: replica_id.0 as u32,
1855 user_id: guest_user_id.to_proto(),
1856 is_host: false,
1857 }),
1858 };
1859
1860 for collaborator in &collaborators {
1861 session
1862 .peer
1863 .send(
1864 collaborator.peer_id.unwrap().into(),
1865 add_project_collaborator.clone(),
1866 )
1867 .trace_err();
1868 }
1869
1870 // First, we send the metadata associated with each worktree.
1871 response.send(proto::JoinProjectResponse {
1872 project_id: project.id.0 as u64,
1873 worktrees: worktrees.clone(),
1874 replica_id: replica_id.0 as u32,
1875 collaborators: collaborators.clone(),
1876 language_servers: project.language_servers.clone(),
1877 role: project.role.into(),
1878 })?;
1879
1880 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1881 // Stream this worktree's entries.
1882 let message = proto::UpdateWorktree {
1883 project_id: project_id.to_proto(),
1884 worktree_id,
1885 abs_path: worktree.abs_path.clone(),
1886 root_name: worktree.root_name,
1887 updated_entries: worktree.entries,
1888 removed_entries: Default::default(),
1889 scan_id: worktree.scan_id,
1890 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1891 updated_repositories: worktree.repository_entries.into_values().collect(),
1892 removed_repositories: Default::default(),
1893 };
1894 for update in proto::split_worktree_update(message) {
1895 session.peer.send(session.connection_id, update.clone())?;
1896 }
1897
1898 // Stream this worktree's diagnostics.
1899 for summary in worktree.diagnostic_summaries {
1900 session.peer.send(
1901 session.connection_id,
1902 proto::UpdateDiagnosticSummary {
1903 project_id: project_id.to_proto(),
1904 worktree_id: worktree.id,
1905 summary: Some(summary),
1906 },
1907 )?;
1908 }
1909
1910 for settings_file in worktree.settings_files {
1911 session.peer.send(
1912 session.connection_id,
1913 proto::UpdateWorktreeSettings {
1914 project_id: project_id.to_proto(),
1915 worktree_id: worktree.id,
1916 path: settings_file.path,
1917 content: Some(settings_file.content),
1918 kind: Some(settings_file.kind.to_proto() as i32),
1919 },
1920 )?;
1921 }
1922 }
1923
1924 for language_server in &project.language_servers {
1925 session.peer.send(
1926 session.connection_id,
1927 proto::UpdateLanguageServer {
1928 project_id: project_id.to_proto(),
1929 language_server_id: language_server.id,
1930 variant: Some(
1931 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1932 proto::LspDiskBasedDiagnosticsUpdated {},
1933 ),
1934 ),
1935 },
1936 )?;
1937 }
1938
1939 Ok(())
1940}
1941
1942/// Leave someone elses shared project.
1943async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1944 let sender_id = session.connection_id;
1945 let project_id = ProjectId::from_proto(request.project_id);
1946 let db = session.db().await;
1947
1948 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1949 tracing::info!(
1950 %project_id,
1951 "leave project"
1952 );
1953
1954 project_left(project, &session);
1955 if let Some(room) = room {
1956 room_updated(room, &session.peer);
1957 }
1958
1959 Ok(())
1960}
1961
1962/// Updates other participants with changes to the project
1963async fn update_project(
1964 request: proto::UpdateProject,
1965 response: Response<proto::UpdateProject>,
1966 session: Session,
1967) -> Result<()> {
1968 let project_id = ProjectId::from_proto(request.project_id);
1969 let (room, guest_connection_ids) = &*session
1970 .db()
1971 .await
1972 .update_project(project_id, session.connection_id, &request.worktrees)
1973 .await?;
1974 broadcast(
1975 Some(session.connection_id),
1976 guest_connection_ids.iter().copied(),
1977 |connection_id| {
1978 session
1979 .peer
1980 .forward_send(session.connection_id, connection_id, request.clone())
1981 },
1982 );
1983 if let Some(room) = room {
1984 room_updated(room, &session.peer);
1985 }
1986 response.send(proto::Ack {})?;
1987
1988 Ok(())
1989}
1990
1991/// Updates other participants with changes to the worktree
1992async fn update_worktree(
1993 request: proto::UpdateWorktree,
1994 response: Response<proto::UpdateWorktree>,
1995 session: Session,
1996) -> Result<()> {
1997 let guest_connection_ids = session
1998 .db()
1999 .await
2000 .update_worktree(&request, session.connection_id)
2001 .await?;
2002
2003 broadcast(
2004 Some(session.connection_id),
2005 guest_connection_ids.iter().copied(),
2006 |connection_id| {
2007 session
2008 .peer
2009 .forward_send(session.connection_id, connection_id, request.clone())
2010 },
2011 );
2012 response.send(proto::Ack {})?;
2013 Ok(())
2014}
2015
2016/// Updates other participants with changes to the diagnostics
2017async fn update_diagnostic_summary(
2018 message: proto::UpdateDiagnosticSummary,
2019 session: Session,
2020) -> Result<()> {
2021 let guest_connection_ids = session
2022 .db()
2023 .await
2024 .update_diagnostic_summary(&message, session.connection_id)
2025 .await?;
2026
2027 broadcast(
2028 Some(session.connection_id),
2029 guest_connection_ids.iter().copied(),
2030 |connection_id| {
2031 session
2032 .peer
2033 .forward_send(session.connection_id, connection_id, message.clone())
2034 },
2035 );
2036
2037 Ok(())
2038}
2039
2040/// Updates other participants with changes to the worktree settings
2041async fn update_worktree_settings(
2042 message: proto::UpdateWorktreeSettings,
2043 session: Session,
2044) -> Result<()> {
2045 let guest_connection_ids = session
2046 .db()
2047 .await
2048 .update_worktree_settings(&message, session.connection_id)
2049 .await?;
2050
2051 broadcast(
2052 Some(session.connection_id),
2053 guest_connection_ids.iter().copied(),
2054 |connection_id| {
2055 session
2056 .peer
2057 .forward_send(session.connection_id, connection_id, message.clone())
2058 },
2059 );
2060
2061 Ok(())
2062}
2063
2064/// Notify other participants that a language server has started.
2065async fn start_language_server(
2066 request: proto::StartLanguageServer,
2067 session: Session,
2068) -> Result<()> {
2069 let guest_connection_ids = session
2070 .db()
2071 .await
2072 .start_language_server(&request, session.connection_id)
2073 .await?;
2074
2075 broadcast(
2076 Some(session.connection_id),
2077 guest_connection_ids.iter().copied(),
2078 |connection_id| {
2079 session
2080 .peer
2081 .forward_send(session.connection_id, connection_id, request.clone())
2082 },
2083 );
2084 Ok(())
2085}
2086
2087/// Notify other participants that a language server has changed.
2088async fn update_language_server(
2089 request: proto::UpdateLanguageServer,
2090 session: Session,
2091) -> Result<()> {
2092 let project_id = ProjectId::from_proto(request.project_id);
2093 let project_connection_ids = session
2094 .db()
2095 .await
2096 .project_connection_ids(project_id, session.connection_id, true)
2097 .await?;
2098 broadcast(
2099 Some(session.connection_id),
2100 project_connection_ids.iter().copied(),
2101 |connection_id| {
2102 session
2103 .peer
2104 .forward_send(session.connection_id, connection_id, request.clone())
2105 },
2106 );
2107 Ok(())
2108}
2109
2110/// forward a project request to the host. These requests should be read only
2111/// as guests are allowed to send them.
2112async fn forward_read_only_project_request<T>(
2113 request: T,
2114 response: Response<T>,
2115 session: Session,
2116) -> Result<()>
2117where
2118 T: EntityMessage + RequestMessage,
2119{
2120 let project_id = ProjectId::from_proto(request.remote_entity_id());
2121 let host_connection_id = session
2122 .db()
2123 .await
2124 .host_for_read_only_project_request(project_id, session.connection_id)
2125 .await?;
2126 let payload = session
2127 .peer
2128 .forward_request(session.connection_id, host_connection_id, request)
2129 .await?;
2130 response.send(payload)?;
2131 Ok(())
2132}
2133
2134async fn forward_find_search_candidates_request(
2135 request: proto::FindSearchCandidates,
2136 response: Response<proto::FindSearchCandidates>,
2137 session: Session,
2138) -> Result<()> {
2139 let project_id = ProjectId::from_proto(request.remote_entity_id());
2140 let host_connection_id = session
2141 .db()
2142 .await
2143 .host_for_read_only_project_request(project_id, session.connection_id)
2144 .await?;
2145 let payload = session
2146 .peer
2147 .forward_request(session.connection_id, host_connection_id, request)
2148 .await?;
2149 response.send(payload)?;
2150 Ok(())
2151}
2152
2153/// forward a project request to the host. These requests are disallowed
2154/// for guests.
2155async fn forward_mutating_project_request<T>(
2156 request: T,
2157 response: Response<T>,
2158 session: Session,
2159) -> Result<()>
2160where
2161 T: EntityMessage + RequestMessage,
2162{
2163 let project_id = ProjectId::from_proto(request.remote_entity_id());
2164
2165 let host_connection_id = session
2166 .db()
2167 .await
2168 .host_for_mutating_project_request(project_id, session.connection_id)
2169 .await?;
2170 let payload = session
2171 .peer
2172 .forward_request(session.connection_id, host_connection_id, request)
2173 .await?;
2174 response.send(payload)?;
2175 Ok(())
2176}
2177
2178/// Notify other participants that a new buffer has been created
2179async fn create_buffer_for_peer(
2180 request: proto::CreateBufferForPeer,
2181 session: Session,
2182) -> Result<()> {
2183 session
2184 .db()
2185 .await
2186 .check_user_is_project_host(
2187 ProjectId::from_proto(request.project_id),
2188 session.connection_id,
2189 )
2190 .await?;
2191 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2192 session
2193 .peer
2194 .forward_send(session.connection_id, peer_id.into(), request)?;
2195 Ok(())
2196}
2197
2198/// Notify other participants that a buffer has been updated. This is
2199/// allowed for guests as long as the update is limited to selections.
2200async fn update_buffer(
2201 request: proto::UpdateBuffer,
2202 response: Response<proto::UpdateBuffer>,
2203 session: Session,
2204) -> Result<()> {
2205 let project_id = ProjectId::from_proto(request.project_id);
2206 let mut capability = Capability::ReadOnly;
2207
2208 for op in request.operations.iter() {
2209 match op.variant {
2210 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2211 Some(_) => capability = Capability::ReadWrite,
2212 }
2213 }
2214
2215 let host = {
2216 let guard = session
2217 .db()
2218 .await
2219 .connections_for_buffer_update(project_id, session.connection_id, capability)
2220 .await?;
2221
2222 let (host, guests) = &*guard;
2223
2224 broadcast(
2225 Some(session.connection_id),
2226 guests.clone(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, request.clone())
2231 },
2232 );
2233
2234 *host
2235 };
2236
2237 if host != session.connection_id {
2238 session
2239 .peer
2240 .forward_request(session.connection_id, host, request.clone())
2241 .await?;
2242 }
2243
2244 response.send(proto::Ack {})?;
2245 Ok(())
2246}
2247
2248async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2249 let project_id = ProjectId::from_proto(message.project_id);
2250
2251 let operation = message.operation.as_ref().context("invalid operation")?;
2252 let capability = match operation.variant.as_ref() {
2253 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2254 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2255 match buffer_op.variant {
2256 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2257 Capability::ReadOnly
2258 }
2259 _ => Capability::ReadWrite,
2260 }
2261 } else {
2262 Capability::ReadWrite
2263 }
2264 }
2265 Some(_) => Capability::ReadWrite,
2266 None => Capability::ReadOnly,
2267 };
2268
2269 let guard = session
2270 .db()
2271 .await
2272 .connections_for_buffer_update(project_id, session.connection_id, capability)
2273 .await?;
2274
2275 let (host, guests) = &*guard;
2276
2277 broadcast(
2278 Some(session.connection_id),
2279 guests.iter().chain([host]).copied(),
2280 |connection_id| {
2281 session
2282 .peer
2283 .forward_send(session.connection_id, connection_id, message.clone())
2284 },
2285 );
2286
2287 Ok(())
2288}
2289
2290/// Notify other participants that a project has been updated.
2291async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2292 request: T,
2293 session: Session,
2294) -> Result<()> {
2295 let project_id = ProjectId::from_proto(request.remote_entity_id());
2296 let project_connection_ids = session
2297 .db()
2298 .await
2299 .project_connection_ids(project_id, session.connection_id, false)
2300 .await?;
2301
2302 broadcast(
2303 Some(session.connection_id),
2304 project_connection_ids.iter().copied(),
2305 |connection_id| {
2306 session
2307 .peer
2308 .forward_send(session.connection_id, connection_id, request.clone())
2309 },
2310 );
2311 Ok(())
2312}
2313
2314/// Start following another user in a call.
2315async fn follow(
2316 request: proto::Follow,
2317 response: Response<proto::Follow>,
2318 session: Session,
2319) -> Result<()> {
2320 let room_id = RoomId::from_proto(request.room_id);
2321 let project_id = request.project_id.map(ProjectId::from_proto);
2322 let leader_id = request
2323 .leader_id
2324 .ok_or_else(|| anyhow!("invalid leader id"))?
2325 .into();
2326 let follower_id = session.connection_id;
2327
2328 session
2329 .db()
2330 .await
2331 .check_room_participants(room_id, leader_id, session.connection_id)
2332 .await?;
2333
2334 let response_payload = session
2335 .peer
2336 .forward_request(session.connection_id, leader_id, request)
2337 .await?;
2338 response.send(response_payload)?;
2339
2340 if let Some(project_id) = project_id {
2341 let room = session
2342 .db()
2343 .await
2344 .follow(room_id, project_id, leader_id, follower_id)
2345 .await?;
2346 room_updated(&room, &session.peer);
2347 }
2348
2349 Ok(())
2350}
2351
2352/// Stop following another user in a call.
2353async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2354 let room_id = RoomId::from_proto(request.room_id);
2355 let project_id = request.project_id.map(ProjectId::from_proto);
2356 let leader_id = request
2357 .leader_id
2358 .ok_or_else(|| anyhow!("invalid leader id"))?
2359 .into();
2360 let follower_id = session.connection_id;
2361
2362 session
2363 .db()
2364 .await
2365 .check_room_participants(room_id, leader_id, session.connection_id)
2366 .await?;
2367
2368 session
2369 .peer
2370 .forward_send(session.connection_id, leader_id, request)?;
2371
2372 if let Some(project_id) = project_id {
2373 let room = session
2374 .db()
2375 .await
2376 .unfollow(room_id, project_id, leader_id, follower_id)
2377 .await?;
2378 room_updated(&room, &session.peer);
2379 }
2380
2381 Ok(())
2382}
2383
2384/// Notify everyone following you of your current location.
2385async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2386 let room_id = RoomId::from_proto(request.room_id);
2387 let database = session.db.lock().await;
2388
2389 let connection_ids = if let Some(project_id) = request.project_id {
2390 let project_id = ProjectId::from_proto(project_id);
2391 database
2392 .project_connection_ids(project_id, session.connection_id, true)
2393 .await?
2394 } else {
2395 database
2396 .room_connection_ids(room_id, session.connection_id)
2397 .await?
2398 };
2399
2400 // For now, don't send view update messages back to that view's current leader.
2401 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2402 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2403 _ => None,
2404 });
2405
2406 for connection_id in connection_ids.iter().cloned() {
2407 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2408 session
2409 .peer
2410 .forward_send(session.connection_id, connection_id, request.clone())?;
2411 }
2412 }
2413 Ok(())
2414}
2415
2416/// Get public data about users.
2417async fn get_users(
2418 request: proto::GetUsers,
2419 response: Response<proto::GetUsers>,
2420 session: Session,
2421) -> Result<()> {
2422 let user_ids = request
2423 .user_ids
2424 .into_iter()
2425 .map(UserId::from_proto)
2426 .collect();
2427 let users = session
2428 .db()
2429 .await
2430 .get_users_by_ids(user_ids)
2431 .await?
2432 .into_iter()
2433 .map(|user| proto::User {
2434 id: user.id.to_proto(),
2435 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2436 github_login: user.github_login,
2437 email: user.email_address,
2438 name: user.name,
2439 })
2440 .collect();
2441 response.send(proto::UsersResponse { users })?;
2442 Ok(())
2443}
2444
2445/// Search for users (to invite) buy Github login
2446async fn fuzzy_search_users(
2447 request: proto::FuzzySearchUsers,
2448 response: Response<proto::FuzzySearchUsers>,
2449 session: Session,
2450) -> Result<()> {
2451 let query = request.query;
2452 let users = match query.len() {
2453 0 => vec![],
2454 1 | 2 => session
2455 .db()
2456 .await
2457 .get_user_by_github_login(&query)
2458 .await?
2459 .into_iter()
2460 .collect(),
2461 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2462 };
2463 let users = users
2464 .into_iter()
2465 .filter(|user| user.id != session.user_id())
2466 .map(|user| proto::User {
2467 id: user.id.to_proto(),
2468 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2469 github_login: user.github_login,
2470 name: user.name,
2471 email: user.email_address,
2472 })
2473 .collect();
2474 response.send(proto::UsersResponse { users })?;
2475 Ok(())
2476}
2477
2478/// Send a contact request to another user.
2479async fn request_contact(
2480 request: proto::RequestContact,
2481 response: Response<proto::RequestContact>,
2482 session: Session,
2483) -> Result<()> {
2484 let requester_id = session.user_id();
2485 let responder_id = UserId::from_proto(request.responder_id);
2486 if requester_id == responder_id {
2487 return Err(anyhow!("cannot add yourself as a contact"))?;
2488 }
2489
2490 let notifications = session
2491 .db()
2492 .await
2493 .send_contact_request(requester_id, responder_id)
2494 .await?;
2495
2496 // Update outgoing contact requests of requester
2497 let mut update = proto::UpdateContacts::default();
2498 update.outgoing_requests.push(responder_id.to_proto());
2499 for connection_id in session
2500 .connection_pool()
2501 .await
2502 .user_connection_ids(requester_id)
2503 {
2504 session.peer.send(connection_id, update.clone())?;
2505 }
2506
2507 // Update incoming contact requests of responder
2508 let mut update = proto::UpdateContacts::default();
2509 update
2510 .incoming_requests
2511 .push(proto::IncomingContactRequest {
2512 requester_id: requester_id.to_proto(),
2513 });
2514 let connection_pool = session.connection_pool().await;
2515 for connection_id in connection_pool.user_connection_ids(responder_id) {
2516 session.peer.send(connection_id, update.clone())?;
2517 }
2518
2519 send_notifications(&connection_pool, &session.peer, notifications);
2520
2521 response.send(proto::Ack {})?;
2522 Ok(())
2523}
2524
2525/// Accept or decline a contact request
2526async fn respond_to_contact_request(
2527 request: proto::RespondToContactRequest,
2528 response: Response<proto::RespondToContactRequest>,
2529 session: Session,
2530) -> Result<()> {
2531 let responder_id = session.user_id();
2532 let requester_id = UserId::from_proto(request.requester_id);
2533 let db = session.db().await;
2534 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2535 db.dismiss_contact_notification(responder_id, requester_id)
2536 .await?;
2537 } else {
2538 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2539
2540 let notifications = db
2541 .respond_to_contact_request(responder_id, requester_id, accept)
2542 .await?;
2543 let requester_busy = db.is_user_busy(requester_id).await?;
2544 let responder_busy = db.is_user_busy(responder_id).await?;
2545
2546 let pool = session.connection_pool().await;
2547 // Update responder with new contact
2548 let mut update = proto::UpdateContacts::default();
2549 if accept {
2550 update
2551 .contacts
2552 .push(contact_for_user(requester_id, requester_busy, &pool));
2553 }
2554 update
2555 .remove_incoming_requests
2556 .push(requester_id.to_proto());
2557 for connection_id in pool.user_connection_ids(responder_id) {
2558 session.peer.send(connection_id, update.clone())?;
2559 }
2560
2561 // Update requester with new contact
2562 let mut update = proto::UpdateContacts::default();
2563 if accept {
2564 update
2565 .contacts
2566 .push(contact_for_user(responder_id, responder_busy, &pool));
2567 }
2568 update
2569 .remove_outgoing_requests
2570 .push(responder_id.to_proto());
2571
2572 for connection_id in pool.user_connection_ids(requester_id) {
2573 session.peer.send(connection_id, update.clone())?;
2574 }
2575
2576 send_notifications(&pool, &session.peer, notifications);
2577 }
2578
2579 response.send(proto::Ack {})?;
2580 Ok(())
2581}
2582
2583/// Remove a contact.
2584async fn remove_contact(
2585 request: proto::RemoveContact,
2586 response: Response<proto::RemoveContact>,
2587 session: Session,
2588) -> Result<()> {
2589 let requester_id = session.user_id();
2590 let responder_id = UserId::from_proto(request.user_id);
2591 let db = session.db().await;
2592 let (contact_accepted, deleted_notification_id) =
2593 db.remove_contact(requester_id, responder_id).await?;
2594
2595 let pool = session.connection_pool().await;
2596 // Update outgoing contact requests of requester
2597 let mut update = proto::UpdateContacts::default();
2598 if contact_accepted {
2599 update.remove_contacts.push(responder_id.to_proto());
2600 } else {
2601 update
2602 .remove_outgoing_requests
2603 .push(responder_id.to_proto());
2604 }
2605 for connection_id in pool.user_connection_ids(requester_id) {
2606 session.peer.send(connection_id, update.clone())?;
2607 }
2608
2609 // Update incoming contact requests of responder
2610 let mut update = proto::UpdateContacts::default();
2611 if contact_accepted {
2612 update.remove_contacts.push(requester_id.to_proto());
2613 } else {
2614 update
2615 .remove_incoming_requests
2616 .push(requester_id.to_proto());
2617 }
2618 for connection_id in pool.user_connection_ids(responder_id) {
2619 session.peer.send(connection_id, update.clone())?;
2620 if let Some(notification_id) = deleted_notification_id {
2621 session.peer.send(
2622 connection_id,
2623 proto::DeleteNotification {
2624 notification_id: notification_id.to_proto(),
2625 },
2626 )?;
2627 }
2628 }
2629
2630 response.send(proto::Ack {})?;
2631 Ok(())
2632}
2633
2634fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2635 version.0.minor() < 139
2636}
2637
2638async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2639 let plan = session.current_plan(&session.db().await).await?;
2640
2641 session
2642 .peer
2643 .send(
2644 session.connection_id,
2645 proto::UpdateUserPlan { plan: plan.into() },
2646 )
2647 .trace_err();
2648
2649 Ok(())
2650}
2651
2652async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2653 subscribe_user_to_channels(session.user_id(), &session).await?;
2654 Ok(())
2655}
2656
2657async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2658 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2659 let mut pool = session.connection_pool().await;
2660 for membership in &channels_for_user.channel_memberships {
2661 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2662 }
2663 session.peer.send(
2664 session.connection_id,
2665 build_update_user_channels(&channels_for_user),
2666 )?;
2667 session.peer.send(
2668 session.connection_id,
2669 build_channels_update(channels_for_user),
2670 )?;
2671 Ok(())
2672}
2673
2674/// Creates a new channel.
2675async fn create_channel(
2676 request: proto::CreateChannel,
2677 response: Response<proto::CreateChannel>,
2678 session: Session,
2679) -> Result<()> {
2680 let db = session.db().await;
2681
2682 let parent_id = request.parent_id.map(ChannelId::from_proto);
2683 let (channel, membership) = db
2684 .create_channel(&request.name, parent_id, session.user_id())
2685 .await?;
2686
2687 let root_id = channel.root_id();
2688 let channel = Channel::from_model(channel);
2689
2690 response.send(proto::CreateChannelResponse {
2691 channel: Some(channel.to_proto()),
2692 parent_id: request.parent_id,
2693 })?;
2694
2695 let mut connection_pool = session.connection_pool().await;
2696 if let Some(membership) = membership {
2697 connection_pool.subscribe_to_channel(
2698 membership.user_id,
2699 membership.channel_id,
2700 membership.role,
2701 );
2702 let update = proto::UpdateUserChannels {
2703 channel_memberships: vec![proto::ChannelMembership {
2704 channel_id: membership.channel_id.to_proto(),
2705 role: membership.role.into(),
2706 }],
2707 ..Default::default()
2708 };
2709 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2710 session.peer.send(connection_id, update.clone())?;
2711 }
2712 }
2713
2714 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2715 if !role.can_see_channel(channel.visibility) {
2716 continue;
2717 }
2718
2719 let update = proto::UpdateChannels {
2720 channels: vec![channel.to_proto()],
2721 ..Default::default()
2722 };
2723 session.peer.send(connection_id, update.clone())?;
2724 }
2725
2726 Ok(())
2727}
2728
2729/// Delete a channel
2730async fn delete_channel(
2731 request: proto::DeleteChannel,
2732 response: Response<proto::DeleteChannel>,
2733 session: Session,
2734) -> Result<()> {
2735 let db = session.db().await;
2736
2737 let channel_id = request.channel_id;
2738 let (root_channel, removed_channels) = db
2739 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2740 .await?;
2741 response.send(proto::Ack {})?;
2742
2743 // Notify members of removed channels
2744 let mut update = proto::UpdateChannels::default();
2745 update
2746 .delete_channels
2747 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2748
2749 let connection_pool = session.connection_pool().await;
2750 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2751 session.peer.send(connection_id, update.clone())?;
2752 }
2753
2754 Ok(())
2755}
2756
2757/// Invite someone to join a channel.
2758async fn invite_channel_member(
2759 request: proto::InviteChannelMember,
2760 response: Response<proto::InviteChannelMember>,
2761 session: Session,
2762) -> Result<()> {
2763 let db = session.db().await;
2764 let channel_id = ChannelId::from_proto(request.channel_id);
2765 let invitee_id = UserId::from_proto(request.user_id);
2766 let InviteMemberResult {
2767 channel,
2768 notifications,
2769 } = db
2770 .invite_channel_member(
2771 channel_id,
2772 invitee_id,
2773 session.user_id(),
2774 request.role().into(),
2775 )
2776 .await?;
2777
2778 let update = proto::UpdateChannels {
2779 channel_invitations: vec![channel.to_proto()],
2780 ..Default::default()
2781 };
2782
2783 let connection_pool = session.connection_pool().await;
2784 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2785 session.peer.send(connection_id, update.clone())?;
2786 }
2787
2788 send_notifications(&connection_pool, &session.peer, notifications);
2789
2790 response.send(proto::Ack {})?;
2791 Ok(())
2792}
2793
2794/// remove someone from a channel
2795async fn remove_channel_member(
2796 request: proto::RemoveChannelMember,
2797 response: Response<proto::RemoveChannelMember>,
2798 session: Session,
2799) -> Result<()> {
2800 let db = session.db().await;
2801 let channel_id = ChannelId::from_proto(request.channel_id);
2802 let member_id = UserId::from_proto(request.user_id);
2803
2804 let RemoveChannelMemberResult {
2805 membership_update,
2806 notification_id,
2807 } = db
2808 .remove_channel_member(channel_id, member_id, session.user_id())
2809 .await?;
2810
2811 let mut connection_pool = session.connection_pool().await;
2812 notify_membership_updated(
2813 &mut connection_pool,
2814 membership_update,
2815 member_id,
2816 &session.peer,
2817 );
2818 for connection_id in connection_pool.user_connection_ids(member_id) {
2819 if let Some(notification_id) = notification_id {
2820 session
2821 .peer
2822 .send(
2823 connection_id,
2824 proto::DeleteNotification {
2825 notification_id: notification_id.to_proto(),
2826 },
2827 )
2828 .trace_err();
2829 }
2830 }
2831
2832 response.send(proto::Ack {})?;
2833 Ok(())
2834}
2835
2836/// Toggle the channel between public and private.
2837/// Care is taken to maintain the invariant that public channels only descend from public channels,
2838/// (though members-only channels can appear at any point in the hierarchy).
2839async fn set_channel_visibility(
2840 request: proto::SetChannelVisibility,
2841 response: Response<proto::SetChannelVisibility>,
2842 session: Session,
2843) -> Result<()> {
2844 let db = session.db().await;
2845 let channel_id = ChannelId::from_proto(request.channel_id);
2846 let visibility = request.visibility().into();
2847
2848 let channel_model = db
2849 .set_channel_visibility(channel_id, visibility, session.user_id())
2850 .await?;
2851 let root_id = channel_model.root_id();
2852 let channel = Channel::from_model(channel_model);
2853
2854 let mut connection_pool = session.connection_pool().await;
2855 for (user_id, role) in connection_pool
2856 .channel_user_ids(root_id)
2857 .collect::<Vec<_>>()
2858 .into_iter()
2859 {
2860 let update = if role.can_see_channel(channel.visibility) {
2861 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2862 proto::UpdateChannels {
2863 channels: vec![channel.to_proto()],
2864 ..Default::default()
2865 }
2866 } else {
2867 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2868 proto::UpdateChannels {
2869 delete_channels: vec![channel.id.to_proto()],
2870 ..Default::default()
2871 }
2872 };
2873
2874 for connection_id in connection_pool.user_connection_ids(user_id) {
2875 session.peer.send(connection_id, update.clone())?;
2876 }
2877 }
2878
2879 response.send(proto::Ack {})?;
2880 Ok(())
2881}
2882
2883/// Alter the role for a user in the channel.
2884async fn set_channel_member_role(
2885 request: proto::SetChannelMemberRole,
2886 response: Response<proto::SetChannelMemberRole>,
2887 session: Session,
2888) -> Result<()> {
2889 let db = session.db().await;
2890 let channel_id = ChannelId::from_proto(request.channel_id);
2891 let member_id = UserId::from_proto(request.user_id);
2892 let result = db
2893 .set_channel_member_role(
2894 channel_id,
2895 session.user_id(),
2896 member_id,
2897 request.role().into(),
2898 )
2899 .await?;
2900
2901 match result {
2902 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2903 let mut connection_pool = session.connection_pool().await;
2904 notify_membership_updated(
2905 &mut connection_pool,
2906 membership_update,
2907 member_id,
2908 &session.peer,
2909 )
2910 }
2911 db::SetMemberRoleResult::InviteUpdated(channel) => {
2912 let update = proto::UpdateChannels {
2913 channel_invitations: vec![channel.to_proto()],
2914 ..Default::default()
2915 };
2916
2917 for connection_id in session
2918 .connection_pool()
2919 .await
2920 .user_connection_ids(member_id)
2921 {
2922 session.peer.send(connection_id, update.clone())?;
2923 }
2924 }
2925 }
2926
2927 response.send(proto::Ack {})?;
2928 Ok(())
2929}
2930
2931/// Change the name of a channel
2932async fn rename_channel(
2933 request: proto::RenameChannel,
2934 response: Response<proto::RenameChannel>,
2935 session: Session,
2936) -> Result<()> {
2937 let db = session.db().await;
2938 let channel_id = ChannelId::from_proto(request.channel_id);
2939 let channel_model = db
2940 .rename_channel(channel_id, session.user_id(), &request.name)
2941 .await?;
2942 let root_id = channel_model.root_id();
2943 let channel = Channel::from_model(channel_model);
2944
2945 response.send(proto::RenameChannelResponse {
2946 channel: Some(channel.to_proto()),
2947 })?;
2948
2949 let connection_pool = session.connection_pool().await;
2950 let update = proto::UpdateChannels {
2951 channels: vec![channel.to_proto()],
2952 ..Default::default()
2953 };
2954 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2955 if role.can_see_channel(channel.visibility) {
2956 session.peer.send(connection_id, update.clone())?;
2957 }
2958 }
2959
2960 Ok(())
2961}
2962
2963/// Move a channel to a new parent.
2964async fn move_channel(
2965 request: proto::MoveChannel,
2966 response: Response<proto::MoveChannel>,
2967 session: Session,
2968) -> Result<()> {
2969 let channel_id = ChannelId::from_proto(request.channel_id);
2970 let to = ChannelId::from_proto(request.to);
2971
2972 let (root_id, channels) = session
2973 .db()
2974 .await
2975 .move_channel(channel_id, to, session.user_id())
2976 .await?;
2977
2978 let connection_pool = session.connection_pool().await;
2979 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2980 let channels = channels
2981 .iter()
2982 .filter_map(|channel| {
2983 if role.can_see_channel(channel.visibility) {
2984 Some(channel.to_proto())
2985 } else {
2986 None
2987 }
2988 })
2989 .collect::<Vec<_>>();
2990 if channels.is_empty() {
2991 continue;
2992 }
2993
2994 let update = proto::UpdateChannels {
2995 channels,
2996 ..Default::default()
2997 };
2998
2999 session.peer.send(connection_id, update.clone())?;
3000 }
3001
3002 response.send(Ack {})?;
3003 Ok(())
3004}
3005
3006/// Get the list of channel members
3007async fn get_channel_members(
3008 request: proto::GetChannelMembers,
3009 response: Response<proto::GetChannelMembers>,
3010 session: Session,
3011) -> Result<()> {
3012 let db = session.db().await;
3013 let channel_id = ChannelId::from_proto(request.channel_id);
3014 let limit = if request.limit == 0 {
3015 u16::MAX as u64
3016 } else {
3017 request.limit
3018 };
3019 let (members, users) = db
3020 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3021 .await?;
3022 response.send(proto::GetChannelMembersResponse { members, users })?;
3023 Ok(())
3024}
3025
3026/// Accept or decline a channel invitation.
3027async fn respond_to_channel_invite(
3028 request: proto::RespondToChannelInvite,
3029 response: Response<proto::RespondToChannelInvite>,
3030 session: Session,
3031) -> Result<()> {
3032 let db = session.db().await;
3033 let channel_id = ChannelId::from_proto(request.channel_id);
3034 let RespondToChannelInvite {
3035 membership_update,
3036 notifications,
3037 } = db
3038 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3039 .await?;
3040
3041 let mut connection_pool = session.connection_pool().await;
3042 if let Some(membership_update) = membership_update {
3043 notify_membership_updated(
3044 &mut connection_pool,
3045 membership_update,
3046 session.user_id(),
3047 &session.peer,
3048 );
3049 } else {
3050 let update = proto::UpdateChannels {
3051 remove_channel_invitations: vec![channel_id.to_proto()],
3052 ..Default::default()
3053 };
3054
3055 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3056 session.peer.send(connection_id, update.clone())?;
3057 }
3058 };
3059
3060 send_notifications(&connection_pool, &session.peer, notifications);
3061
3062 response.send(proto::Ack {})?;
3063
3064 Ok(())
3065}
3066
3067/// Join the channels' room
3068async fn join_channel(
3069 request: proto::JoinChannel,
3070 response: Response<proto::JoinChannel>,
3071 session: Session,
3072) -> Result<()> {
3073 let channel_id = ChannelId::from_proto(request.channel_id);
3074 join_channel_internal(channel_id, Box::new(response), session).await
3075}
3076
3077trait JoinChannelInternalResponse {
3078 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3079}
3080impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3081 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3082 Response::<proto::JoinChannel>::send(self, result)
3083 }
3084}
3085impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3086 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3087 Response::<proto::JoinRoom>::send(self, result)
3088 }
3089}
3090
3091async fn join_channel_internal(
3092 channel_id: ChannelId,
3093 response: Box<impl JoinChannelInternalResponse>,
3094 session: Session,
3095) -> Result<()> {
3096 let joined_room = {
3097 let mut db = session.db().await;
3098 // If zed quits without leaving the room, and the user re-opens zed before the
3099 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3100 // room they were in.
3101 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3102 tracing::info!(
3103 stale_connection_id = %connection,
3104 "cleaning up stale connection",
3105 );
3106 drop(db);
3107 leave_room_for_session(&session, connection).await?;
3108 db = session.db().await;
3109 }
3110
3111 let (joined_room, membership_updated, role) = db
3112 .join_channel(channel_id, session.user_id(), session.connection_id)
3113 .await?;
3114
3115 let live_kit_connection_info =
3116 session
3117 .app_state
3118 .livekit_client
3119 .as_ref()
3120 .and_then(|live_kit| {
3121 let (can_publish, token) = if role == ChannelRole::Guest {
3122 (
3123 false,
3124 live_kit
3125 .guest_token(
3126 &joined_room.room.livekit_room,
3127 &session.user_id().to_string(),
3128 )
3129 .trace_err()?,
3130 )
3131 } else {
3132 (
3133 true,
3134 live_kit
3135 .room_token(
3136 &joined_room.room.livekit_room,
3137 &session.user_id().to_string(),
3138 )
3139 .trace_err()?,
3140 )
3141 };
3142
3143 Some(LiveKitConnectionInfo {
3144 server_url: live_kit.url().into(),
3145 token,
3146 can_publish,
3147 })
3148 });
3149
3150 response.send(proto::JoinRoomResponse {
3151 room: Some(joined_room.room.clone()),
3152 channel_id: joined_room
3153 .channel
3154 .as_ref()
3155 .map(|channel| channel.id.to_proto()),
3156 live_kit_connection_info,
3157 })?;
3158
3159 let mut connection_pool = session.connection_pool().await;
3160 if let Some(membership_updated) = membership_updated {
3161 notify_membership_updated(
3162 &mut connection_pool,
3163 membership_updated,
3164 session.user_id(),
3165 &session.peer,
3166 );
3167 }
3168
3169 room_updated(&joined_room.room, &session.peer);
3170
3171 joined_room
3172 };
3173
3174 channel_updated(
3175 &joined_room
3176 .channel
3177 .ok_or_else(|| anyhow!("channel not returned"))?,
3178 &joined_room.room,
3179 &session.peer,
3180 &*session.connection_pool().await,
3181 );
3182
3183 update_user_contacts(session.user_id(), &session).await?;
3184 Ok(())
3185}
3186
3187/// Start editing the channel notes
3188async fn join_channel_buffer(
3189 request: proto::JoinChannelBuffer,
3190 response: Response<proto::JoinChannelBuffer>,
3191 session: Session,
3192) -> Result<()> {
3193 let db = session.db().await;
3194 let channel_id = ChannelId::from_proto(request.channel_id);
3195
3196 let open_response = db
3197 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3198 .await?;
3199
3200 let collaborators = open_response.collaborators.clone();
3201 response.send(open_response)?;
3202
3203 let update = UpdateChannelBufferCollaborators {
3204 channel_id: channel_id.to_proto(),
3205 collaborators: collaborators.clone(),
3206 };
3207 channel_buffer_updated(
3208 session.connection_id,
3209 collaborators
3210 .iter()
3211 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3212 &update,
3213 &session.peer,
3214 );
3215
3216 Ok(())
3217}
3218
3219/// Edit the channel notes
3220async fn update_channel_buffer(
3221 request: proto::UpdateChannelBuffer,
3222 session: Session,
3223) -> Result<()> {
3224 let db = session.db().await;
3225 let channel_id = ChannelId::from_proto(request.channel_id);
3226
3227 let (collaborators, epoch, version) = db
3228 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3229 .await?;
3230
3231 channel_buffer_updated(
3232 session.connection_id,
3233 collaborators.clone(),
3234 &proto::UpdateChannelBuffer {
3235 channel_id: channel_id.to_proto(),
3236 operations: request.operations,
3237 },
3238 &session.peer,
3239 );
3240
3241 let pool = &*session.connection_pool().await;
3242
3243 let non_collaborators =
3244 pool.channel_connection_ids(channel_id)
3245 .filter_map(|(connection_id, _)| {
3246 if collaborators.contains(&connection_id) {
3247 None
3248 } else {
3249 Some(connection_id)
3250 }
3251 });
3252
3253 broadcast(None, non_collaborators, |peer_id| {
3254 session.peer.send(
3255 peer_id,
3256 proto::UpdateChannels {
3257 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3258 channel_id: channel_id.to_proto(),
3259 epoch: epoch as u64,
3260 version: version.clone(),
3261 }],
3262 ..Default::default()
3263 },
3264 )
3265 });
3266
3267 Ok(())
3268}
3269
3270/// Rejoin the channel notes after a connection blip
3271async fn rejoin_channel_buffers(
3272 request: proto::RejoinChannelBuffers,
3273 response: Response<proto::RejoinChannelBuffers>,
3274 session: Session,
3275) -> Result<()> {
3276 let db = session.db().await;
3277 let buffers = db
3278 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3279 .await?;
3280
3281 for rejoined_buffer in &buffers {
3282 let collaborators_to_notify = rejoined_buffer
3283 .buffer
3284 .collaborators
3285 .iter()
3286 .filter_map(|c| Some(c.peer_id?.into()));
3287 channel_buffer_updated(
3288 session.connection_id,
3289 collaborators_to_notify,
3290 &proto::UpdateChannelBufferCollaborators {
3291 channel_id: rejoined_buffer.buffer.channel_id,
3292 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3293 },
3294 &session.peer,
3295 );
3296 }
3297
3298 response.send(proto::RejoinChannelBuffersResponse {
3299 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3300 })?;
3301
3302 Ok(())
3303}
3304
3305/// Stop editing the channel notes
3306async fn leave_channel_buffer(
3307 request: proto::LeaveChannelBuffer,
3308 response: Response<proto::LeaveChannelBuffer>,
3309 session: Session,
3310) -> Result<()> {
3311 let db = session.db().await;
3312 let channel_id = ChannelId::from_proto(request.channel_id);
3313
3314 let left_buffer = db
3315 .leave_channel_buffer(channel_id, session.connection_id)
3316 .await?;
3317
3318 response.send(Ack {})?;
3319
3320 channel_buffer_updated(
3321 session.connection_id,
3322 left_buffer.connections,
3323 &proto::UpdateChannelBufferCollaborators {
3324 channel_id: channel_id.to_proto(),
3325 collaborators: left_buffer.collaborators,
3326 },
3327 &session.peer,
3328 );
3329
3330 Ok(())
3331}
3332
3333fn channel_buffer_updated<T: EnvelopedMessage>(
3334 sender_id: ConnectionId,
3335 collaborators: impl IntoIterator<Item = ConnectionId>,
3336 message: &T,
3337 peer: &Peer,
3338) {
3339 broadcast(Some(sender_id), collaborators, |peer_id| {
3340 peer.send(peer_id, message.clone())
3341 });
3342}
3343
3344fn send_notifications(
3345 connection_pool: &ConnectionPool,
3346 peer: &Peer,
3347 notifications: db::NotificationBatch,
3348) {
3349 for (user_id, notification) in notifications {
3350 for connection_id in connection_pool.user_connection_ids(user_id) {
3351 if let Err(error) = peer.send(
3352 connection_id,
3353 proto::AddNotification {
3354 notification: Some(notification.clone()),
3355 },
3356 ) {
3357 tracing::error!(
3358 "failed to send notification to {:?} {}",
3359 connection_id,
3360 error
3361 );
3362 }
3363 }
3364 }
3365}
3366
3367/// Send a message to the channel
3368async fn send_channel_message(
3369 request: proto::SendChannelMessage,
3370 response: Response<proto::SendChannelMessage>,
3371 session: Session,
3372) -> Result<()> {
3373 // Validate the message body.
3374 let body = request.body.trim().to_string();
3375 if body.len() > MAX_MESSAGE_LEN {
3376 return Err(anyhow!("message is too long"))?;
3377 }
3378 if body.is_empty() {
3379 return Err(anyhow!("message can't be blank"))?;
3380 }
3381
3382 // TODO: adjust mentions if body is trimmed
3383
3384 let timestamp = OffsetDateTime::now_utc();
3385 let nonce = request
3386 .nonce
3387 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3388
3389 let channel_id = ChannelId::from_proto(request.channel_id);
3390 let CreatedChannelMessage {
3391 message_id,
3392 participant_connection_ids,
3393 notifications,
3394 } = session
3395 .db()
3396 .await
3397 .create_channel_message(
3398 channel_id,
3399 session.user_id(),
3400 &body,
3401 &request.mentions,
3402 timestamp,
3403 nonce.clone().into(),
3404 request.reply_to_message_id.map(MessageId::from_proto),
3405 )
3406 .await?;
3407
3408 let message = proto::ChannelMessage {
3409 sender_id: session.user_id().to_proto(),
3410 id: message_id.to_proto(),
3411 body,
3412 mentions: request.mentions,
3413 timestamp: timestamp.unix_timestamp() as u64,
3414 nonce: Some(nonce),
3415 reply_to_message_id: request.reply_to_message_id,
3416 edited_at: None,
3417 };
3418 broadcast(
3419 Some(session.connection_id),
3420 participant_connection_ids.clone(),
3421 |connection| {
3422 session.peer.send(
3423 connection,
3424 proto::ChannelMessageSent {
3425 channel_id: channel_id.to_proto(),
3426 message: Some(message.clone()),
3427 },
3428 )
3429 },
3430 );
3431 response.send(proto::SendChannelMessageResponse {
3432 message: Some(message),
3433 })?;
3434
3435 let pool = &*session.connection_pool().await;
3436 let non_participants =
3437 pool.channel_connection_ids(channel_id)
3438 .filter_map(|(connection_id, _)| {
3439 if participant_connection_ids.contains(&connection_id) {
3440 None
3441 } else {
3442 Some(connection_id)
3443 }
3444 });
3445 broadcast(None, non_participants, |peer_id| {
3446 session.peer.send(
3447 peer_id,
3448 proto::UpdateChannels {
3449 latest_channel_message_ids: vec![proto::ChannelMessageId {
3450 channel_id: channel_id.to_proto(),
3451 message_id: message_id.to_proto(),
3452 }],
3453 ..Default::default()
3454 },
3455 )
3456 });
3457 send_notifications(pool, &session.peer, notifications);
3458
3459 Ok(())
3460}
3461
3462/// Delete a channel message
3463async fn remove_channel_message(
3464 request: proto::RemoveChannelMessage,
3465 response: Response<proto::RemoveChannelMessage>,
3466 session: Session,
3467) -> Result<()> {
3468 let channel_id = ChannelId::from_proto(request.channel_id);
3469 let message_id = MessageId::from_proto(request.message_id);
3470 let (connection_ids, existing_notification_ids) = session
3471 .db()
3472 .await
3473 .remove_channel_message(channel_id, message_id, session.user_id())
3474 .await?;
3475
3476 broadcast(
3477 Some(session.connection_id),
3478 connection_ids,
3479 move |connection| {
3480 session.peer.send(connection, request.clone())?;
3481
3482 for notification_id in &existing_notification_ids {
3483 session.peer.send(
3484 connection,
3485 proto::DeleteNotification {
3486 notification_id: (*notification_id).to_proto(),
3487 },
3488 )?;
3489 }
3490
3491 Ok(())
3492 },
3493 );
3494 response.send(proto::Ack {})?;
3495 Ok(())
3496}
3497
3498async fn update_channel_message(
3499 request: proto::UpdateChannelMessage,
3500 response: Response<proto::UpdateChannelMessage>,
3501 session: Session,
3502) -> Result<()> {
3503 let channel_id = ChannelId::from_proto(request.channel_id);
3504 let message_id = MessageId::from_proto(request.message_id);
3505 let updated_at = OffsetDateTime::now_utc();
3506 let UpdatedChannelMessage {
3507 message_id,
3508 participant_connection_ids,
3509 notifications,
3510 reply_to_message_id,
3511 timestamp,
3512 deleted_mention_notification_ids,
3513 updated_mention_notifications,
3514 } = session
3515 .db()
3516 .await
3517 .update_channel_message(
3518 channel_id,
3519 message_id,
3520 session.user_id(),
3521 request.body.as_str(),
3522 &request.mentions,
3523 updated_at,
3524 )
3525 .await?;
3526
3527 let nonce = request
3528 .nonce
3529 .clone()
3530 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3531
3532 let message = proto::ChannelMessage {
3533 sender_id: session.user_id().to_proto(),
3534 id: message_id.to_proto(),
3535 body: request.body.clone(),
3536 mentions: request.mentions.clone(),
3537 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3538 nonce: Some(nonce),
3539 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3540 edited_at: Some(updated_at.unix_timestamp() as u64),
3541 };
3542
3543 response.send(proto::Ack {})?;
3544
3545 let pool = &*session.connection_pool().await;
3546 broadcast(
3547 Some(session.connection_id),
3548 participant_connection_ids,
3549 |connection| {
3550 session.peer.send(
3551 connection,
3552 proto::ChannelMessageUpdate {
3553 channel_id: channel_id.to_proto(),
3554 message: Some(message.clone()),
3555 },
3556 )?;
3557
3558 for notification_id in &deleted_mention_notification_ids {
3559 session.peer.send(
3560 connection,
3561 proto::DeleteNotification {
3562 notification_id: (*notification_id).to_proto(),
3563 },
3564 )?;
3565 }
3566
3567 for notification in &updated_mention_notifications {
3568 session.peer.send(
3569 connection,
3570 proto::UpdateNotification {
3571 notification: Some(notification.clone()),
3572 },
3573 )?;
3574 }
3575
3576 Ok(())
3577 },
3578 );
3579
3580 send_notifications(pool, &session.peer, notifications);
3581
3582 Ok(())
3583}
3584
3585/// Mark a channel message as read
3586async fn acknowledge_channel_message(
3587 request: proto::AckChannelMessage,
3588 session: Session,
3589) -> Result<()> {
3590 let channel_id = ChannelId::from_proto(request.channel_id);
3591 let message_id = MessageId::from_proto(request.message_id);
3592 let notifications = session
3593 .db()
3594 .await
3595 .observe_channel_message(channel_id, session.user_id(), message_id)
3596 .await?;
3597 send_notifications(
3598 &*session.connection_pool().await,
3599 &session.peer,
3600 notifications,
3601 );
3602 Ok(())
3603}
3604
3605/// Mark a buffer version as synced
3606async fn acknowledge_buffer_version(
3607 request: proto::AckBufferOperation,
3608 session: Session,
3609) -> Result<()> {
3610 let buffer_id = BufferId::from_proto(request.buffer_id);
3611 session
3612 .db()
3613 .await
3614 .observe_buffer_version(
3615 buffer_id,
3616 session.user_id(),
3617 request.epoch as i32,
3618 &request.version,
3619 )
3620 .await?;
3621 Ok(())
3622}
3623
3624async fn count_language_model_tokens(
3625 request: proto::CountLanguageModelTokens,
3626 response: Response<proto::CountLanguageModelTokens>,
3627 session: Session,
3628 config: &Config,
3629) -> Result<()> {
3630 authorize_access_to_legacy_llm_endpoints(&session).await?;
3631
3632 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3633 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3634 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3635 };
3636
3637 session
3638 .app_state
3639 .rate_limiter
3640 .check(&*rate_limit, session.user_id())
3641 .await?;
3642
3643 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3644 Some(proto::LanguageModelProvider::Google) => {
3645 let api_key = config
3646 .google_ai_api_key
3647 .as_ref()
3648 .context("no Google AI API key configured on the server")?;
3649 google_ai::count_tokens(
3650 session.http_client.as_ref(),
3651 google_ai::API_URL,
3652 api_key,
3653 serde_json::from_str(&request.request)?,
3654 )
3655 .await?
3656 }
3657 _ => return Err(anyhow!("unsupported provider"))?,
3658 };
3659
3660 response.send(proto::CountLanguageModelTokensResponse {
3661 token_count: result.total_tokens as u32,
3662 })?;
3663
3664 Ok(())
3665}
3666
3667struct ZedProCountLanguageModelTokensRateLimit;
3668
3669impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3670 fn capacity(&self) -> usize {
3671 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3672 .ok()
3673 .and_then(|v| v.parse().ok())
3674 .unwrap_or(600) // Picked arbitrarily
3675 }
3676
3677 fn refill_duration(&self) -> chrono::Duration {
3678 chrono::Duration::hours(1)
3679 }
3680
3681 fn db_name(&self) -> &'static str {
3682 "zed-pro:count-language-model-tokens"
3683 }
3684}
3685
3686struct FreeCountLanguageModelTokensRateLimit;
3687
3688impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3689 fn capacity(&self) -> usize {
3690 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3691 .ok()
3692 .and_then(|v| v.parse().ok())
3693 .unwrap_or(600 / 10) // Picked arbitrarily
3694 }
3695
3696 fn refill_duration(&self) -> chrono::Duration {
3697 chrono::Duration::hours(1)
3698 }
3699
3700 fn db_name(&self) -> &'static str {
3701 "free:count-language-model-tokens"
3702 }
3703}
3704
3705struct ZedProComputeEmbeddingsRateLimit;
3706
3707impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3708 fn capacity(&self) -> usize {
3709 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3710 .ok()
3711 .and_then(|v| v.parse().ok())
3712 .unwrap_or(5000) // Picked arbitrarily
3713 }
3714
3715 fn refill_duration(&self) -> chrono::Duration {
3716 chrono::Duration::hours(1)
3717 }
3718
3719 fn db_name(&self) -> &'static str {
3720 "zed-pro:compute-embeddings"
3721 }
3722}
3723
3724struct FreeComputeEmbeddingsRateLimit;
3725
3726impl RateLimit for FreeComputeEmbeddingsRateLimit {
3727 fn capacity(&self) -> usize {
3728 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3729 .ok()
3730 .and_then(|v| v.parse().ok())
3731 .unwrap_or(5000 / 10) // Picked arbitrarily
3732 }
3733
3734 fn refill_duration(&self) -> chrono::Duration {
3735 chrono::Duration::hours(1)
3736 }
3737
3738 fn db_name(&self) -> &'static str {
3739 "free:compute-embeddings"
3740 }
3741}
3742
3743async fn compute_embeddings(
3744 request: proto::ComputeEmbeddings,
3745 response: Response<proto::ComputeEmbeddings>,
3746 session: Session,
3747 api_key: Option<Arc<str>>,
3748) -> Result<()> {
3749 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3750 authorize_access_to_legacy_llm_endpoints(&session).await?;
3751
3752 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3753 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3754 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3755 };
3756
3757 session
3758 .app_state
3759 .rate_limiter
3760 .check(&*rate_limit, session.user_id())
3761 .await?;
3762
3763 let embeddings = match request.model.as_str() {
3764 "openai/text-embedding-3-small" => {
3765 open_ai::embed(
3766 session.http_client.as_ref(),
3767 OPEN_AI_API_URL,
3768 &api_key,
3769 OpenAiEmbeddingModel::TextEmbedding3Small,
3770 request.texts.iter().map(|text| text.as_str()),
3771 )
3772 .await?
3773 }
3774 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3775 };
3776
3777 let embeddings = request
3778 .texts
3779 .iter()
3780 .map(|text| {
3781 let mut hasher = sha2::Sha256::new();
3782 hasher.update(text.as_bytes());
3783 let result = hasher.finalize();
3784 result.to_vec()
3785 })
3786 .zip(
3787 embeddings
3788 .data
3789 .into_iter()
3790 .map(|embedding| embedding.embedding),
3791 )
3792 .collect::<HashMap<_, _>>();
3793
3794 let db = session.db().await;
3795 db.save_embeddings(&request.model, &embeddings)
3796 .await
3797 .context("failed to save embeddings")
3798 .trace_err();
3799
3800 response.send(proto::ComputeEmbeddingsResponse {
3801 embeddings: embeddings
3802 .into_iter()
3803 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3804 .collect(),
3805 })?;
3806 Ok(())
3807}
3808
3809async fn get_cached_embeddings(
3810 request: proto::GetCachedEmbeddings,
3811 response: Response<proto::GetCachedEmbeddings>,
3812 session: Session,
3813) -> Result<()> {
3814 authorize_access_to_legacy_llm_endpoints(&session).await?;
3815
3816 let db = session.db().await;
3817 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3818
3819 response.send(proto::GetCachedEmbeddingsResponse {
3820 embeddings: embeddings
3821 .into_iter()
3822 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3823 .collect(),
3824 })?;
3825 Ok(())
3826}
3827
3828/// This is leftover from before the LLM service.
3829///
3830/// The endpoints protected by this check will be moved there eventually.
3831async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3832 if session.is_staff() {
3833 Ok(())
3834 } else {
3835 Err(anyhow!("permission denied"))?
3836 }
3837}
3838
3839/// Get a Supermaven API key for the user
3840async fn get_supermaven_api_key(
3841 _request: proto::GetSupermavenApiKey,
3842 response: Response<proto::GetSupermavenApiKey>,
3843 session: Session,
3844) -> Result<()> {
3845 let user_id: String = session.user_id().to_string();
3846 if !session.is_staff() {
3847 return Err(anyhow!("supermaven not enabled for this account"))?;
3848 }
3849
3850 let email = session
3851 .email()
3852 .ok_or_else(|| anyhow!("user must have an email"))?;
3853
3854 let supermaven_admin_api = session
3855 .supermaven_client
3856 .as_ref()
3857 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3858
3859 let result = supermaven_admin_api
3860 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3861 .await?;
3862
3863 response.send(proto::GetSupermavenApiKeyResponse {
3864 api_key: result.api_key,
3865 })?;
3866
3867 Ok(())
3868}
3869
3870/// Start receiving chat updates for a channel
3871async fn join_channel_chat(
3872 request: proto::JoinChannelChat,
3873 response: Response<proto::JoinChannelChat>,
3874 session: Session,
3875) -> Result<()> {
3876 let channel_id = ChannelId::from_proto(request.channel_id);
3877
3878 let db = session.db().await;
3879 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3880 .await?;
3881 let messages = db
3882 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3883 .await?;
3884 response.send(proto::JoinChannelChatResponse {
3885 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3886 messages,
3887 })?;
3888 Ok(())
3889}
3890
3891/// Stop receiving chat updates for a channel
3892async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3893 let channel_id = ChannelId::from_proto(request.channel_id);
3894 session
3895 .db()
3896 .await
3897 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3898 .await?;
3899 Ok(())
3900}
3901
3902/// Retrieve the chat history for a channel
3903async fn get_channel_messages(
3904 request: proto::GetChannelMessages,
3905 response: Response<proto::GetChannelMessages>,
3906 session: Session,
3907) -> Result<()> {
3908 let channel_id = ChannelId::from_proto(request.channel_id);
3909 let messages = session
3910 .db()
3911 .await
3912 .get_channel_messages(
3913 channel_id,
3914 session.user_id(),
3915 MESSAGE_COUNT_PER_PAGE,
3916 Some(MessageId::from_proto(request.before_message_id)),
3917 )
3918 .await?;
3919 response.send(proto::GetChannelMessagesResponse {
3920 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921 messages,
3922 })?;
3923 Ok(())
3924}
3925
3926/// Retrieve specific chat messages
3927async fn get_channel_messages_by_id(
3928 request: proto::GetChannelMessagesById,
3929 response: Response<proto::GetChannelMessagesById>,
3930 session: Session,
3931) -> Result<()> {
3932 let message_ids = request
3933 .message_ids
3934 .iter()
3935 .map(|id| MessageId::from_proto(*id))
3936 .collect::<Vec<_>>();
3937 let messages = session
3938 .db()
3939 .await
3940 .get_channel_messages_by_id(session.user_id(), &message_ids)
3941 .await?;
3942 response.send(proto::GetChannelMessagesResponse {
3943 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3944 messages,
3945 })?;
3946 Ok(())
3947}
3948
3949/// Retrieve the current users notifications
3950async fn get_notifications(
3951 request: proto::GetNotifications,
3952 response: Response<proto::GetNotifications>,
3953 session: Session,
3954) -> Result<()> {
3955 let notifications = session
3956 .db()
3957 .await
3958 .get_notifications(
3959 session.user_id(),
3960 NOTIFICATION_COUNT_PER_PAGE,
3961 request.before_id.map(db::NotificationId::from_proto),
3962 )
3963 .await?;
3964 response.send(proto::GetNotificationsResponse {
3965 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3966 notifications,
3967 })?;
3968 Ok(())
3969}
3970
3971/// Mark notifications as read
3972async fn mark_notification_as_read(
3973 request: proto::MarkNotificationRead,
3974 response: Response<proto::MarkNotificationRead>,
3975 session: Session,
3976) -> Result<()> {
3977 let database = &session.db().await;
3978 let notifications = database
3979 .mark_notification_as_read_by_id(
3980 session.user_id(),
3981 NotificationId::from_proto(request.notification_id),
3982 )
3983 .await?;
3984 send_notifications(
3985 &*session.connection_pool().await,
3986 &session.peer,
3987 notifications,
3988 );
3989 response.send(proto::Ack {})?;
3990 Ok(())
3991}
3992
3993/// Get the current users information
3994async fn get_private_user_info(
3995 _request: proto::GetPrivateUserInfo,
3996 response: Response<proto::GetPrivateUserInfo>,
3997 session: Session,
3998) -> Result<()> {
3999 let db = session.db().await;
4000
4001 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4002 let user = db
4003 .get_user_by_id(session.user_id())
4004 .await?
4005 .ok_or_else(|| anyhow!("user not found"))?;
4006 let flags = db.get_user_flags(session.user_id()).await?;
4007
4008 response.send(proto::GetPrivateUserInfoResponse {
4009 metrics_id,
4010 staff: user.admin,
4011 flags,
4012 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4013 })?;
4014 Ok(())
4015}
4016
4017/// Accept the terms of service (tos) on behalf of the current user
4018async fn accept_terms_of_service(
4019 _request: proto::AcceptTermsOfService,
4020 response: Response<proto::AcceptTermsOfService>,
4021 session: Session,
4022) -> Result<()> {
4023 let db = session.db().await;
4024
4025 let accepted_tos_at = Utc::now();
4026 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4027 .await?;
4028
4029 response.send(proto::AcceptTermsOfServiceResponse {
4030 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4031 })?;
4032 Ok(())
4033}
4034
4035/// The minimum account age an account must have in order to use the LLM service.
4036const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4037
4038async fn get_llm_api_token(
4039 _request: proto::GetLlmToken,
4040 response: Response<proto::GetLlmToken>,
4041 session: Session,
4042) -> Result<()> {
4043 let db = session.db().await;
4044
4045 let flags = db.get_user_flags(session.user_id()).await?;
4046 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4047 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4048 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4049
4050 if !session.is_staff() && !has_language_models_feature_flag {
4051 Err(anyhow!("permission denied"))?
4052 }
4053
4054 let user_id = session.user_id();
4055 let user = db
4056 .get_user_by_id(user_id)
4057 .await?
4058 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4059
4060 if user.accepted_tos_at.is_none() {
4061 Err(anyhow!("terms of service not accepted"))?
4062 }
4063
4064 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4065
4066 let bypass_account_age_check =
4067 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4068 if !bypass_account_age_check {
4069 let mut account_created_at = user.created_at;
4070 if let Some(github_created_at) = user.github_user_created_at {
4071 account_created_at = account_created_at.min(github_created_at);
4072 }
4073 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4074 Err(anyhow!("account too young"))?
4075 }
4076 }
4077
4078 let billing_preferences = db.get_billing_preferences(user.id).await?;
4079
4080 let token = LlmTokenClaims::create(
4081 &user,
4082 session.is_staff(),
4083 billing_preferences,
4084 has_llm_closed_beta_feature_flag,
4085 has_predict_edits_feature_flag,
4086 has_llm_subscription,
4087 session.current_plan(&db).await?,
4088 session.system_id.clone(),
4089 &session.app_state.config,
4090 )?;
4091 response.send(proto::GetLlmTokenResponse { token })?;
4092 Ok(())
4093}
4094
4095fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4096 let message = match message {
4097 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4098 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4099 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4100 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4101 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4102 code: frame.code.into(),
4103 reason: frame.reason,
4104 })),
4105 // We should never receive a frame while reading the message, according
4106 // to the `tungstenite` maintainers:
4107 //
4108 // > It cannot occur when you read messages from the WebSocket, but it
4109 // > can be used when you want to send the raw frames (e.g. you want to
4110 // > send the frames to the WebSocket without composing the full message first).
4111 // >
4112 // > — https://github.com/snapview/tungstenite-rs/issues/268
4113 TungsteniteMessage::Frame(_) => {
4114 bail!("received an unexpected frame while reading the message")
4115 }
4116 };
4117
4118 Ok(message)
4119}
4120
4121fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4122 match message {
4123 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4124 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4125 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4126 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4127 AxumMessage::Close(frame) => {
4128 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4129 code: frame.code.into(),
4130 reason: frame.reason,
4131 }))
4132 }
4133 }
4134}
4135
4136fn notify_membership_updated(
4137 connection_pool: &mut ConnectionPool,
4138 result: MembershipUpdated,
4139 user_id: UserId,
4140 peer: &Peer,
4141) {
4142 for membership in &result.new_channels.channel_memberships {
4143 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4144 }
4145 for channel_id in &result.removed_channels {
4146 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4147 }
4148
4149 let user_channels_update = proto::UpdateUserChannels {
4150 channel_memberships: result
4151 .new_channels
4152 .channel_memberships
4153 .iter()
4154 .map(|cm| proto::ChannelMembership {
4155 channel_id: cm.channel_id.to_proto(),
4156 role: cm.role.into(),
4157 })
4158 .collect(),
4159 ..Default::default()
4160 };
4161
4162 let mut update = build_channels_update(result.new_channels);
4163 update.delete_channels = result
4164 .removed_channels
4165 .into_iter()
4166 .map(|id| id.to_proto())
4167 .collect();
4168 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4169
4170 for connection_id in connection_pool.user_connection_ids(user_id) {
4171 peer.send(connection_id, user_channels_update.clone())
4172 .trace_err();
4173 peer.send(connection_id, update.clone()).trace_err();
4174 }
4175}
4176
4177fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4178 proto::UpdateUserChannels {
4179 channel_memberships: channels
4180 .channel_memberships
4181 .iter()
4182 .map(|m| proto::ChannelMembership {
4183 channel_id: m.channel_id.to_proto(),
4184 role: m.role.into(),
4185 })
4186 .collect(),
4187 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4188 observed_channel_message_id: channels.observed_channel_messages.clone(),
4189 }
4190}
4191
4192fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4193 let mut update = proto::UpdateChannels::default();
4194
4195 for channel in channels.channels {
4196 update.channels.push(channel.to_proto());
4197 }
4198
4199 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4200 update.latest_channel_message_ids = channels.latest_channel_messages;
4201
4202 for (channel_id, participants) in channels.channel_participants {
4203 update
4204 .channel_participants
4205 .push(proto::ChannelParticipants {
4206 channel_id: channel_id.to_proto(),
4207 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4208 });
4209 }
4210
4211 for channel in channels.invited_channels {
4212 update.channel_invitations.push(channel.to_proto());
4213 }
4214
4215 update
4216}
4217
4218fn build_initial_contacts_update(
4219 contacts: Vec<db::Contact>,
4220 pool: &ConnectionPool,
4221) -> proto::UpdateContacts {
4222 let mut update = proto::UpdateContacts::default();
4223
4224 for contact in contacts {
4225 match contact {
4226 db::Contact::Accepted { user_id, busy } => {
4227 update.contacts.push(contact_for_user(user_id, busy, pool));
4228 }
4229 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4230 db::Contact::Incoming { user_id } => {
4231 update
4232 .incoming_requests
4233 .push(proto::IncomingContactRequest {
4234 requester_id: user_id.to_proto(),
4235 })
4236 }
4237 }
4238 }
4239
4240 update
4241}
4242
4243fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4244 proto::Contact {
4245 user_id: user_id.to_proto(),
4246 online: pool.is_user_online(user_id),
4247 busy,
4248 }
4249}
4250
4251fn room_updated(room: &proto::Room, peer: &Peer) {
4252 broadcast(
4253 None,
4254 room.participants
4255 .iter()
4256 .filter_map(|participant| Some(participant.peer_id?.into())),
4257 |peer_id| {
4258 peer.send(
4259 peer_id,
4260 proto::RoomUpdated {
4261 room: Some(room.clone()),
4262 },
4263 )
4264 },
4265 );
4266}
4267
4268fn channel_updated(
4269 channel: &db::channel::Model,
4270 room: &proto::Room,
4271 peer: &Peer,
4272 pool: &ConnectionPool,
4273) {
4274 let participants = room
4275 .participants
4276 .iter()
4277 .map(|p| p.user_id)
4278 .collect::<Vec<_>>();
4279
4280 broadcast(
4281 None,
4282 pool.channel_connection_ids(channel.root_id())
4283 .filter_map(|(channel_id, role)| {
4284 role.can_see_channel(channel.visibility)
4285 .then_some(channel_id)
4286 }),
4287 |peer_id| {
4288 peer.send(
4289 peer_id,
4290 proto::UpdateChannels {
4291 channel_participants: vec![proto::ChannelParticipants {
4292 channel_id: channel.id.to_proto(),
4293 participant_user_ids: participants.clone(),
4294 }],
4295 ..Default::default()
4296 },
4297 )
4298 },
4299 );
4300}
4301
4302async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4303 let db = session.db().await;
4304
4305 let contacts = db.get_contacts(user_id).await?;
4306 let busy = db.is_user_busy(user_id).await?;
4307
4308 let pool = session.connection_pool().await;
4309 let updated_contact = contact_for_user(user_id, busy, &pool);
4310 for contact in contacts {
4311 if let db::Contact::Accepted {
4312 user_id: contact_user_id,
4313 ..
4314 } = contact
4315 {
4316 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4317 session
4318 .peer
4319 .send(
4320 contact_conn_id,
4321 proto::UpdateContacts {
4322 contacts: vec![updated_contact.clone()],
4323 remove_contacts: Default::default(),
4324 incoming_requests: Default::default(),
4325 remove_incoming_requests: Default::default(),
4326 outgoing_requests: Default::default(),
4327 remove_outgoing_requests: Default::default(),
4328 },
4329 )
4330 .trace_err();
4331 }
4332 }
4333 }
4334 Ok(())
4335}
4336
4337async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4338 let mut contacts_to_update = HashSet::default();
4339
4340 let room_id;
4341 let canceled_calls_to_user_ids;
4342 let livekit_room;
4343 let delete_livekit_room;
4344 let room;
4345 let channel;
4346
4347 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4348 contacts_to_update.insert(session.user_id());
4349
4350 for project in left_room.left_projects.values() {
4351 project_left(project, session);
4352 }
4353
4354 room_id = RoomId::from_proto(left_room.room.id);
4355 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4356 livekit_room = mem::take(&mut left_room.room.livekit_room);
4357 delete_livekit_room = left_room.deleted;
4358 room = mem::take(&mut left_room.room);
4359 channel = mem::take(&mut left_room.channel);
4360
4361 room_updated(&room, &session.peer);
4362 } else {
4363 return Ok(());
4364 }
4365
4366 if let Some(channel) = channel {
4367 channel_updated(
4368 &channel,
4369 &room,
4370 &session.peer,
4371 &*session.connection_pool().await,
4372 );
4373 }
4374
4375 {
4376 let pool = session.connection_pool().await;
4377 for canceled_user_id in canceled_calls_to_user_ids {
4378 for connection_id in pool.user_connection_ids(canceled_user_id) {
4379 session
4380 .peer
4381 .send(
4382 connection_id,
4383 proto::CallCanceled {
4384 room_id: room_id.to_proto(),
4385 },
4386 )
4387 .trace_err();
4388 }
4389 contacts_to_update.insert(canceled_user_id);
4390 }
4391 }
4392
4393 for contact_user_id in contacts_to_update {
4394 update_user_contacts(contact_user_id, session).await?;
4395 }
4396
4397 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4398 live_kit
4399 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4400 .await
4401 .trace_err();
4402
4403 if delete_livekit_room {
4404 live_kit.delete_room(livekit_room).await.trace_err();
4405 }
4406 }
4407
4408 Ok(())
4409}
4410
4411async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4412 let left_channel_buffers = session
4413 .db()
4414 .await
4415 .leave_channel_buffers(session.connection_id)
4416 .await?;
4417
4418 for left_buffer in left_channel_buffers {
4419 channel_buffer_updated(
4420 session.connection_id,
4421 left_buffer.connections,
4422 &proto::UpdateChannelBufferCollaborators {
4423 channel_id: left_buffer.channel_id.to_proto(),
4424 collaborators: left_buffer.collaborators,
4425 },
4426 &session.peer,
4427 );
4428 }
4429
4430 Ok(())
4431}
4432
4433fn project_left(project: &db::LeftProject, session: &Session) {
4434 for connection_id in &project.connection_ids {
4435 if project.should_unshare {
4436 session
4437 .peer
4438 .send(
4439 *connection_id,
4440 proto::UnshareProject {
4441 project_id: project.id.to_proto(),
4442 },
4443 )
4444 .trace_err();
4445 } else {
4446 session
4447 .peer
4448 .send(
4449 *connection_id,
4450 proto::RemoveProjectCollaborator {
4451 project_id: project.id.to_proto(),
4452 peer_id: Some(session.connection_id.into()),
4453 },
4454 )
4455 .trace_err();
4456 }
4457 }
4458}
4459
4460pub trait ResultExt {
4461 type Ok;
4462
4463 fn trace_err(self) -> Option<Self::Ok>;
4464}
4465
4466impl<T, E> ResultExt for Result<T, E>
4467where
4468 E: std::fmt::Debug,
4469{
4470 type Ok = T;
4471
4472 #[track_caller]
4473 fn trace_err(self) -> Option<T> {
4474 match self {
4475 Ok(value) => Some(value),
4476 Err(error) => {
4477 tracing::error!("{:?}", error);
4478 None
4479 }
4480 }
4481 }
4482}