1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
332 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
342 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
345 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
346 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
347 .add_message_handler(create_buffer_for_peer)
348 .add_request_handler(update_buffer)
349 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
354 .add_request_handler(get_users)
355 .add_request_handler(fuzzy_search_users)
356 .add_request_handler(request_contact)
357 .add_request_handler(remove_contact)
358 .add_request_handler(respond_to_contact_request)
359 .add_message_handler(subscribe_to_channels)
360 .add_request_handler(create_channel)
361 .add_request_handler(delete_channel)
362 .add_request_handler(invite_channel_member)
363 .add_request_handler(remove_channel_member)
364 .add_request_handler(set_channel_member_role)
365 .add_request_handler(set_channel_visibility)
366 .add_request_handler(rename_channel)
367 .add_request_handler(join_channel_buffer)
368 .add_request_handler(leave_channel_buffer)
369 .add_message_handler(update_channel_buffer)
370 .add_request_handler(rejoin_channel_buffers)
371 .add_request_handler(get_channel_members)
372 .add_request_handler(respond_to_channel_invite)
373 .add_request_handler(join_channel)
374 .add_request_handler(join_channel_chat)
375 .add_message_handler(leave_channel_chat)
376 .add_request_handler(send_channel_message)
377 .add_request_handler(remove_channel_message)
378 .add_request_handler(update_channel_message)
379 .add_request_handler(get_channel_messages)
380 .add_request_handler(get_channel_messages_by_id)
381 .add_request_handler(get_notifications)
382 .add_request_handler(mark_notification_as_read)
383 .add_request_handler(move_channel)
384 .add_request_handler(follow)
385 .add_message_handler(unfollow)
386 .add_message_handler(update_followers)
387 .add_request_handler(get_private_user_info)
388 .add_request_handler(get_llm_api_token)
389 .add_request_handler(accept_terms_of_service)
390 .add_message_handler(acknowledge_channel_message)
391 .add_message_handler(acknowledge_buffer_version)
392 .add_request_handler(get_supermaven_api_key)
393 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
395 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
396 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
398 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
399 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
400 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
401 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
402 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
403 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
404 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
405 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
406 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
407 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
408 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
410 .add_message_handler(update_context)
411 .add_request_handler({
412 let app_state = app_state.clone();
413 move |request, response, session| {
414 let app_state = app_state.clone();
415 async move {
416 count_language_model_tokens(request, response, session, &app_state.config)
417 .await
418 }
419 }
420 })
421 .add_request_handler(get_cached_embeddings)
422 .add_request_handler({
423 let app_state = app_state.clone();
424 move |request, response, session| {
425 compute_embeddings(
426 request,
427 response,
428 session,
429 app_state.config.openai_api_key.clone(),
430 )
431 }
432 });
433
434 Arc::new(server)
435 }
436
437 pub async fn start(&self) -> Result<()> {
438 let server_id = *self.id.lock();
439 let app_state = self.app_state.clone();
440 let peer = self.peer.clone();
441 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
442 let pool = self.connection_pool.clone();
443 let livekit_client = self.app_state.livekit_client.clone();
444
445 let span = info_span!("start server");
446 self.app_state.executor.spawn_detached(
447 async move {
448 tracing::info!("waiting for cleanup timeout");
449 timeout.await;
450 tracing::info!("cleanup timeout expired, retrieving stale rooms");
451 if let Some((room_ids, channel_ids)) = app_state
452 .db
453 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
454 .await
455 .trace_err()
456 {
457 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
458 tracing::info!(
459 stale_channel_buffer_count = channel_ids.len(),
460 "retrieved stale channel buffers"
461 );
462
463 for channel_id in channel_ids {
464 if let Some(refreshed_channel_buffer) = app_state
465 .db
466 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
467 .await
468 .trace_err()
469 {
470 for connection_id in refreshed_channel_buffer.connection_ids {
471 peer.send(
472 connection_id,
473 proto::UpdateChannelBufferCollaborators {
474 channel_id: channel_id.to_proto(),
475 collaborators: refreshed_channel_buffer
476 .collaborators
477 .clone(),
478 },
479 )
480 .trace_err();
481 }
482 }
483 }
484
485 for room_id in room_ids {
486 let mut contacts_to_update = HashSet::default();
487 let mut canceled_calls_to_user_ids = Vec::new();
488 let mut livekit_room = String::new();
489 let mut delete_livekit_room = false;
490
491 if let Some(mut refreshed_room) = app_state
492 .db
493 .clear_stale_room_participants(room_id, server_id)
494 .await
495 .trace_err()
496 {
497 tracing::info!(
498 room_id = room_id.0,
499 new_participant_count = refreshed_room.room.participants.len(),
500 "refreshed room"
501 );
502 room_updated(&refreshed_room.room, &peer);
503 if let Some(channel) = refreshed_room.channel.as_ref() {
504 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
505 }
506 contacts_to_update
507 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
508 contacts_to_update
509 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
510 canceled_calls_to_user_ids =
511 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
512 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
513 delete_livekit_room = refreshed_room.room.participants.is_empty();
514 }
515
516 {
517 let pool = pool.lock();
518 for canceled_user_id in canceled_calls_to_user_ids {
519 for connection_id in pool.user_connection_ids(canceled_user_id) {
520 peer.send(
521 connection_id,
522 proto::CallCanceled {
523 room_id: room_id.to_proto(),
524 },
525 )
526 .trace_err();
527 }
528 }
529 }
530
531 for user_id in contacts_to_update {
532 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
533 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
534 if let Some((busy, contacts)) = busy.zip(contacts) {
535 let pool = pool.lock();
536 let updated_contact = contact_for_user(user_id, busy, &pool);
537 for contact in contacts {
538 if let db::Contact::Accepted {
539 user_id: contact_user_id,
540 ..
541 } = contact
542 {
543 for contact_conn_id in
544 pool.user_connection_ids(contact_user_id)
545 {
546 peer.send(
547 contact_conn_id,
548 proto::UpdateContacts {
549 contacts: vec![updated_contact.clone()],
550 remove_contacts: Default::default(),
551 incoming_requests: Default::default(),
552 remove_incoming_requests: Default::default(),
553 outgoing_requests: Default::default(),
554 remove_outgoing_requests: Default::default(),
555 },
556 )
557 .trace_err();
558 }
559 }
560 }
561 }
562 }
563
564 if let Some(live_kit) = livekit_client.as_ref() {
565 if delete_livekit_room {
566 live_kit.delete_room(livekit_room).await.trace_err();
567 }
568 }
569 }
570 }
571
572 app_state
573 .db
574 .delete_stale_servers(&app_state.config.zed_environment, server_id)
575 .await
576 .trace_err();
577 }
578 .instrument(span),
579 );
580 Ok(())
581 }
582
583 pub fn teardown(&self) {
584 self.peer.teardown();
585 self.connection_pool.lock().reset();
586 let _ = self.teardown.send(true);
587 }
588
589 #[cfg(test)]
590 pub fn reset(&self, id: ServerId) {
591 self.teardown();
592 *self.id.lock() = id;
593 self.peer.reset(id.0 as u32);
594 let _ = self.teardown.send(false);
595 }
596
597 #[cfg(test)]
598 pub fn id(&self) -> ServerId {
599 *self.id.lock()
600 }
601
602 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
603 where
604 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
605 Fut: 'static + Send + Future<Output = Result<()>>,
606 M: EnvelopedMessage,
607 {
608 let prev_handler = self.handlers.insert(
609 TypeId::of::<M>(),
610 Box::new(move |envelope, session| {
611 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
612 let received_at = envelope.received_at;
613 tracing::info!("message received");
614 let start_time = Instant::now();
615 let future = (handler)(*envelope, session);
616 async move {
617 let result = future.await;
618 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
619 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
620 let queue_duration_ms = total_duration_ms - processing_duration_ms;
621 let payload_type = M::NAME;
622
623 match result {
624 Err(error) => {
625 tracing::error!(
626 ?error,
627 total_duration_ms,
628 processing_duration_ms,
629 queue_duration_ms,
630 payload_type,
631 "error handling message"
632 )
633 }
634 Ok(()) => tracing::info!(
635 total_duration_ms,
636 processing_duration_ms,
637 queue_duration_ms,
638 "finished handling message"
639 ),
640 }
641 }
642 .boxed()
643 }),
644 );
645 if prev_handler.is_some() {
646 panic!("registered a handler for the same message twice");
647 }
648 self
649 }
650
651 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
652 where
653 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
654 Fut: 'static + Send + Future<Output = Result<()>>,
655 M: EnvelopedMessage,
656 {
657 self.add_handler(move |envelope, session| handler(envelope.payload, session));
658 self
659 }
660
661 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
662 where
663 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
664 Fut: Send + Future<Output = Result<()>>,
665 M: RequestMessage,
666 {
667 let handler = Arc::new(handler);
668 self.add_handler(move |envelope, session| {
669 let receipt = envelope.receipt();
670 let handler = handler.clone();
671 async move {
672 let peer = session.peer.clone();
673 let responded = Arc::new(AtomicBool::default());
674 let response = Response {
675 peer: peer.clone(),
676 responded: responded.clone(),
677 receipt,
678 };
679 match (handler)(envelope.payload, response, session).await {
680 Ok(()) => {
681 if responded.load(std::sync::atomic::Ordering::SeqCst) {
682 Ok(())
683 } else {
684 Err(anyhow!("handler did not send a response"))?
685 }
686 }
687 Err(error) => {
688 let proto_err = match &error {
689 Error::Internal(err) => err.to_proto(),
690 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
691 };
692 peer.respond_with_error(receipt, proto_err)?;
693 Err(error)
694 }
695 }
696 }
697 })
698 }
699
700 #[allow(clippy::too_many_arguments)]
701 pub fn handle_connection(
702 self: &Arc<Self>,
703 connection: Connection,
704 address: String,
705 principal: Principal,
706 zed_version: ZedVersion,
707 geoip_country_code: Option<String>,
708 system_id: Option<String>,
709 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
710 executor: Executor,
711 ) -> impl Future<Output = ()> {
712 let this = self.clone();
713 let span = info_span!("handle connection", %address,
714 connection_id=field::Empty,
715 user_id=field::Empty,
716 login=field::Empty,
717 impersonator=field::Empty,
718 geoip_country_code=field::Empty
719 );
720 principal.update_span(&span);
721 if let Some(country_code) = geoip_country_code.as_ref() {
722 span.record("geoip_country_code", country_code);
723 }
724
725 let mut teardown = self.teardown.subscribe();
726 async move {
727 if *teardown.borrow() {
728 tracing::error!("server is tearing down");
729 return
730 }
731 let (connection_id, handle_io, mut incoming_rx) = this
732 .peer
733 .add_connection(connection, {
734 let executor = executor.clone();
735 move |duration| executor.sleep(duration)
736 });
737 tracing::Span::current().record("connection_id", format!("{}", connection_id));
738
739 tracing::info!("connection opened");
740
741 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
742 let http_client = match ReqwestClient::user_agent(&user_agent) {
743 Ok(http_client) => Arc::new(http_client),
744 Err(error) => {
745 tracing::error!(?error, "failed to create HTTP client");
746 return;
747 }
748 };
749
750 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
751 supermaven_admin_api_key.to_string(),
752 http_client.clone(),
753 )));
754
755 let session = Session {
756 principal: principal.clone(),
757 connection_id,
758 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
759 peer: this.peer.clone(),
760 connection_pool: this.connection_pool.clone(),
761 app_state: this.app_state.clone(),
762 http_client,
763 geoip_country_code,
764 system_id,
765 _executor: executor.clone(),
766 supermaven_client,
767 };
768
769 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
770 tracing::error!(?error, "failed to send initial client update");
771 return;
772 }
773
774 let handle_io = handle_io.fuse();
775 futures::pin_mut!(handle_io);
776
777 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
778 // This prevents deadlocks when e.g., client A performs a request to client B and
779 // client B performs a request to client A. If both clients stop processing further
780 // messages until their respective request completes, they won't have a chance to
781 // respond to the other client's request and cause a deadlock.
782 //
783 // This arrangement ensures we will attempt to process earlier messages first, but fall
784 // back to processing messages arrived later in the spirit of making progress.
785 let mut foreground_message_handlers = FuturesUnordered::new();
786 let concurrent_handlers = Arc::new(Semaphore::new(256));
787 loop {
788 let next_message = async {
789 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
790 let message = incoming_rx.next().await;
791 (permit, message)
792 }.fuse();
793 futures::pin_mut!(next_message);
794 futures::select_biased! {
795 _ = teardown.changed().fuse() => return,
796 result = handle_io => {
797 if let Err(error) = result {
798 tracing::error!(?error, "error handling I/O");
799 }
800 break;
801 }
802 _ = foreground_message_handlers.next() => {}
803 next_message = next_message => {
804 let (permit, message) = next_message;
805 if let Some(message) = message {
806 let type_name = message.payload_type_name();
807 // note: we copy all the fields from the parent span so we can query them in the logs.
808 // (https://github.com/tokio-rs/tracing/issues/2670).
809 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
810 user_id=field::Empty,
811 login=field::Empty,
812 impersonator=field::Empty,
813 );
814 principal.update_span(&span);
815 let span_enter = span.enter();
816 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
817 let is_background = message.is_background();
818 let handle_message = (handler)(message, session.clone());
819 drop(span_enter);
820
821 let handle_message = async move {
822 handle_message.await;
823 drop(permit);
824 }.instrument(span);
825 if is_background {
826 executor.spawn_detached(handle_message);
827 } else {
828 foreground_message_handlers.push(handle_message);
829 }
830 } else {
831 tracing::error!("no message handler");
832 }
833 } else {
834 tracing::info!("connection closed");
835 break;
836 }
837 }
838 }
839 }
840
841 drop(foreground_message_handlers);
842 tracing::info!("signing out");
843 if let Err(error) = connection_lost(session, teardown, executor).await {
844 tracing::error!(?error, "error signing out");
845 }
846
847 }.instrument(span)
848 }
849
850 async fn send_initial_client_update(
851 &self,
852 connection_id: ConnectionId,
853 principal: &Principal,
854 zed_version: ZedVersion,
855 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
856 session: &Session,
857 ) -> Result<()> {
858 self.peer.send(
859 connection_id,
860 proto::Hello {
861 peer_id: Some(connection_id.into()),
862 },
863 )?;
864 tracing::info!("sent hello message");
865 if let Some(send_connection_id) = send_connection_id.take() {
866 let _ = send_connection_id.send(connection_id);
867 }
868
869 match principal {
870 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
871 if !user.connected_once {
872 self.peer.send(connection_id, proto::ShowContacts {})?;
873 self.app_state
874 .db
875 .set_user_connected_once(user.id, true)
876 .await?;
877 }
878
879 update_user_plan(user.id, session).await?;
880
881 let contacts = self.app_state.db.get_contacts(user.id).await?;
882
883 {
884 let mut pool = self.connection_pool.lock();
885 pool.add_connection(connection_id, user.id, user.admin, zed_version);
886 self.peer.send(
887 connection_id,
888 build_initial_contacts_update(contacts, &pool),
889 )?;
890 }
891
892 if should_auto_subscribe_to_channels(zed_version) {
893 subscribe_user_to_channels(user.id, session).await?;
894 }
895
896 if let Some(incoming_call) =
897 self.app_state.db.incoming_call_for_user(user.id).await?
898 {
899 self.peer.send(connection_id, incoming_call)?;
900 }
901
902 update_user_contacts(user.id, session).await?;
903 }
904 }
905
906 Ok(())
907 }
908
909 pub async fn invite_code_redeemed(
910 self: &Arc<Self>,
911 inviter_id: UserId,
912 invitee_id: UserId,
913 ) -> Result<()> {
914 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
915 if let Some(code) = &user.invite_code {
916 let pool = self.connection_pool.lock();
917 let invitee_contact = contact_for_user(invitee_id, false, &pool);
918 for connection_id in pool.user_connection_ids(inviter_id) {
919 self.peer.send(
920 connection_id,
921 proto::UpdateContacts {
922 contacts: vec![invitee_contact.clone()],
923 ..Default::default()
924 },
925 )?;
926 self.peer.send(
927 connection_id,
928 proto::UpdateInviteInfo {
929 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
930 count: user.invite_count as u32,
931 },
932 )?;
933 }
934 }
935 }
936 Ok(())
937 }
938
939 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
940 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
941 if let Some(invite_code) = &user.invite_code {
942 let pool = self.connection_pool.lock();
943 for connection_id in pool.user_connection_ids(user_id) {
944 self.peer.send(
945 connection_id,
946 proto::UpdateInviteInfo {
947 url: format!(
948 "{}{}",
949 self.app_state.config.invite_link_prefix, invite_code
950 ),
951 count: user.invite_count as u32,
952 },
953 )?;
954 }
955 }
956 }
957 Ok(())
958 }
959
960 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
961 let pool = self.connection_pool.lock();
962 for connection_id in pool.user_connection_ids(user_id) {
963 self.peer
964 .send(connection_id, proto::RefreshLlmToken {})
965 .trace_err();
966 }
967 }
968
969 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
970 ServerSnapshot {
971 connection_pool: ConnectionPoolGuard {
972 guard: self.connection_pool.lock(),
973 _not_send: PhantomData,
974 },
975 peer: &self.peer,
976 }
977 }
978}
979
980impl Deref for ConnectionPoolGuard<'_> {
981 type Target = ConnectionPool;
982
983 fn deref(&self) -> &Self::Target {
984 &self.guard
985 }
986}
987
988impl DerefMut for ConnectionPoolGuard<'_> {
989 fn deref_mut(&mut self) -> &mut Self::Target {
990 &mut self.guard
991 }
992}
993
994impl Drop for ConnectionPoolGuard<'_> {
995 fn drop(&mut self) {
996 #[cfg(test)]
997 self.check_invariants();
998 }
999}
1000
1001fn broadcast<F>(
1002 sender_id: Option<ConnectionId>,
1003 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004 mut f: F,
1005) where
1006 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008 for receiver_id in receiver_ids {
1009 if Some(receiver_id) != sender_id {
1010 if let Err(error) = f(receiver_id) {
1011 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012 }
1013 }
1014 }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020 fn name() -> &'static HeaderName {
1021 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023 }
1024
1025 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026 where
1027 Self: Sized,
1028 I: Iterator<Item = &'i axum::http::HeaderValue>,
1029 {
1030 let version = values
1031 .next()
1032 .ok_or_else(axum::headers::Error::invalid)?
1033 .to_str()
1034 .map_err(|_| axum::headers::Error::invalid())?
1035 .parse()
1036 .map_err(|_| axum::headers::Error::invalid())?;
1037 Ok(Self(version))
1038 }
1039
1040 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041 values.extend([self.0.to_string().parse().unwrap()]);
1042 }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047 fn name() -> &'static HeaderName {
1048 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050 }
1051
1052 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053 where
1054 Self: Sized,
1055 I: Iterator<Item = &'i axum::http::HeaderValue>,
1056 {
1057 let version = values
1058 .next()
1059 .ok_or_else(axum::headers::Error::invalid)?
1060 .to_str()
1061 .map_err(|_| axum::headers::Error::invalid())?
1062 .parse()
1063 .map_err(|_| axum::headers::Error::invalid())?;
1064 Ok(Self(version))
1065 }
1066
1067 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068 values.extend([self.0.to_string().parse().unwrap()]);
1069 }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073 Router::new()
1074 .route("/rpc", get(handle_websocket_request))
1075 .layer(
1076 ServiceBuilder::new()
1077 .layer(Extension(server.app_state.clone()))
1078 .layer(middleware::from_fn(auth::validate_header)),
1079 )
1080 .route("/metrics", get(handle_metrics))
1081 .layer(Extension(server))
1082}
1083
1084#[allow(clippy::too_many_arguments)]
1085pub async fn handle_websocket_request(
1086 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089 Extension(server): Extension<Arc<Server>>,
1090 Extension(principal): Extension<Principal>,
1091 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093 ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095 if protocol_version != rpc::PROTOCOL_VERSION {
1096 return (
1097 StatusCode::UPGRADE_REQUIRED,
1098 "client must be upgraded".to_string(),
1099 )
1100 .into_response();
1101 }
1102
1103 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1104 return (
1105 StatusCode::UPGRADE_REQUIRED,
1106 "no version header found".to_string(),
1107 )
1108 .into_response();
1109 };
1110
1111 if !version.can_collaborate() {
1112 return (
1113 StatusCode::UPGRADE_REQUIRED,
1114 "client must be upgraded".to_string(),
1115 )
1116 .into_response();
1117 }
1118
1119 let socket_address = socket_address.to_string();
1120 ws.on_upgrade(move |socket| {
1121 let socket = socket
1122 .map_ok(to_tungstenite_message)
1123 .err_into()
1124 .with(|message| async move { to_axum_message(message) });
1125 let connection = Connection::new(Box::pin(socket));
1126 async move {
1127 server
1128 .handle_connection(
1129 connection,
1130 socket_address,
1131 principal,
1132 version,
1133 country_code_header.map(|header| header.to_string()),
1134 system_id_header.map(|header| header.to_string()),
1135 None,
1136 Executor::Production,
1137 )
1138 .await;
1139 }
1140 })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145 let connections_metric = CONNECTIONS_METRIC
1146 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148 let connections = server
1149 .connection_pool
1150 .lock()
1151 .connections()
1152 .filter(|connection| !connection.admin)
1153 .count();
1154 connections_metric.set(connections as _);
1155
1156 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158 register_int_gauge!(
1159 "shared_projects",
1160 "number of open projects with one or more guests"
1161 )
1162 .unwrap()
1163 });
1164
1165 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166 shared_projects_metric.set(shared_projects as _);
1167
1168 let encoder = prometheus::TextEncoder::new();
1169 let metric_families = prometheus::gather();
1170 let encoded_metrics = encoder
1171 .encode_to_string(&metric_families)
1172 .map_err(|err| anyhow!("{}", err))?;
1173 Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178 session: Session,
1179 mut teardown: watch::Receiver<bool>,
1180 executor: Executor,
1181) -> Result<()> {
1182 session.peer.disconnect(session.connection_id);
1183 session
1184 .connection_pool()
1185 .await
1186 .remove_connection(session.connection_id)?;
1187
1188 session
1189 .db()
1190 .await
1191 .connection_lost(session.connection_id)
1192 .await
1193 .trace_err();
1194
1195 futures::select_biased! {
1196 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199 leave_room_for_session(&session, session.connection_id).await.trace_err();
1200 leave_channel_buffers_for_session(&session)
1201 .await
1202 .trace_err();
1203
1204 if !session
1205 .connection_pool()
1206 .await
1207 .is_user_online(session.user_id())
1208 {
1209 let db = session.db().await;
1210 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211 room_updated(&room, &session.peer);
1212 }
1213 }
1214
1215 update_user_contacts(session.user_id(), &session).await?;
1216 },
1217 _ = teardown.changed().fuse() => {}
1218 }
1219
1220 Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225 response.send(proto::Ack {})?;
1226 Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231 _request: proto::CreateRoom,
1232 response: Response<proto::CreateRoom>,
1233 session: Session,
1234) -> Result<()> {
1235 let livekit_room = nanoid::nanoid!(30);
1236
1237 let live_kit_connection_info = util::maybe!(async {
1238 let live_kit = session.app_state.livekit_client.as_ref();
1239 let live_kit = live_kit?;
1240 let user_id = session.user_id().to_string();
1241
1242 let token = live_kit
1243 .room_token(&livekit_room, &user_id.to_string())
1244 .trace_err()?;
1245
1246 Some(proto::LiveKitConnectionInfo {
1247 server_url: live_kit.url().into(),
1248 token,
1249 can_publish: true,
1250 })
1251 })
1252 .await;
1253
1254 let room = session
1255 .db()
1256 .await
1257 .create_room(session.user_id(), session.connection_id, &livekit_room)
1258 .await?;
1259
1260 response.send(proto::CreateRoomResponse {
1261 room: Some(room.clone()),
1262 live_kit_connection_info,
1263 })?;
1264
1265 update_user_contacts(session.user_id(), &session).await?;
1266 Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271 request: proto::JoinRoom,
1272 response: Response<proto::JoinRoom>,
1273 session: Session,
1274) -> Result<()> {
1275 let room_id = RoomId::from_proto(request.id);
1276
1277 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279 if let Some(channel_id) = channel_id {
1280 return join_channel_internal(channel_id, Box::new(response), session).await;
1281 }
1282
1283 let joined_room = {
1284 let room = session
1285 .db()
1286 .await
1287 .join_room(room_id, session.user_id(), session.connection_id)
1288 .await?;
1289 room_updated(&room.room, &session.peer);
1290 room.into_inner()
1291 };
1292
1293 for connection_id in session
1294 .connection_pool()
1295 .await
1296 .user_connection_ids(session.user_id())
1297 {
1298 session
1299 .peer
1300 .send(
1301 connection_id,
1302 proto::CallCanceled {
1303 room_id: room_id.to_proto(),
1304 },
1305 )
1306 .trace_err();
1307 }
1308
1309 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310 {
1311 live_kit
1312 .room_token(
1313 &joined_room.room.livekit_room,
1314 &session.user_id().to_string(),
1315 )
1316 .trace_err()
1317 .map(|token| proto::LiveKitConnectionInfo {
1318 server_url: live_kit.url().into(),
1319 token,
1320 can_publish: true,
1321 })
1322 } else {
1323 None
1324 };
1325
1326 response.send(proto::JoinRoomResponse {
1327 room: Some(joined_room.room),
1328 channel_id: None,
1329 live_kit_connection_info,
1330 })?;
1331
1332 update_user_contacts(session.user_id(), &session).await?;
1333 Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338 request: proto::RejoinRoom,
1339 response: Response<proto::RejoinRoom>,
1340 session: Session,
1341) -> Result<()> {
1342 let room;
1343 let channel;
1344 {
1345 let mut rejoined_room = session
1346 .db()
1347 .await
1348 .rejoin_room(request, session.user_id(), session.connection_id)
1349 .await?;
1350
1351 response.send(proto::RejoinRoomResponse {
1352 room: Some(rejoined_room.room.clone()),
1353 reshared_projects: rejoined_room
1354 .reshared_projects
1355 .iter()
1356 .map(|project| proto::ResharedProject {
1357 id: project.id.to_proto(),
1358 collaborators: project
1359 .collaborators
1360 .iter()
1361 .map(|collaborator| collaborator.to_proto())
1362 .collect(),
1363 })
1364 .collect(),
1365 rejoined_projects: rejoined_room
1366 .rejoined_projects
1367 .iter()
1368 .map(|rejoined_project| rejoined_project.to_proto())
1369 .collect(),
1370 })?;
1371 room_updated(&rejoined_room.room, &session.peer);
1372
1373 for project in &rejoined_room.reshared_projects {
1374 for collaborator in &project.collaborators {
1375 session
1376 .peer
1377 .send(
1378 collaborator.connection_id,
1379 proto::UpdateProjectCollaborator {
1380 project_id: project.id.to_proto(),
1381 old_peer_id: Some(project.old_connection_id.into()),
1382 new_peer_id: Some(session.connection_id.into()),
1383 },
1384 )
1385 .trace_err();
1386 }
1387
1388 broadcast(
1389 Some(session.connection_id),
1390 project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.connection_id),
1394 |connection_id| {
1395 session.peer.forward_send(
1396 session.connection_id,
1397 connection_id,
1398 proto::UpdateProject {
1399 project_id: project.id.to_proto(),
1400 worktrees: project.worktrees.clone(),
1401 },
1402 )
1403 },
1404 );
1405 }
1406
1407 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409 let rejoined_room = rejoined_room.into_inner();
1410
1411 room = rejoined_room.room;
1412 channel = rejoined_room.channel;
1413 }
1414
1415 if let Some(channel) = channel {
1416 channel_updated(
1417 &channel,
1418 &room,
1419 &session.peer,
1420 &*session.connection_pool().await,
1421 );
1422 }
1423
1424 update_user_contacts(session.user_id(), &session).await?;
1425 Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429 rejoined_projects: &mut Vec<RejoinedProject>,
1430 session: &Session,
1431) -> Result<()> {
1432 for project in rejoined_projects.iter() {
1433 for collaborator in &project.collaborators {
1434 session
1435 .peer
1436 .send(
1437 collaborator.connection_id,
1438 proto::UpdateProjectCollaborator {
1439 project_id: project.id.to_proto(),
1440 old_peer_id: Some(project.old_connection_id.into()),
1441 new_peer_id: Some(session.connection_id.into()),
1442 },
1443 )
1444 .trace_err();
1445 }
1446 }
1447
1448 for project in rejoined_projects {
1449 for worktree in mem::take(&mut project.worktrees) {
1450 // Stream this worktree's entries.
1451 let message = proto::UpdateWorktree {
1452 project_id: project.id.to_proto(),
1453 worktree_id: worktree.id,
1454 abs_path: worktree.abs_path.clone(),
1455 root_name: worktree.root_name,
1456 updated_entries: worktree.updated_entries,
1457 removed_entries: worktree.removed_entries,
1458 scan_id: worktree.scan_id,
1459 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460 updated_repositories: worktree.updated_repositories,
1461 removed_repositories: worktree.removed_repositories,
1462 };
1463 for update in proto::split_worktree_update(message) {
1464 session.peer.send(session.connection_id, update.clone())?;
1465 }
1466
1467 // Stream this worktree's diagnostics.
1468 for summary in worktree.diagnostic_summaries {
1469 session.peer.send(
1470 session.connection_id,
1471 proto::UpdateDiagnosticSummary {
1472 project_id: project.id.to_proto(),
1473 worktree_id: worktree.id,
1474 summary: Some(summary),
1475 },
1476 )?;
1477 }
1478
1479 for settings_file in worktree.settings_files {
1480 session.peer.send(
1481 session.connection_id,
1482 proto::UpdateWorktreeSettings {
1483 project_id: project.id.to_proto(),
1484 worktree_id: worktree.id,
1485 path: settings_file.path,
1486 content: Some(settings_file.content),
1487 kind: Some(settings_file.kind.to_proto().into()),
1488 },
1489 )?;
1490 }
1491 }
1492
1493 for language_server in &project.language_servers {
1494 session.peer.send(
1495 session.connection_id,
1496 proto::UpdateLanguageServer {
1497 project_id: project.id.to_proto(),
1498 language_server_id: language_server.id,
1499 variant: Some(
1500 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1501 proto::LspDiskBasedDiagnosticsUpdated {},
1502 ),
1503 ),
1504 },
1505 )?;
1506 }
1507 }
1508 Ok(())
1509}
1510
1511/// leave room disconnects from the room.
1512async fn leave_room(
1513 _: proto::LeaveRoom,
1514 response: Response<proto::LeaveRoom>,
1515 session: Session,
1516) -> Result<()> {
1517 leave_room_for_session(&session, session.connection_id).await?;
1518 response.send(proto::Ack {})?;
1519 Ok(())
1520}
1521
1522/// Updates the permissions of someone else in the room.
1523async fn set_room_participant_role(
1524 request: proto::SetRoomParticipantRole,
1525 response: Response<proto::SetRoomParticipantRole>,
1526 session: Session,
1527) -> Result<()> {
1528 let user_id = UserId::from_proto(request.user_id);
1529 let role = ChannelRole::from(request.role());
1530
1531 let (livekit_room, can_publish) = {
1532 let room = session
1533 .db()
1534 .await
1535 .set_room_participant_role(
1536 session.user_id(),
1537 RoomId::from_proto(request.room_id),
1538 user_id,
1539 role,
1540 )
1541 .await?;
1542
1543 let livekit_room = room.livekit_room.clone();
1544 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1545 room_updated(&room, &session.peer);
1546 (livekit_room, can_publish)
1547 };
1548
1549 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1550 live_kit
1551 .update_participant(
1552 livekit_room.clone(),
1553 request.user_id.to_string(),
1554 livekit_api::proto::ParticipantPermission {
1555 can_subscribe: true,
1556 can_publish,
1557 can_publish_data: can_publish,
1558 hidden: false,
1559 recorder: false,
1560 },
1561 )
1562 .await
1563 .trace_err();
1564 }
1565
1566 response.send(proto::Ack {})?;
1567 Ok(())
1568}
1569
1570/// Call someone else into the current room
1571async fn call(
1572 request: proto::Call,
1573 response: Response<proto::Call>,
1574 session: Session,
1575) -> Result<()> {
1576 let room_id = RoomId::from_proto(request.room_id);
1577 let calling_user_id = session.user_id();
1578 let calling_connection_id = session.connection_id;
1579 let called_user_id = UserId::from_proto(request.called_user_id);
1580 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1581 if !session
1582 .db()
1583 .await
1584 .has_contact(calling_user_id, called_user_id)
1585 .await?
1586 {
1587 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1588 }
1589
1590 let incoming_call = {
1591 let (room, incoming_call) = &mut *session
1592 .db()
1593 .await
1594 .call(
1595 room_id,
1596 calling_user_id,
1597 calling_connection_id,
1598 called_user_id,
1599 initial_project_id,
1600 )
1601 .await?;
1602 room_updated(room, &session.peer);
1603 mem::take(incoming_call)
1604 };
1605 update_user_contacts(called_user_id, &session).await?;
1606
1607 let mut calls = session
1608 .connection_pool()
1609 .await
1610 .user_connection_ids(called_user_id)
1611 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1612 .collect::<FuturesUnordered<_>>();
1613
1614 while let Some(call_response) = calls.next().await {
1615 match call_response.as_ref() {
1616 Ok(_) => {
1617 response.send(proto::Ack {})?;
1618 return Ok(());
1619 }
1620 Err(_) => {
1621 call_response.trace_err();
1622 }
1623 }
1624 }
1625
1626 {
1627 let room = session
1628 .db()
1629 .await
1630 .call_failed(room_id, called_user_id)
1631 .await?;
1632 room_updated(&room, &session.peer);
1633 }
1634 update_user_contacts(called_user_id, &session).await?;
1635
1636 Err(anyhow!("failed to ring user"))?
1637}
1638
1639/// Cancel an outgoing call.
1640async fn cancel_call(
1641 request: proto::CancelCall,
1642 response: Response<proto::CancelCall>,
1643 session: Session,
1644) -> Result<()> {
1645 let called_user_id = UserId::from_proto(request.called_user_id);
1646 let room_id = RoomId::from_proto(request.room_id);
1647 {
1648 let room = session
1649 .db()
1650 .await
1651 .cancel_call(room_id, session.connection_id, called_user_id)
1652 .await?;
1653 room_updated(&room, &session.peer);
1654 }
1655
1656 for connection_id in session
1657 .connection_pool()
1658 .await
1659 .user_connection_ids(called_user_id)
1660 {
1661 session
1662 .peer
1663 .send(
1664 connection_id,
1665 proto::CallCanceled {
1666 room_id: room_id.to_proto(),
1667 },
1668 )
1669 .trace_err();
1670 }
1671 response.send(proto::Ack {})?;
1672
1673 update_user_contacts(called_user_id, &session).await?;
1674 Ok(())
1675}
1676
1677/// Decline an incoming call.
1678async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1679 let room_id = RoomId::from_proto(message.room_id);
1680 {
1681 let room = session
1682 .db()
1683 .await
1684 .decline_call(Some(room_id), session.user_id())
1685 .await?
1686 .ok_or_else(|| anyhow!("failed to decline call"))?;
1687 room_updated(&room, &session.peer);
1688 }
1689
1690 for connection_id in session
1691 .connection_pool()
1692 .await
1693 .user_connection_ids(session.user_id())
1694 {
1695 session
1696 .peer
1697 .send(
1698 connection_id,
1699 proto::CallCanceled {
1700 room_id: room_id.to_proto(),
1701 },
1702 )
1703 .trace_err();
1704 }
1705 update_user_contacts(session.user_id(), &session).await?;
1706 Ok(())
1707}
1708
1709/// Updates other participants in the room with your current location.
1710async fn update_participant_location(
1711 request: proto::UpdateParticipantLocation,
1712 response: Response<proto::UpdateParticipantLocation>,
1713 session: Session,
1714) -> Result<()> {
1715 let room_id = RoomId::from_proto(request.room_id);
1716 let location = request
1717 .location
1718 .ok_or_else(|| anyhow!("invalid location"))?;
1719
1720 let db = session.db().await;
1721 let room = db
1722 .update_room_participant_location(room_id, session.connection_id, location)
1723 .await?;
1724
1725 room_updated(&room, &session.peer);
1726 response.send(proto::Ack {})?;
1727 Ok(())
1728}
1729
1730/// Share a project into the room.
1731async fn share_project(
1732 request: proto::ShareProject,
1733 response: Response<proto::ShareProject>,
1734 session: Session,
1735) -> Result<()> {
1736 let (project_id, room) = &*session
1737 .db()
1738 .await
1739 .share_project(
1740 RoomId::from_proto(request.room_id),
1741 session.connection_id,
1742 &request.worktrees,
1743 request.is_ssh_project,
1744 )
1745 .await?;
1746 response.send(proto::ShareProjectResponse {
1747 project_id: project_id.to_proto(),
1748 })?;
1749 room_updated(room, &session.peer);
1750
1751 Ok(())
1752}
1753
1754/// Unshare a project from the room.
1755async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1756 let project_id = ProjectId::from_proto(message.project_id);
1757 unshare_project_internal(project_id, session.connection_id, &session).await
1758}
1759
1760async fn unshare_project_internal(
1761 project_id: ProjectId,
1762 connection_id: ConnectionId,
1763 session: &Session,
1764) -> Result<()> {
1765 let delete = {
1766 let room_guard = session
1767 .db()
1768 .await
1769 .unshare_project(project_id, connection_id)
1770 .await?;
1771
1772 let (delete, room, guest_connection_ids) = &*room_guard;
1773
1774 let message = proto::UnshareProject {
1775 project_id: project_id.to_proto(),
1776 };
1777
1778 broadcast(
1779 Some(connection_id),
1780 guest_connection_ids.iter().copied(),
1781 |conn_id| session.peer.send(conn_id, message.clone()),
1782 );
1783 if let Some(room) = room {
1784 room_updated(room, &session.peer);
1785 }
1786
1787 *delete
1788 };
1789
1790 if delete {
1791 let db = session.db().await;
1792 db.delete_project(project_id).await?;
1793 }
1794
1795 Ok(())
1796}
1797
1798/// Join someone elses shared project.
1799async fn join_project(
1800 request: proto::JoinProject,
1801 response: Response<proto::JoinProject>,
1802 session: Session,
1803) -> Result<()> {
1804 let project_id = ProjectId::from_proto(request.project_id);
1805
1806 tracing::info!(%project_id, "join project");
1807
1808 let db = session.db().await;
1809 let (project, replica_id) = &mut *db
1810 .join_project(project_id, session.connection_id, session.user_id())
1811 .await?;
1812 drop(db);
1813 tracing::info!(%project_id, "join remote project");
1814 join_project_internal(response, session, project, replica_id)
1815}
1816
1817trait JoinProjectInternalResponse {
1818 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1819}
1820impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1821 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1822 Response::<proto::JoinProject>::send(self, result)
1823 }
1824}
1825
1826fn join_project_internal(
1827 response: impl JoinProjectInternalResponse,
1828 session: Session,
1829 project: &mut Project,
1830 replica_id: &ReplicaId,
1831) -> Result<()> {
1832 let collaborators = project
1833 .collaborators
1834 .iter()
1835 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1836 .map(|collaborator| collaborator.to_proto())
1837 .collect::<Vec<_>>();
1838 let project_id = project.id;
1839 let guest_user_id = session.user_id();
1840
1841 let worktrees = project
1842 .worktrees
1843 .iter()
1844 .map(|(id, worktree)| proto::WorktreeMetadata {
1845 id: *id,
1846 root_name: worktree.root_name.clone(),
1847 visible: worktree.visible,
1848 abs_path: worktree.abs_path.clone(),
1849 })
1850 .collect::<Vec<_>>();
1851
1852 let add_project_collaborator = proto::AddProjectCollaborator {
1853 project_id: project_id.to_proto(),
1854 collaborator: Some(proto::Collaborator {
1855 peer_id: Some(session.connection_id.into()),
1856 replica_id: replica_id.0 as u32,
1857 user_id: guest_user_id.to_proto(),
1858 is_host: false,
1859 }),
1860 };
1861
1862 for collaborator in &collaborators {
1863 session
1864 .peer
1865 .send(
1866 collaborator.peer_id.unwrap().into(),
1867 add_project_collaborator.clone(),
1868 )
1869 .trace_err();
1870 }
1871
1872 // First, we send the metadata associated with each worktree.
1873 response.send(proto::JoinProjectResponse {
1874 project_id: project.id.0 as u64,
1875 worktrees: worktrees.clone(),
1876 replica_id: replica_id.0 as u32,
1877 collaborators: collaborators.clone(),
1878 language_servers: project.language_servers.clone(),
1879 role: project.role.into(),
1880 })?;
1881
1882 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1883 // Stream this worktree's entries.
1884 let message = proto::UpdateWorktree {
1885 project_id: project_id.to_proto(),
1886 worktree_id,
1887 abs_path: worktree.abs_path.clone(),
1888 root_name: worktree.root_name,
1889 updated_entries: worktree.entries,
1890 removed_entries: Default::default(),
1891 scan_id: worktree.scan_id,
1892 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1893 updated_repositories: worktree.repository_entries.into_values().collect(),
1894 removed_repositories: Default::default(),
1895 };
1896 for update in proto::split_worktree_update(message) {
1897 session.peer.send(session.connection_id, update.clone())?;
1898 }
1899
1900 // Stream this worktree's diagnostics.
1901 for summary in worktree.diagnostic_summaries {
1902 session.peer.send(
1903 session.connection_id,
1904 proto::UpdateDiagnosticSummary {
1905 project_id: project_id.to_proto(),
1906 worktree_id: worktree.id,
1907 summary: Some(summary),
1908 },
1909 )?;
1910 }
1911
1912 for settings_file in worktree.settings_files {
1913 session.peer.send(
1914 session.connection_id,
1915 proto::UpdateWorktreeSettings {
1916 project_id: project_id.to_proto(),
1917 worktree_id: worktree.id,
1918 path: settings_file.path,
1919 content: Some(settings_file.content),
1920 kind: Some(settings_file.kind.to_proto() as i32),
1921 },
1922 )?;
1923 }
1924 }
1925
1926 for language_server in &project.language_servers {
1927 session.peer.send(
1928 session.connection_id,
1929 proto::UpdateLanguageServer {
1930 project_id: project_id.to_proto(),
1931 language_server_id: language_server.id,
1932 variant: Some(
1933 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1934 proto::LspDiskBasedDiagnosticsUpdated {},
1935 ),
1936 ),
1937 },
1938 )?;
1939 }
1940
1941 Ok(())
1942}
1943
1944/// Leave someone elses shared project.
1945async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1946 let sender_id = session.connection_id;
1947 let project_id = ProjectId::from_proto(request.project_id);
1948 let db = session.db().await;
1949
1950 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1951 tracing::info!(
1952 %project_id,
1953 "leave project"
1954 );
1955
1956 project_left(project, &session);
1957 if let Some(room) = room {
1958 room_updated(room, &session.peer);
1959 }
1960
1961 Ok(())
1962}
1963
1964/// Updates other participants with changes to the project
1965async fn update_project(
1966 request: proto::UpdateProject,
1967 response: Response<proto::UpdateProject>,
1968 session: Session,
1969) -> Result<()> {
1970 let project_id = ProjectId::from_proto(request.project_id);
1971 let (room, guest_connection_ids) = &*session
1972 .db()
1973 .await
1974 .update_project(project_id, session.connection_id, &request.worktrees)
1975 .await?;
1976 broadcast(
1977 Some(session.connection_id),
1978 guest_connection_ids.iter().copied(),
1979 |connection_id| {
1980 session
1981 .peer
1982 .forward_send(session.connection_id, connection_id, request.clone())
1983 },
1984 );
1985 if let Some(room) = room {
1986 room_updated(room, &session.peer);
1987 }
1988 response.send(proto::Ack {})?;
1989
1990 Ok(())
1991}
1992
1993/// Updates other participants with changes to the worktree
1994async fn update_worktree(
1995 request: proto::UpdateWorktree,
1996 response: Response<proto::UpdateWorktree>,
1997 session: Session,
1998) -> Result<()> {
1999 let guest_connection_ids = session
2000 .db()
2001 .await
2002 .update_worktree(&request, session.connection_id)
2003 .await?;
2004
2005 broadcast(
2006 Some(session.connection_id),
2007 guest_connection_ids.iter().copied(),
2008 |connection_id| {
2009 session
2010 .peer
2011 .forward_send(session.connection_id, connection_id, request.clone())
2012 },
2013 );
2014 response.send(proto::Ack {})?;
2015 Ok(())
2016}
2017
2018/// Updates other participants with changes to the diagnostics
2019async fn update_diagnostic_summary(
2020 message: proto::UpdateDiagnosticSummary,
2021 session: Session,
2022) -> Result<()> {
2023 let guest_connection_ids = session
2024 .db()
2025 .await
2026 .update_diagnostic_summary(&message, session.connection_id)
2027 .await?;
2028
2029 broadcast(
2030 Some(session.connection_id),
2031 guest_connection_ids.iter().copied(),
2032 |connection_id| {
2033 session
2034 .peer
2035 .forward_send(session.connection_id, connection_id, message.clone())
2036 },
2037 );
2038
2039 Ok(())
2040}
2041
2042/// Updates other participants with changes to the worktree settings
2043async fn update_worktree_settings(
2044 message: proto::UpdateWorktreeSettings,
2045 session: Session,
2046) -> Result<()> {
2047 let guest_connection_ids = session
2048 .db()
2049 .await
2050 .update_worktree_settings(&message, session.connection_id)
2051 .await?;
2052
2053 broadcast(
2054 Some(session.connection_id),
2055 guest_connection_ids.iter().copied(),
2056 |connection_id| {
2057 session
2058 .peer
2059 .forward_send(session.connection_id, connection_id, message.clone())
2060 },
2061 );
2062
2063 Ok(())
2064}
2065
2066/// Notify other participants that a language server has started.
2067async fn start_language_server(
2068 request: proto::StartLanguageServer,
2069 session: Session,
2070) -> Result<()> {
2071 let guest_connection_ids = session
2072 .db()
2073 .await
2074 .start_language_server(&request, session.connection_id)
2075 .await?;
2076
2077 broadcast(
2078 Some(session.connection_id),
2079 guest_connection_ids.iter().copied(),
2080 |connection_id| {
2081 session
2082 .peer
2083 .forward_send(session.connection_id, connection_id, request.clone())
2084 },
2085 );
2086 Ok(())
2087}
2088
2089/// Notify other participants that a language server has changed.
2090async fn update_language_server(
2091 request: proto::UpdateLanguageServer,
2092 session: Session,
2093) -> Result<()> {
2094 let project_id = ProjectId::from_proto(request.project_id);
2095 let project_connection_ids = session
2096 .db()
2097 .await
2098 .project_connection_ids(project_id, session.connection_id, true)
2099 .await?;
2100 broadcast(
2101 Some(session.connection_id),
2102 project_connection_ids.iter().copied(),
2103 |connection_id| {
2104 session
2105 .peer
2106 .forward_send(session.connection_id, connection_id, request.clone())
2107 },
2108 );
2109 Ok(())
2110}
2111
2112/// forward a project request to the host. These requests should be read only
2113/// as guests are allowed to send them.
2114async fn forward_read_only_project_request<T>(
2115 request: T,
2116 response: Response<T>,
2117 session: Session,
2118) -> Result<()>
2119where
2120 T: EntityMessage + RequestMessage,
2121{
2122 let project_id = ProjectId::from_proto(request.remote_entity_id());
2123 let host_connection_id = session
2124 .db()
2125 .await
2126 .host_for_read_only_project_request(project_id, session.connection_id)
2127 .await?;
2128 let payload = session
2129 .peer
2130 .forward_request(session.connection_id, host_connection_id, request)
2131 .await?;
2132 response.send(payload)?;
2133 Ok(())
2134}
2135
2136async fn forward_find_search_candidates_request(
2137 request: proto::FindSearchCandidates,
2138 response: Response<proto::FindSearchCandidates>,
2139 session: Session,
2140) -> Result<()> {
2141 let project_id = ProjectId::from_proto(request.remote_entity_id());
2142 let host_connection_id = session
2143 .db()
2144 .await
2145 .host_for_read_only_project_request(project_id, session.connection_id)
2146 .await?;
2147 let payload = session
2148 .peer
2149 .forward_request(session.connection_id, host_connection_id, request)
2150 .await?;
2151 response.send(payload)?;
2152 Ok(())
2153}
2154
2155/// forward a project request to the host. These requests are disallowed
2156/// for guests.
2157async fn forward_mutating_project_request<T>(
2158 request: T,
2159 response: Response<T>,
2160 session: Session,
2161) -> Result<()>
2162where
2163 T: EntityMessage + RequestMessage,
2164{
2165 let project_id = ProjectId::from_proto(request.remote_entity_id());
2166
2167 let host_connection_id = session
2168 .db()
2169 .await
2170 .host_for_mutating_project_request(project_id, session.connection_id)
2171 .await?;
2172 let payload = session
2173 .peer
2174 .forward_request(session.connection_id, host_connection_id, request)
2175 .await?;
2176 response.send(payload)?;
2177 Ok(())
2178}
2179
2180/// Notify other participants that a new buffer has been created
2181async fn create_buffer_for_peer(
2182 request: proto::CreateBufferForPeer,
2183 session: Session,
2184) -> Result<()> {
2185 session
2186 .db()
2187 .await
2188 .check_user_is_project_host(
2189 ProjectId::from_proto(request.project_id),
2190 session.connection_id,
2191 )
2192 .await?;
2193 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2194 session
2195 .peer
2196 .forward_send(session.connection_id, peer_id.into(), request)?;
2197 Ok(())
2198}
2199
2200/// Notify other participants that a buffer has been updated. This is
2201/// allowed for guests as long as the update is limited to selections.
2202async fn update_buffer(
2203 request: proto::UpdateBuffer,
2204 response: Response<proto::UpdateBuffer>,
2205 session: Session,
2206) -> Result<()> {
2207 let project_id = ProjectId::from_proto(request.project_id);
2208 let mut capability = Capability::ReadOnly;
2209
2210 for op in request.operations.iter() {
2211 match op.variant {
2212 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2213 Some(_) => capability = Capability::ReadWrite,
2214 }
2215 }
2216
2217 let host = {
2218 let guard = session
2219 .db()
2220 .await
2221 .connections_for_buffer_update(project_id, session.connection_id, capability)
2222 .await?;
2223
2224 let (host, guests) = &*guard;
2225
2226 broadcast(
2227 Some(session.connection_id),
2228 guests.clone(),
2229 |connection_id| {
2230 session
2231 .peer
2232 .forward_send(session.connection_id, connection_id, request.clone())
2233 },
2234 );
2235
2236 *host
2237 };
2238
2239 if host != session.connection_id {
2240 session
2241 .peer
2242 .forward_request(session.connection_id, host, request.clone())
2243 .await?;
2244 }
2245
2246 response.send(proto::Ack {})?;
2247 Ok(())
2248}
2249
2250async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2251 let project_id = ProjectId::from_proto(message.project_id);
2252
2253 let operation = message.operation.as_ref().context("invalid operation")?;
2254 let capability = match operation.variant.as_ref() {
2255 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2256 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2257 match buffer_op.variant {
2258 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2259 Capability::ReadOnly
2260 }
2261 _ => Capability::ReadWrite,
2262 }
2263 } else {
2264 Capability::ReadWrite
2265 }
2266 }
2267 Some(_) => Capability::ReadWrite,
2268 None => Capability::ReadOnly,
2269 };
2270
2271 let guard = session
2272 .db()
2273 .await
2274 .connections_for_buffer_update(project_id, session.connection_id, capability)
2275 .await?;
2276
2277 let (host, guests) = &*guard;
2278
2279 broadcast(
2280 Some(session.connection_id),
2281 guests.iter().chain([host]).copied(),
2282 |connection_id| {
2283 session
2284 .peer
2285 .forward_send(session.connection_id, connection_id, message.clone())
2286 },
2287 );
2288
2289 Ok(())
2290}
2291
2292/// Notify other participants that a project has been updated.
2293async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2294 request: T,
2295 session: Session,
2296) -> Result<()> {
2297 let project_id = ProjectId::from_proto(request.remote_entity_id());
2298 let project_connection_ids = session
2299 .db()
2300 .await
2301 .project_connection_ids(project_id, session.connection_id, false)
2302 .await?;
2303
2304 broadcast(
2305 Some(session.connection_id),
2306 project_connection_ids.iter().copied(),
2307 |connection_id| {
2308 session
2309 .peer
2310 .forward_send(session.connection_id, connection_id, request.clone())
2311 },
2312 );
2313 Ok(())
2314}
2315
2316/// Start following another user in a call.
2317async fn follow(
2318 request: proto::Follow,
2319 response: Response<proto::Follow>,
2320 session: Session,
2321) -> Result<()> {
2322 let room_id = RoomId::from_proto(request.room_id);
2323 let project_id = request.project_id.map(ProjectId::from_proto);
2324 let leader_id = request
2325 .leader_id
2326 .ok_or_else(|| anyhow!("invalid leader id"))?
2327 .into();
2328 let follower_id = session.connection_id;
2329
2330 session
2331 .db()
2332 .await
2333 .check_room_participants(room_id, leader_id, session.connection_id)
2334 .await?;
2335
2336 let response_payload = session
2337 .peer
2338 .forward_request(session.connection_id, leader_id, request)
2339 .await?;
2340 response.send(response_payload)?;
2341
2342 if let Some(project_id) = project_id {
2343 let room = session
2344 .db()
2345 .await
2346 .follow(room_id, project_id, leader_id, follower_id)
2347 .await?;
2348 room_updated(&room, &session.peer);
2349 }
2350
2351 Ok(())
2352}
2353
2354/// Stop following another user in a call.
2355async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2356 let room_id = RoomId::from_proto(request.room_id);
2357 let project_id = request.project_id.map(ProjectId::from_proto);
2358 let leader_id = request
2359 .leader_id
2360 .ok_or_else(|| anyhow!("invalid leader id"))?
2361 .into();
2362 let follower_id = session.connection_id;
2363
2364 session
2365 .db()
2366 .await
2367 .check_room_participants(room_id, leader_id, session.connection_id)
2368 .await?;
2369
2370 session
2371 .peer
2372 .forward_send(session.connection_id, leader_id, request)?;
2373
2374 if let Some(project_id) = project_id {
2375 let room = session
2376 .db()
2377 .await
2378 .unfollow(room_id, project_id, leader_id, follower_id)
2379 .await?;
2380 room_updated(&room, &session.peer);
2381 }
2382
2383 Ok(())
2384}
2385
2386/// Notify everyone following you of your current location.
2387async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2388 let room_id = RoomId::from_proto(request.room_id);
2389 let database = session.db.lock().await;
2390
2391 let connection_ids = if let Some(project_id) = request.project_id {
2392 let project_id = ProjectId::from_proto(project_id);
2393 database
2394 .project_connection_ids(project_id, session.connection_id, true)
2395 .await?
2396 } else {
2397 database
2398 .room_connection_ids(room_id, session.connection_id)
2399 .await?
2400 };
2401
2402 // For now, don't send view update messages back to that view's current leader.
2403 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2404 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2405 _ => None,
2406 });
2407
2408 for connection_id in connection_ids.iter().cloned() {
2409 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2410 session
2411 .peer
2412 .forward_send(session.connection_id, connection_id, request.clone())?;
2413 }
2414 }
2415 Ok(())
2416}
2417
2418/// Get public data about users.
2419async fn get_users(
2420 request: proto::GetUsers,
2421 response: Response<proto::GetUsers>,
2422 session: Session,
2423) -> Result<()> {
2424 let user_ids = request
2425 .user_ids
2426 .into_iter()
2427 .map(UserId::from_proto)
2428 .collect();
2429 let users = session
2430 .db()
2431 .await
2432 .get_users_by_ids(user_ids)
2433 .await?
2434 .into_iter()
2435 .map(|user| proto::User {
2436 id: user.id.to_proto(),
2437 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2438 github_login: user.github_login,
2439 email: user.email_address,
2440 name: user.name,
2441 })
2442 .collect();
2443 response.send(proto::UsersResponse { users })?;
2444 Ok(())
2445}
2446
2447/// Search for users (to invite) buy Github login
2448async fn fuzzy_search_users(
2449 request: proto::FuzzySearchUsers,
2450 response: Response<proto::FuzzySearchUsers>,
2451 session: Session,
2452) -> Result<()> {
2453 let query = request.query;
2454 let users = match query.len() {
2455 0 => vec![],
2456 1 | 2 => session
2457 .db()
2458 .await
2459 .get_user_by_github_login(&query)
2460 .await?
2461 .into_iter()
2462 .collect(),
2463 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2464 };
2465 let users = users
2466 .into_iter()
2467 .filter(|user| user.id != session.user_id())
2468 .map(|user| proto::User {
2469 id: user.id.to_proto(),
2470 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2471 github_login: user.github_login,
2472 name: user.name,
2473 email: user.email_address,
2474 })
2475 .collect();
2476 response.send(proto::UsersResponse { users })?;
2477 Ok(())
2478}
2479
2480/// Send a contact request to another user.
2481async fn request_contact(
2482 request: proto::RequestContact,
2483 response: Response<proto::RequestContact>,
2484 session: Session,
2485) -> Result<()> {
2486 let requester_id = session.user_id();
2487 let responder_id = UserId::from_proto(request.responder_id);
2488 if requester_id == responder_id {
2489 return Err(anyhow!("cannot add yourself as a contact"))?;
2490 }
2491
2492 let notifications = session
2493 .db()
2494 .await
2495 .send_contact_request(requester_id, responder_id)
2496 .await?;
2497
2498 // Update outgoing contact requests of requester
2499 let mut update = proto::UpdateContacts::default();
2500 update.outgoing_requests.push(responder_id.to_proto());
2501 for connection_id in session
2502 .connection_pool()
2503 .await
2504 .user_connection_ids(requester_id)
2505 {
2506 session.peer.send(connection_id, update.clone())?;
2507 }
2508
2509 // Update incoming contact requests of responder
2510 let mut update = proto::UpdateContacts::default();
2511 update
2512 .incoming_requests
2513 .push(proto::IncomingContactRequest {
2514 requester_id: requester_id.to_proto(),
2515 });
2516 let connection_pool = session.connection_pool().await;
2517 for connection_id in connection_pool.user_connection_ids(responder_id) {
2518 session.peer.send(connection_id, update.clone())?;
2519 }
2520
2521 send_notifications(&connection_pool, &session.peer, notifications);
2522
2523 response.send(proto::Ack {})?;
2524 Ok(())
2525}
2526
2527/// Accept or decline a contact request
2528async fn respond_to_contact_request(
2529 request: proto::RespondToContactRequest,
2530 response: Response<proto::RespondToContactRequest>,
2531 session: Session,
2532) -> Result<()> {
2533 let responder_id = session.user_id();
2534 let requester_id = UserId::from_proto(request.requester_id);
2535 let db = session.db().await;
2536 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2537 db.dismiss_contact_notification(responder_id, requester_id)
2538 .await?;
2539 } else {
2540 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2541
2542 let notifications = db
2543 .respond_to_contact_request(responder_id, requester_id, accept)
2544 .await?;
2545 let requester_busy = db.is_user_busy(requester_id).await?;
2546 let responder_busy = db.is_user_busy(responder_id).await?;
2547
2548 let pool = session.connection_pool().await;
2549 // Update responder with new contact
2550 let mut update = proto::UpdateContacts::default();
2551 if accept {
2552 update
2553 .contacts
2554 .push(contact_for_user(requester_id, requester_busy, &pool));
2555 }
2556 update
2557 .remove_incoming_requests
2558 .push(requester_id.to_proto());
2559 for connection_id in pool.user_connection_ids(responder_id) {
2560 session.peer.send(connection_id, update.clone())?;
2561 }
2562
2563 // Update requester with new contact
2564 let mut update = proto::UpdateContacts::default();
2565 if accept {
2566 update
2567 .contacts
2568 .push(contact_for_user(responder_id, responder_busy, &pool));
2569 }
2570 update
2571 .remove_outgoing_requests
2572 .push(responder_id.to_proto());
2573
2574 for connection_id in pool.user_connection_ids(requester_id) {
2575 session.peer.send(connection_id, update.clone())?;
2576 }
2577
2578 send_notifications(&pool, &session.peer, notifications);
2579 }
2580
2581 response.send(proto::Ack {})?;
2582 Ok(())
2583}
2584
2585/// Remove a contact.
2586async fn remove_contact(
2587 request: proto::RemoveContact,
2588 response: Response<proto::RemoveContact>,
2589 session: Session,
2590) -> Result<()> {
2591 let requester_id = session.user_id();
2592 let responder_id = UserId::from_proto(request.user_id);
2593 let db = session.db().await;
2594 let (contact_accepted, deleted_notification_id) =
2595 db.remove_contact(requester_id, responder_id).await?;
2596
2597 let pool = session.connection_pool().await;
2598 // Update outgoing contact requests of requester
2599 let mut update = proto::UpdateContacts::default();
2600 if contact_accepted {
2601 update.remove_contacts.push(responder_id.to_proto());
2602 } else {
2603 update
2604 .remove_outgoing_requests
2605 .push(responder_id.to_proto());
2606 }
2607 for connection_id in pool.user_connection_ids(requester_id) {
2608 session.peer.send(connection_id, update.clone())?;
2609 }
2610
2611 // Update incoming contact requests of responder
2612 let mut update = proto::UpdateContacts::default();
2613 if contact_accepted {
2614 update.remove_contacts.push(requester_id.to_proto());
2615 } else {
2616 update
2617 .remove_incoming_requests
2618 .push(requester_id.to_proto());
2619 }
2620 for connection_id in pool.user_connection_ids(responder_id) {
2621 session.peer.send(connection_id, update.clone())?;
2622 if let Some(notification_id) = deleted_notification_id {
2623 session.peer.send(
2624 connection_id,
2625 proto::DeleteNotification {
2626 notification_id: notification_id.to_proto(),
2627 },
2628 )?;
2629 }
2630 }
2631
2632 response.send(proto::Ack {})?;
2633 Ok(())
2634}
2635
2636fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2637 version.0.minor() < 139
2638}
2639
2640async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2641 let plan = session.current_plan(&session.db().await).await?;
2642
2643 session
2644 .peer
2645 .send(
2646 session.connection_id,
2647 proto::UpdateUserPlan { plan: plan.into() },
2648 )
2649 .trace_err();
2650
2651 Ok(())
2652}
2653
2654async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2655 subscribe_user_to_channels(session.user_id(), &session).await?;
2656 Ok(())
2657}
2658
2659async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2660 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2661 let mut pool = session.connection_pool().await;
2662 for membership in &channels_for_user.channel_memberships {
2663 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2664 }
2665 session.peer.send(
2666 session.connection_id,
2667 build_update_user_channels(&channels_for_user),
2668 )?;
2669 session.peer.send(
2670 session.connection_id,
2671 build_channels_update(channels_for_user),
2672 )?;
2673 Ok(())
2674}
2675
2676/// Creates a new channel.
2677async fn create_channel(
2678 request: proto::CreateChannel,
2679 response: Response<proto::CreateChannel>,
2680 session: Session,
2681) -> Result<()> {
2682 let db = session.db().await;
2683
2684 let parent_id = request.parent_id.map(ChannelId::from_proto);
2685 let (channel, membership) = db
2686 .create_channel(&request.name, parent_id, session.user_id())
2687 .await?;
2688
2689 let root_id = channel.root_id();
2690 let channel = Channel::from_model(channel);
2691
2692 response.send(proto::CreateChannelResponse {
2693 channel: Some(channel.to_proto()),
2694 parent_id: request.parent_id,
2695 })?;
2696
2697 let mut connection_pool = session.connection_pool().await;
2698 if let Some(membership) = membership {
2699 connection_pool.subscribe_to_channel(
2700 membership.user_id,
2701 membership.channel_id,
2702 membership.role,
2703 );
2704 let update = proto::UpdateUserChannels {
2705 channel_memberships: vec![proto::ChannelMembership {
2706 channel_id: membership.channel_id.to_proto(),
2707 role: membership.role.into(),
2708 }],
2709 ..Default::default()
2710 };
2711 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2712 session.peer.send(connection_id, update.clone())?;
2713 }
2714 }
2715
2716 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2717 if !role.can_see_channel(channel.visibility) {
2718 continue;
2719 }
2720
2721 let update = proto::UpdateChannels {
2722 channels: vec![channel.to_proto()],
2723 ..Default::default()
2724 };
2725 session.peer.send(connection_id, update.clone())?;
2726 }
2727
2728 Ok(())
2729}
2730
2731/// Delete a channel
2732async fn delete_channel(
2733 request: proto::DeleteChannel,
2734 response: Response<proto::DeleteChannel>,
2735 session: Session,
2736) -> Result<()> {
2737 let db = session.db().await;
2738
2739 let channel_id = request.channel_id;
2740 let (root_channel, removed_channels) = db
2741 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2742 .await?;
2743 response.send(proto::Ack {})?;
2744
2745 // Notify members of removed channels
2746 let mut update = proto::UpdateChannels::default();
2747 update
2748 .delete_channels
2749 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2750
2751 let connection_pool = session.connection_pool().await;
2752 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2753 session.peer.send(connection_id, update.clone())?;
2754 }
2755
2756 Ok(())
2757}
2758
2759/// Invite someone to join a channel.
2760async fn invite_channel_member(
2761 request: proto::InviteChannelMember,
2762 response: Response<proto::InviteChannelMember>,
2763 session: Session,
2764) -> Result<()> {
2765 let db = session.db().await;
2766 let channel_id = ChannelId::from_proto(request.channel_id);
2767 let invitee_id = UserId::from_proto(request.user_id);
2768 let InviteMemberResult {
2769 channel,
2770 notifications,
2771 } = db
2772 .invite_channel_member(
2773 channel_id,
2774 invitee_id,
2775 session.user_id(),
2776 request.role().into(),
2777 )
2778 .await?;
2779
2780 let update = proto::UpdateChannels {
2781 channel_invitations: vec![channel.to_proto()],
2782 ..Default::default()
2783 };
2784
2785 let connection_pool = session.connection_pool().await;
2786 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2787 session.peer.send(connection_id, update.clone())?;
2788 }
2789
2790 send_notifications(&connection_pool, &session.peer, notifications);
2791
2792 response.send(proto::Ack {})?;
2793 Ok(())
2794}
2795
2796/// remove someone from a channel
2797async fn remove_channel_member(
2798 request: proto::RemoveChannelMember,
2799 response: Response<proto::RemoveChannelMember>,
2800 session: Session,
2801) -> Result<()> {
2802 let db = session.db().await;
2803 let channel_id = ChannelId::from_proto(request.channel_id);
2804 let member_id = UserId::from_proto(request.user_id);
2805
2806 let RemoveChannelMemberResult {
2807 membership_update,
2808 notification_id,
2809 } = db
2810 .remove_channel_member(channel_id, member_id, session.user_id())
2811 .await?;
2812
2813 let mut connection_pool = session.connection_pool().await;
2814 notify_membership_updated(
2815 &mut connection_pool,
2816 membership_update,
2817 member_id,
2818 &session.peer,
2819 );
2820 for connection_id in connection_pool.user_connection_ids(member_id) {
2821 if let Some(notification_id) = notification_id {
2822 session
2823 .peer
2824 .send(
2825 connection_id,
2826 proto::DeleteNotification {
2827 notification_id: notification_id.to_proto(),
2828 },
2829 )
2830 .trace_err();
2831 }
2832 }
2833
2834 response.send(proto::Ack {})?;
2835 Ok(())
2836}
2837
2838/// Toggle the channel between public and private.
2839/// Care is taken to maintain the invariant that public channels only descend from public channels,
2840/// (though members-only channels can appear at any point in the hierarchy).
2841async fn set_channel_visibility(
2842 request: proto::SetChannelVisibility,
2843 response: Response<proto::SetChannelVisibility>,
2844 session: Session,
2845) -> Result<()> {
2846 let db = session.db().await;
2847 let channel_id = ChannelId::from_proto(request.channel_id);
2848 let visibility = request.visibility().into();
2849
2850 let channel_model = db
2851 .set_channel_visibility(channel_id, visibility, session.user_id())
2852 .await?;
2853 let root_id = channel_model.root_id();
2854 let channel = Channel::from_model(channel_model);
2855
2856 let mut connection_pool = session.connection_pool().await;
2857 for (user_id, role) in connection_pool
2858 .channel_user_ids(root_id)
2859 .collect::<Vec<_>>()
2860 .into_iter()
2861 {
2862 let update = if role.can_see_channel(channel.visibility) {
2863 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2864 proto::UpdateChannels {
2865 channels: vec![channel.to_proto()],
2866 ..Default::default()
2867 }
2868 } else {
2869 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2870 proto::UpdateChannels {
2871 delete_channels: vec![channel.id.to_proto()],
2872 ..Default::default()
2873 }
2874 };
2875
2876 for connection_id in connection_pool.user_connection_ids(user_id) {
2877 session.peer.send(connection_id, update.clone())?;
2878 }
2879 }
2880
2881 response.send(proto::Ack {})?;
2882 Ok(())
2883}
2884
2885/// Alter the role for a user in the channel.
2886async fn set_channel_member_role(
2887 request: proto::SetChannelMemberRole,
2888 response: Response<proto::SetChannelMemberRole>,
2889 session: Session,
2890) -> Result<()> {
2891 let db = session.db().await;
2892 let channel_id = ChannelId::from_proto(request.channel_id);
2893 let member_id = UserId::from_proto(request.user_id);
2894 let result = db
2895 .set_channel_member_role(
2896 channel_id,
2897 session.user_id(),
2898 member_id,
2899 request.role().into(),
2900 )
2901 .await?;
2902
2903 match result {
2904 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2905 let mut connection_pool = session.connection_pool().await;
2906 notify_membership_updated(
2907 &mut connection_pool,
2908 membership_update,
2909 member_id,
2910 &session.peer,
2911 )
2912 }
2913 db::SetMemberRoleResult::InviteUpdated(channel) => {
2914 let update = proto::UpdateChannels {
2915 channel_invitations: vec![channel.to_proto()],
2916 ..Default::default()
2917 };
2918
2919 for connection_id in session
2920 .connection_pool()
2921 .await
2922 .user_connection_ids(member_id)
2923 {
2924 session.peer.send(connection_id, update.clone())?;
2925 }
2926 }
2927 }
2928
2929 response.send(proto::Ack {})?;
2930 Ok(())
2931}
2932
2933/// Change the name of a channel
2934async fn rename_channel(
2935 request: proto::RenameChannel,
2936 response: Response<proto::RenameChannel>,
2937 session: Session,
2938) -> Result<()> {
2939 let db = session.db().await;
2940 let channel_id = ChannelId::from_proto(request.channel_id);
2941 let channel_model = db
2942 .rename_channel(channel_id, session.user_id(), &request.name)
2943 .await?;
2944 let root_id = channel_model.root_id();
2945 let channel = Channel::from_model(channel_model);
2946
2947 response.send(proto::RenameChannelResponse {
2948 channel: Some(channel.to_proto()),
2949 })?;
2950
2951 let connection_pool = session.connection_pool().await;
2952 let update = proto::UpdateChannels {
2953 channels: vec![channel.to_proto()],
2954 ..Default::default()
2955 };
2956 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2957 if role.can_see_channel(channel.visibility) {
2958 session.peer.send(connection_id, update.clone())?;
2959 }
2960 }
2961
2962 Ok(())
2963}
2964
2965/// Move a channel to a new parent.
2966async fn move_channel(
2967 request: proto::MoveChannel,
2968 response: Response<proto::MoveChannel>,
2969 session: Session,
2970) -> Result<()> {
2971 let channel_id = ChannelId::from_proto(request.channel_id);
2972 let to = ChannelId::from_proto(request.to);
2973
2974 let (root_id, channels) = session
2975 .db()
2976 .await
2977 .move_channel(channel_id, to, session.user_id())
2978 .await?;
2979
2980 let connection_pool = session.connection_pool().await;
2981 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2982 let channels = channels
2983 .iter()
2984 .filter_map(|channel| {
2985 if role.can_see_channel(channel.visibility) {
2986 Some(channel.to_proto())
2987 } else {
2988 None
2989 }
2990 })
2991 .collect::<Vec<_>>();
2992 if channels.is_empty() {
2993 continue;
2994 }
2995
2996 let update = proto::UpdateChannels {
2997 channels,
2998 ..Default::default()
2999 };
3000
3001 session.peer.send(connection_id, update.clone())?;
3002 }
3003
3004 response.send(Ack {})?;
3005 Ok(())
3006}
3007
3008/// Get the list of channel members
3009async fn get_channel_members(
3010 request: proto::GetChannelMembers,
3011 response: Response<proto::GetChannelMembers>,
3012 session: Session,
3013) -> Result<()> {
3014 let db = session.db().await;
3015 let channel_id = ChannelId::from_proto(request.channel_id);
3016 let limit = if request.limit == 0 {
3017 u16::MAX as u64
3018 } else {
3019 request.limit
3020 };
3021 let (members, users) = db
3022 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3023 .await?;
3024 response.send(proto::GetChannelMembersResponse { members, users })?;
3025 Ok(())
3026}
3027
3028/// Accept or decline a channel invitation.
3029async fn respond_to_channel_invite(
3030 request: proto::RespondToChannelInvite,
3031 response: Response<proto::RespondToChannelInvite>,
3032 session: Session,
3033) -> Result<()> {
3034 let db = session.db().await;
3035 let channel_id = ChannelId::from_proto(request.channel_id);
3036 let RespondToChannelInvite {
3037 membership_update,
3038 notifications,
3039 } = db
3040 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3041 .await?;
3042
3043 let mut connection_pool = session.connection_pool().await;
3044 if let Some(membership_update) = membership_update {
3045 notify_membership_updated(
3046 &mut connection_pool,
3047 membership_update,
3048 session.user_id(),
3049 &session.peer,
3050 );
3051 } else {
3052 let update = proto::UpdateChannels {
3053 remove_channel_invitations: vec![channel_id.to_proto()],
3054 ..Default::default()
3055 };
3056
3057 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3058 session.peer.send(connection_id, update.clone())?;
3059 }
3060 };
3061
3062 send_notifications(&connection_pool, &session.peer, notifications);
3063
3064 response.send(proto::Ack {})?;
3065
3066 Ok(())
3067}
3068
3069/// Join the channels' room
3070async fn join_channel(
3071 request: proto::JoinChannel,
3072 response: Response<proto::JoinChannel>,
3073 session: Session,
3074) -> Result<()> {
3075 let channel_id = ChannelId::from_proto(request.channel_id);
3076 join_channel_internal(channel_id, Box::new(response), session).await
3077}
3078
3079trait JoinChannelInternalResponse {
3080 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3081}
3082impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3083 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3084 Response::<proto::JoinChannel>::send(self, result)
3085 }
3086}
3087impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3088 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3089 Response::<proto::JoinRoom>::send(self, result)
3090 }
3091}
3092
3093async fn join_channel_internal(
3094 channel_id: ChannelId,
3095 response: Box<impl JoinChannelInternalResponse>,
3096 session: Session,
3097) -> Result<()> {
3098 let joined_room = {
3099 let mut db = session.db().await;
3100 // If zed quits without leaving the room, and the user re-opens zed before the
3101 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3102 // room they were in.
3103 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3104 tracing::info!(
3105 stale_connection_id = %connection,
3106 "cleaning up stale connection",
3107 );
3108 drop(db);
3109 leave_room_for_session(&session, connection).await?;
3110 db = session.db().await;
3111 }
3112
3113 let (joined_room, membership_updated, role) = db
3114 .join_channel(channel_id, session.user_id(), session.connection_id)
3115 .await?;
3116
3117 let live_kit_connection_info =
3118 session
3119 .app_state
3120 .livekit_client
3121 .as_ref()
3122 .and_then(|live_kit| {
3123 let (can_publish, token) = if role == ChannelRole::Guest {
3124 (
3125 false,
3126 live_kit
3127 .guest_token(
3128 &joined_room.room.livekit_room,
3129 &session.user_id().to_string(),
3130 )
3131 .trace_err()?,
3132 )
3133 } else {
3134 (
3135 true,
3136 live_kit
3137 .room_token(
3138 &joined_room.room.livekit_room,
3139 &session.user_id().to_string(),
3140 )
3141 .trace_err()?,
3142 )
3143 };
3144
3145 Some(LiveKitConnectionInfo {
3146 server_url: live_kit.url().into(),
3147 token,
3148 can_publish,
3149 })
3150 });
3151
3152 response.send(proto::JoinRoomResponse {
3153 room: Some(joined_room.room.clone()),
3154 channel_id: joined_room
3155 .channel
3156 .as_ref()
3157 .map(|channel| channel.id.to_proto()),
3158 live_kit_connection_info,
3159 })?;
3160
3161 let mut connection_pool = session.connection_pool().await;
3162 if let Some(membership_updated) = membership_updated {
3163 notify_membership_updated(
3164 &mut connection_pool,
3165 membership_updated,
3166 session.user_id(),
3167 &session.peer,
3168 );
3169 }
3170
3171 room_updated(&joined_room.room, &session.peer);
3172
3173 joined_room
3174 };
3175
3176 channel_updated(
3177 &joined_room
3178 .channel
3179 .ok_or_else(|| anyhow!("channel not returned"))?,
3180 &joined_room.room,
3181 &session.peer,
3182 &*session.connection_pool().await,
3183 );
3184
3185 update_user_contacts(session.user_id(), &session).await?;
3186 Ok(())
3187}
3188
3189/// Start editing the channel notes
3190async fn join_channel_buffer(
3191 request: proto::JoinChannelBuffer,
3192 response: Response<proto::JoinChannelBuffer>,
3193 session: Session,
3194) -> Result<()> {
3195 let db = session.db().await;
3196 let channel_id = ChannelId::from_proto(request.channel_id);
3197
3198 let open_response = db
3199 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3200 .await?;
3201
3202 let collaborators = open_response.collaborators.clone();
3203 response.send(open_response)?;
3204
3205 let update = UpdateChannelBufferCollaborators {
3206 channel_id: channel_id.to_proto(),
3207 collaborators: collaborators.clone(),
3208 };
3209 channel_buffer_updated(
3210 session.connection_id,
3211 collaborators
3212 .iter()
3213 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3214 &update,
3215 &session.peer,
3216 );
3217
3218 Ok(())
3219}
3220
3221/// Edit the channel notes
3222async fn update_channel_buffer(
3223 request: proto::UpdateChannelBuffer,
3224 session: Session,
3225) -> Result<()> {
3226 let db = session.db().await;
3227 let channel_id = ChannelId::from_proto(request.channel_id);
3228
3229 let (collaborators, epoch, version) = db
3230 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3231 .await?;
3232
3233 channel_buffer_updated(
3234 session.connection_id,
3235 collaborators.clone(),
3236 &proto::UpdateChannelBuffer {
3237 channel_id: channel_id.to_proto(),
3238 operations: request.operations,
3239 },
3240 &session.peer,
3241 );
3242
3243 let pool = &*session.connection_pool().await;
3244
3245 let non_collaborators =
3246 pool.channel_connection_ids(channel_id)
3247 .filter_map(|(connection_id, _)| {
3248 if collaborators.contains(&connection_id) {
3249 None
3250 } else {
3251 Some(connection_id)
3252 }
3253 });
3254
3255 broadcast(None, non_collaborators, |peer_id| {
3256 session.peer.send(
3257 peer_id,
3258 proto::UpdateChannels {
3259 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3260 channel_id: channel_id.to_proto(),
3261 epoch: epoch as u64,
3262 version: version.clone(),
3263 }],
3264 ..Default::default()
3265 },
3266 )
3267 });
3268
3269 Ok(())
3270}
3271
3272/// Rejoin the channel notes after a connection blip
3273async fn rejoin_channel_buffers(
3274 request: proto::RejoinChannelBuffers,
3275 response: Response<proto::RejoinChannelBuffers>,
3276 session: Session,
3277) -> Result<()> {
3278 let db = session.db().await;
3279 let buffers = db
3280 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3281 .await?;
3282
3283 for rejoined_buffer in &buffers {
3284 let collaborators_to_notify = rejoined_buffer
3285 .buffer
3286 .collaborators
3287 .iter()
3288 .filter_map(|c| Some(c.peer_id?.into()));
3289 channel_buffer_updated(
3290 session.connection_id,
3291 collaborators_to_notify,
3292 &proto::UpdateChannelBufferCollaborators {
3293 channel_id: rejoined_buffer.buffer.channel_id,
3294 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3295 },
3296 &session.peer,
3297 );
3298 }
3299
3300 response.send(proto::RejoinChannelBuffersResponse {
3301 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3302 })?;
3303
3304 Ok(())
3305}
3306
3307/// Stop editing the channel notes
3308async fn leave_channel_buffer(
3309 request: proto::LeaveChannelBuffer,
3310 response: Response<proto::LeaveChannelBuffer>,
3311 session: Session,
3312) -> Result<()> {
3313 let db = session.db().await;
3314 let channel_id = ChannelId::from_proto(request.channel_id);
3315
3316 let left_buffer = db
3317 .leave_channel_buffer(channel_id, session.connection_id)
3318 .await?;
3319
3320 response.send(Ack {})?;
3321
3322 channel_buffer_updated(
3323 session.connection_id,
3324 left_buffer.connections,
3325 &proto::UpdateChannelBufferCollaborators {
3326 channel_id: channel_id.to_proto(),
3327 collaborators: left_buffer.collaborators,
3328 },
3329 &session.peer,
3330 );
3331
3332 Ok(())
3333}
3334
3335fn channel_buffer_updated<T: EnvelopedMessage>(
3336 sender_id: ConnectionId,
3337 collaborators: impl IntoIterator<Item = ConnectionId>,
3338 message: &T,
3339 peer: &Peer,
3340) {
3341 broadcast(Some(sender_id), collaborators, |peer_id| {
3342 peer.send(peer_id, message.clone())
3343 });
3344}
3345
3346fn send_notifications(
3347 connection_pool: &ConnectionPool,
3348 peer: &Peer,
3349 notifications: db::NotificationBatch,
3350) {
3351 for (user_id, notification) in notifications {
3352 for connection_id in connection_pool.user_connection_ids(user_id) {
3353 if let Err(error) = peer.send(
3354 connection_id,
3355 proto::AddNotification {
3356 notification: Some(notification.clone()),
3357 },
3358 ) {
3359 tracing::error!(
3360 "failed to send notification to {:?} {}",
3361 connection_id,
3362 error
3363 );
3364 }
3365 }
3366 }
3367}
3368
3369/// Send a message to the channel
3370async fn send_channel_message(
3371 request: proto::SendChannelMessage,
3372 response: Response<proto::SendChannelMessage>,
3373 session: Session,
3374) -> Result<()> {
3375 // Validate the message body.
3376 let body = request.body.trim().to_string();
3377 if body.len() > MAX_MESSAGE_LEN {
3378 return Err(anyhow!("message is too long"))?;
3379 }
3380 if body.is_empty() {
3381 return Err(anyhow!("message can't be blank"))?;
3382 }
3383
3384 // TODO: adjust mentions if body is trimmed
3385
3386 let timestamp = OffsetDateTime::now_utc();
3387 let nonce = request
3388 .nonce
3389 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3390
3391 let channel_id = ChannelId::from_proto(request.channel_id);
3392 let CreatedChannelMessage {
3393 message_id,
3394 participant_connection_ids,
3395 notifications,
3396 } = session
3397 .db()
3398 .await
3399 .create_channel_message(
3400 channel_id,
3401 session.user_id(),
3402 &body,
3403 &request.mentions,
3404 timestamp,
3405 nonce.clone().into(),
3406 request.reply_to_message_id.map(MessageId::from_proto),
3407 )
3408 .await?;
3409
3410 let message = proto::ChannelMessage {
3411 sender_id: session.user_id().to_proto(),
3412 id: message_id.to_proto(),
3413 body,
3414 mentions: request.mentions,
3415 timestamp: timestamp.unix_timestamp() as u64,
3416 nonce: Some(nonce),
3417 reply_to_message_id: request.reply_to_message_id,
3418 edited_at: None,
3419 };
3420 broadcast(
3421 Some(session.connection_id),
3422 participant_connection_ids.clone(),
3423 |connection| {
3424 session.peer.send(
3425 connection,
3426 proto::ChannelMessageSent {
3427 channel_id: channel_id.to_proto(),
3428 message: Some(message.clone()),
3429 },
3430 )
3431 },
3432 );
3433 response.send(proto::SendChannelMessageResponse {
3434 message: Some(message),
3435 })?;
3436
3437 let pool = &*session.connection_pool().await;
3438 let non_participants =
3439 pool.channel_connection_ids(channel_id)
3440 .filter_map(|(connection_id, _)| {
3441 if participant_connection_ids.contains(&connection_id) {
3442 None
3443 } else {
3444 Some(connection_id)
3445 }
3446 });
3447 broadcast(None, non_participants, |peer_id| {
3448 session.peer.send(
3449 peer_id,
3450 proto::UpdateChannels {
3451 latest_channel_message_ids: vec![proto::ChannelMessageId {
3452 channel_id: channel_id.to_proto(),
3453 message_id: message_id.to_proto(),
3454 }],
3455 ..Default::default()
3456 },
3457 )
3458 });
3459 send_notifications(pool, &session.peer, notifications);
3460
3461 Ok(())
3462}
3463
3464/// Delete a channel message
3465async fn remove_channel_message(
3466 request: proto::RemoveChannelMessage,
3467 response: Response<proto::RemoveChannelMessage>,
3468 session: Session,
3469) -> Result<()> {
3470 let channel_id = ChannelId::from_proto(request.channel_id);
3471 let message_id = MessageId::from_proto(request.message_id);
3472 let (connection_ids, existing_notification_ids) = session
3473 .db()
3474 .await
3475 .remove_channel_message(channel_id, message_id, session.user_id())
3476 .await?;
3477
3478 broadcast(
3479 Some(session.connection_id),
3480 connection_ids,
3481 move |connection| {
3482 session.peer.send(connection, request.clone())?;
3483
3484 for notification_id in &existing_notification_ids {
3485 session.peer.send(
3486 connection,
3487 proto::DeleteNotification {
3488 notification_id: (*notification_id).to_proto(),
3489 },
3490 )?;
3491 }
3492
3493 Ok(())
3494 },
3495 );
3496 response.send(proto::Ack {})?;
3497 Ok(())
3498}
3499
3500async fn update_channel_message(
3501 request: proto::UpdateChannelMessage,
3502 response: Response<proto::UpdateChannelMessage>,
3503 session: Session,
3504) -> Result<()> {
3505 let channel_id = ChannelId::from_proto(request.channel_id);
3506 let message_id = MessageId::from_proto(request.message_id);
3507 let updated_at = OffsetDateTime::now_utc();
3508 let UpdatedChannelMessage {
3509 message_id,
3510 participant_connection_ids,
3511 notifications,
3512 reply_to_message_id,
3513 timestamp,
3514 deleted_mention_notification_ids,
3515 updated_mention_notifications,
3516 } = session
3517 .db()
3518 .await
3519 .update_channel_message(
3520 channel_id,
3521 message_id,
3522 session.user_id(),
3523 request.body.as_str(),
3524 &request.mentions,
3525 updated_at,
3526 )
3527 .await?;
3528
3529 let nonce = request
3530 .nonce
3531 .clone()
3532 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3533
3534 let message = proto::ChannelMessage {
3535 sender_id: session.user_id().to_proto(),
3536 id: message_id.to_proto(),
3537 body: request.body.clone(),
3538 mentions: request.mentions.clone(),
3539 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3540 nonce: Some(nonce),
3541 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3542 edited_at: Some(updated_at.unix_timestamp() as u64),
3543 };
3544
3545 response.send(proto::Ack {})?;
3546
3547 let pool = &*session.connection_pool().await;
3548 broadcast(
3549 Some(session.connection_id),
3550 participant_connection_ids,
3551 |connection| {
3552 session.peer.send(
3553 connection,
3554 proto::ChannelMessageUpdate {
3555 channel_id: channel_id.to_proto(),
3556 message: Some(message.clone()),
3557 },
3558 )?;
3559
3560 for notification_id in &deleted_mention_notification_ids {
3561 session.peer.send(
3562 connection,
3563 proto::DeleteNotification {
3564 notification_id: (*notification_id).to_proto(),
3565 },
3566 )?;
3567 }
3568
3569 for notification in &updated_mention_notifications {
3570 session.peer.send(
3571 connection,
3572 proto::UpdateNotification {
3573 notification: Some(notification.clone()),
3574 },
3575 )?;
3576 }
3577
3578 Ok(())
3579 },
3580 );
3581
3582 send_notifications(pool, &session.peer, notifications);
3583
3584 Ok(())
3585}
3586
3587/// Mark a channel message as read
3588async fn acknowledge_channel_message(
3589 request: proto::AckChannelMessage,
3590 session: Session,
3591) -> Result<()> {
3592 let channel_id = ChannelId::from_proto(request.channel_id);
3593 let message_id = MessageId::from_proto(request.message_id);
3594 let notifications = session
3595 .db()
3596 .await
3597 .observe_channel_message(channel_id, session.user_id(), message_id)
3598 .await?;
3599 send_notifications(
3600 &*session.connection_pool().await,
3601 &session.peer,
3602 notifications,
3603 );
3604 Ok(())
3605}
3606
3607/// Mark a buffer version as synced
3608async fn acknowledge_buffer_version(
3609 request: proto::AckBufferOperation,
3610 session: Session,
3611) -> Result<()> {
3612 let buffer_id = BufferId::from_proto(request.buffer_id);
3613 session
3614 .db()
3615 .await
3616 .observe_buffer_version(
3617 buffer_id,
3618 session.user_id(),
3619 request.epoch as i32,
3620 &request.version,
3621 )
3622 .await?;
3623 Ok(())
3624}
3625
3626async fn count_language_model_tokens(
3627 request: proto::CountLanguageModelTokens,
3628 response: Response<proto::CountLanguageModelTokens>,
3629 session: Session,
3630 config: &Config,
3631) -> Result<()> {
3632 authorize_access_to_legacy_llm_endpoints(&session).await?;
3633
3634 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3635 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3636 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3637 };
3638
3639 session
3640 .app_state
3641 .rate_limiter
3642 .check(&*rate_limit, session.user_id())
3643 .await?;
3644
3645 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3646 Some(proto::LanguageModelProvider::Google) => {
3647 let api_key = config
3648 .google_ai_api_key
3649 .as_ref()
3650 .context("no Google AI API key configured on the server")?;
3651 google_ai::count_tokens(
3652 session.http_client.as_ref(),
3653 google_ai::API_URL,
3654 api_key,
3655 serde_json::from_str(&request.request)?,
3656 )
3657 .await?
3658 }
3659 _ => return Err(anyhow!("unsupported provider"))?,
3660 };
3661
3662 response.send(proto::CountLanguageModelTokensResponse {
3663 token_count: result.total_tokens as u32,
3664 })?;
3665
3666 Ok(())
3667}
3668
3669struct ZedProCountLanguageModelTokensRateLimit;
3670
3671impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3672 fn capacity(&self) -> usize {
3673 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3674 .ok()
3675 .and_then(|v| v.parse().ok())
3676 .unwrap_or(600) // Picked arbitrarily
3677 }
3678
3679 fn refill_duration(&self) -> chrono::Duration {
3680 chrono::Duration::hours(1)
3681 }
3682
3683 fn db_name(&self) -> &'static str {
3684 "zed-pro:count-language-model-tokens"
3685 }
3686}
3687
3688struct FreeCountLanguageModelTokensRateLimit;
3689
3690impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3691 fn capacity(&self) -> usize {
3692 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3693 .ok()
3694 .and_then(|v| v.parse().ok())
3695 .unwrap_or(600 / 10) // Picked arbitrarily
3696 }
3697
3698 fn refill_duration(&self) -> chrono::Duration {
3699 chrono::Duration::hours(1)
3700 }
3701
3702 fn db_name(&self) -> &'static str {
3703 "free:count-language-model-tokens"
3704 }
3705}
3706
3707struct ZedProComputeEmbeddingsRateLimit;
3708
3709impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3710 fn capacity(&self) -> usize {
3711 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3712 .ok()
3713 .and_then(|v| v.parse().ok())
3714 .unwrap_or(5000) // Picked arbitrarily
3715 }
3716
3717 fn refill_duration(&self) -> chrono::Duration {
3718 chrono::Duration::hours(1)
3719 }
3720
3721 fn db_name(&self) -> &'static str {
3722 "zed-pro:compute-embeddings"
3723 }
3724}
3725
3726struct FreeComputeEmbeddingsRateLimit;
3727
3728impl RateLimit for FreeComputeEmbeddingsRateLimit {
3729 fn capacity(&self) -> usize {
3730 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3731 .ok()
3732 .and_then(|v| v.parse().ok())
3733 .unwrap_or(5000 / 10) // Picked arbitrarily
3734 }
3735
3736 fn refill_duration(&self) -> chrono::Duration {
3737 chrono::Duration::hours(1)
3738 }
3739
3740 fn db_name(&self) -> &'static str {
3741 "free:compute-embeddings"
3742 }
3743}
3744
3745async fn compute_embeddings(
3746 request: proto::ComputeEmbeddings,
3747 response: Response<proto::ComputeEmbeddings>,
3748 session: Session,
3749 api_key: Option<Arc<str>>,
3750) -> Result<()> {
3751 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3752 authorize_access_to_legacy_llm_endpoints(&session).await?;
3753
3754 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3755 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3756 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3757 };
3758
3759 session
3760 .app_state
3761 .rate_limiter
3762 .check(&*rate_limit, session.user_id())
3763 .await?;
3764
3765 let embeddings = match request.model.as_str() {
3766 "openai/text-embedding-3-small" => {
3767 open_ai::embed(
3768 session.http_client.as_ref(),
3769 OPEN_AI_API_URL,
3770 &api_key,
3771 OpenAiEmbeddingModel::TextEmbedding3Small,
3772 request.texts.iter().map(|text| text.as_str()),
3773 )
3774 .await?
3775 }
3776 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3777 };
3778
3779 let embeddings = request
3780 .texts
3781 .iter()
3782 .map(|text| {
3783 let mut hasher = sha2::Sha256::new();
3784 hasher.update(text.as_bytes());
3785 let result = hasher.finalize();
3786 result.to_vec()
3787 })
3788 .zip(
3789 embeddings
3790 .data
3791 .into_iter()
3792 .map(|embedding| embedding.embedding),
3793 )
3794 .collect::<HashMap<_, _>>();
3795
3796 let db = session.db().await;
3797 db.save_embeddings(&request.model, &embeddings)
3798 .await
3799 .context("failed to save embeddings")
3800 .trace_err();
3801
3802 response.send(proto::ComputeEmbeddingsResponse {
3803 embeddings: embeddings
3804 .into_iter()
3805 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3806 .collect(),
3807 })?;
3808 Ok(())
3809}
3810
3811async fn get_cached_embeddings(
3812 request: proto::GetCachedEmbeddings,
3813 response: Response<proto::GetCachedEmbeddings>,
3814 session: Session,
3815) -> Result<()> {
3816 authorize_access_to_legacy_llm_endpoints(&session).await?;
3817
3818 let db = session.db().await;
3819 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3820
3821 response.send(proto::GetCachedEmbeddingsResponse {
3822 embeddings: embeddings
3823 .into_iter()
3824 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3825 .collect(),
3826 })?;
3827 Ok(())
3828}
3829
3830/// This is leftover from before the LLM service.
3831///
3832/// The endpoints protected by this check will be moved there eventually.
3833async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3834 if session.is_staff() {
3835 Ok(())
3836 } else {
3837 Err(anyhow!("permission denied"))?
3838 }
3839}
3840
3841/// Get a Supermaven API key for the user
3842async fn get_supermaven_api_key(
3843 _request: proto::GetSupermavenApiKey,
3844 response: Response<proto::GetSupermavenApiKey>,
3845 session: Session,
3846) -> Result<()> {
3847 let user_id: String = session.user_id().to_string();
3848 if !session.is_staff() {
3849 return Err(anyhow!("supermaven not enabled for this account"))?;
3850 }
3851
3852 let email = session
3853 .email()
3854 .ok_or_else(|| anyhow!("user must have an email"))?;
3855
3856 let supermaven_admin_api = session
3857 .supermaven_client
3858 .as_ref()
3859 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3860
3861 let result = supermaven_admin_api
3862 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3863 .await?;
3864
3865 response.send(proto::GetSupermavenApiKeyResponse {
3866 api_key: result.api_key,
3867 })?;
3868
3869 Ok(())
3870}
3871
3872/// Start receiving chat updates for a channel
3873async fn join_channel_chat(
3874 request: proto::JoinChannelChat,
3875 response: Response<proto::JoinChannelChat>,
3876 session: Session,
3877) -> Result<()> {
3878 let channel_id = ChannelId::from_proto(request.channel_id);
3879
3880 let db = session.db().await;
3881 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3882 .await?;
3883 let messages = db
3884 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3885 .await?;
3886 response.send(proto::JoinChannelChatResponse {
3887 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888 messages,
3889 })?;
3890 Ok(())
3891}
3892
3893/// Stop receiving chat updates for a channel
3894async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3895 let channel_id = ChannelId::from_proto(request.channel_id);
3896 session
3897 .db()
3898 .await
3899 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3900 .await?;
3901 Ok(())
3902}
3903
3904/// Retrieve the chat history for a channel
3905async fn get_channel_messages(
3906 request: proto::GetChannelMessages,
3907 response: Response<proto::GetChannelMessages>,
3908 session: Session,
3909) -> Result<()> {
3910 let channel_id = ChannelId::from_proto(request.channel_id);
3911 let messages = session
3912 .db()
3913 .await
3914 .get_channel_messages(
3915 channel_id,
3916 session.user_id(),
3917 MESSAGE_COUNT_PER_PAGE,
3918 Some(MessageId::from_proto(request.before_message_id)),
3919 )
3920 .await?;
3921 response.send(proto::GetChannelMessagesResponse {
3922 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3923 messages,
3924 })?;
3925 Ok(())
3926}
3927
3928/// Retrieve specific chat messages
3929async fn get_channel_messages_by_id(
3930 request: proto::GetChannelMessagesById,
3931 response: Response<proto::GetChannelMessagesById>,
3932 session: Session,
3933) -> Result<()> {
3934 let message_ids = request
3935 .message_ids
3936 .iter()
3937 .map(|id| MessageId::from_proto(*id))
3938 .collect::<Vec<_>>();
3939 let messages = session
3940 .db()
3941 .await
3942 .get_channel_messages_by_id(session.user_id(), &message_ids)
3943 .await?;
3944 response.send(proto::GetChannelMessagesResponse {
3945 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3946 messages,
3947 })?;
3948 Ok(())
3949}
3950
3951/// Retrieve the current users notifications
3952async fn get_notifications(
3953 request: proto::GetNotifications,
3954 response: Response<proto::GetNotifications>,
3955 session: Session,
3956) -> Result<()> {
3957 let notifications = session
3958 .db()
3959 .await
3960 .get_notifications(
3961 session.user_id(),
3962 NOTIFICATION_COUNT_PER_PAGE,
3963 request.before_id.map(db::NotificationId::from_proto),
3964 )
3965 .await?;
3966 response.send(proto::GetNotificationsResponse {
3967 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3968 notifications,
3969 })?;
3970 Ok(())
3971}
3972
3973/// Mark notifications as read
3974async fn mark_notification_as_read(
3975 request: proto::MarkNotificationRead,
3976 response: Response<proto::MarkNotificationRead>,
3977 session: Session,
3978) -> Result<()> {
3979 let database = &session.db().await;
3980 let notifications = database
3981 .mark_notification_as_read_by_id(
3982 session.user_id(),
3983 NotificationId::from_proto(request.notification_id),
3984 )
3985 .await?;
3986 send_notifications(
3987 &*session.connection_pool().await,
3988 &session.peer,
3989 notifications,
3990 );
3991 response.send(proto::Ack {})?;
3992 Ok(())
3993}
3994
3995/// Get the current users information
3996async fn get_private_user_info(
3997 _request: proto::GetPrivateUserInfo,
3998 response: Response<proto::GetPrivateUserInfo>,
3999 session: Session,
4000) -> Result<()> {
4001 let db = session.db().await;
4002
4003 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4004 let user = db
4005 .get_user_by_id(session.user_id())
4006 .await?
4007 .ok_or_else(|| anyhow!("user not found"))?;
4008 let flags = db.get_user_flags(session.user_id()).await?;
4009
4010 response.send(proto::GetPrivateUserInfoResponse {
4011 metrics_id,
4012 staff: user.admin,
4013 flags,
4014 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4015 })?;
4016 Ok(())
4017}
4018
4019/// Accept the terms of service (tos) on behalf of the current user
4020async fn accept_terms_of_service(
4021 _request: proto::AcceptTermsOfService,
4022 response: Response<proto::AcceptTermsOfService>,
4023 session: Session,
4024) -> Result<()> {
4025 let db = session.db().await;
4026
4027 let accepted_tos_at = Utc::now();
4028 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4029 .await?;
4030
4031 response.send(proto::AcceptTermsOfServiceResponse {
4032 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4033 })?;
4034 Ok(())
4035}
4036
4037/// The minimum account age an account must have in order to use the LLM service.
4038const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4039
4040async fn get_llm_api_token(
4041 _request: proto::GetLlmToken,
4042 response: Response<proto::GetLlmToken>,
4043 session: Session,
4044) -> Result<()> {
4045 let db = session.db().await;
4046
4047 let flags = db.get_user_flags(session.user_id()).await?;
4048 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4049 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4050 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4051
4052 if !session.is_staff() && !has_language_models_feature_flag {
4053 Err(anyhow!("permission denied"))?
4054 }
4055
4056 let user_id = session.user_id();
4057 let user = db
4058 .get_user_by_id(user_id)
4059 .await?
4060 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4061
4062 if user.accepted_tos_at.is_none() {
4063 Err(anyhow!("terms of service not accepted"))?
4064 }
4065
4066 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4067
4068 let bypass_account_age_check =
4069 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4070 if !bypass_account_age_check {
4071 let mut account_created_at = user.created_at;
4072 if let Some(github_created_at) = user.github_user_created_at {
4073 account_created_at = account_created_at.min(github_created_at);
4074 }
4075 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4076 Err(anyhow!("account too young"))?
4077 }
4078 }
4079
4080 let billing_preferences = db.get_billing_preferences(user.id).await?;
4081
4082 let token = LlmTokenClaims::create(
4083 &user,
4084 session.is_staff(),
4085 billing_preferences,
4086 has_llm_closed_beta_feature_flag,
4087 has_predict_edits_feature_flag,
4088 has_llm_subscription,
4089 session.current_plan(&db).await?,
4090 session.system_id.clone(),
4091 &session.app_state.config,
4092 )?;
4093 response.send(proto::GetLlmTokenResponse { token })?;
4094 Ok(())
4095}
4096
4097fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4098 let message = match message {
4099 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4100 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4101 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4102 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4103 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4104 code: frame.code.into(),
4105 reason: frame.reason,
4106 })),
4107 // We should never receive a frame while reading the message, according
4108 // to the `tungstenite` maintainers:
4109 //
4110 // > It cannot occur when you read messages from the WebSocket, but it
4111 // > can be used when you want to send the raw frames (e.g. you want to
4112 // > send the frames to the WebSocket without composing the full message first).
4113 // >
4114 // > — https://github.com/snapview/tungstenite-rs/issues/268
4115 TungsteniteMessage::Frame(_) => {
4116 bail!("received an unexpected frame while reading the message")
4117 }
4118 };
4119
4120 Ok(message)
4121}
4122
4123fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4124 match message {
4125 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4126 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4127 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4128 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4129 AxumMessage::Close(frame) => {
4130 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4131 code: frame.code.into(),
4132 reason: frame.reason,
4133 }))
4134 }
4135 }
4136}
4137
4138fn notify_membership_updated(
4139 connection_pool: &mut ConnectionPool,
4140 result: MembershipUpdated,
4141 user_id: UserId,
4142 peer: &Peer,
4143) {
4144 for membership in &result.new_channels.channel_memberships {
4145 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4146 }
4147 for channel_id in &result.removed_channels {
4148 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4149 }
4150
4151 let user_channels_update = proto::UpdateUserChannels {
4152 channel_memberships: result
4153 .new_channels
4154 .channel_memberships
4155 .iter()
4156 .map(|cm| proto::ChannelMembership {
4157 channel_id: cm.channel_id.to_proto(),
4158 role: cm.role.into(),
4159 })
4160 .collect(),
4161 ..Default::default()
4162 };
4163
4164 let mut update = build_channels_update(result.new_channels);
4165 update.delete_channels = result
4166 .removed_channels
4167 .into_iter()
4168 .map(|id| id.to_proto())
4169 .collect();
4170 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4171
4172 for connection_id in connection_pool.user_connection_ids(user_id) {
4173 peer.send(connection_id, user_channels_update.clone())
4174 .trace_err();
4175 peer.send(connection_id, update.clone()).trace_err();
4176 }
4177}
4178
4179fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4180 proto::UpdateUserChannels {
4181 channel_memberships: channels
4182 .channel_memberships
4183 .iter()
4184 .map(|m| proto::ChannelMembership {
4185 channel_id: m.channel_id.to_proto(),
4186 role: m.role.into(),
4187 })
4188 .collect(),
4189 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4190 observed_channel_message_id: channels.observed_channel_messages.clone(),
4191 }
4192}
4193
4194fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4195 let mut update = proto::UpdateChannels::default();
4196
4197 for channel in channels.channels {
4198 update.channels.push(channel.to_proto());
4199 }
4200
4201 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4202 update.latest_channel_message_ids = channels.latest_channel_messages;
4203
4204 for (channel_id, participants) in channels.channel_participants {
4205 update
4206 .channel_participants
4207 .push(proto::ChannelParticipants {
4208 channel_id: channel_id.to_proto(),
4209 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4210 });
4211 }
4212
4213 for channel in channels.invited_channels {
4214 update.channel_invitations.push(channel.to_proto());
4215 }
4216
4217 update
4218}
4219
4220fn build_initial_contacts_update(
4221 contacts: Vec<db::Contact>,
4222 pool: &ConnectionPool,
4223) -> proto::UpdateContacts {
4224 let mut update = proto::UpdateContacts::default();
4225
4226 for contact in contacts {
4227 match contact {
4228 db::Contact::Accepted { user_id, busy } => {
4229 update.contacts.push(contact_for_user(user_id, busy, pool));
4230 }
4231 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4232 db::Contact::Incoming { user_id } => {
4233 update
4234 .incoming_requests
4235 .push(proto::IncomingContactRequest {
4236 requester_id: user_id.to_proto(),
4237 })
4238 }
4239 }
4240 }
4241
4242 update
4243}
4244
4245fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4246 proto::Contact {
4247 user_id: user_id.to_proto(),
4248 online: pool.is_user_online(user_id),
4249 busy,
4250 }
4251}
4252
4253fn room_updated(room: &proto::Room, peer: &Peer) {
4254 broadcast(
4255 None,
4256 room.participants
4257 .iter()
4258 .filter_map(|participant| Some(participant.peer_id?.into())),
4259 |peer_id| {
4260 peer.send(
4261 peer_id,
4262 proto::RoomUpdated {
4263 room: Some(room.clone()),
4264 },
4265 )
4266 },
4267 );
4268}
4269
4270fn channel_updated(
4271 channel: &db::channel::Model,
4272 room: &proto::Room,
4273 peer: &Peer,
4274 pool: &ConnectionPool,
4275) {
4276 let participants = room
4277 .participants
4278 .iter()
4279 .map(|p| p.user_id)
4280 .collect::<Vec<_>>();
4281
4282 broadcast(
4283 None,
4284 pool.channel_connection_ids(channel.root_id())
4285 .filter_map(|(channel_id, role)| {
4286 role.can_see_channel(channel.visibility)
4287 .then_some(channel_id)
4288 }),
4289 |peer_id| {
4290 peer.send(
4291 peer_id,
4292 proto::UpdateChannels {
4293 channel_participants: vec![proto::ChannelParticipants {
4294 channel_id: channel.id.to_proto(),
4295 participant_user_ids: participants.clone(),
4296 }],
4297 ..Default::default()
4298 },
4299 )
4300 },
4301 );
4302}
4303
4304async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4305 let db = session.db().await;
4306
4307 let contacts = db.get_contacts(user_id).await?;
4308 let busy = db.is_user_busy(user_id).await?;
4309
4310 let pool = session.connection_pool().await;
4311 let updated_contact = contact_for_user(user_id, busy, &pool);
4312 for contact in contacts {
4313 if let db::Contact::Accepted {
4314 user_id: contact_user_id,
4315 ..
4316 } = contact
4317 {
4318 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4319 session
4320 .peer
4321 .send(
4322 contact_conn_id,
4323 proto::UpdateContacts {
4324 contacts: vec![updated_contact.clone()],
4325 remove_contacts: Default::default(),
4326 incoming_requests: Default::default(),
4327 remove_incoming_requests: Default::default(),
4328 outgoing_requests: Default::default(),
4329 remove_outgoing_requests: Default::default(),
4330 },
4331 )
4332 .trace_err();
4333 }
4334 }
4335 }
4336 Ok(())
4337}
4338
4339async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4340 let mut contacts_to_update = HashSet::default();
4341
4342 let room_id;
4343 let canceled_calls_to_user_ids;
4344 let livekit_room;
4345 let delete_livekit_room;
4346 let room;
4347 let channel;
4348
4349 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4350 contacts_to_update.insert(session.user_id());
4351
4352 for project in left_room.left_projects.values() {
4353 project_left(project, session);
4354 }
4355
4356 room_id = RoomId::from_proto(left_room.room.id);
4357 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4358 livekit_room = mem::take(&mut left_room.room.livekit_room);
4359 delete_livekit_room = left_room.deleted;
4360 room = mem::take(&mut left_room.room);
4361 channel = mem::take(&mut left_room.channel);
4362
4363 room_updated(&room, &session.peer);
4364 } else {
4365 return Ok(());
4366 }
4367
4368 if let Some(channel) = channel {
4369 channel_updated(
4370 &channel,
4371 &room,
4372 &session.peer,
4373 &*session.connection_pool().await,
4374 );
4375 }
4376
4377 {
4378 let pool = session.connection_pool().await;
4379 for canceled_user_id in canceled_calls_to_user_ids {
4380 for connection_id in pool.user_connection_ids(canceled_user_id) {
4381 session
4382 .peer
4383 .send(
4384 connection_id,
4385 proto::CallCanceled {
4386 room_id: room_id.to_proto(),
4387 },
4388 )
4389 .trace_err();
4390 }
4391 contacts_to_update.insert(canceled_user_id);
4392 }
4393 }
4394
4395 for contact_user_id in contacts_to_update {
4396 update_user_contacts(contact_user_id, session).await?;
4397 }
4398
4399 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4400 live_kit
4401 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4402 .await
4403 .trace_err();
4404
4405 if delete_livekit_room {
4406 live_kit.delete_room(livekit_room).await.trace_err();
4407 }
4408 }
4409
4410 Ok(())
4411}
4412
4413async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4414 let left_channel_buffers = session
4415 .db()
4416 .await
4417 .leave_channel_buffers(session.connection_id)
4418 .await?;
4419
4420 for left_buffer in left_channel_buffers {
4421 channel_buffer_updated(
4422 session.connection_id,
4423 left_buffer.connections,
4424 &proto::UpdateChannelBufferCollaborators {
4425 channel_id: left_buffer.channel_id.to_proto(),
4426 collaborators: left_buffer.collaborators,
4427 },
4428 &session.peer,
4429 );
4430 }
4431
4432 Ok(())
4433}
4434
4435fn project_left(project: &db::LeftProject, session: &Session) {
4436 for connection_id in &project.connection_ids {
4437 if project.should_unshare {
4438 session
4439 .peer
4440 .send(
4441 *connection_id,
4442 proto::UnshareProject {
4443 project_id: project.id.to_proto(),
4444 },
4445 )
4446 .trace_err();
4447 } else {
4448 session
4449 .peer
4450 .send(
4451 *connection_id,
4452 proto::RemoveProjectCollaborator {
4453 project_id: project.id.to_proto(),
4454 peer_id: Some(session.connection_id.into()),
4455 },
4456 )
4457 .trace_err();
4458 }
4459 }
4460}
4461
4462pub trait ResultExt {
4463 type Ok;
4464
4465 fn trace_err(self) -> Option<Self::Ok>;
4466}
4467
4468impl<T, E> ResultExt for Result<T, E>
4469where
4470 E: std::fmt::Debug,
4471{
4472 type Ok = T;
4473
4474 #[track_caller]
4475 fn trace_err(self) -> Option<T> {
4476 match self {
4477 Ok(value) => Some(value),
4478 Err(error) => {
4479 tracing::error!("{:?}", error);
4480 None
4481 }
4482 }
4483 }
4484}