1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use semantic_version::SemanticVersion;
50use serde::{Serialize, Serializer};
51use std::{
52 any::TypeId,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc, OnceLock,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{
69 field::{self},
70 info_span, instrument, Instrument,
71};
72use util::http::IsahcHttpClient;
73
74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
75
76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
78
79const MESSAGE_COUNT_PER_PAGE: usize = 100;
80const MAX_MESSAGE_LEN: usize = 1024;
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100struct StreamingResponse<R: RequestMessage> {
101 peer: Arc<Peer>,
102 receipt: Receipt<R>,
103}
104
105impl<R: RequestMessage> StreamingResponse<R> {
106 fn send(&self, payload: R::Response) -> Result<()> {
107 self.peer.respond(self.receipt, payload)?;
108 Ok(())
109 }
110}
111
112#[derive(Clone, Debug)]
113pub enum Principal {
114 User(User),
115 Impersonated { user: User, admin: User },
116 DevServer(dev_server::Model),
117}
118
119impl Principal {
120 fn update_span(&self, span: &tracing::Span) {
121 match &self {
122 Principal::User(user) => {
123 span.record("user_id", &user.id.0);
124 span.record("login", &user.github_login);
125 }
126 Principal::Impersonated { user, admin } => {
127 span.record("user_id", &user.id.0);
128 span.record("login", &user.github_login);
129 span.record("impersonator", &admin.github_login);
130 }
131 Principal::DevServer(dev_server) => {
132 span.record("dev_server_id", &dev_server.id.0);
133 }
134 }
135 }
136}
137
138#[derive(Clone)]
139struct Session {
140 principal: Principal,
141 connection_id: ConnectionId,
142 db: Arc<tokio::sync::Mutex<DbHandle>>,
143 peer: Arc<Peer>,
144 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
146 http_client: IsahcHttpClient,
147 rate_limiter: Arc<RateLimiter>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn user_id(&self) -> Option<UserId> {
176 match &self.principal {
177 Principal::User(user) => Some(user.id),
178 Principal::Impersonated { user, .. } => Some(user.id),
179 Principal::DevServer(_) => None,
180 }
181 }
182}
183
184impl Debug for Session {
185 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186 let mut result = f.debug_struct("Session");
187 match &self.principal {
188 Principal::User(user) => {
189 result.field("user", &user.github_login);
190 }
191 Principal::Impersonated { user, admin } => {
192 result.field("user", &user.github_login);
193 result.field("impersonator", &admin.github_login);
194 }
195 Principal::DevServer(dev_server) => {
196 result.field("dev_server", &dev_server.id);
197 }
198 }
199 result.field("connection_id", &self.connection_id).finish()
200 }
201}
202
203struct UserSession(Session);
204
205impl UserSession {
206 pub fn new(s: Session) -> Option<Self> {
207 s.user_id().map(|_| UserSession(s))
208 }
209 pub fn user_id(&self) -> UserId {
210 self.0.user_id().unwrap()
211 }
212}
213
214impl Deref for UserSession {
215 type Target = Session;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220}
221impl DerefMut for UserSession {
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227fn user_handler<M: RequestMessage, Fut>(
228 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
230where
231 Fut: Send + Future<Output = Result<()>>,
232{
233 let handler = Arc::new(handler);
234 move |message, response, session| {
235 let handler = handler.clone();
236 Box::pin(async move {
237 if let Some(user_session) = session.for_user() {
238 Ok(handler(message, response, user_session).await?)
239 } else {
240 Err(Error::Internal(anyhow!("must be a user")))
241 }
242 })
243 }
244}
245
246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
247 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
249where
250 InnertRetFut: Send + Future<Output = Result<()>>,
251{
252 let handler = Arc::new(handler);
253 move |message, session| {
254 let handler = handler.clone();
255 Box::pin(async move {
256 if let Some(user_session) = session.for_user() {
257 Ok(handler(message, user_session).await?)
258 } else {
259 Err(Error::Internal(anyhow!("must be a user")))
260 }
261 })
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state: app_state.clone(),
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(user_handler(create_room))
319 .add_request_handler(user_handler(join_room))
320 .add_request_handler(user_handler(rejoin_room))
321 .add_request_handler(user_handler(leave_room))
322 .add_request_handler(user_handler(set_room_participant_role))
323 .add_request_handler(user_handler(call))
324 .add_request_handler(user_handler(cancel_call))
325 .add_message_handler(user_message_handler(decline_call))
326 .add_request_handler(user_handler(update_participant_location))
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(user_handler(join_project))
330 .add_request_handler(user_handler(join_hosted_project))
331 .add_message_handler(user_message_handler(leave_project))
332 .add_request_handler(update_project)
333 .add_request_handler(update_worktree)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
342 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
347 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
348 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
369 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
370 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
371 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
372 .add_message_handler(create_buffer_for_peer)
373 .add_request_handler(update_buffer)
374 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
376 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
377 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
378 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
379 .add_request_handler(get_users)
380 .add_request_handler(user_handler(fuzzy_search_users))
381 .add_request_handler(user_handler(request_contact))
382 .add_request_handler(user_handler(remove_contact))
383 .add_request_handler(user_handler(respond_to_contact_request))
384 .add_request_handler(user_handler(create_channel))
385 .add_request_handler(user_handler(delete_channel))
386 .add_request_handler(user_handler(invite_channel_member))
387 .add_request_handler(user_handler(remove_channel_member))
388 .add_request_handler(user_handler(set_channel_member_role))
389 .add_request_handler(user_handler(set_channel_visibility))
390 .add_request_handler(user_handler(rename_channel))
391 .add_request_handler(user_handler(join_channel_buffer))
392 .add_request_handler(user_handler(leave_channel_buffer))
393 .add_message_handler(user_message_handler(update_channel_buffer))
394 .add_request_handler(user_handler(rejoin_channel_buffers))
395 .add_request_handler(user_handler(get_channel_members))
396 .add_request_handler(user_handler(respond_to_channel_invite))
397 .add_request_handler(user_handler(join_channel))
398 .add_request_handler(user_handler(join_channel_chat))
399 .add_message_handler(user_message_handler(leave_channel_chat))
400 .add_request_handler(user_handler(send_channel_message))
401 .add_request_handler(user_handler(remove_channel_message))
402 .add_request_handler(user_handler(update_channel_message))
403 .add_request_handler(user_handler(get_channel_messages))
404 .add_request_handler(user_handler(get_channel_messages_by_id))
405 .add_request_handler(user_handler(get_notifications))
406 .add_request_handler(user_handler(mark_notification_as_read))
407 .add_request_handler(user_handler(move_channel))
408 .add_request_handler(user_handler(follow))
409 .add_message_handler(user_message_handler(unfollow))
410 .add_message_handler(user_message_handler(update_followers))
411 .add_request_handler(user_handler(get_private_user_info))
412 .add_message_handler(user_message_handler(acknowledge_channel_message))
413 .add_message_handler(user_message_handler(acknowledge_buffer_version))
414 .add_streaming_request_handler({
415 let app_state = app_state.clone();
416 move |request, response, session| {
417 complete_with_language_model(
418 request,
419 response,
420 session,
421 app_state.config.openai_api_key.clone(),
422 app_state.config.google_ai_api_key.clone(),
423 app_state.config.anthropic_api_key.clone(),
424 )
425 }
426 })
427 .add_request_handler({
428 let app_state = app_state.clone();
429 user_handler(move |request, response, session| {
430 count_tokens_with_language_model(
431 request,
432 response,
433 session,
434 app_state.config.google_ai_api_key.clone(),
435 )
436 })
437 });
438
439 Arc::new(server)
440 }
441
442 pub async fn start(&self) -> Result<()> {
443 let server_id = *self.id.lock();
444 let app_state = self.app_state.clone();
445 let peer = self.peer.clone();
446 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
447 let pool = self.connection_pool.clone();
448 let live_kit_client = self.app_state.live_kit_client.clone();
449
450 let span = info_span!("start server");
451 self.app_state.executor.spawn_detached(
452 async move {
453 tracing::info!("waiting for cleanup timeout");
454 timeout.await;
455 tracing::info!("cleanup timeout expired, retrieving stale rooms");
456 if let Some((room_ids, channel_ids)) = app_state
457 .db
458 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
459 .await
460 .trace_err()
461 {
462 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
463 tracing::info!(
464 stale_channel_buffer_count = channel_ids.len(),
465 "retrieved stale channel buffers"
466 );
467
468 for channel_id in channel_ids {
469 if let Some(refreshed_channel_buffer) = app_state
470 .db
471 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
472 .await
473 .trace_err()
474 {
475 for connection_id in refreshed_channel_buffer.connection_ids {
476 peer.send(
477 connection_id,
478 proto::UpdateChannelBufferCollaborators {
479 channel_id: channel_id.to_proto(),
480 collaborators: refreshed_channel_buffer
481 .collaborators
482 .clone(),
483 },
484 )
485 .trace_err();
486 }
487 }
488 }
489
490 for room_id in room_ids {
491 let mut contacts_to_update = HashSet::default();
492 let mut canceled_calls_to_user_ids = Vec::new();
493 let mut live_kit_room = String::new();
494 let mut delete_live_kit_room = false;
495
496 if let Some(mut refreshed_room) = app_state
497 .db
498 .clear_stale_room_participants(room_id, server_id)
499 .await
500 .trace_err()
501 {
502 tracing::info!(
503 room_id = room_id.0,
504 new_participant_count = refreshed_room.room.participants.len(),
505 "refreshed room"
506 );
507 room_updated(&refreshed_room.room, &peer);
508 if let Some(channel) = refreshed_room.channel.as_ref() {
509 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
510 }
511 contacts_to_update
512 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
513 contacts_to_update
514 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
515 canceled_calls_to_user_ids =
516 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
517 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
518 delete_live_kit_room = refreshed_room.room.participants.is_empty();
519 }
520
521 {
522 let pool = pool.lock();
523 for canceled_user_id in canceled_calls_to_user_ids {
524 for connection_id in pool.user_connection_ids(canceled_user_id) {
525 peer.send(
526 connection_id,
527 proto::CallCanceled {
528 room_id: room_id.to_proto(),
529 },
530 )
531 .trace_err();
532 }
533 }
534 }
535
536 for user_id in contacts_to_update {
537 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
538 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
539 if let Some((busy, contacts)) = busy.zip(contacts) {
540 let pool = pool.lock();
541 let updated_contact = contact_for_user(user_id, busy, &pool);
542 for contact in contacts {
543 if let db::Contact::Accepted {
544 user_id: contact_user_id,
545 ..
546 } = contact
547 {
548 for contact_conn_id in
549 pool.user_connection_ids(contact_user_id)
550 {
551 peer.send(
552 contact_conn_id,
553 proto::UpdateContacts {
554 contacts: vec![updated_contact.clone()],
555 remove_contacts: Default::default(),
556 incoming_requests: Default::default(),
557 remove_incoming_requests: Default::default(),
558 outgoing_requests: Default::default(),
559 remove_outgoing_requests: Default::default(),
560 },
561 )
562 .trace_err();
563 }
564 }
565 }
566 }
567 }
568
569 if let Some(live_kit) = live_kit_client.as_ref() {
570 if delete_live_kit_room {
571 live_kit.delete_room(live_kit_room).await.trace_err();
572 }
573 }
574 }
575 }
576
577 app_state
578 .db
579 .delete_stale_servers(&app_state.config.zed_environment, server_id)
580 .await
581 .trace_err();
582 }
583 .instrument(span),
584 );
585 Ok(())
586 }
587
588 pub fn teardown(&self) {
589 self.peer.teardown();
590 self.connection_pool.lock().reset();
591 let _ = self.teardown.send(true);
592 }
593
594 #[cfg(test)]
595 pub fn reset(&self, id: ServerId) {
596 self.teardown();
597 *self.id.lock() = id;
598 self.peer.reset(id.0 as u32);
599 let _ = self.teardown.send(false);
600 }
601
602 #[cfg(test)]
603 pub fn id(&self) -> ServerId {
604 *self.id.lock()
605 }
606
607 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
608 where
609 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
610 Fut: 'static + Send + Future<Output = Result<()>>,
611 M: EnvelopedMessage,
612 {
613 let prev_handler = self.handlers.insert(
614 TypeId::of::<M>(),
615 Box::new(move |envelope, session| {
616 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
617 let received_at = envelope.received_at;
618 tracing::info!(
619 "message received"
620 );
621 let start_time = Instant::now();
622 let future = (handler)(*envelope, session);
623 async move {
624 let result = future.await;
625 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
626 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
627 let queue_duration_ms = total_duration_ms - processing_duration_ms;
628 match result {
629 Err(error) => {
630 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
631 }
632 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
633 }
634 }
635 .boxed()
636 }),
637 );
638 if prev_handler.is_some() {
639 panic!("registered a handler for the same message twice");
640 }
641 self
642 }
643
644 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
645 where
646 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
647 Fut: 'static + Send + Future<Output = Result<()>>,
648 M: EnvelopedMessage,
649 {
650 self.add_handler(move |envelope, session| handler(envelope.payload, session));
651 self
652 }
653
654 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
655 where
656 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
657 Fut: Send + Future<Output = Result<()>>,
658 M: RequestMessage,
659 {
660 let handler = Arc::new(handler);
661 self.add_handler(move |envelope, session| {
662 let receipt = envelope.receipt();
663 let handler = handler.clone();
664 async move {
665 let peer = session.peer.clone();
666 let responded = Arc::new(AtomicBool::default());
667 let response = Response {
668 peer: peer.clone(),
669 responded: responded.clone(),
670 receipt,
671 };
672 match (handler)(envelope.payload, response, session).await {
673 Ok(()) => {
674 if responded.load(std::sync::atomic::Ordering::SeqCst) {
675 Ok(())
676 } else {
677 Err(anyhow!("handler did not send a response"))?
678 }
679 }
680 Err(error) => {
681 let proto_err = match &error {
682 Error::Internal(err) => err.to_proto(),
683 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
684 };
685 peer.respond_with_error(receipt, proto_err)?;
686 Err(error)
687 }
688 }
689 }
690 })
691 }
692
693 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
694 where
695 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
696 Fut: Send + Future<Output = Result<()>>,
697 M: RequestMessage,
698 {
699 let handler = Arc::new(handler);
700 self.add_handler(move |envelope, session| {
701 let receipt = envelope.receipt();
702 let handler = handler.clone();
703 async move {
704 let peer = session.peer.clone();
705 let response = StreamingResponse {
706 peer: peer.clone(),
707 receipt,
708 };
709 match (handler)(envelope.payload, response, session).await {
710 Ok(()) => {
711 peer.end_stream(receipt)?;
712 Ok(())
713 }
714 Err(error) => {
715 let proto_err = match &error {
716 Error::Internal(err) => err.to_proto(),
717 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
718 };
719 peer.respond_with_error(receipt, proto_err)?;
720 Err(error)
721 }
722 }
723 }
724 })
725 }
726
727 #[allow(clippy::too_many_arguments)]
728 pub fn handle_connection(
729 self: &Arc<Self>,
730 connection: Connection,
731 address: String,
732 principal: Principal,
733 zed_version: ZedVersion,
734 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
735 executor: Executor,
736 ) -> impl Future<Output = ()> {
737 let this = self.clone();
738 let span = info_span!("handle connection", %address,
739 connection_id=field::Empty,
740 user_id=field::Empty,
741 login=field::Empty,
742 impersonator=field::Empty,
743 dev_server_id=field::Empty
744 );
745 principal.update_span(&span);
746
747 let mut teardown = self.teardown.subscribe();
748 async move {
749 if *teardown.borrow() {
750 tracing::error!("server is tearing down");
751 return
752 }
753 let (connection_id, handle_io, mut incoming_rx) = this
754 .peer
755 .add_connection(connection, {
756 let executor = executor.clone();
757 move |duration| executor.sleep(duration)
758 });
759 tracing::Span::current().record("connection_id", format!("{}", connection_id));
760 tracing::info!("connection opened");
761
762 let http_client = match IsahcHttpClient::new() {
763 Ok(http_client) => http_client,
764 Err(error) => {
765 tracing::error!(?error, "failed to create HTTP client");
766 return;
767 }
768 };
769
770 let session = Session {
771 principal: principal.clone(),
772 connection_id,
773 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
774 peer: this.peer.clone(),
775 connection_pool: this.connection_pool.clone(),
776 live_kit_client: this.app_state.live_kit_client.clone(),
777 http_client,
778 rate_limiter: this.app_state.rate_limiter.clone(),
779 _executor: executor.clone(),
780 };
781
782 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
783 tracing::error!(?error, "failed to send initial client update");
784 return;
785 }
786
787 let handle_io = handle_io.fuse();
788 futures::pin_mut!(handle_io);
789
790 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
791 // This prevents deadlocks when e.g., client A performs a request to client B and
792 // client B performs a request to client A. If both clients stop processing further
793 // messages until their respective request completes, they won't have a chance to
794 // respond to the other client's request and cause a deadlock.
795 //
796 // This arrangement ensures we will attempt to process earlier messages first, but fall
797 // back to processing messages arrived later in the spirit of making progress.
798 let mut foreground_message_handlers = FuturesUnordered::new();
799 let concurrent_handlers = Arc::new(Semaphore::new(256));
800 loop {
801 let next_message = async {
802 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
803 let message = incoming_rx.next().await;
804 (permit, message)
805 }.fuse();
806 futures::pin_mut!(next_message);
807 futures::select_biased! {
808 _ = teardown.changed().fuse() => return,
809 result = handle_io => {
810 if let Err(error) = result {
811 tracing::error!(?error, "error handling I/O");
812 }
813 break;
814 }
815 _ = foreground_message_handlers.next() => {}
816 next_message = next_message => {
817 let (permit, message) = next_message;
818 if let Some(message) = message {
819 let type_name = message.payload_type_name();
820 // note: we copy all the fields from the parent span so we can query them in the logs.
821 // (https://github.com/tokio-rs/tracing/issues/2670).
822 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
823 user_id=field::Empty,
824 login=field::Empty,
825 impersonator=field::Empty,
826 dev_server_id=field::Empty
827 );
828 principal.update_span(&span);
829 let span_enter = span.enter();
830 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
831 let is_background = message.is_background();
832 let handle_message = (handler)(message, session.clone());
833 drop(span_enter);
834
835 let handle_message = async move {
836 handle_message.await;
837 drop(permit);
838 }.instrument(span);
839 if is_background {
840 executor.spawn_detached(handle_message);
841 } else {
842 foreground_message_handlers.push(handle_message);
843 }
844 } else {
845 tracing::error!("no message handler");
846 }
847 } else {
848 tracing::info!("connection closed");
849 break;
850 }
851 }
852 }
853 }
854
855 drop(foreground_message_handlers);
856 tracing::info!("signing out");
857 if let Err(error) = connection_lost(session, teardown, executor).await {
858 tracing::error!(?error, "error signing out");
859 }
860
861 }.instrument(span)
862 }
863
864 async fn send_initial_client_update(
865 &self,
866 connection_id: ConnectionId,
867 principal: &Principal,
868 zed_version: ZedVersion,
869 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
870 session: &Session,
871 ) -> Result<()> {
872 self.peer.send(
873 connection_id,
874 proto::Hello {
875 peer_id: Some(connection_id.into()),
876 },
877 )?;
878 tracing::info!("sent hello message");
879
880 let Principal::User(user) = principal else {
881 return Ok(());
882 };
883
884 if let Some(send_connection_id) = send_connection_id.take() {
885 let _ = send_connection_id.send(connection_id);
886 }
887
888 if !user.connected_once {
889 self.peer.send(connection_id, proto::ShowContacts {})?;
890 self.app_state
891 .db
892 .set_user_connected_once(user.id, true)
893 .await?;
894 }
895
896 let (contacts, channels_for_user, channel_invites) = future::try_join3(
897 self.app_state.db.get_contacts(user.id),
898 self.app_state.db.get_channels_for_user(user.id),
899 self.app_state.db.get_channel_invites_for_user(user.id),
900 )
901 .await?;
902
903 {
904 let mut pool = self.connection_pool.lock();
905 pool.add_connection(connection_id, user.id, user.admin, zed_version);
906 for membership in &channels_for_user.channel_memberships {
907 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
908 }
909 self.peer.send(
910 connection_id,
911 build_initial_contacts_update(contacts, &pool),
912 )?;
913 self.peer.send(
914 connection_id,
915 build_update_user_channels(&channels_for_user),
916 )?;
917 self.peer.send(
918 connection_id,
919 build_channels_update(channels_for_user, channel_invites),
920 )?;
921 }
922
923 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
924 self.peer.send(connection_id, incoming_call)?;
925 }
926
927 update_user_contacts(user.id, &session).await?;
928 Ok(())
929 }
930
931 pub async fn invite_code_redeemed(
932 self: &Arc<Self>,
933 inviter_id: UserId,
934 invitee_id: UserId,
935 ) -> Result<()> {
936 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
937 if let Some(code) = &user.invite_code {
938 let pool = self.connection_pool.lock();
939 let invitee_contact = contact_for_user(invitee_id, false, &pool);
940 for connection_id in pool.user_connection_ids(inviter_id) {
941 self.peer.send(
942 connection_id,
943 proto::UpdateContacts {
944 contacts: vec![invitee_contact.clone()],
945 ..Default::default()
946 },
947 )?;
948 self.peer.send(
949 connection_id,
950 proto::UpdateInviteInfo {
951 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
952 count: user.invite_count as u32,
953 },
954 )?;
955 }
956 }
957 }
958 Ok(())
959 }
960
961 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
962 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
963 if let Some(invite_code) = &user.invite_code {
964 let pool = self.connection_pool.lock();
965 for connection_id in pool.user_connection_ids(user_id) {
966 self.peer.send(
967 connection_id,
968 proto::UpdateInviteInfo {
969 url: format!(
970 "{}{}",
971 self.app_state.config.invite_link_prefix, invite_code
972 ),
973 count: user.invite_count as u32,
974 },
975 )?;
976 }
977 }
978 }
979 Ok(())
980 }
981
982 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
983 ServerSnapshot {
984 connection_pool: ConnectionPoolGuard {
985 guard: self.connection_pool.lock(),
986 _not_send: PhantomData,
987 },
988 peer: &self.peer,
989 }
990 }
991}
992
993impl<'a> Deref for ConnectionPoolGuard<'a> {
994 type Target = ConnectionPool;
995
996 fn deref(&self) -> &Self::Target {
997 &self.guard
998 }
999}
1000
1001impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1002 fn deref_mut(&mut self) -> &mut Self::Target {
1003 &mut self.guard
1004 }
1005}
1006
1007impl<'a> Drop for ConnectionPoolGuard<'a> {
1008 fn drop(&mut self) {
1009 #[cfg(test)]
1010 self.check_invariants();
1011 }
1012}
1013
1014fn broadcast<F>(
1015 sender_id: Option<ConnectionId>,
1016 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1017 mut f: F,
1018) where
1019 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1020{
1021 for receiver_id in receiver_ids {
1022 if Some(receiver_id) != sender_id {
1023 if let Err(error) = f(receiver_id) {
1024 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1025 }
1026 }
1027 }
1028}
1029
1030pub struct ProtocolVersion(u32);
1031
1032impl Header for ProtocolVersion {
1033 fn name() -> &'static HeaderName {
1034 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1035 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1036 }
1037
1038 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1039 where
1040 Self: Sized,
1041 I: Iterator<Item = &'i axum::http::HeaderValue>,
1042 {
1043 let version = values
1044 .next()
1045 .ok_or_else(axum::headers::Error::invalid)?
1046 .to_str()
1047 .map_err(|_| axum::headers::Error::invalid())?
1048 .parse()
1049 .map_err(|_| axum::headers::Error::invalid())?;
1050 Ok(Self(version))
1051 }
1052
1053 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1054 values.extend([self.0.to_string().parse().unwrap()]);
1055 }
1056}
1057
1058pub struct AppVersionHeader(SemanticVersion);
1059impl Header for AppVersionHeader {
1060 fn name() -> &'static HeaderName {
1061 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1062 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1063 }
1064
1065 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1066 where
1067 Self: Sized,
1068 I: Iterator<Item = &'i axum::http::HeaderValue>,
1069 {
1070 let version = values
1071 .next()
1072 .ok_or_else(axum::headers::Error::invalid)?
1073 .to_str()
1074 .map_err(|_| axum::headers::Error::invalid())?
1075 .parse()
1076 .map_err(|_| axum::headers::Error::invalid())?;
1077 Ok(Self(version))
1078 }
1079
1080 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1081 values.extend([self.0.to_string().parse().unwrap()]);
1082 }
1083}
1084
1085pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1086 Router::new()
1087 .route("/rpc", get(handle_websocket_request))
1088 .layer(
1089 ServiceBuilder::new()
1090 .layer(Extension(server.app_state.clone()))
1091 .layer(middleware::from_fn(auth::validate_header)),
1092 )
1093 .route("/metrics", get(handle_metrics))
1094 .layer(Extension(server))
1095}
1096
1097pub async fn handle_websocket_request(
1098 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1099 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1100 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101 Extension(server): Extension<Arc<Server>>,
1102 Extension(principal): Extension<Principal>,
1103 ws: WebSocketUpgrade,
1104) -> axum::response::Response {
1105 if protocol_version != rpc::PROTOCOL_VERSION {
1106 return (
1107 StatusCode::UPGRADE_REQUIRED,
1108 "client must be upgraded".to_string(),
1109 )
1110 .into_response();
1111 }
1112
1113 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1114 return (
1115 StatusCode::UPGRADE_REQUIRED,
1116 "no version header found".to_string(),
1117 )
1118 .into_response();
1119 };
1120
1121 if !version.can_collaborate() {
1122 return (
1123 StatusCode::UPGRADE_REQUIRED,
1124 "client must be upgraded".to_string(),
1125 )
1126 .into_response();
1127 }
1128
1129 let socket_address = socket_address.to_string();
1130 ws.on_upgrade(move |socket| {
1131 let socket = socket
1132 .map_ok(to_tungstenite_message)
1133 .err_into()
1134 .with(|message| async move { Ok(to_axum_message(message)) });
1135 let connection = Connection::new(Box::pin(socket));
1136 async move {
1137 server
1138 .handle_connection(
1139 connection,
1140 socket_address,
1141 principal,
1142 version,
1143 None,
1144 Executor::Production,
1145 )
1146 .await;
1147 }
1148 })
1149}
1150
1151pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1152 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153 let connections_metric = CONNECTIONS_METRIC
1154 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1155
1156 let connections = server
1157 .connection_pool
1158 .lock()
1159 .connections()
1160 .filter(|connection| !connection.admin)
1161 .count();
1162 connections_metric.set(connections as _);
1163
1164 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1165 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1166 register_int_gauge!(
1167 "shared_projects",
1168 "number of open projects with one or more guests"
1169 )
1170 .unwrap()
1171 });
1172
1173 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1174 shared_projects_metric.set(shared_projects as _);
1175
1176 let encoder = prometheus::TextEncoder::new();
1177 let metric_families = prometheus::gather();
1178 let encoded_metrics = encoder
1179 .encode_to_string(&metric_families)
1180 .map_err(|err| anyhow!("{}", err))?;
1181 Ok(encoded_metrics)
1182}
1183
1184#[instrument(err, skip(executor))]
1185async fn connection_lost(
1186 session: Session,
1187 mut teardown: watch::Receiver<bool>,
1188 executor: Executor,
1189) -> Result<()> {
1190 session.peer.disconnect(session.connection_id);
1191 session
1192 .connection_pool()
1193 .await
1194 .remove_connection(session.connection_id)?;
1195
1196 session
1197 .db()
1198 .await
1199 .connection_lost(session.connection_id)
1200 .await
1201 .trace_err();
1202
1203 futures::select_biased! {
1204 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1205 if let Some(session) = session.for_user() {
1206 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1207 leave_room_for_session(&session, session.connection_id).await.trace_err();
1208 leave_channel_buffers_for_session(&session)
1209 .await
1210 .trace_err();
1211
1212 if !session
1213 .connection_pool()
1214 .await
1215 .is_user_online(session.user_id())
1216 {
1217 let db = session.db().await;
1218 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1219 room_updated(&room, &session.peer);
1220 }
1221 }
1222
1223 update_user_contacts(session.user_id(), &session).await?;
1224 }
1225 }
1226 _ = teardown.changed().fuse() => {}
1227 }
1228
1229 Ok(())
1230}
1231
1232/// Acknowledges a ping from a client, used to keep the connection alive.
1233async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1234 response.send(proto::Ack {})?;
1235 Ok(())
1236}
1237
1238/// Creates a new room for calling (outside of channels)
1239async fn create_room(
1240 _request: proto::CreateRoom,
1241 response: Response<proto::CreateRoom>,
1242 session: UserSession,
1243) -> Result<()> {
1244 let live_kit_room = nanoid::nanoid!(30);
1245
1246 let live_kit_connection_info = util::maybe!(async {
1247 let live_kit = session.live_kit_client.as_ref();
1248 let live_kit = live_kit?;
1249 let user_id = session.user_id().to_string();
1250
1251 let token = live_kit
1252 .room_token(&live_kit_room, &user_id.to_string())
1253 .trace_err()?;
1254
1255 Some(proto::LiveKitConnectionInfo {
1256 server_url: live_kit.url().into(),
1257 token,
1258 can_publish: true,
1259 })
1260 })
1261 .await;
1262
1263 let room = session
1264 .db()
1265 .await
1266 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1267 .await?;
1268
1269 response.send(proto::CreateRoomResponse {
1270 room: Some(room.clone()),
1271 live_kit_connection_info,
1272 })?;
1273
1274 update_user_contacts(session.user_id(), &session).await?;
1275 Ok(())
1276}
1277
1278/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1279async fn join_room(
1280 request: proto::JoinRoom,
1281 response: Response<proto::JoinRoom>,
1282 session: UserSession,
1283) -> Result<()> {
1284 let room_id = RoomId::from_proto(request.id);
1285
1286 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1287
1288 if let Some(channel_id) = channel_id {
1289 return join_channel_internal(channel_id, Box::new(response), session).await;
1290 }
1291
1292 let joined_room = {
1293 let room = session
1294 .db()
1295 .await
1296 .join_room(room_id, session.user_id(), session.connection_id)
1297 .await?;
1298 room_updated(&room.room, &session.peer);
1299 room.into_inner()
1300 };
1301
1302 for connection_id in session
1303 .connection_pool()
1304 .await
1305 .user_connection_ids(session.user_id())
1306 {
1307 session
1308 .peer
1309 .send(
1310 connection_id,
1311 proto::CallCanceled {
1312 room_id: room_id.to_proto(),
1313 },
1314 )
1315 .trace_err();
1316 }
1317
1318 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1319 if let Some(token) = live_kit
1320 .room_token(
1321 &joined_room.room.live_kit_room,
1322 &session.user_id().to_string(),
1323 )
1324 .trace_err()
1325 {
1326 Some(proto::LiveKitConnectionInfo {
1327 server_url: live_kit.url().into(),
1328 token,
1329 can_publish: true,
1330 })
1331 } else {
1332 None
1333 }
1334 } else {
1335 None
1336 };
1337
1338 response.send(proto::JoinRoomResponse {
1339 room: Some(joined_room.room),
1340 channel_id: None,
1341 live_kit_connection_info,
1342 })?;
1343
1344 update_user_contacts(session.user_id(), &session).await?;
1345 Ok(())
1346}
1347
1348/// Rejoin room is used to reconnect to a room after connection errors.
1349async fn rejoin_room(
1350 request: proto::RejoinRoom,
1351 response: Response<proto::RejoinRoom>,
1352 session: UserSession,
1353) -> Result<()> {
1354 let room;
1355 let channel;
1356 {
1357 let mut rejoined_room = session
1358 .db()
1359 .await
1360 .rejoin_room(request, session.user_id(), session.connection_id)
1361 .await?;
1362
1363 response.send(proto::RejoinRoomResponse {
1364 room: Some(rejoined_room.room.clone()),
1365 reshared_projects: rejoined_room
1366 .reshared_projects
1367 .iter()
1368 .map(|project| proto::ResharedProject {
1369 id: project.id.to_proto(),
1370 collaborators: project
1371 .collaborators
1372 .iter()
1373 .map(|collaborator| collaborator.to_proto())
1374 .collect(),
1375 })
1376 .collect(),
1377 rejoined_projects: rejoined_room
1378 .rejoined_projects
1379 .iter()
1380 .map(|rejoined_project| proto::RejoinedProject {
1381 id: rejoined_project.id.to_proto(),
1382 worktrees: rejoined_project
1383 .worktrees
1384 .iter()
1385 .map(|worktree| proto::WorktreeMetadata {
1386 id: worktree.id,
1387 root_name: worktree.root_name.clone(),
1388 visible: worktree.visible,
1389 abs_path: worktree.abs_path.clone(),
1390 })
1391 .collect(),
1392 collaborators: rejoined_project
1393 .collaborators
1394 .iter()
1395 .map(|collaborator| collaborator.to_proto())
1396 .collect(),
1397 language_servers: rejoined_project.language_servers.clone(),
1398 })
1399 .collect(),
1400 })?;
1401 room_updated(&rejoined_room.room, &session.peer);
1402
1403 for project in &rejoined_room.reshared_projects {
1404 for collaborator in &project.collaborators {
1405 session
1406 .peer
1407 .send(
1408 collaborator.connection_id,
1409 proto::UpdateProjectCollaborator {
1410 project_id: project.id.to_proto(),
1411 old_peer_id: Some(project.old_connection_id.into()),
1412 new_peer_id: Some(session.connection_id.into()),
1413 },
1414 )
1415 .trace_err();
1416 }
1417
1418 broadcast(
1419 Some(session.connection_id),
1420 project
1421 .collaborators
1422 .iter()
1423 .map(|collaborator| collaborator.connection_id),
1424 |connection_id| {
1425 session.peer.forward_send(
1426 session.connection_id,
1427 connection_id,
1428 proto::UpdateProject {
1429 project_id: project.id.to_proto(),
1430 worktrees: project.worktrees.clone(),
1431 },
1432 )
1433 },
1434 );
1435 }
1436
1437 for project in &rejoined_room.rejoined_projects {
1438 for collaborator in &project.collaborators {
1439 session
1440 .peer
1441 .send(
1442 collaborator.connection_id,
1443 proto::UpdateProjectCollaborator {
1444 project_id: project.id.to_proto(),
1445 old_peer_id: Some(project.old_connection_id.into()),
1446 new_peer_id: Some(session.connection_id.into()),
1447 },
1448 )
1449 .trace_err();
1450 }
1451 }
1452
1453 for project in &mut rejoined_room.rejoined_projects {
1454 for worktree in mem::take(&mut project.worktrees) {
1455 #[cfg(any(test, feature = "test-support"))]
1456 const MAX_CHUNK_SIZE: usize = 2;
1457 #[cfg(not(any(test, feature = "test-support")))]
1458 const MAX_CHUNK_SIZE: usize = 256;
1459
1460 // Stream this worktree's entries.
1461 let message = proto::UpdateWorktree {
1462 project_id: project.id.to_proto(),
1463 worktree_id: worktree.id,
1464 abs_path: worktree.abs_path.clone(),
1465 root_name: worktree.root_name,
1466 updated_entries: worktree.updated_entries,
1467 removed_entries: worktree.removed_entries,
1468 scan_id: worktree.scan_id,
1469 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1470 updated_repositories: worktree.updated_repositories,
1471 removed_repositories: worktree.removed_repositories,
1472 };
1473 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1474 session.peer.send(session.connection_id, update.clone())?;
1475 }
1476
1477 // Stream this worktree's diagnostics.
1478 for summary in worktree.diagnostic_summaries {
1479 session.peer.send(
1480 session.connection_id,
1481 proto::UpdateDiagnosticSummary {
1482 project_id: project.id.to_proto(),
1483 worktree_id: worktree.id,
1484 summary: Some(summary),
1485 },
1486 )?;
1487 }
1488
1489 for settings_file in worktree.settings_files {
1490 session.peer.send(
1491 session.connection_id,
1492 proto::UpdateWorktreeSettings {
1493 project_id: project.id.to_proto(),
1494 worktree_id: worktree.id,
1495 path: settings_file.path,
1496 content: Some(settings_file.content),
1497 },
1498 )?;
1499 }
1500 }
1501
1502 for language_server in &project.language_servers {
1503 session.peer.send(
1504 session.connection_id,
1505 proto::UpdateLanguageServer {
1506 project_id: project.id.to_proto(),
1507 language_server_id: language_server.id,
1508 variant: Some(
1509 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1510 proto::LspDiskBasedDiagnosticsUpdated {},
1511 ),
1512 ),
1513 },
1514 )?;
1515 }
1516 }
1517
1518 let rejoined_room = rejoined_room.into_inner();
1519
1520 room = rejoined_room.room;
1521 channel = rejoined_room.channel;
1522 }
1523
1524 if let Some(channel) = channel {
1525 channel_updated(
1526 &channel,
1527 &room,
1528 &session.peer,
1529 &*session.connection_pool().await,
1530 );
1531 }
1532
1533 update_user_contacts(session.user_id(), &session).await?;
1534 Ok(())
1535}
1536
1537/// leave room disconnects from the room.
1538async fn leave_room(
1539 _: proto::LeaveRoom,
1540 response: Response<proto::LeaveRoom>,
1541 session: UserSession,
1542) -> Result<()> {
1543 leave_room_for_session(&session, session.connection_id).await?;
1544 response.send(proto::Ack {})?;
1545 Ok(())
1546}
1547
1548/// Updates the permissions of someone else in the room.
1549async fn set_room_participant_role(
1550 request: proto::SetRoomParticipantRole,
1551 response: Response<proto::SetRoomParticipantRole>,
1552 session: UserSession,
1553) -> Result<()> {
1554 let user_id = UserId::from_proto(request.user_id);
1555 let role = ChannelRole::from(request.role());
1556
1557 let (live_kit_room, can_publish) = {
1558 let room = session
1559 .db()
1560 .await
1561 .set_room_participant_role(
1562 session.user_id(),
1563 RoomId::from_proto(request.room_id),
1564 user_id,
1565 role,
1566 )
1567 .await?;
1568
1569 let live_kit_room = room.live_kit_room.clone();
1570 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1571 room_updated(&room, &session.peer);
1572 (live_kit_room, can_publish)
1573 };
1574
1575 if let Some(live_kit) = session.live_kit_client.as_ref() {
1576 live_kit
1577 .update_participant(
1578 live_kit_room.clone(),
1579 request.user_id.to_string(),
1580 live_kit_server::proto::ParticipantPermission {
1581 can_subscribe: true,
1582 can_publish,
1583 can_publish_data: can_publish,
1584 hidden: false,
1585 recorder: false,
1586 },
1587 )
1588 .await
1589 .trace_err();
1590 }
1591
1592 response.send(proto::Ack {})?;
1593 Ok(())
1594}
1595
1596/// Call someone else into the current room
1597async fn call(
1598 request: proto::Call,
1599 response: Response<proto::Call>,
1600 session: UserSession,
1601) -> Result<()> {
1602 let room_id = RoomId::from_proto(request.room_id);
1603 let calling_user_id = session.user_id();
1604 let calling_connection_id = session.connection_id;
1605 let called_user_id = UserId::from_proto(request.called_user_id);
1606 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1607 if !session
1608 .db()
1609 .await
1610 .has_contact(calling_user_id, called_user_id)
1611 .await?
1612 {
1613 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1614 }
1615
1616 let incoming_call = {
1617 let (room, incoming_call) = &mut *session
1618 .db()
1619 .await
1620 .call(
1621 room_id,
1622 calling_user_id,
1623 calling_connection_id,
1624 called_user_id,
1625 initial_project_id,
1626 )
1627 .await?;
1628 room_updated(&room, &session.peer);
1629 mem::take(incoming_call)
1630 };
1631 update_user_contacts(called_user_id, &session).await?;
1632
1633 let mut calls = session
1634 .connection_pool()
1635 .await
1636 .user_connection_ids(called_user_id)
1637 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1638 .collect::<FuturesUnordered<_>>();
1639
1640 while let Some(call_response) = calls.next().await {
1641 match call_response.as_ref() {
1642 Ok(_) => {
1643 response.send(proto::Ack {})?;
1644 return Ok(());
1645 }
1646 Err(_) => {
1647 call_response.trace_err();
1648 }
1649 }
1650 }
1651
1652 {
1653 let room = session
1654 .db()
1655 .await
1656 .call_failed(room_id, called_user_id)
1657 .await?;
1658 room_updated(&room, &session.peer);
1659 }
1660 update_user_contacts(called_user_id, &session).await?;
1661
1662 Err(anyhow!("failed to ring user"))?
1663}
1664
1665/// Cancel an outgoing call.
1666async fn cancel_call(
1667 request: proto::CancelCall,
1668 response: Response<proto::CancelCall>,
1669 session: UserSession,
1670) -> Result<()> {
1671 let called_user_id = UserId::from_proto(request.called_user_id);
1672 let room_id = RoomId::from_proto(request.room_id);
1673 {
1674 let room = session
1675 .db()
1676 .await
1677 .cancel_call(room_id, session.connection_id, called_user_id)
1678 .await?;
1679 room_updated(&room, &session.peer);
1680 }
1681
1682 for connection_id in session
1683 .connection_pool()
1684 .await
1685 .user_connection_ids(called_user_id)
1686 {
1687 session
1688 .peer
1689 .send(
1690 connection_id,
1691 proto::CallCanceled {
1692 room_id: room_id.to_proto(),
1693 },
1694 )
1695 .trace_err();
1696 }
1697 response.send(proto::Ack {})?;
1698
1699 update_user_contacts(called_user_id, &session).await?;
1700 Ok(())
1701}
1702
1703/// Decline an incoming call.
1704async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1705 let room_id = RoomId::from_proto(message.room_id);
1706 {
1707 let room = session
1708 .db()
1709 .await
1710 .decline_call(Some(room_id), session.user_id())
1711 .await?
1712 .ok_or_else(|| anyhow!("failed to decline call"))?;
1713 room_updated(&room, &session.peer);
1714 }
1715
1716 for connection_id in session
1717 .connection_pool()
1718 .await
1719 .user_connection_ids(session.user_id())
1720 {
1721 session
1722 .peer
1723 .send(
1724 connection_id,
1725 proto::CallCanceled {
1726 room_id: room_id.to_proto(),
1727 },
1728 )
1729 .trace_err();
1730 }
1731 update_user_contacts(session.user_id(), &session).await?;
1732 Ok(())
1733}
1734
1735/// Updates other participants in the room with your current location.
1736async fn update_participant_location(
1737 request: proto::UpdateParticipantLocation,
1738 response: Response<proto::UpdateParticipantLocation>,
1739 session: UserSession,
1740) -> Result<()> {
1741 let room_id = RoomId::from_proto(request.room_id);
1742 let location = request
1743 .location
1744 .ok_or_else(|| anyhow!("invalid location"))?;
1745
1746 let db = session.db().await;
1747 let room = db
1748 .update_room_participant_location(room_id, session.connection_id, location)
1749 .await?;
1750
1751 room_updated(&room, &session.peer);
1752 response.send(proto::Ack {})?;
1753 Ok(())
1754}
1755
1756/// Share a project into the room.
1757async fn share_project(
1758 request: proto::ShareProject,
1759 response: Response<proto::ShareProject>,
1760 session: Session,
1761) -> Result<()> {
1762 let (project_id, room) = &*session
1763 .db()
1764 .await
1765 .share_project(
1766 RoomId::from_proto(request.room_id),
1767 session.connection_id,
1768 &request.worktrees,
1769 )
1770 .await?;
1771 response.send(proto::ShareProjectResponse {
1772 project_id: project_id.to_proto(),
1773 })?;
1774 room_updated(&room, &session.peer);
1775
1776 Ok(())
1777}
1778
1779/// Unshare a project from the room.
1780async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1781 let project_id = ProjectId::from_proto(message.project_id);
1782
1783 let (room, guest_connection_ids) = &*session
1784 .db()
1785 .await
1786 .unshare_project(project_id, session.connection_id)
1787 .await?;
1788
1789 broadcast(
1790 Some(session.connection_id),
1791 guest_connection_ids.iter().copied(),
1792 |conn_id| session.peer.send(conn_id, message.clone()),
1793 );
1794 room_updated(&room, &session.peer);
1795
1796 Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801 request: proto::JoinProject,
1802 response: Response<proto::JoinProject>,
1803 session: UserSession,
1804) -> Result<()> {
1805 let project_id = ProjectId::from_proto(request.project_id);
1806
1807 tracing::info!(%project_id, "join project");
1808
1809 let (project, replica_id) = &mut *session
1810 .db()
1811 .await
1812 .join_project_in_room(project_id, session.connection_id)
1813 .await?;
1814
1815 join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823 Response::<proto::JoinProject>::send(self, result)
1824 }
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1827 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828 Response::<proto::JoinHostedProject>::send(self, result)
1829 }
1830}
1831
1832fn join_project_internal(
1833 response: impl JoinProjectInternalResponse,
1834 session: UserSession,
1835 project: &mut Project,
1836 replica_id: &ReplicaId,
1837) -> Result<()> {
1838 let collaborators = project
1839 .collaborators
1840 .iter()
1841 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842 .map(|collaborator| collaborator.to_proto())
1843 .collect::<Vec<_>>();
1844 let project_id = project.id;
1845 let guest_user_id = session.user_id();
1846
1847 let worktrees = project
1848 .worktrees
1849 .iter()
1850 .map(|(id, worktree)| proto::WorktreeMetadata {
1851 id: *id,
1852 root_name: worktree.root_name.clone(),
1853 visible: worktree.visible,
1854 abs_path: worktree.abs_path.clone(),
1855 })
1856 .collect::<Vec<_>>();
1857
1858 for collaborator in &collaborators {
1859 session
1860 .peer
1861 .send(
1862 collaborator.peer_id.unwrap().into(),
1863 proto::AddProjectCollaborator {
1864 project_id: project_id.to_proto(),
1865 collaborator: Some(proto::Collaborator {
1866 peer_id: Some(session.connection_id.into()),
1867 replica_id: replica_id.0 as u32,
1868 user_id: guest_user_id.to_proto(),
1869 }),
1870 },
1871 )
1872 .trace_err();
1873 }
1874
1875 // First, we send the metadata associated with each worktree.
1876 response.send(proto::JoinProjectResponse {
1877 project_id: project.id.0 as u64,
1878 worktrees: worktrees.clone(),
1879 replica_id: replica_id.0 as u32,
1880 collaborators: collaborators.clone(),
1881 language_servers: project.language_servers.clone(),
1882 role: project.role.into(), // todo
1883 })?;
1884
1885 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1886 #[cfg(any(test, feature = "test-support"))]
1887 const MAX_CHUNK_SIZE: usize = 2;
1888 #[cfg(not(any(test, feature = "test-support")))]
1889 const MAX_CHUNK_SIZE: usize = 256;
1890
1891 // Stream this worktree's entries.
1892 let message = proto::UpdateWorktree {
1893 project_id: project_id.to_proto(),
1894 worktree_id,
1895 abs_path: worktree.abs_path.clone(),
1896 root_name: worktree.root_name,
1897 updated_entries: worktree.entries,
1898 removed_entries: Default::default(),
1899 scan_id: worktree.scan_id,
1900 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1901 updated_repositories: worktree.repository_entries.into_values().collect(),
1902 removed_repositories: Default::default(),
1903 };
1904 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1905 session.peer.send(session.connection_id, update.clone())?;
1906 }
1907
1908 // Stream this worktree's diagnostics.
1909 for summary in worktree.diagnostic_summaries {
1910 session.peer.send(
1911 session.connection_id,
1912 proto::UpdateDiagnosticSummary {
1913 project_id: project_id.to_proto(),
1914 worktree_id: worktree.id,
1915 summary: Some(summary),
1916 },
1917 )?;
1918 }
1919
1920 for settings_file in worktree.settings_files {
1921 session.peer.send(
1922 session.connection_id,
1923 proto::UpdateWorktreeSettings {
1924 project_id: project_id.to_proto(),
1925 worktree_id: worktree.id,
1926 path: settings_file.path,
1927 content: Some(settings_file.content),
1928 },
1929 )?;
1930 }
1931 }
1932
1933 for language_server in &project.language_servers {
1934 session.peer.send(
1935 session.connection_id,
1936 proto::UpdateLanguageServer {
1937 project_id: project_id.to_proto(),
1938 language_server_id: language_server.id,
1939 variant: Some(
1940 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1941 proto::LspDiskBasedDiagnosticsUpdated {},
1942 ),
1943 ),
1944 },
1945 )?;
1946 }
1947
1948 Ok(())
1949}
1950
1951/// Leave someone elses shared project.
1952async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1953 let sender_id = session.connection_id;
1954 let project_id = ProjectId::from_proto(request.project_id);
1955 let db = session.db().await;
1956 if db.is_hosted_project(project_id).await? {
1957 let project = db.leave_hosted_project(project_id, sender_id).await?;
1958 project_left(&project, &session);
1959 return Ok(());
1960 }
1961
1962 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963 tracing::info!(
1964 %project_id,
1965 host_user_id = ?project.host_user_id,
1966 host_connection_id = ?project.host_connection_id,
1967 "leave project"
1968 );
1969
1970 project_left(&project, &session);
1971 room_updated(&room, &session.peer);
1972
1973 Ok(())
1974}
1975
1976async fn join_hosted_project(
1977 request: proto::JoinHostedProject,
1978 response: Response<proto::JoinHostedProject>,
1979 session: UserSession,
1980) -> Result<()> {
1981 let (mut project, replica_id) = session
1982 .db()
1983 .await
1984 .join_hosted_project(
1985 ProjectId(request.project_id as i32),
1986 session.user_id(),
1987 session.connection_id,
1988 )
1989 .await?;
1990
1991 join_project_internal(response, session, &mut project, &replica_id)
1992}
1993
1994/// Updates other participants with changes to the project
1995async fn update_project(
1996 request: proto::UpdateProject,
1997 response: Response<proto::UpdateProject>,
1998 session: Session,
1999) -> Result<()> {
2000 let project_id = ProjectId::from_proto(request.project_id);
2001 let (room, guest_connection_ids) = &*session
2002 .db()
2003 .await
2004 .update_project(project_id, session.connection_id, &request.worktrees)
2005 .await?;
2006 broadcast(
2007 Some(session.connection_id),
2008 guest_connection_ids.iter().copied(),
2009 |connection_id| {
2010 session
2011 .peer
2012 .forward_send(session.connection_id, connection_id, request.clone())
2013 },
2014 );
2015 room_updated(&room, &session.peer);
2016 response.send(proto::Ack {})?;
2017
2018 Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree
2022async fn update_worktree(
2023 request: proto::UpdateWorktree,
2024 response: Response<proto::UpdateWorktree>,
2025 session: Session,
2026) -> Result<()> {
2027 let guest_connection_ids = session
2028 .db()
2029 .await
2030 .update_worktree(&request, session.connection_id)
2031 .await?;
2032
2033 broadcast(
2034 Some(session.connection_id),
2035 guest_connection_ids.iter().copied(),
2036 |connection_id| {
2037 session
2038 .peer
2039 .forward_send(session.connection_id, connection_id, request.clone())
2040 },
2041 );
2042 response.send(proto::Ack {})?;
2043 Ok(())
2044}
2045
2046/// Updates other participants with changes to the diagnostics
2047async fn update_diagnostic_summary(
2048 message: proto::UpdateDiagnosticSummary,
2049 session: Session,
2050) -> Result<()> {
2051 let guest_connection_ids = session
2052 .db()
2053 .await
2054 .update_diagnostic_summary(&message, session.connection_id)
2055 .await?;
2056
2057 broadcast(
2058 Some(session.connection_id),
2059 guest_connection_ids.iter().copied(),
2060 |connection_id| {
2061 session
2062 .peer
2063 .forward_send(session.connection_id, connection_id, message.clone())
2064 },
2065 );
2066
2067 Ok(())
2068}
2069
2070/// Updates other participants with changes to the worktree settings
2071async fn update_worktree_settings(
2072 message: proto::UpdateWorktreeSettings,
2073 session: Session,
2074) -> Result<()> {
2075 let guest_connection_ids = session
2076 .db()
2077 .await
2078 .update_worktree_settings(&message, session.connection_id)
2079 .await?;
2080
2081 broadcast(
2082 Some(session.connection_id),
2083 guest_connection_ids.iter().copied(),
2084 |connection_id| {
2085 session
2086 .peer
2087 .forward_send(session.connection_id, connection_id, message.clone())
2088 },
2089 );
2090
2091 Ok(())
2092}
2093
2094/// Notify other participants that a language server has started.
2095async fn start_language_server(
2096 request: proto::StartLanguageServer,
2097 session: Session,
2098) -> Result<()> {
2099 let guest_connection_ids = session
2100 .db()
2101 .await
2102 .start_language_server(&request, session.connection_id)
2103 .await?;
2104
2105 broadcast(
2106 Some(session.connection_id),
2107 guest_connection_ids.iter().copied(),
2108 |connection_id| {
2109 session
2110 .peer
2111 .forward_send(session.connection_id, connection_id, request.clone())
2112 },
2113 );
2114 Ok(())
2115}
2116
2117/// Notify other participants that a language server has changed.
2118async fn update_language_server(
2119 request: proto::UpdateLanguageServer,
2120 session: Session,
2121) -> Result<()> {
2122 let project_id = ProjectId::from_proto(request.project_id);
2123 let project_connection_ids = session
2124 .db()
2125 .await
2126 .project_connection_ids(project_id, session.connection_id)
2127 .await?;
2128 broadcast(
2129 Some(session.connection_id),
2130 project_connection_ids.iter().copied(),
2131 |connection_id| {
2132 session
2133 .peer
2134 .forward_send(session.connection_id, connection_id, request.clone())
2135 },
2136 );
2137 Ok(())
2138}
2139
2140/// forward a project request to the host. These requests should be read only
2141/// as guests are allowed to send them.
2142async fn forward_read_only_project_request<T>(
2143 request: T,
2144 response: Response<T>,
2145 session: Session,
2146) -> Result<()>
2147where
2148 T: EntityMessage + RequestMessage,
2149{
2150 let project_id = ProjectId::from_proto(request.remote_entity_id());
2151 let host_connection_id = session
2152 .db()
2153 .await
2154 .host_for_read_only_project_request(project_id, session.connection_id)
2155 .await?;
2156 let payload = session
2157 .peer
2158 .forward_request(session.connection_id, host_connection_id, request)
2159 .await?;
2160 response.send(payload)?;
2161 Ok(())
2162}
2163
2164/// forward a project request to the host. These requests are disallowed
2165/// for guests.
2166async fn forward_mutating_project_request<T>(
2167 request: T,
2168 response: Response<T>,
2169 session: Session,
2170) -> Result<()>
2171where
2172 T: EntityMessage + RequestMessage,
2173{
2174 let project_id = ProjectId::from_proto(request.remote_entity_id());
2175 let host_connection_id = session
2176 .db()
2177 .await
2178 .host_for_mutating_project_request(project_id, session.connection_id)
2179 .await?;
2180 let payload = session
2181 .peer
2182 .forward_request(session.connection_id, host_connection_id, request)
2183 .await?;
2184 response.send(payload)?;
2185 Ok(())
2186}
2187
2188/// Notify other participants that a new buffer has been created
2189async fn create_buffer_for_peer(
2190 request: proto::CreateBufferForPeer,
2191 session: Session,
2192) -> Result<()> {
2193 session
2194 .db()
2195 .await
2196 .check_user_is_project_host(
2197 ProjectId::from_proto(request.project_id),
2198 session.connection_id,
2199 )
2200 .await?;
2201 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2202 session
2203 .peer
2204 .forward_send(session.connection_id, peer_id.into(), request)?;
2205 Ok(())
2206}
2207
2208/// Notify other participants that a buffer has been updated. This is
2209/// allowed for guests as long as the update is limited to selections.
2210async fn update_buffer(
2211 request: proto::UpdateBuffer,
2212 response: Response<proto::UpdateBuffer>,
2213 session: Session,
2214) -> Result<()> {
2215 let project_id = ProjectId::from_proto(request.project_id);
2216 let mut guest_connection_ids;
2217 let mut host_connection_id = None;
2218
2219 let mut requires_write_permission = false;
2220
2221 for op in request.operations.iter() {
2222 match op.variant {
2223 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2224 Some(_) => requires_write_permission = true,
2225 }
2226 }
2227
2228 {
2229 let collaborators = session
2230 .db()
2231 .await
2232 .project_collaborators_for_buffer_update(
2233 project_id,
2234 session.connection_id,
2235 requires_write_permission,
2236 )
2237 .await?;
2238 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2239 for collaborator in collaborators.iter() {
2240 if collaborator.is_host {
2241 host_connection_id = Some(collaborator.connection_id);
2242 } else {
2243 guest_connection_ids.push(collaborator.connection_id);
2244 }
2245 }
2246 }
2247 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2248
2249 broadcast(
2250 Some(session.connection_id),
2251 guest_connection_ids,
2252 |connection_id| {
2253 session
2254 .peer
2255 .forward_send(session.connection_id, connection_id, request.clone())
2256 },
2257 );
2258 if host_connection_id != session.connection_id {
2259 session
2260 .peer
2261 .forward_request(session.connection_id, host_connection_id, request.clone())
2262 .await?;
2263 }
2264
2265 response.send(proto::Ack {})?;
2266 Ok(())
2267}
2268
2269/// Notify other participants that a project has been updated.
2270async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2271 request: T,
2272 session: Session,
2273) -> Result<()> {
2274 let project_id = ProjectId::from_proto(request.remote_entity_id());
2275 let project_connection_ids = session
2276 .db()
2277 .await
2278 .project_connection_ids(project_id, session.connection_id)
2279 .await?;
2280
2281 broadcast(
2282 Some(session.connection_id),
2283 project_connection_ids.iter().copied(),
2284 |connection_id| {
2285 session
2286 .peer
2287 .forward_send(session.connection_id, connection_id, request.clone())
2288 },
2289 );
2290 Ok(())
2291}
2292
2293/// Start following another user in a call.
2294async fn follow(
2295 request: proto::Follow,
2296 response: Response<proto::Follow>,
2297 session: UserSession,
2298) -> Result<()> {
2299 let room_id = RoomId::from_proto(request.room_id);
2300 let project_id = request.project_id.map(ProjectId::from_proto);
2301 let leader_id = request
2302 .leader_id
2303 .ok_or_else(|| anyhow!("invalid leader id"))?
2304 .into();
2305 let follower_id = session.connection_id;
2306
2307 session
2308 .db()
2309 .await
2310 .check_room_participants(room_id, leader_id, session.connection_id)
2311 .await?;
2312
2313 let response_payload = session
2314 .peer
2315 .forward_request(session.connection_id, leader_id, request)
2316 .await?;
2317 response.send(response_payload)?;
2318
2319 if let Some(project_id) = project_id {
2320 let room = session
2321 .db()
2322 .await
2323 .follow(room_id, project_id, leader_id, follower_id)
2324 .await?;
2325 room_updated(&room, &session.peer);
2326 }
2327
2328 Ok(())
2329}
2330
2331/// Stop following another user in a call.
2332async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2333 let room_id = RoomId::from_proto(request.room_id);
2334 let project_id = request.project_id.map(ProjectId::from_proto);
2335 let leader_id = request
2336 .leader_id
2337 .ok_or_else(|| anyhow!("invalid leader id"))?
2338 .into();
2339 let follower_id = session.connection_id;
2340
2341 session
2342 .db()
2343 .await
2344 .check_room_participants(room_id, leader_id, session.connection_id)
2345 .await?;
2346
2347 session
2348 .peer
2349 .forward_send(session.connection_id, leader_id, request)?;
2350
2351 if let Some(project_id) = project_id {
2352 let room = session
2353 .db()
2354 .await
2355 .unfollow(room_id, project_id, leader_id, follower_id)
2356 .await?;
2357 room_updated(&room, &session.peer);
2358 }
2359
2360 Ok(())
2361}
2362
2363/// Notify everyone following you of your current location.
2364async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2365 let room_id = RoomId::from_proto(request.room_id);
2366 let database = session.db.lock().await;
2367
2368 let connection_ids = if let Some(project_id) = request.project_id {
2369 let project_id = ProjectId::from_proto(project_id);
2370 database
2371 .project_connection_ids(project_id, session.connection_id)
2372 .await?
2373 } else {
2374 database
2375 .room_connection_ids(room_id, session.connection_id)
2376 .await?
2377 };
2378
2379 // For now, don't send view update messages back to that view's current leader.
2380 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2381 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2382 _ => None,
2383 });
2384
2385 for connection_id in connection_ids.iter().cloned() {
2386 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2387 session
2388 .peer
2389 .forward_send(session.connection_id, connection_id, request.clone())?;
2390 }
2391 }
2392 Ok(())
2393}
2394
2395/// Get public data about users.
2396async fn get_users(
2397 request: proto::GetUsers,
2398 response: Response<proto::GetUsers>,
2399 session: Session,
2400) -> Result<()> {
2401 let user_ids = request
2402 .user_ids
2403 .into_iter()
2404 .map(UserId::from_proto)
2405 .collect();
2406 let users = session
2407 .db()
2408 .await
2409 .get_users_by_ids(user_ids)
2410 .await?
2411 .into_iter()
2412 .map(|user| proto::User {
2413 id: user.id.to_proto(),
2414 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2415 github_login: user.github_login,
2416 })
2417 .collect();
2418 response.send(proto::UsersResponse { users })?;
2419 Ok(())
2420}
2421
2422/// Search for users (to invite) buy Github login
2423async fn fuzzy_search_users(
2424 request: proto::FuzzySearchUsers,
2425 response: Response<proto::FuzzySearchUsers>,
2426 session: UserSession,
2427) -> Result<()> {
2428 let query = request.query;
2429 let users = match query.len() {
2430 0 => vec![],
2431 1 | 2 => session
2432 .db()
2433 .await
2434 .get_user_by_github_login(&query)
2435 .await?
2436 .into_iter()
2437 .collect(),
2438 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2439 };
2440 let users = users
2441 .into_iter()
2442 .filter(|user| user.id != session.user_id())
2443 .map(|user| proto::User {
2444 id: user.id.to_proto(),
2445 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2446 github_login: user.github_login,
2447 })
2448 .collect();
2449 response.send(proto::UsersResponse { users })?;
2450 Ok(())
2451}
2452
2453/// Send a contact request to another user.
2454async fn request_contact(
2455 request: proto::RequestContact,
2456 response: Response<proto::RequestContact>,
2457 session: UserSession,
2458) -> Result<()> {
2459 let requester_id = session.user_id();
2460 let responder_id = UserId::from_proto(request.responder_id);
2461 if requester_id == responder_id {
2462 return Err(anyhow!("cannot add yourself as a contact"))?;
2463 }
2464
2465 let notifications = session
2466 .db()
2467 .await
2468 .send_contact_request(requester_id, responder_id)
2469 .await?;
2470
2471 // Update outgoing contact requests of requester
2472 let mut update = proto::UpdateContacts::default();
2473 update.outgoing_requests.push(responder_id.to_proto());
2474 for connection_id in session
2475 .connection_pool()
2476 .await
2477 .user_connection_ids(requester_id)
2478 {
2479 session.peer.send(connection_id, update.clone())?;
2480 }
2481
2482 // Update incoming contact requests of responder
2483 let mut update = proto::UpdateContacts::default();
2484 update
2485 .incoming_requests
2486 .push(proto::IncomingContactRequest {
2487 requester_id: requester_id.to_proto(),
2488 });
2489 let connection_pool = session.connection_pool().await;
2490 for connection_id in connection_pool.user_connection_ids(responder_id) {
2491 session.peer.send(connection_id, update.clone())?;
2492 }
2493
2494 send_notifications(&connection_pool, &session.peer, notifications);
2495
2496 response.send(proto::Ack {})?;
2497 Ok(())
2498}
2499
2500/// Accept or decline a contact request
2501async fn respond_to_contact_request(
2502 request: proto::RespondToContactRequest,
2503 response: Response<proto::RespondToContactRequest>,
2504 session: UserSession,
2505) -> Result<()> {
2506 let responder_id = session.user_id();
2507 let requester_id = UserId::from_proto(request.requester_id);
2508 let db = session.db().await;
2509 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2510 db.dismiss_contact_notification(responder_id, requester_id)
2511 .await?;
2512 } else {
2513 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2514
2515 let notifications = db
2516 .respond_to_contact_request(responder_id, requester_id, accept)
2517 .await?;
2518 let requester_busy = db.is_user_busy(requester_id).await?;
2519 let responder_busy = db.is_user_busy(responder_id).await?;
2520
2521 let pool = session.connection_pool().await;
2522 // Update responder with new contact
2523 let mut update = proto::UpdateContacts::default();
2524 if accept {
2525 update
2526 .contacts
2527 .push(contact_for_user(requester_id, requester_busy, &pool));
2528 }
2529 update
2530 .remove_incoming_requests
2531 .push(requester_id.to_proto());
2532 for connection_id in pool.user_connection_ids(responder_id) {
2533 session.peer.send(connection_id, update.clone())?;
2534 }
2535
2536 // Update requester with new contact
2537 let mut update = proto::UpdateContacts::default();
2538 if accept {
2539 update
2540 .contacts
2541 .push(contact_for_user(responder_id, responder_busy, &pool));
2542 }
2543 update
2544 .remove_outgoing_requests
2545 .push(responder_id.to_proto());
2546
2547 for connection_id in pool.user_connection_ids(requester_id) {
2548 session.peer.send(connection_id, update.clone())?;
2549 }
2550
2551 send_notifications(&pool, &session.peer, notifications);
2552 }
2553
2554 response.send(proto::Ack {})?;
2555 Ok(())
2556}
2557
2558/// Remove a contact.
2559async fn remove_contact(
2560 request: proto::RemoveContact,
2561 response: Response<proto::RemoveContact>,
2562 session: UserSession,
2563) -> Result<()> {
2564 let requester_id = session.user_id();
2565 let responder_id = UserId::from_proto(request.user_id);
2566 let db = session.db().await;
2567 let (contact_accepted, deleted_notification_id) =
2568 db.remove_contact(requester_id, responder_id).await?;
2569
2570 let pool = session.connection_pool().await;
2571 // Update outgoing contact requests of requester
2572 let mut update = proto::UpdateContacts::default();
2573 if contact_accepted {
2574 update.remove_contacts.push(responder_id.to_proto());
2575 } else {
2576 update
2577 .remove_outgoing_requests
2578 .push(responder_id.to_proto());
2579 }
2580 for connection_id in pool.user_connection_ids(requester_id) {
2581 session.peer.send(connection_id, update.clone())?;
2582 }
2583
2584 // Update incoming contact requests of responder
2585 let mut update = proto::UpdateContacts::default();
2586 if contact_accepted {
2587 update.remove_contacts.push(requester_id.to_proto());
2588 } else {
2589 update
2590 .remove_incoming_requests
2591 .push(requester_id.to_proto());
2592 }
2593 for connection_id in pool.user_connection_ids(responder_id) {
2594 session.peer.send(connection_id, update.clone())?;
2595 if let Some(notification_id) = deleted_notification_id {
2596 session.peer.send(
2597 connection_id,
2598 proto::DeleteNotification {
2599 notification_id: notification_id.to_proto(),
2600 },
2601 )?;
2602 }
2603 }
2604
2605 response.send(proto::Ack {})?;
2606 Ok(())
2607}
2608
2609/// Creates a new channel.
2610async fn create_channel(
2611 request: proto::CreateChannel,
2612 response: Response<proto::CreateChannel>,
2613 session: UserSession,
2614) -> Result<()> {
2615 let db = session.db().await;
2616
2617 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2618 let (channel, membership) = db
2619 .create_channel(&request.name, parent_id, session.user_id())
2620 .await?;
2621
2622 let root_id = channel.root_id();
2623 let channel = Channel::from_model(channel);
2624
2625 response.send(proto::CreateChannelResponse {
2626 channel: Some(channel.to_proto()),
2627 parent_id: request.parent_id,
2628 })?;
2629
2630 let mut connection_pool = session.connection_pool().await;
2631 if let Some(membership) = membership {
2632 connection_pool.subscribe_to_channel(
2633 membership.user_id,
2634 membership.channel_id,
2635 membership.role,
2636 );
2637 let update = proto::UpdateUserChannels {
2638 channel_memberships: vec![proto::ChannelMembership {
2639 channel_id: membership.channel_id.to_proto(),
2640 role: membership.role.into(),
2641 }],
2642 ..Default::default()
2643 };
2644 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2645 session.peer.send(connection_id, update.clone())?;
2646 }
2647 }
2648
2649 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2650 if !role.can_see_channel(channel.visibility) {
2651 continue;
2652 }
2653
2654 let update = proto::UpdateChannels {
2655 channels: vec![channel.to_proto()],
2656 ..Default::default()
2657 };
2658 session.peer.send(connection_id, update.clone())?;
2659 }
2660
2661 Ok(())
2662}
2663
2664/// Delete a channel
2665async fn delete_channel(
2666 request: proto::DeleteChannel,
2667 response: Response<proto::DeleteChannel>,
2668 session: UserSession,
2669) -> Result<()> {
2670 let db = session.db().await;
2671
2672 let channel_id = request.channel_id;
2673 let (root_channel, removed_channels) = db
2674 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2675 .await?;
2676 response.send(proto::Ack {})?;
2677
2678 // Notify members of removed channels
2679 let mut update = proto::UpdateChannels::default();
2680 update
2681 .delete_channels
2682 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2683
2684 let connection_pool = session.connection_pool().await;
2685 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2686 session.peer.send(connection_id, update.clone())?;
2687 }
2688
2689 Ok(())
2690}
2691
2692/// Invite someone to join a channel.
2693async fn invite_channel_member(
2694 request: proto::InviteChannelMember,
2695 response: Response<proto::InviteChannelMember>,
2696 session: UserSession,
2697) -> Result<()> {
2698 let db = session.db().await;
2699 let channel_id = ChannelId::from_proto(request.channel_id);
2700 let invitee_id = UserId::from_proto(request.user_id);
2701 let InviteMemberResult {
2702 channel,
2703 notifications,
2704 } = db
2705 .invite_channel_member(
2706 channel_id,
2707 invitee_id,
2708 session.user_id(),
2709 request.role().into(),
2710 )
2711 .await?;
2712
2713 let update = proto::UpdateChannels {
2714 channel_invitations: vec![channel.to_proto()],
2715 ..Default::default()
2716 };
2717
2718 let connection_pool = session.connection_pool().await;
2719 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2720 session.peer.send(connection_id, update.clone())?;
2721 }
2722
2723 send_notifications(&connection_pool, &session.peer, notifications);
2724
2725 response.send(proto::Ack {})?;
2726 Ok(())
2727}
2728
2729/// remove someone from a channel
2730async fn remove_channel_member(
2731 request: proto::RemoveChannelMember,
2732 response: Response<proto::RemoveChannelMember>,
2733 session: UserSession,
2734) -> Result<()> {
2735 let db = session.db().await;
2736 let channel_id = ChannelId::from_proto(request.channel_id);
2737 let member_id = UserId::from_proto(request.user_id);
2738
2739 let RemoveChannelMemberResult {
2740 membership_update,
2741 notification_id,
2742 } = db
2743 .remove_channel_member(channel_id, member_id, session.user_id())
2744 .await?;
2745
2746 let mut connection_pool = session.connection_pool().await;
2747 notify_membership_updated(
2748 &mut connection_pool,
2749 membership_update,
2750 member_id,
2751 &session.peer,
2752 );
2753 for connection_id in connection_pool.user_connection_ids(member_id) {
2754 if let Some(notification_id) = notification_id {
2755 session
2756 .peer
2757 .send(
2758 connection_id,
2759 proto::DeleteNotification {
2760 notification_id: notification_id.to_proto(),
2761 },
2762 )
2763 .trace_err();
2764 }
2765 }
2766
2767 response.send(proto::Ack {})?;
2768 Ok(())
2769}
2770
2771/// Toggle the channel between public and private.
2772/// Care is taken to maintain the invariant that public channels only descend from public channels,
2773/// (though members-only channels can appear at any point in the hierarchy).
2774async fn set_channel_visibility(
2775 request: proto::SetChannelVisibility,
2776 response: Response<proto::SetChannelVisibility>,
2777 session: UserSession,
2778) -> Result<()> {
2779 let db = session.db().await;
2780 let channel_id = ChannelId::from_proto(request.channel_id);
2781 let visibility = request.visibility().into();
2782
2783 let channel_model = db
2784 .set_channel_visibility(channel_id, visibility, session.user_id())
2785 .await?;
2786 let root_id = channel_model.root_id();
2787 let channel = Channel::from_model(channel_model);
2788
2789 let mut connection_pool = session.connection_pool().await;
2790 for (user_id, role) in connection_pool
2791 .channel_user_ids(root_id)
2792 .collect::<Vec<_>>()
2793 .into_iter()
2794 {
2795 let update = if role.can_see_channel(channel.visibility) {
2796 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2797 proto::UpdateChannels {
2798 channels: vec![channel.to_proto()],
2799 ..Default::default()
2800 }
2801 } else {
2802 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2803 proto::UpdateChannels {
2804 delete_channels: vec![channel.id.to_proto()],
2805 ..Default::default()
2806 }
2807 };
2808
2809 for connection_id in connection_pool.user_connection_ids(user_id) {
2810 session.peer.send(connection_id, update.clone())?;
2811 }
2812 }
2813
2814 response.send(proto::Ack {})?;
2815 Ok(())
2816}
2817
2818/// Alter the role for a user in the channel.
2819async fn set_channel_member_role(
2820 request: proto::SetChannelMemberRole,
2821 response: Response<proto::SetChannelMemberRole>,
2822 session: UserSession,
2823) -> Result<()> {
2824 let db = session.db().await;
2825 let channel_id = ChannelId::from_proto(request.channel_id);
2826 let member_id = UserId::from_proto(request.user_id);
2827 let result = db
2828 .set_channel_member_role(
2829 channel_id,
2830 session.user_id(),
2831 member_id,
2832 request.role().into(),
2833 )
2834 .await?;
2835
2836 match result {
2837 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2838 let mut connection_pool = session.connection_pool().await;
2839 notify_membership_updated(
2840 &mut connection_pool,
2841 membership_update,
2842 member_id,
2843 &session.peer,
2844 )
2845 }
2846 db::SetMemberRoleResult::InviteUpdated(channel) => {
2847 let update = proto::UpdateChannels {
2848 channel_invitations: vec![channel.to_proto()],
2849 ..Default::default()
2850 };
2851
2852 for connection_id in session
2853 .connection_pool()
2854 .await
2855 .user_connection_ids(member_id)
2856 {
2857 session.peer.send(connection_id, update.clone())?;
2858 }
2859 }
2860 }
2861
2862 response.send(proto::Ack {})?;
2863 Ok(())
2864}
2865
2866/// Change the name of a channel
2867async fn rename_channel(
2868 request: proto::RenameChannel,
2869 response: Response<proto::RenameChannel>,
2870 session: UserSession,
2871) -> Result<()> {
2872 let db = session.db().await;
2873 let channel_id = ChannelId::from_proto(request.channel_id);
2874 let channel_model = db
2875 .rename_channel(channel_id, session.user_id(), &request.name)
2876 .await?;
2877 let root_id = channel_model.root_id();
2878 let channel = Channel::from_model(channel_model);
2879
2880 response.send(proto::RenameChannelResponse {
2881 channel: Some(channel.to_proto()),
2882 })?;
2883
2884 let connection_pool = session.connection_pool().await;
2885 let update = proto::UpdateChannels {
2886 channels: vec![channel.to_proto()],
2887 ..Default::default()
2888 };
2889 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2890 if role.can_see_channel(channel.visibility) {
2891 session.peer.send(connection_id, update.clone())?;
2892 }
2893 }
2894
2895 Ok(())
2896}
2897
2898/// Move a channel to a new parent.
2899async fn move_channel(
2900 request: proto::MoveChannel,
2901 response: Response<proto::MoveChannel>,
2902 session: UserSession,
2903) -> Result<()> {
2904 let channel_id = ChannelId::from_proto(request.channel_id);
2905 let to = ChannelId::from_proto(request.to);
2906
2907 let (root_id, channels) = session
2908 .db()
2909 .await
2910 .move_channel(channel_id, to, session.user_id())
2911 .await?;
2912
2913 let connection_pool = session.connection_pool().await;
2914 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2915 let channels = channels
2916 .iter()
2917 .filter_map(|channel| {
2918 if role.can_see_channel(channel.visibility) {
2919 Some(channel.to_proto())
2920 } else {
2921 None
2922 }
2923 })
2924 .collect::<Vec<_>>();
2925 if channels.is_empty() {
2926 continue;
2927 }
2928
2929 let update = proto::UpdateChannels {
2930 channels,
2931 ..Default::default()
2932 };
2933
2934 session.peer.send(connection_id, update.clone())?;
2935 }
2936
2937 response.send(Ack {})?;
2938 Ok(())
2939}
2940
2941/// Get the list of channel members
2942async fn get_channel_members(
2943 request: proto::GetChannelMembers,
2944 response: Response<proto::GetChannelMembers>,
2945 session: UserSession,
2946) -> Result<()> {
2947 let db = session.db().await;
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let members = db
2950 .get_channel_participant_details(channel_id, session.user_id())
2951 .await?;
2952 response.send(proto::GetChannelMembersResponse { members })?;
2953 Ok(())
2954}
2955
2956/// Accept or decline a channel invitation.
2957async fn respond_to_channel_invite(
2958 request: proto::RespondToChannelInvite,
2959 response: Response<proto::RespondToChannelInvite>,
2960 session: UserSession,
2961) -> Result<()> {
2962 let db = session.db().await;
2963 let channel_id = ChannelId::from_proto(request.channel_id);
2964 let RespondToChannelInvite {
2965 membership_update,
2966 notifications,
2967 } = db
2968 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2969 .await?;
2970
2971 let mut connection_pool = session.connection_pool().await;
2972 if let Some(membership_update) = membership_update {
2973 notify_membership_updated(
2974 &mut connection_pool,
2975 membership_update,
2976 session.user_id(),
2977 &session.peer,
2978 );
2979 } else {
2980 let update = proto::UpdateChannels {
2981 remove_channel_invitations: vec![channel_id.to_proto()],
2982 ..Default::default()
2983 };
2984
2985 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2986 session.peer.send(connection_id, update.clone())?;
2987 }
2988 };
2989
2990 send_notifications(&connection_pool, &session.peer, notifications);
2991
2992 response.send(proto::Ack {})?;
2993
2994 Ok(())
2995}
2996
2997/// Join the channels' room
2998async fn join_channel(
2999 request: proto::JoinChannel,
3000 response: Response<proto::JoinChannel>,
3001 session: UserSession,
3002) -> Result<()> {
3003 let channel_id = ChannelId::from_proto(request.channel_id);
3004 join_channel_internal(channel_id, Box::new(response), session).await
3005}
3006
3007trait JoinChannelInternalResponse {
3008 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3009}
3010impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3011 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3012 Response::<proto::JoinChannel>::send(self, result)
3013 }
3014}
3015impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3016 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3017 Response::<proto::JoinRoom>::send(self, result)
3018 }
3019}
3020
3021async fn join_channel_internal(
3022 channel_id: ChannelId,
3023 response: Box<impl JoinChannelInternalResponse>,
3024 session: UserSession,
3025) -> Result<()> {
3026 let joined_room = {
3027 let mut db = session.db().await;
3028 // If zed quits without leaving the room, and the user re-opens zed before the
3029 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3030 // room they were in.
3031 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3032 tracing::info!(
3033 stale_connection_id = %connection,
3034 "cleaning up stale connection",
3035 );
3036 drop(db);
3037 leave_room_for_session(&session, connection).await?;
3038 db = session.db().await;
3039 }
3040
3041 let (joined_room, membership_updated, role) = db
3042 .join_channel(channel_id, session.user_id(), session.connection_id)
3043 .await?;
3044
3045 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3046 let (can_publish, token) = if role == ChannelRole::Guest {
3047 (
3048 false,
3049 live_kit
3050 .guest_token(
3051 &joined_room.room.live_kit_room,
3052 &session.user_id().to_string(),
3053 )
3054 .trace_err()?,
3055 )
3056 } else {
3057 (
3058 true,
3059 live_kit
3060 .room_token(
3061 &joined_room.room.live_kit_room,
3062 &session.user_id().to_string(),
3063 )
3064 .trace_err()?,
3065 )
3066 };
3067
3068 Some(LiveKitConnectionInfo {
3069 server_url: live_kit.url().into(),
3070 token,
3071 can_publish,
3072 })
3073 });
3074
3075 response.send(proto::JoinRoomResponse {
3076 room: Some(joined_room.room.clone()),
3077 channel_id: joined_room
3078 .channel
3079 .as_ref()
3080 .map(|channel| channel.id.to_proto()),
3081 live_kit_connection_info,
3082 })?;
3083
3084 let mut connection_pool = session.connection_pool().await;
3085 if let Some(membership_updated) = membership_updated {
3086 notify_membership_updated(
3087 &mut connection_pool,
3088 membership_updated,
3089 session.user_id(),
3090 &session.peer,
3091 );
3092 }
3093
3094 room_updated(&joined_room.room, &session.peer);
3095
3096 joined_room
3097 };
3098
3099 channel_updated(
3100 &joined_room
3101 .channel
3102 .ok_or_else(|| anyhow!("channel not returned"))?,
3103 &joined_room.room,
3104 &session.peer,
3105 &*session.connection_pool().await,
3106 );
3107
3108 update_user_contacts(session.user_id(), &session).await?;
3109 Ok(())
3110}
3111
3112/// Start editing the channel notes
3113async fn join_channel_buffer(
3114 request: proto::JoinChannelBuffer,
3115 response: Response<proto::JoinChannelBuffer>,
3116 session: UserSession,
3117) -> Result<()> {
3118 let db = session.db().await;
3119 let channel_id = ChannelId::from_proto(request.channel_id);
3120
3121 let open_response = db
3122 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3123 .await?;
3124
3125 let collaborators = open_response.collaborators.clone();
3126 response.send(open_response)?;
3127
3128 let update = UpdateChannelBufferCollaborators {
3129 channel_id: channel_id.to_proto(),
3130 collaborators: collaborators.clone(),
3131 };
3132 channel_buffer_updated(
3133 session.connection_id,
3134 collaborators
3135 .iter()
3136 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3137 &update,
3138 &session.peer,
3139 );
3140
3141 Ok(())
3142}
3143
3144/// Edit the channel notes
3145async fn update_channel_buffer(
3146 request: proto::UpdateChannelBuffer,
3147 session: UserSession,
3148) -> Result<()> {
3149 let db = session.db().await;
3150 let channel_id = ChannelId::from_proto(request.channel_id);
3151
3152 let (collaborators, non_collaborators, epoch, version) = db
3153 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3154 .await?;
3155
3156 channel_buffer_updated(
3157 session.connection_id,
3158 collaborators,
3159 &proto::UpdateChannelBuffer {
3160 channel_id: channel_id.to_proto(),
3161 operations: request.operations,
3162 },
3163 &session.peer,
3164 );
3165
3166 let pool = &*session.connection_pool().await;
3167
3168 broadcast(
3169 None,
3170 non_collaborators
3171 .iter()
3172 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3173 |peer_id| {
3174 session.peer.send(
3175 peer_id,
3176 proto::UpdateChannels {
3177 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3178 channel_id: channel_id.to_proto(),
3179 epoch: epoch as u64,
3180 version: version.clone(),
3181 }],
3182 ..Default::default()
3183 },
3184 )
3185 },
3186 );
3187
3188 Ok(())
3189}
3190
3191/// Rejoin the channel notes after a connection blip
3192async fn rejoin_channel_buffers(
3193 request: proto::RejoinChannelBuffers,
3194 response: Response<proto::RejoinChannelBuffers>,
3195 session: UserSession,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let buffers = db
3199 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3200 .await?;
3201
3202 for rejoined_buffer in &buffers {
3203 let collaborators_to_notify = rejoined_buffer
3204 .buffer
3205 .collaborators
3206 .iter()
3207 .filter_map(|c| Some(c.peer_id?.into()));
3208 channel_buffer_updated(
3209 session.connection_id,
3210 collaborators_to_notify,
3211 &proto::UpdateChannelBufferCollaborators {
3212 channel_id: rejoined_buffer.buffer.channel_id,
3213 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3214 },
3215 &session.peer,
3216 );
3217 }
3218
3219 response.send(proto::RejoinChannelBuffersResponse {
3220 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3221 })?;
3222
3223 Ok(())
3224}
3225
3226/// Stop editing the channel notes
3227async fn leave_channel_buffer(
3228 request: proto::LeaveChannelBuffer,
3229 response: Response<proto::LeaveChannelBuffer>,
3230 session: UserSession,
3231) -> Result<()> {
3232 let db = session.db().await;
3233 let channel_id = ChannelId::from_proto(request.channel_id);
3234
3235 let left_buffer = db
3236 .leave_channel_buffer(channel_id, session.connection_id)
3237 .await?;
3238
3239 response.send(Ack {})?;
3240
3241 channel_buffer_updated(
3242 session.connection_id,
3243 left_buffer.connections,
3244 &proto::UpdateChannelBufferCollaborators {
3245 channel_id: channel_id.to_proto(),
3246 collaborators: left_buffer.collaborators,
3247 },
3248 &session.peer,
3249 );
3250
3251 Ok(())
3252}
3253
3254fn channel_buffer_updated<T: EnvelopedMessage>(
3255 sender_id: ConnectionId,
3256 collaborators: impl IntoIterator<Item = ConnectionId>,
3257 message: &T,
3258 peer: &Peer,
3259) {
3260 broadcast(Some(sender_id), collaborators, |peer_id| {
3261 peer.send(peer_id, message.clone())
3262 });
3263}
3264
3265fn send_notifications(
3266 connection_pool: &ConnectionPool,
3267 peer: &Peer,
3268 notifications: db::NotificationBatch,
3269) {
3270 for (user_id, notification) in notifications {
3271 for connection_id in connection_pool.user_connection_ids(user_id) {
3272 if let Err(error) = peer.send(
3273 connection_id,
3274 proto::AddNotification {
3275 notification: Some(notification.clone()),
3276 },
3277 ) {
3278 tracing::error!(
3279 "failed to send notification to {:?} {}",
3280 connection_id,
3281 error
3282 );
3283 }
3284 }
3285 }
3286}
3287
3288/// Send a message to the channel
3289async fn send_channel_message(
3290 request: proto::SendChannelMessage,
3291 response: Response<proto::SendChannelMessage>,
3292 session: UserSession,
3293) -> Result<()> {
3294 // Validate the message body.
3295 let body = request.body.trim().to_string();
3296 if body.len() > MAX_MESSAGE_LEN {
3297 return Err(anyhow!("message is too long"))?;
3298 }
3299 if body.is_empty() {
3300 return Err(anyhow!("message can't be blank"))?;
3301 }
3302
3303 // TODO: adjust mentions if body is trimmed
3304
3305 let timestamp = OffsetDateTime::now_utc();
3306 let nonce = request
3307 .nonce
3308 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3309
3310 let channel_id = ChannelId::from_proto(request.channel_id);
3311 let CreatedChannelMessage {
3312 message_id,
3313 participant_connection_ids,
3314 channel_members,
3315 notifications,
3316 } = session
3317 .db()
3318 .await
3319 .create_channel_message(
3320 channel_id,
3321 session.user_id(),
3322 &body,
3323 &request.mentions,
3324 timestamp,
3325 nonce.clone().into(),
3326 match request.reply_to_message_id {
3327 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3328 None => None,
3329 },
3330 )
3331 .await?;
3332
3333 let message = proto::ChannelMessage {
3334 sender_id: session.user_id().to_proto(),
3335 id: message_id.to_proto(),
3336 body,
3337 mentions: request.mentions,
3338 timestamp: timestamp.unix_timestamp() as u64,
3339 nonce: Some(nonce),
3340 reply_to_message_id: request.reply_to_message_id,
3341 edited_at: None,
3342 };
3343 broadcast(
3344 Some(session.connection_id),
3345 participant_connection_ids,
3346 |connection| {
3347 session.peer.send(
3348 connection,
3349 proto::ChannelMessageSent {
3350 channel_id: channel_id.to_proto(),
3351 message: Some(message.clone()),
3352 },
3353 )
3354 },
3355 );
3356 response.send(proto::SendChannelMessageResponse {
3357 message: Some(message),
3358 })?;
3359
3360 let pool = &*session.connection_pool().await;
3361 broadcast(
3362 None,
3363 channel_members
3364 .iter()
3365 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3366 |peer_id| {
3367 session.peer.send(
3368 peer_id,
3369 proto::UpdateChannels {
3370 latest_channel_message_ids: vec![proto::ChannelMessageId {
3371 channel_id: channel_id.to_proto(),
3372 message_id: message_id.to_proto(),
3373 }],
3374 ..Default::default()
3375 },
3376 )
3377 },
3378 );
3379 send_notifications(pool, &session.peer, notifications);
3380
3381 Ok(())
3382}
3383
3384/// Delete a channel message
3385async fn remove_channel_message(
3386 request: proto::RemoveChannelMessage,
3387 response: Response<proto::RemoveChannelMessage>,
3388 session: UserSession,
3389) -> Result<()> {
3390 let channel_id = ChannelId::from_proto(request.channel_id);
3391 let message_id = MessageId::from_proto(request.message_id);
3392 let (connection_ids, existing_notification_ids) = session
3393 .db()
3394 .await
3395 .remove_channel_message(channel_id, message_id, session.user_id())
3396 .await?;
3397
3398 broadcast(
3399 Some(session.connection_id),
3400 connection_ids,
3401 move |connection| {
3402 session.peer.send(connection, request.clone())?;
3403
3404 for notification_id in &existing_notification_ids {
3405 session.peer.send(
3406 connection,
3407 proto::DeleteNotification {
3408 notification_id: (*notification_id).to_proto(),
3409 },
3410 )?;
3411 }
3412
3413 Ok(())
3414 },
3415 );
3416 response.send(proto::Ack {})?;
3417 Ok(())
3418}
3419
3420async fn update_channel_message(
3421 request: proto::UpdateChannelMessage,
3422 response: Response<proto::UpdateChannelMessage>,
3423 session: UserSession,
3424) -> Result<()> {
3425 let channel_id = ChannelId::from_proto(request.channel_id);
3426 let message_id = MessageId::from_proto(request.message_id);
3427 let updated_at = OffsetDateTime::now_utc();
3428 let UpdatedChannelMessage {
3429 message_id,
3430 participant_connection_ids,
3431 notifications,
3432 reply_to_message_id,
3433 timestamp,
3434 deleted_mention_notification_ids,
3435 updated_mention_notifications,
3436 } = session
3437 .db()
3438 .await
3439 .update_channel_message(
3440 channel_id,
3441 message_id,
3442 session.user_id(),
3443 request.body.as_str(),
3444 &request.mentions,
3445 updated_at,
3446 )
3447 .await?;
3448
3449 let nonce = request
3450 .nonce
3451 .clone()
3452 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3453
3454 let message = proto::ChannelMessage {
3455 sender_id: session.user_id().to_proto(),
3456 id: message_id.to_proto(),
3457 body: request.body.clone(),
3458 mentions: request.mentions.clone(),
3459 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3460 nonce: Some(nonce),
3461 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3462 edited_at: Some(updated_at.unix_timestamp() as u64),
3463 };
3464
3465 response.send(proto::Ack {})?;
3466
3467 let pool = &*session.connection_pool().await;
3468 broadcast(
3469 Some(session.connection_id),
3470 participant_connection_ids,
3471 |connection| {
3472 session.peer.send(
3473 connection,
3474 proto::ChannelMessageUpdate {
3475 channel_id: channel_id.to_proto(),
3476 message: Some(message.clone()),
3477 },
3478 )?;
3479
3480 for notification_id in &deleted_mention_notification_ids {
3481 session.peer.send(
3482 connection,
3483 proto::DeleteNotification {
3484 notification_id: (*notification_id).to_proto(),
3485 },
3486 )?;
3487 }
3488
3489 for notification in &updated_mention_notifications {
3490 session.peer.send(
3491 connection,
3492 proto::UpdateNotification {
3493 notification: Some(notification.clone()),
3494 },
3495 )?;
3496 }
3497
3498 Ok(())
3499 },
3500 );
3501
3502 send_notifications(pool, &session.peer, notifications);
3503
3504 Ok(())
3505}
3506
3507/// Mark a channel message as read
3508async fn acknowledge_channel_message(
3509 request: proto::AckChannelMessage,
3510 session: UserSession,
3511) -> Result<()> {
3512 let channel_id = ChannelId::from_proto(request.channel_id);
3513 let message_id = MessageId::from_proto(request.message_id);
3514 let notifications = session
3515 .db()
3516 .await
3517 .observe_channel_message(channel_id, session.user_id(), message_id)
3518 .await?;
3519 send_notifications(
3520 &*session.connection_pool().await,
3521 &session.peer,
3522 notifications,
3523 );
3524 Ok(())
3525}
3526
3527/// Mark a buffer version as synced
3528async fn acknowledge_buffer_version(
3529 request: proto::AckBufferOperation,
3530 session: UserSession,
3531) -> Result<()> {
3532 let buffer_id = BufferId::from_proto(request.buffer_id);
3533 session
3534 .db()
3535 .await
3536 .observe_buffer_version(
3537 buffer_id,
3538 session.user_id(),
3539 request.epoch as i32,
3540 &request.version,
3541 )
3542 .await?;
3543 Ok(())
3544}
3545
3546struct CompleteWithLanguageModelRateLimit;
3547
3548impl RateLimit for CompleteWithLanguageModelRateLimit {
3549 fn capacity() -> usize {
3550 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3551 .ok()
3552 .and_then(|v| v.parse().ok())
3553 .unwrap_or(120) // Picked arbitrarily
3554 }
3555
3556 fn refill_duration() -> chrono::Duration {
3557 chrono::Duration::hours(1)
3558 }
3559
3560 fn db_name() -> &'static str {
3561 "complete-with-language-model"
3562 }
3563}
3564
3565async fn complete_with_language_model(
3566 request: proto::CompleteWithLanguageModel,
3567 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3568 session: Session,
3569 open_ai_api_key: Option<Arc<str>>,
3570 google_ai_api_key: Option<Arc<str>>,
3571 anthropic_api_key: Option<Arc<str>>,
3572) -> Result<()> {
3573 let Some(session) = session.for_user() else {
3574 return Err(anyhow!("user not found"))?;
3575 };
3576 authorize_access_to_language_models(&session).await?;
3577 session
3578 .rate_limiter
3579 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3580 .await?;
3581
3582 if request.model.starts_with("gpt") {
3583 let api_key =
3584 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3585 complete_with_open_ai(request, response, session, api_key).await?;
3586 } else if request.model.starts_with("gemini") {
3587 let api_key = google_ai_api_key
3588 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3589 complete_with_google_ai(request, response, session, api_key).await?;
3590 } else if request.model.starts_with("claude") {
3591 let api_key = anthropic_api_key
3592 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3593 complete_with_anthropic(request, response, session, api_key).await?;
3594 }
3595
3596 Ok(())
3597}
3598
3599async fn complete_with_open_ai(
3600 request: proto::CompleteWithLanguageModel,
3601 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3602 session: UserSession,
3603 api_key: Arc<str>,
3604) -> Result<()> {
3605 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3606
3607 let mut completion_stream = open_ai::stream_completion(
3608 &session.http_client,
3609 OPEN_AI_API_URL,
3610 &api_key,
3611 crate::ai::language_model_request_to_open_ai(request)?,
3612 )
3613 .await
3614 .context("open_ai::stream_completion request failed")?;
3615
3616 while let Some(event) = completion_stream.next().await {
3617 let event = event?;
3618 response.send(proto::LanguageModelResponse {
3619 choices: event
3620 .choices
3621 .into_iter()
3622 .map(|choice| proto::LanguageModelChoiceDelta {
3623 index: choice.index,
3624 delta: Some(proto::LanguageModelResponseMessage {
3625 role: choice.delta.role.map(|role| match role {
3626 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3627 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3628 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3629 } as i32),
3630 content: choice.delta.content,
3631 }),
3632 finish_reason: choice.finish_reason,
3633 })
3634 .collect(),
3635 })?;
3636 }
3637
3638 Ok(())
3639}
3640
3641async fn complete_with_google_ai(
3642 request: proto::CompleteWithLanguageModel,
3643 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3644 session: UserSession,
3645 api_key: Arc<str>,
3646) -> Result<()> {
3647 let mut stream = google_ai::stream_generate_content(
3648 &session.http_client,
3649 google_ai::API_URL,
3650 api_key.as_ref(),
3651 crate::ai::language_model_request_to_google_ai(request)?,
3652 )
3653 .await
3654 .context("google_ai::stream_generate_content request failed")?;
3655
3656 while let Some(event) = stream.next().await {
3657 let event = event?;
3658 response.send(proto::LanguageModelResponse {
3659 choices: event
3660 .candidates
3661 .unwrap_or_default()
3662 .into_iter()
3663 .map(|candidate| proto::LanguageModelChoiceDelta {
3664 index: candidate.index as u32,
3665 delta: Some(proto::LanguageModelResponseMessage {
3666 role: Some(match candidate.content.role {
3667 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3668 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3669 } as i32),
3670 content: Some(
3671 candidate
3672 .content
3673 .parts
3674 .into_iter()
3675 .filter_map(|part| match part {
3676 google_ai::Part::TextPart(part) => Some(part.text),
3677 google_ai::Part::InlineDataPart(_) => None,
3678 })
3679 .collect(),
3680 ),
3681 }),
3682 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3683 })
3684 .collect(),
3685 })?;
3686 }
3687
3688 Ok(())
3689}
3690
3691async fn complete_with_anthropic(
3692 request: proto::CompleteWithLanguageModel,
3693 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3694 session: UserSession,
3695 api_key: Arc<str>,
3696) -> Result<()> {
3697 let model = anthropic::Model::from_id(&request.model)?;
3698
3699 let mut system_message = String::new();
3700 let messages = request
3701 .messages
3702 .into_iter()
3703 .filter_map(|message| match message.role() {
3704 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3705 role: anthropic::Role::User,
3706 content: message.content,
3707 }),
3708 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3709 role: anthropic::Role::Assistant,
3710 content: message.content,
3711 }),
3712 // Anthropic's API breaks system instructions out as a separate field rather
3713 // than having a system message role.
3714 LanguageModelRole::LanguageModelSystem => {
3715 if !system_message.is_empty() {
3716 system_message.push_str("\n\n");
3717 }
3718 system_message.push_str(&message.content);
3719
3720 None
3721 }
3722 })
3723 .collect();
3724
3725 let mut stream = anthropic::stream_completion(
3726 &session.http_client,
3727 "https://api.anthropic.com",
3728 &api_key,
3729 anthropic::Request {
3730 model,
3731 messages,
3732 stream: true,
3733 system: system_message,
3734 max_tokens: 4092,
3735 },
3736 )
3737 .await?;
3738
3739 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3740
3741 while let Some(event) = stream.next().await {
3742 let event = event?;
3743
3744 match event {
3745 anthropic::ResponseEvent::MessageStart { message } => {
3746 if let Some(role) = message.role {
3747 if role == "assistant" {
3748 current_role = proto::LanguageModelRole::LanguageModelAssistant;
3749 } else if role == "user" {
3750 current_role = proto::LanguageModelRole::LanguageModelUser;
3751 }
3752 }
3753 }
3754 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3755 match content_block {
3756 anthropic::ContentBlock::Text { text } => {
3757 if !text.is_empty() {
3758 response.send(proto::LanguageModelResponse {
3759 choices: vec![proto::LanguageModelChoiceDelta {
3760 index: 0,
3761 delta: Some(proto::LanguageModelResponseMessage {
3762 role: Some(current_role as i32),
3763 content: Some(text),
3764 }),
3765 finish_reason: None,
3766 }],
3767 })?;
3768 }
3769 }
3770 }
3771 }
3772 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3773 anthropic::TextDelta::TextDelta { text } => {
3774 response.send(proto::LanguageModelResponse {
3775 choices: vec![proto::LanguageModelChoiceDelta {
3776 index: 0,
3777 delta: Some(proto::LanguageModelResponseMessage {
3778 role: Some(current_role as i32),
3779 content: Some(text),
3780 }),
3781 finish_reason: None,
3782 }],
3783 })?;
3784 }
3785 },
3786 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3787 if let Some(stop_reason) = delta.stop_reason {
3788 response.send(proto::LanguageModelResponse {
3789 choices: vec![proto::LanguageModelChoiceDelta {
3790 index: 0,
3791 delta: None,
3792 finish_reason: Some(stop_reason),
3793 }],
3794 })?;
3795 }
3796 }
3797 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3798 anthropic::ResponseEvent::MessageStop {} => {}
3799 anthropic::ResponseEvent::Ping {} => {}
3800 }
3801 }
3802
3803 Ok(())
3804}
3805
3806struct CountTokensWithLanguageModelRateLimit;
3807
3808impl RateLimit for CountTokensWithLanguageModelRateLimit {
3809 fn capacity() -> usize {
3810 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3811 .ok()
3812 .and_then(|v| v.parse().ok())
3813 .unwrap_or(600) // Picked arbitrarily
3814 }
3815
3816 fn refill_duration() -> chrono::Duration {
3817 chrono::Duration::hours(1)
3818 }
3819
3820 fn db_name() -> &'static str {
3821 "count-tokens-with-language-model"
3822 }
3823}
3824
3825async fn count_tokens_with_language_model(
3826 request: proto::CountTokensWithLanguageModel,
3827 response: Response<proto::CountTokensWithLanguageModel>,
3828 session: UserSession,
3829 google_ai_api_key: Option<Arc<str>>,
3830) -> Result<()> {
3831 authorize_access_to_language_models(&session).await?;
3832
3833 if !request.model.starts_with("gemini") {
3834 return Err(anyhow!(
3835 "counting tokens for model: {:?} is not supported",
3836 request.model
3837 ))?;
3838 }
3839
3840 session
3841 .rate_limiter
3842 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3843 .await?;
3844
3845 let api_key = google_ai_api_key
3846 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3847 let tokens_response = google_ai::count_tokens(
3848 &session.http_client,
3849 google_ai::API_URL,
3850 &api_key,
3851 crate::ai::count_tokens_request_to_google_ai(request)?,
3852 )
3853 .await?;
3854 response.send(proto::CountTokensResponse {
3855 token_count: tokens_response.total_tokens as u32,
3856 })?;
3857 Ok(())
3858}
3859
3860async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3861 let db = session.db().await;
3862 let flags = db.get_user_flags(session.user_id()).await?;
3863 if flags.iter().any(|flag| flag == "language-models") {
3864 Ok(())
3865 } else {
3866 Err(anyhow!("permission denied"))?
3867 }
3868}
3869
3870/// Start receiving chat updates for a channel
3871async fn join_channel_chat(
3872 request: proto::JoinChannelChat,
3873 response: Response<proto::JoinChannelChat>,
3874 session: UserSession,
3875) -> Result<()> {
3876 let channel_id = ChannelId::from_proto(request.channel_id);
3877
3878 let db = session.db().await;
3879 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3880 .await?;
3881 let messages = db
3882 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3883 .await?;
3884 response.send(proto::JoinChannelChatResponse {
3885 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3886 messages,
3887 })?;
3888 Ok(())
3889}
3890
3891/// Stop receiving chat updates for a channel
3892async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3893 let channel_id = ChannelId::from_proto(request.channel_id);
3894 session
3895 .db()
3896 .await
3897 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3898 .await?;
3899 Ok(())
3900}
3901
3902/// Retrieve the chat history for a channel
3903async fn get_channel_messages(
3904 request: proto::GetChannelMessages,
3905 response: Response<proto::GetChannelMessages>,
3906 session: UserSession,
3907) -> Result<()> {
3908 let channel_id = ChannelId::from_proto(request.channel_id);
3909 let messages = session
3910 .db()
3911 .await
3912 .get_channel_messages(
3913 channel_id,
3914 session.user_id(),
3915 MESSAGE_COUNT_PER_PAGE,
3916 Some(MessageId::from_proto(request.before_message_id)),
3917 )
3918 .await?;
3919 response.send(proto::GetChannelMessagesResponse {
3920 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921 messages,
3922 })?;
3923 Ok(())
3924}
3925
3926/// Retrieve specific chat messages
3927async fn get_channel_messages_by_id(
3928 request: proto::GetChannelMessagesById,
3929 response: Response<proto::GetChannelMessagesById>,
3930 session: UserSession,
3931) -> Result<()> {
3932 let message_ids = request
3933 .message_ids
3934 .iter()
3935 .map(|id| MessageId::from_proto(*id))
3936 .collect::<Vec<_>>();
3937 let messages = session
3938 .db()
3939 .await
3940 .get_channel_messages_by_id(session.user_id(), &message_ids)
3941 .await?;
3942 response.send(proto::GetChannelMessagesResponse {
3943 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3944 messages,
3945 })?;
3946 Ok(())
3947}
3948
3949/// Retrieve the current users notifications
3950async fn get_notifications(
3951 request: proto::GetNotifications,
3952 response: Response<proto::GetNotifications>,
3953 session: UserSession,
3954) -> Result<()> {
3955 let notifications = session
3956 .db()
3957 .await
3958 .get_notifications(
3959 session.user_id(),
3960 NOTIFICATION_COUNT_PER_PAGE,
3961 request
3962 .before_id
3963 .map(|id| db::NotificationId::from_proto(id)),
3964 )
3965 .await?;
3966 response.send(proto::GetNotificationsResponse {
3967 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3968 notifications,
3969 })?;
3970 Ok(())
3971}
3972
3973/// Mark notifications as read
3974async fn mark_notification_as_read(
3975 request: proto::MarkNotificationRead,
3976 response: Response<proto::MarkNotificationRead>,
3977 session: UserSession,
3978) -> Result<()> {
3979 let database = &session.db().await;
3980 let notifications = database
3981 .mark_notification_as_read_by_id(
3982 session.user_id(),
3983 NotificationId::from_proto(request.notification_id),
3984 )
3985 .await?;
3986 send_notifications(
3987 &*session.connection_pool().await,
3988 &session.peer,
3989 notifications,
3990 );
3991 response.send(proto::Ack {})?;
3992 Ok(())
3993}
3994
3995/// Get the current users information
3996async fn get_private_user_info(
3997 _request: proto::GetPrivateUserInfo,
3998 response: Response<proto::GetPrivateUserInfo>,
3999 session: UserSession,
4000) -> Result<()> {
4001 let db = session.db().await;
4002
4003 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4004 let user = db
4005 .get_user_by_id(session.user_id())
4006 .await?
4007 .ok_or_else(|| anyhow!("user not found"))?;
4008 let flags = db.get_user_flags(session.user_id()).await?;
4009
4010 response.send(proto::GetPrivateUserInfoResponse {
4011 metrics_id,
4012 staff: user.admin,
4013 flags,
4014 })?;
4015 Ok(())
4016}
4017
4018fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4019 match message {
4020 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4021 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4022 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4023 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4024 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4025 code: frame.code.into(),
4026 reason: frame.reason,
4027 })),
4028 }
4029}
4030
4031fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4032 match message {
4033 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4034 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4035 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4036 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4037 AxumMessage::Close(frame) => {
4038 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4039 code: frame.code.into(),
4040 reason: frame.reason,
4041 }))
4042 }
4043 }
4044}
4045
4046fn notify_membership_updated(
4047 connection_pool: &mut ConnectionPool,
4048 result: MembershipUpdated,
4049 user_id: UserId,
4050 peer: &Peer,
4051) {
4052 for membership in &result.new_channels.channel_memberships {
4053 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4054 }
4055 for channel_id in &result.removed_channels {
4056 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4057 }
4058
4059 let user_channels_update = proto::UpdateUserChannels {
4060 channel_memberships: result
4061 .new_channels
4062 .channel_memberships
4063 .iter()
4064 .map(|cm| proto::ChannelMembership {
4065 channel_id: cm.channel_id.to_proto(),
4066 role: cm.role.into(),
4067 })
4068 .collect(),
4069 ..Default::default()
4070 };
4071
4072 let mut update = build_channels_update(result.new_channels, vec![]);
4073 update.delete_channels = result
4074 .removed_channels
4075 .into_iter()
4076 .map(|id| id.to_proto())
4077 .collect();
4078 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4079
4080 for connection_id in connection_pool.user_connection_ids(user_id) {
4081 peer.send(connection_id, user_channels_update.clone())
4082 .trace_err();
4083 peer.send(connection_id, update.clone()).trace_err();
4084 }
4085}
4086
4087fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4088 proto::UpdateUserChannels {
4089 channel_memberships: channels
4090 .channel_memberships
4091 .iter()
4092 .map(|m| proto::ChannelMembership {
4093 channel_id: m.channel_id.to_proto(),
4094 role: m.role.into(),
4095 })
4096 .collect(),
4097 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4098 observed_channel_message_id: channels.observed_channel_messages.clone(),
4099 }
4100}
4101
4102fn build_channels_update(
4103 channels: ChannelsForUser,
4104 channel_invites: Vec<db::Channel>,
4105) -> proto::UpdateChannels {
4106 let mut update = proto::UpdateChannels::default();
4107
4108 for channel in channels.channels {
4109 update.channels.push(channel.to_proto());
4110 }
4111
4112 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4113 update.latest_channel_message_ids = channels.latest_channel_messages;
4114
4115 for (channel_id, participants) in channels.channel_participants {
4116 update
4117 .channel_participants
4118 .push(proto::ChannelParticipants {
4119 channel_id: channel_id.to_proto(),
4120 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4121 });
4122 }
4123
4124 for channel in channel_invites {
4125 update.channel_invitations.push(channel.to_proto());
4126 }
4127 for project in channels.hosted_projects {
4128 update.hosted_projects.push(project);
4129 }
4130
4131 update
4132}
4133
4134fn build_initial_contacts_update(
4135 contacts: Vec<db::Contact>,
4136 pool: &ConnectionPool,
4137) -> proto::UpdateContacts {
4138 let mut update = proto::UpdateContacts::default();
4139
4140 for contact in contacts {
4141 match contact {
4142 db::Contact::Accepted { user_id, busy } => {
4143 update.contacts.push(contact_for_user(user_id, busy, &pool));
4144 }
4145 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4146 db::Contact::Incoming { user_id } => {
4147 update
4148 .incoming_requests
4149 .push(proto::IncomingContactRequest {
4150 requester_id: user_id.to_proto(),
4151 })
4152 }
4153 }
4154 }
4155
4156 update
4157}
4158
4159fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4160 proto::Contact {
4161 user_id: user_id.to_proto(),
4162 online: pool.is_user_online(user_id),
4163 busy,
4164 }
4165}
4166
4167fn room_updated(room: &proto::Room, peer: &Peer) {
4168 broadcast(
4169 None,
4170 room.participants
4171 .iter()
4172 .filter_map(|participant| Some(participant.peer_id?.into())),
4173 |peer_id| {
4174 peer.send(
4175 peer_id,
4176 proto::RoomUpdated {
4177 room: Some(room.clone()),
4178 },
4179 )
4180 },
4181 );
4182}
4183
4184fn channel_updated(
4185 channel: &db::channel::Model,
4186 room: &proto::Room,
4187 peer: &Peer,
4188 pool: &ConnectionPool,
4189) {
4190 let participants = room
4191 .participants
4192 .iter()
4193 .map(|p| p.user_id)
4194 .collect::<Vec<_>>();
4195
4196 broadcast(
4197 None,
4198 pool.channel_connection_ids(channel.root_id())
4199 .filter_map(|(channel_id, role)| {
4200 role.can_see_channel(channel.visibility).then(|| channel_id)
4201 }),
4202 |peer_id| {
4203 peer.send(
4204 peer_id,
4205 proto::UpdateChannels {
4206 channel_participants: vec![proto::ChannelParticipants {
4207 channel_id: channel.id.to_proto(),
4208 participant_user_ids: participants.clone(),
4209 }],
4210 ..Default::default()
4211 },
4212 )
4213 },
4214 );
4215}
4216
4217async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4218 let db = session.db().await;
4219
4220 let contacts = db.get_contacts(user_id).await?;
4221 let busy = db.is_user_busy(user_id).await?;
4222
4223 let pool = session.connection_pool().await;
4224 let updated_contact = contact_for_user(user_id, busy, &pool);
4225 for contact in contacts {
4226 if let db::Contact::Accepted {
4227 user_id: contact_user_id,
4228 ..
4229 } = contact
4230 {
4231 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4232 session
4233 .peer
4234 .send(
4235 contact_conn_id,
4236 proto::UpdateContacts {
4237 contacts: vec![updated_contact.clone()],
4238 remove_contacts: Default::default(),
4239 incoming_requests: Default::default(),
4240 remove_incoming_requests: Default::default(),
4241 outgoing_requests: Default::default(),
4242 remove_outgoing_requests: Default::default(),
4243 },
4244 )
4245 .trace_err();
4246 }
4247 }
4248 }
4249 Ok(())
4250}
4251
4252async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4253 let mut contacts_to_update = HashSet::default();
4254
4255 let room_id;
4256 let canceled_calls_to_user_ids;
4257 let live_kit_room;
4258 let delete_live_kit_room;
4259 let room;
4260 let channel;
4261
4262 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4263 contacts_to_update.insert(session.user_id());
4264
4265 for project in left_room.left_projects.values() {
4266 project_left(project, session);
4267 }
4268
4269 room_id = RoomId::from_proto(left_room.room.id);
4270 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4271 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4272 delete_live_kit_room = left_room.deleted;
4273 room = mem::take(&mut left_room.room);
4274 channel = mem::take(&mut left_room.channel);
4275
4276 room_updated(&room, &session.peer);
4277 } else {
4278 return Ok(());
4279 }
4280
4281 if let Some(channel) = channel {
4282 channel_updated(
4283 &channel,
4284 &room,
4285 &session.peer,
4286 &*session.connection_pool().await,
4287 );
4288 }
4289
4290 {
4291 let pool = session.connection_pool().await;
4292 for canceled_user_id in canceled_calls_to_user_ids {
4293 for connection_id in pool.user_connection_ids(canceled_user_id) {
4294 session
4295 .peer
4296 .send(
4297 connection_id,
4298 proto::CallCanceled {
4299 room_id: room_id.to_proto(),
4300 },
4301 )
4302 .trace_err();
4303 }
4304 contacts_to_update.insert(canceled_user_id);
4305 }
4306 }
4307
4308 for contact_user_id in contacts_to_update {
4309 update_user_contacts(contact_user_id, &session).await?;
4310 }
4311
4312 if let Some(live_kit) = session.live_kit_client.as_ref() {
4313 live_kit
4314 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4315 .await
4316 .trace_err();
4317
4318 if delete_live_kit_room {
4319 live_kit.delete_room(live_kit_room).await.trace_err();
4320 }
4321 }
4322
4323 Ok(())
4324}
4325
4326async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4327 let left_channel_buffers = session
4328 .db()
4329 .await
4330 .leave_channel_buffers(session.connection_id)
4331 .await?;
4332
4333 for left_buffer in left_channel_buffers {
4334 channel_buffer_updated(
4335 session.connection_id,
4336 left_buffer.connections,
4337 &proto::UpdateChannelBufferCollaborators {
4338 channel_id: left_buffer.channel_id.to_proto(),
4339 collaborators: left_buffer.collaborators,
4340 },
4341 &session.peer,
4342 );
4343 }
4344
4345 Ok(())
4346}
4347
4348fn project_left(project: &db::LeftProject, session: &UserSession) {
4349 for connection_id in &project.connection_ids {
4350 if project.host_user_id == Some(session.user_id()) {
4351 session
4352 .peer
4353 .send(
4354 *connection_id,
4355 proto::UnshareProject {
4356 project_id: project.id.to_proto(),
4357 },
4358 )
4359 .trace_err();
4360 } else {
4361 session
4362 .peer
4363 .send(
4364 *connection_id,
4365 proto::RemoveProjectCollaborator {
4366 project_id: project.id.to_proto(),
4367 peer_id: Some(session.connection_id.into()),
4368 },
4369 )
4370 .trace_err();
4371 }
4372 }
4373}
4374
4375pub trait ResultExt {
4376 type Ok;
4377
4378 fn trace_err(self) -> Option<Self::Ok>;
4379}
4380
4381impl<T, E> ResultExt for Result<T, E>
4382where
4383 E: std::fmt::Debug,
4384{
4385 type Ok = T;
4386
4387 fn trace_err(self) -> Option<T> {
4388 match self {
4389 Ok(value) => Some(value),
4390 Err(error) => {
4391 tracing::error!("{:?}", error);
4392 None
4393 }
4394 }
4395 }
4396}