Use tool calling instead of XML parsing to generate edit operations (#15385)

Antonio Scandurra and Nathan created

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

Cargo.lock                                        |  13 
crates/anthropic/src/anthropic.rs                 | 341 ++++++----
crates/assistant/Cargo.toml                       |   1 
crates/assistant/src/assistant_panel.rs           |  39 
crates/assistant/src/context.rs                   | 340 +++-------
crates/assistant/src/inline_assistant.rs          | 511 +++++++++-------
crates/assistant/src/prompt_library.rs            |  42 
crates/assistant/src/prompts.rs                   |  13 
crates/assistant/src/terminal_inline_assistant.rs |  21 
crates/collab/src/rpc.rs                          | 240 +++++--
crates/completion/Cargo.toml                      |   2 
crates/completion/src/completion.rs               |  38 +
crates/language_model/src/language_model.rs       |  16 
crates/language_model/src/provider/anthropic.rs   | 132 +++
crates/language_model/src/provider/cloud.rs       | 119 ++-
crates/language_model/src/provider/fake.rs        |  23 
crates/language_model/src/provider/google.rs      |  13 
crates/language_model/src/provider/ollama.rs      |  13 
crates/language_model/src/provider/open_ai.rs     |  13 
crates/language_model/src/request.rs              |  16 
crates/proto/proto/zed.proto                      |  45 +
crates/proto/src/proto.rs                         |  15 
22 files changed, 1,154 insertions(+), 852 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -435,7 +435,6 @@ dependencies = [
  "rand 0.8.5",
  "regex",
  "rope",
- "roxmltree 0.20.0",
  "schemars",
  "search",
  "semantic_index",
@@ -2641,7 +2640,9 @@ dependencies = [
  "language_model",
  "project",
  "rand 0.8.5",
+ "schemars",
  "serde",
+ "serde_json",
  "settings",
  "smol",
  "text",
@@ -4237,7 +4238,7 @@ version = "0.5.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
 dependencies = [
- "roxmltree 0.19.0",
+ "roxmltree",
 ]
 
 [[package]]
@@ -8918,12 +8919,6 @@ version = "0.19.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
 
-[[package]]
-name = "roxmltree"
-version = "0.20.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97"
-
 [[package]]
 name = "rpc"
 version = "0.1.0"
@@ -11878,7 +11873,7 @@ dependencies = [
  "kurbo",
  "log",
  "pico-args",
- "roxmltree 0.19.0",
+ "roxmltree",
  "simplecss",
  "siphasher 1.0.1",
  "strict-num",

crates/anthropic/src/anthropic.rs 🔗

@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
-use std::{convert::TryFrom, time::Duration};
+use std::time::Duration;
 use strum::EnumIter;
 
 pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
@@ -70,112 +70,53 @@ impl Model {
     }
 }
 
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
-    User,
-    Assistant,
-}
+pub async fn complete(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+) -> Result<Response> {
+    let uri = format!("{api_url}/v1/messages");
+    let request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("Anthropic-Beta", "tools-2024-04-04")
+        .header("X-Api-Key", api_key)
+        .header("Content-Type", "application/json");
 
-impl TryFrom<String> for Role {
-    type Error = anyhow::Error;
+    let serialized_request = serde_json::to_string(&request)?;
+    let request = request_builder.body(AsyncBody::from(serialized_request))?;
 
-    fn try_from(value: String) -> Result<Self> {
-        match value.as_str() {
-            "user" => Ok(Self::User),
-            "assistant" => Ok(Self::Assistant),
-            _ => Err(anyhow!("invalid role '{value}'")),
-        }
-    }
-}
-
-impl From<Role> for String {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => "user".to_owned(),
-            Role::Assistant => "assistant".to_owned(),
-        }
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+        let response_message: Response = serde_json::from_slice(&body)?;
+        Ok(response_message)
+    } else {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+        let body_str = std::str::from_utf8(&body)?;
+        Err(anyhow!(
+            "Failed to connect to API: {} {}",
+            response.status(),
+            body_str
+        ))
     }
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub struct Request {
-    pub model: String,
-    pub messages: Vec<RequestMessage>,
-    pub stream: bool,
-    pub system: String,
-    pub max_tokens: u32,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
-    pub role: Role,
-    pub content: String,
-}
-
-#[derive(Deserialize, Serialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum ResponseEvent {
-    MessageStart {
-        message: ResponseMessage,
-    },
-    ContentBlockStart {
-        index: u32,
-        content_block: ContentBlock,
-    },
-    Ping {},
-    ContentBlockDelta {
-        index: u32,
-        delta: TextDelta,
-    },
-    ContentBlockStop {
-        index: u32,
-    },
-    MessageDelta {
-        delta: ResponseMessage,
-        usage: Usage,
-    },
-    MessageStop {},
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct ResponseMessage {
-    #[serde(rename = "type")]
-    pub message_type: Option<String>,
-    pub id: Option<String>,
-    pub role: Option<String>,
-    pub content: Option<Vec<String>>,
-    pub model: Option<String>,
-    pub stop_reason: Option<String>,
-    pub stop_sequence: Option<String>,
-    pub usage: Option<Usage>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Usage {
-    pub input_tokens: Option<u32>,
-    pub output_tokens: Option<u32>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum ContentBlock {
-    Text { text: String },
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum TextDelta {
-    TextDelta { text: String },
-}
-
 pub async fn stream_completion(
     client: &dyn HttpClient,
     api_url: &str,
     api_key: &str,
     request: Request,
     low_speed_timeout: Option<Duration>,
-) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+) -> Result<BoxStream<'static, Result<Event>>> {
+    let request = StreamingRequest {
+        base: request,
+        stream: true,
+    };
     let uri = format!("{api_url}/v1/messages");
     let mut request_builder = HttpRequest::builder()
         .method(Method::POST)
@@ -187,7 +128,9 @@ pub async fn stream_completion(
     if let Some(low_speed_timeout) = low_speed_timeout {
         request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
     }
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let serialized_request = serde_json::to_string(&request)?;
+    let request = request_builder.body(AsyncBody::from(serialized_request))?;
+
     let mut response = client.send(request).await?;
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());
@@ -212,7 +155,7 @@ pub async fn stream_completion(
 
         let body_str = std::str::from_utf8(&body)?;
 
-        match serde_json::from_str::<ResponseEvent>(body_str) {
+        match serde_json::from_str::<Event>(body_str) {
             Ok(_) => Err(anyhow!(
                 "Unexpected success response while expecting an error: {}",
                 body_str,
@@ -227,16 +170,18 @@ pub async fn stream_completion(
 }
 
 pub fn extract_text_from_events(
-    response: impl Stream<Item = Result<ResponseEvent>>,
+    response: impl Stream<Item = Result<Event>>,
 ) -> impl Stream<Item = Result<String>> {
     response.filter_map(|response| async move {
         match response {
             Ok(response) => match response {
-                ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
-                    ContentBlock::Text { text } => Some(Ok(text)),
+                Event::ContentBlockStart { content_block, .. } => match content_block {
+                    Content::Text { text } => Some(Ok(text)),
+                    _ => None,
                 },
-                ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
-                    TextDelta::TextDelta { text } => Some(Ok(text)),
+                Event::ContentBlockDelta { delta, .. } => match delta {
+                    ContentDelta::TextDelta { text } => Some(Ok(text)),
+                    _ => None,
                 },
                 _ => None,
             },
@@ -245,42 +190,162 @@ pub fn extract_text_from_events(
     })
 }
 
-// #[cfg(test)]
-// mod tests {
-//     use super::*;
-//     use http::IsahcHttpClient;
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Message {
+    pub role: Role,
+    pub content: Vec<Content>,
+}
 
-//     #[tokio::test]
-//     async fn stream_completion_success() {
-//         let http_client = IsahcHttpClient::new().unwrap();
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+}
 
-//         let request = Request {
-//             model: Model::Claude3Opus,
-//             messages: vec![RequestMessage {
-//                 role: Role::User,
-//                 content: "Ping".to_string(),
-//             }],
-//             stream: true,
-//             system: "Respond to ping with pong".to_string(),
-//             max_tokens: 4096,
-//         };
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum Content {
+    #[serde(rename = "text")]
+    Text { text: String },
+    #[serde(rename = "image")]
+    Image { source: ImageSource },
+    #[serde(rename = "tool_use")]
+    ToolUse {
+        id: String,
+        name: String,
+        input: serde_json::Value,
+    },
+    #[serde(rename = "tool_result")]
+    ToolResult {
+        tool_use_id: String,
+        content: String,
+    },
+}
 
-//         let stream = stream_completion(
-//             &http_client,
-//             "https://api.anthropic.com",
-//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
-//             request,
-//         )
-//         .await
-//         .unwrap();
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ImageSource {
+    #[serde(rename = "type")]
+    pub source_type: String,
+    pub media_type: String,
+    pub data: String,
+}
 
-//         stream
-//             .for_each(|event| async {
-//                 match event {
-//                     Ok(event) => println!("{:?}", event),
-//                     Err(e) => eprintln!("Error: {:?}", e),
-//                 }
-//             })
-//             .await;
-//     }
-// }
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Tool {
+    pub name: String,
+    pub description: String,
+    pub input_schema: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolChoice {
+    Auto,
+    Any,
+    Tool { name: String },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Request {
+    pub model: String,
+    pub max_tokens: u32,
+    pub messages: Vec<Message>,
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
+    pub tools: Vec<Tool>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub tool_choice: Option<ToolChoice>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub system: Option<String>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub metadata: Option<Metadata>,
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
+    pub stop_sequences: Vec<String>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub temperature: Option<f32>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub top_k: Option<u32>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub top_p: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct StreamingRequest {
+    #[serde(flatten)]
+    pub base: Request,
+    pub stream: bool,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Metadata {
+    pub user_id: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Usage {
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub input_tokens: Option<u32>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub output_tokens: Option<u32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Response {
+    pub id: String,
+    #[serde(rename = "type")]
+    pub response_type: String,
+    pub role: Role,
+    pub content: Vec<Content>,
+    pub model: String,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub stop_reason: Option<String>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub stop_sequence: Option<String>,
+    pub usage: Usage,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum Event {
+    #[serde(rename = "message_start")]
+    MessageStart { message: Response },
+    #[serde(rename = "content_block_start")]
+    ContentBlockStart {
+        index: usize,
+        content_block: Content,
+    },
+    #[serde(rename = "content_block_delta")]
+    ContentBlockDelta { index: usize, delta: ContentDelta },
+    #[serde(rename = "content_block_stop")]
+    ContentBlockStop { index: usize },
+    #[serde(rename = "message_delta")]
+    MessageDelta { delta: MessageDelta, usage: Usage },
+    #[serde(rename = "message_stop")]
+    MessageStop,
+    #[serde(rename = "ping")]
+    Ping,
+    #[serde(rename = "error")]
+    Error { error: ApiError },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum ContentDelta {
+    #[serde(rename = "text_delta")]
+    TextDelta { text: String },
+    #[serde(rename = "input_json_delta")]
+    InputJsonDelta { partial_json: String },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageDelta {
+    pub stop_reason: Option<String>,
+    pub stop_sequence: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ApiError {
+    #[serde(rename = "type")]
+    pub error_type: String,
+    pub message: String,
+}

crates/assistant/Cargo.toml 🔗

@@ -75,7 +75,6 @@ util.workspace = true
 uuid.workspace = true
 workspace.workspace = true
 picker.workspace = true
-roxmltree = "0.20.0"
 
 [dev-dependencies]
 completion = { workspace = true, features = ["test-support"] }

crates/assistant/src/assistant_panel.rs 🔗

@@ -1232,12 +1232,16 @@ impl ContextEditor {
 
     fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
         if let Some(step) = self.active_edit_step.as_ref() {
-            InlineAssistant::update_global(cx, |assistant, cx| {
-                for assist_id in &step.assist_ids {
-                    assistant.start_assist(*assist_id, cx);
-                }
-                !step.assist_ids.is_empty()
-            })
+            let assist_ids = step.assist_ids.clone();
+            cx.window_context().defer(|cx| {
+                InlineAssistant::update_global(cx, |assistant, cx| {
+                    for assist_id in assist_ids {
+                        assistant.start_assist(assist_id, cx);
+                    }
+                })
+            });
+
+            !step.assist_ids.is_empty()
         } else {
             false
         }
@@ -1286,11 +1290,7 @@ impl ContextEditor {
                     .collect::<String>()
             ));
             match &step.operations {
-                Some(EditStepOperations::Parsed {
-                    operations,
-                    raw_output,
-                }) => {
-                    output.push_str(&format!("Raw Output:\n{raw_output}\n"));
+                Some(EditStepOperations::Ready(operations)) => {
                     output.push_str("Parsed Operations:\n");
                     for op in operations {
                         output.push_str(&format!("  {:?}\n", op));
@@ -1794,13 +1794,12 @@ impl ContextEditor {
                                     .anchor_in_excerpt(excerpt_id, suggestion.range.end)
                                     .unwrap()
                         };
-                        let initial_text = suggestion.prepend_newline.then(|| "\n".into());
                         InlineAssistant::update_global(cx, |assistant, cx| {
                             assist_ids.push(assistant.suggest_assist(
                                 &editor,
                                 range,
                                 description,
-                                initial_text,
+                                suggestion.initial_insertion,
                                 Some(workspace.clone()),
                                 assistant_panel.upgrade().as_ref(),
                                 cx,
@@ -1862,9 +1861,11 @@ impl ContextEditor {
                                             .anchor_in_excerpt(excerpt_id, suggestion.range.end)
                                             .unwrap()
                                 };
-                                let initial_text =
-                                    suggestion.prepend_newline.then(|| "\n".to_string());
-                                inline_assist_suggestions.push((range, description, initial_text));
+                                inline_assist_suggestions.push((
+                                    range,
+                                    description,
+                                    suggestion.initial_insertion,
+                                ));
                             }
                         }
                     }
@@ -1875,12 +1876,12 @@ impl ContextEditor {
                     .new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
                 cx.update(|cx| {
                     InlineAssistant::update_global(cx, |assistant, cx| {
-                        for (range, description, initial_text) in inline_assist_suggestions {
+                        for (range, description, initial_insertion) in inline_assist_suggestions {
                             assist_ids.push(assistant.suggest_assist(
                                 &editor,
                                 range,
                                 description,
-                                initial_text,
+                                initial_insertion,
                                 Some(workspace.clone()),
                                 assistant_panel.upgrade().as_ref(),
                                 cx,
@@ -2188,7 +2189,7 @@ impl ContextEditor {
         let button_text = match self.edit_step_for_cursor(cx) {
             Some(edit_step) => match &edit_step.operations {
                 Some(EditStepOperations::Pending(_)) => "Computing Changes...",
-                Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
+                Some(EditStepOperations::Ready(_)) => "Apply Changes",
                 None => "Send",
             },
             None => "Send",

crates/assistant/src/context.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
-    prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
-    MessageId, MessageStatus,
+    prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
+    LanguageModelCompletionProvider, MessageId, MessageStatus,
 };
 use anyhow::{anyhow, Context as _, Result};
 use assistant_slash_command::{
@@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
 use language::{
     AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
 };
-use language_model::LanguageModelRequestMessage;
-use language_model::{LanguageModelRequest, Role};
+use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
 use open_ai::Model as OpenAiModel;
 use paths::contexts_dir;
 use project::Project;
+use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use std::{
     cmp,
@@ -352,7 +352,7 @@ pub struct EditSuggestion {
     pub range: Range<language::Anchor>,
     /// If None, assume this is a suggestion to delete the range rather than transform it.
     pub description: Option<String>,
-    pub prepend_newline: bool,
+    pub initial_insertion: Option<InitialInsertion>,
 }
 
 impl EditStep {
@@ -361,7 +361,7 @@ impl EditStep {
         project: &Model<Project>,
         cx: &AppContext,
     ) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
-        let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
+        let Some(EditStepOperations::Ready(operations)) = &self.operations else {
             return Task::ready(HashMap::default());
         };
 
@@ -471,32 +471,28 @@ impl EditStep {
 }
 
 pub enum EditStepOperations {
-    Pending(Task<Result<()>>),
-    Parsed {
-        operations: Vec<EditOperation>,
-        raw_output: String,
-    },
+    Pending(Task<Option<()>>),
+    Ready(Vec<EditOperation>),
 }
 
 impl Debug for EditStepOperations {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
             EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
-            EditStepOperations::Parsed {
-                operations,
-                raw_output,
-            } => f
+            EditStepOperations::Ready(operations) => f
                 .debug_struct("EditStepOperations::Parsed")
                 .field("operations", operations)
-                .field("raw_output", raw_output)
                 .finish(),
         }
     }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+/// A description of an operation to apply to one location in the codebase.
+#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
 pub struct EditOperation {
+    /// The path to the file containing the relevant operation
     pub path: String,
+    #[serde(flatten)]
     pub kind: EditOperationKind,
 }
 
@@ -523,7 +519,7 @@ impl EditOperation {
                 parse_status.changed().await?;
             }
 
-            let prepend_newline = kind.prepend_newline();
+            let initial_insertion = kind.initial_insertion();
             let suggestion_range = if let Some(symbol) = kind.symbol() {
                 let outline = buffer
                     .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@@ -601,39 +597,61 @@ impl EditOperation {
                 EditSuggestion {
                     range: suggestion_range,
                     description: kind.description().map(ToString::to_string),
-                    prepend_newline,
+                    initial_insertion,
                 },
             ))
         })
     }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
+#[serde(tag = "kind")]
 pub enum EditOperationKind {
+    /// Rewrite the specified symbol in its entirely based on the given description.
     Update {
+        /// A full path to the symbol to be rewritten from the provided list.
         symbol: String,
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Create a new file with the given path based on the given description.
     Create {
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Insert a new symbol based on the given description before the specified symbol.
     InsertSiblingBefore {
+        /// A full path to the symbol to be rewritten from the provided list.
         symbol: String,
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Insert a new symbol based on the given description after the specified symbol.
     InsertSiblingAfter {
+        /// A full path to the symbol to be rewritten from the provided list.
         symbol: String,
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Insert a new symbol as a child of the specified symbol at the start.
     PrependChild {
+        /// An optional full path to the symbol to be rewritten from the provided list.
+        /// If not provided, the edit should be applied at the top of the file.
         symbol: Option<String>,
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Insert a new symbol as a child of the specified symbol at the end.
     AppendChild {
+        /// An optional full path to the symbol to be rewritten from the provided list.
+        /// If not provided, the edit should be applied at the top of the file.
         symbol: Option<String>,
+        /// A brief one-line description of the change that should be applied.
         description: String,
     },
+    /// Delete the specified symbol.
     Delete {
+        /// A full path to the symbol to be rewritten from the provided list.
         symbol: String,
     },
 }
@@ -663,13 +681,13 @@ impl EditOperationKind {
         }
     }
 
-    pub fn prepend_newline(&self) -> bool {
+    pub fn initial_insertion(&self) -> Option<InitialInsertion> {
         match self {
-            Self::PrependChild { .. }
-            | Self::AppendChild { .. }
-            | Self::InsertSiblingAfter { .. }
-            | Self::InsertSiblingBefore { .. } => true,
-            _ => false,
+            EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
+            EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
+            EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
+            EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
+            _ => None,
         }
     }
 }
@@ -1137,18 +1155,15 @@ impl Context {
                     .timer(Duration::from_millis(200))
                     .await;
 
-                if let Some(token_count) = cx.update(|cx| {
-                    LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
-                })? {
-                    let token_count = token_count.await?;
-
-                    this.update(&mut cx, |this, cx| {
-                        this.token_count = Some(token_count);
-                        cx.notify()
-                    })?;
-                }
-
-                anyhow::Ok(())
+                let token_count = cx
+                    .update(|cx| {
+                        LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+                    })?
+                    .await?;
+                this.update(&mut cx, |this, cx| {
+                    this.token_count = Some(token_count);
+                    cx.notify()
+                })
             }
             .log_err()
         });
@@ -1304,7 +1319,24 @@ impl Context {
         &self,
         edit_step: &EditStep,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
+    ) -> Task<Option<()>> {
+        #[derive(Debug, Deserialize, JsonSchema)]
+        struct EditTool {
+            /// A sequence of operations to apply to the codebase.
+            /// When multiple operations are required for a step, be sure to include multiple operations in this list.
+            operations: Vec<EditOperation>,
+        }
+
+        impl LanguageModelTool for EditTool {
+            fn name() -> String {
+                "edit".into()
+            }
+
+            fn description() -> String {
+                "suggest edits to one or more locations in the codebase".into()
+            }
+        }
+
         let mut request = self.to_completion_request(cx);
         let edit_step_range = edit_step.source_range.clone();
         let step_text = self
@@ -1313,160 +1345,41 @@ impl Context {
             .text_for_range(edit_step_range.clone())
             .collect::<String>();
 
-        cx.spawn(|this, mut cx| async move {
-            let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
-
-            let mut prompt = prompt_store.operations_prompt();
-            prompt.push_str(&step_text);
+        cx.spawn(|this, mut cx| {
+            async move {
+                let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
 
-            request.messages.push(LanguageModelRequestMessage {
-                role: Role::User,
-                content: prompt,
-            });
+                let mut prompt = prompt_store.operations_prompt();
+                prompt.push_str(&step_text);
 
-            let raw_output = cx
-                .update(|cx| {
-                    LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
-                })?
-                .await?;
+                request.messages.push(LanguageModelRequestMessage {
+                    role: Role::User,
+                    content: prompt,
+                });
 
-            let operations = Self::parse_edit_operations(&raw_output);
-            this.update(&mut cx, |this, cx| {
-                let step_index = this
-                    .edit_steps
-                    .binary_search_by(|step| {
-                        step.source_range
-                            .cmp(&edit_step_range, this.buffer.read(cx))
-                    })
-                    .map_err(|_| anyhow!("edit step not found"))?;
-                if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
-                    edit_step.operations = Some(EditStepOperations::Parsed {
-                        operations,
-                        raw_output,
-                    });
-                    cx.emit(ContextEvent::EditStepsChanged);
-                }
-                anyhow::Ok(())
-            })?
-        })
-    }
+                let tool_use = cx
+                    .update(|cx| {
+                        LanguageModelCompletionProvider::read_global(cx)
+                            .use_tool::<EditTool>(request, cx)
+                    })?
+                    .await?;
 
-    fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
-        let Some(start_ix) = xml.find("<operations>") else {
-            return Vec::new();
-        };
-        let Some(end_ix) = xml[start_ix..].find("</operations>") else {
-            return Vec::new();
-        };
-        let end_ix = end_ix + start_ix + "</operations>".len();
-
-        let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
-        doc.map_or(Vec::new(), |doc| {
-            doc.root_element()
-                .children()
-                .map(|node| {
-                    let tag_name = node.tag_name().name();
-                    let path = node
-                        .attribute("path")
-                        .with_context(|| {
-                            format!("invalid node {node:?}, missing attribute 'path'")
-                        })?
-                        .to_string();
-                    let kind = match tag_name {
-                        "update" => EditOperationKind::Update {
-                            symbol: node
-                                .attribute("symbol")
-                                .with_context(|| {
-                                    format!("invalid node {node:?}, missing attribute 'symbol'")
-                                })?
-                                .to_string(),
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "create" => EditOperationKind::Create {
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
-                            symbol: node
-                                .attribute("symbol")
-                                .with_context(|| {
-                                    format!("invalid node {node:?}, missing attribute 'symbol'")
-                                })?
-                                .to_string(),
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
-                            symbol: node
-                                .attribute("symbol")
-                                .with_context(|| {
-                                    format!("invalid node {node:?}, missing attribute 'symbol'")
-                                })?
-                                .to_string(),
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "prepend_child" => EditOperationKind::PrependChild {
-                            symbol: node.attribute("symbol").map(String::from),
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "append_child" => EditOperationKind::AppendChild {
-                            symbol: node.attribute("symbol").map(String::from),
-                            description: node
-                                .attribute("description")
-                                .with_context(|| {
-                                    format!(
-                                        "invalid node {node:?}, missing attribute 'description'"
-                                    )
-                                })?
-                                .to_string(),
-                        },
-                        "delete" => EditOperationKind::Delete {
-                            symbol: node
-                                .attribute("symbol")
-                                .with_context(|| {
-                                    format!("invalid node {node:?}, missing attribute 'symbol'")
-                                })?
-                                .to_string(),
-                        },
-                        _ => return Err(anyhow!("invalid node {node:?}")),
-                    };
-                    anyhow::Ok(EditOperation { path, kind })
-                })
-                .filter_map(|op| op.log_err())
-                .collect()
+                this.update(&mut cx, |this, cx| {
+                    let step_index = this
+                        .edit_steps
+                        .binary_search_by(|step| {
+                            step.source_range
+                                .cmp(&edit_step_range, this.buffer.read(cx))
+                        })
+                        .map_err(|_| anyhow!("edit step not found"))?;
+                    if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
+                        edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
+                        cx.emit(ContextEvent::EditStepsChanged);
+                    }
+                    anyhow::Ok(())
+                })?
+            }
+            .log_err()
         })
     }
 
@@ -3083,55 +2996,6 @@ mod tests {
         }
     }
 
-    #[test]
-    fn test_parse_edit_operations() {
-        let operations = indoc! {r#"
-            Here are the operations to make all fields of the Canvas struct private:
-
-            <operations>
-                <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
-                <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
-                <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
-                <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
-            </operations>
-        "#};
-
-        let parsed_operations = Context::parse_edit_operations(operations);
-        assert_eq!(
-            parsed_operations,
-            vec![
-                EditOperation {
-                    path: "font-kit/src/canvas.rs".to_string(),
-                    kind: EditOperationKind::Update {
-                        symbol: "pub struct Canvas pub pixels".to_string(),
-                        description: "Remove pub keyword from pixels field".to_string(),
-                    },
-                },
-                EditOperation {
-                    path: "font-kit/src/canvas.rs".to_string(),
-                    kind: EditOperationKind::Update {
-                        symbol: "pub struct Canvas pub size".to_string(),
-                        description: "Remove pub keyword from size field".to_string(),
-                    },
-                },
-                EditOperation {
-                    path: "font-kit/src/canvas.rs".to_string(),
-                    kind: EditOperationKind::Update {
-                        symbol: "pub struct Canvas pub stride".to_string(),
-                        description: "Remove pub keyword from stride field".to_string(),
-                    },
-                },
-                EditOperation {
-                    path: "font-kit/src/canvas.rs".to_string(),
-                    kind: EditOperationKind::Update {
-                        symbol: "pub struct Canvas pub format".to_string(),
-                        description: "Remove pub keyword from format field".to_string(),
-                    },
-                },
-            ]
-        );
-    }
-
     #[gpui::test]
     async fn test_serialization(cx: &mut TestAppContext) {
         let settings_store = cx.update(SettingsStore::test);

crates/assistant/src/inline_assistant.rs 🔗

@@ -17,7 +17,7 @@ use editor::{
 use fs::Fs;
 use futures::{
     channel::mpsc,
-    future::LocalBoxFuture,
+    future::{BoxFuture, LocalBoxFuture},
     stream::{self, BoxStream},
     SinkExt, Stream, StreamExt,
 };
@@ -36,7 +36,7 @@ use similar::TextDiff;
 use smol::future::FutureExt;
 use std::{
     cmp,
-    future::Future,
+    future::{self, Future},
     mem,
     ops::{Range, RangeInclusive},
     pin::Pin,
@@ -46,7 +46,7 @@ use std::{
 };
 use theme::ThemeSettings;
 use ui::{prelude::*, IconButtonShape, Tooltip};
-use util::RangeExt;
+use util::{RangeExt, ResultExt};
 use workspace::{notifications::NotificationId, Toast, Workspace};
 
 pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
@@ -187,7 +187,13 @@ impl InlineAssistant {
             let [prompt_block_id, end_block_id] =
                 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 
-            assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
+            assists.push((
+                assist_id,
+                range,
+                prompt_editor,
+                prompt_block_id,
+                end_block_id,
+            ));
         }
 
         let editor_assists = self
@@ -195,7 +201,7 @@ impl InlineAssistant {
             .entry(editor.downgrade())
             .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
         let mut assist_group = InlineAssistGroup::new();
-        for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
+        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
             self.assists.insert(
                 assist_id,
                 InlineAssist::new(
@@ -206,6 +212,7 @@ impl InlineAssistant {
                     &prompt_editor,
                     prompt_block_id,
                     end_block_id,
+                    range,
                     prompt_editor.read(cx).codegen.clone(),
                     workspace.clone(),
                     cx,
@@ -227,7 +234,7 @@ impl InlineAssistant {
         editor: &View<Editor>,
         mut range: Range<Anchor>,
         initial_prompt: String,
-        initial_insertion: Option<String>,
+        initial_insertion: Option<InitialInsertion>,
         workspace: Option<WeakView<Workspace>>,
         assistant_panel: Option<&View<AssistantPanel>>,
         cx: &mut WindowContext,
@@ -239,22 +246,30 @@ impl InlineAssistant {
         let assist_id = self.next_assist_id.post_inc();
 
         let buffer = editor.read(cx).buffer().clone();
-        let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
-            buffer.update(cx, |buffer, cx| {
-                buffer.start_transaction(cx);
-                buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
-                buffer.end_transaction(cx)
-            })
-        });
+        {
+            let snapshot = buffer.read(cx).read(cx);
+
+            let mut point_range = range.to_point(&snapshot);
+            if point_range.is_empty() {
+                point_range.start.column = 0;
+                point_range.end.column = 0;
+            } else {
+                point_range.start.column = 0;
+                if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
+                    point_range.end.row -= 1;
+                }
+                point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
+            }
 
-        range.start = range.start.bias_left(&buffer.read(cx).read(cx));
-        range.end = range.end.bias_right(&buffer.read(cx).read(cx));
+            range.start = snapshot.anchor_before(point_range.start);
+            range.end = snapshot.anchor_after(point_range.end);
+        }
 
         let codegen = cx.new_model(|cx| {
             Codegen::new(
                 editor.read(cx).buffer().clone(),
                 range.clone(),
-                prepend_transaction_id,
+                initial_insertion,
                 self.telemetry.clone(),
                 cx,
             )
@@ -295,6 +310,7 @@ impl InlineAssistant {
                 &prompt_editor,
                 prompt_block_id,
                 end_block_id,
+                range,
                 prompt_editor.read(cx).codegen.clone(),
                 workspace.clone(),
                 cx,
@@ -445,7 +461,7 @@ impl InlineAssistant {
             let buffer = editor.buffer().read(cx).snapshot(cx);
             for assist_id in &editor_assists.assist_ids {
                 let assist = &self.assists[assist_id];
-                let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
+                let assist_range = assist.range.to_offset(&buffer);
                 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
                 {
                     if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
@@ -473,7 +489,7 @@ impl InlineAssistant {
             let buffer = editor.buffer().read(cx).snapshot(cx);
             for assist_id in &editor_assists.assist_ids {
                 let assist = &self.assists[assist_id];
-                let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
+                let assist_range = assist.range.to_offset(&buffer);
                 if assist.decorations.is_some()
                     && assist_range.contains(&selection.start)
                     && assist_range.contains(&selection.end)
@@ -551,7 +567,7 @@ impl InlineAssistant {
                         assist.codegen.read(cx).status,
                         CodegenStatus::Error(_) | CodegenStatus::Done
                     ) {
-                        let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
+                        let assist_range = assist.range.to_offset(&snapshot);
                         if edited_ranges
                             .iter()
                             .any(|range| range.overlaps(&assist_range))
@@ -721,7 +737,7 @@ impl InlineAssistant {
             });
         }
 
-        let position = assist.codegen.read(cx).range.start;
+        let position = assist.range.start;
         editor.update(cx, |editor, cx| {
             editor.change_selections(None, cx, |selections| {
                 selections.select_anchor_ranges([position..position])
@@ -740,8 +756,7 @@ impl InlineAssistant {
                     .0 as f32;
             } else {
                 let snapshot = editor.snapshot(cx);
-                let codegen = assist.codegen.read(cx);
-                let start_row = codegen
+                let start_row = assist
                     .range
                     .start
                     .to_display_point(&snapshot.display_snapshot)
@@ -829,11 +844,7 @@ impl InlineAssistant {
             return;
         }
 
-        let Some(user_prompt) = assist
-            .decorations
-            .as_ref()
-            .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
-        else {
+        let Some(user_prompt) = assist.user_prompt(cx) else {
             return;
         };
 
@@ -843,139 +854,19 @@ impl InlineAssistant {
             self.prompt_history.pop_front();
         }
 
-        let codegen = assist.codegen.clone();
-        let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
-            .active_model()
-            .map(|m| m.telemetry_id())
-            .unwrap_or_default();
-        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
-            if user_prompt.trim().to_lowercase() == "delete" {
-                async { Ok(stream::empty().boxed()) }.boxed_local()
-            } else {
-                let request = self.request_for_inline_assist(assist_id, cx);
-                let mut cx = cx.to_async();
-                async move {
-                    let request = request.await?;
-                    let chunks = cx
-                        .update(|cx| {
-                            LanguageModelCompletionProvider::read_global(cx)
-                                .stream_completion(request, cx)
-                        })?
-                        .await?;
-                    Ok(chunks.boxed())
-                }
-                .boxed_local()
-            };
-        codegen.update(cx, |codegen, cx| {
-            codegen.start(telemetry_id, chunks, cx);
-        });
-    }
-
-    fn request_for_inline_assist(
-        &self,
-        assist_id: InlineAssistId,
-        cx: &mut WindowContext,
-    ) -> Task<Result<LanguageModelRequest>> {
-        cx.spawn(|mut cx| async move {
-            let (user_prompt, context_request, project_name, buffer, range) =
-                cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
-                    let assist = this.assists.get(&assist_id).context("invalid assist")?;
-                    let decorations = assist.decorations.as_ref().context("invalid assist")?;
-                    let editor = assist.editor.upgrade().context("invalid assist")?;
-                    let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
-                    let context_request = if assist.include_context {
-                        assist.workspace.as_ref().and_then(|workspace| {
-                            let workspace = workspace.upgrade()?.read(cx);
-                            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
-                            Some(
-                                assistant_panel
-                                    .read(cx)
-                                    .active_context(cx)?
-                                    .read(cx)
-                                    .to_completion_request(cx),
-                            )
-                        })
-                    } else {
-                        None
-                    };
-                    let project_name = assist.workspace.as_ref().and_then(|workspace| {
-                        let workspace = workspace.upgrade()?;
-                        Some(
-                            workspace
-                                .read(cx)
-                                .project()
-                                .read(cx)
-                                .worktree_root_names(cx)
-                                .collect::<Vec<&str>>()
-                                .join("/"),
-                        )
-                    });
-                    let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
-                    let range = assist.codegen.read(cx).range.clone();
-                    anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
-                })??;
-
-            let language = buffer.language_at(range.start);
-            let language_name = if let Some(language) = language.as_ref() {
-                if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
-                    None
-                } else {
-                    Some(language.name())
-                }
-            } else {
-                None
-            };
+        let assistant_panel_context = assist.assistant_panel_context(cx);
 
-            // Higher Temperature increases the randomness of model outputs.
-            // If Markdown or No Language is Known, increase the randomness for more creative output
-            // If Code, decrease temperature to get more deterministic outputs
-            let temperature = if let Some(language) = language_name.clone() {
-                if language.as_ref() == "Markdown" {
-                    1.0
-                } else {
-                    0.5
-                }
-            } else {
-                1.0
-            };
-
-            let prompt = cx
-                .background_executor()
-                .spawn(async move {
-                    let language_name = language_name.as_deref();
-                    let start = buffer.point_to_buffer_offset(range.start);
-                    let end = buffer.point_to_buffer_offset(range.end);
-                    let (buffer, range) = if let Some((start, end)) = start.zip(end) {
-                        let (start_buffer, start_buffer_offset) = start;
-                        let (end_buffer, end_buffer_offset) = end;
-                        if start_buffer.remote_id() == end_buffer.remote_id() {
-                            (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
-                        } else {
-                            return Err(anyhow!("invalid transformation range"));
-                        }
-                    } else {
-                        return Err(anyhow!("invalid transformation range"));
-                    };
-                    generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
-                })
-                .await?;
-
-            let mut messages = Vec::new();
-            if let Some(context_request) = context_request {
-                messages = context_request.messages;
-            }
-
-            messages.push(LanguageModelRequestMessage {
-                role: Role::User,
-                content: prompt,
-            });
-
-            Ok(LanguageModelRequest {
-                messages,
-                stop: vec!["|END|>".to_string()],
-                temperature,
+        assist
+            .codegen
+            .update(cx, |codegen, cx| {
+                codegen.start(
+                    assist.range.clone(),
+                    user_prompt,
+                    assistant_panel_context,
+                    cx,
+                )
             })
-        })
+            .log_err();
     }
 
     pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@@ -1006,12 +897,11 @@ impl InlineAssistant {
                 let codegen = assist.codegen.read(cx);
                 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
 
-                if codegen.edit_position != codegen.range.end {
-                    gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
-                }
+                gutter_pending_ranges
+                    .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
 
-                if codegen.range.start != codegen.edit_position {
-                    gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
+                if let Some(edit_position) = codegen.edit_position {
+                    gutter_transformed_ranges.push(assist.range.start..edit_position);
                 }
 
                 if assist.decorations.is_some() {
@@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
     })
 }
 
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum InitialInsertion {
+    NewlineBefore,
+    NewlineAfter,
+}
+
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 pub struct InlineAssistId(usize);
 
@@ -1629,24 +1525,20 @@ impl PromptEditor {
         let assist_id = self.id;
         self.pending_token_count = cx.spawn(|this, mut cx| async move {
             cx.background_executor().timer(Duration::from_secs(1)).await;
-            let request = cx
+            let token_count = cx
                 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
-                    inline_assistant.request_for_inline_assist(assist_id, cx)
-                })?
+                    let assist = inline_assistant
+                        .assists
+                        .get(&assist_id)
+                        .context("assist not found")?;
+                    anyhow::Ok(assist.count_tokens(cx))
+                })??
                 .await?;
 
-            if let Some(token_count) = cx.update(|cx| {
-                LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
-            })? {
-                let token_count = token_count.await?;
-
-                this.update(&mut cx, |this, cx| {
-                    this.token_count = Some(token_count);
-                    cx.notify();
-                })
-            } else {
-                Ok(())
-            }
+            this.update(&mut cx, |this, cx| {
+                this.token_count = Some(token_count);
+                cx.notify();
+            })
         })
     }
 
@@ -1855,6 +1747,7 @@ impl PromptEditor {
 
 struct InlineAssist {
     group_id: InlineAssistGroupId,
+    range: Range<Anchor>,
     editor: WeakView<Editor>,
     decorations: Option<InlineAssistDecorations>,
     codegen: Model<Codegen>,
@@ -1873,6 +1766,7 @@ impl InlineAssist {
         prompt_editor: &View<PromptEditor>,
         prompt_block_id: CustomBlockId,
         end_block_id: CustomBlockId,
+        range: Range<Anchor>,
         codegen: Model<Codegen>,
         workspace: Option<WeakView<Workspace>>,
         cx: &mut WindowContext,
@@ -1888,6 +1782,7 @@ impl InlineAssist {
                 removed_line_block_ids: HashSet::default(),
                 end_block_id,
             }),
+            range,
             codegen: codegen.clone(),
             workspace: workspace.clone(),
             _subscriptions: vec![
@@ -1963,6 +1858,41 @@ impl InlineAssist {
             ],
         }
     }
+
+    fn user_prompt(&self, cx: &AppContext) -> Option<String> {
+        let decorations = self.decorations.as_ref()?;
+        Some(decorations.prompt_editor.read(cx).prompt(cx))
+    }
+
+    fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
+        if self.include_context {
+            let workspace = self.workspace.as_ref()?;
+            let workspace = workspace.upgrade()?.read(cx);
+            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
+            Some(
+                assistant_panel
+                    .read(cx)
+                    .active_context(cx)?
+                    .read(cx)
+                    .to_completion_request(cx),
+            )
+        } else {
+            None
+        }
+    }
+
+    pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
+        let Some(user_prompt) = self.user_prompt(cx) else {
+            return future::ready(Err(anyhow!("no user prompt"))).boxed();
+        };
+        let assistant_panel_context = self.assistant_panel_context(cx);
+        self.codegen.read(cx).count_tokens(
+            self.range.clone(),
+            user_prompt,
+            assistant_panel_context,
+            cx,
+        )
+    }
 }
 
 struct InlineAssistDecorations {
@@ -1982,16 +1912,15 @@ pub struct Codegen {
     buffer: Model<MultiBuffer>,
     old_buffer: Model<Buffer>,
     snapshot: MultiBufferSnapshot,
-    range: Range<Anchor>,
-    edit_position: Anchor,
+    edit_position: Option<Anchor>,
     last_equal_ranges: Vec<Range<Anchor>>,
-    prepend_transaction_id: Option<TransactionId>,
-    generation_transaction_id: Option<TransactionId>,
+    transaction_id: Option<TransactionId>,
     status: CodegenStatus,
     generation: Task<()>,
     diff: Diff,
     telemetry: Option<Arc<Telemetry>>,
     _subscription: gpui::Subscription,
+    initial_insertion: Option<InitialInsertion>,
 }
 
 enum CodegenStatus {
@@ -2015,7 +1944,7 @@ impl Codegen {
     pub fn new(
         buffer: Model<MultiBuffer>,
         range: Range<Anchor>,
-        prepend_transaction_id: Option<TransactionId>,
+        initial_insertion: Option<InitialInsertion>,
         telemetry: Option<Arc<Telemetry>>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -2044,17 +1973,16 @@ impl Codegen {
         Self {
             buffer: buffer.clone(),
             old_buffer,
-            edit_position: range.start,
-            range,
+            edit_position: None,
             snapshot,
             last_equal_ranges: Default::default(),
-            prepend_transaction_id,
-            generation_transaction_id: None,
+            transaction_id: None,
             status: CodegenStatus::Idle,
             generation: Task::ready(()),
             diff: Diff::default(),
             telemetry,
             _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
+            initial_insertion,
         }
     }
 
@@ -2065,13 +1993,8 @@ impl Codegen {
         cx: &mut ModelContext<Self>,
     ) {
         if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
-            if self.generation_transaction_id == Some(*transaction_id) {
-                self.generation_transaction_id = None;
-                self.generation = Task::ready(());
-                cx.emit(CodegenEvent::Undone);
-            } else if self.prepend_transaction_id == Some(*transaction_id) {
-                self.prepend_transaction_id = None;
-                self.generation_transaction_id = None;
+            if self.transaction_id == Some(*transaction_id) {
+                self.transaction_id = None;
                 self.generation = Task::ready(());
                 cx.emit(CodegenEvent::Undone);
             }
@@ -2082,19 +2005,152 @@ impl Codegen {
         &self.last_equal_ranges
     }
 
+    pub fn count_tokens(
+        &self,
+        edit_range: Range<Anchor>,
+        user_prompt: String,
+        assistant_panel_context: Option<LanguageModelRequest>,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
+        LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+    }
+
     pub fn start(
         &mut self,
-        telemetry_id: String,
+        mut edit_range: Range<Anchor>,
+        user_prompt: String,
+        assistant_panel_context: Option<LanguageModelRequest>,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<()> {
+        self.undo(cx);
+
+        // Handle initial insertion
+        self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
+            self.buffer.update(cx, |buffer, cx| {
+                buffer.start_transaction(cx);
+                let offset = edit_range.start.to_offset(&self.snapshot);
+                let edit_position;
+                match initial_insertion {
+                    InitialInsertion::NewlineBefore => {
+                        buffer.edit([(offset..offset, "\n\n")], None, cx);
+                        self.snapshot = buffer.snapshot(cx);
+                        edit_position = self.snapshot.anchor_after(offset + 1);
+                    }
+                    InitialInsertion::NewlineAfter => {
+                        buffer.edit([(offset..offset, "\n")], None, cx);
+                        self.snapshot = buffer.snapshot(cx);
+                        edit_position = self.snapshot.anchor_after(offset);
+                    }
+                }
+                self.edit_position = Some(edit_position);
+                edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
+                buffer.end_transaction(cx)
+            })
+        } else {
+            self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
+            None
+        };
+
+        let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
+            .active_model_telemetry_id()
+            .context("no active model")?;
+
+        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
+            .trim()
+            .to_lowercase()
+            == "delete"
+        {
+            async { Ok(stream::empty().boxed()) }.boxed_local()
+        } else {
+            let request =
+                self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
+            let chunks =
+                LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
+            async move { Ok(chunks.await?.boxed()) }.boxed_local()
+        };
+        self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
+        Ok(())
+    }
+
+    fn build_request(
+        &self,
+        user_prompt: String,
+        assistant_panel_context: Option<LanguageModelRequest>,
+        edit_range: Range<Anchor>,
+        cx: &AppContext,
+    ) -> LanguageModelRequest {
+        let buffer = self.buffer.read(cx).snapshot(cx);
+        let language = buffer.language_at(edit_range.start);
+        let language_name = if let Some(language) = language.as_ref() {
+            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+                None
+            } else {
+                Some(language.name())
+            }
+        } else {
+            None
+        };
+
+        // Higher Temperature increases the randomness of model outputs.
+        // If Markdown or No Language is Known, increase the randomness for more creative output
+        // If Code, decrease temperature to get more deterministic outputs
+        let temperature = if let Some(language) = language_name.clone() {
+            if language.as_ref() == "Markdown" {
+                1.0
+            } else {
+                0.5
+            }
+        } else {
+            1.0
+        };
+
+        let language_name = language_name.as_deref();
+        let start = buffer.point_to_buffer_offset(edit_range.start);
+        let end = buffer.point_to_buffer_offset(edit_range.end);
+        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+            let (start_buffer, start_buffer_offset) = start;
+            let (end_buffer, end_buffer_offset) = end;
+            if start_buffer.remote_id() == end_buffer.remote_id() {
+                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+            } else {
+                panic!("invalid transformation range");
+            }
+        } else {
+            panic!("invalid transformation range");
+        };
+        let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
+
+        let mut messages = Vec::new();
+        if let Some(context_request) = assistant_panel_context {
+            messages = context_request.messages;
+        }
+
+        messages.push(LanguageModelRequestMessage {
+            role: Role::User,
+            content: prompt,
+        });
+
+        LanguageModelRequest {
+            messages,
+            stop: vec!["|END|>".to_string()],
+            temperature,
+        }
+    }
+
+    pub fn handle_stream(
+        &mut self,
+        model_telemetry_id: String,
+        edit_range: Range<Anchor>,
         stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
         cx: &mut ModelContext<Self>,
     ) {
-        let range = self.range.clone();
         let snapshot = self.snapshot.clone();
         let selected_text = snapshot
-            .text_for_range(range.start..range.end)
+            .text_for_range(edit_range.start..edit_range.end)
             .collect::<Rope>();
 
-        let selection_start = range.start.to_point(&snapshot);
+        let selection_start = edit_range.start.to_point(&snapshot);
 
         // Start with the indentation of the first line in the selection
         let mut suggested_line_indent = snapshot
@@ -2105,7 +2161,7 @@ impl Codegen {
 
         // If the first line in the selection does not have indentation, check the following lines
         if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
-            for row in selection_start.row..=range.end.to_point(&snapshot).row {
+            for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
                 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
                 // Prefer tabs if a line in the selection uses tabs as indentation
                 if line_indent.kind == IndentKind::Tab {
@@ -2116,19 +2172,13 @@ impl Codegen {
         }
 
         let telemetry = self.telemetry.clone();
-        self.edit_position = range.start;
         self.diff = Diff::default();
         self.status = CodegenStatus::Pending;
-        if let Some(transaction_id) = self.generation_transaction_id.take() {
-            self.buffer
-                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
-        }
+        let mut edit_start = edit_range.start.to_offset(&snapshot);
         self.generation = cx.spawn(|this, mut cx| {
             async move {
                 let chunks = stream.await;
                 let generate = async {
-                    let mut edit_start = range.start.to_offset(&snapshot);
-
                     let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                     let diff: Task<anyhow::Result<()>> =
                         cx.background_executor().spawn(async move {
@@ -2218,7 +2268,7 @@ impl Codegen {
                                 telemetry.report_assistant_event(
                                     None,
                                     telemetry_events::AssistantKind::Inline,
-                                    telemetry_id,
+                                    model_telemetry_id,
                                     response_latency,
                                     error_message,
                                 );
@@ -2262,13 +2312,13 @@ impl Codegen {
                                     None,
                                     cx,
                                 );
-                                this.edit_position = snapshot.anchor_after(edit_start);
+                                this.edit_position = Some(snapshot.anchor_after(edit_start));
 
                                 buffer.end_transaction(cx)
                             });
 
                             if let Some(transaction) = transaction {
-                                if let Some(first_transaction) = this.generation_transaction_id {
+                                if let Some(first_transaction) = this.transaction_id {
                                     // Group all assistant edits into the first transaction.
                                     this.buffer.update(cx, |buffer, cx| {
                                         buffer.merge_transactions(
@@ -2278,14 +2328,14 @@ impl Codegen {
                                         )
                                     });
                                 } else {
-                                    this.generation_transaction_id = Some(transaction);
+                                    this.transaction_id = Some(transaction);
                                     this.buffer.update(cx, |buffer, cx| {
                                         buffer.finalize_last_transaction(cx)
                                     });
                                 }
                             }
 
-                            this.update_diff(cx);
+                            this.update_diff(edit_range.clone(), cx);
                             cx.notify();
                         })?;
                     }
@@ -2321,27 +2371,22 @@ impl Codegen {
     }
 
     pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
-        if let Some(transaction_id) = self.prepend_transaction_id.take() {
-            self.buffer
-                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
-        }
-
-        if let Some(transaction_id) = self.generation_transaction_id.take() {
+        if let Some(transaction_id) = self.transaction_id.take() {
             self.buffer
                 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
         }
     }
 
-    fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
+    fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
         if self.diff.task.is_some() {
             self.diff.should_update = true;
         } else {
             self.diff.should_update = false;
 
             let old_snapshot = self.snapshot.clone();
-            let old_range = self.range.to_point(&old_snapshot);
+            let old_range = edit_range.to_point(&old_snapshot);
             let new_snapshot = self.buffer.read(cx).snapshot(cx);
-            let new_range = self.range.to_point(&new_snapshot);
+            let new_range = edit_range.to_point(&new_snapshot);
 
             self.diff.task = Some(cx.spawn(|this, mut cx| async move {
                 let (deleted_row_ranges, inserted_row_ranges) = cx
@@ -2422,7 +2467,7 @@ impl Codegen {
                     this.diff.inserted_row_ranges = inserted_row_ranges;
                     this.diff.task = None;
                     if this.diff.should_update {
-                        this.update_diff(cx);
+                        this.update_diff(edit_range, cx);
                     }
                     cx.notify();
                 })
@@ -2629,12 +2674,14 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
         });
-        let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
 
         let (chunks_tx, chunks_rx) = mpsc::unbounded();
         codegen.update(cx, |codegen, cx| {
-            codegen.start(
+            codegen.handle_stream(
                 String::new(),
+                range,
                 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
                 cx,
             )
@@ -2690,12 +2737,14 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
         });
-        let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
 
         let (chunks_tx, chunks_rx) = mpsc::unbounded();
         codegen.update(cx, |codegen, cx| {
-            codegen.start(
+            codegen.handle_stream(
                 String::new(),
+                range.clone(),
                 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
                 cx,
             )
@@ -2755,12 +2804,14 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
         });
-        let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
 
         let (chunks_tx, chunks_rx) = mpsc::unbounded();
         codegen.update(cx, |codegen, cx| {
-            codegen.start(
+            codegen.handle_stream(
                 String::new(),
+                range.clone(),
                 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
                 cx,
             )
@@ -2819,12 +2870,14 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
         });
-        let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+        let codegen =
+            cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
 
         let (chunks_tx, chunks_rx) = mpsc::unbounded();
         codegen.update(cx, |codegen, cx| {
-            codegen.start(
+            codegen.handle_stream(
                 String::new(),
+                range.clone(),
                 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
                 cx,
             )

crates/assistant/src/prompt_library.rs 🔗

@@ -734,29 +734,27 @@ impl PromptLibrary {
                     const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
 
                     cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
-                    if let Some(token_count) = cx.update(|cx| {
-                        LanguageModelCompletionProvider::read_global(cx).count_tokens(
-                            LanguageModelRequest {
-                                messages: vec![LanguageModelRequestMessage {
-                                    role: Role::System,
-                                    content: body.to_string(),
-                                }],
-                                stop: Vec::new(),
-                                temperature: 1.,
-                            },
-                            cx,
-                        )
-                    })? {
-                        let token_count = token_count.await?;
+                    let token_count = cx
+                        .update(|cx| {
+                            LanguageModelCompletionProvider::read_global(cx).count_tokens(
+                                LanguageModelRequest {
+                                    messages: vec![LanguageModelRequestMessage {
+                                        role: Role::System,
+                                        content: body.to_string(),
+                                    }],
+                                    stop: Vec::new(),
+                                    temperature: 1.,
+                                },
+                                cx,
+                            )
+                        })?
+                        .await?;
 
-                        this.update(&mut cx, |this, cx| {
-                            let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
-                            prompt_editor.token_count = Some(token_count);
-                            cx.notify();
-                        })
-                    } else {
-                        Ok(())
-                    }
+                    this.update(&mut cx, |this, cx| {
+                        let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
+                        prompt_editor.token_count = Some(token_count);
+                        cx.notify();
+                    })
                 }
                 .log_err()
             });

crates/assistant/src/prompts.rs 🔗

@@ -6,8 +6,7 @@ pub fn generate_content_prompt(
     language_name: Option<&str>,
     buffer: BufferSnapshot,
     range: Range<usize>,
-    _project_name: Option<String>,
-) -> anyhow::Result<String> {
+) -> String {
     let mut prompt = String::new();
 
     let content_type = match language_name {
@@ -15,14 +14,16 @@ pub fn generate_content_prompt(
             writeln!(
                 prompt,
                 "Here's a file of text that I'm going to ask you to make an edit to."
-            )?;
+            )
+            .unwrap();
             "text"
         }
         Some(language_name) => {
             writeln!(
                 prompt,
                 "Here's a file of {language_name} that I'm going to ask you to make an edit to."
-            )?;
+            )
+            .unwrap();
             "code"
         }
     };
@@ -70,7 +71,7 @@ pub fn generate_content_prompt(
     write!(prompt, "</document>\n\n").unwrap();
 
     if is_truncated {
-        writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
+        writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
     }
 
     if range.is_empty() {
@@ -107,7 +108,7 @@ pub fn generate_content_prompt(
         prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
     }
 
-    Ok(prompt)
+    prompt
 }
 
 pub fn generate_terminal_assistant_prompt(

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -707,18 +707,15 @@ impl PromptEditor {
                     inline_assistant.request_for_inline_assist(assist_id, cx)
                 })??;
 
-            if let Some(token_count) = cx.update(|cx| {
-                LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
-            })? {
-                let token_count = token_count.await?;
-
-                this.update(&mut cx, |this, cx| {
-                    this.token_count = Some(token_count);
-                    cx.notify();
-                })
-            } else {
-                Ok(())
-            }
+            let token_count = cx
+                .update(|cx| {
+                    LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+                })?
+                .await?;
+            this.update(&mut cx, |this, cx| {
+                this.token_count = Some(token_count);
+                cx.notify();
+            })
         })
     }
 

crates/collab/src/rpc.rs 🔗

@@ -10,7 +10,7 @@ use crate::{
         ServerId, UpdatedChannelMessage, User, UserId,
     },
     executor::Executor,
-    AppState, Error, RateLimit, RateLimiter, Result,
+    AppState, Config, Error, RateLimit, RateLimiter, Result,
 };
 use anyhow::{anyhow, bail, Context as _};
 use async_tungstenite::tungstenite::{
@@ -605,17 +605,39 @@ impl Server {
             ))
             .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
             .add_message_handler(update_context)
+            .add_request_handler({
+                let app_state = app_state.clone();
+                move |request, response, session| {
+                    let app_state = app_state.clone();
+                    async move {
+                        complete_with_language_model(request, response, session, &app_state.config)
+                            .await
+                    }
+                }
+            })
             .add_streaming_request_handler({
                 let app_state = app_state.clone();
                 move |request, response, session| {
-                    complete_with_language_model(
-                        request,
-                        response,
-                        session,
-                        app_state.config.openai_api_key.clone(),
-                        app_state.config.google_ai_api_key.clone(),
-                        app_state.config.anthropic_api_key.clone(),
-                    )
+                    let app_state = app_state.clone();
+                    async move {
+                        stream_complete_with_language_model(
+                            request,
+                            response,
+                            session,
+                            &app_state.config,
+                        )
+                        .await
+                    }
+                }
+            })
+            .add_request_handler({
+                let app_state = app_state.clone();
+                move |request, response, session| {
+                    let app_state = app_state.clone();
+                    async move {
+                        count_language_model_tokens(request, response, session, &app_state.config)
+                            .await
+                    }
                 }
             })
             .add_request_handler({
@@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
 }
 
 async fn complete_with_language_model(
-    query: proto::QueryLanguageModel,
-    response: StreamingResponse<proto::QueryLanguageModel>,
+    request: proto::CompleteWithLanguageModel,
+    response: Response<proto::CompleteWithLanguageModel>,
     session: Session,
-    open_ai_api_key: Option<Arc<str>>,
-    google_ai_api_key: Option<Arc<str>>,
-    anthropic_api_key: Option<Arc<str>>,
+    config: &Config,
 ) -> Result<()> {
     let Some(session) = session.for_user() else {
         return Err(anyhow!("user not found"))?;
     };
     authorize_access_to_language_models(&session).await?;
 
-    match proto::LanguageModelRequestKind::from_i32(query.kind) {
-        Some(proto::LanguageModelRequestKind::Complete) => {
-            session
-                .rate_limiter
-                .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
-                .await?;
-        }
-        Some(proto::LanguageModelRequestKind::CountTokens) => {
-            session
-                .rate_limiter
-                .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
-                .await?;
+    session
+        .rate_limiter
+        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+        .await?;
+
+    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
+        Some(proto::LanguageModelProvider::Anthropic) => {
+            let api_key = config
+                .anthropic_api_key
+                .as_ref()
+                .context("no Anthropic AI API key configured on the server")?;
+            anthropic::complete(
+                session.http_client.as_ref(),
+                anthropic::ANTHROPIC_API_URL,
+                api_key,
+                serde_json::from_str(&request.request)?,
+            )
+            .await?
         }
-        None => Err(anyhow!("unknown request kind"))?,
-    }
+        _ => return Err(anyhow!("unsupported provider"))?,
+    };
+
+    response.send(proto::CompleteWithLanguageModelResponse {
+        completion: serde_json::to_string(&result)?,
+    })?;
+
+    Ok(())
+}
 
-    match proto::LanguageModelProvider::from_i32(query.provider) {
+async fn stream_complete_with_language_model(
+    request: proto::StreamCompleteWithLanguageModel,
+    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
+    session: Session,
+    config: &Config,
+) -> Result<()> {
+    let Some(session) = session.for_user() else {
+        return Err(anyhow!("user not found"))?;
+    };
+    authorize_access_to_language_models(&session).await?;
+
+    session
+        .rate_limiter
+        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+        .await?;
+
+    match proto::LanguageModelProvider::from_i32(request.provider) {
         Some(proto::LanguageModelProvider::Anthropic) => {
-            let api_key =
-                anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
+            let api_key = config
+                .anthropic_api_key
+                .as_ref()
+                .context("no Anthropic AI API key configured on the server")?;
             let mut chunks = anthropic::stream_completion(
                 session.http_client.as_ref(),
                 anthropic::ANTHROPIC_API_URL,
-                &api_key,
-                serde_json::from_str(&query.request)?,
+                api_key,
+                serde_json::from_str(&request.request)?,
                 None,
             )
             .await?;
-            while let Some(chunk) = chunks.next().await {
-                let chunk = chunk?;
-                response.send(proto::QueryLanguageModelResponse {
-                    response: serde_json::to_string(&chunk)?,
+            while let Some(event) = chunks.next().await {
+                let chunk = event?;
+                response.send(proto::StreamCompleteWithLanguageModelResponse {
+                    event: serde_json::to_string(&chunk)?,
                 })?;
             }
         }
         Some(proto::LanguageModelProvider::OpenAi) => {
-            let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
-            let mut chunks = open_ai::stream_completion(
+            let api_key = config
+                .openai_api_key
+                .as_ref()
+                .context("no OpenAI API key configured on the server")?;
+            let mut events = open_ai::stream_completion(
                 session.http_client.as_ref(),
                 open_ai::OPEN_AI_API_URL,
-                &api_key,
-                serde_json::from_str(&query.request)?,
+                api_key,
+                serde_json::from_str(&request.request)?,
                 None,
             )
             .await?;
-            while let Some(chunk) = chunks.next().await {
-                let chunk = chunk?;
-                response.send(proto::QueryLanguageModelResponse {
-                    response: serde_json::to_string(&chunk)?,
+            while let Some(event) = events.next().await {
+                let event = event?;
+                response.send(proto::StreamCompleteWithLanguageModelResponse {
+                    event: serde_json::to_string(&event)?,
                 })?;
             }
         }
         Some(proto::LanguageModelProvider::Google) => {
-            let api_key =
-                google_ai_api_key.context("no Google AI API key configured on the server")?;
-
-            match proto::LanguageModelRequestKind::from_i32(query.kind) {
-                Some(proto::LanguageModelRequestKind::Complete) => {
-                    let mut chunks = google_ai::stream_generate_content(
-                        session.http_client.as_ref(),
-                        google_ai::API_URL,
-                        &api_key,
-                        serde_json::from_str(&query.request)?,
-                    )
-                    .await?;
-                    while let Some(chunk) = chunks.next().await {
-                        let chunk = chunk?;
-                        response.send(proto::QueryLanguageModelResponse {
-                            response: serde_json::to_string(&chunk)?,
-                        })?;
-                    }
-                }
-                Some(proto::LanguageModelRequestKind::CountTokens) => {
-                    let tokens_response = google_ai::count_tokens(
-                        session.http_client.as_ref(),
-                        google_ai::API_URL,
-                        &api_key,
-                        serde_json::from_str(&query.request)?,
-                    )
-                    .await?;
-                    response.send(proto::QueryLanguageModelResponse {
-                        response: serde_json::to_string(&tokens_response)?,
-                    })?;
-                }
-                None => Err(anyhow!("unknown request kind"))?,
+            let api_key = config
+                .google_ai_api_key
+                .as_ref()
+                .context("no Google AI API key configured on the server")?;
+            let mut events = google_ai::stream_generate_content(
+                session.http_client.as_ref(),
+                google_ai::API_URL,
+                api_key,
+                serde_json::from_str(&request.request)?,
+            )
+            .await?;
+            while let Some(event) = events.next().await {
+                let event = event?;
+                response.send(proto::StreamCompleteWithLanguageModelResponse {
+                    event: serde_json::to_string(&event)?,
+                })?;
             }
         }
         None => return Err(anyhow!("unknown provider"))?,
@@ -4608,11 +4646,51 @@ async fn complete_with_language_model(
     Ok(())
 }
 
-struct CountTokensWithLanguageModelRateLimit;
+async fn count_language_model_tokens(
+    request: proto::CountLanguageModelTokens,
+    response: Response<proto::CountLanguageModelTokens>,
+    session: Session,
+    config: &Config,
+) -> Result<()> {
+    let Some(session) = session.for_user() else {
+        return Err(anyhow!("user not found"))?;
+    };
+    authorize_access_to_language_models(&session).await?;
+
+    session
+        .rate_limiter
+        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
+        .await?;
+
+    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
+        Some(proto::LanguageModelProvider::Google) => {
+            let api_key = config
+                .google_ai_api_key
+                .as_ref()
+                .context("no Google AI API key configured on the server")?;
+            google_ai::count_tokens(
+                session.http_client.as_ref(),
+                google_ai::API_URL,
+                api_key,
+                serde_json::from_str(&request.request)?,
+            )
+            .await?
+        }
+        _ => return Err(anyhow!("unsupported provider"))?,
+    };
+
+    response.send(proto::CountLanguageModelTokensResponse {
+        token_count: result.total_tokens as u32,
+    })?;
+
+    Ok(())
+}
+
+struct CountLanguageModelTokensRateLimit;
 
-impl RateLimit for CountTokensWithLanguageModelRateLimit {
+impl RateLimit for CountLanguageModelTokensRateLimit {
     fn capacity() -> usize {
-        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
             .ok()
             .and_then(|v| v.parse().ok())
             .unwrap_or(600) // Picked arbitrarily
@@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
     }
 
     fn db_name() -> &'static str {
-        "count-tokens-with-language-model"
+        "count-language-model-tokens"
     }
 }
 

crates/completion/Cargo.toml 🔗

@@ -26,7 +26,9 @@ anyhow.workspace = true
 futures.workspace = true
 gpui.workspace = true
 language_model.workspace = true
+schemars.workspace = true
 serde.workspace = true
+serde_json.workspace = true
 settings.workspace = true
 smol.workspace = true
 ui.workspace = true

crates/completion/src/completion.rs 🔗

@@ -3,10 +3,13 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
 use gpui::{AppContext, Global, Model, ModelContext, Task};
 use language_model::{
     LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
-    LanguageModelRequest,
+    LanguageModelRequest, LanguageModelTool,
 };
-use smol::lock::{Semaphore, SemaphoreGuardArc};
-use std::{pin::Pin, sync::Arc, task::Poll};
+use smol::{
+    future::FutureExt,
+    lock::{Semaphore, SemaphoreGuardArc},
+};
+use std::{future, pin::Pin, sync::Arc, task::Poll};
 use ui::Context;
 
 pub fn init(cx: &mut AppContext) {
@@ -143,11 +146,11 @@ impl LanguageModelCompletionProvider {
         &self,
         request: LanguageModelRequest,
         cx: &AppContext,
-    ) -> Option<BoxFuture<'static, Result<usize>>> {
+    ) -> BoxFuture<'static, Result<usize>> {
         if let Some(model) = self.active_model() {
-            Some(model.count_tokens(request, cx))
+            model.count_tokens(request, cx)
         } else {
-            None
+            future::ready(Err(anyhow!("no active model"))).boxed()
         }
     }
 
@@ -183,6 +186,29 @@ impl LanguageModelCompletionProvider {
             Ok(completion)
         })
     }
+
+    pub fn use_tool<T: LanguageModelTool>(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AppContext,
+    ) -> Task<Result<T>> {
+        if let Some(language_model) = self.active_model() {
+            cx.spawn(|cx| async move {
+                let schema = schemars::schema_for!(T);
+                let schema_json = serde_json::to_value(&schema).unwrap();
+                let request =
+                    language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
+                let response = request.await?;
+                Ok(serde_json::from_value(response)?)
+            })
+        } else {
+            Task::ready(Err(anyhow!("No active model set")))
+        }
+    }
+
+    pub fn active_model_telemetry_id(&self) -> Option<String> {
+        self.active_model.as_ref().map(|m| m.telemetry_id())
+    }
 }
 
 #[cfg(test)]

crates/language_model/src/language_model.rs 🔗

@@ -16,6 +16,8 @@ pub use model::*;
 pub use registry::*;
 pub use request::*;
 pub use role::*;
+use schemars::JsonSchema;
+use serde::de::DeserializeOwned;
 
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
     settings::init(cx);
@@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
         request: LanguageModelRequest,
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+
+    fn use_tool(
+        &self,
+        request: LanguageModelRequest,
+        name: String,
+        description: String,
+        schema: serde_json::Value,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>>;
+}
+
+pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
+    fn name() -> String;
+    fn description() -> String;
 }
 
 pub trait LanguageModelProvider: 'static {

crates/language_model/src/provider/anthropic.rs 🔗

@@ -1,5 +1,9 @@
-use anthropic::stream_completion;
-use anyhow::{anyhow, Result};
+use crate::{
+    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
+    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelRequest, Role,
+};
+use anyhow::{anyhow, Context as _, Result};
 use collections::BTreeMap;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -15,12 +19,6 @@ use theme::ThemeSettings;
 use ui::prelude::*;
 use util::ResultExt;
 
-use crate::{
-    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
-    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, Role,
-};
-
 const PROVIDER_ID: &str = "anthropic";
 const PROVIDER_NAME: &str = "Anthropic";
 
@@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
         .boxed()
 }
 
+impl AnthropicModel {
+    fn request_completion(
+        &self,
+        request: anthropic::Request,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<anthropic::Response>> {
+        let http_client = self.http_client.clone();
+
+        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
+            (state.api_key.clone(), settings.api_url.clone())
+        }) else {
+            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+        };
+
+        async move {
+            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+            anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
+        }
+        .boxed()
+    }
+
+    fn stream_completion(
+        &self,
+        request: anthropic::Request,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
+        let http_client = self.http_client.clone();
+
+        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
+            (
+                state.api_key.clone(),
+                settings.api_url.clone(),
+                settings.low_speed_timeout,
+            )
+        }) else {
+            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+        };
+
+        async move {
+            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+            let request = anthropic::stream_completion(
+                http_client.as_ref(),
+                &api_url,
+                &api_key,
+                request,
+                low_speed_timeout,
+            );
+            request.await
+        }
+        .boxed()
+    }
+}
+
 impl LanguageModel for AnthropicModel {
     fn id(&self) -> LanguageModelId {
         self.id.clone()
@@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         let request = request.into_anthropic(self.model.id().into());
-
-        let http_client = self.http_client.clone();
-
-        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
-            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
-            (
-                state.api_key.clone(),
-                settings.api_url.clone(),
-                settings.low_speed_timeout,
-            )
-        }) else {
-            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
-        };
-
+        let request = self.stream_completion(request, cx);
         async move {
-            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
-            let request = stream_completion(
-                http_client.as_ref(),
-                &api_url,
-                &api_key,
-                request,
-                low_speed_timeout,
-            );
             let response = request.await?;
             Ok(anthropic::extract_text_from_events(response).boxed())
         }
         .boxed()
     }
+
+    fn use_tool(
+        &self,
+        request: LanguageModelRequest,
+        tool_name: String,
+        tool_description: String,
+        input_schema: serde_json::Value,
+        cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        let mut request = request.into_anthropic(self.model.id().into());
+        request.tool_choice = Some(anthropic::ToolChoice::Tool {
+            name: tool_name.clone(),
+        });
+        request.tools = vec![anthropic::Tool {
+            name: tool_name.clone(),
+            description: tool_description,
+            input_schema,
+        }];
+
+        let response = self.request_completion(request, cx);
+        async move {
+            let response = response.await?;
+            response
+                .content
+                .into_iter()
+                .find_map(|content| {
+                    if let anthropic::Content::ToolUse { name, input, .. } = content {
+                        if name == tool_name {
+                            Some(input)
+                        } else {
+                            None
+                        }
+                    } else {
+                        None
+                    }
+                })
+                .context("tool not used")
+        }
+        .boxed()
+    }
 }
 
 struct AuthenticationPrompt {

crates/language_model/src/provider/cloud.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
     LanguageModelProviderState, LanguageModelRequest,
 };
-use anyhow::Result;
+use anyhow::{anyhow, Context as _, Result};
 use client::Client;
 use collections::BTreeMap;
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::{future, sync::Arc};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
 
@@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
                 };
                 async move {
                     let request = serde_json::to_string(&request)?;
-                    let response = client.request(proto::QueryLanguageModel {
-                        provider: proto::LanguageModelProvider::Google as i32,
-                        kind: proto::LanguageModelRequestKind::CountTokens as i32,
-                        request,
-                    });
-                    let response = response.await?;
-                    let response =
-                        serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
-                    Ok(response.total_tokens)
+                    let response = client
+                        .request(proto::CountLanguageModelTokens {
+                            provider: proto::LanguageModelProvider::Google as i32,
+                            request,
+                        })
+                        .await?;
+                    Ok(response.token_count as usize)
                 }
                 .boxed()
             }
@@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
                 let request = request.into_anthropic(model.id().into());
                 async move {
                     let request = serde_json::to_string(&request)?;
-                    let response = client.request_stream(proto::QueryLanguageModel {
-                        provider: proto::LanguageModelProvider::Anthropic as i32,
-                        kind: proto::LanguageModelRequestKind::Complete as i32,
-                        request,
-                    });
-                    let chunks = response.await?;
+                    let stream = client
+                        .request_stream(proto::StreamCompleteWithLanguageModel {
+                            provider: proto::LanguageModelProvider::Anthropic as i32,
+                            request,
+                        })
+                        .await?;
                     Ok(anthropic::extract_text_from_events(
-                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
                     )
                     .boxed())
                 }
@@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
                 let request = request.into_open_ai(model.id().into());
                 async move {
                     let request = serde_json::to_string(&request)?;
-                    let response = client.request_stream(proto::QueryLanguageModel {
-                        provider: proto::LanguageModelProvider::OpenAi as i32,
-                        kind: proto::LanguageModelRequestKind::Complete as i32,
-                        request,
-                    });
-                    let chunks = response.await?;
+                    let stream = client
+                        .request_stream(proto::StreamCompleteWithLanguageModel {
+                            provider: proto::LanguageModelProvider::OpenAi as i32,
+                            request,
+                        })
+                        .await?;
                     Ok(open_ai::extract_text_from_events(
-                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
                     )
                     .boxed())
                 }
@@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
                 let request = request.into_google(model.id().into());
                 async move {
                     let request = serde_json::to_string(&request)?;
-                    let response = client.request_stream(proto::QueryLanguageModel {
-                        provider: proto::LanguageModelProvider::Google as i32,
-                        kind: proto::LanguageModelRequestKind::Complete as i32,
-                        request,
-                    });
-                    let chunks = response.await?;
+                    let stream = client
+                        .request_stream(proto::StreamCompleteWithLanguageModel {
+                            provider: proto::LanguageModelProvider::Google as i32,
+                            request,
+                        })
+                        .await?;
                     Ok(google_ai::extract_text_from_events(
-                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
                     )
                     .boxed())
                 }
@@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
             }
         }
     }
+
+    fn use_tool(
+        &self,
+        request: LanguageModelRequest,
+        tool_name: String,
+        tool_description: String,
+        input_schema: serde_json::Value,
+        _cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        match &self.model {
+            CloudModel::Anthropic(model) => {
+                let client = self.client.clone();
+                let mut request = request.into_anthropic(model.id().into());
+                request.tool_choice = Some(anthropic::ToolChoice::Tool {
+                    name: tool_name.clone(),
+                });
+                request.tools = vec![anthropic::Tool {
+                    name: tool_name.clone(),
+                    description: tool_description,
+                    input_schema,
+                }];
+
+                async move {
+                    let request = serde_json::to_string(&request)?;
+                    let response = client
+                        .request(proto::CompleteWithLanguageModel {
+                            provider: proto::LanguageModelProvider::Anthropic as i32,
+                            request,
+                        })
+                        .await?;
+                    let response: anthropic::Response = serde_json::from_str(&response.completion)?;
+                    response
+                        .content
+                        .into_iter()
+                        .find_map(|content| {
+                            if let anthropic::Content::ToolUse { name, input, .. } = content {
+                                if name == tool_name {
+                                    Some(input)
+                                } else {
+                                    None
+                                }
+                            } else {
+                                None
+                            }
+                        })
+                        .context("tool not used")
+                }
+                .boxed()
+            }
+            CloudModel::OpenAi(_) => {
+                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
+            }
+            CloudModel::Google(_) => {
+                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
+            }
+        }
+    }
 }
 
 struct AuthenticationPrompt {

crates/language_model/src/provider/fake.rs 🔗

@@ -1,15 +1,17 @@
-use std::sync::{Arc, Mutex};
-
-use collections::HashMap;
-use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-
 use crate::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
     LanguageModelRequest,
 };
+use anyhow::anyhow;
+use collections::HashMap;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 use http_client::Result;
+use std::{
+    future,
+    sync::{Arc, Mutex},
+};
 use ui::WindowContext;
 
 pub fn language_model_id() -> LanguageModelId {
@@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
             .insert(serde_json::to_string(&request).unwrap(), tx);
         async move { Ok(rx.map(Ok).boxed()) }.boxed()
     }
+
+    fn use_tool(
+        &self,
+        _request: LanguageModelRequest,
+        _name: String,
+        _description: String,
+        _schema: serde_json::Value,
+        _cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        future::ready(Err(anyhow!("not implemented"))).boxed()
+    }
 }

crates/language_model/src/provider/google.rs 🔗

@@ -9,7 +9,7 @@ use gpui::{
 };
 use http_client::HttpClient;
 use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::prelude::*;
@@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
         }
         .boxed()
     }
+
+    fn use_tool(
+        &self,
+        _request: LanguageModelRequest,
+        _name: String,
+        _description: String,
+        _schema: serde_json::Value,
+        _cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        future::ready(Err(anyhow!("not implemented"))).boxed()
+    }
 }
 
 struct AuthenticationPrompt {

crates/language_model/src/provider/ollama.rs 🔗

@@ -6,7 +6,7 @@ use ollama::{
     get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 };
 use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
 use ui::{prelude::*, ButtonLike, ElevationIndex};
 
 use crate::{
@@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
         }
         .boxed()
     }
+
+    fn use_tool(
+        &self,
+        _request: LanguageModelRequest,
+        _name: String,
+        _description: String,
+        _schema: serde_json::Value,
+        _cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        future::ready(Err(anyhow!("not implemented"))).boxed()
+    }
 }
 
 struct DownloadOllamaMessage {

crates/language_model/src/provider/open_ai.rs 🔗

@@ -9,7 +9,7 @@ use gpui::{
 use http_client::HttpClient;
 use open_ai::stream_completion;
 use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::prelude::*;
@@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
         }
         .boxed()
     }
+
+    fn use_tool(
+        &self,
+        _request: LanguageModelRequest,
+        _name: String,
+        _description: String,
+        _schema: serde_json::Value,
+        _cx: &AsyncAppContext,
+    ) -> BoxFuture<'static, Result<serde_json::Value>> {
+        future::ready(Err(anyhow!("not implemented"))).boxed()
+    }
 }
 
 pub fn count_open_ai_tokens(

crates/language_model/src/request.rs 🔗

@@ -106,19 +106,27 @@ impl LanguageModelRequest {
             messages: new_messages
                 .into_iter()
                 .filter_map(|message| {
-                    Some(anthropic::RequestMessage {
+                    Some(anthropic::Message {
                         role: match message.role {
                             Role::User => anthropic::Role::User,
                             Role::Assistant => anthropic::Role::Assistant,
                             Role::System => return None,
                         },
-                        content: message.content,
+                        content: vec![anthropic::Content::Text {
+                            text: message.content,
+                        }],
                     })
                 })
                 .collect(),
-            stream: true,
             max_tokens: 4092,
-            system: system_message,
+            system: Some(system_message),
+            tools: Vec::new(),
+            tool_choice: None,
+            metadata: None,
+            stop_sequences: Vec::new(),
+            temperature: None,
+            top_k: None,
+            top_p: None,
         }
     }
 }

crates/proto/proto/zed.proto 🔗

@@ -194,8 +194,12 @@ message Envelope {
 
         JoinHostedProject join_hosted_project = 164;
 
-        QueryLanguageModel query_language_model = 224;
-        QueryLanguageModelResponse query_language_model_response = 225; // current max
+        CompleteWithLanguageModel complete_with_language_model = 226;
+        CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
+        StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
+        StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
+        CountLanguageModelTokens count_language_model_tokens = 230;
+        CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
         GetCachedEmbeddings get_cached_embeddings = 189;
         GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
         ComputeEmbeddings compute_embeddings = 191;
@@ -267,6 +271,7 @@ message Envelope {
 
     reserved 158 to 161;
     reserved 166 to 169;
+    reserved 224 to 225;
 }
 
 // Messages
@@ -2050,25 +2055,37 @@ enum LanguageModelRole {
     reserved 3;
 }
 
-message QueryLanguageModel {
+message CompleteWithLanguageModel {
     LanguageModelProvider provider = 1;
-    LanguageModelRequestKind kind = 2;
-    string request = 3;
+    string request = 2;
 }
 
-enum LanguageModelProvider {
-    Anthropic = 0;
-    OpenAI = 1;
-    Google = 2;
+message CompleteWithLanguageModelResponse {
+    string completion = 1;
+}
+
+message StreamCompleteWithLanguageModel {
+    LanguageModelProvider provider = 1;
+    string request = 2;
+}
+
+message StreamCompleteWithLanguageModelResponse {
+    string event = 1;
+}
+
+message CountLanguageModelTokens {
+    LanguageModelProvider provider = 1;
+    string request = 2;
 }
 
-enum LanguageModelRequestKind {
-    Complete = 0;
-    CountTokens = 1;
+message CountLanguageModelTokensResponse {
+    uint32 token_count = 1;
 }
 
-message QueryLanguageModelResponse {
-    string response = 1;
+enum LanguageModelProvider {
+    Anthropic = 0;
+    OpenAI = 1;
+    Google = 2;
 }
 
 message GetCachedEmbeddings {

crates/proto/src/proto.rs 🔗

@@ -294,8 +294,12 @@ messages!(
     (PrepareRename, Background),
     (PrepareRenameResponse, Background),
     (ProjectEntryResponse, Foreground),
-    (QueryLanguageModel, Background),
-    (QueryLanguageModelResponse, Background),
+    (CompleteWithLanguageModel, Background),
+    (CompleteWithLanguageModelResponse, Background),
+    (StreamCompleteWithLanguageModel, Background),
+    (StreamCompleteWithLanguageModelResponse, Background),
+    (CountLanguageModelTokens, Background),
+    (CountLanguageModelTokensResponse, Background),
     (RefreshInlayHints, Foreground),
     (RejoinChannelBuffers, Foreground),
     (RejoinChannelBuffersResponse, Foreground),
@@ -463,7 +467,12 @@ request_messages!(
     (PerformRename, PerformRenameResponse),
     (Ping, Ack),
     (PrepareRename, PrepareRenameResponse),
-    (QueryLanguageModel, QueryLanguageModelResponse),
+    (CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
+    (
+        StreamCompleteWithLanguageModel,
+        StreamCompleteWithLanguageModelResponse
+    ),
+    (CountLanguageModelTokens, CountLanguageModelTokensResponse),
     (RefreshInlayHints, Ack),
     (RejoinChannelBuffers, RejoinChannelBuffersResponse),
     (RejoinRoom, RejoinRoomResponse),