1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 AppState, Config, Error, RateLimit, Result, auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14};
15use anyhow::{Context as _, anyhow, bail};
16use async_tungstenite::tungstenite::{
17 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
18};
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use http_client::HttpClient;
37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
38use reqwest_client::ReqwestClient;
39use rpc::proto::split_repository_update;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
45 stream::FuturesUnordered,
46};
47use prometheus::{IntGauge, register_int_gauge};
48use rpc::{
49 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 Arc, OnceLock,
67 atomic::{AtomicBool, Ordering::SeqCst},
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{MutexGuard, Semaphore, watch};
73use tower::ServiceBuilder;
74use tracing::{
75 Instrument,
76 field::{self},
77 info_span, instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
311 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
312 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
313 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
316 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
320 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(
325 forward_read_only_project_request::<proto::LanguageServerIdForName>,
326 )
327 .add_request_handler(
328 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
329 )
330 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
331 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
332 .add_request_handler(
333 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
334 )
335 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
358 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
359 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
361 .add_message_handler(create_buffer_for_peer)
362 .add_request_handler(update_buffer)
363 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
364 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
368 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
369 .add_request_handler(get_users)
370 .add_request_handler(fuzzy_search_users)
371 .add_request_handler(request_contact)
372 .add_request_handler(remove_contact)
373 .add_request_handler(respond_to_contact_request)
374 .add_message_handler(subscribe_to_channels)
375 .add_request_handler(create_channel)
376 .add_request_handler(delete_channel)
377 .add_request_handler(invite_channel_member)
378 .add_request_handler(remove_channel_member)
379 .add_request_handler(set_channel_member_role)
380 .add_request_handler(set_channel_visibility)
381 .add_request_handler(rename_channel)
382 .add_request_handler(join_channel_buffer)
383 .add_request_handler(leave_channel_buffer)
384 .add_message_handler(update_channel_buffer)
385 .add_request_handler(rejoin_channel_buffers)
386 .add_request_handler(get_channel_members)
387 .add_request_handler(respond_to_channel_invite)
388 .add_request_handler(join_channel)
389 .add_request_handler(join_channel_chat)
390 .add_message_handler(leave_channel_chat)
391 .add_request_handler(send_channel_message)
392 .add_request_handler(remove_channel_message)
393 .add_request_handler(update_channel_message)
394 .add_request_handler(get_channel_messages)
395 .add_request_handler(get_channel_messages_by_id)
396 .add_request_handler(get_notifications)
397 .add_request_handler(mark_notification_as_read)
398 .add_request_handler(move_channel)
399 .add_request_handler(follow)
400 .add_message_handler(unfollow)
401 .add_message_handler(update_followers)
402 .add_request_handler(get_private_user_info)
403 .add_request_handler(get_llm_api_token)
404 .add_request_handler(accept_terms_of_service)
405 .add_message_handler(acknowledge_channel_message)
406 .add_message_handler(acknowledge_buffer_version)
407 .add_request_handler(get_supermaven_api_key)
408 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
409 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
410 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
411 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
412 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
413 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
415 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
416 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
417 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
418 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
419 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
420 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
421 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
422 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
423 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
424 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
425 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
426 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
427 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
429 .add_message_handler(update_context)
430 .add_request_handler({
431 let app_state = app_state.clone();
432 move |request, response, session| {
433 let app_state = app_state.clone();
434 async move {
435 count_language_model_tokens(request, response, session, &app_state.config)
436 .await
437 }
438 }
439 })
440 .add_request_handler(get_cached_embeddings)
441 .add_request_handler({
442 let app_state = app_state.clone();
443 move |request, response, session| {
444 compute_embeddings(
445 request,
446 response,
447 session,
448 app_state.config.openai_api_key.clone(),
449 )
450 }
451 });
452
453 Arc::new(server)
454 }
455
456 pub async fn start(&self) -> Result<()> {
457 let server_id = *self.id.lock();
458 let app_state = self.app_state.clone();
459 let peer = self.peer.clone();
460 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
461 let pool = self.connection_pool.clone();
462 let livekit_client = self.app_state.livekit_client.clone();
463
464 let span = info_span!("start server");
465 self.app_state.executor.spawn_detached(
466 async move {
467 tracing::info!("waiting for cleanup timeout");
468 timeout.await;
469 tracing::info!("cleanup timeout expired, retrieving stale rooms");
470 if let Some((room_ids, channel_ids)) = app_state
471 .db
472 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
473 .await
474 .trace_err()
475 {
476 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
477 tracing::info!(
478 stale_channel_buffer_count = channel_ids.len(),
479 "retrieved stale channel buffers"
480 );
481
482 for channel_id in channel_ids {
483 if let Some(refreshed_channel_buffer) = app_state
484 .db
485 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
486 .await
487 .trace_err()
488 {
489 for connection_id in refreshed_channel_buffer.connection_ids {
490 peer.send(
491 connection_id,
492 proto::UpdateChannelBufferCollaborators {
493 channel_id: channel_id.to_proto(),
494 collaborators: refreshed_channel_buffer
495 .collaborators
496 .clone(),
497 },
498 )
499 .trace_err();
500 }
501 }
502 }
503
504 for room_id in room_ids {
505 let mut contacts_to_update = HashSet::default();
506 let mut canceled_calls_to_user_ids = Vec::new();
507 let mut livekit_room = String::new();
508 let mut delete_livekit_room = false;
509
510 if let Some(mut refreshed_room) = app_state
511 .db
512 .clear_stale_room_participants(room_id, server_id)
513 .await
514 .trace_err()
515 {
516 tracing::info!(
517 room_id = room_id.0,
518 new_participant_count = refreshed_room.room.participants.len(),
519 "refreshed room"
520 );
521 room_updated(&refreshed_room.room, &peer);
522 if let Some(channel) = refreshed_room.channel.as_ref() {
523 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
524 }
525 contacts_to_update
526 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
527 contacts_to_update
528 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
529 canceled_calls_to_user_ids =
530 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
531 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
532 delete_livekit_room = refreshed_room.room.participants.is_empty();
533 }
534
535 {
536 let pool = pool.lock();
537 for canceled_user_id in canceled_calls_to_user_ids {
538 for connection_id in pool.user_connection_ids(canceled_user_id) {
539 peer.send(
540 connection_id,
541 proto::CallCanceled {
542 room_id: room_id.to_proto(),
543 },
544 )
545 .trace_err();
546 }
547 }
548 }
549
550 for user_id in contacts_to_update {
551 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
552 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
553 if let Some((busy, contacts)) = busy.zip(contacts) {
554 let pool = pool.lock();
555 let updated_contact = contact_for_user(user_id, busy, &pool);
556 for contact in contacts {
557 if let db::Contact::Accepted {
558 user_id: contact_user_id,
559 ..
560 } = contact
561 {
562 for contact_conn_id in
563 pool.user_connection_ids(contact_user_id)
564 {
565 peer.send(
566 contact_conn_id,
567 proto::UpdateContacts {
568 contacts: vec![updated_contact.clone()],
569 remove_contacts: Default::default(),
570 incoming_requests: Default::default(),
571 remove_incoming_requests: Default::default(),
572 outgoing_requests: Default::default(),
573 remove_outgoing_requests: Default::default(),
574 },
575 )
576 .trace_err();
577 }
578 }
579 }
580 }
581 }
582
583 if let Some(live_kit) = livekit_client.as_ref() {
584 if delete_livekit_room {
585 live_kit.delete_room(livekit_room).await.trace_err();
586 }
587 }
588 }
589 }
590
591 app_state
592 .db
593 .delete_stale_servers(&app_state.config.zed_environment, server_id)
594 .await
595 .trace_err();
596 }
597 .instrument(span),
598 );
599 Ok(())
600 }
601
602 pub fn teardown(&self) {
603 self.peer.teardown();
604 self.connection_pool.lock().reset();
605 let _ = self.teardown.send(true);
606 }
607
608 #[cfg(test)]
609 pub fn reset(&self, id: ServerId) {
610 self.teardown();
611 *self.id.lock() = id;
612 self.peer.reset(id.0 as u32);
613 let _ = self.teardown.send(false);
614 }
615
616 #[cfg(test)]
617 pub fn id(&self) -> ServerId {
618 *self.id.lock()
619 }
620
621 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
622 where
623 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
624 Fut: 'static + Send + Future<Output = Result<()>>,
625 M: EnvelopedMessage,
626 {
627 let prev_handler = self.handlers.insert(
628 TypeId::of::<M>(),
629 Box::new(move |envelope, session| {
630 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
631 let received_at = envelope.received_at;
632 tracing::info!("message received");
633 let start_time = Instant::now();
634 let future = (handler)(*envelope, session);
635 async move {
636 let result = future.await;
637 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
638 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
639 let queue_duration_ms = total_duration_ms - processing_duration_ms;
640 let payload_type = M::NAME;
641
642 match result {
643 Err(error) => {
644 tracing::error!(
645 ?error,
646 total_duration_ms,
647 processing_duration_ms,
648 queue_duration_ms,
649 payload_type,
650 "error handling message"
651 )
652 }
653 Ok(()) => tracing::info!(
654 total_duration_ms,
655 processing_duration_ms,
656 queue_duration_ms,
657 "finished handling message"
658 ),
659 }
660 }
661 .boxed()
662 }),
663 );
664 if prev_handler.is_some() {
665 panic!("registered a handler for the same message twice");
666 }
667 self
668 }
669
670 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
671 where
672 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
673 Fut: 'static + Send + Future<Output = Result<()>>,
674 M: EnvelopedMessage,
675 {
676 self.add_handler(move |envelope, session| handler(envelope.payload, session));
677 self
678 }
679
680 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
681 where
682 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
683 Fut: Send + Future<Output = Result<()>>,
684 M: RequestMessage,
685 {
686 let handler = Arc::new(handler);
687 self.add_handler(move |envelope, session| {
688 let receipt = envelope.receipt();
689 let handler = handler.clone();
690 async move {
691 let peer = session.peer.clone();
692 let responded = Arc::new(AtomicBool::default());
693 let response = Response {
694 peer: peer.clone(),
695 responded: responded.clone(),
696 receipt,
697 };
698 match (handler)(envelope.payload, response, session).await {
699 Ok(()) => {
700 if responded.load(std::sync::atomic::Ordering::SeqCst) {
701 Ok(())
702 } else {
703 Err(anyhow!("handler did not send a response"))?
704 }
705 }
706 Err(error) => {
707 let proto_err = match &error {
708 Error::Internal(err) => err.to_proto(),
709 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
710 };
711 peer.respond_with_error(receipt, proto_err)?;
712 Err(error)
713 }
714 }
715 }
716 })
717 }
718
719 pub fn handle_connection(
720 self: &Arc<Self>,
721 connection: Connection,
722 address: String,
723 principal: Principal,
724 zed_version: ZedVersion,
725 geoip_country_code: Option<String>,
726 system_id: Option<String>,
727 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
728 executor: Executor,
729 ) -> impl Future<Output = ()> + use<> {
730 let this = self.clone();
731 let span = info_span!("handle connection", %address,
732 connection_id=field::Empty,
733 user_id=field::Empty,
734 login=field::Empty,
735 impersonator=field::Empty,
736 geoip_country_code=field::Empty
737 );
738 principal.update_span(&span);
739 if let Some(country_code) = geoip_country_code.as_ref() {
740 span.record("geoip_country_code", country_code);
741 }
742
743 let mut teardown = self.teardown.subscribe();
744 async move {
745 if *teardown.borrow() {
746 tracing::error!("server is tearing down");
747 return
748 }
749 let (connection_id, handle_io, mut incoming_rx) = this
750 .peer
751 .add_connection(connection, {
752 let executor = executor.clone();
753 move |duration| executor.sleep(duration)
754 });
755 tracing::Span::current().record("connection_id", format!("{}", connection_id));
756
757 tracing::info!("connection opened");
758
759 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
760 let http_client = match ReqwestClient::user_agent(&user_agent) {
761 Ok(http_client) => Arc::new(http_client),
762 Err(error) => {
763 tracing::error!(?error, "failed to create HTTP client");
764 return;
765 }
766 };
767
768 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
769 supermaven_admin_api_key.to_string(),
770 http_client.clone(),
771 )));
772
773 let session = Session {
774 principal: principal.clone(),
775 connection_id,
776 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
777 peer: this.peer.clone(),
778 connection_pool: this.connection_pool.clone(),
779 app_state: this.app_state.clone(),
780 http_client,
781 geoip_country_code,
782 system_id,
783 _executor: executor.clone(),
784 supermaven_client,
785 };
786
787 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
788 tracing::error!(?error, "failed to send initial client update");
789 return;
790 }
791
792 let handle_io = handle_io.fuse();
793 futures::pin_mut!(handle_io);
794
795 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
796 // This prevents deadlocks when e.g., client A performs a request to client B and
797 // client B performs a request to client A. If both clients stop processing further
798 // messages until their respective request completes, they won't have a chance to
799 // respond to the other client's request and cause a deadlock.
800 //
801 // This arrangement ensures we will attempt to process earlier messages first, but fall
802 // back to processing messages arrived later in the spirit of making progress.
803 let mut foreground_message_handlers = FuturesUnordered::new();
804 let concurrent_handlers = Arc::new(Semaphore::new(256));
805 loop {
806 let next_message = async {
807 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
808 let message = incoming_rx.next().await;
809 (permit, message)
810 }.fuse();
811 futures::pin_mut!(next_message);
812 futures::select_biased! {
813 _ = teardown.changed().fuse() => return,
814 result = handle_io => {
815 if let Err(error) = result {
816 tracing::error!(?error, "error handling I/O");
817 }
818 break;
819 }
820 _ = foreground_message_handlers.next() => {}
821 next_message = next_message => {
822 let (permit, message) = next_message;
823 if let Some(message) = message {
824 let type_name = message.payload_type_name();
825 // note: we copy all the fields from the parent span so we can query them in the logs.
826 // (https://github.com/tokio-rs/tracing/issues/2670).
827 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
828 user_id=field::Empty,
829 login=field::Empty,
830 impersonator=field::Empty,
831 );
832 principal.update_span(&span);
833 let span_enter = span.enter();
834 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
835 let is_background = message.is_background();
836 let handle_message = (handler)(message, session.clone());
837 drop(span_enter);
838
839 let handle_message = async move {
840 handle_message.await;
841 drop(permit);
842 }.instrument(span);
843 if is_background {
844 executor.spawn_detached(handle_message);
845 } else {
846 foreground_message_handlers.push(handle_message);
847 }
848 } else {
849 tracing::error!("no message handler");
850 }
851 } else {
852 tracing::info!("connection closed");
853 break;
854 }
855 }
856 }
857 }
858
859 drop(foreground_message_handlers);
860 tracing::info!("signing out");
861 if let Err(error) = connection_lost(session, teardown, executor).await {
862 tracing::error!(?error, "error signing out");
863 }
864
865 }.instrument(span)
866 }
867
868 async fn send_initial_client_update(
869 &self,
870 connection_id: ConnectionId,
871 principal: &Principal,
872 zed_version: ZedVersion,
873 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
874 session: &Session,
875 ) -> Result<()> {
876 self.peer.send(
877 connection_id,
878 proto::Hello {
879 peer_id: Some(connection_id.into()),
880 },
881 )?;
882 tracing::info!("sent hello message");
883 if let Some(send_connection_id) = send_connection_id.take() {
884 let _ = send_connection_id.send(connection_id);
885 }
886
887 match principal {
888 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
889 if !user.connected_once {
890 self.peer.send(connection_id, proto::ShowContacts {})?;
891 self.app_state
892 .db
893 .set_user_connected_once(user.id, true)
894 .await?;
895 }
896
897 update_user_plan(user.id, session).await?;
898
899 let contacts = self.app_state.db.get_contacts(user.id).await?;
900
901 {
902 let mut pool = self.connection_pool.lock();
903 pool.add_connection(connection_id, user.id, user.admin, zed_version);
904 self.peer.send(
905 connection_id,
906 build_initial_contacts_update(contacts, &pool),
907 )?;
908 }
909
910 if should_auto_subscribe_to_channels(zed_version) {
911 subscribe_user_to_channels(user.id, session).await?;
912 }
913
914 if let Some(incoming_call) =
915 self.app_state.db.incoming_call_for_user(user.id).await?
916 {
917 self.peer.send(connection_id, incoming_call)?;
918 }
919
920 update_user_contacts(user.id, session).await?;
921 }
922 }
923
924 Ok(())
925 }
926
927 pub async fn invite_code_redeemed(
928 self: &Arc<Self>,
929 inviter_id: UserId,
930 invitee_id: UserId,
931 ) -> Result<()> {
932 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
933 if let Some(code) = &user.invite_code {
934 let pool = self.connection_pool.lock();
935 let invitee_contact = contact_for_user(invitee_id, false, &pool);
936 for connection_id in pool.user_connection_ids(inviter_id) {
937 self.peer.send(
938 connection_id,
939 proto::UpdateContacts {
940 contacts: vec![invitee_contact.clone()],
941 ..Default::default()
942 },
943 )?;
944 self.peer.send(
945 connection_id,
946 proto::UpdateInviteInfo {
947 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
948 count: user.invite_count as u32,
949 },
950 )?;
951 }
952 }
953 }
954 Ok(())
955 }
956
957 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
958 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
959 if let Some(invite_code) = &user.invite_code {
960 let pool = self.connection_pool.lock();
961 for connection_id in pool.user_connection_ids(user_id) {
962 self.peer.send(
963 connection_id,
964 proto::UpdateInviteInfo {
965 url: format!(
966 "{}{}",
967 self.app_state.config.invite_link_prefix, invite_code
968 ),
969 count: user.invite_count as u32,
970 },
971 )?;
972 }
973 }
974 }
975 Ok(())
976 }
977
978 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
979 let pool = self.connection_pool.lock();
980 for connection_id in pool.user_connection_ids(user_id) {
981 self.peer
982 .send(connection_id, proto::RefreshLlmToken {})
983 .trace_err();
984 }
985 }
986
987 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
988 ServerSnapshot {
989 connection_pool: ConnectionPoolGuard {
990 guard: self.connection_pool.lock(),
991 _not_send: PhantomData,
992 },
993 peer: &self.peer,
994 }
995 }
996}
997
998impl Deref for ConnectionPoolGuard<'_> {
999 type Target = ConnectionPool;
1000
1001 fn deref(&self) -> &Self::Target {
1002 &self.guard
1003 }
1004}
1005
1006impl DerefMut for ConnectionPoolGuard<'_> {
1007 fn deref_mut(&mut self) -> &mut Self::Target {
1008 &mut self.guard
1009 }
1010}
1011
1012impl Drop for ConnectionPoolGuard<'_> {
1013 fn drop(&mut self) {
1014 #[cfg(test)]
1015 self.check_invariants();
1016 }
1017}
1018
1019fn broadcast<F>(
1020 sender_id: Option<ConnectionId>,
1021 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1022 mut f: F,
1023) where
1024 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1025{
1026 for receiver_id in receiver_ids {
1027 if Some(receiver_id) != sender_id {
1028 if let Err(error) = f(receiver_id) {
1029 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1030 }
1031 }
1032 }
1033}
1034
1035pub struct ProtocolVersion(u32);
1036
1037impl Header for ProtocolVersion {
1038 fn name() -> &'static HeaderName {
1039 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1040 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1041 }
1042
1043 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1044 where
1045 Self: Sized,
1046 I: Iterator<Item = &'i axum::http::HeaderValue>,
1047 {
1048 let version = values
1049 .next()
1050 .ok_or_else(axum::headers::Error::invalid)?
1051 .to_str()
1052 .map_err(|_| axum::headers::Error::invalid())?
1053 .parse()
1054 .map_err(|_| axum::headers::Error::invalid())?;
1055 Ok(Self(version))
1056 }
1057
1058 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1059 values.extend([self.0.to_string().parse().unwrap()]);
1060 }
1061}
1062
1063pub struct AppVersionHeader(SemanticVersion);
1064impl Header for AppVersionHeader {
1065 fn name() -> &'static HeaderName {
1066 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1067 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1068 }
1069
1070 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1071 where
1072 Self: Sized,
1073 I: Iterator<Item = &'i axum::http::HeaderValue>,
1074 {
1075 let version = values
1076 .next()
1077 .ok_or_else(axum::headers::Error::invalid)?
1078 .to_str()
1079 .map_err(|_| axum::headers::Error::invalid())?
1080 .parse()
1081 .map_err(|_| axum::headers::Error::invalid())?;
1082 Ok(Self(version))
1083 }
1084
1085 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1086 values.extend([self.0.to_string().parse().unwrap()]);
1087 }
1088}
1089
1090pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1091 Router::new()
1092 .route("/rpc", get(handle_websocket_request))
1093 .layer(
1094 ServiceBuilder::new()
1095 .layer(Extension(server.app_state.clone()))
1096 .layer(middleware::from_fn(auth::validate_header)),
1097 )
1098 .route("/metrics", get(handle_metrics))
1099 .layer(Extension(server))
1100}
1101
1102pub async fn handle_websocket_request(
1103 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1104 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1105 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1106 Extension(server): Extension<Arc<Server>>,
1107 Extension(principal): Extension<Principal>,
1108 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110 ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112 if protocol_version != rpc::PROTOCOL_VERSION {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "client must be upgraded".to_string(),
1116 )
1117 .into_response();
1118 }
1119
1120 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121 return (
1122 StatusCode::UPGRADE_REQUIRED,
1123 "no version header found".to_string(),
1124 )
1125 .into_response();
1126 };
1127
1128 if !version.can_collaborate() {
1129 return (
1130 StatusCode::UPGRADE_REQUIRED,
1131 "client must be upgraded".to_string(),
1132 )
1133 .into_response();
1134 }
1135
1136 let socket_address = socket_address.to_string();
1137 ws.on_upgrade(move |socket| {
1138 let socket = socket
1139 .map_ok(to_tungstenite_message)
1140 .err_into()
1141 .with(|message| async move { to_axum_message(message) });
1142 let connection = Connection::new(Box::pin(socket));
1143 async move {
1144 server
1145 .handle_connection(
1146 connection,
1147 socket_address,
1148 principal,
1149 version,
1150 country_code_header.map(|header| header.to_string()),
1151 system_id_header.map(|header| header.to_string()),
1152 None,
1153 Executor::Production,
1154 )
1155 .await;
1156 }
1157 })
1158}
1159
1160pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1161 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1162 let connections_metric = CONNECTIONS_METRIC
1163 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1164
1165 let connections = server
1166 .connection_pool
1167 .lock()
1168 .connections()
1169 .filter(|connection| !connection.admin)
1170 .count();
1171 connections_metric.set(connections as _);
1172
1173 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1174 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1175 register_int_gauge!(
1176 "shared_projects",
1177 "number of open projects with one or more guests"
1178 )
1179 .unwrap()
1180 });
1181
1182 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1183 shared_projects_metric.set(shared_projects as _);
1184
1185 let encoder = prometheus::TextEncoder::new();
1186 let metric_families = prometheus::gather();
1187 let encoded_metrics = encoder
1188 .encode_to_string(&metric_families)
1189 .map_err(|err| anyhow!("{}", err))?;
1190 Ok(encoded_metrics)
1191}
1192
1193#[instrument(err, skip(executor))]
1194async fn connection_lost(
1195 session: Session,
1196 mut teardown: watch::Receiver<bool>,
1197 executor: Executor,
1198) -> Result<()> {
1199 session.peer.disconnect(session.connection_id);
1200 session
1201 .connection_pool()
1202 .await
1203 .remove_connection(session.connection_id)?;
1204
1205 session
1206 .db()
1207 .await
1208 .connection_lost(session.connection_id)
1209 .await
1210 .trace_err();
1211
1212 futures::select_biased! {
1213 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1214
1215 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1216 leave_room_for_session(&session, session.connection_id).await.trace_err();
1217 leave_channel_buffers_for_session(&session)
1218 .await
1219 .trace_err();
1220
1221 if !session
1222 .connection_pool()
1223 .await
1224 .is_user_online(session.user_id())
1225 {
1226 let db = session.db().await;
1227 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1228 room_updated(&room, &session.peer);
1229 }
1230 }
1231
1232 update_user_contacts(session.user_id(), &session).await?;
1233 },
1234 _ = teardown.changed().fuse() => {}
1235 }
1236
1237 Ok(())
1238}
1239
1240/// Acknowledges a ping from a client, used to keep the connection alive.
1241async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1242 response.send(proto::Ack {})?;
1243 Ok(())
1244}
1245
1246/// Creates a new room for calling (outside of channels)
1247async fn create_room(
1248 _request: proto::CreateRoom,
1249 response: Response<proto::CreateRoom>,
1250 session: Session,
1251) -> Result<()> {
1252 let livekit_room = nanoid::nanoid!(30);
1253
1254 let live_kit_connection_info = util::maybe!(async {
1255 let live_kit = session.app_state.livekit_client.as_ref();
1256 let live_kit = live_kit?;
1257 let user_id = session.user_id().to_string();
1258
1259 let token = live_kit
1260 .room_token(&livekit_room, &user_id.to_string())
1261 .trace_err()?;
1262
1263 Some(proto::LiveKitConnectionInfo {
1264 server_url: live_kit.url().into(),
1265 token,
1266 can_publish: true,
1267 })
1268 })
1269 .await;
1270
1271 let room = session
1272 .db()
1273 .await
1274 .create_room(session.user_id(), session.connection_id, &livekit_room)
1275 .await?;
1276
1277 response.send(proto::CreateRoomResponse {
1278 room: Some(room.clone()),
1279 live_kit_connection_info,
1280 })?;
1281
1282 update_user_contacts(session.user_id(), &session).await?;
1283 Ok(())
1284}
1285
1286/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1287async fn join_room(
1288 request: proto::JoinRoom,
1289 response: Response<proto::JoinRoom>,
1290 session: Session,
1291) -> Result<()> {
1292 let room_id = RoomId::from_proto(request.id);
1293
1294 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1295
1296 if let Some(channel_id) = channel_id {
1297 return join_channel_internal(channel_id, Box::new(response), session).await;
1298 }
1299
1300 let joined_room = {
1301 let room = session
1302 .db()
1303 .await
1304 .join_room(room_id, session.user_id(), session.connection_id)
1305 .await?;
1306 room_updated(&room.room, &session.peer);
1307 room.into_inner()
1308 };
1309
1310 for connection_id in session
1311 .connection_pool()
1312 .await
1313 .user_connection_ids(session.user_id())
1314 {
1315 session
1316 .peer
1317 .send(
1318 connection_id,
1319 proto::CallCanceled {
1320 room_id: room_id.to_proto(),
1321 },
1322 )
1323 .trace_err();
1324 }
1325
1326 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1327 {
1328 live_kit
1329 .room_token(
1330 &joined_room.room.livekit_room,
1331 &session.user_id().to_string(),
1332 )
1333 .trace_err()
1334 .map(|token| proto::LiveKitConnectionInfo {
1335 server_url: live_kit.url().into(),
1336 token,
1337 can_publish: true,
1338 })
1339 } else {
1340 None
1341 };
1342
1343 response.send(proto::JoinRoomResponse {
1344 room: Some(joined_room.room),
1345 channel_id: None,
1346 live_kit_connection_info,
1347 })?;
1348
1349 update_user_contacts(session.user_id(), &session).await?;
1350 Ok(())
1351}
1352
1353/// Rejoin room is used to reconnect to a room after connection errors.
1354async fn rejoin_room(
1355 request: proto::RejoinRoom,
1356 response: Response<proto::RejoinRoom>,
1357 session: Session,
1358) -> Result<()> {
1359 let room;
1360 let channel;
1361 {
1362 let mut rejoined_room = session
1363 .db()
1364 .await
1365 .rejoin_room(request, session.user_id(), session.connection_id)
1366 .await?;
1367
1368 response.send(proto::RejoinRoomResponse {
1369 room: Some(rejoined_room.room.clone()),
1370 reshared_projects: rejoined_room
1371 .reshared_projects
1372 .iter()
1373 .map(|project| proto::ResharedProject {
1374 id: project.id.to_proto(),
1375 collaborators: project
1376 .collaborators
1377 .iter()
1378 .map(|collaborator| collaborator.to_proto())
1379 .collect(),
1380 })
1381 .collect(),
1382 rejoined_projects: rejoined_room
1383 .rejoined_projects
1384 .iter()
1385 .map(|rejoined_project| rejoined_project.to_proto())
1386 .collect(),
1387 })?;
1388 room_updated(&rejoined_room.room, &session.peer);
1389
1390 for project in &rejoined_room.reshared_projects {
1391 for collaborator in &project.collaborators {
1392 session
1393 .peer
1394 .send(
1395 collaborator.connection_id,
1396 proto::UpdateProjectCollaborator {
1397 project_id: project.id.to_proto(),
1398 old_peer_id: Some(project.old_connection_id.into()),
1399 new_peer_id: Some(session.connection_id.into()),
1400 },
1401 )
1402 .trace_err();
1403 }
1404
1405 broadcast(
1406 Some(session.connection_id),
1407 project
1408 .collaborators
1409 .iter()
1410 .map(|collaborator| collaborator.connection_id),
1411 |connection_id| {
1412 session.peer.forward_send(
1413 session.connection_id,
1414 connection_id,
1415 proto::UpdateProject {
1416 project_id: project.id.to_proto(),
1417 worktrees: project.worktrees.clone(),
1418 },
1419 )
1420 },
1421 );
1422 }
1423
1424 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1425
1426 let rejoined_room = rejoined_room.into_inner();
1427
1428 room = rejoined_room.room;
1429 channel = rejoined_room.channel;
1430 }
1431
1432 if let Some(channel) = channel {
1433 channel_updated(
1434 &channel,
1435 &room,
1436 &session.peer,
1437 &*session.connection_pool().await,
1438 );
1439 }
1440
1441 update_user_contacts(session.user_id(), &session).await?;
1442 Ok(())
1443}
1444
1445fn notify_rejoined_projects(
1446 rejoined_projects: &mut Vec<RejoinedProject>,
1447 session: &Session,
1448) -> Result<()> {
1449 for project in rejoined_projects.iter() {
1450 for collaborator in &project.collaborators {
1451 session
1452 .peer
1453 .send(
1454 collaborator.connection_id,
1455 proto::UpdateProjectCollaborator {
1456 project_id: project.id.to_proto(),
1457 old_peer_id: Some(project.old_connection_id.into()),
1458 new_peer_id: Some(session.connection_id.into()),
1459 },
1460 )
1461 .trace_err();
1462 }
1463 }
1464
1465 for project in rejoined_projects {
1466 for worktree in mem::take(&mut project.worktrees) {
1467 // Stream this worktree's entries.
1468 let message = proto::UpdateWorktree {
1469 project_id: project.id.to_proto(),
1470 worktree_id: worktree.id,
1471 abs_path: worktree.abs_path.clone(),
1472 root_name: worktree.root_name,
1473 updated_entries: worktree.updated_entries,
1474 removed_entries: worktree.removed_entries,
1475 scan_id: worktree.scan_id,
1476 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1477 updated_repositories: worktree.updated_repositories,
1478 removed_repositories: worktree.removed_repositories,
1479 };
1480 for update in proto::split_worktree_update(message) {
1481 session.peer.send(session.connection_id, update)?;
1482 }
1483
1484 // Stream this worktree's diagnostics.
1485 for summary in worktree.diagnostic_summaries {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateDiagnosticSummary {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 summary: Some(summary),
1492 },
1493 )?;
1494 }
1495
1496 for settings_file in worktree.settings_files {
1497 session.peer.send(
1498 session.connection_id,
1499 proto::UpdateWorktreeSettings {
1500 project_id: project.id.to_proto(),
1501 worktree_id: worktree.id,
1502 path: settings_file.path,
1503 content: Some(settings_file.content),
1504 kind: Some(settings_file.kind.to_proto().into()),
1505 },
1506 )?;
1507 }
1508 }
1509
1510 for repository in mem::take(&mut project.updated_repositories) {
1511 for update in split_repository_update(repository) {
1512 session.peer.send(session.connection_id, update)?;
1513 }
1514 }
1515
1516 for id in mem::take(&mut project.removed_repositories) {
1517 session.peer.send(
1518 session.connection_id,
1519 proto::RemoveRepository {
1520 project_id: project.id.to_proto(),
1521 id,
1522 },
1523 )?;
1524 }
1525 }
1526
1527 Ok(())
1528}
1529
1530/// leave room disconnects from the room.
1531async fn leave_room(
1532 _: proto::LeaveRoom,
1533 response: Response<proto::LeaveRoom>,
1534 session: Session,
1535) -> Result<()> {
1536 leave_room_for_session(&session, session.connection_id).await?;
1537 response.send(proto::Ack {})?;
1538 Ok(())
1539}
1540
1541/// Updates the permissions of someone else in the room.
1542async fn set_room_participant_role(
1543 request: proto::SetRoomParticipantRole,
1544 response: Response<proto::SetRoomParticipantRole>,
1545 session: Session,
1546) -> Result<()> {
1547 let user_id = UserId::from_proto(request.user_id);
1548 let role = ChannelRole::from(request.role());
1549
1550 let (livekit_room, can_publish) = {
1551 let room = session
1552 .db()
1553 .await
1554 .set_room_participant_role(
1555 session.user_id(),
1556 RoomId::from_proto(request.room_id),
1557 user_id,
1558 role,
1559 )
1560 .await?;
1561
1562 let livekit_room = room.livekit_room.clone();
1563 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1564 room_updated(&room, &session.peer);
1565 (livekit_room, can_publish)
1566 };
1567
1568 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1569 live_kit
1570 .update_participant(
1571 livekit_room.clone(),
1572 request.user_id.to_string(),
1573 livekit_api::proto::ParticipantPermission {
1574 can_subscribe: true,
1575 can_publish,
1576 can_publish_data: can_publish,
1577 hidden: false,
1578 recorder: false,
1579 },
1580 )
1581 .await
1582 .trace_err();
1583 }
1584
1585 response.send(proto::Ack {})?;
1586 Ok(())
1587}
1588
1589/// Call someone else into the current room
1590async fn call(
1591 request: proto::Call,
1592 response: Response<proto::Call>,
1593 session: Session,
1594) -> Result<()> {
1595 let room_id = RoomId::from_proto(request.room_id);
1596 let calling_user_id = session.user_id();
1597 let calling_connection_id = session.connection_id;
1598 let called_user_id = UserId::from_proto(request.called_user_id);
1599 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1600 if !session
1601 .db()
1602 .await
1603 .has_contact(calling_user_id, called_user_id)
1604 .await?
1605 {
1606 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1607 }
1608
1609 let incoming_call = {
1610 let (room, incoming_call) = &mut *session
1611 .db()
1612 .await
1613 .call(
1614 room_id,
1615 calling_user_id,
1616 calling_connection_id,
1617 called_user_id,
1618 initial_project_id,
1619 )
1620 .await?;
1621 room_updated(room, &session.peer);
1622 mem::take(incoming_call)
1623 };
1624 update_user_contacts(called_user_id, &session).await?;
1625
1626 let mut calls = session
1627 .connection_pool()
1628 .await
1629 .user_connection_ids(called_user_id)
1630 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1631 .collect::<FuturesUnordered<_>>();
1632
1633 while let Some(call_response) = calls.next().await {
1634 match call_response.as_ref() {
1635 Ok(_) => {
1636 response.send(proto::Ack {})?;
1637 return Ok(());
1638 }
1639 Err(_) => {
1640 call_response.trace_err();
1641 }
1642 }
1643 }
1644
1645 {
1646 let room = session
1647 .db()
1648 .await
1649 .call_failed(room_id, called_user_id)
1650 .await?;
1651 room_updated(&room, &session.peer);
1652 }
1653 update_user_contacts(called_user_id, &session).await?;
1654
1655 Err(anyhow!("failed to ring user"))?
1656}
1657
1658/// Cancel an outgoing call.
1659async fn cancel_call(
1660 request: proto::CancelCall,
1661 response: Response<proto::CancelCall>,
1662 session: Session,
1663) -> Result<()> {
1664 let called_user_id = UserId::from_proto(request.called_user_id);
1665 let room_id = RoomId::from_proto(request.room_id);
1666 {
1667 let room = session
1668 .db()
1669 .await
1670 .cancel_call(room_id, session.connection_id, called_user_id)
1671 .await?;
1672 room_updated(&room, &session.peer);
1673 }
1674
1675 for connection_id in session
1676 .connection_pool()
1677 .await
1678 .user_connection_ids(called_user_id)
1679 {
1680 session
1681 .peer
1682 .send(
1683 connection_id,
1684 proto::CallCanceled {
1685 room_id: room_id.to_proto(),
1686 },
1687 )
1688 .trace_err();
1689 }
1690 response.send(proto::Ack {})?;
1691
1692 update_user_contacts(called_user_id, &session).await?;
1693 Ok(())
1694}
1695
1696/// Decline an incoming call.
1697async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1698 let room_id = RoomId::from_proto(message.room_id);
1699 {
1700 let room = session
1701 .db()
1702 .await
1703 .decline_call(Some(room_id), session.user_id())
1704 .await?
1705 .ok_or_else(|| anyhow!("failed to decline call"))?;
1706 room_updated(&room, &session.peer);
1707 }
1708
1709 for connection_id in session
1710 .connection_pool()
1711 .await
1712 .user_connection_ids(session.user_id())
1713 {
1714 session
1715 .peer
1716 .send(
1717 connection_id,
1718 proto::CallCanceled {
1719 room_id: room_id.to_proto(),
1720 },
1721 )
1722 .trace_err();
1723 }
1724 update_user_contacts(session.user_id(), &session).await?;
1725 Ok(())
1726}
1727
1728/// Updates other participants in the room with your current location.
1729async fn update_participant_location(
1730 request: proto::UpdateParticipantLocation,
1731 response: Response<proto::UpdateParticipantLocation>,
1732 session: Session,
1733) -> Result<()> {
1734 let room_id = RoomId::from_proto(request.room_id);
1735 let location = request
1736 .location
1737 .ok_or_else(|| anyhow!("invalid location"))?;
1738
1739 let db = session.db().await;
1740 let room = db
1741 .update_room_participant_location(room_id, session.connection_id, location)
1742 .await?;
1743
1744 room_updated(&room, &session.peer);
1745 response.send(proto::Ack {})?;
1746 Ok(())
1747}
1748
1749/// Share a project into the room.
1750async fn share_project(
1751 request: proto::ShareProject,
1752 response: Response<proto::ShareProject>,
1753 session: Session,
1754) -> Result<()> {
1755 let (project_id, room) = &*session
1756 .db()
1757 .await
1758 .share_project(
1759 RoomId::from_proto(request.room_id),
1760 session.connection_id,
1761 &request.worktrees,
1762 request.is_ssh_project,
1763 )
1764 .await?;
1765 response.send(proto::ShareProjectResponse {
1766 project_id: project_id.to_proto(),
1767 })?;
1768 room_updated(room, &session.peer);
1769
1770 Ok(())
1771}
1772
1773/// Unshare a project from the room.
1774async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1775 let project_id = ProjectId::from_proto(message.project_id);
1776 unshare_project_internal(project_id, session.connection_id, &session).await
1777}
1778
1779async fn unshare_project_internal(
1780 project_id: ProjectId,
1781 connection_id: ConnectionId,
1782 session: &Session,
1783) -> Result<()> {
1784 let delete = {
1785 let room_guard = session
1786 .db()
1787 .await
1788 .unshare_project(project_id, connection_id)
1789 .await?;
1790
1791 let (delete, room, guest_connection_ids) = &*room_guard;
1792
1793 let message = proto::UnshareProject {
1794 project_id: project_id.to_proto(),
1795 };
1796
1797 broadcast(
1798 Some(connection_id),
1799 guest_connection_ids.iter().copied(),
1800 |conn_id| session.peer.send(conn_id, message.clone()),
1801 );
1802 if let Some(room) = room {
1803 room_updated(room, &session.peer);
1804 }
1805
1806 *delete
1807 };
1808
1809 if delete {
1810 let db = session.db().await;
1811 db.delete_project(project_id).await?;
1812 }
1813
1814 Ok(())
1815}
1816
1817/// Join someone elses shared project.
1818async fn join_project(
1819 request: proto::JoinProject,
1820 response: Response<proto::JoinProject>,
1821 session: Session,
1822) -> Result<()> {
1823 let project_id = ProjectId::from_proto(request.project_id);
1824
1825 tracing::info!(%project_id, "join project");
1826
1827 let db = session.db().await;
1828 let (project, replica_id) = &mut *db
1829 .join_project(project_id, session.connection_id, session.user_id())
1830 .await?;
1831 drop(db);
1832 tracing::info!(%project_id, "join remote project");
1833 join_project_internal(response, session, project, replica_id)
1834}
1835
1836trait JoinProjectInternalResponse {
1837 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1838}
1839impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1840 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1841 Response::<proto::JoinProject>::send(self, result)
1842 }
1843}
1844
1845fn join_project_internal(
1846 response: impl JoinProjectInternalResponse,
1847 session: Session,
1848 project: &mut Project,
1849 replica_id: &ReplicaId,
1850) -> Result<()> {
1851 let collaborators = project
1852 .collaborators
1853 .iter()
1854 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1855 .map(|collaborator| collaborator.to_proto())
1856 .collect::<Vec<_>>();
1857 let project_id = project.id;
1858 let guest_user_id = session.user_id();
1859
1860 let worktrees = project
1861 .worktrees
1862 .iter()
1863 .map(|(id, worktree)| proto::WorktreeMetadata {
1864 id: *id,
1865 root_name: worktree.root_name.clone(),
1866 visible: worktree.visible,
1867 abs_path: worktree.abs_path.clone(),
1868 })
1869 .collect::<Vec<_>>();
1870
1871 let add_project_collaborator = proto::AddProjectCollaborator {
1872 project_id: project_id.to_proto(),
1873 collaborator: Some(proto::Collaborator {
1874 peer_id: Some(session.connection_id.into()),
1875 replica_id: replica_id.0 as u32,
1876 user_id: guest_user_id.to_proto(),
1877 is_host: false,
1878 }),
1879 };
1880
1881 for collaborator in &collaborators {
1882 session
1883 .peer
1884 .send(
1885 collaborator.peer_id.unwrap().into(),
1886 add_project_collaborator.clone(),
1887 )
1888 .trace_err();
1889 }
1890
1891 // First, we send the metadata associated with each worktree.
1892 response.send(proto::JoinProjectResponse {
1893 project_id: project.id.0 as u64,
1894 worktrees: worktrees.clone(),
1895 replica_id: replica_id.0 as u32,
1896 collaborators: collaborators.clone(),
1897 language_servers: project.language_servers.clone(),
1898 role: project.role.into(),
1899 })?;
1900
1901 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1902 // Stream this worktree's entries.
1903 let message = proto::UpdateWorktree {
1904 project_id: project_id.to_proto(),
1905 worktree_id,
1906 abs_path: worktree.abs_path.clone(),
1907 root_name: worktree.root_name,
1908 updated_entries: worktree.entries,
1909 removed_entries: Default::default(),
1910 scan_id: worktree.scan_id,
1911 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1912 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1913 removed_repositories: Default::default(),
1914 };
1915 for update in proto::split_worktree_update(message) {
1916 session.peer.send(session.connection_id, update.clone())?;
1917 }
1918
1919 // Stream this worktree's diagnostics.
1920 for summary in worktree.diagnostic_summaries {
1921 session.peer.send(
1922 session.connection_id,
1923 proto::UpdateDiagnosticSummary {
1924 project_id: project_id.to_proto(),
1925 worktree_id: worktree.id,
1926 summary: Some(summary),
1927 },
1928 )?;
1929 }
1930
1931 for settings_file in worktree.settings_files {
1932 session.peer.send(
1933 session.connection_id,
1934 proto::UpdateWorktreeSettings {
1935 project_id: project_id.to_proto(),
1936 worktree_id: worktree.id,
1937 path: settings_file.path,
1938 content: Some(settings_file.content),
1939 kind: Some(settings_file.kind.to_proto() as i32),
1940 },
1941 )?;
1942 }
1943 }
1944
1945 for repository in mem::take(&mut project.repositories) {
1946 for update in split_repository_update(repository) {
1947 session.peer.send(session.connection_id, update)?;
1948 }
1949 }
1950
1951 for language_server in &project.language_servers {
1952 session.peer.send(
1953 session.connection_id,
1954 proto::UpdateLanguageServer {
1955 project_id: project_id.to_proto(),
1956 language_server_id: language_server.id,
1957 variant: Some(
1958 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1959 proto::LspDiskBasedDiagnosticsUpdated {},
1960 ),
1961 ),
1962 },
1963 )?;
1964 }
1965
1966 Ok(())
1967}
1968
1969/// Leave someone elses shared project.
1970async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1971 let sender_id = session.connection_id;
1972 let project_id = ProjectId::from_proto(request.project_id);
1973 let db = session.db().await;
1974
1975 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1976 tracing::info!(
1977 %project_id,
1978 "leave project"
1979 );
1980
1981 project_left(project, &session);
1982 if let Some(room) = room {
1983 room_updated(room, &session.peer);
1984 }
1985
1986 Ok(())
1987}
1988
1989/// Updates other participants with changes to the project
1990async fn update_project(
1991 request: proto::UpdateProject,
1992 response: Response<proto::UpdateProject>,
1993 session: Session,
1994) -> Result<()> {
1995 let project_id = ProjectId::from_proto(request.project_id);
1996 let (room, guest_connection_ids) = &*session
1997 .db()
1998 .await
1999 .update_project(project_id, session.connection_id, &request.worktrees)
2000 .await?;
2001 broadcast(
2002 Some(session.connection_id),
2003 guest_connection_ids.iter().copied(),
2004 |connection_id| {
2005 session
2006 .peer
2007 .forward_send(session.connection_id, connection_id, request.clone())
2008 },
2009 );
2010 if let Some(room) = room {
2011 room_updated(room, &session.peer);
2012 }
2013 response.send(proto::Ack {})?;
2014
2015 Ok(())
2016}
2017
2018/// Updates other participants with changes to the worktree
2019async fn update_worktree(
2020 request: proto::UpdateWorktree,
2021 response: Response<proto::UpdateWorktree>,
2022 session: Session,
2023) -> Result<()> {
2024 let guest_connection_ids = session
2025 .db()
2026 .await
2027 .update_worktree(&request, session.connection_id)
2028 .await?;
2029
2030 broadcast(
2031 Some(session.connection_id),
2032 guest_connection_ids.iter().copied(),
2033 |connection_id| {
2034 session
2035 .peer
2036 .forward_send(session.connection_id, connection_id, request.clone())
2037 },
2038 );
2039 response.send(proto::Ack {})?;
2040 Ok(())
2041}
2042
2043async fn update_repository(
2044 request: proto::UpdateRepository,
2045 response: Response<proto::UpdateRepository>,
2046 session: Session,
2047) -> Result<()> {
2048 let guest_connection_ids = session
2049 .db()
2050 .await
2051 .update_repository(&request, session.connection_id)
2052 .await?;
2053
2054 broadcast(
2055 Some(session.connection_id),
2056 guest_connection_ids.iter().copied(),
2057 |connection_id| {
2058 session
2059 .peer
2060 .forward_send(session.connection_id, connection_id, request.clone())
2061 },
2062 );
2063 response.send(proto::Ack {})?;
2064 Ok(())
2065}
2066
2067async fn remove_repository(
2068 request: proto::RemoveRepository,
2069 response: Response<proto::RemoveRepository>,
2070 session: Session,
2071) -> Result<()> {
2072 let guest_connection_ids = session
2073 .db()
2074 .await
2075 .remove_repository(&request, session.connection_id)
2076 .await?;
2077
2078 broadcast(
2079 Some(session.connection_id),
2080 guest_connection_ids.iter().copied(),
2081 |connection_id| {
2082 session
2083 .peer
2084 .forward_send(session.connection_id, connection_id, request.clone())
2085 },
2086 );
2087 response.send(proto::Ack {})?;
2088 Ok(())
2089}
2090
2091/// Updates other participants with changes to the diagnostics
2092async fn update_diagnostic_summary(
2093 message: proto::UpdateDiagnosticSummary,
2094 session: Session,
2095) -> Result<()> {
2096 let guest_connection_ids = session
2097 .db()
2098 .await
2099 .update_diagnostic_summary(&message, session.connection_id)
2100 .await?;
2101
2102 broadcast(
2103 Some(session.connection_id),
2104 guest_connection_ids.iter().copied(),
2105 |connection_id| {
2106 session
2107 .peer
2108 .forward_send(session.connection_id, connection_id, message.clone())
2109 },
2110 );
2111
2112 Ok(())
2113}
2114
2115/// Updates other participants with changes to the worktree settings
2116async fn update_worktree_settings(
2117 message: proto::UpdateWorktreeSettings,
2118 session: Session,
2119) -> Result<()> {
2120 let guest_connection_ids = session
2121 .db()
2122 .await
2123 .update_worktree_settings(&message, session.connection_id)
2124 .await?;
2125
2126 broadcast(
2127 Some(session.connection_id),
2128 guest_connection_ids.iter().copied(),
2129 |connection_id| {
2130 session
2131 .peer
2132 .forward_send(session.connection_id, connection_id, message.clone())
2133 },
2134 );
2135
2136 Ok(())
2137}
2138
2139/// Notify other participants that a language server has started.
2140async fn start_language_server(
2141 request: proto::StartLanguageServer,
2142 session: Session,
2143) -> Result<()> {
2144 let guest_connection_ids = session
2145 .db()
2146 .await
2147 .start_language_server(&request, session.connection_id)
2148 .await?;
2149
2150 broadcast(
2151 Some(session.connection_id),
2152 guest_connection_ids.iter().copied(),
2153 |connection_id| {
2154 session
2155 .peer
2156 .forward_send(session.connection_id, connection_id, request.clone())
2157 },
2158 );
2159 Ok(())
2160}
2161
2162/// Notify other participants that a language server has changed.
2163async fn update_language_server(
2164 request: proto::UpdateLanguageServer,
2165 session: Session,
2166) -> Result<()> {
2167 let project_id = ProjectId::from_proto(request.project_id);
2168 let project_connection_ids = session
2169 .db()
2170 .await
2171 .project_connection_ids(project_id, session.connection_id, true)
2172 .await?;
2173 broadcast(
2174 Some(session.connection_id),
2175 project_connection_ids.iter().copied(),
2176 |connection_id| {
2177 session
2178 .peer
2179 .forward_send(session.connection_id, connection_id, request.clone())
2180 },
2181 );
2182 Ok(())
2183}
2184
2185/// forward a project request to the host. These requests should be read only
2186/// as guests are allowed to send them.
2187async fn forward_read_only_project_request<T>(
2188 request: T,
2189 response: Response<T>,
2190 session: Session,
2191) -> Result<()>
2192where
2193 T: EntityMessage + RequestMessage,
2194{
2195 let project_id = ProjectId::from_proto(request.remote_entity_id());
2196 let host_connection_id = session
2197 .db()
2198 .await
2199 .host_for_read_only_project_request(project_id, session.connection_id)
2200 .await?;
2201 let payload = session
2202 .peer
2203 .forward_request(session.connection_id, host_connection_id, request)
2204 .await?;
2205 response.send(payload)?;
2206 Ok(())
2207}
2208
2209async fn forward_find_search_candidates_request(
2210 request: proto::FindSearchCandidates,
2211 response: Response<proto::FindSearchCandidates>,
2212 session: Session,
2213) -> Result<()> {
2214 let project_id = ProjectId::from_proto(request.remote_entity_id());
2215 let host_connection_id = session
2216 .db()
2217 .await
2218 .host_for_read_only_project_request(project_id, session.connection_id)
2219 .await?;
2220 let payload = session
2221 .peer
2222 .forward_request(session.connection_id, host_connection_id, request)
2223 .await?;
2224 response.send(payload)?;
2225 Ok(())
2226}
2227
2228/// forward a project request to the host. These requests are disallowed
2229/// for guests.
2230async fn forward_mutating_project_request<T>(
2231 request: T,
2232 response: Response<T>,
2233 session: Session,
2234) -> Result<()>
2235where
2236 T: EntityMessage + RequestMessage,
2237{
2238 let project_id = ProjectId::from_proto(request.remote_entity_id());
2239
2240 let host_connection_id = session
2241 .db()
2242 .await
2243 .host_for_mutating_project_request(project_id, session.connection_id)
2244 .await?;
2245 let payload = session
2246 .peer
2247 .forward_request(session.connection_id, host_connection_id, request)
2248 .await?;
2249 response.send(payload)?;
2250 Ok(())
2251}
2252
2253/// Notify other participants that a new buffer has been created
2254async fn create_buffer_for_peer(
2255 request: proto::CreateBufferForPeer,
2256 session: Session,
2257) -> Result<()> {
2258 session
2259 .db()
2260 .await
2261 .check_user_is_project_host(
2262 ProjectId::from_proto(request.project_id),
2263 session.connection_id,
2264 )
2265 .await?;
2266 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2267 session
2268 .peer
2269 .forward_send(session.connection_id, peer_id.into(), request)?;
2270 Ok(())
2271}
2272
2273/// Notify other participants that a buffer has been updated. This is
2274/// allowed for guests as long as the update is limited to selections.
2275async fn update_buffer(
2276 request: proto::UpdateBuffer,
2277 response: Response<proto::UpdateBuffer>,
2278 session: Session,
2279) -> Result<()> {
2280 let project_id = ProjectId::from_proto(request.project_id);
2281 let mut capability = Capability::ReadOnly;
2282
2283 for op in request.operations.iter() {
2284 match op.variant {
2285 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2286 Some(_) => capability = Capability::ReadWrite,
2287 }
2288 }
2289
2290 let host = {
2291 let guard = session
2292 .db()
2293 .await
2294 .connections_for_buffer_update(project_id, session.connection_id, capability)
2295 .await?;
2296
2297 let (host, guests) = &*guard;
2298
2299 broadcast(
2300 Some(session.connection_id),
2301 guests.clone(),
2302 |connection_id| {
2303 session
2304 .peer
2305 .forward_send(session.connection_id, connection_id, request.clone())
2306 },
2307 );
2308
2309 *host
2310 };
2311
2312 if host != session.connection_id {
2313 session
2314 .peer
2315 .forward_request(session.connection_id, host, request.clone())
2316 .await?;
2317 }
2318
2319 response.send(proto::Ack {})?;
2320 Ok(())
2321}
2322
2323async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2324 let project_id = ProjectId::from_proto(message.project_id);
2325
2326 let operation = message.operation.as_ref().context("invalid operation")?;
2327 let capability = match operation.variant.as_ref() {
2328 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2329 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2330 match buffer_op.variant {
2331 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2332 Capability::ReadOnly
2333 }
2334 _ => Capability::ReadWrite,
2335 }
2336 } else {
2337 Capability::ReadWrite
2338 }
2339 }
2340 Some(_) => Capability::ReadWrite,
2341 None => Capability::ReadOnly,
2342 };
2343
2344 let guard = session
2345 .db()
2346 .await
2347 .connections_for_buffer_update(project_id, session.connection_id, capability)
2348 .await?;
2349
2350 let (host, guests) = &*guard;
2351
2352 broadcast(
2353 Some(session.connection_id),
2354 guests.iter().chain([host]).copied(),
2355 |connection_id| {
2356 session
2357 .peer
2358 .forward_send(session.connection_id, connection_id, message.clone())
2359 },
2360 );
2361
2362 Ok(())
2363}
2364
2365/// Notify other participants that a project has been updated.
2366async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2367 request: T,
2368 session: Session,
2369) -> Result<()> {
2370 let project_id = ProjectId::from_proto(request.remote_entity_id());
2371 let project_connection_ids = session
2372 .db()
2373 .await
2374 .project_connection_ids(project_id, session.connection_id, false)
2375 .await?;
2376
2377 broadcast(
2378 Some(session.connection_id),
2379 project_connection_ids.iter().copied(),
2380 |connection_id| {
2381 session
2382 .peer
2383 .forward_send(session.connection_id, connection_id, request.clone())
2384 },
2385 );
2386 Ok(())
2387}
2388
2389/// Start following another user in a call.
2390async fn follow(
2391 request: proto::Follow,
2392 response: Response<proto::Follow>,
2393 session: Session,
2394) -> Result<()> {
2395 let room_id = RoomId::from_proto(request.room_id);
2396 let project_id = request.project_id.map(ProjectId::from_proto);
2397 let leader_id = request
2398 .leader_id
2399 .ok_or_else(|| anyhow!("invalid leader id"))?
2400 .into();
2401 let follower_id = session.connection_id;
2402
2403 session
2404 .db()
2405 .await
2406 .check_room_participants(room_id, leader_id, session.connection_id)
2407 .await?;
2408
2409 let response_payload = session
2410 .peer
2411 .forward_request(session.connection_id, leader_id, request)
2412 .await?;
2413 response.send(response_payload)?;
2414
2415 if let Some(project_id) = project_id {
2416 let room = session
2417 .db()
2418 .await
2419 .follow(room_id, project_id, leader_id, follower_id)
2420 .await?;
2421 room_updated(&room, &session.peer);
2422 }
2423
2424 Ok(())
2425}
2426
2427/// Stop following another user in a call.
2428async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2429 let room_id = RoomId::from_proto(request.room_id);
2430 let project_id = request.project_id.map(ProjectId::from_proto);
2431 let leader_id = request
2432 .leader_id
2433 .ok_or_else(|| anyhow!("invalid leader id"))?
2434 .into();
2435 let follower_id = session.connection_id;
2436
2437 session
2438 .db()
2439 .await
2440 .check_room_participants(room_id, leader_id, session.connection_id)
2441 .await?;
2442
2443 session
2444 .peer
2445 .forward_send(session.connection_id, leader_id, request)?;
2446
2447 if let Some(project_id) = project_id {
2448 let room = session
2449 .db()
2450 .await
2451 .unfollow(room_id, project_id, leader_id, follower_id)
2452 .await?;
2453 room_updated(&room, &session.peer);
2454 }
2455
2456 Ok(())
2457}
2458
2459/// Notify everyone following you of your current location.
2460async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2461 let room_id = RoomId::from_proto(request.room_id);
2462 let database = session.db.lock().await;
2463
2464 let connection_ids = if let Some(project_id) = request.project_id {
2465 let project_id = ProjectId::from_proto(project_id);
2466 database
2467 .project_connection_ids(project_id, session.connection_id, true)
2468 .await?
2469 } else {
2470 database
2471 .room_connection_ids(room_id, session.connection_id)
2472 .await?
2473 };
2474
2475 // For now, don't send view update messages back to that view's current leader.
2476 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2477 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2478 _ => None,
2479 });
2480
2481 for connection_id in connection_ids.iter().cloned() {
2482 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2483 session
2484 .peer
2485 .forward_send(session.connection_id, connection_id, request.clone())?;
2486 }
2487 }
2488 Ok(())
2489}
2490
2491/// Get public data about users.
2492async fn get_users(
2493 request: proto::GetUsers,
2494 response: Response<proto::GetUsers>,
2495 session: Session,
2496) -> Result<()> {
2497 let user_ids = request
2498 .user_ids
2499 .into_iter()
2500 .map(UserId::from_proto)
2501 .collect();
2502 let users = session
2503 .db()
2504 .await
2505 .get_users_by_ids(user_ids)
2506 .await?
2507 .into_iter()
2508 .map(|user| proto::User {
2509 id: user.id.to_proto(),
2510 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2511 github_login: user.github_login,
2512 email: user.email_address,
2513 name: user.name,
2514 })
2515 .collect();
2516 response.send(proto::UsersResponse { users })?;
2517 Ok(())
2518}
2519
2520/// Search for users (to invite) buy Github login
2521async fn fuzzy_search_users(
2522 request: proto::FuzzySearchUsers,
2523 response: Response<proto::FuzzySearchUsers>,
2524 session: Session,
2525) -> Result<()> {
2526 let query = request.query;
2527 let users = match query.len() {
2528 0 => vec![],
2529 1 | 2 => session
2530 .db()
2531 .await
2532 .get_user_by_github_login(&query)
2533 .await?
2534 .into_iter()
2535 .collect(),
2536 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2537 };
2538 let users = users
2539 .into_iter()
2540 .filter(|user| user.id != session.user_id())
2541 .map(|user| proto::User {
2542 id: user.id.to_proto(),
2543 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2544 github_login: user.github_login,
2545 name: user.name,
2546 email: user.email_address,
2547 })
2548 .collect();
2549 response.send(proto::UsersResponse { users })?;
2550 Ok(())
2551}
2552
2553/// Send a contact request to another user.
2554async fn request_contact(
2555 request: proto::RequestContact,
2556 response: Response<proto::RequestContact>,
2557 session: Session,
2558) -> Result<()> {
2559 let requester_id = session.user_id();
2560 let responder_id = UserId::from_proto(request.responder_id);
2561 if requester_id == responder_id {
2562 return Err(anyhow!("cannot add yourself as a contact"))?;
2563 }
2564
2565 let notifications = session
2566 .db()
2567 .await
2568 .send_contact_request(requester_id, responder_id)
2569 .await?;
2570
2571 // Update outgoing contact requests of requester
2572 let mut update = proto::UpdateContacts::default();
2573 update.outgoing_requests.push(responder_id.to_proto());
2574 for connection_id in session
2575 .connection_pool()
2576 .await
2577 .user_connection_ids(requester_id)
2578 {
2579 session.peer.send(connection_id, update.clone())?;
2580 }
2581
2582 // Update incoming contact requests of responder
2583 let mut update = proto::UpdateContacts::default();
2584 update
2585 .incoming_requests
2586 .push(proto::IncomingContactRequest {
2587 requester_id: requester_id.to_proto(),
2588 });
2589 let connection_pool = session.connection_pool().await;
2590 for connection_id in connection_pool.user_connection_ids(responder_id) {
2591 session.peer.send(connection_id, update.clone())?;
2592 }
2593
2594 send_notifications(&connection_pool, &session.peer, notifications);
2595
2596 response.send(proto::Ack {})?;
2597 Ok(())
2598}
2599
2600/// Accept or decline a contact request
2601async fn respond_to_contact_request(
2602 request: proto::RespondToContactRequest,
2603 response: Response<proto::RespondToContactRequest>,
2604 session: Session,
2605) -> Result<()> {
2606 let responder_id = session.user_id();
2607 let requester_id = UserId::from_proto(request.requester_id);
2608 let db = session.db().await;
2609 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2610 db.dismiss_contact_notification(responder_id, requester_id)
2611 .await?;
2612 } else {
2613 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2614
2615 let notifications = db
2616 .respond_to_contact_request(responder_id, requester_id, accept)
2617 .await?;
2618 let requester_busy = db.is_user_busy(requester_id).await?;
2619 let responder_busy = db.is_user_busy(responder_id).await?;
2620
2621 let pool = session.connection_pool().await;
2622 // Update responder with new contact
2623 let mut update = proto::UpdateContacts::default();
2624 if accept {
2625 update
2626 .contacts
2627 .push(contact_for_user(requester_id, requester_busy, &pool));
2628 }
2629 update
2630 .remove_incoming_requests
2631 .push(requester_id.to_proto());
2632 for connection_id in pool.user_connection_ids(responder_id) {
2633 session.peer.send(connection_id, update.clone())?;
2634 }
2635
2636 // Update requester with new contact
2637 let mut update = proto::UpdateContacts::default();
2638 if accept {
2639 update
2640 .contacts
2641 .push(contact_for_user(responder_id, responder_busy, &pool));
2642 }
2643 update
2644 .remove_outgoing_requests
2645 .push(responder_id.to_proto());
2646
2647 for connection_id in pool.user_connection_ids(requester_id) {
2648 session.peer.send(connection_id, update.clone())?;
2649 }
2650
2651 send_notifications(&pool, &session.peer, notifications);
2652 }
2653
2654 response.send(proto::Ack {})?;
2655 Ok(())
2656}
2657
2658/// Remove a contact.
2659async fn remove_contact(
2660 request: proto::RemoveContact,
2661 response: Response<proto::RemoveContact>,
2662 session: Session,
2663) -> Result<()> {
2664 let requester_id = session.user_id();
2665 let responder_id = UserId::from_proto(request.user_id);
2666 let db = session.db().await;
2667 let (contact_accepted, deleted_notification_id) =
2668 db.remove_contact(requester_id, responder_id).await?;
2669
2670 let pool = session.connection_pool().await;
2671 // Update outgoing contact requests of requester
2672 let mut update = proto::UpdateContacts::default();
2673 if contact_accepted {
2674 update.remove_contacts.push(responder_id.to_proto());
2675 } else {
2676 update
2677 .remove_outgoing_requests
2678 .push(responder_id.to_proto());
2679 }
2680 for connection_id in pool.user_connection_ids(requester_id) {
2681 session.peer.send(connection_id, update.clone())?;
2682 }
2683
2684 // Update incoming contact requests of responder
2685 let mut update = proto::UpdateContacts::default();
2686 if contact_accepted {
2687 update.remove_contacts.push(requester_id.to_proto());
2688 } else {
2689 update
2690 .remove_incoming_requests
2691 .push(requester_id.to_proto());
2692 }
2693 for connection_id in pool.user_connection_ids(responder_id) {
2694 session.peer.send(connection_id, update.clone())?;
2695 if let Some(notification_id) = deleted_notification_id {
2696 session.peer.send(
2697 connection_id,
2698 proto::DeleteNotification {
2699 notification_id: notification_id.to_proto(),
2700 },
2701 )?;
2702 }
2703 }
2704
2705 response.send(proto::Ack {})?;
2706 Ok(())
2707}
2708
2709fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2710 version.0.minor() < 139
2711}
2712
2713async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2714 let plan = session.current_plan(&session.db().await).await?;
2715
2716 session
2717 .peer
2718 .send(
2719 session.connection_id,
2720 proto::UpdateUserPlan { plan: plan.into() },
2721 )
2722 .trace_err();
2723
2724 Ok(())
2725}
2726
2727async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2728 subscribe_user_to_channels(session.user_id(), &session).await?;
2729 Ok(())
2730}
2731
2732async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2733 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2734 let mut pool = session.connection_pool().await;
2735 for membership in &channels_for_user.channel_memberships {
2736 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2737 }
2738 session.peer.send(
2739 session.connection_id,
2740 build_update_user_channels(&channels_for_user),
2741 )?;
2742 session.peer.send(
2743 session.connection_id,
2744 build_channels_update(channels_for_user),
2745 )?;
2746 Ok(())
2747}
2748
2749/// Creates a new channel.
2750async fn create_channel(
2751 request: proto::CreateChannel,
2752 response: Response<proto::CreateChannel>,
2753 session: Session,
2754) -> Result<()> {
2755 let db = session.db().await;
2756
2757 let parent_id = request.parent_id.map(ChannelId::from_proto);
2758 let (channel, membership) = db
2759 .create_channel(&request.name, parent_id, session.user_id())
2760 .await?;
2761
2762 let root_id = channel.root_id();
2763 let channel = Channel::from_model(channel);
2764
2765 response.send(proto::CreateChannelResponse {
2766 channel: Some(channel.to_proto()),
2767 parent_id: request.parent_id,
2768 })?;
2769
2770 let mut connection_pool = session.connection_pool().await;
2771 if let Some(membership) = membership {
2772 connection_pool.subscribe_to_channel(
2773 membership.user_id,
2774 membership.channel_id,
2775 membership.role,
2776 );
2777 let update = proto::UpdateUserChannels {
2778 channel_memberships: vec![proto::ChannelMembership {
2779 channel_id: membership.channel_id.to_proto(),
2780 role: membership.role.into(),
2781 }],
2782 ..Default::default()
2783 };
2784 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2785 session.peer.send(connection_id, update.clone())?;
2786 }
2787 }
2788
2789 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2790 if !role.can_see_channel(channel.visibility) {
2791 continue;
2792 }
2793
2794 let update = proto::UpdateChannels {
2795 channels: vec![channel.to_proto()],
2796 ..Default::default()
2797 };
2798 session.peer.send(connection_id, update.clone())?;
2799 }
2800
2801 Ok(())
2802}
2803
2804/// Delete a channel
2805async fn delete_channel(
2806 request: proto::DeleteChannel,
2807 response: Response<proto::DeleteChannel>,
2808 session: Session,
2809) -> Result<()> {
2810 let db = session.db().await;
2811
2812 let channel_id = request.channel_id;
2813 let (root_channel, removed_channels) = db
2814 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2815 .await?;
2816 response.send(proto::Ack {})?;
2817
2818 // Notify members of removed channels
2819 let mut update = proto::UpdateChannels::default();
2820 update
2821 .delete_channels
2822 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2823
2824 let connection_pool = session.connection_pool().await;
2825 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2826 session.peer.send(connection_id, update.clone())?;
2827 }
2828
2829 Ok(())
2830}
2831
2832/// Invite someone to join a channel.
2833async fn invite_channel_member(
2834 request: proto::InviteChannelMember,
2835 response: Response<proto::InviteChannelMember>,
2836 session: Session,
2837) -> Result<()> {
2838 let db = session.db().await;
2839 let channel_id = ChannelId::from_proto(request.channel_id);
2840 let invitee_id = UserId::from_proto(request.user_id);
2841 let InviteMemberResult {
2842 channel,
2843 notifications,
2844 } = db
2845 .invite_channel_member(
2846 channel_id,
2847 invitee_id,
2848 session.user_id(),
2849 request.role().into(),
2850 )
2851 .await?;
2852
2853 let update = proto::UpdateChannels {
2854 channel_invitations: vec![channel.to_proto()],
2855 ..Default::default()
2856 };
2857
2858 let connection_pool = session.connection_pool().await;
2859 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2860 session.peer.send(connection_id, update.clone())?;
2861 }
2862
2863 send_notifications(&connection_pool, &session.peer, notifications);
2864
2865 response.send(proto::Ack {})?;
2866 Ok(())
2867}
2868
2869/// remove someone from a channel
2870async fn remove_channel_member(
2871 request: proto::RemoveChannelMember,
2872 response: Response<proto::RemoveChannelMember>,
2873 session: Session,
2874) -> Result<()> {
2875 let db = session.db().await;
2876 let channel_id = ChannelId::from_proto(request.channel_id);
2877 let member_id = UserId::from_proto(request.user_id);
2878
2879 let RemoveChannelMemberResult {
2880 membership_update,
2881 notification_id,
2882 } = db
2883 .remove_channel_member(channel_id, member_id, session.user_id())
2884 .await?;
2885
2886 let mut connection_pool = session.connection_pool().await;
2887 notify_membership_updated(
2888 &mut connection_pool,
2889 membership_update,
2890 member_id,
2891 &session.peer,
2892 );
2893 for connection_id in connection_pool.user_connection_ids(member_id) {
2894 if let Some(notification_id) = notification_id {
2895 session
2896 .peer
2897 .send(
2898 connection_id,
2899 proto::DeleteNotification {
2900 notification_id: notification_id.to_proto(),
2901 },
2902 )
2903 .trace_err();
2904 }
2905 }
2906
2907 response.send(proto::Ack {})?;
2908 Ok(())
2909}
2910
2911/// Toggle the channel between public and private.
2912/// Care is taken to maintain the invariant that public channels only descend from public channels,
2913/// (though members-only channels can appear at any point in the hierarchy).
2914async fn set_channel_visibility(
2915 request: proto::SetChannelVisibility,
2916 response: Response<proto::SetChannelVisibility>,
2917 session: Session,
2918) -> Result<()> {
2919 let db = session.db().await;
2920 let channel_id = ChannelId::from_proto(request.channel_id);
2921 let visibility = request.visibility().into();
2922
2923 let channel_model = db
2924 .set_channel_visibility(channel_id, visibility, session.user_id())
2925 .await?;
2926 let root_id = channel_model.root_id();
2927 let channel = Channel::from_model(channel_model);
2928
2929 let mut connection_pool = session.connection_pool().await;
2930 for (user_id, role) in connection_pool
2931 .channel_user_ids(root_id)
2932 .collect::<Vec<_>>()
2933 .into_iter()
2934 {
2935 let update = if role.can_see_channel(channel.visibility) {
2936 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2937 proto::UpdateChannels {
2938 channels: vec![channel.to_proto()],
2939 ..Default::default()
2940 }
2941 } else {
2942 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2943 proto::UpdateChannels {
2944 delete_channels: vec![channel.id.to_proto()],
2945 ..Default::default()
2946 }
2947 };
2948
2949 for connection_id in connection_pool.user_connection_ids(user_id) {
2950 session.peer.send(connection_id, update.clone())?;
2951 }
2952 }
2953
2954 response.send(proto::Ack {})?;
2955 Ok(())
2956}
2957
2958/// Alter the role for a user in the channel.
2959async fn set_channel_member_role(
2960 request: proto::SetChannelMemberRole,
2961 response: Response<proto::SetChannelMemberRole>,
2962 session: Session,
2963) -> Result<()> {
2964 let db = session.db().await;
2965 let channel_id = ChannelId::from_proto(request.channel_id);
2966 let member_id = UserId::from_proto(request.user_id);
2967 let result = db
2968 .set_channel_member_role(
2969 channel_id,
2970 session.user_id(),
2971 member_id,
2972 request.role().into(),
2973 )
2974 .await?;
2975
2976 match result {
2977 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2978 let mut connection_pool = session.connection_pool().await;
2979 notify_membership_updated(
2980 &mut connection_pool,
2981 membership_update,
2982 member_id,
2983 &session.peer,
2984 )
2985 }
2986 db::SetMemberRoleResult::InviteUpdated(channel) => {
2987 let update = proto::UpdateChannels {
2988 channel_invitations: vec![channel.to_proto()],
2989 ..Default::default()
2990 };
2991
2992 for connection_id in session
2993 .connection_pool()
2994 .await
2995 .user_connection_ids(member_id)
2996 {
2997 session.peer.send(connection_id, update.clone())?;
2998 }
2999 }
3000 }
3001
3002 response.send(proto::Ack {})?;
3003 Ok(())
3004}
3005
3006/// Change the name of a channel
3007async fn rename_channel(
3008 request: proto::RenameChannel,
3009 response: Response<proto::RenameChannel>,
3010 session: Session,
3011) -> Result<()> {
3012 let db = session.db().await;
3013 let channel_id = ChannelId::from_proto(request.channel_id);
3014 let channel_model = db
3015 .rename_channel(channel_id, session.user_id(), &request.name)
3016 .await?;
3017 let root_id = channel_model.root_id();
3018 let channel = Channel::from_model(channel_model);
3019
3020 response.send(proto::RenameChannelResponse {
3021 channel: Some(channel.to_proto()),
3022 })?;
3023
3024 let connection_pool = session.connection_pool().await;
3025 let update = proto::UpdateChannels {
3026 channels: vec![channel.to_proto()],
3027 ..Default::default()
3028 };
3029 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3030 if role.can_see_channel(channel.visibility) {
3031 session.peer.send(connection_id, update.clone())?;
3032 }
3033 }
3034
3035 Ok(())
3036}
3037
3038/// Move a channel to a new parent.
3039async fn move_channel(
3040 request: proto::MoveChannel,
3041 response: Response<proto::MoveChannel>,
3042 session: Session,
3043) -> Result<()> {
3044 let channel_id = ChannelId::from_proto(request.channel_id);
3045 let to = ChannelId::from_proto(request.to);
3046
3047 let (root_id, channels) = session
3048 .db()
3049 .await
3050 .move_channel(channel_id, to, session.user_id())
3051 .await?;
3052
3053 let connection_pool = session.connection_pool().await;
3054 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055 let channels = channels
3056 .iter()
3057 .filter_map(|channel| {
3058 if role.can_see_channel(channel.visibility) {
3059 Some(channel.to_proto())
3060 } else {
3061 None
3062 }
3063 })
3064 .collect::<Vec<_>>();
3065 if channels.is_empty() {
3066 continue;
3067 }
3068
3069 let update = proto::UpdateChannels {
3070 channels,
3071 ..Default::default()
3072 };
3073
3074 session.peer.send(connection_id, update.clone())?;
3075 }
3076
3077 response.send(Ack {})?;
3078 Ok(())
3079}
3080
3081/// Get the list of channel members
3082async fn get_channel_members(
3083 request: proto::GetChannelMembers,
3084 response: Response<proto::GetChannelMembers>,
3085 session: Session,
3086) -> Result<()> {
3087 let db = session.db().await;
3088 let channel_id = ChannelId::from_proto(request.channel_id);
3089 let limit = if request.limit == 0 {
3090 u16::MAX as u64
3091 } else {
3092 request.limit
3093 };
3094 let (members, users) = db
3095 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3096 .await?;
3097 response.send(proto::GetChannelMembersResponse { members, users })?;
3098 Ok(())
3099}
3100
3101/// Accept or decline a channel invitation.
3102async fn respond_to_channel_invite(
3103 request: proto::RespondToChannelInvite,
3104 response: Response<proto::RespondToChannelInvite>,
3105 session: Session,
3106) -> Result<()> {
3107 let db = session.db().await;
3108 let channel_id = ChannelId::from_proto(request.channel_id);
3109 let RespondToChannelInvite {
3110 membership_update,
3111 notifications,
3112 } = db
3113 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3114 .await?;
3115
3116 let mut connection_pool = session.connection_pool().await;
3117 if let Some(membership_update) = membership_update {
3118 notify_membership_updated(
3119 &mut connection_pool,
3120 membership_update,
3121 session.user_id(),
3122 &session.peer,
3123 );
3124 } else {
3125 let update = proto::UpdateChannels {
3126 remove_channel_invitations: vec![channel_id.to_proto()],
3127 ..Default::default()
3128 };
3129
3130 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3131 session.peer.send(connection_id, update.clone())?;
3132 }
3133 };
3134
3135 send_notifications(&connection_pool, &session.peer, notifications);
3136
3137 response.send(proto::Ack {})?;
3138
3139 Ok(())
3140}
3141
3142/// Join the channels' room
3143async fn join_channel(
3144 request: proto::JoinChannel,
3145 response: Response<proto::JoinChannel>,
3146 session: Session,
3147) -> Result<()> {
3148 let channel_id = ChannelId::from_proto(request.channel_id);
3149 join_channel_internal(channel_id, Box::new(response), session).await
3150}
3151
3152trait JoinChannelInternalResponse {
3153 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3154}
3155impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3156 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3157 Response::<proto::JoinChannel>::send(self, result)
3158 }
3159}
3160impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3161 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3162 Response::<proto::JoinRoom>::send(self, result)
3163 }
3164}
3165
3166async fn join_channel_internal(
3167 channel_id: ChannelId,
3168 response: Box<impl JoinChannelInternalResponse>,
3169 session: Session,
3170) -> Result<()> {
3171 let joined_room = {
3172 let mut db = session.db().await;
3173 // If zed quits without leaving the room, and the user re-opens zed before the
3174 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3175 // room they were in.
3176 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3177 tracing::info!(
3178 stale_connection_id = %connection,
3179 "cleaning up stale connection",
3180 );
3181 drop(db);
3182 leave_room_for_session(&session, connection).await?;
3183 db = session.db().await;
3184 }
3185
3186 let (joined_room, membership_updated, role) = db
3187 .join_channel(channel_id, session.user_id(), session.connection_id)
3188 .await?;
3189
3190 let live_kit_connection_info =
3191 session
3192 .app_state
3193 .livekit_client
3194 .as_ref()
3195 .and_then(|live_kit| {
3196 let (can_publish, token) = if role == ChannelRole::Guest {
3197 (
3198 false,
3199 live_kit
3200 .guest_token(
3201 &joined_room.room.livekit_room,
3202 &session.user_id().to_string(),
3203 )
3204 .trace_err()?,
3205 )
3206 } else {
3207 (
3208 true,
3209 live_kit
3210 .room_token(
3211 &joined_room.room.livekit_room,
3212 &session.user_id().to_string(),
3213 )
3214 .trace_err()?,
3215 )
3216 };
3217
3218 Some(LiveKitConnectionInfo {
3219 server_url: live_kit.url().into(),
3220 token,
3221 can_publish,
3222 })
3223 });
3224
3225 response.send(proto::JoinRoomResponse {
3226 room: Some(joined_room.room.clone()),
3227 channel_id: joined_room
3228 .channel
3229 .as_ref()
3230 .map(|channel| channel.id.to_proto()),
3231 live_kit_connection_info,
3232 })?;
3233
3234 let mut connection_pool = session.connection_pool().await;
3235 if let Some(membership_updated) = membership_updated {
3236 notify_membership_updated(
3237 &mut connection_pool,
3238 membership_updated,
3239 session.user_id(),
3240 &session.peer,
3241 );
3242 }
3243
3244 room_updated(&joined_room.room, &session.peer);
3245
3246 joined_room
3247 };
3248
3249 channel_updated(
3250 &joined_room
3251 .channel
3252 .ok_or_else(|| anyhow!("channel not returned"))?,
3253 &joined_room.room,
3254 &session.peer,
3255 &*session.connection_pool().await,
3256 );
3257
3258 update_user_contacts(session.user_id(), &session).await?;
3259 Ok(())
3260}
3261
3262/// Start editing the channel notes
3263async fn join_channel_buffer(
3264 request: proto::JoinChannelBuffer,
3265 response: Response<proto::JoinChannelBuffer>,
3266 session: Session,
3267) -> Result<()> {
3268 let db = session.db().await;
3269 let channel_id = ChannelId::from_proto(request.channel_id);
3270
3271 let open_response = db
3272 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3273 .await?;
3274
3275 let collaborators = open_response.collaborators.clone();
3276 response.send(open_response)?;
3277
3278 let update = UpdateChannelBufferCollaborators {
3279 channel_id: channel_id.to_proto(),
3280 collaborators: collaborators.clone(),
3281 };
3282 channel_buffer_updated(
3283 session.connection_id,
3284 collaborators
3285 .iter()
3286 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3287 &update,
3288 &session.peer,
3289 );
3290
3291 Ok(())
3292}
3293
3294/// Edit the channel notes
3295async fn update_channel_buffer(
3296 request: proto::UpdateChannelBuffer,
3297 session: Session,
3298) -> Result<()> {
3299 let db = session.db().await;
3300 let channel_id = ChannelId::from_proto(request.channel_id);
3301
3302 let (collaborators, epoch, version) = db
3303 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3304 .await?;
3305
3306 channel_buffer_updated(
3307 session.connection_id,
3308 collaborators.clone(),
3309 &proto::UpdateChannelBuffer {
3310 channel_id: channel_id.to_proto(),
3311 operations: request.operations,
3312 },
3313 &session.peer,
3314 );
3315
3316 let pool = &*session.connection_pool().await;
3317
3318 let non_collaborators =
3319 pool.channel_connection_ids(channel_id)
3320 .filter_map(|(connection_id, _)| {
3321 if collaborators.contains(&connection_id) {
3322 None
3323 } else {
3324 Some(connection_id)
3325 }
3326 });
3327
3328 broadcast(None, non_collaborators, |peer_id| {
3329 session.peer.send(
3330 peer_id,
3331 proto::UpdateChannels {
3332 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3333 channel_id: channel_id.to_proto(),
3334 epoch: epoch as u64,
3335 version: version.clone(),
3336 }],
3337 ..Default::default()
3338 },
3339 )
3340 });
3341
3342 Ok(())
3343}
3344
3345/// Rejoin the channel notes after a connection blip
3346async fn rejoin_channel_buffers(
3347 request: proto::RejoinChannelBuffers,
3348 response: Response<proto::RejoinChannelBuffers>,
3349 session: Session,
3350) -> Result<()> {
3351 let db = session.db().await;
3352 let buffers = db
3353 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3354 .await?;
3355
3356 for rejoined_buffer in &buffers {
3357 let collaborators_to_notify = rejoined_buffer
3358 .buffer
3359 .collaborators
3360 .iter()
3361 .filter_map(|c| Some(c.peer_id?.into()));
3362 channel_buffer_updated(
3363 session.connection_id,
3364 collaborators_to_notify,
3365 &proto::UpdateChannelBufferCollaborators {
3366 channel_id: rejoined_buffer.buffer.channel_id,
3367 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3368 },
3369 &session.peer,
3370 );
3371 }
3372
3373 response.send(proto::RejoinChannelBuffersResponse {
3374 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3375 })?;
3376
3377 Ok(())
3378}
3379
3380/// Stop editing the channel notes
3381async fn leave_channel_buffer(
3382 request: proto::LeaveChannelBuffer,
3383 response: Response<proto::LeaveChannelBuffer>,
3384 session: Session,
3385) -> Result<()> {
3386 let db = session.db().await;
3387 let channel_id = ChannelId::from_proto(request.channel_id);
3388
3389 let left_buffer = db
3390 .leave_channel_buffer(channel_id, session.connection_id)
3391 .await?;
3392
3393 response.send(Ack {})?;
3394
3395 channel_buffer_updated(
3396 session.connection_id,
3397 left_buffer.connections,
3398 &proto::UpdateChannelBufferCollaborators {
3399 channel_id: channel_id.to_proto(),
3400 collaborators: left_buffer.collaborators,
3401 },
3402 &session.peer,
3403 );
3404
3405 Ok(())
3406}
3407
3408fn channel_buffer_updated<T: EnvelopedMessage>(
3409 sender_id: ConnectionId,
3410 collaborators: impl IntoIterator<Item = ConnectionId>,
3411 message: &T,
3412 peer: &Peer,
3413) {
3414 broadcast(Some(sender_id), collaborators, |peer_id| {
3415 peer.send(peer_id, message.clone())
3416 });
3417}
3418
3419fn send_notifications(
3420 connection_pool: &ConnectionPool,
3421 peer: &Peer,
3422 notifications: db::NotificationBatch,
3423) {
3424 for (user_id, notification) in notifications {
3425 for connection_id in connection_pool.user_connection_ids(user_id) {
3426 if let Err(error) = peer.send(
3427 connection_id,
3428 proto::AddNotification {
3429 notification: Some(notification.clone()),
3430 },
3431 ) {
3432 tracing::error!(
3433 "failed to send notification to {:?} {}",
3434 connection_id,
3435 error
3436 );
3437 }
3438 }
3439 }
3440}
3441
3442/// Send a message to the channel
3443async fn send_channel_message(
3444 request: proto::SendChannelMessage,
3445 response: Response<proto::SendChannelMessage>,
3446 session: Session,
3447) -> Result<()> {
3448 // Validate the message body.
3449 let body = request.body.trim().to_string();
3450 if body.len() > MAX_MESSAGE_LEN {
3451 return Err(anyhow!("message is too long"))?;
3452 }
3453 if body.is_empty() {
3454 return Err(anyhow!("message can't be blank"))?;
3455 }
3456
3457 // TODO: adjust mentions if body is trimmed
3458
3459 let timestamp = OffsetDateTime::now_utc();
3460 let nonce = request
3461 .nonce
3462 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3463
3464 let channel_id = ChannelId::from_proto(request.channel_id);
3465 let CreatedChannelMessage {
3466 message_id,
3467 participant_connection_ids,
3468 notifications,
3469 } = session
3470 .db()
3471 .await
3472 .create_channel_message(
3473 channel_id,
3474 session.user_id(),
3475 &body,
3476 &request.mentions,
3477 timestamp,
3478 nonce.clone().into(),
3479 request.reply_to_message_id.map(MessageId::from_proto),
3480 )
3481 .await?;
3482
3483 let message = proto::ChannelMessage {
3484 sender_id: session.user_id().to_proto(),
3485 id: message_id.to_proto(),
3486 body,
3487 mentions: request.mentions,
3488 timestamp: timestamp.unix_timestamp() as u64,
3489 nonce: Some(nonce),
3490 reply_to_message_id: request.reply_to_message_id,
3491 edited_at: None,
3492 };
3493 broadcast(
3494 Some(session.connection_id),
3495 participant_connection_ids.clone(),
3496 |connection| {
3497 session.peer.send(
3498 connection,
3499 proto::ChannelMessageSent {
3500 channel_id: channel_id.to_proto(),
3501 message: Some(message.clone()),
3502 },
3503 )
3504 },
3505 );
3506 response.send(proto::SendChannelMessageResponse {
3507 message: Some(message),
3508 })?;
3509
3510 let pool = &*session.connection_pool().await;
3511 let non_participants =
3512 pool.channel_connection_ids(channel_id)
3513 .filter_map(|(connection_id, _)| {
3514 if participant_connection_ids.contains(&connection_id) {
3515 None
3516 } else {
3517 Some(connection_id)
3518 }
3519 });
3520 broadcast(None, non_participants, |peer_id| {
3521 session.peer.send(
3522 peer_id,
3523 proto::UpdateChannels {
3524 latest_channel_message_ids: vec![proto::ChannelMessageId {
3525 channel_id: channel_id.to_proto(),
3526 message_id: message_id.to_proto(),
3527 }],
3528 ..Default::default()
3529 },
3530 )
3531 });
3532 send_notifications(pool, &session.peer, notifications);
3533
3534 Ok(())
3535}
3536
3537/// Delete a channel message
3538async fn remove_channel_message(
3539 request: proto::RemoveChannelMessage,
3540 response: Response<proto::RemoveChannelMessage>,
3541 session: Session,
3542) -> Result<()> {
3543 let channel_id = ChannelId::from_proto(request.channel_id);
3544 let message_id = MessageId::from_proto(request.message_id);
3545 let (connection_ids, existing_notification_ids) = session
3546 .db()
3547 .await
3548 .remove_channel_message(channel_id, message_id, session.user_id())
3549 .await?;
3550
3551 broadcast(
3552 Some(session.connection_id),
3553 connection_ids,
3554 move |connection| {
3555 session.peer.send(connection, request.clone())?;
3556
3557 for notification_id in &existing_notification_ids {
3558 session.peer.send(
3559 connection,
3560 proto::DeleteNotification {
3561 notification_id: (*notification_id).to_proto(),
3562 },
3563 )?;
3564 }
3565
3566 Ok(())
3567 },
3568 );
3569 response.send(proto::Ack {})?;
3570 Ok(())
3571}
3572
3573async fn update_channel_message(
3574 request: proto::UpdateChannelMessage,
3575 response: Response<proto::UpdateChannelMessage>,
3576 session: Session,
3577) -> Result<()> {
3578 let channel_id = ChannelId::from_proto(request.channel_id);
3579 let message_id = MessageId::from_proto(request.message_id);
3580 let updated_at = OffsetDateTime::now_utc();
3581 let UpdatedChannelMessage {
3582 message_id,
3583 participant_connection_ids,
3584 notifications,
3585 reply_to_message_id,
3586 timestamp,
3587 deleted_mention_notification_ids,
3588 updated_mention_notifications,
3589 } = session
3590 .db()
3591 .await
3592 .update_channel_message(
3593 channel_id,
3594 message_id,
3595 session.user_id(),
3596 request.body.as_str(),
3597 &request.mentions,
3598 updated_at,
3599 )
3600 .await?;
3601
3602 let nonce = request
3603 .nonce
3604 .clone()
3605 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3606
3607 let message = proto::ChannelMessage {
3608 sender_id: session.user_id().to_proto(),
3609 id: message_id.to_proto(),
3610 body: request.body.clone(),
3611 mentions: request.mentions.clone(),
3612 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3613 nonce: Some(nonce),
3614 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3615 edited_at: Some(updated_at.unix_timestamp() as u64),
3616 };
3617
3618 response.send(proto::Ack {})?;
3619
3620 let pool = &*session.connection_pool().await;
3621 broadcast(
3622 Some(session.connection_id),
3623 participant_connection_ids,
3624 |connection| {
3625 session.peer.send(
3626 connection,
3627 proto::ChannelMessageUpdate {
3628 channel_id: channel_id.to_proto(),
3629 message: Some(message.clone()),
3630 },
3631 )?;
3632
3633 for notification_id in &deleted_mention_notification_ids {
3634 session.peer.send(
3635 connection,
3636 proto::DeleteNotification {
3637 notification_id: (*notification_id).to_proto(),
3638 },
3639 )?;
3640 }
3641
3642 for notification in &updated_mention_notifications {
3643 session.peer.send(
3644 connection,
3645 proto::UpdateNotification {
3646 notification: Some(notification.clone()),
3647 },
3648 )?;
3649 }
3650
3651 Ok(())
3652 },
3653 );
3654
3655 send_notifications(pool, &session.peer, notifications);
3656
3657 Ok(())
3658}
3659
3660/// Mark a channel message as read
3661async fn acknowledge_channel_message(
3662 request: proto::AckChannelMessage,
3663 session: Session,
3664) -> Result<()> {
3665 let channel_id = ChannelId::from_proto(request.channel_id);
3666 let message_id = MessageId::from_proto(request.message_id);
3667 let notifications = session
3668 .db()
3669 .await
3670 .observe_channel_message(channel_id, session.user_id(), message_id)
3671 .await?;
3672 send_notifications(
3673 &*session.connection_pool().await,
3674 &session.peer,
3675 notifications,
3676 );
3677 Ok(())
3678}
3679
3680/// Mark a buffer version as synced
3681async fn acknowledge_buffer_version(
3682 request: proto::AckBufferOperation,
3683 session: Session,
3684) -> Result<()> {
3685 let buffer_id = BufferId::from_proto(request.buffer_id);
3686 session
3687 .db()
3688 .await
3689 .observe_buffer_version(
3690 buffer_id,
3691 session.user_id(),
3692 request.epoch as i32,
3693 &request.version,
3694 )
3695 .await?;
3696 Ok(())
3697}
3698
3699async fn count_language_model_tokens(
3700 request: proto::CountLanguageModelTokens,
3701 response: Response<proto::CountLanguageModelTokens>,
3702 session: Session,
3703 config: &Config,
3704) -> Result<()> {
3705 authorize_access_to_legacy_llm_endpoints(&session).await?;
3706
3707 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3708 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3709 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3710 };
3711
3712 session
3713 .app_state
3714 .rate_limiter
3715 .check(&*rate_limit, session.user_id())
3716 .await?;
3717
3718 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3719 Some(proto::LanguageModelProvider::Google) => {
3720 let api_key = config
3721 .google_ai_api_key
3722 .as_ref()
3723 .context("no Google AI API key configured on the server")?;
3724 google_ai::count_tokens(
3725 session.http_client.as_ref(),
3726 google_ai::API_URL,
3727 api_key,
3728 serde_json::from_str(&request.request)?,
3729 )
3730 .await?
3731 }
3732 _ => return Err(anyhow!("unsupported provider"))?,
3733 };
3734
3735 response.send(proto::CountLanguageModelTokensResponse {
3736 token_count: result.total_tokens as u32,
3737 })?;
3738
3739 Ok(())
3740}
3741
3742struct ZedProCountLanguageModelTokensRateLimit;
3743
3744impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3745 fn capacity(&self) -> usize {
3746 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3747 .ok()
3748 .and_then(|v| v.parse().ok())
3749 .unwrap_or(600) // Picked arbitrarily
3750 }
3751
3752 fn refill_duration(&self) -> chrono::Duration {
3753 chrono::Duration::hours(1)
3754 }
3755
3756 fn db_name(&self) -> &'static str {
3757 "zed-pro:count-language-model-tokens"
3758 }
3759}
3760
3761struct FreeCountLanguageModelTokensRateLimit;
3762
3763impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3764 fn capacity(&self) -> usize {
3765 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3766 .ok()
3767 .and_then(|v| v.parse().ok())
3768 .unwrap_or(600 / 10) // Picked arbitrarily
3769 }
3770
3771 fn refill_duration(&self) -> chrono::Duration {
3772 chrono::Duration::hours(1)
3773 }
3774
3775 fn db_name(&self) -> &'static str {
3776 "free:count-language-model-tokens"
3777 }
3778}
3779
3780struct ZedProComputeEmbeddingsRateLimit;
3781
3782impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3783 fn capacity(&self) -> usize {
3784 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3785 .ok()
3786 .and_then(|v| v.parse().ok())
3787 .unwrap_or(5000) // Picked arbitrarily
3788 }
3789
3790 fn refill_duration(&self) -> chrono::Duration {
3791 chrono::Duration::hours(1)
3792 }
3793
3794 fn db_name(&self) -> &'static str {
3795 "zed-pro:compute-embeddings"
3796 }
3797}
3798
3799struct FreeComputeEmbeddingsRateLimit;
3800
3801impl RateLimit for FreeComputeEmbeddingsRateLimit {
3802 fn capacity(&self) -> usize {
3803 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3804 .ok()
3805 .and_then(|v| v.parse().ok())
3806 .unwrap_or(5000 / 10) // Picked arbitrarily
3807 }
3808
3809 fn refill_duration(&self) -> chrono::Duration {
3810 chrono::Duration::hours(1)
3811 }
3812
3813 fn db_name(&self) -> &'static str {
3814 "free:compute-embeddings"
3815 }
3816}
3817
3818async fn compute_embeddings(
3819 request: proto::ComputeEmbeddings,
3820 response: Response<proto::ComputeEmbeddings>,
3821 session: Session,
3822 api_key: Option<Arc<str>>,
3823) -> Result<()> {
3824 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3825 authorize_access_to_legacy_llm_endpoints(&session).await?;
3826
3827 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3828 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3829 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3830 };
3831
3832 session
3833 .app_state
3834 .rate_limiter
3835 .check(&*rate_limit, session.user_id())
3836 .await?;
3837
3838 let embeddings = match request.model.as_str() {
3839 "openai/text-embedding-3-small" => {
3840 open_ai::embed(
3841 session.http_client.as_ref(),
3842 OPEN_AI_API_URL,
3843 &api_key,
3844 OpenAiEmbeddingModel::TextEmbedding3Small,
3845 request.texts.iter().map(|text| text.as_str()),
3846 )
3847 .await?
3848 }
3849 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3850 };
3851
3852 let embeddings = request
3853 .texts
3854 .iter()
3855 .map(|text| {
3856 let mut hasher = sha2::Sha256::new();
3857 hasher.update(text.as_bytes());
3858 let result = hasher.finalize();
3859 result.to_vec()
3860 })
3861 .zip(
3862 embeddings
3863 .data
3864 .into_iter()
3865 .map(|embedding| embedding.embedding),
3866 )
3867 .collect::<HashMap<_, _>>();
3868
3869 let db = session.db().await;
3870 db.save_embeddings(&request.model, &embeddings)
3871 .await
3872 .context("failed to save embeddings")
3873 .trace_err();
3874
3875 response.send(proto::ComputeEmbeddingsResponse {
3876 embeddings: embeddings
3877 .into_iter()
3878 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3879 .collect(),
3880 })?;
3881 Ok(())
3882}
3883
3884async fn get_cached_embeddings(
3885 request: proto::GetCachedEmbeddings,
3886 response: Response<proto::GetCachedEmbeddings>,
3887 session: Session,
3888) -> Result<()> {
3889 authorize_access_to_legacy_llm_endpoints(&session).await?;
3890
3891 let db = session.db().await;
3892 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3893
3894 response.send(proto::GetCachedEmbeddingsResponse {
3895 embeddings: embeddings
3896 .into_iter()
3897 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3898 .collect(),
3899 })?;
3900 Ok(())
3901}
3902
3903/// This is leftover from before the LLM service.
3904///
3905/// The endpoints protected by this check will be moved there eventually.
3906async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3907 if session.is_staff() {
3908 Ok(())
3909 } else {
3910 Err(anyhow!("permission denied"))?
3911 }
3912}
3913
3914/// Get a Supermaven API key for the user
3915async fn get_supermaven_api_key(
3916 _request: proto::GetSupermavenApiKey,
3917 response: Response<proto::GetSupermavenApiKey>,
3918 session: Session,
3919) -> Result<()> {
3920 let user_id: String = session.user_id().to_string();
3921 if !session.is_staff() {
3922 return Err(anyhow!("supermaven not enabled for this account"))?;
3923 }
3924
3925 let email = session
3926 .email()
3927 .ok_or_else(|| anyhow!("user must have an email"))?;
3928
3929 let supermaven_admin_api = session
3930 .supermaven_client
3931 .as_ref()
3932 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3933
3934 let result = supermaven_admin_api
3935 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3936 .await?;
3937
3938 response.send(proto::GetSupermavenApiKeyResponse {
3939 api_key: result.api_key,
3940 })?;
3941
3942 Ok(())
3943}
3944
3945/// Start receiving chat updates for a channel
3946async fn join_channel_chat(
3947 request: proto::JoinChannelChat,
3948 response: Response<proto::JoinChannelChat>,
3949 session: Session,
3950) -> Result<()> {
3951 let channel_id = ChannelId::from_proto(request.channel_id);
3952
3953 let db = session.db().await;
3954 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3955 .await?;
3956 let messages = db
3957 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3958 .await?;
3959 response.send(proto::JoinChannelChatResponse {
3960 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3961 messages,
3962 })?;
3963 Ok(())
3964}
3965
3966/// Stop receiving chat updates for a channel
3967async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3968 let channel_id = ChannelId::from_proto(request.channel_id);
3969 session
3970 .db()
3971 .await
3972 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3973 .await?;
3974 Ok(())
3975}
3976
3977/// Retrieve the chat history for a channel
3978async fn get_channel_messages(
3979 request: proto::GetChannelMessages,
3980 response: Response<proto::GetChannelMessages>,
3981 session: Session,
3982) -> Result<()> {
3983 let channel_id = ChannelId::from_proto(request.channel_id);
3984 let messages = session
3985 .db()
3986 .await
3987 .get_channel_messages(
3988 channel_id,
3989 session.user_id(),
3990 MESSAGE_COUNT_PER_PAGE,
3991 Some(MessageId::from_proto(request.before_message_id)),
3992 )
3993 .await?;
3994 response.send(proto::GetChannelMessagesResponse {
3995 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3996 messages,
3997 })?;
3998 Ok(())
3999}
4000
4001/// Retrieve specific chat messages
4002async fn get_channel_messages_by_id(
4003 request: proto::GetChannelMessagesById,
4004 response: Response<proto::GetChannelMessagesById>,
4005 session: Session,
4006) -> Result<()> {
4007 let message_ids = request
4008 .message_ids
4009 .iter()
4010 .map(|id| MessageId::from_proto(*id))
4011 .collect::<Vec<_>>();
4012 let messages = session
4013 .db()
4014 .await
4015 .get_channel_messages_by_id(session.user_id(), &message_ids)
4016 .await?;
4017 response.send(proto::GetChannelMessagesResponse {
4018 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4019 messages,
4020 })?;
4021 Ok(())
4022}
4023
4024/// Retrieve the current users notifications
4025async fn get_notifications(
4026 request: proto::GetNotifications,
4027 response: Response<proto::GetNotifications>,
4028 session: Session,
4029) -> Result<()> {
4030 let notifications = session
4031 .db()
4032 .await
4033 .get_notifications(
4034 session.user_id(),
4035 NOTIFICATION_COUNT_PER_PAGE,
4036 request.before_id.map(db::NotificationId::from_proto),
4037 )
4038 .await?;
4039 response.send(proto::GetNotificationsResponse {
4040 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4041 notifications,
4042 })?;
4043 Ok(())
4044}
4045
4046/// Mark notifications as read
4047async fn mark_notification_as_read(
4048 request: proto::MarkNotificationRead,
4049 response: Response<proto::MarkNotificationRead>,
4050 session: Session,
4051) -> Result<()> {
4052 let database = &session.db().await;
4053 let notifications = database
4054 .mark_notification_as_read_by_id(
4055 session.user_id(),
4056 NotificationId::from_proto(request.notification_id),
4057 )
4058 .await?;
4059 send_notifications(
4060 &*session.connection_pool().await,
4061 &session.peer,
4062 notifications,
4063 );
4064 response.send(proto::Ack {})?;
4065 Ok(())
4066}
4067
4068/// Get the current users information
4069async fn get_private_user_info(
4070 _request: proto::GetPrivateUserInfo,
4071 response: Response<proto::GetPrivateUserInfo>,
4072 session: Session,
4073) -> Result<()> {
4074 let db = session.db().await;
4075
4076 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4077 let user = db
4078 .get_user_by_id(session.user_id())
4079 .await?
4080 .ok_or_else(|| anyhow!("user not found"))?;
4081 let flags = db.get_user_flags(session.user_id()).await?;
4082
4083 response.send(proto::GetPrivateUserInfoResponse {
4084 metrics_id,
4085 staff: user.admin,
4086 flags,
4087 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4088 })?;
4089 Ok(())
4090}
4091
4092/// Accept the terms of service (tos) on behalf of the current user
4093async fn accept_terms_of_service(
4094 _request: proto::AcceptTermsOfService,
4095 response: Response<proto::AcceptTermsOfService>,
4096 session: Session,
4097) -> Result<()> {
4098 let db = session.db().await;
4099
4100 let accepted_tos_at = Utc::now();
4101 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4102 .await?;
4103
4104 response.send(proto::AcceptTermsOfServiceResponse {
4105 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4106 })?;
4107 Ok(())
4108}
4109
4110/// The minimum account age an account must have in order to use the LLM service.
4111pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4112
4113async fn get_llm_api_token(
4114 _request: proto::GetLlmToken,
4115 response: Response<proto::GetLlmToken>,
4116 session: Session,
4117) -> Result<()> {
4118 let db = session.db().await;
4119
4120 let flags = db.get_user_flags(session.user_id()).await?;
4121 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4122
4123 if !session.is_staff() && !has_language_models_feature_flag {
4124 Err(anyhow!("permission denied"))?
4125 }
4126
4127 let user_id = session.user_id();
4128 let user = db
4129 .get_user_by_id(user_id)
4130 .await?
4131 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4132
4133 if user.accepted_tos_at.is_none() {
4134 Err(anyhow!("terms of service not accepted"))?
4135 }
4136
4137 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4138 let billing_preferences = db.get_billing_preferences(user.id).await?;
4139
4140 let token = LlmTokenClaims::create(
4141 &user,
4142 session.is_staff(),
4143 billing_preferences,
4144 &flags,
4145 has_llm_subscription,
4146 session.current_plan(&db).await?,
4147 session.system_id.clone(),
4148 &session.app_state.config,
4149 )?;
4150 response.send(proto::GetLlmTokenResponse { token })?;
4151 Ok(())
4152}
4153
4154fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4155 let message = match message {
4156 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4157 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4158 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4159 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4160 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4161 code: frame.code.into(),
4162 reason: frame.reason,
4163 })),
4164 // We should never receive a frame while reading the message, according
4165 // to the `tungstenite` maintainers:
4166 //
4167 // > It cannot occur when you read messages from the WebSocket, but it
4168 // > can be used when you want to send the raw frames (e.g. you want to
4169 // > send the frames to the WebSocket without composing the full message first).
4170 // >
4171 // > — https://github.com/snapview/tungstenite-rs/issues/268
4172 TungsteniteMessage::Frame(_) => {
4173 bail!("received an unexpected frame while reading the message")
4174 }
4175 };
4176
4177 Ok(message)
4178}
4179
4180fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4181 match message {
4182 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4183 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4184 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4185 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4186 AxumMessage::Close(frame) => {
4187 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4188 code: frame.code.into(),
4189 reason: frame.reason,
4190 }))
4191 }
4192 }
4193}
4194
4195fn notify_membership_updated(
4196 connection_pool: &mut ConnectionPool,
4197 result: MembershipUpdated,
4198 user_id: UserId,
4199 peer: &Peer,
4200) {
4201 for membership in &result.new_channels.channel_memberships {
4202 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4203 }
4204 for channel_id in &result.removed_channels {
4205 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4206 }
4207
4208 let user_channels_update = proto::UpdateUserChannels {
4209 channel_memberships: result
4210 .new_channels
4211 .channel_memberships
4212 .iter()
4213 .map(|cm| proto::ChannelMembership {
4214 channel_id: cm.channel_id.to_proto(),
4215 role: cm.role.into(),
4216 })
4217 .collect(),
4218 ..Default::default()
4219 };
4220
4221 let mut update = build_channels_update(result.new_channels);
4222 update.delete_channels = result
4223 .removed_channels
4224 .into_iter()
4225 .map(|id| id.to_proto())
4226 .collect();
4227 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4228
4229 for connection_id in connection_pool.user_connection_ids(user_id) {
4230 peer.send(connection_id, user_channels_update.clone())
4231 .trace_err();
4232 peer.send(connection_id, update.clone()).trace_err();
4233 }
4234}
4235
4236fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4237 proto::UpdateUserChannels {
4238 channel_memberships: channels
4239 .channel_memberships
4240 .iter()
4241 .map(|m| proto::ChannelMembership {
4242 channel_id: m.channel_id.to_proto(),
4243 role: m.role.into(),
4244 })
4245 .collect(),
4246 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4247 observed_channel_message_id: channels.observed_channel_messages.clone(),
4248 }
4249}
4250
4251fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4252 let mut update = proto::UpdateChannels::default();
4253
4254 for channel in channels.channels {
4255 update.channels.push(channel.to_proto());
4256 }
4257
4258 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4259 update.latest_channel_message_ids = channels.latest_channel_messages;
4260
4261 for (channel_id, participants) in channels.channel_participants {
4262 update
4263 .channel_participants
4264 .push(proto::ChannelParticipants {
4265 channel_id: channel_id.to_proto(),
4266 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4267 });
4268 }
4269
4270 for channel in channels.invited_channels {
4271 update.channel_invitations.push(channel.to_proto());
4272 }
4273
4274 update
4275}
4276
4277fn build_initial_contacts_update(
4278 contacts: Vec<db::Contact>,
4279 pool: &ConnectionPool,
4280) -> proto::UpdateContacts {
4281 let mut update = proto::UpdateContacts::default();
4282
4283 for contact in contacts {
4284 match contact {
4285 db::Contact::Accepted { user_id, busy } => {
4286 update.contacts.push(contact_for_user(user_id, busy, pool));
4287 }
4288 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4289 db::Contact::Incoming { user_id } => {
4290 update
4291 .incoming_requests
4292 .push(proto::IncomingContactRequest {
4293 requester_id: user_id.to_proto(),
4294 })
4295 }
4296 }
4297 }
4298
4299 update
4300}
4301
4302fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4303 proto::Contact {
4304 user_id: user_id.to_proto(),
4305 online: pool.is_user_online(user_id),
4306 busy,
4307 }
4308}
4309
4310fn room_updated(room: &proto::Room, peer: &Peer) {
4311 broadcast(
4312 None,
4313 room.participants
4314 .iter()
4315 .filter_map(|participant| Some(participant.peer_id?.into())),
4316 |peer_id| {
4317 peer.send(
4318 peer_id,
4319 proto::RoomUpdated {
4320 room: Some(room.clone()),
4321 },
4322 )
4323 },
4324 );
4325}
4326
4327fn channel_updated(
4328 channel: &db::channel::Model,
4329 room: &proto::Room,
4330 peer: &Peer,
4331 pool: &ConnectionPool,
4332) {
4333 let participants = room
4334 .participants
4335 .iter()
4336 .map(|p| p.user_id)
4337 .collect::<Vec<_>>();
4338
4339 broadcast(
4340 None,
4341 pool.channel_connection_ids(channel.root_id())
4342 .filter_map(|(channel_id, role)| {
4343 role.can_see_channel(channel.visibility)
4344 .then_some(channel_id)
4345 }),
4346 |peer_id| {
4347 peer.send(
4348 peer_id,
4349 proto::UpdateChannels {
4350 channel_participants: vec![proto::ChannelParticipants {
4351 channel_id: channel.id.to_proto(),
4352 participant_user_ids: participants.clone(),
4353 }],
4354 ..Default::default()
4355 },
4356 )
4357 },
4358 );
4359}
4360
4361async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4362 let db = session.db().await;
4363
4364 let contacts = db.get_contacts(user_id).await?;
4365 let busy = db.is_user_busy(user_id).await?;
4366
4367 let pool = session.connection_pool().await;
4368 let updated_contact = contact_for_user(user_id, busy, &pool);
4369 for contact in contacts {
4370 if let db::Contact::Accepted {
4371 user_id: contact_user_id,
4372 ..
4373 } = contact
4374 {
4375 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4376 session
4377 .peer
4378 .send(
4379 contact_conn_id,
4380 proto::UpdateContacts {
4381 contacts: vec![updated_contact.clone()],
4382 remove_contacts: Default::default(),
4383 incoming_requests: Default::default(),
4384 remove_incoming_requests: Default::default(),
4385 outgoing_requests: Default::default(),
4386 remove_outgoing_requests: Default::default(),
4387 },
4388 )
4389 .trace_err();
4390 }
4391 }
4392 }
4393 Ok(())
4394}
4395
4396async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4397 let mut contacts_to_update = HashSet::default();
4398
4399 let room_id;
4400 let canceled_calls_to_user_ids;
4401 let livekit_room;
4402 let delete_livekit_room;
4403 let room;
4404 let channel;
4405
4406 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4407 contacts_to_update.insert(session.user_id());
4408
4409 for project in left_room.left_projects.values() {
4410 project_left(project, session);
4411 }
4412
4413 room_id = RoomId::from_proto(left_room.room.id);
4414 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4415 livekit_room = mem::take(&mut left_room.room.livekit_room);
4416 delete_livekit_room = left_room.deleted;
4417 room = mem::take(&mut left_room.room);
4418 channel = mem::take(&mut left_room.channel);
4419
4420 room_updated(&room, &session.peer);
4421 } else {
4422 return Ok(());
4423 }
4424
4425 if let Some(channel) = channel {
4426 channel_updated(
4427 &channel,
4428 &room,
4429 &session.peer,
4430 &*session.connection_pool().await,
4431 );
4432 }
4433
4434 {
4435 let pool = session.connection_pool().await;
4436 for canceled_user_id in canceled_calls_to_user_ids {
4437 for connection_id in pool.user_connection_ids(canceled_user_id) {
4438 session
4439 .peer
4440 .send(
4441 connection_id,
4442 proto::CallCanceled {
4443 room_id: room_id.to_proto(),
4444 },
4445 )
4446 .trace_err();
4447 }
4448 contacts_to_update.insert(canceled_user_id);
4449 }
4450 }
4451
4452 for contact_user_id in contacts_to_update {
4453 update_user_contacts(contact_user_id, session).await?;
4454 }
4455
4456 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4457 live_kit
4458 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4459 .await
4460 .trace_err();
4461
4462 if delete_livekit_room {
4463 live_kit.delete_room(livekit_room).await.trace_err();
4464 }
4465 }
4466
4467 Ok(())
4468}
4469
4470async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4471 let left_channel_buffers = session
4472 .db()
4473 .await
4474 .leave_channel_buffers(session.connection_id)
4475 .await?;
4476
4477 for left_buffer in left_channel_buffers {
4478 channel_buffer_updated(
4479 session.connection_id,
4480 left_buffer.connections,
4481 &proto::UpdateChannelBufferCollaborators {
4482 channel_id: left_buffer.channel_id.to_proto(),
4483 collaborators: left_buffer.collaborators,
4484 },
4485 &session.peer,
4486 );
4487 }
4488
4489 Ok(())
4490}
4491
4492fn project_left(project: &db::LeftProject, session: &Session) {
4493 for connection_id in &project.connection_ids {
4494 if project.should_unshare {
4495 session
4496 .peer
4497 .send(
4498 *connection_id,
4499 proto::UnshareProject {
4500 project_id: project.id.to_proto(),
4501 },
4502 )
4503 .trace_err();
4504 } else {
4505 session
4506 .peer
4507 .send(
4508 *connection_id,
4509 proto::RemoveProjectCollaborator {
4510 project_id: project.id.to_proto(),
4511 peer_id: Some(session.connection_id.into()),
4512 },
4513 )
4514 .trace_err();
4515 }
4516 }
4517}
4518
4519pub trait ResultExt {
4520 type Ok;
4521
4522 fn trace_err(self) -> Option<Self::Ok>;
4523}
4524
4525impl<T, E> ResultExt for Result<T, E>
4526where
4527 E: std::fmt::Debug,
4528{
4529 type Ok = T;
4530
4531 #[track_caller]
4532 fn trace_err(self) -> Option<T> {
4533 match self {
4534 Ok(value) => Some(value),
4535 Err(error) => {
4536 tracing::error!("{:?}", error);
4537 None
4538 }
4539 }
4540 }
4541}