WIP implementation, pausing work on this for now

Michael Sloan created

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs | 151 ++++++++---
crates/zeta/src/training_data_uploader.rs       | 226 +++++++++++++++++++
crates/zeta/src/zeta.rs                         |  31 ++
3 files changed, 361 insertions(+), 47 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -2,7 +2,9 @@ use std::str::FromStr;
 use std::sync::Arc;
 
 use anyhow::Context as _;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use serde_json::value::RawValue;
+use std::marker::PhantomData;
 use strum::{Display, EnumIter, EnumString};
 use uuid::Uuid;
 
@@ -152,30 +154,10 @@ pub struct PredictEditsBody {
     pub git_info: Option<PredictEditsGitInfo>,
 }
 
-/// Additional context only stored when can_collect_data is true for the corresponding edit
-/// predictions request.
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsAdditionalContext {
-    /// Path to the file in the repository that contains the input excerpt.
-    pub input_path: String,
-    /// Cursor position within the file that contains the input excerpt.
-    pub cursor_point: Point,
-    /// Cursor offset in bytes within the file that contains the input excerpt.
-    pub cursor_offset: usize,
-    #[serde(flatten)]
-    pub git_info: PredictEditsGitInfo,
-    /// Diagnostic near the cursor position.
-    #[serde(skip_serializing_if = "Vec::is_empty", default)]
-    pub diagnostic_groups: Vec<(String, Box<serde_json::value::RawValue>)>,
-    /// True if the diagnostics were truncated.
-    pub diagnostic_groups_truncated: bool,
-    /// Recently active files that may be within this repository.
-    #[serde(skip_serializing_if = "Vec::is_empty", default)]
-    pub recent_files: Vec<PredictEditsRecentFile>,
-}
-
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PredictEditsGitInfo {
+    /// full_path to the repo (worktree name + relative path to repo)
+    pub worktree_path: Option<String>,
     /// SHA of git HEAD commit at time of prediction.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub head_sha: Option<String>,
@@ -187,41 +169,64 @@ pub struct PredictEditsGitInfo {
     pub remote_upstream_url: Option<String>,
 }
 
-/// A zero-indexed point in a text buffer consisting of a row and column.
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Point {
-    pub row: u32,
-    pub column: u32,
+pub struct PredictEditsResponse {
+    pub request_id: Uuid,
+    pub output_excerpt: String,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsRecentFile {
-    /// Path to a file within the repository.
-    pub path: String,
-    /// Most recent cursor position with the file.
-    pub cursor_point: Point,
-    /// Milliseconds between the editor for this file being active and the request time.
-    pub active_to_now_ms: u32,
-    /// Number of times the editor for this file was activated.
-    pub activation_count: u32,
-    /// Rough estimate of milliseconds the user was editing the file.
-    pub cumulative_time_editing_ms: u32,
-    /// Rough estimate of milliseconds the user was navigating within the file.
-    pub cumulative_time_navigating_ms: u32,
-    /// Whether the file is a multibuffer.
-    #[serde(skip_serializing_if = "is_default", default)]
-    pub is_multibuffer: bool,
+pub struct AcceptEditPredictionBody {
+    pub request_id: Uuid,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsResponse {
+pub struct PredictEditsTrainingData {
     pub request_id: Uuid,
-    pub output_excerpt: String,
+    /// When true, `request_id` is an ID that corresponds to an edit prediction.
+    pub has_prediction: bool,
+    /// State that `events` is based on. Initially this is `GitHead` and subsequent uploads will
+    /// then be based on the previous upload.
+    pub diff_base: PredictEditsDiffBase,
+    /// Fine-grained edit events atop `diff_base`.
+    #[serde(skip_serializing_if = "Vec::is_empty", default)]
+    pub events: Vec<SerializedJson<PredictEditsEvent>>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct AcceptEditPredictionBody {
-    pub request_id: Uuid,
+#[serde(rename_all = "snake_case")]
+pub enum PredictEditsDiffBase {
+    GitHead { git_info: PredictEditsGitInfo },
+    PreviousUpload { request_id: Uuid },
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsEvent {
+    pub entry_id: usize,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub path: Option<String>,
+    pub timestamp_ms: u64,
+    pub data: PredictEditsEventData,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum PredictEditsEventData {
+    MoveCursor {
+        offset: usize,
+        #[serde(skip_serializing_if = "Vec::is_empty", default)]
+        diagnostic_groups: Vec<(String, Box<RawValue>)>,
+        #[serde(skip_serializing_if = "is_default", default)]
+        diagnostic_groups_truncated: bool,
+    },
+    Create {
+        content: String,
+    },
+    Delete,
+    Edit {
+        unified_diff: String,
+    },
+    MarkDiffTooLarge,
 }
 
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
@@ -382,6 +387,58 @@ pub struct UsageData {
     pub limit: UsageLimit,
 }
 
+#[derive(Debug, Clone)]
+pub struct SerializedJson<T> {
+    raw: Box<RawValue>,
+    _phantom: PhantomData<T>,
+}
+
+impl<T> SerializedJson<T>
+where
+    T: Serialize + for<'de> Deserialize<'de>,
+{
+    pub fn new(value: &T) -> Result<Self, serde_json::Error> {
+        Ok(SerializedJson {
+            raw: serde_json::value::to_raw_value(value)?,
+            _phantom: PhantomData,
+        })
+    }
+
+    pub fn deserialize(&self) -> Result<T, serde_json::Error> {
+        serde_json::from_str(self.raw.get())
+    }
+
+    pub fn as_raw(&self) -> &RawValue {
+        &self.raw
+    }
+
+    pub fn into_raw(self) -> Box<RawValue> {
+        self.raw
+    }
+}
+
+impl<T> Serialize for SerializedJson<T> {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        self.raw.serialize(serializer)
+    }
+}
+
+impl<'de, T> Deserialize<'de> for SerializedJson<T> {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let raw = Box::<RawValue>::deserialize(deserializer)?;
+        Ok(SerializedJson {
+            raw,
+            _phantom: PhantomData,
+        })
+    }
+}
+
 fn is_default<T: Default + PartialEq>(value: &T) -> bool {
     *value == T::default()
 }

crates/zeta/src/training_data_uploader.rs 🔗

@@ -0,0 +1,226 @@
+use std::collections::hash_map;
+
+use cloud_llm_client::{PredictEditsEvent, PredictEditsGitInfo, SerializedJson};
+use collections::{HashMap, HashSet};
+use fs::MTime;
+use gpui::{AppContext as _, Context, Entity, EntityId, Task, WeakEntity};
+use language::{Buffer, BufferEvent};
+use project::{
+    Project, ProjectEntryId, ProjectPath,
+    buffer_store::{BufferStore, BufferStoreEvent},
+    git_store::{GitStore, GitStoreEvent, Repository, RepositoryId},
+    worktree_store::{WorktreeStore, WorktreeStoreEvent},
+};
+use uuid::Uuid;
+
+use crate::license_detection::LicenseDetectionWatcher;
+
+// todos:
+//
+// * Don't subscribe to all buffers
+//
+// * Currently MoveCursor event will only happen for edit prediction requests.
+
+pub struct TrainingDataUploader {
+    projects: HashMap<EntityId, Entity<ZetaProject>>,
+    _upload_task: Task<()>,
+}
+
+struct ZetaProject {
+    project: WeakEntity<Project>,
+    repositories: HashMap<RepositoryId, Entity<ZetaRepository>>,
+    buffers_changed: HashSet<WeakEntity<Buffer>>,
+    project_entries_changed: HashSet<ProjectEntryId>,
+}
+
+struct ZetaRepository {
+    unsent_events: Vec<SerializedJson<PredictEditsEvent>>,
+    pending_event: Option<PredictEditsEvent>,
+    last_snapshot: Option<ZetaRepositorySnapshot>,
+    license_watcher: LicenseDetectionWatcher,
+}
+
+struct ZetaRepositorySnapshot {
+    request_id: Uuid,
+    git_info: PredictEditsGitInfo,
+    buffers: HashMap<ProjectEntryId, ZetaBufferSnapshot>,
+    files: HashMap<ProjectEntryId, ZetaFileSnapshot>,
+}
+
+struct ZetaBufferSnapshot {
+    path: ProjectPath,
+    text: String,
+    buffer: WeakEntity<Buffer>,
+    version: clock::Global,
+}
+
+struct ZetaFileSnapshot {
+    path: ProjectPath,
+    text: String,
+    mtime: MTime,
+}
+
+impl TrainingDataUploader {
+    pub fn new(cx: &mut Context<Self>) -> Self {
+        let _upload_task = cx.spawn(|this, cx| {
+            loop {
+                todo!();
+            }
+        });
+        Self {
+            projects: HashMap::default(),
+            _upload_task,
+        }
+    }
+
+    fn register(&mut self, project: &Entity<Project>, path: ProjectPath, cx: &mut Context<Self>) {
+        let project_entity_id = project.entity_id();
+
+        let zeta_project = match self.projects.entry(project_entity_id) {
+            hash_map::Entry::Vacant(entry) => {
+                let zeta_project = cx.new(|cx| ZetaProject::new(project, cx));
+                cx.observe_release(project, move |this, project, cx| {
+                    this.projects.remove(&project_entity_id);
+                });
+                entry.insert(zeta_project)
+            }
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+        };
+
+        // todo!
+        // zeta_project.update(|zeta_project, cx| zeta_project.register(path, cx));
+    }
+}
+
+impl ZetaProject {
+    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
+        cx.subscribe(&project, Self::handle_project_event).detach();
+        cx.subscribe(
+            &project.read(cx).git_store().clone(),
+            Self::handle_git_store_event,
+        )
+        .detach();
+        cx.subscribe(
+            &project.read(cx).worktree_store(),
+            Self::handle_worktree_store_event,
+        )
+        .detach();
+
+        let buffer_store = project.read(cx).buffer_store().clone();
+        for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
+            Self::register_buffer(&buffer, cx);
+        }
+        cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
+            .detach();
+
+        Self {
+            project: project.downgrade(),
+            repositories: HashMap::default(),
+            buffers_changed: HashSet::default(),
+            project_entries_changed: HashSet::default(),
+        }
+    }
+
+    fn handle_git_store_event(
+        &mut self,
+        _git_store: Entity<GitStore>,
+        event: &GitStoreEvent,
+        cx: &mut Context<Self>,
+    ) {
+        use GitStoreEvent::*;
+        match event {
+            RepositoryRemoved(repository_id) => {
+                self.repositories.remove(&repository_id);
+            }
+            RepositoryAdded(repository_id) => {
+                self.repositories
+                    .insert(*repository_id, cx.new(|cx| ZetaRepository::new(cx)));
+            }
+            RepositoryUpdated(repository_id, event, is_active) => {}
+            ActiveRepositoryChanged { .. }
+            | IndexWriteError { .. }
+            | JobsUpdated
+            | ConflictsUpdated => {}
+        }
+    }
+
+    fn handle_worktree_store_event(
+        &mut self,
+        _worktree_store: Entity<WorktreeStore>,
+        event: &WorktreeStoreEvent,
+        cx: &mut Context<Self>,
+    ) {
+        use WorktreeStoreEvent::*;
+        match event {
+            WorktreeAdded(worktree) => {}
+            WorktreeRemoved(worktree_entity_id, worktree_id) => {}
+            WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
+                for (path, entry_id, _path_change) in updated_entries_set.iter() {
+                    self.project_entries_changed.insert(*entry_id);
+                }
+            }
+            WorktreeUpdatedGitRepositories(worktree_id, updated_git_repositories) => {}
+            WorktreeDeletedEntry(worktree_id, project_entry_id) => {}
+            WorktreeReleased { .. } | WorktreeOrderChanged | WorktreeUpdateSent { .. } => {}
+        }
+    }
+
+    fn handle_buffer_store_event(
+        &mut self,
+        _buffer_store: Entity<BufferStore>,
+        event: &BufferStoreEvent,
+        cx: &mut Context<Self>,
+    ) {
+        use BufferStoreEvent::*;
+        match event {
+            BufferAdded(buffer) => Self::register_buffer(buffer, cx),
+            BufferOpened { .. }
+            | BufferChangedFilePath { .. }
+            | BufferDropped { .. }
+            | SharedBufferClosed { .. } => {}
+        }
+    }
+
+    fn register_buffer(buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
+        cx.subscribe(buffer, Self::handle_buffer_event);
+    }
+
+    fn handle_buffer_event(
+        &mut self,
+        buffer: Entity<Buffer>,
+        event: &BufferEvent,
+        _cx: &mut Context<Self>,
+    ) {
+        match event {
+            BufferEvent::Edited => {
+                self.buffers_changed.insert(buffer.downgrade());
+            }
+            _ => {}
+        }
+    }
+
+    fn handle_project_event(
+        &mut self,
+        _project: Entity<Project>,
+        event: &project::Event,
+        cx: &mut Context<Self>,
+    ) {
+        match event {
+            project::Event::ActiveEntryChanged(entry_id) => {
+                todo!()
+            }
+            _ => {}
+        }
+    }
+}
+
+impl ZetaRepository {
+    pub fn new(cx: &mut Context<Self>) -> Self {
+        Self {
+            unsent_events: Vec::new(),
+            pending_event: None,
+            last_snapshot: None,
+            license_watcher: LicenseDetectionWatcher::new(cx),
+        }
+    }
+}

crates/zeta/src/zeta.rs 🔗

@@ -5,6 +5,7 @@ mod license_detection;
 mod onboarding_modal;
 mod onboarding_telemetry;
 mod rate_completion_modal;
+mod training_data_uploader;
 
 use arrayvec::ArrayVec;
 pub(crate) use completion_diff_element::*;
@@ -60,6 +61,8 @@ use workspace::Workspace;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 use worktree::Worktree;
 
+use crate::training_data_uploader::TrainingDataUploader;
+
 const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
@@ -247,6 +250,8 @@ pub struct Zeta {
     user_store: Entity<UserStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
     projects: HashMap<EntityId, ZetaProject>,
+    /// todo! document that this should be the only place Entity<TrainingDataUploader> is stored.
+    training_data_uploader: Option<Entity<TrainingDataUploader>>,
 }
 
 struct ZetaProject {
@@ -377,6 +382,7 @@ impl Zeta {
             license_detection_watchers: HashMap::default(),
             user_store,
             projects: HashMap::default(),
+            training_data_uploader: None,
         }
     }
 
@@ -529,6 +535,31 @@ impl Zeta {
             can_collect_data,
         ));
 
+        // todo! async
+        if matches!(can_collect_data, CanCollectData(true)) {
+            let training_data_uploader = match &self.training_data_uploader {
+                None => {
+                    let training_data_uploader = cx.new(|cx| TrainingDataUploader::new(cx));
+                    self.training_data_uploader = Some(training_data_uploader.clone());
+                    &training_data_uploader
+                }
+                Some(training_data_uploader) => training_data_uploader,
+            };
+
+            training_data_uploader.update(cx, |training_data_uploader, cx| {
+                let project = project.read(cx);
+                let entry = project.entry_for_path(&project_path, cx)?;
+                if !worktree_entry_is_eligible_for_collection(&entry) {
+                    return None;
+                }
+
+                let git_store = project.git_store().read(cx);
+                let (repository, repo_path) =
+                    git_store.repository_and_path_for_project_path(&project_path, cx)?;
+                let repo_path_string = repo_path.to_str()?.to_string();
+            });
+        }
+
         cx.spawn(async move |this, cx| {
             let GatherContextOutput {
                 body,