Fix issues with Claude in Assistant2 (#12619)

Mikayla Maki and Nathan created

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

crates/assistant/src/assistant.rs                     |  28 +
crates/assistant/src/assistant_panel.rs               |  42 +-
crates/assistant/src/assistant_settings.rs            |  43 +-
crates/assistant/src/codegen.rs                       | 166 ++++++------
crates/assistant/src/completion_provider.rs           |  37 +-
crates/assistant/src/completion_provider/anthropic.rs | 103 +++++--
crates/assistant/src/completion_provider/cloud.rs     |  46 +-
crates/assistant/src/completion_provider/open_ai.rs   |   8 
crates/editor/src/editor.rs                           |  14 +
crates/editor/src/element.rs                          |   6 
crates/editor/src/hunk_diff.rs                        |   2 
11 files changed, 279 insertions(+), 216 deletions(-)

Detailed changes

crates/assistant/src/assistant.rs 🔗

@@ -12,7 +12,7 @@ mod streaming_diff;
 
 pub use assistant_panel::AssistantPanel;
 
-use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
+use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
 use assistant_slash_command::SlashCommandRegistry;
 use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
@@ -87,14 +87,14 @@ impl Display for Role {
 
 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 pub enum LanguageModel {
-    ZedDotDev(ZedDotDevModel),
+    Cloud(CloudModel),
     OpenAi(OpenAiModel),
     Anthropic(AnthropicModel),
 }
 
 impl Default for LanguageModel {
     fn default() -> Self {
-        LanguageModel::ZedDotDev(ZedDotDevModel::default())
+        LanguageModel::Cloud(CloudModel::default())
     }
 }
 
@@ -103,7 +103,7 @@ impl LanguageModel {
         match self {
             LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
             LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
-            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
+            LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
         }
     }
 
@@ -111,7 +111,7 @@ impl LanguageModel {
         match self {
             LanguageModel::OpenAi(model) => model.display_name().into(),
             LanguageModel::Anthropic(model) => model.display_name().into(),
-            LanguageModel::ZedDotDev(model) => model.display_name().into(),
+            LanguageModel::Cloud(model) => model.display_name().into(),
         }
     }
 
@@ -119,7 +119,7 @@ impl LanguageModel {
         match self {
             LanguageModel::OpenAi(model) => model.max_token_count(),
             LanguageModel::Anthropic(model) => model.max_token_count(),
-            LanguageModel::ZedDotDev(model) => model.max_token_count(),
+            LanguageModel::Cloud(model) => model.max_token_count(),
         }
     }
 
@@ -127,7 +127,7 @@ impl LanguageModel {
         match self {
             LanguageModel::OpenAi(model) => model.id(),
             LanguageModel::Anthropic(model) => model.id(),
-            LanguageModel::ZedDotDev(model) => model.id(),
+            LanguageModel::Cloud(model) => model.id(),
         }
     }
 }
@@ -172,6 +172,20 @@ impl LanguageModelRequest {
             tools: Vec::new(),
         }
     }
+
+    /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
+    pub fn preprocess(&mut self) {
+        match &self.model {
+            LanguageModel::OpenAi(_) => {}
+            LanguageModel::Anthropic(_) => {}
+            LanguageModel::Cloud(model) => match model {
+                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
+                    preprocess_anthropic_request(self);
+                }
+                _ => {}
+            },
+        }
+    }
 }
 
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

crates/assistant/src/assistant_panel.rs 🔗

@@ -17,7 +17,7 @@ use anyhow::{anyhow, Result};
 use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
 use client::telemetry::Telemetry;
 use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque};
-use editor::actions::ShowCompletions;
+use editor::{actions::ShowCompletions, GutterDimensions};
 use editor::{
     actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt},
     display_map::{
@@ -469,7 +469,7 @@ impl AssistantPanel {
             )
         });
 
-        let measurements = Arc::new(Mutex::new(BlockMeasurements::default()));
+        let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
         let inline_assistant = cx.new_view(|cx| {
             InlineAssistant::new(
                 inline_assist_id,
@@ -492,10 +492,7 @@ impl AssistantPanel {
                     render: Box::new({
                         let inline_assistant = inline_assistant.clone();
                         move |cx: &mut BlockContext| {
-                            *measurements.lock() = BlockMeasurements {
-                                gutter_width: cx.gutter_dimensions.width,
-                                gutter_margin: cx.gutter_dimensions.margin,
-                            };
+                            *measurements.lock() = *cx.gutter_dimensions;
                             inline_assistant.clone().into_any_element()
                         }
                     }),
@@ -583,6 +580,7 @@ impl AssistantPanel {
                 ],
             },
         );
+
         self.pending_inline_assist_ids_by_editor
             .entry(editor.downgrade())
             .or_default()
@@ -810,7 +808,7 @@ impl AssistantPanel {
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
             anyhow::Ok(())
         })
-        .detach();
+        .detach_and_log_err(cx);
     }
 
     fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut ViewContext<Self>) {
@@ -1431,7 +1429,7 @@ impl Panel for AssistantPanel {
             return None;
         }
 
-        Some(IconName::Ai)
+        Some(IconName::ZedAssistant)
     }
 
     fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
@@ -3151,7 +3149,7 @@ impl ConversationEditor {
 
                             h_flex()
                                 .id(("message_header", message_id.0))
-                                .pl(cx.gutter_dimensions.width + cx.gutter_dimensions.margin)
+                                .pl(cx.gutter_dimensions.full_width())
                                 .h_11()
                                 .w_full()
                                 .relative()
@@ -3551,7 +3549,7 @@ struct InlineAssistant {
     prompt_editor: View<Editor>,
     confirmed: bool,
     include_conversation: bool,
-    measurements: Arc<Mutex<BlockMeasurements>>,
+    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
     prompt_history: VecDeque<String>,
     prompt_history_ix: Option<usize>,
     pending_prompt: String,
@@ -3563,7 +3561,8 @@ impl EventEmitter<InlineAssistantEvent> for InlineAssistant {}
 
 impl Render for InlineAssistant {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        let measurements = *self.measurements.lock();
+        let gutter_dimensions = *self.gutter_dimensions.lock();
+        let icon_size = IconSize::default();
         h_flex()
             .w_full()
             .py_2()
@@ -3576,14 +3575,20 @@ impl Render for InlineAssistant {
             .on_action(cx.listener(Self::move_down))
             .child(
                 h_flex()
-                    .w(measurements.gutter_width + measurements.gutter_margin)
+                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
+                    .pr(gutter_dimensions.fold_area_width())
+                    .justify_end()
                     .children(if let Some(error) = self.codegen.read(cx).error() {
                         let error_message = SharedString::from(error.to_string());
                         Some(
                             div()
                                 .id("error")
                                 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
-                                .child(Icon::new(IconName::XCircle).color(Color::Error)),
+                                .child(
+                                    Icon::new(IconName::XCircle)
+                                        .size(icon_size)
+                                        .color(Color::Error),
+                                ),
                         )
                     } else {
                         None
@@ -3603,7 +3608,7 @@ impl InlineAssistant {
     #[allow(clippy::too_many_arguments)]
     fn new(
         id: usize,
-        measurements: Arc<Mutex<BlockMeasurements>>,
+        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
         include_conversation: bool,
         prompt_history: VecDeque<String>,
         codegen: Model<Codegen>,
@@ -3630,7 +3635,7 @@ impl InlineAssistant {
             prompt_editor,
             confirmed: false,
             include_conversation,
-            measurements,
+            gutter_dimensions,
             prompt_history,
             prompt_history_ix: None,
             pending_prompt: String::new(),
@@ -3755,13 +3760,6 @@ impl InlineAssistant {
     }
 }
 
-// This wouldn't need to exist if we could pass parameters when rendering child views.
-#[derive(Copy, Clone, Default)]
-struct BlockMeasurements {
-    gutter_width: Pixels,
-    gutter_margin: Pixels,
-}
-
 struct PendingInlineAssist {
     editor: WeakView<Editor>,
     inline_assistant: Option<(BlockId, View<InlineAssistant>)>,

crates/assistant/src/assistant_settings.rs 🔗

@@ -14,10 +14,10 @@ use serde::{
 use settings::{Settings, SettingsSources};
 use strum::{EnumIter, IntoEnumIterator};
 
-use crate::LanguageModel;
+use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 
 #[derive(Clone, Debug, Default, PartialEq, EnumIter)]
-pub enum ZedDotDevModel {
+pub enum CloudModel {
     Gpt3Point5Turbo,
     Gpt4,
     Gpt4Turbo,
@@ -29,7 +29,7 @@ pub enum ZedDotDevModel {
     Custom(String),
 }
 
-impl Serialize for ZedDotDevModel {
+impl Serialize for CloudModel {
     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
     where
         S: Serializer,
@@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel {
     }
 }
 
-impl<'de> Deserialize<'de> for ZedDotDevModel {
+impl<'de> Deserialize<'de> for CloudModel {
     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     where
         D: Deserializer<'de>,
@@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
         struct ZedDotDevModelVisitor;
 
         impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
-            type Value = ZedDotDevModel;
+            type Value = CloudModel;
 
             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
@@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
             where
                 E: de::Error,
             {
-                let model = ZedDotDevModel::iter()
+                let model = CloudModel::iter()
                     .find(|model| model.id() == value)
-                    .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
+                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
                 Ok(model)
             }
         }
@@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
     }
 }
 
-impl JsonSchema for ZedDotDevModel {
+impl JsonSchema for CloudModel {
     fn schema_name() -> String {
         "ZedDotDevModel".to_owned()
     }
 
     fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
-        let variants = ZedDotDevModel::iter()
+        let variants = CloudModel::iter()
             .filter_map(|model| {
                 let id = model.id();
                 if id.is_empty() {
@@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel {
             enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
             metadata: Some(Box::new(Metadata {
                 title: Some("ZedDotDevModel".to_owned()),
-                default: Some(ZedDotDevModel::default().id().into()),
+                default: Some(CloudModel::default().id().into()),
                 examples: variants.into_iter().map(Into::into).collect(),
                 ..Default::default()
             })),
@@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel {
     }
 }
 
-impl ZedDotDevModel {
+impl CloudModel {
     pub fn id(&self) -> &str {
         match self {
             Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
@@ -133,6 +133,15 @@ impl ZedDotDevModel {
             Self::Custom(_) => 4096, // TODO: Make this configurable
         }
     }
+
+    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
+        match self {
+            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
+                preprocess_anthropic_request(request)
+            }
+            _ => {}
+        }
+    }
 }
 
 #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
@@ -147,7 +156,7 @@ pub enum AssistantDockPosition {
 #[derive(Debug, PartialEq)]
 pub enum AssistantProvider {
     ZedDotDev {
-        model: ZedDotDevModel,
+        model: CloudModel,
     },
     OpenAi {
         model: OpenAiModel,
@@ -175,9 +184,7 @@ impl Default for AssistantProvider {
 #[serde(tag = "name", rename_all = "snake_case")]
 pub enum AssistantProviderContent {
     #[serde(rename = "zed.dev")]
-    ZedDotDev {
-        default_model: Option<ZedDotDevModel>,
-    },
+    ZedDotDev { default_model: Option<CloudModel> },
     #[serde(rename = "openai")]
     OpenAi {
         default_model: Option<OpenAiModel>,
@@ -281,7 +288,7 @@ impl AssistantSettingsContent {
                     Some(AssistantProviderContent::ZedDotDev {
                         default_model: model,
                     }) => {
-                        if let LanguageModel::ZedDotDev(new_model) = new_model {
+                        if let LanguageModel::Cloud(new_model) = new_model {
                             *model = Some(new_model);
                         }
                     }
@@ -302,7 +309,7 @@ impl AssistantSettingsContent {
                         }
                     }
                     provider => match new_model {
-                        LanguageModel::ZedDotDev(model) => {
+                        LanguageModel::Cloud(model) => {
                             *provider = Some(AssistantProviderContent::ZedDotDev {
                                 default_model: Some(model),
                             })
@@ -613,7 +620,7 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::ZedDotDev {
-                model: ZedDotDevModel::Custom("custom".into())
+                model: CloudModel::Custom("custom".into())
             }
         );
     }

crates/assistant/src/codegen.rs 🔗

@@ -11,6 +11,7 @@ use language::{Rope, TransactionId};
 use multi_buffer::MultiBufferRow;
 use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
 
+#[derive(Debug)]
 pub enum Event {
     Finished,
     Undone,
@@ -120,91 +121,98 @@ impl Codegen {
                     let mut edit_start = range.start.to_offset(&snapshot);
 
                     let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
-                    let diff = cx.background_executor().spawn(async move {
-                        let mut response_latency = None;
-                        let request_start = Instant::now();
-                        let diff = async {
-                            let chunks = strip_invalid_spans_from_codeblock(response.await?);
-                            futures::pin_mut!(chunks);
-                            let mut diff = StreamingDiff::new(selected_text.to_string());
-
-                            let mut new_text = String::new();
-                            let mut base_indent = None;
-                            let mut line_indent = None;
-                            let mut first_line = true;
-
-                            while let Some(chunk) = chunks.next().await {
-                                if response_latency.is_none() {
-                                    response_latency = Some(request_start.elapsed());
-                                }
-                                let chunk = chunk?;
-
-                                let mut lines = chunk.split('\n').peekable();
-                                while let Some(line) = lines.next() {
-                                    new_text.push_str(line);
-                                    if line_indent.is_none() {
-                                        if let Some(non_whitespace_ch_ix) =
-                                            new_text.find(|ch: char| !ch.is_whitespace())
-                                        {
-                                            line_indent = Some(non_whitespace_ch_ix);
-                                            base_indent = base_indent.or(line_indent);
-
-                                            let line_indent = line_indent.unwrap();
-                                            let base_indent = base_indent.unwrap();
-                                            let indent_delta =
-                                                line_indent as i32 - base_indent as i32;
-                                            let mut corrected_indent_len = cmp::max(
-                                                0,
-                                                suggested_line_indent.len as i32 + indent_delta,
-                                            )
-                                                as usize;
-                                            if first_line {
-                                                corrected_indent_len = corrected_indent_len
-                                                    .saturating_sub(
-                                                        selection_start.column as usize,
-                                                    );
+                    let diff: Task<anyhow::Result<()>> =
+                        cx.background_executor().spawn(async move {
+                            let mut response_latency = None;
+                            let request_start = Instant::now();
+                            let diff = async {
+                                let chunks = strip_invalid_spans_from_codeblock(response.await?);
+                                futures::pin_mut!(chunks);
+                                let mut diff = StreamingDiff::new(selected_text.to_string());
+
+                                let mut new_text = String::new();
+                                let mut base_indent = None;
+                                let mut line_indent = None;
+                                let mut first_line = true;
+
+                                while let Some(chunk) = chunks.next().await {
+                                    if response_latency.is_none() {
+                                        response_latency = Some(request_start.elapsed());
+                                    }
+                                    let chunk = chunk?;
+
+                                    let mut lines = chunk.split('\n').peekable();
+                                    while let Some(line) = lines.next() {
+                                        new_text.push_str(line);
+                                        if line_indent.is_none() {
+                                            if let Some(non_whitespace_ch_ix) =
+                                                new_text.find(|ch: char| !ch.is_whitespace())
+                                            {
+                                                line_indent = Some(non_whitespace_ch_ix);
+                                                base_indent = base_indent.or(line_indent);
+
+                                                let line_indent = line_indent.unwrap();
+                                                let base_indent = base_indent.unwrap();
+                                                let indent_delta =
+                                                    line_indent as i32 - base_indent as i32;
+                                                let mut corrected_indent_len = cmp::max(
+                                                    0,
+                                                    suggested_line_indent.len as i32 + indent_delta,
+                                                )
+                                                    as usize;
+                                                if first_line {
+                                                    corrected_indent_len = corrected_indent_len
+                                                        .saturating_sub(
+                                                            selection_start.column as usize,
+                                                        );
+                                                }
+
+                                                let indent_char = suggested_line_indent.char();
+                                                let mut indent_buffer = [0; 4];
+                                                let indent_str =
+                                                    indent_char.encode_utf8(&mut indent_buffer);
+                                                new_text.replace_range(
+                                                    ..line_indent,
+                                                    &indent_str.repeat(corrected_indent_len),
+                                                );
                                             }
-
-                                            let indent_char = suggested_line_indent.char();
-                                            let mut indent_buffer = [0; 4];
-                                            let indent_str =
-                                                indent_char.encode_utf8(&mut indent_buffer);
-                                            new_text.replace_range(
-                                                ..line_indent,
-                                                &indent_str.repeat(corrected_indent_len),
-                                            );
                                         }
-                                    }
 
-                                    if line_indent.is_some() {
-                                        hunks_tx.send(diff.push_new(&new_text)).await?;
-                                        new_text.clear();
-                                    }
+                                        if line_indent.is_some() {
+                                            hunks_tx.send(diff.push_new(&new_text)).await?;
+                                            new_text.clear();
+                                        }
 
-                                    if lines.peek().is_some() {
-                                        hunks_tx.send(diff.push_new("\n")).await?;
-                                        line_indent = None;
-                                        first_line = false;
+                                        if lines.peek().is_some() {
+                                            hunks_tx.send(diff.push_new("\n")).await?;
+                                            line_indent = None;
+                                            first_line = false;
+                                        }
                                     }
                                 }
+                                hunks_tx.send(diff.push_new(&new_text)).await?;
+                                hunks_tx.send(diff.finish()).await?;
+
+                                anyhow::Ok(())
+                            };
+
+                            let result = diff.await;
+
+                            let error_message =
+                                result.as_ref().err().map(|error| error.to_string());
+                            if let Some(telemetry) = telemetry {
+                                telemetry.report_assistant_event(
+                                    None,
+                                    telemetry_events::AssistantKind::Inline,
+                                    model_telemetry_id,
+                                    response_latency,
+                                    error_message,
+                                );
                             }
-                            hunks_tx.send(diff.push_new(&new_text)).await?;
-                            hunks_tx.send(diff.finish()).await?;
-
-                            anyhow::Ok(())
-                        };
-
-                        let error_message = diff.await.err().map(|error| error.to_string());
-                        if let Some(telemetry) = telemetry {
-                            telemetry.report_assistant_event(
-                                None,
-                                telemetry_events::AssistantKind::Inline,
-                                model_telemetry_id,
-                                response_latency,
-                                error_message,
-                            );
-                        }
-                    });
+
+                            result?;
+                            Ok(())
+                        });
 
                     while let Some(hunks) = hunks_rx.next().await {
                         this.update(&mut cx, |this, cx| {
@@ -266,7 +274,7 @@ impl Codegen {
                         })?;
                     }
 
-                    diff.await;
+                    diff.await?;
 
                     anyhow::Ok(())
                 };

crates/assistant/src/completion_provider.rs 🔗

@@ -1,14 +1,14 @@
 mod anthropic;
+mod cloud;
 #[cfg(test)]
 mod fake;
 mod open_ai;
-mod zed;
 
 pub use anthropic::*;
+pub use cloud::*;
 #[cfg(test)]
 pub use fake::*;
 pub use open_ai::*;
-pub use zed::*;
 
 use crate::{
     assistant_settings::{AssistantProvider, AssistantSettings},
@@ -25,8 +25,8 @@ use std::time::Duration;
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     let mut settings_version = 0;
     let provider = match &AssistantSettings::get_global(cx).provider {
-        AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
-            ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
+            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
         ),
         AssistantProvider::OpenAi {
             model,
@@ -87,14 +87,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
                         settings_version,
                     );
                 }
-                (
-                    CompletionProvider::ZedDotDev(provider),
-                    AssistantProvider::ZedDotDev { model },
-                ) => {
+                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
                     provider.update(model.clone(), settings_version);
                 }
                 (_, AssistantProvider::ZedDotDev { model }) => {
-                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
+                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
                         model.clone(),
                         client.clone(),
                         settings_version,
@@ -142,7 +139,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 pub enum CompletionProvider {
     OpenAi(OpenAiCompletionProvider),
     Anthropic(AnthropicCompletionProvider),
-    ZedDotDev(ZedDotDevCompletionProvider),
+    Cloud(CloudCompletionProvider),
     #[cfg(test)]
     Fake(FakeCompletionProvider),
 }
@@ -164,9 +161,9 @@ impl CompletionProvider {
                 .available_models()
                 .map(LanguageModel::Anthropic)
                 .collect(),
-            CompletionProvider::ZedDotDev(provider) => provider
+            CompletionProvider::Cloud(provider) => provider
                 .available_models()
-                .map(LanguageModel::ZedDotDev)
+                .map(LanguageModel::Cloud)
                 .collect(),
             #[cfg(test)]
             CompletionProvider::Fake(_) => unimplemented!(),
@@ -177,7 +174,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.settings_version(),
             CompletionProvider::Anthropic(provider) => provider.settings_version(),
-            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
+            CompletionProvider::Cloud(provider) => provider.settings_version(),
             #[cfg(test)]
             CompletionProvider::Fake(_) => unimplemented!(),
         }
@@ -187,7 +184,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
             CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
-            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
+            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
             #[cfg(test)]
             CompletionProvider::Fake(_) => true,
         }
@@ -197,7 +194,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
             CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
-            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
+            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
             #[cfg(test)]
             CompletionProvider::Fake(_) => Task::ready(Ok(())),
         }
@@ -207,7 +204,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
             CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
-            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
+            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
             #[cfg(test)]
             CompletionProvider::Fake(_) => unimplemented!(),
         }
@@ -217,7 +214,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
             CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
-            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
+            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
             #[cfg(test)]
             CompletionProvider::Fake(_) => Task::ready(Ok(())),
         }
@@ -227,7 +224,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
             CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
-            CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
+            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
             #[cfg(test)]
             CompletionProvider::Fake(_) => LanguageModel::default(),
         }
@@ -241,7 +238,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
             CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
-            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
+            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
             #[cfg(test)]
             CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
         }
@@ -254,7 +251,7 @@ impl CompletionProvider {
         match self {
             CompletionProvider::OpenAi(provider) => provider.complete(request),
             CompletionProvider::Anthropic(provider) => provider.complete(request),
-            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
+            CompletionProvider::Cloud(provider) => provider.complete(request),
             #[cfg(test)]
             CompletionProvider::Fake(provider) => provider.complete(),
         }

crates/assistant/src/completion_provider/anthropic.rs 🔗

@@ -1,9 +1,9 @@
-use crate::count_open_ai_tokens;
 use crate::{
     assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
     Role,
 };
-use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
+use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
+use anthropic::{stream_completion, Request, RequestMessage};
 use anyhow::{anyhow, Result};
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -167,58 +167,85 @@ impl AnthropicCompletionProvider {
         .boxed()
     }
 
-    fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
+    fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
+        preprocess_anthropic_request(&mut request);
+
         let model = match request.model {
             LanguageModel::Anthropic(model) => model,
             _ => self.model(),
         };
 
         let mut system_message = String::new();
+        if request
+            .messages
+            .first()
+            .map_or(false, |message| message.role == Role::System)
+        {
+            system_message = request.messages.remove(0).content;
+        }
 
-        let mut messages: Vec<RequestMessage> = Vec::new();
-        for message in request.messages {
-            if message.content.is_empty() {
-                continue;
-            }
+        Request {
+            model,
+            messages: request
+                .messages
+                .iter()
+                .map(|msg| RequestMessage {
+                    role: match msg.role {
+                        Role::User => anthropic::Role::User,
+                        Role::Assistant => anthropic::Role::Assistant,
+                        Role::System => unreachable!("filtered out by preprocess_request"),
+                    },
+                    content: msg.content.clone(),
+                })
+                .collect(),
+            stream: true,
+            system: system_message,
+            max_tokens: 4092,
+        }
+    }
+}
 
-            match message.role {
-                Role::User | Role::Assistant => {
-                    let role = match message.role {
-                        Role::User => AnthropicRole::User,
-                        Role::Assistant => AnthropicRole::Assistant,
-                        _ => unreachable!(),
-                    };
+pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
+    let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+    let mut system_message = String::new();
 
-                    if let Some(last_message) = messages.last_mut() {
-                        if last_message.role == role {
-                            last_message.content.push_str("\n\n");
-                            last_message.content.push_str(&message.content);
-                            continue;
-                        }
-                    }
+    for message in request.messages.drain(..) {
+        if message.content.is_empty() {
+            continue;
+        }
 
-                    messages.push(RequestMessage {
-                        role,
-                        content: message.content,
-                    });
-                }
-                Role::System => {
-                    if !system_message.is_empty() {
-                        system_message.push_str("\n\n");
+        match message.role {
+            Role::User | Role::Assistant => {
+                if let Some(last_message) = new_messages.last_mut() {
+                    if last_message.role == message.role {
+                        last_message.content.push_str("\n\n");
+                        last_message.content.push_str(&message.content);
+                        continue;
                     }
-                    system_message.push_str(&message.content);
                 }
+
+                new_messages.push(message);
+            }
+            Role::System => {
+                if !system_message.is_empty() {
+                    system_message.push_str("\n\n");
+                }
+                system_message.push_str(&message.content);
             }
         }
+    }
 
-        Request {
-            model,
-            messages,
-            stream: true,
-            system: system_message,
-            max_tokens: 4092,
-        }
+    if !system_message.is_empty() {
+        request.messages.insert(
+            0,
+            LanguageModelRequestMessage {
+                role: Role::System,
+                content: system_message,
+            },
+        );
     }
+
+    request.messages = new_messages;
 }
 
 struct AuthenticationPrompt {

crates/assistant/src/completion_provider/zed.rs → crates/assistant/src/completion_provider/cloud.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
+    assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
     LanguageModelRequest,
 };
 use anyhow::{anyhow, Result};
@@ -10,17 +10,17 @@ use std::{future, sync::Arc};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
 
-pub struct ZedDotDevCompletionProvider {
+pub struct CloudCompletionProvider {
     client: Arc<Client>,
-    model: ZedDotDevModel,
+    model: CloudModel,
     settings_version: usize,
     status: client::Status,
     _maintain_client_status: Task<()>,
 }
 
-impl ZedDotDevCompletionProvider {
+impl CloudCompletionProvider {
     pub fn new(
-        model: ZedDotDevModel,
+        model: CloudModel,
         client: Arc<Client>,
         settings_version: usize,
         cx: &mut AppContext,
@@ -30,7 +30,7 @@ impl ZedDotDevCompletionProvider {
         let maintain_client_status = cx.spawn(|mut cx| async move {
             while let Some(status) = status_rx.next().await {
                 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                    if let CompletionProvider::ZedDotDev(provider) = provider {
+                    if let CompletionProvider::Cloud(provider) = provider {
                         provider.status = status;
                     } else {
                         unreachable!()
@@ -47,20 +47,20 @@ impl ZedDotDevCompletionProvider {
         }
     }
 
-    pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
+    pub fn update(&mut self, model: CloudModel, settings_version: usize) {
         self.model = model;
         self.settings_version = settings_version;
     }
 
-    pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
-        let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
+    pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
+        let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
             Some(custom_model)
         } else {
             None
         };
-        ZedDotDevModel::iter().filter_map(move |model| {
-            if let ZedDotDevModel::Custom(_) = model {
-                Some(ZedDotDevModel::Custom(custom_model.take()?))
+        CloudModel::iter().filter_map(move |model| {
+            if let CloudModel::Custom(_) = model {
+                Some(CloudModel::Custom(custom_model.take()?))
             } else {
                 Some(model)
             }
@@ -71,7 +71,7 @@ impl ZedDotDevCompletionProvider {
         self.settings_version
     }
 
-    pub fn model(&self) -> ZedDotDevModel {
+    pub fn model(&self) -> CloudModel {
         self.model.clone()
     }
 
@@ -94,21 +94,19 @@ impl ZedDotDevCompletionProvider {
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
         match request.model {
-            LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
-            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
-            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
-            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
+            LanguageModel::Cloud(CloudModel::Gpt4)
+            | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
+            | LanguageModel::Cloud(CloudModel::Gpt4Omni)
+            | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
                 count_open_ai_tokens(request, cx.background_executor())
             }
-            LanguageModel::ZedDotDev(
-                ZedDotDevModel::Claude3Opus
-                | ZedDotDevModel::Claude3Sonnet
-                | ZedDotDevModel::Claude3Haiku,
+            LanguageModel::Cloud(
+                CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku,
             ) => {
                 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
                 count_open_ai_tokens(request, cx.background_executor())
             }
-            LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
+            LanguageModel::Cloud(CloudModel::Custom(model)) => {
                 let request = self.client.request(proto::CountTokensWithLanguageModel {
                     model,
                     messages: request
@@ -129,8 +127,10 @@ impl ZedDotDevCompletionProvider {
 
     pub fn complete(
         &self,
-        request: LanguageModelRequest,
+        mut request: LanguageModelRequest,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        request.preprocess();
+
         let request = proto::CompleteWithLanguageModel {
             model: request.model.id().to_string(),
             messages: request

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -1,4 +1,4 @@
-use crate::assistant_settings::ZedDotDevModel;
+use crate::assistant_settings::CloudModel;
 use crate::{
     assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
 };
@@ -210,9 +210,9 @@ pub fn count_open_ai_tokens(
 
             match request.model {
                 LanguageModel::Anthropic(_)
-                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
-                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
-                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
+                | LanguageModel::Cloud(CloudModel::Claude3Opus)
+                | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
+                | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
                     // Tiktoken doesn't yet support these models, so we manually use the
                     // same tokenizer as GPT-4.
                     tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)

crates/editor/src/editor.rs 🔗

@@ -554,6 +554,20 @@ pub struct GutterDimensions {
     pub git_blame_entries_width: Option<Pixels>,
 }
 
+impl GutterDimensions {
+    /// The full width of the space taken up by the gutter.
+    pub fn full_width(&self) -> Pixels {
+        self.margin + self.width
+    }
+
+    /// The width of the space reserved for the fold indicators,
+    /// use alongside 'justify_end' and `gutter_width` to
+    /// right align content with the line numbers
+    pub fn fold_area_width(&self) -> Pixels {
+        self.margin + self.right_padding
+    }
+}
+
 impl Default for GutterDimensions {
     fn default() -> Self {
         Self {

crates/editor/src/element.rs 🔗

@@ -1125,9 +1125,7 @@ impl EditorElement {
                     ix as f32 * line_height - (scroll_pixel_position.y % line_height),
                 );
                 let centering_offset = point(
-                    (gutter_dimensions.right_padding + gutter_dimensions.margin
-                        - fold_indicator_size.width)
-                        / 2.,
+                    (gutter_dimensions.fold_area_width() - fold_indicator_size.width) / 2.,
                     (line_height - fold_indicator_size.height) / 2.,
                 );
                 let origin = gutter_hitbox.origin + position + centering_offset;
@@ -4629,7 +4627,7 @@ impl Element for EditorElement {
                             &mut scroll_width,
                             &gutter_dimensions,
                             em_width,
-                            gutter_dimensions.width + gutter_dimensions.margin,
+                            gutter_dimensions.full_width(),
                             line_height,
                             &line_layouts,
                             cx,

crates/editor/src/hunk_diff.rs 🔗

@@ -320,7 +320,7 @@ impl Editor {
                     div()
                         .bg(deleted_hunk_color)
                         .size_full()
-                        .pl(gutter_dimensions.width + gutter_dimensions.margin)
+                        .pl(gutter_dimensions.full_width())
                         .child(editor_with_deleted_text.clone())
                         .into_any_element()
                 }),