Introduce the ability to cycle between alternative inline assists (#18098)

Antonio Scandurra , Nathan , Roy , and Adam created

Release Notes:

- Added a new `assistant.inline_alternatives` setting to configure
additional models that will be used to perform inline assists in
parallel.

---------

Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Roy <roy@anthropic.com>
Co-authored-by: Adam <wolffiex@anthropic.com>

Change summary

assets/keymaps/default-linux.json          |   7 
assets/keymaps/default-macos.json          |   7 
crates/assistant/src/assistant.rs          |  13 
crates/assistant/src/assistant_settings.rs |  13 
crates/assistant/src/inline_assistant.rs   | 612 +++++++++++++++++------
crates/language_model/src/registry.rs      |  32 +
crates/multi_buffer/src/multi_buffer.rs    |  20 
docs/src/assistant/configuration.md        |  26 +
8 files changed, 563 insertions(+), 167 deletions(-)

Detailed changes

assets/keymaps/default-linux.json 🔗

@@ -520,6 +520,13 @@
       "alt-enter": "editor::Newline"
     }
   },
+  {
+    "context": "PromptEditor",
+    "bindings": {
+      "ctrl-[": "assistant::CyclePreviousInlineAssist",
+      "ctrl-]": "assistant::CycleNextInlineAssist"
+    }
+  },
   {
     "context": "ProjectSearchBar && !in_replace",
     "bindings": {

assets/keymaps/default-macos.json 🔗

@@ -527,6 +527,13 @@
       "ctrl-enter": "assistant::InlineAssist"
     }
   },
+  {
+    "context": "PromptEditor",
+    "bindings": {
+      "ctrl-[": "assistant::CyclePreviousInlineAssist",
+      "ctrl-]": "assistant::CycleNextInlineAssist"
+    }
+  },
   {
     "context": "ProjectSearchBar && !in_replace",
     "bindings": {

crates/assistant/src/assistant.rs 🔗

@@ -69,6 +69,8 @@ actions!(
         ConfirmCommand,
         NewContext,
         ToggleModelSelector,
+        CycleNextInlineAssist,
+        CyclePreviousInlineAssist
     ]
 );
 
@@ -359,8 +361,19 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
     let settings = AssistantSettings::get_global(cx);
     let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
     let model_id = LanguageModelId::from(settings.default_model.model.clone());
+    let inline_alternatives = settings
+        .inline_alternatives
+        .iter()
+        .map(|alternative| {
+            (
+                LanguageModelProviderId::from(alternative.provider.clone()),
+                LanguageModelId::from(alternative.model.clone()),
+            )
+        })
+        .collect::<Vec<_>>();
     LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
         registry.select_active_model(&provider_name, &model_id, cx);
+        registry.select_inline_alternative_models(inline_alternatives, cx);
     });
 }
 

crates/assistant/src/assistant_settings.rs 🔗

@@ -59,6 +59,7 @@ pub struct AssistantSettings {
     pub default_width: Pixels,
     pub default_height: Pixels,
     pub default_model: LanguageModelSelection,
+    pub inline_alternatives: Vec<LanguageModelSelection>,
     pub using_outdated_settings_version: bool,
 }
 
@@ -236,6 +237,7 @@ impl AssistantSettingsContent {
                                 })
                             }
                         }),
+                    inline_alternatives: None,
                 },
                 VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
             },
@@ -254,6 +256,7 @@ impl AssistantSettingsContent {
                         .id()
                         .to_string(),
                 }),
+                inline_alternatives: None,
             },
         }
     }
@@ -369,6 +372,7 @@ impl Default for VersionedAssistantSettingsContent {
             default_width: None,
             default_height: None,
             default_model: None,
+            inline_alternatives: None,
         })
     }
 }
@@ -397,6 +401,8 @@ pub struct AssistantSettingsContentV2 {
     default_height: Option<f32>,
     /// The default model to use when creating new contexts.
     default_model: Option<LanguageModelSelection>,
+    /// Additional models with which to generate alternatives when performing inline assists.
+    inline_alternatives: Option<Vec<LanguageModelSelection>>,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -517,10 +523,8 @@ impl Settings for AssistantSettings {
                 &mut settings.default_height,
                 value.default_height.map(Into::into),
             );
-            merge(
-                &mut settings.default_model,
-                value.default_model.map(Into::into),
-            );
+            merge(&mut settings.default_model, value.default_model);
+            merge(&mut settings.inline_alternatives, value.inline_alternatives);
             // merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
         }
 
@@ -574,6 +578,7 @@ mod tests {
                                 provider: "test-provider".into(),
                                 model: "gpt-99".into(),
                             }),
+                            inline_alternatives: None,
                             enabled: None,
                             button: None,
                             dock: None,

crates/assistant/src/inline_assistant.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
-    AssistantPanel, AssistantPanelEvent, CharOperation, LineDiff, LineOperation, ModelSelector,
-    StreamingDiff,
+    AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
+    CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, StreamingDiff,
 };
 use anyhow::{anyhow, Context as _, Result};
 use client::{telemetry::Telemetry, ErrorExt};
@@ -25,13 +25,13 @@ use futures::{
     SinkExt, Stream, StreamExt,
 };
 use gpui::{
-    anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
-    FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
-    UpdateGlobal, View, ViewContext, WeakView, WindowContext,
+    anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
+    FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
+    TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
 };
 use language::{Buffer, IndentKind, Point, Selection, TransactionId};
 use language_model::{
-    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+    LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 };
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
@@ -41,7 +41,7 @@ use smol::future::FutureExt;
 use std::{
     cmp,
     future::{self, Future},
-    mem,
+    iter, mem,
     ops::{Range, RangeInclusive},
     pin::Pin,
     sync::Arc,
@@ -85,7 +85,7 @@ pub struct InlineAssistant {
             async_watch::Receiver<AssistStatus>,
         ),
     >,
-    confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
+    confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
     prompt_history: VecDeque<String>,
     prompt_builder: Arc<PromptBuilder>,
     telemetry: Option<Arc<Telemetry>>,
@@ -157,7 +157,7 @@ impl InlineAssistant {
                 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
                     for assist_id in editor_assists.assist_ids.clone() {
                         let assist = &self.assists[&assist_id];
-                        if let CodegenStatus::Done = &assist.codegen.read(cx).status {
+                        if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
                             self.finish_assist(assist_id, false, cx)
                         }
                     }
@@ -553,7 +553,7 @@ impl InlineAssistant {
                 let assist_range = assist.range.to_offset(&buffer);
                 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
                 {
-                    if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
+                    if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
                         self.dismiss_assist(*assist_id, cx);
                     } else {
                         self.finish_assist(*assist_id, false, cx);
@@ -671,7 +671,7 @@ impl InlineAssistant {
                 for assist_id in editor_assists.assist_ids.clone() {
                     let assist = &self.assists[&assist_id];
                     if matches!(
-                        assist.codegen.read(cx).status,
+                        assist.codegen.read(cx).status(cx),
                         CodegenStatus::Error(_) | CodegenStatus::Done
                     ) {
                         let assist_range = assist.range.to_offset(&snapshot);
@@ -774,7 +774,9 @@ impl InlineAssistant {
             if undo {
                 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
             } else {
-                self.confirmed_assists.insert(assist_id, assist.codegen);
+                let confirmed_alternative = assist.codegen.read(cx).active_alternative().clone();
+                self.confirmed_assists
+                    .insert(assist_id, confirmed_alternative);
             }
         }
 
@@ -978,12 +980,7 @@ impl InlineAssistant {
         assist
             .codegen
             .update(cx, |codegen, cx| {
-                codegen.start(
-                    assist.range.clone(),
-                    user_prompt,
-                    assistant_panel_context,
-                    cx,
-                )
+                codegen.start(user_prompt, assistant_panel_context, cx)
             })
             .log_err();
 
@@ -1008,7 +1005,7 @@ impl InlineAssistant {
 
     pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
         if let Some(assist) = self.assists.get(&assist_id) {
-            match &assist.codegen.read(cx).status {
+            match assist.codegen.read(cx).status(cx) {
                 CodegenStatus::Idle => InlineAssistStatus::Idle,
                 CodegenStatus::Pending => InlineAssistStatus::Pending,
                 CodegenStatus::Done => InlineAssistStatus::Done,
@@ -1037,16 +1034,16 @@ impl InlineAssistant {
         for assist_id in assist_ids {
             if let Some(assist) = self.assists.get(assist_id) {
                 let codegen = assist.codegen.read(cx);
-                let buffer = codegen.buffer.read(cx).read(cx);
-                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
+                let buffer = codegen.buffer(cx).read(cx).read(cx);
+                foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
 
                 let pending_range =
-                    codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
+                    codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
                 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
                     gutter_pending_ranges.push(pending_range);
                 }
 
-                if let Some(edit_position) = codegen.edit_position {
+                if let Some(edit_position) = codegen.edit_position(cx) {
                     let edited_range = assist.range.start..edit_position;
                     if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
                         gutter_transformed_ranges.push(edited_range);
@@ -1054,7 +1051,8 @@ impl InlineAssistant {
                 }
 
                 if assist.decorations.is_some() {
-                    inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
+                    inserted_row_ranges
+                        .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
                 }
             }
         }
@@ -1125,9 +1123,9 @@ impl InlineAssistant {
         };
 
         let codegen = assist.codegen.read(cx);
-        let old_snapshot = codegen.snapshot.clone();
-        let old_buffer = codegen.old_buffer.clone();
-        let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
+        let old_snapshot = codegen.snapshot(cx);
+        let old_buffer = codegen.old_buffer(cx);
+        let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
 
         editor.update(cx, |editor, cx| {
             let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
@@ -1406,8 +1404,15 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
 impl Render for PromptEditor {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let gutter_dimensions = *self.gutter_dimensions.lock();
-        let status = &self.codegen.read(cx).status;
-        let buttons = match status {
+        let codegen = self.codegen.read(cx);
+
+        let mut buttons = Vec::new();
+        if codegen.alternative_count(cx) > 1 {
+            buttons.push(self.render_cycle_controls(cx));
+        }
+
+        let status = codegen.status(cx);
+        buttons.extend(match status {
             CodegenStatus::Idle => {
                 vec![
                     IconButton::new("cancel", IconName::Close)
@@ -1416,14 +1421,16 @@ impl Render for PromptEditor {
                         .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
                         .on_click(
                             cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
-                        ),
+                        )
+                        .into_any_element(),
                     IconButton::new("start", IconName::SparkleAlt)
                         .icon_color(Color::Muted)
                         .shape(IconButtonShape::Square)
                         .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
                         .on_click(
                             cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
-                        ),
+                        )
+                        .into_any_element(),
                 ]
             }
             CodegenStatus::Pending => {
@@ -1434,7 +1441,8 @@ impl Render for PromptEditor {
                         .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
                         .on_click(
                             cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
-                        ),
+                        )
+                        .into_any_element(),
                     IconButton::new("stop", IconName::Stop)
                         .icon_color(Color::Error)
                         .shape(IconButtonShape::Square)
@@ -1446,9 +1454,8 @@ impl Render for PromptEditor {
                                 cx,
                             )
                         })
-                        .on_click(
-                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
-                        ),
+                        .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
+                        .into_any_element(),
                 ]
             }
             CodegenStatus::Error(_) | CodegenStatus::Done => {
@@ -1459,7 +1466,8 @@ impl Render for PromptEditor {
                         .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
                         .on_click(
                             cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
-                        ),
+                        )
+                        .into_any_element(),
                     if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
                         IconButton::new("restart", IconName::RotateCw)
                             .icon_color(Color::Info)
@@ -1475,6 +1483,7 @@ impl Render for PromptEditor {
                             .on_click(cx.listener(|_, _, cx| {
                                 cx.emit(PromptEditorEvent::StartRequested);
                             }))
+                            .into_any_element()
                     } else {
                         IconButton::new("confirm", IconName::Check)
                             .icon_color(Color::Info)
@@ -1483,12 +1492,14 @@ impl Render for PromptEditor {
                             .on_click(cx.listener(|_, _, cx| {
                                 cx.emit(PromptEditorEvent::ConfirmRequested);
                             }))
+                            .into_any_element()
                     },
                 ]
             }
-        };
+        });
 
         h_flex()
+            .key_context("PromptEditor")
             .bg(cx.theme().colors().editor_background)
             .border_y_1()
             .border_color(cx.theme().status().info_border)
@@ -1498,6 +1509,8 @@ impl Render for PromptEditor {
             .on_action(cx.listener(Self::cancel))
             .on_action(cx.listener(Self::move_up))
             .on_action(cx.listener(Self::move_down))
+            .capture_action(cx.listener(Self::cycle_prev))
+            .capture_action(cx.listener(Self::cycle_next))
             .child(
                 h_flex()
                     .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
@@ -1532,7 +1545,7 @@ impl Render for PromptEditor {
                         ),
                     )
                     .map(|el| {
-                        let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
+                        let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
                             return el;
                         };
 
@@ -1776,7 +1789,7 @@ impl PromptEditor {
     }
 
     fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
-        match &self.codegen.read(cx).status {
+        match self.codegen.read(cx).status(cx) {
             CodegenStatus::Idle => {
                 self.editor
                     .update(cx, |editor, _| editor.set_read_only(false));
@@ -1807,7 +1820,7 @@ impl PromptEditor {
     }
 
     fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
-        match &self.codegen.read(cx).status {
+        match self.codegen.read(cx).status(cx) {
             CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
                 cx.emit(PromptEditorEvent::CancelRequested);
             }
@@ -1818,7 +1831,7 @@ impl PromptEditor {
     }
 
     fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
-        match &self.codegen.read(cx).status {
+        match self.codegen.read(cx).status(cx) {
             CodegenStatus::Idle => {
                 cx.emit(PromptEditorEvent::StartRequested);
             }
@@ -1878,6 +1891,79 @@ impl PromptEditor {
         }
     }
 
+    fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
+        self.codegen
+            .update(cx, |codegen, cx| codegen.cycle_prev(cx));
+    }
+
+    fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
+        self.codegen
+            .update(cx, |codegen, cx| codegen.cycle_next(cx));
+    }
+
+    fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
+        let codegen = self.codegen.read(cx);
+        let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
+
+        h_flex()
+            .child(
+                IconButton::new("previous", IconName::ChevronLeft)
+                    .icon_color(Color::Muted)
+                    .disabled(disabled)
+                    .shape(IconButtonShape::Square)
+                    .tooltip({
+                        let focus_handle = self.editor.focus_handle(cx);
+                        move |cx| {
+                            Tooltip::for_action_in(
+                                "Previous Alternative",
+                                &CyclePreviousInlineAssist,
+                                &focus_handle,
+                                cx,
+                            )
+                        }
+                    })
+                    .on_click(cx.listener(|this, _, cx| {
+                        this.codegen
+                            .update(cx, |codegen, cx| codegen.cycle_prev(cx))
+                    })),
+            )
+            .child(
+                Label::new(format!(
+                    "{}/{}",
+                    codegen.active_alternative + 1,
+                    codegen.alternative_count(cx)
+                ))
+                .size(LabelSize::Small)
+                .color(if disabled {
+                    Color::Disabled
+                } else {
+                    Color::Muted
+                }),
+            )
+            .child(
+                IconButton::new("next", IconName::ChevronRight)
+                    .icon_color(Color::Muted)
+                    .disabled(disabled)
+                    .shape(IconButtonShape::Square)
+                    .tooltip({
+                        let focus_handle = self.editor.focus_handle(cx);
+                        move |cx| {
+                            Tooltip::for_action_in(
+                                "Next Alternative",
+                                &CycleNextInlineAssist,
+                                &focus_handle,
+                                cx,
+                            )
+                        }
+                    })
+                    .on_click(cx.listener(|this, _, cx| {
+                        this.codegen
+                            .update(cx, |codegen, cx| codegen.cycle_next(cx))
+                    })),
+            )
+            .into_any_element()
+    }
+
     fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
         let model = LanguageModelRegistry::read_global(cx).active_model()?;
         let token_counts = self.token_counts?;
@@ -2124,7 +2210,7 @@ impl InlineAssist {
                                 return;
                             };
 
-                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
+                            if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
                                 if assist.decorations.is_none() {
                                     if let Some(workspace) = assist
                                         .workspace
@@ -2185,12 +2271,9 @@ impl InlineAssist {
             return future::ready(Err(anyhow!("no user prompt"))).boxed();
         };
         let assistant_panel_context = self.assistant_panel_context(cx);
-        self.codegen.read(cx).count_tokens(
-            self.range.clone(),
-            user_prompt,
-            assistant_panel_context,
-            cx,
-        )
+        self.codegen
+            .read(cx)
+            .count_tokens(user_prompt, assistant_panel_context, cx)
     }
 }
 
@@ -2201,19 +2284,216 @@ struct InlineAssistDecorations {
     end_block_id: CustomBlockId,
 }
 
-#[derive(Debug)]
+#[derive(Copy, Clone, Debug)]
 pub enum CodegenEvent {
     Finished,
     Undone,
 }
 
 pub struct Codegen {
+    alternatives: Vec<Model<CodegenAlternative>>,
+    active_alternative: usize,
+    subscriptions: Vec<Subscription>,
+    buffer: Model<MultiBuffer>,
+    range: Range<Anchor>,
+    initial_transaction_id: Option<TransactionId>,
+    telemetry: Option<Arc<Telemetry>>,
+    builder: Arc<PromptBuilder>,
+}
+
+impl Codegen {
+    pub fn new(
+        buffer: Model<MultiBuffer>,
+        range: Range<Anchor>,
+        initial_transaction_id: Option<TransactionId>,
+        telemetry: Option<Arc<Telemetry>>,
+        builder: Arc<PromptBuilder>,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
+        let codegen = cx.new_model(|cx| {
+            CodegenAlternative::new(
+                buffer.clone(),
+                range.clone(),
+                false,
+                telemetry.clone(),
+                builder.clone(),
+                cx,
+            )
+        });
+        let mut this = Self {
+            alternatives: vec![codegen],
+            active_alternative: 0,
+            subscriptions: Vec::new(),
+            buffer,
+            range,
+            initial_transaction_id,
+            telemetry,
+            builder,
+        };
+        this.activate(0, cx);
+        this
+    }
+
+    fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
+        let codegen = self.active_alternative().clone();
+        self.subscriptions.clear();
+        self.subscriptions
+            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
+        self.subscriptions
+            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
+    }
+
+    fn active_alternative(&self) -> &Model<CodegenAlternative> {
+        &self.alternatives[self.active_alternative]
+    }
+
+    fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
+        &self.active_alternative().read(cx).status
+    }
+
+    fn alternative_count(&self, cx: &AppContext) -> usize {
+        LanguageModelRegistry::read_global(cx)
+            .inline_alternative_models()
+            .len()
+            + 1
+    }
+
+    pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
+        let next_active_ix = if self.active_alternative == 0 {
+            self.alternatives.len() - 1
+        } else {
+            self.active_alternative - 1
+        };
+        self.activate(next_active_ix, cx);
+    }
+
+    pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
+        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
+        self.activate(next_active_ix, cx);
+    }
+
+    fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
+        self.active_alternative()
+            .update(cx, |codegen, cx| codegen.set_active(false, cx));
+        self.active_alternative = index;
+        self.active_alternative()
+            .update(cx, |codegen, cx| codegen.set_active(true, cx));
+        self.subscribe_to_alternative(cx);
+        cx.notify();
+    }
+
+    pub fn start(
+        &mut self,
+        user_prompt: String,
+        assistant_panel_context: Option<LanguageModelRequest>,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<()> {
+        let alternative_models = LanguageModelRegistry::read_global(cx)
+            .inline_alternative_models()
+            .to_vec();
+
+        self.active_alternative()
+            .update(cx, |alternative, cx| alternative.undo(cx));
+        self.activate(0, cx);
+        self.alternatives.truncate(1);
+
+        for _ in 0..alternative_models.len() {
+            self.alternatives.push(cx.new_model(|cx| {
+                CodegenAlternative::new(
+                    self.buffer.clone(),
+                    self.range.clone(),
+                    false,
+                    self.telemetry.clone(),
+                    self.builder.clone(),
+                    cx,
+                )
+            }));
+        }
+
+        let primary_model = LanguageModelRegistry::read_global(cx)
+            .active_model()
+            .context("no active model")?;
+
+        for (model, alternative) in iter::once(primary_model)
+            .chain(alternative_models)
+            .zip(&self.alternatives)
+        {
+            alternative.update(cx, |alternative, cx| {
+                alternative.start(
+                    user_prompt.clone(),
+                    assistant_panel_context.clone(),
+                    model.clone(),
+                    cx,
+                )
+            })?;
+        }
+
+        Ok(())
+    }
+
+    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
+        for codegen in &self.alternatives {
+            codegen.update(cx, |codegen, cx| codegen.stop(cx));
+        }
+    }
+
+    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+        self.active_alternative()
+            .update(cx, |codegen, cx| codegen.undo(cx));
+
+        self.buffer.update(cx, |buffer, cx| {
+            if let Some(transaction_id) = self.initial_transaction_id.take() {
+                buffer.undo_transaction(transaction_id, cx);
+                buffer.refresh_preview(cx);
+            }
+        });
+    }
+
+    pub fn count_tokens(
+        &self,
+        user_prompt: String,
+        assistant_panel_context: Option<LanguageModelRequest>,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<TokenCounts>> {
+        self.active_alternative()
+            .read(cx)
+            .count_tokens(user_prompt, assistant_panel_context, cx)
+    }
+
+    pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
+        self.active_alternative().read(cx).buffer.clone()
+    }
+
+    pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
+        self.active_alternative().read(cx).old_buffer.clone()
+    }
+
+    pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
+        self.active_alternative().read(cx).snapshot.clone()
+    }
+
+    pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
+        self.active_alternative().read(cx).edit_position
+    }
+
+    fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
+        &self.active_alternative().read(cx).diff
+    }
+
+    pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
+        self.active_alternative().read(cx).last_equal_ranges()
+    }
+}
+
+impl EventEmitter<CodegenEvent> for Codegen {}
+
+pub struct CodegenAlternative {
     buffer: Model<MultiBuffer>,
     old_buffer: Model<Buffer>,
     snapshot: MultiBufferSnapshot,
     edit_position: Option<Anchor>,
+    range: Range<Anchor>,
     last_equal_ranges: Vec<Range<Anchor>>,
-    initial_transaction_id: Option<TransactionId>,
     transformation_transaction_id: Option<TransactionId>,
     status: CodegenStatus,
     generation: Task<()>,
@@ -2221,6 +2501,9 @@ pub struct Codegen {
     telemetry: Option<Arc<Telemetry>>,
     _subscription: gpui::Subscription,
     builder: Arc<PromptBuilder>,
+    active: bool,
+    edits: Vec<(Range<Anchor>, String)>,
+    line_operations: Vec<LineOperation>,
 }
 
 enum CodegenStatus {
@@ -2242,13 +2525,13 @@ impl Diff {
     }
 }
 
-impl EventEmitter<CodegenEvent> for Codegen {}
+impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 
-impl Codegen {
+impl CodegenAlternative {
     pub fn new(
         buffer: Model<MultiBuffer>,
         range: Range<Anchor>,
-        initial_transaction_id: Option<TransactionId>,
+        active: bool,
         telemetry: Option<Arc<Telemetry>>,
         builder: Arc<PromptBuilder>,
         cx: &mut ModelContext<Self>,
@@ -2287,8 +2570,33 @@ impl Codegen {
             diff: Diff::default(),
             telemetry,
             _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
-            initial_transaction_id,
             builder,
+            active,
+            edits: Vec::new(),
+            line_operations: Vec::new(),
+            range,
+        }
+    }
+
+    fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
+        if active != self.active {
+            self.active = active;
+
+            if self.active {
+                let edits = self.edits.clone();
+                self.apply_edits(edits, cx);
+                if matches!(self.status, CodegenStatus::Pending) {
+                    let line_operations = self.line_operations.clone();
+                    self.reapply_line_based_diff(line_operations, cx);
+                } else {
+                    self.reapply_batch_diff(cx).detach();
+                }
+            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
+                self.buffer.update(cx, |buffer, cx| {
+                    buffer.undo_transaction(transaction_id, cx);
+                    buffer.forget_transaction(transaction_id, cx);
+                });
+            }
         }
     }
 
@@ -2313,14 +2621,12 @@ impl Codegen {
 
     pub fn count_tokens(
         &self,
-        edit_range: Range<Anchor>,
         user_prompt: String,
         assistant_panel_context: Option<LanguageModelRequest>,
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<TokenCounts>> {
         if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
-            let request =
-                self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
+            let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
             match request {
                 Ok(request) => {
                     let total_count = model.count_tokens(request.clone(), cx);
@@ -2345,39 +2651,31 @@ impl Codegen {
 
     pub fn start(
         &mut self,
-        edit_range: Range<Anchor>,
         user_prompt: String,
         assistant_panel_context: Option<LanguageModelRequest>,
+        model: Arc<dyn LanguageModel>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
-        let model = LanguageModelRegistry::read_global(cx)
-            .active_model()
-            .context("no active model")?;
-
         if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
             self.buffer.update(cx, |buffer, cx| {
                 buffer.undo_transaction(transformation_transaction_id, cx);
             });
         }
 
-        self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
+        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 
         let telemetry_id = model.telemetry_id();
-        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
-            .trim()
-            .to_lowercase()
-            == "delete"
-        {
-            async { Ok(stream::empty().boxed()) }.boxed_local()
-        } else {
-            let request =
-                self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
+        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
+            if user_prompt.trim().to_lowercase() == "delete" {
+                async { Ok(stream::empty().boxed()) }.boxed_local()
+            } else {
+                let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
 
-            let chunks =
-                cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
-            async move { Ok(chunks.await?.boxed()) }.boxed_local()
-        };
-        self.handle_stream(telemetry_id, edit_range, chunks, cx);
+                let chunks = cx
+                    .spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
+                async move { Ok(chunks.await?.boxed()) }.boxed_local()
+            };
+        self.handle_stream(telemetry_id, chunks, cx);
         Ok(())
     }
 
@@ -2385,11 +2683,10 @@ impl Codegen {
         &self,
         user_prompt: String,
         assistant_panel_context: Option<LanguageModelRequest>,
-        edit_range: Range<Anchor>,
         cx: &AppContext,
     ) -> Result<LanguageModelRequest> {
         let buffer = self.buffer.read(cx).snapshot(cx);
-        let language = buffer.language_at(edit_range.start);
+        let language = buffer.language_at(self.range.start);
         let language_name = if let Some(language) = language.as_ref() {
             if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
                 None
@@ -2401,8 +2698,8 @@ impl Codegen {
         };
 
         let language_name = language_name.as_ref();
-        let start = buffer.point_to_buffer_offset(edit_range.start);
-        let end = buffer.point_to_buffer_offset(edit_range.end);
+        let start = buffer.point_to_buffer_offset(self.range.start);
+        let end = buffer.point_to_buffer_offset(self.range.end);
         let (buffer, range) = if let Some((start, end)) = start.zip(end) {
             let (start_buffer, start_buffer_offset) = start;
             let (end_buffer, end_buffer_offset) = end;
@@ -2442,16 +2739,15 @@ impl Codegen {
     pub fn handle_stream(
         &mut self,
         model_telemetry_id: String,
-        edit_range: Range<Anchor>,
         stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
         cx: &mut ModelContext<Self>,
     ) {
         let snapshot = self.snapshot.clone();
         let selected_text = snapshot
-            .text_for_range(edit_range.start..edit_range.end)
+            .text_for_range(self.range.start..self.range.end)
             .collect::<Rope>();
 
-        let selection_start = edit_range.start.to_point(&snapshot);
+        let selection_start = self.range.start.to_point(&snapshot);
 
         // Start with the indentation of the first line in the selection
         let mut suggested_line_indent = snapshot
@@ -2462,7 +2758,7 @@ impl Codegen {
 
         // If the first line in the selection does not have indentation, check the following lines
         if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
-            for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
+            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
                 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
                 // Prefer tabs if a line in the selection uses tabs as indentation
                 if line_indent.kind == IndentKind::Tab {
@@ -2475,7 +2771,7 @@ impl Codegen {
         let telemetry = self.telemetry.clone();
         self.diff = Diff::default();
         self.status = CodegenStatus::Pending;
-        let mut edit_start = edit_range.start.to_offset(&snapshot);
+        let mut edit_start = self.range.start.to_offset(&snapshot);
         self.generation = cx.spawn(|codegen, mut cx| {
             async move {
                 let chunks = stream.await;
@@ -2597,68 +2893,42 @@ impl Codegen {
                             Ok(())
                         });
 
-                    while let Some((char_ops, line_diff)) = diff_rx.next().await {
+                    while let Some((char_ops, line_ops)) = diff_rx.next().await {
                         codegen.update(&mut cx, |codegen, cx| {
                             codegen.last_equal_ranges.clear();
 
-                            let transaction = codegen.buffer.update(cx, |buffer, cx| {
-                                // Avoid grouping assistant edits with user edits.
-                                buffer.finalize_last_transaction(cx);
-
-                                buffer.start_transaction(cx);
-                                buffer.edit(
-                                    char_ops
-                                        .into_iter()
-                                        .filter_map(|operation| match operation {
-                                            CharOperation::Insert { text } => {
-                                                let edit_start = snapshot.anchor_after(edit_start);
-                                                Some((edit_start..edit_start, text))
-                                            }
-                                            CharOperation::Delete { bytes } => {
-                                                let edit_end = edit_start + bytes;
-                                                let edit_range = snapshot.anchor_after(edit_start)
-                                                    ..snapshot.anchor_before(edit_end);
-                                                edit_start = edit_end;
-                                                Some((edit_range, String::new()))
-                                            }
-                                            CharOperation::Keep { bytes } => {
-                                                let edit_end = edit_start + bytes;
-                                                let edit_range = snapshot.anchor_after(edit_start)
-                                                    ..snapshot.anchor_before(edit_end);
-                                                edit_start = edit_end;
-                                                codegen.last_equal_ranges.push(edit_range);
-                                                None
-                                            }
-                                        }),
-                                    None,
-                                    cx,
-                                );
-                                codegen.edit_position = Some(snapshot.anchor_after(edit_start));
-
-                                buffer.end_transaction(cx)
-                            });
+                            let edits = char_ops
+                                .into_iter()
+                                .filter_map(|operation| match operation {
+                                    CharOperation::Insert { text } => {
+                                        let edit_start = snapshot.anchor_after(edit_start);
+                                        Some((edit_start..edit_start, text))
+                                    }
+                                    CharOperation::Delete { bytes } => {
+                                        let edit_end = edit_start + bytes;
+                                        let edit_range = snapshot.anchor_after(edit_start)
+                                            ..snapshot.anchor_before(edit_end);
+                                        edit_start = edit_end;
+                                        Some((edit_range, String::new()))
+                                    }
+                                    CharOperation::Keep { bytes } => {
+                                        let edit_end = edit_start + bytes;
+                                        let edit_range = snapshot.anchor_after(edit_start)
+                                            ..snapshot.anchor_before(edit_end);
+                                        edit_start = edit_end;
+                                        codegen.last_equal_ranges.push(edit_range);
+                                        None
+                                    }
+                                })
+                                .collect::<Vec<_>>();
 
-                            if let Some(transaction) = transaction {
-                                if let Some(first_transaction) =
-                                    codegen.transformation_transaction_id
-                                {
-                                    // Group all assistant edits into the first transaction.
-                                    codegen.buffer.update(cx, |buffer, cx| {
-                                        buffer.merge_transactions(
-                                            transaction,
-                                            first_transaction,
-                                            cx,
-                                        )
-                                    });
-                                } else {
-                                    codegen.transformation_transaction_id = Some(transaction);
-                                    codegen.buffer.update(cx, |buffer, cx| {
-                                        buffer.finalize_last_transaction(cx)
-                                    });
-                                }
+                            if codegen.active {
+                                codegen.apply_edits(edits.iter().cloned(), cx);
+                                codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
                             }
-
-                            codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
+                            codegen.edits.extend(edits);
+                            codegen.line_operations = line_ops;
+                            codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 
                             cx.notify();
                         })?;
@@ -2667,9 +2937,8 @@ impl Codegen {
                     // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
                     // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
                     // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
-                    let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
-                        codegen.reapply_batch_diff(edit_range.clone(), cx)
-                    })?;
+                    let batch_diff_task =
+                        codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
                     let (line_based_stream_diff, ()) =
                         join!(line_based_stream_diff, batch_diff_task);
                     line_based_stream_diff?;
@@ -2713,24 +2982,45 @@ impl Codegen {
                 buffer.undo_transaction(transaction_id, cx);
                 buffer.refresh_preview(cx);
             }
+        });
+    }
 
-            if let Some(transaction_id) = self.initial_transaction_id.take() {
-                buffer.undo_transaction(transaction_id, cx);
-                buffer.refresh_preview(cx);
-            }
+    fn apply_edits(
+        &mut self,
+        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
+        cx: &mut ModelContext<CodegenAlternative>,
+    ) {
+        let transaction = self.buffer.update(cx, |buffer, cx| {
+            // Avoid grouping assistant edits with user edits.
+            buffer.finalize_last_transaction(cx);
+            buffer.start_transaction(cx);
+            buffer.edit(edits, None, cx);
+            buffer.end_transaction(cx)
         });
+
+        if let Some(transaction) = transaction {
+            if let Some(first_transaction) = self.transformation_transaction_id {
+                // Group all assistant edits into the first transaction.
+                self.buffer.update(cx, |buffer, cx| {
+                    buffer.merge_transactions(transaction, first_transaction, cx)
+                });
+            } else {
+                self.transformation_transaction_id = Some(transaction);
+                self.buffer
+                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
+            }
+        }
     }
 
     fn reapply_line_based_diff(
         &mut self,
-        edit_range: Range<Anchor>,
-        line_operations: Vec<LineOperation>,
+        line_operations: impl IntoIterator<Item = LineOperation>,
         cx: &mut ModelContext<Self>,
     ) {
         let old_snapshot = self.snapshot.clone();
-        let old_range = edit_range.to_point(&old_snapshot);
+        let old_range = self.range.to_point(&old_snapshot);
         let new_snapshot = self.buffer.read(cx).snapshot(cx);
-        let new_range = edit_range.to_point(&new_snapshot);
+        let new_range = self.range.to_point(&new_snapshot);
 
         let mut old_row = old_range.start.row;
         let mut new_row = new_range.start.row;
@@ -2781,15 +3071,11 @@ impl Codegen {
         }
     }
 
-    fn reapply_batch_diff(
-        &mut self,
-        edit_range: Range<Anchor>,
-        cx: &mut ModelContext<Self>,
-    ) -> Task<()> {
+    fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
         let old_snapshot = self.snapshot.clone();
-        let old_range = edit_range.to_point(&old_snapshot);
+        let old_range = self.range.to_point(&old_snapshot);
         let new_snapshot = self.buffer.read(cx).snapshot(cx);
-        let new_range = edit_range.to_point(&new_snapshot);
+        let new_range = self.range.to_point(&new_snapshot);
 
         cx.spawn(|codegen, mut cx| async move {
             let (deleted_row_ranges, inserted_row_ranges) = cx
@@ -3073,10 +3359,10 @@ mod tests {
         });
         let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
         let codegen = cx.new_model(|cx| {
-            Codegen::new(
+            CodegenAlternative::new(
                 buffer.clone(),
                 range.clone(),
-                None,
+                true,
                 None,
                 prompt_builder,
                 cx,

crates/language_model/src/registry.rs 🔗

@@ -76,6 +76,7 @@ impl Global for GlobalLanguageModelRegistry {}
 pub struct LanguageModelRegistry {
     active_model: Option<ActiveModel>,
     providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
+    inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 }
 
 pub struct ActiveModel {
@@ -229,6 +230,37 @@ impl LanguageModelRegistry {
     pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
         self.active_model.as_ref()?.model.clone()
     }
+
+    /// Selects and sets the inline alternatives for language models based on
+    /// provider name and id.
+    pub fn select_inline_alternative_models(
+        &mut self,
+        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let mut selected_alternatives = Vec::new();
+
+        for (provider_id, model_id) in alternatives {
+            if let Some(provider) = self.providers.get(&provider_id) {
+                if let Some(model) = provider
+                    .provided_models(cx)
+                    .iter()
+                    .find(|m| m.id() == model_id)
+                {
+                    selected_alternatives.push(model.clone());
+                }
+            }
+        }
+
+        self.inline_alternatives = selected_alternatives;
+    }
+
+    /// The models to use for inline assists. Returns the union of the active
+    /// model and all inline alternatives. When there are multiple models, the
+    /// user will be able to cycle through results.
+    pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
+        &self.inline_alternatives
+    }
 }
 
 #[cfg(test)]

crates/multi_buffer/src/multi_buffer.rs 🔗

@@ -1106,6 +1106,26 @@ impl MultiBuffer {
         }
     }
 
+    pub fn forget_transaction(
+        &mut self,
+        transaction_id: TransactionId,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let Some(buffer) = self.as_singleton() {
+            buffer.update(cx, |buffer, _| {
+                buffer.forget_transaction(transaction_id);
+            });
+        } else if let Some(transaction) = self.history.forget(transaction_id) {
+            for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions {
+                if let Some(state) = self.buffers.borrow_mut().get_mut(&buffer_id) {
+                    state.buffer.update(cx, |buffer, _| {
+                        buffer.forget_transaction(buffer_transaction_id);
+                    });
+                }
+            }
+        }
+    }
+
     pub fn stream_excerpts_with_context_lines(
         &mut self,
         buffer: Model<Buffer>,

docs/src/assistant/configuration.md 🔗

@@ -20,6 +20,7 @@ To further customize providers, you can use `settings.json` to do that as follow
 - [Configuring endpoints](#custom-endpoint)
 - [Configuring timeouts](#provider-timeout)
 - [Configuring default model](#default-model)
+- [Configuring alternative models for inline assists](#alternative-assists)
 
 ### Zed AI {#zed-ai}
 
@@ -264,6 +265,31 @@ You can also manually edit the `default_model` object in your settings:
 }
 ```
 
+#### Configuring alternative models for inline assists {#alternative-assists}
+
+You can configure additional models that will be used to perform inline assists in parallel. When you do this,
+the inline assist UI will surface controls to cycle between the alternatives generated by each model. The models
+you specify here are always used in _addition_ to your default model. For example, the following configuration
+will generate two outputs for every assist. One with Claude 3.5 Sonnet, and one with GPT-4o.
+
+```json
+{
+  "assistant": {
+    "default_model": {
+      "provider": "zed.dev",
+      "model": "claude-3-5-sonnet"
+    },
+    "inline_alternatives": [
+      {
+        "provider": "zed.dev",
+        "model": "gpt-4o"
+      }
+    ],
+    "version": "2"
+  }
+}
+```
+
 #### Common Panel Settings
 
 | key            | type    | default | description                                                                           |