db2.rs

  1mod project;
  2mod project_collaborator;
  3mod room;
  4mod room_participant;
  5mod worktree;
  6
  7use crate::{Error, Result};
  8use anyhow::anyhow;
  9use collections::HashMap;
 10use dashmap::DashMap;
 11use futures::StreamExt;
 12use rpc::{proto, ConnectionId};
 13use sea_orm::ActiveValue;
 14use sea_orm::{
 15    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
 16    TransactionTrait,
 17};
 18use serde::{Deserialize, Serialize};
 19use std::ops::{Deref, DerefMut};
 20use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
 21use tokio::sync::{Mutex, OwnedMutexGuard};
 22
 23pub struct Database {
 24    pool: DatabaseConnection,
 25    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 26    #[cfg(test)]
 27    background: Option<std::sync::Arc<gpui::executor::Background>>,
 28    #[cfg(test)]
 29    runtime: Option<tokio::runtime::Runtime>,
 30}
 31
 32impl Database {
 33    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 34        let mut options = ConnectOptions::new(url.into());
 35        options.max_connections(max_connections);
 36        Ok(Self {
 37            pool: sea_orm::Database::connect(options).await?,
 38            rooms: DashMap::with_capacity(16384),
 39            #[cfg(test)]
 40            background: None,
 41            #[cfg(test)]
 42            runtime: None,
 43        })
 44    }
 45
 46    pub async fn share_project(
 47        &self,
 48        room_id: RoomId,
 49        connection_id: ConnectionId,
 50        worktrees: &[proto::WorktreeMetadata],
 51    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
 52        self.transact(|tx| async move {
 53            let participant = room_participant::Entity::find()
 54                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
 55                .one(&tx)
 56                .await?
 57                .ok_or_else(|| anyhow!("could not find participant"))?;
 58            if participant.room_id != room_id.0 {
 59                return Err(anyhow!("shared project on unexpected room"))?;
 60            }
 61
 62            let project = project::ActiveModel {
 63                room_id: ActiveValue::set(participant.room_id),
 64                host_user_id: ActiveValue::set(participant.user_id),
 65                host_connection_id: ActiveValue::set(connection_id.0 as i32),
 66                ..Default::default()
 67            }
 68            .insert(&tx)
 69            .await?;
 70
 71            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
 72                id: ActiveValue::set(worktree.id as i32),
 73                project_id: ActiveValue::set(project.id),
 74                abs_path: ActiveValue::set(worktree.abs_path.clone()),
 75                root_name: ActiveValue::set(worktree.root_name.clone()),
 76                visible: ActiveValue::set(worktree.visible),
 77                scan_id: ActiveValue::set(0),
 78                is_complete: ActiveValue::set(false),
 79            }))
 80            .exec(&tx)
 81            .await?;
 82
 83            project_collaborator::ActiveModel {
 84                project_id: ActiveValue::set(project.id),
 85                connection_id: ActiveValue::set(connection_id.0 as i32),
 86                user_id: ActiveValue::set(participant.user_id),
 87                replica_id: ActiveValue::set(0),
 88                is_host: ActiveValue::set(true),
 89                ..Default::default()
 90            }
 91            .insert(&tx)
 92            .await?;
 93
 94            let room = self.get_room(room_id, &tx).await?;
 95            self.commit_room_transaction(room_id, tx, (ProjectId(project.id), room))
 96                .await
 97        })
 98        .await
 99    }
100
101    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
102        let db_room = room::Entity::find_by_id(room_id.0)
103            .one(tx)
104            .await?
105            .ok_or_else(|| anyhow!("could not find room"))?;
106
107        let mut db_participants = db_room
108            .find_related(room_participant::Entity)
109            .stream(tx)
110            .await?;
111        let mut participants = HashMap::default();
112        let mut pending_participants = Vec::new();
113        while let Some(db_participant) = db_participants.next().await {
114            let db_participant = db_participant?;
115            if let Some(answering_connection_id) = db_participant.answering_connection_id {
116                let location = match (
117                    db_participant.location_kind,
118                    db_participant.location_project_id,
119                ) {
120                    (Some(0), Some(project_id)) => {
121                        Some(proto::participant_location::Variant::SharedProject(
122                            proto::participant_location::SharedProject {
123                                id: project_id as u64,
124                            },
125                        ))
126                    }
127                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
128                        Default::default(),
129                    )),
130                    _ => Some(proto::participant_location::Variant::External(
131                        Default::default(),
132                    )),
133                };
134                participants.insert(
135                    answering_connection_id,
136                    proto::Participant {
137                        user_id: db_participant.user_id as u64,
138                        peer_id: answering_connection_id as u32,
139                        projects: Default::default(),
140                        location: Some(proto::ParticipantLocation { variant: location }),
141                    },
142                );
143            } else {
144                pending_participants.push(proto::PendingParticipant {
145                    user_id: db_participant.user_id as u64,
146                    calling_user_id: db_participant.calling_user_id as u64,
147                    initial_project_id: db_participant.initial_project_id.map(|id| id as u64),
148                });
149            }
150        }
151
152        let mut db_projects = db_room
153            .find_related(project::Entity)
154            .find_with_related(worktree::Entity)
155            .stream(tx)
156            .await?;
157
158        while let Some(row) = db_projects.next().await {
159            let (db_project, db_worktree) = row?;
160            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
161                let project = if let Some(project) = participant
162                    .projects
163                    .iter_mut()
164                    .find(|project| project.id as i32 == db_project.id)
165                {
166                    project
167                } else {
168                    participant.projects.push(proto::ParticipantProject {
169                        id: db_project.id as u64,
170                        worktree_root_names: Default::default(),
171                    });
172                    participant.projects.last_mut().unwrap()
173                };
174
175                if let Some(db_worktree) = db_worktree {
176                    project.worktree_root_names.push(db_worktree.root_name);
177                }
178            }
179        }
180
181        Ok(proto::Room {
182            id: db_room.id as u64,
183            live_kit_room: db_room.live_kit_room,
184            participants: participants.into_values().collect(),
185            pending_participants,
186        })
187    }
188
189    async fn commit_room_transaction<T>(
190        &self,
191        room_id: RoomId,
192        tx: DatabaseTransaction,
193        data: T,
194    ) -> Result<RoomGuard<T>> {
195        let lock = self.rooms.entry(room_id).or_default().clone();
196        let _guard = lock.lock_owned().await;
197        tx.commit().await?;
198        Ok(RoomGuard {
199            data,
200            _guard,
201            _not_send: PhantomData,
202        })
203    }
204
205    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
206    where
207        F: Send + Fn(DatabaseTransaction) -> Fut,
208        Fut: Send + Future<Output = Result<T>>,
209    {
210        let body = async {
211            loop {
212                let tx = self.pool.begin().await?;
213                match f(tx).await {
214                    Ok(result) => return Ok(result),
215                    Err(error) => match error {
216                        Error::Database2(
217                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
218                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
219                        ) if error
220                            .as_database_error()
221                            .and_then(|error| error.code())
222                            .as_deref()
223                            == Some("40001") =>
224                        {
225                            // Retry (don't break the loop)
226                        }
227                        error @ _ => return Err(error),
228                    },
229                }
230            }
231        };
232
233        #[cfg(test)]
234        {
235            if let Some(background) = self.background.as_ref() {
236                background.simulate_random_delay().await;
237            }
238
239            self.runtime.as_ref().unwrap().block_on(body)
240        }
241
242        #[cfg(not(test))]
243        {
244            body.await
245        }
246    }
247}
248
249pub struct RoomGuard<T> {
250    data: T,
251    _guard: OwnedMutexGuard<()>,
252    _not_send: PhantomData<Rc<()>>,
253}
254
255impl<T> Deref for RoomGuard<T> {
256    type Target = T;
257
258    fn deref(&self) -> &T {
259        &self.data
260    }
261}
262
263impl<T> DerefMut for RoomGuard<T> {
264    fn deref_mut(&mut self) -> &mut T {
265        &mut self.data
266    }
267}
268
269macro_rules! id_type {
270    ($name:ident) => {
271        #[derive(
272            Clone,
273            Copy,
274            Debug,
275            Default,
276            PartialEq,
277            Eq,
278            PartialOrd,
279            Ord,
280            Hash,
281            sqlx::Type,
282            Serialize,
283            Deserialize,
284        )]
285        #[sqlx(transparent)]
286        #[serde(transparent)]
287        pub struct $name(pub i32);
288
289        impl $name {
290            #[allow(unused)]
291            pub const MAX: Self = Self(i32::MAX);
292
293            #[allow(unused)]
294            pub fn from_proto(value: u64) -> Self {
295                Self(value as i32)
296            }
297
298            #[allow(unused)]
299            pub fn to_proto(self) -> u64 {
300                self.0 as u64
301            }
302        }
303
304        impl std::fmt::Display for $name {
305            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
306                self.0.fmt(f)
307            }
308        }
309    };
310}
311
312id_type!(UserId);
313id_type!(RoomId);
314id_type!(RoomParticipantId);
315id_type!(ProjectId);
316id_type!(WorktreeId);