1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 AppState, Config, Error, RateLimit, Result, auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14};
15use anyhow::{Context as _, anyhow, bail};
16use async_tungstenite::tungstenite::{
17 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
18};
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use http_client::HttpClient;
37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
38use reqwest_client::ReqwestClient;
39use rpc::proto::split_repository_update;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
45 stream::FuturesUnordered,
46};
47use prometheus::{IntGauge, register_int_gauge};
48use rpc::{
49 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 Arc, OnceLock,
67 atomic::{AtomicBool, Ordering::SeqCst},
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{MutexGuard, Semaphore, watch};
73use tower::ServiceBuilder;
74use tracing::{
75 Instrument,
76 field::{self},
77 info_span, instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
311 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
312 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
313 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
316 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
320 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
321 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
322 .add_request_handler(
323 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
324 )
325 .add_request_handler(
326 forward_read_only_project_request::<proto::LanguageServerIdForName>,
327 )
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
341 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
342 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
344 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
345 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
346 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
347 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
351 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
352 .add_request_handler(
353 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
354 )
355 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
356 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
357 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
358 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_request_handler(update_buffer)
364 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
368 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
369 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
370 .add_request_handler(get_users)
371 .add_request_handler(fuzzy_search_users)
372 .add_request_handler(request_contact)
373 .add_request_handler(remove_contact)
374 .add_request_handler(respond_to_contact_request)
375 .add_message_handler(subscribe_to_channels)
376 .add_request_handler(create_channel)
377 .add_request_handler(delete_channel)
378 .add_request_handler(invite_channel_member)
379 .add_request_handler(remove_channel_member)
380 .add_request_handler(set_channel_member_role)
381 .add_request_handler(set_channel_visibility)
382 .add_request_handler(rename_channel)
383 .add_request_handler(join_channel_buffer)
384 .add_request_handler(leave_channel_buffer)
385 .add_message_handler(update_channel_buffer)
386 .add_request_handler(rejoin_channel_buffers)
387 .add_request_handler(get_channel_members)
388 .add_request_handler(respond_to_channel_invite)
389 .add_request_handler(join_channel)
390 .add_request_handler(join_channel_chat)
391 .add_message_handler(leave_channel_chat)
392 .add_request_handler(send_channel_message)
393 .add_request_handler(remove_channel_message)
394 .add_request_handler(update_channel_message)
395 .add_request_handler(get_channel_messages)
396 .add_request_handler(get_channel_messages_by_id)
397 .add_request_handler(get_notifications)
398 .add_request_handler(mark_notification_as_read)
399 .add_request_handler(move_channel)
400 .add_request_handler(follow)
401 .add_message_handler(unfollow)
402 .add_message_handler(update_followers)
403 .add_request_handler(get_private_user_info)
404 .add_request_handler(get_llm_api_token)
405 .add_request_handler(accept_terms_of_service)
406 .add_message_handler(acknowledge_channel_message)
407 .add_message_handler(acknowledge_buffer_version)
408 .add_request_handler(get_supermaven_api_key)
409 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
410 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
411 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
412 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
413 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
415 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
416 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
417 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
418 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
419 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
420 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
421 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
422 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
423 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
424 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
425 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
426 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
427 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
428 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
429 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
430 .add_message_handler(update_context)
431 .add_request_handler({
432 let app_state = app_state.clone();
433 move |request, response, session| {
434 let app_state = app_state.clone();
435 async move {
436 count_language_model_tokens(request, response, session, &app_state.config)
437 .await
438 }
439 }
440 })
441 .add_request_handler(get_cached_embeddings)
442 .add_request_handler({
443 let app_state = app_state.clone();
444 move |request, response, session| {
445 compute_embeddings(
446 request,
447 response,
448 session,
449 app_state.config.openai_api_key.clone(),
450 )
451 }
452 });
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471 if let Some((room_ids, channel_ids)) = app_state
472 .db
473 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
474 .await
475 .trace_err()
476 {
477 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
478 tracing::info!(
479 stale_channel_buffer_count = channel_ids.len(),
480 "retrieved stale channel buffers"
481 );
482
483 for channel_id in channel_ids {
484 if let Some(refreshed_channel_buffer) = app_state
485 .db
486 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
487 .await
488 .trace_err()
489 {
490 for connection_id in refreshed_channel_buffer.connection_ids {
491 peer.send(
492 connection_id,
493 proto::UpdateChannelBufferCollaborators {
494 channel_id: channel_id.to_proto(),
495 collaborators: refreshed_channel_buffer
496 .collaborators
497 .clone(),
498 },
499 )
500 .trace_err();
501 }
502 }
503 }
504
505 for room_id in room_ids {
506 let mut contacts_to_update = HashSet::default();
507 let mut canceled_calls_to_user_ids = Vec::new();
508 let mut livekit_room = String::new();
509 let mut delete_livekit_room = false;
510
511 if let Some(mut refreshed_room) = app_state
512 .db
513 .clear_stale_room_participants(room_id, server_id)
514 .await
515 .trace_err()
516 {
517 tracing::info!(
518 room_id = room_id.0,
519 new_participant_count = refreshed_room.room.participants.len(),
520 "refreshed room"
521 );
522 room_updated(&refreshed_room.room, &peer);
523 if let Some(channel) = refreshed_room.channel.as_ref() {
524 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
525 }
526 contacts_to_update
527 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
528 contacts_to_update
529 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
530 canceled_calls_to_user_ids =
531 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
532 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
533 delete_livekit_room = refreshed_room.room.participants.is_empty();
534 }
535
536 {
537 let pool = pool.lock();
538 for canceled_user_id in canceled_calls_to_user_ids {
539 for connection_id in pool.user_connection_ids(canceled_user_id) {
540 peer.send(
541 connection_id,
542 proto::CallCanceled {
543 room_id: room_id.to_proto(),
544 },
545 )
546 .trace_err();
547 }
548 }
549 }
550
551 for user_id in contacts_to_update {
552 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
553 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
554 if let Some((busy, contacts)) = busy.zip(contacts) {
555 let pool = pool.lock();
556 let updated_contact = contact_for_user(user_id, busy, &pool);
557 for contact in contacts {
558 if let db::Contact::Accepted {
559 user_id: contact_user_id,
560 ..
561 } = contact
562 {
563 for contact_conn_id in
564 pool.user_connection_ids(contact_user_id)
565 {
566 peer.send(
567 contact_conn_id,
568 proto::UpdateContacts {
569 contacts: vec![updated_contact.clone()],
570 remove_contacts: Default::default(),
571 incoming_requests: Default::default(),
572 remove_incoming_requests: Default::default(),
573 outgoing_requests: Default::default(),
574 remove_outgoing_requests: Default::default(),
575 },
576 )
577 .trace_err();
578 }
579 }
580 }
581 }
582 }
583
584 if let Some(live_kit) = livekit_client.as_ref() {
585 if delete_livekit_room {
586 live_kit.delete_room(livekit_room).await.trace_err();
587 }
588 }
589 }
590 }
591
592 app_state
593 .db
594 .delete_stale_servers(&app_state.config.zed_environment, server_id)
595 .await
596 .trace_err();
597 }
598 .instrument(span),
599 );
600 Ok(())
601 }
602
603 pub fn teardown(&self) {
604 self.peer.teardown();
605 self.connection_pool.lock().reset();
606 let _ = self.teardown.send(true);
607 }
608
609 #[cfg(test)]
610 pub fn reset(&self, id: ServerId) {
611 self.teardown();
612 *self.id.lock() = id;
613 self.peer.reset(id.0 as u32);
614 let _ = self.teardown.send(false);
615 }
616
617 #[cfg(test)]
618 pub fn id(&self) -> ServerId {
619 *self.id.lock()
620 }
621
622 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
623 where
624 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
625 Fut: 'static + Send + Future<Output = Result<()>>,
626 M: EnvelopedMessage,
627 {
628 let prev_handler = self.handlers.insert(
629 TypeId::of::<M>(),
630 Box::new(move |envelope, session| {
631 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
632 let received_at = envelope.received_at;
633 tracing::info!("message received");
634 let start_time = Instant::now();
635 let future = (handler)(*envelope, session);
636 async move {
637 let result = future.await;
638 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
639 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
640 let queue_duration_ms = total_duration_ms - processing_duration_ms;
641 let payload_type = M::NAME;
642
643 match result {
644 Err(error) => {
645 tracing::error!(
646 ?error,
647 total_duration_ms,
648 processing_duration_ms,
649 queue_duration_ms,
650 payload_type,
651 "error handling message"
652 )
653 }
654 Ok(()) => tracing::info!(
655 total_duration_ms,
656 processing_duration_ms,
657 queue_duration_ms,
658 "finished handling message"
659 ),
660 }
661 }
662 .boxed()
663 }),
664 );
665 if prev_handler.is_some() {
666 panic!("registered a handler for the same message twice");
667 }
668 self
669 }
670
671 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
674 Fut: 'static + Send + Future<Output = Result<()>>,
675 M: EnvelopedMessage,
676 {
677 self.add_handler(move |envelope, session| handler(envelope.payload, session));
678 self
679 }
680
681 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
682 where
683 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
684 Fut: Send + Future<Output = Result<()>>,
685 M: RequestMessage,
686 {
687 let handler = Arc::new(handler);
688 self.add_handler(move |envelope, session| {
689 let receipt = envelope.receipt();
690 let handler = handler.clone();
691 async move {
692 let peer = session.peer.clone();
693 let responded = Arc::new(AtomicBool::default());
694 let response = Response {
695 peer: peer.clone(),
696 responded: responded.clone(),
697 receipt,
698 };
699 match (handler)(envelope.payload, response, session).await {
700 Ok(()) => {
701 if responded.load(std::sync::atomic::Ordering::SeqCst) {
702 Ok(())
703 } else {
704 Err(anyhow!("handler did not send a response"))?
705 }
706 }
707 Err(error) => {
708 let proto_err = match &error {
709 Error::Internal(err) => err.to_proto(),
710 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
711 };
712 peer.respond_with_error(receipt, proto_err)?;
713 Err(error)
714 }
715 }
716 }
717 })
718 }
719
720 pub fn handle_connection(
721 self: &Arc<Self>,
722 connection: Connection,
723 address: String,
724 principal: Principal,
725 zed_version: ZedVersion,
726 geoip_country_code: Option<String>,
727 system_id: Option<String>,
728 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
729 executor: Executor,
730 ) -> impl Future<Output = ()> + use<> {
731 let this = self.clone();
732 let span = info_span!("handle connection", %address,
733 connection_id=field::Empty,
734 user_id=field::Empty,
735 login=field::Empty,
736 impersonator=field::Empty,
737 geoip_country_code=field::Empty
738 );
739 principal.update_span(&span);
740 if let Some(country_code) = geoip_country_code.as_ref() {
741 span.record("geoip_country_code", country_code);
742 }
743
744 let mut teardown = self.teardown.subscribe();
745 async move {
746 if *teardown.borrow() {
747 tracing::error!("server is tearing down");
748 return
749 }
750 let (connection_id, handle_io, mut incoming_rx) = this
751 .peer
752 .add_connection(connection, {
753 let executor = executor.clone();
754 move |duration| executor.sleep(duration)
755 });
756 tracing::Span::current().record("connection_id", format!("{}", connection_id));
757
758 tracing::info!("connection opened");
759
760 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
761 let http_client = match ReqwestClient::user_agent(&user_agent) {
762 Ok(http_client) => Arc::new(http_client),
763 Err(error) => {
764 tracing::error!(?error, "failed to create HTTP client");
765 return;
766 }
767 };
768
769 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
770 supermaven_admin_api_key.to_string(),
771 http_client.clone(),
772 )));
773
774 let session = Session {
775 principal: principal.clone(),
776 connection_id,
777 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
778 peer: this.peer.clone(),
779 connection_pool: this.connection_pool.clone(),
780 app_state: this.app_state.clone(),
781 http_client,
782 geoip_country_code,
783 system_id,
784 _executor: executor.clone(),
785 supermaven_client,
786 };
787
788 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
789 tracing::error!(?error, "failed to send initial client update");
790 return;
791 }
792
793 let handle_io = handle_io.fuse();
794 futures::pin_mut!(handle_io);
795
796 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
797 // This prevents deadlocks when e.g., client A performs a request to client B and
798 // client B performs a request to client A. If both clients stop processing further
799 // messages until their respective request completes, they won't have a chance to
800 // respond to the other client's request and cause a deadlock.
801 //
802 // This arrangement ensures we will attempt to process earlier messages first, but fall
803 // back to processing messages arrived later in the spirit of making progress.
804 let mut foreground_message_handlers = FuturesUnordered::new();
805 let concurrent_handlers = Arc::new(Semaphore::new(256));
806 loop {
807 let next_message = async {
808 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
809 let message = incoming_rx.next().await;
810 (permit, message)
811 }.fuse();
812 futures::pin_mut!(next_message);
813 futures::select_biased! {
814 _ = teardown.changed().fuse() => return,
815 result = handle_io => {
816 if let Err(error) = result {
817 tracing::error!(?error, "error handling I/O");
818 }
819 break;
820 }
821 _ = foreground_message_handlers.next() => {}
822 next_message = next_message => {
823 let (permit, message) = next_message;
824 if let Some(message) = message {
825 let type_name = message.payload_type_name();
826 // note: we copy all the fields from the parent span so we can query them in the logs.
827 // (https://github.com/tokio-rs/tracing/issues/2670).
828 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
829 user_id=field::Empty,
830 login=field::Empty,
831 impersonator=field::Empty,
832 );
833 principal.update_span(&span);
834 let span_enter = span.enter();
835 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
836 let is_background = message.is_background();
837 let handle_message = (handler)(message, session.clone());
838 drop(span_enter);
839
840 let handle_message = async move {
841 handle_message.await;
842 drop(permit);
843 }.instrument(span);
844 if is_background {
845 executor.spawn_detached(handle_message);
846 } else {
847 foreground_message_handlers.push(handle_message);
848 }
849 } else {
850 tracing::error!("no message handler");
851 }
852 } else {
853 tracing::info!("connection closed");
854 break;
855 }
856 }
857 }
858 }
859
860 drop(foreground_message_handlers);
861 tracing::info!("signing out");
862 if let Err(error) = connection_lost(session, teardown, executor).await {
863 tracing::error!(?error, "error signing out");
864 }
865
866 }.instrument(span)
867 }
868
869 async fn send_initial_client_update(
870 &self,
871 connection_id: ConnectionId,
872 principal: &Principal,
873 zed_version: ZedVersion,
874 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
875 session: &Session,
876 ) -> Result<()> {
877 self.peer.send(
878 connection_id,
879 proto::Hello {
880 peer_id: Some(connection_id.into()),
881 },
882 )?;
883 tracing::info!("sent hello message");
884 if let Some(send_connection_id) = send_connection_id.take() {
885 let _ = send_connection_id.send(connection_id);
886 }
887
888 match principal {
889 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
890 if !user.connected_once {
891 self.peer.send(connection_id, proto::ShowContacts {})?;
892 self.app_state
893 .db
894 .set_user_connected_once(user.id, true)
895 .await?;
896 }
897
898 update_user_plan(user.id, session).await?;
899
900 let contacts = self.app_state.db.get_contacts(user.id).await?;
901
902 {
903 let mut pool = self.connection_pool.lock();
904 pool.add_connection(connection_id, user.id, user.admin, zed_version);
905 self.peer.send(
906 connection_id,
907 build_initial_contacts_update(contacts, &pool),
908 )?;
909 }
910
911 if should_auto_subscribe_to_channels(zed_version) {
912 subscribe_user_to_channels(user.id, session).await?;
913 }
914
915 if let Some(incoming_call) =
916 self.app_state.db.incoming_call_for_user(user.id).await?
917 {
918 self.peer.send(connection_id, incoming_call)?;
919 }
920
921 update_user_contacts(user.id, session).await?;
922 }
923 }
924
925 Ok(())
926 }
927
928 pub async fn invite_code_redeemed(
929 self: &Arc<Self>,
930 inviter_id: UserId,
931 invitee_id: UserId,
932 ) -> Result<()> {
933 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
934 if let Some(code) = &user.invite_code {
935 let pool = self.connection_pool.lock();
936 let invitee_contact = contact_for_user(invitee_id, false, &pool);
937 for connection_id in pool.user_connection_ids(inviter_id) {
938 self.peer.send(
939 connection_id,
940 proto::UpdateContacts {
941 contacts: vec![invitee_contact.clone()],
942 ..Default::default()
943 },
944 )?;
945 self.peer.send(
946 connection_id,
947 proto::UpdateInviteInfo {
948 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
949 count: user.invite_count as u32,
950 },
951 )?;
952 }
953 }
954 }
955 Ok(())
956 }
957
958 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
959 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
960 if let Some(invite_code) = &user.invite_code {
961 let pool = self.connection_pool.lock();
962 for connection_id in pool.user_connection_ids(user_id) {
963 self.peer.send(
964 connection_id,
965 proto::UpdateInviteInfo {
966 url: format!(
967 "{}{}",
968 self.app_state.config.invite_link_prefix, invite_code
969 ),
970 count: user.invite_count as u32,
971 },
972 )?;
973 }
974 }
975 }
976 Ok(())
977 }
978
979 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
980 let pool = self.connection_pool.lock();
981 for connection_id in pool.user_connection_ids(user_id) {
982 self.peer
983 .send(connection_id, proto::RefreshLlmToken {})
984 .trace_err();
985 }
986 }
987
988 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
989 ServerSnapshot {
990 connection_pool: ConnectionPoolGuard {
991 guard: self.connection_pool.lock(),
992 _not_send: PhantomData,
993 },
994 peer: &self.peer,
995 }
996 }
997}
998
999impl Deref for ConnectionPoolGuard<'_> {
1000 type Target = ConnectionPool;
1001
1002 fn deref(&self) -> &Self::Target {
1003 &self.guard
1004 }
1005}
1006
1007impl DerefMut for ConnectionPoolGuard<'_> {
1008 fn deref_mut(&mut self) -> &mut Self::Target {
1009 &mut self.guard
1010 }
1011}
1012
1013impl Drop for ConnectionPoolGuard<'_> {
1014 fn drop(&mut self) {
1015 #[cfg(test)]
1016 self.check_invariants();
1017 }
1018}
1019
1020fn broadcast<F>(
1021 sender_id: Option<ConnectionId>,
1022 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1023 mut f: F,
1024) where
1025 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1026{
1027 for receiver_id in receiver_ids {
1028 if Some(receiver_id) != sender_id {
1029 if let Err(error) = f(receiver_id) {
1030 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1031 }
1032 }
1033 }
1034}
1035
1036pub struct ProtocolVersion(u32);
1037
1038impl Header for ProtocolVersion {
1039 fn name() -> &'static HeaderName {
1040 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1042 }
1043
1044 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045 where
1046 Self: Sized,
1047 I: Iterator<Item = &'i axum::http::HeaderValue>,
1048 {
1049 let version = values
1050 .next()
1051 .ok_or_else(axum::headers::Error::invalid)?
1052 .to_str()
1053 .map_err(|_| axum::headers::Error::invalid())?
1054 .parse()
1055 .map_err(|_| axum::headers::Error::invalid())?;
1056 Ok(Self(version))
1057 }
1058
1059 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060 values.extend([self.0.to_string().parse().unwrap()]);
1061 }
1062}
1063
1064pub struct AppVersionHeader(SemanticVersion);
1065impl Header for AppVersionHeader {
1066 fn name() -> &'static HeaderName {
1067 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1068 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1069 }
1070
1071 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1072 where
1073 Self: Sized,
1074 I: Iterator<Item = &'i axum::http::HeaderValue>,
1075 {
1076 let version = values
1077 .next()
1078 .ok_or_else(axum::headers::Error::invalid)?
1079 .to_str()
1080 .map_err(|_| axum::headers::Error::invalid())?
1081 .parse()
1082 .map_err(|_| axum::headers::Error::invalid())?;
1083 Ok(Self(version))
1084 }
1085
1086 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1087 values.extend([self.0.to_string().parse().unwrap()]);
1088 }
1089}
1090
1091pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1092 Router::new()
1093 .route("/rpc", get(handle_websocket_request))
1094 .layer(
1095 ServiceBuilder::new()
1096 .layer(Extension(server.app_state.clone()))
1097 .layer(middleware::from_fn(auth::validate_header)),
1098 )
1099 .route("/metrics", get(handle_metrics))
1100 .layer(Extension(server))
1101}
1102
1103pub async fn handle_websocket_request(
1104 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1105 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1106 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1107 Extension(server): Extension<Arc<Server>>,
1108 Extension(principal): Extension<Principal>,
1109 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1110 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1111 ws: WebSocketUpgrade,
1112) -> axum::response::Response {
1113 if protocol_version != rpc::PROTOCOL_VERSION {
1114 return (
1115 StatusCode::UPGRADE_REQUIRED,
1116 "client must be upgraded".to_string(),
1117 )
1118 .into_response();
1119 }
1120
1121 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1122 return (
1123 StatusCode::UPGRADE_REQUIRED,
1124 "no version header found".to_string(),
1125 )
1126 .into_response();
1127 };
1128
1129 if !version.can_collaborate() {
1130 return (
1131 StatusCode::UPGRADE_REQUIRED,
1132 "client must be upgraded".to_string(),
1133 )
1134 .into_response();
1135 }
1136
1137 let socket_address = socket_address.to_string();
1138 ws.on_upgrade(move |socket| {
1139 let socket = socket
1140 .map_ok(to_tungstenite_message)
1141 .err_into()
1142 .with(|message| async move { to_axum_message(message) });
1143 let connection = Connection::new(Box::pin(socket));
1144 async move {
1145 server
1146 .handle_connection(
1147 connection,
1148 socket_address,
1149 principal,
1150 version,
1151 country_code_header.map(|header| header.to_string()),
1152 system_id_header.map(|header| header.to_string()),
1153 None,
1154 Executor::Production,
1155 )
1156 .await;
1157 }
1158 })
1159}
1160
1161pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1162 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163 let connections_metric = CONNECTIONS_METRIC
1164 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1165
1166 let connections = server
1167 .connection_pool
1168 .lock()
1169 .connections()
1170 .filter(|connection| !connection.admin)
1171 .count();
1172 connections_metric.set(connections as _);
1173
1174 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1176 register_int_gauge!(
1177 "shared_projects",
1178 "number of open projects with one or more guests"
1179 )
1180 .unwrap()
1181 });
1182
1183 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1184 shared_projects_metric.set(shared_projects as _);
1185
1186 let encoder = prometheus::TextEncoder::new();
1187 let metric_families = prometheus::gather();
1188 let encoded_metrics = encoder
1189 .encode_to_string(&metric_families)
1190 .map_err(|err| anyhow!("{}", err))?;
1191 Ok(encoded_metrics)
1192}
1193
1194#[instrument(err, skip(executor))]
1195async fn connection_lost(
1196 session: Session,
1197 mut teardown: watch::Receiver<bool>,
1198 executor: Executor,
1199) -> Result<()> {
1200 session.peer.disconnect(session.connection_id);
1201 session
1202 .connection_pool()
1203 .await
1204 .remove_connection(session.connection_id)?;
1205
1206 session
1207 .db()
1208 .await
1209 .connection_lost(session.connection_id)
1210 .await
1211 .trace_err();
1212
1213 futures::select_biased! {
1214 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1215
1216 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1217 leave_room_for_session(&session, session.connection_id).await.trace_err();
1218 leave_channel_buffers_for_session(&session)
1219 .await
1220 .trace_err();
1221
1222 if !session
1223 .connection_pool()
1224 .await
1225 .is_user_online(session.user_id())
1226 {
1227 let db = session.db().await;
1228 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1229 room_updated(&room, &session.peer);
1230 }
1231 }
1232
1233 update_user_contacts(session.user_id(), &session).await?;
1234 },
1235 _ = teardown.changed().fuse() => {}
1236 }
1237
1238 Ok(())
1239}
1240
1241/// Acknowledges a ping from a client, used to keep the connection alive.
1242async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1243 response.send(proto::Ack {})?;
1244 Ok(())
1245}
1246
1247/// Creates a new room for calling (outside of channels)
1248async fn create_room(
1249 _request: proto::CreateRoom,
1250 response: Response<proto::CreateRoom>,
1251 session: Session,
1252) -> Result<()> {
1253 let livekit_room = nanoid::nanoid!(30);
1254
1255 let live_kit_connection_info = util::maybe!(async {
1256 let live_kit = session.app_state.livekit_client.as_ref();
1257 let live_kit = live_kit?;
1258 let user_id = session.user_id().to_string();
1259
1260 let token = live_kit
1261 .room_token(&livekit_room, &user_id.to_string())
1262 .trace_err()?;
1263
1264 Some(proto::LiveKitConnectionInfo {
1265 server_url: live_kit.url().into(),
1266 token,
1267 can_publish: true,
1268 })
1269 })
1270 .await;
1271
1272 let room = session
1273 .db()
1274 .await
1275 .create_room(session.user_id(), session.connection_id, &livekit_room)
1276 .await?;
1277
1278 response.send(proto::CreateRoomResponse {
1279 room: Some(room.clone()),
1280 live_kit_connection_info,
1281 })?;
1282
1283 update_user_contacts(session.user_id(), &session).await?;
1284 Ok(())
1285}
1286
1287/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1288async fn join_room(
1289 request: proto::JoinRoom,
1290 response: Response<proto::JoinRoom>,
1291 session: Session,
1292) -> Result<()> {
1293 let room_id = RoomId::from_proto(request.id);
1294
1295 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1296
1297 if let Some(channel_id) = channel_id {
1298 return join_channel_internal(channel_id, Box::new(response), session).await;
1299 }
1300
1301 let joined_room = {
1302 let room = session
1303 .db()
1304 .await
1305 .join_room(room_id, session.user_id(), session.connection_id)
1306 .await?;
1307 room_updated(&room.room, &session.peer);
1308 room.into_inner()
1309 };
1310
1311 for connection_id in session
1312 .connection_pool()
1313 .await
1314 .user_connection_ids(session.user_id())
1315 {
1316 session
1317 .peer
1318 .send(
1319 connection_id,
1320 proto::CallCanceled {
1321 room_id: room_id.to_proto(),
1322 },
1323 )
1324 .trace_err();
1325 }
1326
1327 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1328 {
1329 live_kit
1330 .room_token(
1331 &joined_room.room.livekit_room,
1332 &session.user_id().to_string(),
1333 )
1334 .trace_err()
1335 .map(|token| proto::LiveKitConnectionInfo {
1336 server_url: live_kit.url().into(),
1337 token,
1338 can_publish: true,
1339 })
1340 } else {
1341 None
1342 };
1343
1344 response.send(proto::JoinRoomResponse {
1345 room: Some(joined_room.room),
1346 channel_id: None,
1347 live_kit_connection_info,
1348 })?;
1349
1350 update_user_contacts(session.user_id(), &session).await?;
1351 Ok(())
1352}
1353
1354/// Rejoin room is used to reconnect to a room after connection errors.
1355async fn rejoin_room(
1356 request: proto::RejoinRoom,
1357 response: Response<proto::RejoinRoom>,
1358 session: Session,
1359) -> Result<()> {
1360 let room;
1361 let channel;
1362 {
1363 let mut rejoined_room = session
1364 .db()
1365 .await
1366 .rejoin_room(request, session.user_id(), session.connection_id)
1367 .await?;
1368
1369 response.send(proto::RejoinRoomResponse {
1370 room: Some(rejoined_room.room.clone()),
1371 reshared_projects: rejoined_room
1372 .reshared_projects
1373 .iter()
1374 .map(|project| proto::ResharedProject {
1375 id: project.id.to_proto(),
1376 collaborators: project
1377 .collaborators
1378 .iter()
1379 .map(|collaborator| collaborator.to_proto())
1380 .collect(),
1381 })
1382 .collect(),
1383 rejoined_projects: rejoined_room
1384 .rejoined_projects
1385 .iter()
1386 .map(|rejoined_project| rejoined_project.to_proto())
1387 .collect(),
1388 })?;
1389 room_updated(&rejoined_room.room, &session.peer);
1390
1391 for project in &rejoined_room.reshared_projects {
1392 for collaborator in &project.collaborators {
1393 session
1394 .peer
1395 .send(
1396 collaborator.connection_id,
1397 proto::UpdateProjectCollaborator {
1398 project_id: project.id.to_proto(),
1399 old_peer_id: Some(project.old_connection_id.into()),
1400 new_peer_id: Some(session.connection_id.into()),
1401 },
1402 )
1403 .trace_err();
1404 }
1405
1406 broadcast(
1407 Some(session.connection_id),
1408 project
1409 .collaborators
1410 .iter()
1411 .map(|collaborator| collaborator.connection_id),
1412 |connection_id| {
1413 session.peer.forward_send(
1414 session.connection_id,
1415 connection_id,
1416 proto::UpdateProject {
1417 project_id: project.id.to_proto(),
1418 worktrees: project.worktrees.clone(),
1419 },
1420 )
1421 },
1422 );
1423 }
1424
1425 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1426
1427 let rejoined_room = rejoined_room.into_inner();
1428
1429 room = rejoined_room.room;
1430 channel = rejoined_room.channel;
1431 }
1432
1433 if let Some(channel) = channel {
1434 channel_updated(
1435 &channel,
1436 &room,
1437 &session.peer,
1438 &*session.connection_pool().await,
1439 );
1440 }
1441
1442 update_user_contacts(session.user_id(), &session).await?;
1443 Ok(())
1444}
1445
1446fn notify_rejoined_projects(
1447 rejoined_projects: &mut Vec<RejoinedProject>,
1448 session: &Session,
1449) -> Result<()> {
1450 for project in rejoined_projects.iter() {
1451 for collaborator in &project.collaborators {
1452 session
1453 .peer
1454 .send(
1455 collaborator.connection_id,
1456 proto::UpdateProjectCollaborator {
1457 project_id: project.id.to_proto(),
1458 old_peer_id: Some(project.old_connection_id.into()),
1459 new_peer_id: Some(session.connection_id.into()),
1460 },
1461 )
1462 .trace_err();
1463 }
1464 }
1465
1466 for project in rejoined_projects {
1467 for worktree in mem::take(&mut project.worktrees) {
1468 // Stream this worktree's entries.
1469 let message = proto::UpdateWorktree {
1470 project_id: project.id.to_proto(),
1471 worktree_id: worktree.id,
1472 abs_path: worktree.abs_path.clone(),
1473 root_name: worktree.root_name,
1474 updated_entries: worktree.updated_entries,
1475 removed_entries: worktree.removed_entries,
1476 scan_id: worktree.scan_id,
1477 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1478 updated_repositories: worktree.updated_repositories,
1479 removed_repositories: worktree.removed_repositories,
1480 };
1481 for update in proto::split_worktree_update(message) {
1482 session.peer.send(session.connection_id, update)?;
1483 }
1484
1485 // Stream this worktree's diagnostics.
1486 for summary in worktree.diagnostic_summaries {
1487 session.peer.send(
1488 session.connection_id,
1489 proto::UpdateDiagnosticSummary {
1490 project_id: project.id.to_proto(),
1491 worktree_id: worktree.id,
1492 summary: Some(summary),
1493 },
1494 )?;
1495 }
1496
1497 for settings_file in worktree.settings_files {
1498 session.peer.send(
1499 session.connection_id,
1500 proto::UpdateWorktreeSettings {
1501 project_id: project.id.to_proto(),
1502 worktree_id: worktree.id,
1503 path: settings_file.path,
1504 content: Some(settings_file.content),
1505 kind: Some(settings_file.kind.to_proto().into()),
1506 },
1507 )?;
1508 }
1509 }
1510
1511 for repository in mem::take(&mut project.updated_repositories) {
1512 for update in split_repository_update(repository) {
1513 session.peer.send(session.connection_id, update)?;
1514 }
1515 }
1516
1517 for id in mem::take(&mut project.removed_repositories) {
1518 session.peer.send(
1519 session.connection_id,
1520 proto::RemoveRepository {
1521 project_id: project.id.to_proto(),
1522 id,
1523 },
1524 )?;
1525 }
1526 }
1527
1528 Ok(())
1529}
1530
1531/// leave room disconnects from the room.
1532async fn leave_room(
1533 _: proto::LeaveRoom,
1534 response: Response<proto::LeaveRoom>,
1535 session: Session,
1536) -> Result<()> {
1537 leave_room_for_session(&session, session.connection_id).await?;
1538 response.send(proto::Ack {})?;
1539 Ok(())
1540}
1541
1542/// Updates the permissions of someone else in the room.
1543async fn set_room_participant_role(
1544 request: proto::SetRoomParticipantRole,
1545 response: Response<proto::SetRoomParticipantRole>,
1546 session: Session,
1547) -> Result<()> {
1548 let user_id = UserId::from_proto(request.user_id);
1549 let role = ChannelRole::from(request.role());
1550
1551 let (livekit_room, can_publish) = {
1552 let room = session
1553 .db()
1554 .await
1555 .set_room_participant_role(
1556 session.user_id(),
1557 RoomId::from_proto(request.room_id),
1558 user_id,
1559 role,
1560 )
1561 .await?;
1562
1563 let livekit_room = room.livekit_room.clone();
1564 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1565 room_updated(&room, &session.peer);
1566 (livekit_room, can_publish)
1567 };
1568
1569 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1570 live_kit
1571 .update_participant(
1572 livekit_room.clone(),
1573 request.user_id.to_string(),
1574 livekit_api::proto::ParticipantPermission {
1575 can_subscribe: true,
1576 can_publish,
1577 can_publish_data: can_publish,
1578 hidden: false,
1579 recorder: false,
1580 },
1581 )
1582 .await
1583 .trace_err();
1584 }
1585
1586 response.send(proto::Ack {})?;
1587 Ok(())
1588}
1589
1590/// Call someone else into the current room
1591async fn call(
1592 request: proto::Call,
1593 response: Response<proto::Call>,
1594 session: Session,
1595) -> Result<()> {
1596 let room_id = RoomId::from_proto(request.room_id);
1597 let calling_user_id = session.user_id();
1598 let calling_connection_id = session.connection_id;
1599 let called_user_id = UserId::from_proto(request.called_user_id);
1600 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1601 if !session
1602 .db()
1603 .await
1604 .has_contact(calling_user_id, called_user_id)
1605 .await?
1606 {
1607 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1608 }
1609
1610 let incoming_call = {
1611 let (room, incoming_call) = &mut *session
1612 .db()
1613 .await
1614 .call(
1615 room_id,
1616 calling_user_id,
1617 calling_connection_id,
1618 called_user_id,
1619 initial_project_id,
1620 )
1621 .await?;
1622 room_updated(room, &session.peer);
1623 mem::take(incoming_call)
1624 };
1625 update_user_contacts(called_user_id, &session).await?;
1626
1627 let mut calls = session
1628 .connection_pool()
1629 .await
1630 .user_connection_ids(called_user_id)
1631 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1632 .collect::<FuturesUnordered<_>>();
1633
1634 while let Some(call_response) = calls.next().await {
1635 match call_response.as_ref() {
1636 Ok(_) => {
1637 response.send(proto::Ack {})?;
1638 return Ok(());
1639 }
1640 Err(_) => {
1641 call_response.trace_err();
1642 }
1643 }
1644 }
1645
1646 {
1647 let room = session
1648 .db()
1649 .await
1650 .call_failed(room_id, called_user_id)
1651 .await?;
1652 room_updated(&room, &session.peer);
1653 }
1654 update_user_contacts(called_user_id, &session).await?;
1655
1656 Err(anyhow!("failed to ring user"))?
1657}
1658
1659/// Cancel an outgoing call.
1660async fn cancel_call(
1661 request: proto::CancelCall,
1662 response: Response<proto::CancelCall>,
1663 session: Session,
1664) -> Result<()> {
1665 let called_user_id = UserId::from_proto(request.called_user_id);
1666 let room_id = RoomId::from_proto(request.room_id);
1667 {
1668 let room = session
1669 .db()
1670 .await
1671 .cancel_call(room_id, session.connection_id, called_user_id)
1672 .await?;
1673 room_updated(&room, &session.peer);
1674 }
1675
1676 for connection_id in session
1677 .connection_pool()
1678 .await
1679 .user_connection_ids(called_user_id)
1680 {
1681 session
1682 .peer
1683 .send(
1684 connection_id,
1685 proto::CallCanceled {
1686 room_id: room_id.to_proto(),
1687 },
1688 )
1689 .trace_err();
1690 }
1691 response.send(proto::Ack {})?;
1692
1693 update_user_contacts(called_user_id, &session).await?;
1694 Ok(())
1695}
1696
1697/// Decline an incoming call.
1698async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1699 let room_id = RoomId::from_proto(message.room_id);
1700 {
1701 let room = session
1702 .db()
1703 .await
1704 .decline_call(Some(room_id), session.user_id())
1705 .await?
1706 .ok_or_else(|| anyhow!("failed to decline call"))?;
1707 room_updated(&room, &session.peer);
1708 }
1709
1710 for connection_id in session
1711 .connection_pool()
1712 .await
1713 .user_connection_ids(session.user_id())
1714 {
1715 session
1716 .peer
1717 .send(
1718 connection_id,
1719 proto::CallCanceled {
1720 room_id: room_id.to_proto(),
1721 },
1722 )
1723 .trace_err();
1724 }
1725 update_user_contacts(session.user_id(), &session).await?;
1726 Ok(())
1727}
1728
1729/// Updates other participants in the room with your current location.
1730async fn update_participant_location(
1731 request: proto::UpdateParticipantLocation,
1732 response: Response<proto::UpdateParticipantLocation>,
1733 session: Session,
1734) -> Result<()> {
1735 let room_id = RoomId::from_proto(request.room_id);
1736 let location = request
1737 .location
1738 .ok_or_else(|| anyhow!("invalid location"))?;
1739
1740 let db = session.db().await;
1741 let room = db
1742 .update_room_participant_location(room_id, session.connection_id, location)
1743 .await?;
1744
1745 room_updated(&room, &session.peer);
1746 response.send(proto::Ack {})?;
1747 Ok(())
1748}
1749
1750/// Share a project into the room.
1751async fn share_project(
1752 request: proto::ShareProject,
1753 response: Response<proto::ShareProject>,
1754 session: Session,
1755) -> Result<()> {
1756 let (project_id, room) = &*session
1757 .db()
1758 .await
1759 .share_project(
1760 RoomId::from_proto(request.room_id),
1761 session.connection_id,
1762 &request.worktrees,
1763 request.is_ssh_project,
1764 )
1765 .await?;
1766 response.send(proto::ShareProjectResponse {
1767 project_id: project_id.to_proto(),
1768 })?;
1769 room_updated(room, &session.peer);
1770
1771 Ok(())
1772}
1773
1774/// Unshare a project from the room.
1775async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1776 let project_id = ProjectId::from_proto(message.project_id);
1777 unshare_project_internal(project_id, session.connection_id, &session).await
1778}
1779
1780async fn unshare_project_internal(
1781 project_id: ProjectId,
1782 connection_id: ConnectionId,
1783 session: &Session,
1784) -> Result<()> {
1785 let delete = {
1786 let room_guard = session
1787 .db()
1788 .await
1789 .unshare_project(project_id, connection_id)
1790 .await?;
1791
1792 let (delete, room, guest_connection_ids) = &*room_guard;
1793
1794 let message = proto::UnshareProject {
1795 project_id: project_id.to_proto(),
1796 };
1797
1798 broadcast(
1799 Some(connection_id),
1800 guest_connection_ids.iter().copied(),
1801 |conn_id| session.peer.send(conn_id, message.clone()),
1802 );
1803 if let Some(room) = room {
1804 room_updated(room, &session.peer);
1805 }
1806
1807 *delete
1808 };
1809
1810 if delete {
1811 let db = session.db().await;
1812 db.delete_project(project_id).await?;
1813 }
1814
1815 Ok(())
1816}
1817
1818/// Join someone elses shared project.
1819async fn join_project(
1820 request: proto::JoinProject,
1821 response: Response<proto::JoinProject>,
1822 session: Session,
1823) -> Result<()> {
1824 let project_id = ProjectId::from_proto(request.project_id);
1825
1826 tracing::info!(%project_id, "join project");
1827
1828 let db = session.db().await;
1829 let (project, replica_id) = &mut *db
1830 .join_project(project_id, session.connection_id, session.user_id())
1831 .await?;
1832 drop(db);
1833 tracing::info!(%project_id, "join remote project");
1834 join_project_internal(response, session, project, replica_id)
1835}
1836
1837trait JoinProjectInternalResponse {
1838 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1839}
1840impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1841 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1842 Response::<proto::JoinProject>::send(self, result)
1843 }
1844}
1845
1846fn join_project_internal(
1847 response: impl JoinProjectInternalResponse,
1848 session: Session,
1849 project: &mut Project,
1850 replica_id: &ReplicaId,
1851) -> Result<()> {
1852 let collaborators = project
1853 .collaborators
1854 .iter()
1855 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1856 .map(|collaborator| collaborator.to_proto())
1857 .collect::<Vec<_>>();
1858 let project_id = project.id;
1859 let guest_user_id = session.user_id();
1860
1861 let worktrees = project
1862 .worktrees
1863 .iter()
1864 .map(|(id, worktree)| proto::WorktreeMetadata {
1865 id: *id,
1866 root_name: worktree.root_name.clone(),
1867 visible: worktree.visible,
1868 abs_path: worktree.abs_path.clone(),
1869 })
1870 .collect::<Vec<_>>();
1871
1872 let add_project_collaborator = proto::AddProjectCollaborator {
1873 project_id: project_id.to_proto(),
1874 collaborator: Some(proto::Collaborator {
1875 peer_id: Some(session.connection_id.into()),
1876 replica_id: replica_id.0 as u32,
1877 user_id: guest_user_id.to_proto(),
1878 is_host: false,
1879 }),
1880 };
1881
1882 for collaborator in &collaborators {
1883 session
1884 .peer
1885 .send(
1886 collaborator.peer_id.unwrap().into(),
1887 add_project_collaborator.clone(),
1888 )
1889 .trace_err();
1890 }
1891
1892 // First, we send the metadata associated with each worktree.
1893 response.send(proto::JoinProjectResponse {
1894 project_id: project.id.0 as u64,
1895 worktrees: worktrees.clone(),
1896 replica_id: replica_id.0 as u32,
1897 collaborators: collaborators.clone(),
1898 language_servers: project.language_servers.clone(),
1899 role: project.role.into(),
1900 })?;
1901
1902 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1903 // Stream this worktree's entries.
1904 let message = proto::UpdateWorktree {
1905 project_id: project_id.to_proto(),
1906 worktree_id,
1907 abs_path: worktree.abs_path.clone(),
1908 root_name: worktree.root_name,
1909 updated_entries: worktree.entries,
1910 removed_entries: Default::default(),
1911 scan_id: worktree.scan_id,
1912 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1913 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1914 removed_repositories: Default::default(),
1915 };
1916 for update in proto::split_worktree_update(message) {
1917 session.peer.send(session.connection_id, update.clone())?;
1918 }
1919
1920 // Stream this worktree's diagnostics.
1921 for summary in worktree.diagnostic_summaries {
1922 session.peer.send(
1923 session.connection_id,
1924 proto::UpdateDiagnosticSummary {
1925 project_id: project_id.to_proto(),
1926 worktree_id: worktree.id,
1927 summary: Some(summary),
1928 },
1929 )?;
1930 }
1931
1932 for settings_file in worktree.settings_files {
1933 session.peer.send(
1934 session.connection_id,
1935 proto::UpdateWorktreeSettings {
1936 project_id: project_id.to_proto(),
1937 worktree_id: worktree.id,
1938 path: settings_file.path,
1939 content: Some(settings_file.content),
1940 kind: Some(settings_file.kind.to_proto() as i32),
1941 },
1942 )?;
1943 }
1944 }
1945
1946 for repository in mem::take(&mut project.repositories) {
1947 for update in split_repository_update(repository) {
1948 session.peer.send(session.connection_id, update)?;
1949 }
1950 }
1951
1952 for language_server in &project.language_servers {
1953 session.peer.send(
1954 session.connection_id,
1955 proto::UpdateLanguageServer {
1956 project_id: project_id.to_proto(),
1957 language_server_id: language_server.id,
1958 variant: Some(
1959 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1960 proto::LspDiskBasedDiagnosticsUpdated {},
1961 ),
1962 ),
1963 },
1964 )?;
1965 }
1966
1967 Ok(())
1968}
1969
1970/// Leave someone elses shared project.
1971async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1972 let sender_id = session.connection_id;
1973 let project_id = ProjectId::from_proto(request.project_id);
1974 let db = session.db().await;
1975
1976 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1977 tracing::info!(
1978 %project_id,
1979 "leave project"
1980 );
1981
1982 project_left(project, &session);
1983 if let Some(room) = room {
1984 room_updated(room, &session.peer);
1985 }
1986
1987 Ok(())
1988}
1989
1990/// Updates other participants with changes to the project
1991async fn update_project(
1992 request: proto::UpdateProject,
1993 response: Response<proto::UpdateProject>,
1994 session: Session,
1995) -> Result<()> {
1996 let project_id = ProjectId::from_proto(request.project_id);
1997 let (room, guest_connection_ids) = &*session
1998 .db()
1999 .await
2000 .update_project(project_id, session.connection_id, &request.worktrees)
2001 .await?;
2002 broadcast(
2003 Some(session.connection_id),
2004 guest_connection_ids.iter().copied(),
2005 |connection_id| {
2006 session
2007 .peer
2008 .forward_send(session.connection_id, connection_id, request.clone())
2009 },
2010 );
2011 if let Some(room) = room {
2012 room_updated(room, &session.peer);
2013 }
2014 response.send(proto::Ack {})?;
2015
2016 Ok(())
2017}
2018
2019/// Updates other participants with changes to the worktree
2020async fn update_worktree(
2021 request: proto::UpdateWorktree,
2022 response: Response<proto::UpdateWorktree>,
2023 session: Session,
2024) -> Result<()> {
2025 let guest_connection_ids = session
2026 .db()
2027 .await
2028 .update_worktree(&request, session.connection_id)
2029 .await?;
2030
2031 broadcast(
2032 Some(session.connection_id),
2033 guest_connection_ids.iter().copied(),
2034 |connection_id| {
2035 session
2036 .peer
2037 .forward_send(session.connection_id, connection_id, request.clone())
2038 },
2039 );
2040 response.send(proto::Ack {})?;
2041 Ok(())
2042}
2043
2044async fn update_repository(
2045 request: proto::UpdateRepository,
2046 response: Response<proto::UpdateRepository>,
2047 session: Session,
2048) -> Result<()> {
2049 let guest_connection_ids = session
2050 .db()
2051 .await
2052 .update_repository(&request, session.connection_id)
2053 .await?;
2054
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, request.clone())
2062 },
2063 );
2064 response.send(proto::Ack {})?;
2065 Ok(())
2066}
2067
2068async fn remove_repository(
2069 request: proto::RemoveRepository,
2070 response: Response<proto::RemoveRepository>,
2071 session: Session,
2072) -> Result<()> {
2073 let guest_connection_ids = session
2074 .db()
2075 .await
2076 .remove_repository(&request, session.connection_id)
2077 .await?;
2078
2079 broadcast(
2080 Some(session.connection_id),
2081 guest_connection_ids.iter().copied(),
2082 |connection_id| {
2083 session
2084 .peer
2085 .forward_send(session.connection_id, connection_id, request.clone())
2086 },
2087 );
2088 response.send(proto::Ack {})?;
2089 Ok(())
2090}
2091
2092/// Updates other participants with changes to the diagnostics
2093async fn update_diagnostic_summary(
2094 message: proto::UpdateDiagnosticSummary,
2095 session: Session,
2096) -> Result<()> {
2097 let guest_connection_ids = session
2098 .db()
2099 .await
2100 .update_diagnostic_summary(&message, session.connection_id)
2101 .await?;
2102
2103 broadcast(
2104 Some(session.connection_id),
2105 guest_connection_ids.iter().copied(),
2106 |connection_id| {
2107 session
2108 .peer
2109 .forward_send(session.connection_id, connection_id, message.clone())
2110 },
2111 );
2112
2113 Ok(())
2114}
2115
2116/// Updates other participants with changes to the worktree settings
2117async fn update_worktree_settings(
2118 message: proto::UpdateWorktreeSettings,
2119 session: Session,
2120) -> Result<()> {
2121 let guest_connection_ids = session
2122 .db()
2123 .await
2124 .update_worktree_settings(&message, session.connection_id)
2125 .await?;
2126
2127 broadcast(
2128 Some(session.connection_id),
2129 guest_connection_ids.iter().copied(),
2130 |connection_id| {
2131 session
2132 .peer
2133 .forward_send(session.connection_id, connection_id, message.clone())
2134 },
2135 );
2136
2137 Ok(())
2138}
2139
2140/// Notify other participants that a language server has started.
2141async fn start_language_server(
2142 request: proto::StartLanguageServer,
2143 session: Session,
2144) -> Result<()> {
2145 let guest_connection_ids = session
2146 .db()
2147 .await
2148 .start_language_server(&request, session.connection_id)
2149 .await?;
2150
2151 broadcast(
2152 Some(session.connection_id),
2153 guest_connection_ids.iter().copied(),
2154 |connection_id| {
2155 session
2156 .peer
2157 .forward_send(session.connection_id, connection_id, request.clone())
2158 },
2159 );
2160 Ok(())
2161}
2162
2163/// Notify other participants that a language server has changed.
2164async fn update_language_server(
2165 request: proto::UpdateLanguageServer,
2166 session: Session,
2167) -> Result<()> {
2168 let project_id = ProjectId::from_proto(request.project_id);
2169 let project_connection_ids = session
2170 .db()
2171 .await
2172 .project_connection_ids(project_id, session.connection_id, true)
2173 .await?;
2174 broadcast(
2175 Some(session.connection_id),
2176 project_connection_ids.iter().copied(),
2177 |connection_id| {
2178 session
2179 .peer
2180 .forward_send(session.connection_id, connection_id, request.clone())
2181 },
2182 );
2183 Ok(())
2184}
2185
2186/// forward a project request to the host. These requests should be read only
2187/// as guests are allowed to send them.
2188async fn forward_read_only_project_request<T>(
2189 request: T,
2190 response: Response<T>,
2191 session: Session,
2192) -> Result<()>
2193where
2194 T: EntityMessage + RequestMessage,
2195{
2196 let project_id = ProjectId::from_proto(request.remote_entity_id());
2197 let host_connection_id = session
2198 .db()
2199 .await
2200 .host_for_read_only_project_request(project_id, session.connection_id)
2201 .await?;
2202 let payload = session
2203 .peer
2204 .forward_request(session.connection_id, host_connection_id, request)
2205 .await?;
2206 response.send(payload)?;
2207 Ok(())
2208}
2209
2210async fn forward_find_search_candidates_request(
2211 request: proto::FindSearchCandidates,
2212 response: Response<proto::FindSearchCandidates>,
2213 session: Session,
2214) -> Result<()> {
2215 let project_id = ProjectId::from_proto(request.remote_entity_id());
2216 let host_connection_id = session
2217 .db()
2218 .await
2219 .host_for_read_only_project_request(project_id, session.connection_id)
2220 .await?;
2221 let payload = session
2222 .peer
2223 .forward_request(session.connection_id, host_connection_id, request)
2224 .await?;
2225 response.send(payload)?;
2226 Ok(())
2227}
2228
2229/// forward a project request to the host. These requests are disallowed
2230/// for guests.
2231async fn forward_mutating_project_request<T>(
2232 request: T,
2233 response: Response<T>,
2234 session: Session,
2235) -> Result<()>
2236where
2237 T: EntityMessage + RequestMessage,
2238{
2239 let project_id = ProjectId::from_proto(request.remote_entity_id());
2240
2241 let host_connection_id = session
2242 .db()
2243 .await
2244 .host_for_mutating_project_request(project_id, session.connection_id)
2245 .await?;
2246 let payload = session
2247 .peer
2248 .forward_request(session.connection_id, host_connection_id, request)
2249 .await?;
2250 response.send(payload)?;
2251 Ok(())
2252}
2253
2254/// Notify other participants that a new buffer has been created
2255async fn create_buffer_for_peer(
2256 request: proto::CreateBufferForPeer,
2257 session: Session,
2258) -> Result<()> {
2259 session
2260 .db()
2261 .await
2262 .check_user_is_project_host(
2263 ProjectId::from_proto(request.project_id),
2264 session.connection_id,
2265 )
2266 .await?;
2267 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2268 session
2269 .peer
2270 .forward_send(session.connection_id, peer_id.into(), request)?;
2271 Ok(())
2272}
2273
2274/// Notify other participants that a buffer has been updated. This is
2275/// allowed for guests as long as the update is limited to selections.
2276async fn update_buffer(
2277 request: proto::UpdateBuffer,
2278 response: Response<proto::UpdateBuffer>,
2279 session: Session,
2280) -> Result<()> {
2281 let project_id = ProjectId::from_proto(request.project_id);
2282 let mut capability = Capability::ReadOnly;
2283
2284 for op in request.operations.iter() {
2285 match op.variant {
2286 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2287 Some(_) => capability = Capability::ReadWrite,
2288 }
2289 }
2290
2291 let host = {
2292 let guard = session
2293 .db()
2294 .await
2295 .connections_for_buffer_update(project_id, session.connection_id, capability)
2296 .await?;
2297
2298 let (host, guests) = &*guard;
2299
2300 broadcast(
2301 Some(session.connection_id),
2302 guests.clone(),
2303 |connection_id| {
2304 session
2305 .peer
2306 .forward_send(session.connection_id, connection_id, request.clone())
2307 },
2308 );
2309
2310 *host
2311 };
2312
2313 if host != session.connection_id {
2314 session
2315 .peer
2316 .forward_request(session.connection_id, host, request.clone())
2317 .await?;
2318 }
2319
2320 response.send(proto::Ack {})?;
2321 Ok(())
2322}
2323
2324async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2325 let project_id = ProjectId::from_proto(message.project_id);
2326
2327 let operation = message.operation.as_ref().context("invalid operation")?;
2328 let capability = match operation.variant.as_ref() {
2329 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2330 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2331 match buffer_op.variant {
2332 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2333 Capability::ReadOnly
2334 }
2335 _ => Capability::ReadWrite,
2336 }
2337 } else {
2338 Capability::ReadWrite
2339 }
2340 }
2341 Some(_) => Capability::ReadWrite,
2342 None => Capability::ReadOnly,
2343 };
2344
2345 let guard = session
2346 .db()
2347 .await
2348 .connections_for_buffer_update(project_id, session.connection_id, capability)
2349 .await?;
2350
2351 let (host, guests) = &*guard;
2352
2353 broadcast(
2354 Some(session.connection_id),
2355 guests.iter().chain([host]).copied(),
2356 |connection_id| {
2357 session
2358 .peer
2359 .forward_send(session.connection_id, connection_id, message.clone())
2360 },
2361 );
2362
2363 Ok(())
2364}
2365
2366/// Notify other participants that a project has been updated.
2367async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2368 request: T,
2369 session: Session,
2370) -> Result<()> {
2371 let project_id = ProjectId::from_proto(request.remote_entity_id());
2372 let project_connection_ids = session
2373 .db()
2374 .await
2375 .project_connection_ids(project_id, session.connection_id, false)
2376 .await?;
2377
2378 broadcast(
2379 Some(session.connection_id),
2380 project_connection_ids.iter().copied(),
2381 |connection_id| {
2382 session
2383 .peer
2384 .forward_send(session.connection_id, connection_id, request.clone())
2385 },
2386 );
2387 Ok(())
2388}
2389
2390/// Start following another user in a call.
2391async fn follow(
2392 request: proto::Follow,
2393 response: Response<proto::Follow>,
2394 session: Session,
2395) -> Result<()> {
2396 let room_id = RoomId::from_proto(request.room_id);
2397 let project_id = request.project_id.map(ProjectId::from_proto);
2398 let leader_id = request
2399 .leader_id
2400 .ok_or_else(|| anyhow!("invalid leader id"))?
2401 .into();
2402 let follower_id = session.connection_id;
2403
2404 session
2405 .db()
2406 .await
2407 .check_room_participants(room_id, leader_id, session.connection_id)
2408 .await?;
2409
2410 let response_payload = session
2411 .peer
2412 .forward_request(session.connection_id, leader_id, request)
2413 .await?;
2414 response.send(response_payload)?;
2415
2416 if let Some(project_id) = project_id {
2417 let room = session
2418 .db()
2419 .await
2420 .follow(room_id, project_id, leader_id, follower_id)
2421 .await?;
2422 room_updated(&room, &session.peer);
2423 }
2424
2425 Ok(())
2426}
2427
2428/// Stop following another user in a call.
2429async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2430 let room_id = RoomId::from_proto(request.room_id);
2431 let project_id = request.project_id.map(ProjectId::from_proto);
2432 let leader_id = request
2433 .leader_id
2434 .ok_or_else(|| anyhow!("invalid leader id"))?
2435 .into();
2436 let follower_id = session.connection_id;
2437
2438 session
2439 .db()
2440 .await
2441 .check_room_participants(room_id, leader_id, session.connection_id)
2442 .await?;
2443
2444 session
2445 .peer
2446 .forward_send(session.connection_id, leader_id, request)?;
2447
2448 if let Some(project_id) = project_id {
2449 let room = session
2450 .db()
2451 .await
2452 .unfollow(room_id, project_id, leader_id, follower_id)
2453 .await?;
2454 room_updated(&room, &session.peer);
2455 }
2456
2457 Ok(())
2458}
2459
2460/// Notify everyone following you of your current location.
2461async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2462 let room_id = RoomId::from_proto(request.room_id);
2463 let database = session.db.lock().await;
2464
2465 let connection_ids = if let Some(project_id) = request.project_id {
2466 let project_id = ProjectId::from_proto(project_id);
2467 database
2468 .project_connection_ids(project_id, session.connection_id, true)
2469 .await?
2470 } else {
2471 database
2472 .room_connection_ids(room_id, session.connection_id)
2473 .await?
2474 };
2475
2476 // For now, don't send view update messages back to that view's current leader.
2477 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2478 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2479 _ => None,
2480 });
2481
2482 for connection_id in connection_ids.iter().cloned() {
2483 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2484 session
2485 .peer
2486 .forward_send(session.connection_id, connection_id, request.clone())?;
2487 }
2488 }
2489 Ok(())
2490}
2491
2492/// Get public data about users.
2493async fn get_users(
2494 request: proto::GetUsers,
2495 response: Response<proto::GetUsers>,
2496 session: Session,
2497) -> Result<()> {
2498 let user_ids = request
2499 .user_ids
2500 .into_iter()
2501 .map(UserId::from_proto)
2502 .collect();
2503 let users = session
2504 .db()
2505 .await
2506 .get_users_by_ids(user_ids)
2507 .await?
2508 .into_iter()
2509 .map(|user| proto::User {
2510 id: user.id.to_proto(),
2511 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2512 github_login: user.github_login,
2513 email: user.email_address,
2514 name: user.name,
2515 })
2516 .collect();
2517 response.send(proto::UsersResponse { users })?;
2518 Ok(())
2519}
2520
2521/// Search for users (to invite) buy Github login
2522async fn fuzzy_search_users(
2523 request: proto::FuzzySearchUsers,
2524 response: Response<proto::FuzzySearchUsers>,
2525 session: Session,
2526) -> Result<()> {
2527 let query = request.query;
2528 let users = match query.len() {
2529 0 => vec![],
2530 1 | 2 => session
2531 .db()
2532 .await
2533 .get_user_by_github_login(&query)
2534 .await?
2535 .into_iter()
2536 .collect(),
2537 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2538 };
2539 let users = users
2540 .into_iter()
2541 .filter(|user| user.id != session.user_id())
2542 .map(|user| proto::User {
2543 id: user.id.to_proto(),
2544 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2545 github_login: user.github_login,
2546 name: user.name,
2547 email: user.email_address,
2548 })
2549 .collect();
2550 response.send(proto::UsersResponse { users })?;
2551 Ok(())
2552}
2553
2554/// Send a contact request to another user.
2555async fn request_contact(
2556 request: proto::RequestContact,
2557 response: Response<proto::RequestContact>,
2558 session: Session,
2559) -> Result<()> {
2560 let requester_id = session.user_id();
2561 let responder_id = UserId::from_proto(request.responder_id);
2562 if requester_id == responder_id {
2563 return Err(anyhow!("cannot add yourself as a contact"))?;
2564 }
2565
2566 let notifications = session
2567 .db()
2568 .await
2569 .send_contact_request(requester_id, responder_id)
2570 .await?;
2571
2572 // Update outgoing contact requests of requester
2573 let mut update = proto::UpdateContacts::default();
2574 update.outgoing_requests.push(responder_id.to_proto());
2575 for connection_id in session
2576 .connection_pool()
2577 .await
2578 .user_connection_ids(requester_id)
2579 {
2580 session.peer.send(connection_id, update.clone())?;
2581 }
2582
2583 // Update incoming contact requests of responder
2584 let mut update = proto::UpdateContacts::default();
2585 update
2586 .incoming_requests
2587 .push(proto::IncomingContactRequest {
2588 requester_id: requester_id.to_proto(),
2589 });
2590 let connection_pool = session.connection_pool().await;
2591 for connection_id in connection_pool.user_connection_ids(responder_id) {
2592 session.peer.send(connection_id, update.clone())?;
2593 }
2594
2595 send_notifications(&connection_pool, &session.peer, notifications);
2596
2597 response.send(proto::Ack {})?;
2598 Ok(())
2599}
2600
2601/// Accept or decline a contact request
2602async fn respond_to_contact_request(
2603 request: proto::RespondToContactRequest,
2604 response: Response<proto::RespondToContactRequest>,
2605 session: Session,
2606) -> Result<()> {
2607 let responder_id = session.user_id();
2608 let requester_id = UserId::from_proto(request.requester_id);
2609 let db = session.db().await;
2610 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2611 db.dismiss_contact_notification(responder_id, requester_id)
2612 .await?;
2613 } else {
2614 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2615
2616 let notifications = db
2617 .respond_to_contact_request(responder_id, requester_id, accept)
2618 .await?;
2619 let requester_busy = db.is_user_busy(requester_id).await?;
2620 let responder_busy = db.is_user_busy(responder_id).await?;
2621
2622 let pool = session.connection_pool().await;
2623 // Update responder with new contact
2624 let mut update = proto::UpdateContacts::default();
2625 if accept {
2626 update
2627 .contacts
2628 .push(contact_for_user(requester_id, requester_busy, &pool));
2629 }
2630 update
2631 .remove_incoming_requests
2632 .push(requester_id.to_proto());
2633 for connection_id in pool.user_connection_ids(responder_id) {
2634 session.peer.send(connection_id, update.clone())?;
2635 }
2636
2637 // Update requester with new contact
2638 let mut update = proto::UpdateContacts::default();
2639 if accept {
2640 update
2641 .contacts
2642 .push(contact_for_user(responder_id, responder_busy, &pool));
2643 }
2644 update
2645 .remove_outgoing_requests
2646 .push(responder_id.to_proto());
2647
2648 for connection_id in pool.user_connection_ids(requester_id) {
2649 session.peer.send(connection_id, update.clone())?;
2650 }
2651
2652 send_notifications(&pool, &session.peer, notifications);
2653 }
2654
2655 response.send(proto::Ack {})?;
2656 Ok(())
2657}
2658
2659/// Remove a contact.
2660async fn remove_contact(
2661 request: proto::RemoveContact,
2662 response: Response<proto::RemoveContact>,
2663 session: Session,
2664) -> Result<()> {
2665 let requester_id = session.user_id();
2666 let responder_id = UserId::from_proto(request.user_id);
2667 let db = session.db().await;
2668 let (contact_accepted, deleted_notification_id) =
2669 db.remove_contact(requester_id, responder_id).await?;
2670
2671 let pool = session.connection_pool().await;
2672 // Update outgoing contact requests of requester
2673 let mut update = proto::UpdateContacts::default();
2674 if contact_accepted {
2675 update.remove_contacts.push(responder_id.to_proto());
2676 } else {
2677 update
2678 .remove_outgoing_requests
2679 .push(responder_id.to_proto());
2680 }
2681 for connection_id in pool.user_connection_ids(requester_id) {
2682 session.peer.send(connection_id, update.clone())?;
2683 }
2684
2685 // Update incoming contact requests of responder
2686 let mut update = proto::UpdateContacts::default();
2687 if contact_accepted {
2688 update.remove_contacts.push(requester_id.to_proto());
2689 } else {
2690 update
2691 .remove_incoming_requests
2692 .push(requester_id.to_proto());
2693 }
2694 for connection_id in pool.user_connection_ids(responder_id) {
2695 session.peer.send(connection_id, update.clone())?;
2696 if let Some(notification_id) = deleted_notification_id {
2697 session.peer.send(
2698 connection_id,
2699 proto::DeleteNotification {
2700 notification_id: notification_id.to_proto(),
2701 },
2702 )?;
2703 }
2704 }
2705
2706 response.send(proto::Ack {})?;
2707 Ok(())
2708}
2709
2710fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2711 version.0.minor() < 139
2712}
2713
2714async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2715 let plan = session.current_plan(&session.db().await).await?;
2716
2717 session
2718 .peer
2719 .send(
2720 session.connection_id,
2721 proto::UpdateUserPlan { plan: plan.into() },
2722 )
2723 .trace_err();
2724
2725 Ok(())
2726}
2727
2728async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2729 subscribe_user_to_channels(session.user_id(), &session).await?;
2730 Ok(())
2731}
2732
2733async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2734 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2735 let mut pool = session.connection_pool().await;
2736 for membership in &channels_for_user.channel_memberships {
2737 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2738 }
2739 session.peer.send(
2740 session.connection_id,
2741 build_update_user_channels(&channels_for_user),
2742 )?;
2743 session.peer.send(
2744 session.connection_id,
2745 build_channels_update(channels_for_user),
2746 )?;
2747 Ok(())
2748}
2749
2750/// Creates a new channel.
2751async fn create_channel(
2752 request: proto::CreateChannel,
2753 response: Response<proto::CreateChannel>,
2754 session: Session,
2755) -> Result<()> {
2756 let db = session.db().await;
2757
2758 let parent_id = request.parent_id.map(ChannelId::from_proto);
2759 let (channel, membership) = db
2760 .create_channel(&request.name, parent_id, session.user_id())
2761 .await?;
2762
2763 let root_id = channel.root_id();
2764 let channel = Channel::from_model(channel);
2765
2766 response.send(proto::CreateChannelResponse {
2767 channel: Some(channel.to_proto()),
2768 parent_id: request.parent_id,
2769 })?;
2770
2771 let mut connection_pool = session.connection_pool().await;
2772 if let Some(membership) = membership {
2773 connection_pool.subscribe_to_channel(
2774 membership.user_id,
2775 membership.channel_id,
2776 membership.role,
2777 );
2778 let update = proto::UpdateUserChannels {
2779 channel_memberships: vec![proto::ChannelMembership {
2780 channel_id: membership.channel_id.to_proto(),
2781 role: membership.role.into(),
2782 }],
2783 ..Default::default()
2784 };
2785 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2786 session.peer.send(connection_id, update.clone())?;
2787 }
2788 }
2789
2790 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2791 if !role.can_see_channel(channel.visibility) {
2792 continue;
2793 }
2794
2795 let update = proto::UpdateChannels {
2796 channels: vec![channel.to_proto()],
2797 ..Default::default()
2798 };
2799 session.peer.send(connection_id, update.clone())?;
2800 }
2801
2802 Ok(())
2803}
2804
2805/// Delete a channel
2806async fn delete_channel(
2807 request: proto::DeleteChannel,
2808 response: Response<proto::DeleteChannel>,
2809 session: Session,
2810) -> Result<()> {
2811 let db = session.db().await;
2812
2813 let channel_id = request.channel_id;
2814 let (root_channel, removed_channels) = db
2815 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2816 .await?;
2817 response.send(proto::Ack {})?;
2818
2819 // Notify members of removed channels
2820 let mut update = proto::UpdateChannels::default();
2821 update
2822 .delete_channels
2823 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2824
2825 let connection_pool = session.connection_pool().await;
2826 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2827 session.peer.send(connection_id, update.clone())?;
2828 }
2829
2830 Ok(())
2831}
2832
2833/// Invite someone to join a channel.
2834async fn invite_channel_member(
2835 request: proto::InviteChannelMember,
2836 response: Response<proto::InviteChannelMember>,
2837 session: Session,
2838) -> Result<()> {
2839 let db = session.db().await;
2840 let channel_id = ChannelId::from_proto(request.channel_id);
2841 let invitee_id = UserId::from_proto(request.user_id);
2842 let InviteMemberResult {
2843 channel,
2844 notifications,
2845 } = db
2846 .invite_channel_member(
2847 channel_id,
2848 invitee_id,
2849 session.user_id(),
2850 request.role().into(),
2851 )
2852 .await?;
2853
2854 let update = proto::UpdateChannels {
2855 channel_invitations: vec![channel.to_proto()],
2856 ..Default::default()
2857 };
2858
2859 let connection_pool = session.connection_pool().await;
2860 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2861 session.peer.send(connection_id, update.clone())?;
2862 }
2863
2864 send_notifications(&connection_pool, &session.peer, notifications);
2865
2866 response.send(proto::Ack {})?;
2867 Ok(())
2868}
2869
2870/// remove someone from a channel
2871async fn remove_channel_member(
2872 request: proto::RemoveChannelMember,
2873 response: Response<proto::RemoveChannelMember>,
2874 session: Session,
2875) -> Result<()> {
2876 let db = session.db().await;
2877 let channel_id = ChannelId::from_proto(request.channel_id);
2878 let member_id = UserId::from_proto(request.user_id);
2879
2880 let RemoveChannelMemberResult {
2881 membership_update,
2882 notification_id,
2883 } = db
2884 .remove_channel_member(channel_id, member_id, session.user_id())
2885 .await?;
2886
2887 let mut connection_pool = session.connection_pool().await;
2888 notify_membership_updated(
2889 &mut connection_pool,
2890 membership_update,
2891 member_id,
2892 &session.peer,
2893 );
2894 for connection_id in connection_pool.user_connection_ids(member_id) {
2895 if let Some(notification_id) = notification_id {
2896 session
2897 .peer
2898 .send(
2899 connection_id,
2900 proto::DeleteNotification {
2901 notification_id: notification_id.to_proto(),
2902 },
2903 )
2904 .trace_err();
2905 }
2906 }
2907
2908 response.send(proto::Ack {})?;
2909 Ok(())
2910}
2911
2912/// Toggle the channel between public and private.
2913/// Care is taken to maintain the invariant that public channels only descend from public channels,
2914/// (though members-only channels can appear at any point in the hierarchy).
2915async fn set_channel_visibility(
2916 request: proto::SetChannelVisibility,
2917 response: Response<proto::SetChannelVisibility>,
2918 session: Session,
2919) -> Result<()> {
2920 let db = session.db().await;
2921 let channel_id = ChannelId::from_proto(request.channel_id);
2922 let visibility = request.visibility().into();
2923
2924 let channel_model = db
2925 .set_channel_visibility(channel_id, visibility, session.user_id())
2926 .await?;
2927 let root_id = channel_model.root_id();
2928 let channel = Channel::from_model(channel_model);
2929
2930 let mut connection_pool = session.connection_pool().await;
2931 for (user_id, role) in connection_pool
2932 .channel_user_ids(root_id)
2933 .collect::<Vec<_>>()
2934 .into_iter()
2935 {
2936 let update = if role.can_see_channel(channel.visibility) {
2937 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2938 proto::UpdateChannels {
2939 channels: vec![channel.to_proto()],
2940 ..Default::default()
2941 }
2942 } else {
2943 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2944 proto::UpdateChannels {
2945 delete_channels: vec![channel.id.to_proto()],
2946 ..Default::default()
2947 }
2948 };
2949
2950 for connection_id in connection_pool.user_connection_ids(user_id) {
2951 session.peer.send(connection_id, update.clone())?;
2952 }
2953 }
2954
2955 response.send(proto::Ack {})?;
2956 Ok(())
2957}
2958
2959/// Alter the role for a user in the channel.
2960async fn set_channel_member_role(
2961 request: proto::SetChannelMemberRole,
2962 response: Response<proto::SetChannelMemberRole>,
2963 session: Session,
2964) -> Result<()> {
2965 let db = session.db().await;
2966 let channel_id = ChannelId::from_proto(request.channel_id);
2967 let member_id = UserId::from_proto(request.user_id);
2968 let result = db
2969 .set_channel_member_role(
2970 channel_id,
2971 session.user_id(),
2972 member_id,
2973 request.role().into(),
2974 )
2975 .await?;
2976
2977 match result {
2978 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2979 let mut connection_pool = session.connection_pool().await;
2980 notify_membership_updated(
2981 &mut connection_pool,
2982 membership_update,
2983 member_id,
2984 &session.peer,
2985 )
2986 }
2987 db::SetMemberRoleResult::InviteUpdated(channel) => {
2988 let update = proto::UpdateChannels {
2989 channel_invitations: vec![channel.to_proto()],
2990 ..Default::default()
2991 };
2992
2993 for connection_id in session
2994 .connection_pool()
2995 .await
2996 .user_connection_ids(member_id)
2997 {
2998 session.peer.send(connection_id, update.clone())?;
2999 }
3000 }
3001 }
3002
3003 response.send(proto::Ack {})?;
3004 Ok(())
3005}
3006
3007/// Change the name of a channel
3008async fn rename_channel(
3009 request: proto::RenameChannel,
3010 response: Response<proto::RenameChannel>,
3011 session: Session,
3012) -> Result<()> {
3013 let db = session.db().await;
3014 let channel_id = ChannelId::from_proto(request.channel_id);
3015 let channel_model = db
3016 .rename_channel(channel_id, session.user_id(), &request.name)
3017 .await?;
3018 let root_id = channel_model.root_id();
3019 let channel = Channel::from_model(channel_model);
3020
3021 response.send(proto::RenameChannelResponse {
3022 channel: Some(channel.to_proto()),
3023 })?;
3024
3025 let connection_pool = session.connection_pool().await;
3026 let update = proto::UpdateChannels {
3027 channels: vec![channel.to_proto()],
3028 ..Default::default()
3029 };
3030 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3031 if role.can_see_channel(channel.visibility) {
3032 session.peer.send(connection_id, update.clone())?;
3033 }
3034 }
3035
3036 Ok(())
3037}
3038
3039/// Move a channel to a new parent.
3040async fn move_channel(
3041 request: proto::MoveChannel,
3042 response: Response<proto::MoveChannel>,
3043 session: Session,
3044) -> Result<()> {
3045 let channel_id = ChannelId::from_proto(request.channel_id);
3046 let to = ChannelId::from_proto(request.to);
3047
3048 let (root_id, channels) = session
3049 .db()
3050 .await
3051 .move_channel(channel_id, to, session.user_id())
3052 .await?;
3053
3054 let connection_pool = session.connection_pool().await;
3055 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3056 let channels = channels
3057 .iter()
3058 .filter_map(|channel| {
3059 if role.can_see_channel(channel.visibility) {
3060 Some(channel.to_proto())
3061 } else {
3062 None
3063 }
3064 })
3065 .collect::<Vec<_>>();
3066 if channels.is_empty() {
3067 continue;
3068 }
3069
3070 let update = proto::UpdateChannels {
3071 channels,
3072 ..Default::default()
3073 };
3074
3075 session.peer.send(connection_id, update.clone())?;
3076 }
3077
3078 response.send(Ack {})?;
3079 Ok(())
3080}
3081
3082/// Get the list of channel members
3083async fn get_channel_members(
3084 request: proto::GetChannelMembers,
3085 response: Response<proto::GetChannelMembers>,
3086 session: Session,
3087) -> Result<()> {
3088 let db = session.db().await;
3089 let channel_id = ChannelId::from_proto(request.channel_id);
3090 let limit = if request.limit == 0 {
3091 u16::MAX as u64
3092 } else {
3093 request.limit
3094 };
3095 let (members, users) = db
3096 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3097 .await?;
3098 response.send(proto::GetChannelMembersResponse { members, users })?;
3099 Ok(())
3100}
3101
3102/// Accept or decline a channel invitation.
3103async fn respond_to_channel_invite(
3104 request: proto::RespondToChannelInvite,
3105 response: Response<proto::RespondToChannelInvite>,
3106 session: Session,
3107) -> Result<()> {
3108 let db = session.db().await;
3109 let channel_id = ChannelId::from_proto(request.channel_id);
3110 let RespondToChannelInvite {
3111 membership_update,
3112 notifications,
3113 } = db
3114 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3115 .await?;
3116
3117 let mut connection_pool = session.connection_pool().await;
3118 if let Some(membership_update) = membership_update {
3119 notify_membership_updated(
3120 &mut connection_pool,
3121 membership_update,
3122 session.user_id(),
3123 &session.peer,
3124 );
3125 } else {
3126 let update = proto::UpdateChannels {
3127 remove_channel_invitations: vec![channel_id.to_proto()],
3128 ..Default::default()
3129 };
3130
3131 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3132 session.peer.send(connection_id, update.clone())?;
3133 }
3134 };
3135
3136 send_notifications(&connection_pool, &session.peer, notifications);
3137
3138 response.send(proto::Ack {})?;
3139
3140 Ok(())
3141}
3142
3143/// Join the channels' room
3144async fn join_channel(
3145 request: proto::JoinChannel,
3146 response: Response<proto::JoinChannel>,
3147 session: Session,
3148) -> Result<()> {
3149 let channel_id = ChannelId::from_proto(request.channel_id);
3150 join_channel_internal(channel_id, Box::new(response), session).await
3151}
3152
3153trait JoinChannelInternalResponse {
3154 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3155}
3156impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3157 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3158 Response::<proto::JoinChannel>::send(self, result)
3159 }
3160}
3161impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3162 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3163 Response::<proto::JoinRoom>::send(self, result)
3164 }
3165}
3166
3167async fn join_channel_internal(
3168 channel_id: ChannelId,
3169 response: Box<impl JoinChannelInternalResponse>,
3170 session: Session,
3171) -> Result<()> {
3172 let joined_room = {
3173 let mut db = session.db().await;
3174 // If zed quits without leaving the room, and the user re-opens zed before the
3175 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3176 // room they were in.
3177 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3178 tracing::info!(
3179 stale_connection_id = %connection,
3180 "cleaning up stale connection",
3181 );
3182 drop(db);
3183 leave_room_for_session(&session, connection).await?;
3184 db = session.db().await;
3185 }
3186
3187 let (joined_room, membership_updated, role) = db
3188 .join_channel(channel_id, session.user_id(), session.connection_id)
3189 .await?;
3190
3191 let live_kit_connection_info =
3192 session
3193 .app_state
3194 .livekit_client
3195 .as_ref()
3196 .and_then(|live_kit| {
3197 let (can_publish, token) = if role == ChannelRole::Guest {
3198 (
3199 false,
3200 live_kit
3201 .guest_token(
3202 &joined_room.room.livekit_room,
3203 &session.user_id().to_string(),
3204 )
3205 .trace_err()?,
3206 )
3207 } else {
3208 (
3209 true,
3210 live_kit
3211 .room_token(
3212 &joined_room.room.livekit_room,
3213 &session.user_id().to_string(),
3214 )
3215 .trace_err()?,
3216 )
3217 };
3218
3219 Some(LiveKitConnectionInfo {
3220 server_url: live_kit.url().into(),
3221 token,
3222 can_publish,
3223 })
3224 });
3225
3226 response.send(proto::JoinRoomResponse {
3227 room: Some(joined_room.room.clone()),
3228 channel_id: joined_room
3229 .channel
3230 .as_ref()
3231 .map(|channel| channel.id.to_proto()),
3232 live_kit_connection_info,
3233 })?;
3234
3235 let mut connection_pool = session.connection_pool().await;
3236 if let Some(membership_updated) = membership_updated {
3237 notify_membership_updated(
3238 &mut connection_pool,
3239 membership_updated,
3240 session.user_id(),
3241 &session.peer,
3242 );
3243 }
3244
3245 room_updated(&joined_room.room, &session.peer);
3246
3247 joined_room
3248 };
3249
3250 channel_updated(
3251 &joined_room
3252 .channel
3253 .ok_or_else(|| anyhow!("channel not returned"))?,
3254 &joined_room.room,
3255 &session.peer,
3256 &*session.connection_pool().await,
3257 );
3258
3259 update_user_contacts(session.user_id(), &session).await?;
3260 Ok(())
3261}
3262
3263/// Start editing the channel notes
3264async fn join_channel_buffer(
3265 request: proto::JoinChannelBuffer,
3266 response: Response<proto::JoinChannelBuffer>,
3267 session: Session,
3268) -> Result<()> {
3269 let db = session.db().await;
3270 let channel_id = ChannelId::from_proto(request.channel_id);
3271
3272 let open_response = db
3273 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3274 .await?;
3275
3276 let collaborators = open_response.collaborators.clone();
3277 response.send(open_response)?;
3278
3279 let update = UpdateChannelBufferCollaborators {
3280 channel_id: channel_id.to_proto(),
3281 collaborators: collaborators.clone(),
3282 };
3283 channel_buffer_updated(
3284 session.connection_id,
3285 collaborators
3286 .iter()
3287 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3288 &update,
3289 &session.peer,
3290 );
3291
3292 Ok(())
3293}
3294
3295/// Edit the channel notes
3296async fn update_channel_buffer(
3297 request: proto::UpdateChannelBuffer,
3298 session: Session,
3299) -> Result<()> {
3300 let db = session.db().await;
3301 let channel_id = ChannelId::from_proto(request.channel_id);
3302
3303 let (collaborators, epoch, version) = db
3304 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3305 .await?;
3306
3307 channel_buffer_updated(
3308 session.connection_id,
3309 collaborators.clone(),
3310 &proto::UpdateChannelBuffer {
3311 channel_id: channel_id.to_proto(),
3312 operations: request.operations,
3313 },
3314 &session.peer,
3315 );
3316
3317 let pool = &*session.connection_pool().await;
3318
3319 let non_collaborators =
3320 pool.channel_connection_ids(channel_id)
3321 .filter_map(|(connection_id, _)| {
3322 if collaborators.contains(&connection_id) {
3323 None
3324 } else {
3325 Some(connection_id)
3326 }
3327 });
3328
3329 broadcast(None, non_collaborators, |peer_id| {
3330 session.peer.send(
3331 peer_id,
3332 proto::UpdateChannels {
3333 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3334 channel_id: channel_id.to_proto(),
3335 epoch: epoch as u64,
3336 version: version.clone(),
3337 }],
3338 ..Default::default()
3339 },
3340 )
3341 });
3342
3343 Ok(())
3344}
3345
3346/// Rejoin the channel notes after a connection blip
3347async fn rejoin_channel_buffers(
3348 request: proto::RejoinChannelBuffers,
3349 response: Response<proto::RejoinChannelBuffers>,
3350 session: Session,
3351) -> Result<()> {
3352 let db = session.db().await;
3353 let buffers = db
3354 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3355 .await?;
3356
3357 for rejoined_buffer in &buffers {
3358 let collaborators_to_notify = rejoined_buffer
3359 .buffer
3360 .collaborators
3361 .iter()
3362 .filter_map(|c| Some(c.peer_id?.into()));
3363 channel_buffer_updated(
3364 session.connection_id,
3365 collaborators_to_notify,
3366 &proto::UpdateChannelBufferCollaborators {
3367 channel_id: rejoined_buffer.buffer.channel_id,
3368 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3369 },
3370 &session.peer,
3371 );
3372 }
3373
3374 response.send(proto::RejoinChannelBuffersResponse {
3375 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3376 })?;
3377
3378 Ok(())
3379}
3380
3381/// Stop editing the channel notes
3382async fn leave_channel_buffer(
3383 request: proto::LeaveChannelBuffer,
3384 response: Response<proto::LeaveChannelBuffer>,
3385 session: Session,
3386) -> Result<()> {
3387 let db = session.db().await;
3388 let channel_id = ChannelId::from_proto(request.channel_id);
3389
3390 let left_buffer = db
3391 .leave_channel_buffer(channel_id, session.connection_id)
3392 .await?;
3393
3394 response.send(Ack {})?;
3395
3396 channel_buffer_updated(
3397 session.connection_id,
3398 left_buffer.connections,
3399 &proto::UpdateChannelBufferCollaborators {
3400 channel_id: channel_id.to_proto(),
3401 collaborators: left_buffer.collaborators,
3402 },
3403 &session.peer,
3404 );
3405
3406 Ok(())
3407}
3408
3409fn channel_buffer_updated<T: EnvelopedMessage>(
3410 sender_id: ConnectionId,
3411 collaborators: impl IntoIterator<Item = ConnectionId>,
3412 message: &T,
3413 peer: &Peer,
3414) {
3415 broadcast(Some(sender_id), collaborators, |peer_id| {
3416 peer.send(peer_id, message.clone())
3417 });
3418}
3419
3420fn send_notifications(
3421 connection_pool: &ConnectionPool,
3422 peer: &Peer,
3423 notifications: db::NotificationBatch,
3424) {
3425 for (user_id, notification) in notifications {
3426 for connection_id in connection_pool.user_connection_ids(user_id) {
3427 if let Err(error) = peer.send(
3428 connection_id,
3429 proto::AddNotification {
3430 notification: Some(notification.clone()),
3431 },
3432 ) {
3433 tracing::error!(
3434 "failed to send notification to {:?} {}",
3435 connection_id,
3436 error
3437 );
3438 }
3439 }
3440 }
3441}
3442
3443/// Send a message to the channel
3444async fn send_channel_message(
3445 request: proto::SendChannelMessage,
3446 response: Response<proto::SendChannelMessage>,
3447 session: Session,
3448) -> Result<()> {
3449 // Validate the message body.
3450 let body = request.body.trim().to_string();
3451 if body.len() > MAX_MESSAGE_LEN {
3452 return Err(anyhow!("message is too long"))?;
3453 }
3454 if body.is_empty() {
3455 return Err(anyhow!("message can't be blank"))?;
3456 }
3457
3458 // TODO: adjust mentions if body is trimmed
3459
3460 let timestamp = OffsetDateTime::now_utc();
3461 let nonce = request
3462 .nonce
3463 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3464
3465 let channel_id = ChannelId::from_proto(request.channel_id);
3466 let CreatedChannelMessage {
3467 message_id,
3468 participant_connection_ids,
3469 notifications,
3470 } = session
3471 .db()
3472 .await
3473 .create_channel_message(
3474 channel_id,
3475 session.user_id(),
3476 &body,
3477 &request.mentions,
3478 timestamp,
3479 nonce.clone().into(),
3480 request.reply_to_message_id.map(MessageId::from_proto),
3481 )
3482 .await?;
3483
3484 let message = proto::ChannelMessage {
3485 sender_id: session.user_id().to_proto(),
3486 id: message_id.to_proto(),
3487 body,
3488 mentions: request.mentions,
3489 timestamp: timestamp.unix_timestamp() as u64,
3490 nonce: Some(nonce),
3491 reply_to_message_id: request.reply_to_message_id,
3492 edited_at: None,
3493 };
3494 broadcast(
3495 Some(session.connection_id),
3496 participant_connection_ids.clone(),
3497 |connection| {
3498 session.peer.send(
3499 connection,
3500 proto::ChannelMessageSent {
3501 channel_id: channel_id.to_proto(),
3502 message: Some(message.clone()),
3503 },
3504 )
3505 },
3506 );
3507 response.send(proto::SendChannelMessageResponse {
3508 message: Some(message),
3509 })?;
3510
3511 let pool = &*session.connection_pool().await;
3512 let non_participants =
3513 pool.channel_connection_ids(channel_id)
3514 .filter_map(|(connection_id, _)| {
3515 if participant_connection_ids.contains(&connection_id) {
3516 None
3517 } else {
3518 Some(connection_id)
3519 }
3520 });
3521 broadcast(None, non_participants, |peer_id| {
3522 session.peer.send(
3523 peer_id,
3524 proto::UpdateChannels {
3525 latest_channel_message_ids: vec![proto::ChannelMessageId {
3526 channel_id: channel_id.to_proto(),
3527 message_id: message_id.to_proto(),
3528 }],
3529 ..Default::default()
3530 },
3531 )
3532 });
3533 send_notifications(pool, &session.peer, notifications);
3534
3535 Ok(())
3536}
3537
3538/// Delete a channel message
3539async fn remove_channel_message(
3540 request: proto::RemoveChannelMessage,
3541 response: Response<proto::RemoveChannelMessage>,
3542 session: Session,
3543) -> Result<()> {
3544 let channel_id = ChannelId::from_proto(request.channel_id);
3545 let message_id = MessageId::from_proto(request.message_id);
3546 let (connection_ids, existing_notification_ids) = session
3547 .db()
3548 .await
3549 .remove_channel_message(channel_id, message_id, session.user_id())
3550 .await?;
3551
3552 broadcast(
3553 Some(session.connection_id),
3554 connection_ids,
3555 move |connection| {
3556 session.peer.send(connection, request.clone())?;
3557
3558 for notification_id in &existing_notification_ids {
3559 session.peer.send(
3560 connection,
3561 proto::DeleteNotification {
3562 notification_id: (*notification_id).to_proto(),
3563 },
3564 )?;
3565 }
3566
3567 Ok(())
3568 },
3569 );
3570 response.send(proto::Ack {})?;
3571 Ok(())
3572}
3573
3574async fn update_channel_message(
3575 request: proto::UpdateChannelMessage,
3576 response: Response<proto::UpdateChannelMessage>,
3577 session: Session,
3578) -> Result<()> {
3579 let channel_id = ChannelId::from_proto(request.channel_id);
3580 let message_id = MessageId::from_proto(request.message_id);
3581 let updated_at = OffsetDateTime::now_utc();
3582 let UpdatedChannelMessage {
3583 message_id,
3584 participant_connection_ids,
3585 notifications,
3586 reply_to_message_id,
3587 timestamp,
3588 deleted_mention_notification_ids,
3589 updated_mention_notifications,
3590 } = session
3591 .db()
3592 .await
3593 .update_channel_message(
3594 channel_id,
3595 message_id,
3596 session.user_id(),
3597 request.body.as_str(),
3598 &request.mentions,
3599 updated_at,
3600 )
3601 .await?;
3602
3603 let nonce = request
3604 .nonce
3605 .clone()
3606 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3607
3608 let message = proto::ChannelMessage {
3609 sender_id: session.user_id().to_proto(),
3610 id: message_id.to_proto(),
3611 body: request.body.clone(),
3612 mentions: request.mentions.clone(),
3613 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3614 nonce: Some(nonce),
3615 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3616 edited_at: Some(updated_at.unix_timestamp() as u64),
3617 };
3618
3619 response.send(proto::Ack {})?;
3620
3621 let pool = &*session.connection_pool().await;
3622 broadcast(
3623 Some(session.connection_id),
3624 participant_connection_ids,
3625 |connection| {
3626 session.peer.send(
3627 connection,
3628 proto::ChannelMessageUpdate {
3629 channel_id: channel_id.to_proto(),
3630 message: Some(message.clone()),
3631 },
3632 )?;
3633
3634 for notification_id in &deleted_mention_notification_ids {
3635 session.peer.send(
3636 connection,
3637 proto::DeleteNotification {
3638 notification_id: (*notification_id).to_proto(),
3639 },
3640 )?;
3641 }
3642
3643 for notification in &updated_mention_notifications {
3644 session.peer.send(
3645 connection,
3646 proto::UpdateNotification {
3647 notification: Some(notification.clone()),
3648 },
3649 )?;
3650 }
3651
3652 Ok(())
3653 },
3654 );
3655
3656 send_notifications(pool, &session.peer, notifications);
3657
3658 Ok(())
3659}
3660
3661/// Mark a channel message as read
3662async fn acknowledge_channel_message(
3663 request: proto::AckChannelMessage,
3664 session: Session,
3665) -> Result<()> {
3666 let channel_id = ChannelId::from_proto(request.channel_id);
3667 let message_id = MessageId::from_proto(request.message_id);
3668 let notifications = session
3669 .db()
3670 .await
3671 .observe_channel_message(channel_id, session.user_id(), message_id)
3672 .await?;
3673 send_notifications(
3674 &*session.connection_pool().await,
3675 &session.peer,
3676 notifications,
3677 );
3678 Ok(())
3679}
3680
3681/// Mark a buffer version as synced
3682async fn acknowledge_buffer_version(
3683 request: proto::AckBufferOperation,
3684 session: Session,
3685) -> Result<()> {
3686 let buffer_id = BufferId::from_proto(request.buffer_id);
3687 session
3688 .db()
3689 .await
3690 .observe_buffer_version(
3691 buffer_id,
3692 session.user_id(),
3693 request.epoch as i32,
3694 &request.version,
3695 )
3696 .await?;
3697 Ok(())
3698}
3699
3700async fn count_language_model_tokens(
3701 request: proto::CountLanguageModelTokens,
3702 response: Response<proto::CountLanguageModelTokens>,
3703 session: Session,
3704 config: &Config,
3705) -> Result<()> {
3706 authorize_access_to_legacy_llm_endpoints(&session).await?;
3707
3708 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3709 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3710 proto::Plan::Free | proto::Plan::ZedProTrial => {
3711 Box::new(FreeCountLanguageModelTokensRateLimit)
3712 }
3713 };
3714
3715 session
3716 .app_state
3717 .rate_limiter
3718 .check(&*rate_limit, session.user_id())
3719 .await?;
3720
3721 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3722 Some(proto::LanguageModelProvider::Google) => {
3723 let api_key = config
3724 .google_ai_api_key
3725 .as_ref()
3726 .context("no Google AI API key configured on the server")?;
3727 google_ai::count_tokens(
3728 session.http_client.as_ref(),
3729 google_ai::API_URL,
3730 api_key,
3731 serde_json::from_str(&request.request)?,
3732 )
3733 .await?
3734 }
3735 _ => return Err(anyhow!("unsupported provider"))?,
3736 };
3737
3738 response.send(proto::CountLanguageModelTokensResponse {
3739 token_count: result.total_tokens as u32,
3740 })?;
3741
3742 Ok(())
3743}
3744
3745struct ZedProCountLanguageModelTokensRateLimit;
3746
3747impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3748 fn capacity(&self) -> usize {
3749 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3750 .ok()
3751 .and_then(|v| v.parse().ok())
3752 .unwrap_or(600) // Picked arbitrarily
3753 }
3754
3755 fn refill_duration(&self) -> chrono::Duration {
3756 chrono::Duration::hours(1)
3757 }
3758
3759 fn db_name(&self) -> &'static str {
3760 "zed-pro:count-language-model-tokens"
3761 }
3762}
3763
3764struct FreeCountLanguageModelTokensRateLimit;
3765
3766impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3767 fn capacity(&self) -> usize {
3768 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3769 .ok()
3770 .and_then(|v| v.parse().ok())
3771 .unwrap_or(600 / 10) // Picked arbitrarily
3772 }
3773
3774 fn refill_duration(&self) -> chrono::Duration {
3775 chrono::Duration::hours(1)
3776 }
3777
3778 fn db_name(&self) -> &'static str {
3779 "free:count-language-model-tokens"
3780 }
3781}
3782
3783struct ZedProComputeEmbeddingsRateLimit;
3784
3785impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3786 fn capacity(&self) -> usize {
3787 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3788 .ok()
3789 .and_then(|v| v.parse().ok())
3790 .unwrap_or(5000) // Picked arbitrarily
3791 }
3792
3793 fn refill_duration(&self) -> chrono::Duration {
3794 chrono::Duration::hours(1)
3795 }
3796
3797 fn db_name(&self) -> &'static str {
3798 "zed-pro:compute-embeddings"
3799 }
3800}
3801
3802struct FreeComputeEmbeddingsRateLimit;
3803
3804impl RateLimit for FreeComputeEmbeddingsRateLimit {
3805 fn capacity(&self) -> usize {
3806 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3807 .ok()
3808 .and_then(|v| v.parse().ok())
3809 .unwrap_or(5000 / 10) // Picked arbitrarily
3810 }
3811
3812 fn refill_duration(&self) -> chrono::Duration {
3813 chrono::Duration::hours(1)
3814 }
3815
3816 fn db_name(&self) -> &'static str {
3817 "free:compute-embeddings"
3818 }
3819}
3820
3821async fn compute_embeddings(
3822 request: proto::ComputeEmbeddings,
3823 response: Response<proto::ComputeEmbeddings>,
3824 session: Session,
3825 api_key: Option<Arc<str>>,
3826) -> Result<()> {
3827 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3828 authorize_access_to_legacy_llm_endpoints(&session).await?;
3829
3830 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3831 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3832 proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
3833 };
3834
3835 session
3836 .app_state
3837 .rate_limiter
3838 .check(&*rate_limit, session.user_id())
3839 .await?;
3840
3841 let embeddings = match request.model.as_str() {
3842 "openai/text-embedding-3-small" => {
3843 open_ai::embed(
3844 session.http_client.as_ref(),
3845 OPEN_AI_API_URL,
3846 &api_key,
3847 OpenAiEmbeddingModel::TextEmbedding3Small,
3848 request.texts.iter().map(|text| text.as_str()),
3849 )
3850 .await?
3851 }
3852 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3853 };
3854
3855 let embeddings = request
3856 .texts
3857 .iter()
3858 .map(|text| {
3859 let mut hasher = sha2::Sha256::new();
3860 hasher.update(text.as_bytes());
3861 let result = hasher.finalize();
3862 result.to_vec()
3863 })
3864 .zip(
3865 embeddings
3866 .data
3867 .into_iter()
3868 .map(|embedding| embedding.embedding),
3869 )
3870 .collect::<HashMap<_, _>>();
3871
3872 let db = session.db().await;
3873 db.save_embeddings(&request.model, &embeddings)
3874 .await
3875 .context("failed to save embeddings")
3876 .trace_err();
3877
3878 response.send(proto::ComputeEmbeddingsResponse {
3879 embeddings: embeddings
3880 .into_iter()
3881 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3882 .collect(),
3883 })?;
3884 Ok(())
3885}
3886
3887async fn get_cached_embeddings(
3888 request: proto::GetCachedEmbeddings,
3889 response: Response<proto::GetCachedEmbeddings>,
3890 session: Session,
3891) -> Result<()> {
3892 authorize_access_to_legacy_llm_endpoints(&session).await?;
3893
3894 let db = session.db().await;
3895 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3896
3897 response.send(proto::GetCachedEmbeddingsResponse {
3898 embeddings: embeddings
3899 .into_iter()
3900 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3901 .collect(),
3902 })?;
3903 Ok(())
3904}
3905
3906/// This is leftover from before the LLM service.
3907///
3908/// The endpoints protected by this check will be moved there eventually.
3909async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3910 if session.is_staff() {
3911 Ok(())
3912 } else {
3913 Err(anyhow!("permission denied"))?
3914 }
3915}
3916
3917/// Get a Supermaven API key for the user
3918async fn get_supermaven_api_key(
3919 _request: proto::GetSupermavenApiKey,
3920 response: Response<proto::GetSupermavenApiKey>,
3921 session: Session,
3922) -> Result<()> {
3923 let user_id: String = session.user_id().to_string();
3924 if !session.is_staff() {
3925 return Err(anyhow!("supermaven not enabled for this account"))?;
3926 }
3927
3928 let email = session
3929 .email()
3930 .ok_or_else(|| anyhow!("user must have an email"))?;
3931
3932 let supermaven_admin_api = session
3933 .supermaven_client
3934 .as_ref()
3935 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3936
3937 let result = supermaven_admin_api
3938 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3939 .await?;
3940
3941 response.send(proto::GetSupermavenApiKeyResponse {
3942 api_key: result.api_key,
3943 })?;
3944
3945 Ok(())
3946}
3947
3948/// Start receiving chat updates for a channel
3949async fn join_channel_chat(
3950 request: proto::JoinChannelChat,
3951 response: Response<proto::JoinChannelChat>,
3952 session: Session,
3953) -> Result<()> {
3954 let channel_id = ChannelId::from_proto(request.channel_id);
3955
3956 let db = session.db().await;
3957 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3958 .await?;
3959 let messages = db
3960 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3961 .await?;
3962 response.send(proto::JoinChannelChatResponse {
3963 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3964 messages,
3965 })?;
3966 Ok(())
3967}
3968
3969/// Stop receiving chat updates for a channel
3970async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3971 let channel_id = ChannelId::from_proto(request.channel_id);
3972 session
3973 .db()
3974 .await
3975 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3976 .await?;
3977 Ok(())
3978}
3979
3980/// Retrieve the chat history for a channel
3981async fn get_channel_messages(
3982 request: proto::GetChannelMessages,
3983 response: Response<proto::GetChannelMessages>,
3984 session: Session,
3985) -> Result<()> {
3986 let channel_id = ChannelId::from_proto(request.channel_id);
3987 let messages = session
3988 .db()
3989 .await
3990 .get_channel_messages(
3991 channel_id,
3992 session.user_id(),
3993 MESSAGE_COUNT_PER_PAGE,
3994 Some(MessageId::from_proto(request.before_message_id)),
3995 )
3996 .await?;
3997 response.send(proto::GetChannelMessagesResponse {
3998 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3999 messages,
4000 })?;
4001 Ok(())
4002}
4003
4004/// Retrieve specific chat messages
4005async fn get_channel_messages_by_id(
4006 request: proto::GetChannelMessagesById,
4007 response: Response<proto::GetChannelMessagesById>,
4008 session: Session,
4009) -> Result<()> {
4010 let message_ids = request
4011 .message_ids
4012 .iter()
4013 .map(|id| MessageId::from_proto(*id))
4014 .collect::<Vec<_>>();
4015 let messages = session
4016 .db()
4017 .await
4018 .get_channel_messages_by_id(session.user_id(), &message_ids)
4019 .await?;
4020 response.send(proto::GetChannelMessagesResponse {
4021 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4022 messages,
4023 })?;
4024 Ok(())
4025}
4026
4027/// Retrieve the current users notifications
4028async fn get_notifications(
4029 request: proto::GetNotifications,
4030 response: Response<proto::GetNotifications>,
4031 session: Session,
4032) -> Result<()> {
4033 let notifications = session
4034 .db()
4035 .await
4036 .get_notifications(
4037 session.user_id(),
4038 NOTIFICATION_COUNT_PER_PAGE,
4039 request.before_id.map(db::NotificationId::from_proto),
4040 )
4041 .await?;
4042 response.send(proto::GetNotificationsResponse {
4043 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4044 notifications,
4045 })?;
4046 Ok(())
4047}
4048
4049/// Mark notifications as read
4050async fn mark_notification_as_read(
4051 request: proto::MarkNotificationRead,
4052 response: Response<proto::MarkNotificationRead>,
4053 session: Session,
4054) -> Result<()> {
4055 let database = &session.db().await;
4056 let notifications = database
4057 .mark_notification_as_read_by_id(
4058 session.user_id(),
4059 NotificationId::from_proto(request.notification_id),
4060 )
4061 .await?;
4062 send_notifications(
4063 &*session.connection_pool().await,
4064 &session.peer,
4065 notifications,
4066 );
4067 response.send(proto::Ack {})?;
4068 Ok(())
4069}
4070
4071/// Get the current users information
4072async fn get_private_user_info(
4073 _request: proto::GetPrivateUserInfo,
4074 response: Response<proto::GetPrivateUserInfo>,
4075 session: Session,
4076) -> Result<()> {
4077 let db = session.db().await;
4078
4079 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4080 let user = db
4081 .get_user_by_id(session.user_id())
4082 .await?
4083 .ok_or_else(|| anyhow!("user not found"))?;
4084 let flags = db.get_user_flags(session.user_id()).await?;
4085
4086 response.send(proto::GetPrivateUserInfoResponse {
4087 metrics_id,
4088 staff: user.admin,
4089 flags,
4090 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4091 })?;
4092 Ok(())
4093}
4094
4095/// Accept the terms of service (tos) on behalf of the current user
4096async fn accept_terms_of_service(
4097 _request: proto::AcceptTermsOfService,
4098 response: Response<proto::AcceptTermsOfService>,
4099 session: Session,
4100) -> Result<()> {
4101 let db = session.db().await;
4102
4103 let accepted_tos_at = Utc::now();
4104 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4105 .await?;
4106
4107 response.send(proto::AcceptTermsOfServiceResponse {
4108 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4109 })?;
4110 Ok(())
4111}
4112
4113/// The minimum account age an account must have in order to use the LLM service.
4114pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4115
4116async fn get_llm_api_token(
4117 _request: proto::GetLlmToken,
4118 response: Response<proto::GetLlmToken>,
4119 session: Session,
4120) -> Result<()> {
4121 let db = session.db().await;
4122
4123 let flags = db.get_user_flags(session.user_id()).await?;
4124 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4125
4126 if !session.is_staff() && !has_language_models_feature_flag {
4127 Err(anyhow!("permission denied"))?
4128 }
4129
4130 let user_id = session.user_id();
4131 let user = db
4132 .get_user_by_id(user_id)
4133 .await?
4134 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4135
4136 if user.accepted_tos_at.is_none() {
4137 Err(anyhow!("terms of service not accepted"))?
4138 }
4139
4140 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4141 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4142 let billing_preferences = db.get_billing_preferences(user.id).await?;
4143
4144 let token = LlmTokenClaims::create(
4145 &user,
4146 session.is_staff(),
4147 billing_preferences,
4148 &flags,
4149 has_legacy_llm_subscription,
4150 session.current_plan(&db).await?,
4151 billing_subscription,
4152 session.system_id.clone(),
4153 &session.app_state.config,
4154 )?;
4155 response.send(proto::GetLlmTokenResponse { token })?;
4156 Ok(())
4157}
4158
4159fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4160 let message = match message {
4161 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4162 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4163 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4164 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4165 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4166 code: frame.code.into(),
4167 reason: frame.reason.as_str().to_owned().into(),
4168 })),
4169 // We should never receive a frame while reading the message, according
4170 // to the `tungstenite` maintainers:
4171 //
4172 // > It cannot occur when you read messages from the WebSocket, but it
4173 // > can be used when you want to send the raw frames (e.g. you want to
4174 // > send the frames to the WebSocket without composing the full message first).
4175 // >
4176 // > — https://github.com/snapview/tungstenite-rs/issues/268
4177 TungsteniteMessage::Frame(_) => {
4178 bail!("received an unexpected frame while reading the message")
4179 }
4180 };
4181
4182 Ok(message)
4183}
4184
4185fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4186 match message {
4187 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4188 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4189 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4190 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4191 AxumMessage::Close(frame) => {
4192 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4193 code: frame.code.into(),
4194 reason: frame.reason.as_ref().into(),
4195 }))
4196 }
4197 }
4198}
4199
4200fn notify_membership_updated(
4201 connection_pool: &mut ConnectionPool,
4202 result: MembershipUpdated,
4203 user_id: UserId,
4204 peer: &Peer,
4205) {
4206 for membership in &result.new_channels.channel_memberships {
4207 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4208 }
4209 for channel_id in &result.removed_channels {
4210 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4211 }
4212
4213 let user_channels_update = proto::UpdateUserChannels {
4214 channel_memberships: result
4215 .new_channels
4216 .channel_memberships
4217 .iter()
4218 .map(|cm| proto::ChannelMembership {
4219 channel_id: cm.channel_id.to_proto(),
4220 role: cm.role.into(),
4221 })
4222 .collect(),
4223 ..Default::default()
4224 };
4225
4226 let mut update = build_channels_update(result.new_channels);
4227 update.delete_channels = result
4228 .removed_channels
4229 .into_iter()
4230 .map(|id| id.to_proto())
4231 .collect();
4232 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4233
4234 for connection_id in connection_pool.user_connection_ids(user_id) {
4235 peer.send(connection_id, user_channels_update.clone())
4236 .trace_err();
4237 peer.send(connection_id, update.clone()).trace_err();
4238 }
4239}
4240
4241fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4242 proto::UpdateUserChannels {
4243 channel_memberships: channels
4244 .channel_memberships
4245 .iter()
4246 .map(|m| proto::ChannelMembership {
4247 channel_id: m.channel_id.to_proto(),
4248 role: m.role.into(),
4249 })
4250 .collect(),
4251 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4252 observed_channel_message_id: channels.observed_channel_messages.clone(),
4253 }
4254}
4255
4256fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4257 let mut update = proto::UpdateChannels::default();
4258
4259 for channel in channels.channels {
4260 update.channels.push(channel.to_proto());
4261 }
4262
4263 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4264 update.latest_channel_message_ids = channels.latest_channel_messages;
4265
4266 for (channel_id, participants) in channels.channel_participants {
4267 update
4268 .channel_participants
4269 .push(proto::ChannelParticipants {
4270 channel_id: channel_id.to_proto(),
4271 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4272 });
4273 }
4274
4275 for channel in channels.invited_channels {
4276 update.channel_invitations.push(channel.to_proto());
4277 }
4278
4279 update
4280}
4281
4282fn build_initial_contacts_update(
4283 contacts: Vec<db::Contact>,
4284 pool: &ConnectionPool,
4285) -> proto::UpdateContacts {
4286 let mut update = proto::UpdateContacts::default();
4287
4288 for contact in contacts {
4289 match contact {
4290 db::Contact::Accepted { user_id, busy } => {
4291 update.contacts.push(contact_for_user(user_id, busy, pool));
4292 }
4293 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4294 db::Contact::Incoming { user_id } => {
4295 update
4296 .incoming_requests
4297 .push(proto::IncomingContactRequest {
4298 requester_id: user_id.to_proto(),
4299 })
4300 }
4301 }
4302 }
4303
4304 update
4305}
4306
4307fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4308 proto::Contact {
4309 user_id: user_id.to_proto(),
4310 online: pool.is_user_online(user_id),
4311 busy,
4312 }
4313}
4314
4315fn room_updated(room: &proto::Room, peer: &Peer) {
4316 broadcast(
4317 None,
4318 room.participants
4319 .iter()
4320 .filter_map(|participant| Some(participant.peer_id?.into())),
4321 |peer_id| {
4322 peer.send(
4323 peer_id,
4324 proto::RoomUpdated {
4325 room: Some(room.clone()),
4326 },
4327 )
4328 },
4329 );
4330}
4331
4332fn channel_updated(
4333 channel: &db::channel::Model,
4334 room: &proto::Room,
4335 peer: &Peer,
4336 pool: &ConnectionPool,
4337) {
4338 let participants = room
4339 .participants
4340 .iter()
4341 .map(|p| p.user_id)
4342 .collect::<Vec<_>>();
4343
4344 broadcast(
4345 None,
4346 pool.channel_connection_ids(channel.root_id())
4347 .filter_map(|(channel_id, role)| {
4348 role.can_see_channel(channel.visibility)
4349 .then_some(channel_id)
4350 }),
4351 |peer_id| {
4352 peer.send(
4353 peer_id,
4354 proto::UpdateChannels {
4355 channel_participants: vec![proto::ChannelParticipants {
4356 channel_id: channel.id.to_proto(),
4357 participant_user_ids: participants.clone(),
4358 }],
4359 ..Default::default()
4360 },
4361 )
4362 },
4363 );
4364}
4365
4366async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4367 let db = session.db().await;
4368
4369 let contacts = db.get_contacts(user_id).await?;
4370 let busy = db.is_user_busy(user_id).await?;
4371
4372 let pool = session.connection_pool().await;
4373 let updated_contact = contact_for_user(user_id, busy, &pool);
4374 for contact in contacts {
4375 if let db::Contact::Accepted {
4376 user_id: contact_user_id,
4377 ..
4378 } = contact
4379 {
4380 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4381 session
4382 .peer
4383 .send(
4384 contact_conn_id,
4385 proto::UpdateContacts {
4386 contacts: vec![updated_contact.clone()],
4387 remove_contacts: Default::default(),
4388 incoming_requests: Default::default(),
4389 remove_incoming_requests: Default::default(),
4390 outgoing_requests: Default::default(),
4391 remove_outgoing_requests: Default::default(),
4392 },
4393 )
4394 .trace_err();
4395 }
4396 }
4397 }
4398 Ok(())
4399}
4400
4401async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4402 let mut contacts_to_update = HashSet::default();
4403
4404 let room_id;
4405 let canceled_calls_to_user_ids;
4406 let livekit_room;
4407 let delete_livekit_room;
4408 let room;
4409 let channel;
4410
4411 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4412 contacts_to_update.insert(session.user_id());
4413
4414 for project in left_room.left_projects.values() {
4415 project_left(project, session);
4416 }
4417
4418 room_id = RoomId::from_proto(left_room.room.id);
4419 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4420 livekit_room = mem::take(&mut left_room.room.livekit_room);
4421 delete_livekit_room = left_room.deleted;
4422 room = mem::take(&mut left_room.room);
4423 channel = mem::take(&mut left_room.channel);
4424
4425 room_updated(&room, &session.peer);
4426 } else {
4427 return Ok(());
4428 }
4429
4430 if let Some(channel) = channel {
4431 channel_updated(
4432 &channel,
4433 &room,
4434 &session.peer,
4435 &*session.connection_pool().await,
4436 );
4437 }
4438
4439 {
4440 let pool = session.connection_pool().await;
4441 for canceled_user_id in canceled_calls_to_user_ids {
4442 for connection_id in pool.user_connection_ids(canceled_user_id) {
4443 session
4444 .peer
4445 .send(
4446 connection_id,
4447 proto::CallCanceled {
4448 room_id: room_id.to_proto(),
4449 },
4450 )
4451 .trace_err();
4452 }
4453 contacts_to_update.insert(canceled_user_id);
4454 }
4455 }
4456
4457 for contact_user_id in contacts_to_update {
4458 update_user_contacts(contact_user_id, session).await?;
4459 }
4460
4461 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4462 live_kit
4463 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4464 .await
4465 .trace_err();
4466
4467 if delete_livekit_room {
4468 live_kit.delete_room(livekit_room).await.trace_err();
4469 }
4470 }
4471
4472 Ok(())
4473}
4474
4475async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4476 let left_channel_buffers = session
4477 .db()
4478 .await
4479 .leave_channel_buffers(session.connection_id)
4480 .await?;
4481
4482 for left_buffer in left_channel_buffers {
4483 channel_buffer_updated(
4484 session.connection_id,
4485 left_buffer.connections,
4486 &proto::UpdateChannelBufferCollaborators {
4487 channel_id: left_buffer.channel_id.to_proto(),
4488 collaborators: left_buffer.collaborators,
4489 },
4490 &session.peer,
4491 );
4492 }
4493
4494 Ok(())
4495}
4496
4497fn project_left(project: &db::LeftProject, session: &Session) {
4498 for connection_id in &project.connection_ids {
4499 if project.should_unshare {
4500 session
4501 .peer
4502 .send(
4503 *connection_id,
4504 proto::UnshareProject {
4505 project_id: project.id.to_proto(),
4506 },
4507 )
4508 .trace_err();
4509 } else {
4510 session
4511 .peer
4512 .send(
4513 *connection_id,
4514 proto::RemoveProjectCollaborator {
4515 project_id: project.id.to_proto(),
4516 peer_id: Some(session.connection_id.into()),
4517 },
4518 )
4519 .trace_err();
4520 }
4521 }
4522}
4523
4524pub trait ResultExt {
4525 type Ok;
4526
4527 fn trace_err(self) -> Option<Self::Ok>;
4528}
4529
4530impl<T, E> ResultExt for Result<T, E>
4531where
4532 E: std::fmt::Debug,
4533{
4534 type Ok = T;
4535
4536 #[track_caller]
4537 fn trace_err(self) -> Option<T> {
4538 match self {
4539 Ok(value) => Some(value),
4540 Err(error) => {
4541 tracing::error!("{:?}", error);
4542 None
4543 }
4544 }
4545 }
4546}