@@ -149,6 +149,22 @@ pub struct PredictEditsBody {
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
+ /// Info about the git repository state, only present when can_collect_data is true.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub git_info: Option<PredictEditsGitInfo>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsGitInfo {
+ /// SHA of git HEAD commit at time of prediction.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub head_sha: Option<String>,
+ /// URL of the remote called `origin`.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub remote_origin_url: Option<String>,
+ /// URL of the remote called `upstream`.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub remote_upstream_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -246,6 +246,8 @@ pub struct RepositorySnapshot {
pub head_commit: Option<CommitDetails>,
pub scan_id: u64,
pub merge: MergeDetails,
+ pub remote_origin_url: Option<String>,
+ pub remote_upstream_url: Option<String>,
}
type JobId = u64;
@@ -2673,6 +2675,8 @@ impl RepositorySnapshot {
head_commit: None,
scan_id: 0,
merge: Default::default(),
+ remote_origin_url: None,
+ remote_upstream_url: None,
}
}
@@ -4818,6 +4822,10 @@ async fn compute_snapshot(
None => None,
};
+ // Used by edit prediction data collection
+ let remote_origin_url = backend.remote_url("origin");
+ let remote_upstream_url = backend.remote_url("upstream");
+
let snapshot = RepositorySnapshot {
id,
statuses_by_path,
@@ -4826,6 +4834,8 @@ async fn compute_snapshot(
branch,
head_commit,
merge: merge_details,
+ remote_origin_url,
+ remote_upstream_url,
};
Ok((snapshot, events))
@@ -19,7 +19,7 @@ use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
- PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+ PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
};
use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt;
@@ -34,7 +34,7 @@ use language::{
};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use postage::watch;
-use project::Project;
+use project::{Project, ProjectPath};
use release_channel::AppVersion;
use settings::WorktreeId;
use std::str::FromStr;
@@ -400,6 +400,14 @@ impl Zeta {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
+ let git_info = if let (true, Some(project), Some(file)) =
+ (can_collect_data, project, snapshot.file())
+ {
+ git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ } else {
+ None
+ };
+
let full_path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
@@ -415,6 +423,7 @@ impl Zeta {
cursor_point,
make_events_prompt,
can_collect_data,
+ git_info,
cx,
);
@@ -1155,6 +1164,35 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
+fn git_info_for_file(
+ project: &Entity<Project>,
+ project_path: &ProjectPath,
+ cx: &App,
+) -> Option<PredictEditsGitInfo> {
+ let git_store = project.read(cx).git_store().read(cx);
+ if let Some((repository, _repo_path)) =
+ git_store.repository_and_path_for_project_path(project_path, cx)
+ {
+ let repository = repository.read(cx);
+ let head_sha = repository
+ .head_commit
+ .as_ref()
+ .map(|head_commit| head_commit.sha.to_string());
+ let remote_origin_url = repository.remote_origin_url.clone();
+ let remote_upstream_url = repository.remote_upstream_url.clone();
+ if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
+ return None;
+ }
+ Some(PredictEditsGitInfo {
+ head_sha,
+ remote_origin_url,
+ remote_upstream_url,
+ })
+ } else {
+ None
+ }
+}
+
pub struct GatherContextOutput {
pub body: PredictEditsBody,
pub editable_range: Range<usize>,
@@ -1167,6 +1205,7 @@ pub fn gather_context(
cursor_point: language::Point,
make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool,
+ git_info: Option<PredictEditsGitInfo>,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
let local_lsp_store =
@@ -1216,6 +1255,7 @@ pub fn gather_context(
outline: Some(input_outline),
can_collect_data,
diagnostic_groups,
+ git_info,
};
Ok(GatherContextOutput {
@@ -172,6 +172,7 @@ async fn get_context(
None => String::new(),
};
let can_collect_data = false;
+ let git_info = None;
cx.update(|cx| {
gather_context(
project.as_ref(),
@@ -180,6 +181,7 @@ async fn get_context(
clipped_cursor,
move || events,
can_collect_data,
+ git_info,
cx,
)
})?