diff --git a/Cargo.lock b/Cargo.lock index 6ff72c08b4482a7b5ae5e4abeed90157f6bcd124..430a52d0f3d41523c737dbcd6ecf4e0e5a9424fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21533,6 +21533,7 @@ dependencies = [ "serde", "serde_json", "settings", + "telemetry", "text", "ui", "ui_input", diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index fff7469199f88b88bb02fcf2d595d5ee76628315..51cf0b03a56a03aaaa9cc8ad6550d8debfda0df7 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -1290,5 +1290,13 @@ "home": "settings_editor::FocusFirstNavEntry", "end": "settings_editor::FocusLastNavEntry" } + }, + { + "context": "Zeta2Feedback > Editor", + "bindings": { + "enter": "editor::Newline", + "ctrl-enter up": "dev::Zeta2RatePredictionPositive", + "ctrl-enter down": "dev::Zeta2RatePredictionNegative" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 0b4119e95e4bf33d1f19a538fa231cc13ff79419..97846a8edf63cae577eb17d49ee835b43295be35 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -1396,5 +1396,13 @@ "home": "settings_editor::FocusFirstNavEntry", "end": "settings_editor::FocusLastNavEntry" } + }, + { + "context": "Zeta2Feedback > Editor", + "bindings": { + "enter": "editor::Newline", + "cmd-enter up": "dev::Zeta2RatePredictionPositive", + "cmd-enter down": "dev::Zeta2RatePredictionNegative" + } } ] diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 39c1b672a4105e9565bbdaded7229402831c702d..02bd2207c2805ef5f3eef1b06378f197595ad4a4 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -1319,5 +1319,13 @@ "home": "settings_editor::FocusFirstNavEntry", "end": "settings_editor::FocusLastNavEntry" } + }, + { + "context": "Zeta2Feedback > Editor", + "bindings": { + "enter": "editor::Newline", + "ctrl-enter up": "dev::Zeta2RatePredictionPositive", + "ctrl-enter down": "dev::Zeta2RatePredictionNegative" + } } ] diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 32dec9f723a6776fd14def29be3be4eb21afa72d..65eb25e6ac9d005fc2e18901a56287e2938e5bb8 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -48,24 +48,10 @@ use util::rel_path::RelPath; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProjectSnapshot { - pub worktree_snapshots: Vec, + pub worktree_snapshots: Vec, pub timestamp: DateTime, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct WorktreeSnapshot { - pub worktree_path: String, - pub git_state: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct GitState { - pub remote_url: Option, - pub head_sha: Option, - pub current_branch: Option, - pub diff: Option, -} - const RULES_FILE_NAMES: [&str; 9] = [ ".rules", ".cursorrules", diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index ec9d50ff2f62c5602dd91e5da47593764ea01c85..c89ad1df241c3b9c6e07b9a5433dd964244ba2cb 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1,9 +1,8 @@ use crate::{ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread, - DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GitState, GrepTool, + DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool, SystemPromptTemplate, Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, - WorktreeSnapshot, }; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; @@ -26,7 +25,6 @@ use futures::{ future::Shared, stream::FuturesUnordered, }; -use git::repository::DiffType; use gpui::{ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; @@ -37,10 +35,7 @@ use language_model::{ LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID, }; -use project::{ - Project, - git_store::{GitStore, RepositoryState}, -}; +use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; @@ -880,101 +875,17 @@ impl Thread { project: Entity, cx: &mut Context, ) -> Task> { - let git_store = project.read(cx).git_store().clone(); - let worktree_snapshots: Vec<_> = project - .read(cx) - .visible_worktrees(cx) - .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) - .collect(); - + let task = project::telemetry_snapshot::TelemetrySnapshot::new(&project, cx); cx.spawn(async move |_, _| { - let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + let snapshot = task.await; Arc::new(ProjectSnapshot { - worktree_snapshots, + worktree_snapshots: snapshot.worktree_snapshots, timestamp: Utc::now(), }) }) } - fn worktree_snapshot( - worktree: Entity, - git_store: Entity, - cx: &App, - ) -> Task { - cx.spawn(async move |cx| { - // Get worktree path and snapshot - let worktree_info = cx.update(|app_cx| { - let worktree = worktree.read(app_cx); - let path = worktree.abs_path().to_string_lossy().into_owned(); - let snapshot = worktree.snapshot(); - (path, snapshot) - }); - - let Ok((worktree_path, _snapshot)) = worktree_info else { - return WorktreeSnapshot { - worktree_path: String::new(), - git_state: None, - }; - }; - - let git_state = git_store - .update(cx, |git_store, cx| { - git_store - .repositories() - .values() - .find(|repo| { - repo.read(cx) - .abs_path_to_repo_path(&worktree.read(cx).abs_path()) - .is_some() - }) - .cloned() - }) - .ok() - .flatten() - .map(|repo| { - repo.update(cx, |repo, _| { - let current_branch = - repo.branch.as_ref().map(|branch| branch.name().to_owned()); - repo.send_job(None, |state, _| async move { - let RepositoryState::Local { backend, .. } = state else { - return GitState { - remote_url: None, - head_sha: None, - current_branch, - diff: None, - }; - }; - - let remote_url = backend.remote_url("origin"); - let head_sha = backend.head_sha().await; - let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); - - GitState { - remote_url, - head_sha, - current_branch, - diff, - } - }) - }) - }); - - let git_state = match git_state { - Some(git_state) => match git_state.ok() { - Some(git_state) => git_state.await.ok(), - None => None, - }, - None => None, - }; - - WorktreeSnapshot { - worktree_path, - git_state, - } - }) - } - pub fn project_context(&self) -> &Entity { &self.project_context } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 678607b53219992317e1762ff15b57500eb33d79..c0d853966694a68fad9d69ad160071c3d5fca9bf 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -16,6 +16,7 @@ pub mod project_settings; pub mod search; mod task_inventory; pub mod task_store; +pub mod telemetry_snapshot; pub mod terminals; pub mod toolchain_store; pub mod worktree_store; diff --git a/crates/project/src/telemetry_snapshot.rs b/crates/project/src/telemetry_snapshot.rs new file mode 100644 index 0000000000000000000000000000000000000000..79fe2bd8b3f21df03b4cf7a59f73df93b22f3a6c --- /dev/null +++ b/crates/project/src/telemetry_snapshot.rs @@ -0,0 +1,125 @@ +use git::repository::DiffType; +use gpui::{App, Entity, Task}; +use serde::{Deserialize, Serialize}; +use worktree::Worktree; + +use crate::{ + Project, + git_store::{GitStore, RepositoryState}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct TelemetrySnapshot { + pub worktree_snapshots: Vec, +} + +impl TelemetrySnapshot { + pub fn new(project: &Entity, cx: &mut App) -> Task { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| TelemetryWorktreeSnapshot::new(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + Self { worktree_snapshots } + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct TelemetryWorktreeSnapshot { + pub worktree_path: String, + pub git_state: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct GitState { + pub remote_url: Option, + pub head_sha: Option, + pub current_branch: Option, + pub diff: Option, +} + +impl TelemetryWorktreeSnapshot { + fn new( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().into_owned(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return TelemetryWorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + TelemetryWorktreeSnapshot { + worktree_path, + git_state, + } + }) + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 7c945d6ebbe9c994977adbcc72c6d8fc175930d4..42eb565502e6568491e820dfb5c0921e4d56039b 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -11,7 +11,7 @@ use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState, }; -use feature_flags::FeatureFlag; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncReadExt as _; use futures::channel::{mpsc, oneshot}; use gpui::http_client::{AsyncBody, Method}; @@ -32,7 +32,6 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; use util::rel_path::RelPathBuf; -use util::some_or_debug_panic; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; mod prediction; @@ -103,12 +102,12 @@ pub struct ZetaOptions { } pub struct PredictionDebugInfo { - pub context: EditPredictionContext, + pub request: predict_edits_v3::PredictEditsRequest, pub retrieval_time: TimeDelta, pub buffer: WeakEntity, pub position: language::Anchor, pub local_prompt: Result, - pub response_rx: oneshot::Receiver>, + pub response_rx: oneshot::Receiver>, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -571,6 +570,9 @@ impl Zeta { if path.pop() { Some(path) } else { None } }); + // TODO data collection + let can_collect_data = cx.is_staff(); + let request_task = cx.background_spawn({ let snapshot = snapshot.clone(); let buffer = buffer.clone(); @@ -606,25 +608,22 @@ impl Zeta { options.max_diagnostic_bytes, ); - let debug_context = debug_tx.map(|tx| (tx, context.clone())); - let request = make_cloud_request( excerpt_path, context, events, - // TODO data collection - false, + can_collect_data, diagnostic_groups, diagnostic_groups_truncated, None, - debug_context.is_some(), + debug_tx.is_some(), &worktree_snapshots, index_state.as_deref(), Some(options.max_prompt_bytes), options.prompt_format, ); - let debug_response_tx = if let Some((debug_tx, context)) = debug_context { + let debug_response_tx = if let Some(debug_tx) = &debug_tx { let (response_tx, response_rx) = oneshot::channel(); let local_prompt = PlannedPrompt::populate(&request) @@ -633,7 +632,7 @@ impl Zeta { debug_tx .unbounded_send(PredictionDebugInfo { - context, + request: request.clone(), retrieval_time, buffer: buffer.downgrade(), local_prompt, @@ -660,12 +659,12 @@ impl Zeta { if let Some(debug_response_tx) = debug_response_tx { debug_response_tx - .send(response.as_ref().map_err(|err| err.to_string()).and_then( - |response| match some_or_debug_panic(response.0.debug_info.clone()) { - Some(debug_info) => Ok(debug_info), - None => Err("Missing debug info".to_string()), - }, - )) + .send( + response + .as_ref() + .map_err(|err| err.to_string()) + .map(|response| response.0.clone()), + ) .ok(); } diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index b56b806e783b7e6acc946a9dadb00703e4a7f2c1..edd1b1eb242c6c02001bec53120425f9a05e5d1d 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -27,6 +27,7 @@ multi_buffer.workspace = true ordered-float.workspace = true project.workspace = true serde.workspace = true +telemetry.workspace = true text.workspace = true ui.workspace = true ui_input.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 0ac4fb2162ca632618df0c2b0d256b2fd7c30742..7b806e2b9a4ba7c7dbda41bb0f5750e5d2b9ff97 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -1,37 +1,38 @@ -use std::{ - cmp::Reverse, collections::hash_map::Entry, path::PathBuf, str::FromStr, sync::Arc, - time::Duration, -}; +use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; use chrono::TimeDelta; use client::{Client, UserStore}; -use cloud_llm_client::predict_edits_v3::{DeclarationScoreComponents, PromptFormat}; +use cloud_llm_client::predict_edits_v3::{ + self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat, +}; use collections::HashMap; use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer}; use feature_flags::FeatureFlagAppExt as _; -use futures::{StreamExt as _, channel::oneshot}; +use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; use gpui::{ - CursorStyle, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, - actions, prelude::*, + CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, + WeakEntity, actions, prelude::*, }; use language::{Buffer, DiskState}; use ordered_float::OrderedFloat; -use project::{Project, WorktreeId}; -use ui::{ContextMenu, ContextMenuEntry, DropdownMenu, prelude::*}; +use project::{Project, WorktreeId, telemetry_snapshot::TelemetrySnapshot}; +use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*}; use ui_input::SingleLineInput; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; use zeta2::{PredictionDebugInfo, Zeta, Zeta2FeatureFlag, ZetaOptions}; -use edit_prediction_context::{ - DeclarationStyle, EditPredictionContextOptions, EditPredictionExcerptOptions, -}; +use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions}; actions!( dev, [ /// Opens the language server protocol logs viewer. - OpenZeta2Inspector + OpenZeta2Inspector, + /// Rate prediction as positive. + Zeta2RatePredictionPositive, + /// Rate prediction as negative. + Zeta2RatePredictionNegative, ] ); @@ -89,16 +90,24 @@ struct LastPrediction { buffer: WeakEntity, position: language::Anchor, state: LastPredictionState, + request: PredictEditsRequest, + project_snapshot: Shared>>, _task: Option>, } +#[derive(Clone, Copy, PartialEq)] +enum Feedback { + Positive, + Negative, +} + enum LastPredictionState { Requested, Success { - inference_time: TimeDelta, - parsing_time: TimeDelta, - prompt_planning_time: TimeDelta, model_response_editor: Entity, + feedback_editor: Entity, + feedback: Option, + response: predict_edits_v3::PredictEditsResponse, }, Failed { message: String, @@ -129,7 +138,7 @@ impl Zeta2Inspector { focus_handle: cx.focus_handle(), project: project.clone(), last_prediction: None, - active_view: ActiveView::Context, + active_view: ActiveView::Inference, max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx), min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx), cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx), @@ -300,17 +309,23 @@ impl Zeta2Inspector { let language_registry = self.project.read(cx).languages().clone(); async move |this, cx| { let mut languages = HashMap::default(); - for lang_id in prediction - .context - .declarations + for ext in prediction + .request + .referenced_declarations .iter() - .map(|snippet| snippet.declaration.identifier().language_id) - .chain(prediction.context.excerpt_text.language_id) + .filter_map(|snippet| snippet.path.extension()) + .chain(prediction.request.excerpt_path.extension()) { - if let Entry::Vacant(entry) = languages.entry(lang_id) { + if !languages.contains_key(ext) { // Most snippets are gonna be the same language, // so we think it's fine to do this sequentially for now - entry.insert(language_registry.language_for_id(lang_id).await.ok()); + languages.insert( + ext.to_owned(), + language_registry + .language_for_name_or_extension(&ext.to_string_lossy()) + .await + .ok(), + ); } } @@ -333,13 +348,12 @@ impl Zeta2Inspector { let excerpt_buffer = cx.new(|cx| { let mut buffer = - Buffer::local(prediction.context.excerpt_text.body, cx); + Buffer::local(prediction.request.excerpt.clone(), cx); if let Some(language) = prediction - .context - .excerpt_text - .language_id - .as_ref() - .and_then(|id| languages.get(id)) + .request + .excerpt_path + .extension() + .and_then(|ext| languages.get(ext)) { buffer.set_language(language.clone(), cx); } @@ -353,25 +367,18 @@ impl Zeta2Inspector { cx, ); - let mut declarations = prediction.context.declarations.clone(); + let mut declarations = + prediction.request.referenced_declarations.clone(); declarations.sort_unstable_by_key(|declaration| { - Reverse(OrderedFloat( - declaration.score(DeclarationStyle::Declaration), - )) + Reverse(OrderedFloat(declaration.declaration_score)) }); for snippet in &declarations { - let path = this - .project - .read(cx) - .path_for_entry(snippet.declaration.project_entry_id(), cx); - let snippet_file = Arc::new(ExcerptMetadataFile { title: RelPath::unix(&format!( "{} (Score: {})", - path.map(|p| p.path.display(path_style).to_string()) - .unwrap_or_else(|| "".to_string()), - snippet.score(DeclarationStyle::Declaration) + snippet.path.display(), + snippet.declaration_score )) .unwrap() .into(), @@ -380,11 +387,10 @@ impl Zeta2Inspector { }); let excerpt_buffer = cx.new(|cx| { - let mut buffer = - Buffer::local(snippet.declaration.item_text().0, cx); + let mut buffer = Buffer::local(snippet.text.clone(), cx); buffer.file_updated(snippet_file, cx); - if let Some(language) = - languages.get(&snippet.declaration.identifier().language_id) + if let Some(ext) = snippet.path.extension() + && let Some(language) = languages.get(ext) { buffer.set_language(language.clone(), cx); } @@ -399,7 +405,7 @@ impl Zeta2Inspector { let excerpt_id = excerpt_ids.first().unwrap(); excerpt_score_components - .insert(*excerpt_id, snippet.components.clone()); + .insert(*excerpt_id, snippet.score_components.clone()); } multibuffer @@ -431,25 +437,91 @@ impl Zeta2Inspector { if let Some(prediction) = this.last_prediction.as_mut() { prediction.state = match response { Ok(Ok(response)) => { - prediction.prompt_editor.update( - cx, - |prompt_editor, cx| { - prompt_editor.set_text( - response.prompt, - window, + if let Some(debug_info) = &response.debug_info { + prediction.prompt_editor.update( + cx, + |prompt_editor, cx| { + prompt_editor.set_text( + debug_info.prompt.as_str(), + window, + cx, + ); + }, + ); + } + + let feedback_editor = cx.new(|cx| { + let buffer = cx.new(|cx| { + let mut buffer = Buffer::local("", cx); + buffer.set_language( + markdown_language.clone(), cx, ); + buffer + }); + let buffer = + cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let mut editor = Editor::new( + EditorMode::AutoHeight { + min_lines: 3, + max_lines: None, + }, + buffer, + None, + window, + cx, + ); + editor.set_placeholder_text( + "Write feedback here", + window, + cx, + ); + editor.set_show_line_numbers(false, cx); + editor.set_show_gutter(false, cx); + editor.set_show_scrollbars(false, cx); + editor + }); + + cx.subscribe_in( + &feedback_editor, + window, + |this, editor, ev, window, cx| match ev { + EditorEvent::BufferEdited => { + if let Some(last_prediction) = + this.last_prediction.as_mut() + && let LastPredictionState::Success { + feedback: feedback_state, + .. + } = &mut last_prediction.state + { + if feedback_state.take().is_some() { + editor.update(cx, |editor, cx| { + editor.set_placeholder_text( + "Write feedback here", + window, + cx, + ); + }); + cx.notify(); + } + } + } + _ => {} }, - ); + ) + .detach(); LastPredictionState::Success { - prompt_planning_time: response.prompt_planning_time, - inference_time: response.inference_time, - parsing_time: response.parsing_time, model_response_editor: cx.new(|cx| { let buffer = cx.new(|cx| { let mut buffer = Buffer::local( - response.model_response, + response + .debug_info + .as_ref() + .map(|p| p.model_response.as_str()) + .unwrap_or( + "(Debug info not available)", + ), cx, ); buffer.set_language(markdown_language, cx); @@ -471,6 +543,9 @@ impl Zeta2Inspector { editor.set_show_scrollbars(false, cx); editor }), + feedback_editor, + feedback: None, + response, } } Ok(Err(err)) => { @@ -486,6 +561,8 @@ impl Zeta2Inspector { } }); + let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx); + this.last_prediction = Some(LastPrediction { context_editor, prompt_editor: cx.new(|cx| { @@ -508,6 +585,11 @@ impl Zeta2Inspector { buffer, position, state: LastPredictionState::Requested, + project_snapshot: cx + .foreground_executor() + .spawn(async move { Arc::new(project_snapshot_task.await) }) + .shared(), + request: prediction.request, _task: Some(task), }); cx.notify(); @@ -517,6 +599,103 @@ impl Zeta2Inspector { }); } + fn handle_rate_positive( + &mut self, + _action: &Zeta2RatePredictionPositive, + window: &mut Window, + cx: &mut Context, + ) { + self.handle_rate(Feedback::Positive, window, cx); + } + + fn handle_rate_negative( + &mut self, + _action: &Zeta2RatePredictionNegative, + window: &mut Window, + cx: &mut Context, + ) { + self.handle_rate(Feedback::Negative, window, cx); + } + + fn handle_rate(&mut self, kind: Feedback, window: &mut Window, cx: &mut Context) { + let Some(last_prediction) = self.last_prediction.as_mut() else { + return; + }; + if !last_prediction.request.can_collect_data { + return; + } + + let project_snapshot_task = last_prediction.project_snapshot.clone(); + + cx.spawn_in(window, async move |this, cx| { + let project_snapshot = project_snapshot_task.await; + this.update_in(cx, |this, window, cx| { + let Some(last_prediction) = this.last_prediction.as_mut() else { + return; + }; + + let LastPredictionState::Success { + feedback: feedback_state, + feedback_editor, + model_response_editor, + response, + .. + } = &mut last_prediction.state + else { + return; + }; + + *feedback_state = Some(kind); + let text = feedback_editor.update(cx, |feedback_editor, cx| { + feedback_editor.set_placeholder_text( + "Submitted. Edit or submit again to change.", + window, + cx, + ); + feedback_editor.text(cx) + }); + cx.notify(); + + cx.defer_in(window, { + let model_response_editor = model_response_editor.downgrade(); + move |_, window, cx| { + if let Some(model_response_editor) = model_response_editor.upgrade() { + model_response_editor.focus_handle(cx).focus(window); + } + } + }); + + let kind = match kind { + Feedback::Positive => "positive", + Feedback::Negative => "negative", + }; + + telemetry::event!( + "Zeta2 Prediction Rated", + id = response.request_id, + kind = kind, + text = text, + request = last_prediction.request, + response = response, + project_snapshot = project_snapshot, + ); + }) + .log_err(); + }) + .detach(); + } + + fn focus_feedback(&mut self, window: &mut Window, cx: &mut Context) { + if let Some(last_prediction) = self.last_prediction.as_mut() { + if let LastPredictionState::Success { + feedback_editor, .. + } = &mut last_prediction.state + { + feedback_editor.focus_handle(cx).focus(window); + } + }; + } + fn render_options(&self, window: &mut Window, cx: &mut Context) -> Div { v_flex() .gap_2() @@ -618,8 +797,9 @@ impl Zeta2Inspector { ), ui::ToggleButtonSimple::new( "Inference", - cx.listener(|this, _, _, cx| { + cx.listener(|this, _, window, cx| { this.active_view = ActiveView::Inference; + this.focus_feedback(window, cx); cx.notify(); }), ), @@ -640,21 +820,24 @@ impl Zeta2Inspector { return None; }; - let (prompt_planning_time, inference_time, parsing_time) = match &prediction.state { - LastPredictionState::Success { - inference_time, - parsing_time, - prompt_planning_time, + let (prompt_planning_time, inference_time, parsing_time) = + if let LastPredictionState::Success { + response: + PredictEditsResponse { + debug_info: Some(debug_info), + .. + }, .. - } => ( - Some(*prompt_planning_time), - Some(*inference_time), - Some(*parsing_time), - ), - LastPredictionState::Requested | LastPredictionState::Failed { .. } => { + } = &prediction.state + { + ( + Some(debug_info.prompt_planning_time), + Some(debug_info.inference_time), + Some(debug_info.parsing_time), + ) + } else { (None, None, None) - } - }; + }; Some( v_flex() @@ -690,14 +873,16 @@ impl Zeta2Inspector { }) } - fn render_content(&self, cx: &mut Context) -> AnyElement { + fn render_content(&self, window: &mut Window, cx: &mut Context) -> AnyElement { if !cx.has_flag::() { return Self::render_message("`zeta2` feature flag is not enabled"); } match self.last_prediction.as_ref() { None => Self::render_message("No prediction"), - Some(prediction) => self.render_last_prediction(prediction, cx).into_any(), + Some(prediction) => self + .render_last_prediction(prediction, window, cx) + .into_any(), } } @@ -710,7 +895,12 @@ impl Zeta2Inspector { .into_any() } - fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context) -> Div { + fn render_last_prediction( + &self, + prediction: &LastPrediction, + window: &mut Window, + cx: &mut Context, + ) -> Div { match &self.active_view { ActiveView::Context => div().size_full().child(prediction.context_editor.clone()), ActiveView::Inference => h_flex() @@ -748,24 +938,107 @@ impl Zeta2Inspector { .flex_1() .gap_2() .h_full() - .p_4() - .child(ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall)) - .child(match &prediction.state { - LastPredictionState::Success { - model_response_editor, - .. - } => model_response_editor.clone().into_any_element(), - LastPredictionState::Requested => v_flex() - .p_4() + .child( + v_flex() + .flex_1() .gap_2() - .child(Label::new("Loading...").buffer_font(cx)) - .into_any(), - LastPredictionState::Failed { message } => v_flex() .p_4() - .gap_2() - .child(Label::new(message.clone()).buffer_font(cx)) - .into_any(), - }), + .child( + ui::Headline::new("Model Response") + .size(ui::HeadlineSize::XSmall), + ) + .child(match &prediction.state { + LastPredictionState::Success { + model_response_editor, + .. + } => model_response_editor.clone().into_any_element(), + LastPredictionState::Requested => v_flex() + .gap_2() + .child(Label::new("Loading...").buffer_font(cx)) + .into_any_element(), + LastPredictionState::Failed { message } => v_flex() + .gap_2() + .max_w_96() + .child(Label::new(message.clone()).buffer_font(cx)) + .into_any_element(), + }), + ) + .child(ui::divider()) + .child( + if prediction.request.can_collect_data + && let LastPredictionState::Success { + feedback_editor, + feedback: feedback_state, + .. + } = &prediction.state + { + v_flex() + .key_context("Zeta2Feedback") + .on_action(cx.listener(Self::handle_rate_positive)) + .on_action(cx.listener(Self::handle_rate_negative)) + .gap_2() + .p_2() + .child(feedback_editor.clone()) + .child( + h_flex() + .justify_end() + .w_full() + .child( + ButtonLike::new("rate-positive") + .when( + *feedback_state == Some(Feedback::Positive), + |this| this.style(ButtonStyle::Filled), + ) + .children( + KeyBinding::for_action( + &Zeta2RatePredictionPositive, + window, + cx, + ) + .map(|k| k.size(TextSize::Small.rems(cx))), + ) + .child(ui::Icon::new(ui::IconName::ThumbsUp)) + .on_click(cx.listener( + |this, _, window, cx| { + this.handle_rate_positive( + &Zeta2RatePredictionPositive, + window, + cx, + ); + }, + )), + ) + .child( + ButtonLike::new("rate-negative") + .when( + *feedback_state == Some(Feedback::Negative), + |this| this.style(ButtonStyle::Filled), + ) + .children( + KeyBinding::for_action( + &Zeta2RatePredictionNegative, + window, + cx, + ) + .map(|k| k.size(TextSize::Small.rems(cx))), + ) + .child(ui::Icon::new(ui::IconName::ThumbsDown)) + .on_click(cx.listener( + |this, _, window, cx| { + this.handle_rate_negative( + &Zeta2RatePredictionNegative, + window, + cx, + ); + }, + )), + ), + ) + .into_any() + } else { + Empty.into_any_element() + }, + ), ), } } @@ -808,7 +1081,7 @@ impl Render for Zeta2Inspector { .child(ui::vertical_divider()) .children(self.render_stats()), ) - .child(self.render_content(cx)) + .child(self.render_content(window, cx)) } }