1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use rpc::proto::split_repository_update;
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
46 TryStreamExt,
47};
48use prometheus::{register_int_gauge, IntGauge};
49use rpc::{
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 atomic::{AtomicBool, Ordering::SeqCst},
68 Arc, OnceLock,
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{watch, MutexGuard, Semaphore};
74use tower::ServiceBuilder;
75use tracing::{
76 field::{self},
77 info_span, instrument, Instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
310 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
311 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
312 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
313 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
315 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
322 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
327 .add_request_handler(
328 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
329 )
330 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
332 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
333 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
334 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
335 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
336 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
337 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
339 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
340 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
341 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
342 .add_request_handler(
343 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
344 )
345 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
346 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
347 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
348 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
349 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
350 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
351 .add_message_handler(create_buffer_for_peer)
352 .add_request_handler(update_buffer)
353 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
357 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
358 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
359 .add_request_handler(get_users)
360 .add_request_handler(fuzzy_search_users)
361 .add_request_handler(request_contact)
362 .add_request_handler(remove_contact)
363 .add_request_handler(respond_to_contact_request)
364 .add_message_handler(subscribe_to_channels)
365 .add_request_handler(create_channel)
366 .add_request_handler(delete_channel)
367 .add_request_handler(invite_channel_member)
368 .add_request_handler(remove_channel_member)
369 .add_request_handler(set_channel_member_role)
370 .add_request_handler(set_channel_visibility)
371 .add_request_handler(rename_channel)
372 .add_request_handler(join_channel_buffer)
373 .add_request_handler(leave_channel_buffer)
374 .add_message_handler(update_channel_buffer)
375 .add_request_handler(rejoin_channel_buffers)
376 .add_request_handler(get_channel_members)
377 .add_request_handler(respond_to_channel_invite)
378 .add_request_handler(join_channel)
379 .add_request_handler(join_channel_chat)
380 .add_message_handler(leave_channel_chat)
381 .add_request_handler(send_channel_message)
382 .add_request_handler(remove_channel_message)
383 .add_request_handler(update_channel_message)
384 .add_request_handler(get_channel_messages)
385 .add_request_handler(get_channel_messages_by_id)
386 .add_request_handler(get_notifications)
387 .add_request_handler(mark_notification_as_read)
388 .add_request_handler(move_channel)
389 .add_request_handler(follow)
390 .add_message_handler(unfollow)
391 .add_message_handler(update_followers)
392 .add_request_handler(get_private_user_info)
393 .add_request_handler(get_llm_api_token)
394 .add_request_handler(accept_terms_of_service)
395 .add_message_handler(acknowledge_channel_message)
396 .add_message_handler(acknowledge_buffer_version)
397 .add_request_handler(get_supermaven_api_key)
398 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
399 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
400 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
401 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
402 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
403 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
404 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
405 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
406 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
408 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
409 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
410 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
412 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
415 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
416 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
417 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
418 .add_message_handler(update_context)
419 .add_request_handler({
420 let app_state = app_state.clone();
421 move |request, response, session| {
422 let app_state = app_state.clone();
423 async move {
424 count_language_model_tokens(request, response, session, &app_state.config)
425 .await
426 }
427 }
428 })
429 .add_request_handler(get_cached_embeddings)
430 .add_request_handler({
431 let app_state = app_state.clone();
432 move |request, response, session| {
433 compute_embeddings(
434 request,
435 response,
436 session,
437 app_state.config.openai_api_key.clone(),
438 )
439 }
440 });
441
442 Arc::new(server)
443 }
444
445 pub async fn start(&self) -> Result<()> {
446 let server_id = *self.id.lock();
447 let app_state = self.app_state.clone();
448 let peer = self.peer.clone();
449 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
450 let pool = self.connection_pool.clone();
451 let livekit_client = self.app_state.livekit_client.clone();
452
453 let span = info_span!("start server");
454 self.app_state.executor.spawn_detached(
455 async move {
456 tracing::info!("waiting for cleanup timeout");
457 timeout.await;
458 tracing::info!("cleanup timeout expired, retrieving stale rooms");
459 if let Some((room_ids, channel_ids)) = app_state
460 .db
461 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
462 .await
463 .trace_err()
464 {
465 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
466 tracing::info!(
467 stale_channel_buffer_count = channel_ids.len(),
468 "retrieved stale channel buffers"
469 );
470
471 for channel_id in channel_ids {
472 if let Some(refreshed_channel_buffer) = app_state
473 .db
474 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
475 .await
476 .trace_err()
477 {
478 for connection_id in refreshed_channel_buffer.connection_ids {
479 peer.send(
480 connection_id,
481 proto::UpdateChannelBufferCollaborators {
482 channel_id: channel_id.to_proto(),
483 collaborators: refreshed_channel_buffer
484 .collaborators
485 .clone(),
486 },
487 )
488 .trace_err();
489 }
490 }
491 }
492
493 for room_id in room_ids {
494 let mut contacts_to_update = HashSet::default();
495 let mut canceled_calls_to_user_ids = Vec::new();
496 let mut livekit_room = String::new();
497 let mut delete_livekit_room = false;
498
499 if let Some(mut refreshed_room) = app_state
500 .db
501 .clear_stale_room_participants(room_id, server_id)
502 .await
503 .trace_err()
504 {
505 tracing::info!(
506 room_id = room_id.0,
507 new_participant_count = refreshed_room.room.participants.len(),
508 "refreshed room"
509 );
510 room_updated(&refreshed_room.room, &peer);
511 if let Some(channel) = refreshed_room.channel.as_ref() {
512 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
513 }
514 contacts_to_update
515 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
516 contacts_to_update
517 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
518 canceled_calls_to_user_ids =
519 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
520 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
521 delete_livekit_room = refreshed_room.room.participants.is_empty();
522 }
523
524 {
525 let pool = pool.lock();
526 for canceled_user_id in canceled_calls_to_user_ids {
527 for connection_id in pool.user_connection_ids(canceled_user_id) {
528 peer.send(
529 connection_id,
530 proto::CallCanceled {
531 room_id: room_id.to_proto(),
532 },
533 )
534 .trace_err();
535 }
536 }
537 }
538
539 for user_id in contacts_to_update {
540 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
541 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
542 if let Some((busy, contacts)) = busy.zip(contacts) {
543 let pool = pool.lock();
544 let updated_contact = contact_for_user(user_id, busy, &pool);
545 for contact in contacts {
546 if let db::Contact::Accepted {
547 user_id: contact_user_id,
548 ..
549 } = contact
550 {
551 for contact_conn_id in
552 pool.user_connection_ids(contact_user_id)
553 {
554 peer.send(
555 contact_conn_id,
556 proto::UpdateContacts {
557 contacts: vec![updated_contact.clone()],
558 remove_contacts: Default::default(),
559 incoming_requests: Default::default(),
560 remove_incoming_requests: Default::default(),
561 outgoing_requests: Default::default(),
562 remove_outgoing_requests: Default::default(),
563 },
564 )
565 .trace_err();
566 }
567 }
568 }
569 }
570 }
571
572 if let Some(live_kit) = livekit_client.as_ref() {
573 if delete_livekit_room {
574 live_kit.delete_room(livekit_room).await.trace_err();
575 }
576 }
577 }
578 }
579
580 app_state
581 .db
582 .delete_stale_servers(&app_state.config.zed_environment, server_id)
583 .await
584 .trace_err();
585 }
586 .instrument(span),
587 );
588 Ok(())
589 }
590
591 pub fn teardown(&self) {
592 self.peer.teardown();
593 self.connection_pool.lock().reset();
594 let _ = self.teardown.send(true);
595 }
596
597 #[cfg(test)]
598 pub fn reset(&self, id: ServerId) {
599 self.teardown();
600 *self.id.lock() = id;
601 self.peer.reset(id.0 as u32);
602 let _ = self.teardown.send(false);
603 }
604
605 #[cfg(test)]
606 pub fn id(&self) -> ServerId {
607 *self.id.lock()
608 }
609
610 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
611 where
612 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
613 Fut: 'static + Send + Future<Output = Result<()>>,
614 M: EnvelopedMessage,
615 {
616 let prev_handler = self.handlers.insert(
617 TypeId::of::<M>(),
618 Box::new(move |envelope, session| {
619 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
620 let received_at = envelope.received_at;
621 tracing::info!("message received");
622 let start_time = Instant::now();
623 let future = (handler)(*envelope, session);
624 async move {
625 let result = future.await;
626 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
627 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
628 let queue_duration_ms = total_duration_ms - processing_duration_ms;
629 let payload_type = M::NAME;
630
631 match result {
632 Err(error) => {
633 tracing::error!(
634 ?error,
635 total_duration_ms,
636 processing_duration_ms,
637 queue_duration_ms,
638 payload_type,
639 "error handling message"
640 )
641 }
642 Ok(()) => tracing::info!(
643 total_duration_ms,
644 processing_duration_ms,
645 queue_duration_ms,
646 "finished handling message"
647 ),
648 }
649 }
650 .boxed()
651 }),
652 );
653 if prev_handler.is_some() {
654 panic!("registered a handler for the same message twice");
655 }
656 self
657 }
658
659 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
660 where
661 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
662 Fut: 'static + Send + Future<Output = Result<()>>,
663 M: EnvelopedMessage,
664 {
665 self.add_handler(move |envelope, session| handler(envelope.payload, session));
666 self
667 }
668
669 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
670 where
671 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
672 Fut: Send + Future<Output = Result<()>>,
673 M: RequestMessage,
674 {
675 let handler = Arc::new(handler);
676 self.add_handler(move |envelope, session| {
677 let receipt = envelope.receipt();
678 let handler = handler.clone();
679 async move {
680 let peer = session.peer.clone();
681 let responded = Arc::new(AtomicBool::default());
682 let response = Response {
683 peer: peer.clone(),
684 responded: responded.clone(),
685 receipt,
686 };
687 match (handler)(envelope.payload, response, session).await {
688 Ok(()) => {
689 if responded.load(std::sync::atomic::Ordering::SeqCst) {
690 Ok(())
691 } else {
692 Err(anyhow!("handler did not send a response"))?
693 }
694 }
695 Err(error) => {
696 let proto_err = match &error {
697 Error::Internal(err) => err.to_proto(),
698 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
699 };
700 peer.respond_with_error(receipt, proto_err)?;
701 Err(error)
702 }
703 }
704 }
705 })
706 }
707
708 pub fn handle_connection(
709 self: &Arc<Self>,
710 connection: Connection,
711 address: String,
712 principal: Principal,
713 zed_version: ZedVersion,
714 geoip_country_code: Option<String>,
715 system_id: Option<String>,
716 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
717 executor: Executor,
718 ) -> impl Future<Output = ()> {
719 let this = self.clone();
720 let span = info_span!("handle connection", %address,
721 connection_id=field::Empty,
722 user_id=field::Empty,
723 login=field::Empty,
724 impersonator=field::Empty,
725 geoip_country_code=field::Empty
726 );
727 principal.update_span(&span);
728 if let Some(country_code) = geoip_country_code.as_ref() {
729 span.record("geoip_country_code", country_code);
730 }
731
732 let mut teardown = self.teardown.subscribe();
733 async move {
734 if *teardown.borrow() {
735 tracing::error!("server is tearing down");
736 return
737 }
738 let (connection_id, handle_io, mut incoming_rx) = this
739 .peer
740 .add_connection(connection, {
741 let executor = executor.clone();
742 move |duration| executor.sleep(duration)
743 });
744 tracing::Span::current().record("connection_id", format!("{}", connection_id));
745
746 tracing::info!("connection opened");
747
748 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
749 let http_client = match ReqwestClient::user_agent(&user_agent) {
750 Ok(http_client) => Arc::new(http_client),
751 Err(error) => {
752 tracing::error!(?error, "failed to create HTTP client");
753 return;
754 }
755 };
756
757 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
758 supermaven_admin_api_key.to_string(),
759 http_client.clone(),
760 )));
761
762 let session = Session {
763 principal: principal.clone(),
764 connection_id,
765 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
766 peer: this.peer.clone(),
767 connection_pool: this.connection_pool.clone(),
768 app_state: this.app_state.clone(),
769 http_client,
770 geoip_country_code,
771 system_id,
772 _executor: executor.clone(),
773 supermaven_client,
774 };
775
776 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
777 tracing::error!(?error, "failed to send initial client update");
778 return;
779 }
780
781 let handle_io = handle_io.fuse();
782 futures::pin_mut!(handle_io);
783
784 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
785 // This prevents deadlocks when e.g., client A performs a request to client B and
786 // client B performs a request to client A. If both clients stop processing further
787 // messages until their respective request completes, they won't have a chance to
788 // respond to the other client's request and cause a deadlock.
789 //
790 // This arrangement ensures we will attempt to process earlier messages first, but fall
791 // back to processing messages arrived later in the spirit of making progress.
792 let mut foreground_message_handlers = FuturesUnordered::new();
793 let concurrent_handlers = Arc::new(Semaphore::new(256));
794 loop {
795 let next_message = async {
796 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
797 let message = incoming_rx.next().await;
798 (permit, message)
799 }.fuse();
800 futures::pin_mut!(next_message);
801 futures::select_biased! {
802 _ = teardown.changed().fuse() => return,
803 result = handle_io => {
804 if let Err(error) = result {
805 tracing::error!(?error, "error handling I/O");
806 }
807 break;
808 }
809 _ = foreground_message_handlers.next() => {}
810 next_message = next_message => {
811 let (permit, message) = next_message;
812 if let Some(message) = message {
813 let type_name = message.payload_type_name();
814 // note: we copy all the fields from the parent span so we can query them in the logs.
815 // (https://github.com/tokio-rs/tracing/issues/2670).
816 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
817 user_id=field::Empty,
818 login=field::Empty,
819 impersonator=field::Empty,
820 );
821 principal.update_span(&span);
822 let span_enter = span.enter();
823 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
824 let is_background = message.is_background();
825 let handle_message = (handler)(message, session.clone());
826 drop(span_enter);
827
828 let handle_message = async move {
829 handle_message.await;
830 drop(permit);
831 }.instrument(span);
832 if is_background {
833 executor.spawn_detached(handle_message);
834 } else {
835 foreground_message_handlers.push(handle_message);
836 }
837 } else {
838 tracing::error!("no message handler");
839 }
840 } else {
841 tracing::info!("connection closed");
842 break;
843 }
844 }
845 }
846 }
847
848 drop(foreground_message_handlers);
849 tracing::info!("signing out");
850 if let Err(error) = connection_lost(session, teardown, executor).await {
851 tracing::error!(?error, "error signing out");
852 }
853
854 }.instrument(span)
855 }
856
857 async fn send_initial_client_update(
858 &self,
859 connection_id: ConnectionId,
860 principal: &Principal,
861 zed_version: ZedVersion,
862 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
863 session: &Session,
864 ) -> Result<()> {
865 self.peer.send(
866 connection_id,
867 proto::Hello {
868 peer_id: Some(connection_id.into()),
869 },
870 )?;
871 tracing::info!("sent hello message");
872 if let Some(send_connection_id) = send_connection_id.take() {
873 let _ = send_connection_id.send(connection_id);
874 }
875
876 match principal {
877 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
878 if !user.connected_once {
879 self.peer.send(connection_id, proto::ShowContacts {})?;
880 self.app_state
881 .db
882 .set_user_connected_once(user.id, true)
883 .await?;
884 }
885
886 update_user_plan(user.id, session).await?;
887
888 let contacts = self.app_state.db.get_contacts(user.id).await?;
889
890 {
891 let mut pool = self.connection_pool.lock();
892 pool.add_connection(connection_id, user.id, user.admin, zed_version);
893 self.peer.send(
894 connection_id,
895 build_initial_contacts_update(contacts, &pool),
896 )?;
897 }
898
899 if should_auto_subscribe_to_channels(zed_version) {
900 subscribe_user_to_channels(user.id, session).await?;
901 }
902
903 if let Some(incoming_call) =
904 self.app_state.db.incoming_call_for_user(user.id).await?
905 {
906 self.peer.send(connection_id, incoming_call)?;
907 }
908
909 update_user_contacts(user.id, session).await?;
910 }
911 }
912
913 Ok(())
914 }
915
916 pub async fn invite_code_redeemed(
917 self: &Arc<Self>,
918 inviter_id: UserId,
919 invitee_id: UserId,
920 ) -> Result<()> {
921 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
922 if let Some(code) = &user.invite_code {
923 let pool = self.connection_pool.lock();
924 let invitee_contact = contact_for_user(invitee_id, false, &pool);
925 for connection_id in pool.user_connection_ids(inviter_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateContacts {
929 contacts: vec![invitee_contact.clone()],
930 ..Default::default()
931 },
932 )?;
933 self.peer.send(
934 connection_id,
935 proto::UpdateInviteInfo {
936 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
937 count: user.invite_count as u32,
938 },
939 )?;
940 }
941 }
942 }
943 Ok(())
944 }
945
946 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
947 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
948 if let Some(invite_code) = &user.invite_code {
949 let pool = self.connection_pool.lock();
950 for connection_id in pool.user_connection_ids(user_id) {
951 self.peer.send(
952 connection_id,
953 proto::UpdateInviteInfo {
954 url: format!(
955 "{}{}",
956 self.app_state.config.invite_link_prefix, invite_code
957 ),
958 count: user.invite_count as u32,
959 },
960 )?;
961 }
962 }
963 }
964 Ok(())
965 }
966
967 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
968 let pool = self.connection_pool.lock();
969 for connection_id in pool.user_connection_ids(user_id) {
970 self.peer
971 .send(connection_id, proto::RefreshLlmToken {})
972 .trace_err();
973 }
974 }
975
976 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
977 ServerSnapshot {
978 connection_pool: ConnectionPoolGuard {
979 guard: self.connection_pool.lock(),
980 _not_send: PhantomData,
981 },
982 peer: &self.peer,
983 }
984 }
985}
986
987impl Deref for ConnectionPoolGuard<'_> {
988 type Target = ConnectionPool;
989
990 fn deref(&self) -> &Self::Target {
991 &self.guard
992 }
993}
994
995impl DerefMut for ConnectionPoolGuard<'_> {
996 fn deref_mut(&mut self) -> &mut Self::Target {
997 &mut self.guard
998 }
999}
1000
1001impl Drop for ConnectionPoolGuard<'_> {
1002 fn drop(&mut self) {
1003 #[cfg(test)]
1004 self.check_invariants();
1005 }
1006}
1007
1008fn broadcast<F>(
1009 sender_id: Option<ConnectionId>,
1010 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1011 mut f: F,
1012) where
1013 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1014{
1015 for receiver_id in receiver_ids {
1016 if Some(receiver_id) != sender_id {
1017 if let Err(error) = f(receiver_id) {
1018 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1019 }
1020 }
1021 }
1022}
1023
1024pub struct ProtocolVersion(u32);
1025
1026impl Header for ProtocolVersion {
1027 fn name() -> &'static HeaderName {
1028 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1029 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1030 }
1031
1032 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1033 where
1034 Self: Sized,
1035 I: Iterator<Item = &'i axum::http::HeaderValue>,
1036 {
1037 let version = values
1038 .next()
1039 .ok_or_else(axum::headers::Error::invalid)?
1040 .to_str()
1041 .map_err(|_| axum::headers::Error::invalid())?
1042 .parse()
1043 .map_err(|_| axum::headers::Error::invalid())?;
1044 Ok(Self(version))
1045 }
1046
1047 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1048 values.extend([self.0.to_string().parse().unwrap()]);
1049 }
1050}
1051
1052pub struct AppVersionHeader(SemanticVersion);
1053impl Header for AppVersionHeader {
1054 fn name() -> &'static HeaderName {
1055 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1057 }
1058
1059 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060 where
1061 Self: Sized,
1062 I: Iterator<Item = &'i axum::http::HeaderValue>,
1063 {
1064 let version = values
1065 .next()
1066 .ok_or_else(axum::headers::Error::invalid)?
1067 .to_str()
1068 .map_err(|_| axum::headers::Error::invalid())?
1069 .parse()
1070 .map_err(|_| axum::headers::Error::invalid())?;
1071 Ok(Self(version))
1072 }
1073
1074 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075 values.extend([self.0.to_string().parse().unwrap()]);
1076 }
1077}
1078
1079pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1080 Router::new()
1081 .route("/rpc", get(handle_websocket_request))
1082 .layer(
1083 ServiceBuilder::new()
1084 .layer(Extension(server.app_state.clone()))
1085 .layer(middleware::from_fn(auth::validate_header)),
1086 )
1087 .route("/metrics", get(handle_metrics))
1088 .layer(Extension(server))
1089}
1090
1091pub async fn handle_websocket_request(
1092 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1093 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1094 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1095 Extension(server): Extension<Arc<Server>>,
1096 Extension(principal): Extension<Principal>,
1097 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1098 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1099 ws: WebSocketUpgrade,
1100) -> axum::response::Response {
1101 if protocol_version != rpc::PROTOCOL_VERSION {
1102 return (
1103 StatusCode::UPGRADE_REQUIRED,
1104 "client must be upgraded".to_string(),
1105 )
1106 .into_response();
1107 }
1108
1109 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1110 return (
1111 StatusCode::UPGRADE_REQUIRED,
1112 "no version header found".to_string(),
1113 )
1114 .into_response();
1115 };
1116
1117 if !version.can_collaborate() {
1118 return (
1119 StatusCode::UPGRADE_REQUIRED,
1120 "client must be upgraded".to_string(),
1121 )
1122 .into_response();
1123 }
1124
1125 let socket_address = socket_address.to_string();
1126 ws.on_upgrade(move |socket| {
1127 let socket = socket
1128 .map_ok(to_tungstenite_message)
1129 .err_into()
1130 .with(|message| async move { to_axum_message(message) });
1131 let connection = Connection::new(Box::pin(socket));
1132 async move {
1133 server
1134 .handle_connection(
1135 connection,
1136 socket_address,
1137 principal,
1138 version,
1139 country_code_header.map(|header| header.to_string()),
1140 system_id_header.map(|header| header.to_string()),
1141 None,
1142 Executor::Production,
1143 )
1144 .await;
1145 }
1146 })
1147}
1148
1149pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1150 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151 let connections_metric = CONNECTIONS_METRIC
1152 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1153
1154 let connections = server
1155 .connection_pool
1156 .lock()
1157 .connections()
1158 .filter(|connection| !connection.admin)
1159 .count();
1160 connections_metric.set(connections as _);
1161
1162 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1164 register_int_gauge!(
1165 "shared_projects",
1166 "number of open projects with one or more guests"
1167 )
1168 .unwrap()
1169 });
1170
1171 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1172 shared_projects_metric.set(shared_projects as _);
1173
1174 let encoder = prometheus::TextEncoder::new();
1175 let metric_families = prometheus::gather();
1176 let encoded_metrics = encoder
1177 .encode_to_string(&metric_families)
1178 .map_err(|err| anyhow!("{}", err))?;
1179 Ok(encoded_metrics)
1180}
1181
1182#[instrument(err, skip(executor))]
1183async fn connection_lost(
1184 session: Session,
1185 mut teardown: watch::Receiver<bool>,
1186 executor: Executor,
1187) -> Result<()> {
1188 session.peer.disconnect(session.connection_id);
1189 session
1190 .connection_pool()
1191 .await
1192 .remove_connection(session.connection_id)?;
1193
1194 session
1195 .db()
1196 .await
1197 .connection_lost(session.connection_id)
1198 .await
1199 .trace_err();
1200
1201 futures::select_biased! {
1202 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1203
1204 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1205 leave_room_for_session(&session, session.connection_id).await.trace_err();
1206 leave_channel_buffers_for_session(&session)
1207 .await
1208 .trace_err();
1209
1210 if !session
1211 .connection_pool()
1212 .await
1213 .is_user_online(session.user_id())
1214 {
1215 let db = session.db().await;
1216 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1217 room_updated(&room, &session.peer);
1218 }
1219 }
1220
1221 update_user_contacts(session.user_id(), &session).await?;
1222 },
1223 _ = teardown.changed().fuse() => {}
1224 }
1225
1226 Ok(())
1227}
1228
1229/// Acknowledges a ping from a client, used to keep the connection alive.
1230async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1231 response.send(proto::Ack {})?;
1232 Ok(())
1233}
1234
1235/// Creates a new room for calling (outside of channels)
1236async fn create_room(
1237 _request: proto::CreateRoom,
1238 response: Response<proto::CreateRoom>,
1239 session: Session,
1240) -> Result<()> {
1241 let livekit_room = nanoid::nanoid!(30);
1242
1243 let live_kit_connection_info = util::maybe!(async {
1244 let live_kit = session.app_state.livekit_client.as_ref();
1245 let live_kit = live_kit?;
1246 let user_id = session.user_id().to_string();
1247
1248 let token = live_kit
1249 .room_token(&livekit_room, &user_id.to_string())
1250 .trace_err()?;
1251
1252 Some(proto::LiveKitConnectionInfo {
1253 server_url: live_kit.url().into(),
1254 token,
1255 can_publish: true,
1256 })
1257 })
1258 .await;
1259
1260 let room = session
1261 .db()
1262 .await
1263 .create_room(session.user_id(), session.connection_id, &livekit_room)
1264 .await?;
1265
1266 response.send(proto::CreateRoomResponse {
1267 room: Some(room.clone()),
1268 live_kit_connection_info,
1269 })?;
1270
1271 update_user_contacts(session.user_id(), &session).await?;
1272 Ok(())
1273}
1274
1275/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1276async fn join_room(
1277 request: proto::JoinRoom,
1278 response: Response<proto::JoinRoom>,
1279 session: Session,
1280) -> Result<()> {
1281 let room_id = RoomId::from_proto(request.id);
1282
1283 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1284
1285 if let Some(channel_id) = channel_id {
1286 return join_channel_internal(channel_id, Box::new(response), session).await;
1287 }
1288
1289 let joined_room = {
1290 let room = session
1291 .db()
1292 .await
1293 .join_room(room_id, session.user_id(), session.connection_id)
1294 .await?;
1295 room_updated(&room.room, &session.peer);
1296 room.into_inner()
1297 };
1298
1299 for connection_id in session
1300 .connection_pool()
1301 .await
1302 .user_connection_ids(session.user_id())
1303 {
1304 session
1305 .peer
1306 .send(
1307 connection_id,
1308 proto::CallCanceled {
1309 room_id: room_id.to_proto(),
1310 },
1311 )
1312 .trace_err();
1313 }
1314
1315 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1316 {
1317 live_kit
1318 .room_token(
1319 &joined_room.room.livekit_room,
1320 &session.user_id().to_string(),
1321 )
1322 .trace_err()
1323 .map(|token| proto::LiveKitConnectionInfo {
1324 server_url: live_kit.url().into(),
1325 token,
1326 can_publish: true,
1327 })
1328 } else {
1329 None
1330 };
1331
1332 response.send(proto::JoinRoomResponse {
1333 room: Some(joined_room.room),
1334 channel_id: None,
1335 live_kit_connection_info,
1336 })?;
1337
1338 update_user_contacts(session.user_id(), &session).await?;
1339 Ok(())
1340}
1341
1342/// Rejoin room is used to reconnect to a room after connection errors.
1343async fn rejoin_room(
1344 request: proto::RejoinRoom,
1345 response: Response<proto::RejoinRoom>,
1346 session: Session,
1347) -> Result<()> {
1348 let room;
1349 let channel;
1350 {
1351 let mut rejoined_room = session
1352 .db()
1353 .await
1354 .rejoin_room(request, session.user_id(), session.connection_id)
1355 .await?;
1356
1357 response.send(proto::RejoinRoomResponse {
1358 room: Some(rejoined_room.room.clone()),
1359 reshared_projects: rejoined_room
1360 .reshared_projects
1361 .iter()
1362 .map(|project| proto::ResharedProject {
1363 id: project.id.to_proto(),
1364 collaborators: project
1365 .collaborators
1366 .iter()
1367 .map(|collaborator| collaborator.to_proto())
1368 .collect(),
1369 })
1370 .collect(),
1371 rejoined_projects: rejoined_room
1372 .rejoined_projects
1373 .iter()
1374 .map(|rejoined_project| rejoined_project.to_proto())
1375 .collect(),
1376 })?;
1377 room_updated(&rejoined_room.room, &session.peer);
1378
1379 for project in &rejoined_room.reshared_projects {
1380 for collaborator in &project.collaborators {
1381 session
1382 .peer
1383 .send(
1384 collaborator.connection_id,
1385 proto::UpdateProjectCollaborator {
1386 project_id: project.id.to_proto(),
1387 old_peer_id: Some(project.old_connection_id.into()),
1388 new_peer_id: Some(session.connection_id.into()),
1389 },
1390 )
1391 .trace_err();
1392 }
1393
1394 broadcast(
1395 Some(session.connection_id),
1396 project
1397 .collaborators
1398 .iter()
1399 .map(|collaborator| collaborator.connection_id),
1400 |connection_id| {
1401 session.peer.forward_send(
1402 session.connection_id,
1403 connection_id,
1404 proto::UpdateProject {
1405 project_id: project.id.to_proto(),
1406 worktrees: project.worktrees.clone(),
1407 },
1408 )
1409 },
1410 );
1411 }
1412
1413 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1414
1415 let rejoined_room = rejoined_room.into_inner();
1416
1417 room = rejoined_room.room;
1418 channel = rejoined_room.channel;
1419 }
1420
1421 if let Some(channel) = channel {
1422 channel_updated(
1423 &channel,
1424 &room,
1425 &session.peer,
1426 &*session.connection_pool().await,
1427 );
1428 }
1429
1430 update_user_contacts(session.user_id(), &session).await?;
1431 Ok(())
1432}
1433
1434fn notify_rejoined_projects(
1435 rejoined_projects: &mut Vec<RejoinedProject>,
1436 session: &Session,
1437) -> Result<()> {
1438 for project in rejoined_projects.iter() {
1439 for collaborator in &project.collaborators {
1440 session
1441 .peer
1442 .send(
1443 collaborator.connection_id,
1444 proto::UpdateProjectCollaborator {
1445 project_id: project.id.to_proto(),
1446 old_peer_id: Some(project.old_connection_id.into()),
1447 new_peer_id: Some(session.connection_id.into()),
1448 },
1449 )
1450 .trace_err();
1451 }
1452 }
1453
1454 for project in rejoined_projects {
1455 for worktree in mem::take(&mut project.worktrees) {
1456 // Stream this worktree's entries.
1457 let message = proto::UpdateWorktree {
1458 project_id: project.id.to_proto(),
1459 worktree_id: worktree.id,
1460 abs_path: worktree.abs_path.clone(),
1461 root_name: worktree.root_name,
1462 updated_entries: worktree.updated_entries,
1463 removed_entries: worktree.removed_entries,
1464 scan_id: worktree.scan_id,
1465 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1466 updated_repositories: worktree.updated_repositories,
1467 removed_repositories: worktree.removed_repositories,
1468 };
1469 for update in proto::split_worktree_update(message) {
1470 session.peer.send(session.connection_id, update)?;
1471 }
1472
1473 // Stream this worktree's diagnostics.
1474 for summary in worktree.diagnostic_summaries {
1475 session.peer.send(
1476 session.connection_id,
1477 proto::UpdateDiagnosticSummary {
1478 project_id: project.id.to_proto(),
1479 worktree_id: worktree.id,
1480 summary: Some(summary),
1481 },
1482 )?;
1483 }
1484
1485 for settings_file in worktree.settings_files {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateWorktreeSettings {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 path: settings_file.path,
1492 content: Some(settings_file.content),
1493 kind: Some(settings_file.kind.to_proto().into()),
1494 },
1495 )?;
1496 }
1497 }
1498
1499 for repository in mem::take(&mut project.updated_repositories) {
1500 for update in split_repository_update(repository) {
1501 session.peer.send(session.connection_id, update)?;
1502 }
1503 }
1504
1505 for id in mem::take(&mut project.removed_repositories) {
1506 session.peer.send(
1507 session.connection_id,
1508 proto::RemoveRepository {
1509 project_id: project.id.to_proto(),
1510 id,
1511 },
1512 )?;
1513 }
1514 }
1515
1516 Ok(())
1517}
1518
1519/// leave room disconnects from the room.
1520async fn leave_room(
1521 _: proto::LeaveRoom,
1522 response: Response<proto::LeaveRoom>,
1523 session: Session,
1524) -> Result<()> {
1525 leave_room_for_session(&session, session.connection_id).await?;
1526 response.send(proto::Ack {})?;
1527 Ok(())
1528}
1529
1530/// Updates the permissions of someone else in the room.
1531async fn set_room_participant_role(
1532 request: proto::SetRoomParticipantRole,
1533 response: Response<proto::SetRoomParticipantRole>,
1534 session: Session,
1535) -> Result<()> {
1536 let user_id = UserId::from_proto(request.user_id);
1537 let role = ChannelRole::from(request.role());
1538
1539 let (livekit_room, can_publish) = {
1540 let room = session
1541 .db()
1542 .await
1543 .set_room_participant_role(
1544 session.user_id(),
1545 RoomId::from_proto(request.room_id),
1546 user_id,
1547 role,
1548 )
1549 .await?;
1550
1551 let livekit_room = room.livekit_room.clone();
1552 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1553 room_updated(&room, &session.peer);
1554 (livekit_room, can_publish)
1555 };
1556
1557 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1558 live_kit
1559 .update_participant(
1560 livekit_room.clone(),
1561 request.user_id.to_string(),
1562 livekit_api::proto::ParticipantPermission {
1563 can_subscribe: true,
1564 can_publish,
1565 can_publish_data: can_publish,
1566 hidden: false,
1567 recorder: false,
1568 },
1569 )
1570 .await
1571 .trace_err();
1572 }
1573
1574 response.send(proto::Ack {})?;
1575 Ok(())
1576}
1577
1578/// Call someone else into the current room
1579async fn call(
1580 request: proto::Call,
1581 response: Response<proto::Call>,
1582 session: Session,
1583) -> Result<()> {
1584 let room_id = RoomId::from_proto(request.room_id);
1585 let calling_user_id = session.user_id();
1586 let calling_connection_id = session.connection_id;
1587 let called_user_id = UserId::from_proto(request.called_user_id);
1588 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1589 if !session
1590 .db()
1591 .await
1592 .has_contact(calling_user_id, called_user_id)
1593 .await?
1594 {
1595 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1596 }
1597
1598 let incoming_call = {
1599 let (room, incoming_call) = &mut *session
1600 .db()
1601 .await
1602 .call(
1603 room_id,
1604 calling_user_id,
1605 calling_connection_id,
1606 called_user_id,
1607 initial_project_id,
1608 )
1609 .await?;
1610 room_updated(room, &session.peer);
1611 mem::take(incoming_call)
1612 };
1613 update_user_contacts(called_user_id, &session).await?;
1614
1615 let mut calls = session
1616 .connection_pool()
1617 .await
1618 .user_connection_ids(called_user_id)
1619 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1620 .collect::<FuturesUnordered<_>>();
1621
1622 while let Some(call_response) = calls.next().await {
1623 match call_response.as_ref() {
1624 Ok(_) => {
1625 response.send(proto::Ack {})?;
1626 return Ok(());
1627 }
1628 Err(_) => {
1629 call_response.trace_err();
1630 }
1631 }
1632 }
1633
1634 {
1635 let room = session
1636 .db()
1637 .await
1638 .call_failed(room_id, called_user_id)
1639 .await?;
1640 room_updated(&room, &session.peer);
1641 }
1642 update_user_contacts(called_user_id, &session).await?;
1643
1644 Err(anyhow!("failed to ring user"))?
1645}
1646
1647/// Cancel an outgoing call.
1648async fn cancel_call(
1649 request: proto::CancelCall,
1650 response: Response<proto::CancelCall>,
1651 session: Session,
1652) -> Result<()> {
1653 let called_user_id = UserId::from_proto(request.called_user_id);
1654 let room_id = RoomId::from_proto(request.room_id);
1655 {
1656 let room = session
1657 .db()
1658 .await
1659 .cancel_call(room_id, session.connection_id, called_user_id)
1660 .await?;
1661 room_updated(&room, &session.peer);
1662 }
1663
1664 for connection_id in session
1665 .connection_pool()
1666 .await
1667 .user_connection_ids(called_user_id)
1668 {
1669 session
1670 .peer
1671 .send(
1672 connection_id,
1673 proto::CallCanceled {
1674 room_id: room_id.to_proto(),
1675 },
1676 )
1677 .trace_err();
1678 }
1679 response.send(proto::Ack {})?;
1680
1681 update_user_contacts(called_user_id, &session).await?;
1682 Ok(())
1683}
1684
1685/// Decline an incoming call.
1686async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1687 let room_id = RoomId::from_proto(message.room_id);
1688 {
1689 let room = session
1690 .db()
1691 .await
1692 .decline_call(Some(room_id), session.user_id())
1693 .await?
1694 .ok_or_else(|| anyhow!("failed to decline call"))?;
1695 room_updated(&room, &session.peer);
1696 }
1697
1698 for connection_id in session
1699 .connection_pool()
1700 .await
1701 .user_connection_ids(session.user_id())
1702 {
1703 session
1704 .peer
1705 .send(
1706 connection_id,
1707 proto::CallCanceled {
1708 room_id: room_id.to_proto(),
1709 },
1710 )
1711 .trace_err();
1712 }
1713 update_user_contacts(session.user_id(), &session).await?;
1714 Ok(())
1715}
1716
1717/// Updates other participants in the room with your current location.
1718async fn update_participant_location(
1719 request: proto::UpdateParticipantLocation,
1720 response: Response<proto::UpdateParticipantLocation>,
1721 session: Session,
1722) -> Result<()> {
1723 let room_id = RoomId::from_proto(request.room_id);
1724 let location = request
1725 .location
1726 .ok_or_else(|| anyhow!("invalid location"))?;
1727
1728 let db = session.db().await;
1729 let room = db
1730 .update_room_participant_location(room_id, session.connection_id, location)
1731 .await?;
1732
1733 room_updated(&room, &session.peer);
1734 response.send(proto::Ack {})?;
1735 Ok(())
1736}
1737
1738/// Share a project into the room.
1739async fn share_project(
1740 request: proto::ShareProject,
1741 response: Response<proto::ShareProject>,
1742 session: Session,
1743) -> Result<()> {
1744 let (project_id, room) = &*session
1745 .db()
1746 .await
1747 .share_project(
1748 RoomId::from_proto(request.room_id),
1749 session.connection_id,
1750 &request.worktrees,
1751 request.is_ssh_project,
1752 )
1753 .await?;
1754 response.send(proto::ShareProjectResponse {
1755 project_id: project_id.to_proto(),
1756 })?;
1757 room_updated(room, &session.peer);
1758
1759 Ok(())
1760}
1761
1762/// Unshare a project from the room.
1763async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1764 let project_id = ProjectId::from_proto(message.project_id);
1765 unshare_project_internal(project_id, session.connection_id, &session).await
1766}
1767
1768async fn unshare_project_internal(
1769 project_id: ProjectId,
1770 connection_id: ConnectionId,
1771 session: &Session,
1772) -> Result<()> {
1773 let delete = {
1774 let room_guard = session
1775 .db()
1776 .await
1777 .unshare_project(project_id, connection_id)
1778 .await?;
1779
1780 let (delete, room, guest_connection_ids) = &*room_guard;
1781
1782 let message = proto::UnshareProject {
1783 project_id: project_id.to_proto(),
1784 };
1785
1786 broadcast(
1787 Some(connection_id),
1788 guest_connection_ids.iter().copied(),
1789 |conn_id| session.peer.send(conn_id, message.clone()),
1790 );
1791 if let Some(room) = room {
1792 room_updated(room, &session.peer);
1793 }
1794
1795 *delete
1796 };
1797
1798 if delete {
1799 let db = session.db().await;
1800 db.delete_project(project_id).await?;
1801 }
1802
1803 Ok(())
1804}
1805
1806/// Join someone elses shared project.
1807async fn join_project(
1808 request: proto::JoinProject,
1809 response: Response<proto::JoinProject>,
1810 session: Session,
1811) -> Result<()> {
1812 let project_id = ProjectId::from_proto(request.project_id);
1813
1814 tracing::info!(%project_id, "join project");
1815
1816 let db = session.db().await;
1817 let (project, replica_id) = &mut *db
1818 .join_project(project_id, session.connection_id, session.user_id())
1819 .await?;
1820 drop(db);
1821 tracing::info!(%project_id, "join remote project");
1822 join_project_internal(response, session, project, replica_id)
1823}
1824
1825trait JoinProjectInternalResponse {
1826 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1827}
1828impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1829 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1830 Response::<proto::JoinProject>::send(self, result)
1831 }
1832}
1833
1834fn join_project_internal(
1835 response: impl JoinProjectInternalResponse,
1836 session: Session,
1837 project: &mut Project,
1838 replica_id: &ReplicaId,
1839) -> Result<()> {
1840 let collaborators = project
1841 .collaborators
1842 .iter()
1843 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1844 .map(|collaborator| collaborator.to_proto())
1845 .collect::<Vec<_>>();
1846 let project_id = project.id;
1847 let guest_user_id = session.user_id();
1848
1849 let worktrees = project
1850 .worktrees
1851 .iter()
1852 .map(|(id, worktree)| proto::WorktreeMetadata {
1853 id: *id,
1854 root_name: worktree.root_name.clone(),
1855 visible: worktree.visible,
1856 abs_path: worktree.abs_path.clone(),
1857 })
1858 .collect::<Vec<_>>();
1859
1860 let add_project_collaborator = proto::AddProjectCollaborator {
1861 project_id: project_id.to_proto(),
1862 collaborator: Some(proto::Collaborator {
1863 peer_id: Some(session.connection_id.into()),
1864 replica_id: replica_id.0 as u32,
1865 user_id: guest_user_id.to_proto(),
1866 is_host: false,
1867 }),
1868 };
1869
1870 for collaborator in &collaborators {
1871 session
1872 .peer
1873 .send(
1874 collaborator.peer_id.unwrap().into(),
1875 add_project_collaborator.clone(),
1876 )
1877 .trace_err();
1878 }
1879
1880 // First, we send the metadata associated with each worktree.
1881 response.send(proto::JoinProjectResponse {
1882 project_id: project.id.0 as u64,
1883 worktrees: worktrees.clone(),
1884 replica_id: replica_id.0 as u32,
1885 collaborators: collaborators.clone(),
1886 language_servers: project.language_servers.clone(),
1887 role: project.role.into(),
1888 })?;
1889
1890 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1891 // Stream this worktree's entries.
1892 let message = proto::UpdateWorktree {
1893 project_id: project_id.to_proto(),
1894 worktree_id,
1895 abs_path: worktree.abs_path.clone(),
1896 root_name: worktree.root_name,
1897 updated_entries: worktree.entries,
1898 removed_entries: Default::default(),
1899 scan_id: worktree.scan_id,
1900 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1901 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1902 removed_repositories: Default::default(),
1903 };
1904 for update in proto::split_worktree_update(message) {
1905 session.peer.send(session.connection_id, update.clone())?;
1906 }
1907
1908 // Stream this worktree's diagnostics.
1909 for summary in worktree.diagnostic_summaries {
1910 session.peer.send(
1911 session.connection_id,
1912 proto::UpdateDiagnosticSummary {
1913 project_id: project_id.to_proto(),
1914 worktree_id: worktree.id,
1915 summary: Some(summary),
1916 },
1917 )?;
1918 }
1919
1920 for settings_file in worktree.settings_files {
1921 session.peer.send(
1922 session.connection_id,
1923 proto::UpdateWorktreeSettings {
1924 project_id: project_id.to_proto(),
1925 worktree_id: worktree.id,
1926 path: settings_file.path,
1927 content: Some(settings_file.content),
1928 kind: Some(settings_file.kind.to_proto() as i32),
1929 },
1930 )?;
1931 }
1932 }
1933
1934 for repository in mem::take(&mut project.repositories) {
1935 for update in split_repository_update(repository) {
1936 session.peer.send(session.connection_id, update)?;
1937 }
1938 }
1939
1940 for language_server in &project.language_servers {
1941 session.peer.send(
1942 session.connection_id,
1943 proto::UpdateLanguageServer {
1944 project_id: project_id.to_proto(),
1945 language_server_id: language_server.id,
1946 variant: Some(
1947 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1948 proto::LspDiskBasedDiagnosticsUpdated {},
1949 ),
1950 ),
1951 },
1952 )?;
1953 }
1954
1955 Ok(())
1956}
1957
1958/// Leave someone elses shared project.
1959async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1960 let sender_id = session.connection_id;
1961 let project_id = ProjectId::from_proto(request.project_id);
1962 let db = session.db().await;
1963
1964 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1965 tracing::info!(
1966 %project_id,
1967 "leave project"
1968 );
1969
1970 project_left(project, &session);
1971 if let Some(room) = room {
1972 room_updated(room, &session.peer);
1973 }
1974
1975 Ok(())
1976}
1977
1978/// Updates other participants with changes to the project
1979async fn update_project(
1980 request: proto::UpdateProject,
1981 response: Response<proto::UpdateProject>,
1982 session: Session,
1983) -> Result<()> {
1984 let project_id = ProjectId::from_proto(request.project_id);
1985 let (room, guest_connection_ids) = &*session
1986 .db()
1987 .await
1988 .update_project(project_id, session.connection_id, &request.worktrees)
1989 .await?;
1990 broadcast(
1991 Some(session.connection_id),
1992 guest_connection_ids.iter().copied(),
1993 |connection_id| {
1994 session
1995 .peer
1996 .forward_send(session.connection_id, connection_id, request.clone())
1997 },
1998 );
1999 if let Some(room) = room {
2000 room_updated(room, &session.peer);
2001 }
2002 response.send(proto::Ack {})?;
2003
2004 Ok(())
2005}
2006
2007/// Updates other participants with changes to the worktree
2008async fn update_worktree(
2009 request: proto::UpdateWorktree,
2010 response: Response<proto::UpdateWorktree>,
2011 session: Session,
2012) -> Result<()> {
2013 let guest_connection_ids = session
2014 .db()
2015 .await
2016 .update_worktree(&request, session.connection_id)
2017 .await?;
2018
2019 broadcast(
2020 Some(session.connection_id),
2021 guest_connection_ids.iter().copied(),
2022 |connection_id| {
2023 session
2024 .peer
2025 .forward_send(session.connection_id, connection_id, request.clone())
2026 },
2027 );
2028 response.send(proto::Ack {})?;
2029 Ok(())
2030}
2031
2032async fn update_repository(
2033 request: proto::UpdateRepository,
2034 response: Response<proto::UpdateRepository>,
2035 session: Session,
2036) -> Result<()> {
2037 let guest_connection_ids = session
2038 .db()
2039 .await
2040 .update_repository(&request, session.connection_id)
2041 .await?;
2042
2043 broadcast(
2044 Some(session.connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |connection_id| {
2047 session
2048 .peer
2049 .forward_send(session.connection_id, connection_id, request.clone())
2050 },
2051 );
2052 response.send(proto::Ack {})?;
2053 Ok(())
2054}
2055
2056async fn remove_repository(
2057 request: proto::RemoveRepository,
2058 response: Response<proto::RemoveRepository>,
2059 session: Session,
2060) -> Result<()> {
2061 let guest_connection_ids = session
2062 .db()
2063 .await
2064 .remove_repository(&request, session.connection_id)
2065 .await?;
2066
2067 broadcast(
2068 Some(session.connection_id),
2069 guest_connection_ids.iter().copied(),
2070 |connection_id| {
2071 session
2072 .peer
2073 .forward_send(session.connection_id, connection_id, request.clone())
2074 },
2075 );
2076 response.send(proto::Ack {})?;
2077 Ok(())
2078}
2079
2080/// Updates other participants with changes to the diagnostics
2081async fn update_diagnostic_summary(
2082 message: proto::UpdateDiagnosticSummary,
2083 session: Session,
2084) -> Result<()> {
2085 let guest_connection_ids = session
2086 .db()
2087 .await
2088 .update_diagnostic_summary(&message, session.connection_id)
2089 .await?;
2090
2091 broadcast(
2092 Some(session.connection_id),
2093 guest_connection_ids.iter().copied(),
2094 |connection_id| {
2095 session
2096 .peer
2097 .forward_send(session.connection_id, connection_id, message.clone())
2098 },
2099 );
2100
2101 Ok(())
2102}
2103
2104/// Updates other participants with changes to the worktree settings
2105async fn update_worktree_settings(
2106 message: proto::UpdateWorktreeSettings,
2107 session: Session,
2108) -> Result<()> {
2109 let guest_connection_ids = session
2110 .db()
2111 .await
2112 .update_worktree_settings(&message, session.connection_id)
2113 .await?;
2114
2115 broadcast(
2116 Some(session.connection_id),
2117 guest_connection_ids.iter().copied(),
2118 |connection_id| {
2119 session
2120 .peer
2121 .forward_send(session.connection_id, connection_id, message.clone())
2122 },
2123 );
2124
2125 Ok(())
2126}
2127
2128/// Notify other participants that a language server has started.
2129async fn start_language_server(
2130 request: proto::StartLanguageServer,
2131 session: Session,
2132) -> Result<()> {
2133 let guest_connection_ids = session
2134 .db()
2135 .await
2136 .start_language_server(&request, session.connection_id)
2137 .await?;
2138
2139 broadcast(
2140 Some(session.connection_id),
2141 guest_connection_ids.iter().copied(),
2142 |connection_id| {
2143 session
2144 .peer
2145 .forward_send(session.connection_id, connection_id, request.clone())
2146 },
2147 );
2148 Ok(())
2149}
2150
2151/// Notify other participants that a language server has changed.
2152async fn update_language_server(
2153 request: proto::UpdateLanguageServer,
2154 session: Session,
2155) -> Result<()> {
2156 let project_id = ProjectId::from_proto(request.project_id);
2157 let project_connection_ids = session
2158 .db()
2159 .await
2160 .project_connection_ids(project_id, session.connection_id, true)
2161 .await?;
2162 broadcast(
2163 Some(session.connection_id),
2164 project_connection_ids.iter().copied(),
2165 |connection_id| {
2166 session
2167 .peer
2168 .forward_send(session.connection_id, connection_id, request.clone())
2169 },
2170 );
2171 Ok(())
2172}
2173
2174/// forward a project request to the host. These requests should be read only
2175/// as guests are allowed to send them.
2176async fn forward_read_only_project_request<T>(
2177 request: T,
2178 response: Response<T>,
2179 session: Session,
2180) -> Result<()>
2181where
2182 T: EntityMessage + RequestMessage,
2183{
2184 let project_id = ProjectId::from_proto(request.remote_entity_id());
2185 let host_connection_id = session
2186 .db()
2187 .await
2188 .host_for_read_only_project_request(project_id, session.connection_id)
2189 .await?;
2190 let payload = session
2191 .peer
2192 .forward_request(session.connection_id, host_connection_id, request)
2193 .await?;
2194 response.send(payload)?;
2195 Ok(())
2196}
2197
2198async fn forward_find_search_candidates_request(
2199 request: proto::FindSearchCandidates,
2200 response: Response<proto::FindSearchCandidates>,
2201 session: Session,
2202) -> Result<()> {
2203 let project_id = ProjectId::from_proto(request.remote_entity_id());
2204 let host_connection_id = session
2205 .db()
2206 .await
2207 .host_for_read_only_project_request(project_id, session.connection_id)
2208 .await?;
2209 let payload = session
2210 .peer
2211 .forward_request(session.connection_id, host_connection_id, request)
2212 .await?;
2213 response.send(payload)?;
2214 Ok(())
2215}
2216
2217/// forward a project request to the host. These requests are disallowed
2218/// for guests.
2219async fn forward_mutating_project_request<T>(
2220 request: T,
2221 response: Response<T>,
2222 session: Session,
2223) -> Result<()>
2224where
2225 T: EntityMessage + RequestMessage,
2226{
2227 let project_id = ProjectId::from_proto(request.remote_entity_id());
2228
2229 let host_connection_id = session
2230 .db()
2231 .await
2232 .host_for_mutating_project_request(project_id, session.connection_id)
2233 .await?;
2234 let payload = session
2235 .peer
2236 .forward_request(session.connection_id, host_connection_id, request)
2237 .await?;
2238 response.send(payload)?;
2239 Ok(())
2240}
2241
2242/// Notify other participants that a new buffer has been created
2243async fn create_buffer_for_peer(
2244 request: proto::CreateBufferForPeer,
2245 session: Session,
2246) -> Result<()> {
2247 session
2248 .db()
2249 .await
2250 .check_user_is_project_host(
2251 ProjectId::from_proto(request.project_id),
2252 session.connection_id,
2253 )
2254 .await?;
2255 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2256 session
2257 .peer
2258 .forward_send(session.connection_id, peer_id.into(), request)?;
2259 Ok(())
2260}
2261
2262/// Notify other participants that a buffer has been updated. This is
2263/// allowed for guests as long as the update is limited to selections.
2264async fn update_buffer(
2265 request: proto::UpdateBuffer,
2266 response: Response<proto::UpdateBuffer>,
2267 session: Session,
2268) -> Result<()> {
2269 let project_id = ProjectId::from_proto(request.project_id);
2270 let mut capability = Capability::ReadOnly;
2271
2272 for op in request.operations.iter() {
2273 match op.variant {
2274 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2275 Some(_) => capability = Capability::ReadWrite,
2276 }
2277 }
2278
2279 let host = {
2280 let guard = session
2281 .db()
2282 .await
2283 .connections_for_buffer_update(project_id, session.connection_id, capability)
2284 .await?;
2285
2286 let (host, guests) = &*guard;
2287
2288 broadcast(
2289 Some(session.connection_id),
2290 guests.clone(),
2291 |connection_id| {
2292 session
2293 .peer
2294 .forward_send(session.connection_id, connection_id, request.clone())
2295 },
2296 );
2297
2298 *host
2299 };
2300
2301 if host != session.connection_id {
2302 session
2303 .peer
2304 .forward_request(session.connection_id, host, request.clone())
2305 .await?;
2306 }
2307
2308 response.send(proto::Ack {})?;
2309 Ok(())
2310}
2311
2312async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2313 let project_id = ProjectId::from_proto(message.project_id);
2314
2315 let operation = message.operation.as_ref().context("invalid operation")?;
2316 let capability = match operation.variant.as_ref() {
2317 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2318 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2319 match buffer_op.variant {
2320 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2321 Capability::ReadOnly
2322 }
2323 _ => Capability::ReadWrite,
2324 }
2325 } else {
2326 Capability::ReadWrite
2327 }
2328 }
2329 Some(_) => Capability::ReadWrite,
2330 None => Capability::ReadOnly,
2331 };
2332
2333 let guard = session
2334 .db()
2335 .await
2336 .connections_for_buffer_update(project_id, session.connection_id, capability)
2337 .await?;
2338
2339 let (host, guests) = &*guard;
2340
2341 broadcast(
2342 Some(session.connection_id),
2343 guests.iter().chain([host]).copied(),
2344 |connection_id| {
2345 session
2346 .peer
2347 .forward_send(session.connection_id, connection_id, message.clone())
2348 },
2349 );
2350
2351 Ok(())
2352}
2353
2354/// Notify other participants that a project has been updated.
2355async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2356 request: T,
2357 session: Session,
2358) -> Result<()> {
2359 let project_id = ProjectId::from_proto(request.remote_entity_id());
2360 let project_connection_ids = session
2361 .db()
2362 .await
2363 .project_connection_ids(project_id, session.connection_id, false)
2364 .await?;
2365
2366 broadcast(
2367 Some(session.connection_id),
2368 project_connection_ids.iter().copied(),
2369 |connection_id| {
2370 session
2371 .peer
2372 .forward_send(session.connection_id, connection_id, request.clone())
2373 },
2374 );
2375 Ok(())
2376}
2377
2378/// Start following another user in a call.
2379async fn follow(
2380 request: proto::Follow,
2381 response: Response<proto::Follow>,
2382 session: Session,
2383) -> Result<()> {
2384 let room_id = RoomId::from_proto(request.room_id);
2385 let project_id = request.project_id.map(ProjectId::from_proto);
2386 let leader_id = request
2387 .leader_id
2388 .ok_or_else(|| anyhow!("invalid leader id"))?
2389 .into();
2390 let follower_id = session.connection_id;
2391
2392 session
2393 .db()
2394 .await
2395 .check_room_participants(room_id, leader_id, session.connection_id)
2396 .await?;
2397
2398 let response_payload = session
2399 .peer
2400 .forward_request(session.connection_id, leader_id, request)
2401 .await?;
2402 response.send(response_payload)?;
2403
2404 if let Some(project_id) = project_id {
2405 let room = session
2406 .db()
2407 .await
2408 .follow(room_id, project_id, leader_id, follower_id)
2409 .await?;
2410 room_updated(&room, &session.peer);
2411 }
2412
2413 Ok(())
2414}
2415
2416/// Stop following another user in a call.
2417async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2418 let room_id = RoomId::from_proto(request.room_id);
2419 let project_id = request.project_id.map(ProjectId::from_proto);
2420 let leader_id = request
2421 .leader_id
2422 .ok_or_else(|| anyhow!("invalid leader id"))?
2423 .into();
2424 let follower_id = session.connection_id;
2425
2426 session
2427 .db()
2428 .await
2429 .check_room_participants(room_id, leader_id, session.connection_id)
2430 .await?;
2431
2432 session
2433 .peer
2434 .forward_send(session.connection_id, leader_id, request)?;
2435
2436 if let Some(project_id) = project_id {
2437 let room = session
2438 .db()
2439 .await
2440 .unfollow(room_id, project_id, leader_id, follower_id)
2441 .await?;
2442 room_updated(&room, &session.peer);
2443 }
2444
2445 Ok(())
2446}
2447
2448/// Notify everyone following you of your current location.
2449async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2450 let room_id = RoomId::from_proto(request.room_id);
2451 let database = session.db.lock().await;
2452
2453 let connection_ids = if let Some(project_id) = request.project_id {
2454 let project_id = ProjectId::from_proto(project_id);
2455 database
2456 .project_connection_ids(project_id, session.connection_id, true)
2457 .await?
2458 } else {
2459 database
2460 .room_connection_ids(room_id, session.connection_id)
2461 .await?
2462 };
2463
2464 // For now, don't send view update messages back to that view's current leader.
2465 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2466 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2467 _ => None,
2468 });
2469
2470 for connection_id in connection_ids.iter().cloned() {
2471 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2472 session
2473 .peer
2474 .forward_send(session.connection_id, connection_id, request.clone())?;
2475 }
2476 }
2477 Ok(())
2478}
2479
2480/// Get public data about users.
2481async fn get_users(
2482 request: proto::GetUsers,
2483 response: Response<proto::GetUsers>,
2484 session: Session,
2485) -> Result<()> {
2486 let user_ids = request
2487 .user_ids
2488 .into_iter()
2489 .map(UserId::from_proto)
2490 .collect();
2491 let users = session
2492 .db()
2493 .await
2494 .get_users_by_ids(user_ids)
2495 .await?
2496 .into_iter()
2497 .map(|user| proto::User {
2498 id: user.id.to_proto(),
2499 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2500 github_login: user.github_login,
2501 email: user.email_address,
2502 name: user.name,
2503 })
2504 .collect();
2505 response.send(proto::UsersResponse { users })?;
2506 Ok(())
2507}
2508
2509/// Search for users (to invite) buy Github login
2510async fn fuzzy_search_users(
2511 request: proto::FuzzySearchUsers,
2512 response: Response<proto::FuzzySearchUsers>,
2513 session: Session,
2514) -> Result<()> {
2515 let query = request.query;
2516 let users = match query.len() {
2517 0 => vec![],
2518 1 | 2 => session
2519 .db()
2520 .await
2521 .get_user_by_github_login(&query)
2522 .await?
2523 .into_iter()
2524 .collect(),
2525 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2526 };
2527 let users = users
2528 .into_iter()
2529 .filter(|user| user.id != session.user_id())
2530 .map(|user| proto::User {
2531 id: user.id.to_proto(),
2532 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2533 github_login: user.github_login,
2534 name: user.name,
2535 email: user.email_address,
2536 })
2537 .collect();
2538 response.send(proto::UsersResponse { users })?;
2539 Ok(())
2540}
2541
2542/// Send a contact request to another user.
2543async fn request_contact(
2544 request: proto::RequestContact,
2545 response: Response<proto::RequestContact>,
2546 session: Session,
2547) -> Result<()> {
2548 let requester_id = session.user_id();
2549 let responder_id = UserId::from_proto(request.responder_id);
2550 if requester_id == responder_id {
2551 return Err(anyhow!("cannot add yourself as a contact"))?;
2552 }
2553
2554 let notifications = session
2555 .db()
2556 .await
2557 .send_contact_request(requester_id, responder_id)
2558 .await?;
2559
2560 // Update outgoing contact requests of requester
2561 let mut update = proto::UpdateContacts::default();
2562 update.outgoing_requests.push(responder_id.to_proto());
2563 for connection_id in session
2564 .connection_pool()
2565 .await
2566 .user_connection_ids(requester_id)
2567 {
2568 session.peer.send(connection_id, update.clone())?;
2569 }
2570
2571 // Update incoming contact requests of responder
2572 let mut update = proto::UpdateContacts::default();
2573 update
2574 .incoming_requests
2575 .push(proto::IncomingContactRequest {
2576 requester_id: requester_id.to_proto(),
2577 });
2578 let connection_pool = session.connection_pool().await;
2579 for connection_id in connection_pool.user_connection_ids(responder_id) {
2580 session.peer.send(connection_id, update.clone())?;
2581 }
2582
2583 send_notifications(&connection_pool, &session.peer, notifications);
2584
2585 response.send(proto::Ack {})?;
2586 Ok(())
2587}
2588
2589/// Accept or decline a contact request
2590async fn respond_to_contact_request(
2591 request: proto::RespondToContactRequest,
2592 response: Response<proto::RespondToContactRequest>,
2593 session: Session,
2594) -> Result<()> {
2595 let responder_id = session.user_id();
2596 let requester_id = UserId::from_proto(request.requester_id);
2597 let db = session.db().await;
2598 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2599 db.dismiss_contact_notification(responder_id, requester_id)
2600 .await?;
2601 } else {
2602 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2603
2604 let notifications = db
2605 .respond_to_contact_request(responder_id, requester_id, accept)
2606 .await?;
2607 let requester_busy = db.is_user_busy(requester_id).await?;
2608 let responder_busy = db.is_user_busy(responder_id).await?;
2609
2610 let pool = session.connection_pool().await;
2611 // Update responder with new contact
2612 let mut update = proto::UpdateContacts::default();
2613 if accept {
2614 update
2615 .contacts
2616 .push(contact_for_user(requester_id, requester_busy, &pool));
2617 }
2618 update
2619 .remove_incoming_requests
2620 .push(requester_id.to_proto());
2621 for connection_id in pool.user_connection_ids(responder_id) {
2622 session.peer.send(connection_id, update.clone())?;
2623 }
2624
2625 // Update requester with new contact
2626 let mut update = proto::UpdateContacts::default();
2627 if accept {
2628 update
2629 .contacts
2630 .push(contact_for_user(responder_id, responder_busy, &pool));
2631 }
2632 update
2633 .remove_outgoing_requests
2634 .push(responder_id.to_proto());
2635
2636 for connection_id in pool.user_connection_ids(requester_id) {
2637 session.peer.send(connection_id, update.clone())?;
2638 }
2639
2640 send_notifications(&pool, &session.peer, notifications);
2641 }
2642
2643 response.send(proto::Ack {})?;
2644 Ok(())
2645}
2646
2647/// Remove a contact.
2648async fn remove_contact(
2649 request: proto::RemoveContact,
2650 response: Response<proto::RemoveContact>,
2651 session: Session,
2652) -> Result<()> {
2653 let requester_id = session.user_id();
2654 let responder_id = UserId::from_proto(request.user_id);
2655 let db = session.db().await;
2656 let (contact_accepted, deleted_notification_id) =
2657 db.remove_contact(requester_id, responder_id).await?;
2658
2659 let pool = session.connection_pool().await;
2660 // Update outgoing contact requests of requester
2661 let mut update = proto::UpdateContacts::default();
2662 if contact_accepted {
2663 update.remove_contacts.push(responder_id.to_proto());
2664 } else {
2665 update
2666 .remove_outgoing_requests
2667 .push(responder_id.to_proto());
2668 }
2669 for connection_id in pool.user_connection_ids(requester_id) {
2670 session.peer.send(connection_id, update.clone())?;
2671 }
2672
2673 // Update incoming contact requests of responder
2674 let mut update = proto::UpdateContacts::default();
2675 if contact_accepted {
2676 update.remove_contacts.push(requester_id.to_proto());
2677 } else {
2678 update
2679 .remove_incoming_requests
2680 .push(requester_id.to_proto());
2681 }
2682 for connection_id in pool.user_connection_ids(responder_id) {
2683 session.peer.send(connection_id, update.clone())?;
2684 if let Some(notification_id) = deleted_notification_id {
2685 session.peer.send(
2686 connection_id,
2687 proto::DeleteNotification {
2688 notification_id: notification_id.to_proto(),
2689 },
2690 )?;
2691 }
2692 }
2693
2694 response.send(proto::Ack {})?;
2695 Ok(())
2696}
2697
2698fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2699 version.0.minor() < 139
2700}
2701
2702async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2703 let plan = session.current_plan(&session.db().await).await?;
2704
2705 session
2706 .peer
2707 .send(
2708 session.connection_id,
2709 proto::UpdateUserPlan { plan: plan.into() },
2710 )
2711 .trace_err();
2712
2713 Ok(())
2714}
2715
2716async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2717 subscribe_user_to_channels(session.user_id(), &session).await?;
2718 Ok(())
2719}
2720
2721async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2722 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2723 let mut pool = session.connection_pool().await;
2724 for membership in &channels_for_user.channel_memberships {
2725 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2726 }
2727 session.peer.send(
2728 session.connection_id,
2729 build_update_user_channels(&channels_for_user),
2730 )?;
2731 session.peer.send(
2732 session.connection_id,
2733 build_channels_update(channels_for_user),
2734 )?;
2735 Ok(())
2736}
2737
2738/// Creates a new channel.
2739async fn create_channel(
2740 request: proto::CreateChannel,
2741 response: Response<proto::CreateChannel>,
2742 session: Session,
2743) -> Result<()> {
2744 let db = session.db().await;
2745
2746 let parent_id = request.parent_id.map(ChannelId::from_proto);
2747 let (channel, membership) = db
2748 .create_channel(&request.name, parent_id, session.user_id())
2749 .await?;
2750
2751 let root_id = channel.root_id();
2752 let channel = Channel::from_model(channel);
2753
2754 response.send(proto::CreateChannelResponse {
2755 channel: Some(channel.to_proto()),
2756 parent_id: request.parent_id,
2757 })?;
2758
2759 let mut connection_pool = session.connection_pool().await;
2760 if let Some(membership) = membership {
2761 connection_pool.subscribe_to_channel(
2762 membership.user_id,
2763 membership.channel_id,
2764 membership.role,
2765 );
2766 let update = proto::UpdateUserChannels {
2767 channel_memberships: vec![proto::ChannelMembership {
2768 channel_id: membership.channel_id.to_proto(),
2769 role: membership.role.into(),
2770 }],
2771 ..Default::default()
2772 };
2773 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2774 session.peer.send(connection_id, update.clone())?;
2775 }
2776 }
2777
2778 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2779 if !role.can_see_channel(channel.visibility) {
2780 continue;
2781 }
2782
2783 let update = proto::UpdateChannels {
2784 channels: vec![channel.to_proto()],
2785 ..Default::default()
2786 };
2787 session.peer.send(connection_id, update.clone())?;
2788 }
2789
2790 Ok(())
2791}
2792
2793/// Delete a channel
2794async fn delete_channel(
2795 request: proto::DeleteChannel,
2796 response: Response<proto::DeleteChannel>,
2797 session: Session,
2798) -> Result<()> {
2799 let db = session.db().await;
2800
2801 let channel_id = request.channel_id;
2802 let (root_channel, removed_channels) = db
2803 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2804 .await?;
2805 response.send(proto::Ack {})?;
2806
2807 // Notify members of removed channels
2808 let mut update = proto::UpdateChannels::default();
2809 update
2810 .delete_channels
2811 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2812
2813 let connection_pool = session.connection_pool().await;
2814 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2815 session.peer.send(connection_id, update.clone())?;
2816 }
2817
2818 Ok(())
2819}
2820
2821/// Invite someone to join a channel.
2822async fn invite_channel_member(
2823 request: proto::InviteChannelMember,
2824 response: Response<proto::InviteChannelMember>,
2825 session: Session,
2826) -> Result<()> {
2827 let db = session.db().await;
2828 let channel_id = ChannelId::from_proto(request.channel_id);
2829 let invitee_id = UserId::from_proto(request.user_id);
2830 let InviteMemberResult {
2831 channel,
2832 notifications,
2833 } = db
2834 .invite_channel_member(
2835 channel_id,
2836 invitee_id,
2837 session.user_id(),
2838 request.role().into(),
2839 )
2840 .await?;
2841
2842 let update = proto::UpdateChannels {
2843 channel_invitations: vec![channel.to_proto()],
2844 ..Default::default()
2845 };
2846
2847 let connection_pool = session.connection_pool().await;
2848 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2849 session.peer.send(connection_id, update.clone())?;
2850 }
2851
2852 send_notifications(&connection_pool, &session.peer, notifications);
2853
2854 response.send(proto::Ack {})?;
2855 Ok(())
2856}
2857
2858/// remove someone from a channel
2859async fn remove_channel_member(
2860 request: proto::RemoveChannelMember,
2861 response: Response<proto::RemoveChannelMember>,
2862 session: Session,
2863) -> Result<()> {
2864 let db = session.db().await;
2865 let channel_id = ChannelId::from_proto(request.channel_id);
2866 let member_id = UserId::from_proto(request.user_id);
2867
2868 let RemoveChannelMemberResult {
2869 membership_update,
2870 notification_id,
2871 } = db
2872 .remove_channel_member(channel_id, member_id, session.user_id())
2873 .await?;
2874
2875 let mut connection_pool = session.connection_pool().await;
2876 notify_membership_updated(
2877 &mut connection_pool,
2878 membership_update,
2879 member_id,
2880 &session.peer,
2881 );
2882 for connection_id in connection_pool.user_connection_ids(member_id) {
2883 if let Some(notification_id) = notification_id {
2884 session
2885 .peer
2886 .send(
2887 connection_id,
2888 proto::DeleteNotification {
2889 notification_id: notification_id.to_proto(),
2890 },
2891 )
2892 .trace_err();
2893 }
2894 }
2895
2896 response.send(proto::Ack {})?;
2897 Ok(())
2898}
2899
2900/// Toggle the channel between public and private.
2901/// Care is taken to maintain the invariant that public channels only descend from public channels,
2902/// (though members-only channels can appear at any point in the hierarchy).
2903async fn set_channel_visibility(
2904 request: proto::SetChannelVisibility,
2905 response: Response<proto::SetChannelVisibility>,
2906 session: Session,
2907) -> Result<()> {
2908 let db = session.db().await;
2909 let channel_id = ChannelId::from_proto(request.channel_id);
2910 let visibility = request.visibility().into();
2911
2912 let channel_model = db
2913 .set_channel_visibility(channel_id, visibility, session.user_id())
2914 .await?;
2915 let root_id = channel_model.root_id();
2916 let channel = Channel::from_model(channel_model);
2917
2918 let mut connection_pool = session.connection_pool().await;
2919 for (user_id, role) in connection_pool
2920 .channel_user_ids(root_id)
2921 .collect::<Vec<_>>()
2922 .into_iter()
2923 {
2924 let update = if role.can_see_channel(channel.visibility) {
2925 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2926 proto::UpdateChannels {
2927 channels: vec![channel.to_proto()],
2928 ..Default::default()
2929 }
2930 } else {
2931 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2932 proto::UpdateChannels {
2933 delete_channels: vec![channel.id.to_proto()],
2934 ..Default::default()
2935 }
2936 };
2937
2938 for connection_id in connection_pool.user_connection_ids(user_id) {
2939 session.peer.send(connection_id, update.clone())?;
2940 }
2941 }
2942
2943 response.send(proto::Ack {})?;
2944 Ok(())
2945}
2946
2947/// Alter the role for a user in the channel.
2948async fn set_channel_member_role(
2949 request: proto::SetChannelMemberRole,
2950 response: Response<proto::SetChannelMemberRole>,
2951 session: Session,
2952) -> Result<()> {
2953 let db = session.db().await;
2954 let channel_id = ChannelId::from_proto(request.channel_id);
2955 let member_id = UserId::from_proto(request.user_id);
2956 let result = db
2957 .set_channel_member_role(
2958 channel_id,
2959 session.user_id(),
2960 member_id,
2961 request.role().into(),
2962 )
2963 .await?;
2964
2965 match result {
2966 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2967 let mut connection_pool = session.connection_pool().await;
2968 notify_membership_updated(
2969 &mut connection_pool,
2970 membership_update,
2971 member_id,
2972 &session.peer,
2973 )
2974 }
2975 db::SetMemberRoleResult::InviteUpdated(channel) => {
2976 let update = proto::UpdateChannels {
2977 channel_invitations: vec![channel.to_proto()],
2978 ..Default::default()
2979 };
2980
2981 for connection_id in session
2982 .connection_pool()
2983 .await
2984 .user_connection_ids(member_id)
2985 {
2986 session.peer.send(connection_id, update.clone())?;
2987 }
2988 }
2989 }
2990
2991 response.send(proto::Ack {})?;
2992 Ok(())
2993}
2994
2995/// Change the name of a channel
2996async fn rename_channel(
2997 request: proto::RenameChannel,
2998 response: Response<proto::RenameChannel>,
2999 session: Session,
3000) -> Result<()> {
3001 let db = session.db().await;
3002 let channel_id = ChannelId::from_proto(request.channel_id);
3003 let channel_model = db
3004 .rename_channel(channel_id, session.user_id(), &request.name)
3005 .await?;
3006 let root_id = channel_model.root_id();
3007 let channel = Channel::from_model(channel_model);
3008
3009 response.send(proto::RenameChannelResponse {
3010 channel: Some(channel.to_proto()),
3011 })?;
3012
3013 let connection_pool = session.connection_pool().await;
3014 let update = proto::UpdateChannels {
3015 channels: vec![channel.to_proto()],
3016 ..Default::default()
3017 };
3018 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3019 if role.can_see_channel(channel.visibility) {
3020 session.peer.send(connection_id, update.clone())?;
3021 }
3022 }
3023
3024 Ok(())
3025}
3026
3027/// Move a channel to a new parent.
3028async fn move_channel(
3029 request: proto::MoveChannel,
3030 response: Response<proto::MoveChannel>,
3031 session: Session,
3032) -> Result<()> {
3033 let channel_id = ChannelId::from_proto(request.channel_id);
3034 let to = ChannelId::from_proto(request.to);
3035
3036 let (root_id, channels) = session
3037 .db()
3038 .await
3039 .move_channel(channel_id, to, session.user_id())
3040 .await?;
3041
3042 let connection_pool = session.connection_pool().await;
3043 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3044 let channels = channels
3045 .iter()
3046 .filter_map(|channel| {
3047 if role.can_see_channel(channel.visibility) {
3048 Some(channel.to_proto())
3049 } else {
3050 None
3051 }
3052 })
3053 .collect::<Vec<_>>();
3054 if channels.is_empty() {
3055 continue;
3056 }
3057
3058 let update = proto::UpdateChannels {
3059 channels,
3060 ..Default::default()
3061 };
3062
3063 session.peer.send(connection_id, update.clone())?;
3064 }
3065
3066 response.send(Ack {})?;
3067 Ok(())
3068}
3069
3070/// Get the list of channel members
3071async fn get_channel_members(
3072 request: proto::GetChannelMembers,
3073 response: Response<proto::GetChannelMembers>,
3074 session: Session,
3075) -> Result<()> {
3076 let db = session.db().await;
3077 let channel_id = ChannelId::from_proto(request.channel_id);
3078 let limit = if request.limit == 0 {
3079 u16::MAX as u64
3080 } else {
3081 request.limit
3082 };
3083 let (members, users) = db
3084 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3085 .await?;
3086 response.send(proto::GetChannelMembersResponse { members, users })?;
3087 Ok(())
3088}
3089
3090/// Accept or decline a channel invitation.
3091async fn respond_to_channel_invite(
3092 request: proto::RespondToChannelInvite,
3093 response: Response<proto::RespondToChannelInvite>,
3094 session: Session,
3095) -> Result<()> {
3096 let db = session.db().await;
3097 let channel_id = ChannelId::from_proto(request.channel_id);
3098 let RespondToChannelInvite {
3099 membership_update,
3100 notifications,
3101 } = db
3102 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3103 .await?;
3104
3105 let mut connection_pool = session.connection_pool().await;
3106 if let Some(membership_update) = membership_update {
3107 notify_membership_updated(
3108 &mut connection_pool,
3109 membership_update,
3110 session.user_id(),
3111 &session.peer,
3112 );
3113 } else {
3114 let update = proto::UpdateChannels {
3115 remove_channel_invitations: vec![channel_id.to_proto()],
3116 ..Default::default()
3117 };
3118
3119 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3120 session.peer.send(connection_id, update.clone())?;
3121 }
3122 };
3123
3124 send_notifications(&connection_pool, &session.peer, notifications);
3125
3126 response.send(proto::Ack {})?;
3127
3128 Ok(())
3129}
3130
3131/// Join the channels' room
3132async fn join_channel(
3133 request: proto::JoinChannel,
3134 response: Response<proto::JoinChannel>,
3135 session: Session,
3136) -> Result<()> {
3137 let channel_id = ChannelId::from_proto(request.channel_id);
3138 join_channel_internal(channel_id, Box::new(response), session).await
3139}
3140
3141trait JoinChannelInternalResponse {
3142 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3143}
3144impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3145 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3146 Response::<proto::JoinChannel>::send(self, result)
3147 }
3148}
3149impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3150 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3151 Response::<proto::JoinRoom>::send(self, result)
3152 }
3153}
3154
3155async fn join_channel_internal(
3156 channel_id: ChannelId,
3157 response: Box<impl JoinChannelInternalResponse>,
3158 session: Session,
3159) -> Result<()> {
3160 let joined_room = {
3161 let mut db = session.db().await;
3162 // If zed quits without leaving the room, and the user re-opens zed before the
3163 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3164 // room they were in.
3165 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3166 tracing::info!(
3167 stale_connection_id = %connection,
3168 "cleaning up stale connection",
3169 );
3170 drop(db);
3171 leave_room_for_session(&session, connection).await?;
3172 db = session.db().await;
3173 }
3174
3175 let (joined_room, membership_updated, role) = db
3176 .join_channel(channel_id, session.user_id(), session.connection_id)
3177 .await?;
3178
3179 let live_kit_connection_info =
3180 session
3181 .app_state
3182 .livekit_client
3183 .as_ref()
3184 .and_then(|live_kit| {
3185 let (can_publish, token) = if role == ChannelRole::Guest {
3186 (
3187 false,
3188 live_kit
3189 .guest_token(
3190 &joined_room.room.livekit_room,
3191 &session.user_id().to_string(),
3192 )
3193 .trace_err()?,
3194 )
3195 } else {
3196 (
3197 true,
3198 live_kit
3199 .room_token(
3200 &joined_room.room.livekit_room,
3201 &session.user_id().to_string(),
3202 )
3203 .trace_err()?,
3204 )
3205 };
3206
3207 Some(LiveKitConnectionInfo {
3208 server_url: live_kit.url().into(),
3209 token,
3210 can_publish,
3211 })
3212 });
3213
3214 response.send(proto::JoinRoomResponse {
3215 room: Some(joined_room.room.clone()),
3216 channel_id: joined_room
3217 .channel
3218 .as_ref()
3219 .map(|channel| channel.id.to_proto()),
3220 live_kit_connection_info,
3221 })?;
3222
3223 let mut connection_pool = session.connection_pool().await;
3224 if let Some(membership_updated) = membership_updated {
3225 notify_membership_updated(
3226 &mut connection_pool,
3227 membership_updated,
3228 session.user_id(),
3229 &session.peer,
3230 );
3231 }
3232
3233 room_updated(&joined_room.room, &session.peer);
3234
3235 joined_room
3236 };
3237
3238 channel_updated(
3239 &joined_room
3240 .channel
3241 .ok_or_else(|| anyhow!("channel not returned"))?,
3242 &joined_room.room,
3243 &session.peer,
3244 &*session.connection_pool().await,
3245 );
3246
3247 update_user_contacts(session.user_id(), &session).await?;
3248 Ok(())
3249}
3250
3251/// Start editing the channel notes
3252async fn join_channel_buffer(
3253 request: proto::JoinChannelBuffer,
3254 response: Response<proto::JoinChannelBuffer>,
3255 session: Session,
3256) -> Result<()> {
3257 let db = session.db().await;
3258 let channel_id = ChannelId::from_proto(request.channel_id);
3259
3260 let open_response = db
3261 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3262 .await?;
3263
3264 let collaborators = open_response.collaborators.clone();
3265 response.send(open_response)?;
3266
3267 let update = UpdateChannelBufferCollaborators {
3268 channel_id: channel_id.to_proto(),
3269 collaborators: collaborators.clone(),
3270 };
3271 channel_buffer_updated(
3272 session.connection_id,
3273 collaborators
3274 .iter()
3275 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3276 &update,
3277 &session.peer,
3278 );
3279
3280 Ok(())
3281}
3282
3283/// Edit the channel notes
3284async fn update_channel_buffer(
3285 request: proto::UpdateChannelBuffer,
3286 session: Session,
3287) -> Result<()> {
3288 let db = session.db().await;
3289 let channel_id = ChannelId::from_proto(request.channel_id);
3290
3291 let (collaborators, epoch, version) = db
3292 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3293 .await?;
3294
3295 channel_buffer_updated(
3296 session.connection_id,
3297 collaborators.clone(),
3298 &proto::UpdateChannelBuffer {
3299 channel_id: channel_id.to_proto(),
3300 operations: request.operations,
3301 },
3302 &session.peer,
3303 );
3304
3305 let pool = &*session.connection_pool().await;
3306
3307 let non_collaborators =
3308 pool.channel_connection_ids(channel_id)
3309 .filter_map(|(connection_id, _)| {
3310 if collaborators.contains(&connection_id) {
3311 None
3312 } else {
3313 Some(connection_id)
3314 }
3315 });
3316
3317 broadcast(None, non_collaborators, |peer_id| {
3318 session.peer.send(
3319 peer_id,
3320 proto::UpdateChannels {
3321 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3322 channel_id: channel_id.to_proto(),
3323 epoch: epoch as u64,
3324 version: version.clone(),
3325 }],
3326 ..Default::default()
3327 },
3328 )
3329 });
3330
3331 Ok(())
3332}
3333
3334/// Rejoin the channel notes after a connection blip
3335async fn rejoin_channel_buffers(
3336 request: proto::RejoinChannelBuffers,
3337 response: Response<proto::RejoinChannelBuffers>,
3338 session: Session,
3339) -> Result<()> {
3340 let db = session.db().await;
3341 let buffers = db
3342 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3343 .await?;
3344
3345 for rejoined_buffer in &buffers {
3346 let collaborators_to_notify = rejoined_buffer
3347 .buffer
3348 .collaborators
3349 .iter()
3350 .filter_map(|c| Some(c.peer_id?.into()));
3351 channel_buffer_updated(
3352 session.connection_id,
3353 collaborators_to_notify,
3354 &proto::UpdateChannelBufferCollaborators {
3355 channel_id: rejoined_buffer.buffer.channel_id,
3356 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3357 },
3358 &session.peer,
3359 );
3360 }
3361
3362 response.send(proto::RejoinChannelBuffersResponse {
3363 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3364 })?;
3365
3366 Ok(())
3367}
3368
3369/// Stop editing the channel notes
3370async fn leave_channel_buffer(
3371 request: proto::LeaveChannelBuffer,
3372 response: Response<proto::LeaveChannelBuffer>,
3373 session: Session,
3374) -> Result<()> {
3375 let db = session.db().await;
3376 let channel_id = ChannelId::from_proto(request.channel_id);
3377
3378 let left_buffer = db
3379 .leave_channel_buffer(channel_id, session.connection_id)
3380 .await?;
3381
3382 response.send(Ack {})?;
3383
3384 channel_buffer_updated(
3385 session.connection_id,
3386 left_buffer.connections,
3387 &proto::UpdateChannelBufferCollaborators {
3388 channel_id: channel_id.to_proto(),
3389 collaborators: left_buffer.collaborators,
3390 },
3391 &session.peer,
3392 );
3393
3394 Ok(())
3395}
3396
3397fn channel_buffer_updated<T: EnvelopedMessage>(
3398 sender_id: ConnectionId,
3399 collaborators: impl IntoIterator<Item = ConnectionId>,
3400 message: &T,
3401 peer: &Peer,
3402) {
3403 broadcast(Some(sender_id), collaborators, |peer_id| {
3404 peer.send(peer_id, message.clone())
3405 });
3406}
3407
3408fn send_notifications(
3409 connection_pool: &ConnectionPool,
3410 peer: &Peer,
3411 notifications: db::NotificationBatch,
3412) {
3413 for (user_id, notification) in notifications {
3414 for connection_id in connection_pool.user_connection_ids(user_id) {
3415 if let Err(error) = peer.send(
3416 connection_id,
3417 proto::AddNotification {
3418 notification: Some(notification.clone()),
3419 },
3420 ) {
3421 tracing::error!(
3422 "failed to send notification to {:?} {}",
3423 connection_id,
3424 error
3425 );
3426 }
3427 }
3428 }
3429}
3430
3431/// Send a message to the channel
3432async fn send_channel_message(
3433 request: proto::SendChannelMessage,
3434 response: Response<proto::SendChannelMessage>,
3435 session: Session,
3436) -> Result<()> {
3437 // Validate the message body.
3438 let body = request.body.trim().to_string();
3439 if body.len() > MAX_MESSAGE_LEN {
3440 return Err(anyhow!("message is too long"))?;
3441 }
3442 if body.is_empty() {
3443 return Err(anyhow!("message can't be blank"))?;
3444 }
3445
3446 // TODO: adjust mentions if body is trimmed
3447
3448 let timestamp = OffsetDateTime::now_utc();
3449 let nonce = request
3450 .nonce
3451 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3452
3453 let channel_id = ChannelId::from_proto(request.channel_id);
3454 let CreatedChannelMessage {
3455 message_id,
3456 participant_connection_ids,
3457 notifications,
3458 } = session
3459 .db()
3460 .await
3461 .create_channel_message(
3462 channel_id,
3463 session.user_id(),
3464 &body,
3465 &request.mentions,
3466 timestamp,
3467 nonce.clone().into(),
3468 request.reply_to_message_id.map(MessageId::from_proto),
3469 )
3470 .await?;
3471
3472 let message = proto::ChannelMessage {
3473 sender_id: session.user_id().to_proto(),
3474 id: message_id.to_proto(),
3475 body,
3476 mentions: request.mentions,
3477 timestamp: timestamp.unix_timestamp() as u64,
3478 nonce: Some(nonce),
3479 reply_to_message_id: request.reply_to_message_id,
3480 edited_at: None,
3481 };
3482 broadcast(
3483 Some(session.connection_id),
3484 participant_connection_ids.clone(),
3485 |connection| {
3486 session.peer.send(
3487 connection,
3488 proto::ChannelMessageSent {
3489 channel_id: channel_id.to_proto(),
3490 message: Some(message.clone()),
3491 },
3492 )
3493 },
3494 );
3495 response.send(proto::SendChannelMessageResponse {
3496 message: Some(message),
3497 })?;
3498
3499 let pool = &*session.connection_pool().await;
3500 let non_participants =
3501 pool.channel_connection_ids(channel_id)
3502 .filter_map(|(connection_id, _)| {
3503 if participant_connection_ids.contains(&connection_id) {
3504 None
3505 } else {
3506 Some(connection_id)
3507 }
3508 });
3509 broadcast(None, non_participants, |peer_id| {
3510 session.peer.send(
3511 peer_id,
3512 proto::UpdateChannels {
3513 latest_channel_message_ids: vec![proto::ChannelMessageId {
3514 channel_id: channel_id.to_proto(),
3515 message_id: message_id.to_proto(),
3516 }],
3517 ..Default::default()
3518 },
3519 )
3520 });
3521 send_notifications(pool, &session.peer, notifications);
3522
3523 Ok(())
3524}
3525
3526/// Delete a channel message
3527async fn remove_channel_message(
3528 request: proto::RemoveChannelMessage,
3529 response: Response<proto::RemoveChannelMessage>,
3530 session: Session,
3531) -> Result<()> {
3532 let channel_id = ChannelId::from_proto(request.channel_id);
3533 let message_id = MessageId::from_proto(request.message_id);
3534 let (connection_ids, existing_notification_ids) = session
3535 .db()
3536 .await
3537 .remove_channel_message(channel_id, message_id, session.user_id())
3538 .await?;
3539
3540 broadcast(
3541 Some(session.connection_id),
3542 connection_ids,
3543 move |connection| {
3544 session.peer.send(connection, request.clone())?;
3545
3546 for notification_id in &existing_notification_ids {
3547 session.peer.send(
3548 connection,
3549 proto::DeleteNotification {
3550 notification_id: (*notification_id).to_proto(),
3551 },
3552 )?;
3553 }
3554
3555 Ok(())
3556 },
3557 );
3558 response.send(proto::Ack {})?;
3559 Ok(())
3560}
3561
3562async fn update_channel_message(
3563 request: proto::UpdateChannelMessage,
3564 response: Response<proto::UpdateChannelMessage>,
3565 session: Session,
3566) -> Result<()> {
3567 let channel_id = ChannelId::from_proto(request.channel_id);
3568 let message_id = MessageId::from_proto(request.message_id);
3569 let updated_at = OffsetDateTime::now_utc();
3570 let UpdatedChannelMessage {
3571 message_id,
3572 participant_connection_ids,
3573 notifications,
3574 reply_to_message_id,
3575 timestamp,
3576 deleted_mention_notification_ids,
3577 updated_mention_notifications,
3578 } = session
3579 .db()
3580 .await
3581 .update_channel_message(
3582 channel_id,
3583 message_id,
3584 session.user_id(),
3585 request.body.as_str(),
3586 &request.mentions,
3587 updated_at,
3588 )
3589 .await?;
3590
3591 let nonce = request
3592 .nonce
3593 .clone()
3594 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3595
3596 let message = proto::ChannelMessage {
3597 sender_id: session.user_id().to_proto(),
3598 id: message_id.to_proto(),
3599 body: request.body.clone(),
3600 mentions: request.mentions.clone(),
3601 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3602 nonce: Some(nonce),
3603 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3604 edited_at: Some(updated_at.unix_timestamp() as u64),
3605 };
3606
3607 response.send(proto::Ack {})?;
3608
3609 let pool = &*session.connection_pool().await;
3610 broadcast(
3611 Some(session.connection_id),
3612 participant_connection_ids,
3613 |connection| {
3614 session.peer.send(
3615 connection,
3616 proto::ChannelMessageUpdate {
3617 channel_id: channel_id.to_proto(),
3618 message: Some(message.clone()),
3619 },
3620 )?;
3621
3622 for notification_id in &deleted_mention_notification_ids {
3623 session.peer.send(
3624 connection,
3625 proto::DeleteNotification {
3626 notification_id: (*notification_id).to_proto(),
3627 },
3628 )?;
3629 }
3630
3631 for notification in &updated_mention_notifications {
3632 session.peer.send(
3633 connection,
3634 proto::UpdateNotification {
3635 notification: Some(notification.clone()),
3636 },
3637 )?;
3638 }
3639
3640 Ok(())
3641 },
3642 );
3643
3644 send_notifications(pool, &session.peer, notifications);
3645
3646 Ok(())
3647}
3648
3649/// Mark a channel message as read
3650async fn acknowledge_channel_message(
3651 request: proto::AckChannelMessage,
3652 session: Session,
3653) -> Result<()> {
3654 let channel_id = ChannelId::from_proto(request.channel_id);
3655 let message_id = MessageId::from_proto(request.message_id);
3656 let notifications = session
3657 .db()
3658 .await
3659 .observe_channel_message(channel_id, session.user_id(), message_id)
3660 .await?;
3661 send_notifications(
3662 &*session.connection_pool().await,
3663 &session.peer,
3664 notifications,
3665 );
3666 Ok(())
3667}
3668
3669/// Mark a buffer version as synced
3670async fn acknowledge_buffer_version(
3671 request: proto::AckBufferOperation,
3672 session: Session,
3673) -> Result<()> {
3674 let buffer_id = BufferId::from_proto(request.buffer_id);
3675 session
3676 .db()
3677 .await
3678 .observe_buffer_version(
3679 buffer_id,
3680 session.user_id(),
3681 request.epoch as i32,
3682 &request.version,
3683 )
3684 .await?;
3685 Ok(())
3686}
3687
3688async fn count_language_model_tokens(
3689 request: proto::CountLanguageModelTokens,
3690 response: Response<proto::CountLanguageModelTokens>,
3691 session: Session,
3692 config: &Config,
3693) -> Result<()> {
3694 authorize_access_to_legacy_llm_endpoints(&session).await?;
3695
3696 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3697 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3698 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3699 };
3700
3701 session
3702 .app_state
3703 .rate_limiter
3704 .check(&*rate_limit, session.user_id())
3705 .await?;
3706
3707 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3708 Some(proto::LanguageModelProvider::Google) => {
3709 let api_key = config
3710 .google_ai_api_key
3711 .as_ref()
3712 .context("no Google AI API key configured on the server")?;
3713 google_ai::count_tokens(
3714 session.http_client.as_ref(),
3715 google_ai::API_URL,
3716 api_key,
3717 serde_json::from_str(&request.request)?,
3718 )
3719 .await?
3720 }
3721 _ => return Err(anyhow!("unsupported provider"))?,
3722 };
3723
3724 response.send(proto::CountLanguageModelTokensResponse {
3725 token_count: result.total_tokens as u32,
3726 })?;
3727
3728 Ok(())
3729}
3730
3731struct ZedProCountLanguageModelTokensRateLimit;
3732
3733impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3734 fn capacity(&self) -> usize {
3735 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3736 .ok()
3737 .and_then(|v| v.parse().ok())
3738 .unwrap_or(600) // Picked arbitrarily
3739 }
3740
3741 fn refill_duration(&self) -> chrono::Duration {
3742 chrono::Duration::hours(1)
3743 }
3744
3745 fn db_name(&self) -> &'static str {
3746 "zed-pro:count-language-model-tokens"
3747 }
3748}
3749
3750struct FreeCountLanguageModelTokensRateLimit;
3751
3752impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3753 fn capacity(&self) -> usize {
3754 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3755 .ok()
3756 .and_then(|v| v.parse().ok())
3757 .unwrap_or(600 / 10) // Picked arbitrarily
3758 }
3759
3760 fn refill_duration(&self) -> chrono::Duration {
3761 chrono::Duration::hours(1)
3762 }
3763
3764 fn db_name(&self) -> &'static str {
3765 "free:count-language-model-tokens"
3766 }
3767}
3768
3769struct ZedProComputeEmbeddingsRateLimit;
3770
3771impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3772 fn capacity(&self) -> usize {
3773 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3774 .ok()
3775 .and_then(|v| v.parse().ok())
3776 .unwrap_or(5000) // Picked arbitrarily
3777 }
3778
3779 fn refill_duration(&self) -> chrono::Duration {
3780 chrono::Duration::hours(1)
3781 }
3782
3783 fn db_name(&self) -> &'static str {
3784 "zed-pro:compute-embeddings"
3785 }
3786}
3787
3788struct FreeComputeEmbeddingsRateLimit;
3789
3790impl RateLimit for FreeComputeEmbeddingsRateLimit {
3791 fn capacity(&self) -> usize {
3792 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3793 .ok()
3794 .and_then(|v| v.parse().ok())
3795 .unwrap_or(5000 / 10) // Picked arbitrarily
3796 }
3797
3798 fn refill_duration(&self) -> chrono::Duration {
3799 chrono::Duration::hours(1)
3800 }
3801
3802 fn db_name(&self) -> &'static str {
3803 "free:compute-embeddings"
3804 }
3805}
3806
3807async fn compute_embeddings(
3808 request: proto::ComputeEmbeddings,
3809 response: Response<proto::ComputeEmbeddings>,
3810 session: Session,
3811 api_key: Option<Arc<str>>,
3812) -> Result<()> {
3813 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3814 authorize_access_to_legacy_llm_endpoints(&session).await?;
3815
3816 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3817 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3818 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3819 };
3820
3821 session
3822 .app_state
3823 .rate_limiter
3824 .check(&*rate_limit, session.user_id())
3825 .await?;
3826
3827 let embeddings = match request.model.as_str() {
3828 "openai/text-embedding-3-small" => {
3829 open_ai::embed(
3830 session.http_client.as_ref(),
3831 OPEN_AI_API_URL,
3832 &api_key,
3833 OpenAiEmbeddingModel::TextEmbedding3Small,
3834 request.texts.iter().map(|text| text.as_str()),
3835 )
3836 .await?
3837 }
3838 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3839 };
3840
3841 let embeddings = request
3842 .texts
3843 .iter()
3844 .map(|text| {
3845 let mut hasher = sha2::Sha256::new();
3846 hasher.update(text.as_bytes());
3847 let result = hasher.finalize();
3848 result.to_vec()
3849 })
3850 .zip(
3851 embeddings
3852 .data
3853 .into_iter()
3854 .map(|embedding| embedding.embedding),
3855 )
3856 .collect::<HashMap<_, _>>();
3857
3858 let db = session.db().await;
3859 db.save_embeddings(&request.model, &embeddings)
3860 .await
3861 .context("failed to save embeddings")
3862 .trace_err();
3863
3864 response.send(proto::ComputeEmbeddingsResponse {
3865 embeddings: embeddings
3866 .into_iter()
3867 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3868 .collect(),
3869 })?;
3870 Ok(())
3871}
3872
3873async fn get_cached_embeddings(
3874 request: proto::GetCachedEmbeddings,
3875 response: Response<proto::GetCachedEmbeddings>,
3876 session: Session,
3877) -> Result<()> {
3878 authorize_access_to_legacy_llm_endpoints(&session).await?;
3879
3880 let db = session.db().await;
3881 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3882
3883 response.send(proto::GetCachedEmbeddingsResponse {
3884 embeddings: embeddings
3885 .into_iter()
3886 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3887 .collect(),
3888 })?;
3889 Ok(())
3890}
3891
3892/// This is leftover from before the LLM service.
3893///
3894/// The endpoints protected by this check will be moved there eventually.
3895async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3896 if session.is_staff() {
3897 Ok(())
3898 } else {
3899 Err(anyhow!("permission denied"))?
3900 }
3901}
3902
3903/// Get a Supermaven API key for the user
3904async fn get_supermaven_api_key(
3905 _request: proto::GetSupermavenApiKey,
3906 response: Response<proto::GetSupermavenApiKey>,
3907 session: Session,
3908) -> Result<()> {
3909 let user_id: String = session.user_id().to_string();
3910 if !session.is_staff() {
3911 return Err(anyhow!("supermaven not enabled for this account"))?;
3912 }
3913
3914 let email = session
3915 .email()
3916 .ok_or_else(|| anyhow!("user must have an email"))?;
3917
3918 let supermaven_admin_api = session
3919 .supermaven_client
3920 .as_ref()
3921 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3922
3923 let result = supermaven_admin_api
3924 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3925 .await?;
3926
3927 response.send(proto::GetSupermavenApiKeyResponse {
3928 api_key: result.api_key,
3929 })?;
3930
3931 Ok(())
3932}
3933
3934/// Start receiving chat updates for a channel
3935async fn join_channel_chat(
3936 request: proto::JoinChannelChat,
3937 response: Response<proto::JoinChannelChat>,
3938 session: Session,
3939) -> Result<()> {
3940 let channel_id = ChannelId::from_proto(request.channel_id);
3941
3942 let db = session.db().await;
3943 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3944 .await?;
3945 let messages = db
3946 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3947 .await?;
3948 response.send(proto::JoinChannelChatResponse {
3949 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3950 messages,
3951 })?;
3952 Ok(())
3953}
3954
3955/// Stop receiving chat updates for a channel
3956async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3957 let channel_id = ChannelId::from_proto(request.channel_id);
3958 session
3959 .db()
3960 .await
3961 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3962 .await?;
3963 Ok(())
3964}
3965
3966/// Retrieve the chat history for a channel
3967async fn get_channel_messages(
3968 request: proto::GetChannelMessages,
3969 response: Response<proto::GetChannelMessages>,
3970 session: Session,
3971) -> Result<()> {
3972 let channel_id = ChannelId::from_proto(request.channel_id);
3973 let messages = session
3974 .db()
3975 .await
3976 .get_channel_messages(
3977 channel_id,
3978 session.user_id(),
3979 MESSAGE_COUNT_PER_PAGE,
3980 Some(MessageId::from_proto(request.before_message_id)),
3981 )
3982 .await?;
3983 response.send(proto::GetChannelMessagesResponse {
3984 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3985 messages,
3986 })?;
3987 Ok(())
3988}
3989
3990/// Retrieve specific chat messages
3991async fn get_channel_messages_by_id(
3992 request: proto::GetChannelMessagesById,
3993 response: Response<proto::GetChannelMessagesById>,
3994 session: Session,
3995) -> Result<()> {
3996 let message_ids = request
3997 .message_ids
3998 .iter()
3999 .map(|id| MessageId::from_proto(*id))
4000 .collect::<Vec<_>>();
4001 let messages = session
4002 .db()
4003 .await
4004 .get_channel_messages_by_id(session.user_id(), &message_ids)
4005 .await?;
4006 response.send(proto::GetChannelMessagesResponse {
4007 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4008 messages,
4009 })?;
4010 Ok(())
4011}
4012
4013/// Retrieve the current users notifications
4014async fn get_notifications(
4015 request: proto::GetNotifications,
4016 response: Response<proto::GetNotifications>,
4017 session: Session,
4018) -> Result<()> {
4019 let notifications = session
4020 .db()
4021 .await
4022 .get_notifications(
4023 session.user_id(),
4024 NOTIFICATION_COUNT_PER_PAGE,
4025 request.before_id.map(db::NotificationId::from_proto),
4026 )
4027 .await?;
4028 response.send(proto::GetNotificationsResponse {
4029 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4030 notifications,
4031 })?;
4032 Ok(())
4033}
4034
4035/// Mark notifications as read
4036async fn mark_notification_as_read(
4037 request: proto::MarkNotificationRead,
4038 response: Response<proto::MarkNotificationRead>,
4039 session: Session,
4040) -> Result<()> {
4041 let database = &session.db().await;
4042 let notifications = database
4043 .mark_notification_as_read_by_id(
4044 session.user_id(),
4045 NotificationId::from_proto(request.notification_id),
4046 )
4047 .await?;
4048 send_notifications(
4049 &*session.connection_pool().await,
4050 &session.peer,
4051 notifications,
4052 );
4053 response.send(proto::Ack {})?;
4054 Ok(())
4055}
4056
4057/// Get the current users information
4058async fn get_private_user_info(
4059 _request: proto::GetPrivateUserInfo,
4060 response: Response<proto::GetPrivateUserInfo>,
4061 session: Session,
4062) -> Result<()> {
4063 let db = session.db().await;
4064
4065 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4066 let user = db
4067 .get_user_by_id(session.user_id())
4068 .await?
4069 .ok_or_else(|| anyhow!("user not found"))?;
4070 let flags = db.get_user_flags(session.user_id()).await?;
4071
4072 response.send(proto::GetPrivateUserInfoResponse {
4073 metrics_id,
4074 staff: user.admin,
4075 flags,
4076 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4077 })?;
4078 Ok(())
4079}
4080
4081/// Accept the terms of service (tos) on behalf of the current user
4082async fn accept_terms_of_service(
4083 _request: proto::AcceptTermsOfService,
4084 response: Response<proto::AcceptTermsOfService>,
4085 session: Session,
4086) -> Result<()> {
4087 let db = session.db().await;
4088
4089 let accepted_tos_at = Utc::now();
4090 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4091 .await?;
4092
4093 response.send(proto::AcceptTermsOfServiceResponse {
4094 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4095 })?;
4096 Ok(())
4097}
4098
4099/// The minimum account age an account must have in order to use the LLM service.
4100pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4101
4102async fn get_llm_api_token(
4103 _request: proto::GetLlmToken,
4104 response: Response<proto::GetLlmToken>,
4105 session: Session,
4106) -> Result<()> {
4107 let db = session.db().await;
4108
4109 let flags = db.get_user_flags(session.user_id()).await?;
4110 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4111
4112 if !session.is_staff() && !has_language_models_feature_flag {
4113 Err(anyhow!("permission denied"))?
4114 }
4115
4116 let user_id = session.user_id();
4117 let user = db
4118 .get_user_by_id(user_id)
4119 .await?
4120 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4121
4122 if user.accepted_tos_at.is_none() {
4123 Err(anyhow!("terms of service not accepted"))?
4124 }
4125
4126 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4127 let billing_preferences = db.get_billing_preferences(user.id).await?;
4128
4129 let token = LlmTokenClaims::create(
4130 &user,
4131 session.is_staff(),
4132 billing_preferences,
4133 &flags,
4134 has_llm_subscription,
4135 session.current_plan(&db).await?,
4136 session.system_id.clone(),
4137 &session.app_state.config,
4138 )?;
4139 response.send(proto::GetLlmTokenResponse { token })?;
4140 Ok(())
4141}
4142
4143fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4144 let message = match message {
4145 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4146 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4147 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4148 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4149 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4150 code: frame.code.into(),
4151 reason: frame.reason,
4152 })),
4153 // We should never receive a frame while reading the message, according
4154 // to the `tungstenite` maintainers:
4155 //
4156 // > It cannot occur when you read messages from the WebSocket, but it
4157 // > can be used when you want to send the raw frames (e.g. you want to
4158 // > send the frames to the WebSocket without composing the full message first).
4159 // >
4160 // > — https://github.com/snapview/tungstenite-rs/issues/268
4161 TungsteniteMessage::Frame(_) => {
4162 bail!("received an unexpected frame while reading the message")
4163 }
4164 };
4165
4166 Ok(message)
4167}
4168
4169fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4170 match message {
4171 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4172 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4173 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4174 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4175 AxumMessage::Close(frame) => {
4176 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4177 code: frame.code.into(),
4178 reason: frame.reason,
4179 }))
4180 }
4181 }
4182}
4183
4184fn notify_membership_updated(
4185 connection_pool: &mut ConnectionPool,
4186 result: MembershipUpdated,
4187 user_id: UserId,
4188 peer: &Peer,
4189) {
4190 for membership in &result.new_channels.channel_memberships {
4191 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4192 }
4193 for channel_id in &result.removed_channels {
4194 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4195 }
4196
4197 let user_channels_update = proto::UpdateUserChannels {
4198 channel_memberships: result
4199 .new_channels
4200 .channel_memberships
4201 .iter()
4202 .map(|cm| proto::ChannelMembership {
4203 channel_id: cm.channel_id.to_proto(),
4204 role: cm.role.into(),
4205 })
4206 .collect(),
4207 ..Default::default()
4208 };
4209
4210 let mut update = build_channels_update(result.new_channels);
4211 update.delete_channels = result
4212 .removed_channels
4213 .into_iter()
4214 .map(|id| id.to_proto())
4215 .collect();
4216 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4217
4218 for connection_id in connection_pool.user_connection_ids(user_id) {
4219 peer.send(connection_id, user_channels_update.clone())
4220 .trace_err();
4221 peer.send(connection_id, update.clone()).trace_err();
4222 }
4223}
4224
4225fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4226 proto::UpdateUserChannels {
4227 channel_memberships: channels
4228 .channel_memberships
4229 .iter()
4230 .map(|m| proto::ChannelMembership {
4231 channel_id: m.channel_id.to_proto(),
4232 role: m.role.into(),
4233 })
4234 .collect(),
4235 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4236 observed_channel_message_id: channels.observed_channel_messages.clone(),
4237 }
4238}
4239
4240fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4241 let mut update = proto::UpdateChannels::default();
4242
4243 for channel in channels.channels {
4244 update.channels.push(channel.to_proto());
4245 }
4246
4247 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4248 update.latest_channel_message_ids = channels.latest_channel_messages;
4249
4250 for (channel_id, participants) in channels.channel_participants {
4251 update
4252 .channel_participants
4253 .push(proto::ChannelParticipants {
4254 channel_id: channel_id.to_proto(),
4255 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4256 });
4257 }
4258
4259 for channel in channels.invited_channels {
4260 update.channel_invitations.push(channel.to_proto());
4261 }
4262
4263 update
4264}
4265
4266fn build_initial_contacts_update(
4267 contacts: Vec<db::Contact>,
4268 pool: &ConnectionPool,
4269) -> proto::UpdateContacts {
4270 let mut update = proto::UpdateContacts::default();
4271
4272 for contact in contacts {
4273 match contact {
4274 db::Contact::Accepted { user_id, busy } => {
4275 update.contacts.push(contact_for_user(user_id, busy, pool));
4276 }
4277 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4278 db::Contact::Incoming { user_id } => {
4279 update
4280 .incoming_requests
4281 .push(proto::IncomingContactRequest {
4282 requester_id: user_id.to_proto(),
4283 })
4284 }
4285 }
4286 }
4287
4288 update
4289}
4290
4291fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4292 proto::Contact {
4293 user_id: user_id.to_proto(),
4294 online: pool.is_user_online(user_id),
4295 busy,
4296 }
4297}
4298
4299fn room_updated(room: &proto::Room, peer: &Peer) {
4300 broadcast(
4301 None,
4302 room.participants
4303 .iter()
4304 .filter_map(|participant| Some(participant.peer_id?.into())),
4305 |peer_id| {
4306 peer.send(
4307 peer_id,
4308 proto::RoomUpdated {
4309 room: Some(room.clone()),
4310 },
4311 )
4312 },
4313 );
4314}
4315
4316fn channel_updated(
4317 channel: &db::channel::Model,
4318 room: &proto::Room,
4319 peer: &Peer,
4320 pool: &ConnectionPool,
4321) {
4322 let participants = room
4323 .participants
4324 .iter()
4325 .map(|p| p.user_id)
4326 .collect::<Vec<_>>();
4327
4328 broadcast(
4329 None,
4330 pool.channel_connection_ids(channel.root_id())
4331 .filter_map(|(channel_id, role)| {
4332 role.can_see_channel(channel.visibility)
4333 .then_some(channel_id)
4334 }),
4335 |peer_id| {
4336 peer.send(
4337 peer_id,
4338 proto::UpdateChannels {
4339 channel_participants: vec![proto::ChannelParticipants {
4340 channel_id: channel.id.to_proto(),
4341 participant_user_ids: participants.clone(),
4342 }],
4343 ..Default::default()
4344 },
4345 )
4346 },
4347 );
4348}
4349
4350async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4351 let db = session.db().await;
4352
4353 let contacts = db.get_contacts(user_id).await?;
4354 let busy = db.is_user_busy(user_id).await?;
4355
4356 let pool = session.connection_pool().await;
4357 let updated_contact = contact_for_user(user_id, busy, &pool);
4358 for contact in contacts {
4359 if let db::Contact::Accepted {
4360 user_id: contact_user_id,
4361 ..
4362 } = contact
4363 {
4364 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4365 session
4366 .peer
4367 .send(
4368 contact_conn_id,
4369 proto::UpdateContacts {
4370 contacts: vec![updated_contact.clone()],
4371 remove_contacts: Default::default(),
4372 incoming_requests: Default::default(),
4373 remove_incoming_requests: Default::default(),
4374 outgoing_requests: Default::default(),
4375 remove_outgoing_requests: Default::default(),
4376 },
4377 )
4378 .trace_err();
4379 }
4380 }
4381 }
4382 Ok(())
4383}
4384
4385async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4386 let mut contacts_to_update = HashSet::default();
4387
4388 let room_id;
4389 let canceled_calls_to_user_ids;
4390 let livekit_room;
4391 let delete_livekit_room;
4392 let room;
4393 let channel;
4394
4395 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4396 contacts_to_update.insert(session.user_id());
4397
4398 for project in left_room.left_projects.values() {
4399 project_left(project, session);
4400 }
4401
4402 room_id = RoomId::from_proto(left_room.room.id);
4403 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4404 livekit_room = mem::take(&mut left_room.room.livekit_room);
4405 delete_livekit_room = left_room.deleted;
4406 room = mem::take(&mut left_room.room);
4407 channel = mem::take(&mut left_room.channel);
4408
4409 room_updated(&room, &session.peer);
4410 } else {
4411 return Ok(());
4412 }
4413
4414 if let Some(channel) = channel {
4415 channel_updated(
4416 &channel,
4417 &room,
4418 &session.peer,
4419 &*session.connection_pool().await,
4420 );
4421 }
4422
4423 {
4424 let pool = session.connection_pool().await;
4425 for canceled_user_id in canceled_calls_to_user_ids {
4426 for connection_id in pool.user_connection_ids(canceled_user_id) {
4427 session
4428 .peer
4429 .send(
4430 connection_id,
4431 proto::CallCanceled {
4432 room_id: room_id.to_proto(),
4433 },
4434 )
4435 .trace_err();
4436 }
4437 contacts_to_update.insert(canceled_user_id);
4438 }
4439 }
4440
4441 for contact_user_id in contacts_to_update {
4442 update_user_contacts(contact_user_id, session).await?;
4443 }
4444
4445 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4446 live_kit
4447 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4448 .await
4449 .trace_err();
4450
4451 if delete_livekit_room {
4452 live_kit.delete_room(livekit_room).await.trace_err();
4453 }
4454 }
4455
4456 Ok(())
4457}
4458
4459async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4460 let left_channel_buffers = session
4461 .db()
4462 .await
4463 .leave_channel_buffers(session.connection_id)
4464 .await?;
4465
4466 for left_buffer in left_channel_buffers {
4467 channel_buffer_updated(
4468 session.connection_id,
4469 left_buffer.connections,
4470 &proto::UpdateChannelBufferCollaborators {
4471 channel_id: left_buffer.channel_id.to_proto(),
4472 collaborators: left_buffer.collaborators,
4473 },
4474 &session.peer,
4475 );
4476 }
4477
4478 Ok(())
4479}
4480
4481fn project_left(project: &db::LeftProject, session: &Session) {
4482 for connection_id in &project.connection_ids {
4483 if project.should_unshare {
4484 session
4485 .peer
4486 .send(
4487 *connection_id,
4488 proto::UnshareProject {
4489 project_id: project.id.to_proto(),
4490 },
4491 )
4492 .trace_err();
4493 } else {
4494 session
4495 .peer
4496 .send(
4497 *connection_id,
4498 proto::RemoveProjectCollaborator {
4499 project_id: project.id.to_proto(),
4500 peer_id: Some(session.connection_id.into()),
4501 },
4502 )
4503 .trace_err();
4504 }
4505 }
4506}
4507
4508pub trait ResultExt {
4509 type Ok;
4510
4511 fn trace_err(self) -> Option<Self::Ok>;
4512}
4513
4514impl<T, E> ResultExt for Result<T, E>
4515where
4516 E: std::fmt::Debug,
4517{
4518 type Ok = T;
4519
4520 #[track_caller]
4521 fn trace_err(self) -> Option<T> {
4522 match self {
4523 Ok(value) => Some(value),
4524 Err(error) => {
4525 tracing::error!("{:?}", error);
4526 None
4527 }
4528 }
4529 }
4530}