1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
341 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
344 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
345 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
346 .add_message_handler(create_buffer_for_peer)
347 .add_request_handler(update_buffer)
348 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
353 .add_request_handler(get_users)
354 .add_request_handler(fuzzy_search_users)
355 .add_request_handler(request_contact)
356 .add_request_handler(remove_contact)
357 .add_request_handler(respond_to_contact_request)
358 .add_message_handler(subscribe_to_channels)
359 .add_request_handler(create_channel)
360 .add_request_handler(delete_channel)
361 .add_request_handler(invite_channel_member)
362 .add_request_handler(remove_channel_member)
363 .add_request_handler(set_channel_member_role)
364 .add_request_handler(set_channel_visibility)
365 .add_request_handler(rename_channel)
366 .add_request_handler(join_channel_buffer)
367 .add_request_handler(leave_channel_buffer)
368 .add_message_handler(update_channel_buffer)
369 .add_request_handler(rejoin_channel_buffers)
370 .add_request_handler(get_channel_members)
371 .add_request_handler(respond_to_channel_invite)
372 .add_request_handler(join_channel)
373 .add_request_handler(join_channel_chat)
374 .add_message_handler(leave_channel_chat)
375 .add_request_handler(send_channel_message)
376 .add_request_handler(remove_channel_message)
377 .add_request_handler(update_channel_message)
378 .add_request_handler(get_channel_messages)
379 .add_request_handler(get_channel_messages_by_id)
380 .add_request_handler(get_notifications)
381 .add_request_handler(mark_notification_as_read)
382 .add_request_handler(move_channel)
383 .add_request_handler(follow)
384 .add_message_handler(unfollow)
385 .add_message_handler(update_followers)
386 .add_request_handler(get_private_user_info)
387 .add_request_handler(get_llm_api_token)
388 .add_request_handler(accept_terms_of_service)
389 .add_message_handler(acknowledge_channel_message)
390 .add_message_handler(acknowledge_buffer_version)
391 .add_request_handler(get_supermaven_api_key)
392 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
395 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
398 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
399 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
400 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
401 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
403 .add_message_handler(update_context)
404 .add_request_handler({
405 let app_state = app_state.clone();
406 move |request, response, session| {
407 let app_state = app_state.clone();
408 async move {
409 count_language_model_tokens(request, response, session, &app_state.config)
410 .await
411 }
412 }
413 })
414 .add_request_handler(get_cached_embeddings)
415 .add_request_handler({
416 let app_state = app_state.clone();
417 move |request, response, session| {
418 compute_embeddings(
419 request,
420 response,
421 session,
422 app_state.config.openai_api_key.clone(),
423 )
424 }
425 });
426
427 Arc::new(server)
428 }
429
430 pub async fn start(&self) -> Result<()> {
431 let server_id = *self.id.lock();
432 let app_state = self.app_state.clone();
433 let peer = self.peer.clone();
434 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
435 let pool = self.connection_pool.clone();
436 let livekit_client = self.app_state.livekit_client.clone();
437
438 let span = info_span!("start server");
439 self.app_state.executor.spawn_detached(
440 async move {
441 tracing::info!("waiting for cleanup timeout");
442 timeout.await;
443 tracing::info!("cleanup timeout expired, retrieving stale rooms");
444 if let Some((room_ids, channel_ids)) = app_state
445 .db
446 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
447 .await
448 .trace_err()
449 {
450 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
451 tracing::info!(
452 stale_channel_buffer_count = channel_ids.len(),
453 "retrieved stale channel buffers"
454 );
455
456 for channel_id in channel_ids {
457 if let Some(refreshed_channel_buffer) = app_state
458 .db
459 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
460 .await
461 .trace_err()
462 {
463 for connection_id in refreshed_channel_buffer.connection_ids {
464 peer.send(
465 connection_id,
466 proto::UpdateChannelBufferCollaborators {
467 channel_id: channel_id.to_proto(),
468 collaborators: refreshed_channel_buffer
469 .collaborators
470 .clone(),
471 },
472 )
473 .trace_err();
474 }
475 }
476 }
477
478 for room_id in room_ids {
479 let mut contacts_to_update = HashSet::default();
480 let mut canceled_calls_to_user_ids = Vec::new();
481 let mut livekit_room = String::new();
482 let mut delete_livekit_room = false;
483
484 if let Some(mut refreshed_room) = app_state
485 .db
486 .clear_stale_room_participants(room_id, server_id)
487 .await
488 .trace_err()
489 {
490 tracing::info!(
491 room_id = room_id.0,
492 new_participant_count = refreshed_room.room.participants.len(),
493 "refreshed room"
494 );
495 room_updated(&refreshed_room.room, &peer);
496 if let Some(channel) = refreshed_room.channel.as_ref() {
497 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
498 }
499 contacts_to_update
500 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
501 contacts_to_update
502 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
503 canceled_calls_to_user_ids =
504 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
505 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
506 delete_livekit_room = refreshed_room.room.participants.is_empty();
507 }
508
509 {
510 let pool = pool.lock();
511 for canceled_user_id in canceled_calls_to_user_ids {
512 for connection_id in pool.user_connection_ids(canceled_user_id) {
513 peer.send(
514 connection_id,
515 proto::CallCanceled {
516 room_id: room_id.to_proto(),
517 },
518 )
519 .trace_err();
520 }
521 }
522 }
523
524 for user_id in contacts_to_update {
525 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
526 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
527 if let Some((busy, contacts)) = busy.zip(contacts) {
528 let pool = pool.lock();
529 let updated_contact = contact_for_user(user_id, busy, &pool);
530 for contact in contacts {
531 if let db::Contact::Accepted {
532 user_id: contact_user_id,
533 ..
534 } = contact
535 {
536 for contact_conn_id in
537 pool.user_connection_ids(contact_user_id)
538 {
539 peer.send(
540 contact_conn_id,
541 proto::UpdateContacts {
542 contacts: vec![updated_contact.clone()],
543 remove_contacts: Default::default(),
544 incoming_requests: Default::default(),
545 remove_incoming_requests: Default::default(),
546 outgoing_requests: Default::default(),
547 remove_outgoing_requests: Default::default(),
548 },
549 )
550 .trace_err();
551 }
552 }
553 }
554 }
555 }
556
557 if let Some(live_kit) = livekit_client.as_ref() {
558 if delete_livekit_room {
559 live_kit.delete_room(livekit_room).await.trace_err();
560 }
561 }
562 }
563 }
564
565 app_state
566 .db
567 .delete_stale_servers(&app_state.config.zed_environment, server_id)
568 .await
569 .trace_err();
570 }
571 .instrument(span),
572 );
573 Ok(())
574 }
575
576 pub fn teardown(&self) {
577 self.peer.teardown();
578 self.connection_pool.lock().reset();
579 let _ = self.teardown.send(true);
580 }
581
582 #[cfg(test)]
583 pub fn reset(&self, id: ServerId) {
584 self.teardown();
585 *self.id.lock() = id;
586 self.peer.reset(id.0 as u32);
587 let _ = self.teardown.send(false);
588 }
589
590 #[cfg(test)]
591 pub fn id(&self) -> ServerId {
592 *self.id.lock()
593 }
594
595 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
596 where
597 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
598 Fut: 'static + Send + Future<Output = Result<()>>,
599 M: EnvelopedMessage,
600 {
601 let prev_handler = self.handlers.insert(
602 TypeId::of::<M>(),
603 Box::new(move |envelope, session| {
604 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
605 let received_at = envelope.received_at;
606 tracing::info!("message received");
607 let start_time = Instant::now();
608 let future = (handler)(*envelope, session);
609 async move {
610 let result = future.await;
611 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
612 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
613 let queue_duration_ms = total_duration_ms - processing_duration_ms;
614 let payload_type = M::NAME;
615
616 match result {
617 Err(error) => {
618 tracing::error!(
619 ?error,
620 total_duration_ms,
621 processing_duration_ms,
622 queue_duration_ms,
623 payload_type,
624 "error handling message"
625 )
626 }
627 Ok(()) => tracing::info!(
628 total_duration_ms,
629 processing_duration_ms,
630 queue_duration_ms,
631 "finished handling message"
632 ),
633 }
634 }
635 .boxed()
636 }),
637 );
638 if prev_handler.is_some() {
639 panic!("registered a handler for the same message twice");
640 }
641 self
642 }
643
644 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
645 where
646 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
647 Fut: 'static + Send + Future<Output = Result<()>>,
648 M: EnvelopedMessage,
649 {
650 self.add_handler(move |envelope, session| handler(envelope.payload, session));
651 self
652 }
653
654 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
655 where
656 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
657 Fut: Send + Future<Output = Result<()>>,
658 M: RequestMessage,
659 {
660 let handler = Arc::new(handler);
661 self.add_handler(move |envelope, session| {
662 let receipt = envelope.receipt();
663 let handler = handler.clone();
664 async move {
665 let peer = session.peer.clone();
666 let responded = Arc::new(AtomicBool::default());
667 let response = Response {
668 peer: peer.clone(),
669 responded: responded.clone(),
670 receipt,
671 };
672 match (handler)(envelope.payload, response, session).await {
673 Ok(()) => {
674 if responded.load(std::sync::atomic::Ordering::SeqCst) {
675 Ok(())
676 } else {
677 Err(anyhow!("handler did not send a response"))?
678 }
679 }
680 Err(error) => {
681 let proto_err = match &error {
682 Error::Internal(err) => err.to_proto(),
683 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
684 };
685 peer.respond_with_error(receipt, proto_err)?;
686 Err(error)
687 }
688 }
689 }
690 })
691 }
692
693 #[allow(clippy::too_many_arguments)]
694 pub fn handle_connection(
695 self: &Arc<Self>,
696 connection: Connection,
697 address: String,
698 principal: Principal,
699 zed_version: ZedVersion,
700 geoip_country_code: Option<String>,
701 system_id: Option<String>,
702 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
703 executor: Executor,
704 ) -> impl Future<Output = ()> {
705 let this = self.clone();
706 let span = info_span!("handle connection", %address,
707 connection_id=field::Empty,
708 user_id=field::Empty,
709 login=field::Empty,
710 impersonator=field::Empty,
711 geoip_country_code=field::Empty
712 );
713 principal.update_span(&span);
714 if let Some(country_code) = geoip_country_code.as_ref() {
715 span.record("geoip_country_code", country_code);
716 }
717
718 let mut teardown = self.teardown.subscribe();
719 async move {
720 if *teardown.borrow() {
721 tracing::error!("server is tearing down");
722 return
723 }
724 let (connection_id, handle_io, mut incoming_rx) = this
725 .peer
726 .add_connection(connection, {
727 let executor = executor.clone();
728 move |duration| executor.sleep(duration)
729 });
730 tracing::Span::current().record("connection_id", format!("{}", connection_id));
731
732 tracing::info!("connection opened");
733
734 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
735 let http_client = match ReqwestClient::user_agent(&user_agent) {
736 Ok(http_client) => Arc::new(http_client),
737 Err(error) => {
738 tracing::error!(?error, "failed to create HTTP client");
739 return;
740 }
741 };
742
743 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
744 supermaven_admin_api_key.to_string(),
745 http_client.clone(),
746 )));
747
748 let session = Session {
749 principal: principal.clone(),
750 connection_id,
751 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
752 peer: this.peer.clone(),
753 connection_pool: this.connection_pool.clone(),
754 app_state: this.app_state.clone(),
755 http_client,
756 geoip_country_code,
757 system_id,
758 _executor: executor.clone(),
759 supermaven_client,
760 };
761
762 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
763 tracing::error!(?error, "failed to send initial client update");
764 return;
765 }
766
767 let handle_io = handle_io.fuse();
768 futures::pin_mut!(handle_io);
769
770 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
771 // This prevents deadlocks when e.g., client A performs a request to client B and
772 // client B performs a request to client A. If both clients stop processing further
773 // messages until their respective request completes, they won't have a chance to
774 // respond to the other client's request and cause a deadlock.
775 //
776 // This arrangement ensures we will attempt to process earlier messages first, but fall
777 // back to processing messages arrived later in the spirit of making progress.
778 let mut foreground_message_handlers = FuturesUnordered::new();
779 let concurrent_handlers = Arc::new(Semaphore::new(256));
780 loop {
781 let next_message = async {
782 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
783 let message = incoming_rx.next().await;
784 (permit, message)
785 }.fuse();
786 futures::pin_mut!(next_message);
787 futures::select_biased! {
788 _ = teardown.changed().fuse() => return,
789 result = handle_io => {
790 if let Err(error) = result {
791 tracing::error!(?error, "error handling I/O");
792 }
793 break;
794 }
795 _ = foreground_message_handlers.next() => {}
796 next_message = next_message => {
797 let (permit, message) = next_message;
798 if let Some(message) = message {
799 let type_name = message.payload_type_name();
800 // note: we copy all the fields from the parent span so we can query them in the logs.
801 // (https://github.com/tokio-rs/tracing/issues/2670).
802 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
803 user_id=field::Empty,
804 login=field::Empty,
805 impersonator=field::Empty,
806 );
807 principal.update_span(&span);
808 let span_enter = span.enter();
809 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
810 let is_background = message.is_background();
811 let handle_message = (handler)(message, session.clone());
812 drop(span_enter);
813
814 let handle_message = async move {
815 handle_message.await;
816 drop(permit);
817 }.instrument(span);
818 if is_background {
819 executor.spawn_detached(handle_message);
820 } else {
821 foreground_message_handlers.push(handle_message);
822 }
823 } else {
824 tracing::error!("no message handler");
825 }
826 } else {
827 tracing::info!("connection closed");
828 break;
829 }
830 }
831 }
832 }
833
834 drop(foreground_message_handlers);
835 tracing::info!("signing out");
836 if let Err(error) = connection_lost(session, teardown, executor).await {
837 tracing::error!(?error, "error signing out");
838 }
839
840 }.instrument(span)
841 }
842
843 async fn send_initial_client_update(
844 &self,
845 connection_id: ConnectionId,
846 principal: &Principal,
847 zed_version: ZedVersion,
848 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
849 session: &Session,
850 ) -> Result<()> {
851 self.peer.send(
852 connection_id,
853 proto::Hello {
854 peer_id: Some(connection_id.into()),
855 },
856 )?;
857 tracing::info!("sent hello message");
858 if let Some(send_connection_id) = send_connection_id.take() {
859 let _ = send_connection_id.send(connection_id);
860 }
861
862 match principal {
863 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
864 if !user.connected_once {
865 self.peer.send(connection_id, proto::ShowContacts {})?;
866 self.app_state
867 .db
868 .set_user_connected_once(user.id, true)
869 .await?;
870 }
871
872 update_user_plan(user.id, session).await?;
873
874 let contacts = self.app_state.db.get_contacts(user.id).await?;
875
876 {
877 let mut pool = self.connection_pool.lock();
878 pool.add_connection(connection_id, user.id, user.admin, zed_version);
879 self.peer.send(
880 connection_id,
881 build_initial_contacts_update(contacts, &pool),
882 )?;
883 }
884
885 if should_auto_subscribe_to_channels(zed_version) {
886 subscribe_user_to_channels(user.id, session).await?;
887 }
888
889 if let Some(incoming_call) =
890 self.app_state.db.incoming_call_for_user(user.id).await?
891 {
892 self.peer.send(connection_id, incoming_call)?;
893 }
894
895 update_user_contacts(user.id, session).await?;
896 }
897 }
898
899 Ok(())
900 }
901
902 pub async fn invite_code_redeemed(
903 self: &Arc<Self>,
904 inviter_id: UserId,
905 invitee_id: UserId,
906 ) -> Result<()> {
907 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
908 if let Some(code) = &user.invite_code {
909 let pool = self.connection_pool.lock();
910 let invitee_contact = contact_for_user(invitee_id, false, &pool);
911 for connection_id in pool.user_connection_ids(inviter_id) {
912 self.peer.send(
913 connection_id,
914 proto::UpdateContacts {
915 contacts: vec![invitee_contact.clone()],
916 ..Default::default()
917 },
918 )?;
919 self.peer.send(
920 connection_id,
921 proto::UpdateInviteInfo {
922 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
923 count: user.invite_count as u32,
924 },
925 )?;
926 }
927 }
928 }
929 Ok(())
930 }
931
932 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
933 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
934 if let Some(invite_code) = &user.invite_code {
935 let pool = self.connection_pool.lock();
936 for connection_id in pool.user_connection_ids(user_id) {
937 self.peer.send(
938 connection_id,
939 proto::UpdateInviteInfo {
940 url: format!(
941 "{}{}",
942 self.app_state.config.invite_link_prefix, invite_code
943 ),
944 count: user.invite_count as u32,
945 },
946 )?;
947 }
948 }
949 }
950 Ok(())
951 }
952
953 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
954 let pool = self.connection_pool.lock();
955 for connection_id in pool.user_connection_ids(user_id) {
956 self.peer
957 .send(connection_id, proto::RefreshLlmToken {})
958 .trace_err();
959 }
960 }
961
962 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
963 ServerSnapshot {
964 connection_pool: ConnectionPoolGuard {
965 guard: self.connection_pool.lock(),
966 _not_send: PhantomData,
967 },
968 peer: &self.peer,
969 }
970 }
971}
972
973impl<'a> Deref for ConnectionPoolGuard<'a> {
974 type Target = ConnectionPool;
975
976 fn deref(&self) -> &Self::Target {
977 &self.guard
978 }
979}
980
981impl<'a> DerefMut for ConnectionPoolGuard<'a> {
982 fn deref_mut(&mut self) -> &mut Self::Target {
983 &mut self.guard
984 }
985}
986
987impl<'a> Drop for ConnectionPoolGuard<'a> {
988 fn drop(&mut self) {
989 #[cfg(test)]
990 self.check_invariants();
991 }
992}
993
994fn broadcast<F>(
995 sender_id: Option<ConnectionId>,
996 receiver_ids: impl IntoIterator<Item = ConnectionId>,
997 mut f: F,
998) where
999 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1000{
1001 for receiver_id in receiver_ids {
1002 if Some(receiver_id) != sender_id {
1003 if let Err(error) = f(receiver_id) {
1004 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1005 }
1006 }
1007 }
1008}
1009
1010pub struct ProtocolVersion(u32);
1011
1012impl Header for ProtocolVersion {
1013 fn name() -> &'static HeaderName {
1014 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1015 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1016 }
1017
1018 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1019 where
1020 Self: Sized,
1021 I: Iterator<Item = &'i axum::http::HeaderValue>,
1022 {
1023 let version = values
1024 .next()
1025 .ok_or_else(axum::headers::Error::invalid)?
1026 .to_str()
1027 .map_err(|_| axum::headers::Error::invalid())?
1028 .parse()
1029 .map_err(|_| axum::headers::Error::invalid())?;
1030 Ok(Self(version))
1031 }
1032
1033 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1034 values.extend([self.0.to_string().parse().unwrap()]);
1035 }
1036}
1037
1038pub struct AppVersionHeader(SemanticVersion);
1039impl Header for AppVersionHeader {
1040 fn name() -> &'static HeaderName {
1041 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1043 }
1044
1045 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046 where
1047 Self: Sized,
1048 I: Iterator<Item = &'i axum::http::HeaderValue>,
1049 {
1050 let version = values
1051 .next()
1052 .ok_or_else(axum::headers::Error::invalid)?
1053 .to_str()
1054 .map_err(|_| axum::headers::Error::invalid())?
1055 .parse()
1056 .map_err(|_| axum::headers::Error::invalid())?;
1057 Ok(Self(version))
1058 }
1059
1060 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061 values.extend([self.0.to_string().parse().unwrap()]);
1062 }
1063}
1064
1065pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1066 Router::new()
1067 .route("/rpc", get(handle_websocket_request))
1068 .layer(
1069 ServiceBuilder::new()
1070 .layer(Extension(server.app_state.clone()))
1071 .layer(middleware::from_fn(auth::validate_header)),
1072 )
1073 .route("/metrics", get(handle_metrics))
1074 .layer(Extension(server))
1075}
1076
1077#[allow(clippy::too_many_arguments)]
1078pub async fn handle_websocket_request(
1079 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1080 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1081 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1082 Extension(server): Extension<Arc<Server>>,
1083 Extension(principal): Extension<Principal>,
1084 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1085 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1086 ws: WebSocketUpgrade,
1087) -> axum::response::Response {
1088 if protocol_version != rpc::PROTOCOL_VERSION {
1089 return (
1090 StatusCode::UPGRADE_REQUIRED,
1091 "client must be upgraded".to_string(),
1092 )
1093 .into_response();
1094 }
1095
1096 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1097 return (
1098 StatusCode::UPGRADE_REQUIRED,
1099 "no version header found".to_string(),
1100 )
1101 .into_response();
1102 };
1103
1104 if !version.can_collaborate() {
1105 return (
1106 StatusCode::UPGRADE_REQUIRED,
1107 "client must be upgraded".to_string(),
1108 )
1109 .into_response();
1110 }
1111
1112 let socket_address = socket_address.to_string();
1113 ws.on_upgrade(move |socket| {
1114 let socket = socket
1115 .map_ok(to_tungstenite_message)
1116 .err_into()
1117 .with(|message| async move { to_axum_message(message) });
1118 let connection = Connection::new(Box::pin(socket));
1119 async move {
1120 server
1121 .handle_connection(
1122 connection,
1123 socket_address,
1124 principal,
1125 version,
1126 country_code_header.map(|header| header.to_string()),
1127 system_id_header.map(|header| header.to_string()),
1128 None,
1129 Executor::Production,
1130 )
1131 .await;
1132 }
1133 })
1134}
1135
1136pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1137 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1138 let connections_metric = CONNECTIONS_METRIC
1139 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1140
1141 let connections = server
1142 .connection_pool
1143 .lock()
1144 .connections()
1145 .filter(|connection| !connection.admin)
1146 .count();
1147 connections_metric.set(connections as _);
1148
1149 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1150 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1151 register_int_gauge!(
1152 "shared_projects",
1153 "number of open projects with one or more guests"
1154 )
1155 .unwrap()
1156 });
1157
1158 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1159 shared_projects_metric.set(shared_projects as _);
1160
1161 let encoder = prometheus::TextEncoder::new();
1162 let metric_families = prometheus::gather();
1163 let encoded_metrics = encoder
1164 .encode_to_string(&metric_families)
1165 .map_err(|err| anyhow!("{}", err))?;
1166 Ok(encoded_metrics)
1167}
1168
1169#[instrument(err, skip(executor))]
1170async fn connection_lost(
1171 session: Session,
1172 mut teardown: watch::Receiver<bool>,
1173 executor: Executor,
1174) -> Result<()> {
1175 session.peer.disconnect(session.connection_id);
1176 session
1177 .connection_pool()
1178 .await
1179 .remove_connection(session.connection_id)?;
1180
1181 session
1182 .db()
1183 .await
1184 .connection_lost(session.connection_id)
1185 .await
1186 .trace_err();
1187
1188 futures::select_biased! {
1189 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1190
1191 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1192 leave_room_for_session(&session, session.connection_id).await.trace_err();
1193 leave_channel_buffers_for_session(&session)
1194 .await
1195 .trace_err();
1196
1197 if !session
1198 .connection_pool()
1199 .await
1200 .is_user_online(session.user_id())
1201 {
1202 let db = session.db().await;
1203 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1204 room_updated(&room, &session.peer);
1205 }
1206 }
1207
1208 update_user_contacts(session.user_id(), &session).await?;
1209 },
1210 _ = teardown.changed().fuse() => {}
1211 }
1212
1213 Ok(())
1214}
1215
1216/// Acknowledges a ping from a client, used to keep the connection alive.
1217async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1218 response.send(proto::Ack {})?;
1219 Ok(())
1220}
1221
1222/// Creates a new room for calling (outside of channels)
1223async fn create_room(
1224 _request: proto::CreateRoom,
1225 response: Response<proto::CreateRoom>,
1226 session: Session,
1227) -> Result<()> {
1228 let livekit_room = nanoid::nanoid!(30);
1229
1230 let live_kit_connection_info = util::maybe!(async {
1231 let live_kit = session.app_state.livekit_client.as_ref();
1232 let live_kit = live_kit?;
1233 let user_id = session.user_id().to_string();
1234
1235 let token = live_kit
1236 .room_token(&livekit_room, &user_id.to_string())
1237 .trace_err()?;
1238
1239 Some(proto::LiveKitConnectionInfo {
1240 server_url: live_kit.url().into(),
1241 token,
1242 can_publish: true,
1243 })
1244 })
1245 .await;
1246
1247 let room = session
1248 .db()
1249 .await
1250 .create_room(session.user_id(), session.connection_id, &livekit_room)
1251 .await?;
1252
1253 response.send(proto::CreateRoomResponse {
1254 room: Some(room.clone()),
1255 live_kit_connection_info,
1256 })?;
1257
1258 update_user_contacts(session.user_id(), &session).await?;
1259 Ok(())
1260}
1261
1262/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1263async fn join_room(
1264 request: proto::JoinRoom,
1265 response: Response<proto::JoinRoom>,
1266 session: Session,
1267) -> Result<()> {
1268 let room_id = RoomId::from_proto(request.id);
1269
1270 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1271
1272 if let Some(channel_id) = channel_id {
1273 return join_channel_internal(channel_id, Box::new(response), session).await;
1274 }
1275
1276 let joined_room = {
1277 let room = session
1278 .db()
1279 .await
1280 .join_room(room_id, session.user_id(), session.connection_id)
1281 .await?;
1282 room_updated(&room.room, &session.peer);
1283 room.into_inner()
1284 };
1285
1286 for connection_id in session
1287 .connection_pool()
1288 .await
1289 .user_connection_ids(session.user_id())
1290 {
1291 session
1292 .peer
1293 .send(
1294 connection_id,
1295 proto::CallCanceled {
1296 room_id: room_id.to_proto(),
1297 },
1298 )
1299 .trace_err();
1300 }
1301
1302 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1303 {
1304 live_kit
1305 .room_token(
1306 &joined_room.room.livekit_room,
1307 &session.user_id().to_string(),
1308 )
1309 .trace_err()
1310 .map(|token| proto::LiveKitConnectionInfo {
1311 server_url: live_kit.url().into(),
1312 token,
1313 can_publish: true,
1314 })
1315 } else {
1316 None
1317 };
1318
1319 response.send(proto::JoinRoomResponse {
1320 room: Some(joined_room.room),
1321 channel_id: None,
1322 live_kit_connection_info,
1323 })?;
1324
1325 update_user_contacts(session.user_id(), &session).await?;
1326 Ok(())
1327}
1328
1329/// Rejoin room is used to reconnect to a room after connection errors.
1330async fn rejoin_room(
1331 request: proto::RejoinRoom,
1332 response: Response<proto::RejoinRoom>,
1333 session: Session,
1334) -> Result<()> {
1335 let room;
1336 let channel;
1337 {
1338 let mut rejoined_room = session
1339 .db()
1340 .await
1341 .rejoin_room(request, session.user_id(), session.connection_id)
1342 .await?;
1343
1344 response.send(proto::RejoinRoomResponse {
1345 room: Some(rejoined_room.room.clone()),
1346 reshared_projects: rejoined_room
1347 .reshared_projects
1348 .iter()
1349 .map(|project| proto::ResharedProject {
1350 id: project.id.to_proto(),
1351 collaborators: project
1352 .collaborators
1353 .iter()
1354 .map(|collaborator| collaborator.to_proto())
1355 .collect(),
1356 })
1357 .collect(),
1358 rejoined_projects: rejoined_room
1359 .rejoined_projects
1360 .iter()
1361 .map(|rejoined_project| rejoined_project.to_proto())
1362 .collect(),
1363 })?;
1364 room_updated(&rejoined_room.room, &session.peer);
1365
1366 for project in &rejoined_room.reshared_projects {
1367 for collaborator in &project.collaborators {
1368 session
1369 .peer
1370 .send(
1371 collaborator.connection_id,
1372 proto::UpdateProjectCollaborator {
1373 project_id: project.id.to_proto(),
1374 old_peer_id: Some(project.old_connection_id.into()),
1375 new_peer_id: Some(session.connection_id.into()),
1376 },
1377 )
1378 .trace_err();
1379 }
1380
1381 broadcast(
1382 Some(session.connection_id),
1383 project
1384 .collaborators
1385 .iter()
1386 .map(|collaborator| collaborator.connection_id),
1387 |connection_id| {
1388 session.peer.forward_send(
1389 session.connection_id,
1390 connection_id,
1391 proto::UpdateProject {
1392 project_id: project.id.to_proto(),
1393 worktrees: project.worktrees.clone(),
1394 },
1395 )
1396 },
1397 );
1398 }
1399
1400 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1401
1402 let rejoined_room = rejoined_room.into_inner();
1403
1404 room = rejoined_room.room;
1405 channel = rejoined_room.channel;
1406 }
1407
1408 if let Some(channel) = channel {
1409 channel_updated(
1410 &channel,
1411 &room,
1412 &session.peer,
1413 &*session.connection_pool().await,
1414 );
1415 }
1416
1417 update_user_contacts(session.user_id(), &session).await?;
1418 Ok(())
1419}
1420
1421fn notify_rejoined_projects(
1422 rejoined_projects: &mut Vec<RejoinedProject>,
1423 session: &Session,
1424) -> Result<()> {
1425 for project in rejoined_projects.iter() {
1426 for collaborator in &project.collaborators {
1427 session
1428 .peer
1429 .send(
1430 collaborator.connection_id,
1431 proto::UpdateProjectCollaborator {
1432 project_id: project.id.to_proto(),
1433 old_peer_id: Some(project.old_connection_id.into()),
1434 new_peer_id: Some(session.connection_id.into()),
1435 },
1436 )
1437 .trace_err();
1438 }
1439 }
1440
1441 for project in rejoined_projects {
1442 for worktree in mem::take(&mut project.worktrees) {
1443 // Stream this worktree's entries.
1444 let message = proto::UpdateWorktree {
1445 project_id: project.id.to_proto(),
1446 worktree_id: worktree.id,
1447 abs_path: worktree.abs_path.clone(),
1448 root_name: worktree.root_name,
1449 updated_entries: worktree.updated_entries,
1450 removed_entries: worktree.removed_entries,
1451 scan_id: worktree.scan_id,
1452 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1453 updated_repositories: worktree.updated_repositories,
1454 removed_repositories: worktree.removed_repositories,
1455 };
1456 for update in proto::split_worktree_update(message) {
1457 session.peer.send(session.connection_id, update.clone())?;
1458 }
1459
1460 // Stream this worktree's diagnostics.
1461 for summary in worktree.diagnostic_summaries {
1462 session.peer.send(
1463 session.connection_id,
1464 proto::UpdateDiagnosticSummary {
1465 project_id: project.id.to_proto(),
1466 worktree_id: worktree.id,
1467 summary: Some(summary),
1468 },
1469 )?;
1470 }
1471
1472 for settings_file in worktree.settings_files {
1473 session.peer.send(
1474 session.connection_id,
1475 proto::UpdateWorktreeSettings {
1476 project_id: project.id.to_proto(),
1477 worktree_id: worktree.id,
1478 path: settings_file.path,
1479 content: Some(settings_file.content),
1480 kind: Some(settings_file.kind.to_proto().into()),
1481 },
1482 )?;
1483 }
1484 }
1485
1486 for language_server in &project.language_servers {
1487 session.peer.send(
1488 session.connection_id,
1489 proto::UpdateLanguageServer {
1490 project_id: project.id.to_proto(),
1491 language_server_id: language_server.id,
1492 variant: Some(
1493 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1494 proto::LspDiskBasedDiagnosticsUpdated {},
1495 ),
1496 ),
1497 },
1498 )?;
1499 }
1500 }
1501 Ok(())
1502}
1503
1504/// leave room disconnects from the room.
1505async fn leave_room(
1506 _: proto::LeaveRoom,
1507 response: Response<proto::LeaveRoom>,
1508 session: Session,
1509) -> Result<()> {
1510 leave_room_for_session(&session, session.connection_id).await?;
1511 response.send(proto::Ack {})?;
1512 Ok(())
1513}
1514
1515/// Updates the permissions of someone else in the room.
1516async fn set_room_participant_role(
1517 request: proto::SetRoomParticipantRole,
1518 response: Response<proto::SetRoomParticipantRole>,
1519 session: Session,
1520) -> Result<()> {
1521 let user_id = UserId::from_proto(request.user_id);
1522 let role = ChannelRole::from(request.role());
1523
1524 let (livekit_room, can_publish) = {
1525 let room = session
1526 .db()
1527 .await
1528 .set_room_participant_role(
1529 session.user_id(),
1530 RoomId::from_proto(request.room_id),
1531 user_id,
1532 role,
1533 )
1534 .await?;
1535
1536 let livekit_room = room.livekit_room.clone();
1537 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1538 room_updated(&room, &session.peer);
1539 (livekit_room, can_publish)
1540 };
1541
1542 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1543 live_kit
1544 .update_participant(
1545 livekit_room.clone(),
1546 request.user_id.to_string(),
1547 livekit_server::proto::ParticipantPermission {
1548 can_subscribe: true,
1549 can_publish,
1550 can_publish_data: can_publish,
1551 hidden: false,
1552 recorder: false,
1553 },
1554 )
1555 .await
1556 .trace_err();
1557 }
1558
1559 response.send(proto::Ack {})?;
1560 Ok(())
1561}
1562
1563/// Call someone else into the current room
1564async fn call(
1565 request: proto::Call,
1566 response: Response<proto::Call>,
1567 session: Session,
1568) -> Result<()> {
1569 let room_id = RoomId::from_proto(request.room_id);
1570 let calling_user_id = session.user_id();
1571 let calling_connection_id = session.connection_id;
1572 let called_user_id = UserId::from_proto(request.called_user_id);
1573 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1574 if !session
1575 .db()
1576 .await
1577 .has_contact(calling_user_id, called_user_id)
1578 .await?
1579 {
1580 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1581 }
1582
1583 let incoming_call = {
1584 let (room, incoming_call) = &mut *session
1585 .db()
1586 .await
1587 .call(
1588 room_id,
1589 calling_user_id,
1590 calling_connection_id,
1591 called_user_id,
1592 initial_project_id,
1593 )
1594 .await?;
1595 room_updated(room, &session.peer);
1596 mem::take(incoming_call)
1597 };
1598 update_user_contacts(called_user_id, &session).await?;
1599
1600 let mut calls = session
1601 .connection_pool()
1602 .await
1603 .user_connection_ids(called_user_id)
1604 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1605 .collect::<FuturesUnordered<_>>();
1606
1607 while let Some(call_response) = calls.next().await {
1608 match call_response.as_ref() {
1609 Ok(_) => {
1610 response.send(proto::Ack {})?;
1611 return Ok(());
1612 }
1613 Err(_) => {
1614 call_response.trace_err();
1615 }
1616 }
1617 }
1618
1619 {
1620 let room = session
1621 .db()
1622 .await
1623 .call_failed(room_id, called_user_id)
1624 .await?;
1625 room_updated(&room, &session.peer);
1626 }
1627 update_user_contacts(called_user_id, &session).await?;
1628
1629 Err(anyhow!("failed to ring user"))?
1630}
1631
1632/// Cancel an outgoing call.
1633async fn cancel_call(
1634 request: proto::CancelCall,
1635 response: Response<proto::CancelCall>,
1636 session: Session,
1637) -> Result<()> {
1638 let called_user_id = UserId::from_proto(request.called_user_id);
1639 let room_id = RoomId::from_proto(request.room_id);
1640 {
1641 let room = session
1642 .db()
1643 .await
1644 .cancel_call(room_id, session.connection_id, called_user_id)
1645 .await?;
1646 room_updated(&room, &session.peer);
1647 }
1648
1649 for connection_id in session
1650 .connection_pool()
1651 .await
1652 .user_connection_ids(called_user_id)
1653 {
1654 session
1655 .peer
1656 .send(
1657 connection_id,
1658 proto::CallCanceled {
1659 room_id: room_id.to_proto(),
1660 },
1661 )
1662 .trace_err();
1663 }
1664 response.send(proto::Ack {})?;
1665
1666 update_user_contacts(called_user_id, &session).await?;
1667 Ok(())
1668}
1669
1670/// Decline an incoming call.
1671async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1672 let room_id = RoomId::from_proto(message.room_id);
1673 {
1674 let room = session
1675 .db()
1676 .await
1677 .decline_call(Some(room_id), session.user_id())
1678 .await?
1679 .ok_or_else(|| anyhow!("failed to decline call"))?;
1680 room_updated(&room, &session.peer);
1681 }
1682
1683 for connection_id in session
1684 .connection_pool()
1685 .await
1686 .user_connection_ids(session.user_id())
1687 {
1688 session
1689 .peer
1690 .send(
1691 connection_id,
1692 proto::CallCanceled {
1693 room_id: room_id.to_proto(),
1694 },
1695 )
1696 .trace_err();
1697 }
1698 update_user_contacts(session.user_id(), &session).await?;
1699 Ok(())
1700}
1701
1702/// Updates other participants in the room with your current location.
1703async fn update_participant_location(
1704 request: proto::UpdateParticipantLocation,
1705 response: Response<proto::UpdateParticipantLocation>,
1706 session: Session,
1707) -> Result<()> {
1708 let room_id = RoomId::from_proto(request.room_id);
1709 let location = request
1710 .location
1711 .ok_or_else(|| anyhow!("invalid location"))?;
1712
1713 let db = session.db().await;
1714 let room = db
1715 .update_room_participant_location(room_id, session.connection_id, location)
1716 .await?;
1717
1718 room_updated(&room, &session.peer);
1719 response.send(proto::Ack {})?;
1720 Ok(())
1721}
1722
1723/// Share a project into the room.
1724async fn share_project(
1725 request: proto::ShareProject,
1726 response: Response<proto::ShareProject>,
1727 session: Session,
1728) -> Result<()> {
1729 let (project_id, room) = &*session
1730 .db()
1731 .await
1732 .share_project(
1733 RoomId::from_proto(request.room_id),
1734 session.connection_id,
1735 &request.worktrees,
1736 request.is_ssh_project,
1737 )
1738 .await?;
1739 response.send(proto::ShareProjectResponse {
1740 project_id: project_id.to_proto(),
1741 })?;
1742 room_updated(room, &session.peer);
1743
1744 Ok(())
1745}
1746
1747/// Unshare a project from the room.
1748async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1749 let project_id = ProjectId::from_proto(message.project_id);
1750 unshare_project_internal(project_id, session.connection_id, &session).await
1751}
1752
1753async fn unshare_project_internal(
1754 project_id: ProjectId,
1755 connection_id: ConnectionId,
1756 session: &Session,
1757) -> Result<()> {
1758 let delete = {
1759 let room_guard = session
1760 .db()
1761 .await
1762 .unshare_project(project_id, connection_id)
1763 .await?;
1764
1765 let (delete, room, guest_connection_ids) = &*room_guard;
1766
1767 let message = proto::UnshareProject {
1768 project_id: project_id.to_proto(),
1769 };
1770
1771 broadcast(
1772 Some(connection_id),
1773 guest_connection_ids.iter().copied(),
1774 |conn_id| session.peer.send(conn_id, message.clone()),
1775 );
1776 if let Some(room) = room {
1777 room_updated(room, &session.peer);
1778 }
1779
1780 *delete
1781 };
1782
1783 if delete {
1784 let db = session.db().await;
1785 db.delete_project(project_id).await?;
1786 }
1787
1788 Ok(())
1789}
1790
1791/// Join someone elses shared project.
1792async fn join_project(
1793 request: proto::JoinProject,
1794 response: Response<proto::JoinProject>,
1795 session: Session,
1796) -> Result<()> {
1797 let project_id = ProjectId::from_proto(request.project_id);
1798
1799 tracing::info!(%project_id, "join project");
1800
1801 let db = session.db().await;
1802 let (project, replica_id) = &mut *db
1803 .join_project(project_id, session.connection_id, session.user_id())
1804 .await?;
1805 drop(db);
1806 tracing::info!(%project_id, "join remote project");
1807 join_project_internal(response, session, project, replica_id)
1808}
1809
1810trait JoinProjectInternalResponse {
1811 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1814 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815 Response::<proto::JoinProject>::send(self, result)
1816 }
1817}
1818
1819fn join_project_internal(
1820 response: impl JoinProjectInternalResponse,
1821 session: Session,
1822 project: &mut Project,
1823 replica_id: &ReplicaId,
1824) -> Result<()> {
1825 let collaborators = project
1826 .collaborators
1827 .iter()
1828 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829 .map(|collaborator| collaborator.to_proto())
1830 .collect::<Vec<_>>();
1831 let project_id = project.id;
1832 let guest_user_id = session.user_id();
1833
1834 let worktrees = project
1835 .worktrees
1836 .iter()
1837 .map(|(id, worktree)| proto::WorktreeMetadata {
1838 id: *id,
1839 root_name: worktree.root_name.clone(),
1840 visible: worktree.visible,
1841 abs_path: worktree.abs_path.clone(),
1842 })
1843 .collect::<Vec<_>>();
1844
1845 let add_project_collaborator = proto::AddProjectCollaborator {
1846 project_id: project_id.to_proto(),
1847 collaborator: Some(proto::Collaborator {
1848 peer_id: Some(session.connection_id.into()),
1849 replica_id: replica_id.0 as u32,
1850 user_id: guest_user_id.to_proto(),
1851 is_host: false,
1852 }),
1853 };
1854
1855 for collaborator in &collaborators {
1856 session
1857 .peer
1858 .send(
1859 collaborator.peer_id.unwrap().into(),
1860 add_project_collaborator.clone(),
1861 )
1862 .trace_err();
1863 }
1864
1865 // First, we send the metadata associated with each worktree.
1866 response.send(proto::JoinProjectResponse {
1867 project_id: project.id.0 as u64,
1868 worktrees: worktrees.clone(),
1869 replica_id: replica_id.0 as u32,
1870 collaborators: collaborators.clone(),
1871 language_servers: project.language_servers.clone(),
1872 role: project.role.into(),
1873 })?;
1874
1875 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1876 // Stream this worktree's entries.
1877 let message = proto::UpdateWorktree {
1878 project_id: project_id.to_proto(),
1879 worktree_id,
1880 abs_path: worktree.abs_path.clone(),
1881 root_name: worktree.root_name,
1882 updated_entries: worktree.entries,
1883 removed_entries: Default::default(),
1884 scan_id: worktree.scan_id,
1885 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886 updated_repositories: worktree.repository_entries.into_values().collect(),
1887 removed_repositories: Default::default(),
1888 };
1889 for update in proto::split_worktree_update(message) {
1890 session.peer.send(session.connection_id, update.clone())?;
1891 }
1892
1893 // Stream this worktree's diagnostics.
1894 for summary in worktree.diagnostic_summaries {
1895 session.peer.send(
1896 session.connection_id,
1897 proto::UpdateDiagnosticSummary {
1898 project_id: project_id.to_proto(),
1899 worktree_id: worktree.id,
1900 summary: Some(summary),
1901 },
1902 )?;
1903 }
1904
1905 for settings_file in worktree.settings_files {
1906 session.peer.send(
1907 session.connection_id,
1908 proto::UpdateWorktreeSettings {
1909 project_id: project_id.to_proto(),
1910 worktree_id: worktree.id,
1911 path: settings_file.path,
1912 content: Some(settings_file.content),
1913 kind: Some(settings_file.kind.to_proto() as i32),
1914 },
1915 )?;
1916 }
1917 }
1918
1919 for language_server in &project.language_servers {
1920 session.peer.send(
1921 session.connection_id,
1922 proto::UpdateLanguageServer {
1923 project_id: project_id.to_proto(),
1924 language_server_id: language_server.id,
1925 variant: Some(
1926 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1927 proto::LspDiskBasedDiagnosticsUpdated {},
1928 ),
1929 ),
1930 },
1931 )?;
1932 }
1933
1934 Ok(())
1935}
1936
1937/// Leave someone elses shared project.
1938async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1939 let sender_id = session.connection_id;
1940 let project_id = ProjectId::from_proto(request.project_id);
1941 let db = session.db().await;
1942
1943 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1944 tracing::info!(
1945 %project_id,
1946 "leave project"
1947 );
1948
1949 project_left(project, &session);
1950 if let Some(room) = room {
1951 room_updated(room, &session.peer);
1952 }
1953
1954 Ok(())
1955}
1956
1957/// Updates other participants with changes to the project
1958async fn update_project(
1959 request: proto::UpdateProject,
1960 response: Response<proto::UpdateProject>,
1961 session: Session,
1962) -> Result<()> {
1963 let project_id = ProjectId::from_proto(request.project_id);
1964 let (room, guest_connection_ids) = &*session
1965 .db()
1966 .await
1967 .update_project(project_id, session.connection_id, &request.worktrees)
1968 .await?;
1969 broadcast(
1970 Some(session.connection_id),
1971 guest_connection_ids.iter().copied(),
1972 |connection_id| {
1973 session
1974 .peer
1975 .forward_send(session.connection_id, connection_id, request.clone())
1976 },
1977 );
1978 if let Some(room) = room {
1979 room_updated(room, &session.peer);
1980 }
1981 response.send(proto::Ack {})?;
1982
1983 Ok(())
1984}
1985
1986/// Updates other participants with changes to the worktree
1987async fn update_worktree(
1988 request: proto::UpdateWorktree,
1989 response: Response<proto::UpdateWorktree>,
1990 session: Session,
1991) -> Result<()> {
1992 let guest_connection_ids = session
1993 .db()
1994 .await
1995 .update_worktree(&request, session.connection_id)
1996 .await?;
1997
1998 broadcast(
1999 Some(session.connection_id),
2000 guest_connection_ids.iter().copied(),
2001 |connection_id| {
2002 session
2003 .peer
2004 .forward_send(session.connection_id, connection_id, request.clone())
2005 },
2006 );
2007 response.send(proto::Ack {})?;
2008 Ok(())
2009}
2010
2011/// Updates other participants with changes to the diagnostics
2012async fn update_diagnostic_summary(
2013 message: proto::UpdateDiagnosticSummary,
2014 session: Session,
2015) -> Result<()> {
2016 let guest_connection_ids = session
2017 .db()
2018 .await
2019 .update_diagnostic_summary(&message, session.connection_id)
2020 .await?;
2021
2022 broadcast(
2023 Some(session.connection_id),
2024 guest_connection_ids.iter().copied(),
2025 |connection_id| {
2026 session
2027 .peer
2028 .forward_send(session.connection_id, connection_id, message.clone())
2029 },
2030 );
2031
2032 Ok(())
2033}
2034
2035/// Updates other participants with changes to the worktree settings
2036async fn update_worktree_settings(
2037 message: proto::UpdateWorktreeSettings,
2038 session: Session,
2039) -> Result<()> {
2040 let guest_connection_ids = session
2041 .db()
2042 .await
2043 .update_worktree_settings(&message, session.connection_id)
2044 .await?;
2045
2046 broadcast(
2047 Some(session.connection_id),
2048 guest_connection_ids.iter().copied(),
2049 |connection_id| {
2050 session
2051 .peer
2052 .forward_send(session.connection_id, connection_id, message.clone())
2053 },
2054 );
2055
2056 Ok(())
2057}
2058
2059/// Notify other participants that a language server has started.
2060async fn start_language_server(
2061 request: proto::StartLanguageServer,
2062 session: Session,
2063) -> Result<()> {
2064 let guest_connection_ids = session
2065 .db()
2066 .await
2067 .start_language_server(&request, session.connection_id)
2068 .await?;
2069
2070 broadcast(
2071 Some(session.connection_id),
2072 guest_connection_ids.iter().copied(),
2073 |connection_id| {
2074 session
2075 .peer
2076 .forward_send(session.connection_id, connection_id, request.clone())
2077 },
2078 );
2079 Ok(())
2080}
2081
2082/// Notify other participants that a language server has changed.
2083async fn update_language_server(
2084 request: proto::UpdateLanguageServer,
2085 session: Session,
2086) -> Result<()> {
2087 let project_id = ProjectId::from_proto(request.project_id);
2088 let project_connection_ids = session
2089 .db()
2090 .await
2091 .project_connection_ids(project_id, session.connection_id, true)
2092 .await?;
2093 broadcast(
2094 Some(session.connection_id),
2095 project_connection_ids.iter().copied(),
2096 |connection_id| {
2097 session
2098 .peer
2099 .forward_send(session.connection_id, connection_id, request.clone())
2100 },
2101 );
2102 Ok(())
2103}
2104
2105/// forward a project request to the host. These requests should be read only
2106/// as guests are allowed to send them.
2107async fn forward_read_only_project_request<T>(
2108 request: T,
2109 response: Response<T>,
2110 session: Session,
2111) -> Result<()>
2112where
2113 T: EntityMessage + RequestMessage,
2114{
2115 let project_id = ProjectId::from_proto(request.remote_entity_id());
2116 let host_connection_id = session
2117 .db()
2118 .await
2119 .host_for_read_only_project_request(project_id, session.connection_id)
2120 .await?;
2121 let payload = session
2122 .peer
2123 .forward_request(session.connection_id, host_connection_id, request)
2124 .await?;
2125 response.send(payload)?;
2126 Ok(())
2127}
2128
2129async fn forward_find_search_candidates_request(
2130 request: proto::FindSearchCandidates,
2131 response: Response<proto::FindSearchCandidates>,
2132 session: Session,
2133) -> Result<()> {
2134 let project_id = ProjectId::from_proto(request.remote_entity_id());
2135 let host_connection_id = session
2136 .db()
2137 .await
2138 .host_for_read_only_project_request(project_id, session.connection_id)
2139 .await?;
2140 let payload = session
2141 .peer
2142 .forward_request(session.connection_id, host_connection_id, request)
2143 .await?;
2144 response.send(payload)?;
2145 Ok(())
2146}
2147
2148/// forward a project request to the host. These requests are disallowed
2149/// for guests.
2150async fn forward_mutating_project_request<T>(
2151 request: T,
2152 response: Response<T>,
2153 session: Session,
2154) -> Result<()>
2155where
2156 T: EntityMessage + RequestMessage,
2157{
2158 let project_id = ProjectId::from_proto(request.remote_entity_id());
2159
2160 let host_connection_id = session
2161 .db()
2162 .await
2163 .host_for_mutating_project_request(project_id, session.connection_id)
2164 .await?;
2165 let payload = session
2166 .peer
2167 .forward_request(session.connection_id, host_connection_id, request)
2168 .await?;
2169 response.send(payload)?;
2170 Ok(())
2171}
2172
2173/// Notify other participants that a new buffer has been created
2174async fn create_buffer_for_peer(
2175 request: proto::CreateBufferForPeer,
2176 session: Session,
2177) -> Result<()> {
2178 session
2179 .db()
2180 .await
2181 .check_user_is_project_host(
2182 ProjectId::from_proto(request.project_id),
2183 session.connection_id,
2184 )
2185 .await?;
2186 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2187 session
2188 .peer
2189 .forward_send(session.connection_id, peer_id.into(), request)?;
2190 Ok(())
2191}
2192
2193/// Notify other participants that a buffer has been updated. This is
2194/// allowed for guests as long as the update is limited to selections.
2195async fn update_buffer(
2196 request: proto::UpdateBuffer,
2197 response: Response<proto::UpdateBuffer>,
2198 session: Session,
2199) -> Result<()> {
2200 let project_id = ProjectId::from_proto(request.project_id);
2201 let mut capability = Capability::ReadOnly;
2202
2203 for op in request.operations.iter() {
2204 match op.variant {
2205 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2206 Some(_) => capability = Capability::ReadWrite,
2207 }
2208 }
2209
2210 let host = {
2211 let guard = session
2212 .db()
2213 .await
2214 .connections_for_buffer_update(project_id, session.connection_id, capability)
2215 .await?;
2216
2217 let (host, guests) = &*guard;
2218
2219 broadcast(
2220 Some(session.connection_id),
2221 guests.clone(),
2222 |connection_id| {
2223 session
2224 .peer
2225 .forward_send(session.connection_id, connection_id, request.clone())
2226 },
2227 );
2228
2229 *host
2230 };
2231
2232 if host != session.connection_id {
2233 session
2234 .peer
2235 .forward_request(session.connection_id, host, request.clone())
2236 .await?;
2237 }
2238
2239 response.send(proto::Ack {})?;
2240 Ok(())
2241}
2242
2243async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2244 let project_id = ProjectId::from_proto(message.project_id);
2245
2246 let operation = message.operation.as_ref().context("invalid operation")?;
2247 let capability = match operation.variant.as_ref() {
2248 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2249 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2250 match buffer_op.variant {
2251 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2252 Capability::ReadOnly
2253 }
2254 _ => Capability::ReadWrite,
2255 }
2256 } else {
2257 Capability::ReadWrite
2258 }
2259 }
2260 Some(_) => Capability::ReadWrite,
2261 None => Capability::ReadOnly,
2262 };
2263
2264 let guard = session
2265 .db()
2266 .await
2267 .connections_for_buffer_update(project_id, session.connection_id, capability)
2268 .await?;
2269
2270 let (host, guests) = &*guard;
2271
2272 broadcast(
2273 Some(session.connection_id),
2274 guests.iter().chain([host]).copied(),
2275 |connection_id| {
2276 session
2277 .peer
2278 .forward_send(session.connection_id, connection_id, message.clone())
2279 },
2280 );
2281
2282 Ok(())
2283}
2284
2285/// Notify other participants that a project has been updated.
2286async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2287 request: T,
2288 session: Session,
2289) -> Result<()> {
2290 let project_id = ProjectId::from_proto(request.remote_entity_id());
2291 let project_connection_ids = session
2292 .db()
2293 .await
2294 .project_connection_ids(project_id, session.connection_id, false)
2295 .await?;
2296
2297 broadcast(
2298 Some(session.connection_id),
2299 project_connection_ids.iter().copied(),
2300 |connection_id| {
2301 session
2302 .peer
2303 .forward_send(session.connection_id, connection_id, request.clone())
2304 },
2305 );
2306 Ok(())
2307}
2308
2309/// Start following another user in a call.
2310async fn follow(
2311 request: proto::Follow,
2312 response: Response<proto::Follow>,
2313 session: Session,
2314) -> Result<()> {
2315 let room_id = RoomId::from_proto(request.room_id);
2316 let project_id = request.project_id.map(ProjectId::from_proto);
2317 let leader_id = request
2318 .leader_id
2319 .ok_or_else(|| anyhow!("invalid leader id"))?
2320 .into();
2321 let follower_id = session.connection_id;
2322
2323 session
2324 .db()
2325 .await
2326 .check_room_participants(room_id, leader_id, session.connection_id)
2327 .await?;
2328
2329 let response_payload = session
2330 .peer
2331 .forward_request(session.connection_id, leader_id, request)
2332 .await?;
2333 response.send(response_payload)?;
2334
2335 if let Some(project_id) = project_id {
2336 let room = session
2337 .db()
2338 .await
2339 .follow(room_id, project_id, leader_id, follower_id)
2340 .await?;
2341 room_updated(&room, &session.peer);
2342 }
2343
2344 Ok(())
2345}
2346
2347/// Stop following another user in a call.
2348async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2349 let room_id = RoomId::from_proto(request.room_id);
2350 let project_id = request.project_id.map(ProjectId::from_proto);
2351 let leader_id = request
2352 .leader_id
2353 .ok_or_else(|| anyhow!("invalid leader id"))?
2354 .into();
2355 let follower_id = session.connection_id;
2356
2357 session
2358 .db()
2359 .await
2360 .check_room_participants(room_id, leader_id, session.connection_id)
2361 .await?;
2362
2363 session
2364 .peer
2365 .forward_send(session.connection_id, leader_id, request)?;
2366
2367 if let Some(project_id) = project_id {
2368 let room = session
2369 .db()
2370 .await
2371 .unfollow(room_id, project_id, leader_id, follower_id)
2372 .await?;
2373 room_updated(&room, &session.peer);
2374 }
2375
2376 Ok(())
2377}
2378
2379/// Notify everyone following you of your current location.
2380async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2381 let room_id = RoomId::from_proto(request.room_id);
2382 let database = session.db.lock().await;
2383
2384 let connection_ids = if let Some(project_id) = request.project_id {
2385 let project_id = ProjectId::from_proto(project_id);
2386 database
2387 .project_connection_ids(project_id, session.connection_id, true)
2388 .await?
2389 } else {
2390 database
2391 .room_connection_ids(room_id, session.connection_id)
2392 .await?
2393 };
2394
2395 // For now, don't send view update messages back to that view's current leader.
2396 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2397 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2398 _ => None,
2399 });
2400
2401 for connection_id in connection_ids.iter().cloned() {
2402 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2403 session
2404 .peer
2405 .forward_send(session.connection_id, connection_id, request.clone())?;
2406 }
2407 }
2408 Ok(())
2409}
2410
2411/// Get public data about users.
2412async fn get_users(
2413 request: proto::GetUsers,
2414 response: Response<proto::GetUsers>,
2415 session: Session,
2416) -> Result<()> {
2417 let user_ids = request
2418 .user_ids
2419 .into_iter()
2420 .map(UserId::from_proto)
2421 .collect();
2422 let users = session
2423 .db()
2424 .await
2425 .get_users_by_ids(user_ids)
2426 .await?
2427 .into_iter()
2428 .map(|user| proto::User {
2429 id: user.id.to_proto(),
2430 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2431 github_login: user.github_login,
2432 email: user.email_address,
2433 name: user.name,
2434 })
2435 .collect();
2436 response.send(proto::UsersResponse { users })?;
2437 Ok(())
2438}
2439
2440/// Search for users (to invite) buy Github login
2441async fn fuzzy_search_users(
2442 request: proto::FuzzySearchUsers,
2443 response: Response<proto::FuzzySearchUsers>,
2444 session: Session,
2445) -> Result<()> {
2446 let query = request.query;
2447 let users = match query.len() {
2448 0 => vec![],
2449 1 | 2 => session
2450 .db()
2451 .await
2452 .get_user_by_github_login(&query)
2453 .await?
2454 .into_iter()
2455 .collect(),
2456 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2457 };
2458 let users = users
2459 .into_iter()
2460 .filter(|user| user.id != session.user_id())
2461 .map(|user| proto::User {
2462 id: user.id.to_proto(),
2463 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2464 github_login: user.github_login,
2465 name: user.name,
2466 email: user.email_address,
2467 })
2468 .collect();
2469 response.send(proto::UsersResponse { users })?;
2470 Ok(())
2471}
2472
2473/// Send a contact request to another user.
2474async fn request_contact(
2475 request: proto::RequestContact,
2476 response: Response<proto::RequestContact>,
2477 session: Session,
2478) -> Result<()> {
2479 let requester_id = session.user_id();
2480 let responder_id = UserId::from_proto(request.responder_id);
2481 if requester_id == responder_id {
2482 return Err(anyhow!("cannot add yourself as a contact"))?;
2483 }
2484
2485 let notifications = session
2486 .db()
2487 .await
2488 .send_contact_request(requester_id, responder_id)
2489 .await?;
2490
2491 // Update outgoing contact requests of requester
2492 let mut update = proto::UpdateContacts::default();
2493 update.outgoing_requests.push(responder_id.to_proto());
2494 for connection_id in session
2495 .connection_pool()
2496 .await
2497 .user_connection_ids(requester_id)
2498 {
2499 session.peer.send(connection_id, update.clone())?;
2500 }
2501
2502 // Update incoming contact requests of responder
2503 let mut update = proto::UpdateContacts::default();
2504 update
2505 .incoming_requests
2506 .push(proto::IncomingContactRequest {
2507 requester_id: requester_id.to_proto(),
2508 });
2509 let connection_pool = session.connection_pool().await;
2510 for connection_id in connection_pool.user_connection_ids(responder_id) {
2511 session.peer.send(connection_id, update.clone())?;
2512 }
2513
2514 send_notifications(&connection_pool, &session.peer, notifications);
2515
2516 response.send(proto::Ack {})?;
2517 Ok(())
2518}
2519
2520/// Accept or decline a contact request
2521async fn respond_to_contact_request(
2522 request: proto::RespondToContactRequest,
2523 response: Response<proto::RespondToContactRequest>,
2524 session: Session,
2525) -> Result<()> {
2526 let responder_id = session.user_id();
2527 let requester_id = UserId::from_proto(request.requester_id);
2528 let db = session.db().await;
2529 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2530 db.dismiss_contact_notification(responder_id, requester_id)
2531 .await?;
2532 } else {
2533 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2534
2535 let notifications = db
2536 .respond_to_contact_request(responder_id, requester_id, accept)
2537 .await?;
2538 let requester_busy = db.is_user_busy(requester_id).await?;
2539 let responder_busy = db.is_user_busy(responder_id).await?;
2540
2541 let pool = session.connection_pool().await;
2542 // Update responder with new contact
2543 let mut update = proto::UpdateContacts::default();
2544 if accept {
2545 update
2546 .contacts
2547 .push(contact_for_user(requester_id, requester_busy, &pool));
2548 }
2549 update
2550 .remove_incoming_requests
2551 .push(requester_id.to_proto());
2552 for connection_id in pool.user_connection_ids(responder_id) {
2553 session.peer.send(connection_id, update.clone())?;
2554 }
2555
2556 // Update requester with new contact
2557 let mut update = proto::UpdateContacts::default();
2558 if accept {
2559 update
2560 .contacts
2561 .push(contact_for_user(responder_id, responder_busy, &pool));
2562 }
2563 update
2564 .remove_outgoing_requests
2565 .push(responder_id.to_proto());
2566
2567 for connection_id in pool.user_connection_ids(requester_id) {
2568 session.peer.send(connection_id, update.clone())?;
2569 }
2570
2571 send_notifications(&pool, &session.peer, notifications);
2572 }
2573
2574 response.send(proto::Ack {})?;
2575 Ok(())
2576}
2577
2578/// Remove a contact.
2579async fn remove_contact(
2580 request: proto::RemoveContact,
2581 response: Response<proto::RemoveContact>,
2582 session: Session,
2583) -> Result<()> {
2584 let requester_id = session.user_id();
2585 let responder_id = UserId::from_proto(request.user_id);
2586 let db = session.db().await;
2587 let (contact_accepted, deleted_notification_id) =
2588 db.remove_contact(requester_id, responder_id).await?;
2589
2590 let pool = session.connection_pool().await;
2591 // Update outgoing contact requests of requester
2592 let mut update = proto::UpdateContacts::default();
2593 if contact_accepted {
2594 update.remove_contacts.push(responder_id.to_proto());
2595 } else {
2596 update
2597 .remove_outgoing_requests
2598 .push(responder_id.to_proto());
2599 }
2600 for connection_id in pool.user_connection_ids(requester_id) {
2601 session.peer.send(connection_id, update.clone())?;
2602 }
2603
2604 // Update incoming contact requests of responder
2605 let mut update = proto::UpdateContacts::default();
2606 if contact_accepted {
2607 update.remove_contacts.push(requester_id.to_proto());
2608 } else {
2609 update
2610 .remove_incoming_requests
2611 .push(requester_id.to_proto());
2612 }
2613 for connection_id in pool.user_connection_ids(responder_id) {
2614 session.peer.send(connection_id, update.clone())?;
2615 if let Some(notification_id) = deleted_notification_id {
2616 session.peer.send(
2617 connection_id,
2618 proto::DeleteNotification {
2619 notification_id: notification_id.to_proto(),
2620 },
2621 )?;
2622 }
2623 }
2624
2625 response.send(proto::Ack {})?;
2626 Ok(())
2627}
2628
2629fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2630 version.0.minor() < 139
2631}
2632
2633async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2634 let plan = session.current_plan(&session.db().await).await?;
2635
2636 session
2637 .peer
2638 .send(
2639 session.connection_id,
2640 proto::UpdateUserPlan { plan: plan.into() },
2641 )
2642 .trace_err();
2643
2644 Ok(())
2645}
2646
2647async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2648 subscribe_user_to_channels(session.user_id(), &session).await?;
2649 Ok(())
2650}
2651
2652async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2653 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2654 let mut pool = session.connection_pool().await;
2655 for membership in &channels_for_user.channel_memberships {
2656 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2657 }
2658 session.peer.send(
2659 session.connection_id,
2660 build_update_user_channels(&channels_for_user),
2661 )?;
2662 session.peer.send(
2663 session.connection_id,
2664 build_channels_update(channels_for_user),
2665 )?;
2666 Ok(())
2667}
2668
2669/// Creates a new channel.
2670async fn create_channel(
2671 request: proto::CreateChannel,
2672 response: Response<proto::CreateChannel>,
2673 session: Session,
2674) -> Result<()> {
2675 let db = session.db().await;
2676
2677 let parent_id = request.parent_id.map(ChannelId::from_proto);
2678 let (channel, membership) = db
2679 .create_channel(&request.name, parent_id, session.user_id())
2680 .await?;
2681
2682 let root_id = channel.root_id();
2683 let channel = Channel::from_model(channel);
2684
2685 response.send(proto::CreateChannelResponse {
2686 channel: Some(channel.to_proto()),
2687 parent_id: request.parent_id,
2688 })?;
2689
2690 let mut connection_pool = session.connection_pool().await;
2691 if let Some(membership) = membership {
2692 connection_pool.subscribe_to_channel(
2693 membership.user_id,
2694 membership.channel_id,
2695 membership.role,
2696 );
2697 let update = proto::UpdateUserChannels {
2698 channel_memberships: vec![proto::ChannelMembership {
2699 channel_id: membership.channel_id.to_proto(),
2700 role: membership.role.into(),
2701 }],
2702 ..Default::default()
2703 };
2704 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2705 session.peer.send(connection_id, update.clone())?;
2706 }
2707 }
2708
2709 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2710 if !role.can_see_channel(channel.visibility) {
2711 continue;
2712 }
2713
2714 let update = proto::UpdateChannels {
2715 channels: vec![channel.to_proto()],
2716 ..Default::default()
2717 };
2718 session.peer.send(connection_id, update.clone())?;
2719 }
2720
2721 Ok(())
2722}
2723
2724/// Delete a channel
2725async fn delete_channel(
2726 request: proto::DeleteChannel,
2727 response: Response<proto::DeleteChannel>,
2728 session: Session,
2729) -> Result<()> {
2730 let db = session.db().await;
2731
2732 let channel_id = request.channel_id;
2733 let (root_channel, removed_channels) = db
2734 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2735 .await?;
2736 response.send(proto::Ack {})?;
2737
2738 // Notify members of removed channels
2739 let mut update = proto::UpdateChannels::default();
2740 update
2741 .delete_channels
2742 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2743
2744 let connection_pool = session.connection_pool().await;
2745 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2746 session.peer.send(connection_id, update.clone())?;
2747 }
2748
2749 Ok(())
2750}
2751
2752/// Invite someone to join a channel.
2753async fn invite_channel_member(
2754 request: proto::InviteChannelMember,
2755 response: Response<proto::InviteChannelMember>,
2756 session: Session,
2757) -> Result<()> {
2758 let db = session.db().await;
2759 let channel_id = ChannelId::from_proto(request.channel_id);
2760 let invitee_id = UserId::from_proto(request.user_id);
2761 let InviteMemberResult {
2762 channel,
2763 notifications,
2764 } = db
2765 .invite_channel_member(
2766 channel_id,
2767 invitee_id,
2768 session.user_id(),
2769 request.role().into(),
2770 )
2771 .await?;
2772
2773 let update = proto::UpdateChannels {
2774 channel_invitations: vec![channel.to_proto()],
2775 ..Default::default()
2776 };
2777
2778 let connection_pool = session.connection_pool().await;
2779 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2780 session.peer.send(connection_id, update.clone())?;
2781 }
2782
2783 send_notifications(&connection_pool, &session.peer, notifications);
2784
2785 response.send(proto::Ack {})?;
2786 Ok(())
2787}
2788
2789/// remove someone from a channel
2790async fn remove_channel_member(
2791 request: proto::RemoveChannelMember,
2792 response: Response<proto::RemoveChannelMember>,
2793 session: Session,
2794) -> Result<()> {
2795 let db = session.db().await;
2796 let channel_id = ChannelId::from_proto(request.channel_id);
2797 let member_id = UserId::from_proto(request.user_id);
2798
2799 let RemoveChannelMemberResult {
2800 membership_update,
2801 notification_id,
2802 } = db
2803 .remove_channel_member(channel_id, member_id, session.user_id())
2804 .await?;
2805
2806 let mut connection_pool = session.connection_pool().await;
2807 notify_membership_updated(
2808 &mut connection_pool,
2809 membership_update,
2810 member_id,
2811 &session.peer,
2812 );
2813 for connection_id in connection_pool.user_connection_ids(member_id) {
2814 if let Some(notification_id) = notification_id {
2815 session
2816 .peer
2817 .send(
2818 connection_id,
2819 proto::DeleteNotification {
2820 notification_id: notification_id.to_proto(),
2821 },
2822 )
2823 .trace_err();
2824 }
2825 }
2826
2827 response.send(proto::Ack {})?;
2828 Ok(())
2829}
2830
2831/// Toggle the channel between public and private.
2832/// Care is taken to maintain the invariant that public channels only descend from public channels,
2833/// (though members-only channels can appear at any point in the hierarchy).
2834async fn set_channel_visibility(
2835 request: proto::SetChannelVisibility,
2836 response: Response<proto::SetChannelVisibility>,
2837 session: Session,
2838) -> Result<()> {
2839 let db = session.db().await;
2840 let channel_id = ChannelId::from_proto(request.channel_id);
2841 let visibility = request.visibility().into();
2842
2843 let channel_model = db
2844 .set_channel_visibility(channel_id, visibility, session.user_id())
2845 .await?;
2846 let root_id = channel_model.root_id();
2847 let channel = Channel::from_model(channel_model);
2848
2849 let mut connection_pool = session.connection_pool().await;
2850 for (user_id, role) in connection_pool
2851 .channel_user_ids(root_id)
2852 .collect::<Vec<_>>()
2853 .into_iter()
2854 {
2855 let update = if role.can_see_channel(channel.visibility) {
2856 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2857 proto::UpdateChannels {
2858 channels: vec![channel.to_proto()],
2859 ..Default::default()
2860 }
2861 } else {
2862 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2863 proto::UpdateChannels {
2864 delete_channels: vec![channel.id.to_proto()],
2865 ..Default::default()
2866 }
2867 };
2868
2869 for connection_id in connection_pool.user_connection_ids(user_id) {
2870 session.peer.send(connection_id, update.clone())?;
2871 }
2872 }
2873
2874 response.send(proto::Ack {})?;
2875 Ok(())
2876}
2877
2878/// Alter the role for a user in the channel.
2879async fn set_channel_member_role(
2880 request: proto::SetChannelMemberRole,
2881 response: Response<proto::SetChannelMemberRole>,
2882 session: Session,
2883) -> Result<()> {
2884 let db = session.db().await;
2885 let channel_id = ChannelId::from_proto(request.channel_id);
2886 let member_id = UserId::from_proto(request.user_id);
2887 let result = db
2888 .set_channel_member_role(
2889 channel_id,
2890 session.user_id(),
2891 member_id,
2892 request.role().into(),
2893 )
2894 .await?;
2895
2896 match result {
2897 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2898 let mut connection_pool = session.connection_pool().await;
2899 notify_membership_updated(
2900 &mut connection_pool,
2901 membership_update,
2902 member_id,
2903 &session.peer,
2904 )
2905 }
2906 db::SetMemberRoleResult::InviteUpdated(channel) => {
2907 let update = proto::UpdateChannels {
2908 channel_invitations: vec![channel.to_proto()],
2909 ..Default::default()
2910 };
2911
2912 for connection_id in session
2913 .connection_pool()
2914 .await
2915 .user_connection_ids(member_id)
2916 {
2917 session.peer.send(connection_id, update.clone())?;
2918 }
2919 }
2920 }
2921
2922 response.send(proto::Ack {})?;
2923 Ok(())
2924}
2925
2926/// Change the name of a channel
2927async fn rename_channel(
2928 request: proto::RenameChannel,
2929 response: Response<proto::RenameChannel>,
2930 session: Session,
2931) -> Result<()> {
2932 let db = session.db().await;
2933 let channel_id = ChannelId::from_proto(request.channel_id);
2934 let channel_model = db
2935 .rename_channel(channel_id, session.user_id(), &request.name)
2936 .await?;
2937 let root_id = channel_model.root_id();
2938 let channel = Channel::from_model(channel_model);
2939
2940 response.send(proto::RenameChannelResponse {
2941 channel: Some(channel.to_proto()),
2942 })?;
2943
2944 let connection_pool = session.connection_pool().await;
2945 let update = proto::UpdateChannels {
2946 channels: vec![channel.to_proto()],
2947 ..Default::default()
2948 };
2949 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2950 if role.can_see_channel(channel.visibility) {
2951 session.peer.send(connection_id, update.clone())?;
2952 }
2953 }
2954
2955 Ok(())
2956}
2957
2958/// Move a channel to a new parent.
2959async fn move_channel(
2960 request: proto::MoveChannel,
2961 response: Response<proto::MoveChannel>,
2962 session: Session,
2963) -> Result<()> {
2964 let channel_id = ChannelId::from_proto(request.channel_id);
2965 let to = ChannelId::from_proto(request.to);
2966
2967 let (root_id, channels) = session
2968 .db()
2969 .await
2970 .move_channel(channel_id, to, session.user_id())
2971 .await?;
2972
2973 let connection_pool = session.connection_pool().await;
2974 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2975 let channels = channels
2976 .iter()
2977 .filter_map(|channel| {
2978 if role.can_see_channel(channel.visibility) {
2979 Some(channel.to_proto())
2980 } else {
2981 None
2982 }
2983 })
2984 .collect::<Vec<_>>();
2985 if channels.is_empty() {
2986 continue;
2987 }
2988
2989 let update = proto::UpdateChannels {
2990 channels,
2991 ..Default::default()
2992 };
2993
2994 session.peer.send(connection_id, update.clone())?;
2995 }
2996
2997 response.send(Ack {})?;
2998 Ok(())
2999}
3000
3001/// Get the list of channel members
3002async fn get_channel_members(
3003 request: proto::GetChannelMembers,
3004 response: Response<proto::GetChannelMembers>,
3005 session: Session,
3006) -> Result<()> {
3007 let db = session.db().await;
3008 let channel_id = ChannelId::from_proto(request.channel_id);
3009 let limit = if request.limit == 0 {
3010 u16::MAX as u64
3011 } else {
3012 request.limit
3013 };
3014 let (members, users) = db
3015 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3016 .await?;
3017 response.send(proto::GetChannelMembersResponse { members, users })?;
3018 Ok(())
3019}
3020
3021/// Accept or decline a channel invitation.
3022async fn respond_to_channel_invite(
3023 request: proto::RespondToChannelInvite,
3024 response: Response<proto::RespondToChannelInvite>,
3025 session: Session,
3026) -> Result<()> {
3027 let db = session.db().await;
3028 let channel_id = ChannelId::from_proto(request.channel_id);
3029 let RespondToChannelInvite {
3030 membership_update,
3031 notifications,
3032 } = db
3033 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3034 .await?;
3035
3036 let mut connection_pool = session.connection_pool().await;
3037 if let Some(membership_update) = membership_update {
3038 notify_membership_updated(
3039 &mut connection_pool,
3040 membership_update,
3041 session.user_id(),
3042 &session.peer,
3043 );
3044 } else {
3045 let update = proto::UpdateChannels {
3046 remove_channel_invitations: vec![channel_id.to_proto()],
3047 ..Default::default()
3048 };
3049
3050 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3051 session.peer.send(connection_id, update.clone())?;
3052 }
3053 };
3054
3055 send_notifications(&connection_pool, &session.peer, notifications);
3056
3057 response.send(proto::Ack {})?;
3058
3059 Ok(())
3060}
3061
3062/// Join the channels' room
3063async fn join_channel(
3064 request: proto::JoinChannel,
3065 response: Response<proto::JoinChannel>,
3066 session: Session,
3067) -> Result<()> {
3068 let channel_id = ChannelId::from_proto(request.channel_id);
3069 join_channel_internal(channel_id, Box::new(response), session).await
3070}
3071
3072trait JoinChannelInternalResponse {
3073 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3074}
3075impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3076 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3077 Response::<proto::JoinChannel>::send(self, result)
3078 }
3079}
3080impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3081 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3082 Response::<proto::JoinRoom>::send(self, result)
3083 }
3084}
3085
3086async fn join_channel_internal(
3087 channel_id: ChannelId,
3088 response: Box<impl JoinChannelInternalResponse>,
3089 session: Session,
3090) -> Result<()> {
3091 let joined_room = {
3092 let mut db = session.db().await;
3093 // If zed quits without leaving the room, and the user re-opens zed before the
3094 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3095 // room they were in.
3096 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3097 tracing::info!(
3098 stale_connection_id = %connection,
3099 "cleaning up stale connection",
3100 );
3101 drop(db);
3102 leave_room_for_session(&session, connection).await?;
3103 db = session.db().await;
3104 }
3105
3106 let (joined_room, membership_updated, role) = db
3107 .join_channel(channel_id, session.user_id(), session.connection_id)
3108 .await?;
3109
3110 let live_kit_connection_info =
3111 session
3112 .app_state
3113 .livekit_client
3114 .as_ref()
3115 .and_then(|live_kit| {
3116 let (can_publish, token) = if role == ChannelRole::Guest {
3117 (
3118 false,
3119 live_kit
3120 .guest_token(
3121 &joined_room.room.livekit_room,
3122 &session.user_id().to_string(),
3123 )
3124 .trace_err()?,
3125 )
3126 } else {
3127 (
3128 true,
3129 live_kit
3130 .room_token(
3131 &joined_room.room.livekit_room,
3132 &session.user_id().to_string(),
3133 )
3134 .trace_err()?,
3135 )
3136 };
3137
3138 Some(LiveKitConnectionInfo {
3139 server_url: live_kit.url().into(),
3140 token,
3141 can_publish,
3142 })
3143 });
3144
3145 response.send(proto::JoinRoomResponse {
3146 room: Some(joined_room.room.clone()),
3147 channel_id: joined_room
3148 .channel
3149 .as_ref()
3150 .map(|channel| channel.id.to_proto()),
3151 live_kit_connection_info,
3152 })?;
3153
3154 let mut connection_pool = session.connection_pool().await;
3155 if let Some(membership_updated) = membership_updated {
3156 notify_membership_updated(
3157 &mut connection_pool,
3158 membership_updated,
3159 session.user_id(),
3160 &session.peer,
3161 );
3162 }
3163
3164 room_updated(&joined_room.room, &session.peer);
3165
3166 joined_room
3167 };
3168
3169 channel_updated(
3170 &joined_room
3171 .channel
3172 .ok_or_else(|| anyhow!("channel not returned"))?,
3173 &joined_room.room,
3174 &session.peer,
3175 &*session.connection_pool().await,
3176 );
3177
3178 update_user_contacts(session.user_id(), &session).await?;
3179 Ok(())
3180}
3181
3182/// Start editing the channel notes
3183async fn join_channel_buffer(
3184 request: proto::JoinChannelBuffer,
3185 response: Response<proto::JoinChannelBuffer>,
3186 session: Session,
3187) -> Result<()> {
3188 let db = session.db().await;
3189 let channel_id = ChannelId::from_proto(request.channel_id);
3190
3191 let open_response = db
3192 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3193 .await?;
3194
3195 let collaborators = open_response.collaborators.clone();
3196 response.send(open_response)?;
3197
3198 let update = UpdateChannelBufferCollaborators {
3199 channel_id: channel_id.to_proto(),
3200 collaborators: collaborators.clone(),
3201 };
3202 channel_buffer_updated(
3203 session.connection_id,
3204 collaborators
3205 .iter()
3206 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3207 &update,
3208 &session.peer,
3209 );
3210
3211 Ok(())
3212}
3213
3214/// Edit the channel notes
3215async fn update_channel_buffer(
3216 request: proto::UpdateChannelBuffer,
3217 session: Session,
3218) -> Result<()> {
3219 let db = session.db().await;
3220 let channel_id = ChannelId::from_proto(request.channel_id);
3221
3222 let (collaborators, epoch, version) = db
3223 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3224 .await?;
3225
3226 channel_buffer_updated(
3227 session.connection_id,
3228 collaborators.clone(),
3229 &proto::UpdateChannelBuffer {
3230 channel_id: channel_id.to_proto(),
3231 operations: request.operations,
3232 },
3233 &session.peer,
3234 );
3235
3236 let pool = &*session.connection_pool().await;
3237
3238 let non_collaborators =
3239 pool.channel_connection_ids(channel_id)
3240 .filter_map(|(connection_id, _)| {
3241 if collaborators.contains(&connection_id) {
3242 None
3243 } else {
3244 Some(connection_id)
3245 }
3246 });
3247
3248 broadcast(None, non_collaborators, |peer_id| {
3249 session.peer.send(
3250 peer_id,
3251 proto::UpdateChannels {
3252 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3253 channel_id: channel_id.to_proto(),
3254 epoch: epoch as u64,
3255 version: version.clone(),
3256 }],
3257 ..Default::default()
3258 },
3259 )
3260 });
3261
3262 Ok(())
3263}
3264
3265/// Rejoin the channel notes after a connection blip
3266async fn rejoin_channel_buffers(
3267 request: proto::RejoinChannelBuffers,
3268 response: Response<proto::RejoinChannelBuffers>,
3269 session: Session,
3270) -> Result<()> {
3271 let db = session.db().await;
3272 let buffers = db
3273 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3274 .await?;
3275
3276 for rejoined_buffer in &buffers {
3277 let collaborators_to_notify = rejoined_buffer
3278 .buffer
3279 .collaborators
3280 .iter()
3281 .filter_map(|c| Some(c.peer_id?.into()));
3282 channel_buffer_updated(
3283 session.connection_id,
3284 collaborators_to_notify,
3285 &proto::UpdateChannelBufferCollaborators {
3286 channel_id: rejoined_buffer.buffer.channel_id,
3287 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3288 },
3289 &session.peer,
3290 );
3291 }
3292
3293 response.send(proto::RejoinChannelBuffersResponse {
3294 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3295 })?;
3296
3297 Ok(())
3298}
3299
3300/// Stop editing the channel notes
3301async fn leave_channel_buffer(
3302 request: proto::LeaveChannelBuffer,
3303 response: Response<proto::LeaveChannelBuffer>,
3304 session: Session,
3305) -> Result<()> {
3306 let db = session.db().await;
3307 let channel_id = ChannelId::from_proto(request.channel_id);
3308
3309 let left_buffer = db
3310 .leave_channel_buffer(channel_id, session.connection_id)
3311 .await?;
3312
3313 response.send(Ack {})?;
3314
3315 channel_buffer_updated(
3316 session.connection_id,
3317 left_buffer.connections,
3318 &proto::UpdateChannelBufferCollaborators {
3319 channel_id: channel_id.to_proto(),
3320 collaborators: left_buffer.collaborators,
3321 },
3322 &session.peer,
3323 );
3324
3325 Ok(())
3326}
3327
3328fn channel_buffer_updated<T: EnvelopedMessage>(
3329 sender_id: ConnectionId,
3330 collaborators: impl IntoIterator<Item = ConnectionId>,
3331 message: &T,
3332 peer: &Peer,
3333) {
3334 broadcast(Some(sender_id), collaborators, |peer_id| {
3335 peer.send(peer_id, message.clone())
3336 });
3337}
3338
3339fn send_notifications(
3340 connection_pool: &ConnectionPool,
3341 peer: &Peer,
3342 notifications: db::NotificationBatch,
3343) {
3344 for (user_id, notification) in notifications {
3345 for connection_id in connection_pool.user_connection_ids(user_id) {
3346 if let Err(error) = peer.send(
3347 connection_id,
3348 proto::AddNotification {
3349 notification: Some(notification.clone()),
3350 },
3351 ) {
3352 tracing::error!(
3353 "failed to send notification to {:?} {}",
3354 connection_id,
3355 error
3356 );
3357 }
3358 }
3359 }
3360}
3361
3362/// Send a message to the channel
3363async fn send_channel_message(
3364 request: proto::SendChannelMessage,
3365 response: Response<proto::SendChannelMessage>,
3366 session: Session,
3367) -> Result<()> {
3368 // Validate the message body.
3369 let body = request.body.trim().to_string();
3370 if body.len() > MAX_MESSAGE_LEN {
3371 return Err(anyhow!("message is too long"))?;
3372 }
3373 if body.is_empty() {
3374 return Err(anyhow!("message can't be blank"))?;
3375 }
3376
3377 // TODO: adjust mentions if body is trimmed
3378
3379 let timestamp = OffsetDateTime::now_utc();
3380 let nonce = request
3381 .nonce
3382 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3383
3384 let channel_id = ChannelId::from_proto(request.channel_id);
3385 let CreatedChannelMessage {
3386 message_id,
3387 participant_connection_ids,
3388 notifications,
3389 } = session
3390 .db()
3391 .await
3392 .create_channel_message(
3393 channel_id,
3394 session.user_id(),
3395 &body,
3396 &request.mentions,
3397 timestamp,
3398 nonce.clone().into(),
3399 request.reply_to_message_id.map(MessageId::from_proto),
3400 )
3401 .await?;
3402
3403 let message = proto::ChannelMessage {
3404 sender_id: session.user_id().to_proto(),
3405 id: message_id.to_proto(),
3406 body,
3407 mentions: request.mentions,
3408 timestamp: timestamp.unix_timestamp() as u64,
3409 nonce: Some(nonce),
3410 reply_to_message_id: request.reply_to_message_id,
3411 edited_at: None,
3412 };
3413 broadcast(
3414 Some(session.connection_id),
3415 participant_connection_ids.clone(),
3416 |connection| {
3417 session.peer.send(
3418 connection,
3419 proto::ChannelMessageSent {
3420 channel_id: channel_id.to_proto(),
3421 message: Some(message.clone()),
3422 },
3423 )
3424 },
3425 );
3426 response.send(proto::SendChannelMessageResponse {
3427 message: Some(message),
3428 })?;
3429
3430 let pool = &*session.connection_pool().await;
3431 let non_participants =
3432 pool.channel_connection_ids(channel_id)
3433 .filter_map(|(connection_id, _)| {
3434 if participant_connection_ids.contains(&connection_id) {
3435 None
3436 } else {
3437 Some(connection_id)
3438 }
3439 });
3440 broadcast(None, non_participants, |peer_id| {
3441 session.peer.send(
3442 peer_id,
3443 proto::UpdateChannels {
3444 latest_channel_message_ids: vec![proto::ChannelMessageId {
3445 channel_id: channel_id.to_proto(),
3446 message_id: message_id.to_proto(),
3447 }],
3448 ..Default::default()
3449 },
3450 )
3451 });
3452 send_notifications(pool, &session.peer, notifications);
3453
3454 Ok(())
3455}
3456
3457/// Delete a channel message
3458async fn remove_channel_message(
3459 request: proto::RemoveChannelMessage,
3460 response: Response<proto::RemoveChannelMessage>,
3461 session: Session,
3462) -> Result<()> {
3463 let channel_id = ChannelId::from_proto(request.channel_id);
3464 let message_id = MessageId::from_proto(request.message_id);
3465 let (connection_ids, existing_notification_ids) = session
3466 .db()
3467 .await
3468 .remove_channel_message(channel_id, message_id, session.user_id())
3469 .await?;
3470
3471 broadcast(
3472 Some(session.connection_id),
3473 connection_ids,
3474 move |connection| {
3475 session.peer.send(connection, request.clone())?;
3476
3477 for notification_id in &existing_notification_ids {
3478 session.peer.send(
3479 connection,
3480 proto::DeleteNotification {
3481 notification_id: (*notification_id).to_proto(),
3482 },
3483 )?;
3484 }
3485
3486 Ok(())
3487 },
3488 );
3489 response.send(proto::Ack {})?;
3490 Ok(())
3491}
3492
3493async fn update_channel_message(
3494 request: proto::UpdateChannelMessage,
3495 response: Response<proto::UpdateChannelMessage>,
3496 session: Session,
3497) -> Result<()> {
3498 let channel_id = ChannelId::from_proto(request.channel_id);
3499 let message_id = MessageId::from_proto(request.message_id);
3500 let updated_at = OffsetDateTime::now_utc();
3501 let UpdatedChannelMessage {
3502 message_id,
3503 participant_connection_ids,
3504 notifications,
3505 reply_to_message_id,
3506 timestamp,
3507 deleted_mention_notification_ids,
3508 updated_mention_notifications,
3509 } = session
3510 .db()
3511 .await
3512 .update_channel_message(
3513 channel_id,
3514 message_id,
3515 session.user_id(),
3516 request.body.as_str(),
3517 &request.mentions,
3518 updated_at,
3519 )
3520 .await?;
3521
3522 let nonce = request
3523 .nonce
3524 .clone()
3525 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3526
3527 let message = proto::ChannelMessage {
3528 sender_id: session.user_id().to_proto(),
3529 id: message_id.to_proto(),
3530 body: request.body.clone(),
3531 mentions: request.mentions.clone(),
3532 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3533 nonce: Some(nonce),
3534 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3535 edited_at: Some(updated_at.unix_timestamp() as u64),
3536 };
3537
3538 response.send(proto::Ack {})?;
3539
3540 let pool = &*session.connection_pool().await;
3541 broadcast(
3542 Some(session.connection_id),
3543 participant_connection_ids,
3544 |connection| {
3545 session.peer.send(
3546 connection,
3547 proto::ChannelMessageUpdate {
3548 channel_id: channel_id.to_proto(),
3549 message: Some(message.clone()),
3550 },
3551 )?;
3552
3553 for notification_id in &deleted_mention_notification_ids {
3554 session.peer.send(
3555 connection,
3556 proto::DeleteNotification {
3557 notification_id: (*notification_id).to_proto(),
3558 },
3559 )?;
3560 }
3561
3562 for notification in &updated_mention_notifications {
3563 session.peer.send(
3564 connection,
3565 proto::UpdateNotification {
3566 notification: Some(notification.clone()),
3567 },
3568 )?;
3569 }
3570
3571 Ok(())
3572 },
3573 );
3574
3575 send_notifications(pool, &session.peer, notifications);
3576
3577 Ok(())
3578}
3579
3580/// Mark a channel message as read
3581async fn acknowledge_channel_message(
3582 request: proto::AckChannelMessage,
3583 session: Session,
3584) -> Result<()> {
3585 let channel_id = ChannelId::from_proto(request.channel_id);
3586 let message_id = MessageId::from_proto(request.message_id);
3587 let notifications = session
3588 .db()
3589 .await
3590 .observe_channel_message(channel_id, session.user_id(), message_id)
3591 .await?;
3592 send_notifications(
3593 &*session.connection_pool().await,
3594 &session.peer,
3595 notifications,
3596 );
3597 Ok(())
3598}
3599
3600/// Mark a buffer version as synced
3601async fn acknowledge_buffer_version(
3602 request: proto::AckBufferOperation,
3603 session: Session,
3604) -> Result<()> {
3605 let buffer_id = BufferId::from_proto(request.buffer_id);
3606 session
3607 .db()
3608 .await
3609 .observe_buffer_version(
3610 buffer_id,
3611 session.user_id(),
3612 request.epoch as i32,
3613 &request.version,
3614 )
3615 .await?;
3616 Ok(())
3617}
3618
3619async fn count_language_model_tokens(
3620 request: proto::CountLanguageModelTokens,
3621 response: Response<proto::CountLanguageModelTokens>,
3622 session: Session,
3623 config: &Config,
3624) -> Result<()> {
3625 authorize_access_to_legacy_llm_endpoints(&session).await?;
3626
3627 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3628 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3629 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3630 };
3631
3632 session
3633 .app_state
3634 .rate_limiter
3635 .check(&*rate_limit, session.user_id())
3636 .await?;
3637
3638 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3639 Some(proto::LanguageModelProvider::Google) => {
3640 let api_key = config
3641 .google_ai_api_key
3642 .as_ref()
3643 .context("no Google AI API key configured on the server")?;
3644 google_ai::count_tokens(
3645 session.http_client.as_ref(),
3646 google_ai::API_URL,
3647 api_key,
3648 serde_json::from_str(&request.request)?,
3649 )
3650 .await?
3651 }
3652 _ => return Err(anyhow!("unsupported provider"))?,
3653 };
3654
3655 response.send(proto::CountLanguageModelTokensResponse {
3656 token_count: result.total_tokens as u32,
3657 })?;
3658
3659 Ok(())
3660}
3661
3662struct ZedProCountLanguageModelTokensRateLimit;
3663
3664impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3665 fn capacity(&self) -> usize {
3666 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3667 .ok()
3668 .and_then(|v| v.parse().ok())
3669 .unwrap_or(600) // Picked arbitrarily
3670 }
3671
3672 fn refill_duration(&self) -> chrono::Duration {
3673 chrono::Duration::hours(1)
3674 }
3675
3676 fn db_name(&self) -> &'static str {
3677 "zed-pro:count-language-model-tokens"
3678 }
3679}
3680
3681struct FreeCountLanguageModelTokensRateLimit;
3682
3683impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3684 fn capacity(&self) -> usize {
3685 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3686 .ok()
3687 .and_then(|v| v.parse().ok())
3688 .unwrap_or(600 / 10) // Picked arbitrarily
3689 }
3690
3691 fn refill_duration(&self) -> chrono::Duration {
3692 chrono::Duration::hours(1)
3693 }
3694
3695 fn db_name(&self) -> &'static str {
3696 "free:count-language-model-tokens"
3697 }
3698}
3699
3700struct ZedProComputeEmbeddingsRateLimit;
3701
3702impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3703 fn capacity(&self) -> usize {
3704 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3705 .ok()
3706 .and_then(|v| v.parse().ok())
3707 .unwrap_or(5000) // Picked arbitrarily
3708 }
3709
3710 fn refill_duration(&self) -> chrono::Duration {
3711 chrono::Duration::hours(1)
3712 }
3713
3714 fn db_name(&self) -> &'static str {
3715 "zed-pro:compute-embeddings"
3716 }
3717}
3718
3719struct FreeComputeEmbeddingsRateLimit;
3720
3721impl RateLimit for FreeComputeEmbeddingsRateLimit {
3722 fn capacity(&self) -> usize {
3723 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3724 .ok()
3725 .and_then(|v| v.parse().ok())
3726 .unwrap_or(5000 / 10) // Picked arbitrarily
3727 }
3728
3729 fn refill_duration(&self) -> chrono::Duration {
3730 chrono::Duration::hours(1)
3731 }
3732
3733 fn db_name(&self) -> &'static str {
3734 "free:compute-embeddings"
3735 }
3736}
3737
3738async fn compute_embeddings(
3739 request: proto::ComputeEmbeddings,
3740 response: Response<proto::ComputeEmbeddings>,
3741 session: Session,
3742 api_key: Option<Arc<str>>,
3743) -> Result<()> {
3744 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3745 authorize_access_to_legacy_llm_endpoints(&session).await?;
3746
3747 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3748 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3749 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3750 };
3751
3752 session
3753 .app_state
3754 .rate_limiter
3755 .check(&*rate_limit, session.user_id())
3756 .await?;
3757
3758 let embeddings = match request.model.as_str() {
3759 "openai/text-embedding-3-small" => {
3760 open_ai::embed(
3761 session.http_client.as_ref(),
3762 OPEN_AI_API_URL,
3763 &api_key,
3764 OpenAiEmbeddingModel::TextEmbedding3Small,
3765 request.texts.iter().map(|text| text.as_str()),
3766 )
3767 .await?
3768 }
3769 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3770 };
3771
3772 let embeddings = request
3773 .texts
3774 .iter()
3775 .map(|text| {
3776 let mut hasher = sha2::Sha256::new();
3777 hasher.update(text.as_bytes());
3778 let result = hasher.finalize();
3779 result.to_vec()
3780 })
3781 .zip(
3782 embeddings
3783 .data
3784 .into_iter()
3785 .map(|embedding| embedding.embedding),
3786 )
3787 .collect::<HashMap<_, _>>();
3788
3789 let db = session.db().await;
3790 db.save_embeddings(&request.model, &embeddings)
3791 .await
3792 .context("failed to save embeddings")
3793 .trace_err();
3794
3795 response.send(proto::ComputeEmbeddingsResponse {
3796 embeddings: embeddings
3797 .into_iter()
3798 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3799 .collect(),
3800 })?;
3801 Ok(())
3802}
3803
3804async fn get_cached_embeddings(
3805 request: proto::GetCachedEmbeddings,
3806 response: Response<proto::GetCachedEmbeddings>,
3807 session: Session,
3808) -> Result<()> {
3809 authorize_access_to_legacy_llm_endpoints(&session).await?;
3810
3811 let db = session.db().await;
3812 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3813
3814 response.send(proto::GetCachedEmbeddingsResponse {
3815 embeddings: embeddings
3816 .into_iter()
3817 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3818 .collect(),
3819 })?;
3820 Ok(())
3821}
3822
3823/// This is leftover from before the LLM service.
3824///
3825/// The endpoints protected by this check will be moved there eventually.
3826async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3827 if session.is_staff() {
3828 Ok(())
3829 } else {
3830 Err(anyhow!("permission denied"))?
3831 }
3832}
3833
3834/// Get a Supermaven API key for the user
3835async fn get_supermaven_api_key(
3836 _request: proto::GetSupermavenApiKey,
3837 response: Response<proto::GetSupermavenApiKey>,
3838 session: Session,
3839) -> Result<()> {
3840 let user_id: String = session.user_id().to_string();
3841 if !session.is_staff() {
3842 return Err(anyhow!("supermaven not enabled for this account"))?;
3843 }
3844
3845 let email = session
3846 .email()
3847 .ok_or_else(|| anyhow!("user must have an email"))?;
3848
3849 let supermaven_admin_api = session
3850 .supermaven_client
3851 .as_ref()
3852 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3853
3854 let result = supermaven_admin_api
3855 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3856 .await?;
3857
3858 response.send(proto::GetSupermavenApiKeyResponse {
3859 api_key: result.api_key,
3860 })?;
3861
3862 Ok(())
3863}
3864
3865/// Start receiving chat updates for a channel
3866async fn join_channel_chat(
3867 request: proto::JoinChannelChat,
3868 response: Response<proto::JoinChannelChat>,
3869 session: Session,
3870) -> Result<()> {
3871 let channel_id = ChannelId::from_proto(request.channel_id);
3872
3873 let db = session.db().await;
3874 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3875 .await?;
3876 let messages = db
3877 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3878 .await?;
3879 response.send(proto::JoinChannelChatResponse {
3880 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3881 messages,
3882 })?;
3883 Ok(())
3884}
3885
3886/// Stop receiving chat updates for a channel
3887async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3888 let channel_id = ChannelId::from_proto(request.channel_id);
3889 session
3890 .db()
3891 .await
3892 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3893 .await?;
3894 Ok(())
3895}
3896
3897/// Retrieve the chat history for a channel
3898async fn get_channel_messages(
3899 request: proto::GetChannelMessages,
3900 response: Response<proto::GetChannelMessages>,
3901 session: Session,
3902) -> Result<()> {
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 let messages = session
3905 .db()
3906 .await
3907 .get_channel_messages(
3908 channel_id,
3909 session.user_id(),
3910 MESSAGE_COUNT_PER_PAGE,
3911 Some(MessageId::from_proto(request.before_message_id)),
3912 )
3913 .await?;
3914 response.send(proto::GetChannelMessagesResponse {
3915 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3916 messages,
3917 })?;
3918 Ok(())
3919}
3920
3921/// Retrieve specific chat messages
3922async fn get_channel_messages_by_id(
3923 request: proto::GetChannelMessagesById,
3924 response: Response<proto::GetChannelMessagesById>,
3925 session: Session,
3926) -> Result<()> {
3927 let message_ids = request
3928 .message_ids
3929 .iter()
3930 .map(|id| MessageId::from_proto(*id))
3931 .collect::<Vec<_>>();
3932 let messages = session
3933 .db()
3934 .await
3935 .get_channel_messages_by_id(session.user_id(), &message_ids)
3936 .await?;
3937 response.send(proto::GetChannelMessagesResponse {
3938 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3939 messages,
3940 })?;
3941 Ok(())
3942}
3943
3944/// Retrieve the current users notifications
3945async fn get_notifications(
3946 request: proto::GetNotifications,
3947 response: Response<proto::GetNotifications>,
3948 session: Session,
3949) -> Result<()> {
3950 let notifications = session
3951 .db()
3952 .await
3953 .get_notifications(
3954 session.user_id(),
3955 NOTIFICATION_COUNT_PER_PAGE,
3956 request.before_id.map(db::NotificationId::from_proto),
3957 )
3958 .await?;
3959 response.send(proto::GetNotificationsResponse {
3960 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3961 notifications,
3962 })?;
3963 Ok(())
3964}
3965
3966/// Mark notifications as read
3967async fn mark_notification_as_read(
3968 request: proto::MarkNotificationRead,
3969 response: Response<proto::MarkNotificationRead>,
3970 session: Session,
3971) -> Result<()> {
3972 let database = &session.db().await;
3973 let notifications = database
3974 .mark_notification_as_read_by_id(
3975 session.user_id(),
3976 NotificationId::from_proto(request.notification_id),
3977 )
3978 .await?;
3979 send_notifications(
3980 &*session.connection_pool().await,
3981 &session.peer,
3982 notifications,
3983 );
3984 response.send(proto::Ack {})?;
3985 Ok(())
3986}
3987
3988/// Get the current users information
3989async fn get_private_user_info(
3990 _request: proto::GetPrivateUserInfo,
3991 response: Response<proto::GetPrivateUserInfo>,
3992 session: Session,
3993) -> Result<()> {
3994 let db = session.db().await;
3995
3996 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3997 let user = db
3998 .get_user_by_id(session.user_id())
3999 .await?
4000 .ok_or_else(|| anyhow!("user not found"))?;
4001 let flags = db.get_user_flags(session.user_id()).await?;
4002
4003 response.send(proto::GetPrivateUserInfoResponse {
4004 metrics_id,
4005 staff: user.admin,
4006 flags,
4007 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4008 })?;
4009 Ok(())
4010}
4011
4012/// Accept the terms of service (tos) on behalf of the current user
4013async fn accept_terms_of_service(
4014 _request: proto::AcceptTermsOfService,
4015 response: Response<proto::AcceptTermsOfService>,
4016 session: Session,
4017) -> Result<()> {
4018 let db = session.db().await;
4019
4020 let accepted_tos_at = Utc::now();
4021 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4022 .await?;
4023
4024 response.send(proto::AcceptTermsOfServiceResponse {
4025 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4026 })?;
4027 Ok(())
4028}
4029
4030/// The minimum account age an account must have in order to use the LLM service.
4031const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4032
4033async fn get_llm_api_token(
4034 _request: proto::GetLlmToken,
4035 response: Response<proto::GetLlmToken>,
4036 session: Session,
4037) -> Result<()> {
4038 let db = session.db().await;
4039
4040 let flags = db.get_user_flags(session.user_id()).await?;
4041 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4042 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4043 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4044
4045 if !session.is_staff() && !has_language_models_feature_flag {
4046 Err(anyhow!("permission denied"))?
4047 }
4048
4049 let user_id = session.user_id();
4050 let user = db
4051 .get_user_by_id(user_id)
4052 .await?
4053 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4054
4055 if user.accepted_tos_at.is_none() {
4056 Err(anyhow!("terms of service not accepted"))?
4057 }
4058
4059 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4060
4061 let bypass_account_age_check =
4062 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4063 if !bypass_account_age_check {
4064 let mut account_created_at = user.created_at;
4065 if let Some(github_created_at) = user.github_user_created_at {
4066 account_created_at = account_created_at.min(github_created_at);
4067 }
4068 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4069 Err(anyhow!("account too young"))?
4070 }
4071 }
4072
4073 let billing_preferences = db.get_billing_preferences(user.id).await?;
4074
4075 let token = LlmTokenClaims::create(
4076 &user,
4077 session.is_staff(),
4078 billing_preferences,
4079 has_llm_closed_beta_feature_flag,
4080 has_predict_edits_feature_flag,
4081 has_llm_subscription,
4082 session.current_plan(&db).await?,
4083 session.system_id.clone(),
4084 &session.app_state.config,
4085 )?;
4086 response.send(proto::GetLlmTokenResponse { token })?;
4087 Ok(())
4088}
4089
4090fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4091 let message = match message {
4092 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4093 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4094 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4095 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4096 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4097 code: frame.code.into(),
4098 reason: frame.reason,
4099 })),
4100 // We should never receive a frame while reading the message, according
4101 // to the `tungstenite` maintainers:
4102 //
4103 // > It cannot occur when you read messages from the WebSocket, but it
4104 // > can be used when you want to send the raw frames (e.g. you want to
4105 // > send the frames to the WebSocket without composing the full message first).
4106 // >
4107 // > — https://github.com/snapview/tungstenite-rs/issues/268
4108 TungsteniteMessage::Frame(_) => {
4109 bail!("received an unexpected frame while reading the message")
4110 }
4111 };
4112
4113 Ok(message)
4114}
4115
4116fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4117 match message {
4118 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4119 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4120 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4121 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4122 AxumMessage::Close(frame) => {
4123 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4124 code: frame.code.into(),
4125 reason: frame.reason,
4126 }))
4127 }
4128 }
4129}
4130
4131fn notify_membership_updated(
4132 connection_pool: &mut ConnectionPool,
4133 result: MembershipUpdated,
4134 user_id: UserId,
4135 peer: &Peer,
4136) {
4137 for membership in &result.new_channels.channel_memberships {
4138 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4139 }
4140 for channel_id in &result.removed_channels {
4141 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4142 }
4143
4144 let user_channels_update = proto::UpdateUserChannels {
4145 channel_memberships: result
4146 .new_channels
4147 .channel_memberships
4148 .iter()
4149 .map(|cm| proto::ChannelMembership {
4150 channel_id: cm.channel_id.to_proto(),
4151 role: cm.role.into(),
4152 })
4153 .collect(),
4154 ..Default::default()
4155 };
4156
4157 let mut update = build_channels_update(result.new_channels);
4158 update.delete_channels = result
4159 .removed_channels
4160 .into_iter()
4161 .map(|id| id.to_proto())
4162 .collect();
4163 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4164
4165 for connection_id in connection_pool.user_connection_ids(user_id) {
4166 peer.send(connection_id, user_channels_update.clone())
4167 .trace_err();
4168 peer.send(connection_id, update.clone()).trace_err();
4169 }
4170}
4171
4172fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4173 proto::UpdateUserChannels {
4174 channel_memberships: channels
4175 .channel_memberships
4176 .iter()
4177 .map(|m| proto::ChannelMembership {
4178 channel_id: m.channel_id.to_proto(),
4179 role: m.role.into(),
4180 })
4181 .collect(),
4182 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4183 observed_channel_message_id: channels.observed_channel_messages.clone(),
4184 }
4185}
4186
4187fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4188 let mut update = proto::UpdateChannels::default();
4189
4190 for channel in channels.channels {
4191 update.channels.push(channel.to_proto());
4192 }
4193
4194 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4195 update.latest_channel_message_ids = channels.latest_channel_messages;
4196
4197 for (channel_id, participants) in channels.channel_participants {
4198 update
4199 .channel_participants
4200 .push(proto::ChannelParticipants {
4201 channel_id: channel_id.to_proto(),
4202 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4203 });
4204 }
4205
4206 for channel in channels.invited_channels {
4207 update.channel_invitations.push(channel.to_proto());
4208 }
4209
4210 update
4211}
4212
4213fn build_initial_contacts_update(
4214 contacts: Vec<db::Contact>,
4215 pool: &ConnectionPool,
4216) -> proto::UpdateContacts {
4217 let mut update = proto::UpdateContacts::default();
4218
4219 for contact in contacts {
4220 match contact {
4221 db::Contact::Accepted { user_id, busy } => {
4222 update.contacts.push(contact_for_user(user_id, busy, pool));
4223 }
4224 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4225 db::Contact::Incoming { user_id } => {
4226 update
4227 .incoming_requests
4228 .push(proto::IncomingContactRequest {
4229 requester_id: user_id.to_proto(),
4230 })
4231 }
4232 }
4233 }
4234
4235 update
4236}
4237
4238fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4239 proto::Contact {
4240 user_id: user_id.to_proto(),
4241 online: pool.is_user_online(user_id),
4242 busy,
4243 }
4244}
4245
4246fn room_updated(room: &proto::Room, peer: &Peer) {
4247 broadcast(
4248 None,
4249 room.participants
4250 .iter()
4251 .filter_map(|participant| Some(participant.peer_id?.into())),
4252 |peer_id| {
4253 peer.send(
4254 peer_id,
4255 proto::RoomUpdated {
4256 room: Some(room.clone()),
4257 },
4258 )
4259 },
4260 );
4261}
4262
4263fn channel_updated(
4264 channel: &db::channel::Model,
4265 room: &proto::Room,
4266 peer: &Peer,
4267 pool: &ConnectionPool,
4268) {
4269 let participants = room
4270 .participants
4271 .iter()
4272 .map(|p| p.user_id)
4273 .collect::<Vec<_>>();
4274
4275 broadcast(
4276 None,
4277 pool.channel_connection_ids(channel.root_id())
4278 .filter_map(|(channel_id, role)| {
4279 role.can_see_channel(channel.visibility)
4280 .then_some(channel_id)
4281 }),
4282 |peer_id| {
4283 peer.send(
4284 peer_id,
4285 proto::UpdateChannels {
4286 channel_participants: vec![proto::ChannelParticipants {
4287 channel_id: channel.id.to_proto(),
4288 participant_user_ids: participants.clone(),
4289 }],
4290 ..Default::default()
4291 },
4292 )
4293 },
4294 );
4295}
4296
4297async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4298 let db = session.db().await;
4299
4300 let contacts = db.get_contacts(user_id).await?;
4301 let busy = db.is_user_busy(user_id).await?;
4302
4303 let pool = session.connection_pool().await;
4304 let updated_contact = contact_for_user(user_id, busy, &pool);
4305 for contact in contacts {
4306 if let db::Contact::Accepted {
4307 user_id: contact_user_id,
4308 ..
4309 } = contact
4310 {
4311 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4312 session
4313 .peer
4314 .send(
4315 contact_conn_id,
4316 proto::UpdateContacts {
4317 contacts: vec![updated_contact.clone()],
4318 remove_contacts: Default::default(),
4319 incoming_requests: Default::default(),
4320 remove_incoming_requests: Default::default(),
4321 outgoing_requests: Default::default(),
4322 remove_outgoing_requests: Default::default(),
4323 },
4324 )
4325 .trace_err();
4326 }
4327 }
4328 }
4329 Ok(())
4330}
4331
4332async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4333 let mut contacts_to_update = HashSet::default();
4334
4335 let room_id;
4336 let canceled_calls_to_user_ids;
4337 let livekit_room;
4338 let delete_livekit_room;
4339 let room;
4340 let channel;
4341
4342 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4343 contacts_to_update.insert(session.user_id());
4344
4345 for project in left_room.left_projects.values() {
4346 project_left(project, session);
4347 }
4348
4349 room_id = RoomId::from_proto(left_room.room.id);
4350 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4351 livekit_room = mem::take(&mut left_room.room.livekit_room);
4352 delete_livekit_room = left_room.deleted;
4353 room = mem::take(&mut left_room.room);
4354 channel = mem::take(&mut left_room.channel);
4355
4356 room_updated(&room, &session.peer);
4357 } else {
4358 return Ok(());
4359 }
4360
4361 if let Some(channel) = channel {
4362 channel_updated(
4363 &channel,
4364 &room,
4365 &session.peer,
4366 &*session.connection_pool().await,
4367 );
4368 }
4369
4370 {
4371 let pool = session.connection_pool().await;
4372 for canceled_user_id in canceled_calls_to_user_ids {
4373 for connection_id in pool.user_connection_ids(canceled_user_id) {
4374 session
4375 .peer
4376 .send(
4377 connection_id,
4378 proto::CallCanceled {
4379 room_id: room_id.to_proto(),
4380 },
4381 )
4382 .trace_err();
4383 }
4384 contacts_to_update.insert(canceled_user_id);
4385 }
4386 }
4387
4388 for contact_user_id in contacts_to_update {
4389 update_user_contacts(contact_user_id, session).await?;
4390 }
4391
4392 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4393 live_kit
4394 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4395 .await
4396 .trace_err();
4397
4398 if delete_livekit_room {
4399 live_kit.delete_room(livekit_room).await.trace_err();
4400 }
4401 }
4402
4403 Ok(())
4404}
4405
4406async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4407 let left_channel_buffers = session
4408 .db()
4409 .await
4410 .leave_channel_buffers(session.connection_id)
4411 .await?;
4412
4413 for left_buffer in left_channel_buffers {
4414 channel_buffer_updated(
4415 session.connection_id,
4416 left_buffer.connections,
4417 &proto::UpdateChannelBufferCollaborators {
4418 channel_id: left_buffer.channel_id.to_proto(),
4419 collaborators: left_buffer.collaborators,
4420 },
4421 &session.peer,
4422 );
4423 }
4424
4425 Ok(())
4426}
4427
4428fn project_left(project: &db::LeftProject, session: &Session) {
4429 for connection_id in &project.connection_ids {
4430 if project.should_unshare {
4431 session
4432 .peer
4433 .send(
4434 *connection_id,
4435 proto::UnshareProject {
4436 project_id: project.id.to_proto(),
4437 },
4438 )
4439 .trace_err();
4440 } else {
4441 session
4442 .peer
4443 .send(
4444 *connection_id,
4445 proto::RemoveProjectCollaborator {
4446 project_id: project.id.to_proto(),
4447 peer_id: Some(session.connection_id.into()),
4448 },
4449 )
4450 .trace_err();
4451 }
4452 }
4453}
4454
4455pub trait ResultExt {
4456 type Ok;
4457
4458 fn trace_err(self) -> Option<Self::Ok>;
4459}
4460
4461impl<T, E> ResultExt for Result<T, E>
4462where
4463 E: std::fmt::Debug,
4464{
4465 type Ok = T;
4466
4467 #[track_caller]
4468 fn trace_err(self) -> Option<T> {
4469 match self {
4470 Ok(value) => Some(value),
4471 Err(error) => {
4472 tracing::error!("{:?}", error);
4473 None
4474 }
4475 }
4476 }
4477}