Detailed changes
@@ -2,7 +2,9 @@ use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context as _;
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use serde_json::value::RawValue;
+use std::marker::PhantomData;
use strum::{Display, EnumIter, EnumString};
use uuid::Uuid;
@@ -152,30 +154,10 @@ pub struct PredictEditsBody {
pub git_info: Option<PredictEditsGitInfo>,
}
-/// Additional context only stored when can_collect_data is true for the corresponding edit
-/// predictions request.
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsAdditionalContext {
- /// Path to the file in the repository that contains the input excerpt.
- pub input_path: String,
- /// Cursor position within the file that contains the input excerpt.
- pub cursor_point: Point,
- /// Cursor offset in bytes within the file that contains the input excerpt.
- pub cursor_offset: usize,
- #[serde(flatten)]
- pub git_info: PredictEditsGitInfo,
- /// Diagnostic near the cursor position.
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub diagnostic_groups: Vec<(String, Box<serde_json::value::RawValue>)>,
- /// True if the diagnostics were truncated.
- pub diagnostic_groups_truncated: bool,
- /// Recently active files that may be within this repository.
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub recent_files: Vec<PredictEditsRecentFile>,
-}
-
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsGitInfo {
+ /// full_path to the repo (worktree name + relative path to repo)
+ pub worktree_path: Option<String>,
/// SHA of git HEAD commit at time of prediction.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub head_sha: Option<String>,
@@ -187,41 +169,64 @@ pub struct PredictEditsGitInfo {
pub remote_upstream_url: Option<String>,
}
-/// A zero-indexed point in a text buffer consisting of a row and column.
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Point {
- pub row: u32,
- pub column: u32,
+pub struct PredictEditsResponse {
+ pub request_id: Uuid,
+ pub output_excerpt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsRecentFile {
- /// Path to a file within the repository.
- pub path: String,
- /// Most recent cursor position with the file.
- pub cursor_point: Point,
- /// Milliseconds between the editor for this file being active and the request time.
- pub active_to_now_ms: u32,
- /// Number of times the editor for this file was activated.
- pub activation_count: u32,
- /// Rough estimate of milliseconds the user was editing the file.
- pub cumulative_time_editing_ms: u32,
- /// Rough estimate of milliseconds the user was navigating within the file.
- pub cumulative_time_navigating_ms: u32,
- /// Whether the file is a multibuffer.
- #[serde(skip_serializing_if = "is_default", default)]
- pub is_multibuffer: bool,
+pub struct AcceptEditPredictionBody {
+ pub request_id: Uuid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsResponse {
+pub struct PredictEditsTrainingData {
pub request_id: Uuid,
- pub output_excerpt: String,
+ /// When true, `request_id` is an ID that corresponds to an edit prediction.
+ pub has_prediction: bool,
+ /// State that `events` is based on. Initially this is `GitHead` and subsequent uploads will
+ /// then be based on the previous upload.
+ pub diff_base: PredictEditsDiffBase,
+ /// Fine-grained edit events atop `diff_base`.
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
+ pub events: Vec<SerializedJson<PredictEditsEvent>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct AcceptEditPredictionBody {
- pub request_id: Uuid,
+#[serde(rename_all = "snake_case")]
+pub enum PredictEditsDiffBase {
+ GitHead { git_info: PredictEditsGitInfo },
+ PreviousUpload { request_id: Uuid },
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsEvent {
+ pub entry_id: usize,
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub path: Option<String>,
+ pub timestamp_ms: u64,
+ pub data: PredictEditsEventData,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum PredictEditsEventData {
+ MoveCursor {
+ offset: usize,
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
+ diagnostic_groups: Vec<(String, Box<RawValue>)>,
+ #[serde(skip_serializing_if = "is_default", default)]
+ diagnostic_groups_truncated: bool,
+ },
+ Create {
+ content: String,
+ },
+ Delete,
+ Edit {
+ unified_diff: String,
+ },
+ MarkDiffTooLarge,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
@@ -382,6 +387,58 @@ pub struct UsageData {
pub limit: UsageLimit,
}
+#[derive(Debug, Clone)]
+pub struct SerializedJson<T> {
+ raw: Box<RawValue>,
+ _phantom: PhantomData<T>,
+}
+
+impl<T> SerializedJson<T>
+where
+ T: Serialize + for<'de> Deserialize<'de>,
+{
+ pub fn new(value: &T) -> Result<Self, serde_json::Error> {
+ Ok(SerializedJson {
+ raw: serde_json::value::to_raw_value(value)?,
+ _phantom: PhantomData,
+ })
+ }
+
+ pub fn deserialize(&self) -> Result<T, serde_json::Error> {
+ serde_json::from_str(self.raw.get())
+ }
+
+ pub fn as_raw(&self) -> &RawValue {
+ &self.raw
+ }
+
+ pub fn into_raw(self) -> Box<RawValue> {
+ self.raw
+ }
+}
+
+impl<T> Serialize for SerializedJson<T> {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ self.raw.serialize(serializer)
+ }
+}
+
+impl<'de, T> Deserialize<'de> for SerializedJson<T> {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let raw = Box::<RawValue>::deserialize(deserializer)?;
+ Ok(SerializedJson {
+ raw,
+ _phantom: PhantomData,
+ })
+ }
+}
+
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
*value == T::default()
}
@@ -0,0 +1,226 @@
+use std::collections::hash_map;
+
+use cloud_llm_client::{PredictEditsEvent, PredictEditsGitInfo, SerializedJson};
+use collections::{HashMap, HashSet};
+use fs::MTime;
+use gpui::{AppContext as _, Context, Entity, EntityId, Task, WeakEntity};
+use language::{Buffer, BufferEvent};
+use project::{
+ Project, ProjectEntryId, ProjectPath,
+ buffer_store::{BufferStore, BufferStoreEvent},
+ git_store::{GitStore, GitStoreEvent, Repository, RepositoryId},
+ worktree_store::{WorktreeStore, WorktreeStoreEvent},
+};
+use uuid::Uuid;
+
+use crate::license_detection::LicenseDetectionWatcher;
+
+// todos:
+//
+// * Don't subscribe to all buffers
+//
+// * Currently MoveCursor event will only happen for edit prediction requests.
+
+pub struct TrainingDataUploader {
+ projects: HashMap<EntityId, Entity<ZetaProject>>,
+ _upload_task: Task<()>,
+}
+
+struct ZetaProject {
+ project: WeakEntity<Project>,
+ repositories: HashMap<RepositoryId, Entity<ZetaRepository>>,
+ buffers_changed: HashSet<WeakEntity<Buffer>>,
+ project_entries_changed: HashSet<ProjectEntryId>,
+}
+
+struct ZetaRepository {
+ unsent_events: Vec<SerializedJson<PredictEditsEvent>>,
+ pending_event: Option<PredictEditsEvent>,
+ last_snapshot: Option<ZetaRepositorySnapshot>,
+ license_watcher: LicenseDetectionWatcher,
+}
+
+struct ZetaRepositorySnapshot {
+ request_id: Uuid,
+ git_info: PredictEditsGitInfo,
+ buffers: HashMap<ProjectEntryId, ZetaBufferSnapshot>,
+ files: HashMap<ProjectEntryId, ZetaFileSnapshot>,
+}
+
+struct ZetaBufferSnapshot {
+ path: ProjectPath,
+ text: String,
+ buffer: WeakEntity<Buffer>,
+ version: clock::Global,
+}
+
+struct ZetaFileSnapshot {
+ path: ProjectPath,
+ text: String,
+ mtime: MTime,
+}
+
+impl TrainingDataUploader {
+ pub fn new(cx: &mut Context<Self>) -> Self {
+ let _upload_task = cx.spawn(|this, cx| {
+ loop {
+ todo!();
+ }
+ });
+ Self {
+ projects: HashMap::default(),
+ _upload_task,
+ }
+ }
+
+ fn register(&mut self, project: &Entity<Project>, path: ProjectPath, cx: &mut Context<Self>) {
+ let project_entity_id = project.entity_id();
+
+ let zeta_project = match self.projects.entry(project_entity_id) {
+ hash_map::Entry::Vacant(entry) => {
+ let zeta_project = cx.new(|cx| ZetaProject::new(project, cx));
+ cx.observe_release(project, move |this, project, cx| {
+ this.projects.remove(&project_entity_id);
+ });
+ entry.insert(zeta_project)
+ }
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ };
+
+ // todo!
+ // zeta_project.update(|zeta_project, cx| zeta_project.register(path, cx));
+ }
+}
+
+impl ZetaProject {
+ pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
+ cx.subscribe(&project, Self::handle_project_event).detach();
+ cx.subscribe(
+ &project.read(cx).git_store().clone(),
+ Self::handle_git_store_event,
+ )
+ .detach();
+ cx.subscribe(
+ &project.read(cx).worktree_store(),
+ Self::handle_worktree_store_event,
+ )
+ .detach();
+
+ let buffer_store = project.read(cx).buffer_store().clone();
+ for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
+ Self::register_buffer(&buffer, cx);
+ }
+ cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
+ .detach();
+
+ Self {
+ project: project.downgrade(),
+ repositories: HashMap::default(),
+ buffers_changed: HashSet::default(),
+ project_entries_changed: HashSet::default(),
+ }
+ }
+
+ fn handle_git_store_event(
+ &mut self,
+ _git_store: Entity<GitStore>,
+ event: &GitStoreEvent,
+ cx: &mut Context<Self>,
+ ) {
+ use GitStoreEvent::*;
+ match event {
+ RepositoryRemoved(repository_id) => {
+ self.repositories.remove(&repository_id);
+ }
+ RepositoryAdded(repository_id) => {
+ self.repositories
+ .insert(*repository_id, cx.new(|cx| ZetaRepository::new(cx)));
+ }
+ RepositoryUpdated(repository_id, event, is_active) => {}
+ ActiveRepositoryChanged { .. }
+ | IndexWriteError { .. }
+ | JobsUpdated
+ | ConflictsUpdated => {}
+ }
+ }
+
+ fn handle_worktree_store_event(
+ &mut self,
+ _worktree_store: Entity<WorktreeStore>,
+ event: &WorktreeStoreEvent,
+ cx: &mut Context<Self>,
+ ) {
+ use WorktreeStoreEvent::*;
+ match event {
+ WorktreeAdded(worktree) => {}
+ WorktreeRemoved(worktree_entity_id, worktree_id) => {}
+ WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
+ for (path, entry_id, _path_change) in updated_entries_set.iter() {
+ self.project_entries_changed.insert(*entry_id);
+ }
+ }
+ WorktreeUpdatedGitRepositories(worktree_id, updated_git_repositories) => {}
+ WorktreeDeletedEntry(worktree_id, project_entry_id) => {}
+ WorktreeReleased { .. } | WorktreeOrderChanged | WorktreeUpdateSent { .. } => {}
+ }
+ }
+
+ fn handle_buffer_store_event(
+ &mut self,
+ _buffer_store: Entity<BufferStore>,
+ event: &BufferStoreEvent,
+ cx: &mut Context<Self>,
+ ) {
+ use BufferStoreEvent::*;
+ match event {
+ BufferAdded(buffer) => Self::register_buffer(buffer, cx),
+ BufferOpened { .. }
+ | BufferChangedFilePath { .. }
+ | BufferDropped { .. }
+ | SharedBufferClosed { .. } => {}
+ }
+ }
+
+ fn register_buffer(buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
+ cx.subscribe(buffer, Self::handle_buffer_event);
+ }
+
+ fn handle_buffer_event(
+ &mut self,
+ buffer: Entity<Buffer>,
+ event: &BufferEvent,
+ _cx: &mut Context<Self>,
+ ) {
+ match event {
+ BufferEvent::Edited => {
+ self.buffers_changed.insert(buffer.downgrade());
+ }
+ _ => {}
+ }
+ }
+
+ fn handle_project_event(
+ &mut self,
+ _project: Entity<Project>,
+ event: &project::Event,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ project::Event::ActiveEntryChanged(entry_id) => {
+ todo!()
+ }
+ _ => {}
+ }
+ }
+}
+
+impl ZetaRepository {
+ pub fn new(cx: &mut Context<Self>) -> Self {
+ Self {
+ unsent_events: Vec::new(),
+ pending_event: None,
+ last_snapshot: None,
+ license_watcher: LicenseDetectionWatcher::new(cx),
+ }
+ }
+}
@@ -5,6 +5,7 @@ mod license_detection;
mod onboarding_modal;
mod onboarding_telemetry;
mod rate_completion_modal;
+mod training_data_uploader;
use arrayvec::ArrayVec;
pub(crate) use completion_diff_element::*;
@@ -60,6 +61,8 @@ use workspace::Workspace;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use worktree::Worktree;
+use crate::training_data_uploader::TrainingDataUploader;
+
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
@@ -247,6 +250,8 @@ pub struct Zeta {
user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
projects: HashMap<EntityId, ZetaProject>,
+ /// todo! document that this should be the only place Entity<TrainingDataUploader> is stored.
+ training_data_uploader: Option<Entity<TrainingDataUploader>>,
}
struct ZetaProject {
@@ -377,6 +382,7 @@ impl Zeta {
license_detection_watchers: HashMap::default(),
user_store,
projects: HashMap::default(),
+ training_data_uploader: None,
}
}
@@ -529,6 +535,31 @@ impl Zeta {
can_collect_data,
));
+ // todo! async
+ if matches!(can_collect_data, CanCollectData(true)) {
+ let training_data_uploader = match &self.training_data_uploader {
+ None => {
+ let training_data_uploader = cx.new(|cx| TrainingDataUploader::new(cx));
+ self.training_data_uploader = Some(training_data_uploader.clone());
+ &training_data_uploader
+ }
+ Some(training_data_uploader) => training_data_uploader,
+ };
+
+ training_data_uploader.update(cx, |training_data_uploader, cx| {
+ let project = project.read(cx);
+ let entry = project.entry_for_path(&project_path, cx)?;
+ if !worktree_entry_is_eligible_for_collection(&entry) {
+ return None;
+ }
+
+ let git_store = project.git_store().read(cx);
+ let (repository, repo_path) =
+ git_store.repository_and_path_for_project_path(&project_path, cx)?;
+ let repo_path_string = repo_path.to_str()?.to_string();
+ });
+ }
+
cx.spawn(async move |this, cx| {
let GatherContextOutput {
body,