1mod project;
2mod project_collaborator;
3mod room;
4mod room_participant;
5mod worktree;
6
7use crate::{Error, Result};
8use anyhow::anyhow;
9use collections::HashMap;
10use dashmap::DashMap;
11use futures::StreamExt;
12use rpc::{proto, ConnectionId};
13use sea_orm::ActiveValue;
14use sea_orm::{
15 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
16 TransactionTrait,
17};
18use serde::{Deserialize, Serialize};
19use std::ops::{Deref, DerefMut};
20use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
21use tokio::sync::{Mutex, OwnedMutexGuard};
22
23pub struct Database {
24 pool: DatabaseConnection,
25 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
26 #[cfg(test)]
27 background: Option<std::sync::Arc<gpui::executor::Background>>,
28 #[cfg(test)]
29 runtime: Option<tokio::runtime::Runtime>,
30}
31
32impl Database {
33 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
34 let mut options = ConnectOptions::new(url.into());
35 options.max_connections(max_connections);
36 Ok(Self {
37 pool: sea_orm::Database::connect(options).await?,
38 rooms: DashMap::with_capacity(16384),
39 #[cfg(test)]
40 background: None,
41 #[cfg(test)]
42 runtime: None,
43 })
44 }
45
46 pub async fn share_project(
47 &self,
48 room_id: RoomId,
49 connection_id: ConnectionId,
50 worktrees: &[proto::WorktreeMetadata],
51 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
52 self.transact(|tx| async move {
53 let participant = room_participant::Entity::find()
54 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
55 .one(&tx)
56 .await?
57 .ok_or_else(|| anyhow!("could not find participant"))?;
58 if participant.room_id != room_id.0 {
59 return Err(anyhow!("shared project on unexpected room"))?;
60 }
61
62 let project = project::ActiveModel {
63 room_id: ActiveValue::set(participant.room_id),
64 host_user_id: ActiveValue::set(participant.user_id),
65 host_connection_id: ActiveValue::set(connection_id.0 as i32),
66 ..Default::default()
67 }
68 .insert(&tx)
69 .await?;
70
71 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
72 id: ActiveValue::set(worktree.id as i32),
73 project_id: ActiveValue::set(project.id),
74 abs_path: ActiveValue::set(worktree.abs_path.clone()),
75 root_name: ActiveValue::set(worktree.root_name.clone()),
76 visible: ActiveValue::set(worktree.visible),
77 scan_id: ActiveValue::set(0),
78 is_complete: ActiveValue::set(false),
79 }))
80 .exec(&tx)
81 .await?;
82
83 project_collaborator::ActiveModel {
84 project_id: ActiveValue::set(project.id),
85 connection_id: ActiveValue::set(connection_id.0 as i32),
86 user_id: ActiveValue::set(participant.user_id),
87 replica_id: ActiveValue::set(0),
88 is_host: ActiveValue::set(true),
89 ..Default::default()
90 }
91 .insert(&tx)
92 .await?;
93
94 let room = self.get_room(room_id, &tx).await?;
95 self.commit_room_transaction(room_id, tx, (ProjectId(project.id), room))
96 .await
97 })
98 .await
99 }
100
101 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
102 let db_room = room::Entity::find_by_id(room_id.0)
103 .one(tx)
104 .await?
105 .ok_or_else(|| anyhow!("could not find room"))?;
106
107 let mut db_participants = db_room
108 .find_related(room_participant::Entity)
109 .stream(tx)
110 .await?;
111 let mut participants = HashMap::default();
112 let mut pending_participants = Vec::new();
113 while let Some(db_participant) = db_participants.next().await {
114 let db_participant = db_participant?;
115 if let Some(answering_connection_id) = db_participant.answering_connection_id {
116 let location = match (
117 db_participant.location_kind,
118 db_participant.location_project_id,
119 ) {
120 (Some(0), Some(project_id)) => {
121 Some(proto::participant_location::Variant::SharedProject(
122 proto::participant_location::SharedProject {
123 id: project_id as u64,
124 },
125 ))
126 }
127 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
128 Default::default(),
129 )),
130 _ => Some(proto::participant_location::Variant::External(
131 Default::default(),
132 )),
133 };
134 participants.insert(
135 answering_connection_id,
136 proto::Participant {
137 user_id: db_participant.user_id as u64,
138 peer_id: answering_connection_id as u32,
139 projects: Default::default(),
140 location: Some(proto::ParticipantLocation { variant: location }),
141 },
142 );
143 } else {
144 pending_participants.push(proto::PendingParticipant {
145 user_id: db_participant.user_id as u64,
146 calling_user_id: db_participant.calling_user_id as u64,
147 initial_project_id: db_participant.initial_project_id.map(|id| id as u64),
148 });
149 }
150 }
151
152 let mut db_projects = db_room
153 .find_related(project::Entity)
154 .find_with_related(worktree::Entity)
155 .stream(tx)
156 .await?;
157
158 while let Some(row) = db_projects.next().await {
159 let (db_project, db_worktree) = row?;
160 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
161 let project = if let Some(project) = participant
162 .projects
163 .iter_mut()
164 .find(|project| project.id as i32 == db_project.id)
165 {
166 project
167 } else {
168 participant.projects.push(proto::ParticipantProject {
169 id: db_project.id as u64,
170 worktree_root_names: Default::default(),
171 });
172 participant.projects.last_mut().unwrap()
173 };
174
175 if let Some(db_worktree) = db_worktree {
176 project.worktree_root_names.push(db_worktree.root_name);
177 }
178 }
179 }
180
181 Ok(proto::Room {
182 id: db_room.id as u64,
183 live_kit_room: db_room.live_kit_room,
184 participants: participants.into_values().collect(),
185 pending_participants,
186 })
187 }
188
189 async fn commit_room_transaction<T>(
190 &self,
191 room_id: RoomId,
192 tx: DatabaseTransaction,
193 data: T,
194 ) -> Result<RoomGuard<T>> {
195 let lock = self.rooms.entry(room_id).or_default().clone();
196 let _guard = lock.lock_owned().await;
197 tx.commit().await?;
198 Ok(RoomGuard {
199 data,
200 _guard,
201 _not_send: PhantomData,
202 })
203 }
204
205 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
206 where
207 F: Send + Fn(DatabaseTransaction) -> Fut,
208 Fut: Send + Future<Output = Result<T>>,
209 {
210 let body = async {
211 loop {
212 let tx = self.pool.begin().await?;
213 match f(tx).await {
214 Ok(result) => return Ok(result),
215 Err(error) => match error {
216 Error::Database2(
217 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
218 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
219 ) if error
220 .as_database_error()
221 .and_then(|error| error.code())
222 .as_deref()
223 == Some("40001") =>
224 {
225 // Retry (don't break the loop)
226 }
227 error @ _ => return Err(error),
228 },
229 }
230 }
231 };
232
233 #[cfg(test)]
234 {
235 if let Some(background) = self.background.as_ref() {
236 background.simulate_random_delay().await;
237 }
238
239 self.runtime.as_ref().unwrap().block_on(body)
240 }
241
242 #[cfg(not(test))]
243 {
244 body.await
245 }
246 }
247}
248
249pub struct RoomGuard<T> {
250 data: T,
251 _guard: OwnedMutexGuard<()>,
252 _not_send: PhantomData<Rc<()>>,
253}
254
255impl<T> Deref for RoomGuard<T> {
256 type Target = T;
257
258 fn deref(&self) -> &T {
259 &self.data
260 }
261}
262
263impl<T> DerefMut for RoomGuard<T> {
264 fn deref_mut(&mut self) -> &mut T {
265 &mut self.data
266 }
267}
268
269macro_rules! id_type {
270 ($name:ident) => {
271 #[derive(
272 Clone,
273 Copy,
274 Debug,
275 Default,
276 PartialEq,
277 Eq,
278 PartialOrd,
279 Ord,
280 Hash,
281 sqlx::Type,
282 Serialize,
283 Deserialize,
284 )]
285 #[sqlx(transparent)]
286 #[serde(transparent)]
287 pub struct $name(pub i32);
288
289 impl $name {
290 #[allow(unused)]
291 pub const MAX: Self = Self(i32::MAX);
292
293 #[allow(unused)]
294 pub fn from_proto(value: u64) -> Self {
295 Self(value as i32)
296 }
297
298 #[allow(unused)]
299 pub fn to_proto(self) -> u64 {
300 self.0 as u64
301 }
302 }
303
304 impl std::fmt::Display for $name {
305 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
306 self.0.fmt(f)
307 }
308 }
309 };
310}
311
312id_type!(UserId);
313id_type!(RoomId);
314id_type!(RoomParticipantId);
315id_type!(ProjectId);
316id_type!(WorktreeId);