@@ -25,7 +25,7 @@ use gpui::{
prelude::*,
};
use language::language_settings::all_language_settings;
-use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
+use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::{Project, ProjectPath, WorktreeId};
@@ -42,7 +42,7 @@ use std::path::Path;
use std::rc::Rc;
use std::str::FromStr as _;
use std::sync::{Arc, LazyLock};
-use std::time::{Duration, Instant};
+use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use std::{env, mem};
use thiserror::Error;
use util::{RangeExt as _, ResultExt as _};
@@ -197,6 +197,7 @@ pub struct EditPredictionModelInput {
trigger: PredictEditsRequestTrigger,
diagnostic_search_range: Range<Point>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+ pub user_actions: Vec<UserActionRecord>,
}
#[derive(Debug, Clone, PartialEq)]
@@ -243,6 +244,26 @@ pub struct EditPredictionFinishedDebugEvent {
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
+const USER_ACTION_HISTORY_SIZE: usize = 16;
+
+#[derive(Clone, Debug)]
+pub struct UserActionRecord {
+ pub action_type: UserActionType,
+ pub buffer_id: EntityId,
+ pub line_number: u32,
+ pub offset: usize,
+ pub timestamp_epoch_ms: u64,
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum UserActionType {
+ InsertChar,
+ InsertSelection,
+ DeleteChar,
+ DeleteSelection,
+ CursorMovement,
+}
+
/// An event with associated metadata for reconstructing buffer state.
#[derive(Clone)]
pub struct StoredEvent {
@@ -263,10 +284,18 @@ struct ProjectState {
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ user_actions: VecDeque<UserActionRecord>,
_subscription: gpui::Subscription,
}
impl ProjectState {
+ fn record_user_action(&mut self, action: UserActionRecord) {
+ if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
+ self.user_actions.pop_front();
+ }
+ self.user_actions.push_back(action);
+ }
+
pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
@@ -803,6 +832,7 @@ impl EditPredictionStore {
next_pending_prediction_id: 0,
last_prediction_refresh: None,
license_detection_watchers: HashMap::default(),
+ user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
_subscription: cx.subscribe(&project, Self::handle_project_event),
})
}
@@ -999,7 +1029,51 @@ impl EditPredictionStore {
let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- let edit_range = edited_range(&old_snapshot, &new_snapshot);
+ let mut num_edits = 0usize;
+ let mut total_deleted = 0usize;
+ let mut total_inserted = 0usize;
+ let mut edit_range: Option<Range<Anchor>> = None;
+ let mut last_offset: Option<usize> = None;
+
+ for (edit, anchor_range) in
+ new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
+ {
+ num_edits += 1;
+ total_deleted += edit.old.len();
+ total_inserted += edit.new.len();
+ edit_range = Some(match edit_range {
+ None => anchor_range,
+ Some(acc) => acc.start..anchor_range.end,
+ });
+ last_offset = Some(edit.new.end);
+ }
+
+ if num_edits > 0 {
+ let action_type = match (total_deleted, total_inserted, num_edits) {
+ (0, ins, n) if ins == n => UserActionType::InsertChar,
+ (0, _, _) => UserActionType::InsertSelection,
+ (del, 0, n) if del == n => UserActionType::DeleteChar,
+ (_, 0, _) => UserActionType::DeleteSelection,
+ (_, ins, n) if ins == n => UserActionType::InsertChar,
+ (_, _, _) => UserActionType::InsertSelection,
+ };
+
+ if let Some(offset) = last_offset {
+ let point = new_snapshot.offset_to_point(offset);
+ let timestamp_epoch_ms = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .map(|d| d.as_millis() as u64)
+ .unwrap_or(0);
+ project_state.record_user_action(UserActionRecord {
+ action_type,
+ buffer_id: buffer.entity_id(),
+ line_number: point.row,
+ offset,
+ timestamp_epoch_ms,
+ });
+ }
+ }
+
let events = &mut project_state.events;
let now = cx.background_executor().now();
@@ -1615,6 +1689,28 @@ impl EditPredictionStore {
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
+ let current_offset = position.to_offset(&snapshot);
+
+ let mut user_actions: Vec<UserActionRecord> =
+ project_state.user_actions.iter().cloned().collect();
+
+ if let Some(last_action) = user_actions.last() {
+ if last_action.buffer_id == active_buffer.entity_id()
+ && current_offset != last_action.offset
+ {
+ let timestamp_epoch_ms = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .map(|d| d.as_millis() as u64)
+ .unwrap_or(0);
+ user_actions.push(UserActionRecord {
+ action_type: UserActionType::CursorMovement,
+ buffer_id: active_buffer.entity_id(),
+ line_number: cursor_point.row,
+ offset: current_offset,
+ timestamp_epoch_ms,
+ });
+ }
+ }
let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
let diagnostic_search_range =
@@ -1637,6 +1733,7 @@ impl EditPredictionStore {
trigger,
diagnostic_search_range: diagnostic_search_range.clone(),
debug_tx,
+ user_actions,
};
let can_collect_example = snapshot
@@ -2096,20 +2193,6 @@ impl EditPredictionStore {
}
}
-fn edited_range(
- old_snapshot: &TextBufferSnapshot,
- new_snapshot: &TextBufferSnapshot,
-) -> Option<Range<Anchor>> {
- new_snapshot
- .anchored_edits_since::<usize>(&old_snapshot.version)
- .fold(None, |acc, (_, range)| {
- Some(match acc {
- None => range,
- Some(acc) => acc.start..range.end,
- })
- })
-}
-
#[derive(Error, Debug)]
#[error(
"You must update to Zed version {minimum_version} or higher to continue using edit predictions."
@@ -1,7 +1,7 @@
use crate::{
CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
- EditPredictionStore, prediction::EditPredictionResult,
+ EditPredictionStore, UserActionRecord, UserActionType, prediction::EditPredictionResult,
};
use anyhow::{Result, bail};
use client::Client;
@@ -68,6 +68,7 @@ impl SweepAi {
.unwrap_or("untitled")
.into();
let offset = inputs.position.to_offset(&inputs.snapshot);
+ let buffer_entity_id = inputs.buffer.entity_id();
let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
@@ -171,6 +172,14 @@ impl SweepAi {
});
}
+ let file_path_str = full_path.display().to_string();
+ let recent_user_actions = inputs
+ .user_actions
+ .iter()
+ .filter(|r| r.buffer_id == buffer_entity_id)
+ .map(|r| to_sweep_user_action(r, &file_path_str))
+ .collect();
+
let request_body = AutocompleteRequest {
debug_info,
repo_name,
@@ -184,7 +193,7 @@ impl SweepAi {
branch: None,
file_chunks,
retrieval_chunks,
- recent_user_actions: vec![],
+ recent_user_actions,
use_bytes: true,
// TODO
privacy_mode_enabled: false,
@@ -386,7 +395,6 @@ struct UserAction {
pub timestamp: u64,
}
-#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
enum ActionType {
@@ -397,6 +405,22 @@ enum ActionType {
DeleteSelection,
}
+fn to_sweep_user_action(record: &UserActionRecord, file_path: &str) -> UserAction {
+ UserAction {
+ action_type: match record.action_type {
+ UserActionType::InsertChar => ActionType::InsertChar,
+ UserActionType::InsertSelection => ActionType::InsertSelection,
+ UserActionType::DeleteChar => ActionType::DeleteChar,
+ UserActionType::DeleteSelection => ActionType::DeleteSelection,
+ UserActionType::CursorMovement => ActionType::CursorMovement,
+ },
+ line_number: record.line_number as usize,
+ offset: record.offset,
+ file_path: file_path.to_string(),
+ timestamp: record.timestamp_epoch_ms,
+ }
+}
+
#[derive(Debug, Clone, Deserialize)]
struct AutocompleteResponse {
pub autocomplete_id: String,