Implement `Db::summarize_user_activity`

Antonio Scandurra and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

crates/collab/src/db.rs | 166 +++++++++++++++++++++++++++++++++++++++++-
1 file changed, 162 insertions(+), 4 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -10,7 +10,7 @@ use nanoid::nanoid;
 use serde::Serialize;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
 use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
-use time::OffsetDateTime;
+use time::{OffsetDateTime, PrimitiveDateTime};
 
 #[async_trait]
 pub trait Db: Send + Sync {
@@ -77,6 +77,13 @@ pub trait Db: Send + Sync {
         max_user_count: usize,
     ) -> Result<Vec<UserActivitySummary>>;
 
+    /// Get the project activity for the given user and time period.
+    async fn summarize_user_activity(
+        &self,
+        user_id: UserId,
+        time_period: Range<OffsetDateTime>,
+    ) -> Result<Vec<UserActivityDuration>>;
+
     async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
     async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
     async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
@@ -596,7 +603,7 @@ impl Db for PostgresDb {
                 project_durations AS (
                     SELECT user_id, project_id, SUM(duration_millis) AS project_duration
                     FROM project_activity_periods
-                    WHERE $1 <= ended_at AND ended_at <= $2
+                    WHERE $1 < ended_at AND ended_at <= $2
                     GROUP BY user_id, project_id
                 ),
                 user_durations AS (
@@ -641,6 +648,89 @@ impl Db for PostgresDb {
         Ok(result)
     }
 
+    async fn summarize_user_activity(
+        &self,
+        user_id: UserId,
+        time_period: Range<OffsetDateTime>,
+    ) -> Result<Vec<UserActivityDuration>> {
+        const COALESCE_THRESHOLD: Duration = Duration::from_secs(5);
+
+        let query = "
+            SELECT
+                project_activity_periods.ended_at,
+                project_activity_periods.duration_millis,
+                project_activity_periods.project_id,
+                worktree_extensions.extension,
+                worktree_extensions.count
+            FROM project_activity_periods
+            LEFT OUTER JOIN
+                worktree_extensions
+            ON
+                project_activity_periods.project_id = worktree_extensions.project_id
+            WHERE
+                project_activity_periods.user_id = $1 AND
+                $2 < project_activity_periods.ended_at AND
+                project_activity_periods.ended_at <= $3
+            ORDER BY project_activity_periods.id ASC
+        ";
+
+        let mut rows = sqlx::query_as::<
+            _,
+            (
+                PrimitiveDateTime,
+                i32,
+                ProjectId,
+                Option<String>,
+                Option<i32>,
+            ),
+        >(query)
+        .bind(user_id)
+        .bind(time_period.start)
+        .bind(time_period.end)
+        .fetch(&self.pool);
+
+        let mut durations: HashMap<ProjectId, Vec<UserActivityDuration>> = Default::default();
+        while let Some(row) = rows.next().await {
+            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
+            let ended_at = ended_at.assume_utc();
+            let duration = Duration::from_millis(duration_millis as u64);
+            let started_at = ended_at - duration;
+            let project_durations = durations.entry(project_id).or_default();
+
+            if let Some(prev_duration) = project_durations.last_mut() {
+                if started_at - prev_duration.end <= COALESCE_THRESHOLD {
+                    prev_duration.end = ended_at;
+                } else {
+                    project_durations.push(UserActivityDuration {
+                        project_id,
+                        start: started_at,
+                        end: ended_at,
+                        extensions: Default::default(),
+                    });
+                }
+            } else {
+                project_durations.push(UserActivityDuration {
+                    project_id,
+                    start: started_at,
+                    end: ended_at,
+                    extensions: Default::default(),
+                });
+            }
+
+            if let Some((extension, extension_count)) = extension.zip(extension_count) {
+                project_durations
+                    .last_mut()
+                    .unwrap()
+                    .extensions
+                    .insert(extension, extension_count as usize);
+            }
+        }
+
+        let mut durations = durations.into_values().flatten().collect::<Vec<_>>();
+        durations.sort_unstable_by_key(|duration| duration.start);
+        Ok(durations)
+    }
+
     // contacts
 
     async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
@@ -1172,6 +1262,14 @@ pub struct UserActivitySummary {
     pub project_activity: Vec<(ProjectId, Duration)>,
 }
 
+#[derive(Clone, Debug, PartialEq, Serialize)]
+pub struct UserActivityDuration {
+    project_id: ProjectId,
+    start: OffsetDateTime,
+    end: OffsetDateTime,
+    extensions: HashMap<String, usize>,
+}
+
 id_type!(OrgId);
 #[derive(FromRow)]
 pub struct Org {
@@ -1439,6 +1537,13 @@ pub mod tests {
         let user_2 = db.create_user("user_2", None, false).await.unwrap();
         let user_3 = db.create_user("user_3", None, false).await.unwrap();
         let project_1 = db.register_project(user_1).await.unwrap();
+        db.update_worktree_extensions(
+            project_1,
+            1,
+            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
+        )
+        .await
+        .unwrap();
         let project_2 = db.register_project(user_2).await.unwrap();
         let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
 
@@ -1492,9 +1597,14 @@ pub mod tests {
             .await
             .unwrap();
 
-        let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
+        let t7 = t6 + Duration::from_secs(60);
+        let t8 = t7 + Duration::from_secs(10);
+        db.record_project_activity(t7..t8, &[(user_1, project_1)])
+            .await
+            .unwrap();
+
         assert_eq!(
-            summary,
+            db.summarize_project_activity(t0..t6, 10).await.unwrap(),
             &[
                 UserActivitySummary {
                     id: user_1,
@@ -1516,6 +1626,46 @@ pub mod tests {
                 },
             ]
         );
+        assert_eq!(
+            db.summarize_user_activity(user_1, t3..t6).await.unwrap(),
+            &[
+                UserActivityDuration {
+                    project_id: project_1,
+                    start: t3,
+                    end: t6,
+                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
+                },
+                UserActivityDuration {
+                    project_id: project_2,
+                    start: t3,
+                    end: t5,
+                    extensions: Default::default(),
+                },
+            ]
+        );
+        assert_eq!(
+            db.summarize_user_activity(user_1, t0..t8).await.unwrap(),
+            &[
+                UserActivityDuration {
+                    project_id: project_2,
+                    start: t2,
+                    end: t5,
+                    extensions: Default::default(),
+                },
+                UserActivityDuration {
+                    project_id: project_1,
+                    start: t3,
+                    end: t6,
+                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
+                },
+                UserActivityDuration {
+                    project_id: project_1,
+                    start: t7,
+                    end: t8,
+                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
+                },
+            ]
+        );
     }
 
     #[tokio::test(flavor = "multi_thread")]
@@ -2316,6 +2466,14 @@ pub mod tests {
             unimplemented!()
         }
 
+        async fn summarize_user_activity(
+            &self,
+            _user_id: UserId,
+            _time_period: Range<OffsetDateTime>,
+        ) -> Result<Vec<UserActivityDuration>> {
+            unimplemented!()
+        }
+
         // contacts
 
         async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {