diff --git a/Cargo.lock b/Cargo.lock
index 1aebd9eb25a927cc4e41818948ac0f27d9be46a1..6bb1c470037484c0ab15b6b8d9f005e86487f1c5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -223,6 +223,7 @@ name = "anthropic"
version = "0.1.0"
dependencies = [
"anyhow",
+ "chrono",
"futures 0.3.30",
"http_client",
"isahc",
@@ -232,6 +233,7 @@ dependencies = [
"strum",
"thiserror",
"tokio",
+ "util",
]
[[package]]
@@ -5057,9 +5059,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "heed"
-version = "0.20.4"
+version = "0.20.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "620033c8c8edfd2f53e6f99a30565eb56a33b42c468e3ad80e21d85fb93bafb0"
+checksum = "7d4f449bab7320c56003d37732a917e18798e2f1709d80263face2b4f9436ddb"
dependencies = [
"bitflags 2.6.0",
"byteorder",
@@ -6314,9 +6316,9 @@ dependencies = [
[[package]]
name = "lmdb-master-sys"
-version = "0.2.3"
+version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1de7e761853c15ca72821d9f928d7bb123ef4c05377c4e7ab69fa1c742f91d24"
+checksum = "472c3760e2a8d0f61f322fb36788021bb36d573c502b50fa3e2bcaac3ec326c9"
dependencies = [
"cc",
"doxygen-rs",
@@ -7590,6 +7592,29 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
+[[package]]
+name = "performance"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "gpui",
+ "log",
+ "schemars",
+ "serde",
+ "settings",
+ "util",
+ "workspace",
+]
+
+[[package]]
+name = "perplexity"
+version = "0.1.0"
+dependencies = [
+ "serde",
+ "zed_extension_api 0.1.0",
+]
+
[[package]]
name = "pest"
version = "2.7.11"
@@ -13875,6 +13900,7 @@ dependencies = [
"outline_panel",
"parking_lot",
"paths",
+ "performance",
"profiling",
"project",
"project_panel",
diff --git a/Cargo.toml b/Cargo.toml
index 329688b34b946f793e81f26d20b0ccf5e2ff3fdf..3a143f33f175377ee7001fd7e52360f34b6071d7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -70,6 +70,7 @@ members = [
"crates/outline",
"crates/outline_panel",
"crates/paths",
+ "crates/performance",
"crates/picker",
"crates/prettier",
"crates/project",
@@ -145,6 +146,7 @@ members = [
"extensions/lua",
"extensions/ocaml",
"extensions/php",
+ "extensions/perplexity",
"extensions/prisma",
"extensions/purescript",
"extensions/ruff",
@@ -241,6 +243,7 @@ open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
outline_panel = { path = "crates/outline_panel" }
paths = { path = "crates/paths" }
+performance = { path = "crates/performance" }
picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
plugin_macros = { path = "crates/plugin_macros" }
diff --git a/assets/icons/ai_anthropic_hosted.svg b/assets/icons/ai_anthropic_hosted.svg
new file mode 100644
index 0000000000000000000000000000000000000000..12d731fb0b4438fcf6c263bd6c071bc8873823de
--- /dev/null
+++ b/assets/icons/ai_anthropic_hosted.svg
@@ -0,0 +1,11 @@
+
diff --git a/assets/icons/ellipsis_vertical.svg b/assets/icons/ellipsis_vertical.svg
new file mode 100644
index 0000000000000000000000000000000000000000..077dbe8778f2015bc97a0ad07400585e8a7ea807
--- /dev/null
+++ b/assets/icons/ellipsis_vertical.svg
@@ -0,0 +1 @@
+
diff --git a/assets/icons/magic_wand.svg b/assets/icons/magic_wand.svg
deleted file mode 100644
index cd2194764767ee14fe2c4ac1b5e6c73a463d41a1..0000000000000000000000000000000000000000
--- a/assets/icons/magic_wand.svg
+++ /dev/null
@@ -1,10 +0,0 @@
-
diff --git a/assets/icons/search_code.svg b/assets/icons/search_code.svg
new file mode 100644
index 0000000000000000000000000000000000000000..1cc9affeb80fe8111de1417a0241497de663ee90
--- /dev/null
+++ b/assets/icons/search_code.svg
@@ -0,0 +1 @@
+
diff --git a/assets/icons/slash.svg b/assets/icons/slash.svg
new file mode 100644
index 0000000000000000000000000000000000000000..792c405bb081a99ba456a7cb61ea5fbb5cfcd270
--- /dev/null
+++ b/assets/icons/slash.svg
@@ -0,0 +1 @@
+
diff --git a/assets/icons/slash_square.svg b/assets/icons/slash_square.svg
new file mode 100644
index 0000000000000000000000000000000000000000..8f269ddeb5e879dd0f8488af037dd4f35743a0ea
--- /dev/null
+++ b/assets/icons/slash_square.svg
@@ -0,0 +1 @@
+
diff --git a/assets/prompts/content_prompt.hbs b/assets/prompts/content_prompt.hbs
index cd618a67613e54f63796b62f0b3e3386d258af19..cf4141349b356c8aeb81da9814e7a1c1c27b1f88 100644
--- a/assets/prompts/content_prompt.hbs
+++ b/assets/prompts/content_prompt.hbs
@@ -1,426 +1,61 @@
-You are an expert developer assistant working in an AI-enabled text editor.
-Your task is to rewrite a specific section of the provided document based on a user-provided prompt.
-
-
-1. Scope: Modify only content within tags. Do not alter anything outside these boundaries.
-2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is.
-3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure.
-4. Tag exclusion: Never include , , , or tags in the output.
-5. Indentation: Maintain the original indentation level of the file in rewritten sections.
-6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions.
-7. Insertions: Replace tags with appropriate content as specified by the prompt.
-8. Code integrity: Respect existing code structure and functionality when making changes.
-9. Consistency: Maintain a uniform style and tone throughout the rewritten text.
-
-
-
-
-
-
-use std::cell::Cell;
-use std::collections::HashMap;
-use std::cmp;
-
-
-
-
-pub struct LruCache {
- /// The maximum number of items the cache can hold.
- capacity: usize,
- /// The map storing the cached items.
- items: HashMap,
-}
-
-// The rest of the implementation...
-
-
-doc this
-
-
-
-
-/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
-///
-/// This structure is used for efficient spatial queries and collision detection.
-/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
-///
-/// # Type Parameters
-///
-/// * `T`: The type of data associated with each node in the tree.
-pub struct AabbTree {
- root: Option,
-
-
-/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
-///
-/// This structure is used for efficient spatial queries and collision detection.
-/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
-///
-/// # Type Parameters
-///
-/// * `T`: The type of data associated with each node in the tree.
-
-
-
-
-
-
-import math
-
-def calculate_circle_area(radius):
- """Calculate the area of a circle given its radius."""
- return math.pi * radius ** 2
-
-
-
-
-class Circle:
- def __init__(self, radius):
- self.radius = radius
-
- def area(self):
- return math.pi * self.radius ** 2
-
- def circumference(self):
- return 2 * math.pi * self.radius
-
-# Usage example
-circle = Circle(5)
-print(f"Area: {circle.area():.2f}")
-print(f"Circumference: {circle.circumference():.2f}")
-
-
-write docs
-
-
-
-
-"""
-Represents a circle with methods to calculate its area and circumference.
-
-This class provides a simple way to work with circles in a geometric context.
-It allows for the creation of Circle objects with a specified radius and
-offers methods to compute the circle's area and circumference.
-
-Attributes:
- radius (float): The radius of the circle.
-
-Methods:
- area(): Calculates and returns the area of the circle.
- circumference(): Calculates and returns the circumference of the circle.
-"""
-class Circle:
-
-
-"""
-Represents a circle with methods to calculate its area and circumference.
-
-This class provides a simple way to work with circles in a geometric context.
-It allows for the creation of Circle objects with a specified radius and
-offers methods to compute the circle's area and circumference.
-
-Attributes:
- radius (float): The radius of the circle.
+{{#if language_name}}
+Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
+{{else}}
+Here's a file of text that I'm going to ask you to make an edit to.
+{{/if}}
-Methods:
- area(): Calculates and returns the area of the circle.
- circumference(): Calculates and returns the circumference of the circle.
-"""
-
-
+{{#if is_insert}}
+The point you'll need to insert at is marked with .
+{{else}}
+The section you'll need to rewrite is marked with tags.
+{{/if}}
-
-
-class BankAccount {
- private balance: number;
-
- constructor(initialBalance: number) {
- this.balance = initialBalance;
- }
-
-
-
-
- deposit(amount: number): void {
- if (amount > 0) {
- this.balance += amount;
- }
- }
-
- withdraw(amount: number): boolean {
- if (amount > 0 && this.balance >= amount) {
- this.balance -= amount;
- return true;
- }
- return false;
- }
-
- getBalance(): number {
- return this.balance;
- }
-}
-
-// Usage
-const account = new BankAccount(1000);
-account.deposit(500);
-console.log(account.getBalance()); // 1500
-account.withdraw(200);
-console.log(account.getBalance()); // 1300
+{{{document_content}}}
-
-//
-
-
-
- /**
- * Deposits the specified amount into the bank account.
- *
- * @param amount The amount to deposit. Must be a positive number.
- * @throws Error if the amount is not positive.
- */
- deposit(amount: number): void {
- if (amount > 0) {
- this.balance += amount;
- } else {
- throw new Error("Deposit amount must be positive");
- }
- }
-
-
- /**
- * Deposits the specified amount into the bank account.
- *
- * @param amount The amount to deposit. Must be a positive number.
- * @throws Error if the amount is not positive.
- */
-
-
+{{#if is_truncated}}
+The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
+{{/if}}
-
-
-
-use std::collections::VecDeque;
+{{#if is_insert}}
+You can't replace {{content_type}}, your answer will be inserted in place of the `` tags. Don't include the insert_here tags in your output.
-pub struct BinaryTree {
- root: Option>,
-}
+Generate {{content_type}} based on the following prompt:
-
-
-
-struct Node {
- value: T,
- left: Option>>,
- right: Option>>,
-}
-
-derive clone
+{{{user_prompt}}}
-
-
-
-#[derive(Clone)]
-
-struct Node {
- value: T,
- left: Option>>,
- right: Option>>,
-}
-
-
-pub struct BinaryTree {
- root: Option>,
-}
+Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
-#[derive(Clone)]
-
-
-
-#[derive(Clone)]
-struct Node {
- value: T,
- left: Option>>,
- right: Option>>,
-}
-
-impl Node {
- fn new(value: T) -> Self {
- Node {
- value,
- left: None,
- right: None,
- }
- }
-}
-
-
-#[derive(Clone)]
-
-
-
-
-
-
-import math
-
-def calculate_circle_area(radius):
- """Calculate the area of a circle given its radius."""
- return math.pi * radius ** 2
-
-
-
-
-class Circle:
- def __init__(self, radius):
- self.radius = radius
-
- def area(self):
- return math.pi * self.radius ** 2
+Immediately start with the following format with no remarks:
- def circumference(self):
- return 2 * math.pi * self.radius
+```
+\{{INSERTED_CODE}}
+```
+{{else}}
+Edit the section of {{content_type}} in tags based on the following prompt:
-# Usage example
-circle = Circle(5)
-print(f"Area: {circle.area():.2f}")
-print(f"Circumference: {circle.circumference():.2f}")
-
-add dataclass decorator
+{{{user_prompt}}}
-
-
-
-@dataclass
-class Circle:
- radius: float
-
- def __init__(self, radius):
- self.radius = radius
- def area(self):
- return math.pi * self.radius ** 2
-
-
-@dataclass
-
-
-
-
-
-
-interface ShoppingCart {
- items: string[];
- total: number;
-}
+{{#if rewrite_section}}
+And here's the section to rewrite based on that prompt again for reference:
-class ShoppingCartManager {
+{{{rewrite_section}}}
- private cart: ShoppingCart;
-
- constructor() {
- this.cart = { items: [], total: 0 };
- }
-
- addItem(item: string, price: number): void {
- this.cart.items.push(item);
- this.cart.total += price;
- }
-
- getTotal(): number {
- return this.cart.total;
- }
-}
-
-// Usage
-const manager = new ShoppingCartManager();
-manager.addItem("Book", 15.99);
-console.log(manager.getTotal()); // 15.99
-
-
-add readonly modifier
-
-
-
-
-readonly interface ShoppingCart {
- items: string[];
- total: number;
-}
-
-class ShoppingCartManager {
- private readonly cart: ShoppingCart;
-
- constructor() {
- this.cart = { items: [], total: 0 };
- }
-
-
-readonly interface ShoppingCart {
-
-
-
-
-
-With these examples in mind, edit the following file:
-
-
-{{{ document_content }}}
-
-
-{{#if is_truncated}}
-The provided document has been truncated (potentially mid-line) for brevity.
-{{/if}}
-
-
-{{#if has_insertion}}
-Insert text anywhere you see marked with tags. It's CRITICAL that you DO NOT include tags in your output.
-{{/if}}
-{{#if has_replacement}}
-Edit text that you see surrounded with ... tags. It's CRITICAL that you DO NOT include tags in your output.
{{/if}}
-Make no changes to the rewritten content outside these tags.
-
-{{{ rewrite_section_prefix }}}
-
-{{{ rewrite_section_with_edits }}}
-
-{{{ rewrite_section_suffix }}}
-
+Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
-Rewrite the lines enclosed within the tags in accordance with the provided instructions and the prompt below.
-
-
-{{{ user_prompt }}}
-
-
-Do not include or annotations in your output. Here is a clean copy of the snippet without annotations for your reference.
-
-
-{{{ rewrite_section_prefix }}}
-{{{ rewrite_section }}}
-{{{ rewrite_section_suffix }}}
-
-
-
-
-1. Focus on necessary changes: Modify only what's required to fulfill the prompt.
-2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow.
-3. Exclude annotation tags: Do not output , , , or tags.
-4. Maintain indentation: Begin at the original file's indentation level.
-5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed.
-6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output.
-7. Respect content boundaries: Preserve code integrity.
-
+Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
Immediately start with the following format with no remarks:
```
\{{REWRITTEN_CODE}}
```
+{{/if}}
diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml
index 4628d3db809cf8f33672bfb37ad78ca723265aaf..9e48ad0e57d81d1434d3e872e84edcab7f233900 100644
--- a/crates/anthropic/Cargo.toml
+++ b/crates/anthropic/Cargo.toml
@@ -17,6 +17,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
+chrono.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true
@@ -25,6 +26,7 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
+util.workspace = true
[dev-dependencies]
tokio.workspace = true
diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs
index e9f0ea51a99562a149960806573cbbd17678e827..38b4f5466c32c96de2ae1a0d5d31900ea5266c81 100644
--- a/crates/anthropic/src/anthropic.rs
+++ b/crates/anthropic/src/anthropic.rs
@@ -1,14 +1,17 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result};
+use chrono::{DateTime, Utc};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
+use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
+use util::ResultExt as _;
pub use supported_countries::*;
@@ -38,6 +41,8 @@ pub enum Model {
Custom {
name: String,
max_tokens: usize,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ display_name: Option,
/// Override this model with a different Anthropic model for tool calls.
tool_override: Option,
/// Indicates whether this custom model supports caching.
@@ -77,7 +82,9 @@ impl Model {
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
- Self::Custom { name, .. } => name,
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_ref().unwrap_or(name),
}
}
@@ -191,6 +198,66 @@ pub async fn stream_completion(
request: Request,
low_speed_timeout: Option,
) -> Result>, AnthropicError> {
+ stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
+ .await
+ .map(|output| output.0)
+}
+
+/// https://docs.anthropic.com/en/api/rate-limits#response-headers
+#[derive(Debug)]
+pub struct RateLimitInfo {
+ pub requests_limit: usize,
+ pub requests_remaining: usize,
+ pub requests_reset: DateTime,
+ pub tokens_limit: usize,
+ pub tokens_remaining: usize,
+ pub tokens_reset: DateTime,
+}
+
+impl RateLimitInfo {
+ fn from_headers(headers: &HeaderMap) -> Result {
+ let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
+ let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
+ let tokens_remaining =
+ get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
+ let requests_remaining =
+ get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
+ let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
+ let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
+ let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
+ let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
+
+ Ok(Self {
+ requests_limit,
+ tokens_limit,
+ requests_remaining,
+ tokens_remaining,
+ requests_reset,
+ tokens_reset,
+ })
+ }
+}
+
+fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
+ Ok(headers
+ .get(key)
+ .ok_or_else(|| anyhow!("missing header `{key}`"))?
+ .to_str()?)
+}
+
+pub async fn stream_completion_with_rate_limit_info(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+ low_speed_timeout: Option,
+) -> Result<
+ (
+ BoxStream<'static, Result>,
+ Option,
+ ),
+ AnthropicError,
+> {
let request = StreamingRequest {
base: request,
stream: true,
@@ -220,8 +287,9 @@ pub async fn stream_completion(
.await
.context("failed to send request to Anthropic")?;
if response.status().is_success() {
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
let reader = BufReader::new(response.into_body());
- Ok(reader
+ let stream = reader
.lines()
.filter_map(|line| async move {
match line {
@@ -235,7 +303,8 @@ pub async fn stream_completion(
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
}
})
- .boxed())
+ .boxed();
+ Ok((stream, rate_limits.log_err()))
} else {
let mut body = Vec::new();
response
diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs
index 7b26113ea9538b206203de72f22dda8ee7ccfa97..2e7242baf67ac1b3bbf38e6c81ce4c638e01b5db 100644
--- a/crates/assistant/src/assistant.rs
+++ b/crates/assistant/src/assistant.rs
@@ -9,6 +9,7 @@ mod model_selector;
mod prompt_library;
mod prompts;
mod slash_command;
+pub(crate) mod slash_command_picker;
pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
@@ -33,7 +34,7 @@ use language_model::{
};
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
-use prompts::PromptOverrideContext;
+use prompts::PromptLoadingParams;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
@@ -59,7 +60,6 @@ actions!(
InsertIntoEditor,
ToggleFocus,
InsertActivePrompt,
- ShowConfiguration,
DeployHistory,
DeployPromptLibrary,
ConfirmCommand,
@@ -184,7 +184,7 @@ impl Assistant {
pub fn init(
fs: Arc,
client: Arc,
- dev_mode: bool,
+ stdout_is_a_pty: bool,
cx: &mut AppContext,
) -> Arc {
cx.set_global(Assistant::default());
@@ -223,9 +223,11 @@ pub fn init(
assistant_panel::init(cx);
context_servers::init(cx);
- let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
- dev_mode,
+ let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
fs: fs.clone(),
+ repo_path: stdout_is_a_pty
+ .then(|| std::env::current_dir().log_err())
+ .flatten(),
cx,
}))
.log_err()
diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs
index ac2e54ab12103f27a0aa0c53e89a043017e44868..8e6a7d364e8301c99239818c515b693a02c411b0 100644
--- a/crates/assistant/src/assistant_panel.rs
+++ b/crates/assistant/src/assistant_panel.rs
@@ -9,6 +9,7 @@ use crate::{
file_command::codeblock_fence_for_path,
SlashCommandCompletionProvider, SlashCommandRegistry,
},
+ slash_command_picker,
terminal_inline_assistant::TerminalInlineAssistant,
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant,
@@ -16,7 +17,7 @@ use crate::{
QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus,
ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
};
-use crate::{ContextStoreEvent, ModelPickerDelegate, ShowConfiguration};
+use crate::{ContextStoreEvent, ModelPickerDelegate};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::{proto, Client, Status};
@@ -35,10 +36,10 @@ use fs::Fs;
use gpui::{
canvas, div, img, percentage, point, pulsating_between, size, Action, Animation, AnimationExt,
AnyElement, AnyView, AppContext, AsyncWindowContext, ClipboardEntry, ClipboardItem,
- Context as _, CursorStyle, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle,
- FocusableView, FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels,
- ReadGlobal, Render, RenderImage, SharedString, Size, StatefulInteractiveElement, Styled,
- Subscription, Task, Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
+ Context as _, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle, FocusableView,
+ FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels, ReadGlobal, Render,
+ RenderImage, SharedString, Size, StatefulInteractiveElement, Styled, Subscription, Task,
+ Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
};
use indexed_docs::IndexedDocsStore;
use language::{
@@ -68,8 +69,8 @@ use ui::TintColor;
use ui::{
prelude::*,
utils::{format_distance_from_now, DateTimeType},
- Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem,
- ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip,
+ Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, IconButtonShape,
+ KeyBinding, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip,
};
use util::ResultExt;
use workspace::{
@@ -77,7 +78,8 @@ use workspace::{
item::{self, FollowableItem, Item, ItemHandle},
pane::{self, SaveIntent},
searchable::{SearchEvent, SearchableItem},
- Pane, Save, ToggleZoom, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
+ Pane, Save, ShowConfiguration, ToggleZoom, ToolbarItemEvent, ToolbarItemLocation,
+ ToolbarItemView, Workspace,
};
use workspace::{searchable::SearchableItemHandle, NewFile};
@@ -119,7 +121,7 @@ struct InlineAssistTabBarButton;
impl Render for InlineAssistTabBarButton {
fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement {
- IconButton::new("terminal_inline_assistant", IconName::MagicWand)
+ IconButton::new("terminal_inline_assistant", IconName::ZedAssistant)
.icon_size(IconSize::Small)
.on_click(cx.listener(|_, _, cx| {
cx.dispatch_action(InlineAssist::default().boxed_clone());
@@ -541,19 +543,18 @@ impl AssistantPanel {
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
-
- pane::Event::RemoveItem { idx } => {
- if self
+ pane::Event::RemovedItem { .. } => {
+ let has_configuration_view = self
.pane
.read(cx)
- .item_for_index(*idx)
- .map_or(false, |item| item.downcast::().is_some())
- {
+ .items_of_type::()
+ .next()
+ .is_some();
+
+ if !has_configuration_view {
self.configuration_subscription = None;
}
- false
- }
- pane::Event::RemovedItem { .. } => {
+
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
@@ -704,7 +705,9 @@ impl AssistantPanel {
self.authenticate_provider_task = Some((
provider.id(),
cx.spawn(|this, mut cx| async move {
- let _ = load_credentials.await;
+ if let Some(future) = load_credentials {
+ let _ = future.await;
+ }
this.update(&mut cx, |this, _cx| {
this.authenticate_provider_task = None;
})
@@ -735,6 +738,7 @@ impl AssistantPanel {
};
let initial_prompt = action.prompt.clone();
+
if assistant_panel.update(cx, |assistant, cx| assistant.is_authenticated(cx)) {
match inline_assist_target {
InlineAssistTarget::Editor(active_editor, include_context) => {
@@ -763,9 +767,27 @@ impl AssistantPanel {
} else {
let assistant_panel = assistant_panel.downgrade();
cx.spawn(|workspace, mut cx| async move {
- assistant_panel
- .update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
- .await?;
+ let Some(task) =
+ assistant_panel.update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
+ else {
+ let answer = cx
+ .prompt(
+ gpui::PromptLevel::Warning,
+ "No language model provider configured",
+ None,
+ &["Configure", "Cancel"],
+ )
+ .await
+ .ok();
+ if let Some(answer) = answer {
+ if answer == 0 {
+ cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
+ .ok();
+ }
+ }
+ return Ok(());
+ };
+ task.await?;
if assistant_panel.update(&mut cx, |panel, cx| panel.is_authenticated(cx))? {
cx.update(|cx| match inline_assist_target {
InlineAssistTarget::Editor(active_editor, include_context) => {
@@ -1173,13 +1195,10 @@ impl AssistantPanel {
.map_or(false, |provider| provider.is_authenticated(cx))
}
- fn authenticate(&mut self, cx: &mut ViewContext) -> Task> {
+ fn authenticate(&mut self, cx: &mut ViewContext) -> Option>> {
LanguageModelRegistry::read_global(cx)
.active_provider()
- .map_or(
- Task::ready(Err(anyhow!("no active language model provider"))),
- |provider| provider.authenticate(cx),
- )
+ .map_or(None, |provider| Some(provider.authenticate(cx)))
}
}
@@ -1336,6 +1355,7 @@ struct WorkflowStep {
footer_block_id: CustomBlockId,
resolved_step: Option>>,
assist: Option,
+ auto_apply: bool,
}
impl WorkflowStep {
@@ -1372,13 +1392,16 @@ impl WorkflowStep {
}
}
Some(Err(error)) => WorkflowStepStatus::Error(error.clone()),
- None => WorkflowStepStatus::Resolving,
+ None => WorkflowStepStatus::Resolving {
+ auto_apply: self.auto_apply,
+ },
}
}
}
+#[derive(Clone)]
enum WorkflowStepStatus {
- Resolving,
+ Resolving { auto_apply: bool },
Error(Arc),
Empty,
Idle,
@@ -1455,16 +1478,6 @@ impl WorkflowStepStatus {
.unwrap_or_default()
}
match self {
- WorkflowStepStatus::Resolving => Label::new("Resolving")
- .size(LabelSize::Small)
- .with_animation(
- ("resolving-suggestion-animation", id),
- Animation::new(Duration::from_secs(2))
- .repeat()
- .with_easing(pulsating_between(0.4, 0.8)),
- |label, delta| label.alpha(delta),
- )
- .into_any_element(),
WorkflowStepStatus::Error(error) => Self::render_workflow_step_error(
id,
editor.clone(),
@@ -1477,43 +1490,72 @@ impl WorkflowStepStatus {
step_range.clone(),
"Model was unable to locate the code to edit".to_string(),
),
- WorkflowStepStatus::Idle => Button::new(("transform", id), "Transform")
- .icon(IconName::SparkleAlt)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .label_size(LabelSize::Small)
- .style(ButtonStyle::Tinted(TintColor::Accent))
- .tooltip({
- let step_range = step_range.clone();
- let editor = editor.clone();
- move |cx| {
- cx.new_view(|cx| {
- let tooltip = Tooltip::new("Transform");
- if display_keybind_in_tooltip(&step_range, &editor, cx) {
- tooltip.key_binding(KeyBinding::for_action_in(
- &Assist,
- &focus_handle,
- cx,
- ))
- } else {
- tooltip
- }
- })
- .into()
- }
- })
- .on_click({
- let editor = editor.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- editor
- .update(cx, |this, cx| {
- this.apply_workflow_step(step_range.clone(), cx)
+ WorkflowStepStatus::Idle | WorkflowStepStatus::Resolving { .. } => {
+ let status = self.clone();
+ Button::new(("transform", id), "Transform")
+ .icon(IconName::SparkleAlt)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .label_size(LabelSize::Small)
+ .style(ButtonStyle::Tinted(TintColor::Accent))
+ .tooltip({
+ let step_range = step_range.clone();
+ let editor = editor.clone();
+ move |cx| {
+ cx.new_view(|cx| {
+ let tooltip = Tooltip::new("Transform");
+ if display_keybind_in_tooltip(&step_range, &editor, cx) {
+ tooltip.key_binding(KeyBinding::for_action_in(
+ &Assist,
+ &focus_handle,
+ cx,
+ ))
+ } else {
+ tooltip
+ }
})
- .ok();
- }
- })
- .into_any_element(),
+ .into()
+ }
+ })
+ .on_click({
+ let editor = editor.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ if let WorkflowStepStatus::Idle = &status {
+ editor
+ .update(cx, |this, cx| {
+ this.apply_workflow_step(step_range.clone(), cx)
+ })
+ .ok();
+ } else if let WorkflowStepStatus::Resolving { auto_apply: false } =
+ &status
+ {
+ editor
+ .update(cx, |this, _| {
+ if let Some(step) = this.workflow_steps.get_mut(&step_range)
+ {
+ step.auto_apply = true;
+ }
+ })
+ .ok();
+ }
+ }
+ })
+ .map(|this| {
+ if let WorkflowStepStatus::Resolving { auto_apply: true } = &self {
+ this.with_animation(
+ ("resolving-suggestion-animation", id),
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else {
+ this.into_any_element()
+ }
+ })
+ }
WorkflowStepStatus::Pending => h_flex()
.items_center()
.gap_2()
@@ -1699,6 +1741,8 @@ pub struct ContextEditor {
assistant_panel: WeakView,
error_message: Option,
show_accept_terms: bool,
+ pub(crate) slash_menu_handle:
+ PopoverMenuHandle>,
}
const DEFAULT_TAB_TITLE: &str = "New Context";
@@ -1760,10 +1804,11 @@ impl ContextEditor {
assistant_panel,
error_message: None,
show_accept_terms: false,
+ slash_menu_handle: Default::default(),
};
this.update_message_headers(cx);
this.update_image_blocks(cx);
- this.insert_slash_command_output_sections(sections, cx);
+ this.insert_slash_command_output_sections(sections, false, cx);
this
}
@@ -1776,7 +1821,7 @@ impl ContextEditor {
let command = self.context.update(cx, |context, cx| {
let first_message_id = context.messages(cx).next().unwrap().id;
context.update_metadata(first_message_id, cx, |metadata| {
- metadata.role = Role::System;
+ metadata.role = Role::User;
});
context.reparse_slash_commands(cx);
context.pending_slash_commands()[0].clone()
@@ -1787,6 +1832,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ true,
self.workspace.clone(),
cx,
);
@@ -1834,7 +1880,7 @@ impl ContextEditor {
let range = step.range.clone();
match step.status(cx) {
- WorkflowStepStatus::Resolving | WorkflowStepStatus::Pending => true,
+ WorkflowStepStatus::Resolving { .. } | WorkflowStepStatus::Pending => true,
WorkflowStepStatus::Idle => {
self.apply_workflow_step(range, cx);
true
@@ -1988,7 +2034,7 @@ impl ContextEditor {
.collect()
}
- fn insert_command(&mut self, name: &str, cx: &mut ViewContext) {
+ pub fn insert_command(&mut self, name: &str, cx: &mut ViewContext) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
self.editor.update(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
@@ -2053,6 +2099,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
true,
+ false,
workspace.clone(),
cx,
);
@@ -2061,19 +2108,27 @@ impl ContextEditor {
}
}
+ #[allow(clippy::too_many_arguments)]
pub fn run_command(
&mut self,
command_range: Range,
name: &str,
arguments: &[String],
- insert_trailing_newline: bool,
+ ensure_trailing_newline: bool,
+ expand_result: bool,
workspace: WeakView,
cx: &mut ViewContext,
) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
let output = command.run(arguments, workspace, self.lsp_adapter_delegate.clone(), cx);
self.context.update(cx, |context, cx| {
- context.insert_command_output(command_range, output, insert_trailing_newline, cx)
+ context.insert_command_output(
+ command_range,
+ output,
+ ensure_trailing_newline,
+ expand_result,
+ cx,
+ )
});
}
}
@@ -2159,6 +2214,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ false,
workspace.clone(),
cx,
);
@@ -2255,8 +2311,13 @@ impl ContextEditor {
output_range,
sections,
run_commands_in_output,
+ expand_result,
} => {
- self.insert_slash_command_output_sections(sections.iter().cloned(), cx);
+ self.insert_slash_command_output_sections(
+ sections.iter().cloned(),
+ *expand_result,
+ cx,
+ );
if *run_commands_in_output {
let commands = self.context.update(cx, |context, cx| {
@@ -2272,6 +2333,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ false,
self.workspace.clone(),
cx,
);
@@ -2288,6 +2350,7 @@ impl ContextEditor {
fn insert_slash_command_output_sections(
&mut self,
sections: impl IntoIterator- >,
+ expand_result: bool,
cx: &mut ViewContext,
) {
self.editor.update(cx, |editor, cx| {
@@ -2342,6 +2405,9 @@ impl ContextEditor {
editor.insert_creases(creases, cx);
+ if expand_result {
+ buffer_rows_to_fold.clear();
+ }
for buffer_row in buffer_rows_to_fold.into_iter().rev() {
editor.fold_at(&FoldAt { buffer_row }, cx);
}
@@ -2517,19 +2583,35 @@ impl ContextEditor {
div().child(step_label)
};
- let step_label = step_label
+ let step_label_element = step_label.into_any_element();
+
+ let step_label = h_flex()
.id("step")
- .cursor(CursorStyle::PointingHand)
- .on_click({
- let this = weak_self.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- this.update(cx, |this, cx| {
- this.open_workflow_step(step_range.clone(), cx);
- })
- .ok();
- }
- });
+ .group("step-label")
+ .items_center()
+ .gap_1()
+ .child(step_label_element)
+ .child(
+ IconButton::new("edit-step", IconName::SearchCode)
+ .size(ButtonSize::Compact)
+ .icon_size(IconSize::Small)
+ .shape(IconButtonShape::Square)
+ .visible_on_hover("step-label")
+ .tooltip(|cx| Tooltip::text("Open Step View", cx))
+ .on_click({
+ let this = weak_self.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ this.update(cx, |this, cx| {
+ this.open_workflow_step(
+ step_range.clone(),
+ cx,
+ );
+ })
+ .ok();
+ }
+ }),
+ );
div()
.w_full()
@@ -2612,11 +2694,17 @@ impl ContextEditor {
footer_block_id: block_ids[1],
resolved_step,
assist: None,
+ auto_apply: false,
},
);
}
self.update_active_workflow_step(cx);
+ if let Some(step) = self.workflow_steps.get_mut(&step_range) {
+ if step.auto_apply && matches!(step.status(cx), WorkflowStepStatus::Idle) {
+ self.apply_workflow_step(step_range, cx);
+ }
+ }
}
fn open_workflow_step(
@@ -2654,14 +2742,25 @@ impl ContextEditor {
fn update_active_workflow_step(&mut self, cx: &mut ViewContext) {
let new_step = self.active_workflow_step_for_cursor(cx);
if new_step.as_ref() != self.active_workflow_step.as_ref() {
+ let mut old_editor = None;
+ let mut old_editor_was_open = None;
if let Some(old_step) = self.active_workflow_step.take() {
- self.hide_workflow_step(old_step.range, cx);
+ (old_editor, old_editor_was_open) =
+ self.hide_workflow_step(old_step.range, cx).unzip();
}
+ let mut new_editor = None;
if let Some(new_step) = new_step {
- self.show_workflow_step(new_step.range.clone(), cx);
+ new_editor = self.show_workflow_step(new_step.range.clone(), cx);
self.active_workflow_step = Some(new_step);
}
+
+ if new_editor != old_editor {
+ if let Some((old_editor, old_editor_was_open)) = old_editor.zip(old_editor_was_open)
+ {
+ self.close_workflow_editor(cx, old_editor, old_editor_was_open)
+ }
+ }
}
}
@@ -2669,15 +2768,15 @@ impl ContextEditor {
&mut self,
step_range: Range,
cx: &mut ViewContext,
- ) {
+ ) -> Option<(View, bool)> {
let Some(step) = self.workflow_steps.get_mut(&step_range) else {
- return;
+ return None;
};
let Some(assist) = step.assist.as_ref() else {
- return;
+ return None;
};
let Some(editor) = assist.editor.upgrade() else {
- return;
+ return None;
};
if matches!(step.status(cx), WorkflowStepStatus::Idle) {
@@ -2687,32 +2786,42 @@ impl ContextEditor {
assistant.finish_assist(assist_id, true, cx)
}
});
-
- self.workspace
- .update(cx, |workspace, cx| {
- if let Some(pane) = workspace.pane_for(&editor) {
- pane.update(cx, |pane, cx| {
- let item_id = editor.entity_id();
- if !assist.editor_was_open && pane.is_active_preview_item(item_id) {
- pane.close_item_by_id(item_id, SaveIntent::Skip, cx)
- .detach_and_log_err(cx);
- }
- });
- }
- })
- .ok();
+ return Some((editor, assist.editor_was_open));
}
+
+ return None;
+ }
+
+ fn close_workflow_editor(
+ &mut self,
+ cx: &mut ViewContext,
+ editor: View,
+ editor_was_open: bool,
+ ) {
+ self.workspace
+ .update(cx, |workspace, cx| {
+ if let Some(pane) = workspace.pane_for(&editor) {
+ pane.update(cx, |pane, cx| {
+ let item_id = editor.entity_id();
+ if !editor_was_open && pane.is_active_preview_item(item_id) {
+ pane.close_item_by_id(item_id, SaveIntent::Skip, cx)
+ .detach_and_log_err(cx);
+ }
+ });
+ }
+ })
+ .ok();
}
fn show_workflow_step(
&mut self,
step_range: Range,
cx: &mut ViewContext,
- ) {
+ ) -> Option> {
let Some(step) = self.workflow_steps.get_mut(&step_range) else {
- return;
+ return None;
};
-
+ let mut editor_to_return = None;
let mut scroll_to_assist_id = None;
match step.status(cx) {
WorkflowStepStatus::Idle => {
@@ -2726,6 +2835,10 @@ impl ContextEditor {
&self.workspace,
cx,
);
+ editor_to_return = step
+ .assist
+ .as_ref()
+ .and_then(|assist| assist.editor.upgrade());
}
}
WorkflowStepStatus::Pending => {
@@ -2747,14 +2860,15 @@ impl ContextEditor {
}
if let Some(assist_id) = scroll_to_assist_id {
- if let Some(editor) = step
+ if let Some(assist_editor) = step
.assist
.as_ref()
.and_then(|assists| assists.editor.upgrade())
{
+ editor_to_return = Some(assist_editor.clone());
self.workspace
.update(cx, |workspace, cx| {
- workspace.activate_item(&editor, false, false, cx);
+ workspace.activate_item(&assist_editor, false, false, cx);
})
.ok();
InlineAssistant::update_global(cx, |assistant, cx| {
@@ -2762,6 +2876,8 @@ impl ContextEditor {
});
}
}
+
+ return editor_to_return;
}
fn open_assists_for_step(
@@ -2805,7 +2921,6 @@ impl ContextEditor {
)
})
.log_err()?;
-
let (&excerpt_id, _, _) = editor
.read(cx)
.buffer()
@@ -3484,7 +3599,8 @@ impl ContextEditor {
};
Some(
h_flex()
- .p_3()
+ .px_3()
+ .py_2()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
@@ -3520,13 +3636,17 @@ impl ContextEditor {
fn render_send_button(&self, cx: &mut ViewContext) -> impl IntoElement {
let focus_handle = self.focus_handle(cx).clone();
+ let mut should_pulsate = false;
let button_text = match self.active_workflow_step() {
Some(step) => match step.status(cx) {
- WorkflowStepStatus::Resolving => "Resolving Step...",
WorkflowStepStatus::Empty | WorkflowStepStatus::Error(_) => "Retry Step Resolution",
+ WorkflowStepStatus::Resolving { auto_apply } => {
+ should_pulsate = auto_apply;
+ "Transform"
+ }
WorkflowStepStatus::Idle => "Transform",
- WorkflowStepStatus::Pending => "Transforming...",
- WorkflowStepStatus::Done => "Accept Transformation",
+ WorkflowStepStatus::Pending => "Applying...",
+ WorkflowStepStatus::Done => "Accept",
WorkflowStepStatus::Confirmed => "Send",
},
None => "Send",
@@ -3556,12 +3676,12 @@ impl ContextEditor {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
&& provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));
- let has_active_error = self.error_message.is_some();
- let disabled = needs_to_accept_terms || has_active_error;
+ let disabled = has_configuration_error || needs_to_accept_terms;
ButtonLike::new("send_button")
.disabled(disabled)
@@ -3570,11 +3690,24 @@ impl ContextEditor {
button.tooltip(move |_| tooltip.clone())
})
.layer(ElevationIndex::ModalSurface)
+ .child(Label::new(button_text).map(|this| {
+ if should_pulsate {
+ this.with_animation(
+ "resolving-suggestion-send-button-animation",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else {
+ this.into_any_element()
+ }
+ }))
.children(
KeyBinding::for_action_in(&Assist, &focus_handle, cx)
.map(|binding| binding.into_any_element()),
)
- .child(Label::new(button_text))
.on_click(move |_event, cx| {
focus_handle.dispatch_action(&Assist, cx);
})
@@ -3604,7 +3737,13 @@ impl Render for ContextEditor {
} else {
None
};
-
+ let focus_handle = self
+ .workspace
+ .update(cx, |workspace, cx| {
+ Some(workspace.active_item_as::(cx)?.focus_handle(cx))
+ })
+ .ok()
+ .flatten();
v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
@@ -3627,8 +3766,8 @@ impl Render for ContextEditor {
this.child(
div()
.absolute()
- .right_4()
- .bottom_10()
+ .right_3()
+ .bottom_12()
.max_w_96()
.py_2()
.px_3()
@@ -3642,8 +3781,8 @@ impl Render for ContextEditor {
this.child(
div()
.absolute()
- .right_4()
- .bottom_10()
+ .right_3()
+ .bottom_12()
.max_w_96()
.py_2()
.px_3()
@@ -3654,12 +3793,12 @@ impl Render for ContextEditor {
.gap_0p5()
.child(
h_flex()
- .gap_1()
+ .gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new("Error interacting with language model")
- .weight(FontWeight::SEMIBOLD),
+ .weight(FontWeight::MEDIUM),
),
)
.child(
@@ -3681,14 +3820,45 @@ impl Render for ContextEditor {
)
})
.child(
- h_flex().flex_none().relative().child(
+ h_flex().w_full().relative().child(
h_flex()
+ .p_2()
.w_full()
- .absolute()
- .right_4()
- .bottom_2()
- .justify_end()
- .child(self.render_send_button(cx)),
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ h_flex()
+ .gap_2()
+ .child(render_inject_context_menu(cx.view().downgrade(), cx))
+ .child(
+ IconButton::new("quote-button", IconName::Quote)
+ .icon_size(IconSize::Small)
+ .on_click(|_, cx| {
+ cx.dispatch_action(QuoteSelection.boxed_clone());
+ })
+ .tooltip(move |cx| {
+ cx.new_view(|cx| {
+ Tooltip::new("Insert Selection").key_binding(
+ focus_handle.as_ref().and_then(|handle| {
+ KeyBinding::for_action_in(
+ &QuoteSelection,
+ &handle,
+ cx,
+ )
+ }),
+ )
+ })
+ .into()
+ }),
+ ),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .justify_end()
+ .child(div().child(self.render_send_button(cx))),
+ ),
),
)
}
@@ -3937,6 +4107,37 @@ pub struct ContextEditorToolbarItem {
model_selector_menu_handle: PopoverMenuHandle>,
}
+fn active_editor_focus_handle(
+ workspace: &WeakView,
+ cx: &WindowContext<'_>,
+) -> Option {
+ workspace.upgrade().and_then(|workspace| {
+ Some(
+ workspace
+ .read(cx)
+ .active_item_as::(cx)?
+ .focus_handle(cx),
+ )
+ })
+}
+
+fn render_inject_context_menu(
+ active_context_editor: WeakView,
+ cx: &mut WindowContext<'_>,
+) -> impl IntoElement {
+ let commands = SlashCommandRegistry::global(cx);
+
+ slash_command_picker::SlashCommandSelector::new(
+ commands.clone(),
+ active_context_editor,
+ IconButton::new("trigger", IconName::SlashSquare)
+ .icon_size(IconSize::Small)
+ .tooltip(|cx| {
+ Tooltip::with_meta("Insert Context", None, "Type / to insert via keyboard", cx)
+ }),
+ )
+}
+
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
@@ -3952,70 +4153,6 @@ impl ContextEditorToolbarItem {
}
}
- fn render_inject_context_menu(&self, cx: &mut ViewContext) -> impl Element {
- let commands = SlashCommandRegistry::global(cx);
- let active_editor_focus_handle = self.workspace.upgrade().and_then(|workspace| {
- Some(
- workspace
- .read(cx)
- .active_item_as::(cx)?
- .focus_handle(cx),
- )
- });
- let active_context_editor = self.active_context_editor.clone();
-
- PopoverMenu::new("inject-context-menu")
- .trigger(IconButton::new("trigger", IconName::Quote).tooltip(|cx| {
- Tooltip::with_meta("Insert Context", None, "Type / to insert via keyboard", cx)
- }))
- .menu(move |cx| {
- let active_context_editor = active_context_editor.clone()?;
- ContextMenu::build(cx, |mut menu, _cx| {
- for command_name in commands.featured_command_names() {
- if let Some(command) = commands.command(&command_name) {
- let menu_text = SharedString::from(Arc::from(command.menu_text()));
- menu = menu.custom_entry(
- {
- let command_name = command_name.clone();
- move |_cx| {
- h_flex()
- .gap_4()
- .w_full()
- .justify_between()
- .child(Label::new(menu_text.clone()))
- .child(
- Label::new(format!("/{command_name}"))
- .color(Color::Muted),
- )
- .into_any()
- }
- },
- {
- let active_context_editor = active_context_editor.clone();
- move |cx| {
- active_context_editor
- .update(cx, |context_editor, cx| {
- context_editor.insert_command(&command_name, cx)
- })
- .ok();
- }
- },
- )
- }
- }
-
- if let Some(active_editor_focus_handle) = active_editor_focus_handle.clone() {
- menu = menu
- .context(active_editor_focus_handle)
- .action("Quote Selection", Box::new(QuoteSelection));
- }
-
- menu
- })
- .into()
- })
- }
-
fn render_remaining_tokens(&self, cx: &mut ViewContext) -> Option {
let context = &self
.active_context_editor
@@ -4062,24 +4199,16 @@ impl ContextEditorToolbarItem {
impl Render for ContextEditorToolbarItem {
fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement {
let left_side = h_flex()
+ .pl_1()
.gap_2()
.flex_1()
.min_w(rems(DEFAULT_TAB_TITLE.len() as f32))
.when(self.active_context_editor.is_some(), |left_side| {
- left_side
- .child(
- IconButton::new("regenerate-context", IconName::ArrowCircle)
- .visible_on_hover("toolbar")
- .tooltip(|cx| Tooltip::text("Regenerate Summary", cx))
- .on_click(cx.listener(move |_, _, cx| {
- cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
- })),
- )
- .child(self.model_summary_editor.clone())
+ left_side.child(self.model_summary_editor.clone())
});
let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
let active_model = LanguageModelRegistry::read_global(cx).active_model();
-
+ let weak_self = cx.view().downgrade();
let right_side = h_flex()
.gap_2()
.child(
@@ -4100,7 +4229,7 @@ impl Render for ContextEditorToolbarItem {
(Some(provider), Some(model)) => h_flex()
.gap_1()
.child(
- Icon::new(provider.icon())
+ Icon::new(model.icon().unwrap_or_else(|| provider.icon()))
.color(Color::Muted)
.size(IconSize::XSmall),
)
@@ -4129,7 +4258,70 @@ impl Render for ContextEditorToolbarItem {
.with_handle(self.model_selector_menu_handle.clone()),
)
.children(self.render_remaining_tokens(cx))
- .child(self.render_inject_context_menu(cx));
+ .child(
+ PopoverMenu::new("context-editor-popover")
+ .trigger(
+ IconButton::new("context-editor-trigger", IconName::EllipsisVertical)
+ .icon_size(IconSize::Small)
+ .tooltip(|cx| Tooltip::text("Open Context Options", cx)),
+ )
+ .menu({
+ let weak_self = weak_self.clone();
+ move |cx| {
+ let weak_self = weak_self.clone();
+ Some(ContextMenu::build(cx, move |menu, cx| {
+ let context = weak_self
+ .update(cx, |this, cx| {
+ active_editor_focus_handle(&this.workspace, cx)
+ })
+ .ok()
+ .flatten();
+ menu.when_some(context, |menu, context| menu.context(context))
+ .entry("Regenerate Context Title", None, {
+ let weak_self = weak_self.clone();
+ move |cx| {
+ weak_self
+ .update(cx, |_, cx| {
+ cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
+ })
+ .ok();
+ }
+ })
+ .custom_entry(
+ |_| {
+ h_flex()
+ .w_full()
+ .justify_between()
+ .gap_2()
+ .child(Label::new("Insert Context"))
+ .child(Label::new("/ command").color(Color::Muted))
+ .into_any()
+ },
+ {
+ let weak_self = weak_self.clone();
+ move |cx| {
+ weak_self
+ .update(cx, |this, cx| {
+ if let Some(editor) =
+ &this.active_context_editor
+ {
+ editor
+ .update(cx, |this, cx| {
+ this.slash_menu_handle
+ .toggle(cx);
+ })
+ .ok();
+ }
+ })
+ .ok();
+ }
+ },
+ )
+ .action("Insert Selection", QuoteSelection.boxed_clone())
+ }))
+ }
+ }),
+ );
h_flex()
.size_full()
@@ -4312,6 +4504,7 @@ impl ConfigurationView {
provider: &Arc,
cx: &mut ViewContext,
) -> Div {
+ let provider_id = provider.id().0.clone();
let provider_name = provider.name().0.clone();
let configuration_view = self.configuration_views.get(&provider.id()).cloned();
@@ -4333,12 +4526,15 @@ impl ConfigurationView {
.when(provider.is_authenticated(cx), move |this| {
this.child(
h_flex().justify_end().child(
- Button::new("new-context", "Open new context")
- .icon_position(IconPosition::Start)
- .icon(IconName::Plus)
- .style(ButtonStyle::Filled)
- .layer(ElevationIndex::ModalSurface)
- .on_click(open_new_context),
+ Button::new(
+ SharedString::from(format!("new-context-{provider_id}")),
+ "Open new context",
+ )
+ .icon_position(IconPosition::Start)
+ .icon(IconName::Plus)
+ .style(ButtonStyle::Filled)
+ .layer(ElevationIndex::ModalSurface)
+ .on_click(open_new_context),
),
)
}),
diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs
index 5b0209b433671e590cc9d2a9c02d8bf01c414207..f68cdb53eb2e326504218902c251c914eb3e641d 100644
--- a/crates/assistant/src/context.rs
+++ b/crates/assistant/src/context.rs
@@ -295,6 +295,7 @@ pub enum ContextEvent {
output_range: Range,
sections: Vec>,
run_commands_in_output: bool,
+ expand_result: bool,
},
Operation(ContextOperation),
}
@@ -774,6 +775,7 @@ impl Context {
cx.emit(ContextEvent::SlashCommandFinished {
output_range,
sections,
+ expand_result: false,
run_commands_in_output: false,
});
}
@@ -1395,7 +1397,8 @@ impl Context {
&mut self,
command_range: Range,
output: Task>,
- insert_trailing_newline: bool,
+ ensure_trailing_newline: bool,
+ expand_result: bool,
cx: &mut ModelContext,
) {
self.reparse_slash_commands(cx);
@@ -1406,8 +1409,27 @@ impl Context {
let output = output.await;
this.update(&mut cx, |this, cx| match output {
Ok(mut output) => {
- if insert_trailing_newline {
- output.text.push('\n');
+ // Ensure section ranges are valid.
+ for section in &mut output.sections {
+ section.range.start = section.range.start.min(output.text.len());
+ section.range.end = section.range.end.min(output.text.len());
+ while !output.text.is_char_boundary(section.range.start) {
+ section.range.start -= 1;
+ }
+ while !output.text.is_char_boundary(section.range.end) {
+ section.range.end += 1;
+ }
+ }
+
+ // Ensure there is a newline after the last section.
+ if ensure_trailing_newline {
+ let has_newline_after_last_section =
+ output.sections.last().map_or(false, |last_section| {
+ output.text[last_section.range.end..].ends_with('\n')
+ });
+ if !has_newline_after_last_section {
+ output.text.push('\n');
+ }
}
let version = this.version.clone();
@@ -1450,6 +1472,7 @@ impl Context {
output_range,
sections,
run_commands_in_output: output.run_commands_in_text,
+ expand_result,
},
)
});
diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs
index 3718c3781266996659b6821cab56e60f63329b8a..4eb7b75a64e54ded5065b6a2bd92bc57d8189fd8 100644
--- a/crates/assistant/src/context/context_tests.rs
+++ b/crates/assistant/src/context/context_tests.rs
@@ -473,7 +473,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
}
#[gpui::test]
-async fn test_edit_step_parsing(cx: &mut TestAppContext) {
+async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(prompt_library::init);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
@@ -891,6 +891,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
run_commands_in_text: false,
})),
true,
+ false,
cx,
);
});
diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs
index 533107c1d589b9f8dc95b9a2a3d5ddd72be03219..dbb750f512d08a18c9c0985263eebe26e78c5b50 100644
--- a/crates/assistant/src/inline_assistant.rs
+++ b/crates/assistant/src/inline_assistant.rs
@@ -28,7 +28,7 @@ use gpui::{
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
UpdateGlobal, View, ViewContext, WeakView, WindowContext,
};
-use language::{Buffer, IndentKind, Point, TransactionId};
+use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
@@ -38,6 +38,7 @@ use rope::Rope;
use settings::Settings;
use smol::future::FutureExt;
use std::{
+ cmp,
future::{self, Future},
mem,
ops::{Range, RangeInclusive},
@@ -46,7 +47,6 @@ use std::{
task::{self, Poll},
time::{Duration, Instant},
};
-use text::OffsetRangeExt as _;
use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
@@ -140,81 +140,66 @@ impl InlineAssistant {
cx: &mut WindowContext,
) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
- struct CodegenRange {
- transform_range: Range,
- selection_ranges: Vec>,
- focus_assist: bool,
- }
- let newest_selection_range = editor.read(cx).selections.newest::(cx).range();
- let mut codegen_ranges: Vec = Vec::new();
-
- let selection_ranges = snapshot
- .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
- .map(|range| range.to_point(&snapshot))
- .collect::>>();
-
- for selection_range in selection_ranges {
- let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
- let mut transform_range = selection_range.start..selection_range.end;
-
- // Expand the transform range to start/end of lines.
- // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
- transform_range.start.column = 0;
- if transform_range.end.column == 0 && transform_range.end > transform_range.start {
- transform_range.end.row -= 1;
- }
- transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
- let selection_range =
- selection_range.start..selection_range.end.min(transform_range.end);
-
- // If we intersect the previous transform range,
- if let Some(CodegenRange {
- transform_range: prev_transform_range,
- selection_ranges,
- focus_assist,
- }) = codegen_ranges.last_mut()
- {
- if transform_range.start <= prev_transform_range.end {
- prev_transform_range.end = transform_range.end;
- selection_ranges.push(selection_range);
- *focus_assist |= selection_is_newest;
+ let mut selections = Vec::>::new();
+ let mut newest_selection = None;
+ for mut selection in editor.read(cx).selections.all::(cx) {
+ if selection.end > selection.start {
+ selection.start.column = 0;
+ // If the selection ends at the start of the line, we don't want to include it.
+ if selection.end.column == 0 {
+ selection.end.row -= 1;
+ }
+ selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
+ }
+
+ if let Some(prev_selection) = selections.last_mut() {
+ if selection.start <= prev_selection.end {
+ prev_selection.end = selection.end;
continue;
}
}
- codegen_ranges.push(CodegenRange {
- transform_range,
- selection_ranges: vec![selection_range],
- focus_assist: selection_is_newest,
- })
+ let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
+ if selection.id > latest_selection.id {
+ *latest_selection = selection.clone();
+ }
+ selections.push(selection);
+ }
+ let newest_selection = newest_selection.unwrap();
+
+ let mut codegen_ranges = Vec::new();
+ for (excerpt_id, buffer, buffer_range) in
+ snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
+ snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
+ }))
+ {
+ let start = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_before(buffer_range.start),
+ };
+ let end = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_after(buffer_range.end),
+ };
+ codegen_ranges.push(start..end);
}
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
+
let mut assists = Vec::new();
let mut assist_to_focus = None;
-
- for CodegenRange {
- transform_range,
- selection_ranges,
- focus_assist,
- } in codegen_ranges
- {
- let transform_range = snapshot.anchor_before(transform_range.start)
- ..snapshot.anchor_after(transform_range.end);
- let selection_ranges = selection_ranges
- .iter()
- .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
- .collect::>();
-
+ for range in codegen_ranges {
+ let assist_id = self.next_assist_id.post_inc();
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
- transform_range.clone(),
- selection_ranges,
+ range.clone(),
None,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -222,7 +207,6 @@ impl InlineAssistant {
)
});
- let assist_id = self.next_assist_id.post_inc();
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
PromptEditor::new(
@@ -239,16 +223,23 @@ impl InlineAssistant {
)
});
- if focus_assist {
- assist_to_focus = Some(assist_id);
+ if assist_to_focus.is_none() {
+ let focus_assist = if newest_selection.reversed {
+ range.start.to_point(&snapshot) == newest_selection.start
+ } else {
+ range.end.to_point(&snapshot) == newest_selection.end
+ };
+ if focus_assist {
+ assist_to_focus = Some(assist_id);
+ }
}
let [prompt_block_id, end_block_id] =
- self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
+ self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((
assist_id,
- transform_range,
+ range,
prompt_editor,
prompt_block_id,
end_block_id,
@@ -315,7 +306,6 @@ impl InlineAssistant {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
- vec![range.clone()],
initial_transaction_id,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -925,7 +915,12 @@ impl InlineAssistant {
assist
.codegen
.update(cx, |codegen, cx| {
- codegen.start(user_prompt, assistant_panel_context, cx)
+ codegen.start(
+ assist.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
})
.log_err();
@@ -2120,9 +2115,12 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
- self.codegen
- .read(cx)
- .count_tokens(user_prompt, assistant_panel_context, cx)
+ self.codegen.read(cx).count_tokens(
+ self.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
}
}
@@ -2143,8 +2141,6 @@ pub struct Codegen {
buffer: Model,
old_buffer: Model,
snapshot: MultiBufferSnapshot,
- transform_range: Range,
- selected_ranges: Vec>,
edit_position: Option,
last_equal_ranges: Vec>,
initial_transaction_id: Option,
@@ -2154,7 +2150,7 @@ pub struct Codegen {
diff: Diff,
telemetry: Option>,
_subscription: gpui::Subscription,
- prompt_builder: Arc,
+ builder: Arc,
}
enum CodegenStatus {
@@ -2181,8 +2177,7 @@ impl EventEmitter for Codegen {}
impl Codegen {
pub fn new(
buffer: Model,
- transform_range: Range,
- selected_ranges: Vec>,
+ range: Range,
initial_transaction_id: Option,
telemetry: Option>,
builder: Arc,
@@ -2192,7 +2187,7 @@ impl Codegen {
let (old_buffer, _, _) = buffer
.read(cx)
- .range_to_buffer_ranges(transform_range.clone(), cx)
+ .range_to_buffer_ranges(range.clone(), cx)
.pop()
.unwrap();
let old_buffer = cx.new_model(|cx| {
@@ -2223,9 +2218,7 @@ impl Codegen {
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id,
- prompt_builder: builder,
- transform_range,
- selected_ranges,
+ builder,
}
}
@@ -2250,12 +2243,14 @@ impl Codegen {
pub fn count_tokens(
&self,
+ edit_range: Range,
user_prompt: String,
assistant_panel_context: Option,
cx: &AppContext,
) -> BoxFuture<'static, Result> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
- let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
+ let request =
+ self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
match request {
Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx);
@@ -2280,6 +2275,7 @@ impl Codegen {
pub fn start(
&mut self,
+ edit_range: Range,
user_prompt: String,
assistant_panel_context: Option,
cx: &mut ModelContext,
@@ -2294,20 +2290,24 @@ impl Codegen {
});
}
- self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
+ self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id();
- let chunks: LocalBoxFuture>>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(stream::empty().boxed()) }.boxed_local()
- } else {
- let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
+ let chunks: LocalBoxFuture>>> = if user_prompt
+ .trim()
+ .to_lowercase()
+ == "delete"
+ {
+ async { Ok(stream::empty().boxed()) }.boxed_local()
+ } else {
+ let request =
+ self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
- let chunks =
- cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
- async move { Ok(chunks.await?.boxed()) }.boxed_local()
- };
- self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
+ let chunks =
+ cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
+ async move { Ok(chunks.await?.boxed()) }.boxed_local()
+ };
+ self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(())
}
@@ -2315,10 +2315,11 @@ impl Codegen {
&self,
user_prompt: String,
assistant_panel_context: Option,
+ edit_range: Range,
cx: &AppContext,
) -> Result {
let buffer = self.buffer.read(cx).snapshot(cx);
- let language = buffer.language_at(self.transform_range.start);
+ let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@@ -2343,9 +2344,9 @@ impl Codegen {
};
let language_name = language_name.as_deref();
- let start = buffer.point_to_buffer_offset(self.transform_range.start);
- let end = buffer.point_to_buffer_offset(self.transform_range.end);
- let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
+ let start = buffer.point_to_buffer_offset(edit_range.start);
+ let end = buffer.point_to_buffer_offset(edit_range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
@@ -2357,39 +2358,9 @@ impl Codegen {
return Err(anyhow::anyhow!("invalid transformation range"));
};
- let mut transform_context_range = transform_range.to_point(&transform_buffer);
- transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
- transform_context_range.start.column = 0;
- transform_context_range.end =
- (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
- transform_context_range.end.column =
- transform_buffer.line_len(transform_context_range.end.row);
- let transform_context_range = transform_context_range.to_offset(&transform_buffer);
-
- let selected_ranges = self
- .selected_ranges
- .iter()
- .filter_map(|selected_range| {
- let start = buffer
- .point_to_buffer_offset(selected_range.start)
- .map(|(_, offset)| offset)?;
- let end = buffer
- .point_to_buffer_offset(selected_range.end)
- .map(|(_, offset)| offset)?;
- Some(start..end)
- })
- .collect::>();
-
let prompt = self
- .prompt_builder
- .generate_content_prompt(
- user_prompt,
- language_name,
- transform_buffer,
- transform_range,
- selected_ranges,
- transform_context_range,
- )
+ .builder
+ .generate_content_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new();
@@ -2462,19 +2433,84 @@ impl Codegen {
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
+ let mut new_text = String::new();
+ let mut base_indent = None;
+ let mut line_indent = None;
+ let mut first_line = true;
+
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
- let char_ops = diff.push_new(&chunk);
- line_diff.push_char_operations(&char_ops, &selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
+
+ let mut lines = chunk.split('\n').peekable();
+ while let Some(line) = lines.next() {
+ new_text.push_str(line);
+ if line_indent.is_none() {
+ if let Some(non_whitespace_ch_ix) =
+ new_text.find(|ch: char| !ch.is_whitespace())
+ {
+ line_indent = Some(non_whitespace_ch_ix);
+ base_indent = base_indent.or(line_indent);
+
+ let line_indent = line_indent.unwrap();
+ let base_indent = base_indent.unwrap();
+ let indent_delta =
+ line_indent as i32 - base_indent as i32;
+ let mut corrected_indent_len = cmp::max(
+ 0,
+ suggested_line_indent.len as i32 + indent_delta,
+ )
+ as usize;
+ if first_line {
+ corrected_indent_len = corrected_indent_len
+ .saturating_sub(
+ selection_start.column as usize,
+ );
+ }
+
+ let indent_char = suggested_line_indent.char();
+ let mut indent_buffer = [0; 4];
+ let indent_str =
+ indent_char.encode_utf8(&mut indent_buffer);
+ new_text.replace_range(
+ ..line_indent,
+ &indent_str.repeat(corrected_indent_len),
+ );
+ }
+ }
+
+ if line_indent.is_some() {
+ let char_ops = diff.push_new(&new_text);
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ new_text.clear();
+ }
+
+ if lines.peek().is_some() {
+ let char_ops = diff.push_new("\n");
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ if line_indent.is_none() {
+ // Don't write out the leading indentation in empty lines on the next line
+ // This is the case where the above if statement didn't clear the buffer
+ new_text.clear();
+ }
+ line_indent = None;
+ first_line = false;
+ }
+ }
}
- let char_ops = diff.finish();
+ let mut char_ops = diff.push_new(&new_text);
+ char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
@@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use futures::stream::{self};
+ use gpui::{Context, TestAppContext};
+ use indoc::indoc;
+ use language::{
+ language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
+ Point,
+ };
+ use language_model::LanguageModelRegistry;
+ use rand::prelude::*;
use serde::Serialize;
+ use settings::SettingsStore;
+ use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
+ #[gpui::test(iterations = 10)]
+ async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_model::LanguageModelRegistry::test);
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ let x = 0;
+ for _ in 0..10 {
+ x += 1;
+ }
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range,
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let mut new_text = concat!(
+ " let mut x = 0;\n",
+ " while x < 10 {\n",
+ " x += 1;\n",
+ " }",
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_past_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ le
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "t mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_before_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = concat!(
+ "fn main() {\n",
+ " \n",
+ "}\n" //
+ );
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "let mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ func main() {
+ \tx := 0
+ \tfor i := 0; i < 10; i++ {
+ \t\tx++
+ \t}
+ }
+ "};
+ let buffer = cx.new_model(|cx| Buffer::local(text, cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let new_text = concat!(
+ "func main() {\n",
+ "\tx := 0\n",
+ "\tfor x < 10 {\n",
+ "\t\tx++\n",
+ "\t}", //
+ );
+ chunks_tx.unbounded_send(new_text.to_string()).unwrap();
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ func main() {
+ \tx := 0
+ \tfor x < 10 {
+ \t\tx++
+ \t}
+ }
+ "}
+ );
+ }
+
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@@ -2984,4 +3318,27 @@ mod tests {
)
}
}
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ .with_indents_query(
+ r#"
+ (call_expression) @indent
+ (field_expression) @indent
+ (_ "(" ")" @end) @indent
+ (_ "{" "}" @end) @indent
+ "#,
+ )
+ .unwrap()
+ }
}
diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs
index 957267c5dc4e6115d512af6d60df3ebbb97e719b..514bb3ee870fe2814f03364b990e81455c70a314 100644
--- a/crates/assistant/src/model_selector.rs
+++ b/crates/assistant/src/model_selector.rs
@@ -1,15 +1,15 @@
use feature_flags::ZedPro;
+use gpui::Action;
use gpui::DismissEvent;
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use proto::Plan;
+use workspace::ShowConfiguration;
use std::sync::Arc;
use ui::ListItemSpacing;
use crate::assistant_settings::AssistantSettings;
-use crate::ShowConfiguration;
use fs::Fs;
-use gpui::Action;
use gpui::SharedString;
use gpui::Task;
use picker::{Picker, PickerDelegate};
@@ -36,7 +36,7 @@ pub struct ModelPickerDelegate {
#[derive(Clone)]
struct ModelInfo {
model: Arc,
- provider_icon: IconName,
+ icon: IconName,
availability: LanguageModelAvailability,
is_selected: bool,
}
@@ -156,7 +156,7 @@ impl PickerDelegate for ModelPickerDelegate {
.selected(selected)
.start_slot(
div().pr_1().child(
- Icon::new(model_info.provider_icon)
+ Icon::new(model_info.icon)
.color(Color::Muted)
.size(IconSize::Medium),
),
@@ -261,16 +261,17 @@ impl RenderOnce for ModelSelector {
.iter()
.flat_map(|provider| {
let provider_id = provider.id();
- let provider_icon = provider.icon();
+ let icon = provider.icon();
let selected_model = selected_model.clone();
let selected_provider = selected_provider.clone();
provider.provided_models(cx).into_iter().map(move |model| {
let model = model.clone();
+ let icon = model.icon().unwrap_or(icon);
ModelInfo {
model: model.clone(),
- provider_icon,
+ icon,
availability: model.availability(),
is_selected: selected_model.as_ref() == Some(&model.id())
&& selected_provider.as_ref() == Some(&provider_id),
diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs
index ed3324f54f7ff310a7adc02e985c740e54fc8695..462f74f5d3e834f1653e344690a93ec1ddf41b4a 100644
--- a/crates/assistant/src/prompts.rs
+++ b/crates/assistant/src/prompts.rs
@@ -1,26 +1,24 @@
+use anyhow::Result;
use assets::Assets;
use fs::Fs;
use futures::StreamExt;
-use handlebars::{Handlebars, RenderError, TemplateError};
+use gpui::AssetSource;
+use handlebars::{Handlebars, RenderError};
use language::BufferSnapshot;
use parking_lot::Mutex;
use serde::Serialize;
-use std::{ops::Range, sync::Arc, time::Duration};
+use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration};
use util::ResultExt;
#[derive(Serialize)]
pub struct ContentPromptContext {
pub content_type: String,
pub language_name: Option,
+ pub is_insert: bool,
pub is_truncated: bool,
pub document_content: String,
pub user_prompt: String,
- pub rewrite_section: String,
- pub rewrite_section_prefix: String,
- pub rewrite_section_suffix: String,
- pub rewrite_section_with_edits: String,
- pub has_insertion: bool,
- pub has_replacement: bool,
+ pub rewrite_section: Option,
}
#[derive(Serialize)]
@@ -42,128 +40,162 @@ pub struct StepResolutionContext {
pub step_to_resolve: String,
}
-pub struct PromptBuilder {
- handlebars: Arc>>,
+pub struct PromptLoadingParams<'a> {
+ pub fs: Arc,
+ pub repo_path: Option,
+ pub cx: &'a gpui::AppContext,
}
-pub struct PromptOverrideContext<'a> {
- pub dev_mode: bool,
- pub fs: Arc,
- pub cx: &'a mut gpui::AppContext,
+pub struct PromptBuilder {
+ handlebars: Arc>>,
}
impl PromptBuilder {
- pub fn new(override_cx: Option) -> Result> {
+ pub fn new(loading_params: Option) -> Result {
let mut handlebars = Handlebars::new();
- Self::register_templates(&mut handlebars)?;
+ Self::register_built_in_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars));
- if let Some(override_cx) = override_cx {
- Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
+ if let Some(params) = loading_params {
+ Self::watch_fs_for_template_overrides(params, handlebars.clone());
}
Ok(Self { handlebars })
}
+ /// Watches the filesystem for changes to prompt template overrides.
+ ///
+ /// This function sets up a file watcher on the prompt templates directory. It performs
+ /// an initial scan of the directory and registers any existing template overrides.
+ /// Then it continuously monitors for changes, reloading templates as they are
+ /// modified or added.
+ ///
+ /// If the templates directory doesn't exist initially, it waits for it to be created.
+ /// If the directory is removed, it restores the built-in templates and waits for the
+ /// directory to be recreated.
+ ///
+ /// # Arguments
+ ///
+ /// * `params` - A `PromptLoadingParams` struct containing the filesystem, repository path,
+ /// and application context.
+ /// * `handlebars` - An `Arc>` for registering and updating templates.
fn watch_fs_for_template_overrides(
- PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
+ mut params: PromptLoadingParams,
handlebars: Arc>>,
) {
- cx.background_executor()
+ params.repo_path = None;
+ let templates_dir = paths::prompt_overrides_dir(params.repo_path.as_deref());
+ params.cx.background_executor()
.spawn(async move {
- let templates_dir = if dev_mode {
- std::env::current_dir()
- .ok()
- .and_then(|pwd| {
- let pwd_assets_prompts = pwd.join("assets").join("prompts");
- pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
- })
- .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
- } else {
- paths::prompt_overrides_dir().clone()
+ let Some(parent_dir) = templates_dir.parent() else {
+ return;
};
- // Create the prompt templates directory if it doesn't exist
- if !fs.is_dir(&templates_dir).await {
- if let Err(e) = fs.create_dir(&templates_dir).await {
- log::error!("Failed to create prompt templates directory: {}", e);
- return;
+ let mut found_dir_once = false;
+ loop {
+ // Check if the templates directory exists and handle its status
+ // If it exists, log its presence and check if it's a symlink
+ // If it doesn't exist:
+ // - Log that we're using built-in prompts
+ // - Check if it's a broken symlink and log if so
+ // - Set up a watcher to detect when it's created
+ // After the first check, set the `found_dir_once` flag
+ // This allows us to avoid logging when looping back around after deleting the prompt overrides directory.
+ let dir_status = params.fs.is_dir(&templates_dir).await;
+ let symlink_status = params.fs.read_link(&templates_dir).await.ok();
+ if dir_status {
+ let mut log_message = format!("Prompt template overrides directory found at {}", templates_dir.display());
+ if let Some(target) = symlink_status {
+ log_message.push_str(" -> ");
+ log_message.push_str(&target.display().to_string());
+ }
+ log::info!("{}.", log_message);
+ } else {
+ if !found_dir_once {
+ log::info!("No prompt template overrides directory found at {}. Using built-in prompts.", templates_dir.display());
+ if let Some(target) = symlink_status {
+ log::info!("Symlink found pointing to {}, but target is invalid.", target.display());
+ }
+ }
+
+ if params.fs.is_dir(parent_dir).await {
+ let (mut changes, _watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await;
+ while let Some(changed_paths) = changes.next().await {
+ if changed_paths.iter().any(|p| p == &templates_dir) {
+ let mut log_message = format!("Prompt template overrides directory detected at {}", templates_dir.display());
+ if let Ok(target) = params.fs.read_link(&templates_dir).await {
+ log_message.push_str(" -> ");
+ log_message.push_str(&target.display().to_string());
+ }
+ log::info!("{}.", log_message);
+ break;
+ }
+ }
+ } else {
+ return;
+ }
}
- }
- // Initial scan of the prompts directory
- if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
- while let Some(Ok(file_path)) = entries.next().await {
- if file_path.to_string_lossy().ends_with(".hbs") {
- if let Ok(content) = fs.load(&file_path).await {
- let file_name = file_path.file_stem().unwrap().to_string_lossy();
+ found_dir_once = true;
- match handlebars.lock().register_template_string(&file_name, content) {
- Ok(_) => {
- log::info!(
- "Successfully registered template override: {} ({})",
- file_name,
- file_path.display()
- );
- },
- Err(e) => {
- log::error!(
- "Failed to register template during initial scan: {} ({})",
- e,
- file_path.display()
- );
- },
+ // Initial scan of the prompt overrides directory
+ if let Ok(mut entries) = params.fs.read_dir(&templates_dir).await {
+ while let Some(Ok(file_path)) = entries.next().await {
+ if file_path.to_string_lossy().ends_with(".hbs") {
+ if let Ok(content) = params.fs.load(&file_path).await {
+ let file_name = file_path.file_stem().unwrap().to_string_lossy();
+ log::info!("Registering prompt template override: {}", file_name);
+ handlebars.lock().register_template_string(&file_name, content).log_err();
}
}
}
}
- }
- // Watch for changes
- let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
- while let Some(changed_paths) = changes.next().await {
- for changed_path in changed_paths {
- if changed_path.extension().map_or(false, |ext| ext == "hbs") {
- log::info!("Reloading template: {}", changed_path.display());
- if let Some(content) = fs.load(&changed_path).await.log_err() {
- let file_name = changed_path.file_stem().unwrap().to_string_lossy();
- let file_path = changed_path.to_string_lossy();
- match handlebars.lock().register_template_string(&file_name, content) {
- Ok(_) => log::info!(
- "Successfully reloaded template: {} ({})",
- file_name,
- file_path
- ),
- Err(e) => log::error!(
- "Failed to register template: {} ({})",
- e,
- file_path
- ),
+ // Watch both the parent directory and the template overrides directory:
+ // - Monitor the parent directory to detect if the template overrides directory is deleted.
+ // - Monitor the template overrides directory to re-register templates when they change.
+ // Combine both watch streams into a single stream.
+ let (parent_changes, parent_watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await;
+ let (changes, watcher) = params.fs.watch(&templates_dir, Duration::from_secs(1)).await;
+ let mut combined_changes = futures::stream::select(changes, parent_changes);
+
+ while let Some(changed_paths) = combined_changes.next().await {
+ if changed_paths.iter().any(|p| p == &templates_dir) {
+ if !params.fs.is_dir(&templates_dir).await {
+ log::info!("Prompt template overrides directory removed. Restoring built-in prompt templates.");
+ Self::register_built_in_templates(&mut handlebars.lock()).log_err();
+ break;
+ }
+ }
+ for changed_path in changed_paths {
+ if changed_path.starts_with(&templates_dir) && changed_path.extension().map_or(false, |ext| ext == "hbs") {
+ log::info!("Reloading prompt template override: {}", changed_path.display());
+ if let Some(content) = params.fs.load(&changed_path).await.log_err() {
+ let file_name = changed_path.file_stem().unwrap().to_string_lossy();
+ handlebars.lock().register_template_string(&file_name, content).log_err();
}
}
}
}
+
+ drop(watcher);
+ drop(parent_watcher);
}
- drop(watcher);
})
.detach();
}
- fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box> {
- let mut register_template = |id: &str| {
- let prompt = Assets::get(&format!("prompts/{}.hbs", id))
- .unwrap_or_else(|| panic!("{} prompt template not found", id))
- .data;
- handlebars
- .register_template_string(id, String::from_utf8_lossy(&prompt))
- .map_err(Box::new)
- };
-
- register_template("content_prompt")?;
- register_template("terminal_assistant_prompt")?;
- register_template("edit_workflow")?;
- register_template("step_resolution")?;
+ fn register_built_in_templates(handlebars: &mut Handlebars) -> Result<()> {
+ for path in Assets.list("prompts")? {
+ if let Some(id) = path.split('/').last().and_then(|s| s.strip_suffix(".hbs")) {
+ if let Some(prompt) = Assets.load(path.as_ref()).log_err().flatten() {
+ log::info!("Registering built-in prompt template: {}", id);
+ handlebars
+ .register_template_string(id, String::from_utf8_lossy(prompt.as_ref()))?
+ }
+ }
+ }
Ok(())
}
@@ -173,9 +205,7 @@ impl PromptBuilder {
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
- transform_range: Range,
- selected_ranges: Vec>,
- transform_context_range: Range,
+ range: Range,
) -> Result {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@@ -183,20 +213,21 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
+ let is_insert = range.is_empty();
let mut is_truncated = false;
- let before_range = 0..transform_range.start;
+ let before_range = 0..range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.start - MAX_CTX..transform_range.start
+ range.start - MAX_CTX..range.start
} else {
before_range
};
- let after_range = transform_range.end..buffer.len();
+ let after_range = range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.end..transform_range.end + MAX_CTX
+ range.end..range.end + MAX_CTX
} else {
after_range
};
@@ -205,74 +236,37 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
-
- document_content.push_str("\n");
- for chunk in buffer.text_for_range(transform_range.clone()) {
- document_content.push_str(chunk);
+ if is_insert {
+ document_content.push_str("");
+ } else {
+ document_content.push_str("\n");
+ for chunk in buffer.text_for_range(range.clone()) {
+ document_content.push_str(chunk);
+ }
+ document_content.push_str("\n");
}
- document_content.push_str("\n");
-
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
- let mut rewrite_section = String::new();
- for chunk in buffer.text_for_range(transform_range.clone()) {
- rewrite_section.push_str(chunk);
- }
-
- let mut rewrite_section_prefix = String::new();
- for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
- rewrite_section_prefix.push_str(chunk);
- }
-
- let mut rewrite_section_suffix = String::new();
- for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
- rewrite_section_suffix.push_str(chunk);
- }
-
- let rewrite_section_with_edits = {
- let mut section_with_selections = String::new();
- let mut last_end = 0;
- for selected_range in &selected_ranges {
- if selected_range.start > last_end {
- section_with_selections.push_str(
- &rewrite_section[last_end..selected_range.start - transform_range.start],
- );
- }
- if selected_range.start == selected_range.end {
- section_with_selections.push_str("");
- } else {
- section_with_selections.push_str("");
- section_with_selections.push_str(
- &rewrite_section[selected_range.start - transform_range.start
- ..selected_range.end - transform_range.start],
- );
- section_with_selections.push_str("");
- }
- last_end = selected_range.end - transform_range.start;
- }
- if last_end < rewrite_section.len() {
- section_with_selections.push_str(&rewrite_section[last_end..]);
+ let rewrite_section = if !is_insert {
+ let mut section = String::new();
+ for chunk in buffer.text_for_range(range.clone()) {
+ section.push_str(chunk);
}
- section_with_selections
+ Some(section)
+ } else {
+ None
};
- let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
- let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
-
let context = ContentPromptContext {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
+ is_insert,
is_truncated,
document_content,
user_prompt,
rewrite_section,
- rewrite_section_prefix,
- rewrite_section_suffix,
- rewrite_section_with_edits,
- has_insertion,
- has_replacement,
};
self.handlebars.lock().render("content_prompt", &context)
diff --git a/crates/assistant/src/slash_command.rs b/crates/assistant/src/slash_command.rs
index 667b11c7421dce6c06735399531c0e8f5b677efa..b1a97688b2b46a977654c6ddd6d6389f79638d7d 100644
--- a/crates/assistant/src/slash_command.rs
+++ b/crates/assistant/src/slash_command.rs
@@ -124,6 +124,7 @@ impl SlashCommandCompletionProvider {
&command_name,
&[],
true,
+ false,
workspace.clone(),
cx,
);
@@ -208,6 +209,7 @@ impl SlashCommandCompletionProvider {
&command_name,
&completed_arguments,
true,
+ false,
workspace.clone(),
cx,
);
diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs
index 95c58be1ee0e00c66b48c9485c045cf6db06dd20..c39f132a7bc2fdece665fcb514504c5742d0ca1f 100644
--- a/crates/assistant/src/slash_command/context_server_command.rs
+++ b/crates/assistant/src/slash_command/context_server_command.rs
@@ -67,7 +67,11 @@ impl SlashCommand for ContextServerSlashCommand {
) -> Task> {
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
- let argument = arguments.first().cloned();
+
+ let prompt_args = match prompt_arguments(&self.prompt, arguments) {
+ Ok(args) => args,
+ Err(e) => return Task::ready(Err(e)),
+ };
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
@@ -76,10 +80,7 @@ impl SlashCommand for ContextServerSlashCommand {
let Some(protocol) = server.client.read().clone() else {
return Err(anyhow!("Context server not initialized"));
};
-
- let result = protocol
- .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
- .await?;
+ let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
Ok(SlashCommandOutput {
sections: vec![SlashCommandOutputSection {
@@ -97,19 +98,27 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
-fn prompt_arguments(
- prompt: &PromptInfo,
- argument: Option,
-) -> Result> {
+fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result> {
match &prompt.arguments {
- Some(args) if args.len() >= 2 => Err(anyhow!(
+ Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported"
)),
- Some(args) if args.len() == 1 => match argument {
- Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
- None => Err(anyhow!("Prompt expects argument but none given")),
- },
- Some(_) | None => Ok(HashMap::default()),
+ Some(args) if args.len() == 1 => {
+ if !arguments.is_empty() {
+ let mut map = HashMap::default();
+ map.insert(args[0].name.clone(), arguments.join(" "));
+ Ok(map)
+ } else {
+ Err(anyhow!("Prompt expects argument but none given"))
+ }
+ }
+ Some(_) | None => {
+ if arguments.is_empty() {
+ Ok(HashMap::default())
+ } else {
+ Err(anyhow!("Prompt expects no arguments but some were given"))
+ }
+ }
}
}
diff --git a/crates/assistant/src/slash_command_picker.rs b/crates/assistant/src/slash_command_picker.rs
new file mode 100644
index 0000000000000000000000000000000000000000..4b57dcfb3306c5cc676f8064af0108782d7886a8
--- /dev/null
+++ b/crates/assistant/src/slash_command_picker.rs
@@ -0,0 +1,306 @@
+use std::sync::Arc;
+
+use assistant_slash_command::SlashCommandRegistry;
+use gpui::AnyElement;
+use gpui::DismissEvent;
+use gpui::WeakView;
+use picker::PickerEditorPosition;
+
+use ui::ListItemSpacing;
+
+use gpui::SharedString;
+use gpui::Task;
+use picker::{Picker, PickerDelegate};
+use ui::{prelude::*, ListItem, PopoverMenu, PopoverTrigger};
+
+use crate::assistant_panel::ContextEditor;
+
+#[derive(IntoElement)]
+pub(super) struct SlashCommandSelector {
+ registry: Arc,
+ active_context_editor: WeakView,
+ trigger: T,
+}
+
+#[derive(Clone)]
+struct SlashCommandInfo {
+ name: SharedString,
+ description: SharedString,
+ args: Option,
+}
+
+#[derive(Clone)]
+enum SlashCommandEntry {
+ Info(SlashCommandInfo),
+ Advert {
+ name: SharedString,
+ renderer: fn(&mut WindowContext<'_>) -> AnyElement,
+ on_confirm: fn(&mut WindowContext<'_>),
+ },
+}
+
+impl AsRef for SlashCommandEntry {
+ fn as_ref(&self) -> &str {
+ match self {
+ SlashCommandEntry::Info(SlashCommandInfo { name, .. })
+ | SlashCommandEntry::Advert { name, .. } => name,
+ }
+ }
+}
+
+pub(crate) struct SlashCommandDelegate {
+ all_commands: Vec,
+ filtered_commands: Vec,
+ active_context_editor: WeakView,
+ selected_index: usize,
+}
+
+impl SlashCommandSelector {
+ pub(crate) fn new(
+ registry: Arc,
+ active_context_editor: WeakView,
+ trigger: T,
+ ) -> Self {
+ SlashCommandSelector {
+ registry,
+ active_context_editor,
+ trigger,
+ }
+ }
+}
+
+impl PickerDelegate for SlashCommandDelegate {
+ type ListItem = ListItem;
+
+ fn match_count(&self) -> usize {
+ self.filtered_commands.len()
+ }
+
+ fn selected_index(&self) -> usize {
+ self.selected_index
+ }
+
+ fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext>) {
+ self.selected_index = ix.min(self.filtered_commands.len().saturating_sub(1));
+ cx.notify();
+ }
+
+ fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc {
+ "Select a command...".into()
+ }
+
+ fn update_matches(&mut self, query: String, cx: &mut ViewContext>) -> Task<()> {
+ let all_commands = self.all_commands.clone();
+ cx.spawn(|this, mut cx| async move {
+ let filtered_commands = cx
+ .background_executor()
+ .spawn(async move {
+ if query.is_empty() {
+ all_commands
+ } else {
+ all_commands
+ .into_iter()
+ .filter(|model_info| {
+ model_info
+ .as_ref()
+ .to_lowercase()
+ .contains(&query.to_lowercase())
+ })
+ .collect()
+ }
+ })
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ this.delegate.filtered_commands = filtered_commands;
+ this.delegate.set_selected_index(0, cx);
+ cx.notify();
+ })
+ .ok();
+ })
+ }
+
+ fn separators_after_indices(&self) -> Vec {
+ let mut ret = vec![];
+ let mut previous_is_advert = false;
+
+ for (index, command) in self.filtered_commands.iter().enumerate() {
+ if previous_is_advert {
+ if let SlashCommandEntry::Info(_) = command {
+ previous_is_advert = false;
+ debug_assert_ne!(
+ index, 0,
+ "index cannot be zero, as we can never have a separator at 0th position"
+ );
+ ret.push(index - 1);
+ }
+ } else {
+ if let SlashCommandEntry::Advert { .. } = command {
+ previous_is_advert = true;
+ if index != 0 {
+ ret.push(index - 1);
+ }
+ }
+ }
+ }
+ ret
+ }
+ fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext>) {
+ if let Some(command) = self.filtered_commands.get(self.selected_index) {
+ if let SlashCommandEntry::Info(info) = command {
+ self.active_context_editor
+ .update(cx, |context_editor, cx| {
+ context_editor.insert_command(&info.name, cx)
+ })
+ .ok();
+ } else if let SlashCommandEntry::Advert { on_confirm, .. } = command {
+ on_confirm(cx);
+ }
+ cx.emit(DismissEvent);
+ }
+ }
+
+ fn dismissed(&mut self, _cx: &mut ViewContext>) {}
+
+ fn editor_position(&self) -> PickerEditorPosition {
+ PickerEditorPosition::End
+ }
+
+ fn render_match(
+ &self,
+ ix: usize,
+ selected: bool,
+ cx: &mut ViewContext>,
+ ) -> Option {
+ let command_info = self.filtered_commands.get(ix)?;
+
+ match command_info {
+ SlashCommandEntry::Info(info) => Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .selected(selected)
+ .child(
+ h_flex()
+ .group(format!("command-entry-label-{ix}"))
+ .w_full()
+ .min_w(px(220.))
+ .child(
+ v_flex()
+ .child(
+ h_flex()
+ .child(div().font_buffer(cx).child({
+ let mut label = format!("/{}", info.name);
+ if let Some(args) =
+ info.args.as_ref().filter(|_| selected)
+ {
+ label.push_str(&args);
+ }
+ Label::new(label).size(LabelSize::Small)
+ }))
+ .children(info.args.clone().filter(|_| !selected).map(
+ |args| {
+ div()
+ .font_buffer(cx)
+ .child(
+ Label::new(args).size(LabelSize::Small),
+ )
+ .visible_on_hover(format!(
+ "command-entry-label-{ix}"
+ ))
+ },
+ )),
+ )
+ .child(
+ Label::new(info.description.clone())
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ ),
+ ),
+ ),
+ SlashCommandEntry::Advert { renderer, .. } => Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .selected(selected)
+ .child(renderer(cx)),
+ ),
+ }
+ }
+}
+
+impl RenderOnce for SlashCommandSelector {
+ fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ let all_models = self
+ .registry
+ .featured_command_names()
+ .into_iter()
+ .filter_map(|command_name| {
+ let command = self.registry.command(&command_name)?;
+ let menu_text = SharedString::from(Arc::from(command.menu_text()));
+ let label = command.label(cx);
+ let args = label.filter_range.end.ne(&label.text.len()).then(|| {
+ SharedString::from(
+ label.text[label.filter_range.end..label.text.len()].to_owned(),
+ )
+ });
+ Some(SlashCommandEntry::Info(SlashCommandInfo {
+ name: command_name.into(),
+ description: menu_text,
+ args,
+ }))
+ })
+ .chain([SlashCommandEntry::Advert {
+ name: "create-your-command".into(),
+ renderer: |cx| {
+ v_flex()
+ .child(
+ h_flex()
+ .font_buffer(cx)
+ .items_center()
+ .gap_1()
+ .child(div().font_buffer(cx).child(
+ Label::new("create-your-command").size(LabelSize::Small),
+ ))
+ .child(Icon::new(IconName::ArrowUpRight).size(IconSize::XSmall)),
+ )
+ .child(
+ Label::new("Learn how to create a custom command")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ },
+ on_confirm: |cx| cx.open_url("https://zed.dev/docs/extensions/slash-commands"),
+ }])
+ .collect::>();
+
+ let delegate = SlashCommandDelegate {
+ all_commands: all_models.clone(),
+ active_context_editor: self.active_context_editor.clone(),
+ filtered_commands: all_models,
+ selected_index: 0,
+ };
+
+ let picker_view = cx.new_view(|cx| {
+ let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
+ picker
+ });
+
+ let handle = self
+ .active_context_editor
+ .update(cx, |this, _| this.slash_menu_handle.clone())
+ .ok();
+ PopoverMenu::new("model-switcher")
+ .menu(move |_cx| Some(picker_view.clone()))
+ .trigger(self.trigger)
+ .attach(gpui::AnchorCorner::TopLeft)
+ .anchor(gpui::AnchorCorner::BottomLeft)
+ .offset(gpui::Point {
+ x: px(0.0),
+ y: px(-16.0),
+ })
+ .when_some(handle, |this, handle| this.with_handle(handle))
+ }
+}
diff --git a/crates/assistant/src/using-the-assistant.md b/crates/assistant/src/using-the-assistant.md
deleted file mode 100644
index 23cc3b287c937ebe2e5f705a5c2553c192db3107..0000000000000000000000000000000000000000
--- a/crates/assistant/src/using-the-assistant.md
+++ /dev/null
@@ -1,42 +0,0 @@
-## Assistant Panel
-
-Once you have configured a provider, you can interact with the provider's language models in a context editor.
-
-To create a new context editor, use the menu in the top right of the assistant panel and select the `New Context` option.
-
-In the context editor, select a model from one of the configured providers, type a message in the `You` block, and submit with `cmd-enter` (or `ctrl-enter` on Linux).
-
-### Adding Prompts
-
-You can customize the default prompts used in new context editors by opening the `Prompt Library`.
-
-Open the `Prompt Library` using either the menu in the top right of the assistant panel and choosing the `Prompt Library` option, or by using the `assistant: deploy prompt library` command when the assistant panel is focused.
-
-### Viewing past contexts
-
-You can view all previous contexts by opening the `History` tab in the assistant panel.
-
-Open the `History` using the menu in the top right of the assistant panel and choosing `History`.
-
-### Slash commands
-
-Slash commands enhance the assistant's capabilities. Begin by typing a `/` at the beginning of the line to see a list of available commands:
-
-- default: Inserts the default prompt into the context
-- diagnostics: Injects errors reported by the project's language server into the context
-- fetch: Pulls the content of a webpage and inserts it into the context
-- file: Pulls a single file or a directory of files into the context
-- now: Inserts the current date and time into the context
-- prompt: Adds a custom-configured prompt to the context (see Prompt Library)
-- search: Performs semantic search for content in your project based on natural language
-- symbols: Pulls the current tab's active symbols into the context
-- tab: Pulls in the content of the active tab or all open tabs into the context
-- terminal: Pulls in a select number of lines of output from the terminal
-
-## Inline assistant
-
-You can use `ctrl-enter` to open the inline assistant in both a normal editor and within the assistant panel.
-
-The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response.
-
-The inline assistant pulls its context from the assistant panel, allowing you to provide additional instructions or rules for code transformations.
diff --git a/crates/assistant/src/workflow/step_view.rs b/crates/assistant/src/workflow/step_view.rs
index 5e75669fc12f2735a364622869378fc6ae6ae445..c8615dc8534088606a4fc2c020bdbe8da6774091 100644
--- a/crates/assistant/src/workflow/step_view.rs
+++ b/crates/assistant/src/workflow/step_view.rs
@@ -273,7 +273,7 @@ impl Item for WorkflowStepView {
}
fn tab_icon(&self, _cx: &WindowContext) -> Option {
- Some(Icon::new(IconName::Pencil))
+ Some(Icon::new(IconName::SearchCode))
}
fn to_item_events(event: &Self::Event, mut f: impl FnMut(item::ItemEvent)) {
diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
index e2dd9cdf8608c3b1bd485522231c624aa8c9419f..225971ec324e68d69c1129a23d6c6fdd46ddc759 100644
--- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
+++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql
@@ -295,7 +295,8 @@ CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection
CREATE TABLE "feature_flags" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
- "flag" TEXT NOT NULL UNIQUE
+ "flag" TEXT NOT NULL UNIQUE,
+ "enabled_for_all" BOOLEAN NOT NULL DEFAULT false
);
CREATE INDEX "index_feature_flags" ON "feature_flags" ("id");
diff --git a/crates/collab/migrations/20240816181658_add_enabled_for_all_to_feature_flags.sql b/crates/collab/migrations/20240816181658_add_enabled_for_all_to_feature_flags.sql
new file mode 100644
index 0000000000000000000000000000000000000000..a56c87b97a41de260ccb4aa7d44fd35d2c026293
--- /dev/null
+++ b/crates/collab/migrations/20240816181658_add_enabled_for_all_to_feature_flags.sql
@@ -0,0 +1 @@
+alter table feature_flags add column enabled_for_all boolean not null default false;
diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs
index a6afc98bfc94e2e939d2254c2d783589dc07bf1c..9331107259440a8ff671df9d1b3e6e390f1b8552 100644
--- a/crates/collab/src/api/events.rs
+++ b/crates/collab/src/api/events.rs
@@ -1,5 +1,6 @@
use super::ips_file::IpsFile;
use crate::api::CloudflareIpCountryHeader;
+use crate::clickhouse::write_to_table;
use crate::{api::slack, AppState, Error, Result};
use anyhow::{anyhow, Context};
use aws_sdk_s3::primitives::ByteStream;
@@ -529,12 +530,12 @@ struct ToUpload {
impl ToUpload {
pub async fn upload(&self, clickhouse_client: &clickhouse::Client) -> anyhow::Result<()> {
const EDITOR_EVENTS_TABLE: &str = "editor_events";
- Self::upload_to_table(EDITOR_EVENTS_TABLE, &self.editor_events, clickhouse_client)
+ write_to_table(EDITOR_EVENTS_TABLE, &self.editor_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{EDITOR_EVENTS_TABLE}'"))?;
const INLINE_COMPLETION_EVENTS_TABLE: &str = "inline_completion_events";
- Self::upload_to_table(
+ write_to_table(
INLINE_COMPLETION_EVENTS_TABLE,
&self.inline_completion_events,
clickhouse_client,
@@ -543,7 +544,7 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{INLINE_COMPLETION_EVENTS_TABLE}'"))?;
const ASSISTANT_EVENTS_TABLE: &str = "assistant_events";
- Self::upload_to_table(
+ write_to_table(
ASSISTANT_EVENTS_TABLE,
&self.assistant_events,
clickhouse_client,
@@ -552,27 +553,27 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{ASSISTANT_EVENTS_TABLE}'"))?;
const CALL_EVENTS_TABLE: &str = "call_events";
- Self::upload_to_table(CALL_EVENTS_TABLE, &self.call_events, clickhouse_client)
+ write_to_table(CALL_EVENTS_TABLE, &self.call_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{CALL_EVENTS_TABLE}'"))?;
const CPU_EVENTS_TABLE: &str = "cpu_events";
- Self::upload_to_table(CPU_EVENTS_TABLE, &self.cpu_events, clickhouse_client)
+ write_to_table(CPU_EVENTS_TABLE, &self.cpu_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{CPU_EVENTS_TABLE}'"))?;
const MEMORY_EVENTS_TABLE: &str = "memory_events";
- Self::upload_to_table(MEMORY_EVENTS_TABLE, &self.memory_events, clickhouse_client)
+ write_to_table(MEMORY_EVENTS_TABLE, &self.memory_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{MEMORY_EVENTS_TABLE}'"))?;
const APP_EVENTS_TABLE: &str = "app_events";
- Self::upload_to_table(APP_EVENTS_TABLE, &self.app_events, clickhouse_client)
+ write_to_table(APP_EVENTS_TABLE, &self.app_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{APP_EVENTS_TABLE}'"))?;
const SETTING_EVENTS_TABLE: &str = "setting_events";
- Self::upload_to_table(
+ write_to_table(
SETTING_EVENTS_TABLE,
&self.setting_events,
clickhouse_client,
@@ -581,7 +582,7 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{SETTING_EVENTS_TABLE}'"))?;
const EXTENSION_EVENTS_TABLE: &str = "extension_events";
- Self::upload_to_table(
+ write_to_table(
EXTENSION_EVENTS_TABLE,
&self.extension_events,
clickhouse_client,
@@ -590,48 +591,22 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{EXTENSION_EVENTS_TABLE}'"))?;
const EDIT_EVENTS_TABLE: &str = "edit_events";
- Self::upload_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
+ write_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{EDIT_EVENTS_TABLE}'"))?;
const ACTION_EVENTS_TABLE: &str = "action_events";
- Self::upload_to_table(ACTION_EVENTS_TABLE, &self.action_events, clickhouse_client)
+ write_to_table(ACTION_EVENTS_TABLE, &self.action_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{ACTION_EVENTS_TABLE}'"))?;
const REPL_EVENTS_TABLE: &str = "repl_events";
- Self::upload_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
+ write_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{REPL_EVENTS_TABLE}'"))?;
Ok(())
}
-
- async fn upload_to_table(
- table: &str,
- rows: &[T],
- clickhouse_client: &clickhouse::Client,
- ) -> anyhow::Result<()> {
- if rows.is_empty() {
- return Ok(());
- }
-
- let mut insert = clickhouse_client.insert(table)?;
-
- for event in rows {
- insert.write(event).await?;
- }
-
- insert.end().await?;
-
- let event_count = rows.len();
- log::info!(
- "wrote {event_count} {event_specifier} to '{table}'",
- event_specifier = if event_count == 1 { "event" } else { "events" }
- );
-
- Ok(())
- }
}
pub fn serialize_country_code
(country_code: &str, serializer: S) -> Result
diff --git a/crates/collab/src/clickhouse.rs b/crates/collab/src/clickhouse.rs
new file mode 100644
index 0000000000000000000000000000000000000000..2937116bad552f5561de8fbebcbb783d4500c1ed
--- /dev/null
+++ b/crates/collab/src/clickhouse.rs
@@ -0,0 +1,28 @@
+use serde::Serialize;
+
+/// Writes the given rows to the specified Clickhouse table.
+pub async fn write_to_table(
+ table: &str,
+ rows: &[T],
+ clickhouse_client: &clickhouse::Client,
+) -> anyhow::Result<()> {
+ if rows.is_empty() {
+ return Ok(());
+ }
+
+ let mut insert = clickhouse_client.insert(table)?;
+
+ for event in rows {
+ insert.write(event).await?;
+ }
+
+ insert.end().await?;
+
+ let event_count = rows.len();
+ log::info!(
+ "wrote {event_count} {event_specifier} to '{table}'",
+ event_specifier = if event_count == 1 { "event" } else { "events" }
+ );
+
+ Ok(())
+}
diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs
index 6ae482c36bf7182686e650d0e886128b76801540..251f5bec685d9266261549ef8e1bc537725efaa5 100644
--- a/crates/collab/src/db/queries/users.rs
+++ b/crates/collab/src/db/queries/users.rs
@@ -312,10 +312,11 @@ impl Database {
}
/// Creates a new feature flag.
- pub async fn create_user_flag(&self, flag: &str) -> Result {
+ pub async fn create_user_flag(&self, flag: &str, enabled_for_all: bool) -> Result {
self.transaction(|tx| async move {
let flag = feature_flag::Entity::insert(feature_flag::ActiveModel {
flag: ActiveValue::set(flag.to_string()),
+ enabled_for_all: ActiveValue::set(enabled_for_all),
..Default::default()
})
.exec(&*tx)
@@ -350,7 +351,15 @@ impl Database {
Flag,
}
- let flags = user::Model {
+ let flags_enabled_for_all = feature_flag::Entity::find()
+ .filter(feature_flag::Column::EnabledForAll.eq(true))
+ .select_only()
+ .column(feature_flag::Column::Flag)
+ .into_values::<_, QueryAs>()
+ .all(&*tx)
+ .await?;
+
+ let flags_enabled_for_user = user::Model {
id: user,
..Default::default()
}
@@ -361,7 +370,10 @@ impl Database {
.all(&*tx)
.await?;
- Ok(flags)
+ let mut all_flags = HashSet::from_iter(flags_enabled_for_all);
+ all_flags.extend(flags_enabled_for_user);
+
+ Ok(all_flags.into_iter().collect())
})
.await
}
diff --git a/crates/collab/src/db/tables/feature_flag.rs b/crates/collab/src/db/tables/feature_flag.rs
index 41c1451c648e7115165a2cf3bfc4e84d9ae534a1..5bbfedd71e70b7f1cc58219475c49c28bc62ff3d 100644
--- a/crates/collab/src/db/tables/feature_flag.rs
+++ b/crates/collab/src/db/tables/feature_flag.rs
@@ -8,6 +8,7 @@ pub struct Model {
#[sea_orm(primary_key)]
pub id: FlagId,
pub flag: String,
+ pub enabled_for_all: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs
index 5269d5354fd48178d2464ec6df1ac9921ffcbe52..972b45e1bccbf207d9a869a879e5a49a2698a164 100644
--- a/crates/collab/src/db/tests/feature_flag_tests.rs
+++ b/crates/collab/src/db/tests/feature_flag_tests.rs
@@ -2,6 +2,7 @@ use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
+use pretty_assertions::assert_eq;
use std::sync::Arc;
test_both_dbs!(
@@ -37,22 +38,27 @@ async fn test_get_user_flags(db: &Arc) {
.unwrap()
.user_id;
- const CHANNELS_ALPHA: &str = "channels-alpha";
- const NEW_SEARCH: &str = "new-search";
+ const FEATURE_FLAG_ONE: &str = "brand-new-ux";
+ const FEATURE_FLAG_TWO: &str = "cool-feature";
+ const FEATURE_FLAG_THREE: &str = "feature-enabled-for-everyone";
- let channels_flag = db.create_user_flag(CHANNELS_ALPHA).await.unwrap();
- let search_flag = db.create_user_flag(NEW_SEARCH).await.unwrap();
+ let feature_flag_one = db.create_user_flag(FEATURE_FLAG_ONE, false).await.unwrap();
+ let feature_flag_two = db.create_user_flag(FEATURE_FLAG_TWO, false).await.unwrap();
+ db.create_user_flag(FEATURE_FLAG_THREE, true).await.unwrap();
- db.add_user_flag(user_1, channels_flag).await.unwrap();
- db.add_user_flag(user_1, search_flag).await.unwrap();
+ db.add_user_flag(user_1, feature_flag_one).await.unwrap();
+ db.add_user_flag(user_1, feature_flag_two).await.unwrap();
- db.add_user_flag(user_2, channels_flag).await.unwrap();
+ db.add_user_flag(user_2, feature_flag_one).await.unwrap();
let mut user_1_flags = db.get_user_flags(user_1).await.unwrap();
user_1_flags.sort();
- assert_eq!(user_1_flags, &[CHANNELS_ALPHA, NEW_SEARCH]);
+ assert_eq!(
+ user_1_flags,
+ &[FEATURE_FLAG_ONE, FEATURE_FLAG_TWO, FEATURE_FLAG_THREE]
+ );
let mut user_2_flags = db.get_user_flags(user_2).await.unwrap();
user_2_flags.sort();
- assert_eq!(user_2_flags, &[CHANNELS_ALPHA]);
+ assert_eq!(user_2_flags, &[FEATURE_FLAG_ONE, FEATURE_FLAG_THREE]);
}
diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs
index 9cae7713dc7708901dd149fa717ee0aebe33fab8..81cc334c43f88172f1aeddb61b78004856aea444 100644
--- a/crates/collab/src/lib.rs
+++ b/crates/collab/src/lib.rs
@@ -1,5 +1,6 @@
pub mod api;
pub mod auth;
+pub mod clickhouse;
pub mod db;
pub mod env;
pub mod executor;
@@ -267,7 +268,7 @@ pub struct AppState {
pub stripe_client: Option>,
pub rate_limiter: Arc,
pub executor: Executor,
- pub clickhouse_client: Option,
+ pub clickhouse_client: Option<::clickhouse::Client>,
pub config: Config,
}
@@ -358,8 +359,8 @@ async fn build_blob_store_client(config: &Config) -> anyhow::Result anyhow::Result {
- Ok(clickhouse::Client::default()
+fn build_clickhouse_client(config: &Config) -> anyhow::Result<::clickhouse::Client> {
+ Ok(::clickhouse::Client::default()
.with_url(
config
.clickhouse_url
diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs
index 151b6a247ca276212601aca6c115762d222b5c70..3f5c864aae8af5dd8f4d874acc7c61adabe68bd3 100644
--- a/crates/collab/src/llm.rs
+++ b/crates/collab/src/llm.rs
@@ -141,7 +141,8 @@ async fn validate_api_token(mut req: Request, next: Next) -> impl IntoR
tracing::Span::current()
.record("user_id", claims.user_id)
.record("login", claims.github_user_login.clone())
- .record("authn.jti", &claims.jti);
+ .record("authn.jti", &claims.jti)
+ .record("is_staff", &claims.is_staff);
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
@@ -169,7 +170,10 @@ async fn perform_completion(
country_code_header: Option>,
Json(params): Json,
) -> Result {
- let model = normalize_model_name(params.provider, params.model);
+ let model = normalize_model_name(
+ state.db.model_names_for_provider(params.provider),
+ params.model,
+ );
authorize_access_to_language_model(
&state.config,
@@ -200,17 +204,21 @@ async fn perform_completion(
let mut request: anthropic::Request =
serde_json::from_str(¶ms.provider_request.get())?;
- // Parse the model, throw away the version that was included, and then set a specific
- // version that we control on the server.
+ // Override the model on the request with the latest version of the model that is
+ // known to the server.
+ //
// Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed.
- request.model = match anthropic::Model::from_id(&request.model) {
- Ok(model) => model.id().to_string(),
- Err(_) => request.model,
+ request.model = match model.as_str() {
+ "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
+ "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
+ "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
+ "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
+ _ => request.model,
};
- let chunks = anthropic::stream_completion(
+ let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
@@ -238,6 +246,19 @@ async fn perform_completion(
anthropic::AnthropicError::Other(err) => Error::Internal(err),
})?;
+ if let Some(rate_limit_info) = rate_limit_info {
+ tracing::info!(
+ target: "upstream rate limit",
+ is_staff = claims.is_staff,
+ provider = params.provider.to_string(),
+ model = model,
+ tokens_remaining = rate_limit_info.tokens_remaining,
+ requests_remaining = rate_limit_info.requests_remaining,
+ requests_reset = ?rate_limit_info.requests_reset,
+ tokens_reset = ?rate_limit_info.tokens_reset,
+ );
+ }
+
chunks
.map(move |event| {
let chunk = event?;
@@ -369,31 +390,13 @@ async fn perform_completion(
})))
}
-fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
- let prefixes: &[_] = match provider {
- LanguageModelProvider::Anthropic => &[
- "claude-3-5-sonnet",
- "claude-3-haiku",
- "claude-3-opus",
- "claude-3-sonnet",
- ],
- LanguageModelProvider::OpenAi => &[
- "gpt-3.5-turbo",
- "gpt-4-turbo-preview",
- "gpt-4o-mini",
- "gpt-4o",
- "gpt-4",
- ],
- LanguageModelProvider::Google => &[],
- LanguageModelProvider::Zed => &[],
- };
-
- if let Some(prefix) = prefixes
+fn normalize_model_name(known_models: Vec, name: String) -> String {
+ if let Some(known_model_name) = known_models
.iter()
- .filter(|&&prefix| name.starts_with(prefix))
- .max_by_key(|&&prefix| prefix.len())
+ .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
+ .max_by_key(|known_model_name| known_model_name.len())
{
- prefix.to_string()
+ known_model_name.to_string()
} else {
name
}
@@ -551,33 +554,75 @@ impl Drop for TokenCountingStream {
.await
.log_err();
- if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
- report_llm_usage(
- clickhouse_client,
- LlmUsageEventRow {
- time: Utc::now().timestamp_millis(),
- user_id: claims.user_id as i32,
- is_staff: claims.is_staff,
- plan: match claims.plan {
- Plan::Free => "free".to_string(),
- Plan::ZedPro => "zed_pro".to_string(),
+ if let Some(usage) = usage {
+ tracing::info!(
+ target: "user usage",
+ user_id = claims.user_id,
+ login = claims.github_user_login,
+ authn.jti = claims.jti,
+ is_staff = claims.is_staff,
+ requests_this_minute = usage.requests_this_minute,
+ tokens_this_minute = usage.tokens_this_minute,
+ );
+
+ if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
+ report_llm_usage(
+ clickhouse_client,
+ LlmUsageEventRow {
+ time: Utc::now().timestamp_millis(),
+ user_id: claims.user_id as i32,
+ is_staff: claims.is_staff,
+ plan: match claims.plan {
+ Plan::Free => "free".to_string(),
+ Plan::ZedPro => "zed_pro".to_string(),
+ },
+ model,
+ provider: provider.to_string(),
+ input_token_count: input_token_count as u64,
+ output_token_count: output_token_count as u64,
+ requests_this_minute: usage.requests_this_minute as u64,
+ tokens_this_minute: usage.tokens_this_minute as u64,
+ tokens_this_day: usage.tokens_this_day as u64,
+ input_tokens_this_month: usage.input_tokens_this_month as u64,
+ output_tokens_this_month: usage.output_tokens_this_month as u64,
+ spending_this_month: usage.spending_this_month as u64,
+ lifetime_spending: usage.lifetime_spending as u64,
},
- model,
- provider: provider.to_string(),
- input_token_count: input_token_count as u64,
- output_token_count: output_token_count as u64,
- requests_this_minute: usage.requests_this_minute as u64,
- tokens_this_minute: usage.tokens_this_minute as u64,
- tokens_this_day: usage.tokens_this_day as u64,
- input_tokens_this_month: usage.input_tokens_this_month as u64,
- output_tokens_this_month: usage.output_tokens_this_month as u64,
- spending_this_month: usage.spending_this_month as u64,
- lifetime_spending: usage.lifetime_spending as u64,
- },
- )
- .await
- .log_err();
+ )
+ .await
+ .log_err();
+ }
}
})
}
}
+
+pub fn log_usage_periodically(state: Arc) {
+ state.executor.clone().spawn_detached(async move {
+ loop {
+ state
+ .executor
+ .sleep(std::time::Duration::from_secs(30))
+ .await;
+
+ let Some(usages) = state
+ .db
+ .get_application_wide_usages_by_model(Utc::now())
+ .await
+ .log_err()
+ else {
+ continue;
+ };
+
+ for usage in usages {
+ tracing::info!(
+ target: "computed usage",
+ provider = usage.provider.to_string(),
+ model = usage.model,
+ requests_this_minute = usage.requests_this_minute,
+ tokens_this_minute = usage.tokens_this_minute,
+ );
+ }
+ }
+ })
+}
diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs
index 865ca97f7ab8fce70d21f4c7d499d1f7cf49fe64..0b62dd4e0a9ef9c18216070993b70bcc68dd93ac 100644
--- a/crates/collab/src/llm/authorization.rs
+++ b/crates/collab/src/llm/authorization.rs
@@ -26,9 +26,7 @@ fn authorize_access_to_model(
}
match (provider, model) {
- (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
- Ok(())
- }
+ (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs
index b3144eeecd18095983669d0ac6df3d6178cb0ef0..f76a722471e760f976506b715fb836e85f4fd98f 100644
--- a/crates/collab/src/llm/db.rs
+++ b/crates/collab/src/llm/db.rs
@@ -67,6 +67,21 @@ impl LlmDatabase {
Ok(())
}
+ /// Returns the names of the known models for the given [`LanguageModelProvider`].
+ pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec {
+ self.models
+ .keys()
+ .filter_map(|(model_provider, model_name)| {
+ if model_provider == &provider {
+ Some(model_name)
+ } else {
+ None
+ }
+ })
+ .cloned()
+ .collect::>()
+ }
+
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models
diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs
index adfd55088fdc5aa69e96744633b03bc4f70f5dd7..0bfbb4c1b1ba96eaac2e81164e462de19899cae4 100644
--- a/crates/collab/src/llm/db/queries/usages.rs
+++ b/crates/collab/src/llm/db/queries/usages.rs
@@ -1,5 +1,6 @@
use crate::db::UserId;
use chrono::Duration;
+use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
@@ -18,6 +19,14 @@ pub struct Usage {
pub lifetime_spending: usize,
}
+#[derive(Debug, PartialEq, Clone)]
+pub struct ApplicationWideUsage {
+ pub provider: LanguageModelProvider,
+ pub model: String,
+ pub requests_this_minute: usize,
+ pub tokens_this_minute: usize,
+}
+
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
@@ -63,6 +72,71 @@ impl LlmDatabase {
Ok(())
}
+ pub async fn get_application_wide_usages_by_model(
+ &self,
+ now: DateTimeUtc,
+ ) -> Result> {
+ self.transaction(|tx| async move {
+ let past_minute = now - Duration::minutes(1);
+ let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
+ let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
+
+ let mut results = Vec::new();
+ for (provider, model) in self.models.keys().cloned() {
+ let mut usages = usage::Entity::find()
+ .filter(
+ usage::Column::Timestamp
+ .gte(past_minute.naive_utc())
+ .and(usage::Column::IsStaff.eq(false))
+ .and(
+ usage::Column::MeasureId
+ .eq(requests_per_minute)
+ .or(usage::Column::MeasureId.eq(tokens_per_minute)),
+ ),
+ )
+ .stream(&*tx)
+ .await?;
+
+ let mut requests_this_minute = 0;
+ let mut tokens_this_minute = 0;
+ while let Some(usage) = usages.next().await {
+ let usage = usage?;
+ if usage.measure_id == requests_per_minute {
+ requests_this_minute += Self::get_live_buckets(
+ &usage,
+ now.naive_utc(),
+ UsageMeasure::RequestsPerMinute,
+ )
+ .0
+ .iter()
+ .copied()
+ .sum::() as usize;
+ } else if usage.measure_id == tokens_per_minute {
+ tokens_this_minute += Self::get_live_buckets(
+ &usage,
+ now.naive_utc(),
+ UsageMeasure::TokensPerMinute,
+ )
+ .0
+ .iter()
+ .copied()
+ .sum::() as usize;
+ }
+ }
+
+ results.push(ApplicationWideUsage {
+ provider,
+ model,
+ requests_this_minute,
+ tokens_this_minute,
+ })
+ }
+
+ Ok(results)
+ })
+ .await
+ }
+
pub async fn get_usage(
&self,
user_id: UserId,
diff --git a/crates/collab/src/llm/telemetry.rs b/crates/collab/src/llm/telemetry.rs
index ac90bd265ab36dce0bd160918d46d5880b95e082..17a2cb9cd3389dbff1bd014ecec7c528dfe88ce4 100644
--- a/crates/collab/src/llm/telemetry.rs
+++ b/crates/collab/src/llm/telemetry.rs
@@ -1,6 +1,8 @@
-use anyhow::Result;
+use anyhow::{Context, Result};
use serde::Serialize;
+use crate::clickhouse::write_to_table;
+
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct LlmUsageEventRow {
pub time: i64,
@@ -40,9 +42,10 @@ pub struct LlmRateLimitEventRow {
}
pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
- let mut insert = client.insert("llm_usage_events")?;
- insert.write(&row).await?;
- insert.end().await?;
+ const LLM_USAGE_EVENTS_TABLE: &str = "llm_usage_events";
+ write_to_table(LLM_USAGE_EVENTS_TABLE, &[row], client)
+ .await
+ .with_context(|| format!("failed to upload to table '{LLM_USAGE_EVENTS_TABLE}'"))?;
Ok(())
}
@@ -50,8 +53,9 @@ pub async fn report_llm_rate_limit(
client: &clickhouse::Client,
row: LlmRateLimitEventRow,
) -> Result<()> {
- let mut insert = client.insert("llm_rate_limits")?;
- insert.write(&row).await?;
- insert.end().await?;
+ const LLM_RATE_LIMIT_EVENTS_TABLE: &str = "llm_rate_limit_events";
+ write_to_table(LLM_RATE_LIMIT_EVENTS_TABLE, &[row], client)
+ .await
+ .with_context(|| format!("failed to upload to table '{LLM_RATE_LIMIT_EVENTS_TABLE}'"))?;
Ok(())
}
diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs
index 5d4fc2abe06c915bba6cd03fa0625ebd2d52eed1..a946d1662019b70198ca62ef019a36eb2857431b 100644
--- a/crates/collab/src/main.rs
+++ b/crates/collab/src/main.rs
@@ -5,7 +5,7 @@ use axum::{
routing::get,
Extension, Router,
};
-use collab::llm::db::LlmDatabase;
+use collab::llm::{db::LlmDatabase, log_usage_periodically};
use collab::migrations::run_database_migrations;
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{
@@ -95,6 +95,8 @@ async fn main() -> Result<()> {
let state = LlmState::new(config.clone(), Executor::Production).await?;
+ log_usage_periodically(state.clone());
+
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
@@ -152,7 +154,8 @@ async fn main() -> Result<()> {
matched_path,
user_id = tracing::field::Empty,
login = tracing::field::Empty,
- authn.jti = tracing::field::Empty
+ authn.jti = tracing::field::Empty,
+ is_staff = tracing::field::Empty
)
})
.on_response(
diff --git a/crates/collab/src/seed.rs b/crates/collab/src/seed.rs
index 509b986089185bae9bb4a6ab6354fb635e819164..6b56f46d50e461c9ac2fd2fdba7f7c9a9247c3c0 100644
--- a/crates/collab/src/seed.rs
+++ b/crates/collab/src/seed.rs
@@ -1,7 +1,6 @@
use crate::db::{self, ChannelRole, NewUserParams};
use anyhow::Context;
-use chrono::{DateTime, Utc};
use db::Database;
use serde::{de::DeserializeOwned, Deserialize};
use std::{fmt::Write, fs, path::Path};
@@ -13,7 +12,6 @@ struct GitHubUser {
id: i32,
login: String,
email: Option,
- created_at: DateTime,
}
#[derive(Deserialize)]
@@ -44,6 +42,17 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
let mut first_user = None;
let mut others = vec![];
+ let flag_names = ["remoting", "language-models"];
+ let mut flags = Vec::new();
+
+ for flag_name in flag_names {
+ let flag = db
+ .create_user_flag(flag_name, false)
+ .await
+ .unwrap_or_else(|_| panic!("failed to create flag: '{flag_name}'"));
+ flags.push(flag);
+ }
+
for admin_login in seed_config.admins {
let user = fetch_github::(
&client,
@@ -66,6 +75,15 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
} else {
others.push(user.user_id)
}
+
+ for flag in &flags {
+ db.add_user_flag(user.user_id, *flag)
+ .await
+ .context(format!(
+ "Unable to enable flag '{}' for user '{}'",
+ flag, user.user_id
+ ))?;
+ }
}
for channel in seed_config.channels {
@@ -86,6 +104,7 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
}
}
+ // TODO: Fix this later
if let Some(number_of_users) = seed_config.number_of_users {
// Fetch 100 other random users from GitHub and insert them into the database
// (for testing autocompleters, etc.)
@@ -105,15 +124,23 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
for github_user in users {
last_user_id = Some(github_user.id);
user_count += 1;
- db.get_or_create_user_by_github_account(
- &github_user.login,
- Some(github_user.id),
- github_user.email.as_deref(),
- Some(github_user.created_at),
- None,
- )
- .await
- .expect("failed to insert user");
+ let user = db
+ .get_or_create_user_by_github_account(
+ &github_user.login,
+ Some(github_user.id),
+ github_user.email.as_deref(),
+ None,
+ None,
+ )
+ .await
+ .expect("failed to insert user");
+
+ for flag in &flags {
+ db.add_user_flag(user.id, *flag).await.context(format!(
+ "Unable to enable flag '{}' for user '{}'",
+ flag, user.id
+ ))?;
+ }
}
}
}
@@ -132,9 +159,9 @@ async fn fetch_github(client: &reqwest::Client, url: &str)
.header("user-agent", "zed")
.send()
.await
- .unwrap_or_else(|_| panic!("failed to fetch '{}'", url));
+ .unwrap_or_else(|error| panic!("failed to fetch '{url}': {error}"));
response
.json()
.await
- .unwrap_or_else(|_| panic!("failed to deserialize github user from '{}'", url))
+ .unwrap_or_else(|error| panic!("failed to deserialize github user from '{url}': {error}"))
}
diff --git a/crates/context_servers/src/context_servers.rs b/crates/context_servers/src/context_servers.rs
index 3892adff56f27852c87f3d2232d93028508e7fb0..0dd58de0c8334485e2bbe33827cbb72e1b84aeed 100644
--- a/crates/context_servers/src/context_servers.rs
+++ b/crates/context_servers/src/context_servers.rs
@@ -30,7 +30,9 @@ fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewC
let model = ContextServerManager::global(&cx);
cx.update_model(&model, |manager, cx| {
for server in manager.servers() {
- manager.restart_server(&server.id, cx).detach();
+ manager
+ .restart_server(&server.id, cx)
+ .detach_and_log_err(cx);
}
});
}
diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs
index 9d7d67a72f893d89ae931690658d83ac46b6c4a5..30164cd5c4f17fbf4016abd93f4761ca60714ce9 100644
--- a/crates/context_servers/src/manager.rs
+++ b/crates/context_servers/src/manager.rs
@@ -266,11 +266,11 @@ pub fn init(cx: &mut AppContext) {
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
- manager.add_server(config, cx).detach();
+ manager.add_server(config, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {
- manager.remove_server(&id, cx).detach();
+ manager.remove_server(&id, cx).detach_and_log_err(cx);
}
})
})
diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs
index a2fca980a5b8b1d838e93502ee43ecf1bae157cf..c69877340d84f592a52e1d022b272ca6229480e1 100644
--- a/crates/copilot/src/copilot_chat.rs
+++ b/crates/copilot/src/copilot_chat.rs
@@ -31,6 +31,8 @@ pub enum Role {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
+ #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
+ Gpt4o,
#[serde(alias = "gpt-4", rename = "gpt-4")]
Gpt4,
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
@@ -40,6 +42,7 @@ pub enum Model {
impl Model {
pub fn from_id(id: &str) -> Result {
match id {
+ "gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
_ => Err(anyhow!("Invalid model id: {}", id)),
@@ -50,6 +53,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
+ Self::Gpt4o => "gpt-4o",
}
}
@@ -57,11 +61,13 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4",
+ Self::Gpt4o => "GPT-4o",
}
}
pub fn max_token_count(&self) -> usize {
match self {
+ Self::Gpt4o => 128000,
Self::Gpt4 => 8192,
Self::Gpt3_5Turbo => 16385,
}
diff --git a/crates/go_to_line/src/cursor_position.rs b/crates/go_to_line/src/cursor_position.rs
index 0f14af3bd11fa1c0a7942ab2a69ebd50eeb19780..f53543ac6bd88e9ecdf628af5a60ab04b33ad784 100644
--- a/crates/go_to_line/src/cursor_position.rs
+++ b/crates/go_to_line/src/cursor_position.rs
@@ -12,11 +12,11 @@ use ui::{
use util::paths::FILE_ROW_COLUMN_DELIMITER;
use workspace::{item::ItemHandle, StatusItemView, Workspace};
-#[derive(Copy, Clone, Default, PartialOrd, PartialEq)]
-struct SelectionStats {
- lines: usize,
- characters: usize,
- selections: usize,
+#[derive(Copy, Clone, Debug, Default, PartialOrd, PartialEq)]
+pub(crate) struct SelectionStats {
+ pub lines: usize,
+ pub characters: usize,
+ pub selections: usize,
}
pub struct CursorPosition {
@@ -44,7 +44,10 @@ impl CursorPosition {
self.selected_count.selections = editor.selections.count();
let mut last_selection: Option> = None;
for selection in editor.selections.all::(cx) {
- self.selected_count.characters += selection.end - selection.start;
+ self.selected_count.characters += buffer
+ .text_for_range(selection.start..selection.end)
+ .map(|t| t.chars().count())
+ .sum::();
if last_selection
.as_ref()
.map_or(true, |last_selection| selection.id > last_selection.id)
@@ -106,6 +109,11 @@ impl CursorPosition {
}
text.push(')');
}
+
+ #[cfg(test)]
+ pub(crate) fn selection_stats(&self) -> &SelectionStats {
+ &self.selected_count
+ }
}
impl Render for CursorPosition {
diff --git a/crates/go_to_line/src/go_to_line.rs b/crates/go_to_line/src/go_to_line.rs
index 155f38d808bbac3ff498cac625f06be0f40983c4..4f3e6194a022eb8ac73ec737e24df54265ebfd40 100644
--- a/crates/go_to_line/src/go_to_line.rs
+++ b/crates/go_to_line/src/go_to_line.rs
@@ -221,6 +221,8 @@ impl Render for GoToLine {
#[cfg(test)]
mod tests {
use super::*;
+ use cursor_position::{CursorPosition, SelectionStats};
+ use editor::actions::SelectAll;
use gpui::{TestAppContext, VisualTestContext};
use indoc::indoc;
use project::{FakeFs, Project};
@@ -335,6 +337,83 @@ mod tests {
assert_single_caret_at_row(&editor, expected_highlighted_row, cx);
}
+ #[gpui::test]
+ async fn test_unicode_characters_selection(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/dir",
+ json!({
+ "a.rs": "ēlo"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, ["/dir".as_ref()], cx).await;
+ let (workspace, cx) = cx.add_window_view(|cx| Workspace::test_new(project.clone(), cx));
+ workspace.update(cx, |workspace, cx| {
+ let cursor_position = cx.new_view(|_| CursorPosition::new(workspace));
+ workspace.status_bar().update(cx, |status_bar, cx| {
+ status_bar.add_right_item(cursor_position, cx);
+ });
+ });
+
+ let worktree_id = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ project.worktrees(cx).next().unwrap().read(cx).id()
+ })
+ });
+ let _buffer = project
+ .update(cx, |project, cx| project.open_local_buffer("/dir/a.rs", cx))
+ .await
+ .unwrap();
+ let editor = workspace
+ .update(cx, |workspace, cx| {
+ workspace.open_path((worktree_id, "a.rs"), None, true, cx)
+ })
+ .await
+ .unwrap()
+ .downcast::()
+ .unwrap();
+
+ workspace.update(cx, |workspace, cx| {
+ assert_eq!(
+ &SelectionStats {
+ lines: 0,
+ characters: 0,
+ selections: 1,
+ },
+ workspace
+ .status_bar()
+ .read(cx)
+ .item_of_type::()
+ .expect("missing cursor position item")
+ .read(cx)
+ .selection_stats(),
+ "No selections should be initially"
+ );
+ });
+ editor.update(cx, |editor, cx| editor.select_all(&SelectAll, cx));
+ workspace.update(cx, |workspace, cx| {
+ assert_eq!(
+ &SelectionStats {
+ lines: 1,
+ characters: 3,
+ selections: 1,
+ },
+ workspace
+ .status_bar()
+ .read(cx)
+ .item_of_type::()
+ .expect("missing cursor position item")
+ .read(cx)
+ .selection_stats(),
+ "After selecting a text with multibyte unicode characters, the character count should be correct"
+ );
+ });
+ }
+
fn open_go_to_line_view(
workspace: &View,
cx: &mut VisualTestContext,
diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs
index 5e3922061de303d95660dbb89a5ce67bb1592c92..6d011a767542847a1979778fc69897b5928ad679 100644
--- a/crates/gpui/src/app.rs
+++ b/crates/gpui/src/app.rs
@@ -6,7 +6,7 @@ use std::{
path::{Path, PathBuf},
rc::{Rc, Weak},
sync::{atomic::Ordering::SeqCst, Arc},
- time::Duration,
+ time::{Duration, Instant},
};
use anyhow::{anyhow, Result};
@@ -142,6 +142,12 @@ impl App {
self
}
+ /// Sets a start time for tracking time to first window draw.
+ pub fn measure_time_to_first_window_draw(self, start: Instant) -> Self {
+ self.0.borrow_mut().time_to_first_window_draw = Some(TimeToFirstWindowDraw::Pending(start));
+ self
+ }
+
/// Start the application. The provided callback will be called once the
/// app is fully launched.
pub fn run(self, on_finish_launching: F)
@@ -247,6 +253,7 @@ pub struct AppContext {
pub(crate) layout_id_buffer: Vec, // We recycle this memory across layout requests.
pub(crate) propagate_event: bool,
pub(crate) prompt_builder: Option,
+ pub(crate) time_to_first_window_draw: Option,
}
impl AppContext {
@@ -300,6 +307,7 @@ impl AppContext {
layout_id_buffer: Default::default(),
propagate_event: true,
prompt_builder: Some(PromptBuilder::Default),
+ time_to_first_window_draw: None,
}),
});
@@ -1302,6 +1310,14 @@ impl AppContext {
(task, is_first)
}
+
+ /// Returns the time to first window draw, if available.
+ pub fn time_to_first_window_draw(&self) -> Option {
+ match self.time_to_first_window_draw {
+ Some(TimeToFirstWindowDraw::Done(duration)) => Some(duration),
+ _ => None,
+ }
+ }
}
impl Context for AppContext {
@@ -1465,6 +1481,15 @@ impl DerefMut for GlobalLease {
}
}
+/// Represents the initialization duration of the application.
+#[derive(Clone, Copy)]
+pub enum TimeToFirstWindowDraw {
+ /// The application is still initializing, and contains the start time.
+ Pending(Instant),
+ /// The application has finished initializing, and contains the total duration.
+ Done(Duration),
+}
+
/// Contains state associated with an active drag operation, started by dragging an element
/// within the window or by dragging into the app from the underlying platform.
pub struct AnyDrag {
diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs
index d19a6b745acb7dc4175ddc6b918ba549414c5508..d10f386419bc0e534d826fffda394fa9c76b7979 100644
--- a/crates/gpui/src/platform.rs
+++ b/crates/gpui/src/platform.rs
@@ -16,6 +16,7 @@ mod blade;
#[cfg(any(test, feature = "test-support"))]
mod test;
+mod fps;
#[cfg(target_os = "windows")]
mod windows;
@@ -51,6 +52,7 @@ use strum::EnumIter;
use uuid::Uuid;
pub use app_menu::*;
+pub use fps::*;
pub use keystroke::*;
#[cfg(target_os = "linux")]
@@ -354,7 +356,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn on_should_close(&self, callback: Box bool>);
fn on_close(&self, callback: Box);
fn on_appearance_changed(&self, callback: Box);
- fn draw(&self, scene: &Scene);
+ fn draw(&self, scene: &Scene, on_complete: Option>);
fn completed_frame(&self) {}
fn sprite_atlas(&self) -> Arc;
@@ -379,6 +381,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
}
fn set_client_inset(&self, _inset: Pixels) {}
fn gpu_specs(&self) -> Option;
+ fn fps(&self) -> Option;
#[cfg(any(test, feature = "test-support"))]
fn as_test(&mut self) -> Option<&mut TestWindow> {
diff --git a/crates/gpui/src/platform/blade/blade_renderer.rs b/crates/gpui/src/platform/blade/blade_renderer.rs
index 27e76b977855f4f2b5a9af5b2686014bef4ce21d..afb065895d65dd3ea7c56c3b9a9f3e97273e8ef2 100644
--- a/crates/gpui/src/platform/blade/blade_renderer.rs
+++ b/crates/gpui/src/platform/blade/blade_renderer.rs
@@ -9,6 +9,7 @@ use crate::{
};
use bytemuck::{Pod, Zeroable};
use collections::HashMap;
+use futures::channel::oneshot;
#[cfg(target_os = "macos")]
use media::core_video::CVMetalTextureCache;
#[cfg(target_os = "macos")]
@@ -537,7 +538,12 @@ impl BladeRenderer {
self.gpu.destroy_command_encoder(&mut self.command_encoder);
}
- pub fn draw(&mut self, scene: &Scene) {
+ pub fn draw(
+ &mut self,
+ scene: &Scene,
+ // Required to compile on macOS, but not currently supported.
+ _on_complete: Option>,
+ ) {
self.command_encoder.start();
self.atlas.before_frame(&mut self.command_encoder);
self.rasterize_paths(scene.paths());
@@ -766,4 +772,9 @@ impl BladeRenderer {
self.wait_for_gpu();
self.last_sync_point = Some(sync_point);
}
+
+ /// Required to compile on macOS, but not currently supported.
+ pub fn fps(&self) -> f32 {
+ 0.0
+ }
}
diff --git a/crates/gpui/src/platform/fps.rs b/crates/gpui/src/platform/fps.rs
new file mode 100644
index 0000000000000000000000000000000000000000..9776e0d454da3efdf323092dd7aa4b8ad9ff8587
--- /dev/null
+++ b/crates/gpui/src/platform/fps.rs
@@ -0,0 +1,94 @@
+use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
+use std::sync::Arc;
+
+const NANOS_PER_SEC: u64 = 1_000_000_000;
+const WINDOW_SIZE: usize = 128;
+
+/// Represents a rolling FPS (Frames Per Second) counter.
+///
+/// This struct provides a lock-free mechanism to measure and calculate FPS
+/// continuously, updating with every frame. It uses atomic operations to
+/// ensure thread-safety without the need for locks.
+pub struct FpsCounter {
+ frame_times: [AtomicU64; WINDOW_SIZE],
+ head: AtomicUsize,
+ tail: AtomicUsize,
+}
+
+impl FpsCounter {
+ /// Creates a new `Fps` counter.
+ ///
+ /// Returns an `Arc` for safe sharing across threads.
+ pub fn new() -> Arc {
+ Arc::new(Self {
+ frame_times: std::array::from_fn(|_| AtomicU64::new(0)),
+ head: AtomicUsize::new(0),
+ tail: AtomicUsize::new(0),
+ })
+ }
+
+ /// Increments the FPS counter with a new frame timestamp.
+ ///
+ /// This method updates the internal state to maintain a rolling window
+ /// of frame data for the last second. It uses atomic operations to
+ /// ensure thread-safety.
+ ///
+ /// # Arguments
+ ///
+ /// * `timestamp_ns` - The timestamp of the new frame in nanoseconds.
+ pub fn increment(&self, timestamp_ns: u64) {
+ let mut head = self.head.load(Ordering::Relaxed);
+ let mut tail = self.tail.load(Ordering::Relaxed);
+
+ // Add new timestamp
+ self.frame_times[head].store(timestamp_ns, Ordering::Relaxed);
+ // Increment head and wrap around to 0 if it reaches WINDOW_SIZE
+ head = (head + 1) % WINDOW_SIZE;
+ self.head.store(head, Ordering::Relaxed);
+
+ // Remove old timestamps (older than 1 second)
+ while tail != head {
+ let oldest = self.frame_times[tail].load(Ordering::Relaxed);
+ if timestamp_ns.wrapping_sub(oldest) <= NANOS_PER_SEC {
+ break;
+ }
+ // Increment tail and wrap around to 0 if it reaches WINDOW_SIZE
+ tail = (tail + 1) % WINDOW_SIZE;
+ self.tail.store(tail, Ordering::Relaxed);
+ }
+ }
+
+ /// Calculates and returns the current FPS.
+ ///
+ /// This method computes the FPS based on the frames recorded in the last second.
+ /// It uses atomic loads to ensure thread-safety.
+ ///
+ /// # Returns
+ ///
+ /// The calculated FPS as a `f32`, or 0.0 if no frames have been recorded.
+ pub fn fps(&self) -> f32 {
+ let head = self.head.load(Ordering::Relaxed);
+ let tail = self.tail.load(Ordering::Relaxed);
+
+ if head == tail {
+ return 0.0;
+ }
+
+ let newest =
+ self.frame_times[head.wrapping_sub(1) & (WINDOW_SIZE - 1)].load(Ordering::Relaxed);
+ let oldest = self.frame_times[tail].load(Ordering::Relaxed);
+
+ let time_diff = newest.wrapping_sub(oldest) as f32;
+ if time_diff == 0.0 {
+ return 0.0;
+ }
+
+ let frame_count = if head > tail {
+ head - tail
+ } else {
+ WINDOW_SIZE - tail + head
+ };
+
+ (frame_count as f32 - 1.0) * NANOS_PER_SEC as f32 / time_diff
+ }
+}
diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs
index 2bd57adafcdc4f40ac9612da71798d6ded5ef131..30e5d8a17266cd272ce4c728d25c4a64833fe4c1 100644
--- a/crates/gpui/src/platform/linux/wayland/window.rs
+++ b/crates/gpui/src/platform/linux/wayland/window.rs
@@ -6,7 +6,7 @@ use std::sync::Arc;
use blade_graphics as gpu;
use collections::HashMap;
-use futures::channel::oneshot::Receiver;
+use futures::channel::oneshot;
use raw_window_handle as rwh;
use wayland_backend::client::ObjectId;
@@ -827,7 +827,7 @@ impl PlatformWindow for WaylandWindow {
_msg: &str,
_detail: Option<&str>,
_answers: &[&str],
- ) -> Option> {
+ ) -> Option> {
None
}
@@ -934,9 +934,9 @@ impl PlatformWindow for WaylandWindow {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
+ fn draw(&self, scene: &Scene, on_complete: Option>) {
let mut state = self.borrow_mut();
- state.renderer.draw(scene);
+ state.renderer.draw(scene, on_complete);
}
fn completed_frame(&self) {
@@ -1009,6 +1009,10 @@ impl PlatformWindow for WaylandWindow {
fn gpu_specs(&self) -> Option {
self.borrow().renderer.gpu_specs().into()
}
+
+ fn fps(&self) -> Option {
+ None
+ }
}
fn update_window(mut state: RefMut) {
diff --git a/crates/gpui/src/platform/linux/x11/window.rs b/crates/gpui/src/platform/linux/x11/window.rs
index b3c8ea7cc722c96d864a68676cf24a2b86345e5a..eb0d784ca305d79b506514208f8e70387d2fbf8b 100644
--- a/crates/gpui/src/platform/linux/x11/window.rs
+++ b/crates/gpui/src/platform/linux/x11/window.rs
@@ -1,5 +1,3 @@
-use anyhow::Context;
-
use crate::{
platform::blade::{BladeRenderer, BladeSurfaceConfig},
px, size, AnyWindowHandle, Bounds, Decorations, DevicePixels, ForegroundExecutor, GPUSpecs,
@@ -9,7 +7,9 @@ use crate::{
X11ClientStatePtr,
};
+use anyhow::Context;
use blade_graphics as gpu;
+use futures::channel::oneshot;
use raw_window_handle as rwh;
use util::{maybe, ResultExt};
use x11rb::{
@@ -1210,9 +1210,10 @@ impl PlatformWindow for X11Window {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
+ // TODO: on_complete not yet supported for X11 windows
+ fn draw(&self, scene: &Scene, on_complete: Option>) {
let mut inner = self.0.state.borrow_mut();
- inner.renderer.draw(scene);
+ inner.renderer.draw(scene, on_complete);
}
fn sprite_atlas(&self) -> Arc {
@@ -1398,4 +1399,8 @@ impl PlatformWindow for X11Window {
fn gpu_specs(&self) -> Option {
self.0.state.borrow().renderer.gpu_specs().into()
}
+
+ fn fps(&self) -> Option {
+ None
+ }
}
diff --git a/crates/gpui/src/platform/mac/metal_renderer.rs b/crates/gpui/src/platform/mac/metal_renderer.rs
index 401734e2536c1bf1faf5eaf61fab77c5c4acf199..e8d92057af11956942f90e00002c05d446b1a560 100644
--- a/crates/gpui/src/platform/mac/metal_renderer.rs
+++ b/crates/gpui/src/platform/mac/metal_renderer.rs
@@ -1,7 +1,7 @@
use super::metal_atlas::MetalAtlas;
use crate::{
point, size, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, ContentMask, DevicePixels,
- Hsla, MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite,
+ FpsCounter, Hsla, MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite,
PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline,
};
use anyhow::{anyhow, Result};
@@ -14,6 +14,7 @@ use cocoa::{
use collections::HashMap;
use core_foundation::base::TCFType;
use foreign_types::ForeignType;
+use futures::channel::oneshot;
use media::core_video::CVMetalTextureCache;
use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange};
use objc::{self, msg_send, sel, sel_impl};
@@ -105,6 +106,7 @@ pub(crate) struct MetalRenderer {
instance_buffer_pool: Arc>,
sprite_atlas: Arc,
core_video_texture_cache: CVMetalTextureCache,
+ fps_counter: Arc,
}
impl MetalRenderer {
@@ -250,6 +252,7 @@ impl MetalRenderer {
instance_buffer_pool,
sprite_atlas,
core_video_texture_cache,
+ fps_counter: FpsCounter::new(),
}
}
@@ -292,7 +295,8 @@ impl MetalRenderer {
// nothing to do
}
- pub fn draw(&mut self, scene: &Scene) {
+ pub fn draw(&mut self, scene: &Scene, on_complete: Option>) {
+ let on_complete = Arc::new(Mutex::new(on_complete));
let layer = self.layer.clone();
let viewport_size = layer.drawable_size();
let viewport_size: Size = size(
@@ -319,13 +323,24 @@ impl MetalRenderer {
Ok(command_buffer) => {
let instance_buffer_pool = self.instance_buffer_pool.clone();
let instance_buffer = Cell::new(Some(instance_buffer));
- let block = ConcreteBlock::new(move |_| {
- if let Some(instance_buffer) = instance_buffer.take() {
- instance_buffer_pool.lock().release(instance_buffer);
- }
- });
- let block = block.copy();
- command_buffer.add_completed_handler(&block);
+ let device = self.device.clone();
+ let fps_counter = self.fps_counter.clone();
+ let completed_handler =
+ ConcreteBlock::new(move |_: &metal::CommandBufferRef| {
+ let mut cpu_timestamp = 0;
+ let mut gpu_timestamp = 0;
+ device.sample_timestamps(&mut cpu_timestamp, &mut gpu_timestamp);
+
+ fps_counter.increment(gpu_timestamp);
+ if let Some(on_complete) = on_complete.lock().take() {
+ on_complete.send(()).ok();
+ }
+ if let Some(instance_buffer) = instance_buffer.take() {
+ instance_buffer_pool.lock().release(instance_buffer);
+ }
+ });
+ let completed_handler = completed_handler.copy();
+ command_buffer.add_completed_handler(&completed_handler);
if self.presents_with_transaction {
command_buffer.commit();
@@ -1117,6 +1132,10 @@ impl MetalRenderer {
}
true
}
+
+ pub fn fps(&self) -> f32 {
+ self.fps_counter.fps()
+ }
}
fn build_pipeline_state(
diff --git a/crates/gpui/src/platform/mac/window.rs b/crates/gpui/src/platform/mac/window.rs
index 0df9f3936e3a4fcd0ac6498315a2d4ecb41a0d97..bc9ace81d37966952e73c71bdc508ceec9aa60ae 100644
--- a/crates/gpui/src/platform/mac/window.rs
+++ b/crates/gpui/src/platform/mac/window.rs
@@ -784,14 +784,14 @@ impl PlatformWindow for MacWindow {
self.0.as_ref().lock().bounds()
}
- fn window_bounds(&self) -> WindowBounds {
- self.0.as_ref().lock().window_bounds()
- }
-
fn is_maximized(&self) -> bool {
self.0.as_ref().lock().is_maximized()
}
+ fn window_bounds(&self) -> WindowBounds {
+ self.0.as_ref().lock().window_bounds()
+ }
+
fn content_size(&self) -> Size {
self.0.as_ref().lock().content_size()
}
@@ -975,8 +975,6 @@ impl PlatformWindow for MacWindow {
}
}
- fn set_app_id(&mut self, _app_id: &str) {}
-
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
let mut this = self.0.as_ref().lock();
this.renderer
@@ -1007,30 +1005,6 @@ impl PlatformWindow for MacWindow {
}
}
- fn set_edited(&mut self, edited: bool) {
- unsafe {
- let window = self.0.lock().native_window;
- msg_send![window, setDocumentEdited: edited as BOOL]
- }
-
- // Changing the document edited state resets the traffic light position,
- // so we have to move it again.
- self.0.lock().move_traffic_light();
- }
-
- fn show_character_palette(&self) {
- let this = self.0.lock();
- let window = this.native_window;
- this.executor
- .spawn(async move {
- unsafe {
- let app = NSApplication::sharedApplication(nil);
- let _: () = msg_send![app, orderFrontCharacterPalette: window];
- }
- })
- .detach();
- }
-
fn minimize(&self) {
let window = self.0.lock().native_window;
unsafe {
@@ -1107,18 +1081,48 @@ impl PlatformWindow for MacWindow {
self.0.lock().appearance_changed_callback = Some(callback);
}
- fn draw(&self, scene: &crate::Scene) {
+ fn draw(&self, scene: &crate::Scene, on_complete: Option>) {
let mut this = self.0.lock();
- this.renderer.draw(scene);
+ this.renderer.draw(scene, on_complete);
}
fn sprite_atlas(&self) -> Arc {
self.0.lock().renderer.sprite_atlas().clone()
}
+ fn set_edited(&mut self, edited: bool) {
+ unsafe {
+ let window = self.0.lock().native_window;
+ msg_send![window, setDocumentEdited: edited as BOOL]
+ }
+
+ // Changing the document edited state resets the traffic light position,
+ // so we have to move it again.
+ self.0.lock().move_traffic_light();
+ }
+
+ fn show_character_palette(&self) {
+ let this = self.0.lock();
+ let window = this.native_window;
+ this.executor
+ .spawn(async move {
+ unsafe {
+ let app = NSApplication::sharedApplication(nil);
+ let _: () = msg_send![app, orderFrontCharacterPalette: window];
+ }
+ })
+ .detach();
+ }
+
+ fn set_app_id(&mut self, _app_id: &str) {}
+
fn gpu_specs(&self) -> Option {
None
}
+
+ fn fps(&self) -> Option {
+ Some(self.0.lock().renderer.fps())
+ }
}
impl rwh::HasWindowHandle for MacWindow {
diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs
index f79421b12a8a325ded7d4973d71565226135fdad..2680d1ce2380abb35adddd3ee3ebe768be9261f5 100644
--- a/crates/gpui/src/platform/test/window.rs
+++ b/crates/gpui/src/platform/test/window.rs
@@ -251,7 +251,12 @@ impl PlatformWindow for TestWindow {
fn on_appearance_changed(&self, _callback: Box) {}
- fn draw(&self, _scene: &crate::Scene) {}
+ fn draw(
+ &self,
+ _scene: &crate::Scene,
+ _on_complete: Option>,
+ ) {
+ }
fn sprite_atlas(&self) -> sync::Arc {
self.0.lock().sprite_atlas.clone()
@@ -277,6 +282,10 @@ impl PlatformWindow for TestWindow {
fn gpu_specs(&self) -> Option {
None
}
+
+ fn fps(&self) -> Option {
+ None
+ }
}
pub(crate) struct TestAtlasState {
diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs
index d2db91b5cbd3bbdc60af4c880c7ccac282c06f25..5d1ba28db97304b8a597b8d45fa055b1d54c1877 100644
--- a/crates/gpui/src/platform/windows/window.rs
+++ b/crates/gpui/src/platform/windows/window.rs
@@ -660,8 +660,8 @@ impl PlatformWindow for WindowsWindow {
self.0.state.borrow_mut().callbacks.appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
- self.0.state.borrow_mut().renderer.draw(scene)
+ fn draw(&self, scene: &Scene, on_complete: Option>) {
+ self.0.state.borrow_mut().renderer.draw(scene, on_complete)
}
fn sprite_atlas(&self) -> Arc {
@@ -675,6 +675,10 @@ impl PlatformWindow for WindowsWindow {
fn gpu_specs(&self) -> Option {
Some(self.0.state.borrow().renderer.gpu_specs())
}
+
+ fn fps(&self) -> Option {
+ None
+ }
}
#[implement(IDropTarget)]
diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs
index ae454ae022b2249ada44121e2ddda875dbf0fa1d..7aae3708d13cad0b0bdc2ccd176d26986a360e2c 100644
--- a/crates/gpui/src/window.rs
+++ b/crates/gpui/src/window.rs
@@ -11,9 +11,9 @@ use crate::{
PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams,
Replay, ResizeEdge, ScaledPixels, Scene, Shadow, SharedString, Size, StrikethroughStyle, Style,
SubscriberSet, Subscription, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement,
- TransformationMatrix, Underline, UnderlineStyle, View, VisualContext, WeakView,
- WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations,
- WindowOptions, WindowParams, WindowTextSystem, SUBPIXEL_VARIANTS,
+ TimeToFirstWindowDraw, TransformationMatrix, Underline, UnderlineStyle, View, VisualContext,
+ WeakView, WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls,
+ WindowDecorations, WindowOptions, WindowParams, WindowTextSystem, SUBPIXEL_VARIANTS,
};
use anyhow::{anyhow, Context as _, Result};
use collections::{FxHashMap, FxHashSet};
@@ -544,6 +544,8 @@ pub struct Window {
hovered: Rc>,
pub(crate) dirty: Rc>,
pub(crate) needs_present: Rc| >,
+ /// We assign this to be notified when the platform graphics backend fires the next completion callback for drawing the window.
+ present_completed: RefCell | | |