Add action tracking to enable token healing for Sweep (#46212)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction/src/edit_prediction.rs       | 117 ++++++++++++--
crates/edit_prediction/src/edit_prediction_tests.rs |   2 
crates/edit_prediction/src/sweep_ai.rs              |  30 +++
3 files changed, 128 insertions(+), 21 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -25,7 +25,7 @@ use gpui::{
     prelude::*,
 };
 use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
+use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::{Project, ProjectPath, WorktreeId};
@@ -42,7 +42,7 @@ use std::path::Path;
 use std::rc::Rc;
 use std::str::FromStr as _;
 use std::sync::{Arc, LazyLock};
-use std::time::{Duration, Instant};
+use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
 use std::{env, mem};
 use thiserror::Error;
 use util::{RangeExt as _, ResultExt as _};
@@ -197,6 +197,7 @@ pub struct EditPredictionModelInput {
     trigger: PredictEditsRequestTrigger,
     diagnostic_search_range: Range<Point>,
     debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+    pub user_actions: Vec<UserActionRecord>,
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -243,6 +244,26 @@ pub struct EditPredictionFinishedDebugEvent {
 
 pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 
+const USER_ACTION_HISTORY_SIZE: usize = 16;
+
+#[derive(Clone, Debug)]
+pub struct UserActionRecord {
+    pub action_type: UserActionType,
+    pub buffer_id: EntityId,
+    pub line_number: u32,
+    pub offset: usize,
+    pub timestamp_epoch_ms: u64,
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum UserActionType {
+    InsertChar,
+    InsertSelection,
+    DeleteChar,
+    DeleteSelection,
+    CursorMovement,
+}
+
 /// An event with associated metadata for reconstructing buffer state.
 #[derive(Clone)]
 pub struct StoredEvent {
@@ -263,10 +284,18 @@ struct ProjectState {
     cancelled_predictions: HashSet<usize>,
     context: Entity<RelatedExcerptStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+    user_actions: VecDeque<UserActionRecord>,
     _subscription: gpui::Subscription,
 }
 
 impl ProjectState {
+    fn record_user_action(&mut self, action: UserActionRecord) {
+        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
+            self.user_actions.pop_front();
+        }
+        self.user_actions.push_back(action);
+    }
+
     pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
         self.events
             .iter()
@@ -803,6 +832,7 @@ impl EditPredictionStore {
                 next_pending_prediction_id: 0,
                 last_prediction_refresh: None,
                 license_detection_watchers: HashMap::default(),
+                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
                 _subscription: cx.subscribe(&project, Self::handle_project_event),
             })
     }
@@ -999,7 +1029,51 @@ impl EditPredictionStore {
 
         let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
         let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
-        let edit_range = edited_range(&old_snapshot, &new_snapshot);
+        let mut num_edits = 0usize;
+        let mut total_deleted = 0usize;
+        let mut total_inserted = 0usize;
+        let mut edit_range: Option<Range<Anchor>> = None;
+        let mut last_offset: Option<usize> = None;
+
+        for (edit, anchor_range) in
+            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
+        {
+            num_edits += 1;
+            total_deleted += edit.old.len();
+            total_inserted += edit.new.len();
+            edit_range = Some(match edit_range {
+                None => anchor_range,
+                Some(acc) => acc.start..anchor_range.end,
+            });
+            last_offset = Some(edit.new.end);
+        }
+
+        if num_edits > 0 {
+            let action_type = match (total_deleted, total_inserted, num_edits) {
+                (0, ins, n) if ins == n => UserActionType::InsertChar,
+                (0, _, _) => UserActionType::InsertSelection,
+                (del, 0, n) if del == n => UserActionType::DeleteChar,
+                (_, 0, _) => UserActionType::DeleteSelection,
+                (_, ins, n) if ins == n => UserActionType::InsertChar,
+                (_, _, _) => UserActionType::InsertSelection,
+            };
+
+            if let Some(offset) = last_offset {
+                let point = new_snapshot.offset_to_point(offset);
+                let timestamp_epoch_ms = SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .map(|d| d.as_millis() as u64)
+                    .unwrap_or(0);
+                project_state.record_user_action(UserActionRecord {
+                    action_type,
+                    buffer_id: buffer.entity_id(),
+                    line_number: point.row,
+                    offset,
+                    timestamp_epoch_ms,
+                });
+            }
+        }
+
         let events = &mut project_state.events;
 
         let now = cx.background_executor().now();
@@ -1615,6 +1689,28 @@ impl EditPredictionStore {
 
         let snapshot = active_buffer.read(cx).snapshot();
         let cursor_point = position.to_point(&snapshot);
+        let current_offset = position.to_offset(&snapshot);
+
+        let mut user_actions: Vec<UserActionRecord> =
+            project_state.user_actions.iter().cloned().collect();
+
+        if let Some(last_action) = user_actions.last() {
+            if last_action.buffer_id == active_buffer.entity_id()
+                && current_offset != last_action.offset
+            {
+                let timestamp_epoch_ms = SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .map(|d| d.as_millis() as u64)
+                    .unwrap_or(0);
+                user_actions.push(UserActionRecord {
+                    action_type: UserActionType::CursorMovement,
+                    buffer_id: active_buffer.entity_id(),
+                    line_number: cursor_point.row,
+                    offset: current_offset,
+                    timestamp_epoch_ms,
+                });
+            }
+        }
         let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
         let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
         let diagnostic_search_range =
@@ -1637,6 +1733,7 @@ impl EditPredictionStore {
             trigger,
             diagnostic_search_range: diagnostic_search_range.clone(),
             debug_tx,
+            user_actions,
         };
 
         let can_collect_example = snapshot
@@ -2096,20 +2193,6 @@ impl EditPredictionStore {
     }
 }
 
-fn edited_range(
-    old_snapshot: &TextBufferSnapshot,
-    new_snapshot: &TextBufferSnapshot,
-) -> Option<Range<Anchor>> {
-    new_snapshot
-        .anchored_edits_since::<usize>(&old_snapshot.version)
-        .fold(None, |acc, (_, range)| {
-            Some(match acc {
-                None => range,
-                Some(acc) => acc.start..range.end,
-            })
-        })
-}
-
 #[derive(Error, Debug)]
 #[error(
     "You must update to Zed version {minimum_version} or higher to continue using edit predictions."

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -16,7 +16,7 @@ use gpui::{
     http_client::{FakeHttpClient, Response},
 };
 use indoc::indoc;
-use language::{Point, ToOffset as _};
+use language::Point;
 use lsp::LanguageServerId;
 use open_ai::Usage;
 use parking_lot::Mutex;

crates/edit_prediction/src/sweep_ai.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
     EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
-    EditPredictionStore, prediction::EditPredictionResult,
+    EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
 };
 use anyhow::{Result, bail};
 use client::Client;
@@ -68,6 +68,7 @@ impl SweepAi {
             .unwrap_or("untitled")
             .into();
         let offset = inputs.position.to_offset(&inputs.snapshot);
+        let buffer_entity_id = inputs.buffer.entity_id();
 
         let recent_buffers = inputs.recent_paths.iter().cloned();
         let http_client = cx.http_client();
@@ -171,6 +172,14 @@ impl SweepAi {
                 });
             }
 
+            let file_path_str = full_path.display().to_string();
+            let recent_user_actions = inputs
+                .user_actions
+                .iter()
+                .filter(|r| r.buffer_id == buffer_entity_id)
+                .map(|r| to_sweep_user_action(r, &file_path_str))
+                .collect();
+
             let request_body = AutocompleteRequest {
                 debug_info,
                 repo_name,
@@ -184,7 +193,7 @@ impl SweepAi {
                 branch: None,
                 file_chunks,
                 retrieval_chunks,
-                recent_user_actions: vec![],
+                recent_user_actions,
                 use_bytes: true,
                 // TODO
                 privacy_mode_enabled: false,
@@ -386,7 +395,6 @@ struct UserAction {
     pub timestamp: u64,
 }
 
-#[allow(dead_code)]
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
 enum ActionType {
@@ -397,6 +405,22 @@ enum ActionType {
     DeleteSelection,
 }
 
+fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
+    UserAction {
+        action_type: match record.action_type {
+            UserActionType::InsertChar => ActionType::InsertChar,
+            UserActionType::InsertSelection => ActionType::InsertSelection,
+            UserActionType::DeleteChar => ActionType::DeleteChar,
+            UserActionType::DeleteSelection => ActionType::DeleteSelection,
+            UserActionType::CursorMovement => ActionType::CursorMovement,
+        },
+        line_number: record.line_number as usize,
+        offset: record.offset,
+        file_path: file_path.to_string(),
+        timestamp: record.timestamp_epoch_ms,
+    }
+}
+
 #[derive(Debug, Clone, Deserialize)]
 struct AutocompleteResponse {
     pub autocomplete_id: String,