diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index b69c12a2a328ed8643315f091be11d764dcdc00d..b1c6038602b77465cef3f994b02cdc0635ed6776 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -17,7 +17,7 @@ use editor::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, }, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, + Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint, }; use fs::Fs; use futures::StreamExt; @@ -278,22 +278,36 @@ impl AssistantPanel { if selection.start.excerpt_id() != selection.end.excerpt_id() { return; } - - let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let provider = Arc::new(OpenAICompletionProvider::new( - api_key, - cx.background().clone(), - )); - let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { + + // Extend the selection to the start and the end of the line. + let mut point_selection = selection.map(|selection| selection.to_point(&snapshot)); + if point_selection.end > point_selection.start { + point_selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if point_selection.end.column == 0 { + point_selection.end.row -= 1; + } + point_selection.end.column = snapshot.line_len(point_selection.end.row); + } + + let codegen_kind = if point_selection.start == point_selection.end { CodegenKind::Generate { - position: selection.start, + position: snapshot.anchor_after(point_selection.start), } } else { CodegenKind::Transform { - range: selection.start..selection.end, + range: snapshot.anchor_before(point_selection.start) + ..snapshot.anchor_after(point_selection.end), } }; + + let inline_assist_id = post_inc(&mut self.next_inline_assist_id); + let provider = Arc::new(OpenAICompletionProvider::new( + api_key, + cx.background().clone(), + )); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -319,7 +333,7 @@ impl AssistantPanel { editor.insert_blocks( [BlockProperties { style: BlockStyle::Flex, - position: selection.head().bias_left(&snapshot), + position: snapshot.anchor_before(point_selection.head()), height: 2, render: Arc::new({ let inline_assistant = inline_assistant.clone(); @@ -578,10 +592,7 @@ impl AssistantPanel { let codegen_kind = codegen.read(cx).kind().clone(); let user_prompt = user_prompt.to_string(); - let prompt = cx.background().spawn(async move { - let language_name = language_name.as_deref(); - generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind) - }); + let mut messages = Vec::new(); let mut model = settings::get::(cx) .default_open_ai_model @@ -597,6 +608,11 @@ impl AssistantPanel { model = conversation.model.clone(); } + let prompt = cx.background().spawn(async move { + let language_name = language_name.as_deref(); + generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind) + }); + cx.spawn(|_, mut cx| async move { let prompt = prompt.await; diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e956d722606f6db27c73385d7cf54d58bc82958b..b6ef6b5cfa7fef58936828e0f121946290bc8b48 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,9 +1,7 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; use ai::completion::{CompletionProvider, OpenAIRequest}; use anyhow::Result; -use editor::{ - multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, -}; +use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; @@ -40,26 +38,11 @@ impl Entity for Codegen { impl Codegen { pub fn new( buffer: ModelHandle, - mut kind: CodegenKind, + kind: CodegenKind, provider: Arc, cx: &mut ModelContext, ) -> Self { let snapshot = buffer.read(cx).snapshot(cx); - match &mut kind { - CodegenKind::Transform { range } => { - let mut point_range = range.to_point(&snapshot); - point_range.start.column = 0; - if point_range.end.column > 0 || point_range.start.row == point_range.end.row { - point_range.end.column = snapshot.line_len(point_range.end.row); - } - range.start = snapshot.anchor_before(point_range.start); - range.end = snapshot.anchor_after(point_range.end); - } - CodegenKind::Generate { position } => { - *position = position.bias_right(&snapshot); - } - } - Self { provider, buffer: buffer.clone(), @@ -386,7 +369,7 @@ mod tests { let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); let provider = Arc::new(TestCompletionProvider::new()); let codegen = cx.add_model(|cx| { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index bf041dff523d57d62cfbc3f312a350ad4766d160..d326a7f44547ee977095484006a2867a2546d525 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -4,6 +4,7 @@ use std::cmp::{self, Reverse}; use std::fmt::Write; use std::ops::Range; +#[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { #[derive(Debug)] struct Match { @@ -121,6 +122,7 @@ pub fn generate_content_prompt( range: Range, kind: CodegenKind, ) -> String { + let range = range.to_offset(buffer); let mut prompt = String::new(); // General Preamble @@ -130,17 +132,29 @@ pub fn generate_content_prompt( writeln!(prompt, "You're an expert engineer.\n").unwrap(); } - let outline = summarize(buffer, range); + let mut content = String::new(); + content.extend(buffer.text_for_range(0..range.start)); + if range.start == range.end { + content.push_str("<|START|>"); + } else { + content.push_str("<|START|"); + } + content.extend(buffer.text_for_range(range.clone())); + if range.start != range.end { + content.push_str("|END|>"); + } + content.extend(buffer.text_for_range(range.end..buffer.len())); + writeln!( prompt, - "The file you are currently working on has the following outline:" + "The file you are currently working on has the following content:" ) .unwrap(); if let Some(language_name) = language_name { let language_name = language_name.to_lowercase(); - writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); + writeln!(prompt, "```{language_name}\n{content}\n```").unwrap(); } else { - writeln!(prompt, "```\n{outline}\n```").unwrap(); + writeln!(prompt, "```\n{content}\n```").unwrap(); } match kind {