Detailed changes
@@ -435,7 +435,6 @@ dependencies = [
"rand 0.8.5",
"regex",
"rope",
- "roxmltree 0.20.0",
"schemars",
"search",
"semantic_index",
@@ -2641,7 +2640,9 @@ dependencies = [
"language_model",
"project",
"rand 0.8.5",
+ "schemars",
"serde",
+ "serde_json",
"settings",
"smol",
"text",
@@ -4237,7 +4238,7 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
dependencies = [
- "roxmltree 0.19.0",
+ "roxmltree",
]
[[package]]
@@ -8918,12 +8919,6 @@ version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
-[[package]]
-name = "roxmltree"
-version = "0.20.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97"
-
[[package]]
name = "rpc"
version = "0.1.0"
@@ -11878,7 +11873,7 @@ dependencies = [
"kurbo",
"log",
"pico-args",
- "roxmltree 0.19.0",
+ "roxmltree",
"simplecss",
"siphasher 1.0.1",
"strict-num",
@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
-use std::{convert::TryFrom, time::Duration};
+use std::time::Duration;
use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
@@ -70,112 +70,53 @@ impl Model {
}
}
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
-}
+pub async fn complete(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result<Response> {
+ let uri = format!("{api_url}/v1/messages");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("Anthropic-Beta", "tools-2024-04-04")
+ .header("X-Api-Key", api_key)
+ .header("Content-Type", "application/json");
-impl TryFrom<String> for Role {
- type Error = anyhow::Error;
+ let serialized_request = serde_json::to_string(&request)?;
+ let request = request_builder.body(AsyncBody::from(serialized_request))?;
- fn try_from(value: String) -> Result<Self> {
- match value.as_str() {
- "user" => Ok(Self::User),
- "assistant" => Ok(Self::Assistant),
- _ => Err(anyhow!("invalid role '{value}'")),
- }
- }
-}
-
-impl From<Role> for String {
- fn from(val: Role) -> Self {
- match val {
- Role::User => "user".to_owned(),
- Role::Assistant => "assistant".to_owned(),
- }
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ let response_message: Response = serde_json::from_slice(&body)?;
+ Ok(response_message)
+ } else {
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ let body_str = std::str::from_utf8(&body)?;
+ Err(anyhow!(
+ "Failed to connect to API: {} {}",
+ response.status(),
+ body_str
+ ))
}
}
-#[derive(Debug, Serialize, Deserialize)]
-pub struct Request {
- pub model: String,
- pub messages: Vec<RequestMessage>,
- pub stream: bool,
- pub system: String,
- pub max_tokens: u32,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-#[derive(Deserialize, Serialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum ResponseEvent {
- MessageStart {
- message: ResponseMessage,
- },
- ContentBlockStart {
- index: u32,
- content_block: ContentBlock,
- },
- Ping {},
- ContentBlockDelta {
- index: u32,
- delta: TextDelta,
- },
- ContentBlockStop {
- index: u32,
- },
- MessageDelta {
- delta: ResponseMessage,
- usage: Usage,
- },
- MessageStop {},
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct ResponseMessage {
- #[serde(rename = "type")]
- pub message_type: Option<String>,
- pub id: Option<String>,
- pub role: Option<String>,
- pub content: Option<Vec<String>>,
- pub model: Option<String>,
- pub stop_reason: Option<String>,
- pub stop_sequence: Option<String>,
- pub usage: Option<Usage>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Usage {
- pub input_tokens: Option<u32>,
- pub output_tokens: Option<u32>,
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum ContentBlock {
- Text { text: String },
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum TextDelta {
- TextDelta { text: String },
-}
-
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
-) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+) -> Result<BoxStream<'static, Result<Event>>> {
+ let request = StreamingRequest {
+ base: request,
+ stream: true,
+ };
let uri = format!("{api_url}/v1/messages");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
@@ -187,7 +128,9 @@ pub async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let serialized_request = serde_json::to_string(&request)?;
+ let request = request_builder.body(AsyncBody::from(serialized_request))?;
+
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -212,7 +155,7 @@ pub async fn stream_completion(
let body_str = std::str::from_utf8(&body)?;
- match serde_json::from_str::<ResponseEvent>(body_str) {
+ match serde_json::from_str::<Event>(body_str) {
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
@@ -227,16 +170,18 @@ pub async fn stream_completion(
}
pub fn extract_text_from_events(
- response: impl Stream<Item = Result<ResponseEvent>>,
+ response: impl Stream<Item = Result<Event>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
- ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
- ContentBlock::Text { text } => Some(Ok(text)),
+ Event::ContentBlockStart { content_block, .. } => match content_block {
+ Content::Text { text } => Some(Ok(text)),
+ _ => None,
},
- ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
- TextDelta::TextDelta { text } => Some(Ok(text)),
+ Event::ContentBlockDelta { delta, .. } => match delta {
+ ContentDelta::TextDelta { text } => Some(Ok(text)),
+ _ => None,
},
_ => None,
},
@@ -245,42 +190,162 @@ pub fn extract_text_from_events(
})
}
-// #[cfg(test)]
-// mod tests {
-// use super::*;
-// use http::IsahcHttpClient;
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Message {
+ pub role: Role,
+ pub content: Vec<Content>,
+}
-// #[tokio::test]
-// async fn stream_completion_success() {
-// let http_client = IsahcHttpClient::new().unwrap();
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+}
-// let request = Request {
-// model: Model::Claude3Opus,
-// messages: vec![RequestMessage {
-// role: Role::User,
-// content: "Ping".to_string(),
-// }],
-// stream: true,
-// system: "Respond to ping with pong".to_string(),
-// max_tokens: 4096,
-// };
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum Content {
+ #[serde(rename = "text")]
+ Text { text: String },
+ #[serde(rename = "image")]
+ Image { source: ImageSource },
+ #[serde(rename = "tool_use")]
+ ToolUse {
+ id: String,
+ name: String,
+ input: serde_json::Value,
+ },
+ #[serde(rename = "tool_result")]
+ ToolResult {
+ tool_use_id: String,
+ content: String,
+ },
+}
-// let stream = stream_completion(
-// &http_client,
-// "https://api.anthropic.com",
-// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
-// request,
-// )
-// .await
-// .unwrap();
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ImageSource {
+ #[serde(rename = "type")]
+ pub source_type: String,
+ pub media_type: String,
+ pub data: String,
+}
-// stream
-// .for_each(|event| async {
-// match event {
-// Ok(event) => println!("{:?}", event),
-// Err(e) => eprintln!("Error: {:?}", e),
-// }
-// })
-// .await;
-// }
-// }
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Tool {
+ pub name: String,
+ pub description: String,
+ pub input_schema: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolChoice {
+ Auto,
+ Any,
+ Tool { name: String },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Request {
+ pub model: String,
+ pub max_tokens: u32,
+ pub messages: Vec<Message>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec<Tool>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option<ToolChoice>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub system: Option<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub metadata: Option<Metadata>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub stop_sequences: Vec<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub temperature: Option<f32>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub top_k: Option<u32>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub top_p: Option<f32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct StreamingRequest {
+ #[serde(flatten)]
+ pub base: Request,
+ pub stream: bool,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Metadata {
+ pub user_id: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Usage {
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub input_tokens: Option<u32>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub output_tokens: Option<u32>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Response {
+ pub id: String,
+ #[serde(rename = "type")]
+ pub response_type: String,
+ pub role: Role,
+ pub content: Vec<Content>,
+ pub model: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub stop_reason: Option<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub stop_sequence: Option<String>,
+ pub usage: Usage,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum Event {
+ #[serde(rename = "message_start")]
+ MessageStart { message: Response },
+ #[serde(rename = "content_block_start")]
+ ContentBlockStart {
+ index: usize,
+ content_block: Content,
+ },
+ #[serde(rename = "content_block_delta")]
+ ContentBlockDelta { index: usize, delta: ContentDelta },
+ #[serde(rename = "content_block_stop")]
+ ContentBlockStop { index: usize },
+ #[serde(rename = "message_delta")]
+ MessageDelta { delta: MessageDelta, usage: Usage },
+ #[serde(rename = "message_stop")]
+ MessageStop,
+ #[serde(rename = "ping")]
+ Ping,
+ #[serde(rename = "error")]
+ Error { error: ApiError },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum ContentDelta {
+ #[serde(rename = "text_delta")]
+ TextDelta { text: String },
+ #[serde(rename = "input_json_delta")]
+ InputJsonDelta { partial_json: String },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageDelta {
+ pub stop_reason: Option<String>,
+ pub stop_sequence: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct ApiError {
+ #[serde(rename = "type")]
+ pub error_type: String,
+ pub message: String,
+}
@@ -75,7 +75,6 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
picker.workspace = true
-roxmltree = "0.20.0"
[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
@@ -1232,12 +1232,16 @@ impl ContextEditor {
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
if let Some(step) = self.active_edit_step.as_ref() {
- InlineAssistant::update_global(cx, |assistant, cx| {
- for assist_id in &step.assist_ids {
- assistant.start_assist(*assist_id, cx);
- }
- !step.assist_ids.is_empty()
- })
+ let assist_ids = step.assist_ids.clone();
+ cx.window_context().defer(|cx| {
+ InlineAssistant::update_global(cx, |assistant, cx| {
+ for assist_id in assist_ids {
+ assistant.start_assist(assist_id, cx);
+ }
+ })
+ });
+
+ !step.assist_ids.is_empty()
} else {
false
}
@@ -1286,11 +1290,7 @@ impl ContextEditor {
.collect::<String>()
));
match &step.operations {
- Some(EditStepOperations::Parsed {
- operations,
- raw_output,
- }) => {
- output.push_str(&format!("Raw Output:\n{raw_output}\n"));
+ Some(EditStepOperations::Ready(operations)) => {
output.push_str("Parsed Operations:\n");
for op in operations {
output.push_str(&format!(" {:?}\n", op));
@@ -1794,13 +1794,12 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap()
};
- let initial_text = suggestion.prepend_newline.then(|| "\n".into());
InlineAssistant::update_global(cx, |assistant, cx| {
assist_ids.push(assistant.suggest_assist(
&editor,
range,
description,
- initial_text,
+ suggestion.initial_insertion,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -1862,9 +1861,11 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap()
};
- let initial_text =
- suggestion.prepend_newline.then(|| "\n".to_string());
- inline_assist_suggestions.push((range, description, initial_text));
+ inline_assist_suggestions.push((
+ range,
+ description,
+ suggestion.initial_insertion,
+ ));
}
}
}
@@ -1875,12 +1876,12 @@ impl ContextEditor {
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
cx.update(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
- for (range, description, initial_text) in inline_assist_suggestions {
+ for (range, description, initial_insertion) in inline_assist_suggestions {
assist_ids.push(assistant.suggest_assist(
&editor,
range,
description,
- initial_text,
+ initial_insertion,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -2188,7 +2189,7 @@ impl ContextEditor {
let button_text = match self.edit_step_for_cursor(cx) {
Some(edit_step) => match &edit_step.operations {
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
- Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
+ Some(EditStepOperations::Ready(_)) => "Apply Changes",
None => "Send",
},
None => "Send",
@@ -1,6 +1,6 @@
use crate::{
- prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
- MessageId, MessageStatus,
+ prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
+ LanguageModelCompletionProvider, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
-use language_model::LanguageModelRequestMessage;
-use language_model::{LanguageModelRequest, Role};
+use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
+use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
cmp,
@@ -352,7 +352,7 @@ pub struct EditSuggestion {
pub range: Range<language::Anchor>,
/// If None, assume this is a suggestion to delete the range rather than transform it.
pub description: Option<String>,
- pub prepend_newline: bool,
+ pub initial_insertion: Option<InitialInsertion>,
}
impl EditStep {
@@ -361,7 +361,7 @@ impl EditStep {
project: &Model<Project>,
cx: &AppContext,
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
- let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
+ let Some(EditStepOperations::Ready(operations)) = &self.operations else {
return Task::ready(HashMap::default());
};
@@ -471,32 +471,28 @@ impl EditStep {
}
pub enum EditStepOperations {
- Pending(Task<Result<()>>),
- Parsed {
- operations: Vec<EditOperation>,
- raw_output: String,
- },
+ Pending(Task<Option<()>>),
+ Ready(Vec<EditOperation>),
}
impl Debug for EditStepOperations {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
- EditStepOperations::Parsed {
- operations,
- raw_output,
- } => f
+ EditStepOperations::Ready(operations) => f
.debug_struct("EditStepOperations::Parsed")
.field("operations", operations)
- .field("raw_output", raw_output)
.finish(),
}
}
}
-#[derive(Clone, Debug, PartialEq, Eq)]
+/// A description of an operation to apply to one location in the codebase.
+#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
pub struct EditOperation {
+ /// The path to the file containing the relevant operation
pub path: String,
+ #[serde(flatten)]
pub kind: EditOperationKind,
}
@@ -523,7 +519,7 @@ impl EditOperation {
parse_status.changed().await?;
}
- let prepend_newline = kind.prepend_newline();
+ let initial_insertion = kind.initial_insertion();
let suggestion_range = if let Some(symbol) = kind.symbol() {
let outline = buffer
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@@ -601,39 +597,61 @@ impl EditOperation {
EditSuggestion {
range: suggestion_range,
description: kind.description().map(ToString::to_string),
- prepend_newline,
+ initial_insertion,
},
))
})
}
}
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
+#[serde(tag = "kind")]
pub enum EditOperationKind {
+ /// Rewrite the specified symbol in its entirely based on the given description.
Update {
+ /// A full path to the symbol to be rewritten from the provided list.
symbol: String,
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Create a new file with the given path based on the given description.
Create {
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Insert a new symbol based on the given description before the specified symbol.
InsertSiblingBefore {
+ /// A full path to the symbol to be rewritten from the provided list.
symbol: String,
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Insert a new symbol based on the given description after the specified symbol.
InsertSiblingAfter {
+ /// A full path to the symbol to be rewritten from the provided list.
symbol: String,
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Insert a new symbol as a child of the specified symbol at the start.
PrependChild {
+ /// An optional full path to the symbol to be rewritten from the provided list.
+ /// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Insert a new symbol as a child of the specified symbol at the end.
AppendChild {
+ /// An optional full path to the symbol to be rewritten from the provided list.
+ /// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
+ /// A brief one-line description of the change that should be applied.
description: String,
},
+ /// Delete the specified symbol.
Delete {
+ /// A full path to the symbol to be rewritten from the provided list.
symbol: String,
},
}
@@ -663,13 +681,13 @@ impl EditOperationKind {
}
}
- pub fn prepend_newline(&self) -> bool {
+ pub fn initial_insertion(&self) -> Option<InitialInsertion> {
match self {
- Self::PrependChild { .. }
- | Self::AppendChild { .. }
- | Self::InsertSiblingAfter { .. }
- | Self::InsertSiblingBefore { .. } => true,
- _ => false,
+ EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
+ EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
+ EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
+ EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
+ _ => None,
}
}
}
@@ -1137,18 +1155,15 @@ impl Context {
.timer(Duration::from_millis(200))
.await;
- if let Some(token_count) = cx.update(|cx| {
- LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
- })? {
- let token_count = token_count.await?;
-
- this.update(&mut cx, |this, cx| {
- this.token_count = Some(token_count);
- cx.notify()
- })?;
- }
-
- anyhow::Ok(())
+ let token_count = cx
+ .update(|cx| {
+ LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+ })?
+ .await?;
+ this.update(&mut cx, |this, cx| {
+ this.token_count = Some(token_count);
+ cx.notify()
+ })
}
.log_err()
});
@@ -1304,7 +1319,24 @@ impl Context {
&self,
edit_step: &EditStep,
cx: &mut ModelContext<Self>,
- ) -> Task<Result<()>> {
+ ) -> Task<Option<()>> {
+ #[derive(Debug, Deserialize, JsonSchema)]
+ struct EditTool {
+ /// A sequence of operations to apply to the codebase.
+ /// When multiple operations are required for a step, be sure to include multiple operations in this list.
+ operations: Vec<EditOperation>,
+ }
+
+ impl LanguageModelTool for EditTool {
+ fn name() -> String {
+ "edit".into()
+ }
+
+ fn description() -> String {
+ "suggest edits to one or more locations in the codebase".into()
+ }
+ }
+
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@@ -1313,160 +1345,41 @@ impl Context {
.text_for_range(edit_step_range.clone())
.collect::<String>();
- cx.spawn(|this, mut cx| async move {
- let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
-
- let mut prompt = prompt_store.operations_prompt();
- prompt.push_str(&step_text);
+ cx.spawn(|this, mut cx| {
+ async move {
+ let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
- request.messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: prompt,
- });
+ let mut prompt = prompt_store.operations_prompt();
+ prompt.push_str(&step_text);
- let raw_output = cx
- .update(|cx| {
- LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
- })?
- .await?;
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: prompt,
+ });
- let operations = Self::parse_edit_operations(&raw_output);
- this.update(&mut cx, |this, cx| {
- let step_index = this
- .edit_steps
- .binary_search_by(|step| {
- step.source_range
- .cmp(&edit_step_range, this.buffer.read(cx))
- })
- .map_err(|_| anyhow!("edit step not found"))?;
- if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
- edit_step.operations = Some(EditStepOperations::Parsed {
- operations,
- raw_output,
- });
- cx.emit(ContextEvent::EditStepsChanged);
- }
- anyhow::Ok(())
- })?
- })
- }
+ let tool_use = cx
+ .update(|cx| {
+ LanguageModelCompletionProvider::read_global(cx)
+ .use_tool::<EditTool>(request, cx)
+ })?
+ .await?;
- fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
- let Some(start_ix) = xml.find("<operations>") else {
- return Vec::new();
- };
- let Some(end_ix) = xml[start_ix..].find("</operations>") else {
- return Vec::new();
- };
- let end_ix = end_ix + start_ix + "</operations>".len();
-
- let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
- doc.map_or(Vec::new(), |doc| {
- doc.root_element()
- .children()
- .map(|node| {
- let tag_name = node.tag_name().name();
- let path = node
- .attribute("path")
- .with_context(|| {
- format!("invalid node {node:?}, missing attribute 'path'")
- })?
- .to_string();
- let kind = match tag_name {
- "update" => EditOperationKind::Update {
- symbol: node
- .attribute("symbol")
- .with_context(|| {
- format!("invalid node {node:?}, missing attribute 'symbol'")
- })?
- .to_string(),
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "create" => EditOperationKind::Create {
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
- symbol: node
- .attribute("symbol")
- .with_context(|| {
- format!("invalid node {node:?}, missing attribute 'symbol'")
- })?
- .to_string(),
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
- symbol: node
- .attribute("symbol")
- .with_context(|| {
- format!("invalid node {node:?}, missing attribute 'symbol'")
- })?
- .to_string(),
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "prepend_child" => EditOperationKind::PrependChild {
- symbol: node.attribute("symbol").map(String::from),
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "append_child" => EditOperationKind::AppendChild {
- symbol: node.attribute("symbol").map(String::from),
- description: node
- .attribute("description")
- .with_context(|| {
- format!(
- "invalid node {node:?}, missing attribute 'description'"
- )
- })?
- .to_string(),
- },
- "delete" => EditOperationKind::Delete {
- symbol: node
- .attribute("symbol")
- .with_context(|| {
- format!("invalid node {node:?}, missing attribute 'symbol'")
- })?
- .to_string(),
- },
- _ => return Err(anyhow!("invalid node {node:?}")),
- };
- anyhow::Ok(EditOperation { path, kind })
- })
- .filter_map(|op| op.log_err())
- .collect()
+ this.update(&mut cx, |this, cx| {
+ let step_index = this
+ .edit_steps
+ .binary_search_by(|step| {
+ step.source_range
+ .cmp(&edit_step_range, this.buffer.read(cx))
+ })
+ .map_err(|_| anyhow!("edit step not found"))?;
+ if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
+ edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
+ cx.emit(ContextEvent::EditStepsChanged);
+ }
+ anyhow::Ok(())
+ })?
+ }
+ .log_err()
})
}
@@ -3083,55 +2996,6 @@ mod tests {
}
}
- #[test]
- fn test_parse_edit_operations() {
- let operations = indoc! {r#"
- Here are the operations to make all fields of the Canvas struct private:
-
- <operations>
- <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
- <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
- <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
- <update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
- </operations>
- "#};
-
- let parsed_operations = Context::parse_edit_operations(operations);
- assert_eq!(
- parsed_operations,
- vec![
- EditOperation {
- path: "font-kit/src/canvas.rs".to_string(),
- kind: EditOperationKind::Update {
- symbol: "pub struct Canvas pub pixels".to_string(),
- description: "Remove pub keyword from pixels field".to_string(),
- },
- },
- EditOperation {
- path: "font-kit/src/canvas.rs".to_string(),
- kind: EditOperationKind::Update {
- symbol: "pub struct Canvas pub size".to_string(),
- description: "Remove pub keyword from size field".to_string(),
- },
- },
- EditOperation {
- path: "font-kit/src/canvas.rs".to_string(),
- kind: EditOperationKind::Update {
- symbol: "pub struct Canvas pub stride".to_string(),
- description: "Remove pub keyword from stride field".to_string(),
- },
- },
- EditOperation {
- path: "font-kit/src/canvas.rs".to_string(),
- kind: EditOperationKind::Update {
- symbol: "pub struct Canvas pub format".to_string(),
- description: "Remove pub keyword from format field".to_string(),
- },
- },
- ]
- );
- }
-
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
@@ -17,7 +17,7 @@ use editor::{
use fs::Fs;
use futures::{
channel::mpsc,
- future::LocalBoxFuture,
+ future::{BoxFuture, LocalBoxFuture},
stream::{self, BoxStream},
SinkExt, Stream, StreamExt,
};
@@ -36,7 +36,7 @@ use similar::TextDiff;
use smol::future::FutureExt;
use std::{
cmp,
- future::Future,
+ future::{self, Future},
mem,
ops::{Range, RangeInclusive},
pin::Pin,
@@ -46,7 +46,7 @@ use std::{
};
use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip};
-use util::RangeExt;
+use util::{RangeExt, ResultExt};
use workspace::{notifications::NotificationId, Toast, Workspace};
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
@@ -187,7 +187,13 @@ impl InlineAssistant {
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
- assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
+ assists.push((
+ assist_id,
+ range,
+ prompt_editor,
+ prompt_block_id,
+ end_block_id,
+ ));
}
let editor_assists = self
@@ -195,7 +201,7 @@ impl InlineAssistant {
.entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new();
- for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
+ for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
self.assists.insert(
assist_id,
InlineAssist::new(
@@ -206,6 +212,7 @@ impl InlineAssistant {
&prompt_editor,
prompt_block_id,
end_block_id,
+ range,
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
@@ -227,7 +234,7 @@ impl InlineAssistant {
editor: &View<Editor>,
mut range: Range<Anchor>,
initial_prompt: String,
- initial_insertion: Option<String>,
+ initial_insertion: Option<InitialInsertion>,
workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>,
cx: &mut WindowContext,
@@ -239,22 +246,30 @@ impl InlineAssistant {
let assist_id = self.next_assist_id.post_inc();
let buffer = editor.read(cx).buffer().clone();
- let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
- buffer.update(cx, |buffer, cx| {
- buffer.start_transaction(cx);
- buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
- buffer.end_transaction(cx)
- })
- });
+ {
+ let snapshot = buffer.read(cx).read(cx);
+
+ let mut point_range = range.to_point(&snapshot);
+ if point_range.is_empty() {
+ point_range.start.column = 0;
+ point_range.end.column = 0;
+ } else {
+ point_range.start.column = 0;
+ if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
+ point_range.end.row -= 1;
+ }
+ point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
+ }
- range.start = range.start.bias_left(&buffer.read(cx).read(cx));
- range.end = range.end.bias_right(&buffer.read(cx).read(cx));
+ range.start = snapshot.anchor_before(point_range.start);
+ range.end = snapshot.anchor_after(point_range.end);
+ }
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
- prepend_transaction_id,
+ initial_insertion,
self.telemetry.clone(),
cx,
)
@@ -295,6 +310,7 @@ impl InlineAssistant {
&prompt_editor,
prompt_block_id,
end_block_id,
+ range,
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
@@ -445,7 +461,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id];
- let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
+ let assist_range = assist.range.to_offset(&buffer);
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
{
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
@@ -473,7 +489,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id];
- let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
+ let assist_range = assist.range.to_offset(&buffer);
if assist.decorations.is_some()
&& assist_range.contains(&selection.start)
&& assist_range.contains(&selection.end)
@@ -551,7 +567,7 @@ impl InlineAssistant {
assist.codegen.read(cx).status,
CodegenStatus::Error(_) | CodegenStatus::Done
) {
- let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
+ let assist_range = assist.range.to_offset(&snapshot);
if edited_ranges
.iter()
.any(|range| range.overlaps(&assist_range))
@@ -721,7 +737,7 @@ impl InlineAssistant {
});
}
- let position = assist.codegen.read(cx).range.start;
+ let position = assist.range.start;
editor.update(cx, |editor, cx| {
editor.change_selections(None, cx, |selections| {
selections.select_anchor_ranges([position..position])
@@ -740,8 +756,7 @@ impl InlineAssistant {
.0 as f32;
} else {
let snapshot = editor.snapshot(cx);
- let codegen = assist.codegen.read(cx);
- let start_row = codegen
+ let start_row = assist
.range
.start
.to_display_point(&snapshot.display_snapshot)
@@ -829,11 +844,7 @@ impl InlineAssistant {
return;
}
- let Some(user_prompt) = assist
- .decorations
- .as_ref()
- .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
- else {
+ let Some(user_prompt) = assist.user_prompt(cx) else {
return;
};
@@ -843,139 +854,19 @@ impl InlineAssistant {
self.prompt_history.pop_front();
}
- let codegen = assist.codegen.clone();
- let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
- .active_model()
- .map(|m| m.telemetry_id())
- .unwrap_or_default();
- let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(stream::empty().boxed()) }.boxed_local()
- } else {
- let request = self.request_for_inline_assist(assist_id, cx);
- let mut cx = cx.to_async();
- async move {
- let request = request.await?;
- let chunks = cx
- .update(|cx| {
- LanguageModelCompletionProvider::read_global(cx)
- .stream_completion(request, cx)
- })?
- .await?;
- Ok(chunks.boxed())
- }
- .boxed_local()
- };
- codegen.update(cx, |codegen, cx| {
- codegen.start(telemetry_id, chunks, cx);
- });
- }
-
- fn request_for_inline_assist(
- &self,
- assist_id: InlineAssistId,
- cx: &mut WindowContext,
- ) -> Task<Result<LanguageModelRequest>> {
- cx.spawn(|mut cx| async move {
- let (user_prompt, context_request, project_name, buffer, range) =
- cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
- let assist = this.assists.get(&assist_id).context("invalid assist")?;
- let decorations = assist.decorations.as_ref().context("invalid assist")?;
- let editor = assist.editor.upgrade().context("invalid assist")?;
- let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
- let context_request = if assist.include_context {
- assist.workspace.as_ref().and_then(|workspace| {
- let workspace = workspace.upgrade()?.read(cx);
- let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
- Some(
- assistant_panel
- .read(cx)
- .active_context(cx)?
- .read(cx)
- .to_completion_request(cx),
- )
- })
- } else {
- None
- };
- let project_name = assist.workspace.as_ref().and_then(|workspace| {
- let workspace = workspace.upgrade()?;
- Some(
- workspace
- .read(cx)
- .project()
- .read(cx)
- .worktree_root_names(cx)
- .collect::<Vec<&str>>()
- .join("/"),
- )
- });
- let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
- let range = assist.codegen.read(cx).range.clone();
- anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
- })??;
-
- let language = buffer.language_at(range.start);
- let language_name = if let Some(language) = language.as_ref() {
- if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
- None
- } else {
- Some(language.name())
- }
- } else {
- None
- };
+ let assistant_panel_context = assist.assistant_panel_context(cx);
- // Higher Temperature increases the randomness of model outputs.
- // If Markdown or No Language is Known, increase the randomness for more creative output
- // If Code, decrease temperature to get more deterministic outputs
- let temperature = if let Some(language) = language_name.clone() {
- if language.as_ref() == "Markdown" {
- 1.0
- } else {
- 0.5
- }
- } else {
- 1.0
- };
-
- let prompt = cx
- .background_executor()
- .spawn(async move {
- let language_name = language_name.as_deref();
- let start = buffer.point_to_buffer_offset(range.start);
- let end = buffer.point_to_buffer_offset(range.end);
- let (buffer, range) = if let Some((start, end)) = start.zip(end) {
- let (start_buffer, start_buffer_offset) = start;
- let (end_buffer, end_buffer_offset) = end;
- if start_buffer.remote_id() == end_buffer.remote_id() {
- (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
- } else {
- return Err(anyhow!("invalid transformation range"));
- }
- } else {
- return Err(anyhow!("invalid transformation range"));
- };
- generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
- })
- .await?;
-
- let mut messages = Vec::new();
- if let Some(context_request) = context_request {
- messages = context_request.messages;
- }
-
- messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: prompt,
- });
-
- Ok(LanguageModelRequest {
- messages,
- stop: vec!["|END|>".to_string()],
- temperature,
+ assist
+ .codegen
+ .update(cx, |codegen, cx| {
+ codegen.start(
+ assist.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
})
- })
+ .log_err();
}
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@@ -1006,12 +897,11 @@ impl InlineAssistant {
let codegen = assist.codegen.read(cx);
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
- if codegen.edit_position != codegen.range.end {
- gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
- }
+ gutter_pending_ranges
+ .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
- if codegen.range.start != codegen.edit_position {
- gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
+ if let Some(edit_position) = codegen.edit_position {
+ gutter_transformed_ranges.push(assist.range.start..edit_position);
}
if assist.decorations.is_some() {
@@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
})
}
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum InitialInsertion {
+ NewlineBefore,
+ NewlineAfter,
+}
+
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);
@@ -1629,24 +1525,20 @@ impl PromptEditor {
let assist_id = self.id;
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
- let request = cx
+ let token_count = cx
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
- inline_assistant.request_for_inline_assist(assist_id, cx)
- })?
+ let assist = inline_assistant
+ .assists
+ .get(&assist_id)
+ .context("assist not found")?;
+ anyhow::Ok(assist.count_tokens(cx))
+ })??
.await?;
- if let Some(token_count) = cx.update(|cx| {
- LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
- })? {
- let token_count = token_count.await?;
-
- this.update(&mut cx, |this, cx| {
- this.token_count = Some(token_count);
- cx.notify();
- })
- } else {
- Ok(())
- }
+ this.update(&mut cx, |this, cx| {
+ this.token_count = Some(token_count);
+ cx.notify();
+ })
})
}
@@ -1855,6 +1747,7 @@ impl PromptEditor {
struct InlineAssist {
group_id: InlineAssistGroupId,
+ range: Range<Anchor>,
editor: WeakView<Editor>,
decorations: Option<InlineAssistDecorations>,
codegen: Model<Codegen>,
@@ -1873,6 +1766,7 @@ impl InlineAssist {
prompt_editor: &View<PromptEditor>,
prompt_block_id: CustomBlockId,
end_block_id: CustomBlockId,
+ range: Range<Anchor>,
codegen: Model<Codegen>,
workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
@@ -1888,6 +1782,7 @@ impl InlineAssist {
removed_line_block_ids: HashSet::default(),
end_block_id,
}),
+ range,
codegen: codegen.clone(),
workspace: workspace.clone(),
_subscriptions: vec![
@@ -1963,6 +1858,41 @@ impl InlineAssist {
],
}
}
+
+ fn user_prompt(&self, cx: &AppContext) -> Option<String> {
+ let decorations = self.decorations.as_ref()?;
+ Some(decorations.prompt_editor.read(cx).prompt(cx))
+ }
+
+ fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
+ if self.include_context {
+ let workspace = self.workspace.as_ref()?;
+ let workspace = workspace.upgrade()?.read(cx);
+ let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
+ Some(
+ assistant_panel
+ .read(cx)
+ .active_context(cx)?
+ .read(cx)
+ .to_completion_request(cx),
+ )
+ } else {
+ None
+ }
+ }
+
+ pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
+ let Some(user_prompt) = self.user_prompt(cx) else {
+ return future::ready(Err(anyhow!("no user prompt"))).boxed();
+ };
+ let assistant_panel_context = self.assistant_panel_context(cx);
+ self.codegen.read(cx).count_tokens(
+ self.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
+ }
}
struct InlineAssistDecorations {
@@ -1982,16 +1912,15 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
- range: Range<Anchor>,
- edit_position: Anchor,
+ edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
- prepend_transaction_id: Option<TransactionId>,
- generation_transaction_id: Option<TransactionId>,
+ transaction_id: Option<TransactionId>,
status: CodegenStatus,
generation: Task<()>,
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
+ initial_insertion: Option<InitialInsertion>,
}
enum CodegenStatus {
@@ -2015,7 +1944,7 @@ impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
- prepend_transaction_id: Option<TransactionId>,
+ initial_insertion: Option<InitialInsertion>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -2044,17 +1973,16 @@ impl Codegen {
Self {
buffer: buffer.clone(),
old_buffer,
- edit_position: range.start,
- range,
+ edit_position: None,
snapshot,
last_equal_ranges: Default::default(),
- prepend_transaction_id,
- generation_transaction_id: None,
+ transaction_id: None,
status: CodegenStatus::Idle,
generation: Task::ready(()),
diff: Diff::default(),
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
+ initial_insertion,
}
}
@@ -2065,13 +1993,8 @@ impl Codegen {
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
- if self.generation_transaction_id == Some(*transaction_id) {
- self.generation_transaction_id = None;
- self.generation = Task::ready(());
- cx.emit(CodegenEvent::Undone);
- } else if self.prepend_transaction_id == Some(*transaction_id) {
- self.prepend_transaction_id = None;
- self.generation_transaction_id = None;
+ if self.transaction_id == Some(*transaction_id) {
+ self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
}
@@ -2082,19 +2005,152 @@ impl Codegen {
&self.last_equal_ranges
}
+ pub fn count_tokens(
+ &self,
+ edit_range: Range<Anchor>,
+ user_prompt: String,
+ assistant_panel_context: Option<LanguageModelRequest>,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
+ LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+ }
+
pub fn start(
&mut self,
- telemetry_id: String,
+ mut edit_range: Range<Anchor>,
+ user_prompt: String,
+ assistant_panel_context: Option<LanguageModelRequest>,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<()> {
+ self.undo(cx);
+
+ // Handle initial insertion
+ self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.start_transaction(cx);
+ let offset = edit_range.start.to_offset(&self.snapshot);
+ let edit_position;
+ match initial_insertion {
+ InitialInsertion::NewlineBefore => {
+ buffer.edit([(offset..offset, "\n\n")], None, cx);
+ self.snapshot = buffer.snapshot(cx);
+ edit_position = self.snapshot.anchor_after(offset + 1);
+ }
+ InitialInsertion::NewlineAfter => {
+ buffer.edit([(offset..offset, "\n")], None, cx);
+ self.snapshot = buffer.snapshot(cx);
+ edit_position = self.snapshot.anchor_after(offset);
+ }
+ }
+ self.edit_position = Some(edit_position);
+ edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
+ buffer.end_transaction(cx)
+ })
+ } else {
+ self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
+ None
+ };
+
+ let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
+ .active_model_telemetry_id()
+ .context("no active model")?;
+
+ let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
+ .trim()
+ .to_lowercase()
+ == "delete"
+ {
+ async { Ok(stream::empty().boxed()) }.boxed_local()
+ } else {
+ let request =
+ self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
+ let chunks =
+ LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
+ async move { Ok(chunks.await?.boxed()) }.boxed_local()
+ };
+ self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
+ Ok(())
+ }
+
+ fn build_request(
+ &self,
+ user_prompt: String,
+ assistant_panel_context: Option<LanguageModelRequest>,
+ edit_range: Range<Anchor>,
+ cx: &AppContext,
+ ) -> LanguageModelRequest {
+ let buffer = self.buffer.read(cx).snapshot(cx);
+ let language = buffer.language_at(edit_range.start);
+ let language_name = if let Some(language) = language.as_ref() {
+ if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+ None
+ } else {
+ Some(language.name())
+ }
+ } else {
+ None
+ };
+
+ // Higher Temperature increases the randomness of model outputs.
+ // If Markdown or No Language is Known, increase the randomness for more creative output
+ // If Code, decrease temperature to get more deterministic outputs
+ let temperature = if let Some(language) = language_name.clone() {
+ if language.as_ref() == "Markdown" {
+ 1.0
+ } else {
+ 0.5
+ }
+ } else {
+ 1.0
+ };
+
+ let language_name = language_name.as_deref();
+ let start = buffer.point_to_buffer_offset(edit_range.start);
+ let end = buffer.point_to_buffer_offset(edit_range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+ let (start_buffer, start_buffer_offset) = start;
+ let (end_buffer, end_buffer_offset) = end;
+ if start_buffer.remote_id() == end_buffer.remote_id() {
+ (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+ } else {
+ panic!("invalid transformation range");
+ }
+ } else {
+ panic!("invalid transformation range");
+ };
+ let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
+
+ let mut messages = Vec::new();
+ if let Some(context_request) = assistant_panel_context {
+ messages = context_request.messages;
+ }
+
+ messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: prompt,
+ });
+
+ LanguageModelRequest {
+ messages,
+ stop: vec!["|END|>".to_string()],
+ temperature,
+ }
+ }
+
+ pub fn handle_stream(
+ &mut self,
+ model_telemetry_id: String,
+ edit_range: Range<Anchor>,
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
cx: &mut ModelContext<Self>,
) {
- let range = self.range.clone();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
- .text_for_range(range.start..range.end)
+ .text_for_range(edit_range.start..edit_range.end)
.collect::<Rope>();
- let selection_start = range.start.to_point(&snapshot);
+ let selection_start = edit_range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection
let mut suggested_line_indent = snapshot
@@ -2105,7 +2161,7 @@ impl Codegen {
// If the first line in the selection does not have indentation, check the following lines
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
- for row in selection_start.row..=range.end.to_point(&snapshot).row {
+ for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
// Prefer tabs if a line in the selection uses tabs as indentation
if line_indent.kind == IndentKind::Tab {
@@ -2116,19 +2172,13 @@ impl Codegen {
}
let telemetry = self.telemetry.clone();
- self.edit_position = range.start;
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
- if let Some(transaction_id) = self.generation_transaction_id.take() {
- self.buffer
- .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
- }
+ let mut edit_start = edit_range.start.to_offset(&snapshot);
self.generation = cx.spawn(|this, mut cx| {
async move {
let chunks = stream.await;
let generate = async {
- let mut edit_start = range.start.to_offset(&snapshot);
-
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff: Task<anyhow::Result<()>> =
cx.background_executor().spawn(async move {
@@ -2218,7 +2268,7 @@ impl Codegen {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
- telemetry_id,
+ model_telemetry_id,
response_latency,
error_message,
);
@@ -2262,13 +2312,13 @@ impl Codegen {
None,
cx,
);
- this.edit_position = snapshot.anchor_after(edit_start);
+ this.edit_position = Some(snapshot.anchor_after(edit_start));
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
- if let Some(first_transaction) = this.generation_transaction_id {
+ if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
@@ -2278,14 +2328,14 @@ impl Codegen {
)
});
} else {
- this.generation_transaction_id = Some(transaction);
+ this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
- this.update_diff(cx);
+ this.update_diff(edit_range.clone(), cx);
cx.notify();
})?;
}
@@ -2321,27 +2371,22 @@ impl Codegen {
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
- if let Some(transaction_id) = self.prepend_transaction_id.take() {
- self.buffer
- .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
- }
-
- if let Some(transaction_id) = self.generation_transaction_id.take() {
+ if let Some(transaction_id) = self.transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
}
- fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
+ fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
if self.diff.task.is_some() {
self.diff.should_update = true;
} else {
self.diff.should_update = false;
let old_snapshot = self.snapshot.clone();
- let old_range = self.range.to_point(&old_snapshot);
+ let old_range = edit_range.to_point(&old_snapshot);
let new_snapshot = self.buffer.read(cx).snapshot(cx);
- let new_range = self.range.to_point(&new_snapshot);
+ let new_range = edit_range.to_point(&new_snapshot);
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx
@@ -2422,7 +2467,7 @@ impl Codegen {
this.diff.inserted_row_ranges = inserted_row_ranges;
this.diff.task = None;
if this.diff.should_update {
- this.update_diff(cx);
+ this.update_diff(edit_range, cx);
}
cx.notify();
})
@@ -2629,12 +2674,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
- let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
- codegen.start(
+ codegen.handle_stream(
String::new(),
+ range,
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2690,12 +2737,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
- let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
- codegen.start(
+ codegen.handle_stream(
String::new(),
+ range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2755,12 +2804,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
- let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
- codegen.start(
+ codegen.handle_stream(
String::new(),
+ range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2819,12 +2870,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
- let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
- codegen.start(
+ codegen.handle_stream(
String::new(),
+ range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -734,29 +734,27 @@ impl PromptLibrary {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
- if let Some(token_count) = cx.update(|cx| {
- LanguageModelCompletionProvider::read_global(cx).count_tokens(
- LanguageModelRequest {
- messages: vec![LanguageModelRequestMessage {
- role: Role::System,
- content: body.to_string(),
- }],
- stop: Vec::new(),
- temperature: 1.,
- },
- cx,
- )
- })? {
- let token_count = token_count.await?;
+ let token_count = cx
+ .update(|cx| {
+ LanguageModelCompletionProvider::read_global(cx).count_tokens(
+ LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::System,
+ content: body.to_string(),
+ }],
+ stop: Vec::new(),
+ temperature: 1.,
+ },
+ cx,
+ )
+ })?
+ .await?;
- this.update(&mut cx, |this, cx| {
- let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
- prompt_editor.token_count = Some(token_count);
- cx.notify();
- })
- } else {
- Ok(())
- }
+ this.update(&mut cx, |this, cx| {
+ let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
+ prompt_editor.token_count = Some(token_count);
+ cx.notify();
+ })
}
.log_err()
});
@@ -6,8 +6,7 @@ pub fn generate_content_prompt(
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
- _project_name: Option<String>,
-) -> anyhow::Result<String> {
+) -> String {
let mut prompt = String::new();
let content_type = match language_name {
@@ -15,14 +14,16 @@ pub fn generate_content_prompt(
writeln!(
prompt,
"Here's a file of text that I'm going to ask you to make an edit to."
- )?;
+ )
+ .unwrap();
"text"
}
Some(language_name) => {
writeln!(
prompt,
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
- )?;
+ )
+ .unwrap();
"code"
}
};
@@ -70,7 +71,7 @@ pub fn generate_content_prompt(
write!(prompt, "</document>\n\n").unwrap();
if is_truncated {
- writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
+ writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
}
if range.is_empty() {
@@ -107,7 +108,7 @@ pub fn generate_content_prompt(
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
}
- Ok(prompt)
+ prompt
}
pub fn generate_terminal_assistant_prompt(
@@ -707,18 +707,15 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx)
})??;
- if let Some(token_count) = cx.update(|cx| {
- LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
- })? {
- let token_count = token_count.await?;
-
- this.update(&mut cx, |this, cx| {
- this.token_count = Some(token_count);
- cx.notify();
- })
- } else {
- Ok(())
- }
+ let token_count = cx
+ .update(|cx| {
+ LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+ })?
+ .await?;
+ this.update(&mut cx, |this, cx| {
+ this.token_count = Some(token_count);
+ cx.notify();
+ })
})
}
@@ -10,7 +10,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
- AppState, Error, RateLimit, RateLimiter, Result,
+ AppState, Config, Error, RateLimit, RateLimiter, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@@ -605,17 +605,39 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
+ .add_request_handler({
+ let app_state = app_state.clone();
+ move |request, response, session| {
+ let app_state = app_state.clone();
+ async move {
+ complete_with_language_model(request, response, session, &app_state.config)
+ .await
+ }
+ }
+ })
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
- complete_with_language_model(
- request,
- response,
- session,
- app_state.config.openai_api_key.clone(),
- app_state.config.google_ai_api_key.clone(),
- app_state.config.anthropic_api_key.clone(),
- )
+ let app_state = app_state.clone();
+ async move {
+ stream_complete_with_language_model(
+ request,
+ response,
+ session,
+ &app_state.config,
+ )
+ .await
+ }
+ }
+ })
+ .add_request_handler({
+ let app_state = app_state.clone();
+ move |request, response, session| {
+ let app_state = app_state.clone();
+ async move {
+ count_language_model_tokens(request, response, session, &app_state.config)
+ .await
+ }
}
})
.add_request_handler({
@@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
- query: proto::QueryLanguageModel,
- response: StreamingResponse<proto::QueryLanguageModel>,
+ request: proto::CompleteWithLanguageModel,
+ response: Response<proto::CompleteWithLanguageModel>,
session: Session,
- open_ai_api_key: Option<Arc<str>>,
- google_ai_api_key: Option<Arc<str>>,
- anthropic_api_key: Option<Arc<str>>,
+ config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
- match proto::LanguageModelRequestKind::from_i32(query.kind) {
- Some(proto::LanguageModelRequestKind::Complete) => {
- session
- .rate_limiter
- .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
- .await?;
- }
- Some(proto::LanguageModelRequestKind::CountTokens) => {
- session
- .rate_limiter
- .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
- .await?;
+ session
+ .rate_limiter
+ .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+ .await?;
+
+ let result = match proto::LanguageModelProvider::from_i32(request.provider) {
+ Some(proto::LanguageModelProvider::Anthropic) => {
+ let api_key = config
+ .anthropic_api_key
+ .as_ref()
+ .context("no Anthropic AI API key configured on the server")?;
+ anthropic::complete(
+ session.http_client.as_ref(),
+ anthropic::ANTHROPIC_API_URL,
+ api_key,
+ serde_json::from_str(&request.request)?,
+ )
+ .await?
}
- None => Err(anyhow!("unknown request kind"))?,
- }
+ _ => return Err(anyhow!("unsupported provider"))?,
+ };
+
+ response.send(proto::CompleteWithLanguageModelResponse {
+ completion: serde_json::to_string(&result)?,
+ })?;
+
+ Ok(())
+}
- match proto::LanguageModelProvider::from_i32(query.provider) {
+async fn stream_complete_with_language_model(
+ request: proto::StreamCompleteWithLanguageModel,
+ response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
+ session: Session,
+ config: &Config,
+) -> Result<()> {
+ let Some(session) = session.for_user() else {
+ return Err(anyhow!("user not found"))?;
+ };
+ authorize_access_to_language_models(&session).await?;
+
+ session
+ .rate_limiter
+ .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+ .await?;
+
+ match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
- let api_key =
- anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
+ let api_key = config
+ .anthropic_api_key
+ .as_ref()
+ .context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
- &api_key,
- serde_json::from_str(&query.request)?,
+ api_key,
+ serde_json::from_str(&request.request)?,
None,
)
.await?;
- while let Some(chunk) = chunks.next().await {
- let chunk = chunk?;
- response.send(proto::QueryLanguageModelResponse {
- response: serde_json::to_string(&chunk)?,
+ while let Some(event) = chunks.next().await {
+ let chunk = event?;
+ response.send(proto::StreamCompleteWithLanguageModelResponse {
+ event: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::OpenAi) => {
- let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
- let mut chunks = open_ai::stream_completion(
+ let api_key = config
+ .openai_api_key
+ .as_ref()
+ .context("no OpenAI API key configured on the server")?;
+ let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
- &api_key,
- serde_json::from_str(&query.request)?,
+ api_key,
+ serde_json::from_str(&request.request)?,
None,
)
.await?;
- while let Some(chunk) = chunks.next().await {
- let chunk = chunk?;
- response.send(proto::QueryLanguageModelResponse {
- response: serde_json::to_string(&chunk)?,
+ while let Some(event) = events.next().await {
+ let event = event?;
+ response.send(proto::StreamCompleteWithLanguageModelResponse {
+ event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
- let api_key =
- google_ai_api_key.context("no Google AI API key configured on the server")?;
-
- match proto::LanguageModelRequestKind::from_i32(query.kind) {
- Some(proto::LanguageModelRequestKind::Complete) => {
- let mut chunks = google_ai::stream_generate_content(
- session.http_client.as_ref(),
- google_ai::API_URL,
- &api_key,
- serde_json::from_str(&query.request)?,
- )
- .await?;
- while let Some(chunk) = chunks.next().await {
- let chunk = chunk?;
- response.send(proto::QueryLanguageModelResponse {
- response: serde_json::to_string(&chunk)?,
- })?;
- }
- }
- Some(proto::LanguageModelRequestKind::CountTokens) => {
- let tokens_response = google_ai::count_tokens(
- session.http_client.as_ref(),
- google_ai::API_URL,
- &api_key,
- serde_json::from_str(&query.request)?,
- )
- .await?;
- response.send(proto::QueryLanguageModelResponse {
- response: serde_json::to_string(&tokens_response)?,
- })?;
- }
- None => Err(anyhow!("unknown request kind"))?,
+ let api_key = config
+ .google_ai_api_key
+ .as_ref()
+ .context("no Google AI API key configured on the server")?;
+ let mut events = google_ai::stream_generate_content(
+ session.http_client.as_ref(),
+ google_ai::API_URL,
+ api_key,
+ serde_json::from_str(&request.request)?,
+ )
+ .await?;
+ while let Some(event) = events.next().await {
+ let event = event?;
+ response.send(proto::StreamCompleteWithLanguageModelResponse {
+ event: serde_json::to_string(&event)?,
+ })?;
}
}
None => return Err(anyhow!("unknown provider"))?,
@@ -4608,11 +4646,51 @@ async fn complete_with_language_model(
Ok(())
}
-struct CountTokensWithLanguageModelRateLimit;
+async fn count_language_model_tokens(
+ request: proto::CountLanguageModelTokens,
+ response: Response<proto::CountLanguageModelTokens>,
+ session: Session,
+ config: &Config,
+) -> Result<()> {
+ let Some(session) = session.for_user() else {
+ return Err(anyhow!("user not found"))?;
+ };
+ authorize_access_to_language_models(&session).await?;
+
+ session
+ .rate_limiter
+ .check::<CountLanguageModelTokensRateLimit>(session.user_id())
+ .await?;
+
+ let result = match proto::LanguageModelProvider::from_i32(request.provider) {
+ Some(proto::LanguageModelProvider::Google) => {
+ let api_key = config
+ .google_ai_api_key
+ .as_ref()
+ .context("no Google AI API key configured on the server")?;
+ google_ai::count_tokens(
+ session.http_client.as_ref(),
+ google_ai::API_URL,
+ api_key,
+ serde_json::from_str(&request.request)?,
+ )
+ .await?
+ }
+ _ => return Err(anyhow!("unsupported provider"))?,
+ };
+
+ response.send(proto::CountLanguageModelTokensResponse {
+ token_count: result.total_tokens as u32,
+ })?;
+
+ Ok(())
+}
+
+struct CountLanguageModelTokensRateLimit;
-impl RateLimit for CountTokensWithLanguageModelRateLimit {
+impl RateLimit for CountLanguageModelTokensRateLimit {
fn capacity() -> usize {
- std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+ std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
@@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
fn db_name() -> &'static str {
- "count-tokens-with-language-model"
+ "count-language-model-tokens"
}
}
@@ -26,7 +26,9 @@ anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
+schemars.workspace = true
serde.workspace = true
+serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
@@ -3,10 +3,13 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
- LanguageModelRequest,
+ LanguageModelRequest, LanguageModelTool,
};
-use smol::lock::{Semaphore, SemaphoreGuardArc};
-use std::{pin::Pin, sync::Arc, task::Poll};
+use smol::{
+ future::FutureExt,
+ lock::{Semaphore, SemaphoreGuardArc},
+};
+use std::{future, pin::Pin, sync::Arc, task::Poll};
use ui::Context;
pub fn init(cx: &mut AppContext) {
@@ -143,11 +146,11 @@ impl LanguageModelCompletionProvider {
&self,
request: LanguageModelRequest,
cx: &AppContext,
- ) -> Option<BoxFuture<'static, Result<usize>>> {
+ ) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() {
- Some(model.count_tokens(request, cx))
+ model.count_tokens(request, cx)
} else {
- None
+ future::ready(Err(anyhow!("no active model"))).boxed()
}
}
@@ -183,6 +186,29 @@ impl LanguageModelCompletionProvider {
Ok(completion)
})
}
+
+ pub fn use_tool<T: LanguageModelTool>(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> Task<Result<T>> {
+ if let Some(language_model) = self.active_model() {
+ cx.spawn(|cx| async move {
+ let schema = schemars::schema_for!(T);
+ let schema_json = serde_json::to_value(&schema).unwrap();
+ let request =
+ language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
+ let response = request.await?;
+ Ok(serde_json::from_value(response)?)
+ })
+ } else {
+ Task::ready(Err(anyhow!("No active model set")))
+ }
+ }
+
+ pub fn active_model_telemetry_id(&self) -> Option<String> {
+ self.active_model.as_ref().map(|m| m.telemetry_id())
+ }
}
#[cfg(test)]
@@ -16,6 +16,8 @@ pub use model::*;
pub use registry::*;
pub use request::*;
pub use role::*;
+use schemars::JsonSchema;
+use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings::init(cx);
@@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+
+ fn use_tool(
+ &self,
+ request: LanguageModelRequest,
+ name: String,
+ description: String,
+ schema: serde_json::Value,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>>;
+}
+
+pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
+ fn name() -> String;
+ fn description() -> String;
}
pub trait LanguageModelProvider: 'static {
@@ -1,5 +1,9 @@
-use anthropic::stream_completion;
-use anyhow::{anyhow, Result};
+use crate::{
+ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, Role,
+};
+use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -15,12 +19,6 @@ use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
-use crate::{
- settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, Role,
-};
-
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
.boxed()
}
+impl AnthropicModel {
+ fn request_completion(
+ &self,
+ request: anthropic::Request,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<anthropic::Response>> {
+ let http_client = self.http_client.clone();
+
+ let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
+ (state.api_key.clone(), settings.api_url.clone())
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
+ }
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: anthropic::Request,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
+ let http_client = self.http_client.clone();
+
+ let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
+ (
+ state.api_key.clone(),
+ settings.api_url.clone(),
+ settings.low_speed_timeout,
+ )
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let request = anthropic::stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ low_speed_timeout,
+ );
+ request.await
+ }
+ .boxed()
+ }
+}
+
impl LanguageModel for AnthropicModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
-
- let http_client = self.http_client.clone();
-
- let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
- (
- state.api_key.clone(),
- settings.api_url.clone(),
- settings.low_speed_timeout,
- )
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
+ let request = self.stream_completion(request, cx);
async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
- let request = stream_completion(
- http_client.as_ref(),
- &api_url,
- &api_key,
- request,
- low_speed_timeout,
- );
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
}
+
+ fn use_tool(
+ &self,
+ request: LanguageModelRequest,
+ tool_name: String,
+ tool_description: String,
+ input_schema: serde_json::Value,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ let mut request = request.into_anthropic(self.model.id().into());
+ request.tool_choice = Some(anthropic::ToolChoice::Tool {
+ name: tool_name.clone(),
+ });
+ request.tools = vec![anthropic::Tool {
+ name: tool_name.clone(),
+ description: tool_description,
+ input_schema,
+ }];
+
+ let response = self.request_completion(request, cx);
+ async move {
+ let response = response.await?;
+ response
+ .content
+ .into_iter()
+ .find_map(|content| {
+ if let anthropic::Content::ToolUse { name, input, .. } = content {
+ if name == tool_name {
+ Some(input)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ })
+ .context("tool not used")
+ }
+ .boxed()
+ }
}
struct AuthenticationPrompt {
@@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
-use anyhow::Result;
+use anyhow::{anyhow, Context as _, Result};
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
};
async move {
let request = serde_json::to_string(&request)?;
- let response = client.request(proto::QueryLanguageModel {
- provider: proto::LanguageModelProvider::Google as i32,
- kind: proto::LanguageModelRequestKind::CountTokens as i32,
- request,
- });
- let response = response.await?;
- let response =
- serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
- Ok(response.total_tokens)
+ let response = client
+ .request(proto::CountLanguageModelTokens {
+ provider: proto::LanguageModelProvider::Google as i32,
+ request,
+ })
+ .await?;
+ Ok(response.token_count as usize)
}
.boxed()
}
@@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_anthropic(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
- let response = client.request_stream(proto::QueryLanguageModel {
- provider: proto::LanguageModelProvider::Anthropic as i32,
- kind: proto::LanguageModelRequestKind::Complete as i32,
- request,
- });
- let chunks = response.await?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Anthropic as i32,
+ request,
+ })
+ .await?;
Ok(anthropic::extract_text_from_events(
- chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_open_ai(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
- let response = client.request_stream(proto::QueryLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- kind: proto::LanguageModelRequestKind::Complete as i32,
- request,
- });
- let chunks = response.await?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
Ok(open_ai::extract_text_from_events(
- chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_google(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
- let response = client.request_stream(proto::QueryLanguageModel {
- provider: proto::LanguageModelProvider::Google as i32,
- kind: proto::LanguageModelRequestKind::Complete as i32,
- request,
- });
- let chunks = response.await?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Google as i32,
+ request,
+ })
+ .await?;
Ok(google_ai::extract_text_from_events(
- chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
}
}
}
+
+ fn use_tool(
+ &self,
+ request: LanguageModelRequest,
+ tool_name: String,
+ tool_description: String,
+ input_schema: serde_json::Value,
+ _cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ match &self.model {
+ CloudModel::Anthropic(model) => {
+ let client = self.client.clone();
+ let mut request = request.into_anthropic(model.id().into());
+ request.tool_choice = Some(anthropic::ToolChoice::Tool {
+ name: tool_name.clone(),
+ });
+ request.tools = vec![anthropic::Tool {
+ name: tool_name.clone(),
+ description: tool_description,
+ input_schema,
+ }];
+
+ async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request(proto::CompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Anthropic as i32,
+ request,
+ })
+ .await?;
+ let response: anthropic::Response = serde_json::from_str(&response.completion)?;
+ response
+ .content
+ .into_iter()
+ .find_map(|content| {
+ if let anthropic::Content::ToolUse { name, input, .. } = content {
+ if name == tool_name {
+ Some(input)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ })
+ .context("tool not used")
+ }
+ .boxed()
+ }
+ CloudModel::OpenAi(_) => {
+ future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
+ }
+ CloudModel::Google(_) => {
+ future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
+ }
+ }
+ }
}
struct AuthenticationPrompt {
@@ -1,15 +1,17 @@
-use std::sync::{Arc, Mutex};
-
-use collections::HashMap;
-use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
+use anyhow::anyhow;
+use collections::HashMap;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
+use std::{
+ future,
+ sync::{Arc, Mutex},
+};
use ui::WindowContext;
pub fn language_model_id() -> LanguageModelId {
@@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
.insert(serde_json::to_string(&request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
+
+ fn use_tool(
+ &self,
+ _request: LanguageModelRequest,
+ _name: String,
+ _description: String,
+ _schema: serde_json::Value,
+ _cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ future::ready(Err(anyhow!("not implemented"))).boxed()
+ }
}
@@ -9,7 +9,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
}
.boxed()
}
+
+ fn use_tool(
+ &self,
+ _request: LanguageModelRequest,
+ _name: String,
+ _description: String,
+ _schema: serde_json::Value,
+ _cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ future::ready(Err(anyhow!("not implemented"))).boxed()
+ }
}
struct AuthenticationPrompt {
@@ -6,7 +6,7 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
@@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
}
.boxed()
}
+
+ fn use_tool(
+ &self,
+ _request: LanguageModelRequest,
+ _name: String,
+ _description: String,
+ _schema: serde_json::Value,
+ _cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ future::ready(Err(anyhow!("not implemented"))).boxed()
+ }
}
struct DownloadOllamaMessage {
@@ -9,7 +9,7 @@ use gpui::{
use http_client::HttpClient;
use open_ai::stream_completion;
use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
}
.boxed()
}
+
+ fn use_tool(
+ &self,
+ _request: LanguageModelRequest,
+ _name: String,
+ _description: String,
+ _schema: serde_json::Value,
+ _cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ future::ready(Err(anyhow!("not implemented"))).boxed()
+ }
}
pub fn count_open_ai_tokens(
@@ -106,19 +106,27 @@ impl LanguageModelRequest {
messages: new_messages
.into_iter()
.filter_map(|message| {
- Some(anthropic::RequestMessage {
+ Some(anthropic::Message {
role: match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => return None,
},
- content: message.content,
+ content: vec![anthropic::Content::Text {
+ text: message.content,
+ }],
})
})
.collect(),
- stream: true,
max_tokens: 4092,
- system: system_message,
+ system: Some(system_message),
+ tools: Vec::new(),
+ tool_choice: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
}
}
}
@@ -194,8 +194,12 @@ message Envelope {
JoinHostedProject join_hosted_project = 164;
- QueryLanguageModel query_language_model = 224;
- QueryLanguageModelResponse query_language_model_response = 225; // current max
+ CompleteWithLanguageModel complete_with_language_model = 226;
+ CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
+ StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
+ StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
+ CountLanguageModelTokens count_language_model_tokens = 230;
+ CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
@@ -267,6 +271,7 @@ message Envelope {
reserved 158 to 161;
reserved 166 to 169;
+ reserved 224 to 225;
}
// Messages
@@ -2050,25 +2055,37 @@ enum LanguageModelRole {
reserved 3;
}
-message QueryLanguageModel {
+message CompleteWithLanguageModel {
LanguageModelProvider provider = 1;
- LanguageModelRequestKind kind = 2;
- string request = 3;
+ string request = 2;
}
-enum LanguageModelProvider {
- Anthropic = 0;
- OpenAI = 1;
- Google = 2;
+message CompleteWithLanguageModelResponse {
+ string completion = 1;
+}
+
+message StreamCompleteWithLanguageModel {
+ LanguageModelProvider provider = 1;
+ string request = 2;
+}
+
+message StreamCompleteWithLanguageModelResponse {
+ string event = 1;
+}
+
+message CountLanguageModelTokens {
+ LanguageModelProvider provider = 1;
+ string request = 2;
}
-enum LanguageModelRequestKind {
- Complete = 0;
- CountTokens = 1;
+message CountLanguageModelTokensResponse {
+ uint32 token_count = 1;
}
-message QueryLanguageModelResponse {
- string response = 1;
+enum LanguageModelProvider {
+ Anthropic = 0;
+ OpenAI = 1;
+ Google = 2;
}
message GetCachedEmbeddings {
@@ -294,8 +294,12 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
- (QueryLanguageModel, Background),
- (QueryLanguageModelResponse, Background),
+ (CompleteWithLanguageModel, Background),
+ (CompleteWithLanguageModelResponse, Background),
+ (StreamCompleteWithLanguageModel, Background),
+ (StreamCompleteWithLanguageModelResponse, Background),
+ (CountLanguageModelTokens, Background),
+ (CountLanguageModelTokensResponse, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@@ -463,7 +467,12 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
- (QueryLanguageModel, QueryLanguageModelResponse),
+ (CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
+ (
+ StreamCompleteWithLanguageModel,
+ StreamCompleteWithLanguageModelResponse
+ ),
+ (CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),