Start wiring up assistant2

Antonio Scandurra created

Change summary

Cargo.lock                               |   2 
crates/assistant2/src/assistant_panel.rs | 229 ++++++++++++++-----------
crates/assistant2/src/codegen.rs         |   8 
crates/assistant2/src/prompts.rs         |   3 
crates/zed2/Cargo.toml                   |   4 
crates/zed2/src/main.rs                  |  12 
6 files changed, 137 insertions(+), 121 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -11881,6 +11881,7 @@ dependencies = [
  "activity_indicator2",
  "ai2",
  "anyhow",
+ "assistant2",
  "async-compression",
  "async-recursion 0.3.2",
  "async-tar",
@@ -11939,6 +11940,7 @@ dependencies = [
  "rust-embed",
  "schemars",
  "search2",
+ "semantic_index2",
  "serde",
  "serde_derive",
  "serde_json",

crates/assistant2/src/assistant_panel.rs 🔗

@@ -22,16 +22,18 @@ use editor::{
         BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
     },
     scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
-    Anchor, Editor, EditorEvent, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
+    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MoveDown, MoveUp, MultiBufferSnapshot,
+    ToOffset, ToPoint,
 };
 use fs::Fs;
 use futures::StreamExt;
 use gpui::{
-    actions, div, point, uniform_list, Action, AnyElement, AppContext, AsyncWindowContext,
-    ClipboardItem, Context, Div, EventEmitter, FocusHandle, Focusable, FocusableView,
-    HighlightStyle, InteractiveElement, IntoElement, Model, ModelContext, ParentElement, Pixels,
-    PromptLevel, Render, StatefulInteractiveElement, Styled, Subscription, Task,
-    UniformListScrollHandle, View, ViewContext, VisualContext, WeakModel, WeakView, WindowContext,
+    actions, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
+    AsyncWindowContext, ClipboardItem, Context, Div, EventEmitter, FocusHandle, Focusable,
+    FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
+    ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
+    StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
+    View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
 };
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
 use project::Project;
@@ -49,6 +51,7 @@ use std::{
     sync::Arc,
     time::{Duration, Instant},
 };
+use theme::{ActiveTheme, ThemeSettings};
 use ui::{
     h_stack, v_stack, Button, ButtonCommon, ButtonLike, Clickable, Color, Icon, IconButton,
     IconElement, Label, Selectable, Tooltip,
@@ -77,7 +80,7 @@ actions!(
 pub fn init(cx: &mut AppContext) {
     AssistantSettings::register(cx);
     cx.observe_new_views(
-        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
+        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
             workspace
                 .register_action(|workspace, _: &ToggleFocus, cx| {
                     workspace.toggle_panel_focus::<AssistantPanel>(cx);
@@ -122,7 +125,7 @@ impl AssistantPanel {
 
     pub fn load(
         workspace: WeakView<Workspace>,
-        mut cx: AsyncWindowContext,
+        cx: AsyncWindowContext,
     ) -> Task<Result<View<Self>>> {
         cx.spawn(|mut cx| async move {
             let fs = workspace.update(&mut cx, |workspace, _| workspace.app_state().fs.clone())?;
@@ -540,7 +543,7 @@ impl AssistantPanel {
         if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
             if let hash_map::Entry::Occupied(mut entry) = self
                 .pending_inline_assist_ids_by_editor
-                .entry(pending_assist.editor)
+                .entry(pending_assist.editor.clone())
             {
                 entry.get_mut().retain(|id| *id != assist_id);
                 if entry.get().is_empty() {
@@ -747,7 +750,7 @@ impl AssistantPanel {
                 temperature,
             });
 
-            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
             anyhow::Ok(())
         })
         .detach();
@@ -779,7 +782,7 @@ impl AssistantPanel {
             } else {
                 editor.highlight_background::<PendingInlineAssist>(
                     background_ranges,
-                    |theme| gpui::red(), // todo!("use the appropriate color")
+                    |theme| theme.editor_active_line_background, // todo!("use the appropriate color")
                     cx,
                 );
             }
@@ -1240,7 +1243,7 @@ impl Panel for AssistantPanel {
         }
     }
 
-    fn icon(&self, cx: &WindowContext) -> Option<Icon> {
+    fn icon(&self, _cx: &WindowContext) -> Option<Icon> {
         Some(Icon::Ai)
     }
 
@@ -1862,7 +1865,7 @@ impl Conversation {
                                 .text
                                 .push_str(&text);
                             cx.emit(ConversationEvent::SummaryChanged);
-                        });
+                        })?;
                     }
 
                     this.update(&mut cx, |this, cx| {
@@ -1870,7 +1873,7 @@ impl Conversation {
                             summary.done = true;
                             cx.emit(ConversationEvent::SummaryChanged);
                         }
-                    });
+                    })?;
 
                     anyhow::Ok(())
                 }
@@ -2249,7 +2252,7 @@ impl ConversationEditor {
                     style: BlockStyle::Sticky,
                     render: Arc::new({
                         let conversation = self.conversation.clone();
-                        move |cx| {
+                        move |_cx| {
                             let message_id = message.id;
                             let sender = ButtonLike::new("role")
                                 .child(match message.role {
@@ -2277,16 +2280,18 @@ impl ConversationEditor {
                                 .border_color(gpui::red())
                                 .child(sender)
                                 .child(Label::new(message.sent_at.format("%I:%M%P").to_string()))
-                                .children(if let MessageStatus::Error(error) = &message.status {
-                                    Some(
-                                        div()
-                                            .id("error")
-                                            .tooltip(|cx| Tooltip::text(error, cx))
-                                            .child(IconElement::new(Icon::XCircle)),
-                                    )
-                                } else {
-                                    None
-                                })
+                                .children(
+                                    if let MessageStatus::Error(error) = message.status.clone() {
+                                        Some(
+                                            div()
+                                                .id("error")
+                                                .tooltip(move |cx| Tooltip::text(&error, cx))
+                                                .child(IconElement::new(Icon::XCircle)),
+                                        )
+                                    } else {
+                                        None
+                                    },
+                                )
                                 .into_any_element()
                         }
                     }),
@@ -2602,10 +2607,11 @@ impl Render for InlineAssistant {
                         None
                     })
                     .children(if let Some(error) = self.codegen.read(cx).error() {
+                        let error_message = SharedString::from(error.to_string());
                         Some(
                             div()
                                 .id("error")
-                                .tooltip(|cx| Tooltip::text(error.to_string(), cx))
+                                .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
                                 .child(IconElement::new(Icon::XCircle).color(Color::Error)),
                         )
                     } else {
@@ -2615,7 +2621,7 @@ impl Render for InlineAssistant {
             .child(
                 div()
                     .ml(measurements.anchor_x - measurements.gutter_width)
-                    .child(self.prompt_editor.clone()),
+                    .child(self.render_prompt_editor(cx)),
             )
             .children(if self.retrieve_context {
                 self.retrieve_context_status(cx)
@@ -2752,24 +2758,14 @@ impl InlineAssistant {
 
     fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
         let is_read_only = !self.codegen.read(cx).idle();
-        self.prompt_editor.update(cx, |editor, cx| {
+        self.prompt_editor.update(cx, |editor, _cx| {
             let was_read_only = editor.read_only();
             if was_read_only != is_read_only {
                 if is_read_only {
                     editor.set_read_only(true);
-                    editor.set_field_editor_style(
-                        Some(Arc::new(|theme| {
-                            theme.assistant.inline.disabled_editor.clone()
-                        })),
-                        cx,
-                    );
                 } else {
                     self.confirmed = false;
                     editor.set_read_only(false);
-                    editor.set_field_editor_style(
-                        Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
-                        cx,
-                    );
                 }
             }
         });
@@ -2787,15 +2783,8 @@ impl InlineAssistant {
             report_assistant_event(self.workspace.clone(), None, AssistantKind::Inline, cx);
 
             let prompt = self.prompt_editor.read(cx).text(cx);
-            self.prompt_editor.update(cx, |editor, cx| {
-                editor.set_read_only(true);
-                editor.set_field_editor_style(
-                    Some(Arc::new(|theme| {
-                        theme.assistant.inline.disabled_editor.clone()
-                    })),
-                    cx,
-                );
-            });
+            self.prompt_editor
+                .update(cx, |editor, _cx| editor.set_read_only(true));
             cx.emit(InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation: self.include_conversation,
@@ -2827,7 +2816,7 @@ impl InlineAssistant {
         cx.spawn(|this, mut cx| async move {
             // If Necessary prompt user
             if !semantic_permissioned.await.unwrap_or(false) {
-                let mut answer = this.update(&mut cx, |_, cx| {
+                let answer = this.update(&mut cx, |_, cx| {
                     cx.prompt(
                         PromptLevel::Info,
                         prompt_text.as_str(),
@@ -2888,71 +2877,68 @@ impl InlineAssistant {
     }
 
     fn retrieve_context_status(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
-        enum ContextStatusIcon {}
-
         let Some(project) = self.project.upgrade() else {
             return None;
         };
 
-        if let Some(semantic_index) = SemanticIndex::global(cx) {
-            let status = semantic_index.update(cx, |index, _| index.status(&project));
-            match status {
-                SemanticIndexStatus::NotAuthenticated {} => Some(
-                    div()
-                        .id("error")
-                        .tooltip(|cx| Tooltip::text("Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.", cx))
-                        .child(IconElement::new(Icon::XCircle))
-                        .into_any_element()
-                ),
+        let semantic_index = SemanticIndex::global(cx)?;
+        let status = semantic_index.update(cx, |index, _| index.status(&project));
+        match status {
+            SemanticIndexStatus::NotAuthenticated {} => Some(
+                div()
+                    .id("error")
+                    .tooltip(|cx| Tooltip::text("Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.", cx))
+                    .child(IconElement::new(Icon::XCircle))
+                    .into_any_element()
+            ),
 
-                SemanticIndexStatus::NotIndexed {} => Some(
-                    div()
-                        .id("error")
-                        .tooltip(|cx| Tooltip::text("Not Indexed", cx))
-                        .child(IconElement::new(Icon::XCircle))
-                        .into_any_element()
-                ),
-
-                SemanticIndexStatus::Indexing {
-                    remaining_files,
-                    rate_limit_expiry,
-                } => {
-                    let mut status_text = if remaining_files == 0 {
-                        "Indexing...".to_string()
-                    } else {
-                        format!("Remaining files to index: {remaining_files}")
-                    };
+            SemanticIndexStatus::NotIndexed {} => Some(
+                div()
+                    .id("error")
+                    .tooltip(|cx| Tooltip::text("Not Indexed", cx))
+                    .child(IconElement::new(Icon::XCircle))
+                    .into_any_element()
+            ),
 
-                    if let Some(rate_limit_expiry) = rate_limit_expiry {
-                        let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
-                        if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
-                            write!(
-                                status_text,
-                                " (rate limit expires in {}s)",
-                                remaining_seconds.as_secs()
-                            )
-                            .unwrap();
-                        }
+            SemanticIndexStatus::Indexing {
+                remaining_files,
+                rate_limit_expiry,
+            } => {
+                let mut status_text = if remaining_files == 0 {
+                    "Indexing...".to_string()
+                } else {
+                    format!("Remaining files to index: {remaining_files}")
+                };
+
+                if let Some(rate_limit_expiry) = rate_limit_expiry {
+                    let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
+                    if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
+                        write!(
+                            status_text,
+                            " (rate limit expires in {}s)",
+                            remaining_seconds.as_secs()
+                        )
+                        .unwrap();
                     }
-                    Some(
-                        div()
-                            .id("update")
-                            .tooltip(|cx| Tooltip::text(status_text, cx))
-                            .child(IconElement::new(Icon::Update).color(Color::Info))
-                            .into_any_element()
-                    )
                 }
 
-                SemanticIndexStatus::Indexed {} => Some(
+                let status_text = SharedString::from(status_text);
+                Some(
                     div()
-                        .id("check")
-                        .tooltip(|cx| Tooltip::text("Index up to date", cx))
-                        .child(IconElement::new(Icon::Check).color(Color::Success))
+                        .id("update")
+                        .tooltip(move |cx| Tooltip::text(status_text.clone(), cx))
+                        .child(IconElement::new(Icon::Update).color(Color::Info))
                         .into_any_element()
-                ),
+                )
             }
-        } else {
-            None
+
+            SemanticIndexStatus::Indexed {} => Some(
+                div()
+                    .id("check")
+                    .tooltip(|cx| Tooltip::text("Index up to date", cx))
+                    .child(IconElement::new(Icon::Check).color(Color::Success))
+                    .into_any_element()
+            ),
         }
     }
 
@@ -3004,6 +2990,35 @@ impl InlineAssistant {
             });
         });
     }
+
+    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let settings = ThemeSettings::get_global(cx);
+        let text_style = TextStyle {
+            color: if self.prompt_editor.read(cx).read_only() {
+                cx.theme().colors().text_disabled
+            } else {
+                cx.theme().colors().text
+            },
+            font_family: settings.ui_font.family.clone(),
+            font_features: settings.ui_font.features,
+            font_size: rems(0.875).into(),
+            font_weight: FontWeight::NORMAL,
+            font_style: FontStyle::Normal,
+            line_height: relative(1.).into(),
+            background_color: None,
+            underline: None,
+            white_space: WhiteSpace::Normal,
+        };
+        EditorElement::new(
+            &self.prompt_editor,
+            EditorStyle {
+                background: cx.theme().colors().editor_background,
+                local_player: cx.theme().players().local(),
+                text: text_style,
+                ..Default::default()
+            },
+        )
+    }
 }
 
 // This wouldn't need to exist if we could pass parameters when rendering child views.
@@ -3052,7 +3067,8 @@ mod tests {
 
     #[gpui::test]
     fn test_inserting_and_removing_messages(cx: &mut AppContext) {
-        cx.set_global(SettingsStore::test(cx));
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
 
@@ -3183,7 +3199,8 @@ mod tests {
 
     #[gpui::test]
     fn test_message_splitting(cx: &mut AppContext) {
-        cx.set_global(SettingsStore::test(cx));
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let completion_provider = Arc::new(FakeCompletionProvider::new());
@@ -3282,7 +3299,8 @@ mod tests {
 
     #[gpui::test]
     fn test_messages_for_offsets(cx: &mut AppContext) {
-        cx.set_global(SettingsStore::test(cx));
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let completion_provider = Arc::new(FakeCompletionProvider::new());
@@ -3366,7 +3384,8 @@ mod tests {
 
     #[gpui::test]
     fn test_serialization(cx: &mut AppContext) {
-        cx.set_global(SettingsStore::test(cx));
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
         let completion_provider = Arc::new(FakeCompletionProvider::new());

crates/assistant2/src/codegen.rs 🔗

@@ -181,12 +181,6 @@ impl Codegen {
                     });
 
                     while let Some(hunks) = hunks_rx.next().await {
-                        let this = if let Some(this) = this.upgrade() {
-                            this
-                        } else {
-                            break;
-                        };
-
                         this.update(&mut cx, |this, cx| {
                             this.last_equal_ranges.clear();
 
@@ -243,7 +237,7 @@ impl Codegen {
                             }
 
                             cx.notify();
-                        });
+                        })?;
                     }
 
                     diff.await?;

crates/assistant2/src/prompts.rs 🔗

@@ -227,7 +227,8 @@ pub(crate) mod tests {
 
     #[gpui::test]
     fn test_outline_for_prompt(cx: &mut AppContext) {
-        cx.set_global(SettingsStore::test(cx));
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
         language_settings::init(cx);
         let text = indoc! {"
             struct X {

crates/zed2/Cargo.toml 🔗

@@ -49,7 +49,7 @@ lsp = { package = "lsp2", path = "../lsp2" }
 menu = { package = "menu2", path = "../menu2" }
 # language_tools = { path = "../language_tools" }
 node_runtime = { path = "../node_runtime" }
-# assistant = { path = "../assistant" }
+assistant = { package = "assistant2", path = "../assistant2" }
 outline = { package = "outline2", path = "../outline2" }
 # plugin_runtime = { path = "../plugin_runtime",optional = true }
 project = { package = "project2", path = "../project2" }
@@ -68,7 +68,7 @@ terminal_view = { package = "terminal_view2", path = "../terminal_view2" }
 theme = { package = "theme2", path = "../theme2" }
 theme_selector = { package = "theme_selector2", path = "../theme_selector2" }
 util = { path = "../util" }
-# semantic_index = { path = "../semantic_index" }
+semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
 # vim = { path = "../vim" }
 workspace = { package = "workspace2", path = "../workspace2" }
 welcome = { package = "welcome2", path = "../welcome2" }

crates/zed2/src/main.rs 🔗

@@ -161,11 +161,11 @@ fn main() {
             node_runtime.clone(),
             cx,
         );
-        // assistant::init(cx);
+        assistant::init(cx);
         // component_test::init(cx);
 
-        // cx.spawn(|_| watch_languages(fs.clone(), languages.clone()))
-        //     .detach();
+        cx.spawn(|_| watch_languages(fs.clone(), languages.clone()))
+            .detach();
         watch_file_types(fs.clone(), cx);
 
         languages.set_theme(cx.theme().clone());
@@ -186,10 +186,10 @@ fn main() {
             .report_app_event(telemetry_settings, event_operation);
 
         let app_state = Arc::new(AppState {
-            languages,
+            languages: languages.clone(),
             client: client.clone(),
             user_store: user_store.clone(),
-            fs,
+            fs: fs.clone(),
             build_window_options,
             workspace_store,
             node_runtime,
@@ -210,7 +210,7 @@ fn main() {
         channel::init(&client, user_store.clone(), cx);
         // diagnostics::init(cx);
         search::init(cx);
-        // semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
+        semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
         // vim::init(cx);
         terminal_view::init(cx);