diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 64d8b6962801cb57a4d2c253e63e3aac05c14a07..89999f26f39d63591cee4c43274d441439ea5f77 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,21 +1,7 @@ mod assistant; -use anyhow::{anyhow, Result}; -use assets::Assets; -use collections::HashMap; -use editor::Editor; -use futures::AsyncBufReadExt; -use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; -use gpui::executor::Background; -use gpui::{actions, AppContext, Task, ViewContext}; -use isahc::prelude::*; -use isahc::{http::StatusCode, Request}; +use gpui::{actions, AppContext}; use serde::{Deserialize, Serialize}; -use std::cell::RefCell; -use std::fs; -use std::rc::Rc; -use std::{io, sync::Arc}; -use util::{ResultExt, TryFutureExt}; pub use assistant::AssistantPanel; @@ -89,230 +75,5 @@ struct OpenAIChoice { } pub fn init(cx: &mut AppContext) { - // if *RELEASE_CHANNEL == ReleaseChannel::Stable { - // return; - // } - assistant::init(cx); - - // let assistant = Rc::new(Assistant::default()); - // cx.add_action({ - // let assistant = assistant.clone(); - // move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext| { - // assistant.assist(editor, cx).log_err(); - // } - // }); - // cx.capture_action({ - // let assistant = assistant.clone(); - // move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext| { - // if !assistant.cancel_last_assist(cx.view_id()) { - // cx.propagate_action(); - // } - // } - // }); -} - -type CompletionId = usize; - -#[derive(Default)] -struct Assistant(RefCell); - -#[derive(Default)] -struct AssistantState { - assist_stacks: HashMap>)>>, - next_completion_id: CompletionId, -} - -impl Assistant { - fn assist(self: &Rc, editor: &mut Editor, cx: &mut ViewContext) -> Result<()> { - let api_key = std::env::var("OPENAI_API_KEY")?; - - let selections = editor.selections.all(cx); - let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { - // Insert markers around selected text as described in the system prompt above. - let snapshot = buffer.snapshot(cx); - let mut user_message = String::new(); - let mut user_message_suffix = String::new(); - let mut buffer_offset = 0; - for selection in selections { - if !selection.is_empty() { - if user_message_suffix.is_empty() { - user_message_suffix.push_str("\n\n"); - } - user_message_suffix.push_str("[Selected excerpt from above]\n"); - user_message_suffix - .extend(snapshot.text_for_range(selection.start..selection.end)); - user_message_suffix.push_str("\n\n"); - } - - user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); - user_message.push_str("[SELECTION_START]"); - user_message.extend(snapshot.text_for_range(selection.start..selection.end)); - buffer_offset = selection.end; - user_message.push_str("[SELECTION_END]"); - } - if buffer_offset < snapshot.len() { - user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); - } - user_message.push_str(&user_message_suffix); - - // Ensure the document ends with 4 trailing newlines. - let trailing_newline_count = snapshot - .reversed_chars_at(snapshot.len()) - .take_while(|c| *c == '\n') - .take(4); - let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count()); - buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx); - - let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. - let insertion_site = snapshot.anchor_after(snapshot.len() - 2); - - (user_message, insertion_site) - }); - - let this = self.clone(); - let buffer = editor.buffer().clone(); - let executor = cx.background_executor().clone(); - let editor_id = cx.view_id(); - let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id); - let assist_task = cx.spawn(|_, mut cx| { - async move { - // TODO: We should have a get_string method on assets. This is repateated elsewhere. - let content = Assets::get("contexts/system.zmd").unwrap(); - let mut system_message = std::str::from_utf8(content.data.as_ref()) - .unwrap() - .to_string(); - - if let Ok(custom_system_message_path) = - std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH") - { - system_message.push_str( - "\n\nAlso consider the following user-defined system prompt:\n\n", - ); - // TODO: Replace this with our file system trait object. - system_message.push_str( - &cx.background() - .spawn(async move { fs::read_to_string(custom_system_message_path) }) - .await?, - ); - } - - let stream = stream_completion( - api_key, - executor, - OpenAIRequest { - model: "gpt-4".to_string(), - messages: vec![ - RequestMessage { - role: Role::System, - content: system_message.to_string(), - }, - RequestMessage { - role: Role::User, - content: user_message, - }, - ], - stream: false, - }, - ); - - let mut messages = stream.await?; - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - buffer.update(&mut cx, |buffer, cx| { - let text: Arc = choice.delta.content?.into(); - buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); - Some(()) - }); - } - } - - this.0 - .borrow_mut() - .assist_stacks - .get_mut(&editor_id) - .unwrap() - .retain(|(id, _)| *id != assist_id); - - anyhow::Ok(()) - } - .log_err() - }); - - self.0 - .borrow_mut() - .assist_stacks - .entry(cx.view_id()) - .or_default() - .push((dbg!(assist_id), assist_task)); - - Ok(()) - } - - fn cancel_last_assist(self: &Rc, editor_id: usize) -> bool { - self.0 - .borrow_mut() - .assist_stacks - .get_mut(&editor_id) - .and_then(|assists| assists.pop()) - .is_some() - } -} - -async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post("https://api.openai.com/v1/chat/completions") - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - tx.unbounded_send(event).log_err(); - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )) - } } diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 72e4e9fa427b22bcd985535e5ae431b52029f0f5..e2796f4ecfb656a748dc97958476c54f8c113547 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,12 +1,14 @@ -use crate::{stream_completion, OpenAIRequest, RequestMessage, Role}; +use crate::{OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role}; +use anyhow::{anyhow, Result}; use editor::{Editor, MultiBuffer}; -use futures::StreamExt; +use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use gpui::{ - actions, elements::*, Action, AppContext, Entity, ModelHandle, Subscription, Task, View, - ViewContext, ViewHandle, WeakViewHandle, WindowContext, + actions, elements::*, executor::Background, Action, AppContext, Entity, ModelHandle, + Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; +use isahc::{http::StatusCode, Request, RequestExt}; use language::{language_settings::SoftWrap, Anchor, Buffer}; -use std::sync::Arc; +use std::{io, sync::Arc}; use util::{post_inc, ResultExt, TryFutureExt}; use workspace::{ dock::{DockPosition, Panel}, @@ -17,8 +19,8 @@ use workspace::{ actions!(assistant, [NewContext, Assist, CancelLastAssist]); pub fn init(cx: &mut AppContext) { - cx.add_action(ContextEditor::assist); - cx.add_action(ContextEditor::cancel_last_assist); + cx.add_action(Assistant::assist); + cx.capture_action(Assistant::cancel_last_assist); } pub enum AssistantPanelEvent { @@ -37,9 +39,7 @@ pub struct AssistantPanel { impl AssistantPanel { pub fn new(workspace: &Workspace, cx: &mut ViewContext) -> Self { - let weak_self = cx.weak_handle(); let pane = cx.add_view(|cx| { - let window_id = cx.window_id(); let mut pane = Pane::new( workspace.weak_handle(), workspace.app_state().background_actions, @@ -48,16 +48,15 @@ impl AssistantPanel { ); pane.set_can_split(false, cx); pane.set_can_navigate(false, cx); - pane.on_can_drop(move |_, cx| false); + pane.on_can_drop(move |_, _| false); pane.set_render_tab_bar_buttons(cx, move |pane, cx| { - let this = weak_self.clone(); Flex::row() .with_child(Pane::render_tab_bar_button( 0, "icons/plus_12.svg", Some(("New Context".into(), Some(Box::new(NewContext)))), cx, - move |_, cx| {}, + move |_, _| todo!(), None, )) .with_child(Pane::render_tab_bar_button( @@ -123,7 +122,7 @@ impl View for AssistantPanel { } impl Panel for AssistantPanel { - fn position(&self, cx: &WindowContext) -> DockPosition { + fn position(&self, _: &WindowContext) -> DockPosition { DockPosition::Right } @@ -131,9 +130,11 @@ impl Panel for AssistantPanel { matches!(position, DockPosition::Right) } - fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext) {} + fn set_position(&mut self, _: DockPosition, _: &mut ViewContext) { + // TODO! + } - fn size(&self, cx: &WindowContext) -> f32 { + fn size(&self, _: &WindowContext) -> f32 { self.width.unwrap_or(480.) } @@ -164,7 +165,7 @@ impl Panel for AssistantPanel { if let Some(workspace) = this.workspace.upgrade(cx) { workspace.update(cx, |workspace, cx| { let focus = this.pane.read(cx).has_focus(); - let editor = Box::new(cx.add_view(|cx| ContextEditor::new(cx))); + let editor = Box::new(cx.add_view(|cx| Assistant::new(cx))); Pane::add_item(workspace, &this.pane, editor, true, focus, None, cx); }) } @@ -180,7 +181,8 @@ impl Panel for AssistantPanel { ("Assistant Panel".into(), None) } - fn should_change_position_on_event(event: &Self::Event) -> bool { + fn should_change_position_on_event(_: &Self::Event) -> bool { + // TODO! false } @@ -201,7 +203,7 @@ impl Panel for AssistantPanel { } } -struct ContextEditor { +struct Assistant { messages: Vec, editor: ViewHandle, completion_count: usize, @@ -210,10 +212,10 @@ struct ContextEditor { struct PendingCompletion { id: usize, - task: Task>, + _task: Task>, } -impl ContextEditor { +impl Assistant { fn new(cx: &mut ViewContext) -> Self { let messages = vec![Message { role: Role::User, @@ -264,15 +266,26 @@ impl ContextEditor { if let Some(api_key) = std::env::var("OPENAI_API_KEY").log_err() { let stream = stream_completion(api_key, cx.background_executor().clone(), request); - let content = cx.add_model(|cx| Buffer::new(0, "", cx)); + let response_buffer = cx.add_model(|cx| Buffer::new(0, "", cx)); self.messages.push(Message { role: Role::Assistant, - content: content.clone(), + content: response_buffer.clone(), + }); + let next_request_buffer = cx.add_model(|cx| Buffer::new(0, "", cx)); + self.messages.push(Message { + role: Role::User, + content: next_request_buffer.clone(), }); self.editor.update(cx, |editor, cx| { editor.buffer().update(cx, |multibuffer, cx| { multibuffer.push_excerpts_with_context_lines( - content.clone(), + response_buffer.clone(), + vec![Anchor::MIN..Anchor::MAX], + 0, + cx, + ); + multibuffer.push_excerpts_with_context_lines( + next_request_buffer, vec![Anchor::MIN..Anchor::MAX], 0, cx, @@ -286,7 +299,7 @@ impl ContextEditor { while let Some(message) = messages.next().await { let mut message = message?; if let Some(choice) = message.choices.pop() { - content.update(&mut cx, |content, cx| { + response_buffer.update(&mut cx, |content, cx| { let text: Arc = choice.delta.content?.into(); content.edit([(content.len()..content.len(), text)], None, cx); Some(()) @@ -307,23 +320,23 @@ impl ContextEditor { self.pending_completions.push(PendingCompletion { id: post_inc(&mut self.completion_count), - task, + _task: task, }); } } - fn cancel_last_assist(&mut self, _: &CancelLastAssist, cx: &mut ViewContext) { + fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { if self.pending_completions.pop().is_none() { cx.propagate_action(); } } } -impl Entity for ContextEditor { +impl Entity for Assistant { type Event = (); } -impl View for ContextEditor { +impl View for Assistant { fn ui_name() -> &'static str { "ContextEditor" } @@ -338,7 +351,7 @@ impl View for ContextEditor { } } -impl Item for ContextEditor { +impl Item for Assistant { fn tab_content( &self, _: Option, @@ -353,3 +366,60 @@ struct Message { role: Role, content: ModelHandle, } + +async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post("https://api.openai.com/v1/chat/completions") + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + tx.unbounded_send(event).log_err(); + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )) + } +} diff --git a/crates/zed/src/languages/markdown/config.toml b/crates/zed/src/languages/markdown/config.toml index 55204cc7a57ad051004a4fc0d76746057908aa20..2fa3ff3cf2aba297517494cbd1f2e0608daaa402 100644 --- a/crates/zed/src/languages/markdown/config.toml +++ b/crates/zed/src/languages/markdown/config.toml @@ -1,5 +1,5 @@ name = "Markdown" -path_suffixes = ["md", "mdx", "zmd"] +path_suffixes = ["md", "mdx"] brackets = [ { start = "{", end = "}", close = true, newline = true }, { start = "[", end = "]", close = true, newline = true },