WIP: Not sure I actually want to rip this out

Nathan Sobo created

Change summary

crates/ai/src/ai.rs                           | 241 --------------------
crates/ai/src/assistant.rs                    | 128 ++++++++--
crates/zed/src/languages/markdown/config.toml |   2 
3 files changed, 101 insertions(+), 270 deletions(-)

Detailed changes

crates/ai/src/ai.rs 🔗

@@ -1,21 +1,7 @@
 mod assistant;
 
-use anyhow::{anyhow, Result};
-use assets::Assets;
-use collections::HashMap;
-use editor::Editor;
-use futures::AsyncBufReadExt;
-use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
-use gpui::executor::Background;
-use gpui::{actions, AppContext, Task, ViewContext};
-use isahc::prelude::*;
-use isahc::{http::StatusCode, Request};
+use gpui::{actions, AppContext};
 use serde::{Deserialize, Serialize};
-use std::cell::RefCell;
-use std::fs;
-use std::rc::Rc;
-use std::{io, sync::Arc};
-use util::{ResultExt, TryFutureExt};
 
 pub use assistant::AssistantPanel;
 
@@ -89,230 +75,5 @@ struct OpenAIChoice {
 }
 
 pub fn init(cx: &mut AppContext) {
-    // if *RELEASE_CHANNEL == ReleaseChannel::Stable {
-    //     return;
-    // }
-
     assistant::init(cx);
-
-    // let assistant = Rc::new(Assistant::default());
-    // cx.add_action({
-    //     let assistant = assistant.clone();
-    //     move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
-    //         assistant.assist(editor, cx).log_err();
-    //     }
-    // });
-    // cx.capture_action({
-    //     let assistant = assistant.clone();
-    //     move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
-    //         if !assistant.cancel_last_assist(cx.view_id()) {
-    //             cx.propagate_action();
-    //         }
-    //     }
-    // });
-}
-
-type CompletionId = usize;
-
-#[derive(Default)]
-struct Assistant(RefCell<AssistantState>);
-
-#[derive(Default)]
-struct AssistantState {
-    assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
-    next_completion_id: CompletionId,
-}
-
-impl Assistant {
-    fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
-        let api_key = std::env::var("OPENAI_API_KEY")?;
-
-        let selections = editor.selections.all(cx);
-        let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
-            // Insert markers around selected text as described in the system prompt above.
-            let snapshot = buffer.snapshot(cx);
-            let mut user_message = String::new();
-            let mut user_message_suffix = String::new();
-            let mut buffer_offset = 0;
-            for selection in selections {
-                if !selection.is_empty() {
-                    if user_message_suffix.is_empty() {
-                        user_message_suffix.push_str("\n\n");
-                    }
-                    user_message_suffix.push_str("[Selected excerpt from above]\n");
-                    user_message_suffix
-                        .extend(snapshot.text_for_range(selection.start..selection.end));
-                    user_message_suffix.push_str("\n\n");
-                }
-
-                user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
-                user_message.push_str("[SELECTION_START]");
-                user_message.extend(snapshot.text_for_range(selection.start..selection.end));
-                buffer_offset = selection.end;
-                user_message.push_str("[SELECTION_END]");
-            }
-            if buffer_offset < snapshot.len() {
-                user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
-            }
-            user_message.push_str(&user_message_suffix);
-
-            // Ensure the document ends with 4 trailing newlines.
-            let trailing_newline_count = snapshot
-                .reversed_chars_at(snapshot.len())
-                .take_while(|c| *c == '\n')
-                .take(4);
-            let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count());
-            buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx);
-
-            let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
-            let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
-
-            (user_message, insertion_site)
-        });
-
-        let this = self.clone();
-        let buffer = editor.buffer().clone();
-        let executor = cx.background_executor().clone();
-        let editor_id = cx.view_id();
-        let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
-        let assist_task = cx.spawn(|_, mut cx| {
-            async move {
-                // TODO: We should have a get_string method on assets. This is repateated elsewhere.
-                let content = Assets::get("contexts/system.zmd").unwrap();
-                let mut system_message = std::str::from_utf8(content.data.as_ref())
-                    .unwrap()
-                    .to_string();
-
-                if let Ok(custom_system_message_path) =
-                    std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
-                {
-                    system_message.push_str(
-                        "\n\nAlso consider the following user-defined system prompt:\n\n",
-                    );
-                    // TODO: Replace this with our file system trait object.
-                    system_message.push_str(
-                        &cx.background()
-                            .spawn(async move { fs::read_to_string(custom_system_message_path) })
-                            .await?,
-                    );
-                }
-
-                let stream = stream_completion(
-                    api_key,
-                    executor,
-                    OpenAIRequest {
-                        model: "gpt-4".to_string(),
-                        messages: vec![
-                            RequestMessage {
-                                role: Role::System,
-                                content: system_message.to_string(),
-                            },
-                            RequestMessage {
-                                role: Role::User,
-                                content: user_message,
-                            },
-                        ],
-                        stream: false,
-                    },
-                );
-
-                let mut messages = stream.await?;
-                while let Some(message) = messages.next().await {
-                    let mut message = message?;
-                    if let Some(choice) = message.choices.pop() {
-                        buffer.update(&mut cx, |buffer, cx| {
-                            let text: Arc<str> = choice.delta.content?.into();
-                            buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
-                            Some(())
-                        });
-                    }
-                }
-
-                this.0
-                    .borrow_mut()
-                    .assist_stacks
-                    .get_mut(&editor_id)
-                    .unwrap()
-                    .retain(|(id, _)| *id != assist_id);
-
-                anyhow::Ok(())
-            }
-            .log_err()
-        });
-
-        self.0
-            .borrow_mut()
-            .assist_stacks
-            .entry(cx.view_id())
-            .or_default()
-            .push((dbg!(assist_id), assist_task));
-
-        Ok(())
-    }
-
-    fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
-        self.0
-            .borrow_mut()
-            .assist_stacks
-            .get_mut(&editor_id)
-            .and_then(|assists| assists.pop())
-            .is_some()
-    }
-}
-
-async fn stream_completion(
-    api_key: String,
-    executor: Arc<Background>,
-    mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
-    request.stream = true;
-
-    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
-    let json_data = serde_json::to_string(&request)?;
-    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
-        .body(json_data)?
-        .send_async()
-        .await?;
-
-    let status = response.status();
-    if status == StatusCode::OK {
-        executor
-            .spawn(async move {
-                let mut lines = BufReader::new(response.body_mut()).lines();
-
-                fn parse_line(
-                    line: Result<String, io::Error>,
-                ) -> Result<Option<OpenAIResponseStreamEvent>> {
-                    if let Some(data) = line?.strip_prefix("data: ") {
-                        let event = serde_json::from_str(&data)?;
-                        Ok(Some(event))
-                    } else {
-                        Ok(None)
-                    }
-                }
-
-                while let Some(line) = lines.next().await {
-                    if let Some(event) = parse_line(line).transpose() {
-                        tx.unbounded_send(event).log_err();
-                    }
-                }
-
-                anyhow::Ok(())
-            })
-            .detach();
-
-        Ok(rx)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        Err(anyhow!(
-            "Failed to connect to OpenAI API: {} {}",
-            response.status(),
-            body,
-        ))
-    }
 }

crates/ai/src/assistant.rs 🔗

@@ -1,12 +1,14 @@
-use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
+use crate::{OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role};
+use anyhow::{anyhow, Result};
 use editor::{Editor, MultiBuffer};
-use futures::StreamExt;
+use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use gpui::{
-    actions, elements::*, Action, AppContext, Entity, ModelHandle, Subscription, Task, View,
-    ViewContext, ViewHandle, WeakViewHandle, WindowContext,
+    actions, elements::*, executor::Background, Action, AppContext, Entity, ModelHandle,
+    Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
+use isahc::{http::StatusCode, Request, RequestExt};
 use language::{language_settings::SoftWrap, Anchor, Buffer};
-use std::sync::Arc;
+use std::{io, sync::Arc};
 use util::{post_inc, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
@@ -17,8 +19,8 @@ use workspace::{
 actions!(assistant, [NewContext, Assist, CancelLastAssist]);
 
 pub fn init(cx: &mut AppContext) {
-    cx.add_action(ContextEditor::assist);
-    cx.add_action(ContextEditor::cancel_last_assist);
+    cx.add_action(Assistant::assist);
+    cx.capture_action(Assistant::cancel_last_assist);
 }
 
 pub enum AssistantPanelEvent {
@@ -37,9 +39,7 @@ pub struct AssistantPanel {
 
 impl AssistantPanel {
     pub fn new(workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
-        let weak_self = cx.weak_handle();
         let pane = cx.add_view(|cx| {
-            let window_id = cx.window_id();
             let mut pane = Pane::new(
                 workspace.weak_handle(),
                 workspace.app_state().background_actions,
@@ -48,16 +48,15 @@ impl AssistantPanel {
             );
             pane.set_can_split(false, cx);
             pane.set_can_navigate(false, cx);
-            pane.on_can_drop(move |_, cx| false);
+            pane.on_can_drop(move |_, _| false);
             pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
-                let this = weak_self.clone();
                 Flex::row()
                     .with_child(Pane::render_tab_bar_button(
                         0,
                         "icons/plus_12.svg",
                         Some(("New Context".into(), Some(Box::new(NewContext)))),
                         cx,
-                        move |_, cx| {},
+                        move |_, _| todo!(),
                         None,
                     ))
                     .with_child(Pane::render_tab_bar_button(
@@ -123,7 +122,7 @@ impl View for AssistantPanel {
 }
 
 impl Panel for AssistantPanel {
-    fn position(&self, cx: &WindowContext) -> DockPosition {
+    fn position(&self, _: &WindowContext) -> DockPosition {
         DockPosition::Right
     }
 
@@ -131,9 +130,11 @@ impl Panel for AssistantPanel {
         matches!(position, DockPosition::Right)
     }
 
-    fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {}
+    fn set_position(&mut self, _: DockPosition, _: &mut ViewContext<Self>) {
+        // TODO!
+    }
 
-    fn size(&self, cx: &WindowContext) -> f32 {
+    fn size(&self, _: &WindowContext) -> f32 {
         self.width.unwrap_or(480.)
     }
 
@@ -164,7 +165,7 @@ impl Panel for AssistantPanel {
                 if let Some(workspace) = this.workspace.upgrade(cx) {
                     workspace.update(cx, |workspace, cx| {
                         let focus = this.pane.read(cx).has_focus();
-                        let editor = Box::new(cx.add_view(|cx| ContextEditor::new(cx)));
+                        let editor = Box::new(cx.add_view(|cx| Assistant::new(cx)));
                         Pane::add_item(workspace, &this.pane, editor, true, focus, None, cx);
                     })
                 }
@@ -180,7 +181,8 @@ impl Panel for AssistantPanel {
         ("Assistant Panel".into(), None)
     }
 
-    fn should_change_position_on_event(event: &Self::Event) -> bool {
+    fn should_change_position_on_event(_: &Self::Event) -> bool {
+        // TODO!
         false
     }
 
@@ -201,7 +203,7 @@ impl Panel for AssistantPanel {
     }
 }
 
-struct ContextEditor {
+struct Assistant {
     messages: Vec<Message>,
     editor: ViewHandle<Editor>,
     completion_count: usize,
@@ -210,10 +212,10 @@ struct ContextEditor {
 
 struct PendingCompletion {
     id: usize,
-    task: Task<Option<()>>,
+    _task: Task<Option<()>>,
 }
 
-impl ContextEditor {
+impl Assistant {
     fn new(cx: &mut ViewContext<Self>) -> Self {
         let messages = vec![Message {
             role: Role::User,
@@ -264,15 +266,26 @@ impl ContextEditor {
 
         if let Some(api_key) = std::env::var("OPENAI_API_KEY").log_err() {
             let stream = stream_completion(api_key, cx.background_executor().clone(), request);
-            let content = cx.add_model(|cx| Buffer::new(0, "", cx));
+            let response_buffer = cx.add_model(|cx| Buffer::new(0, "", cx));
             self.messages.push(Message {
                 role: Role::Assistant,
-                content: content.clone(),
+                content: response_buffer.clone(),
+            });
+            let next_request_buffer = cx.add_model(|cx| Buffer::new(0, "", cx));
+            self.messages.push(Message {
+                role: Role::User,
+                content: next_request_buffer.clone(),
             });
             self.editor.update(cx, |editor, cx| {
                 editor.buffer().update(cx, |multibuffer, cx| {
                     multibuffer.push_excerpts_with_context_lines(
-                        content.clone(),
+                        response_buffer.clone(),
+                        vec![Anchor::MIN..Anchor::MAX],
+                        0,
+                        cx,
+                    );
+                    multibuffer.push_excerpts_with_context_lines(
+                        next_request_buffer,
                         vec![Anchor::MIN..Anchor::MAX],
                         0,
                         cx,
@@ -286,7 +299,7 @@ impl ContextEditor {
                     while let Some(message) = messages.next().await {
                         let mut message = message?;
                         if let Some(choice) = message.choices.pop() {
-                            content.update(&mut cx, |content, cx| {
+                            response_buffer.update(&mut cx, |content, cx| {
                                 let text: Arc<str> = choice.delta.content?.into();
                                 content.edit([(content.len()..content.len(), text)], None, cx);
                                 Some(())
@@ -307,23 +320,23 @@ impl ContextEditor {
 
             self.pending_completions.push(PendingCompletion {
                 id: post_inc(&mut self.completion_count),
-                task,
+                _task: task,
             });
         }
     }
 
-    fn cancel_last_assist(&mut self, _: &CancelLastAssist, cx: &mut ViewContext<Self>) {
+    fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
         if self.pending_completions.pop().is_none() {
             cx.propagate_action();
         }
     }
 }
 
-impl Entity for ContextEditor {
+impl Entity for Assistant {
     type Event = ();
 }
 
-impl View for ContextEditor {
+impl View for Assistant {
     fn ui_name() -> &'static str {
         "ContextEditor"
     }
@@ -338,7 +351,7 @@ impl View for ContextEditor {
     }
 }
 
-impl Item for ContextEditor {
+impl Item for Assistant {
     fn tab_content<V: View>(
         &self,
         _: Option<usize>,
@@ -353,3 +366,60 @@ struct Message {
     role: Role,
     content: ModelHandle<Buffer>,
 }
+
+async fn stream_completion(
+    api_key: String,
+    executor: Arc<Background>,
+    mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    request.stream = true;
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = serde_json::to_string(&request)?;
+    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        tx.unbounded_send(event).log_err();
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        Err(anyhow!(
+            "Failed to connect to OpenAI API: {} {}",
+            response.status(),
+            body,
+        ))
+    }
+}

crates/zed/src/languages/markdown/config.toml 🔗

@@ -1,5 +1,5 @@
 name = "Markdown"
-path_suffixes = ["md", "mdx", "zmd"]
+path_suffixes = ["md", "mdx"]
 brackets = [
     { start = "{", end = "}", close = true, newline = true },
     { start = "[", end = "]", close = true, newline = true },