From f3cb4e1b28202d4c8ec5abd006b9177adc398837 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Thu, 4 Sep 2025 16:07:14 -0600 Subject: [PATCH] WIP implementation, pausing work on this for now --- .../cloud_llm_client/src/cloud_llm_client.rs | 151 ++++++++---- crates/zeta/src/training_data_uploader.rs | 226 ++++++++++++++++++ crates/zeta/src/zeta.rs | 31 +++ 3 files changed, 361 insertions(+), 47 deletions(-) create mode 100644 crates/zeta/src/training_data_uploader.rs diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 4d037d3598f50b7dab1801873295b394e6b3191b..c315d49d774ba52b473db70f54616c57c6c2f7b6 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -2,7 +2,9 @@ use std::str::FromStr; use std::sync::Arc; use anyhow::Context as _; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::value::RawValue; +use std::marker::PhantomData; use strum::{Display, EnumIter, EnumString}; use uuid::Uuid; @@ -152,30 +154,10 @@ pub struct PredictEditsBody { pub git_info: Option, } -/// Additional context only stored when can_collect_data is true for the corresponding edit -/// predictions request. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsAdditionalContext { - /// Path to the file in the repository that contains the input excerpt. - pub input_path: String, - /// Cursor position within the file that contains the input excerpt. - pub cursor_point: Point, - /// Cursor offset in bytes within the file that contains the input excerpt. - pub cursor_offset: usize, - #[serde(flatten)] - pub git_info: PredictEditsGitInfo, - /// Diagnostic near the cursor position. - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub diagnostic_groups: Vec<(String, Box)>, - /// True if the diagnostics were truncated. - pub diagnostic_groups_truncated: bool, - /// Recently active files that may be within this repository. - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub recent_files: Vec, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsGitInfo { + /// full_path to the repo (worktree name + relative path to repo) + pub worktree_path: Option, /// SHA of git HEAD commit at time of prediction. #[serde(skip_serializing_if = "Option::is_none", default)] pub head_sha: Option, @@ -187,41 +169,64 @@ pub struct PredictEditsGitInfo { pub remote_upstream_url: Option, } -/// A zero-indexed point in a text buffer consisting of a row and column. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Point { - pub row: u32, - pub column: u32, +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub output_excerpt: String, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsRecentFile { - /// Path to a file within the repository. - pub path: String, - /// Most recent cursor position with the file. - pub cursor_point: Point, - /// Milliseconds between the editor for this file being active and the request time. - pub active_to_now_ms: u32, - /// Number of times the editor for this file was activated. - pub activation_count: u32, - /// Rough estimate of milliseconds the user was editing the file. - pub cumulative_time_editing_ms: u32, - /// Rough estimate of milliseconds the user was navigating within the file. - pub cumulative_time_navigating_ms: u32, - /// Whether the file is a multibuffer. - #[serde(skip_serializing_if = "is_default", default)] - pub is_multibuffer: bool, +pub struct AcceptEditPredictionBody { + pub request_id: Uuid, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsResponse { +pub struct PredictEditsTrainingData { pub request_id: Uuid, - pub output_excerpt: String, + /// When true, `request_id` is an ID that corresponds to an edit prediction. + pub has_prediction: bool, + /// State that `events` is based on. Initially this is `GitHead` and subsequent uploads will + /// then be based on the previous upload. + pub diff_base: PredictEditsDiffBase, + /// Fine-grained edit events atop `diff_base`. + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub events: Vec>, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AcceptEditPredictionBody { - pub request_id: Uuid, +#[serde(rename_all = "snake_case")] +pub enum PredictEditsDiffBase { + GitHead { git_info: PredictEditsGitInfo }, + PreviousUpload { request_id: Uuid }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsEvent { + pub entry_id: usize, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub path: Option, + pub timestamp_ms: u64, + pub data: PredictEditsEventData, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PredictEditsEventData { + MoveCursor { + offset: usize, + #[serde(skip_serializing_if = "Vec::is_empty", default)] + diagnostic_groups: Vec<(String, Box)>, + #[serde(skip_serializing_if = "is_default", default)] + diagnostic_groups_truncated: bool, + }, + Create { + content: String, + }, + Delete, + Edit { + unified_diff: String, + }, + MarkDiffTooLarge, } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] @@ -382,6 +387,58 @@ pub struct UsageData { pub limit: UsageLimit, } +#[derive(Debug, Clone)] +pub struct SerializedJson { + raw: Box, + _phantom: PhantomData, +} + +impl SerializedJson +where + T: Serialize + for<'de> Deserialize<'de>, +{ + pub fn new(value: &T) -> Result { + Ok(SerializedJson { + raw: serde_json::value::to_raw_value(value)?, + _phantom: PhantomData, + }) + } + + pub fn deserialize(&self) -> Result { + serde_json::from_str(self.raw.get()) + } + + pub fn as_raw(&self) -> &RawValue { + &self.raw + } + + pub fn into_raw(self) -> Box { + self.raw + } +} + +impl Serialize for SerializedJson { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.raw.serialize(serializer) + } +} + +impl<'de, T> Deserialize<'de> for SerializedJson { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw = Box::::deserialize(deserializer)?; + Ok(SerializedJson { + raw, + _phantom: PhantomData, + }) + } +} + fn is_default(value: &T) -> bool { *value == T::default() } diff --git a/crates/zeta/src/training_data_uploader.rs b/crates/zeta/src/training_data_uploader.rs new file mode 100644 index 0000000000000000000000000000000000000000..e7c8e271e7235ae41bdef911908dc2261fcd6583 --- /dev/null +++ b/crates/zeta/src/training_data_uploader.rs @@ -0,0 +1,226 @@ +use std::collections::hash_map; + +use cloud_llm_client::{PredictEditsEvent, PredictEditsGitInfo, SerializedJson}; +use collections::{HashMap, HashSet}; +use fs::MTime; +use gpui::{AppContext as _, Context, Entity, EntityId, Task, WeakEntity}; +use language::{Buffer, BufferEvent}; +use project::{ + Project, ProjectEntryId, ProjectPath, + buffer_store::{BufferStore, BufferStoreEvent}, + git_store::{GitStore, GitStoreEvent, Repository, RepositoryId}, + worktree_store::{WorktreeStore, WorktreeStoreEvent}, +}; +use uuid::Uuid; + +use crate::license_detection::LicenseDetectionWatcher; + +// todos: +// +// * Don't subscribe to all buffers +// +// * Currently MoveCursor event will only happen for edit prediction requests. + +pub struct TrainingDataUploader { + projects: HashMap>, + _upload_task: Task<()>, +} + +struct ZetaProject { + project: WeakEntity, + repositories: HashMap>, + buffers_changed: HashSet>, + project_entries_changed: HashSet, +} + +struct ZetaRepository { + unsent_events: Vec>, + pending_event: Option, + last_snapshot: Option, + license_watcher: LicenseDetectionWatcher, +} + +struct ZetaRepositorySnapshot { + request_id: Uuid, + git_info: PredictEditsGitInfo, + buffers: HashMap, + files: HashMap, +} + +struct ZetaBufferSnapshot { + path: ProjectPath, + text: String, + buffer: WeakEntity, + version: clock::Global, +} + +struct ZetaFileSnapshot { + path: ProjectPath, + text: String, + mtime: MTime, +} + +impl TrainingDataUploader { + pub fn new(cx: &mut Context) -> Self { + let _upload_task = cx.spawn(|this, cx| { + loop { + todo!(); + } + }); + Self { + projects: HashMap::default(), + _upload_task, + } + } + + fn register(&mut self, project: &Entity, path: ProjectPath, cx: &mut Context) { + let project_entity_id = project.entity_id(); + + let zeta_project = match self.projects.entry(project_entity_id) { + hash_map::Entry::Vacant(entry) => { + let zeta_project = cx.new(|cx| ZetaProject::new(project, cx)); + cx.observe_release(project, move |this, project, cx| { + this.projects.remove(&project_entity_id); + }); + entry.insert(zeta_project) + } + hash_map::Entry::Occupied(entry) => entry.into_mut(), + }; + + // todo! + // zeta_project.update(|zeta_project, cx| zeta_project.register(path, cx)); + } +} + +impl ZetaProject { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + cx.subscribe(&project, Self::handle_project_event).detach(); + cx.subscribe( + &project.read(cx).git_store().clone(), + Self::handle_git_store_event, + ) + .detach(); + cx.subscribe( + &project.read(cx).worktree_store(), + Self::handle_worktree_store_event, + ) + .detach(); + + let buffer_store = project.read(cx).buffer_store().clone(); + for buffer in buffer_store.read(cx).buffers().collect::>() { + Self::register_buffer(&buffer, cx); + } + cx.subscribe(&buffer_store, Self::handle_buffer_store_event) + .detach(); + + Self { + project: project.downgrade(), + repositories: HashMap::default(), + buffers_changed: HashSet::default(), + project_entries_changed: HashSet::default(), + } + } + + fn handle_git_store_event( + &mut self, + _git_store: Entity, + event: &GitStoreEvent, + cx: &mut Context, + ) { + use GitStoreEvent::*; + match event { + RepositoryRemoved(repository_id) => { + self.repositories.remove(&repository_id); + } + RepositoryAdded(repository_id) => { + self.repositories + .insert(*repository_id, cx.new(|cx| ZetaRepository::new(cx))); + } + RepositoryUpdated(repository_id, event, is_active) => {} + ActiveRepositoryChanged { .. } + | IndexWriteError { .. } + | JobsUpdated + | ConflictsUpdated => {} + } + } + + fn handle_worktree_store_event( + &mut self, + _worktree_store: Entity, + event: &WorktreeStoreEvent, + cx: &mut Context, + ) { + use WorktreeStoreEvent::*; + match event { + WorktreeAdded(worktree) => {} + WorktreeRemoved(worktree_entity_id, worktree_id) => {} + WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { + for (path, entry_id, _path_change) in updated_entries_set.iter() { + self.project_entries_changed.insert(*entry_id); + } + } + WorktreeUpdatedGitRepositories(worktree_id, updated_git_repositories) => {} + WorktreeDeletedEntry(worktree_id, project_entry_id) => {} + WorktreeReleased { .. } | WorktreeOrderChanged | WorktreeUpdateSent { .. } => {} + } + } + + fn handle_buffer_store_event( + &mut self, + _buffer_store: Entity, + event: &BufferStoreEvent, + cx: &mut Context, + ) { + use BufferStoreEvent::*; + match event { + BufferAdded(buffer) => Self::register_buffer(buffer, cx), + BufferOpened { .. } + | BufferChangedFilePath { .. } + | BufferDropped { .. } + | SharedBufferClosed { .. } => {} + } + } + + fn register_buffer(buffer: &Entity, cx: &mut Context) { + cx.subscribe(buffer, Self::handle_buffer_event); + } + + fn handle_buffer_event( + &mut self, + buffer: Entity, + event: &BufferEvent, + _cx: &mut Context, + ) { + match event { + BufferEvent::Edited => { + self.buffers_changed.insert(buffer.downgrade()); + } + _ => {} + } + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + cx: &mut Context, + ) { + match event { + project::Event::ActiveEntryChanged(entry_id) => { + todo!() + } + _ => {} + } + } +} + +impl ZetaRepository { + pub fn new(cx: &mut Context) -> Self { + Self { + unsent_events: Vec::new(), + pending_event: None, + last_snapshot: None, + license_watcher: LicenseDetectionWatcher::new(cx), + } + } +} diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index bf4b9706b51a9b2bb538807e8666360a6f4b98da..982c71b0df61076754d258b6cac2d6ce6e5b88aa 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -5,6 +5,7 @@ mod license_detection; mod onboarding_modal; mod onboarding_telemetry; mod rate_completion_modal; +mod training_data_uploader; use arrayvec::ArrayVec; pub(crate) use completion_diff_element::*; @@ -60,6 +61,8 @@ use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; use worktree::Worktree; +use crate::training_data_uploader::TrainingDataUploader; + const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>"; @@ -247,6 +250,8 @@ pub struct Zeta { user_store: Entity, license_detection_watchers: HashMap>, projects: HashMap, + /// todo! document that this should be the only place Entity is stored. + training_data_uploader: Option>, } struct ZetaProject { @@ -377,6 +382,7 @@ impl Zeta { license_detection_watchers: HashMap::default(), user_store, projects: HashMap::default(), + training_data_uploader: None, } } @@ -529,6 +535,31 @@ impl Zeta { can_collect_data, )); + // todo! async + if matches!(can_collect_data, CanCollectData(true)) { + let training_data_uploader = match &self.training_data_uploader { + None => { + let training_data_uploader = cx.new(|cx| TrainingDataUploader::new(cx)); + self.training_data_uploader = Some(training_data_uploader.clone()); + &training_data_uploader + } + Some(training_data_uploader) => training_data_uploader, + }; + + training_data_uploader.update(cx, |training_data_uploader, cx| { + let project = project.read(cx); + let entry = project.entry_for_path(&project_path, cx)?; + if !worktree_entry_is_eligible_for_collection(&entry) { + return None; + } + + let git_store = project.git_store().read(cx); + let (repository, repo_path) = + git_store.repository_and_path_for_project_path(&project_path, cx)?; + let repo_path_string = repo_path.to_str()?.to_string(); + }); + } + cx.spawn(async move |this, cx| { let GatherContextOutput { body,