@@ -1,7 +1,7 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
- AssistantPanel, AssistantPanelEvent, CharOperation, LineDiff, LineOperation, ModelSelector,
- StreamingDiff,
+ AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
+ CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::{telemetry::Telemetry, ErrorExt};
@@ -25,13 +25,13 @@ use futures::{
SinkExt, Stream, StreamExt,
};
use gpui::{
- anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
- FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
- UpdateGlobal, View, ViewContext, WeakView, WindowContext,
+ anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
+ FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
+ TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{
- LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
@@ -41,7 +41,7 @@ use smol::future::FutureExt;
use std::{
cmp,
future::{self, Future},
- mem,
+ iter, mem,
ops::{Range, RangeInclusive},
pin::Pin,
sync::Arc,
@@ -85,7 +85,7 @@ pub struct InlineAssistant {
async_watch::Receiver<AssistStatus>,
),
>,
- confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
+ confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
prompt_history: VecDeque<String>,
prompt_builder: Arc<PromptBuilder>,
telemetry: Option<Arc<Telemetry>>,
@@ -157,7 +157,7 @@ impl InlineAssistant {
if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id];
- if let CodegenStatus::Done = &assist.codegen.read(cx).status {
+ if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
self.finish_assist(assist_id, false, cx)
}
}
@@ -553,7 +553,7 @@ impl InlineAssistant {
let assist_range = assist.range.to_offset(&buffer);
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
{
- if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
+ if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
self.dismiss_assist(*assist_id, cx);
} else {
self.finish_assist(*assist_id, false, cx);
@@ -671,7 +671,7 @@ impl InlineAssistant {
for assist_id in editor_assists.assist_ids.clone() {
let assist = &self.assists[&assist_id];
if matches!(
- assist.codegen.read(cx).status,
+ assist.codegen.read(cx).status(cx),
CodegenStatus::Error(_) | CodegenStatus::Done
) {
let assist_range = assist.range.to_offset(&snapshot);
@@ -774,7 +774,9 @@ impl InlineAssistant {
if undo {
assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
} else {
- self.confirmed_assists.insert(assist_id, assist.codegen);
+ let confirmed_alternative = assist.codegen.read(cx).active_alternative().clone();
+ self.confirmed_assists
+ .insert(assist_id, confirmed_alternative);
}
}
@@ -978,12 +980,7 @@ impl InlineAssistant {
assist
.codegen
.update(cx, |codegen, cx| {
- codegen.start(
- assist.range.clone(),
- user_prompt,
- assistant_panel_context,
- cx,
- )
+ codegen.start(user_prompt, assistant_panel_context, cx)
})
.log_err();
@@ -1008,7 +1005,7 @@ impl InlineAssistant {
pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
if let Some(assist) = self.assists.get(&assist_id) {
- match &assist.codegen.read(cx).status {
+ match assist.codegen.read(cx).status(cx) {
CodegenStatus::Idle => InlineAssistStatus::Idle,
CodegenStatus::Pending => InlineAssistStatus::Pending,
CodegenStatus::Done => InlineAssistStatus::Done,
@@ -1037,16 +1034,16 @@ impl InlineAssistant {
for assist_id in assist_ids {
if let Some(assist) = self.assists.get(assist_id) {
let codegen = assist.codegen.read(cx);
- let buffer = codegen.buffer.read(cx).read(cx);
- foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
+ let buffer = codegen.buffer(cx).read(cx).read(cx);
+ foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
let pending_range =
- codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
+ codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
gutter_pending_ranges.push(pending_range);
}
- if let Some(edit_position) = codegen.edit_position {
+ if let Some(edit_position) = codegen.edit_position(cx) {
let edited_range = assist.range.start..edit_position;
if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
gutter_transformed_ranges.push(edited_range);
@@ -1054,7 +1051,8 @@ impl InlineAssistant {
}
if assist.decorations.is_some() {
- inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
+ inserted_row_ranges
+ .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
}
}
}
@@ -1125,9 +1123,9 @@ impl InlineAssistant {
};
let codegen = assist.codegen.read(cx);
- let old_snapshot = codegen.snapshot.clone();
- let old_buffer = codegen.old_buffer.clone();
- let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
+ let old_snapshot = codegen.snapshot(cx);
+ let old_buffer = codegen.old_buffer(cx);
+ let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
editor.update(cx, |editor, cx| {
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
@@ -1406,8 +1404,15 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
impl Render for PromptEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let gutter_dimensions = *self.gutter_dimensions.lock();
- let status = &self.codegen.read(cx).status;
- let buttons = match status {
+ let codegen = self.codegen.read(cx);
+
+ let mut buttons = Vec::new();
+ if codegen.alternative_count(cx) > 1 {
+ buttons.push(self.render_cycle_controls(cx));
+ }
+
+ let status = codegen.status(cx);
+ buttons.extend(match status {
CodegenStatus::Idle => {
vec![
IconButton::new("cancel", IconName::Close)
@@ -1416,14 +1421,16 @@ impl Render for PromptEditor {
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
- ),
+ )
+ .into_any_element(),
IconButton::new("start", IconName::SparkleAlt)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
- ),
+ )
+ .into_any_element(),
]
}
CodegenStatus::Pending => {
@@ -1434,7 +1441,8 @@ impl Render for PromptEditor {
.tooltip(|cx| Tooltip::text("Cancel Assist", cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
- ),
+ )
+ .into_any_element(),
IconButton::new("stop", IconName::Stop)
.icon_color(Color::Error)
.shape(IconButtonShape::Square)
@@ -1446,9 +1454,8 @@ impl Render for PromptEditor {
cx,
)
})
- .on_click(
- cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
- ),
+ .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
+ .into_any_element(),
]
}
CodegenStatus::Error(_) | CodegenStatus::Done => {
@@ -1459,7 +1466,8 @@ impl Render for PromptEditor {
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
- ),
+ )
+ .into_any_element(),
if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
IconButton::new("restart", IconName::RotateCw)
.icon_color(Color::Info)
@@ -1475,6 +1483,7 @@ impl Render for PromptEditor {
.on_click(cx.listener(|_, _, cx| {
cx.emit(PromptEditorEvent::StartRequested);
}))
+ .into_any_element()
} else {
IconButton::new("confirm", IconName::Check)
.icon_color(Color::Info)
@@ -1483,12 +1492,14 @@ impl Render for PromptEditor {
.on_click(cx.listener(|_, _, cx| {
cx.emit(PromptEditorEvent::ConfirmRequested);
}))
+ .into_any_element()
},
]
}
- };
+ });
h_flex()
+ .key_context("PromptEditor")
.bg(cx.theme().colors().editor_background)
.border_y_1()
.border_color(cx.theme().status().info_border)
@@ -1498,6 +1509,8 @@ impl Render for PromptEditor {
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::move_down))
+ .capture_action(cx.listener(Self::cycle_prev))
+ .capture_action(cx.listener(Self::cycle_next))
.child(
h_flex()
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
@@ -1532,7 +1545,7 @@ impl Render for PromptEditor {
),
)
.map(|el| {
- let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
+ let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
return el;
};
@@ -1776,7 +1789,7 @@ impl PromptEditor {
}
fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
+ match self.codegen.read(cx).status(cx) {
CodegenStatus::Idle => {
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
@@ -1807,7 +1820,7 @@ impl PromptEditor {
}
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
+ match self.codegen.read(cx).status(cx) {
CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
cx.emit(PromptEditorEvent::CancelRequested);
}
@@ -1818,7 +1831,7 @@ impl PromptEditor {
}
fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
+ match self.codegen.read(cx).status(cx) {
CodegenStatus::Idle => {
cx.emit(PromptEditorEvent::StartRequested);
}
@@ -1878,6 +1891,79 @@ impl PromptEditor {
}
}
+ fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
+ self.codegen
+ .update(cx, |codegen, cx| codegen.cycle_prev(cx));
+ }
+
+ fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
+ self.codegen
+ .update(cx, |codegen, cx| codegen.cycle_next(cx));
+ }
+
+ fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
+ let codegen = self.codegen.read(cx);
+ let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
+
+ h_flex()
+ .child(
+ IconButton::new("previous", IconName::ChevronLeft)
+ .icon_color(Color::Muted)
+ .disabled(disabled)
+ .shape(IconButtonShape::Square)
+ .tooltip({
+ let focus_handle = self.editor.focus_handle(cx);
+ move |cx| {
+ Tooltip::for_action_in(
+ "Previous Alternative",
+ &CyclePreviousInlineAssist,
+ &focus_handle,
+ cx,
+ )
+ }
+ })
+ .on_click(cx.listener(|this, _, cx| {
+ this.codegen
+ .update(cx, |codegen, cx| codegen.cycle_prev(cx))
+ })),
+ )
+ .child(
+ Label::new(format!(
+ "{}/{}",
+ codegen.active_alternative + 1,
+ codegen.alternative_count(cx)
+ ))
+ .size(LabelSize::Small)
+ .color(if disabled {
+ Color::Disabled
+ } else {
+ Color::Muted
+ }),
+ )
+ .child(
+ IconButton::new("next", IconName::ChevronRight)
+ .icon_color(Color::Muted)
+ .disabled(disabled)
+ .shape(IconButtonShape::Square)
+ .tooltip({
+ let focus_handle = self.editor.focus_handle(cx);
+ move |cx| {
+ Tooltip::for_action_in(
+ "Next Alternative",
+ &CycleNextInlineAssist,
+ &focus_handle,
+ cx,
+ )
+ }
+ })
+ .on_click(cx.listener(|this, _, cx| {
+ this.codegen
+ .update(cx, |codegen, cx| codegen.cycle_next(cx))
+ })),
+ )
+ .into_any_element()
+ }
+
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_counts = self.token_counts?;
@@ -2124,7 +2210,7 @@ impl InlineAssist {
return;
};
- if let CodegenStatus::Error(error) = &codegen.read(cx).status {
+ if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
if assist.decorations.is_none() {
if let Some(workspace) = assist
.workspace
@@ -2185,12 +2271,9 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
- self.codegen.read(cx).count_tokens(
- self.range.clone(),
- user_prompt,
- assistant_panel_context,
- cx,
- )
+ self.codegen
+ .read(cx)
+ .count_tokens(user_prompt, assistant_panel_context, cx)
}
}
@@ -2201,19 +2284,216 @@ struct InlineAssistDecorations {
end_block_id: CustomBlockId,
}
-#[derive(Debug)]
+#[derive(Copy, Clone, Debug)]
pub enum CodegenEvent {
Finished,
Undone,
}
pub struct Codegen {
+ alternatives: Vec<Model<CodegenAlternative>>,
+ active_alternative: usize,
+ subscriptions: Vec<Subscription>,
+ buffer: Model<MultiBuffer>,
+ range: Range<Anchor>,
+ initial_transaction_id: Option<TransactionId>,
+ telemetry: Option<Arc<Telemetry>>,
+ builder: Arc<PromptBuilder>,
+}
+
+impl Codegen {
+ pub fn new(
+ buffer: Model<MultiBuffer>,
+ range: Range<Anchor>,
+ initial_transaction_id: Option<TransactionId>,
+ telemetry: Option<Arc<Telemetry>>,
+ builder: Arc<PromptBuilder>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ false,
+ telemetry.clone(),
+ builder.clone(),
+ cx,
+ )
+ });
+ let mut this = Self {
+ alternatives: vec![codegen],
+ active_alternative: 0,
+ subscriptions: Vec::new(),
+ buffer,
+ range,
+ initial_transaction_id,
+ telemetry,
+ builder,
+ };
+ this.activate(0, cx);
+ this
+ }
+
+ fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
+ let codegen = self.active_alternative().clone();
+ self.subscriptions.clear();
+ self.subscriptions
+ .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
+ self.subscriptions
+ .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
+ }
+
+ fn active_alternative(&self) -> &Model<CodegenAlternative> {
+ &self.alternatives[self.active_alternative]
+ }
+
+ fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
+ &self.active_alternative().read(cx).status
+ }
+
+ fn alternative_count(&self, cx: &AppContext) -> usize {
+ LanguageModelRegistry::read_global(cx)
+ .inline_alternative_models()
+ .len()
+ + 1
+ }
+
+ pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
+ let next_active_ix = if self.active_alternative == 0 {
+ self.alternatives.len() - 1
+ } else {
+ self.active_alternative - 1
+ };
+ self.activate(next_active_ix, cx);
+ }
+
+ pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
+ let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
+ self.activate(next_active_ix, cx);
+ }
+
+ fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.set_active(false, cx));
+ self.active_alternative = index;
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.set_active(true, cx));
+ self.subscribe_to_alternative(cx);
+ cx.notify();
+ }
+
+ pub fn start(
+ &mut self,
+ user_prompt: String,
+ assistant_panel_context: Option<LanguageModelRequest>,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<()> {
+ let alternative_models = LanguageModelRegistry::read_global(cx)
+ .inline_alternative_models()
+ .to_vec();
+
+ self.active_alternative()
+ .update(cx, |alternative, cx| alternative.undo(cx));
+ self.activate(0, cx);
+ self.alternatives.truncate(1);
+
+ for _ in 0..alternative_models.len() {
+ self.alternatives.push(cx.new_model(|cx| {
+ CodegenAlternative::new(
+ self.buffer.clone(),
+ self.range.clone(),
+ false,
+ self.telemetry.clone(),
+ self.builder.clone(),
+ cx,
+ )
+ }));
+ }
+
+ let primary_model = LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .context("no active model")?;
+
+ for (model, alternative) in iter::once(primary_model)
+ .chain(alternative_models)
+ .zip(&self.alternatives)
+ {
+ alternative.update(cx, |alternative, cx| {
+ alternative.start(
+ user_prompt.clone(),
+ assistant_panel_context.clone(),
+ model.clone(),
+ cx,
+ )
+ })?;
+ }
+
+ Ok(())
+ }
+
+ pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
+ for codegen in &self.alternatives {
+ codegen.update(cx, |codegen, cx| codegen.stop(cx));
+ }
+ }
+
+ pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.undo(cx));
+
+ self.buffer.update(cx, |buffer, cx| {
+ if let Some(transaction_id) = self.initial_transaction_id.take() {
+ buffer.undo_transaction(transaction_id, cx);
+ buffer.refresh_preview(cx);
+ }
+ });
+ }
+
+ pub fn count_tokens(
+ &self,
+ user_prompt: String,
+ assistant_panel_context: Option<LanguageModelRequest>,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<TokenCounts>> {
+ self.active_alternative()
+ .read(cx)
+ .count_tokens(user_prompt, assistant_panel_context, cx)
+ }
+
+ pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
+ self.active_alternative().read(cx).buffer.clone()
+ }
+
+ pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
+ self.active_alternative().read(cx).old_buffer.clone()
+ }
+
+ pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
+ self.active_alternative().read(cx).snapshot.clone()
+ }
+
+ pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
+ self.active_alternative().read(cx).edit_position
+ }
+
+ fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
+ &self.active_alternative().read(cx).diff
+ }
+
+ pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
+ self.active_alternative().read(cx).last_equal_ranges()
+ }
+}
+
+impl EventEmitter<CodegenEvent> for Codegen {}
+
+pub struct CodegenAlternative {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
edit_position: Option<Anchor>,
+ range: Range<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
- initial_transaction_id: Option<TransactionId>,
transformation_transaction_id: Option<TransactionId>,
status: CodegenStatus,
generation: Task<()>,
@@ -2221,6 +2501,9 @@ pub struct Codegen {
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
+ active: bool,
+ edits: Vec<(Range<Anchor>, String)>,
+ line_operations: Vec<LineOperation>,
}
enum CodegenStatus {
@@ -2242,13 +2525,13 @@ impl Diff {
}
}
-impl EventEmitter<CodegenEvent> for Codegen {}
+impl EventEmitter<CodegenEvent> for CodegenAlternative {}
-impl Codegen {
+impl CodegenAlternative {
pub fn new(
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
- initial_transaction_id: Option<TransactionId>,
+ active: bool,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut ModelContext<Self>,
@@ -2287,8 +2570,33 @@ impl Codegen {
diff: Diff::default(),
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
- initial_transaction_id,
builder,
+ active,
+ edits: Vec::new(),
+ line_operations: Vec::new(),
+ range,
+ }
+ }
+
+ fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
+ if active != self.active {
+ self.active = active;
+
+ if self.active {
+ let edits = self.edits.clone();
+ self.apply_edits(edits, cx);
+ if matches!(self.status, CodegenStatus::Pending) {
+ let line_operations = self.line_operations.clone();
+ self.reapply_line_based_diff(line_operations, cx);
+ } else {
+ self.reapply_batch_diff(cx).detach();
+ }
+ } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.undo_transaction(transaction_id, cx);
+ buffer.forget_transaction(transaction_id, cx);
+ });
+ }
}
}
@@ -2313,14 +2621,12 @@ impl Codegen {
pub fn count_tokens(
&self,
- edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
- let request =
- self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
+ let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
match request {
Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx);
@@ -2345,39 +2651,31 @@ impl Codegen {
pub fn start(
&mut self,
- edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
+ model: Arc<dyn LanguageModel>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
- let model = LanguageModelRegistry::read_global(cx)
- .active_model()
- .context("no active model")?;
-
if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
self.buffer.update(cx, |buffer, cx| {
buffer.undo_transaction(transformation_transaction_id, cx);
});
}
- self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
+ self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id();
- let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
- .trim()
- .to_lowercase()
- == "delete"
- {
- async { Ok(stream::empty().boxed()) }.boxed_local()
- } else {
- let request =
- self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
+ let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
+ if user_prompt.trim().to_lowercase() == "delete" {
+ async { Ok(stream::empty().boxed()) }.boxed_local()
+ } else {
+ let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
- let chunks =
- cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
- async move { Ok(chunks.await?.boxed()) }.boxed_local()
- };
- self.handle_stream(telemetry_id, edit_range, chunks, cx);
+ let chunks = cx
+ .spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
+ async move { Ok(chunks.await?.boxed()) }.boxed_local()
+ };
+ self.handle_stream(telemetry_id, chunks, cx);
Ok(())
}
@@ -2385,11 +2683,10 @@ impl Codegen {
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
- edit_range: Range<Anchor>,
cx: &AppContext,
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
- let language = buffer.language_at(edit_range.start);
+ let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@@ -2401,8 +2698,8 @@ impl Codegen {
};
let language_name = language_name.as_ref();
- let start = buffer.point_to_buffer_offset(edit_range.start);
- let end = buffer.point_to_buffer_offset(edit_range.end);
+ let start = buffer.point_to_buffer_offset(self.range.start);
+ let end = buffer.point_to_buffer_offset(self.range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
@@ -2442,16 +2739,15 @@ impl Codegen {
pub fn handle_stream(
&mut self,
model_telemetry_id: String,
- edit_range: Range<Anchor>,
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
cx: &mut ModelContext<Self>,
) {
let snapshot = self.snapshot.clone();
let selected_text = snapshot
- .text_for_range(edit_range.start..edit_range.end)
+ .text_for_range(self.range.start..self.range.end)
.collect::<Rope>();
- let selection_start = edit_range.start.to_point(&snapshot);
+ let selection_start = self.range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection
let mut suggested_line_indent = snapshot
@@ -2462,7 +2758,7 @@ impl Codegen {
// If the first line in the selection does not have indentation, check the following lines
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
- for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
+ for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
// Prefer tabs if a line in the selection uses tabs as indentation
if line_indent.kind == IndentKind::Tab {
@@ -2475,7 +2771,7 @@ impl Codegen {
let telemetry = self.telemetry.clone();
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
- let mut edit_start = edit_range.start.to_offset(&snapshot);
+ let mut edit_start = self.range.start.to_offset(&snapshot);
self.generation = cx.spawn(|codegen, mut cx| {
async move {
let chunks = stream.await;
@@ -2597,68 +2893,42 @@ impl Codegen {
Ok(())
});
- while let Some((char_ops, line_diff)) = diff_rx.next().await {
+ while let Some((char_ops, line_ops)) = diff_rx.next().await {
codegen.update(&mut cx, |codegen, cx| {
codegen.last_equal_ranges.clear();
- let transaction = codegen.buffer.update(cx, |buffer, cx| {
- // Avoid grouping assistant edits with user edits.
- buffer.finalize_last_transaction(cx);
-
- buffer.start_transaction(cx);
- buffer.edit(
- char_ops
- .into_iter()
- .filter_map(|operation| match operation {
- CharOperation::Insert { text } => {
- let edit_start = snapshot.anchor_after(edit_start);
- Some((edit_start..edit_start, text))
- }
- CharOperation::Delete { bytes } => {
- let edit_end = edit_start + bytes;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start = edit_end;
- Some((edit_range, String::new()))
- }
- CharOperation::Keep { bytes } => {
- let edit_end = edit_start + bytes;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start = edit_end;
- codegen.last_equal_ranges.push(edit_range);
- None
- }
- }),
- None,
- cx,
- );
- codegen.edit_position = Some(snapshot.anchor_after(edit_start));
-
- buffer.end_transaction(cx)
- });
+ let edits = char_ops
+ .into_iter()
+ .filter_map(|operation| match operation {
+ CharOperation::Insert { text } => {
+ let edit_start = snapshot.anchor_after(edit_start);
+ Some((edit_start..edit_start, text))
+ }
+ CharOperation::Delete { bytes } => {
+ let edit_end = edit_start + bytes;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ Some((edit_range, String::new()))
+ }
+ CharOperation::Keep { bytes } => {
+ let edit_end = edit_start + bytes;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ codegen.last_equal_ranges.push(edit_range);
+ None
+ }
+ })
+ .collect::<Vec<_>>();
- if let Some(transaction) = transaction {
- if let Some(first_transaction) =
- codegen.transformation_transaction_id
- {
- // Group all assistant edits into the first transaction.
- codegen.buffer.update(cx, |buffer, cx| {
- buffer.merge_transactions(
- transaction,
- first_transaction,
- cx,
- )
- });
- } else {
- codegen.transformation_transaction_id = Some(transaction);
- codegen.buffer.update(cx, |buffer, cx| {
- buffer.finalize_last_transaction(cx)
- });
- }
+ if codegen.active {
+ codegen.apply_edits(edits.iter().cloned(), cx);
+ codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
}
-
- codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
+ codegen.edits.extend(edits);
+ codegen.line_operations = line_ops;
+ codegen.edit_position = Some(snapshot.anchor_after(edit_start));
cx.notify();
})?;
@@ -2667,9 +2937,8 @@ impl Codegen {
// Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
// That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
// It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
- let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
- codegen.reapply_batch_diff(edit_range.clone(), cx)
- })?;
+ let batch_diff_task =
+ codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
let (line_based_stream_diff, ()) =
join!(line_based_stream_diff, batch_diff_task);
line_based_stream_diff?;
@@ -2713,24 +2982,45 @@ impl Codegen {
buffer.undo_transaction(transaction_id, cx);
buffer.refresh_preview(cx);
}
+ });
+ }
- if let Some(transaction_id) = self.initial_transaction_id.take() {
- buffer.undo_transaction(transaction_id, cx);
- buffer.refresh_preview(cx);
- }
+ fn apply_edits(
+ &mut self,
+ edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
+ cx: &mut ModelContext<CodegenAlternative>,
+ ) {
+ let transaction = self.buffer.update(cx, |buffer, cx| {
+ // Avoid grouping assistant edits with user edits.
+ buffer.finalize_last_transaction(cx);
+ buffer.start_transaction(cx);
+ buffer.edit(edits, None, cx);
+ buffer.end_transaction(cx)
});
+
+ if let Some(transaction) = transaction {
+ if let Some(first_transaction) = self.transformation_transaction_id {
+ // Group all assistant edits into the first transaction.
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.merge_transactions(transaction, first_transaction, cx)
+ });
+ } else {
+ self.transformation_transaction_id = Some(transaction);
+ self.buffer
+ .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
+ }
+ }
}
fn reapply_line_based_diff(
&mut self,
- edit_range: Range<Anchor>,
- line_operations: Vec<LineOperation>,
+ line_operations: impl IntoIterator<Item = LineOperation>,
cx: &mut ModelContext<Self>,
) {
let old_snapshot = self.snapshot.clone();
- let old_range = edit_range.to_point(&old_snapshot);
+ let old_range = self.range.to_point(&old_snapshot);
let new_snapshot = self.buffer.read(cx).snapshot(cx);
- let new_range = edit_range.to_point(&new_snapshot);
+ let new_range = self.range.to_point(&new_snapshot);
let mut old_row = old_range.start.row;
let mut new_row = new_range.start.row;
@@ -2781,15 +3071,11 @@ impl Codegen {
}
}
- fn reapply_batch_diff(
- &mut self,
- edit_range: Range<Anchor>,
- cx: &mut ModelContext<Self>,
- ) -> Task<()> {
+ fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
let old_snapshot = self.snapshot.clone();
- let old_range = edit_range.to_point(&old_snapshot);
+ let old_range = self.range.to_point(&old_snapshot);
let new_snapshot = self.buffer.read(cx).snapshot(cx);
- let new_range = edit_range.to_point(&new_snapshot);
+ let new_range = self.range.to_point(&new_snapshot);
cx.spawn(|codegen, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx
@@ -3073,10 +3359,10 @@ mod tests {
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
- Codegen::new(
+ CodegenAlternative::new(
buffer.clone(),
range.clone(),
- None,
+ true,
None,
prompt_builder,
cx,